├── assets
├── main_table.png
├── method_NFT.jpg
├── distribution_NFT.jpg
├── main_compare_NFT.jpg
├── val_acc_curve_NFT.jpg
└── algorithm_spectrum_NFT.jpg
├── requirements.txt
├── download_model.sh
├── download_data.sh
├── eval_local_7B.sh
├── eval_local_32B.sh
├── experience_maker.py
├── .gitignore
├── verifier.py
├── compute_acc.py
├── README.md
├── train_7B.sh
├── train_32B.sh
├── main_nft.py
├── config
└── nft_trainer.yaml
├── qwen_math_eval_toolkit
├── utils.py
├── grader.py
├── parser.py
└── examples.py
├── README_VeRL.md
├── LICENSE
└── dp_actor.py
/assets/main_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/main_table.png
--------------------------------------------------------------------------------
/assets/method_NFT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/method_NFT.jpg
--------------------------------------------------------------------------------
/assets/distribution_NFT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/distribution_NFT.jpg
--------------------------------------------------------------------------------
/assets/main_compare_NFT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/main_compare_NFT.jpg
--------------------------------------------------------------------------------
/assets/val_acc_curve_NFT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/val_acc_curve_NFT.jpg
--------------------------------------------------------------------------------
/assets/algorithm_spectrum_NFT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/algorithm_spectrum_NFT.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # requirements.txt records the full set of dependencies for development
2 | accelerate
3 | codetiming
4 | datasets
5 | dill
6 | flash-attn
7 | hydra-core
8 | liger-kernel
9 | math-verify[antlr4_9_3]
10 | numpy
11 | pandas
12 | peft
13 | pyarrow>=15.0.0
14 | pybind11
15 | pylatexenc
16 | ray[default]
17 | tensordict<0.6
18 | torchdata
19 | transformers
20 | vllm<=0.6.3
21 | wandb
22 |
--------------------------------------------------------------------------------
/download_model.sh:
--------------------------------------------------------------------------------
1 | mkdir -p models
2 | # required only for 7B models
3 | huggingface-cli download Qwen/Qwen2.5-Math-7B --local-dir ./models/Qwen2.5-Math-7B --local-dir-use-symlinks False
4 |
5 | # required only for 32B models
6 | huggingface-cli download Qwen/Qwen2.5-32B --local-dir ./models/Qwen2.5-32B --local-dir-use-symlinks False
7 | # replace 32B tokenizer config to keep a consistent system prompt with 7B
8 | cp models/32b_tokenizer_config.json models/Qwen2.5-32B/tokenizer_config.json
9 |
10 |
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | HF_REPO_BASE="https://huggingface.co/datasets/ChenDRAG/VeRL_math_validation/resolve/main"
2 | mkdir -p data
3 | wget -O data/dapo-math-17k_10boxed.parquet "${HF_REPO_BASE}/dapo-math-17k_10boxed.parquet?download=true"
4 | wget -O data/aime-2024-boxed_w_answer.parquet "${HF_REPO_BASE}/aime-2024-boxed_w_answer.parquet?download=true"
5 | wget -O data/math500_boxed.parquet "${HF_REPO_BASE}/math500_boxed.parquet?download=true"
6 | wget -O data/minerva_math.parquet "${HF_REPO_BASE}/minerva_math.parquet?download=true"
7 | wget -O data/olympiadbench.parquet "${HF_REPO_BASE}/olympiadbench.parquet?download=true"
8 | wget -O data/aime2025_32_dapo_boxed_w_answer.parquet "${HF_REPO_BASE}/aime2025_32_dapo_boxed_w_answer.parquet?download=true"
9 | wget -O data/amc2023_32_dapo_boxed_w_answer.parquet "${HF_REPO_BASE}/amc2023_32_dapo_boxed_w_answer.parquet?download=true"
--------------------------------------------------------------------------------
/eval_local_7B.sh:
--------------------------------------------------------------------------------
1 | pip install word2number
2 | save_path=./NFT_32B
3 | context_length=8192
4 | max_prompt_length=1024
5 | n_samples=1
6 | top_p=0.7
7 | temperature=1.0
8 | top_k=-1
9 | max_response_length=$(($context_length-$max_prompt_length))
10 |
11 |
12 | for data_name in aime-2024-boxed_w_answer math500_boxed minerva_math olympiadbench aime2025_32_dapo_boxed_w_answer amc2023_32_dapo_boxed_w_answer; do
13 | data_load_path=./data/$data_name.parquet; \
14 | data_save_path=${save_path}/${data_name}_max${context_length}_topp${top_p}topk${top_k}_temp${temperature}_@${n_samples}.parquet; \
15 | python -u -m verl.trainer.main_generation \
16 | trainer.nnodes=1 \
17 | trainer.n_gpus_per_node=8 \
18 | data.path=$data_load_path \
19 | data.prompt_key=prompt \
20 | data.n_samples=$n_samples \
21 | data.output_path=$data_save_path \
22 | model.path=$save_path/huggingface \
23 | +model.trust_remote_code=True \
24 | rollout.temperature=$temperature \
25 | rollout.top_k=$top_k \
26 | rollout.top_p=$top_p \
27 | rollout.prompt_length=$max_prompt_length \
28 | rollout.response_length=$max_response_length \
29 | rollout.tensor_model_parallel_size=4 \
30 | rollout.gpu_memory_utilization=0.8; \
31 | python3 -m compute_acc --input_path $data_save_path --verifier all
32 | done
--------------------------------------------------------------------------------
/eval_local_32B.sh:
--------------------------------------------------------------------------------
1 | pip install word2number
2 | save_path=./NFT_7B
3 | context_length=4096
4 | max_prompt_length=512
5 | n_samples=1
6 | top_p=0.7
7 | temperature=0.6
8 | top_k=-1
9 | max_response_length=$(($context_length-$max_prompt_length))
10 |
11 |
12 | for data_name in aime-2024-boxed_w_answer math500_boxed minerva_math olympiadbench aime2025_32_dapo_boxed_w_answer amc2023_32_dapo_boxed_w_answer; do
13 | data_load_path=./data/$data_name.parquet; \
14 | data_save_path=${save_path}/${data_name}_max${context_length}_topp${top_p}topk${top_k}_temp${temperature}_@${n_samples}.parquet; \
15 | python -u -m verl.trainer.main_generation \
16 | trainer.nnodes=1 \
17 | trainer.n_gpus_per_node=8 \
18 | data.path=$data_load_path \
19 | data.prompt_key=prompt \
20 | data.n_samples=$n_samples \
21 | data.output_path=$data_save_path \
22 | model.path=$save_path/huggingface \
23 | +model.trust_remote_code=True \
24 | rollout.temperature=$temperature \
25 | rollout.top_k=$top_k \
26 | rollout.top_p=$top_p \
27 | rollout.prompt_length=$max_prompt_length \
28 | rollout.response_length=$max_response_length \
29 | rollout.tensor_model_parallel_size=4 \
30 | rollout.gpu_memory_utilization=0.8; \
31 | python3 -m verl.trainer.compute_acc --input_path $data_save_path --verifier all;
32 | done
--------------------------------------------------------------------------------
/experience_maker.py:
--------------------------------------------------------------------------------
1 | from qwen_math_eval_toolkit.parser import extract_answer as qwen_extract_answer
2 | from qwen_math_eval_toolkit.grader import math_equal as qwen_math_equal
3 |
4 | from multiprocessing import Process, Queue
5 | def qwen_math_equal_subprocess(prediction, reference, timeout_seconds=10):
6 | def worker(q, prediction, reference):
7 | result = qwen_math_equal(prediction=prediction, reference=reference, timeout=False)
8 | q.put(result)
9 |
10 | q = Queue()
11 | p = Process(target=worker, args=(q, prediction, reference))
12 | p.start()
13 |
14 | # 添加超时处理
15 | p.join(timeout=timeout_seconds) # 等待进程完成,最多等待 timeout_seconds 秒
16 |
17 | # 如果进程还在运行,则终止它并返回 False
18 | if p.is_alive():
19 | p.terminate()
20 | p.join() # 确保进程被完全清理
21 | return False
22 |
23 | # 如果进程正常完成,获取结果
24 | try:
25 | return q.get_nowait()
26 | except:
27 | return False
28 |
29 | import re
30 | def preprocess_box_response_for_qwen_prompt(sequence, answer):
31 | # breakpoint()
32 | model_output= re.sub(r'^.*?<\|im_start\|>assistant', '<|im_start|>assistant', sequence, flags=re.DOTALL,count = 1)
33 | stop_words = ["", "<|im_end|>", "<|endoftext|>"]
34 | for stop_word in stop_words:
35 | if stop_word in model_output:
36 | model_output = model_output.split(stop_word)[0].strip()
37 | extract_answer = qwen_extract_answer(model_output, data_name="math") #TODO: check the data_name, hard code here for now
38 |
39 | if qwen_math_equal_subprocess(prediction=extract_answer, reference=answer):
40 | return 1.0
41 | else:
42 | return 0.0
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | **/*.pt
3 | **/checkpoints
4 | **/wget-log
5 | **/_build/
6 | **/*.ckpt
7 | **/outputs
8 | **/*.tar.gz
9 | **/playground
10 | **/wandb
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 | dataset/*
17 | tensorflow/my_graph/*
18 | .idea/
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *,cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # IPython Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # dotenv
92 | .env
93 |
94 | # virtualenv
95 | venv/
96 | ENV/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # vscode
105 | .vscode
106 |
107 | # Mac
108 | .DS_Store
109 |
110 | # output logs
111 | tests/e2e/toy_examples/deepspeed/synchronous/output.txt
112 |
113 | # vim
114 | *.swp
115 |
116 | # ckpt
117 | *.lock
118 |
119 | # data
120 | *.parquet
121 |
122 |
123 | # local logs
124 | logs
125 | log
126 |
127 | ckpts/
128 | models/
--------------------------------------------------------------------------------
/verifier.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
17 | #
18 | # Licensed under the Apache License, Version 2.0 (the "License");
19 | # you may not use this file except in compliance with the License.
20 | # You may obtain a copy of the License at
21 | #
22 | # http://www.apache.org/licenses/LICENSE-2.0
23 | #
24 | # Unless required by applicable law or agreed to in writing, software
25 | # distributed under the License is distributed on an "AS IS" BASIS,
26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 | # See the License for the specific language governing permissions and
28 | # limitations under the License.
29 |
30 | from math_verify.metric import math_metric
31 | from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
32 |
33 | def math_verify_compute_score(model_output: str, ground_truth: str) -> bool:
34 | verify_func = math_metric(
35 | gold_extraction_target=(LatexExtractionConfig(),),
36 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
37 | )
38 | ret_score = 0.
39 |
40 | # Wrap the ground truth in \boxed{} format for verification
41 | ground_truth_boxed = "\\boxed{" + ground_truth + "}"
42 | try:
43 | ret_score, _ = verify_func([ground_truth_boxed], [model_output])
44 | except Exception as e:
45 | print(e)
46 |
47 | return {
48 | "score": 1.0 if ret_score > 0.5 else -1.0,
49 | "acc": 1.0 if ret_score > 0.5 else 0.0,
50 | "pred": model_output,
51 | }
52 |
53 | def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None):
54 | if data_source == 'openai/gsm8k':
55 | from verl.utils.reward_score import gsm8k
56 | res = gsm8k.compute_score(solution_str, ground_truth)
57 | elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval', 'math500', "math500_w_answer", "olympiadbench", "minerva_math"]:
58 | # from verl.utils.reward_score import math
59 | # res = math.compute_score(solution_str, ground_truth)
60 | # Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy
61 | from verl.utils.reward_score import math_verify
62 | res = math_verify.compute_score(solution_str, ground_truth)
63 | elif data_source == 'math_dapo':
64 | from verl.utils.reward_score import math_dapo
65 | res = math_dapo.compute_score(solution_str, ground_truth)
66 | elif data_source in ['math_dapo_boxed', 'amc_dapo_boxed', 'aime_2025_dapo_boxed', 'amc2023_dapo_boxed']:
67 | from verl.utils.reward_score import math_dapo
68 | res = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True)
69 | elif data_source in [
70 | 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
71 | 'numina_olympiads'
72 | ]:
73 | from verl.utils.reward_score import prime_math
74 | res = prime_math.compute_score(solution_str, ground_truth)
75 | elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']:
76 | from verl.utils.reward_score import prime_code
77 | res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
78 | elif data_source in ['hiyouga/geometry3k']:
79 | from verl.utils.reward_score import geo3k
80 | res = geo3k.compute_score(solution_str, ground_truth)
81 | else:
82 | raise NotImplementedError
83 |
84 | if isinstance(res, dict):
85 | return res
86 | elif isinstance(res, (int, float, bool)):
87 | return float(res)
88 | else:
89 | return float(res[0])
90 |
--------------------------------------------------------------------------------
/compute_acc.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import pandas as pd
17 | import json
18 | from experience_maker import preprocess_box_response_for_qwen_prompt
19 | from verl.utils.reward_score import prime_math, math_verify, math_dapo
20 | import argparse
21 |
22 |
23 | RESPONSE_COL = "responses"
24 | GROUND_TRUTH_COL = "reward_model"
25 | GROUND_TRUTH_KEY = "ground_truth"
26 |
27 |
28 | def calculate_accuracy(file_path: str, verifier_type: str) -> None:
29 | """Loads data, calculates accuracy, and adds verifier results to DataFrame."""
30 | try:
31 | df = pd.read_parquet(file_path)
32 | print(f"Loaded {len(df)} rows from {file_path}")
33 | except Exception as e:
34 | print(f"Error loading Parquet file '{file_path}': {e}")
35 | return
36 |
37 | total_correct = 0
38 | total_checked = 0
39 | all_results = []
40 |
41 | verifier = {
42 | "qwen": preprocess_box_response_for_qwen_prompt,
43 | "prime": lambda response, answer: prime_math.compute_score(response, answer)[
44 | "acc"
45 | ],
46 | "math": lambda response, answer: math_verify.compute_score(
47 | response, answer
48 | )["acc"],
49 | "dapo": lambda response, answer: math_dapo.compute_score(
50 | response, answer, strict_box_verify=True
51 | )["acc"],
52 | }[verifier_type]
53 |
54 | for index, row in df.iterrows():
55 | # 1. Get Ground Truth
56 | gt_data = row[GROUND_TRUTH_COL]
57 | if isinstance(gt_data, str):
58 | gt_dict = json.loads(gt_data)
59 | else:
60 | gt_dict = gt_data
61 | ground_truth_answer = str(gt_dict[GROUND_TRUTH_KEY])
62 |
63 | # 2. Get Responses (list)
64 | generated_responses = row[RESPONSE_COL]
65 |
66 | # 3. Compare each response
67 | row_results = []
68 | for response in generated_responses:
69 | try:
70 | is_correct = verifier(str(response), ground_truth_answer)
71 | except Exception as e:
72 | print(f"Warning: Skipping due to error: {e}")
73 | is_correct=-1
74 | is_correct = float(is_correct)
75 | row_results.append(is_correct)
76 | total_correct += is_correct
77 | total_checked += 1
78 | all_results.append(row_results)
79 |
80 | # Add result column to DataFrame
81 | df[f"{verifier_type}_results"] = all_results
82 |
83 | # Save updated DataFrame
84 | df.to_parquet(file_path)
85 | print(f"Updated DataFrame saved to: {file_path}")
86 |
87 | # Print Summary
88 | if total_checked > 0:
89 | accuracy = total_correct / total_checked
90 | print(f"\n--- Results --- for {verifier_type}")
91 | print(f"Total responses checked: {total_checked}")
92 | print(f"Total correct responses: {int(total_correct)}")
93 | print(f"Accuracy: {accuracy:.4f}")
94 | else:
95 | print("\n--- Results ---")
96 | print("No responses were checked.")
97 |
98 |
99 | if __name__ == "__main__":
100 | parser = argparse.ArgumentParser(
101 | description="Calculate accuracy from generated responses in a Parquet file."
102 | )
103 | parser.add_argument("--input_path", help="Path to the input Parquet file.")
104 | parser.add_argument(
105 | "--verifier", default="qwen", choices=["qwen", "prime", "math", "dapo", "all"]
106 | )
107 | args = parser.parse_args()
108 | if args.verifier == "all":
109 | for ven in ["qwen", "prime", "math", "dapo"]:
110 | calculate_accuracy(args.input_path, ven)
111 | else:
112 | calculate_accuracy(args.input_path, args.verifier)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Negative-aware Fine-Tuning (NFT): Bridging Supervised Learning and Reinforcement Learning in Math Reasoning
3 |
4 |
5 | [](https://arxiv.org/pdf/2505.18116)
6 | [](https://research.nvidia.com/labs/dir/Negative-aware-Fine-Tuning/)
7 | [](https://huggingface.co/datasets/ChenDRAG/VeRL_math_validation)
8 | [](https://huggingface.co/nvidia/NFT-32B)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | NFT is a pure supervised learning method for improving LLMs' math-reasoning abilities with no external teachers.
17 |
18 | - As an SL method, NFT outperforms leading RL algorithms like GRPO and DAPO in 7B model experiments and performs similarly to DAPO in 32B settings.
19 | - NFT allows directly optimizing LLMs on negative data, thereby significantly outperforming other SL baselines such as Rejective sampling Fine-Tuning (RFT).
20 | - NFT is equivalent to GRPO when training is strictly on-policy, despite their entirely different theoretical foundations.
21 |
22 | NFT shows self-reflective improvement is not an inherent priority of RL algorithms. Rather, the current gap between SL and RL methods actually stems from their ability to effectively leverage negative data.
23 |
24 | ## Algorithm Overview
25 |
26 | NFT bridges reinforcement learning and supervised learning methods through the leverage of negative feedback via supervision.
27 |
28 |
29 |
30 |
31 |
32 |
33 | The NFT pipeline consists of:
34 | 1. **Data Collection:** LLM generates answers to math questions, split into positive/negative based on correctness
35 | 2. **Implicit Negative Policy:** Constructs a policy to model negative answers using the same parameters as the positive policy
36 | 3. **Policy Optimization:** Both positive and negative answers optimize the LLM via supervised learning
37 |
38 |
39 |
40 |
41 |
42 | ## Experimental Results
43 |
44 | Comparison of NFT-7B with other zero-shot math models in the Qwen series.
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 | NFT performs competitively compared with other algorithms. We report avg@32 for AIME24, AIME25, and AMC23 and avg@1 for others.
53 |
54 |
55 |
56 |
57 |
58 |
59 | Validation accuracy curves showing NFT's ability to leverage negative data for continuous improvement.
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | ## Evaluation
68 |
69 | ### Environment setup
70 | We use exactly the same environment configuration as the official DAPO codebase.
71 | ```bash
72 | pip install git+ssh://git@github.com/volcengine/verl.git@01ef7184821d0d7844796ec0ced17665c1f50673
73 | ```
74 |
75 | ### Benchmarking
76 |
77 | Pretrained [7B](https://huggingface.co/nvidia/NFT-7B) and [32B](https://huggingface.co/nvidia/NFT-32B) models can be found at Huggingface.
78 |
79 | We provide the evaluation codebase integrated in the VeRL infra:
80 |
81 | Please refer to `eval_local_7B.sh` and `eval_local_32B.sh` for evaluation scripts.
82 |
83 | ## Training
84 |
85 | ### Environment setup
86 | We use exactly the same environment configuration as the official DAPO codebase.
87 | ```bash
88 | pip install git+ssh://git@github.com/volcengine/verl.git@01ef7184821d0d7844796ec0ced17665c1f50673
89 | ```
90 |
91 | ### Datasets
92 | We employ public dataset [DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k) for training, and 6 public math benchmarks for validation. Download pre-sorted training and validation data by
93 | ```bash
94 | bash download_data.sh
95 | ```
96 |
97 | ### Base Model
98 | ```bash
99 | bash download_model.sh
100 | ```
101 |
102 | ### Starting Experiments
103 | Please see `train_7B.sh` and `train_32B.sh` for a running script (one node). Note that we run 7B experiments using 4×8 H100s, and 32B experiments using 16×8 H100s. Please refer to the instruction of [VeRL](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main) for launching distributed tasks.
104 |
105 | Hyperparameter:
106 |
107 | - `neg_weight`: The weight of negative data in NFT's objective. Set to 1.0 for default NFT config. Set to 0.0 for RFT by masking out all negative data loss. Set to -1.0 for the DAPO algorithm for comparison.
108 | - `normalize`: Controls the prompt weight in NFT's objective. Set to 0 so that all question data is treated equally. Set to 1 (default) or 2 to prioritize harder questions. `normalize=1` matches Dr. GRPO algorithm in on-policy training, while `normalize=2` matches standard GRPO.
109 |
110 | ## Acknowledgement
111 |
112 | We thank the [verl](https://github.com/volcengine/verl) for providing the awesome open-source RL infrastructure.
113 |
114 | ## Citation
115 | If you find our project helpful, please consider citing
116 | ```bibtex
117 | @article{chen2025bridging,
118 | title = {Bridging Supervised Learning and Reinforcement Learning in Math Reasoning},
119 | author = {Huayu Chen, Kaiwen Zheng, Qinsheng Zhang, Ganqu Cui, Yin Cui, Haotian Ye, Tsung-Yi Lin, Ming-Yu Liu, Jun Zhu, Haoxiang Wang},
120 | journal = {arXiv preprint arXiv:2505.18116},
121 | year = {2025}
122 | }
123 | ```
--------------------------------------------------------------------------------
/train_7B.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -euxo pipefail
3 |
4 | # WANDB configuration
5 | export WANDB_API_KEY=xxxx
6 | export WANDB_ENTITY=xxxx
7 |
8 | project_name='NFT'
9 | exp_name='NFT-Qwen2.5-7B-Math-Test'
10 |
11 | adv_estimator=grpo
12 |
13 | kl_coef=0.0
14 | kl_loss_coef=0.0
15 |
16 | clip_ratio_low=0.2
17 | clip_ratio_high=0.28
18 |
19 | enable_overlong_buffer=False
20 | overlong_buffer_len=0
21 | overlong_penalty_factor=0.0
22 |
23 | enable_filter_groups=True
24 | filter_groups_metric=acc
25 | max_num_gen_batches=10
26 | train_prompt_bsz=512
27 | gen_prompt_bsz=$((train_prompt_bsz * 3))
28 | train_prompt_mini_bsz=32
29 | n_resp_per_prompt=16
30 |
31 | use_token_level_loss=True
32 |
33 | # Ray
34 | WORKING_DIR=${WORKING_DIR:-"${PWD}"}
35 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
36 | NNODES=${NNODES:-1}
37 | # Paths
38 | verl_workdir=.
39 | RAY_DATA_HOME=${RAY_DATA_HOME:-"${verl_workdir}"}
40 | MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
41 | CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
42 | TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k_10boxed.parquet"}
43 |
44 | # Algorithm
45 | ## Train
46 | max_prompt_length=$((512))
47 | max_response_length=$((4096-512))
48 | ## Validation
49 | val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
50 |
51 | # NFT specific parameters
52 | neg_weight=1.0 # -1.0 for DAPO, 1.0 for NFT, 0.0 for RFT
53 | ratio_type="token"
54 | ppo_epoch=1
55 | bugged_dynamic_scale=0
56 | clamp_negative=1.0
57 | clamp_positive=0.0
58 | normalize=1
59 |
60 | # Performance Related Parameter
61 | sp_size=1
62 | use_dynamic_bsz=True
63 | actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
64 | infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
65 | gen_tp=4
66 | actor_offload=False
67 | ref_offload=True
68 |
69 | python3 -m main_nft \
70 | data.train_files="${TRAIN_FILE}" \
71 | data.val_files="['./data/aime-2024-boxed_w_answer.parquet', './data/math500_boxed.parquet', './data/minerva_math.parquet', './data/olympiadbench.parquet', './data/aime2025_32_dapo_boxed_w_answer.parquet', './data/amc2023_32_dapo_boxed_w_answer.parquet']" \
72 | data.prompt_key=prompt \
73 | data.truncation='left' \
74 | data.max_prompt_length=${max_prompt_length} \
75 | data.max_response_length=${max_response_length} \
76 | data.filter_overlong_prompts=True \
77 | data.gen_batch_size=${gen_prompt_bsz} \
78 | data.train_batch_size=${train_prompt_bsz} \
79 | actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
80 | algorithm.adv_estimator=${adv_estimator} \
81 | algorithm.kl_ctrl.kl_coef=${kl_coef} \
82 | actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
83 | actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
84 | actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
85 | algorithm.filter_groups.enable=${enable_filter_groups} \
86 | algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
87 | algorithm.filter_groups.metric=${filter_groups_metric} \
88 | actor_rollout_ref.model.use_remove_padding=True \
89 | actor_rollout_ref.neg_weight=${neg_weight} \
90 | actor_rollout_ref.actor.ratio_type=${ratio_type} \
91 | actor_rollout_ref.actor.ppo_epochs=${ppo_epoch} \
92 | actor_rollout_ref.bugged_dynamic_scale=${bugged_dynamic_scale} \
93 | actor_rollout_ref.clamp_negative=${clamp_negative} \
94 | actor_rollout_ref.clamp_positive=${clamp_positive} \
95 | actor_rollout_ref.normalize=${normalize} \
96 | actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
97 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
98 | actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
99 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
100 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
101 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
102 | actor_rollout_ref.model.path="${MODEL_PATH}" \
103 | +actor_rollout_ref.model.override_config.attention_dropout=0. \
104 | +actor_rollout_ref.model.override_config.embd_pdrop=0. \
105 | +actor_rollout_ref.model.override_config.resid_pdrop=0. \
106 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
107 | actor_rollout_ref.actor.optim.lr=1e-6 \
108 | actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
109 | actor_rollout_ref.actor.optim.weight_decay=0.1 \
110 | actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
111 | actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
112 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
113 | actor_rollout_ref.actor.entropy_coeff=0 \
114 | actor_rollout_ref.actor.grad_clip=1.0 \
115 | actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \
116 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
117 | actor_rollout_ref.rollout.gpu_memory_utilization=0.70 \
118 | actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
119 | actor_rollout_ref.rollout.enable_chunked_prefill=True \
120 | actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
121 | actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \
122 | actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \
123 | actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
124 | actor_rollout_ref.rollout.val_kwargs.n=1 \
125 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \
126 | actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
127 | actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
128 | actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
129 | custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \
130 | custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \
131 | custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
132 | trainer.logger=['console','wandb'] \
133 | trainer.project_name="${project_name}" \
134 | trainer.experiment_name="${exp_name}" \
135 | trainer.n_gpus_per_node=8 \
136 | trainer.nnodes="${NNODES}" \
137 | +trainer.val_before_train=False \
138 | trainer.test_freq=2 \
139 | trainer.save_freq=2 \
140 | trainer.total_epochs=100 \
141 | trainer.default_local_dir="${CKPTS_DIR}" \
142 | trainer.resume_mode=auto
--------------------------------------------------------------------------------
/train_32B.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -euxo pipefail
3 |
4 | # WANDB configuration
5 | export WANDB_API_KEY=xxxx
6 | export WANDB_ENTITY=xxxx
7 |
8 | project_name='NFT'
9 | exp_name='NFT-Qwen2.5-32B-Test'
10 |
11 | adv_estimator=grpo
12 |
13 | kl_coef=0.0
14 | kl_loss_coef=0.0
15 |
16 | clip_ratio_low=0.2
17 | clip_ratio_high=0.28
18 |
19 | enable_filter_groups=True
20 | filter_groups_metric=acc
21 | max_num_gen_batches=10
22 | train_prompt_bsz=512
23 | gen_prompt_bsz=$((train_prompt_bsz * 3))
24 | train_prompt_mini_bsz=32
25 | n_resp_per_prompt=16
26 |
27 | use_token_level_loss=True
28 |
29 | # Ray
30 | WORKING_DIR=${WORKING_DIR:-"${PWD}"}
31 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
32 | NNODES=${NNODES:-1}
33 | # Paths
34 | verl_workdir=.
35 | RAY_DATA_HOME=${RAY_DATA_HOME:-"${verl_workdir}"}
36 | MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
37 | CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
38 | TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k_10boxed.parquet"}
39 |
40 | # Algorithm
41 | ## Train
42 | ## Validation
43 | val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
44 |
45 | # NFT specific parameters
46 | neg_weight=1.0 # -1.0 for DAPO, 1.0 for NFT, 0.0 for RFT (try 0.5 for 32B)
47 | ratio_type="token"
48 | ppo_epoch=1
49 | bugged_dynamic_scale=0
50 | clamp_negative=1.0 # 0.5
51 | clamp_positive=0.0
52 | normalize=1
53 | if [ "$neg_weight" = "-1.0" ]; then
54 | max_prompt_length=$((1024 * 2))
55 | max_response_length=$((1024 * 14))
56 | enable_overlong_buffer=True
57 | overlong_buffer_len=$((1024 * 4))
58 | overlong_penalty_factor=1.0
59 | else
60 | max_prompt_length=$((1024 * 2))
61 | max_response_length=$((1024 * 14))
62 | enable_overlong_buffer=False
63 | overlong_buffer_len=0
64 | overlong_penalty_factor=0.0
65 | fi
66 |
67 | # Performance Related Parameter
68 | sp_size=8
69 | use_dynamic_bsz=True
70 | actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
71 | infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
72 | gen_tp=4
73 | actor_offload=False
74 | ref_offload=True
75 |
76 | python3 -m main_nft \
77 | data.train_files="${TRAIN_FILE}" \
78 | data.val_files="['./data/aime-2024-boxed_w_answer.parquet', './data/math500_boxed.parquet', './data/minerva_math.parquet', './data/olympiadbench.parquet', './data/aime2025_32_dapo_boxed_w_answer.parquet', './data/amc2023_32_dapo_boxed_w_answer.parquet']" \
79 | data.prompt_key=prompt \
80 | data.truncation='left' \
81 | data.max_prompt_length=${max_prompt_length} \
82 | data.max_response_length=${max_response_length} \
83 | data.filter_overlong_prompts=True \
84 | data.gen_batch_size=${gen_prompt_bsz} \
85 | data.train_batch_size=${train_prompt_bsz} \
86 | actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
87 | algorithm.adv_estimator=${adv_estimator} \
88 | algorithm.kl_ctrl.kl_coef=${kl_coef} \
89 | actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
90 | actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
91 | actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
92 | algorithm.filter_groups.enable=${enable_filter_groups} \
93 | algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
94 | algorithm.filter_groups.metric=${filter_groups_metric} \
95 | actor_rollout_ref.model.use_remove_padding=True \
96 | actor_rollout_ref.neg_weight=${neg_weight} \
97 | actor_rollout_ref.actor.ratio_type=${ratio_type} \
98 | actor_rollout_ref.actor.ppo_epochs=${ppo_epoch} \
99 | actor_rollout_ref.bugged_dynamic_scale=${bugged_dynamic_scale} \
100 | actor_rollout_ref.clamp_negative=${clamp_negative} \
101 | actor_rollout_ref.clamp_positive=${clamp_positive} \
102 | actor_rollout_ref.normalize=${normalize} \
103 | actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
104 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
105 | actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
106 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
107 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
108 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
109 | actor_rollout_ref.model.path="${MODEL_PATH}" \
110 | +actor_rollout_ref.model.override_config.attention_dropout=0. \
111 | +actor_rollout_ref.model.override_config.embd_pdrop=0. \
112 | +actor_rollout_ref.model.override_config.resid_pdrop=0. \
113 | actor_rollout_ref.model.enable_gradient_checkpointing=True \
114 | actor_rollout_ref.actor.optim.lr=1e-6 \
115 | actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
116 | actor_rollout_ref.actor.optim.weight_decay=0.1 \
117 | actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
118 | actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
119 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
120 | actor_rollout_ref.actor.entropy_coeff=0 \
121 | actor_rollout_ref.actor.grad_clip=1.0 \
122 | actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \
123 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
124 | actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \
125 | actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
126 | actor_rollout_ref.rollout.enable_chunked_prefill=True \
127 | actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
128 | actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \
129 | actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \
130 | actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
131 | actor_rollout_ref.rollout.val_kwargs.n=1 \
132 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \
133 | actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
134 | actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
135 | actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
136 | custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \
137 | custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \
138 | custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
139 | trainer.logger=['console','wandb'] \
140 | trainer.project_name="${project_name}" \
141 | trainer.experiment_name="${exp_name}" \
142 | trainer.n_gpus_per_node=8 \
143 | trainer.nnodes="${NNODES}" \
144 | +trainer.val_before_train=False \
145 | trainer.test_freq=2 \
146 | trainer.save_freq=2 \
147 | trainer.total_epochs=100 \
148 | trainer.default_local_dir="${CKPTS_DIR}" \
149 | trainer.resume_mode=auto
--------------------------------------------------------------------------------
/main_nft.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
17 | #
18 | # Licensed under the Apache License, Version 2.0 (the "License");
19 | # you may not use this file except in compliance with the License.
20 | # You may obtain a copy of the License at
21 | #
22 | # http://www.apache.org/licenses/LICENSE-2.0
23 | #
24 | # Unless required by applicable law or agreed to in writing, software
25 | # distributed under the License is distributed on an "AS IS" BASIS,
26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 | # See the License for the specific language governing permissions and
28 | # limitations under the License.
29 | """
30 | Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
31 | """
32 |
33 | import ray
34 | import hydra
35 |
36 |
37 | def get_custom_reward_fn(config):
38 | import importlib.util, os
39 |
40 | reward_fn_config = config.get("custom_reward_function") or {}
41 | file_path = reward_fn_config.get("path")
42 | if not file_path:
43 | return None
44 |
45 | if not os.path.exists(file_path):
46 | raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
47 |
48 | spec = importlib.util.spec_from_file_location("custom_module", file_path)
49 | module = importlib.util.module_from_spec(spec)
50 | try:
51 | spec.loader.exec_module(module)
52 | except Exception as e:
53 | raise RuntimeError(f"Error loading module from '{file_path}': {e}")
54 |
55 | function_name = reward_fn_config.get("name")
56 |
57 | if not hasattr(module, function_name):
58 | raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
59 |
60 | print(f"using customized reward function '{function_name}' from '{file_path}'")
61 |
62 | return getattr(module, function_name)
63 |
64 |
65 | @hydra.main(config_path='config', config_name='nft_trainer', version_base=None)
66 | def main(config):
67 | run_nft(config)
68 |
69 |
70 | def run_nft(config) -> None:
71 |
72 | if not ray.is_initialized():
73 | # this is for local ray cluster
74 | ray.init(runtime_env={
75 | 'env_vars': {
76 | 'TOKENIZERS_PARALLELISM': 'true',
77 | 'NCCL_DEBUG': 'WARN',
78 | 'VLLM_LOGGING_LEVEL': 'WARN'
79 | }
80 | })
81 |
82 | ray.get(main_task.remote(config))
83 |
84 |
85 | @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
86 | def main_task(config):
87 | import verl.workers.actor
88 | import verl.workers.actor.dp_actor
89 | import verl.utils.reward_score
90 | # override the original definition of DataParallelPPOActor
91 | import dp_actor
92 | verl.workers.actor.DataParallelPPOActor = dp_actor.DataParallelPPOActor
93 | verl.workers.actor.dp_actor.DataParallelPPOActor = dp_actor.DataParallelPPOActor
94 |
95 | import verl.utils.reward_score.math_verify
96 | import verifier
97 | verl.utils.reward_score._default_compute_score = verifier._default_compute_score
98 | verl.utils.reward_score.math_verify.compute_score = verifier.math_verify_compute_score
99 |
100 | from ray_nft_trainer import RayNFTTrainer
101 | from verl.utils.fs import copy_to_local
102 | # print initial config
103 | from pprint import pprint
104 | from omegaconf import OmegaConf
105 | pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
106 | OmegaConf.resolve(config)
107 |
108 | # download the checkpoint from hdfs
109 | local_path = copy_to_local(config.actor_rollout_ref.model.path)
110 |
111 | # instantiate tokenizer
112 | from verl.utils import hf_tokenizer, hf_processor
113 | tokenizer = hf_tokenizer(local_path)
114 | processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
115 |
116 | # define worker classes
117 | if config.actor_rollout_ref.actor.strategy == 'fsdp':
118 | assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
119 | from fsdp_workers import ActorRolloutRefWorker, CriticWorker
120 | from verl.single_controller.ray import RayWorkerGroup
121 | ray_worker_group_cls = RayWorkerGroup
122 |
123 | elif config.actor_rollout_ref.actor.strategy == 'megatron':
124 | assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
125 | from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
126 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
127 | ray_worker_group_cls = NVMegatronRayWorkerGroup
128 |
129 | else:
130 | raise NotImplementedError
131 |
132 | from ray_nft_trainer import ResourcePoolManager, Role
133 |
134 | role_worker_mapping = {
135 | Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
136 | Role.Critic: ray.remote(CriticWorker),
137 | Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
138 | }
139 |
140 | global_pool_id = 'global_pool'
141 | resource_pool_spec = {
142 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
143 | }
144 | mapping = {
145 | Role.ActorRollout: global_pool_id,
146 | Role.Critic: global_pool_id,
147 | Role.RefPolicy: global_pool_id,
148 | }
149 |
150 | # we should adopt a multi-source reward function here
151 | # - for rule-based rm, we directly call a reward score
152 | # - for model-based rm, we call a model
153 | # - for code related prompt, we send to a sandbox if there are test cases
154 | # - finally, we combine all the rewards together
155 | # - The reward type depends on the tag of the data
156 | if config.reward_model.enable:
157 | if config.reward_model.strategy == 'fsdp':
158 | from fsdp_workers import RewardModelWorker
159 | elif config.reward_model.strategy == 'megatron':
160 | from verl.workers.megatron_workers import RewardModelWorker
161 | else:
162 | raise NotImplementedError
163 | role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
164 | mapping[Role.RewardModel] = global_pool_id
165 |
166 | reward_manager_name = config.reward_model.get("reward_manager", "naive")
167 | if reward_manager_name == 'naive':
168 | from verl.workers.reward_manager import NaiveRewardManager
169 | reward_manager_cls = NaiveRewardManager
170 | elif reward_manager_name == 'prime':
171 | from verl.workers.reward_manager import PrimeRewardManager
172 | reward_manager_cls = PrimeRewardManager
173 | else:
174 | raise NotImplementedError
175 |
176 | compute_score = get_custom_reward_fn(config)
177 | reward_fn = reward_manager_cls(tokenizer=tokenizer,
178 | num_examine=0,
179 | compute_score=compute_score,
180 | reward_fn_key=config.data.reward_fn_key,
181 | max_resp_len=config.data.max_response_length,
182 | overlong_buffer_cfg=config.custom_reward_function.overlong_buffer)
183 |
184 | # Note that we always use function-based RM for validation
185 | val_reward_fn = reward_manager_cls(tokenizer=tokenizer,
186 | num_examine=1,
187 | compute_score=compute_score,
188 | reward_fn_key=config.data.reward_fn_key,
189 | max_resp_len=config.data.max_response_length,
190 | overlong_buffer_cfg=config.custom_reward_function.overlong_buffer)
191 |
192 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
193 |
194 | trainer = RayNFTTrainer(config=config,
195 | tokenizer=tokenizer,
196 | processor=processor,
197 | role_worker_mapping=role_worker_mapping,
198 | resource_pool_manager=resource_pool_manager,
199 | ray_worker_group_cls=ray_worker_group_cls,
200 | reward_fn=reward_fn,
201 | val_reward_fn=val_reward_fn)
202 | trainer.init_workers()
203 | trainer.fit()
204 |
205 |
206 | if __name__ == '__main__':
207 | main()
208 |
--------------------------------------------------------------------------------
/config/nft_trainer.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | tokenizer: null
3 | train_files: ~/data/rlhf/gsm8k/train.parquet
4 | val_files: ~/data/rlhf/gsm8k/test.parquet
5 | prompt_key: prompt
6 | reward_fn_key: data_source
7 | max_prompt_length: 512
8 | max_response_length: 512
9 | gen_batch_size: ${data.train_batch_size}
10 | train_batch_size: 1024
11 | val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
12 | return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
13 | return_raw_chat: False
14 | shuffle: True
15 | filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
16 | truncation: error
17 | image_key: images
18 |
19 | actor_rollout_ref:
20 | hybrid_engine: True
21 | bugged_dynamic_scale: 0
22 | neg_weight: -1.0
23 | clamp_positive: 0.0
24 | clamp_negative: 1.0
25 | normalize: 1
26 | model:
27 | path: ~/models/deepseek-llm-7b-chat
28 | external_lib: null
29 | override_config: { }
30 | enable_gradient_checkpointing: True
31 | use_remove_padding: False
32 | actor:
33 | strategy: fsdp # This is for backward-compatibility
34 | ppo_mini_batch_size: 256
35 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
36 | ppo_micro_batch_size_per_gpu: null
37 | neg_weight: ${actor_rollout_ref.neg_weight}
38 | clamp_positive: ${actor_rollout_ref.clamp_positive}
39 | normalize: ${actor_rollout_ref.normalize}
40 | clamp_negative: ${actor_rollout_ref.clamp_negative}
41 | bugged_dynamic_scale: ${actor_rollout_ref.bugged_dynamic_scale}
42 | max_response_length: ${data.max_response_length}
43 | use_dynamic_bsz: False
44 | ratio_type: token
45 | ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
46 | grad_clip: 1.0
47 | # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
48 | clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
49 | clip_ratio_low: 0.2
50 | clip_ratio_high: 0.2
51 | use_token_level_loss: True
52 | entropy_coeff: 0.001
53 | use_kl_loss: False # True for GRPO
54 | use_torch_compile: True # False to disable torch compile
55 | kl_loss_coef: 0.001 # for grpo
56 | kl_loss_type: low_var_kl # for grpo
57 | ppo_epochs: 1
58 | shuffle: False
59 | ulysses_sequence_parallel_size: 1 # sp size
60 | optim:
61 | lr: 1e-6
62 | lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
63 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
64 | min_lr_ratio: null # only useful for warmup with cosine
65 | warmup_style: constant # select from constant/cosine
66 | total_training_steps: -1 # must be override by program
67 | weight_decay: 0.01
68 | fsdp_config:
69 | wrap_policy:
70 | # transformer_layer_cls_to_wrap: None
71 | min_num_params: 0
72 | param_offload: False
73 | optimizer_offload: False
74 | fsdp_size: -1
75 | ref:
76 | fsdp_config:
77 | param_offload: False
78 | wrap_policy:
79 | # transformer_layer_cls_to_wrap: None
80 | min_num_params: 0
81 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
82 | log_prob_micro_batch_size_per_gpu: null
83 | log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
84 | log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
85 | ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
86 | rollout:
87 | name: vllm
88 | temperature: 1.0
89 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout
90 | top_p: 1
91 | use_fire_sampling: False # https://arxiv.org/abs/2410.21236
92 | prompt_length: ${data.max_prompt_length} # not use for opensource
93 | response_length: ${data.max_response_length}
94 | # for vllm rollout
95 | dtype: bfloat16 # should align with FSDP
96 | gpu_memory_utilization: 0.5
97 | ignore_eos: False
98 | enforce_eager: True
99 | free_cache_engine: True
100 | load_format: dummy_dtensor
101 | tensor_model_parallel_size: 2
102 | max_num_batched_tokens: 8192
103 | max_model_len: null
104 | max_num_seqs: 1024
105 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
106 | log_prob_micro_batch_size_per_gpu: null
107 | log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
108 | log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
109 | disable_log_stats: True
110 | enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
111 | # for hf rollout
112 | do_sample: True
113 | # number of responses (i.e. num sample times)
114 | n: 1 # > 1 for grpo
115 | val_kwargs:
116 | # sampling parameters for validation
117 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout
118 | top_p: 1.0
119 | temperature: 0.6
120 | n: 1
121 | do_sample: False # default eager for validation
122 |
123 | critic:
124 | strategy: fsdp
125 | optim:
126 | lr: 1e-5
127 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
128 | min_lr_ratio: null # only useful for warmup with cosine
129 | warmup_style: constant # select from constant/cosine
130 | total_training_steps: -1 # must be override by program
131 | weight_decay: 0.01
132 | model:
133 | path: ~/models/deepseek-llm-7b-chat
134 | tokenizer_path: ${actor_rollout_ref.model.path}
135 | override_config: { }
136 | external_lib: ${actor_rollout_ref.model.external_lib}
137 | enable_gradient_checkpointing: True
138 | use_remove_padding: False
139 | fsdp_config:
140 | param_offload: False
141 | optimizer_offload: False
142 | wrap_policy:
143 | # transformer_layer_cls_to_wrap: None
144 | min_num_params: 0
145 | fsdp_size: -1
146 | ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
147 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
148 | ppo_micro_batch_size_per_gpu: null
149 | forward_micro_batch_size: ${critic.ppo_micro_batch_size}
150 | forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
151 | use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
152 | ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
153 | forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
154 | ulysses_sequence_parallel_size: 1 # sp size
155 | ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
156 | shuffle: ${actor_rollout_ref.actor.shuffle}
157 | grad_clip: 1.0
158 | cliprange_value: 0.5
159 |
160 | reward_model:
161 | enable: False
162 | strategy: fsdp
163 | model:
164 | input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
165 | path: ~/models/FsfairX-LLaMA3-RM-v0.1
166 | external_lib: ${actor_rollout_ref.model.external_lib}
167 | use_remove_padding: False
168 | fsdp_config:
169 | wrap_policy:
170 | min_num_params: 0
171 | param_offload: False
172 | fsdp_size: -1
173 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
174 | micro_batch_size_per_gpu: null # set a number
175 | max_length: null
176 | ulysses_sequence_parallel_size: 1 # sp size
177 | use_dynamic_bsz: ${critic.use_dynamic_bsz}
178 | forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
179 | reward_manager: naive
180 |
181 | custom_reward_function:
182 | path: null
183 | name: compute_score
184 | overlong_buffer:
185 | enable: False # We try to avoid forgetting to set enable
186 | len: 0
187 | penalty_factor: 0.0
188 | log: False
189 |
190 | algorithm:
191 | gamma: 1.0
192 | lam: 1.0
193 | adv_estimator: gae
194 | kl_penalty: kl # how to estimate kl divergence
195 | kl_ctrl:
196 | type: fixed
197 | kl_coef: 0.001
198 | filter_groups:
199 | enable: False # We try to avoid forgetting to set enable
200 | metric: null # acc / score / seq_reward / seq_final_reward / ...
201 | max_num_gen_batches: 0 # Non-positive values mean no upper limit
202 |
203 | trainer:
204 | val_only: False
205 | balance_batch: True
206 | total_epochs: 30
207 | total_training_steps: null
208 | project_name: verl_examples
209 | experiment_name: gsm8k
210 | logger: [ 'console', 'wandb' ]
211 | val_generations_to_log_to_wandb: 0
212 | nnodes: 1
213 | n_gpus_per_node: 8
214 | save_freq: -1
215 | # auto: find the last ckpt to resume. If can't find, start from scratch
216 | resume_mode: auto # or auto or resume_path if
217 | resume_from_path: False
218 | test_freq: -1
219 | critic_warmup: 0
220 | default_hdfs_dir: null
221 | remove_previous_ckpt_in_save: True
222 | del_local_ckpt_after_load: False
223 | default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
224 |
--------------------------------------------------------------------------------
/qwen_math_eval_toolkit/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import json
5 | import os
6 | import numpy as np
7 | from pathlib import Path
8 | from typing import Iterable, Union, Any
9 |
10 | from qwen_math_eval_toolkit.examples import get_examples
11 |
12 |
13 | def set_seed(seed: int = 42) -> None:
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | os.environ["PYTHONHASHSEED"] = str(seed)
17 | print(f"Random seed set as {seed}")
18 |
19 |
20 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
21 | with open(file, "r", encoding="utf-8") as f:
22 | for line in f:
23 | try:
24 | yield json.loads(line)
25 | except:
26 | print("Error in loading:", line)
27 | exit()
28 |
29 |
30 | def save_jsonl(samples, save_path):
31 | # ensure path
32 | folder = os.path.dirname(save_path)
33 | os.makedirs(folder, exist_ok=True)
34 |
35 | with open(save_path, "w", encoding="utf-8") as f:
36 | for sample in samples:
37 | f.write(json.dumps(sample, ensure_ascii=False) + "\n")
38 | print("Saved to", save_path)
39 |
40 |
41 | def lower_keys(example):
42 | new_example = {}
43 | for key, value in example.items():
44 | if key != key.lower():
45 | new_key = key.lower()
46 | new_example[new_key] = value
47 | else:
48 | new_example[key] = value
49 | return new_example
50 |
51 |
52 | EXAMPLES = get_examples()
53 |
54 |
55 | def load_prompt(data_name, prompt_type, num_shots):
56 | if not num_shots:
57 | return []
58 |
59 | if data_name in ["gsm_hard", "svamp", "tabmwp", "asdiv", "mawps"]:
60 | data_name = "gsm8k"
61 | if data_name in ["math_oai", "hungarian_exam", "math-oai", "aime24", "amc23"]:
62 | data_name = "math"
63 | if data_name in ["sat_math"]:
64 | data_name = "mmlu_stem"
65 | if data_name in [
66 | "gaokao2024_I",
67 | "gaokao2024_II",
68 | "gaokao_math_qa",
69 | "gaokao2024_mix",
70 | "cn_middle_school",
71 | ]:
72 | data_name = "gaokao"
73 |
74 | if prompt_type in ["tool-integrated"]:
75 | prompt_type = "tora"
76 |
77 | return EXAMPLES[data_name][:num_shots]
78 |
79 |
80 | PROMPT_TEMPLATES = {
81 | "direct": ("Question: {input}\nAnswer: ", "{output}", "\n\n"),
82 | "cot": ("Question: {input}\nAnswer: ", "{output}", "\n\n\n"),
83 | "pal": ("Question: {input}\n\n", "{output}", "\n---\n"),
84 | "tool-integrated": ("Question: {input}\n\nSolution:\n", "{output}", "\n---\n"),
85 | "self-instruct": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"),
86 | "tora": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"),
87 | "wizard_zs": (
88 | "### Instruction:\n{input}\n\n### Response: Let's think step by step.",
89 | "{output}",
90 | "\n\n\n",
91 | ),
92 | "platypus_fs": (
93 | "### Instruction:\n{input}\n\n### Response:\n",
94 | "{output}",
95 | "\n\n\n",
96 | ),
97 | "deepseek-math": (
98 | "User: {input}\nPlease reason step by step, "
99 | "and put your final answer within \\boxed{{}}.\n\nAssistant:",
100 | "{output}",
101 | "\n\n\n",
102 | ),
103 | "kpmath": (
104 | "User: Please reason step by step and put your final answer at the end "
105 | 'with "The answer is: ".\n\n{input}\n\nAssistant:',
106 | "{output}",
107 | ),
108 | "jiuzhang": (
109 | "## Question\n{input}\n\n## Solution\n",
110 | "{output}",
111 | "\n\n\n",
112 | ),
113 | "jiuzhang_tora": (
114 | "## Question\n{input}\n\n## Code Solution\n",
115 | "{output}",
116 | "\n\n\n",
117 | ),
118 | "jiuzhang_nl": (
119 | "## Question\n{input}\n\n## Natural Language Solution\n",
120 | "{output}",
121 | "\n\n\n",
122 | ),
123 | "mmiqc": (
124 | 'Please solve the following problem and put your answer at the end with "The answer is: ".\n\n{input}\n\n',
125 | "{output}",
126 | "\n\n\n",
127 | ),
128 | "abel": (
129 | "Question:\n{input}\nAnswer:\nLet's think step by step.\n",
130 | "{output}",
131 | "\n\n",
132 | ),
133 | "shepherd": ("{input}\n", "{output}", "\n\n\n"),
134 | "qwen-boxed": (
135 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
136 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
137 | "<|im_start|>assistant\n",
138 | "{output}",
139 | "\n\n",
140 | ),
141 | "qwen25-math-cot": (
142 | "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
143 | "<|im_start|>user\n{input}<|im_end|>\n"
144 | "<|im_start|>assistant\n",
145 | "{output}",
146 | "\n\n",
147 | ),
148 | "mathstral": (
149 | "{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
150 | "{output}",
151 | "\n\n",
152 | ),
153 | "internlm-math-fs": ("Question:{input}\nAnswer:", "{output}", "\n"),
154 | "internlm-math-chat": (
155 | "<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n",
156 | "{output}",
157 | "\n\n",
158 | ),
159 | "mistral": (
160 | "[INST] {input}[/INST]",
161 | "{output}",
162 | "\n\n",
163 | ),
164 | "numina": ("### Problem: {input}\n### Solution:", " {output}", "\n\n"),
165 | "o1_cot": (
166 | '[Round 0] USER:\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}. ASSISTANT:\n',
167 | "{output}",
168 | "\n\n"
169 | )
170 | }
171 |
172 |
173 | def construct_prompt(example, data_name, args):
174 | if args.adapt_few_shot and data_name in [
175 | "gaokao2024_I",
176 | "gaokao2024_II",
177 | "gaokao_math_qa",
178 | "gaokao2024_mix",
179 | "cn_middle_school",
180 | ]:
181 | demos = load_prompt(data_name, args.prompt_type, 5)
182 | else:
183 | demos = load_prompt(data_name, args.prompt_type, args.num_shots)
184 | prompt_type = args.prompt_type
185 | if prompt_type == "platypus_fs":
186 | prompt_type = "cot"
187 | if prompt_type == "tool-integrated":
188 | prompt_type = "tora"
189 |
190 | prompt_temp = PROMPT_TEMPLATES[args.prompt_type]
191 |
192 | splitter = prompt_temp[2]
193 | input_template, output_template, splitter = (
194 | prompt_temp[0],
195 | prompt_temp[1],
196 | prompt_temp[2],
197 | )
198 | if args.prompt_type == "qwen25-math-cot":
199 | # Hotfix to support putting all demos into a single turn
200 | demo_prompt = splitter.join([q + "\n" + a for q, a in demos])
201 | else:
202 | demo_prompt = splitter.join(
203 | [
204 | input_template.format(input=q) + output_template.format(output=a)
205 | for q, a in demos
206 | ]
207 | )
208 | context = input_template.format(input=example["question"])
209 | if len(demo_prompt) == 0 or (
210 | args.adapt_few_shot and example["gt_ans"] not in ["A", "B", "C", "D", "E"]
211 | ):
212 | full_prompt = context
213 | else:
214 | if args.prompt_type == "qwen25-math-cot":
215 | # Hotfix to supportting put all demos into a single turn
216 | full_prompt = demo_prompt + splitter + example["question"]
217 | full_prompt = input_template.format(input=full_prompt)
218 | else:
219 | full_prompt = demo_prompt + splitter + context
220 |
221 | if args.prompt_type == "platypus_fs":
222 | full_prompt_temp = (
223 | "Below is an instruction that describes a task. "
224 | "Write a response that appropriately completes the request.\n\n"
225 | "### Instruction:\n{instruction}\n\n### Response:\n"
226 | )
227 | full_prompt = full_prompt_temp.format(instruction=full_prompt)
228 |
229 | if prompt_type == "tora":
230 | full_prompt = (
231 | """Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:
232 |
233 | - Analyze the question and write functions to solve the problem; the function should not take any arguments.
234 | - Present the final result in LaTeX using a `\boxed{}` without any units.
235 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values.
236 |
237 | Here are some examples you may refer to:
238 |
239 | ---
240 |
241 | """
242 | + full_prompt
243 | )
244 |
245 | return full_prompt.strip(" ") # important!
246 |
247 |
248 | key_map = {
249 | "gt": "Ground Truth",
250 | "pred": "Prediction",
251 | "gt_cot": "Reference CoT",
252 | "score": "Score",
253 | }
254 |
255 |
256 | def show_sample(sample, print_all_preds=False):
257 | print("==" * 20)
258 | for key in ["idx", "type", "level", "dataset"]:
259 | if key in sample:
260 | # capitalize
261 | print("{}: {}".format(key[0].upper() + key[1:], sample[key]))
262 | print("Question:", repr(sample["question"]))
263 | if "code" in sample:
264 | if print_all_preds:
265 | for code in sample["code"]:
266 | print("-" * 20)
267 | print("code:", code)
268 | print("Execution:", sample["report"])
269 | else:
270 | print("Solution:\n", sample["code"][0])
271 | print("Execution:", sample["report"][0])
272 | if "pred" in sample:
273 | print("Prediction:", repr(sample["pred"][0]))
274 | for key in ["gt", "score", "unit", "gt_cot"]:
275 | if key in sample:
276 | _key = key_map.get(key, key)
277 | print("{}: {}".format(_key, repr(sample[key])))
278 | print()
279 |
--------------------------------------------------------------------------------
/README_VeRL.md:
--------------------------------------------------------------------------------
1 | verl: Volcano Engine Reinforcement Learning for LLM
2 |
3 | verl is a flexible, efficient and production-ready RL training library for large language models (LLMs).
4 |
5 | verl is the open-source version of **[HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)** paper.
6 |
7 | verl is flexible and easy to use with:
8 |
9 | - **Easy extension of diverse RL algorithms**: The Hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code.
10 |
11 | - **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM and vLLM. Moreover, users can easily extend to other LLM training and inference frameworks.
12 |
13 | - **Flexible device mapping**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.
14 |
15 | - Readily integration with popular HuggingFace models
16 |
17 |
18 | verl is fast with:
19 |
20 | - **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, verl achieves high generation and training throughput.
21 |
22 | - **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.
23 |
24 |
25 | | Documentation | Paper | Slack | Wechat | Twitter
26 |
27 |
28 |
29 |
30 | ## News
31 | - [2025/3] We will present verl(HybridFlow) at [EuroSys 2025](https://2025.eurosys.org/). See you in Rotterdam!
32 | - [2025/3] We will introduce the programming model of verl at the [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg) on 3/16. See you in Beijing!
33 | - [2025/2] verl v0.2.0.post1 is released! See [release note](https://github.com/volcengine/verl/releases/) for details.
34 | - [2025/2] We presented verl in the [Bytedance/NVIDIA/Anyscale Ray Meetup](https://lu.ma/ji7atxux). See you in San Jose!
35 | - [2025/1] [Doubao-1.5-pro](https://team.doubao.com/zh/special/doubao_1_5_pro) is released with SOTA-level performance on LLM & VLM. The RL scaling preview model is trained using verl, reaching OpenAI O1-level performance on math benchmarks (70.0 pass@1 on AIME).
36 | - [2024/12] The team presented Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips) and [video](https://neurips.cc/Expo/Conferences/2024/workshop/100677) available.
37 | - [2024/12] verl is presented at Ray Forward 2024. Slides available [here](https://github.com/eric-haibin-lin/verl-community/blob/main/slides/Ray_Forward_2024_%E5%B7%AB%E9%94%A1%E6%96%8C.pdf).
38 | - [2024/10] verl is presented at Ray Summit. [Youtube video](https://www.youtube.com/watch?v=MrhMcXkXvJU&list=PLzTswPQNepXntmT8jr9WaNfqQ60QwW7-U&index=37) available.
39 | - [2024/08] HybridFlow (verl) is accepted to EuroSys 2025.
40 |
41 | ## Key Features
42 |
43 | - **FSDP** and **Megatron-LM** for training.
44 | - **vLLM** and **HF Transformers** for rollout generation, **SGLang** support coming soon.
45 | - Compatible with Hugging Face Transformers and Modelscope Hub.
46 | - Supervised fine-tuning.
47 | - Reinforcement learning with [PPO](examples/ppo_trainer/), [GRPO](examples/grpo_trainer/), [ReMax](examples/remax_trainer/), [Reinforce++](https://verl.readthedocs.io/en/latest/examples/config.html#algorithm), [RLOO](examples/rloo_trainer/), [PRIME](recipe/prime/), etc.
48 | - Support model-based reward and function-based reward (verifiable reward)
49 | - Support vision-language models (VLMs) and [multi-modal RL](examples/grpo_trainer/run_qwen2_5_vl-7b.sh)
50 | - Flash attention 2, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [sequence parallelism](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh).
51 | - Scales up to 70B models and hundreds of GPUs.
52 | - Experiment tracking with wandb, swanlab, mlflow and tensorboard.
53 |
54 | ## Upcoming Features
55 | - Reward model training
56 | - DPO training
57 | - DeepSeek integration with Megatron v0.11
58 | - SGLang integration
59 |
60 | ## Getting Started
61 |
62 | **Quickstart:**
63 | - [Installation](https://verl.readthedocs.io/en/latest/start/install.html)
64 | - [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html)
65 | - [Programming Guide](https://verl.readthedocs.io/en/latest/hybrid_flow.html)
66 |
67 | **Running a PPO example step-by-step:**
68 | - Data and Reward Preparation
69 | - [Prepare Data for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)
70 | - [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html)
71 | - Understanding the PPO Example
72 | - [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html)
73 | - [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html)
74 | - [Run GSM8K Example](https://verl.readthedocs.io/en/latest/examples/gsm8k_example.html)
75 |
76 | **Reproducible algorithm baselines:**
77 | - [PPO, GRPO, ReMax](https://verl.readthedocs.io/en/latest/experiment/ppo.html)
78 |
79 | **For code explanation and advance usage (extension):**
80 | - PPO Trainer and Workers
81 | - [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html)
82 | - [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html)
83 | - [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/index.html)
84 | - Advance Usage and Extension
85 | - [Ray API design tutorial](https://verl.readthedocs.io/en/latest/advance/placement.html)
86 | - [Extend to Other RL(HF) algorithms](https://verl.readthedocs.io/en/latest/advance/dpo_extension.html)
87 | - [Add Models with the FSDP Backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html)
88 | - [Add Models with the Megatron-LM Backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html)
89 | - [Deployment using Separate GPU Resources](https://github.com/volcengine/verl/tree/main/examples/split_placement)
90 |
91 | **Blogs from the community**
92 | - [使用verl进行GRPO分布式强化学习训练最佳实践](https://www.volcengine.com/docs/6459/1463942)
93 | - [HybridFlow veRL 原文浅析](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/readme.md)
94 | - [最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!](https://team.doubao.com/en/blog/%E6%9C%80%E9%AB%98%E6%8F%90%E5%8D%8720%E5%80%8D%E5%90%9E%E5%90%90%E9%87%8F-%E8%B1%86%E5%8C%85%E5%A4%A7%E6%A8%A1%E5%9E%8B%E5%9B%A2%E9%98%9F%E5%8F%91%E5%B8%83%E5%85%A8%E6%96%B0-rlhf-%E6%A1%86%E6%9E%B6-%E7%8E%B0%E5%B7%B2%E5%BC%80%E6%BA%90)
95 |
96 | Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer/verl_getting_started.ipynb) to get started with PPO training with a single 24GB L4 GPU (**FREE** GPU quota provided by [Lighting Studio](https://lightning.ai/hlin-verl/studios/verl-getting-started))!
97 |
98 | ## Performance Tuning Guide
99 | The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details.
100 |
101 | ## vLLM v0.7 integration preview
102 | We have released a testing version of veRL that supports vLLM>=0.7.0. Please refer to [this document](https://github.com/volcengine/verl/blob/main/docs/README_vllm0.7.md) for installation guide and more information.
103 |
104 | ## Citation and acknowledgement
105 |
106 | If you find the project helpful, please cite:
107 | - [HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)
108 | - [A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization](https://i.cs.hku.hk/~cwu/papers/gmsheng-NL2Code24.pdf)
109 |
110 | ```tex
111 | @article{sheng2024hybridflow,
112 | title = {HybridFlow: A Flexible and Efficient RLHF Framework},
113 | author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},
114 | year = {2024},
115 | journal = {arXiv preprint arXiv: 2409.19256}
116 | }
117 | ```
118 |
119 | verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and supported by Anyscale, Bytedance, LMSys.org, Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, University of Hong Kong, and many more.
120 |
121 | ## Awesome work using verl
122 | - [TinyZero](https://github.com/Jiayi-Pan/TinyZero): a reproduction of **DeepSeek R1 Zero** recipe for reasoning tasks
123 | - [PRIME](https://github.com/PRIME-RL/PRIME): Process reinforcement through implicit rewards
124 | - [RAGEN](https://github.com/ZihanWang314/ragen): a general-purpose reasoning **agent** training framework
125 | - [Logic-RL](https://github.com/Unakar/Logic-RL): a reproduction of DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset.
126 | - [SkyThought](https://github.com/NovaSky-AI/SkyThought): RL training for Sky-T1-7B by NovaSky AI team.
127 | - [deepscaler](https://github.com/agentica-project/deepscaler): iterative context scaling with GRPO
128 | - [critic-rl](https://github.com/HKUNLP/critic-rl): LLM critics for code generation
129 | - [Easy-R1](https://github.com/hiyouga/EasyR1): **Multi-modal** RL training framework
130 | - [self-rewarding-reasoning-LLM](https://arxiv.org/pdf/2502.19613): self-rewarding and correction with **generative reward models**
131 | - [Search-R1](https://github.com/PeterGriffinJin/Search-R1): RL with reasoning and **searching (tool-call)** interleaved LLMs
132 | - [Code-R1](https://github.com/ganler/code-r1): Reproducing R1 for **Code** with Reliable Rewards
133 | - [DQO](https://arxiv.org/abs/2410.09302): Enhancing multi-Step reasoning abilities of language models through direct Q-function optimization
134 | - [FIRE](https://arxiv.org/abs/2410.21236): Flaming-hot initiation with regular execution sampling for large language models
135 | - [ReSearch](https://github.com/Agent-RL/ReSearch): Learning to **Re**ason with **Search** for LLMs via Reinforcement Learning
136 | - [DeepRetrieval](https://github.com/pat-jj/DeepRetrieval): Let LLMs learn to **search** and **retrieve** desirable docs with RL
137 | - [cognitive-behaviors](https://github.com/kanishkg/cognitive-behaviors): Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs
138 |
139 | ## Contribution Guide
140 | Contributions from the community are welcome! Please checkout our [roadmap](https://github.com/volcengine/verl/issues/22) and [release plan](https://github.com/volcengine/verl/issues/354).
141 |
142 | ### Code formatting
143 | We use yapf (Google style) to enforce strict code formatting when reviewing PRs. To reformat you code locally, make sure you installed **latest** `yapf`
144 | ```bash
145 | pip3 install yapf --upgrade
146 | ```
147 | Then, make sure you are at top level of verl repo and run
148 | ```bash
149 | bash scripts/format.sh
150 | ```
151 | We are HIRING! Send us an [email](mailto:haibin.lin@bytedance.com) if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment.
152 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 |
--------------------------------------------------------------------------------
/qwen_math_eval_toolkit/grader.py:
--------------------------------------------------------------------------------
1 | """
2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
4 | - https://github.com/openai/prm800k
5 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
6 | - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
7 | """
8 |
9 | import re
10 | import regex
11 | import multiprocessing
12 | from math import isclose
13 | from typing import Union
14 | from collections import defaultdict
15 |
16 | from sympy import simplify, N
17 | from sympy.parsing.sympy_parser import parse_expr
18 | from sympy.parsing.latex import parse_latex
19 | from latex2sympy2_extended import latex2sympy
20 |
21 | # from .parser import choice_answer_clean, strip_string
22 | # from parser import choice_answer_clean
23 |
24 |
25 | def choice_answer_clean(pred: str):
26 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
27 | # Clean the answer based on the dataset
28 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
29 | if tmp:
30 | pred = tmp
31 | else:
32 | pred = [pred.strip().strip(".")]
33 | pred = pred[-1]
34 | # Remove the period at the end, again!
35 | pred = pred.rstrip(".").rstrip("/")
36 | return pred
37 |
38 |
39 | def parse_digits(num):
40 | num = regex.sub(",", "", str(num))
41 | try:
42 | return float(num)
43 | except:
44 | if num.endswith("%"):
45 | num = num[:-1]
46 | if num.endswith("\\"):
47 | num = num[:-1]
48 | try:
49 | return float(num) / 100
50 | except:
51 | pass
52 | return None
53 |
54 |
55 | def is_digit(num):
56 | # paired with parse_digits
57 | return parse_digits(num) is not None
58 |
59 |
60 | def str_to_pmatrix(input_str):
61 | input_str = input_str.strip()
62 | matrix_str = re.findall(r"\{.*,.*\}", input_str)
63 | pmatrix_list = []
64 |
65 | for m in matrix_str:
66 | m = m.strip("{}")
67 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
68 | pmatrix_list.append(pmatrix)
69 |
70 | return ", ".join(pmatrix_list)
71 |
72 |
73 | def math_equal(
74 | prediction: Union[bool, float, str],
75 | reference: Union[float, str],
76 | include_percentage: bool = True,
77 | is_close: bool = True,
78 | timeout: bool = False,
79 | ) -> bool:
80 | """
81 | Exact match of math if and only if:
82 | 1. numerical equal: both can convert to float and are equal
83 | 2. symbolic equal: both can convert to sympy expression and are equal
84 | """
85 | # print("Judge:", prediction, reference)
86 | if prediction is None or reference is None:
87 | return False
88 | if str(prediction.strip().lower()) == str(reference.strip().lower()):
89 | return True
90 | if (
91 | reference in ["A", "B", "C", "D", "E"]
92 | and choice_answer_clean(prediction) == reference
93 | ):
94 | return True
95 |
96 | try: # 1. numerical equal
97 | if is_digit(prediction) and is_digit(reference):
98 | prediction = parse_digits(prediction)
99 | reference = parse_digits(reference)
100 | # number questions
101 | if include_percentage:
102 | gt_result = [reference / 100, reference, reference * 100]
103 | else:
104 | gt_result = [reference]
105 | for item in gt_result:
106 | try:
107 | if is_close:
108 | if numeric_equal(prediction, item):
109 | return True
110 | else:
111 | if item == prediction:
112 | return True
113 | except Exception:
114 | continue
115 | return False
116 | except:
117 | pass
118 |
119 | if not prediction and prediction not in [0, False]:
120 | return False
121 |
122 | # 2. symbolic equal
123 | reference = str(reference).strip()
124 | prediction = str(prediction).strip()
125 |
126 | ## pmatrix (amps)
127 | if "pmatrix" in prediction and not "pmatrix" in reference:
128 | reference = str_to_pmatrix(reference)
129 |
130 | ## deal with [], (), {}
131 | pred_str, ref_str = prediction, reference
132 | if (
133 | prediction.startswith("[")
134 | and prediction.endswith("]")
135 | and not reference.startswith("(")
136 | ) or (
137 | prediction.startswith("(")
138 | and prediction.endswith(")")
139 | and not reference.startswith("[")
140 | ):
141 | pred_str = pred_str.strip("[]()")
142 | ref_str = ref_str.strip("[]()")
143 | for s in ["{", "}", "(", ")"]:
144 | ref_str = ref_str.replace(s, "")
145 | pred_str = pred_str.replace(s, "")
146 | if pred_str.lower() == ref_str.lower():
147 | return True
148 |
149 | ## [a, b] vs. [c, d], return a==c and b==d
150 | if (
151 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
152 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
153 | ):
154 | pred_parts = prediction[1:-1].split(",")
155 | ref_parts = reference[1:-1].split(",")
156 | if len(pred_parts) == len(ref_parts):
157 | if all(
158 | [
159 | math_equal(
160 | pred_parts[i], ref_parts[i], include_percentage, is_close
161 | )
162 | for i in range(len(pred_parts))
163 | ]
164 | ):
165 | return True
166 | if (
167 | (
168 | prediction.startswith("\\begin{pmatrix}")
169 | or prediction.startswith("\\begin{bmatrix}")
170 | )
171 | and (
172 | prediction.endswith("\\end{pmatrix}")
173 | or prediction.endswith("\\end{bmatrix}")
174 | )
175 | and (
176 | reference.startswith("\\begin{pmatrix}")
177 | or reference.startswith("\\begin{bmatrix}")
178 | )
179 | and (
180 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
181 | )
182 | ):
183 | pred_lines = [
184 | line.strip()
185 | for line in prediction[
186 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
187 | ].split("\\\\")
188 | if line.strip()
189 | ]
190 | ref_lines = [
191 | line.strip()
192 | for line in reference[
193 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
194 | ].split("\\\\")
195 | if line.strip()
196 | ]
197 | matched = True
198 | if len(pred_lines) == len(ref_lines):
199 | for pred_line, ref_line in zip(pred_lines, ref_lines):
200 | pred_parts = pred_line.split("&")
201 | ref_parts = ref_line.split("&")
202 | if len(pred_parts) == len(ref_parts):
203 | if not all(
204 | [
205 | math_equal(
206 | pred_parts[i],
207 | ref_parts[i],
208 | include_percentage,
209 | is_close,
210 | )
211 | for i in range(len(pred_parts))
212 | ]
213 | ):
214 | matched = False
215 | break
216 | else:
217 | matched = False
218 | if not matched:
219 | break
220 | else:
221 | matched = False
222 | if matched:
223 | return True
224 |
225 | if prediction.count("=") == 1 and reference.count("=") == 1:
226 | pred = prediction.split("=")
227 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
228 | ref = reference.split("=")
229 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
230 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
231 | return True
232 | elif (
233 | prediction.count("=") == 1
234 | and len(prediction.split("=")[0].strip()) <= 2
235 | and "=" not in reference
236 | ):
237 | if math_equal(
238 | prediction.split("=")[1], reference, include_percentage, is_close
239 | ):
240 | return True
241 | elif (
242 | reference.count("=") == 1
243 | and len(reference.split("=")[0].strip()) <= 2
244 | and "=" not in prediction
245 | ):
246 | if math_equal(
247 | prediction, reference.split("=")[1], include_percentage, is_close
248 | ):
249 | return True
250 |
251 | # symbolic equal with sympy
252 | if timeout:
253 | if call_with_timeout(symbolic_equal_process, prediction, reference):
254 | return True
255 | else:
256 | if symbolic_equal(prediction, reference):
257 | return True
258 |
259 | return False
260 |
261 |
262 | def math_equal_process(param):
263 | return math_equal(param[-2], param[-1])
264 |
265 |
266 | def numeric_equal(prediction: float, reference: float):
267 | # Note that relative tolerance has significant impact
268 | # on the result of the synthesized GSM-Hard dataset
269 | # if reference.is_integer():
270 | # return isclose(reference, round(prediction), abs_tol=1e-4)
271 | # else:
272 | # prediction = round(prediction, len(str(reference).split(".")[-1]))
273 | return isclose(reference, prediction, rel_tol=1e-4)
274 |
275 |
276 | def symbolic_equal(a, b):
277 | def _parse(s):
278 | for f in [parse_latex, parse_expr, latex2sympy]:
279 | try:
280 | return f(s.replace("\\\\", "\\"))
281 | except:
282 | try:
283 | return f(s)
284 | except:
285 | pass
286 | return s
287 |
288 | a = _parse(a)
289 | b = _parse(b)
290 |
291 | # direct equal
292 | try:
293 | if str(a) == str(b) or a == b:
294 | return True
295 | except:
296 | pass
297 |
298 | # simplify equal
299 | try:
300 | if a.equals(b) or simplify(a - b) == 0:
301 | return True
302 | except:
303 | pass
304 |
305 | # equation equal
306 | try:
307 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
308 | return True
309 | except:
310 | pass
311 |
312 | try:
313 | if numeric_equal(float(N(a)), float(N(b))):
314 | return True
315 | except:
316 | pass
317 |
318 | # matrix
319 | try:
320 | # if a and b are matrix
321 | if a.shape == b.shape:
322 | _a = a.applyfunc(lambda x: round(x, 3))
323 | _b = b.applyfunc(lambda x: round(x, 3))
324 | if _a.equals(_b):
325 | return True
326 | except:
327 | pass
328 |
329 | return False
330 |
331 |
332 | def symbolic_equal_process(a, b, output_queue):
333 | result = symbolic_equal(a, b)
334 | output_queue.put(result)
335 |
336 |
337 | def call_with_timeout(func, *args, timeout=1, **kwargs):
338 | output_queue = multiprocessing.Queue()
339 | process_args = args + (output_queue,)
340 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
341 | process.start()
342 | process.join(timeout)
343 |
344 | if process.is_alive():
345 | process.terminate()
346 | process.join()
347 | return False
348 |
349 | return output_queue.get()
350 |
351 | def _test_math_equal():
352 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
353 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
354 | # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
355 | # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
356 | # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
357 |
358 | # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
359 | # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
360 |
361 | # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
362 | # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
363 |
364 | # pred = '-34x-45y+20z-100=0'
365 | # gt = '34x+45y-20z+100=0'
366 |
367 | # pred = '\\frac{100}{3}'
368 | # gt = '33.3'
369 |
370 | # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
371 | # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
372 |
373 | # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
374 | # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
375 |
376 | # pred = '(+5)(b+2)'
377 | # gt = '(a+5)(b+2)'
378 |
379 | # pred = '\\frac{1+\\sqrt{5}}{2}'
380 | # gt = '2'
381 |
382 | # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
383 | # pred = '1', gt = '1\\\\sqrt{19}'
384 |
385 | # pred = "(0.6,2.6667]"
386 | # gt = "(\\frac{3}{5},\\frac{8}{3}]"
387 |
388 | gt = "x+2n+1"
389 | pred = "x+1"
390 |
391 | print(math_equal(pred, gt, timeout=True))
392 |
393 |
394 | if __name__ == "__main__":
395 | _test_math_equal()
396 |
--------------------------------------------------------------------------------
/dp_actor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
17 | #
18 | # Licensed under the Apache License, Version 2.0 (the "License");
19 | # you may not use this file except in compliance with the License.
20 | # You may obtain a copy of the License at
21 | #
22 | # http://www.apache.org/licenses/LICENSE-2.0
23 | #
24 | # Unless required by applicable law or agreed to in writing, software
25 | # distributed under the License is distributed on an "AS IS" BASIS,
26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 | # See the License for the specific language governing permissions and
28 | # limitations under the License.
29 | """
30 | Single Process Actor
31 | """
32 |
33 | import itertools
34 | from typing import Iterable, Tuple
35 |
36 | import torch
37 | from torch import nn
38 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
39 |
40 | from verl import DataProto
41 | from verl.trainer.ppo import core_algos
42 | from verl.workers.actor import BasePPOActor
43 | from verl.utils.py_functional import append_to_dict
44 | from verl.utils.torch_functional import logprobs_from_logits, masked_mean
45 | from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
46 | from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
47 | import verl.utils.torch_functional as verl_F
48 |
49 | from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
50 |
51 | __all__ = ['DataParallelPPOActor']
52 |
53 | def clamp_preserve_grad(x, min_val=None, max_val=None):
54 | clipped = torch.clamp(x, min=min_val, max=max_val)
55 | return x + (clipped - x).detach()
56 |
57 | class DataParallelPPOActor(BasePPOActor):
58 |
59 | def __init__(
60 | self,
61 | config,
62 | actor_module: nn.Module,
63 | actor_optimizer: torch.optim.Optimizer = None,
64 | ):
65 | """When optimizer is None, it is Reference Policy"""
66 | super().__init__(config)
67 | self.actor_module = actor_module
68 | self.actor_optimizer = actor_optimizer
69 | self.use_remove_padding = self.config.get('use_remove_padding', False)
70 | print(f'Actor use_remove_padding={self.use_remove_padding}')
71 | self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
72 | self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
73 |
74 | self.compute_entropy_from_logits = (
75 | torch.compile(verl_F.entropy_from_logits, dynamic=True)
76 | if self.config.get('use_torch_compile', True) # use torch compile by default
77 | else verl_F.entropy_from_logits)
78 |
79 | def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
80 | """
81 | Returns:
82 | entropy: # (bs, response_len)
83 | log_probs: # (bs, response_len)
84 | """
85 | response_length = micro_batch['responses'].size(-1)
86 | multi_modal_inputs = {}
87 | if 'multi_modal_inputs' in micro_batch:
88 | for key in micro_batch['multi_modal_inputs'][0].keys():
89 | multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']],
90 | dim=0)
91 |
92 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
93 | input_ids = micro_batch['input_ids']
94 | batch_size, seqlen = input_ids.shape
95 | attention_mask = micro_batch['attention_mask']
96 | position_ids = micro_batch['position_ids']
97 | if position_ids.dim() == 3: # qwen2vl mrope
98 | position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
99 |
100 | if self.use_remove_padding:
101 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
102 | attention_mask) # input_ids_rmpad (total_nnz, ...)
103 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
104 |
105 | # unpad the position_ids to align the rotary
106 | if position_ids.dim() == 3:
107 | position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."),
108 | indices).transpose(0, 1).unsqueeze(
109 | 1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
110 | else:
111 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
112 | indices).transpose(0, 1)
113 |
114 | # for compute the log_prob
115 | input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
116 |
117 | # pad and slice the inputs if sp > 1
118 | if self.use_ulysses_sp:
119 | input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
120 | position_ids_rmpad, \
121 | sp_size=self.ulysses_sequence_parallel_size)
122 | input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
123 | self.ulysses_sequence_parallel_size)
124 |
125 | input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
126 |
127 | # only pass input_ids and position_ids to enable flash_attn_varlen
128 | output = self.actor_module(input_ids=input_ids_rmpad,
129 | attention_mask=None,
130 | position_ids=position_ids_rmpad,
131 | **multi_modal_inputs,
132 | use_cache=False) # prevent model thinks we are generating
133 | logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
134 |
135 | logits_rmpad.div_(temperature)
136 |
137 | # compute entropy
138 | entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
139 |
140 | # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
141 | log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
142 |
143 | # gather log_prob if sp > 1
144 | if self.use_ulysses_sp:
145 | # gather and unpad for the ulysses sp
146 | log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
147 | entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
148 | gather_dim=0,
149 | unpad_dim=0,
150 | padding_size=pad_size)
151 | # pad back to (bsz, seqlen)
152 | full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1),
153 | indices=indices,
154 | batch=batch_size,
155 | seqlen=seqlen)
156 | full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1),
157 | indices=indices,
158 | batch=batch_size,
159 | seqlen=seqlen)
160 |
161 | # only return response part:
162 | entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
163 | log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
164 |
165 | else: # not using rmpad and no ulysses sp
166 | output = self.actor_module(input_ids=input_ids,
167 | attention_mask=attention_mask,
168 | position_ids=position_ids,
169 | **multi_modal_inputs,
170 | use_cache=False) # prevent model thinks we are generating
171 | logits = output.logits
172 | logits.div_(temperature)
173 | logits = logits[:, -response_length - 1:-1, :] # (bsz, response_length, vocab_size)
174 | log_probs = logprobs_from_logits(logits, micro_batch['responses'])
175 | entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
176 |
177 | return entropy, log_probs
178 |
179 | def _optimizer_step(self):
180 | assert self.config.grad_clip is not None
181 |
182 | if isinstance(self.actor_module, FSDP):
183 | grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
184 | else:
185 | grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
186 | self.actor_optimizer.step()
187 | return grad_norm
188 |
189 | def compute_log_prob(self, data: DataProto) -> torch.Tensor:
190 | """Compute the log probability of the responses given input_ids, attention_mask and position_ids
191 |
192 | Args:
193 | data (DataProto): a DataProto containing keys
194 |
195 | ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
196 | concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
197 |
198 | ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
199 |
200 | ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
201 |
202 | ``responses``: tensor of shape [batch_size, response_length]. torch.int64.
203 |
204 | Returns:
205 | torch.Tensor: the log_prob tensor
206 | """
207 | # set to eval
208 | self.actor_module.eval()
209 |
210 | micro_batch_size = data.meta_info['micro_batch_size']
211 | temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
212 | use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
213 |
214 | select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
215 | batch = data.select(batch_keys=select_keys).batch
216 | has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
217 |
218 | if has_multi_modal_inputs:
219 | num_micro_batches = data.batch.batch_size[0] // micro_batch_size
220 | non_tensor_select_keys = ['multi_modal_inputs']
221 | micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
222 | elif use_dynamic_bsz:
223 | # split using dynamic bsz
224 | max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
225 | micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
226 | else:
227 | micro_batches = batch.split(micro_batch_size)
228 |
229 | log_probs_lst = []
230 | for micro_batch in micro_batches:
231 | if isinstance(micro_batch, DataProto):
232 | micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
233 |
234 | with torch.no_grad():
235 | _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
236 | log_probs_lst.append(log_probs)
237 | log_probs = torch.concat(log_probs_lst, dim=0)
238 |
239 | if use_dynamic_bsz:
240 | indices = list(itertools.chain.from_iterable(indices))
241 | assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
242 | revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
243 | log_probs = log_probs[revert_indices]
244 |
245 | return log_probs
246 |
247 | def update_policy(self, data: DataProto):
248 | # make sure we are in training mode
249 | self.actor_module.train()
250 |
251 | temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error
252 |
253 | select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages', "correct_rate", "correct_mask"]
254 | if self.config.use_kl_loss:
255 | select_keys.append('ref_log_prob')
256 | batch = data.select(batch_keys=select_keys).batch
257 | has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys()
258 |
259 | # Split to make minibatch iterator for updating the actor
260 | # See PPO paper for details. https://arxiv.org/abs/1707.06347
261 | if has_multi_modal_inputs:
262 | num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
263 | non_tensor_select_keys = ['multi_modal_inputs']
264 | dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
265 | else:
266 | dataloader = batch.split(self.config.ppo_mini_batch_size)
267 |
268 | metrics = {}
269 | for epoch in range(self.config.ppo_epochs):
270 | for batch_idx, data in enumerate(dataloader):
271 | # split batch into micro_batches
272 | mini_batch = data
273 | if has_multi_modal_inputs:
274 | self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
275 | num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
276 | micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
277 | elif self.config.use_dynamic_bsz:
278 | max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
279 | micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
280 | else:
281 | self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
282 | # split batch into micro_batches
283 | micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
284 |
285 | self.actor_optimizer.zero_grad()
286 |
287 | for data in micro_batches:
288 | # Support all hardwares
289 | if isinstance(data, DataProto):
290 | data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch}
291 | else:
292 | data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload
293 | responses = data['responses']
294 | response_length = responses.size(1)
295 | attention_mask = data['attention_mask']
296 | response_mask = attention_mask[:, -response_length:]
297 | old_log_prob = data['old_log_probs']
298 | advantages = data['advantages']
299 | correct_rate = data['correct_rate']
300 | correct_mask = data['correct_mask']
301 | assert correct_rate.dim() == 2, f"correct_rate.shape() = {correct_rate.shape()}"
302 |
303 | clip_ratio = self.config.clip_ratio
304 | clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
305 | clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
306 | entropy_coeff = self.config.entropy_coeff
307 | use_token_level_loss = self.config.use_token_level_loss
308 |
309 | # all return: (bsz, response_length)
310 | entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)
311 | if self.config.neg_weight < 0.0:
312 | pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(
313 | old_log_prob=old_log_prob,
314 | log_prob=log_prob,
315 | advantages=advantages,
316 | eos_mask=response_mask,
317 | cliprange=clip_ratio,
318 | cliprange_low=clip_ratio_low,
319 | cliprange_high=clip_ratio_high,
320 | use_token_level_loss=use_token_level_loss)
321 | else:
322 | negative_approx_kl = log_prob - old_log_prob
323 | ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
324 | pg_clipfrac = ppo_kl
325 | if self.config.ratio_type == "token":
326 | ratio = torch.exp((negative_approx_kl).clamp(max=10.0)) # (bsz, response_length)
327 | elif self.config.ratio_type == "sequence":
328 | ratio = torch.exp(((negative_approx_kl) * response_mask).sum(dim=-1).clamp(max=10.0))[:,None] # (bsz, response_length)
329 | else:
330 | raise NotImplementedError
331 | assert ratio.dim() == 2, f"ratio.shape() = {ratio.shape()}"
332 | assert correct_mask.max().detach().cpu().item() <= 1.0, f" {correct_mask}"
333 | assert correct_mask.min().detach().cpu().item() >= 0.0, f" {correct_mask}"
334 | assert correct_rate.min().detach().cpu().item() < 1.0, f" {correct_rate}"
335 | assert correct_rate.max().detach().cpu().item() > 0.0, f" {correct_rate}"
336 | correct_rate = torch.clamp(correct_rate, min=0.02, max=0.98)
337 | negative_rate = 1.0 - correct_rate
338 | if self.config.ratio_type == "token":
339 | proxy = 1.0 - correct_rate * ratio
340 | norm = 1.0 - correct_rate * 1.0
341 | proxy = proxy / norm
342 | negative_loss = - torch.log(clamp_preserve_grad(proxy, min_val=abs(self.config.clamp_negative)))
343 | positive_loss = - torch.log(clamp_preserve_grad(ratio, min_val=self.config.clamp_positive))
344 | pg_losses = (negative_loss * (1.0-correct_mask[:, None]) * self.config.neg_weight + correct_mask[:, None] * positive_loss)
345 | if self.config.normalize == 1:
346 | pg_losses = pg_losses * negative_rate
347 | elif self.config.normalize == 2:
348 | pg_losses = pg_losses * (negative_rate / correct_rate)**0.5
349 | else:
350 | pg_losses = pg_losses
351 | print(f"normalize config {self.config.normalize}")
352 | elif self.config.ratio_type == "sequence":
353 | ratio = ratio.squeeze(-1)
354 | assert ratio.dim() == 1, f"ratio.shape() = {ratio.shape()}"
355 | correct_rate =correct_rate[:,0]
356 | assert correct_rate.dim() == 1, f"correct_rate.shape() = {correct_rate.shape()}"
357 |
358 | proxy = 1.0 - correct_rate * ratio
359 | norm = 1.0 - correct_rate * 1.0
360 | proxy = proxy / norm
361 | negative_loss = - torch.log(clamp_preserve_grad(proxy, min_val=abs(self.config.clamp_negative)))
362 | positive_loss = - torch.log(clamp_preserve_grad(ratio, min_val=self.config.clamp_positive))
363 | pg_losses = (negative_loss * (1.0 - correct_mask) * self.config.neg_weight + correct_mask * positive_loss)
364 | if self.config.normalize == 1:
365 | pg_losses = pg_losses * negative_rate # new update on 0414
366 | elif self.config.normalize == 2:
367 | pg_losses = pg_losses * (negative_rate / correct_rate)**0.5
368 | else:
369 | pg_losses = pg_losses
370 | # reduction
371 | if self.config.ratio_type == "token":
372 | pg_loss = torch.sum(pg_losses * response_mask, dim=1) / 1000.0
373 | pg_loss = torch.sum(pg_loss)
374 | elif self.config.ratio_type == "sequence":
375 | pg_loss = torch.sum(pg_losses)
376 |
377 | # compute entropy loss from entropy
378 | entropy_loss = verl_F.masked_mean(entropy, response_mask)
379 |
380 | # compute policy loss
381 | policy_loss = pg_loss - entropy_loss * entropy_coeff
382 |
383 | if self.config.use_kl_loss:
384 | ref_log_prob = data['ref_log_prob']
385 | # compute kl loss
386 | kld = core_algos.kl_penalty(logprob=log_prob,
387 | ref_logprob=ref_log_prob,
388 | kl_penalty=self.config.kl_loss_type)
389 | kl_loss = masked_mean(kld, response_mask)
390 |
391 | policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
392 | metrics['actor/kl_loss'] = kl_loss.detach().item()
393 | metrics['actor/kl_coef'] = self.config.kl_loss_coef
394 |
395 | if self.config.use_dynamic_bsz:
396 | # relative to the dynamic bsz
397 | if self.config.bugged_dynamic_scale == 1:
398 | loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
399 | else:
400 | loss = policy_loss
401 | else:
402 | loss = policy_loss / self.gradient_accumulation
403 | loss.backward()
404 |
405 | data = {
406 | 'actor/entropy': entropy_loss.detach().item(),
407 | 'actor/pg_loss': pg_loss.detach().item(),
408 | 'actor/pg_clipfrac': pg_clipfrac.detach().item(),
409 | 'actor/ppo_kl': ppo_kl.detach().item(),
410 | }
411 | append_to_dict(metrics, data)
412 |
413 | grad_norm = self._optimizer_step()
414 | data = {'actor/grad_norm': grad_norm.detach().item()}
415 | append_to_dict(metrics, data)
416 | self.actor_optimizer.zero_grad()
417 | return metrics
418 |
--------------------------------------------------------------------------------
/qwen_math_eval_toolkit/parser.py:
--------------------------------------------------------------------------------
1 | import random
2 | import regex
3 | import re
4 | import sympy
5 | from latex2sympy2_extended import latex2sympy
6 | from typing import TypeVar, Iterable, List, Union, Any, Dict
7 | from word2number import w2n
8 | from qwen_math_eval_toolkit.utils import *
9 |
10 |
11 | def _fix_fracs(string):
12 | substrs = string.split("\\frac")
13 | new_str = substrs[0]
14 | if len(substrs) > 1:
15 | substrs = substrs[1:]
16 | for substr in substrs:
17 | new_str += "\\frac"
18 | if len(substr) > 0 and substr[0] == "{":
19 | new_str += substr
20 | else:
21 | try:
22 | assert len(substr) >= 2
23 | except:
24 | return string
25 | a = substr[0]
26 | b = substr[1]
27 | if b != "{":
28 | if len(substr) > 2:
29 | post_substr = substr[2:]
30 | new_str += "{" + a + "}{" + b + "}" + post_substr
31 | else:
32 | new_str += "{" + a + "}{" + b + "}"
33 | else:
34 | if len(substr) > 2:
35 | post_substr = substr[2:]
36 | new_str += "{" + a + "}" + b + post_substr
37 | else:
38 | new_str += "{" + a + "}" + b
39 | string = new_str
40 | return string
41 |
42 |
43 | def _fix_a_slash_b(string):
44 | if len(string.split("/")) != 2:
45 | return string
46 | a = string.split("/")[0]
47 | b = string.split("/")[1]
48 | try:
49 | if "sqrt" not in a:
50 | a = int(a)
51 | if "sqrt" not in b:
52 | b = int(b)
53 | assert string == "{}/{}".format(a, b)
54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
55 | return new_string
56 | except:
57 | return string
58 |
59 |
60 | def _fix_sqrt(string):
61 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
62 | return _string
63 |
64 |
65 | def convert_word_number(text: str) -> str:
66 | try:
67 | text = str(w2n.word_to_num(text))
68 | except:
69 | pass
70 | return text
71 |
72 |
73 | # units mainly from MathQA
74 | unit_texts = [
75 | "east",
76 | "degree",
77 | "mph",
78 | "kmph",
79 | "ft",
80 | "m sqaure",
81 | " m east",
82 | "sq m",
83 | "deg",
84 | "mile",
85 | "q .",
86 | "monkey",
87 | "prime",
88 | "ratio",
89 | "profit of rs",
90 | "rd",
91 | "o",
92 | "gm",
93 | "p . m",
94 | "lb",
95 | "tile",
96 | "per",
97 | "dm",
98 | "lt",
99 | "gain",
100 | "ab",
101 | "way",
102 | "west",
103 | "a .",
104 | "b .",
105 | "c .",
106 | "d .",
107 | "e .",
108 | "f .",
109 | "g .",
110 | "h .",
111 | "t",
112 | "a",
113 | "h",
114 | "no change",
115 | "men",
116 | "soldier",
117 | "pie",
118 | "bc",
119 | "excess",
120 | "st",
121 | "inches",
122 | "noon",
123 | "percent",
124 | "by",
125 | "gal",
126 | "kmh",
127 | "c",
128 | "acre",
129 | "rise",
130 | "a . m",
131 | "th",
132 | "π r 2",
133 | "sq",
134 | "mark",
135 | "l",
136 | "toy",
137 | "coin",
138 | "sq . m",
139 | "gallon",
140 | "° f",
141 | "profit",
142 | "minw",
143 | "yr",
144 | "women",
145 | "feet",
146 | "am",
147 | "pm",
148 | "hr",
149 | "cu cm",
150 | "square",
151 | "v â € ™",
152 | "are",
153 | "rupee",
154 | "rounds",
155 | "cubic",
156 | "cc",
157 | "mtr",
158 | "s",
159 | "ohm",
160 | "number",
161 | "kmph",
162 | "day",
163 | "hour",
164 | "minute",
165 | "min",
166 | "second",
167 | "man",
168 | "woman",
169 | "sec",
170 | "cube",
171 | "mt",
172 | "sq inch",
173 | "mp",
174 | "∏ cm ³",
175 | "hectare",
176 | "more",
177 | "sec",
178 | "unit",
179 | "cu . m",
180 | "cm 2",
181 | "rs .",
182 | "rs",
183 | "kg",
184 | "g",
185 | "month",
186 | "km",
187 | "m",
188 | "cm",
189 | "mm",
190 | "apple",
191 | "liter",
192 | "loss",
193 | "yard",
194 | "pure",
195 | "year",
196 | "increase",
197 | "decrease",
198 | "d",
199 | "less",
200 | "Surface",
201 | "litre",
202 | "pi sq m",
203 | "s .",
204 | "metre",
205 | "meter",
206 | "inch",
207 | ]
208 |
209 | unit_texts.extend([t + "s" for t in unit_texts])
210 |
211 |
212 | def strip_string(string, skip_unit=False):
213 | string = str(string).strip()
214 | # linebreaks
215 | string = string.replace("\n", "")
216 |
217 | # right "."
218 | string = string.rstrip(".")
219 |
220 | # remove inverse spaces
221 | # replace \\ with \
222 | string = string.replace("\\!", "")
223 | # string = string.replace("\\ ", "")
224 | # string = string.replace("\\\\", "\\")
225 |
226 | # matrix
227 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
228 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
229 | string = string.replace("bmatrix", "pmatrix")
230 |
231 | # replace tfrac and dfrac with frac
232 | string = string.replace("tfrac", "frac")
233 | string = string.replace("dfrac", "frac")
234 | string = (
235 | string.replace("\\neq", "\\ne")
236 | .replace("\\leq", "\\le")
237 | .replace("\\geq", "\\ge")
238 | )
239 |
240 | # remove \left and \right
241 | string = string.replace("\\left", "")
242 | string = string.replace("\\right", "")
243 | string = string.replace("\\{", "{")
244 | string = string.replace("\\}", "}")
245 |
246 | # Remove unit: miles, dollars if after is not none
247 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
248 | if _string != "" and _string != string:
249 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
250 | string = _string
251 |
252 | if not skip_unit:
253 | # Remove unit: texts
254 | for _ in range(2):
255 | for unit_text in unit_texts:
256 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character
257 | # the suffix should be either the end of the string or a non-alphanumeric character
258 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
259 | if _string != "":
260 | string = _string
261 |
262 | # Remove circ (degrees)
263 | string = string.replace("^{\\circ}", "")
264 | string = string.replace("^\\circ", "")
265 |
266 | # remove dollar signs
267 | string = string.replace("\\$", "")
268 | string = string.replace("$", "")
269 | string = string.replace("\\(", "").replace("\\)", "")
270 |
271 | # convert word number to digit
272 | string = convert_word_number(string)
273 |
274 | # replace "\\text{...}" to "..."
275 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
276 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
277 | string = string.replace(key, "")
278 | string = string.replace("\\emptyset", r"{}")
279 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
280 |
281 | # remove percentage
282 | string = string.replace("\\%", "")
283 | string = string.replace("\%", "")
284 | string = string.replace("%", "")
285 |
286 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
287 | string = string.replace(" .", " 0.")
288 | string = string.replace("{.", "{0.")
289 |
290 | # cdot
291 | # string = string.replace("\\cdot", "")
292 | if (
293 | string.startswith("{")
294 | and string.endswith("}")
295 | and string.isalnum()
296 | or string.startswith("(")
297 | and string.endswith(")")
298 | and string.isalnum()
299 | or string.startswith("[")
300 | and string.endswith("]")
301 | and string.isalnum()
302 | ):
303 | string = string[1:-1]
304 |
305 | # inf
306 | string = string.replace("infinity", "\\infty")
307 | if "\\infty" not in string:
308 | string = string.replace("inf", "\\infty")
309 | string = string.replace("+\\inity", "\\infty")
310 |
311 | # and
312 | string = string.replace("and", "")
313 | string = string.replace("\\mathbf", "")
314 |
315 | # use regex to remove \mbox{...}
316 | string = re.sub(r"\\mbox{.*?}", "", string)
317 |
318 | # quote
319 | string.replace("'", "")
320 | string.replace('"', "")
321 |
322 | # i, j
323 | if "j" in string and "i" not in string:
324 | string = string.replace("j", "i")
325 |
326 | # replace a.000b where b is not number or b is end, with ab, use regex
327 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
328 | string = re.sub(r"(\d+)\.0*$", r"\1", string)
329 |
330 | # if empty, return empty string
331 | if len(string) == 0:
332 | return string
333 | if string[0] == ".":
334 | string = "0" + string
335 |
336 | # to consider: get rid of e.g. "k = " or "q = " at beginning
337 | if len(string.split("=")) == 2:
338 | if len(string.split("=")[0]) <= 2:
339 | string = string.split("=")[1]
340 |
341 | string = _fix_sqrt(string)
342 | string = string.replace(" ", "")
343 |
344 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
345 | string = _fix_fracs(string)
346 |
347 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
348 | string = _fix_a_slash_b(string)
349 |
350 | return string
351 |
352 |
353 | def extract_multi_choice_answer(pred_str):
354 | # TODO: SFT models
355 | if "Problem:" in pred_str:
356 | pred_str = pred_str.split("Problem:", 1)[0]
357 | pred_str = pred_str.replace("choice is", "answer is")
358 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower())
359 | if patt is not None:
360 | return patt.group("ans").upper()
361 | return "placeholder"
362 |
363 |
364 | direct_answer_trigger_for_fewshot = ("choice is", "answer is")
365 |
366 |
367 | def choice_answer_clean(pred: str):
368 | pred = pred.strip("\n")
369 |
370 | # Determine if this is ICL, if so, use \n\n to split the first chunk.
371 | ICL = False
372 | for trigger in direct_answer_trigger_for_fewshot:
373 | if pred.count(trigger) > 1:
374 | ICL = True
375 | if ICL:
376 | pred = pred.split("\n\n")[0]
377 |
378 | # Split the trigger to find the answer.
379 | preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
380 | if len(preds) > 1:
381 | answer_flag = True
382 | pred = preds[-1]
383 | else:
384 | answer_flag = False
385 |
386 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
387 |
388 | # Clean the answer based on the dataset
389 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
390 | if tmp:
391 | pred = tmp
392 | else:
393 | pred = [pred.strip().strip(".")]
394 |
395 | if len(pred) == 0:
396 | pred = ""
397 | else:
398 | if answer_flag:
399 | # choose the first element in list ...
400 | pred = pred[0]
401 | else:
402 | # choose the last e
403 | pred = pred[-1]
404 |
405 | # Remove the period at the end, again!
406 | pred = pred.rstrip(".").rstrip("/")
407 |
408 | return pred
409 |
410 |
411 | def find_box(pred_str: str):
412 | ans = pred_str.split("boxed")[-1]
413 | if not ans:
414 | return ""
415 | if ans[0] == "{":
416 | stack = 1
417 | a = ""
418 | for c in ans[1:]:
419 | if c == "{":
420 | stack += 1
421 | a += c
422 | elif c == "}":
423 | stack -= 1
424 | if stack == 0:
425 | break
426 | a += c
427 | else:
428 | a += c
429 | else:
430 | a = ans.split("$")[0].strip()
431 | return a
432 |
433 |
434 | def clean_units(pred_str: str):
435 | """Clean the units in the number."""
436 |
437 | def convert_pi_to_number(code_string):
438 | code_string = code_string.replace("\\pi", "π")
439 | # Replace \pi or π not preceded by a digit or } with 3.14
440 | code_string = re.sub(r"(? "3*3.14"
442 | code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
443 | # Handle cases where π is within braces or followed by a multiplication symbol
444 | # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
445 | code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
446 | code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
447 | return code_string
448 |
449 | pred_str = convert_pi_to_number(pred_str)
450 | pred_str = pred_str.replace("%", "/100")
451 | pred_str = pred_str.replace("$", "")
452 | pred_str = pred_str.replace("¥", "")
453 | pred_str = pred_str.replace("°C", "")
454 | pred_str = pred_str.replace(" C", "")
455 | pred_str = pred_str.replace("°", "")
456 | return pred_str
457 |
458 |
459 | def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
460 | if any([option in pred.lower() for option in ["yes", "true"]]):
461 | pred = "True"
462 | elif any([option in pred.lower() for option in ["no", "false"]]):
463 | pred = "False"
464 | elif any(
465 | [
466 | option in pred.lower()
467 | for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
468 | ]
469 | ):
470 | pass
471 | else:
472 | # Some of the models somehow get used to boxed output from pre-training
473 | if "boxed" in pred:
474 | pred = find_box(pred)
475 |
476 | if answer_flag:
477 | # Extract the numbers out of the string
478 | pred = pred.split("=")[-1].strip()
479 | pred = clean_units(pred)
480 | try:
481 | tmp = str(latex2sympy(pred))
482 | pred = str(eval(tmp))
483 | except Exception:
484 | if re.match(r"-?[\d\.]+\s\D+$", pred):
485 | pred = pred.split(" ")[0]
486 | elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
487 | pred = pred.split(" ")[0]
488 | else:
489 | # desparate search over the last number
490 | preds = re.findall(r"-?\d*\.?\d+", pred)
491 | if len(preds) >= 1:
492 | pred = preds[-1]
493 | else:
494 | pred = ""
495 |
496 | return pred
497 |
498 | #关键提取函数
499 | def extract_answer(pred_str, data_name, use_last_number=True):
500 | pred_str = pred_str.replace("\u043a\u0438", "")
501 | if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: #skip for math benchmark
502 | # TODO check multiple choice
503 | return choice_answer_clean(pred_str)
504 |
505 | if "final answer is $" in pred_str and "$. I hope" in pred_str:
506 | # minerva_math
507 | tmp = pred_str.split("final answer is $", 1)[1]
508 | pred = tmp.split("$. I hope", 1)[0].strip()
509 | elif "boxed" in pred_str:
510 | ans = pred_str.split("boxed")[-1]
511 | if len(ans) == 0:
512 | return ""
513 | elif ans[0] == "{":
514 | stack = 1
515 | a = ""
516 | for c in ans[1:]:
517 | if c == "{":
518 | stack += 1
519 | a += c
520 | elif c == "}":
521 | stack -= 1
522 | if stack == 0:
523 | break
524 | a += c
525 | else:
526 | a += c
527 | else:
528 | a = ans.split("$")[0].strip()
529 | pred = a
530 | elif "he answer is" in pred_str:
531 | pred = pred_str.split("he answer is")[-1].strip()
532 | elif "final answer is" in pred_str:
533 | pred = pred_str.split("final answer is")[-1].strip()
534 | elif "答案是" in pred_str:
535 | # Handle Chinese few-shot multiple choice problem answer extraction
536 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
537 | else: # use the last number
538 | if use_last_number:
539 | pattern = "-?\d*\.?\d+"
540 | pred = re.findall(pattern, pred_str.replace(",", ""))
541 | if len(pred) >= 1:
542 | pred = pred[-1]
543 | else:
544 | pred = ""
545 | else:
546 | pred = ""
547 |
548 | # choice answer
549 | if (
550 | data_name in ["sat_math", "aqua"]
551 | or "mmlu" in data_name
552 | ): #false for math benchmark
553 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
554 | if tmp:
555 | pred = tmp[-1]
556 | else:
557 | pred = pred.strip().strip(".")
558 |
559 | # multiple line
560 | # pred = pred.split("\n")[0]
561 | pred = re.sub(r"\n\s*", "", pred)
562 | if pred != "" and pred[0] == ":":
563 | pred = pred[1:]
564 | if pred != "" and pred[-1] == ".":
565 | pred = pred[:-1]
566 | if pred != "" and pred[-1] == "/":
567 | pred = pred[:-1]
568 | pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
569 | return pred
570 |
571 |
572 | STRIP_EXCEPTIONS = ["carp_en", "minerva_math"]
573 |
574 |
575 | def parse_ground_truth(example: Dict[str, Any], data_name):
576 | if "gt_cot" in example and "gt" in example:
577 | if data_name in ["math"]:
578 | gt_ans = extract_answer(example["gt_cot"], data_name)
579 | elif data_name in STRIP_EXCEPTIONS:
580 | gt_ans = example["gt"]
581 | else:
582 | gt_ans = strip_string(example["gt"])
583 | return example["gt_cot"], gt_ans
584 |
585 | # parse ground truth
586 | if data_name in ["math", "minerva_math", "math500"]: #关键代码
587 | gt_cot = example["solution"]
588 | gt_ans = extract_answer(gt_cot, data_name)
589 | elif data_name == "gsm8k":
590 | gt_cot, gt_ans = example["answer"].split("####")
591 | elif data_name == "svamp":
592 | gt_cot, gt_ans = example["Equation"], example["Answer"]
593 | elif data_name == "asdiv":
594 | gt_cot = example["formula"]
595 | gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
596 | elif data_name == "mawps":
597 | gt_cot, gt_ans = None, example["target"]
598 | elif data_name == "tabmwp":
599 | gt_cot = example["solution"]
600 | gt_ans = example["answer"]
601 | if example["ans_type"] in ["integer_number", "decimal_number"]:
602 | if "/" in gt_ans:
603 | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
604 | elif "," in gt_ans:
605 | gt_ans = float(gt_ans.replace(",", ""))
606 | elif "%" in gt_ans:
607 | gt_ans = float(gt_ans.split("%")[0]) / 100
608 | else:
609 | gt_ans = float(gt_ans)
610 | elif data_name == "carp_en":
611 | gt_cot, gt_ans = example["steps"], example["answer"]
612 | elif data_name == "mmlu_stem":
613 | abcd = "ABCD"
614 | gt_cot, gt_ans = None, abcd[example["answer"]]
615 | elif data_name == "sat_math":
616 | gt_cot, gt_ans = None, example["Answer"]
617 | elif data_name == "aqua":
618 | gt_cot, gt_ans = None, example["correct"]
619 | elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]:
620 | gt_cot, gt_ans = None, example["answer"].replace("$", "").strip()
621 | elif data_name == "gaokao_math_qa":
622 | gt_cot, gt_ans = None, example["label"]
623 | elif data_name in ["gaokao2024_mix", "cn_middle_school"]:
624 | if len(example["choice_answer"]) > 0:
625 | gt_cot, gt_ans = None, example["choice_answer"]
626 | else:
627 | gt_cot, gt_ans = None, example["answer"]
628 | elif data_name == "olympiadbench":
629 | gt_cot, gt_ans = None, example["final_answer"][0].strip("$")
630 | elif data_name in [
631 | "aime24",
632 | "amc23",
633 | "cmath",
634 | "gaokao2024_I",
635 | "gaokao2024_II",
636 | "imo2024",
637 | ]:
638 | gt_cot, gt_ans = None, example["answer"]
639 | else:
640 | raise NotImplementedError(f"`{data_name}`")
641 | # post process
642 | gt_cot = str(gt_cot).strip()
643 | if data_name not in STRIP_EXCEPTIONS:
644 | gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en")
645 | else:
646 | gt_ans = (
647 | gt_ans.replace("\\neq", "\\ne")
648 | .replace("\\leq", "\\le")
649 | .replace("\\geq", "\\ge")
650 | )
651 | return gt_cot, gt_ans
652 |
653 |
654 | def parse_question(example, data_name):
655 | question = ""
656 | if data_name == "asdiv":
657 | question = f"{example['body'].strip()} {example['question'].strip()}"
658 | elif data_name == "svamp":
659 | body = example["Body"].strip()
660 | if not body.endswith("."):
661 | body = body + "."
662 | question = f'{body} {example["Question"].strip()}'
663 | elif data_name == "tabmwp":
664 | title_str = (
665 | f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
666 | )
667 | question = f"Read the following table {title_str}and answer a question:\n"
668 | question += f'{example["table"]}\n{example["question"]}'
669 | if example["choices"]:
670 | question += (
671 | f' Please select from the following options: {example["choices"]}'
672 | )
673 | elif data_name == "carp_en":
674 | question = example["content"]
675 | elif data_name == "mmlu_stem":
676 | options = example["choices"]
677 | assert len(options) == 4
678 | for i, (label, option) in enumerate(zip("ABCD", options)):
679 | options[i] = f"({label}) {str(option).strip()}"
680 | options = " ".join(options)
681 | # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
682 | question = f"{example['question'].strip()}\nAnswer Choices: {options}"
683 | elif data_name == "sat_math":
684 | options = example["options"].strip()
685 | assert "A" == options[0]
686 | options = "(" + options
687 | for ch in "BCD":
688 | if f" {ch}) " in options:
689 | options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
690 | # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
691 | question = f"{example['question'].strip()}\nAnswer Choices: {options}"
692 | elif "aqua" in data_name:
693 | options = example["options"]
694 | choice = "(" + "(".join(options)
695 | choice = choice.replace("(", " (").replace(")", ") ").strip()
696 | choice = "\nAnswer Choices: " + choice
697 | question = example["question"].strip() + choice
698 | elif data_name == "gaokao_math_qa":
699 | options_dict = example["options"]
700 | options = []
701 | for key in options_dict:
702 | options.append(f"({key}) {options_dict[key]}")
703 | options = " ".join(options)
704 | question = f"{example['question'].strip()}\n选项: {options}"
705 | else:
706 | for key in ["question", "problem", "Question", "input"]:
707 | if key in example:
708 | question = example[key]
709 | break
710 | # assert question != ""
711 | # Yes or No question
712 | _, gt_ans = parse_ground_truth(example, data_name)
713 | if isinstance(gt_ans, str):
714 | gt_lower = gt_ans.lower()
715 | if gt_lower in ["true", "false"]:
716 | question += " (True or False)"
717 | if gt_lower in ["yes", "no"]:
718 | question += " (Yes or No)"
719 | return question.strip()
720 |
721 |
722 | def run_execute(executor, result, prompt_type, data_name, execute=False):
723 | if not result or result == "error":
724 | return None, None
725 | report = None
726 |
727 | if "program_only" in prompt_type:
728 | prediction = extract_program_output(result)
729 | elif prompt_type in ["pot", "pal"] and execute:
730 | code = extract_program(result)
731 | prediction, report = executor.apply(code)
732 | else:
733 | prediction = extract_answer(result, data_name)
734 |
735 | # prediction = strip_string(prediction, skip_unit=data_name == "carp_en")
736 | prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS)
737 | return prediction, report
738 |
739 |
740 | def _test_extract_answer():
741 | text = """
742 | This is still not equal to $0$, so we must have made another mistake.
743 |
744 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
745 |
746 | \[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\]
747 |
748 | This is still not equal to $0$, so we must have made another mistake.
749 |
750 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
751 |
752 | \[\frac{386}{64}
753 | """
754 | print(extract_answer(text, "math-oai", use_last_number=False))
755 | print(choice_answer_clean("\mathrm{(D)\}1,008,016"))
756 | # should output a dict
757 |
758 |
759 | if __name__ == "__main__":
760 | _test_extract_answer()
761 |
--------------------------------------------------------------------------------
/qwen_math_eval_toolkit/examples.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def get_examples():
5 | examples = {}
6 | examples["gsm8k"] = [
7 | (
8 | "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
9 | "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.",
10 | ),
11 | (
12 | "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
13 | "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.",
14 | ),
15 | (
16 | "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
17 | "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.",
18 | ),
19 | (
20 | "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
21 | "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.",
22 | ),
23 | (
24 | "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
25 | "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.",
26 | ),
27 | (
28 | "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
29 | "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.",
30 | ),
31 | (
32 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
33 | "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.",
34 | ),
35 | (
36 | "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
37 | "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.",
38 | ),
39 | ]
40 | examples["gsm8k-pal"] = [
41 | (
42 | "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
43 | '```python\ndef solution():\n """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"""\n money_initial = 23\n bagels = 5\n bagel_cost = 3\n money_spent = bagels * bagel_cost\n money_left = money_initial - money_spent\n result = money_left\n return result\n```',
44 | ),
45 | (
46 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
47 | '```python\ndef solution():\n """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?"""\n golf_balls_initial = 58\n golf_balls_lost_tuesday = 23\n golf_balls_lost_wednesday = 2\n golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n result = golf_balls_left\n return result\n```',
48 | ),
49 | (
50 | "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
51 | '```python\ndef solution():\n """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"""\n computers_initial = 9\n computers_per_day = 5\n num_days = 4 # 4 days between monday and thursday\n computers_added = computers_per_day * num_days\n computers_total = computers_initial + computers_added\n result = computers_total\n return result\n```',
52 | ),
53 | ]
54 | examples["gsm8k-tora"] = [
55 | (
56 | "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
57 | "```python\ndef money_left():\n money_initial = 23\n bagels = 5\n bagel_cost = 3\n money_spent = bagels * bagel_cost\n remaining_money = money_initial - money_spent\n return remaining_money\n \nremaining_money = money_left()\nprint(remaining_money)\n```\n```output\n8\n```\nOlivia has $\\boxed{8}$ dollars left.",
58 | ),
59 | (
60 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
61 | "```python\ndef remaining_golf_balls():\n golf_balls_initial = 58\n golf_balls_lost_tuesday = 23\n golf_balls_lost_wednesday = 2\n golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n remaining_golf_balls = golf_balls_left\n return remaining_golf_balls\n\nanswer = remaining_golf_balls() \nprint(answer)\n```\n```output\n33\n```\nMichael had $\\boxed{33}$ golf balls at the end of Wednesday.",
62 | ),
63 | (
64 | "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
65 | "```python\ndef total_computers():\n computers_initial = 9\n computers_per_day = 5\n num_days = 4 # 4 days between monday and thursday\n computers_added = computers_per_day * num_days\n computers_total = computers_initial + computers_added\n return computers_total\n\ntotal_computers = total_computers()\nprint(total_computers)\n```\n```output\n29\n```\nThere're $\\boxed{29}$ computers in the server room.",
66 | ),
67 | ]
68 |
69 | examples["math"] = [
70 | (
71 | "Kevin Kangaroo begins hopping on a number line at 0. He wants to get to 1, but he can hop only $\\frac{1}{3}$ of the distance. Each hop tires him out so that he continues to hop $\\frac{1}{3}$ of the remaining distance. How far has he hopped after five hops? Express your answer as a common fraction.",
72 | "Let's think step by step\nKevin hops $1/3$ of the remaining distance with every hop.\nHis first hop takes $1/3$ closer.\nFor his second hop, he has $2/3$ left to travel, so he hops forward $(2/3)(1/3)$.\nFor his third hop, he has $(2/3)^2$ left to travel, so he hops forward $(2/3)^2(1/3)$.\nIn general, Kevin hops forward $(2/3)^{k-1}(1/3)$ on his $k$th hop.\nWe want to find how far he has hopped after five hops.\nThis is a finite geometric series with first term $1/3$, common ratio $2/3$, and five terms.\nThus, Kevin has hopped $\\frac{\\frac{1}{3}\\left(1-\\left(\\frac{2}{3}\\right)^5\\right)}{1-\\frac{2}{3}} = \\boxed{\\frac{211}{243}}$.\nThe answer is \\frac{211}{243}}",
73 | ),
74 | (
75 | "What is the area of the region defined by the equation $x^2+y^2 - 7 = 4y-14x+3$?",
76 | "Let's think step by step\nWe rewrite the equation as $x^2 + 14x + y^2 - 4y = 10$ and then complete the square,\nresulting in $(x+7)^2-49 + (y-2)^2-4=10$,\nor $(x+7)^2+(y-2)^2=63$.\nThis is the equation of a circle with center $(-7, 2)$ and radius $\\sqrt{63},$\nso the area of this region is $\\pi r^2 = \\boxed{63\\pi}$.\nThe answer is 63\\pi",
77 | ),
78 | (
79 | "If $x^2+y^2=1$, what is the largest possible value of $|x|+|y|$?",
80 | "Let's think step by step\nIf $(x,y)$ lies on the circle,\nso does $(x,-y),$ $(-x,-y),$ and $(-x,-y),$ (which all give the same value of $|x| + |y|$),\nso we can assume that $x \\ge 0$ and $y \\ge 0.$\nThen $|x| + |y| = x + y.$ Squaring, we get\n\\[(x + y)^2 = x^2 + 2xy + y^2 = 1 + 2xy.\\]\nNote that $(x - y)^2 \\ge 0.$\nExpanding, we get $x^2 - 2xy + y^2 \\ge 0,$ so $2xy \\le x^2 + y^2 = 1.$\nHence,\\[1 + 2xy \\le 2,\\]which means $x + y \\le \\sqrt{2}.$\nEquality occurs when $x = y = \\frac{1}{\\sqrt{2}},$\nso the maximum value of $|x| + |y|$ is $\\boxed{\\sqrt{2}}.$\nThe answer is \\sqrt{2}",
81 | ),
82 | (
83 | "If $f(x)=\\frac{ax+b}{cx+d}, abcd\\not=0$ and $f(f(x))=x$ for all $x$ in the domain of $f$, what is the value of $a+d$?",
84 | "Let's think step by step\nThe condition $f(f(x))$ means that $f$ is the inverse of itself,\nso its graph is symmetrical about the line $y = x$.\nWith a rational function of this form, we will have two asymptotes:\na vertical one at $x=-d/c$ if $cx+d$ does not divide $ax+b$,\nand a horizontal one at $y=a/c$,\nif we take the limit of $f(x)$ as $x$ goes to $\\pm\\infty$.\nIn order for $f$ to be its own inverse, the intersection of the asymptotes must lie on the line $y=x$\nso that it and its asymptotes reflect onto themselves.\nThis means that $-d/c=a/c$,\nand therefore $-d=a$ and $a+d=\\boxed{0}$.\nThe answer is 0",
85 | ),
86 | (
87 | "Expand $(2z^2 + 5z - 6)(3z^3 - 2z + 1)$.",
88 | "Let's think step by step\n$$\\begin{array}{crrrrrrr}\n& & & 3z^3 & & -2z & + 1 & \\\\\n\\times & & & & 2z^2 & +5z & -6 \\\\\n\\cline{1-7}\\rule{0pt}{0.17in}\n& & & -18z^3 & & +12z & -6 & \\\\\n& & +15z^4 & & -10z^2 & +5z & & \\\\\n+ & 6z^5 & & -4z^3 & +2z^2 & & & \\\\\n\\cline{1-7}\\rule{0pt}{0.17in}\n& 6z^5 & +15z^4 & -22z^3 & - 8z^2 &+17z & -6 &\n\\end{array}$$\nThe answer is 6z^5+15z^4-22z^3-8z^2+17z-6",
89 | ),
90 | ]
91 |
92 | examples["math_pal"] = [
93 | (
94 | "Display the final result in LaTeX.\n\n Find the coefficient of $x^3$ when $3(x^2 - x^3+x) +3(x +2x^3- 3x^2 + 3x^5+x^3) -5(1+x-4x^3 - x^2)$ is simplifie.",
95 | "```python\nfrom sympy import symbols, simplify\n\ndef solution():\n x = symbols('x')\n expr = 3*(x**2 - x**3 + x) + 3*(x + 2*x**3 - 3*x**2 + 3*x**5 + x**3) - 5*(1 + x - 4*x**3 - x**2)\n simplified_expr = simplify(expr)\n\n x3_coefficient = simplified_expr.as_coefficients_dict()[x**3]\n result = x3_coefficient\n return result\n```",
96 | ),
97 | (
98 | "The surface area of a sphere with radius $r$ is $4\\pi r^2$. Including the area of its circular base, what is the total surface area of a hemisphere with radius 6 cm? Express your answer in terms of $\\pi$.",
99 | "```python\nimport math\n\ndef solution():\n radius = 6\n\n # Surface area of the hemisphere\n hemisphere_area = 2 * math.pi * radius**2\n\n # Area of the circular base\n base_area = math.pi * radius**2\n\n # Total surface area\n total_surface_area = hemisphere_area + base_area\n\n # Formatting the result in LaTeX\n result = r'{}\\\\pi'.format(total_surface_area / math.pi)\n return result\n```",
100 | ),
101 | (
102 | "Monica tosses a fair 6-sided die. If the roll is a prime number, then she wins that amount of dollars (so that, for example, if she rolls 3, then she wins 3 dollars). If the roll is composite, she wins nothing. Otherwise, she loses 3 dollars. What is the expected value of her winnings on one die toss? Express your answer as a dollar value to the nearest cent.",
103 | '```python\ndef solution():\n # Probabilities of each outcome\n prime_prob = 1 / 6\n composite_prob = 1 / 3\n otherwise_prob = 1 / 6\n\n # Expected value of each outcome\n prime_expected_value = (2 * prime_prob) + (3 * prime_prob) + (5 * prime_prob)\n composite_expected_value = 0 * composite_prob\n otherwise_expected_value = -3 * otherwise_prob\n\n # Total expected value\n total_expected_value = prime_expected_value + composite_expected_value + otherwise_expected_value\n\n # Dollar value to the nearest cent\n result = "{:.2f}".format(total_expected_value)\n return result\n```',
104 | ),
105 | (
106 | "Given $\\mathbf{a} = \\begin{pmatrix} -7 \\\\ 0 \\\\ 1 \\end{pmatrix}$ and $\\mathbf{b} = \\begin{pmatrix} 4 \\\\ 2 \\\\ -1 \\end{pmatrix},$ find $\\mathbf{a} - 3 \\mathbf{b}.$",
107 | "```python\nimport numpy as np\n\ndef solution()\n a = np.array([-7, 0, 1])\n b = np.array([4, 2, -1])\n\n result = a - 3 * b\n\n result = r'\\begin{{pmatrix}} {} \\\\ {} \\\\ {} \\end{{pmatrix}}'.format(result[0], result[1], result[2])\n return result\n```",
108 | ),
109 | ]
110 | examples["math-tora"] = [
111 | (
112 | "Convert the point $(0, -3 \\sqrt{3}, 3)$ in rectangular coordinates to spherical coordinates. Enter your answer in the form $(\\rho,\\theta,\\phi),$ where $\\rho > 0,$ $0 \\le \\theta < 2 \\pi,$ and $0 \\le \\phi \\le \\pi.$",
113 | "To convert the point $(x, y, z)$ in rectangular coordinates to spherical coordinates $(\\rho, \\theta, \\phi)$, we use the following formulas:\n1. $\\rho = \\sqrt{x^2 + y^2 + z^2}$\n2. $\\theta = \\arctan{\\frac{y}{x}}$\n3. $\\phi = \\arccos{\\frac{z}{\\rho}}$\n\nLet's write a function:\n\n```python\nfrom sympy import sqrt, atan2, acos, pi\n\ndef rectangular_to_spherical():\n x, y, z = 0, -3*sqrt(3), 3\n rho = sqrt(x**2 + y**2 + z**2)\n theta = atan2(y, x)\n phi = acos(z/rho)\n return rho, theta, phi\n\nspherical_coordinates = rectangular_to_spherical()\nprint(spherical_coordinates)\n```\n```output\n(6, -pi/2, pi/3)\n```\n\nThe required form of the answer is $(\\rho, \\theta, \\phi)$, where $\\rho > 0$, $0 \\le \\theta < 2\\pi$, and $0 \\le \\phi \\le \\pi$. Since our initial output has $\\theta = -\\pi/2$, which doesn't meet the condition for $\\theta$, we add $2\\pi$ to adjust its value.\n\nHence, the spherical coordinates for the point $(0, -3\\sqrt{3}, 3)$ are $\\boxed{(6, \\frac{3\\pi}{2}, \\frac{\\pi}{3})}$.",
114 | ),
115 | (
116 | "Compute $1011_2 + 101_2 - 1100_2 + 1101_2$. Express your answer in base 2.",
117 | 'Let\'s write a single python function for binary computation:\n\n```python\ndef binary_sum_diff():\n # Convert binary to decimal\n num1 = int("1011", 2)\n num2 = int("101", 2)\n num3 = int("1100", 2)\n num4 = int("1101", 2)\n\n # Perform the operations\n result = num1 + num2 - num3 + num4\n\n # Convert the result back to binary\n result_binary = format(result, "b")\n\n return result_binary\n\nresult = binary_sum_diff()\nprint(result)\n```\n```output\n10001\n```\nThe answer expressed in base 2 is $\\boxed{10001_2}$.',
118 | ),
119 | (
120 | "Find the matrix that corresponds to projecting onto the vector $\\begin{pmatrix} 2 \\\\ -3 \\end{pmatrix}.$",
121 | "```python\nfrom sympy import Matrix, simplify\n\ndef projection_matrix():\n v = Matrix([2, -3])\n\n # Calculate the projection matrix\n P = (v * v.T) / (v.T * v)[0]\n\n # Simplify the elements\n P = P.applyfunc(simplify)\n\n return P\n\nresult = projection_matrix()\nprint(result)\n```\n```output\nMatrix([[4/13, -6/13], [-6/13, 9/13]])\n```\n\nThe matrix is $\\boxed{\\begin{pmatrix} \\frac{4}{13} & -\\frac{6}{13} \\ -\\frac{6}{13} & \\frac{9}{13} \\end{pmatrix}}.$",
122 | ),
123 | (
124 | "Solve the inequality\n\\[\\frac{x^2 - 25}{x + 5} < 0.\\]",
125 | "Let's use `sympy` to solve the inequality:\n\n```python\nfrom sympy import symbols, simplify\nfrom sympy.solvers.inequalities import solve_univariate_inequality\nfrom sympy.core.relational import StrictLessThan\n\ndef solve_inequality():\n x = symbols('x')\n expression = (x**2 - 25) / (x + 5)\n inequality = StrictLessThan(expression, 0)\n solution = solve_univariate_inequality(inequality, x, relational=False)\n simplified_solution = simplify(solution)\n \n return simplified_solution\n\nresult = solve_inequality()\nprint(result)\n```\n```output\nUnion(Interval.open(-oo, -5), Interval.Lopen(-5, 5))\n```\nThe answer is $\\boxed{(-\\infty,-5)\\cup(-5,5)}$",
126 | ),
127 | (
128 | 'In the figure, triangles $ABC$ and $BCD$ are equilateral triangles. What is the value of $AD \\div BC$ when expressed in simplest radical form?\n\n[asy]\ndraw((0,0)--(5,8.7)--(10,0)--cycle);\ndraw((10,0)--(15,8.7)--(5,8.7));\nlabel("$A$",(0,0),SW);\nlabel("$B$",(5,8.7),N);\nlabel("$C$",(10,0),SE);\nlabel("$D$",(15,8.7),NE);\n[/asy]',
129 | "```python\nfrom sympy import Rational, sqrt, simplify\n\ndef ad_divided_by_bc():\n\n x = Rational(1, 1) # Side length of equilateral triangles\n\n ad_squared = 2 * x**2 * (1 + Rational(1, 2)) # Using the law of cosines with cos(2*pi/3) = -1/2\n ad = sqrt(ad_squared)\n\n bc = x # BC is the side length of the equilateral triangles\n\n simplified_ratio = simplify(ad / bc)\n\n return simplified_ratio\n\nresult = ad_divided_by_bc()\nprint(result)\n```\n```output\nsqrt(3)\n```\nThe value of $AD \\div BC$ is $\\boxed{\\sqrt{3}}$.",
130 | ),
131 | ]
132 | examples["carp_en"] = [
133 | (
134 | "If $3 a ^ { m + 2 } b$ and $\\frac { 1 } { 2 } ab ^ { n - 1 }$ are similar terms, then $m + n$ is equal to.",
135 | "Let's think step by step\n$3 a ^ { m + 2 } b$ and $\\frac { 1 } { 2 } ab ^ { n - 1 }$ are like terms. We can obtain $m + 2 = 1$ and $n - 1 = 1$. Solving for $m$ and $n$, we get $m = - 1$ and $n = 2$. Therefore, $m + n = - 1 + 2 = 1$.\nThe answer is: 1",
136 | ),
137 | (
138 | "The solution to the equation $y - \\frac { y - 1 } { 2 } = - \\frac { y + 2 } { 5 }$ is ____ ?",
139 | "Let's think step by step\nTo eliminate the denominator, we have $10 y - 5 ( y - 1 ) = - 2 ( y + 2 )$. Expanding the brackets gives $10 y - 5 y + 5 = - 2 y - 4$. Rearranging terms gives $10 y - 5 y + 2 y = - 4 - 5$, which simplifies to $7 y = - 9$. Dividing both sides by 7 gives $y = - \\frac { 9 } { 7 }$.\nThe answer is: y = - \\frac { 9 } { 7 }",
140 | ),
141 | (
142 | "If $( m + 4 ) ^ 2 + | n - 3 | = 0$, then $\\frac { 1 } { 2 } m - n$ = ____?",
143 | "Let's think step by step\n$\\because ( m + 4 ) ^ 2 + | n - 3 | = 0$, $\\therefore m + 4 = 0$, $n - 3 = 0$, which means $m = - 4$, $n = 3$. Then the original expression equals $- 2 - 3 = - 5$.\nThe answer is: - 5",
144 | ),
145 | (
146 | "Given a quadratic equation in one variable $x$, $x ^ 2 + x + m = 0$, with one root being $x = 1$, what is the other root of this equation?",
147 | "Let's think step by step\nSuppose the quadratic equation in one variable about $x$ is $x ^ 2 + x + m = 0$, and $\\alpha$ is another real root of the equation. Since one real root of the quadratic equation in one variable about $x$ is $1$, we have $\\alpha + 1 = - 1$. Therefore, $\\alpha = - 2$.\nThe answer is: - 2",
148 | ),
149 | (
150 | "The parabola $y = - 5 { x } ^ 2 + 1$ is translated $2$ units upward and $1$ unit to the left, resulting in the parabola _____.",
151 | "Let's think step by step\nThe parabola $y = - 5 { x } ^ 2 + 1$ is first shifted upward by 2 units, resulting in $y = - 5 { x } ^ 2 + 3$. Then it is shifted left by 1 unit, resulting in $y = - 5 {( x + 1 )} ^ 2 + 3$.\nThe answer is: y = - 5 ( x + 1 ) ^ { 2 } + 3",
152 | ),
153 | (
154 | "If the radical $\\sqrt { x - 8 }$ is defined, then the range of real numbers for $x$ is ____ ?",
155 | "Let's think step by step\nSince the radical $\\sqrt { x - 8 }$ is defined, therefore $x - 8 \\ge 0$, which implies $x \\ge 8$.\nThe answer is: x \\ge 8",
156 | ),
157 | (
158 | "If $a ^ { m } \\times a ^ { 2 } = a ^ { 7 }$, then the value of $m$ is ____?",
159 | "Let's think step by step\nAccording to the multiplication rule of powers with the same base: when multiplying powers with the same base, keep the base the same and add the exponents. We have $m + 2 = 7$, so solving for $m$ gives $m = 5$.\nThe answer is: 5",
160 | ),
161 | (
162 | "If line segment $a$ and $b$ satisfy $\\frac { a } { b } = \\frac { 5 } { 2 }$, then the value of $\\frac { a - b } { b }$ is ____?",
163 | "Let's think step by step\n$\\because \\frac { a } { b } = \\frac { 5 } { 2 }$, $\\therefore$ we can assume $a = 5 k$, then $b = 2 k$, $\\therefore \\frac { a - b } { b } = \\frac { 5 k - 2 k } { 2 k } = \\frac { 3 } { 2 }$.\nThe answer is: \\frac { 3 } { 2 }",
164 | ),
165 | ]
166 |
167 | examples["minerva_math"] = [
168 | (
169 | "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}",
170 | "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.",
171 | ),
172 | (
173 | "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$",
174 | "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$",
175 | ),
176 | (
177 | "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?",
178 | "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}",
179 | ),
180 | (
181 | "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.",
182 | "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$",
183 | ),
184 | ]
185 |
186 | examples["aqua"] = [
187 | (
188 | "John found that the average of 15 numbers is 40. If 10 is added to each number then the mean of the numbers is?\nAnswer Choices: (A) 50 (B) 45 (C) 65 (D) 78 (E) 64",
189 | "If 10 is added to each number, then the mean of the numbers also increases by 10. So the new mean would be 50. The answer is (A).",
190 | ),
191 | (
192 | "If a / b = 3/4 and 8a + 5b = 22,then find the value of a.\nAnswer Choices: (A) 1/2 (B) 3/2 (C) 5/2 (D) 4/2 (E) 7/2",
193 | "a / b = 3/4, then b = 4a / 3. So 8a + 5(4a / 3) = 22. This simplifies to 8a + 20a / 3 = 22, which means 44a / 3 = 22. So a is equal to 3/2. The answer is (B).",
194 | ),
195 | (
196 | "A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?\nAnswer Choices: (A) 53 km (B) 55 km (C) 52 km (D) 60 km (E) 50 km",
197 | "The distance that the person traveled would have been 20 km/hr * 2.5 hrs = 50 km. The answer is (E).",
198 | ),
199 | (
200 | "How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: (A) 1156 (B) 1392 (C) 1480 (D) 1562 (E) 1788",
201 | "There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401 three-digit numbers from 100 to 500. 9 + 90(2) + 401(3) = 1392. The answer is (B).",
202 | ),
203 | ]
204 | examples["sat_math"] = [
205 | (
206 | "If $\frac{x-1}{3}=k$ and $k=3$, what is the value of $x$ ? \nAnswer Choices: (A) 2 (B) 4 (C) 9 (D) 10",
207 | "If k = 3, then x - 1 = 3 * 3, therfore, x - 1 = 9 and x = 10. The answer is D",
208 | ),
209 | (
210 | "For $i=\\sqrt{-1}$, what is the sum $(7+3 i)+(-8+9 i)$ ? \nAnswer Choices: (A) $-1+12 i$ (B) $-1-6 i$ (C) $15+12 i$ (D) $15-6 i$ 3",
211 | "For (7+3 i)+(-8+9 i), the real part is 7 + (-8) = -1, the imageinary part is 3 i + 9 i = 12 i. The answer is A",
212 | ),
213 | (
214 | "On Saturday afternoon, Armand sent $m$ text messages each hour for 5 hours, and Tyrone sent $p$ text messages each hour for 4 hours. Which of the following represents the total number of messages sent by Armand and Tyrone on Saturday afternoon?\nAnswer Choices: (A) $9 m p$ (B) $20 m p$ (C) $5 m+4 p$ (D) $4 m+5 p$",
215 | "Armand texts m messages each hour for 5 hours, which leads to 5m messages. Tyrone texts p messages each hour for 4 hours, which leds to 4p messages. The total is 5m + 4p. The answer is C.",
216 | ),
217 | (
218 | "$$\begin{array}{r}3 x+4 y=-23 \\2 y-x=-19\\end{array}$$What is the solution $(x, y)$ to the system of equations above?\nAnswer Choices: (A) $(-5,-2)$ (B) $(3,-8)$ (C) $(4,-6)$ (D) $(9,-6)$",
219 | "By solving this equation, we found that x = 3 and y = -8. The answer is B.",
220 | ),
221 | ]
222 | examples["mmlu_mathematics"] = [
223 | (
224 | "Simplify and write the result with a rational denominator: $$\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}$$\nAnswer Choices: (A) \\frac{3\\sqrt{3}}{3} (B) \\frac{1}{3} (C) \\sqrt{3} (D) \\frac{\\sqrt{3}}{3}",
225 | "Factoring $729=3^6$ and combining the roots $\frac{1}{2}\frac{1}{3}\frac{1}{2}=\frac{1}{12}$, we get that $\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}=\\left(\frac{1}{3^6}\right)^{\frac{1}{12}}=\frac{1}{3^{\frac{1}{2}}}=\frac{3}{\\sqrt{3}}$. The answer is (D).",
226 | ),
227 | (
228 | "Five thousand dollars compounded annually at an $x\\%$ interest rate takes six years to double. At the same interest rate, how many years will it take $\\$300$ to grow to $\\$9600$?\nAnswer Choices:(A) 12 (B) 1 (C) 30 (D) 5",
229 | "To go from $\\$300$ to $\\$9600$, the value must go up by a factor of $9600/300=32=2^5$. Since at this interest rate it takes six years for it to double, it will take $5*6=30$ years to grow to $\\$9600$. The answer is (C).",
230 | ),
231 | (
232 | "Ten students take a biology test and receive the following scores: 45, 55, 50, 70, 65, 80, 40, 90, 70, 85. What is the mean of the students’ test scores?\nAnswer Choices: (A) 55 (B) 60 (C) 62 (D) 65",
233 | "There are 10 students and the sum of their scores is $45 + 55 + 50 + 70 + 65 + 80 + 40 + 90 + 70 + 85 = 650$, the mean is $650/10=65$. The answer is (D).",
234 | ),
235 | (
236 | "The variable $x$ varies directly as the square of $y$, and $y$ varies directly as the cube of $z$. If $x$ equals $-16$ when $z$ equals 2, what is the value of $x$ when $z$ equals $\frac{1}{2}$?\nAnswer Choices: (A) -1 (B) 16 (C) -\frac{1}{256} (D) \\frac{1}{16}",
237 | "We know that $x \\propto y^2$ and $y \\propto z^3$, so $x = k z^6$ for some constant $k$. Plugging in for $x=-16$ and $z=2$, the constant value is $k=\frac{x}{z^6}=\frac{-16}{64}=-\frac{1}{4}$. So, when $z=\frac{1}{2}$, the value of $x$ is $x=kz^6=-\frac{1}{4}\frac{1}{2^6}=-\frac{1}{256}$. The answer is (C).",
238 | ),
239 | (
240 | "Joe was in charge of lights for a dance. The red light blinks every two seconds, the yellow light every three seconds, and the blue light every five seconds. If we include the very beginning and very end of the dance, how many times during a seven minute dance will all the lights come on at the same time? (Assume that all three lights blink simultaneously at the very beginning of the dance.)\nAnswer Choices: (A) 3 (B) 15 (C) 6 (D) 5",
241 | "The least common multiple of 2, 3 and 5 is 30, so during a 7 minute dance, all the three lights will come on at the same time $2*7+1=15$ times. The answer is (B).",
242 | ),
243 | ]
244 | examples["mmlu_physics"] = [
245 | (
246 | "A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\nAnswer Choices: (A) 10 W (B) 30 W (C) 60 W (D) 240 W",
247 | "Rate of energy usage is known as power; in an dissipative electrical circuit, power is given by voltage times current. So in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D).",
248 | ),
249 | (
250 | "A point charge, Q = +1 mC, is fixed at the origin. How much work is required to move a charge, Q = +8 µC, from the point (0, 4 meters) to the point (3 meters, 0)?\nAnswer Choices: (A) 3.5 J (B) 6.0 J (C) 22.5 J (D) 40 J",
251 | "To calculate the work required to move a charge from one location to another in a fixed electric field, it is enough to calculate the potential difference between the two locations. Here, the potential only depends on the distance between the charges; it’s $k q_1 q_2 / r$, where $k$ is Coulomb’s constant. Plugging in values $q_1 = $ 1 mC, $q_2 = 8 \\mu$ C, gives the answer as 5.992 J, which rounds to 6 J. The answer is (B).",
252 | ),
253 | (
254 | "Which of the following conditions will ensure that angular momentum is conserved? I. Conservation of linear momentum II. Zero net external force III. Zero net external torque.\nAnswer Choices: (A) I and II only (B) I and III only (C) II and III only (D) III only",
255 | "Torque is defined as the change in angular momentum; if there is zero external torque, angular momentum is conserved. The answer is (D).",
256 | ),
257 | (
258 | "A photocell of work function ϕ = 2eV is connected to a resistor in series. Light of frequency f = 1 × 10^15 Hz hits a metal plate of the photocell. If the power of the light is P = 100 W, what is the current through the resistor?\nAnswer Choices: (A) 2:00 AM (B) 6:00 AM (C) 12:00 AM (D) 24 A",
259 | "The only answer above which has units of current is D, 24 A. The answer is (D).",
260 | ),
261 | (
262 | "A pipe full of air is closed at one end. A standing wave is produced in the pipe, causing the pipe to sound a note. Which of the following is a correct statement about the wave’s properties at the closed end of the pipe?\nAnswer Choices: (A) The pressure is at a node, but the particle displacement is at an antinode. (B) The pressure is at an antinode, but the particle displacement is at a node. (C) The pressure and the particle displacement are both at nodes. (D) The pressure and the particle displacement are both at antinodes.",
263 | "At the closed end of the pipe, the particles cannot have any net displacement because the pipe closure stops them. So the particle displacement is at a node. This closure also causes the pressure to be maximal, i.e. an antinode. The answer is (B).",
264 | ),
265 | ]
266 | examples["mmlu_chemistry"] = [
267 | (
268 | "Which of the following is considered an acid anhydride?\nAnswer Choices: (A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3",
269 | "An acid anhydride is a compound that is derived by removing water from an acid. The chemical formula for water is H2O, which means that we need to determine which of these options, when combined with H2O, forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, or sulfuric acid. The answer is (C).",
270 | ),
271 | (
272 | "Which of the following is expected to be a polar molecule?\nAnswer Choices: (A) PCl4F (B) BF3 (C) CO2 (D) Si(CH3)4",
273 | "A polar molecule is one that has a slightly positive charge on one end of the molecule and a slightly negative charge on the other end. Boron trifluoride (BF3) has Boron as the center atom and three fluorine atoms attached to it; it is trigonal planar and symmetric, so it is nonpolar. Carbon Dioxide (CO2) has Carbon as the central atom with double bonds to two Oxygen atoms - this is also symmetrical and therefore nonpolar. The same is the case for tetramethyl silane (SI(CH3)4), which is a Silicon atom surrounded by four methyl groups. The structure of PCL4F is that Phosphorus is the central atom, attached to four chlorines and one fluorine atom. This is asymmetrical, and therefore has a net dipole and is expected to be a polar molecule. The answer is (A).",
274 | ),
275 | (
276 | "From the solubility rules, which of the following is true?\nAnswer Choices: (A) All chlorides, bromides, and iodides are soluble (B) All sulfates are soluble (C) All hydroxides are soluble (D) All ammonium-containing compounds are soluble",
277 | "The chlorides, bromides, and iodides of lead, silver, and mercury are not soluble in water. This rules out (A). The sulfates of lead, barium, and calcium are not soluble in water, which rules out (B). The hydroxides of any metal besides sodium, potassium, ammonium, calcium, and barium are insoluble. This rules out (C). Typically ammonium ions indicate a soluble ionic substance. The answer is (D).",
278 | ),
279 | (
280 | "A new compound is synthesized and found to be a monoprotic acid with a molar mass of 248 g/mol. When 0.0050 mol of this acid are dissolved in 0.500 L of water, the pH is measured as 3.89. What is the pKa of this acid?\nAnswer Choices: (A) 3.89 (B) 7.78 (C) 5.78 (D) 2.33",
281 | "Recall that $[A] = [H^{+}]$. Here, this is equal to $$10^{-3.89}$. Then we have $K_{a} = $frac{[H^{+}][A^{-}]}{[HA]} = \\frac{10^{-3.89} \\cdot 10^{-3.89}}{10^{-2}}. The resulting exponent is $-3.89 + (-3.89) - (-2) = 5.78$, therefore $K_a = 10^{-5.78}$. The $pK_a$ is the negative log of $K_a$, which is equal to $5.78$. The answer is (C).",
282 | ),
283 | (
284 | "A solution contains 2.00 mole of acetic acid, CH3COOH, and 1.00 mole of calcium acetate, Ca(CH3COO)2. The solution is able to resist the addition of a small amount of strong acid or strong base with only minor changes in the pH of the solution. Larger quantities of strong acid or strong base can cause a significant change in pH. How many moles of nitric acid, HNO3, may be added before the pH begins to change significantly?\nAnswer Choices: (A) 0.500 mole (B) 1.00 mole (C) 2.00 mole (D) 3.00 mole",
285 | "We would like to compute the buffer capacity of this solution. First we write the equation for the ionization of the weak acid, in this case of acetic acid. $CH_{3}COOH (aq) + H_{2}O \rightarrow H_{3}O^{+} + CH3COO^{-}$. The conjugate base is therefore the acetate ion. The added strong acid, Nitric acid, will react with the conjugate base. Therefore the maximum amount of acid that can be added will be equal to the amount of acetate ion, or 2 moles. The answer is (C).",
286 | ),
287 | ]
288 | examples["mmlu_biology"] = [
289 | (
290 | "In animal cells, which of the following represents the most likely pathway that a secretory protein takes as it is synthesized in a cell?\nAnswer Choices: (A) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (B) Ribosome–Golgi apparatus–rough ER–secretory vesicle–plasma membrane (C) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (D) Ribosome–rough ER–Golgi apparatus–secretory vesicle–plasma membrane",
291 | "Protein synthesis starts at the ribosome, so we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum and moves from there to the Golgi apparatus, where it is modified and packaged into a vesicle. The vesicle then floats to the plasma membrane and is secreted. The answer is (D).",
292 | ),
293 | (
294 | "A mutation in a bacterial enzyme changed a previously polar amino acid into a nonpolar amino acid. This amino acid was located at a site distant from the enzyme’s active site. How might this mutation alter the enzyme’s substrate specificity?\nAnswer Choices: (A) By changing the enzyme’s pH optimum (B) By changing the enzyme’s location in the cell (C) By changing the shape of the protein (D) An amino acid change away from the active site cannot alter the enzyme’s substrate specificity.",
295 | "A change in an amino acid leads to a change in the primary structure of the protein. A change in the primary structure may lead to a change in the secondary and the tertiary structure of the protein. A change in the tertiary structure means a change in the shape of the protein, so (C) has to be correct. Since the change does not affect the active site of the enzyme, we do not expect the activity of the enzyme to be affected. The answer is (C).",
296 | ),
297 | (
298 | "Which of the following is not a way to form recombinant DNA?\nAnswer Choices: (A) Translation (B) Conjugation (C) Specialized transduction (D) Transformation",
299 | "The introduction of foreign DNA or RNA into bacteria or eukaryotic cells is a common technique in molecular biology and scientific research. There are multiple ways foreign DNA can be introduced into cells including transformation, transduction, conjugation, and transfection. In contrast, (A) is not a way to form DNA: during translation the ribosomes synthesize proteins from RNA. The answer is (A).",
300 | ),
301 | (
302 | "Homologous structures are often cited as evidence for the process of natural selection. All of the following are examples of homologous structures EXCEPT\nAnswer Choices: (A) the wings of a bird and the wings of a bat (B) the flippers of a whale and the arms of a man (C) the pectoral fins of a porpoise and the flippers of a seal (D) the forelegs of an insect and the forelimbs of a dog",
303 | "Homologous structures are similar physical features in organisms that share a common ancestor but different functions. Comparisons (B) and (C) are clearly homologous because they share a common ancestor and the structures serve different purposes. Bat wings and birg wings are also homologous, while they are both wings, the forelimbs serve different purposes. Insects and dogs are very far ancestors since one is vertebrate while the other is invertebrate and the forelimbs serve the same purpose, so they are not homologous. The answer is (D).",
304 | ),
305 | (
306 | "Which of the following is not known to be involved in the control of cell division?\nAnswer Choices: (A) Cyclins (B) Protein kinases (C) Checkpoints (D) Fibroblast cells",
307 | "Normal cells move through the cell cycle in a regulated way. At the checkpoint stage, they use information about their own internal state and cues from the environment around them to decide whether to proceed with cell division. Cues like these act by changing the activity of core cell cycle regulators inside the cell. The most common regulators are cyclins and cyclin-dependent kinases. Fibroblast cells do not play any role in cell division. The answer is (D).",
308 | ),
309 | ]
310 | examples["mmlu_computer"] = [
311 | (
312 | "Which of the following is an example of the use of a device on the Internet of Things (IoT) ?\nAnswer Choices: (A) A car alerts a driver that it is about to hit an object. (B) A hiker uses a G P S watch to keep track of her position. (C) A refrigerator orders milk from an online delivery service when the milk in the refrigerator is almost gone. (D) A runner uses a watch with optical sensors to monitor his heart rate.",
313 | "The term Internet of Things (IoT) refers to common devices which are connected to the internet, enabling new functionality. Choice A is incorrect because it does not describe an internet connected device. In choice B, the watch is only described as having GPS functionality but no internet connectivity. Choice C describes a common device (a refrigerator) which has internet connectivity enabling new functionality (online ordering). Choice D does not mention internet connectivity for the watch, only optical sensors. The answer is (C).",
314 | ),
315 | (
316 | "Many Web browsers allow users to open anonymous windows. During a browsing session in an anonymous window, the browser does not record a browsing history or a list of downloaded files. When the anonymous window is exited, cookies created during the session are deleted. Which of the following statements about browsing sessions in an anonymous window is true?\nAnswer Choices: (A) The activities of a user browsing in an anonymous window will not be visible to people who monitor the user's network, such as the system administrator. (B) Items placed in a Web store's shopping cart for future purchase during the anonymous browsing session will not be saved on the user's computer. (C) A user will not be able to log in to e-mail or social media accounts during the anonymous browsing session. (D) A user browsing in an anonymous window will be protected from viruses launched from any web sites visited or files downloaded.",
317 | "Choice A is incorrect as it only describes network traffic, which an anonymous browser does not change. Choice B is correct as it correctly describes how an anonymous browser will prevent saving data on the user’s computer after the session is ended. Choice C is incorrect because an anonymous browser will not prevent logging in to email or social media accounts. Choice D is incorrect because an anonymous browser in itself performs no virus protection. The answer is (B).",
318 | ),
319 | (
320 | 'What is the output of "abc"[::-1] in Python 3? \nAnswer Choices: (A) Error (B) abc (C) cba (D) c',
321 | 'We know that the slicing operator [::-1] takes all of the elements in the string in reverse order, so we reverse the order of the string "abc", resulting in "cba". The answer is (C).',
322 | ),
323 | (
324 | 'In the program below, the initial value of X is 5 and the initial value of Y is 10.\nIF (X < 0){\n DISPLAY ("Foxtrot")\n} ELSE {\n IF (X > Y){\n DISPLAY ("Hotel")\n } ELSE {\n IF (Y > 0){\n DISPLAY ("November")\n } ELSE {\n DISPLAY ("Yankee")\n }\n }\n}\nWhat is displayed as a result of running the program?\nAnswer Choices: (A) Foxtrot (B) Hotel (C) November (D) Yankee',
325 | 'Because X has the value 5, the first conditional IF (X < 0) is false, so we move to the first ELSE clause. Because X is 5 and Y is 10, the second conditional IF (X > Y) is false, so we move to the following ELSE clause. Since Y is 10, the conditional IF (Y > 0) is true, so the command DISPLAY ("November") is executed. The answer is (C).',
326 | ),
327 | (
328 | "A list of numbers has n elements, indexed from 1 to n. The following algorithm is intended to display the number of elements in the list that have a value greater than 100. The algorithm uses the variables count and position. Steps 3 and 4 are missing.\n Step 1: Set count to 0 and position to 1.\n Step 2: If the value of the element at index position is greater than 100, increase the value of count by 1.\n Step 3: (missing step)\n Step 4: (missing step)\n Step 5: Display the value of count.\nWhich of the following could be used to replace steps 3 and 4 so that the algorithm works as intended?\nAnswer Choices: (A) Step 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until the value of count is greater than 100.\n(B) Step 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until the value of position is greater than n.\n(C) Step 3: Repeat step 2 until the value of count is greater than 100.\n Step 4: Increase the value of position by 1.\n(D) Step 3: Repeat step 2 until the value of position is greater than n.\n Step 4: Increase the value of count by 1.",
329 | "Choice A is incorrect, because its Step 4 has an incorrect termination condition, stopping when count is greater than 100. We need to stop after inspecting all elements in the list. Choice B is correct because it correctly increments both count and position, and correctly repeats these steps and terminates when all elements in the list have been inspected. Choice C is incorrect because it incorrectly increments the variable count until its value is greater than 100, regardless of the elements in the list. Choice D is incorrect because its step 3 does not increment the value of position, so it will repeat forever. The answer is (B).",
330 | ),
331 | ]
332 | # mammoth
333 | examples["mmlu_stem"] = [
334 | (
335 | "Simplify and write the result with a rational denominator: $$\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}$$\nAnswer Choices: (A) \\frac{3\\sqrt{3}}{3} (B) \\frac{1}{3} (C) \\sqrt{3} (D) \\frac{\\sqrt{3}}{3}",
336 | "Factoring $729=3^6$ and combining the roots $\\frac{1}{2}\\frac{1}{3}\\frac{1}{2}=\\frac{1}{12}$, we get that $\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}=\\left(\frac{1}{3^6}\right)^{\frac{1}{12}}=\frac{1}{3^{\frac{1}{2}}}=\frac{3}{\\sqrt{3}}$. The answer is (D).",
337 | ),
338 | (
339 | "In animal cells, which of the following represents the most likely pathway that a secretory protein takes as it is synthesized in a cell?\nAnswer Choices: (A) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (B) Ribosome–Golgi apparatus–rough ER–secretory vesicle–plasma membrane (C) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (D) Ribosome–rough ER–Golgi apparatus–secretory vesicle–plasma membrane",
340 | "Protein synthesis starts at the ribosome, so we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum and moves from there to the Golgi apparatus, where it is modified and packaged into a vesicle. The vesicle then floats to the plasma membrane and is secreted. The answer is (D).",
341 | ),
342 | (
343 | "A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\nAnswer Choices: (A) 10 W (B) 30 W (C) 60 W (D) 240 W",
344 | "Rate of energy usage is known as power; in an dissipative electrical circuit, power is given by voltage times current. So in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D).",
345 | ),
346 | (
347 | "Which of the following is considered an acid anhydride?\nAnswer Choices: (A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3",
348 | "An acid anhydride is a compound that is derived by removing water from an acid. The chemical formula for water is H2O, which means that we need to determine which of these options, when combined with H2O, forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, or sulfuric acid. The answer is (C).",
349 | ),
350 | (
351 | 'What is the output of "abc"[::-1] in Python 3? \nAnswer Choices: (A) Error (B) abc (C) cba (D) c',
352 | 'We know that the slicing operator [::-1] takes all of the elements in the string in reverse order, so we reverse the order of the string "abc", resulting in "cba". The answer is (C).',
353 | ),
354 | ]
355 | examples["gaokao"] = [
356 | (
357 | "已知 $\\alpha, \\beta, \\gamma$ 是互不相同的锐角, 则在 $\\sin \\alpha \\cos \\beta, \\sin \\beta \\cos \\gamma, \\sin \\gamma \\cos \\alpha$ 三个值中, 大于 $\\frac{1}{2}$ 的个数的最大值是 ( )\n从以下选项中选择:\n(A) 0\n(B) 1\n(C) 2\n(D) 3",
358 | "1. 如果 $\\alpha, \\beta, \\gamma$ 均小于 $60^\\circ$,那么他们的正弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。\n2. 如果有一个角大于 $60^\\circ$,假设为 $\\alpha$,那么对应的正弦值大于 $\\frac{1}{2}$。此时,由于三角形内角和为 $180^\\circ$,所以 $\\beta + \\gamma < 120^\\circ$。这意味着 $\\beta, \\gamma$ 的余弦值均大于 $\\frac{1}{2}$,所以此时 $\\sin \\alpha \\cos \\beta > \\frac{1}{2}, \\sin \\beta \\cos \\gamma > \\frac{1}{2}$。\n3. 如果有两个角大于 $60^\\circ$,例如 $\\alpha$ 和 $\\beta$,那么由于三角形内角和为 $180^\\circ$,我们可以得到 $\\gamma < 60^\\circ$,此时 $\\sin \\gamma < \\frac{1}{2}$。由于 $\\alpha$ 和 $\\beta$ 的余弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。\n4. 如果三个角都大于 $60^\\circ$,显然不符合题意。\n综上所述,当有一个角大于 $60^\\circ$ 时,大于 $\\frac{1}{2}$ 的个数的最大值是 2。\n答案是 C",
359 | ),
360 | (
361 | "正方体 $A B C D-A_{1} B_{1} C_{1} D_{1}$ 中, $B B_{1}$ 与平面 $A C D_{1}$ 所成角的余弦值为 ( )\n从以下选项中选择:\n(A) $\\frac{\\sqrt{2}}{3}$\n(B) $\\frac{\\sqrt{3}}{3}$\n(C) $\\frac{2}{3}$\n(D) $\\frac{\\sqrt{6}}{3}$",
362 | "设上下底面的中心分别为 $\\mathrm{O}_{1}, \\mathrm{O}$, 设正方体的棱长等于 1 , 则 $O_{1} O$ 与平面 $A C D_{1}$ 所成角就是 $B B_{1}$ 与平面 $A C D_{1}$ 所成角, 即 $\\angle O_{1} O D_{1}$, 直角三角形 $\\mathrm{OO}_{1} \\mathrm{D}_{1}$ 中, $\\cos \\angle \\mathrm{O}_{1} \\mathrm{OD}_{1}=\\frac{\\mathrm{O}_{1} \\mathrm{O}}{\\mathrm{OD}_{1}}=\\frac{\\frac{1}{\\sqrt{6}}}{2}=\\frac{\\sqrt{6}}{3}$. \n答案是 C",
363 | ),
364 | (
365 | "设函数 $f(x)=\\left\\{\\begin{array}{ll}1+\\log _{2}(2-x), & x<1 \\ 2^{x-1}, & x \\geqslant 1,\\end{array}\\right.$ 则 $f(-2)+f\\left(\\log _{2} 12\\right)=$ ( )\n从以下选项中选择:\n(A) 3\n(B) 6\n(C) 9\n(D) 12",
366 | "首先,我们可以根据定义计算 $f(-2)$ 和 $f(\\log_2 12)$:\n$f(-2)=1+\\log_2(2-(-2))=1+\\log_2 4=3$\n$f(\\log_2 12)=2^{\\log_2 12-1}=6$\n因此,$f(-2)+f(\\log_2 12)=3+6=9$。\n答案是 C",
367 | ),
368 | (
369 | "已知函数 $f(x)=a x^{3}-3 x^{2}+1$, 若 $f(x)$ 存在唯一的零点 $x_{0}$, 且 $x_{0}>$ 0 , 则实数 $\\mathrm{a}$ 的取值范围是 ( )\n从以下选项中选择:\n(A) $(1,+\\infty)$\n(B) $(2,+\\infty)$\n(C) $(-\\infty,-1)$\n(D) $(-\\infty,-2)$",
370 | "首先,我们可以通过求出函数的导函数 $f'(x)$ 来判断函数在 $x>0$ 区间内的单调性。在这里,我们求出导函数 $f'(x)$ 为 $f'(x)=3ax^2-6x$。\n然后,我们需要求出导函数 $f'(x)$ 的零点,以确定函数 $f(x)$ 在 $x>0$ 区间内的单调性。导函数 $f'(x)$ 的零点为 $x=0$ 和 $x=\\frac{2}{\\sqrt{a}}$。注意到 $x>0$,所以我们得到 $a<0$。此外,由于函数 $f(x)$ 在 $x=0$ 处的函数值为 $1$,因此不能有 $a=\\frac{4}{3}$。\n综上所述,当 $a$ 的取值范围为 $a<-\\frac{4}{3}$ 时,函数 $f(x)$ 在 $x>0$ 区间内是单调递减的,此时存在唯一的零点 $x_0$。因此,答案为 $(-\\infty,-2)$。\n答案是 D",
371 | ),
372 | (
373 | "设 $\\left\\{a_{n}\\right\\}$ 是公差不为 0 的无穷等差数列, 则“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的 ( )\n从以下选项中选择:\n(A) 充分而不必要条件\n(B) 必要而不充分条件\n(C) 充分必要条件\n(D) 既不充分也不必要条件",
374 | "首先,我们可以通过举例来判断该条件是充分还是必要条件。如果一个数列递增,那么它的公差一定大于 0,也就是存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$。因此,“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的必要条件。\n接下来,我们需要判断是否充分。也就是说,如果存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$,那么能否得出“ $\\left\\{a_{n}\\right\\}$ 为递增数列”这一结论。答案是肯定的。因为如果 $a_{n}>0$,那么 $a_{n+1}-a_{n}>0$,即公差大于 0,因此该数列是递增的。因此,该条件是充分条件。\n综上所述,选项为 (C) 充分必要条件。\n答案是 C",
375 | ),
376 | ]
377 |
378 | return examples
379 |
--------------------------------------------------------------------------------