├── assets ├── main_table.png ├── method_NFT.jpg ├── distribution_NFT.jpg ├── main_compare_NFT.jpg ├── val_acc_curve_NFT.jpg └── algorithm_spectrum_NFT.jpg ├── requirements.txt ├── download_model.sh ├── download_data.sh ├── eval_local_7B.sh ├── eval_local_32B.sh ├── experience_maker.py ├── .gitignore ├── verifier.py ├── compute_acc.py ├── README.md ├── train_7B.sh ├── train_32B.sh ├── main_nft.py ├── config └── nft_trainer.yaml ├── qwen_math_eval_toolkit ├── utils.py ├── grader.py ├── parser.py └── examples.py ├── README_VeRL.md ├── LICENSE └── dp_actor.py /assets/main_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/main_table.png -------------------------------------------------------------------------------- /assets/method_NFT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/method_NFT.jpg -------------------------------------------------------------------------------- /assets/distribution_NFT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/distribution_NFT.jpg -------------------------------------------------------------------------------- /assets/main_compare_NFT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/main_compare_NFT.jpg -------------------------------------------------------------------------------- /assets/val_acc_curve_NFT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/val_acc_curve_NFT.jpg -------------------------------------------------------------------------------- /assets/algorithm_spectrum_NFT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/NFT/HEAD/assets/algorithm_spectrum_NFT.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt records the full set of dependencies for development 2 | accelerate 3 | codetiming 4 | datasets 5 | dill 6 | flash-attn 7 | hydra-core 8 | liger-kernel 9 | math-verify[antlr4_9_3] 10 | numpy 11 | pandas 12 | peft 13 | pyarrow>=15.0.0 14 | pybind11 15 | pylatexenc 16 | ray[default] 17 | tensordict<0.6 18 | torchdata 19 | transformers 20 | vllm<=0.6.3 21 | wandb 22 | -------------------------------------------------------------------------------- /download_model.sh: -------------------------------------------------------------------------------- 1 | mkdir -p models 2 | # required only for 7B models 3 | huggingface-cli download Qwen/Qwen2.5-Math-7B --local-dir ./models/Qwen2.5-Math-7B --local-dir-use-symlinks False 4 | 5 | # required only for 32B models 6 | huggingface-cli download Qwen/Qwen2.5-32B --local-dir ./models/Qwen2.5-32B --local-dir-use-symlinks False 7 | # replace 32B tokenizer config to keep a consistent system prompt with 7B 8 | cp models/32b_tokenizer_config.json models/Qwen2.5-32B/tokenizer_config.json 9 | 10 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | HF_REPO_BASE="https://huggingface.co/datasets/ChenDRAG/VeRL_math_validation/resolve/main" 2 | mkdir -p data 3 | wget -O data/dapo-math-17k_10boxed.parquet "${HF_REPO_BASE}/dapo-math-17k_10boxed.parquet?download=true" 4 | wget -O data/aime-2024-boxed_w_answer.parquet "${HF_REPO_BASE}/aime-2024-boxed_w_answer.parquet?download=true" 5 | wget -O data/math500_boxed.parquet "${HF_REPO_BASE}/math500_boxed.parquet?download=true" 6 | wget -O data/minerva_math.parquet "${HF_REPO_BASE}/minerva_math.parquet?download=true" 7 | wget -O data/olympiadbench.parquet "${HF_REPO_BASE}/olympiadbench.parquet?download=true" 8 | wget -O data/aime2025_32_dapo_boxed_w_answer.parquet "${HF_REPO_BASE}/aime2025_32_dapo_boxed_w_answer.parquet?download=true" 9 | wget -O data/amc2023_32_dapo_boxed_w_answer.parquet "${HF_REPO_BASE}/amc2023_32_dapo_boxed_w_answer.parquet?download=true" -------------------------------------------------------------------------------- /eval_local_7B.sh: -------------------------------------------------------------------------------- 1 | pip install word2number 2 | save_path=./NFT_32B 3 | context_length=8192 4 | max_prompt_length=1024 5 | n_samples=1 6 | top_p=0.7 7 | temperature=1.0 8 | top_k=-1 9 | max_response_length=$(($context_length-$max_prompt_length)) 10 | 11 | 12 | for data_name in aime-2024-boxed_w_answer math500_boxed minerva_math olympiadbench aime2025_32_dapo_boxed_w_answer amc2023_32_dapo_boxed_w_answer; do 13 | data_load_path=./data/$data_name.parquet; \ 14 | data_save_path=${save_path}/${data_name}_max${context_length}_topp${top_p}topk${top_k}_temp${temperature}_@${n_samples}.parquet; \ 15 | python -u -m verl.trainer.main_generation \ 16 | trainer.nnodes=1 \ 17 | trainer.n_gpus_per_node=8 \ 18 | data.path=$data_load_path \ 19 | data.prompt_key=prompt \ 20 | data.n_samples=$n_samples \ 21 | data.output_path=$data_save_path \ 22 | model.path=$save_path/huggingface \ 23 | +model.trust_remote_code=True \ 24 | rollout.temperature=$temperature \ 25 | rollout.top_k=$top_k \ 26 | rollout.top_p=$top_p \ 27 | rollout.prompt_length=$max_prompt_length \ 28 | rollout.response_length=$max_response_length \ 29 | rollout.tensor_model_parallel_size=4 \ 30 | rollout.gpu_memory_utilization=0.8; \ 31 | python3 -m compute_acc --input_path $data_save_path --verifier all 32 | done -------------------------------------------------------------------------------- /eval_local_32B.sh: -------------------------------------------------------------------------------- 1 | pip install word2number 2 | save_path=./NFT_7B 3 | context_length=4096 4 | max_prompt_length=512 5 | n_samples=1 6 | top_p=0.7 7 | temperature=0.6 8 | top_k=-1 9 | max_response_length=$(($context_length-$max_prompt_length)) 10 | 11 | 12 | for data_name in aime-2024-boxed_w_answer math500_boxed minerva_math olympiadbench aime2025_32_dapo_boxed_w_answer amc2023_32_dapo_boxed_w_answer; do 13 | data_load_path=./data/$data_name.parquet; \ 14 | data_save_path=${save_path}/${data_name}_max${context_length}_topp${top_p}topk${top_k}_temp${temperature}_@${n_samples}.parquet; \ 15 | python -u -m verl.trainer.main_generation \ 16 | trainer.nnodes=1 \ 17 | trainer.n_gpus_per_node=8 \ 18 | data.path=$data_load_path \ 19 | data.prompt_key=prompt \ 20 | data.n_samples=$n_samples \ 21 | data.output_path=$data_save_path \ 22 | model.path=$save_path/huggingface \ 23 | +model.trust_remote_code=True \ 24 | rollout.temperature=$temperature \ 25 | rollout.top_k=$top_k \ 26 | rollout.top_p=$top_p \ 27 | rollout.prompt_length=$max_prompt_length \ 28 | rollout.response_length=$max_response_length \ 29 | rollout.tensor_model_parallel_size=4 \ 30 | rollout.gpu_memory_utilization=0.8; \ 31 | python3 -m verl.trainer.compute_acc --input_path $data_save_path --verifier all; 32 | done -------------------------------------------------------------------------------- /experience_maker.py: -------------------------------------------------------------------------------- 1 | from qwen_math_eval_toolkit.parser import extract_answer as qwen_extract_answer 2 | from qwen_math_eval_toolkit.grader import math_equal as qwen_math_equal 3 | 4 | from multiprocessing import Process, Queue 5 | def qwen_math_equal_subprocess(prediction, reference, timeout_seconds=10): 6 | def worker(q, prediction, reference): 7 | result = qwen_math_equal(prediction=prediction, reference=reference, timeout=False) 8 | q.put(result) 9 | 10 | q = Queue() 11 | p = Process(target=worker, args=(q, prediction, reference)) 12 | p.start() 13 | 14 | # 添加超时处理 15 | p.join(timeout=timeout_seconds) # 等待进程完成,最多等待 timeout_seconds 秒 16 | 17 | # 如果进程还在运行,则终止它并返回 False 18 | if p.is_alive(): 19 | p.terminate() 20 | p.join() # 确保进程被完全清理 21 | return False 22 | 23 | # 如果进程正常完成,获取结果 24 | try: 25 | return q.get_nowait() 26 | except: 27 | return False 28 | 29 | import re 30 | def preprocess_box_response_for_qwen_prompt(sequence, answer): 31 | # breakpoint() 32 | model_output= re.sub(r'^.*?<\|im_start\|>assistant', '<|im_start|>assistant', sequence, flags=re.DOTALL,count = 1) 33 | stop_words = ["", "<|im_end|>", "<|endoftext|>"] 34 | for stop_word in stop_words: 35 | if stop_word in model_output: 36 | model_output = model_output.split(stop_word)[0].strip() 37 | extract_answer = qwen_extract_answer(model_output, data_name="math") #TODO: check the data_name, hard code here for now 38 | 39 | if qwen_math_equal_subprocess(prediction=extract_answer, reference=answer): 40 | return 1.0 41 | else: 42 | return 0.0 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | **/*.pt 3 | **/checkpoints 4 | **/wget-log 5 | **/_build/ 6 | **/*.ckpt 7 | **/outputs 8 | **/*.tar.gz 9 | **/playground 10 | **/wandb 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | dataset/* 17 | tensorflow/my_graph/* 18 | .idea/ 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # IPython Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # vscode 105 | .vscode 106 | 107 | # Mac 108 | .DS_Store 109 | 110 | # output logs 111 | tests/e2e/toy_examples/deepspeed/synchronous/output.txt 112 | 113 | # vim 114 | *.swp 115 | 116 | # ckpt 117 | *.lock 118 | 119 | # data 120 | *.parquet 121 | 122 | 123 | # local logs 124 | logs 125 | log 126 | 127 | ckpts/ 128 | models/ -------------------------------------------------------------------------------- /verifier.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | 30 | from math_verify.metric import math_metric 31 | from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig 32 | 33 | def math_verify_compute_score(model_output: str, ground_truth: str) -> bool: 34 | verify_func = math_metric( 35 | gold_extraction_target=(LatexExtractionConfig(),), 36 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 37 | ) 38 | ret_score = 0. 39 | 40 | # Wrap the ground truth in \boxed{} format for verification 41 | ground_truth_boxed = "\\boxed{" + ground_truth + "}" 42 | try: 43 | ret_score, _ = verify_func([ground_truth_boxed], [model_output]) 44 | except Exception as e: 45 | print(e) 46 | 47 | return { 48 | "score": 1.0 if ret_score > 0.5 else -1.0, 49 | "acc": 1.0 if ret_score > 0.5 else 0.0, 50 | "pred": model_output, 51 | } 52 | 53 | def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None): 54 | if data_source == 'openai/gsm8k': 55 | from verl.utils.reward_score import gsm8k 56 | res = gsm8k.compute_score(solution_str, ground_truth) 57 | elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval', 'math500', "math500_w_answer", "olympiadbench", "minerva_math"]: 58 | # from verl.utils.reward_score import math 59 | # res = math.compute_score(solution_str, ground_truth) 60 | # Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy 61 | from verl.utils.reward_score import math_verify 62 | res = math_verify.compute_score(solution_str, ground_truth) 63 | elif data_source == 'math_dapo': 64 | from verl.utils.reward_score import math_dapo 65 | res = math_dapo.compute_score(solution_str, ground_truth) 66 | elif data_source in ['math_dapo_boxed', 'amc_dapo_boxed', 'aime_2025_dapo_boxed', 'amc2023_dapo_boxed']: 67 | from verl.utils.reward_score import math_dapo 68 | res = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True) 69 | elif data_source in [ 70 | 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 71 | 'numina_olympiads' 72 | ]: 73 | from verl.utils.reward_score import prime_math 74 | res = prime_math.compute_score(solution_str, ground_truth) 75 | elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']: 76 | from verl.utils.reward_score import prime_code 77 | res = prime_code.compute_score(solution_str, ground_truth, continuous=True) 78 | elif data_source in ['hiyouga/geometry3k']: 79 | from verl.utils.reward_score import geo3k 80 | res = geo3k.compute_score(solution_str, ground_truth) 81 | else: 82 | raise NotImplementedError 83 | 84 | if isinstance(res, dict): 85 | return res 86 | elif isinstance(res, (int, float, bool)): 87 | return float(res) 88 | else: 89 | return float(res[0]) 90 | -------------------------------------------------------------------------------- /compute_acc.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pandas as pd 17 | import json 18 | from experience_maker import preprocess_box_response_for_qwen_prompt 19 | from verl.utils.reward_score import prime_math, math_verify, math_dapo 20 | import argparse 21 | 22 | 23 | RESPONSE_COL = "responses" 24 | GROUND_TRUTH_COL = "reward_model" 25 | GROUND_TRUTH_KEY = "ground_truth" 26 | 27 | 28 | def calculate_accuracy(file_path: str, verifier_type: str) -> None: 29 | """Loads data, calculates accuracy, and adds verifier results to DataFrame.""" 30 | try: 31 | df = pd.read_parquet(file_path) 32 | print(f"Loaded {len(df)} rows from {file_path}") 33 | except Exception as e: 34 | print(f"Error loading Parquet file '{file_path}': {e}") 35 | return 36 | 37 | total_correct = 0 38 | total_checked = 0 39 | all_results = [] 40 | 41 | verifier = { 42 | "qwen": preprocess_box_response_for_qwen_prompt, 43 | "prime": lambda response, answer: prime_math.compute_score(response, answer)[ 44 | "acc" 45 | ], 46 | "math": lambda response, answer: math_verify.compute_score( 47 | response, answer 48 | )["acc"], 49 | "dapo": lambda response, answer: math_dapo.compute_score( 50 | response, answer, strict_box_verify=True 51 | )["acc"], 52 | }[verifier_type] 53 | 54 | for index, row in df.iterrows(): 55 | # 1. Get Ground Truth 56 | gt_data = row[GROUND_TRUTH_COL] 57 | if isinstance(gt_data, str): 58 | gt_dict = json.loads(gt_data) 59 | else: 60 | gt_dict = gt_data 61 | ground_truth_answer = str(gt_dict[GROUND_TRUTH_KEY]) 62 | 63 | # 2. Get Responses (list) 64 | generated_responses = row[RESPONSE_COL] 65 | 66 | # 3. Compare each response 67 | row_results = [] 68 | for response in generated_responses: 69 | try: 70 | is_correct = verifier(str(response), ground_truth_answer) 71 | except Exception as e: 72 | print(f"Warning: Skipping due to error: {e}") 73 | is_correct=-1 74 | is_correct = float(is_correct) 75 | row_results.append(is_correct) 76 | total_correct += is_correct 77 | total_checked += 1 78 | all_results.append(row_results) 79 | 80 | # Add result column to DataFrame 81 | df[f"{verifier_type}_results"] = all_results 82 | 83 | # Save updated DataFrame 84 | df.to_parquet(file_path) 85 | print(f"Updated DataFrame saved to: {file_path}") 86 | 87 | # Print Summary 88 | if total_checked > 0: 89 | accuracy = total_correct / total_checked 90 | print(f"\n--- Results --- for {verifier_type}") 91 | print(f"Total responses checked: {total_checked}") 92 | print(f"Total correct responses: {int(total_correct)}") 93 | print(f"Accuracy: {accuracy:.4f}") 94 | else: 95 | print("\n--- Results ---") 96 | print("No responses were checked.") 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser( 101 | description="Calculate accuracy from generated responses in a Parquet file." 102 | ) 103 | parser.add_argument("--input_path", help="Path to the input Parquet file.") 104 | parser.add_argument( 105 | "--verifier", default="qwen", choices=["qwen", "prime", "math", "dapo", "all"] 106 | ) 107 | args = parser.parse_args() 108 | if args.verifier == "all": 109 | for ven in ["qwen", "prime", "math", "dapo"]: 110 | calculate_accuracy(args.input_path, ven) 111 | else: 112 | calculate_accuracy(args.input_path, args.verifier) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Negative-aware Fine-Tuning (NFT): Bridging Supervised Learning and Reinforcement Learning in Math Reasoning

3 | 4 | 5 | [![Paper](https://img.shields.io/badge/paper-5f16a8?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.18116) 6 | [![Blog](https://img.shields.io/badge/Blog-3858bf?style=for-the-badge&logo=homepage&logoColor=white)](https://research.nvidia.com/labs/dir/Negative-aware-Fine-Tuning/) 7 | [![Dataset](https://img.shields.io/badge/Datasets-4d8cd8?style=for-the-badge&logo=huggingface&logoColor=white)](https://huggingface.co/datasets/ChenDRAG/VeRL_math_validation) 8 | [![Weights](https://img.shields.io/badge/Model%20Weights-63cad3?style=for-the-badge&logo=huggingface&logoColor=white)](https://huggingface.co/nvidia/NFT-32B) 9 |
10 | 11 |

12 | seed logo 13 |

14 | 15 | 16 | NFT is a pure supervised learning method for improving LLMs' math-reasoning abilities with no external teachers. 17 | 18 | - As an SL method, NFT outperforms leading RL algorithms like GRPO and DAPO in 7B model experiments and performs similarly to DAPO in 32B settings. 19 | - NFT allows directly optimizing LLMs on negative data, thereby significantly outperforming other SL baselines such as Rejective sampling Fine-Tuning (RFT). 20 | - NFT is equivalent to GRPO when training is strictly on-policy, despite their entirely different theoretical foundations. 21 | 22 | NFT shows self-reflective improvement is not an inherent priority of RL algorithms. Rather, the current gap between SL and RL methods actually stems from their ability to effectively leverage negative data. 23 | 24 | ## Algorithm Overview 25 | 26 | NFT bridges reinforcement learning and supervised learning methods through the leverage of negative feedback via supervision. 27 | 28 | 29 |

30 | NFT Method 31 |

32 | 33 | The NFT pipeline consists of: 34 | 1. **Data Collection:** LLM generates answers to math questions, split into positive/negative based on correctness 35 | 2. **Implicit Negative Policy:** Constructs a policy to model negative answers using the same parameters as the positive policy 36 | 3. **Policy Optimization:** Both positive and negative answers optimize the LLM via supervised learning 37 | 38 |

39 | Policy Distribution 40 |

41 | 42 | ## Experimental Results 43 | 44 | Comparison of NFT-7B with other zero-shot math models in the Qwen series. 45 | 46 |

47 | Performance Comparison 48 |

49 | 50 | 51 | 52 | NFT performs competitively compared with other algorithms. We report avg@32 for AIME24, AIME25, and AMC23 and avg@1 for others. 53 | 54 |

55 | Benchmarking 56 |

57 | 58 | 59 | Validation accuracy curves showing NFT's ability to leverage negative data for continuous improvement. 60 | 61 |

62 | Validation Accuracy 63 |

64 | 65 | 66 | 67 | ## Evaluation 68 | 69 | ### Environment setup 70 | We use exactly the same environment configuration as the official DAPO codebase. 71 | ```bash 72 | pip install git+ssh://git@github.com/volcengine/verl.git@01ef7184821d0d7844796ec0ced17665c1f50673 73 | ``` 74 | 75 | ### Benchmarking 76 | 77 | Pretrained [7B](https://huggingface.co/nvidia/NFT-7B) and [32B](https://huggingface.co/nvidia/NFT-32B) models can be found at Huggingface. 78 | 79 | We provide the evaluation codebase integrated in the VeRL infra: 80 | 81 | Please refer to `eval_local_7B.sh` and `eval_local_32B.sh` for evaluation scripts. 82 | 83 | ## Training 84 | 85 | ### Environment setup 86 | We use exactly the same environment configuration as the official DAPO codebase. 87 | ```bash 88 | pip install git+ssh://git@github.com/volcengine/verl.git@01ef7184821d0d7844796ec0ced17665c1f50673 89 | ``` 90 | 91 | ### Datasets 92 | We employ public dataset [DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k) for training, and 6 public math benchmarks for validation. Download pre-sorted training and validation data by 93 | ```bash 94 | bash download_data.sh 95 | ``` 96 | 97 | ### Base Model 98 | ```bash 99 | bash download_model.sh 100 | ``` 101 | 102 | ### Starting Experiments 103 | Please see `train_7B.sh` and `train_32B.sh` for a running script (one node). Note that we run 7B experiments using 4×8 H100s, and 32B experiments using 16×8 H100s. Please refer to the instruction of [VeRL](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main) for launching distributed tasks. 104 | 105 | Hyperparameter: 106 | 107 | - `neg_weight`: The weight of negative data in NFT's objective. Set to 1.0 for default NFT config. Set to 0.0 for RFT by masking out all negative data loss. Set to -1.0 for the DAPO algorithm for comparison. 108 | - `normalize`: Controls the prompt weight in NFT's objective. Set to 0 so that all question data is treated equally. Set to 1 (default) or 2 to prioritize harder questions. `normalize=1` matches Dr. GRPO algorithm in on-policy training, while `normalize=2` matches standard GRPO. 109 | 110 | ## Acknowledgement 111 | 112 | We thank the [verl](https://github.com/volcengine/verl) for providing the awesome open-source RL infrastructure. 113 | 114 | ## Citation 115 | If you find our project helpful, please consider citing 116 | ```bibtex 117 | @article{chen2025bridging, 118 | title = {Bridging Supervised Learning and Reinforcement Learning in Math Reasoning}, 119 | author = {Huayu Chen, Kaiwen Zheng, Qinsheng Zhang, Ganqu Cui, Yin Cui, Haotian Ye, Tsung-Yi Lin, Ming-Yu Liu, Jun Zhu, Haoxiang Wang}, 120 | journal = {arXiv preprint arXiv:2505.18116}, 121 | year = {2025} 122 | } 123 | ``` -------------------------------------------------------------------------------- /train_7B.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | # WANDB configuration 5 | export WANDB_API_KEY=xxxx 6 | export WANDB_ENTITY=xxxx 7 | 8 | project_name='NFT' 9 | exp_name='NFT-Qwen2.5-7B-Math-Test' 10 | 11 | adv_estimator=grpo 12 | 13 | kl_coef=0.0 14 | kl_loss_coef=0.0 15 | 16 | clip_ratio_low=0.2 17 | clip_ratio_high=0.28 18 | 19 | enable_overlong_buffer=False 20 | overlong_buffer_len=0 21 | overlong_penalty_factor=0.0 22 | 23 | enable_filter_groups=True 24 | filter_groups_metric=acc 25 | max_num_gen_batches=10 26 | train_prompt_bsz=512 27 | gen_prompt_bsz=$((train_prompt_bsz * 3)) 28 | train_prompt_mini_bsz=32 29 | n_resp_per_prompt=16 30 | 31 | use_token_level_loss=True 32 | 33 | # Ray 34 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 35 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 36 | NNODES=${NNODES:-1} 37 | # Paths 38 | verl_workdir=. 39 | RAY_DATA_HOME=${RAY_DATA_HOME:-"${verl_workdir}"} 40 | MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} 41 | CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} 42 | TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k_10boxed.parquet"} 43 | 44 | # Algorithm 45 | ## Train 46 | max_prompt_length=$((512)) 47 | max_response_length=$((4096-512)) 48 | ## Validation 49 | val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout 50 | 51 | # NFT specific parameters 52 | neg_weight=1.0 # -1.0 for DAPO, 1.0 for NFT, 0.0 for RFT 53 | ratio_type="token" 54 | ppo_epoch=1 55 | bugged_dynamic_scale=0 56 | clamp_negative=1.0 57 | clamp_positive=0.0 58 | normalize=1 59 | 60 | # Performance Related Parameter 61 | sp_size=1 62 | use_dynamic_bsz=True 63 | actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) 64 | infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) 65 | gen_tp=4 66 | actor_offload=False 67 | ref_offload=True 68 | 69 | python3 -m main_nft \ 70 | data.train_files="${TRAIN_FILE}" \ 71 | data.val_files="['./data/aime-2024-boxed_w_answer.parquet', './data/math500_boxed.parquet', './data/minerva_math.parquet', './data/olympiadbench.parquet', './data/aime2025_32_dapo_boxed_w_answer.parquet', './data/amc2023_32_dapo_boxed_w_answer.parquet']" \ 72 | data.prompt_key=prompt \ 73 | data.truncation='left' \ 74 | data.max_prompt_length=${max_prompt_length} \ 75 | data.max_response_length=${max_response_length} \ 76 | data.filter_overlong_prompts=True \ 77 | data.gen_batch_size=${gen_prompt_bsz} \ 78 | data.train_batch_size=${train_prompt_bsz} \ 79 | actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ 80 | algorithm.adv_estimator=${adv_estimator} \ 81 | algorithm.kl_ctrl.kl_coef=${kl_coef} \ 82 | actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ 83 | actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ 84 | actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ 85 | algorithm.filter_groups.enable=${enable_filter_groups} \ 86 | algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ 87 | algorithm.filter_groups.metric=${filter_groups_metric} \ 88 | actor_rollout_ref.model.use_remove_padding=True \ 89 | actor_rollout_ref.neg_weight=${neg_weight} \ 90 | actor_rollout_ref.actor.ratio_type=${ratio_type} \ 91 | actor_rollout_ref.actor.ppo_epochs=${ppo_epoch} \ 92 | actor_rollout_ref.bugged_dynamic_scale=${bugged_dynamic_scale} \ 93 | actor_rollout_ref.clamp_negative=${clamp_negative} \ 94 | actor_rollout_ref.clamp_positive=${clamp_positive} \ 95 | actor_rollout_ref.normalize=${normalize} \ 96 | actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ 97 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ 98 | actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ 99 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ 100 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ 101 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ 102 | actor_rollout_ref.model.path="${MODEL_PATH}" \ 103 | +actor_rollout_ref.model.override_config.attention_dropout=0. \ 104 | +actor_rollout_ref.model.override_config.embd_pdrop=0. \ 105 | +actor_rollout_ref.model.override_config.resid_pdrop=0. \ 106 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 107 | actor_rollout_ref.actor.optim.lr=1e-6 \ 108 | actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ 109 | actor_rollout_ref.actor.optim.weight_decay=0.1 \ 110 | actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ 111 | actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ 112 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ 113 | actor_rollout_ref.actor.entropy_coeff=0 \ 114 | actor_rollout_ref.actor.grad_clip=1.0 \ 115 | actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \ 116 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ 117 | actor_rollout_ref.rollout.gpu_memory_utilization=0.70 \ 118 | actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ 119 | actor_rollout_ref.rollout.enable_chunked_prefill=True \ 120 | actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ 121 | actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \ 122 | actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ 123 | actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ 124 | actor_rollout_ref.rollout.val_kwargs.n=1 \ 125 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \ 126 | actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ 127 | actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ 128 | actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ 129 | custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \ 130 | custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \ 131 | custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ 132 | trainer.logger=['console','wandb'] \ 133 | trainer.project_name="${project_name}" \ 134 | trainer.experiment_name="${exp_name}" \ 135 | trainer.n_gpus_per_node=8 \ 136 | trainer.nnodes="${NNODES}" \ 137 | +trainer.val_before_train=False \ 138 | trainer.test_freq=2 \ 139 | trainer.save_freq=2 \ 140 | trainer.total_epochs=100 \ 141 | trainer.default_local_dir="${CKPTS_DIR}" \ 142 | trainer.resume_mode=auto -------------------------------------------------------------------------------- /train_32B.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | # WANDB configuration 5 | export WANDB_API_KEY=xxxx 6 | export WANDB_ENTITY=xxxx 7 | 8 | project_name='NFT' 9 | exp_name='NFT-Qwen2.5-32B-Test' 10 | 11 | adv_estimator=grpo 12 | 13 | kl_coef=0.0 14 | kl_loss_coef=0.0 15 | 16 | clip_ratio_low=0.2 17 | clip_ratio_high=0.28 18 | 19 | enable_filter_groups=True 20 | filter_groups_metric=acc 21 | max_num_gen_batches=10 22 | train_prompt_bsz=512 23 | gen_prompt_bsz=$((train_prompt_bsz * 3)) 24 | train_prompt_mini_bsz=32 25 | n_resp_per_prompt=16 26 | 27 | use_token_level_loss=True 28 | 29 | # Ray 30 | WORKING_DIR=${WORKING_DIR:-"${PWD}"} 31 | RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} 32 | NNODES=${NNODES:-1} 33 | # Paths 34 | verl_workdir=. 35 | RAY_DATA_HOME=${RAY_DATA_HOME:-"${verl_workdir}"} 36 | MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} 37 | CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} 38 | TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k_10boxed.parquet"} 39 | 40 | # Algorithm 41 | ## Train 42 | ## Validation 43 | val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout 44 | 45 | # NFT specific parameters 46 | neg_weight=1.0 # -1.0 for DAPO, 1.0 for NFT, 0.0 for RFT (try 0.5 for 32B) 47 | ratio_type="token" 48 | ppo_epoch=1 49 | bugged_dynamic_scale=0 50 | clamp_negative=1.0 # 0.5 51 | clamp_positive=0.0 52 | normalize=1 53 | if [ "$neg_weight" = "-1.0" ]; then 54 | max_prompt_length=$((1024 * 2)) 55 | max_response_length=$((1024 * 14)) 56 | enable_overlong_buffer=True 57 | overlong_buffer_len=$((1024 * 4)) 58 | overlong_penalty_factor=1.0 59 | else 60 | max_prompt_length=$((1024 * 2)) 61 | max_response_length=$((1024 * 14)) 62 | enable_overlong_buffer=False 63 | overlong_buffer_len=0 64 | overlong_penalty_factor=0.0 65 | fi 66 | 67 | # Performance Related Parameter 68 | sp_size=8 69 | use_dynamic_bsz=True 70 | actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) 71 | infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) 72 | gen_tp=4 73 | actor_offload=False 74 | ref_offload=True 75 | 76 | python3 -m main_nft \ 77 | data.train_files="${TRAIN_FILE}" \ 78 | data.val_files="['./data/aime-2024-boxed_w_answer.parquet', './data/math500_boxed.parquet', './data/minerva_math.parquet', './data/olympiadbench.parquet', './data/aime2025_32_dapo_boxed_w_answer.parquet', './data/amc2023_32_dapo_boxed_w_answer.parquet']" \ 79 | data.prompt_key=prompt \ 80 | data.truncation='left' \ 81 | data.max_prompt_length=${max_prompt_length} \ 82 | data.max_response_length=${max_response_length} \ 83 | data.filter_overlong_prompts=True \ 84 | data.gen_batch_size=${gen_prompt_bsz} \ 85 | data.train_batch_size=${train_prompt_bsz} \ 86 | actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ 87 | algorithm.adv_estimator=${adv_estimator} \ 88 | algorithm.kl_ctrl.kl_coef=${kl_coef} \ 89 | actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ 90 | actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ 91 | actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ 92 | algorithm.filter_groups.enable=${enable_filter_groups} \ 93 | algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ 94 | algorithm.filter_groups.metric=${filter_groups_metric} \ 95 | actor_rollout_ref.model.use_remove_padding=True \ 96 | actor_rollout_ref.neg_weight=${neg_weight} \ 97 | actor_rollout_ref.actor.ratio_type=${ratio_type} \ 98 | actor_rollout_ref.actor.ppo_epochs=${ppo_epoch} \ 99 | actor_rollout_ref.bugged_dynamic_scale=${bugged_dynamic_scale} \ 100 | actor_rollout_ref.clamp_negative=${clamp_negative} \ 101 | actor_rollout_ref.clamp_positive=${clamp_positive} \ 102 | actor_rollout_ref.normalize=${normalize} \ 103 | actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ 104 | actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ 105 | actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ 106 | actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ 107 | actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ 108 | actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ 109 | actor_rollout_ref.model.path="${MODEL_PATH}" \ 110 | +actor_rollout_ref.model.override_config.attention_dropout=0. \ 111 | +actor_rollout_ref.model.override_config.embd_pdrop=0. \ 112 | +actor_rollout_ref.model.override_config.resid_pdrop=0. \ 113 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 114 | actor_rollout_ref.actor.optim.lr=1e-6 \ 115 | actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ 116 | actor_rollout_ref.actor.optim.weight_decay=0.1 \ 117 | actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ 118 | actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ 119 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ 120 | actor_rollout_ref.actor.entropy_coeff=0 \ 121 | actor_rollout_ref.actor.grad_clip=1.0 \ 122 | actor_rollout_ref.actor.use_token_level_loss=${use_token_level_loss} \ 123 | actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ 124 | actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ 125 | actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ 126 | actor_rollout_ref.rollout.enable_chunked_prefill=True \ 127 | actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ 128 | actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \ 129 | actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ 130 | actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ 131 | actor_rollout_ref.rollout.val_kwargs.n=1 \ 132 | actor_rollout_ref.rollout.val_kwargs.do_sample=True \ 133 | actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ 134 | actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ 135 | actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ 136 | custom_reward_function.overlong_buffer.enable=${enable_overlong_buffer} \ 137 | custom_reward_function.overlong_buffer.len=${overlong_buffer_len} \ 138 | custom_reward_function.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ 139 | trainer.logger=['console','wandb'] \ 140 | trainer.project_name="${project_name}" \ 141 | trainer.experiment_name="${exp_name}" \ 142 | trainer.n_gpus_per_node=8 \ 143 | trainer.nnodes="${NNODES}" \ 144 | +trainer.val_before_train=False \ 145 | trainer.test_freq=2 \ 146 | trainer.save_freq=2 \ 147 | trainer.total_epochs=100 \ 148 | trainer.default_local_dir="${CKPTS_DIR}" \ 149 | trainer.resume_mode=auto -------------------------------------------------------------------------------- /main_nft.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | """ 30 | Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. 31 | """ 32 | 33 | import ray 34 | import hydra 35 | 36 | 37 | def get_custom_reward_fn(config): 38 | import importlib.util, os 39 | 40 | reward_fn_config = config.get("custom_reward_function") or {} 41 | file_path = reward_fn_config.get("path") 42 | if not file_path: 43 | return None 44 | 45 | if not os.path.exists(file_path): 46 | raise FileNotFoundError(f"Reward function file '{file_path}' not found.") 47 | 48 | spec = importlib.util.spec_from_file_location("custom_module", file_path) 49 | module = importlib.util.module_from_spec(spec) 50 | try: 51 | spec.loader.exec_module(module) 52 | except Exception as e: 53 | raise RuntimeError(f"Error loading module from '{file_path}': {e}") 54 | 55 | function_name = reward_fn_config.get("name") 56 | 57 | if not hasattr(module, function_name): 58 | raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") 59 | 60 | print(f"using customized reward function '{function_name}' from '{file_path}'") 61 | 62 | return getattr(module, function_name) 63 | 64 | 65 | @hydra.main(config_path='config', config_name='nft_trainer', version_base=None) 66 | def main(config): 67 | run_nft(config) 68 | 69 | 70 | def run_nft(config) -> None: 71 | 72 | if not ray.is_initialized(): 73 | # this is for local ray cluster 74 | ray.init(runtime_env={ 75 | 'env_vars': { 76 | 'TOKENIZERS_PARALLELISM': 'true', 77 | 'NCCL_DEBUG': 'WARN', 78 | 'VLLM_LOGGING_LEVEL': 'WARN' 79 | } 80 | }) 81 | 82 | ray.get(main_task.remote(config)) 83 | 84 | 85 | @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head 86 | def main_task(config): 87 | import verl.workers.actor 88 | import verl.workers.actor.dp_actor 89 | import verl.utils.reward_score 90 | # override the original definition of DataParallelPPOActor 91 | import dp_actor 92 | verl.workers.actor.DataParallelPPOActor = dp_actor.DataParallelPPOActor 93 | verl.workers.actor.dp_actor.DataParallelPPOActor = dp_actor.DataParallelPPOActor 94 | 95 | import verl.utils.reward_score.math_verify 96 | import verifier 97 | verl.utils.reward_score._default_compute_score = verifier._default_compute_score 98 | verl.utils.reward_score.math_verify.compute_score = verifier.math_verify_compute_score 99 | 100 | from ray_nft_trainer import RayNFTTrainer 101 | from verl.utils.fs import copy_to_local 102 | # print initial config 103 | from pprint import pprint 104 | from omegaconf import OmegaConf 105 | pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values 106 | OmegaConf.resolve(config) 107 | 108 | # download the checkpoint from hdfs 109 | local_path = copy_to_local(config.actor_rollout_ref.model.path) 110 | 111 | # instantiate tokenizer 112 | from verl.utils import hf_tokenizer, hf_processor 113 | tokenizer = hf_tokenizer(local_path) 114 | processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none 115 | 116 | # define worker classes 117 | if config.actor_rollout_ref.actor.strategy == 'fsdp': 118 | assert config.actor_rollout_ref.actor.strategy == config.critic.strategy 119 | from fsdp_workers import ActorRolloutRefWorker, CriticWorker 120 | from verl.single_controller.ray import RayWorkerGroup 121 | ray_worker_group_cls = RayWorkerGroup 122 | 123 | elif config.actor_rollout_ref.actor.strategy == 'megatron': 124 | assert config.actor_rollout_ref.actor.strategy == config.critic.strategy 125 | from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker 126 | from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup 127 | ray_worker_group_cls = NVMegatronRayWorkerGroup 128 | 129 | else: 130 | raise NotImplementedError 131 | 132 | from ray_nft_trainer import ResourcePoolManager, Role 133 | 134 | role_worker_mapping = { 135 | Role.ActorRollout: ray.remote(ActorRolloutRefWorker), 136 | Role.Critic: ray.remote(CriticWorker), 137 | Role.RefPolicy: ray.remote(ActorRolloutRefWorker) 138 | } 139 | 140 | global_pool_id = 'global_pool' 141 | resource_pool_spec = { 142 | global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, 143 | } 144 | mapping = { 145 | Role.ActorRollout: global_pool_id, 146 | Role.Critic: global_pool_id, 147 | Role.RefPolicy: global_pool_id, 148 | } 149 | 150 | # we should adopt a multi-source reward function here 151 | # - for rule-based rm, we directly call a reward score 152 | # - for model-based rm, we call a model 153 | # - for code related prompt, we send to a sandbox if there are test cases 154 | # - finally, we combine all the rewards together 155 | # - The reward type depends on the tag of the data 156 | if config.reward_model.enable: 157 | if config.reward_model.strategy == 'fsdp': 158 | from fsdp_workers import RewardModelWorker 159 | elif config.reward_model.strategy == 'megatron': 160 | from verl.workers.megatron_workers import RewardModelWorker 161 | else: 162 | raise NotImplementedError 163 | role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) 164 | mapping[Role.RewardModel] = global_pool_id 165 | 166 | reward_manager_name = config.reward_model.get("reward_manager", "naive") 167 | if reward_manager_name == 'naive': 168 | from verl.workers.reward_manager import NaiveRewardManager 169 | reward_manager_cls = NaiveRewardManager 170 | elif reward_manager_name == 'prime': 171 | from verl.workers.reward_manager import PrimeRewardManager 172 | reward_manager_cls = PrimeRewardManager 173 | else: 174 | raise NotImplementedError 175 | 176 | compute_score = get_custom_reward_fn(config) 177 | reward_fn = reward_manager_cls(tokenizer=tokenizer, 178 | num_examine=0, 179 | compute_score=compute_score, 180 | reward_fn_key=config.data.reward_fn_key, 181 | max_resp_len=config.data.max_response_length, 182 | overlong_buffer_cfg=config.custom_reward_function.overlong_buffer) 183 | 184 | # Note that we always use function-based RM for validation 185 | val_reward_fn = reward_manager_cls(tokenizer=tokenizer, 186 | num_examine=1, 187 | compute_score=compute_score, 188 | reward_fn_key=config.data.reward_fn_key, 189 | max_resp_len=config.data.max_response_length, 190 | overlong_buffer_cfg=config.custom_reward_function.overlong_buffer) 191 | 192 | resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) 193 | 194 | trainer = RayNFTTrainer(config=config, 195 | tokenizer=tokenizer, 196 | processor=processor, 197 | role_worker_mapping=role_worker_mapping, 198 | resource_pool_manager=resource_pool_manager, 199 | ray_worker_group_cls=ray_worker_group_cls, 200 | reward_fn=reward_fn, 201 | val_reward_fn=val_reward_fn) 202 | trainer.init_workers() 203 | trainer.fit() 204 | 205 | 206 | if __name__ == '__main__': 207 | main() 208 | -------------------------------------------------------------------------------- /config/nft_trainer.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | tokenizer: null 3 | train_files: ~/data/rlhf/gsm8k/train.parquet 4 | val_files: ~/data/rlhf/gsm8k/test.parquet 5 | prompt_key: prompt 6 | reward_fn_key: data_source 7 | max_prompt_length: 512 8 | max_response_length: 512 9 | gen_batch_size: ${data.train_batch_size} 10 | train_batch_size: 1024 11 | val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves 12 | return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs 13 | return_raw_chat: False 14 | shuffle: True 15 | filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' 16 | truncation: error 17 | image_key: images 18 | 19 | actor_rollout_ref: 20 | hybrid_engine: True 21 | bugged_dynamic_scale: 0 22 | neg_weight: -1.0 23 | clamp_positive: 0.0 24 | clamp_negative: 1.0 25 | normalize: 1 26 | model: 27 | path: ~/models/deepseek-llm-7b-chat 28 | external_lib: null 29 | override_config: { } 30 | enable_gradient_checkpointing: True 31 | use_remove_padding: False 32 | actor: 33 | strategy: fsdp # This is for backward-compatibility 34 | ppo_mini_batch_size: 256 35 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu 36 | ppo_micro_batch_size_per_gpu: null 37 | neg_weight: ${actor_rollout_ref.neg_weight} 38 | clamp_positive: ${actor_rollout_ref.clamp_positive} 39 | normalize: ${actor_rollout_ref.normalize} 40 | clamp_negative: ${actor_rollout_ref.clamp_negative} 41 | bugged_dynamic_scale: ${actor_rollout_ref.bugged_dynamic_scale} 42 | max_response_length: ${data.max_response_length} 43 | use_dynamic_bsz: False 44 | ratio_type: token 45 | ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} 46 | grad_clip: 1.0 47 | # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) 48 | clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified 49 | clip_ratio_low: 0.2 50 | clip_ratio_high: 0.2 51 | use_token_level_loss: True 52 | entropy_coeff: 0.001 53 | use_kl_loss: False # True for GRPO 54 | use_torch_compile: True # False to disable torch compile 55 | kl_loss_coef: 0.001 # for grpo 56 | kl_loss_type: low_var_kl # for grpo 57 | ppo_epochs: 1 58 | shuffle: False 59 | ulysses_sequence_parallel_size: 1 # sp size 60 | optim: 61 | lr: 1e-6 62 | lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. 63 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 64 | min_lr_ratio: null # only useful for warmup with cosine 65 | warmup_style: constant # select from constant/cosine 66 | total_training_steps: -1 # must be override by program 67 | weight_decay: 0.01 68 | fsdp_config: 69 | wrap_policy: 70 | # transformer_layer_cls_to_wrap: None 71 | min_num_params: 0 72 | param_offload: False 73 | optimizer_offload: False 74 | fsdp_size: -1 75 | ref: 76 | fsdp_config: 77 | param_offload: False 78 | wrap_policy: 79 | # transformer_layer_cls_to_wrap: None 80 | min_num_params: 0 81 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu 82 | log_prob_micro_batch_size_per_gpu: null 83 | log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} 84 | log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} 85 | ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size 86 | rollout: 87 | name: vllm 88 | temperature: 1.0 89 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout 90 | top_p: 1 91 | use_fire_sampling: False # https://arxiv.org/abs/2410.21236 92 | prompt_length: ${data.max_prompt_length} # not use for opensource 93 | response_length: ${data.max_response_length} 94 | # for vllm rollout 95 | dtype: bfloat16 # should align with FSDP 96 | gpu_memory_utilization: 0.5 97 | ignore_eos: False 98 | enforce_eager: True 99 | free_cache_engine: True 100 | load_format: dummy_dtensor 101 | tensor_model_parallel_size: 2 102 | max_num_batched_tokens: 8192 103 | max_model_len: null 104 | max_num_seqs: 1024 105 | log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu 106 | log_prob_micro_batch_size_per_gpu: null 107 | log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} 108 | log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} 109 | disable_log_stats: True 110 | enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. 111 | # for hf rollout 112 | do_sample: True 113 | # number of responses (i.e. num sample times) 114 | n: 1 # > 1 for grpo 115 | val_kwargs: 116 | # sampling parameters for validation 117 | top_k: -1 # 0 for hf rollout, -1 for vllm rollout 118 | top_p: 1.0 119 | temperature: 0.6 120 | n: 1 121 | do_sample: False # default eager for validation 122 | 123 | critic: 124 | strategy: fsdp 125 | optim: 126 | lr: 1e-5 127 | lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime 128 | min_lr_ratio: null # only useful for warmup with cosine 129 | warmup_style: constant # select from constant/cosine 130 | total_training_steps: -1 # must be override by program 131 | weight_decay: 0.01 132 | model: 133 | path: ~/models/deepseek-llm-7b-chat 134 | tokenizer_path: ${actor_rollout_ref.model.path} 135 | override_config: { } 136 | external_lib: ${actor_rollout_ref.model.external_lib} 137 | enable_gradient_checkpointing: True 138 | use_remove_padding: False 139 | fsdp_config: 140 | param_offload: False 141 | optimizer_offload: False 142 | wrap_policy: 143 | # transformer_layer_cls_to_wrap: None 144 | min_num_params: 0 145 | fsdp_size: -1 146 | ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} 147 | ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu 148 | ppo_micro_batch_size_per_gpu: null 149 | forward_micro_batch_size: ${critic.ppo_micro_batch_size} 150 | forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} 151 | use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} 152 | ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 153 | forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} 154 | ulysses_sequence_parallel_size: 1 # sp size 155 | ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} 156 | shuffle: ${actor_rollout_ref.actor.shuffle} 157 | grad_clip: 1.0 158 | cliprange_value: 0.5 159 | 160 | reward_model: 161 | enable: False 162 | strategy: fsdp 163 | model: 164 | input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical 165 | path: ~/models/FsfairX-LLaMA3-RM-v0.1 166 | external_lib: ${actor_rollout_ref.model.external_lib} 167 | use_remove_padding: False 168 | fsdp_config: 169 | wrap_policy: 170 | min_num_params: 0 171 | param_offload: False 172 | fsdp_size: -1 173 | micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu 174 | micro_batch_size_per_gpu: null # set a number 175 | max_length: null 176 | ulysses_sequence_parallel_size: 1 # sp size 177 | use_dynamic_bsz: ${critic.use_dynamic_bsz} 178 | forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} 179 | reward_manager: naive 180 | 181 | custom_reward_function: 182 | path: null 183 | name: compute_score 184 | overlong_buffer: 185 | enable: False # We try to avoid forgetting to set enable 186 | len: 0 187 | penalty_factor: 0.0 188 | log: False 189 | 190 | algorithm: 191 | gamma: 1.0 192 | lam: 1.0 193 | adv_estimator: gae 194 | kl_penalty: kl # how to estimate kl divergence 195 | kl_ctrl: 196 | type: fixed 197 | kl_coef: 0.001 198 | filter_groups: 199 | enable: False # We try to avoid forgetting to set enable 200 | metric: null # acc / score / seq_reward / seq_final_reward / ... 201 | max_num_gen_batches: 0 # Non-positive values mean no upper limit 202 | 203 | trainer: 204 | val_only: False 205 | balance_batch: True 206 | total_epochs: 30 207 | total_training_steps: null 208 | project_name: verl_examples 209 | experiment_name: gsm8k 210 | logger: [ 'console', 'wandb' ] 211 | val_generations_to_log_to_wandb: 0 212 | nnodes: 1 213 | n_gpus_per_node: 8 214 | save_freq: -1 215 | # auto: find the last ckpt to resume. If can't find, start from scratch 216 | resume_mode: auto # or auto or resume_path if 217 | resume_from_path: False 218 | test_freq: -1 219 | critic_warmup: 0 220 | default_hdfs_dir: null 221 | remove_previous_ckpt_in_save: True 222 | del_local_ckpt_after_load: False 223 | default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} 224 | -------------------------------------------------------------------------------- /qwen_math_eval_toolkit/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import json 5 | import os 6 | import numpy as np 7 | from pathlib import Path 8 | from typing import Iterable, Union, Any 9 | 10 | from qwen_math_eval_toolkit.examples import get_examples 11 | 12 | 13 | def set_seed(seed: int = 42) -> None: 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | print(f"Random seed set as {seed}") 18 | 19 | 20 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]: 21 | with open(file, "r", encoding="utf-8") as f: 22 | for line in f: 23 | try: 24 | yield json.loads(line) 25 | except: 26 | print("Error in loading:", line) 27 | exit() 28 | 29 | 30 | def save_jsonl(samples, save_path): 31 | # ensure path 32 | folder = os.path.dirname(save_path) 33 | os.makedirs(folder, exist_ok=True) 34 | 35 | with open(save_path, "w", encoding="utf-8") as f: 36 | for sample in samples: 37 | f.write(json.dumps(sample, ensure_ascii=False) + "\n") 38 | print("Saved to", save_path) 39 | 40 | 41 | def lower_keys(example): 42 | new_example = {} 43 | for key, value in example.items(): 44 | if key != key.lower(): 45 | new_key = key.lower() 46 | new_example[new_key] = value 47 | else: 48 | new_example[key] = value 49 | return new_example 50 | 51 | 52 | EXAMPLES = get_examples() 53 | 54 | 55 | def load_prompt(data_name, prompt_type, num_shots): 56 | if not num_shots: 57 | return [] 58 | 59 | if data_name in ["gsm_hard", "svamp", "tabmwp", "asdiv", "mawps"]: 60 | data_name = "gsm8k" 61 | if data_name in ["math_oai", "hungarian_exam", "math-oai", "aime24", "amc23"]: 62 | data_name = "math" 63 | if data_name in ["sat_math"]: 64 | data_name = "mmlu_stem" 65 | if data_name in [ 66 | "gaokao2024_I", 67 | "gaokao2024_II", 68 | "gaokao_math_qa", 69 | "gaokao2024_mix", 70 | "cn_middle_school", 71 | ]: 72 | data_name = "gaokao" 73 | 74 | if prompt_type in ["tool-integrated"]: 75 | prompt_type = "tora" 76 | 77 | return EXAMPLES[data_name][:num_shots] 78 | 79 | 80 | PROMPT_TEMPLATES = { 81 | "direct": ("Question: {input}\nAnswer: ", "{output}", "\n\n"), 82 | "cot": ("Question: {input}\nAnswer: ", "{output}", "\n\n\n"), 83 | "pal": ("Question: {input}\n\n", "{output}", "\n---\n"), 84 | "tool-integrated": ("Question: {input}\n\nSolution:\n", "{output}", "\n---\n"), 85 | "self-instruct": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"), 86 | "tora": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"), 87 | "wizard_zs": ( 88 | "### Instruction:\n{input}\n\n### Response: Let's think step by step.", 89 | "{output}", 90 | "\n\n\n", 91 | ), 92 | "platypus_fs": ( 93 | "### Instruction:\n{input}\n\n### Response:\n", 94 | "{output}", 95 | "\n\n\n", 96 | ), 97 | "deepseek-math": ( 98 | "User: {input}\nPlease reason step by step, " 99 | "and put your final answer within \\boxed{{}}.\n\nAssistant:", 100 | "{output}", 101 | "\n\n\n", 102 | ), 103 | "kpmath": ( 104 | "User: Please reason step by step and put your final answer at the end " 105 | 'with "The answer is: ".\n\n{input}\n\nAssistant:', 106 | "{output}", 107 | ), 108 | "jiuzhang": ( 109 | "## Question\n{input}\n\n## Solution\n", 110 | "{output}", 111 | "\n\n\n", 112 | ), 113 | "jiuzhang_tora": ( 114 | "## Question\n{input}\n\n## Code Solution\n", 115 | "{output}", 116 | "\n\n\n", 117 | ), 118 | "jiuzhang_nl": ( 119 | "## Question\n{input}\n\n## Natural Language Solution\n", 120 | "{output}", 121 | "\n\n\n", 122 | ), 123 | "mmiqc": ( 124 | 'Please solve the following problem and put your answer at the end with "The answer is: ".\n\n{input}\n\n', 125 | "{output}", 126 | "\n\n\n", 127 | ), 128 | "abel": ( 129 | "Question:\n{input}\nAnswer:\nLet's think step by step.\n", 130 | "{output}", 131 | "\n\n", 132 | ), 133 | "shepherd": ("{input}\n", "{output}", "\n\n\n"), 134 | "qwen-boxed": ( 135 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 136 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 137 | "<|im_start|>assistant\n", 138 | "{output}", 139 | "\n\n", 140 | ), 141 | "qwen25-math-cot": ( 142 | "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 143 | "<|im_start|>user\n{input}<|im_end|>\n" 144 | "<|im_start|>assistant\n", 145 | "{output}", 146 | "\n\n", 147 | ), 148 | "mathstral": ( 149 | "{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.", 150 | "{output}", 151 | "\n\n", 152 | ), 153 | "internlm-math-fs": ("Question:{input}\nAnswer:", "{output}", "\n"), 154 | "internlm-math-chat": ( 155 | "<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n", 156 | "{output}", 157 | "\n\n", 158 | ), 159 | "mistral": ( 160 | "[INST] {input}[/INST]", 161 | "{output}", 162 | "\n\n", 163 | ), 164 | "numina": ("### Problem: {input}\n### Solution:", " {output}", "\n\n"), 165 | "o1_cot": ( 166 | '[Round 0] USER:\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}. ASSISTANT:\n', 167 | "{output}", 168 | "\n\n" 169 | ) 170 | } 171 | 172 | 173 | def construct_prompt(example, data_name, args): 174 | if args.adapt_few_shot and data_name in [ 175 | "gaokao2024_I", 176 | "gaokao2024_II", 177 | "gaokao_math_qa", 178 | "gaokao2024_mix", 179 | "cn_middle_school", 180 | ]: 181 | demos = load_prompt(data_name, args.prompt_type, 5) 182 | else: 183 | demos = load_prompt(data_name, args.prompt_type, args.num_shots) 184 | prompt_type = args.prompt_type 185 | if prompt_type == "platypus_fs": 186 | prompt_type = "cot" 187 | if prompt_type == "tool-integrated": 188 | prompt_type = "tora" 189 | 190 | prompt_temp = PROMPT_TEMPLATES[args.prompt_type] 191 | 192 | splitter = prompt_temp[2] 193 | input_template, output_template, splitter = ( 194 | prompt_temp[0], 195 | prompt_temp[1], 196 | prompt_temp[2], 197 | ) 198 | if args.prompt_type == "qwen25-math-cot": 199 | # Hotfix to support putting all demos into a single turn 200 | demo_prompt = splitter.join([q + "\n" + a for q, a in demos]) 201 | else: 202 | demo_prompt = splitter.join( 203 | [ 204 | input_template.format(input=q) + output_template.format(output=a) 205 | for q, a in demos 206 | ] 207 | ) 208 | context = input_template.format(input=example["question"]) 209 | if len(demo_prompt) == 0 or ( 210 | args.adapt_few_shot and example["gt_ans"] not in ["A", "B", "C", "D", "E"] 211 | ): 212 | full_prompt = context 213 | else: 214 | if args.prompt_type == "qwen25-math-cot": 215 | # Hotfix to supportting put all demos into a single turn 216 | full_prompt = demo_prompt + splitter + example["question"] 217 | full_prompt = input_template.format(input=full_prompt) 218 | else: 219 | full_prompt = demo_prompt + splitter + context 220 | 221 | if args.prompt_type == "platypus_fs": 222 | full_prompt_temp = ( 223 | "Below is an instruction that describes a task. " 224 | "Write a response that appropriately completes the request.\n\n" 225 | "### Instruction:\n{instruction}\n\n### Response:\n" 226 | ) 227 | full_prompt = full_prompt_temp.format(instruction=full_prompt) 228 | 229 | if prompt_type == "tora": 230 | full_prompt = ( 231 | """Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines: 232 | 233 | - Analyze the question and write functions to solve the problem; the function should not take any arguments. 234 | - Present the final result in LaTeX using a `\boxed{}` without any units. 235 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values. 236 | 237 | Here are some examples you may refer to: 238 | 239 | --- 240 | 241 | """ 242 | + full_prompt 243 | ) 244 | 245 | return full_prompt.strip(" ") # important! 246 | 247 | 248 | key_map = { 249 | "gt": "Ground Truth", 250 | "pred": "Prediction", 251 | "gt_cot": "Reference CoT", 252 | "score": "Score", 253 | } 254 | 255 | 256 | def show_sample(sample, print_all_preds=False): 257 | print("==" * 20) 258 | for key in ["idx", "type", "level", "dataset"]: 259 | if key in sample: 260 | # capitalize 261 | print("{}: {}".format(key[0].upper() + key[1:], sample[key])) 262 | print("Question:", repr(sample["question"])) 263 | if "code" in sample: 264 | if print_all_preds: 265 | for code in sample["code"]: 266 | print("-" * 20) 267 | print("code:", code) 268 | print("Execution:", sample["report"]) 269 | else: 270 | print("Solution:\n", sample["code"][0]) 271 | print("Execution:", sample["report"][0]) 272 | if "pred" in sample: 273 | print("Prediction:", repr(sample["pred"][0])) 274 | for key in ["gt", "score", "unit", "gt_cot"]: 275 | if key in sample: 276 | _key = key_map.get(key, key) 277 | print("{}: {}".format(_key, repr(sample[key]))) 278 | print() 279 | -------------------------------------------------------------------------------- /README_VeRL.md: -------------------------------------------------------------------------------- 1 |

verl: Volcano Engine Reinforcement Learning for LLM

2 | 3 | verl is a flexible, efficient and production-ready RL training library for large language models (LLMs). 4 | 5 | verl is the open-source version of **[HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2)** paper. 6 | 7 | verl is flexible and easy to use with: 8 | 9 | - **Easy extension of diverse RL algorithms**: The Hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code. 10 | 11 | - **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM and vLLM. Moreover, users can easily extend to other LLM training and inference frameworks. 12 | 13 | - **Flexible device mapping**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes. 14 | 15 | - Readily integration with popular HuggingFace models 16 | 17 | 18 | verl is fast with: 19 | 20 | - **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, verl achieves high generation and training throughput. 21 | 22 | - **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases. 23 | 24 |

25 | | Documentation | Paper | Slack | Wechat | Twitter 26 | 27 | 28 |

29 | 30 | ## News 31 | - [2025/3] We will present verl(HybridFlow) at [EuroSys 2025](https://2025.eurosys.org/). See you in Rotterdam! 32 | - [2025/3] We will introduce the programming model of verl at the [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg) on 3/16. See you in Beijing! 33 | - [2025/2] verl v0.2.0.post1 is released! See [release note](https://github.com/volcengine/verl/releases/) for details. 34 | - [2025/2] We presented verl in the [Bytedance/NVIDIA/Anyscale Ray Meetup](https://lu.ma/ji7atxux). See you in San Jose! 35 | - [2025/1] [Doubao-1.5-pro](https://team.doubao.com/zh/special/doubao_1_5_pro) is released with SOTA-level performance on LLM & VLM. The RL scaling preview model is trained using verl, reaching OpenAI O1-level performance on math benchmarks (70.0 pass@1 on AIME). 36 | - [2024/12] The team presented Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips) and [video](https://neurips.cc/Expo/Conferences/2024/workshop/100677) available. 37 | - [2024/12] verl is presented at Ray Forward 2024. Slides available [here](https://github.com/eric-haibin-lin/verl-community/blob/main/slides/Ray_Forward_2024_%E5%B7%AB%E9%94%A1%E6%96%8C.pdf). 38 | - [2024/10] verl is presented at Ray Summit. [Youtube video](https://www.youtube.com/watch?v=MrhMcXkXvJU&list=PLzTswPQNepXntmT8jr9WaNfqQ60QwW7-U&index=37) available. 39 | - [2024/08] HybridFlow (verl) is accepted to EuroSys 2025. 40 | 41 | ## Key Features 42 | 43 | - **FSDP** and **Megatron-LM** for training. 44 | - **vLLM** and **HF Transformers** for rollout generation, **SGLang** support coming soon. 45 | - Compatible with Hugging Face Transformers and Modelscope Hub. 46 | - Supervised fine-tuning. 47 | - Reinforcement learning with [PPO](examples/ppo_trainer/), [GRPO](examples/grpo_trainer/), [ReMax](examples/remax_trainer/), [Reinforce++](https://verl.readthedocs.io/en/latest/examples/config.html#algorithm), [RLOO](examples/rloo_trainer/), [PRIME](recipe/prime/), etc. 48 | - Support model-based reward and function-based reward (verifiable reward) 49 | - Support vision-language models (VLMs) and [multi-modal RL](examples/grpo_trainer/run_qwen2_5_vl-7b.sh) 50 | - Flash attention 2, [sequence packing](examples/ppo_trainer/run_qwen2-7b_seq_balance.sh), [sequence parallelism](examples/ppo_trainer/run_deepseek7b_llm_sp2.sh) support via DeepSpeed Ulysses, [LoRA](examples/sft/gsm8k/run_qwen_05_peft.sh), [Liger-kernel](examples/sft/gsm8k/run_qwen_05_sp2_liger.sh). 51 | - Scales up to 70B models and hundreds of GPUs. 52 | - Experiment tracking with wandb, swanlab, mlflow and tensorboard. 53 | 54 | ## Upcoming Features 55 | - Reward model training 56 | - DPO training 57 | - DeepSeek integration with Megatron v0.11 58 | - SGLang integration 59 | 60 | ## Getting Started 61 | 62 | **Quickstart:** 63 | - [Installation](https://verl.readthedocs.io/en/latest/start/install.html) 64 | - [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html) 65 | - [Programming Guide](https://verl.readthedocs.io/en/latest/hybrid_flow.html) 66 | 67 | **Running a PPO example step-by-step:** 68 | - Data and Reward Preparation 69 | - [Prepare Data for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html) 70 | - [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html) 71 | - Understanding the PPO Example 72 | - [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html) 73 | - [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html) 74 | - [Run GSM8K Example](https://verl.readthedocs.io/en/latest/examples/gsm8k_example.html) 75 | 76 | **Reproducible algorithm baselines:** 77 | - [PPO, GRPO, ReMax](https://verl.readthedocs.io/en/latest/experiment/ppo.html) 78 | 79 | **For code explanation and advance usage (extension):** 80 | - PPO Trainer and Workers 81 | - [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html) 82 | - [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html) 83 | - [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/index.html) 84 | - Advance Usage and Extension 85 | - [Ray API design tutorial](https://verl.readthedocs.io/en/latest/advance/placement.html) 86 | - [Extend to Other RL(HF) algorithms](https://verl.readthedocs.io/en/latest/advance/dpo_extension.html) 87 | - [Add Models with the FSDP Backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html) 88 | - [Add Models with the Megatron-LM Backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html) 89 | - [Deployment using Separate GPU Resources](https://github.com/volcengine/verl/tree/main/examples/split_placement) 90 | 91 | **Blogs from the community** 92 | - [使用verl进行GRPO分布式强化学习训练最佳实践](https://www.volcengine.com/docs/6459/1463942) 93 | - [HybridFlow veRL 原文浅析](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/readme.md) 94 | - [最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!](https://team.doubao.com/en/blog/%E6%9C%80%E9%AB%98%E6%8F%90%E5%8D%8720%E5%80%8D%E5%90%9E%E5%90%90%E9%87%8F-%E8%B1%86%E5%8C%85%E5%A4%A7%E6%A8%A1%E5%9E%8B%E5%9B%A2%E9%98%9F%E5%8F%91%E5%B8%83%E5%85%A8%E6%96%B0-rlhf-%E6%A1%86%E6%9E%B6-%E7%8E%B0%E5%B7%B2%E5%BC%80%E6%BA%90) 95 | 96 | Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer/verl_getting_started.ipynb) to get started with PPO training with a single 24GB L4 GPU (**FREE** GPU quota provided by [Lighting Studio](https://lightning.ai/hlin-verl/studios/verl-getting-started))! 97 | 98 | ## Performance Tuning Guide 99 | The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details. 100 | 101 | ## vLLM v0.7 integration preview 102 | We have released a testing version of veRL that supports vLLM>=0.7.0. Please refer to [this document](https://github.com/volcengine/verl/blob/main/docs/README_vllm0.7.md) for installation guide and more information. 103 | 104 | ## Citation and acknowledgement 105 | 106 | If you find the project helpful, please cite: 107 | - [HybridFlow: A Flexible and Efficient RLHF Framework](https://arxiv.org/abs/2409.19256v2) 108 | - [A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization](https://i.cs.hku.hk/~cwu/papers/gmsheng-NL2Code24.pdf) 109 | 110 | ```tex 111 | @article{sheng2024hybridflow, 112 | title = {HybridFlow: A Flexible and Efficient RLHF Framework}, 113 | author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu}, 114 | year = {2024}, 115 | journal = {arXiv preprint arXiv: 2409.19256} 116 | } 117 | ``` 118 | 119 | verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and supported by Anyscale, Bytedance, LMSys.org, Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, University of Hong Kong, and many more. 120 | 121 | ## Awesome work using verl 122 | - [TinyZero](https://github.com/Jiayi-Pan/TinyZero): a reproduction of **DeepSeek R1 Zero** recipe for reasoning tasks 123 | - [PRIME](https://github.com/PRIME-RL/PRIME): Process reinforcement through implicit rewards 124 | - [RAGEN](https://github.com/ZihanWang314/ragen): a general-purpose reasoning **agent** training framework 125 | - [Logic-RL](https://github.com/Unakar/Logic-RL): a reproduction of DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset. 126 | - [SkyThought](https://github.com/NovaSky-AI/SkyThought): RL training for Sky-T1-7B by NovaSky AI team. 127 | - [deepscaler](https://github.com/agentica-project/deepscaler): iterative context scaling with GRPO 128 | - [critic-rl](https://github.com/HKUNLP/critic-rl): LLM critics for code generation 129 | - [Easy-R1](https://github.com/hiyouga/EasyR1): **Multi-modal** RL training framework 130 | - [self-rewarding-reasoning-LLM](https://arxiv.org/pdf/2502.19613): self-rewarding and correction with **generative reward models** 131 | - [Search-R1](https://github.com/PeterGriffinJin/Search-R1): RL with reasoning and **searching (tool-call)** interleaved LLMs 132 | - [Code-R1](https://github.com/ganler/code-r1): Reproducing R1 for **Code** with Reliable Rewards 133 | - [DQO](https://arxiv.org/abs/2410.09302): Enhancing multi-Step reasoning abilities of language models through direct Q-function optimization 134 | - [FIRE](https://arxiv.org/abs/2410.21236): Flaming-hot initiation with regular execution sampling for large language models 135 | - [ReSearch](https://github.com/Agent-RL/ReSearch): Learning to **Re**ason with **Search** for LLMs via Reinforcement Learning 136 | - [DeepRetrieval](https://github.com/pat-jj/DeepRetrieval): Let LLMs learn to **search** and **retrieve** desirable docs with RL 137 | - [cognitive-behaviors](https://github.com/kanishkg/cognitive-behaviors): Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs 138 | 139 | ## Contribution Guide 140 | Contributions from the community are welcome! Please checkout our [roadmap](https://github.com/volcengine/verl/issues/22) and [release plan](https://github.com/volcengine/verl/issues/354). 141 | 142 | ### Code formatting 143 | We use yapf (Google style) to enforce strict code formatting when reviewing PRs. To reformat you code locally, make sure you installed **latest** `yapf` 144 | ```bash 145 | pip3 install yapf --upgrade 146 | ``` 147 | Then, make sure you are at top level of verl repo and run 148 | ```bash 149 | bash scripts/format.sh 150 | ``` 151 | We are HIRING! Send us an [email](mailto:haibin.lin@bytedance.com) if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment. 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /qwen_math_eval_toolkit/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: 3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC 4 | - https://github.com/openai/prm800k 5 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py 6 | - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py 7 | """ 8 | 9 | import re 10 | import regex 11 | import multiprocessing 12 | from math import isclose 13 | from typing import Union 14 | from collections import defaultdict 15 | 16 | from sympy import simplify, N 17 | from sympy.parsing.sympy_parser import parse_expr 18 | from sympy.parsing.latex import parse_latex 19 | from latex2sympy2_extended import latex2sympy 20 | 21 | # from .parser import choice_answer_clean, strip_string 22 | # from parser import choice_answer_clean 23 | 24 | 25 | def choice_answer_clean(pred: str): 26 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 27 | # Clean the answer based on the dataset 28 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 29 | if tmp: 30 | pred = tmp 31 | else: 32 | pred = [pred.strip().strip(".")] 33 | pred = pred[-1] 34 | # Remove the period at the end, again! 35 | pred = pred.rstrip(".").rstrip("/") 36 | return pred 37 | 38 | 39 | def parse_digits(num): 40 | num = regex.sub(",", "", str(num)) 41 | try: 42 | return float(num) 43 | except: 44 | if num.endswith("%"): 45 | num = num[:-1] 46 | if num.endswith("\\"): 47 | num = num[:-1] 48 | try: 49 | return float(num) / 100 50 | except: 51 | pass 52 | return None 53 | 54 | 55 | def is_digit(num): 56 | # paired with parse_digits 57 | return parse_digits(num) is not None 58 | 59 | 60 | def str_to_pmatrix(input_str): 61 | input_str = input_str.strip() 62 | matrix_str = re.findall(r"\{.*,.*\}", input_str) 63 | pmatrix_list = [] 64 | 65 | for m in matrix_str: 66 | m = m.strip("{}") 67 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" 68 | pmatrix_list.append(pmatrix) 69 | 70 | return ", ".join(pmatrix_list) 71 | 72 | 73 | def math_equal( 74 | prediction: Union[bool, float, str], 75 | reference: Union[float, str], 76 | include_percentage: bool = True, 77 | is_close: bool = True, 78 | timeout: bool = False, 79 | ) -> bool: 80 | """ 81 | Exact match of math if and only if: 82 | 1. numerical equal: both can convert to float and are equal 83 | 2. symbolic equal: both can convert to sympy expression and are equal 84 | """ 85 | # print("Judge:", prediction, reference) 86 | if prediction is None or reference is None: 87 | return False 88 | if str(prediction.strip().lower()) == str(reference.strip().lower()): 89 | return True 90 | if ( 91 | reference in ["A", "B", "C", "D", "E"] 92 | and choice_answer_clean(prediction) == reference 93 | ): 94 | return True 95 | 96 | try: # 1. numerical equal 97 | if is_digit(prediction) and is_digit(reference): 98 | prediction = parse_digits(prediction) 99 | reference = parse_digits(reference) 100 | # number questions 101 | if include_percentage: 102 | gt_result = [reference / 100, reference, reference * 100] 103 | else: 104 | gt_result = [reference] 105 | for item in gt_result: 106 | try: 107 | if is_close: 108 | if numeric_equal(prediction, item): 109 | return True 110 | else: 111 | if item == prediction: 112 | return True 113 | except Exception: 114 | continue 115 | return False 116 | except: 117 | pass 118 | 119 | if not prediction and prediction not in [0, False]: 120 | return False 121 | 122 | # 2. symbolic equal 123 | reference = str(reference).strip() 124 | prediction = str(prediction).strip() 125 | 126 | ## pmatrix (amps) 127 | if "pmatrix" in prediction and not "pmatrix" in reference: 128 | reference = str_to_pmatrix(reference) 129 | 130 | ## deal with [], (), {} 131 | pred_str, ref_str = prediction, reference 132 | if ( 133 | prediction.startswith("[") 134 | and prediction.endswith("]") 135 | and not reference.startswith("(") 136 | ) or ( 137 | prediction.startswith("(") 138 | and prediction.endswith(")") 139 | and not reference.startswith("[") 140 | ): 141 | pred_str = pred_str.strip("[]()") 142 | ref_str = ref_str.strip("[]()") 143 | for s in ["{", "}", "(", ")"]: 144 | ref_str = ref_str.replace(s, "") 145 | pred_str = pred_str.replace(s, "") 146 | if pred_str.lower() == ref_str.lower(): 147 | return True 148 | 149 | ## [a, b] vs. [c, d], return a==c and b==d 150 | if ( 151 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None 152 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None 153 | ): 154 | pred_parts = prediction[1:-1].split(",") 155 | ref_parts = reference[1:-1].split(",") 156 | if len(pred_parts) == len(ref_parts): 157 | if all( 158 | [ 159 | math_equal( 160 | pred_parts[i], ref_parts[i], include_percentage, is_close 161 | ) 162 | for i in range(len(pred_parts)) 163 | ] 164 | ): 165 | return True 166 | if ( 167 | ( 168 | prediction.startswith("\\begin{pmatrix}") 169 | or prediction.startswith("\\begin{bmatrix}") 170 | ) 171 | and ( 172 | prediction.endswith("\\end{pmatrix}") 173 | or prediction.endswith("\\end{bmatrix}") 174 | ) 175 | and ( 176 | reference.startswith("\\begin{pmatrix}") 177 | or reference.startswith("\\begin{bmatrix}") 178 | ) 179 | and ( 180 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") 181 | ) 182 | ): 183 | pred_lines = [ 184 | line.strip() 185 | for line in prediction[ 186 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 187 | ].split("\\\\") 188 | if line.strip() 189 | ] 190 | ref_lines = [ 191 | line.strip() 192 | for line in reference[ 193 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 194 | ].split("\\\\") 195 | if line.strip() 196 | ] 197 | matched = True 198 | if len(pred_lines) == len(ref_lines): 199 | for pred_line, ref_line in zip(pred_lines, ref_lines): 200 | pred_parts = pred_line.split("&") 201 | ref_parts = ref_line.split("&") 202 | if len(pred_parts) == len(ref_parts): 203 | if not all( 204 | [ 205 | math_equal( 206 | pred_parts[i], 207 | ref_parts[i], 208 | include_percentage, 209 | is_close, 210 | ) 211 | for i in range(len(pred_parts)) 212 | ] 213 | ): 214 | matched = False 215 | break 216 | else: 217 | matched = False 218 | if not matched: 219 | break 220 | else: 221 | matched = False 222 | if matched: 223 | return True 224 | 225 | if prediction.count("=") == 1 and reference.count("=") == 1: 226 | pred = prediction.split("=") 227 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 228 | ref = reference.split("=") 229 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 230 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 231 | return True 232 | elif ( 233 | prediction.count("=") == 1 234 | and len(prediction.split("=")[0].strip()) <= 2 235 | and "=" not in reference 236 | ): 237 | if math_equal( 238 | prediction.split("=")[1], reference, include_percentage, is_close 239 | ): 240 | return True 241 | elif ( 242 | reference.count("=") == 1 243 | and len(reference.split("=")[0].strip()) <= 2 244 | and "=" not in prediction 245 | ): 246 | if math_equal( 247 | prediction, reference.split("=")[1], include_percentage, is_close 248 | ): 249 | return True 250 | 251 | # symbolic equal with sympy 252 | if timeout: 253 | if call_with_timeout(symbolic_equal_process, prediction, reference): 254 | return True 255 | else: 256 | if symbolic_equal(prediction, reference): 257 | return True 258 | 259 | return False 260 | 261 | 262 | def math_equal_process(param): 263 | return math_equal(param[-2], param[-1]) 264 | 265 | 266 | def numeric_equal(prediction: float, reference: float): 267 | # Note that relative tolerance has significant impact 268 | # on the result of the synthesized GSM-Hard dataset 269 | # if reference.is_integer(): 270 | # return isclose(reference, round(prediction), abs_tol=1e-4) 271 | # else: 272 | # prediction = round(prediction, len(str(reference).split(".")[-1])) 273 | return isclose(reference, prediction, rel_tol=1e-4) 274 | 275 | 276 | def symbolic_equal(a, b): 277 | def _parse(s): 278 | for f in [parse_latex, parse_expr, latex2sympy]: 279 | try: 280 | return f(s.replace("\\\\", "\\")) 281 | except: 282 | try: 283 | return f(s) 284 | except: 285 | pass 286 | return s 287 | 288 | a = _parse(a) 289 | b = _parse(b) 290 | 291 | # direct equal 292 | try: 293 | if str(a) == str(b) or a == b: 294 | return True 295 | except: 296 | pass 297 | 298 | # simplify equal 299 | try: 300 | if a.equals(b) or simplify(a - b) == 0: 301 | return True 302 | except: 303 | pass 304 | 305 | # equation equal 306 | try: 307 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): 308 | return True 309 | except: 310 | pass 311 | 312 | try: 313 | if numeric_equal(float(N(a)), float(N(b))): 314 | return True 315 | except: 316 | pass 317 | 318 | # matrix 319 | try: 320 | # if a and b are matrix 321 | if a.shape == b.shape: 322 | _a = a.applyfunc(lambda x: round(x, 3)) 323 | _b = b.applyfunc(lambda x: round(x, 3)) 324 | if _a.equals(_b): 325 | return True 326 | except: 327 | pass 328 | 329 | return False 330 | 331 | 332 | def symbolic_equal_process(a, b, output_queue): 333 | result = symbolic_equal(a, b) 334 | output_queue.put(result) 335 | 336 | 337 | def call_with_timeout(func, *args, timeout=1, **kwargs): 338 | output_queue = multiprocessing.Queue() 339 | process_args = args + (output_queue,) 340 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 341 | process.start() 342 | process.join(timeout) 343 | 344 | if process.is_alive(): 345 | process.terminate() 346 | process.join() 347 | return False 348 | 349 | return output_queue.get() 350 | 351 | def _test_math_equal(): 352 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) 353 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) 354 | # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) 355 | # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True)) 356 | # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True)) 357 | 358 | # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}' 359 | # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})' 360 | 361 | # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}' 362 | # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}' 363 | 364 | # pred = '-34x-45y+20z-100=0' 365 | # gt = '34x+45y-20z+100=0' 366 | 367 | # pred = '\\frac{100}{3}' 368 | # gt = '33.3' 369 | 370 | # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}' 371 | # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})' 372 | 373 | # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}' 374 | # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}' 375 | 376 | # pred = '(+5)(b+2)' 377 | # gt = '(a+5)(b+2)' 378 | 379 | # pred = '\\frac{1+\\sqrt{5}}{2}' 380 | # gt = '2' 381 | 382 | # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4' 383 | # pred = '1', gt = '1\\\\sqrt{19}' 384 | 385 | # pred = "(0.6,2.6667]" 386 | # gt = "(\\frac{3}{5},\\frac{8}{3}]" 387 | 388 | gt = "x+2n+1" 389 | pred = "x+1" 390 | 391 | print(math_equal(pred, gt, timeout=True)) 392 | 393 | 394 | if __name__ == "__main__": 395 | _test_math_equal() 396 | -------------------------------------------------------------------------------- /dp_actor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | """ 30 | Single Process Actor 31 | """ 32 | 33 | import itertools 34 | from typing import Iterable, Tuple 35 | 36 | import torch 37 | from torch import nn 38 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 39 | 40 | from verl import DataProto 41 | from verl.trainer.ppo import core_algos 42 | from verl.workers.actor import BasePPOActor 43 | from verl.utils.py_functional import append_to_dict 44 | from verl.utils.torch_functional import logprobs_from_logits, masked_mean 45 | from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad 46 | from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx 47 | import verl.utils.torch_functional as verl_F 48 | 49 | from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis 50 | 51 | __all__ = ['DataParallelPPOActor'] 52 | 53 | def clamp_preserve_grad(x, min_val=None, max_val=None): 54 | clipped = torch.clamp(x, min=min_val, max=max_val) 55 | return x + (clipped - x).detach() 56 | 57 | class DataParallelPPOActor(BasePPOActor): 58 | 59 | def __init__( 60 | self, 61 | config, 62 | actor_module: nn.Module, 63 | actor_optimizer: torch.optim.Optimizer = None, 64 | ): 65 | """When optimizer is None, it is Reference Policy""" 66 | super().__init__(config) 67 | self.actor_module = actor_module 68 | self.actor_optimizer = actor_optimizer 69 | self.use_remove_padding = self.config.get('use_remove_padding', False) 70 | print(f'Actor use_remove_padding={self.use_remove_padding}') 71 | self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size 72 | self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 73 | 74 | self.compute_entropy_from_logits = ( 75 | torch.compile(verl_F.entropy_from_logits, dynamic=True) 76 | if self.config.get('use_torch_compile', True) # use torch compile by default 77 | else verl_F.entropy_from_logits) 78 | 79 | def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Returns: 82 | entropy: # (bs, response_len) 83 | log_probs: # (bs, response_len) 84 | """ 85 | response_length = micro_batch['responses'].size(-1) 86 | multi_modal_inputs = {} 87 | if 'multi_modal_inputs' in micro_batch: 88 | for key in micro_batch['multi_modal_inputs'][0].keys(): 89 | multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch['multi_modal_inputs']], 90 | dim=0) 91 | 92 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 93 | input_ids = micro_batch['input_ids'] 94 | batch_size, seqlen = input_ids.shape 95 | attention_mask = micro_batch['attention_mask'] 96 | position_ids = micro_batch['position_ids'] 97 | if position_ids.dim() == 3: # qwen2vl mrope 98 | position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) 99 | 100 | if self.use_remove_padding: 101 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), 102 | attention_mask) # input_ids_rmpad (total_nnz, ...) 103 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) 104 | 105 | # unpad the position_ids to align the rotary 106 | if position_ids.dim() == 3: 107 | position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), 108 | indices).transpose(0, 1).unsqueeze( 109 | 1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) 110 | else: 111 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), 112 | indices).transpose(0, 1) 113 | 114 | # for compute the log_prob 115 | input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) 116 | 117 | # pad and slice the inputs if sp > 1 118 | if self.use_ulysses_sp: 119 | input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ 120 | position_ids_rmpad, \ 121 | sp_size=self.ulysses_sequence_parallel_size) 122 | input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, 123 | self.ulysses_sequence_parallel_size) 124 | 125 | input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) 126 | 127 | # only pass input_ids and position_ids to enable flash_attn_varlen 128 | output = self.actor_module(input_ids=input_ids_rmpad, 129 | attention_mask=None, 130 | position_ids=position_ids_rmpad, 131 | **multi_modal_inputs, 132 | use_cache=False) # prevent model thinks we are generating 133 | logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) 134 | 135 | logits_rmpad.div_(temperature) 136 | 137 | # compute entropy 138 | entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) 139 | 140 | # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) 141 | log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) 142 | 143 | # gather log_prob if sp > 1 144 | if self.use_ulysses_sp: 145 | # gather and unpad for the ulysses sp 146 | log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) 147 | entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad, 148 | gather_dim=0, 149 | unpad_dim=0, 150 | padding_size=pad_size) 151 | # pad back to (bsz, seqlen) 152 | full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1), 153 | indices=indices, 154 | batch=batch_size, 155 | seqlen=seqlen) 156 | full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1), 157 | indices=indices, 158 | batch=batch_size, 159 | seqlen=seqlen) 160 | 161 | # only return response part: 162 | entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) 163 | log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) 164 | 165 | else: # not using rmpad and no ulysses sp 166 | output = self.actor_module(input_ids=input_ids, 167 | attention_mask=attention_mask, 168 | position_ids=position_ids, 169 | **multi_modal_inputs, 170 | use_cache=False) # prevent model thinks we are generating 171 | logits = output.logits 172 | logits.div_(temperature) 173 | logits = logits[:, -response_length - 1:-1, :] # (bsz, response_length, vocab_size) 174 | log_probs = logprobs_from_logits(logits, micro_batch['responses']) 175 | entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) 176 | 177 | return entropy, log_probs 178 | 179 | def _optimizer_step(self): 180 | assert self.config.grad_clip is not None 181 | 182 | if isinstance(self.actor_module, FSDP): 183 | grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) 184 | else: 185 | grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) 186 | self.actor_optimizer.step() 187 | return grad_norm 188 | 189 | def compute_log_prob(self, data: DataProto) -> torch.Tensor: 190 | """Compute the log probability of the responses given input_ids, attention_mask and position_ids 191 | 192 | Args: 193 | data (DataProto): a DataProto containing keys 194 | 195 | ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the 196 | concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. 197 | 198 | ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. 199 | 200 | ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. 201 | 202 | ``responses``: tensor of shape [batch_size, response_length]. torch.int64. 203 | 204 | Returns: 205 | torch.Tensor: the log_prob tensor 206 | """ 207 | # set to eval 208 | self.actor_module.eval() 209 | 210 | micro_batch_size = data.meta_info['micro_batch_size'] 211 | temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error 212 | use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] 213 | 214 | select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] 215 | batch = data.select(batch_keys=select_keys).batch 216 | has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() 217 | 218 | if has_multi_modal_inputs: 219 | num_micro_batches = data.batch.batch_size[0] // micro_batch_size 220 | non_tensor_select_keys = ['multi_modal_inputs'] 221 | micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) 222 | elif use_dynamic_bsz: 223 | # split using dynamic bsz 224 | max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size 225 | micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) 226 | else: 227 | micro_batches = batch.split(micro_batch_size) 228 | 229 | log_probs_lst = [] 230 | for micro_batch in micro_batches: 231 | if isinstance(micro_batch, DataProto): 232 | micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} 233 | 234 | with torch.no_grad(): 235 | _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) 236 | log_probs_lst.append(log_probs) 237 | log_probs = torch.concat(log_probs_lst, dim=0) 238 | 239 | if use_dynamic_bsz: 240 | indices = list(itertools.chain.from_iterable(indices)) 241 | assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" 242 | revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) 243 | log_probs = log_probs[revert_indices] 244 | 245 | return log_probs 246 | 247 | def update_policy(self, data: DataProto): 248 | # make sure we are in training mode 249 | self.actor_module.train() 250 | 251 | temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error 252 | 253 | select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages', "correct_rate", "correct_mask"] 254 | if self.config.use_kl_loss: 255 | select_keys.append('ref_log_prob') 256 | batch = data.select(batch_keys=select_keys).batch 257 | has_multi_modal_inputs = 'multi_modal_inputs' in data.non_tensor_batch.keys() 258 | 259 | # Split to make minibatch iterator for updating the actor 260 | # See PPO paper for details. https://arxiv.org/abs/1707.06347 261 | if has_multi_modal_inputs: 262 | num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size 263 | non_tensor_select_keys = ['multi_modal_inputs'] 264 | dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) 265 | else: 266 | dataloader = batch.split(self.config.ppo_mini_batch_size) 267 | 268 | metrics = {} 269 | for epoch in range(self.config.ppo_epochs): 270 | for batch_idx, data in enumerate(dataloader): 271 | # split batch into micro_batches 272 | mini_batch = data 273 | if has_multi_modal_inputs: 274 | self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu 275 | num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu 276 | micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) 277 | elif self.config.use_dynamic_bsz: 278 | max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size 279 | micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) 280 | else: 281 | self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu 282 | # split batch into micro_batches 283 | micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) 284 | 285 | self.actor_optimizer.zero_grad() 286 | 287 | for data in micro_batches: 288 | # Support all hardwares 289 | if isinstance(data, DataProto): 290 | data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} 291 | else: 292 | data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload 293 | responses = data['responses'] 294 | response_length = responses.size(1) 295 | attention_mask = data['attention_mask'] 296 | response_mask = attention_mask[:, -response_length:] 297 | old_log_prob = data['old_log_probs'] 298 | advantages = data['advantages'] 299 | correct_rate = data['correct_rate'] 300 | correct_mask = data['correct_mask'] 301 | assert correct_rate.dim() == 2, f"correct_rate.shape() = {correct_rate.shape()}" 302 | 303 | clip_ratio = self.config.clip_ratio 304 | clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio 305 | clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio 306 | entropy_coeff = self.config.entropy_coeff 307 | use_token_level_loss = self.config.use_token_level_loss 308 | 309 | # all return: (bsz, response_length) 310 | entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) 311 | if self.config.neg_weight < 0.0: 312 | pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss( 313 | old_log_prob=old_log_prob, 314 | log_prob=log_prob, 315 | advantages=advantages, 316 | eos_mask=response_mask, 317 | cliprange=clip_ratio, 318 | cliprange_low=clip_ratio_low, 319 | cliprange_high=clip_ratio_high, 320 | use_token_level_loss=use_token_level_loss) 321 | else: 322 | negative_approx_kl = log_prob - old_log_prob 323 | ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) 324 | pg_clipfrac = ppo_kl 325 | if self.config.ratio_type == "token": 326 | ratio = torch.exp((negative_approx_kl).clamp(max=10.0)) # (bsz, response_length) 327 | elif self.config.ratio_type == "sequence": 328 | ratio = torch.exp(((negative_approx_kl) * response_mask).sum(dim=-1).clamp(max=10.0))[:,None] # (bsz, response_length) 329 | else: 330 | raise NotImplementedError 331 | assert ratio.dim() == 2, f"ratio.shape() = {ratio.shape()}" 332 | assert correct_mask.max().detach().cpu().item() <= 1.0, f" {correct_mask}" 333 | assert correct_mask.min().detach().cpu().item() >= 0.0, f" {correct_mask}" 334 | assert correct_rate.min().detach().cpu().item() < 1.0, f" {correct_rate}" 335 | assert correct_rate.max().detach().cpu().item() > 0.0, f" {correct_rate}" 336 | correct_rate = torch.clamp(correct_rate, min=0.02, max=0.98) 337 | negative_rate = 1.0 - correct_rate 338 | if self.config.ratio_type == "token": 339 | proxy = 1.0 - correct_rate * ratio 340 | norm = 1.0 - correct_rate * 1.0 341 | proxy = proxy / norm 342 | negative_loss = - torch.log(clamp_preserve_grad(proxy, min_val=abs(self.config.clamp_negative))) 343 | positive_loss = - torch.log(clamp_preserve_grad(ratio, min_val=self.config.clamp_positive)) 344 | pg_losses = (negative_loss * (1.0-correct_mask[:, None]) * self.config.neg_weight + correct_mask[:, None] * positive_loss) 345 | if self.config.normalize == 1: 346 | pg_losses = pg_losses * negative_rate 347 | elif self.config.normalize == 2: 348 | pg_losses = pg_losses * (negative_rate / correct_rate)**0.5 349 | else: 350 | pg_losses = pg_losses 351 | print(f"normalize config {self.config.normalize}") 352 | elif self.config.ratio_type == "sequence": 353 | ratio = ratio.squeeze(-1) 354 | assert ratio.dim() == 1, f"ratio.shape() = {ratio.shape()}" 355 | correct_rate =correct_rate[:,0] 356 | assert correct_rate.dim() == 1, f"correct_rate.shape() = {correct_rate.shape()}" 357 | 358 | proxy = 1.0 - correct_rate * ratio 359 | norm = 1.0 - correct_rate * 1.0 360 | proxy = proxy / norm 361 | negative_loss = - torch.log(clamp_preserve_grad(proxy, min_val=abs(self.config.clamp_negative))) 362 | positive_loss = - torch.log(clamp_preserve_grad(ratio, min_val=self.config.clamp_positive)) 363 | pg_losses = (negative_loss * (1.0 - correct_mask) * self.config.neg_weight + correct_mask * positive_loss) 364 | if self.config.normalize == 1: 365 | pg_losses = pg_losses * negative_rate # new update on 0414 366 | elif self.config.normalize == 2: 367 | pg_losses = pg_losses * (negative_rate / correct_rate)**0.5 368 | else: 369 | pg_losses = pg_losses 370 | # reduction 371 | if self.config.ratio_type == "token": 372 | pg_loss = torch.sum(pg_losses * response_mask, dim=1) / 1000.0 373 | pg_loss = torch.sum(pg_loss) 374 | elif self.config.ratio_type == "sequence": 375 | pg_loss = torch.sum(pg_losses) 376 | 377 | # compute entropy loss from entropy 378 | entropy_loss = verl_F.masked_mean(entropy, response_mask) 379 | 380 | # compute policy loss 381 | policy_loss = pg_loss - entropy_loss * entropy_coeff 382 | 383 | if self.config.use_kl_loss: 384 | ref_log_prob = data['ref_log_prob'] 385 | # compute kl loss 386 | kld = core_algos.kl_penalty(logprob=log_prob, 387 | ref_logprob=ref_log_prob, 388 | kl_penalty=self.config.kl_loss_type) 389 | kl_loss = masked_mean(kld, response_mask) 390 | 391 | policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef 392 | metrics['actor/kl_loss'] = kl_loss.detach().item() 393 | metrics['actor/kl_coef'] = self.config.kl_loss_coef 394 | 395 | if self.config.use_dynamic_bsz: 396 | # relative to the dynamic bsz 397 | if self.config.bugged_dynamic_scale == 1: 398 | loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) 399 | else: 400 | loss = policy_loss 401 | else: 402 | loss = policy_loss / self.gradient_accumulation 403 | loss.backward() 404 | 405 | data = { 406 | 'actor/entropy': entropy_loss.detach().item(), 407 | 'actor/pg_loss': pg_loss.detach().item(), 408 | 'actor/pg_clipfrac': pg_clipfrac.detach().item(), 409 | 'actor/ppo_kl': ppo_kl.detach().item(), 410 | } 411 | append_to_dict(metrics, data) 412 | 413 | grad_norm = self._optimizer_step() 414 | data = {'actor/grad_norm': grad_norm.detach().item()} 415 | append_to_dict(metrics, data) 416 | self.actor_optimizer.zero_grad() 417 | return metrics 418 | -------------------------------------------------------------------------------- /qwen_math_eval_toolkit/parser.py: -------------------------------------------------------------------------------- 1 | import random 2 | import regex 3 | import re 4 | import sympy 5 | from latex2sympy2_extended import latex2sympy 6 | from typing import TypeVar, Iterable, List, Union, Any, Dict 7 | from word2number import w2n 8 | from qwen_math_eval_toolkit.utils import * 9 | 10 | 11 | def _fix_fracs(string): 12 | substrs = string.split("\\frac") 13 | new_str = substrs[0] 14 | if len(substrs) > 1: 15 | substrs = substrs[1:] 16 | for substr in substrs: 17 | new_str += "\\frac" 18 | if len(substr) > 0 and substr[0] == "{": 19 | new_str += substr 20 | else: 21 | try: 22 | assert len(substr) >= 2 23 | except: 24 | return string 25 | a = substr[0] 26 | b = substr[1] 27 | if b != "{": 28 | if len(substr) > 2: 29 | post_substr = substr[2:] 30 | new_str += "{" + a + "}{" + b + "}" + post_substr 31 | else: 32 | new_str += "{" + a + "}{" + b + "}" 33 | else: 34 | if len(substr) > 2: 35 | post_substr = substr[2:] 36 | new_str += "{" + a + "}" + b + post_substr 37 | else: 38 | new_str += "{" + a + "}" + b 39 | string = new_str 40 | return string 41 | 42 | 43 | def _fix_a_slash_b(string): 44 | if len(string.split("/")) != 2: 45 | return string 46 | a = string.split("/")[0] 47 | b = string.split("/")[1] 48 | try: 49 | if "sqrt" not in a: 50 | a = int(a) 51 | if "sqrt" not in b: 52 | b = int(b) 53 | assert string == "{}/{}".format(a, b) 54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 55 | return new_string 56 | except: 57 | return string 58 | 59 | 60 | def _fix_sqrt(string): 61 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 62 | return _string 63 | 64 | 65 | def convert_word_number(text: str) -> str: 66 | try: 67 | text = str(w2n.word_to_num(text)) 68 | except: 69 | pass 70 | return text 71 | 72 | 73 | # units mainly from MathQA 74 | unit_texts = [ 75 | "east", 76 | "degree", 77 | "mph", 78 | "kmph", 79 | "ft", 80 | "m sqaure", 81 | " m east", 82 | "sq m", 83 | "deg", 84 | "mile", 85 | "q .", 86 | "monkey", 87 | "prime", 88 | "ratio", 89 | "profit of rs", 90 | "rd", 91 | "o", 92 | "gm", 93 | "p . m", 94 | "lb", 95 | "tile", 96 | "per", 97 | "dm", 98 | "lt", 99 | "gain", 100 | "ab", 101 | "way", 102 | "west", 103 | "a .", 104 | "b .", 105 | "c .", 106 | "d .", 107 | "e .", 108 | "f .", 109 | "g .", 110 | "h .", 111 | "t", 112 | "a", 113 | "h", 114 | "no change", 115 | "men", 116 | "soldier", 117 | "pie", 118 | "bc", 119 | "excess", 120 | "st", 121 | "inches", 122 | "noon", 123 | "percent", 124 | "by", 125 | "gal", 126 | "kmh", 127 | "c", 128 | "acre", 129 | "rise", 130 | "a . m", 131 | "th", 132 | "π r 2", 133 | "sq", 134 | "mark", 135 | "l", 136 | "toy", 137 | "coin", 138 | "sq . m", 139 | "gallon", 140 | "° f", 141 | "profit", 142 | "minw", 143 | "yr", 144 | "women", 145 | "feet", 146 | "am", 147 | "pm", 148 | "hr", 149 | "cu cm", 150 | "square", 151 | "v â € ™", 152 | "are", 153 | "rupee", 154 | "rounds", 155 | "cubic", 156 | "cc", 157 | "mtr", 158 | "s", 159 | "ohm", 160 | "number", 161 | "kmph", 162 | "day", 163 | "hour", 164 | "minute", 165 | "min", 166 | "second", 167 | "man", 168 | "woman", 169 | "sec", 170 | "cube", 171 | "mt", 172 | "sq inch", 173 | "mp", 174 | "∏ cm ³", 175 | "hectare", 176 | "more", 177 | "sec", 178 | "unit", 179 | "cu . m", 180 | "cm 2", 181 | "rs .", 182 | "rs", 183 | "kg", 184 | "g", 185 | "month", 186 | "km", 187 | "m", 188 | "cm", 189 | "mm", 190 | "apple", 191 | "liter", 192 | "loss", 193 | "yard", 194 | "pure", 195 | "year", 196 | "increase", 197 | "decrease", 198 | "d", 199 | "less", 200 | "Surface", 201 | "litre", 202 | "pi sq m", 203 | "s .", 204 | "metre", 205 | "meter", 206 | "inch", 207 | ] 208 | 209 | unit_texts.extend([t + "s" for t in unit_texts]) 210 | 211 | 212 | def strip_string(string, skip_unit=False): 213 | string = str(string).strip() 214 | # linebreaks 215 | string = string.replace("\n", "") 216 | 217 | # right "." 218 | string = string.rstrip(".") 219 | 220 | # remove inverse spaces 221 | # replace \\ with \ 222 | string = string.replace("\\!", "") 223 | # string = string.replace("\\ ", "") 224 | # string = string.replace("\\\\", "\\") 225 | 226 | # matrix 227 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) 228 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) 229 | string = string.replace("bmatrix", "pmatrix") 230 | 231 | # replace tfrac and dfrac with frac 232 | string = string.replace("tfrac", "frac") 233 | string = string.replace("dfrac", "frac") 234 | string = ( 235 | string.replace("\\neq", "\\ne") 236 | .replace("\\leq", "\\le") 237 | .replace("\\geq", "\\ge") 238 | ) 239 | 240 | # remove \left and \right 241 | string = string.replace("\\left", "") 242 | string = string.replace("\\right", "") 243 | string = string.replace("\\{", "{") 244 | string = string.replace("\\}", "}") 245 | 246 | # Remove unit: miles, dollars if after is not none 247 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 248 | if _string != "" and _string != string: 249 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 250 | string = _string 251 | 252 | if not skip_unit: 253 | # Remove unit: texts 254 | for _ in range(2): 255 | for unit_text in unit_texts: 256 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character 257 | # the suffix should be either the end of the string or a non-alphanumeric character 258 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) 259 | if _string != "": 260 | string = _string 261 | 262 | # Remove circ (degrees) 263 | string = string.replace("^{\\circ}", "") 264 | string = string.replace("^\\circ", "") 265 | 266 | # remove dollar signs 267 | string = string.replace("\\$", "") 268 | string = string.replace("$", "") 269 | string = string.replace("\\(", "").replace("\\)", "") 270 | 271 | # convert word number to digit 272 | string = convert_word_number(string) 273 | 274 | # replace "\\text{...}" to "..." 275 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string) 276 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: 277 | string = string.replace(key, "") 278 | string = string.replace("\\emptyset", r"{}") 279 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") 280 | 281 | # remove percentage 282 | string = string.replace("\\%", "") 283 | string = string.replace("\%", "") 284 | string = string.replace("%", "") 285 | 286 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 287 | string = string.replace(" .", " 0.") 288 | string = string.replace("{.", "{0.") 289 | 290 | # cdot 291 | # string = string.replace("\\cdot", "") 292 | if ( 293 | string.startswith("{") 294 | and string.endswith("}") 295 | and string.isalnum() 296 | or string.startswith("(") 297 | and string.endswith(")") 298 | and string.isalnum() 299 | or string.startswith("[") 300 | and string.endswith("]") 301 | and string.isalnum() 302 | ): 303 | string = string[1:-1] 304 | 305 | # inf 306 | string = string.replace("infinity", "\\infty") 307 | if "\\infty" not in string: 308 | string = string.replace("inf", "\\infty") 309 | string = string.replace("+\\inity", "\\infty") 310 | 311 | # and 312 | string = string.replace("and", "") 313 | string = string.replace("\\mathbf", "") 314 | 315 | # use regex to remove \mbox{...} 316 | string = re.sub(r"\\mbox{.*?}", "", string) 317 | 318 | # quote 319 | string.replace("'", "") 320 | string.replace('"', "") 321 | 322 | # i, j 323 | if "j" in string and "i" not in string: 324 | string = string.replace("j", "i") 325 | 326 | # replace a.000b where b is not number or b is end, with ab, use regex 327 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) 328 | string = re.sub(r"(\d+)\.0*$", r"\1", string) 329 | 330 | # if empty, return empty string 331 | if len(string) == 0: 332 | return string 333 | if string[0] == ".": 334 | string = "0" + string 335 | 336 | # to consider: get rid of e.g. "k = " or "q = " at beginning 337 | if len(string.split("=")) == 2: 338 | if len(string.split("=")[0]) <= 2: 339 | string = string.split("=")[1] 340 | 341 | string = _fix_sqrt(string) 342 | string = string.replace(" ", "") 343 | 344 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 345 | string = _fix_fracs(string) 346 | 347 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 348 | string = _fix_a_slash_b(string) 349 | 350 | return string 351 | 352 | 353 | def extract_multi_choice_answer(pred_str): 354 | # TODO: SFT models 355 | if "Problem:" in pred_str: 356 | pred_str = pred_str.split("Problem:", 1)[0] 357 | pred_str = pred_str.replace("choice is", "answer is") 358 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower()) 359 | if patt is not None: 360 | return patt.group("ans").upper() 361 | return "placeholder" 362 | 363 | 364 | direct_answer_trigger_for_fewshot = ("choice is", "answer is") 365 | 366 | 367 | def choice_answer_clean(pred: str): 368 | pred = pred.strip("\n") 369 | 370 | # Determine if this is ICL, if so, use \n\n to split the first chunk. 371 | ICL = False 372 | for trigger in direct_answer_trigger_for_fewshot: 373 | if pred.count(trigger) > 1: 374 | ICL = True 375 | if ICL: 376 | pred = pred.split("\n\n")[0] 377 | 378 | # Split the trigger to find the answer. 379 | preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) 380 | if len(preds) > 1: 381 | answer_flag = True 382 | pred = preds[-1] 383 | else: 384 | answer_flag = False 385 | 386 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 387 | 388 | # Clean the answer based on the dataset 389 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 390 | if tmp: 391 | pred = tmp 392 | else: 393 | pred = [pred.strip().strip(".")] 394 | 395 | if len(pred) == 0: 396 | pred = "" 397 | else: 398 | if answer_flag: 399 | # choose the first element in list ... 400 | pred = pred[0] 401 | else: 402 | # choose the last e 403 | pred = pred[-1] 404 | 405 | # Remove the period at the end, again! 406 | pred = pred.rstrip(".").rstrip("/") 407 | 408 | return pred 409 | 410 | 411 | def find_box(pred_str: str): 412 | ans = pred_str.split("boxed")[-1] 413 | if not ans: 414 | return "" 415 | if ans[0] == "{": 416 | stack = 1 417 | a = "" 418 | for c in ans[1:]: 419 | if c == "{": 420 | stack += 1 421 | a += c 422 | elif c == "}": 423 | stack -= 1 424 | if stack == 0: 425 | break 426 | a += c 427 | else: 428 | a += c 429 | else: 430 | a = ans.split("$")[0].strip() 431 | return a 432 | 433 | 434 | def clean_units(pred_str: str): 435 | """Clean the units in the number.""" 436 | 437 | def convert_pi_to_number(code_string): 438 | code_string = code_string.replace("\\pi", "π") 439 | # Replace \pi or π not preceded by a digit or } with 3.14 440 | code_string = re.sub(r"(? "3*3.14" 442 | code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) 443 | # Handle cases where π is within braces or followed by a multiplication symbol 444 | # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14" 445 | code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) 446 | code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) 447 | return code_string 448 | 449 | pred_str = convert_pi_to_number(pred_str) 450 | pred_str = pred_str.replace("%", "/100") 451 | pred_str = pred_str.replace("$", "") 452 | pred_str = pred_str.replace("¥", "") 453 | pred_str = pred_str.replace("°C", "") 454 | pred_str = pred_str.replace(" C", "") 455 | pred_str = pred_str.replace("°", "") 456 | return pred_str 457 | 458 | 459 | def extract_theoremqa_answer(pred: str, answer_flag: bool = True): 460 | if any([option in pred.lower() for option in ["yes", "true"]]): 461 | pred = "True" 462 | elif any([option in pred.lower() for option in ["no", "false"]]): 463 | pred = "False" 464 | elif any( 465 | [ 466 | option in pred.lower() 467 | for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"] 468 | ] 469 | ): 470 | pass 471 | else: 472 | # Some of the models somehow get used to boxed output from pre-training 473 | if "boxed" in pred: 474 | pred = find_box(pred) 475 | 476 | if answer_flag: 477 | # Extract the numbers out of the string 478 | pred = pred.split("=")[-1].strip() 479 | pred = clean_units(pred) 480 | try: 481 | tmp = str(latex2sympy(pred)) 482 | pred = str(eval(tmp)) 483 | except Exception: 484 | if re.match(r"-?[\d\.]+\s\D+$", pred): 485 | pred = pred.split(" ")[0] 486 | elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): 487 | pred = pred.split(" ")[0] 488 | else: 489 | # desparate search over the last number 490 | preds = re.findall(r"-?\d*\.?\d+", pred) 491 | if len(preds) >= 1: 492 | pred = preds[-1] 493 | else: 494 | pred = "" 495 | 496 | return pred 497 | 498 | #关键提取函数 499 | def extract_answer(pred_str, data_name, use_last_number=True): 500 | pred_str = pred_str.replace("\u043a\u0438", "") 501 | if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: #skip for math benchmark 502 | # TODO check multiple choice 503 | return choice_answer_clean(pred_str) 504 | 505 | if "final answer is $" in pred_str and "$. I hope" in pred_str: 506 | # minerva_math 507 | tmp = pred_str.split("final answer is $", 1)[1] 508 | pred = tmp.split("$. I hope", 1)[0].strip() 509 | elif "boxed" in pred_str: 510 | ans = pred_str.split("boxed")[-1] 511 | if len(ans) == 0: 512 | return "" 513 | elif ans[0] == "{": 514 | stack = 1 515 | a = "" 516 | for c in ans[1:]: 517 | if c == "{": 518 | stack += 1 519 | a += c 520 | elif c == "}": 521 | stack -= 1 522 | if stack == 0: 523 | break 524 | a += c 525 | else: 526 | a += c 527 | else: 528 | a = ans.split("$")[0].strip() 529 | pred = a 530 | elif "he answer is" in pred_str: 531 | pred = pred_str.split("he answer is")[-1].strip() 532 | elif "final answer is" in pred_str: 533 | pred = pred_str.split("final answer is")[-1].strip() 534 | elif "答案是" in pred_str: 535 | # Handle Chinese few-shot multiple choice problem answer extraction 536 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() 537 | else: # use the last number 538 | if use_last_number: 539 | pattern = "-?\d*\.?\d+" 540 | pred = re.findall(pattern, pred_str.replace(",", "")) 541 | if len(pred) >= 1: 542 | pred = pred[-1] 543 | else: 544 | pred = "" 545 | else: 546 | pred = "" 547 | 548 | # choice answer 549 | if ( 550 | data_name in ["sat_math", "aqua"] 551 | or "mmlu" in data_name 552 | ): #false for math benchmark 553 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 554 | if tmp: 555 | pred = tmp[-1] 556 | else: 557 | pred = pred.strip().strip(".") 558 | 559 | # multiple line 560 | # pred = pred.split("\n")[0] 561 | pred = re.sub(r"\n\s*", "", pred) 562 | if pred != "" and pred[0] == ":": 563 | pred = pred[1:] 564 | if pred != "" and pred[-1] == ".": 565 | pred = pred[:-1] 566 | if pred != "" and pred[-1] == "/": 567 | pred = pred[:-1] 568 | pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) 569 | return pred 570 | 571 | 572 | STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] 573 | 574 | 575 | def parse_ground_truth(example: Dict[str, Any], data_name): 576 | if "gt_cot" in example and "gt" in example: 577 | if data_name in ["math"]: 578 | gt_ans = extract_answer(example["gt_cot"], data_name) 579 | elif data_name in STRIP_EXCEPTIONS: 580 | gt_ans = example["gt"] 581 | else: 582 | gt_ans = strip_string(example["gt"]) 583 | return example["gt_cot"], gt_ans 584 | 585 | # parse ground truth 586 | if data_name in ["math", "minerva_math", "math500"]: #关键代码 587 | gt_cot = example["solution"] 588 | gt_ans = extract_answer(gt_cot, data_name) 589 | elif data_name == "gsm8k": 590 | gt_cot, gt_ans = example["answer"].split("####") 591 | elif data_name == "svamp": 592 | gt_cot, gt_ans = example["Equation"], example["Answer"] 593 | elif data_name == "asdiv": 594 | gt_cot = example["formula"] 595 | gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) 596 | elif data_name == "mawps": 597 | gt_cot, gt_ans = None, example["target"] 598 | elif data_name == "tabmwp": 599 | gt_cot = example["solution"] 600 | gt_ans = example["answer"] 601 | if example["ans_type"] in ["integer_number", "decimal_number"]: 602 | if "/" in gt_ans: 603 | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) 604 | elif "," in gt_ans: 605 | gt_ans = float(gt_ans.replace(",", "")) 606 | elif "%" in gt_ans: 607 | gt_ans = float(gt_ans.split("%")[0]) / 100 608 | else: 609 | gt_ans = float(gt_ans) 610 | elif data_name == "carp_en": 611 | gt_cot, gt_ans = example["steps"], example["answer"] 612 | elif data_name == "mmlu_stem": 613 | abcd = "ABCD" 614 | gt_cot, gt_ans = None, abcd[example["answer"]] 615 | elif data_name == "sat_math": 616 | gt_cot, gt_ans = None, example["Answer"] 617 | elif data_name == "aqua": 618 | gt_cot, gt_ans = None, example["correct"] 619 | elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]: 620 | gt_cot, gt_ans = None, example["answer"].replace("$", "").strip() 621 | elif data_name == "gaokao_math_qa": 622 | gt_cot, gt_ans = None, example["label"] 623 | elif data_name in ["gaokao2024_mix", "cn_middle_school"]: 624 | if len(example["choice_answer"]) > 0: 625 | gt_cot, gt_ans = None, example["choice_answer"] 626 | else: 627 | gt_cot, gt_ans = None, example["answer"] 628 | elif data_name == "olympiadbench": 629 | gt_cot, gt_ans = None, example["final_answer"][0].strip("$") 630 | elif data_name in [ 631 | "aime24", 632 | "amc23", 633 | "cmath", 634 | "gaokao2024_I", 635 | "gaokao2024_II", 636 | "imo2024", 637 | ]: 638 | gt_cot, gt_ans = None, example["answer"] 639 | else: 640 | raise NotImplementedError(f"`{data_name}`") 641 | # post process 642 | gt_cot = str(gt_cot).strip() 643 | if data_name not in STRIP_EXCEPTIONS: 644 | gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en") 645 | else: 646 | gt_ans = ( 647 | gt_ans.replace("\\neq", "\\ne") 648 | .replace("\\leq", "\\le") 649 | .replace("\\geq", "\\ge") 650 | ) 651 | return gt_cot, gt_ans 652 | 653 | 654 | def parse_question(example, data_name): 655 | question = "" 656 | if data_name == "asdiv": 657 | question = f"{example['body'].strip()} {example['question'].strip()}" 658 | elif data_name == "svamp": 659 | body = example["Body"].strip() 660 | if not body.endswith("."): 661 | body = body + "." 662 | question = f'{body} {example["Question"].strip()}' 663 | elif data_name == "tabmwp": 664 | title_str = ( 665 | f'regarding "{example["table_title"]}" ' if example["table_title"] else "" 666 | ) 667 | question = f"Read the following table {title_str}and answer a question:\n" 668 | question += f'{example["table"]}\n{example["question"]}' 669 | if example["choices"]: 670 | question += ( 671 | f' Please select from the following options: {example["choices"]}' 672 | ) 673 | elif data_name == "carp_en": 674 | question = example["content"] 675 | elif data_name == "mmlu_stem": 676 | options = example["choices"] 677 | assert len(options) == 4 678 | for i, (label, option) in enumerate(zip("ABCD", options)): 679 | options[i] = f"({label}) {str(option).strip()}" 680 | options = " ".join(options) 681 | # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" 682 | question = f"{example['question'].strip()}\nAnswer Choices: {options}" 683 | elif data_name == "sat_math": 684 | options = example["options"].strip() 685 | assert "A" == options[0] 686 | options = "(" + options 687 | for ch in "BCD": 688 | if f" {ch}) " in options: 689 | options = regex.sub(f" {ch}\) ", f" ({ch}) ", options) 690 | # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 691 | question = f"{example['question'].strip()}\nAnswer Choices: {options}" 692 | elif "aqua" in data_name: 693 | options = example["options"] 694 | choice = "(" + "(".join(options) 695 | choice = choice.replace("(", " (").replace(")", ") ").strip() 696 | choice = "\nAnswer Choices: " + choice 697 | question = example["question"].strip() + choice 698 | elif data_name == "gaokao_math_qa": 699 | options_dict = example["options"] 700 | options = [] 701 | for key in options_dict: 702 | options.append(f"({key}) {options_dict[key]}") 703 | options = " ".join(options) 704 | question = f"{example['question'].strip()}\n选项: {options}" 705 | else: 706 | for key in ["question", "problem", "Question", "input"]: 707 | if key in example: 708 | question = example[key] 709 | break 710 | # assert question != "" 711 | # Yes or No question 712 | _, gt_ans = parse_ground_truth(example, data_name) 713 | if isinstance(gt_ans, str): 714 | gt_lower = gt_ans.lower() 715 | if gt_lower in ["true", "false"]: 716 | question += " (True or False)" 717 | if gt_lower in ["yes", "no"]: 718 | question += " (Yes or No)" 719 | return question.strip() 720 | 721 | 722 | def run_execute(executor, result, prompt_type, data_name, execute=False): 723 | if not result or result == "error": 724 | return None, None 725 | report = None 726 | 727 | if "program_only" in prompt_type: 728 | prediction = extract_program_output(result) 729 | elif prompt_type in ["pot", "pal"] and execute: 730 | code = extract_program(result) 731 | prediction, report = executor.apply(code) 732 | else: 733 | prediction = extract_answer(result, data_name) 734 | 735 | # prediction = strip_string(prediction, skip_unit=data_name == "carp_en") 736 | prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS) 737 | return prediction, report 738 | 739 | 740 | def _test_extract_answer(): 741 | text = """ 742 | This is still not equal to $0$, so we must have made another mistake. 743 | 744 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: 745 | 746 | \[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\] 747 | 748 | This is still not equal to $0$, so we must have made another mistake. 749 | 750 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: 751 | 752 | \[\frac{386}{64} 753 | """ 754 | print(extract_answer(text, "math-oai", use_last_number=False)) 755 | print(choice_answer_clean("\mathrm{(D)\}1,008,016")) 756 | # should output a dict 757 | 758 | 759 | if __name__ == "__main__": 760 | _test_extract_answer() 761 | -------------------------------------------------------------------------------- /qwen_math_eval_toolkit/examples.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def get_examples(): 5 | examples = {} 6 | examples["gsm8k"] = [ 7 | ( 8 | "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?", 9 | "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.", 10 | ), 11 | ( 12 | "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", 13 | "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.", 14 | ), 15 | ( 16 | "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?", 17 | "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.", 18 | ), 19 | ( 20 | "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?", 21 | "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.", 22 | ), 23 | ( 24 | "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?", 25 | "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.", 26 | ), 27 | ( 28 | "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", 29 | "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.", 30 | ), 31 | ( 32 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", 33 | "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.", 34 | ), 35 | ( 36 | "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", 37 | "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.", 38 | ), 39 | ] 40 | examples["gsm8k-pal"] = [ 41 | ( 42 | "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", 43 | '```python\ndef solution():\n """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"""\n money_initial = 23\n bagels = 5\n bagel_cost = 3\n money_spent = bagels * bagel_cost\n money_left = money_initial - money_spent\n result = money_left\n return result\n```', 44 | ), 45 | ( 46 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", 47 | '```python\ndef solution():\n """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?"""\n golf_balls_initial = 58\n golf_balls_lost_tuesday = 23\n golf_balls_lost_wednesday = 2\n golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n result = golf_balls_left\n return result\n```', 48 | ), 49 | ( 50 | "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", 51 | '```python\ndef solution():\n """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"""\n computers_initial = 9\n computers_per_day = 5\n num_days = 4 # 4 days between monday and thursday\n computers_added = computers_per_day * num_days\n computers_total = computers_initial + computers_added\n result = computers_total\n return result\n```', 52 | ), 53 | ] 54 | examples["gsm8k-tora"] = [ 55 | ( 56 | "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", 57 | "```python\ndef money_left():\n money_initial = 23\n bagels = 5\n bagel_cost = 3\n money_spent = bagels * bagel_cost\n remaining_money = money_initial - money_spent\n return remaining_money\n \nremaining_money = money_left()\nprint(remaining_money)\n```\n```output\n8\n```\nOlivia has $\\boxed{8}$ dollars left.", 58 | ), 59 | ( 60 | "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", 61 | "```python\ndef remaining_golf_balls():\n golf_balls_initial = 58\n golf_balls_lost_tuesday = 23\n golf_balls_lost_wednesday = 2\n golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n remaining_golf_balls = golf_balls_left\n return remaining_golf_balls\n\nanswer = remaining_golf_balls() \nprint(answer)\n```\n```output\n33\n```\nMichael had $\\boxed{33}$ golf balls at the end of Wednesday.", 62 | ), 63 | ( 64 | "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", 65 | "```python\ndef total_computers():\n computers_initial = 9\n computers_per_day = 5\n num_days = 4 # 4 days between monday and thursday\n computers_added = computers_per_day * num_days\n computers_total = computers_initial + computers_added\n return computers_total\n\ntotal_computers = total_computers()\nprint(total_computers)\n```\n```output\n29\n```\nThere're $\\boxed{29}$ computers in the server room.", 66 | ), 67 | ] 68 | 69 | examples["math"] = [ 70 | ( 71 | "Kevin Kangaroo begins hopping on a number line at 0. He wants to get to 1, but he can hop only $\\frac{1}{3}$ of the distance. Each hop tires him out so that he continues to hop $\\frac{1}{3}$ of the remaining distance. How far has he hopped after five hops? Express your answer as a common fraction.", 72 | "Let's think step by step\nKevin hops $1/3$ of the remaining distance with every hop.\nHis first hop takes $1/3$ closer.\nFor his second hop, he has $2/3$ left to travel, so he hops forward $(2/3)(1/3)$.\nFor his third hop, he has $(2/3)^2$ left to travel, so he hops forward $(2/3)^2(1/3)$.\nIn general, Kevin hops forward $(2/3)^{k-1}(1/3)$ on his $k$th hop.\nWe want to find how far he has hopped after five hops.\nThis is a finite geometric series with first term $1/3$, common ratio $2/3$, and five terms.\nThus, Kevin has hopped $\\frac{\\frac{1}{3}\\left(1-\\left(\\frac{2}{3}\\right)^5\\right)}{1-\\frac{2}{3}} = \\boxed{\\frac{211}{243}}$.\nThe answer is \\frac{211}{243}}", 73 | ), 74 | ( 75 | "What is the area of the region defined by the equation $x^2+y^2 - 7 = 4y-14x+3$?", 76 | "Let's think step by step\nWe rewrite the equation as $x^2 + 14x + y^2 - 4y = 10$ and then complete the square,\nresulting in $(x+7)^2-49 + (y-2)^2-4=10$,\nor $(x+7)^2+(y-2)^2=63$.\nThis is the equation of a circle with center $(-7, 2)$ and radius $\\sqrt{63},$\nso the area of this region is $\\pi r^2 = \\boxed{63\\pi}$.\nThe answer is 63\\pi", 77 | ), 78 | ( 79 | "If $x^2+y^2=1$, what is the largest possible value of $|x|+|y|$?", 80 | "Let's think step by step\nIf $(x,y)$ lies on the circle,\nso does $(x,-y),$ $(-x,-y),$ and $(-x,-y),$ (which all give the same value of $|x| + |y|$),\nso we can assume that $x \\ge 0$ and $y \\ge 0.$\nThen $|x| + |y| = x + y.$ Squaring, we get\n\\[(x + y)^2 = x^2 + 2xy + y^2 = 1 + 2xy.\\]\nNote that $(x - y)^2 \\ge 0.$\nExpanding, we get $x^2 - 2xy + y^2 \\ge 0,$ so $2xy \\le x^2 + y^2 = 1.$\nHence,\\[1 + 2xy \\le 2,\\]which means $x + y \\le \\sqrt{2}.$\nEquality occurs when $x = y = \\frac{1}{\\sqrt{2}},$\nso the maximum value of $|x| + |y|$ is $\\boxed{\\sqrt{2}}.$\nThe answer is \\sqrt{2}", 81 | ), 82 | ( 83 | "If $f(x)=\\frac{ax+b}{cx+d}, abcd\\not=0$ and $f(f(x))=x$ for all $x$ in the domain of $f$, what is the value of $a+d$?", 84 | "Let's think step by step\nThe condition $f(f(x))$ means that $f$ is the inverse of itself,\nso its graph is symmetrical about the line $y = x$.\nWith a rational function of this form, we will have two asymptotes:\na vertical one at $x=-d/c$ if $cx+d$ does not divide $ax+b$,\nand a horizontal one at $y=a/c$,\nif we take the limit of $f(x)$ as $x$ goes to $\\pm\\infty$.\nIn order for $f$ to be its own inverse, the intersection of the asymptotes must lie on the line $y=x$\nso that it and its asymptotes reflect onto themselves.\nThis means that $-d/c=a/c$,\nand therefore $-d=a$ and $a+d=\\boxed{0}$.\nThe answer is 0", 85 | ), 86 | ( 87 | "Expand $(2z^2 + 5z - 6)(3z^3 - 2z + 1)$.", 88 | "Let's think step by step\n$$\\begin{array}{crrrrrrr}\n& & & 3z^3 & & -2z & + 1 & \\\\\n\\times & & & & 2z^2 & +5z & -6 \\\\\n\\cline{1-7}\\rule{0pt}{0.17in}\n& & & -18z^3 & & +12z & -6 & \\\\\n& & +15z^4 & & -10z^2 & +5z & & \\\\\n+ & 6z^5 & & -4z^3 & +2z^2 & & & \\\\\n\\cline{1-7}\\rule{0pt}{0.17in}\n& 6z^5 & +15z^4 & -22z^3 & - 8z^2 &+17z & -6 &\n\\end{array}$$\nThe answer is 6z^5+15z^4-22z^3-8z^2+17z-6", 89 | ), 90 | ] 91 | 92 | examples["math_pal"] = [ 93 | ( 94 | "Display the final result in LaTeX.\n\n Find the coefficient of $x^3$ when $3(x^2 - x^3+x) +3(x +2x^3- 3x^2 + 3x^5+x^3) -5(1+x-4x^3 - x^2)$ is simplifie.", 95 | "```python\nfrom sympy import symbols, simplify\n\ndef solution():\n x = symbols('x')\n expr = 3*(x**2 - x**3 + x) + 3*(x + 2*x**3 - 3*x**2 + 3*x**5 + x**3) - 5*(1 + x - 4*x**3 - x**2)\n simplified_expr = simplify(expr)\n\n x3_coefficient = simplified_expr.as_coefficients_dict()[x**3]\n result = x3_coefficient\n return result\n```", 96 | ), 97 | ( 98 | "The surface area of a sphere with radius $r$ is $4\\pi r^2$. Including the area of its circular base, what is the total surface area of a hemisphere with radius 6 cm? Express your answer in terms of $\\pi$.", 99 | "```python\nimport math\n\ndef solution():\n radius = 6\n\n # Surface area of the hemisphere\n hemisphere_area = 2 * math.pi * radius**2\n\n # Area of the circular base\n base_area = math.pi * radius**2\n\n # Total surface area\n total_surface_area = hemisphere_area + base_area\n\n # Formatting the result in LaTeX\n result = r'{}\\\\pi'.format(total_surface_area / math.pi)\n return result\n```", 100 | ), 101 | ( 102 | "Monica tosses a fair 6-sided die. If the roll is a prime number, then she wins that amount of dollars (so that, for example, if she rolls 3, then she wins 3 dollars). If the roll is composite, she wins nothing. Otherwise, she loses 3 dollars. What is the expected value of her winnings on one die toss? Express your answer as a dollar value to the nearest cent.", 103 | '```python\ndef solution():\n # Probabilities of each outcome\n prime_prob = 1 / 6\n composite_prob = 1 / 3\n otherwise_prob = 1 / 6\n\n # Expected value of each outcome\n prime_expected_value = (2 * prime_prob) + (3 * prime_prob) + (5 * prime_prob)\n composite_expected_value = 0 * composite_prob\n otherwise_expected_value = -3 * otherwise_prob\n\n # Total expected value\n total_expected_value = prime_expected_value + composite_expected_value + otherwise_expected_value\n\n # Dollar value to the nearest cent\n result = "{:.2f}".format(total_expected_value)\n return result\n```', 104 | ), 105 | ( 106 | "Given $\\mathbf{a} = \\begin{pmatrix} -7 \\\\ 0 \\\\ 1 \\end{pmatrix}$ and $\\mathbf{b} = \\begin{pmatrix} 4 \\\\ 2 \\\\ -1 \\end{pmatrix},$ find $\\mathbf{a} - 3 \\mathbf{b}.$", 107 | "```python\nimport numpy as np\n\ndef solution()\n a = np.array([-7, 0, 1])\n b = np.array([4, 2, -1])\n\n result = a - 3 * b\n\n result = r'\\begin{{pmatrix}} {} \\\\ {} \\\\ {} \\end{{pmatrix}}'.format(result[0], result[1], result[2])\n return result\n```", 108 | ), 109 | ] 110 | examples["math-tora"] = [ 111 | ( 112 | "Convert the point $(0, -3 \\sqrt{3}, 3)$ in rectangular coordinates to spherical coordinates. Enter your answer in the form $(\\rho,\\theta,\\phi),$ where $\\rho > 0,$ $0 \\le \\theta < 2 \\pi,$ and $0 \\le \\phi \\le \\pi.$", 113 | "To convert the point $(x, y, z)$ in rectangular coordinates to spherical coordinates $(\\rho, \\theta, \\phi)$, we use the following formulas:\n1. $\\rho = \\sqrt{x^2 + y^2 + z^2}$\n2. $\\theta = \\arctan{\\frac{y}{x}}$\n3. $\\phi = \\arccos{\\frac{z}{\\rho}}$\n\nLet's write a function:\n\n```python\nfrom sympy import sqrt, atan2, acos, pi\n\ndef rectangular_to_spherical():\n x, y, z = 0, -3*sqrt(3), 3\n rho = sqrt(x**2 + y**2 + z**2)\n theta = atan2(y, x)\n phi = acos(z/rho)\n return rho, theta, phi\n\nspherical_coordinates = rectangular_to_spherical()\nprint(spherical_coordinates)\n```\n```output\n(6, -pi/2, pi/3)\n```\n\nThe required form of the answer is $(\\rho, \\theta, \\phi)$, where $\\rho > 0$, $0 \\le \\theta < 2\\pi$, and $0 \\le \\phi \\le \\pi$. Since our initial output has $\\theta = -\\pi/2$, which doesn't meet the condition for $\\theta$, we add $2\\pi$ to adjust its value.\n\nHence, the spherical coordinates for the point $(0, -3\\sqrt{3}, 3)$ are $\\boxed{(6, \\frac{3\\pi}{2}, \\frac{\\pi}{3})}$.", 114 | ), 115 | ( 116 | "Compute $1011_2 + 101_2 - 1100_2 + 1101_2$. Express your answer in base 2.", 117 | 'Let\'s write a single python function for binary computation:\n\n```python\ndef binary_sum_diff():\n # Convert binary to decimal\n num1 = int("1011", 2)\n num2 = int("101", 2)\n num3 = int("1100", 2)\n num4 = int("1101", 2)\n\n # Perform the operations\n result = num1 + num2 - num3 + num4\n\n # Convert the result back to binary\n result_binary = format(result, "b")\n\n return result_binary\n\nresult = binary_sum_diff()\nprint(result)\n```\n```output\n10001\n```\nThe answer expressed in base 2 is $\\boxed{10001_2}$.', 118 | ), 119 | ( 120 | "Find the matrix that corresponds to projecting onto the vector $\\begin{pmatrix} 2 \\\\ -3 \\end{pmatrix}.$", 121 | "```python\nfrom sympy import Matrix, simplify\n\ndef projection_matrix():\n v = Matrix([2, -3])\n\n # Calculate the projection matrix\n P = (v * v.T) / (v.T * v)[0]\n\n # Simplify the elements\n P = P.applyfunc(simplify)\n\n return P\n\nresult = projection_matrix()\nprint(result)\n```\n```output\nMatrix([[4/13, -6/13], [-6/13, 9/13]])\n```\n\nThe matrix is $\\boxed{\\begin{pmatrix} \\frac{4}{13} & -\\frac{6}{13} \\ -\\frac{6}{13} & \\frac{9}{13} \\end{pmatrix}}.$", 122 | ), 123 | ( 124 | "Solve the inequality\n\\[\\frac{x^2 - 25}{x + 5} < 0.\\]", 125 | "Let's use `sympy` to solve the inequality:\n\n```python\nfrom sympy import symbols, simplify\nfrom sympy.solvers.inequalities import solve_univariate_inequality\nfrom sympy.core.relational import StrictLessThan\n\ndef solve_inequality():\n x = symbols('x')\n expression = (x**2 - 25) / (x + 5)\n inequality = StrictLessThan(expression, 0)\n solution = solve_univariate_inequality(inequality, x, relational=False)\n simplified_solution = simplify(solution)\n \n return simplified_solution\n\nresult = solve_inequality()\nprint(result)\n```\n```output\nUnion(Interval.open(-oo, -5), Interval.Lopen(-5, 5))\n```\nThe answer is $\\boxed{(-\\infty,-5)\\cup(-5,5)}$", 126 | ), 127 | ( 128 | 'In the figure, triangles $ABC$ and $BCD$ are equilateral triangles. What is the value of $AD \\div BC$ when expressed in simplest radical form?\n\n[asy]\ndraw((0,0)--(5,8.7)--(10,0)--cycle);\ndraw((10,0)--(15,8.7)--(5,8.7));\nlabel("$A$",(0,0),SW);\nlabel("$B$",(5,8.7),N);\nlabel("$C$",(10,0),SE);\nlabel("$D$",(15,8.7),NE);\n[/asy]', 129 | "```python\nfrom sympy import Rational, sqrt, simplify\n\ndef ad_divided_by_bc():\n\n x = Rational(1, 1) # Side length of equilateral triangles\n\n ad_squared = 2 * x**2 * (1 + Rational(1, 2)) # Using the law of cosines with cos(2*pi/3) = -1/2\n ad = sqrt(ad_squared)\n\n bc = x # BC is the side length of the equilateral triangles\n\n simplified_ratio = simplify(ad / bc)\n\n return simplified_ratio\n\nresult = ad_divided_by_bc()\nprint(result)\n```\n```output\nsqrt(3)\n```\nThe value of $AD \\div BC$ is $\\boxed{\\sqrt{3}}$.", 130 | ), 131 | ] 132 | examples["carp_en"] = [ 133 | ( 134 | "If $3 a ^ { m + 2 } b$ and $\\frac { 1 } { 2 } ab ^ { n - 1 }$ are similar terms, then $m + n$ is equal to.", 135 | "Let's think step by step\n$3 a ^ { m + 2 } b$ and $\\frac { 1 } { 2 } ab ^ { n - 1 }$ are like terms. We can obtain $m + 2 = 1$ and $n - 1 = 1$. Solving for $m$ and $n$, we get $m = - 1$ and $n = 2$. Therefore, $m + n = - 1 + 2 = 1$.\nThe answer is: 1", 136 | ), 137 | ( 138 | "The solution to the equation $y - \\frac { y - 1 } { 2 } = - \\frac { y + 2 } { 5 }$ is ____ ?", 139 | "Let's think step by step\nTo eliminate the denominator, we have $10 y - 5 ( y - 1 ) = - 2 ( y + 2 )$. Expanding the brackets gives $10 y - 5 y + 5 = - 2 y - 4$. Rearranging terms gives $10 y - 5 y + 2 y = - 4 - 5$, which simplifies to $7 y = - 9$. Dividing both sides by 7 gives $y = - \\frac { 9 } { 7 }$.\nThe answer is: y = - \\frac { 9 } { 7 }", 140 | ), 141 | ( 142 | "If $( m + 4 ) ^ 2 + | n - 3 | = 0$, then $\\frac { 1 } { 2 } m - n$ = ____?", 143 | "Let's think step by step\n$\\because ( m + 4 ) ^ 2 + | n - 3 | = 0$, $\\therefore m + 4 = 0$, $n - 3 = 0$, which means $m = - 4$, $n = 3$. Then the original expression equals $- 2 - 3 = - 5$.\nThe answer is: - 5", 144 | ), 145 | ( 146 | "Given a quadratic equation in one variable $x$, $x ^ 2 + x + m = 0$, with one root being $x = 1$, what is the other root of this equation?", 147 | "Let's think step by step\nSuppose the quadratic equation in one variable about $x$ is $x ^ 2 + x + m = 0$, and $\\alpha$ is another real root of the equation. Since one real root of the quadratic equation in one variable about $x$ is $1$, we have $\\alpha + 1 = - 1$. Therefore, $\\alpha = - 2$.\nThe answer is: - 2", 148 | ), 149 | ( 150 | "The parabola $y = - 5 { x } ^ 2 + 1$ is translated $2$ units upward and $1$ unit to the left, resulting in the parabola _____.", 151 | "Let's think step by step\nThe parabola $y = - 5 { x } ^ 2 + 1$ is first shifted upward by 2 units, resulting in $y = - 5 { x } ^ 2 + 3$. Then it is shifted left by 1 unit, resulting in $y = - 5 {( x + 1 )} ^ 2 + 3$.\nThe answer is: y = - 5 ( x + 1 ) ^ { 2 } + 3", 152 | ), 153 | ( 154 | "If the radical $\\sqrt { x - 8 }$ is defined, then the range of real numbers for $x$ is ____ ?", 155 | "Let's think step by step\nSince the radical $\\sqrt { x - 8 }$ is defined, therefore $x - 8 \\ge 0$, which implies $x \\ge 8$.\nThe answer is: x \\ge 8", 156 | ), 157 | ( 158 | "If $a ^ { m } \\times a ^ { 2 } = a ^ { 7 }$, then the value of $m$ is ____?", 159 | "Let's think step by step\nAccording to the multiplication rule of powers with the same base: when multiplying powers with the same base, keep the base the same and add the exponents. We have $m + 2 = 7$, so solving for $m$ gives $m = 5$.\nThe answer is: 5", 160 | ), 161 | ( 162 | "If line segment $a$ and $b$ satisfy $\\frac { a } { b } = \\frac { 5 } { 2 }$, then the value of $\\frac { a - b } { b }$ is ____?", 163 | "Let's think step by step\n$\\because \\frac { a } { b } = \\frac { 5 } { 2 }$, $\\therefore$ we can assume $a = 5 k$, then $b = 2 k$, $\\therefore \\frac { a - b } { b } = \\frac { 5 k - 2 k } { 2 k } = \\frac { 3 } { 2 }$.\nThe answer is: \\frac { 3 } { 2 }", 164 | ), 165 | ] 166 | 167 | examples["minerva_math"] = [ 168 | ( 169 | "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", 170 | "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", 171 | ), 172 | ( 173 | "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", 174 | "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$", 175 | ), 176 | ( 177 | "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", 178 | "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}", 179 | ), 180 | ( 181 | "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.", 182 | "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$", 183 | ), 184 | ] 185 | 186 | examples["aqua"] = [ 187 | ( 188 | "John found that the average of 15 numbers is 40. If 10 is added to each number then the mean of the numbers is?\nAnswer Choices: (A) 50 (B) 45 (C) 65 (D) 78 (E) 64", 189 | "If 10 is added to each number, then the mean of the numbers also increases by 10. So the new mean would be 50. The answer is (A).", 190 | ), 191 | ( 192 | "If a / b = 3/4 and 8a + 5b = 22,then find the value of a.\nAnswer Choices: (A) 1/2 (B) 3/2 (C) 5/2 (D) 4/2 (E) 7/2", 193 | "a / b = 3/4, then b = 4a / 3. So 8a + 5(4a / 3) = 22. This simplifies to 8a + 20a / 3 = 22, which means 44a / 3 = 22. So a is equal to 3/2. The answer is (B).", 194 | ), 195 | ( 196 | "A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?\nAnswer Choices: (A) 53 km (B) 55 km (C) 52 km (D) 60 km (E) 50 km", 197 | "The distance that the person traveled would have been 20 km/hr * 2.5 hrs = 50 km. The answer is (E).", 198 | ), 199 | ( 200 | "How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: (A) 1156 (B) 1392 (C) 1480 (D) 1562 (E) 1788", 201 | "There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401 three-digit numbers from 100 to 500. 9 + 90(2) + 401(3) = 1392. The answer is (B).", 202 | ), 203 | ] 204 | examples["sat_math"] = [ 205 | ( 206 | "If $\frac{x-1}{3}=k$ and $k=3$, what is the value of $x$ ? \nAnswer Choices: (A) 2 (B) 4 (C) 9 (D) 10", 207 | "If k = 3, then x - 1 = 3 * 3, therfore, x - 1 = 9 and x = 10. The answer is D", 208 | ), 209 | ( 210 | "For $i=\\sqrt{-1}$, what is the sum $(7+3 i)+(-8+9 i)$ ? \nAnswer Choices: (A) $-1+12 i$ (B) $-1-6 i$ (C) $15+12 i$ (D) $15-6 i$ 3", 211 | "For (7+3 i)+(-8+9 i), the real part is 7 + (-8) = -1, the imageinary part is 3 i + 9 i = 12 i. The answer is A", 212 | ), 213 | ( 214 | "On Saturday afternoon, Armand sent $m$ text messages each hour for 5 hours, and Tyrone sent $p$ text messages each hour for 4 hours. Which of the following represents the total number of messages sent by Armand and Tyrone on Saturday afternoon?\nAnswer Choices: (A) $9 m p$ (B) $20 m p$ (C) $5 m+4 p$ (D) $4 m+5 p$", 215 | "Armand texts m messages each hour for 5 hours, which leads to 5m messages. Tyrone texts p messages each hour for 4 hours, which leds to 4p messages. The total is 5m + 4p. The answer is C.", 216 | ), 217 | ( 218 | "$$\begin{array}{r}3 x+4 y=-23 \\2 y-x=-19\\end{array}$$What is the solution $(x, y)$ to the system of equations above?\nAnswer Choices: (A) $(-5,-2)$ (B) $(3,-8)$ (C) $(4,-6)$ (D) $(9,-6)$", 219 | "By solving this equation, we found that x = 3 and y = -8. The answer is B.", 220 | ), 221 | ] 222 | examples["mmlu_mathematics"] = [ 223 | ( 224 | "Simplify and write the result with a rational denominator: $$\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}$$\nAnswer Choices: (A) \\frac{3\\sqrt{3}}{3} (B) \\frac{1}{3} (C) \\sqrt{3} (D) \\frac{\\sqrt{3}}{3}", 225 | "Factoring $729=3^6$ and combining the roots $\frac{1}{2}\frac{1}{3}\frac{1}{2}=\frac{1}{12}$, we get that $\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}=\\left(\frac{1}{3^6}\right)^{\frac{1}{12}}=\frac{1}{3^{\frac{1}{2}}}=\frac{3}{\\sqrt{3}}$. The answer is (D).", 226 | ), 227 | ( 228 | "Five thousand dollars compounded annually at an $x\\%$ interest rate takes six years to double. At the same interest rate, how many years will it take $\\$300$ to grow to $\\$9600$?\nAnswer Choices:(A) 12 (B) 1 (C) 30 (D) 5", 229 | "To go from $\\$300$ to $\\$9600$, the value must go up by a factor of $9600/300=32=2^5$. Since at this interest rate it takes six years for it to double, it will take $5*6=30$ years to grow to $\\$9600$. The answer is (C).", 230 | ), 231 | ( 232 | "Ten students take a biology test and receive the following scores: 45, 55, 50, 70, 65, 80, 40, 90, 70, 85. What is the mean of the students’ test scores?\nAnswer Choices: (A) 55 (B) 60 (C) 62 (D) 65", 233 | "There are 10 students and the sum of their scores is $45 + 55 + 50 + 70 + 65 + 80 + 40 + 90 + 70 + 85 = 650$, the mean is $650/10=65$. The answer is (D).", 234 | ), 235 | ( 236 | "The variable $x$ varies directly as the square of $y$, and $y$ varies directly as the cube of $z$. If $x$ equals $-16$ when $z$ equals 2, what is the value of $x$ when $z$ equals $\frac{1}{2}$?\nAnswer Choices: (A) -1 (B) 16 (C) -\frac{1}{256} (D) \\frac{1}{16}", 237 | "We know that $x \\propto y^2$ and $y \\propto z^3$, so $x = k z^6$ for some constant $k$. Plugging in for $x=-16$ and $z=2$, the constant value is $k=\frac{x}{z^6}=\frac{-16}{64}=-\frac{1}{4}$. So, when $z=\frac{1}{2}$, the value of $x$ is $x=kz^6=-\frac{1}{4}\frac{1}{2^6}=-\frac{1}{256}$. The answer is (C).", 238 | ), 239 | ( 240 | "Joe was in charge of lights for a dance. The red light blinks every two seconds, the yellow light every three seconds, and the blue light every five seconds. If we include the very beginning and very end of the dance, how many times during a seven minute dance will all the lights come on at the same time? (Assume that all three lights blink simultaneously at the very beginning of the dance.)\nAnswer Choices: (A) 3 (B) 15 (C) 6 (D) 5", 241 | "The least common multiple of 2, 3 and 5 is 30, so during a 7 minute dance, all the three lights will come on at the same time $2*7+1=15$ times. The answer is (B).", 242 | ), 243 | ] 244 | examples["mmlu_physics"] = [ 245 | ( 246 | "A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\nAnswer Choices: (A) 10 W (B) 30 W (C) 60 W (D) 240 W", 247 | "Rate of energy usage is known as power; in an dissipative electrical circuit, power is given by voltage times current. So in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D).", 248 | ), 249 | ( 250 | "A point charge, Q = +1 mC, is fixed at the origin. How much work is required to move a charge, Q = +8 µC, from the point (0, 4 meters) to the point (3 meters, 0)?\nAnswer Choices: (A) 3.5 J (B) 6.0 J (C) 22.5 J (D) 40 J", 251 | "To calculate the work required to move a charge from one location to another in a fixed electric field, it is enough to calculate the potential difference between the two locations. Here, the potential only depends on the distance between the charges; it’s $k q_1 q_2 / r$, where $k$ is Coulomb’s constant. Plugging in values $q_1 = $ 1 mC, $q_2 = 8 \\mu$ C, gives the answer as 5.992 J, which rounds to 6 J. The answer is (B).", 252 | ), 253 | ( 254 | "Which of the following conditions will ensure that angular momentum is conserved? I. Conservation of linear momentum II. Zero net external force III. Zero net external torque.\nAnswer Choices: (A) I and II only (B) I and III only (C) II and III only (D) III only", 255 | "Torque is defined as the change in angular momentum; if there is zero external torque, angular momentum is conserved. The answer is (D).", 256 | ), 257 | ( 258 | "A photocell of work function ϕ = 2eV is connected to a resistor in series. Light of frequency f = 1 × 10^15 Hz hits a metal plate of the photocell. If the power of the light is P = 100 W, what is the current through the resistor?\nAnswer Choices: (A) 2:00 AM (B) 6:00 AM (C) 12:00 AM (D) 24 A", 259 | "The only answer above which has units of current is D, 24 A. The answer is (D).", 260 | ), 261 | ( 262 | "A pipe full of air is closed at one end. A standing wave is produced in the pipe, causing the pipe to sound a note. Which of the following is a correct statement about the wave’s properties at the closed end of the pipe?\nAnswer Choices: (A) The pressure is at a node, but the particle displacement is at an antinode. (B) The pressure is at an antinode, but the particle displacement is at a node. (C) The pressure and the particle displacement are both at nodes. (D) The pressure and the particle displacement are both at antinodes.", 263 | "At the closed end of the pipe, the particles cannot have any net displacement because the pipe closure stops them. So the particle displacement is at a node. This closure also causes the pressure to be maximal, i.e. an antinode. The answer is (B).", 264 | ), 265 | ] 266 | examples["mmlu_chemistry"] = [ 267 | ( 268 | "Which of the following is considered an acid anhydride?\nAnswer Choices: (A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3", 269 | "An acid anhydride is a compound that is derived by removing water from an acid. The chemical formula for water is H2O, which means that we need to determine which of these options, when combined with H2O, forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, or sulfuric acid. The answer is (C).", 270 | ), 271 | ( 272 | "Which of the following is expected to be a polar molecule?\nAnswer Choices: (A) PCl4F (B) BF3 (C) CO2 (D) Si(CH3)4", 273 | "A polar molecule is one that has a slightly positive charge on one end of the molecule and a slightly negative charge on the other end. Boron trifluoride (BF3) has Boron as the center atom and three fluorine atoms attached to it; it is trigonal planar and symmetric, so it is nonpolar. Carbon Dioxide (CO2) has Carbon as the central atom with double bonds to two Oxygen atoms - this is also symmetrical and therefore nonpolar. The same is the case for tetramethyl silane (SI(CH3)4), which is a Silicon atom surrounded by four methyl groups. The structure of PCL4F is that Phosphorus is the central atom, attached to four chlorines and one fluorine atom. This is asymmetrical, and therefore has a net dipole and is expected to be a polar molecule. The answer is (A).", 274 | ), 275 | ( 276 | "From the solubility rules, which of the following is true?\nAnswer Choices: (A) All chlorides, bromides, and iodides are soluble (B) All sulfates are soluble (C) All hydroxides are soluble (D) All ammonium-containing compounds are soluble", 277 | "The chlorides, bromides, and iodides of lead, silver, and mercury are not soluble in water. This rules out (A). The sulfates of lead, barium, and calcium are not soluble in water, which rules out (B). The hydroxides of any metal besides sodium, potassium, ammonium, calcium, and barium are insoluble. This rules out (C). Typically ammonium ions indicate a soluble ionic substance. The answer is (D).", 278 | ), 279 | ( 280 | "A new compound is synthesized and found to be a monoprotic acid with a molar mass of 248 g/mol. When 0.0050 mol of this acid are dissolved in 0.500 L of water, the pH is measured as 3.89. What is the pKa of this acid?\nAnswer Choices: (A) 3.89 (B) 7.78 (C) 5.78 (D) 2.33", 281 | "Recall that $[A] = [H^{+}]$. Here, this is equal to $$10^{-3.89}$. Then we have $K_{a} = $frac{[H^{+}][A^{-}]}{[HA]} = \\frac{10^{-3.89} \\cdot 10^{-3.89}}{10^{-2}}. The resulting exponent is $-3.89 + (-3.89) - (-2) = 5.78$, therefore $K_a = 10^{-5.78}$. The $pK_a$ is the negative log of $K_a$, which is equal to $5.78$. The answer is (C).", 282 | ), 283 | ( 284 | "A solution contains 2.00 mole of acetic acid, CH3COOH, and 1.00 mole of calcium acetate, Ca(CH3COO)2. The solution is able to resist the addition of a small amount of strong acid or strong base with only minor changes in the pH of the solution. Larger quantities of strong acid or strong base can cause a significant change in pH. How many moles of nitric acid, HNO3, may be added before the pH begins to change significantly?\nAnswer Choices: (A) 0.500 mole (B) 1.00 mole (C) 2.00 mole (D) 3.00 mole", 285 | "We would like to compute the buffer capacity of this solution. First we write the equation for the ionization of the weak acid, in this case of acetic acid. $CH_{3}COOH (aq) + H_{2}O \rightarrow H_{3}O^{+} + CH3COO^{-}$. The conjugate base is therefore the acetate ion. The added strong acid, Nitric acid, will react with the conjugate base. Therefore the maximum amount of acid that can be added will be equal to the amount of acetate ion, or 2 moles. The answer is (C).", 286 | ), 287 | ] 288 | examples["mmlu_biology"] = [ 289 | ( 290 | "In animal cells, which of the following represents the most likely pathway that a secretory protein takes as it is synthesized in a cell?\nAnswer Choices: (A) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (B) Ribosome–Golgi apparatus–rough ER–secretory vesicle–plasma membrane (C) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (D) Ribosome–rough ER–Golgi apparatus–secretory vesicle–plasma membrane", 291 | "Protein synthesis starts at the ribosome, so we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum and moves from there to the Golgi apparatus, where it is modified and packaged into a vesicle. The vesicle then floats to the plasma membrane and is secreted. The answer is (D).", 292 | ), 293 | ( 294 | "A mutation in a bacterial enzyme changed a previously polar amino acid into a nonpolar amino acid. This amino acid was located at a site distant from the enzyme’s active site. How might this mutation alter the enzyme’s substrate specificity?\nAnswer Choices: (A) By changing the enzyme’s pH optimum (B) By changing the enzyme’s location in the cell (C) By changing the shape of the protein (D) An amino acid change away from the active site cannot alter the enzyme’s substrate specificity.", 295 | "A change in an amino acid leads to a change in the primary structure of the protein. A change in the primary structure may lead to a change in the secondary and the tertiary structure of the protein. A change in the tertiary structure means a change in the shape of the protein, so (C) has to be correct. Since the change does not affect the active site of the enzyme, we do not expect the activity of the enzyme to be affected. The answer is (C).", 296 | ), 297 | ( 298 | "Which of the following is not a way to form recombinant DNA?\nAnswer Choices: (A) Translation (B) Conjugation (C) Specialized transduction (D) Transformation", 299 | "The introduction of foreign DNA or RNA into bacteria or eukaryotic cells is a common technique in molecular biology and scientific research. There are multiple ways foreign DNA can be introduced into cells including transformation, transduction, conjugation, and transfection. In contrast, (A) is not a way to form DNA: during translation the ribosomes synthesize proteins from RNA. The answer is (A).", 300 | ), 301 | ( 302 | "Homologous structures are often cited as evidence for the process of natural selection. All of the following are examples of homologous structures EXCEPT\nAnswer Choices: (A) the wings of a bird and the wings of a bat (B) the flippers of a whale and the arms of a man (C) the pectoral fins of a porpoise and the flippers of a seal (D) the forelegs of an insect and the forelimbs of a dog", 303 | "Homologous structures are similar physical features in organisms that share a common ancestor ​​but different functions. Comparisons (B) and (C) are clearly homologous because they share a common ancestor and the structures serve different purposes. Bat wings and birg wings are also homologous, while they are both wings, the forelimbs serve different purposes. Insects and dogs are very far ancestors since one is vertebrate while the other is invertebrate and the forelimbs serve the same purpose, so they are not homologous. The answer is (D).", 304 | ), 305 | ( 306 | "Which of the following is not known to be involved in the control of cell division?\nAnswer Choices: (A) Cyclins (B) Protein kinases (C) Checkpoints (D) Fibroblast cells", 307 | "Normal cells move through the cell cycle in a regulated way. At the checkpoint stage, they use information about their own internal state and cues from the environment around them to decide whether to proceed with cell division. Cues like these act by changing the activity of core cell cycle regulators inside the cell. The most common regulators are cyclins and cyclin-dependent kinases. Fibroblast cells do not play any role in cell division. The answer is (D).", 308 | ), 309 | ] 310 | examples["mmlu_computer"] = [ 311 | ( 312 | "Which of the following is an example of the use of a device on the Internet of Things (IoT) ?\nAnswer Choices: (A) A car alerts a driver that it is about to hit an object. (B) A hiker uses a G P S watch to keep track of her position. (C) A refrigerator orders milk from an online delivery service when the milk in the refrigerator is almost gone. (D) A runner uses a watch with optical sensors to monitor his heart rate.", 313 | "The term Internet of Things (IoT) refers to common devices which are connected to the internet, enabling new functionality. Choice A is incorrect because it does not describe an internet connected device. In choice B, the watch is only described as having GPS functionality but no internet connectivity. Choice C describes a common device (a refrigerator) which has internet connectivity enabling new functionality (online ordering). Choice D does not mention internet connectivity for the watch, only optical sensors. The answer is (C).", 314 | ), 315 | ( 316 | "Many Web browsers allow users to open anonymous windows. During a browsing session in an anonymous window, the browser does not record a browsing history or a list of downloaded files. When the anonymous window is exited, cookies created during the session are deleted. Which of the following statements about browsing sessions in an anonymous window is true?\nAnswer Choices: (A) The activities of a user browsing in an anonymous window will not be visible to people who monitor the user's network, such as the system administrator. (B) Items placed in a Web store's shopping cart for future purchase during the anonymous browsing session will not be saved on the user's computer. (C) A user will not be able to log in to e-mail or social media accounts during the anonymous browsing session. (D) A user browsing in an anonymous window will be protected from viruses launched from any web sites visited or files downloaded.", 317 | "Choice A is incorrect as it only describes network traffic, which an anonymous browser does not change. Choice B is correct as it correctly describes how an anonymous browser will prevent saving data on the user’s computer after the session is ended. Choice C is incorrect because an anonymous browser will not prevent logging in to email or social media accounts. Choice D is incorrect because an anonymous browser in itself performs no virus protection. The answer is (B).", 318 | ), 319 | ( 320 | 'What is the output of "abc"[::-1] in Python 3? \nAnswer Choices: (A) Error (B) abc (C) cba (D) c', 321 | 'We know that the slicing operator [::-1] takes all of the elements in the string in reverse order, so we reverse the order of the string "abc", resulting in "cba". The answer is (C).', 322 | ), 323 | ( 324 | 'In the program below, the initial value of X is 5 and the initial value of Y is 10.\nIF (X < 0){\n DISPLAY ("Foxtrot")\n} ELSE {\n IF (X > Y){\n DISPLAY ("Hotel")\n } ELSE {\n IF (Y > 0){\n DISPLAY ("November")\n } ELSE {\n DISPLAY ("Yankee")\n }\n }\n}\nWhat is displayed as a result of running the program?\nAnswer Choices: (A) Foxtrot (B) Hotel (C) November (D) Yankee', 325 | 'Because X has the value 5, the first conditional IF (X < 0) is false, so we move to the first ELSE clause. Because X is 5 and Y is 10, the second conditional IF (X > Y) is false, so we move to the following ELSE clause. Since Y is 10, the conditional IF (Y > 0) is true, so the command DISPLAY ("November") is executed. The answer is (C).', 326 | ), 327 | ( 328 | "A list of numbers has n elements, indexed from 1 to n. The following algorithm is intended to display the number of elements in the list that have a value greater than 100. The algorithm uses the variables count and position. Steps 3 and 4 are missing.\n Step 1: Set count to 0 and position to 1.\n Step 2: If the value of the element at index position is greater than 100, increase the value of count by 1.\n Step 3: (missing step)\n Step 4: (missing step)\n Step 5: Display the value of count.\nWhich of the following could be used to replace steps 3 and 4 so that the algorithm works as intended?\nAnswer Choices: (A) Step 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until the value of count is greater than 100.\n(B) Step 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until the value of position is greater than n.\n(C) Step 3: Repeat step 2 until the value of count is greater than 100.\n Step 4: Increase the value of position by 1.\n(D) Step 3: Repeat step 2 until the value of position is greater than n.\n Step 4: Increase the value of count by 1.", 329 | "Choice A is incorrect, because its Step 4 has an incorrect termination condition, stopping when count is greater than 100. We need to stop after inspecting all elements in the list. Choice B is correct because it correctly increments both count and position, and correctly repeats these steps and terminates when all elements in the list have been inspected. Choice C is incorrect because it incorrectly increments the variable count until its value is greater than 100, regardless of the elements in the list. Choice D is incorrect because its step 3 does not increment the value of position, so it will repeat forever. The answer is (B).", 330 | ), 331 | ] 332 | # mammoth 333 | examples["mmlu_stem"] = [ 334 | ( 335 | "Simplify and write the result with a rational denominator: $$\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}$$\nAnswer Choices: (A) \\frac{3\\sqrt{3}}{3} (B) \\frac{1}{3} (C) \\sqrt{3} (D) \\frac{\\sqrt{3}}{3}", 336 | "Factoring $729=3^6$ and combining the roots $\\frac{1}{2}\\frac{1}{3}\\frac{1}{2}=\\frac{1}{12}$, we get that $\\sqrt{\\sqrt[3]{\\sqrt{\frac{1}{729}}}}=\\left(\frac{1}{3^6}\right)^{\frac{1}{12}}=\frac{1}{3^{\frac{1}{2}}}=\frac{3}{\\sqrt{3}}$. The answer is (D).", 337 | ), 338 | ( 339 | "In animal cells, which of the following represents the most likely pathway that a secretory protein takes as it is synthesized in a cell?\nAnswer Choices: (A) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (B) Ribosome–Golgi apparatus–rough ER–secretory vesicle–plasma membrane (C) Plasma membrane–Golgi apparatus–ribosome–secretory vesicle–rough ER (D) Ribosome–rough ER–Golgi apparatus–secretory vesicle–plasma membrane", 340 | "Protein synthesis starts at the ribosome, so we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum and moves from there to the Golgi apparatus, where it is modified and packaged into a vesicle. The vesicle then floats to the plasma membrane and is secreted. The answer is (D).", 341 | ), 342 | ( 343 | "A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\nAnswer Choices: (A) 10 W (B) 30 W (C) 60 W (D) 240 W", 344 | "Rate of energy usage is known as power; in an dissipative electrical circuit, power is given by voltage times current. So in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D).", 345 | ), 346 | ( 347 | "Which of the following is considered an acid anhydride?\nAnswer Choices: (A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3", 348 | "An acid anhydride is a compound that is derived by removing water from an acid. The chemical formula for water is H2O, which means that we need to determine which of these options, when combined with H2O, forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, or sulfuric acid. The answer is (C).", 349 | ), 350 | ( 351 | 'What is the output of "abc"[::-1] in Python 3? \nAnswer Choices: (A) Error (B) abc (C) cba (D) c', 352 | 'We know that the slicing operator [::-1] takes all of the elements in the string in reverse order, so we reverse the order of the string "abc", resulting in "cba". The answer is (C).', 353 | ), 354 | ] 355 | examples["gaokao"] = [ 356 | ( 357 | "已知 $\\alpha, \\beta, \\gamma$ 是互不相同的锐角, 则在 $\\sin \\alpha \\cos \\beta, \\sin \\beta \\cos \\gamma, \\sin \\gamma \\cos \\alpha$ 三个值中, 大于 $\\frac{1}{2}$ 的个数的最大值是 ( )\n从以下选项中选择:\n(A) 0\n(B) 1\n(C) 2\n(D) 3", 358 | "1. 如果 $\\alpha, \\beta, \\gamma$ 均小于 $60^\\circ$,那么他们的正弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。\n2. 如果有一个角大于 $60^\\circ$,假设为 $\\alpha$,那么对应的正弦值大于 $\\frac{1}{2}$。此时,由于三角形内角和为 $180^\\circ$,所以 $\\beta + \\gamma < 120^\\circ$。这意味着 $\\beta, \\gamma$ 的余弦值均大于 $\\frac{1}{2}$,所以此时 $\\sin \\alpha \\cos \\beta > \\frac{1}{2}, \\sin \\beta \\cos \\gamma > \\frac{1}{2}$。\n3. 如果有两个角大于 $60^\\circ$,例如 $\\alpha$ 和 $\\beta$,那么由于三角形内角和为 $180^\\circ$,我们可以得到 $\\gamma < 60^\\circ$,此时 $\\sin \\gamma < \\frac{1}{2}$。由于 $\\alpha$ 和 $\\beta$ 的余弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。\n4. 如果三个角都大于 $60^\\circ$,显然不符合题意。\n综上所述,当有一个角大于 $60^\\circ$ 时,大于 $\\frac{1}{2}$ 的个数的最大值是 2。\n答案是 C", 359 | ), 360 | ( 361 | "正方体 $A B C D-A_{1} B_{1} C_{1} D_{1}$ 中, $B B_{1}$ 与平面 $A C D_{1}$ 所成角的余弦值为 ( )\n从以下选项中选择:\n(A) $\\frac{\\sqrt{2}}{3}$\n(B) $\\frac{\\sqrt{3}}{3}$\n(C) $\\frac{2}{3}$\n(D) $\\frac{\\sqrt{6}}{3}$", 362 | "设上下底面的中心分别为 $\\mathrm{O}_{1}, \\mathrm{O}$, 设正方体的棱长等于 1 , 则 $O_{1} O$ 与平面 $A C D_{1}$ 所成角就是 $B B_{1}$ 与平面 $A C D_{1}$ 所成角, 即 $\\angle O_{1} O D_{1}$, 直角三角形 $\\mathrm{OO}_{1} \\mathrm{D}_{1}$ 中, $\\cos \\angle \\mathrm{O}_{1} \\mathrm{OD}_{1}=\\frac{\\mathrm{O}_{1} \\mathrm{O}}{\\mathrm{OD}_{1}}=\\frac{\\frac{1}{\\sqrt{6}}}{2}=\\frac{\\sqrt{6}}{3}$. \n答案是 C", 363 | ), 364 | ( 365 | "设函数 $f(x)=\\left\\{\\begin{array}{ll}1+\\log _{2}(2-x), & x<1 \\ 2^{x-1}, & x \\geqslant 1,\\end{array}\\right.$ 则 $f(-2)+f\\left(\\log _{2} 12\\right)=$ ( )\n从以下选项中选择:\n(A) 3\n(B) 6\n(C) 9\n(D) 12", 366 | "首先,我们可以根据定义计算 $f(-2)$ 和 $f(\\log_2 12)$:\n$f(-2)=1+\\log_2(2-(-2))=1+\\log_2 4=3$\n$f(\\log_2 12)=2^{\\log_2 12-1}=6$\n因此,$f(-2)+f(\\log_2 12)=3+6=9$。\n答案是 C", 367 | ), 368 | ( 369 | "已知函数 $f(x)=a x^{3}-3 x^{2}+1$, 若 $f(x)$ 存在唯一的零点 $x_{0}$, 且 $x_{0}>$ 0 , 则实数 $\\mathrm{a}$ 的取值范围是 ( )\n从以下选项中选择:\n(A) $(1,+\\infty)$\n(B) $(2,+\\infty)$\n(C) $(-\\infty,-1)$\n(D) $(-\\infty,-2)$", 370 | "首先,我们可以通过求出函数的导函数 $f'(x)$ 来判断函数在 $x>0$ 区间内的单调性。在这里,我们求出导函数 $f'(x)$ 为 $f'(x)=3ax^2-6x$。\n然后,我们需要求出导函数 $f'(x)$ 的零点,以确定函数 $f(x)$ 在 $x>0$ 区间内的单调性。导函数 $f'(x)$ 的零点为 $x=0$ 和 $x=\\frac{2}{\\sqrt{a}}$。注意到 $x>0$,所以我们得到 $a<0$。此外,由于函数 $f(x)$ 在 $x=0$ 处的函数值为 $1$,因此不能有 $a=\\frac{4}{3}$。\n综上所述,当 $a$ 的取值范围为 $a<-\\frac{4}{3}$ 时,函数 $f(x)$ 在 $x>0$ 区间内是单调递减的,此时存在唯一的零点 $x_0$。因此,答案为 $(-\\infty,-2)$。\n答案是 D", 371 | ), 372 | ( 373 | "设 $\\left\\{a_{n}\\right\\}$ 是公差不为 0 的无穷等差数列, 则“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的 ( )\n从以下选项中选择:\n(A) 充分而不必要条件\n(B) 必要而不充分条件\n(C) 充分必要条件\n(D) 既不充分也不必要条件", 374 | "首先,我们可以通过举例来判断该条件是充分还是必要条件。如果一个数列递增,那么它的公差一定大于 0,也就是存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$。因此,“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的必要条件。\n接下来,我们需要判断是否充分。也就是说,如果存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$,那么能否得出“ $\\left\\{a_{n}\\right\\}$ 为递增数列”这一结论。答案是肯定的。因为如果 $a_{n}>0$,那么 $a_{n+1}-a_{n}>0$,即公差大于 0,因此该数列是递增的。因此,该条件是充分条件。\n综上所述,选项为 (C) 充分必要条件。\n答案是 C", 375 | ), 376 | ] 377 | 378 | return examples 379 | --------------------------------------------------------------------------------