├── images ├── deer.png ├── deer-ds.png ├── deer-main.png ├── deer-res.png ├── WechatIMG.jpeg └── deer-qwen3.png ├── utils ├── __pycache__ │ ├── grader.cpython-310.pyc │ ├── parser.cpython-310.pyc │ ├── utils.cpython-310.pyc │ ├── examples.cpython-310.pyc │ ├── data_loader.cpython-310.pyc │ └── math_normalization.cpython-310.pyc ├── data_loader.py ├── math_normalization.py ├── utils.py ├── parser.py └── grader.py ├── prompts └── qwen-instruct │ ├── __pycache__ │ ├── aime.cpython-310.pyc │ ├── amc.cpython-310.pyc │ ├── gpqa.cpython-310.pyc │ ├── gsm8k.cpython-310.pyc │ ├── math.cpython-310.pyc │ ├── aime25.cpython-310.pyc │ ├── minerva.cpython-310.pyc │ └── olympiadbench.cpython-310.pyc │ ├── aime.py │ ├── amc.py │ ├── gpqa.py │ ├── gsm8k.py │ ├── math.py │ ├── aime25.py │ ├── minerva.py │ └── olympiadbench.py ├── bashes ├── bash-vanilla-deer.sh ├── bash-check-correct.sh ├── bash-vllm-deer-qwen3.sh └── bash-vllm-deer.sh ├── requirements.txt ├── LICENSE ├── README.md ├── check.py ├── vanilla_deer.py ├── data ├── aime25 │ └── test.jsonl └── amc │ └── test.jsonl ├── vllm-deer-qwen3.py └── vllm-deer.py /images/deer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer.png -------------------------------------------------------------------------------- /images/deer-ds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-ds.png -------------------------------------------------------------------------------- /images/deer-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-main.png -------------------------------------------------------------------------------- /images/deer-res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-res.png -------------------------------------------------------------------------------- /images/WechatIMG.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/WechatIMG.jpeg -------------------------------------------------------------------------------- /images/deer-qwen3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-qwen3.png -------------------------------------------------------------------------------- /utils/__pycache__/grader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/grader.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/parser.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/parser.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/examples.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/examples.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_loader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/data_loader.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/math_normalization.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/math_normalization.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/aime.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/aime.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/amc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/amc.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/gpqa.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/gpqa.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/gsm8k.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/gsm8k.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/math.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/math.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/aime25.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/aime25.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/minerva.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/minerva.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/__pycache__/olympiadbench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/olympiadbench.cpython-310.pyc -------------------------------------------------------------------------------- /prompts/qwen-instruct/aime.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/amc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/gpqa.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/gsm8k.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/math.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/aime25.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/minerva.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /prompts/qwen-instruct/olympiadbench.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." 4 | 5 | few_shot_prompt = "" 6 | 7 | question_format = """{question}""" -------------------------------------------------------------------------------- /bashes/bash-vanilla-deer.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUDA_VISIBLE_DEVICES='1' \ 4 | python ../vanilla_deer.py \ 5 | --model_name_or_path ./DeepSeek-R1-Distill-Qwen-14B \ 6 | --threshold 0.95 \ 7 | --max_len 16384 \ 8 | --dataset math \ 9 | -------------------------------------------------------------------------------- /bashes/bash-check-correct.sh: -------------------------------------------------------------------------------- 1 | 2 | python ../check.py \ 3 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \ 4 | --data_name "math" \ 5 | --generation_path "./outputs/DeepSeek-R1-Distill-Qwen-14B/math/greedy_p0.95_ratio0.9_len16385_temperature0.0_run_time1_no_thinking0_rep0_points1.jsonl" \ 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # common 2 | vllm<=0.6.1 3 | tqdm 4 | datasets 5 | torch 6 | transformers 7 | python_dateutil 8 | flash_attn 9 | 10 | # math_eval 11 | sympy==1.12 12 | antlr4-python3-runtime==4.11.1 # ! The version needs to be compatible with sympy. 13 | word2number 14 | Pebble 15 | timeout-decorator 16 | latex2sympy2==1.9.1 -------------------------------------------------------------------------------- /bashes/bash-vllm-deer-qwen3.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer-qwen3.py \ 4 | --model_name_or_path "./Qwen3-4B" \ 5 | --dataset_dir "./data/" \ 6 | --output_path "./outputs" \ 7 | --dataset "math" \ 8 | --threshold 0.95 \ 9 | --max_generated_tokens 16000 \ 10 | --think_ratio 0.8 \ 11 | --policy avg2 \ 12 | --batch_size 2000 \ 13 | --dtype bfloat16 \ 14 | --gpu-memory-utilization 0.9 \ 15 | -------------------------------------------------------------------------------- /bashes/bash-vllm-deer.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer.py \ 4 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \ 5 | --dataset_dir "./data/" \ 6 | --output_path "./outputs" \ 7 | --dataset "math" \ 8 | --threshold 0.95 \ 9 | --max_generated_tokens 16000 \ 10 | --policy avg1 \ 11 | --think_ratio 0.6 \ 12 | --batch_size 2000 \ 13 | --dtype bfloat16 \ 14 | --gpu-memory-utilization 0.9 \ 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 chenxuYang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from datasets import load_dataset, Dataset, concatenate_datasets 5 | from utils.utils import load_jsonl, lower_keys 6 | 7 | def load_data(data_name, split, data_dir='./data'): 8 | data_file = f"{data_dir}/{data_name}/{split}.jsonl" 9 | if os.path.exists(data_file): 10 | examples = list(load_jsonl(data_file)) 11 | else: 12 | if data_name == "math": 13 | dataset = load_dataset("competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp") 14 | elif data_name == "theorem-qa": 15 | dataset = load_dataset("wenhu/TheoremQA", split=split) 16 | elif data_name == "gsm8k": 17 | dataset = load_dataset(data_name, split=split) 18 | elif data_name == "gsm-hard": 19 | dataset = load_dataset("reasoning-machines/gsm-hard", split="train") 20 | elif data_name == "svamp": 21 | # evaluate on training set + test set 22 | dataset = load_dataset("ChilleD/SVAMP", split="train") 23 | dataset = concatenate_datasets([dataset, load_dataset("ChilleD/SVAMP", split="test")]) 24 | elif data_name == "asdiv": 25 | dataset = load_dataset("EleutherAI/asdiv", split="validation") 26 | dataset = dataset.filter(lambda x: ";" not in x['answer']) # remove multi-answer examples 27 | elif data_name == "mawps": 28 | examples = [] 29 | # four sub-tasks 30 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: 31 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) 32 | for example in sub_examples: 33 | example['type'] = data_name 34 | examples.extend(sub_examples) 35 | dataset = Dataset.from_list(examples) 36 | elif data_name == "finqa": 37 | dataset = load_dataset("dreamerdeo/finqa", split=split, name="main") 38 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 39 | elif data_name == "tabmwp": 40 | examples = [] 41 | with open(f"{data_dir}/tabmwp/tabmwp_{split}.json", "r") as f: 42 | data_dict = json.load(f) 43 | examples.extend(data_dict.values()) 44 | dataset = Dataset.from_list(examples) 45 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 46 | elif data_name == "bbh": 47 | examples = [] 48 | for data_name in ["reasoning_about_colored_objects", "penguins_in_a_table",\ 49 | "date_understanding", "repeat_copy_logic", "object_counting"]: 50 | with open(f"{data_dir}/bbh/bbh/{data_name}.json", "r") as f: 51 | sub_examples = json.load(f)["examples"] 52 | for example in sub_examples: 53 | example['type'] = data_name 54 | examples.extend(sub_examples) 55 | dataset = Dataset.from_list(examples) 56 | else: 57 | raise NotImplementedError(data_name) 58 | 59 | examples = list(dataset) 60 | examples = [lower_keys(example) for example in examples] 61 | dataset = Dataset.from_list(examples) 62 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) 63 | dataset.to_json(data_file) 64 | 65 | # add 'idx' in the first column 66 | if 'idx' not in examples[0]: 67 | examples = [{'idx': i, **example} for i, example in enumerate(examples)] 68 | 69 | # dedepulicate & sort 70 | examples = sorted(examples, key=lambda x: x['idx']) 71 | return examples -------------------------------------------------------------------------------- /utils/math_normalization.py: -------------------------------------------------------------------------------- 1 | # Part of the code is modified from the code snippets provided in "Solving Quantitative Reasoning Problems with Language Models" by Lewkowycz et al. 2 | import pdb 3 | import re 4 | import sympy 5 | import threading 6 | from sympy.parsing.latex import parse_latex 7 | from .parser import strip_string 8 | 9 | SUBSTITUTIONS = [ 10 | ('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), (r'\ ', ''), ('\%', '%'), 11 | (' ', ''), ('mbox', 'text'), (',\\text{and}', ','), 12 | ('\\text{and}', ','), ('\\text{m}', '\\text{}') 13 | ] 14 | REMOVED_EXPRESSIONS = [ 15 | 'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft', 16 | 'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 17 | 'minutes', 'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 18 | 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 19 | '\\text{s}', '\\text{.}', '\\text{\ns}', '\\text{}^2', 20 | '\\text{}^3', '\\text{\n}', '\\text{}', r'\mathrm{th}', 21 | r'^\circ', r'^{\circ}', r'\;', r',\!', '{,}', '"', '\\dots' 22 | ] 23 | 24 | def is_integer(s): 25 | try: 26 | int(s) 27 | return True 28 | except ValueError: 29 | return False 30 | 31 | def normalize_final_answer(final_answer: str) -> str: 32 | """Normalize a final answer to a quantitative reasoning question.""" 33 | final_answer = str(final_answer).split('=')[-1] 34 | 35 | for before, after in SUBSTITUTIONS: 36 | final_answer = final_answer.replace(before, after) 37 | for expr in REMOVED_EXPRESSIONS: 38 | final_answer = final_answer.replace(expr, '') 39 | 40 | # Extract answer that is in LaTeX math, is bold, 41 | # is surrounded by a box, etc. 42 | final_answer = re.sub(r'(.*?)(\$)(.*?)(\$)(.*)', '$\\3$', final_answer) 43 | final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer) 44 | final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer) 45 | final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer) 46 | final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer) 47 | 48 | # Normalize shorthand TeX: 49 | # \fracab -> \frac{a}{b} 50 | # \frac{abc}{bef} -> \frac{abc}{bef} 51 | # \fracabc -> \frac{a}{b}c 52 | # \sqrta -> \sqrt{a} 53 | # \sqrtab -> sqrt{a}b 54 | final_answer = re.sub( 55 | r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer) 56 | final_answer = re.sub( 57 | r'(sqrt)([^{])', 'sqrt{\\2}', final_answer) 58 | final_answer = final_answer.replace('$', '') 59 | 60 | # Normalize 100,000 -> 100000 61 | if final_answer.replace(',', '').isdigit(): 62 | final_answer = final_answer.replace(',', '') 63 | # 3.0 -> 3 64 | if final_answer.endswith(".0") and final_answer[:-2].isdigit(): 65 | final_answer = final_answer[:-2] 66 | # 3.00 -> 3 67 | if final_answer.endswith(".00") and final_answer[:-3].isdigit(): 68 | final_answer = final_answer[:-3] 69 | if final_answer.endswith("%") and final_answer[:-1].isdigit(): 70 | final_answer = final_answer[:-1] 71 | # A -> a 72 | if final_answer.lower() in ['a', 'b', 'c', 'd', 'e', 'f', 'g']: 73 | final_answer = final_answer.lower() 74 | return final_answer 75 | 76 | def check_sympy_equivalence(formatted_target_str, formatted_prediction_str): 77 | flag = False 78 | try: 79 | target_expr = parse_latex(formatted_target_str) 80 | except: 81 | target_expr = formatted_target_str 82 | flag = True 83 | 84 | try: 85 | prediction_expr = parse_latex(formatted_prediction_str) 86 | except: 87 | prediction_expr = formatted_prediction_str 88 | flag = True 89 | 90 | if flag == True: 91 | return formatted_target_str == formatted_prediction_str 92 | 93 | try: 94 | return sympy.simplify(target_expr - prediction_expr) == 0 95 | except: 96 | return False 97 | 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEER 🦌: Dynamic Early Exit in Reasoning Models 2 | [![arXiv](https://img.shields.io/badge/arXiv-2504.15895-b31b1b.svg)](https://arxiv.org/abs/2504.15895) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | [![Python](https://img.shields.io/badge/Python-3.8%2B-blue)](https://www.python.org/) 5 | [![HuggingFace](https://img.shields.io/badge/HuggingFace-Transformers-orange)](https://huggingface.co/) 6 | [![vLLM](https://img.shields.io/badge/vLLM-Efficient%20LLM%20Inference-green)](https://github.com/vllm-project/vllm) 7 | 8 | This is the repository of our paper: [Dynamic Early Exit in Reasoning Models](https://arxiv.org/abs/2504.15895). 9 | 10 | 11 |

12 | 13 | **DEER** monitors model behavior at potential reasoning transition points and dynamically terminates the next reasoning chain’s generation when the model exhibits high confidence in a trial answer. It is consistently effective on 11 cutting-edge reasoning LLMs of varying series and sizes, reducing the length of CoT sequences by an average of **19.1% - 80.1%** while improving accuracy by **0.3% - 5.0%**. 14 | 15 | --- 16 | 17 | ## 🔥 **Latest Updates** 18 | - **[2025/05/20]** Released DEER code for mathematical reasoning tasks (HuggingFace & vLLM). 19 | - **[Coming Soon]** DEER for code generation tasks & Branch-Parallel Decoding Acceleration. 20 | 21 | --- 22 | 23 | ## 🎯 Key Results 24 | Results on 11 reasoning models with 16k token budgets. "Acc" denotes accuracy, "Tok" denotes token count, and "CR" denotes compression rate. 25 |

26 | 27 | 28 | Experimental results presented in bar charts. 29 | 30 |

31 | 32 |

33 | 34 | --- 35 | 36 | 37 | ## 🚀 Quick Start 38 | ### 1. Installation 39 | ```bash 40 | git clone https://github.com/yourusername/DEER.git 41 | cd DEER 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ### 2. DEER on vLLM (Recommended) 46 | Considering efficiency, we recommend reproducing the results using the code based on the **vLLM** framework. 47 | 48 | #### For Most Reasoning Models 49 | ``` 50 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer.py \ 51 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \ 52 | --dataset_dir "./data/" \ 53 | --output_path "./outputs" \ 54 | --dataset "math" \ 55 | --threshold 0.95 \ 56 | --max_generated_tokens 16000 \ 57 | --think_ratio 0.6 \ 58 | --batch_size 2000 \ 59 | --policy avg1 \ 60 | --dtype bfloat16 \ 61 | --gpu-memory-utilization 0.9 \ 62 | ``` 63 | or run: 64 | ```bash 65 | bash ./bashes/bash-vllm-deer.sh. 66 | ``` 67 | 68 | 69 | #### For Qwen3 Models 70 | 71 | ``` 72 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer-qwen3.py \ 73 | --model_name_or_path "./Qwen3-4B" \ 74 | --dataset_dir "./data/" \ 75 | --output_path "./outputs" \ 76 | --dataset "math" \ 77 | --threshold 0.95 \ 78 | --max_generated_tokens 16000 \ 79 | --think_ratio 0.8 \ 80 | --batch_size 2000 \ 81 | --dtype bfloat16 \ 82 | --policy avg2 \ 83 | --gpu-memory-utilization 0.9 \ 84 | ``` 85 | or run: 86 | ```bash 87 | bash ./bashes/bash-vllm-deer-qwen3.sh. 88 | ``` 89 | In our experiments, we found that Qwen3-series models tend to be over-confident in confidence prediction, so we made some modifications to its implementation. 90 | - The calculation of answer confidence was changed from arithmetic mean to geometric mean. 91 | - An additional condition must be satisfied for early exit: the model must generate <\/think> after the trial answer. 92 | 93 | ### 3. DEER on Transformers 94 | 95 | For inference using HuggingFace Transformers (without vLLM), run: 96 | ```bash 97 | bash ./bashes/bash-vanilla-deer.sh 98 | ``` 99 | 100 | 101 | ## 📊 Evaluation 102 | 103 | DEER currently supports evaluation on 7 reasoning benchmarks. The rule-based evaluation for these benchmarks is based on the code implementation from the project [LIMO](https://github.com/GAIR-NLP/LIMO/tree/main). 104 | 105 | 106 | ``` 107 | python ../check.py \ 108 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \ 109 | --data_name "math" \ 110 | --generation_path "your_output.jsonl" \ 111 | ``` 112 | or run 113 | ```bash 114 | bash ./bashes/bash-check-correct.sh 115 | ``` 116 | 117 | 118 | 119 | ## 📜 Citation 120 | If you use DEER in your research, please cite our paper: 121 | ```bibtex 122 | @article{yang2025dynamic, 123 | title={Dynamic Early Exit in Reasoning Models}, 124 | author={Yang, Chenxu and Si, Qingyi and Duan, Yongjie and Zhu, Zheliang and Zhu, Chenyu and Lin, Zheng and Cao, Li and Wang, Weiping}, 125 | journal={arXiv preprint arXiv:2504.15895}, 126 | year={2025} 127 | } 128 | ``` 129 | ## 💬 Community 130 | 131 | Join our WeChat group for discussions: 132 |

133 | -------------------------------------------------------------------------------- /check.py: -------------------------------------------------------------------------------- 1 | import json 2 | from transformers import AutoTokenizer 3 | 4 | import re 5 | import importlib.util 6 | import os 7 | import argparse 8 | 9 | import random 10 | import time 11 | from datetime import datetime 12 | from tqdm import tqdm 13 | from utils.utils import set_seed, load_jsonl, save_jsonl, construct_prompt 14 | from utils.parser import * 15 | from utils.data_loader import load_data 16 | from utils.math_normalization import * 17 | from utils.grader import * 18 | import pickle 19 | from math import comb 20 | import pdb 21 | 22 | 23 | def parse_list(arg): 24 | return arg.split(',') 25 | 26 | def save_completions(completions, filepath): 27 | with open(filepath, 'wb') as file: 28 | pickle.dump(completions, file) 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--model_name_or_path', type=str, default="./", help="model dir") 33 | parser.add_argument('--n_sampling', type=int, default=1, help="n for sampling") 34 | parser.add_argument("--k", type=int, default=1, help="Value of k for pass@k calculation") 35 | parser.add_argument("--data_dir", default="./data", type=str) 36 | parser.add_argument('--data_name', type=str, default="math", help='identify how to extract answer') 37 | parser.add_argument("--split", default="test", type=str) 38 | parser.add_argument("--generation_path", default="test", type=str) 39 | 40 | parser.add_argument("--prompt_type", default="qwen-base", type=str) 41 | 42 | args = parser.parse_args() 43 | 44 | 45 | 46 | return args 47 | 48 | def get_conversation_prompt_by_messages(tokenizer, messages): 49 | text = tokenizer.apply_chat_template( 50 | messages, 51 | tokenize=False, 52 | add_generation_prompt=True 53 | ) 54 | return text 55 | 56 | def get_three_prompt(prompt_type, data_name): 57 | file_path = os.path.join(".", "prompts", prompt_type, f"{data_name}.py") 58 | if not os.path.exists(file_path): 59 | raise FileNotFoundError(f"File not found: {file_path}") 60 | 61 | spec = importlib.util.spec_from_file_location("dynamic_module", file_path) 62 | module = importlib.util.module_from_spec(spec) 63 | spec.loader.exec_module(module) 64 | 65 | if hasattr(module, 'system_prompt'): 66 | system_prompt = module.system_prompt 67 | else: 68 | raise AttributeError(f"'system_prompt' not found in {file_path}") 69 | 70 | if hasattr(module, 'few_shot_prompt'): 71 | few_shot_prompt = module.few_shot_prompt 72 | else: 73 | raise AttributeError(f"'few_shot_prompt' not found in {file_path}") 74 | 75 | if hasattr(module, 'question_format'): 76 | question_format = module.question_format 77 | else: 78 | raise AttributeError(f"'question_format' not found in {file_path}") 79 | 80 | return system_prompt, few_shot_prompt, question_format 81 | 82 | def read_jsonl(file_path): 83 | 84 | data = [] 85 | with open(file_path, 'r', encoding='utf-8') as f: 86 | for line in f: 87 | 88 | json_obj = json.loads(line.strip()) 89 | data.append(json_obj) 90 | return data 91 | 92 | 93 | 94 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 95 | 96 | def infer(args): 97 | examples = load_data(args.data_name, args.split, args.data_dir) 98 | file_outputs = read_jsonl(args.generation_path) 99 | 100 | 101 | print("llm generate done") 102 | print(len(file_outputs)) 103 | 104 | pass_at_k_list = [] 105 | k = args.k 106 | 107 | correct_cnt = 0 108 | for i in tqdm(range(len(file_outputs)), "check correct..."): 109 | d = examples[i] 110 | gt_cot, gt_ans = parse_ground_truth(d, args.data_name) 111 | generated_responses = file_outputs[i]['generated_responses'] 112 | 113 | 114 | generated_answers = [extract_answer(generated_response, args.data_name) for generated_response in generated_responses] 115 | is_correct_list = [check_is_correct(generated_answer, gt_ans) for generated_answer in generated_answers] 116 | is_correct = any(is_correct_list) 117 | if is_correct: 118 | #print(i) 119 | correct_cnt += 1 120 | file_outputs[i]['generated_answers'] = generated_answers 121 | file_outputs[i]['gold_answer'] = gt_ans 122 | file_outputs[i]['is_correct'] = is_correct 123 | file_outputs[i]['answers_correctness'] = is_correct_list 124 | 125 | if len(is_correct_list) > 1: 126 | correct_answers = sum(is_correct_list) 127 | n = len(generated_answers) 128 | if correct_answers > 0: 129 | if n - correct_answers < k: 130 | pass_at_k = 1 131 | else: 132 | pass_at_k = 1 - (comb(n - correct_answers, k) / comb(n, k)) 133 | pass_at_k_list.append(pass_at_k) 134 | else: 135 | pass_at_k_list.append(0) 136 | 137 | 138 | print(f"correct cnt / total cnt: {correct_cnt}/{len(file_outputs)}") 139 | print(f"Acc: {correct_cnt / len(file_outputs):.4f}") 140 | 141 | if pass_at_k_list: 142 | average_pass_at_k = sum(pass_at_k_list) / len(pass_at_k_list) 143 | print(f"Pass@{k}: {sum(pass_at_k_list)}/{len(pass_at_k_list)} = {average_pass_at_k:.4f}") 144 | else: 145 | print(f"Pass@1: {correct_cnt}/{len(file_outputs)} = {correct_cnt / len(file_outputs):.4f}") 146 | 147 | 148 | 149 | response_length = [] 150 | token_num = [] 151 | wait_num = [] 152 | alt_num = [] 153 | 154 | test_num = len(file_outputs) 155 | correct_num = 0 156 | for data in file_outputs: 157 | response_length.append(len(data['generated_responses'][0].split())) 158 | tokens_response_len = len(tokenizer(data['generated_responses'][0])['input_ids']) 159 | token_num.append(tokens_response_len) 160 | 161 | 162 | avg_response_length = sum(response_length) / test_num 163 | avg_token_num = sum(token_num) / test_num 164 | 165 | print("length:", avg_response_length) 166 | print('token_num:', avg_token_num) 167 | 168 | 169 | if __name__ == "__main__": 170 | args = parse_args() 171 | infer(args) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import json 5 | import os 6 | import numpy as np 7 | from pathlib import Path 8 | from typing import Iterable, Union, Any 9 | 10 | from utils.examples import get_examples 11 | 12 | 13 | def set_seed(seed: int = 42) -> None: 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | print(f"Random seed set as {seed}") 18 | 19 | 20 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]: 21 | with open(file, "r", encoding="utf-8") as f: 22 | for line in f: 23 | try: 24 | yield json.loads(line) 25 | except: 26 | print("Error in loading:", line) 27 | exit() 28 | 29 | 30 | def save_jsonl(samples, save_path): 31 | # ensure path 32 | folder = os.path.dirname(save_path) 33 | os.makedirs(folder, exist_ok=True) 34 | 35 | with open(save_path, "w", encoding="utf-8") as f: 36 | for sample in samples: 37 | f.write(json.dumps(sample, ensure_ascii=False) + "\n") 38 | print("Saved to", save_path) 39 | 40 | 41 | def lower_keys(example): 42 | new_example = {} 43 | for key, value in example.items(): 44 | if key != key.lower(): 45 | new_key = key.lower() 46 | new_example[new_key] = value 47 | else: 48 | new_example[key] = value 49 | return new_example 50 | 51 | 52 | EXAMPLES = get_examples() 53 | 54 | 55 | def load_prompt(data_name, prompt_type, num_shots): 56 | if not num_shots: 57 | return [] 58 | 59 | if data_name in ["gsm_hard", "svamp", "tabmwp", "asdiv", "mawps"]: 60 | data_name = "gsm8k" 61 | if data_name in ["math_oai", "hungarian_exam", "math-oai", "aime24", "amc23"]: 62 | data_name = "math" 63 | if data_name in ["sat_math"]: 64 | data_name = "mmlu_stem" 65 | if data_name in [ 66 | "gaokao2024_I", 67 | "gaokao2024_II", 68 | "gaokao_math_qa", 69 | "gaokao2024_mix", 70 | "cn_middle_school", 71 | ]: 72 | data_name = "gaokao" 73 | 74 | if prompt_type in ["tool-integrated"]: 75 | prompt_type = "tora" 76 | 77 | return EXAMPLES[data_name][:num_shots] 78 | 79 | 80 | PROMPT_TEMPLATES = { 81 | "direct": ("Question: {input}\nAnswer: ", "{output}", "\n\n"), 82 | "cot": ("Question: {input}\nAnswer: ", "{output}", "\n\n\n"), 83 | "pal": ("Question: {input}\n\n", "{output}", "\n---\n"), 84 | "tool-integrated": ("Question: {input}\n\nSolution:\n", "{output}", "\n---\n"), 85 | "self-instruct": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"), 86 | "tora": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"), 87 | "wizard_zs": ( 88 | "### Instruction:\n{input}\n\n### Response: Let's think step by step.", 89 | "{output}", 90 | "\n\n\n", 91 | ), 92 | "platypus_fs": ( 93 | "### Instruction:\n{input}\n\n### Response:\n", 94 | "{output}", 95 | "\n\n\n", 96 | ), 97 | "deepseek-math": ( 98 | "User: {input}\nPlease reason step by step, " 99 | "and put your final answer within \\boxed{{}}.\n\nAssistant:", 100 | "{output}", 101 | "\n\n\n", 102 | ), 103 | "kpmath": ( 104 | "User: Please reason step by step and put your final answer at the end " 105 | 'with "The answer is: ".\n\n{input}\n\nAssistant:', 106 | "{output}", 107 | ), 108 | "jiuzhang": ( 109 | "## Question\n{input}\n\n## Solution\n", 110 | "{output}", 111 | "\n\n\n", 112 | ), 113 | "jiuzhang_tora": ( 114 | "## Question\n{input}\n\n## Code Solution\n", 115 | "{output}", 116 | "\n\n\n", 117 | ), 118 | "jiuzhang_nl": ( 119 | "## Question\n{input}\n\n## Natural Language Solution\n", 120 | "{output}", 121 | "\n\n\n", 122 | ), 123 | "mmiqc": ( 124 | 'Please solve the following problem and put your answer at the end with "The answer is: ".\n\n{input}\n\n', 125 | "{output}", 126 | "\n\n\n", 127 | ), 128 | "abel": ( 129 | "Question:\n{input}\nAnswer:\nLet's think step by step.\n", 130 | "{output}", 131 | "\n\n", 132 | ), 133 | "shepherd": ("{input}\n", "{output}", "\n\n\n"), 134 | "qwen-boxed": ( 135 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 136 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 137 | "<|im_start|>assistant\n", 138 | "{output}", 139 | "\n\n", 140 | ), 141 | "qwen25-math-cot": ( 142 | "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 143 | "<|im_start|>user\n{input}<|im_end|>\n" 144 | "<|im_start|>assistant\n", 145 | "{output}", 146 | "\n\n", 147 | ), 148 | "mathstral": ( 149 | "{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.", 150 | "{output}", 151 | "\n\n", 152 | ), 153 | "internlm-math-fs": ("Question:{input}\nAnswer:", "{output}", "\n"), 154 | "internlm-math-chat": ( 155 | "<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n", 156 | "{output}", 157 | "\n\n", 158 | ), 159 | "mistral": ( 160 | "[INST] {input}[/INST]", 161 | "{output}", 162 | "\n\n", 163 | ), 164 | "numina": ("### Problem: {input}\n### Solution:", " {output}", "\n\n"), 165 | } 166 | 167 | 168 | def construct_prompt(example, data_name, args): 169 | if args.adapt_few_shot and data_name in [ 170 | "gaokao2024_I", 171 | "gaokao2024_II", 172 | "gaokao_math_qa", 173 | "gaokao2024_mix", 174 | "cn_middle_school", 175 | ]: 176 | demos = load_prompt(data_name, args.prompt_type, 5) 177 | else: 178 | demos = load_prompt(data_name, args.prompt_type, args.num_shots) 179 | prompt_type = args.prompt_type 180 | if prompt_type == "platypus_fs": 181 | prompt_type = "cot" 182 | if prompt_type == "tool-integrated": 183 | prompt_type = "tora" 184 | 185 | prompt_temp = PROMPT_TEMPLATES[args.prompt_type] 186 | 187 | splitter = prompt_temp[2] 188 | input_template, output_template, splitter = ( 189 | prompt_temp[0], 190 | prompt_temp[1], 191 | prompt_temp[2], 192 | ) 193 | if args.prompt_type == "qwen25-math-cot": 194 | # Hotfix to support putting all demos into a single turn 195 | demo_prompt = splitter.join([q + "\n" + a for q, a in demos]) 196 | else: 197 | demo_prompt = splitter.join( 198 | [ 199 | input_template.format(input=q) + output_template.format(output=a) 200 | for q, a in demos 201 | ] 202 | ) 203 | context = input_template.format(input=example["question"]) 204 | if len(demo_prompt) == 0 or ( 205 | args.adapt_few_shot and example["gt_ans"] not in ["A", "B", "C", "D", "E"] 206 | ): 207 | full_prompt = context 208 | else: 209 | if args.prompt_type == "qwen25-math-cot": 210 | # Hotfix to supportting put all demos into a single turn 211 | full_prompt = demo_prompt + splitter + example["question"] 212 | full_prompt = input_template.format(input=full_prompt) 213 | else: 214 | full_prompt = demo_prompt + splitter + context 215 | 216 | if args.prompt_type == "platypus_fs": 217 | full_prompt_temp = ( 218 | "Below is an instruction that describes a task. " 219 | "Write a response that appropriately completes the request.\n\n" 220 | "### Instruction:\n{instruction}\n\n### Response:\n" 221 | ) 222 | full_prompt = full_prompt_temp.format(instruction=full_prompt) 223 | 224 | if prompt_type == "tora": 225 | full_prompt = ( 226 | """Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines: 227 | 228 | - Analyze the question and write functions to solve the problem; the function should not take any arguments. 229 | - Present the final result in LaTeX using a `\boxed{}` without any units. 230 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values. 231 | 232 | Here are some examples you may refer to: 233 | 234 | --- 235 | 236 | """ 237 | + full_prompt 238 | ) 239 | 240 | return full_prompt.strip(" ") # important! 241 | 242 | 243 | key_map = { 244 | "gt": "Ground Truth", 245 | "pred": "Prediction", 246 | "gt_cot": "Reference CoT", 247 | "score": "Score", 248 | } 249 | 250 | 251 | def show_sample(sample, print_all_preds=False): 252 | print("==" * 20) 253 | for key in ["idx", "type", "level", "dataset"]: 254 | if key in sample: 255 | # capitalize 256 | print("{}: {}".format(key[0].upper() + key[1:], sample[key])) 257 | print("Question:", repr(sample["question"])) 258 | if "code" in sample: 259 | if print_all_preds: 260 | for code in sample["code"]: 261 | print("-" * 20) 262 | print("code:", code) 263 | print("Execution:", sample["report"]) 264 | else: 265 | print("Solution:\n", sample["code"][0]) 266 | print("Execution:", sample["report"][0]) 267 | if "pred" in sample: 268 | print("Prediction:", repr(sample["pred"][0])) 269 | for key in ["gt", "score", "unit", "gt_cot"]: 270 | if key in sample: 271 | _key = key_map.get(key, key) 272 | print("{}: {}".format(_key, repr(sample[key]))) 273 | print() 274 | -------------------------------------------------------------------------------- /vanilla_deer.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | warnings.filterwarnings("ignore") # Ignore all warnings 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | import pdb 6 | import torch 7 | import torch.nn.functional as F 8 | import json 9 | import re 10 | from tqdm import tqdm 11 | import argparse 12 | import os 13 | from transformers.cache_utils import DynamicCache 14 | from copy import deepcopy 15 | 16 | def append_jsonl(data, file_path): 17 | """ 18 | Append results from list to a .jsonl file. 19 | 20 | Parameters: 21 | data (list): List containing Python dictionaries to append. 22 | file_path (str): Target .jsonl file path. 23 | """ 24 | with open(file_path, 'a', encoding='utf-8') as f: 25 | for item in data: 26 | json_line = json.dumps(item, ensure_ascii=False) 27 | f.write(json_line + '\n') 28 | 29 | def write_jsonl(data, file_path): 30 | """ 31 | Write results from list to a .jsonl file. 32 | 33 | Parameters: 34 | data (list): List containing Python dictionaries to write. 35 | file_path (str): Output .jsonl file path. 36 | """ 37 | with open(file_path, 'w', encoding='utf-8') as f: 38 | for item in data: 39 | # Convert each dictionary to JSON string and write to file 40 | json_line = json.dumps(item, ensure_ascii=False) 41 | f.write(json_line + '\n') 42 | 43 | def read_jsonl(file_path): 44 | """ 45 | Read .jsonl file and return a list of dictionaries. 46 | 47 | Parameters: 48 | file_path (str): Path to .jsonl file. 49 | 50 | Returns: 51 | data (list): List containing all JSON objects. 52 | """ 53 | data = [] 54 | with open(file_path, 'r', encoding='utf-8') as f: 55 | for line in f: 56 | # Parse JSON object from each line 57 | json_obj = json.loads(line.strip()) 58 | data.append(json_obj) 59 | return data 60 | 61 | def calcu_max_probs_w_kv(model, pred_input_ids, kv_cache, tokenizer, method=1): 62 | list1 = tokenizer([' }', '}', '}.', '}.\n', '}\\', '}}', ')}', ')}.', ')}\n', ''])['input_ids'] 63 | stop_ids = sum(list1, []) 64 | total_steps = 0 65 | if method == 0: 66 | total_prob_max = 1.0 67 | else: 68 | total_prob_max = 0.0 69 | 70 | pred_tokens = [] 71 | last_token = -1 72 | 73 | backup_cache = deepcopy(kv_cache) 74 | 75 | with torch.no_grad(): 76 | while last_token not in stop_ids: 77 | 78 | if last_token == -1: 79 | output_dicts = model(input_ids=pred_input_ids, past_key_values=backup_cache) 80 | else: 81 | output_dicts = model(input_ids=torch.tensor([last_token]).unsqueeze(0).to(pred_input_ids.device), past_key_values=backup_cache) 82 | logits = output_dicts['logits'][0][-1] 83 | past_key_values = output_dicts['past_key_values'] 84 | probs = F.softmax(logits, dim=-1) 85 | 86 | max_value, max_index = torch.max(probs, dim=0) 87 | 88 | 89 | if last_token == -1: 90 | total_prob_max = total_prob_max 91 | else: 92 | 93 | if method == 0: 94 | total_prob_max *= max_value 95 | else: 96 | total_prob_max += max_value 97 | 98 | pred_tokens.append(max_index) 99 | last_token = max_index 100 | total_steps += 1 101 | if total_steps > 20: 102 | break 103 | 104 | 105 | if method != 0: 106 | total_prob_max = (total_prob_max - max_value) / (total_steps - 2) 107 | 108 | del backup_cache, past_key_values 109 | torch.cuda.empty_cache() 110 | 111 | return total_prob_max.item() 112 | 113 | def parse_args(): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--model_name_or_path', type=str, default="/mnt/data1/ycx/DeepSeek-R1-Distill-Qwen-32B", help="model directory") 116 | parser.add_argument('--threshold', type=float, default=0.95) 117 | parser.add_argument('--max_len', type=int, default=16384) 118 | parser.add_argument('--dataset', type=str, default='math') 119 | parser.add_argument('--output_path', type=str, default='./outputs') 120 | parser.add_argument('--log', type=bool, default=False) 121 | parser.add_argument('--think_ratio', type=float, default=0.9) 122 | 123 | args = parser.parse_args() 124 | return args 125 | 126 | args = parse_args() 127 | 128 | think_len = int(args.max_len * args.think_ratio) 129 | answer_len = args.max_len - think_len 130 | 131 | # model 132 | model_name = args.model_name_or_path 133 | model = AutoModelForCausalLM.from_pretrained( 134 | model_name, 135 | torch_dtype="bfloat16", 136 | device_map="auto" 137 | ) 138 | tokenizer = AutoTokenizer.from_pretrained(model_name) 139 | 140 | # data 141 | sys_prompts = ['Please reason step by step, and put your final answer within \\boxed{}.'] 142 | questions_json = read_jsonl('./data/' + args.dataset + '/test.jsonl') 143 | 144 | os.makedirs(f'{args.output_path}/{model_name}/{args.dataset}', exist_ok=True) 145 | output_list = [] 146 | 147 | for i in tqdm(range(0, len(questions_json))): 148 | output_dict = {} 149 | sys_prompt = sys_prompts[0] 150 | 151 | prompt = questions_json[i]['problem'] #+ 'start ' * 30000 152 | answer = str(questions_json[i]['answer']) 153 | 154 | messages = [ 155 | {"role": "system", "content": sys_prompt}, 156 | {"role": "user", "content": prompt} 157 | ] 158 | 159 | text = tokenizer.apply_chat_template( 160 | messages, 161 | tokenize=False, 162 | add_generation_prompt=True 163 | ) 164 | 165 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) 166 | input_ids = model_inputs["input_ids"] 167 | input_length = len(input_ids[0]) 168 | 169 | last_token_ids = tokenizer("**", add_special_tokens=False)["input_ids"] + tokenizer("", add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] 170 | continue_ids = tokenizer("Wait", add_special_tokens=False)["input_ids"] 171 | stop_ids = continue_ids + last_token_ids 172 | 173 | answer_prompt_ids = tokenizer("\n**Final Answer**\n\nThe final answer is \\boxed", add_special_tokens=False)["input_ids"] 174 | answer_ids = tokenizer(answer, add_special_tokens=False)["input_ids"] 175 | 176 | past_key_values = DynamicCache() 177 | first_round = True 178 | too_long = False 179 | while 1: 180 | if first_round: 181 | 182 | generated_dicts = model.generate( 183 | input_ids, 184 | #max_new_tokens=200, 185 | max_new_tokens=think_len-len(input_ids[0]), 186 | do_sample=False, 187 | eos_token_id=stop_ids, 188 | return_dict_in_generate=True, 189 | output_logits=True, 190 | tokenizer=tokenizer, 191 | past_key_values=past_key_values, 192 | ) 193 | 194 | else: 195 | generated_dicts = model.generate( 196 | input_ids, 197 | max_new_tokens=think_len-len(input_ids[0]), 198 | do_sample=False, 199 | return_dict_in_generate=True, 200 | output_logits=True, 201 | tokenizer=tokenizer, 202 | eos_token_id=stop_ids, 203 | past_key_values=past_key_values, 204 | ) 205 | 206 | generated_ids = [ 207 | output_ids[len(input_ids):-1] for input_ids, output_ids in zip(input_ids, generated_dicts['sequences']) 208 | ] 209 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 210 | 211 | logits = generated_dicts['logits'][-1] 212 | probs = F.softmax(logits, dim=-1)[0] 213 | 214 | max_value, max_index = torch.max(probs, dim=0) 215 | 216 | if max_index in last_token_ids: 217 | real_stop = 1 218 | else: 219 | real_stop = 0 220 | 221 | pred_input_ids = torch.cat((input_ids, generated_ids[0].unsqueeze(0)), dim=1) 222 | 223 | if len(pred_input_ids[0]) >= think_len - 100: 224 | too_long = True 225 | 226 | pred_prob = calcu_max_probs_w_kv(model, torch.tensor(answer_prompt_ids).to(generated_ids[0].device).unsqueeze(0), past_key_values, tokenizer, 1) 227 | 228 | torch.cuda.empty_cache() 229 | 230 | if pred_prob > args.threshold or real_stop or too_long: 231 | input_ids = torch.cat((pred_input_ids, torch.tensor(tokenizer('\n\n\n')['input_ids']).to(generated_ids[0].device).unsqueeze(0)), dim=1) # with wait 232 | 233 | generated_dicts = model.generate( 234 | input_ids, 235 | max_new_tokens=answer_len, 236 | do_sample=False, 237 | return_dict_in_generate=True, 238 | past_key_values=past_key_values, 239 | ) 240 | 241 | generated_ids = [ 242 | output_ids[len(input_ids):-1] for input_ids, output_ids in zip(input_ids, generated_dicts['sequences']) 243 | ] 244 | final_output_ids = torch.cat((input_ids[0], generated_ids[0]), dim=-1) 245 | response = tokenizer.batch_decode([final_output_ids[input_length:]], skip_special_tokens=True)[0] 246 | 247 | if args.log: 248 | log_file_path = args.output_path + "./outputs/log" + str(args.threshold) + ".txt" 249 | with open(log_file_path, "a") as file: 250 | file.write(response + "\n") 251 | 252 | break 253 | 254 | else: 255 | tmp = torch.cat((generated_ids[0], torch.tensor(continue_ids).to(generated_ids[0].device)), dim=0) 256 | input_ids = torch.cat((input_ids, tmp.unsqueeze(0)), dim=1) # with wait 257 | torch.cuda.empty_cache() 258 | 259 | output_dict['question'] = questions_json[i]['problem'] 260 | output_dict['generated_responses'] = [response] 261 | output_dict['gold_answer'] = questions_json[i]['answer'] 262 | 263 | append_jsonl([output_dict], args.output_path + model_name + '/' + args.dataset + '/greedy_p' + str(args.threshold) + '_len' + str(args.max_len) + '.jsonl') -------------------------------------------------------------------------------- /data/aime25/test.jsonl: -------------------------------------------------------------------------------- 1 | {"problem": "Six points $ A, B, C, D, E, $ and $ F $ lie in a straight line in that order. Suppose that $ G $ is a point not on the line and that $ AC = 26 $, $ BD = 22 $, $ CE = 31 $, $ DF = 33 $, $ AF = 73 $, $ CG = 40 $, and $ DG = 30 $. Find the area of $ \\triangle BGE $.", "answer": "468"} 2 | {"problem": "Find the sum of all positive integers $ n $ such that $ n + 2 $ divides the product $ 3(n + 3)(n^2 + 9) $.", "answer": "49"} 3 | {"problem": "Four unit squares form a $2 \\times 2$ grid. Each of the 12 unit line segments forming the sides of the squares is colored either red or blue in such a way that each unit square has 2 red sides and 2 blue sides. Find the number of such colorings.", "answer": "82"} 4 | {"problem": "The product $ \\prod_{k=4}^{63} \\frac{\\log_k(5^{k^2-1})}{\\log_{k+1}(5^{k^2-4})} = \\frac{\\log_4(5^{15})}{\\log_5(5^{12})} \\cdot \\frac{\\log_5(5^{24})}{\\log_6(5^{21})} \\cdot \\frac{\\log_6(5^{35})}{\\log_7(5^{32})} \\cdots \\frac{\\log_{63}(5^{3968})}{\\log_{64}(5^{3965})} $ is equal to $ \\frac{m}{n} $, where $ m $ and $ n $ are relatively prime positive integers. Find $ m + n $.", "answer": "106"} 5 | {"problem": "Suppose $ \\triangle ABC $ has angles $ \\angle BAC = 84^\\circ $, $ \\angle ABC = 60^\\circ $, and $ \\angle ACB = 36^\\circ $. Let $ D, E, $ and $ F $ be the midpoints of sides $ \\overline{BC} $, $ \\overline{AC} $, and $ \\overline{AB} $, respectively. The circumcircle of $ \\triangle DEF $ intersects $ \\overline{BD} $, $ \\overline{AE} $, and $ \\overline{AF} $ at points $ G, H, $ and $ J $, respectively. The points $ G, D, E, H, J, $ and $ F $ divide the circumcircle of $ \\triangle DEF $ into six minor arcs, as shown. Find $ \\widehat{DE} + 2 \\cdot \\widehat{HJ} + 3 \\cdot \\widehat{FG} $, where the arcs are measured in degrees.", "answer": "336^\\circ"} 6 | {"problem": "Circle $\\omega_1$ with radius 6 centered at point $A$ is internally tangent at point $B$ to circle $\\omega_2$ with radius 15. Points $C$ and $D$ lie on $\\omega_2$ such that $\\overline{BC}$ is a diameter of $\\omega_2$ and $\\overline{BC} \\perp \\overline{AD}$. The rectangle $EFGH$ is inscribed in $\\omega_1$ such that $\\overline{EF} \\perp \\overline{BC}$, $C$ is closer to $\\overline{GH}$ than to $\\overline{EF}$, and $D$ is closer to $\\overline{FG}$ than to $\\overline{EH}$, as shown. Triangles $\\triangle DGF$ and $\\triangle CHG$ have equal areas. The area of rectangle $EFGH$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m + n$.", "answer": "293"} 7 | {"problem": "Let $ A $ be the set of positive integer divisors of 2025. Let $ B $ be a randomly selected subset of $ A $. The probability that $ B $ is a nonempty set with the property that the least common multiple of its elements is 2025 is $ \\frac{m}{n} $, where $ m $ and $ n $ are relatively prime positive integers. Find $ m + n $.", "answer": "237"} 8 | {"problem": "From an unlimited supply of 1-cent coins, 10-cent coins, and 25-cent coins, Silas wants to find a collection of coins that has a total value of $ N $ cents, where $ N $ is a positive integer. He uses the so-called **greedy algorithm**, successively choosing the coin of greatest value that does not cause the value of his collection to exceed $ N $. For example, to get 42 cents, Silas will choose a 25-cent coin, then a 10-cent coin, then 7 1-cent coins. However, this collection of 9 coins uses more coins than necessary to get a total of 42 cents; indeed, choosing 4 10-cent coins and 2 1-cent coins achieves the same total value with only 6 coins.\n\nIn general, the greedy algorithm succeeds for a given $ N $ if no other collection of 1-cent, 10-cent, and 25-cent coins gives a total value of $ N $ cents using strictly fewer coins than the collection given by the greedy algorithm. Find the number of values of $ N $ between 1 and 1000 inclusive for which the greedy algorithm succeeds.", "answer": "610"} 9 | {"problem": "There are $ n $ values of $ x $ in the interval $ 0 < x < 2\\pi $ where $ f(x) = \\sin(7\\pi \\cdot \\sin(5x)) = 0 $. For $ t $ of these $ n $ values of $ x $, the graph of $ y = f(x) $ is tangent to the $ x $-axis. Find $ n + t $.", "answer": "149"} 10 | {"problem": "Sixteen chairs are arranged in a row. Eight people each select a chair in which to sit so that no person sits next to two other people. Let $ N $ be the number of subsets of 16 chairs that could be selected. Find the remainder when $ N $ is divided by 1000.", "answer": "907"} 11 | {"problem": "Let $ S $ be the set of vertices of a regular 24-gon. Find the number of ways to draw 12 segments of equal lengths so that each vertex in $ S $ is an endpoint of exactly one of the 12 segments.", "answer": "113"} 12 | {"problem": "Let $ A_1A_2 \\ldots A_{11} $ be an 11-sided non-convex simple polygon with the following properties:\n* The area of $ A_iA_1A_{i+1} $ is 1 for each $ 2 \\leq i \\leq 10 $,\n* $ \\cos(\\angle A_iA_1A_{i+1}) = \\frac{12}{13} $ for each $ 2 \\leq i \\leq 10 $,\n* The perimeter of $ A_1A_2 \\ldots A_{11} $ is 20.\nIf $ A_1A_2 + A_1A_{11} $ can be expressed as $ \\frac{m\\sqrt{n} - p}{q} $ for positive integers $ m, n, p, q $ with $ n $ squarefree and no prime divides all of $ m, p, q$, find $ m + n + p + q $.", "answer": "19"} 13 | {"problem": "Let the sequence of rationals $ x_1, x_2, \\ldots $ be defined such that $ x_1 = \\frac{25}{11} $ and\n$ x_{k+1} = \\frac{1}{3} \\left( x_k + \\frac{1}{x_k} - 1 \\right). $\n$ x_{2025} $ can be expressed as $ \\frac{m}{n} $ for relatively prime positive integers $ m $ and $ n $. Find the remainder when $ m + n $ is divided by 1000.", "answer": "248"} 14 | {"problem": "Let $ \\triangle ABC $ be a right triangle with $ \\angle A = 90^\\circ $ and $ BC = 38 $. There exist points $ K $ and $ L $ inside the triangle such that $ AK = AL = BK = CL = KL = 14. $ The area of the quadrilateral $ BKLC $ can be expressed as $ n \\sqrt{3} $ for some positive integer $ n $. Find $ n $.", "answer": "104"} 15 | {"problem": "There are exactly three positive real numbers $ k $ such that the function\n$ f(x) = \\frac{(x - 18)(x - 72)(x - 98)(x - k)}{x} $\ndefined over the positive real numbers achieves its minimum value at exactly two positive real numbers $ x $. Find the sum of these three values of $ k $.", "answer": "240"} 16 | {"problem": "Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.", "answer": "70"} 17 | {"problem": "On $\\triangle ABC$ points $A,D,E$, and $B$ lie that order on side $\\overline{AB}$ with $AD=4, DE=16$, and $EB=8$. Points $A,F,G$, and $C$ lie in that order on side $\\overline{AC}$ with $AF=13, FG=52$, and $GC=26$. Let $M$ be the reflection of $D$ through $F$, and let $N$ be the reflection of $G$ through $E$. Quadrilateral $DEGF$ has area 288. Find the area of heptagon $AFNBCEM$.", "answer": "588"} 18 | {"problem": "The 9 members of a baseball team went to an ice cream parlor after their game. Each player had a singlescoop cone of chocolate, vanilla, or strawberry ice cream. At least one player chose each flavor, and the number of players who chose chocolate was greater than the number of players who chose vanilla, which was greater than the number of players who chose strawberry. Let $N$ be the number of different assignments of flavors to players that meet these conditions. Find the remainder when $N$ is divided by 1000.", "answer": "16"} 19 | {"problem": "Find the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers between $-100$ and $100$, inclusive, such that $12x^{2}-xy-6y^{2}=0$.", "answer": "117"} 20 | {"problem": "There are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.", "answer": "279"} 21 | {"problem": "An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is 3, and the area of the trapezoid is 72. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \\neq s$. Find $r^{2}+s^{2}$.", "answer": "504"} 22 | {"problem": "The twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. For example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$.", "answer": "821"} 23 | {"problem": "Let $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\\sqrt{-1}$.", "answer": "77"} 24 | {"problem": "The parabola with equation $y=x^{2}-4$ is rotated $60^{\\circ}$ counterclockwise around the origin. The unique point in the fourth quadrant where the original parabola and its image intersect has $y$-coordinate $\\frac{a-\\sqrt{b}}{c}$, where $a$, $b$, and $c$ are positive integers, and $a$ and $c$ are relatively prime. Find $a+b+c$.", "answer": "62"} 25 | {"problem": "The 27 cells of a $3\\times9$ grid are filled in using the numbers 1 through 9 so that each row contains 9 different numbers, and each of the three $3\\times3$ blocks heavily outlined in the example below contains 9 different numbers, as in the first three rows of a Sudoku puzzle. \n | 4 | 2 | 8 | 9 | 6 | 3 | 1 | 7 | 5 | \n | 3 | 7 | 9 | 5 | 2 | 1 | 6 | 8 | 4 | \n | 5 | 6 | 1 | 8 | 4 | 7 | 9 | 2 | 3 | \n The number of different ways to fill such a grid can be written as $p^a\\cdot q^b\\cdot r^c\\cdot s^d$, where $p,q,r,$ and $s$ are distinct prime numbers and $a,b,c,$ and $d$ are positive integers. Find $p\\cdot a+q\\cdot b+r\\cdot c+s\\cdot d$.", "answer": "81"} 26 | {"problem": "A piecewise linear periodic function is defined by $f(x)=\\begin{cases}x&\\text{if }x\\in[-1,1)\\\\2-x&\\text{if }x\\in[1,3)\\end{cases}$ and $f(x+4)=f(x)$ for all real numbers $x$. The graph of $f(x)$ has the sawtooth pattern. The parabola $x=34y^2$ intersects the graph of $f(x)$ at finitely many points. The sum of the $y$-coordinates of these intersection points can be expressed in the form $\\frac{a+b\\sqrt{c}}{d}$, where $a,b,c,$ and $d$ are positive integers, $a,b,$ and $d$ have greatest common divisor equal to 1, and $c$ is not divisible by the square of any prime. Find $a+b+c+d$.", "answer": "259"} 27 | {"problem": "The set of points in 3-dimensional coordinate space that lie in the plane $x+y+z=75$ whose coordinates satisfy the inequalities $x-yz 1: 15 | substrs = substrs[1:] 16 | for substr in substrs: 17 | new_str += "\\frac" 18 | if len(substr) > 0 and substr[0] == "{": 19 | new_str += substr 20 | else: 21 | try: 22 | assert len(substr) >= 2 23 | except: 24 | return string 25 | a = substr[0] 26 | b = substr[1] 27 | if b != "{": 28 | if len(substr) > 2: 29 | post_substr = substr[2:] 30 | new_str += "{" + a + "}{" + b + "}" + post_substr 31 | else: 32 | new_str += "{" + a + "}{" + b + "}" 33 | else: 34 | if len(substr) > 2: 35 | post_substr = substr[2:] 36 | new_str += "{" + a + "}" + b + post_substr 37 | else: 38 | new_str += "{" + a + "}" + b 39 | string = new_str 40 | return string 41 | 42 | 43 | def _fix_a_slash_b(string): 44 | if len(string.split("/")) != 2: 45 | return string 46 | a = string.split("/")[0] 47 | b = string.split("/")[1] 48 | try: 49 | if "sqrt" not in a: 50 | a = int(a) 51 | if "sqrt" not in b: 52 | b = int(b) 53 | assert string == "{}/{}".format(a, b) 54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 55 | return new_string 56 | except: 57 | return string 58 | 59 | 60 | def _fix_sqrt(string): 61 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 62 | return _string 63 | 64 | 65 | def convert_word_number(text: str) -> str: 66 | try: 67 | text = str(w2n.word_to_num(text)) 68 | except: 69 | pass 70 | return text 71 | 72 | 73 | # units mainly from MathQA 74 | unit_texts = [ 75 | "east", 76 | "degree", 77 | "mph", 78 | "kmph", 79 | "ft", 80 | "m sqaure", 81 | " m east", 82 | "sq m", 83 | "deg", 84 | "mile", 85 | "q .", 86 | "monkey", 87 | "prime", 88 | "ratio", 89 | "profit of rs", 90 | "rd", 91 | "o", 92 | "gm", 93 | "p . m", 94 | "lb", 95 | "tile", 96 | "per", 97 | "dm", 98 | "lt", 99 | "gain", 100 | "ab", 101 | "way", 102 | "west", 103 | "a .", 104 | "b .", 105 | "c .", 106 | "d .", 107 | "e .", 108 | "f .", 109 | "g .", 110 | "h .", 111 | "t", 112 | "a", 113 | "h", 114 | "no change", 115 | "men", 116 | "soldier", 117 | "pie", 118 | "bc", 119 | "excess", 120 | "st", 121 | "inches", 122 | "noon", 123 | "percent", 124 | "by", 125 | "gal", 126 | "kmh", 127 | "c", 128 | "acre", 129 | "rise", 130 | "a . m", 131 | "th", 132 | "π r 2", 133 | "sq", 134 | "mark", 135 | "l", 136 | "toy", 137 | "coin", 138 | "sq . m", 139 | "gallon", 140 | "° f", 141 | "profit", 142 | "minw", 143 | "yr", 144 | "women", 145 | "feet", 146 | "am", 147 | "pm", 148 | "hr", 149 | "cu cm", 150 | "square", 151 | "v â € ™", 152 | "are", 153 | "rupee", 154 | "rounds", 155 | "cubic", 156 | "cc", 157 | "mtr", 158 | "s", 159 | "ohm", 160 | "number", 161 | "kmph", 162 | "day", 163 | "hour", 164 | "minute", 165 | "min", 166 | "second", 167 | "man", 168 | "woman", 169 | "sec", 170 | "cube", 171 | "mt", 172 | "sq inch", 173 | "mp", 174 | "∏ cm ³", 175 | "hectare", 176 | "more", 177 | "sec", 178 | "unit", 179 | "cu . m", 180 | "cm 2", 181 | "rs .", 182 | "rs", 183 | "kg", 184 | "g", 185 | "month", 186 | "km", 187 | "m", 188 | "cm", 189 | "mm", 190 | "apple", 191 | "liter", 192 | "loss", 193 | "yard", 194 | "pure", 195 | "year", 196 | "increase", 197 | "decrease", 198 | "d", 199 | "less", 200 | "Surface", 201 | "litre", 202 | "pi sq m", 203 | "s .", 204 | "metre", 205 | "meter", 206 | "inch", 207 | ] 208 | 209 | unit_texts.extend([t + "s" for t in unit_texts]) 210 | 211 | 212 | def strip_string(string): 213 | string = str(string).strip() 214 | # linebreaks 215 | string = string.replace("\n", "") 216 | 217 | # right "." 218 | string = string.rstrip(".") 219 | 220 | # remove inverse spaces 221 | # replace \\ with \ 222 | string = string.replace("\\!", "") 223 | # string = string.replace("\\ ", "") 224 | # string = string.replace("\\\\", "\\") 225 | 226 | # matrix 227 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) 228 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) 229 | string = string.replace("bmatrix", "pmatrix") 230 | 231 | # replace tfrac and dfrac with frac 232 | string = string.replace("tfrac", "frac") 233 | string = string.replace("dfrac", "frac") 234 | string = ( 235 | string.replace("\\neq", "\\ne") 236 | .replace("\\leq", "\\le") 237 | .replace("\\geq", "\\ge") 238 | ) 239 | 240 | # remove \left and \right 241 | string = string.replace("\\left", "") 242 | string = string.replace("\\right", "") 243 | string = string.replace("\\{", "{") 244 | string = string.replace("\\}", "}") 245 | 246 | # Remove unit: miles, dollars if after is not none 247 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 248 | if _string != "" and _string != string: 249 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 250 | string = _string 251 | 252 | for unit_text in unit_texts: 253 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character 254 | # the suffix should be either the end of the string or a non-alphanumeric character 255 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) 256 | if _string != "": 257 | string = _string 258 | 259 | # Remove circ (degrees) 260 | string = string.replace("^{\\circ}", "") 261 | string = string.replace("^\\circ", "") 262 | 263 | # remove dollar signs 264 | string = string.replace("\\$", "") 265 | string = string.replace("$", "") 266 | string = string.replace("\\(", "").replace("\\)", "") 267 | 268 | # convert word number to digit 269 | string = convert_word_number(string) 270 | 271 | # replace "\\text{...}" to "..." 272 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string) 273 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: 274 | string = string.replace(key, "") 275 | string = string.replace("\\emptyset", r"{}") 276 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") 277 | 278 | # remove percentage 279 | string = string.replace("\\%", "") 280 | string = string.replace("\%", "") 281 | string = string.replace("%", "") 282 | 283 | months = r"\b(January|February|March|April|May|June|July|August|September|October|November|December)\b" 284 | string = re.sub(months, "", string, flags=re.IGNORECASE) 285 | 286 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 287 | string = string.replace(" .", " 0.") 288 | string = string.replace("{.", "{0.") 289 | 290 | # cdot 291 | # string = string.replace("\\cdot", "") 292 | if ( 293 | string.startswith("{") 294 | and string.endswith("}") 295 | and string.isalnum() 296 | or string.startswith("(") 297 | and string.endswith(")") 298 | and string.isalnum() 299 | or string.startswith("[") 300 | and string.endswith("]") 301 | and string.isalnum() 302 | ): 303 | string = string[1:-1] 304 | 305 | # inf 306 | string = string.replace("infinity", "\\infty") 307 | if "\\infty" not in string: 308 | string = string.replace("inf", "\\infty") 309 | string = string.replace("+\\inity", "\\infty") 310 | 311 | # and 312 | string = string.replace("and", "") 313 | string = string.replace("\\mathbf", "") 314 | 315 | # use regex to remove \mbox{...} 316 | string = re.sub(r"\\mbox{.*?}", "", string) 317 | 318 | # quote 319 | string.replace("'", "") 320 | string.replace('"', "") 321 | 322 | # i, j 323 | if "j" in string and "i" not in string: 324 | string = string.replace("j", "i") 325 | 326 | # replace a.000b where b is not number or b is end, with ab, use regex 327 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) 328 | string = re.sub(r"(\d+)\.0*$", r"\1", string) 329 | 330 | # if empty, return empty string 331 | if len(string) == 0: 332 | return string 333 | if string[0] == ".": 334 | string = "0" + string 335 | 336 | # to consider: get rid of e.g. "k = " or "q = " at beginning 337 | if len(string.split("=")) == 2: 338 | if len(string.split("=")[0]) <= 2: 339 | string = string.split("=")[1] 340 | 341 | string = _fix_sqrt(string) 342 | string = string.replace(" ", "") 343 | 344 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 345 | string = _fix_fracs(string) 346 | 347 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 348 | string = _fix_a_slash_b(string) 349 | 350 | return string 351 | 352 | 353 | def extract_multi_choice_answer(pred_str): 354 | # TODO: SFT models 355 | if "Problem:" in pred_str: 356 | pred_str = pred_str.split("Problem:", 1)[0] 357 | pred_str = pred_str.replace("choice is", "answer is") 358 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower()) 359 | if patt is not None: 360 | return patt.group("ans").upper() 361 | return "placeholder" 362 | 363 | 364 | direct_answer_trigger_for_fewshot = ("choice is", "answer is") 365 | 366 | 367 | def choice_answer_clean(pred: str): 368 | pred = pred.strip("\n") 369 | 370 | # Determine if this is ICL, if so, use \n\n to split the first chunk. 371 | ICL = False 372 | for trigger in direct_answer_trigger_for_fewshot: 373 | if pred.count(trigger) > 1: 374 | ICL = True 375 | if ICL: 376 | pred = pred.split("\n\n")[0] 377 | 378 | # Split the trigger to find the answer. 379 | preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) 380 | if len(preds) > 1: 381 | answer_flag = True 382 | pred = preds[-1] 383 | else: 384 | answer_flag = False 385 | 386 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 387 | 388 | # Clean the answer based on the dataset 389 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 390 | if tmp: 391 | pred = tmp 392 | else: 393 | pred = [pred.strip().strip(".")] 394 | 395 | if len(pred) == 0: 396 | pred = "" 397 | else: 398 | if answer_flag: 399 | # choose the first element in list ... 400 | pred = pred[0] 401 | else: 402 | # choose the last e 403 | pred = pred[-1] 404 | 405 | # Remove the period at the end, again! 406 | pred = pred.rstrip(".").rstrip("/") 407 | 408 | return pred 409 | 410 | 411 | def find_box(pred_str: str): 412 | ans = pred_str.split("boxed")[-1] 413 | if not ans: 414 | return "" 415 | if ans[0] == "{": 416 | stack = 1 417 | a = "" 418 | for c in ans[1:]: 419 | if c == "{": 420 | stack += 1 421 | a += c 422 | elif c == "}": 423 | stack -= 1 424 | if stack == 0: 425 | break 426 | a += c 427 | else: 428 | a += c 429 | else: 430 | a = ans.split("$")[0].strip() 431 | return a 432 | 433 | 434 | def clean_units(pred_str: str): 435 | """Clean the units in the number.""" 436 | 437 | def convert_pi_to_number(code_string): 438 | code_string = code_string.replace("\\pi", "π") 439 | # Replace \pi or π not preceded by a digit or } with 3.14 440 | code_string = re.sub(r"(? "3*3.14" 442 | code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) 443 | # Handle cases where π is within braces or followed by a multiplication symbol 444 | # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14" 445 | code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) 446 | code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) 447 | return code_string 448 | 449 | pred_str = convert_pi_to_number(pred_str) 450 | pred_str = pred_str.replace("%", "/100") 451 | pred_str = pred_str.replace("$", "") 452 | pred_str = pred_str.replace("¥", "") 453 | pred_str = pred_str.replace("°C", "") 454 | pred_str = pred_str.replace(" C", "") 455 | pred_str = pred_str.replace("°", "") 456 | return pred_str 457 | 458 | 459 | def extract_theoremqa_answer(pred: str, answer_flag: bool = True): 460 | if any([option in pred.lower() for option in ["yes", "true"]]): 461 | pred = "True" 462 | elif any([option in pred.lower() for option in ["no", "false"]]): 463 | pred = "False" 464 | elif any( 465 | [ 466 | option in pred.lower() 467 | for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"] 468 | ] 469 | ): 470 | pass 471 | else: 472 | # Some of the models somehow get used to boxed output from pre-training 473 | if "boxed" in pred: 474 | pred = find_box(pred) 475 | 476 | if answer_flag: 477 | # Extract the numbers out of the string 478 | pred = pred.split("=")[-1].strip() 479 | pred = clean_units(pred) 480 | try: 481 | tmp = str(latex2sympy(pred)) 482 | pred = str(eval(tmp)) 483 | except Exception: 484 | if re.match(r"-?[\d\.]+\s\D+$", pred): 485 | pred = pred.split(" ")[0] 486 | elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): 487 | pred = pred.split(" ")[0] 488 | else: 489 | # desparate search over the last number 490 | preds = re.findall(r"-?\d*\.?\d+", pred) 491 | if len(preds) >= 1: 492 | pred = preds[-1] 493 | else: 494 | pred = "" 495 | 496 | return pred 497 | 498 | 499 | def extract_answer(pred_str, use_last_number=True): 500 | pred_str = pred_str.replace("\u043a\u0438", "") 501 | 502 | pred = "" 503 | 504 | if "boxed" in pred_str: 505 | ans = pred_str.split("boxed")[-1] 506 | if len(ans) == 0: 507 | return "" 508 | elif ans[0] == "{": 509 | stack = 1 510 | a = "" 511 | for c in ans[1:]: 512 | if c == "{": 513 | stack += 1 514 | a += c 515 | elif c == "}": 516 | stack -= 1 517 | if stack == 0: 518 | break 519 | a += c 520 | else: 521 | a += c 522 | else: 523 | a = ans.split("$")[0].strip() 524 | pred = a 525 | 526 | # multiple line 527 | # pred = pred.split("\n")[0] 528 | pred = re.sub(r"\n\s*", "", pred) 529 | if pred != "" and pred[0] == ":": 530 | pred = pred[1:] 531 | if pred != "" and pred[-1] == ".": 532 | pred = pred[:-1] 533 | if pred != "" and pred[-1] == "/": 534 | pred = pred[:-1] 535 | # pred = strip_string(pred) 536 | return pred 537 | 538 | 539 | STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] 540 | 541 | 542 | def parse_ground_truth(example: Dict[str, Any], data_name): 543 | if "answer" in example: 544 | return None, str(example["answer"]) 545 | else: 546 | return None, None 547 | 548 | 549 | def parse_question(example, data_name=None): 550 | question = "" 551 | for key in ["question", "problem", "Question", "input"]: 552 | if key in example: 553 | question = example[key] 554 | break 555 | 556 | return question.strip() 557 | 558 | 559 | def run_execute(executor, result, prompt_type, data_name, execute=False): 560 | if not result or result == "error": 561 | return None, None 562 | report = None 563 | 564 | if "program_only" in prompt_type: 565 | prediction = extract_program_output(result) 566 | elif prompt_type in ["pot", "pal"] and execute: 567 | code = extract_program(result) 568 | prediction, report = executor.apply(code) 569 | else: 570 | prediction = extract_answer(result, data_name) 571 | 572 | # prediction = strip_string(prediction, skip_unit=data_name == "carp_en") 573 | prediction = strip_string(prediction) 574 | return prediction, report 575 | 576 | 577 | def _test_extract_answer(): 578 | pass 579 | 580 | 581 | if __name__ == "__main__": 582 | _test_extract_answer() 583 | -------------------------------------------------------------------------------- /utils/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: 3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC 4 | - https://github.com/openai/prm800k 5 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py 6 | - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py 7 | """ 8 | 9 | import re 10 | import regex 11 | import multiprocessing 12 | from math import isclose 13 | from typing import Union 14 | from collections import defaultdict 15 | 16 | from sympy import simplify, N 17 | from sympy.parsing.sympy_parser import parse_expr 18 | from sympy.parsing.latex import parse_latex 19 | from latex2sympy2 import latex2sympy 20 | 21 | from .parser import strip_string 22 | # from parser import choice_answer_clean, strip_string 23 | 24 | from .math_normalization import check_sympy_equivalence 25 | 26 | import signal 27 | from concurrent.futures import ThreadPoolExecutor 28 | 29 | def timeout_handler(signum, frame): 30 | raise TimeoutError("Function execution timed out") 31 | 32 | 33 | def choice_answer_clean(pred: str): 34 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 35 | # Clean the answer based on the dataset 36 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 37 | if tmp: 38 | pred = tmp 39 | else: 40 | pred = [pred.strip().strip(".")] 41 | pred = pred[-1] 42 | # Remove the period at the end, again! 43 | pred = pred.rstrip(".").rstrip("/") 44 | return pred 45 | 46 | 47 | def parse_digits(num): 48 | num = regex.sub(",", "", str(num)) 49 | try: 50 | return float(num) 51 | except: 52 | if num.endswith("%"): 53 | num = num[:-1] 54 | if num.endswith("\\"): 55 | num = num[:-1] 56 | try: 57 | return float(num) / 100 58 | except: 59 | pass 60 | return None 61 | 62 | 63 | def is_digit(num): 64 | # paired with parse_digits 65 | return parse_digits(num) is not None 66 | 67 | 68 | def str_to_pmatrix(input_str): 69 | input_str = input_str.strip() 70 | matrix_str = re.findall(r"\{.*,.*\}", input_str) 71 | pmatrix_list = [] 72 | 73 | for m in matrix_str: 74 | m = m.strip("{}") 75 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" 76 | pmatrix_list.append(pmatrix) 77 | 78 | return ", ".join(pmatrix_list) 79 | 80 | 81 | single_choice_patterns = [ 82 | r"^\(A\)", r"^\(B\)", r"^\(C\)", r"^\(D\)", r"^\(E\)", # (A) (B) (C) (D) (E) 83 | r"^A\.", r"^B\.", r"^C\.", r"^D\.", r"^E\.", # A. B. C. D. E. 84 | r"^A\)", r"^B\)", r"^C\)", r"^D\)", r"^E\)", # A) B) C) D) E) 85 | r"^\*\*A\*\*", r"^\*\*B\*\*", r"^\*\*C\*\*", r"^\*\*D\*\*", r"^\*\*E\*\*", # **A** **B** **C** **D** **E** 86 | r"^A:", r"^B:", r"^C:", r"^D:", r"^E:", # A: B: C: D: E: 87 | ] 88 | 89 | 90 | def math_equal( 91 | prediction: Union[bool, float, str], 92 | reference: Union[float, str], 93 | include_percentage: bool = True, 94 | is_close: bool = True, 95 | timeout: bool = True, 96 | depth: int = 0, 97 | max_depth: int = 5 98 | ) -> bool: 99 | """ 100 | Exact match of math if and only if: 101 | 1. numerical equal: both can convert to float and are equal 102 | 2. symbolic equal: both can convert to sympy expression and are equal 103 | """ 104 | 105 | if depth > max_depth: 106 | return False 107 | 108 | 109 | if prediction is None or reference is None: 110 | return False 111 | if str(prediction.strip().lower()) == str(reference.strip().lower()): 112 | return True 113 | if ( 114 | reference in ["A", "B", "C", "D", "E"] 115 | and choice_answer_clean(prediction) == reference 116 | ): 117 | return True 118 | 119 | for pattern in single_choice_patterns: 120 | if regex.match(pattern, prediction): 121 | # Remove the pattern from the beginning of the prediction and strip the result 122 | prediction_cleaned = regex.sub(pattern, "", prediction, count=1).strip() 123 | # Recursively call math_equal to check if the cleaned prediction matches the reference 124 | if math_equal(prediction_cleaned, reference, include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth): 125 | return True 126 | 127 | if "," in prediction and "," in reference: 128 | # 按逗号分割并去除空格 129 | pred_parts = [part.strip() for part in prediction.split(",")] 130 | ref_parts = [part.strip() for part in reference.split(",")] 131 | 132 | if len(pred_parts) == len(ref_parts): 133 | # 对两个列表排序后逐个比较,使用 math_equal 递归判断是否相等 134 | pred_parts_sorted = sorted(pred_parts) 135 | ref_parts_sorted = sorted(ref_parts) 136 | 137 | if all( 138 | math_equal(pred_parts_sorted[i], ref_parts_sorted[i], include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth) 139 | for i in range(len(pred_parts_sorted)) 140 | ): 141 | return True 142 | 143 | 144 | 145 | try: # 1. numerical equal 146 | if is_digit(prediction) and is_digit(reference): 147 | prediction = parse_digits(prediction) 148 | reference = parse_digits(reference) 149 | # number questions 150 | if include_percentage: 151 | gt_result = [reference / 100, reference, reference * 100] 152 | else: 153 | gt_result = [reference] 154 | for item in gt_result: 155 | try: 156 | if is_close: 157 | if numeric_equal(prediction, item): 158 | return True 159 | else: 160 | if item == prediction: 161 | return True 162 | except Exception: 163 | continue 164 | return False 165 | except: 166 | pass 167 | 168 | if not prediction and prediction not in [0, False]: 169 | return False 170 | 171 | # 2. symbolic equal 172 | reference = str(reference).strip() 173 | prediction = str(prediction).strip() 174 | 175 | ## pmatrix (amps) 176 | if "pmatrix" in prediction and not "pmatrix" in reference: 177 | reference = str_to_pmatrix(reference) 178 | 179 | ## deal with [], (), {} 180 | pred_str, ref_str = prediction, reference 181 | if ( 182 | prediction.startswith("[") 183 | and prediction.endswith("]") 184 | and not reference.startswith("(") 185 | ) or ( 186 | prediction.startswith("(") 187 | and prediction.endswith(")") 188 | and not reference.startswith("[") 189 | ): 190 | pred_str = pred_str.strip("[]()") 191 | ref_str = ref_str.strip("[]()") 192 | for s in ["{", "}", "(", ")"]: 193 | ref_str = ref_str.replace(s, "") 194 | pred_str = pred_str.replace(s, "") 195 | if pred_str.lower() == ref_str.lower(): 196 | return True 197 | 198 | 199 | ## unordered [a, b] vs. [c, d] 200 | # if ( 201 | # regex.match(r"(\(|\[).+(\)|\])", prediction) is not None 202 | # and regex.match(r"(\(|\[).+(\)|\])", reference) is not None 203 | # ): 204 | # pred_parts = prediction[1:-1].split(",") 205 | # ref_parts = reference[1:-1].split(",") 206 | 207 | # if len(pred_parts) == len(ref_parts): 208 | # # 对两个列表的每个元素进行排序后逐个比较,使用 math_equal 递归判断是否相等 209 | # pred_parts_sorted = sorted(pred_parts, key=lambda x: x.strip()) 210 | # ref_parts_sorted = sorted(ref_parts, key=lambda x: x.strip()) 211 | 212 | # if all( 213 | # [ 214 | # math_equal(pred_parts_sorted[i], ref_parts_sorted[i], include_percentage, is_close, timeout=timeout) 215 | # for i in range(len(pred_parts_sorted)) 216 | # ] 217 | # ): 218 | # return True 219 | 220 | 221 | ## [a, b] vs. [c, d], return a==c and b==d 222 | if ( 223 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None 224 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None 225 | ): 226 | pred_parts = prediction[1:-1].split(",") 227 | ref_parts = reference[1:-1].split(",") 228 | if len(pred_parts) == len(ref_parts): 229 | if all( 230 | [ 231 | math_equal( 232 | pred_parts[i], ref_parts[i], include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth 233 | ) 234 | for i in range(len(pred_parts)) 235 | ] 236 | ): 237 | return True 238 | if ( 239 | ( 240 | prediction.startswith("\\begin{pmatrix}") 241 | or prediction.startswith("\\begin{bmatrix}") 242 | ) 243 | and ( 244 | prediction.endswith("\\end{pmatrix}") 245 | or prediction.endswith("\\end{bmatrix}") 246 | ) 247 | and ( 248 | reference.startswith("\\begin{pmatrix}") 249 | or reference.startswith("\\begin{bmatrix}") 250 | ) 251 | and ( 252 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") 253 | ) 254 | ): 255 | pred_lines = [ 256 | line.strip() 257 | for line in prediction[ 258 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 259 | ].split("\\\\") 260 | if line.strip() 261 | ] 262 | ref_lines = [ 263 | line.strip() 264 | for line in reference[ 265 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 266 | ].split("\\\\") 267 | if line.strip() 268 | ] 269 | matched = True 270 | if len(pred_lines) == len(ref_lines): 271 | for pred_line, ref_line in zip(pred_lines, ref_lines): 272 | pred_parts = pred_line.split("&") 273 | ref_parts = ref_line.split("&") 274 | if len(pred_parts) == len(ref_parts): 275 | if not all( 276 | [ 277 | math_equal( 278 | pred_parts[i], 279 | ref_parts[i], 280 | include_percentage, 281 | is_close, 282 | timeout=timeout, 283 | depth=depth+1, 284 | max_depth=max_depth 285 | ) 286 | for i in range(len(pred_parts)) 287 | ] 288 | ): 289 | matched = False 290 | break 291 | else: 292 | matched = False 293 | if not matched: 294 | break 295 | else: 296 | matched = False 297 | if matched: 298 | return True 299 | 300 | if prediction.count("=") == 1 and reference.count("=") == 1: 301 | pred = prediction.split("=") 302 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 303 | ref = reference.split("=") 304 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 305 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 306 | return True 307 | elif ( 308 | prediction.count("=") == 1 309 | and len(prediction.split("=")[0].strip()) <= 2 310 | and "=" not in reference 311 | ): 312 | if math_equal( 313 | prediction.split("=")[1], reference, include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth 314 | ): 315 | return True 316 | elif ( 317 | reference.count("=") == 1 318 | and len(reference.split("=")[0].strip()) <= 2 319 | and "=" not in prediction 320 | ): 321 | if math_equal( 322 | prediction, reference.split("=")[1], include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth 323 | ): 324 | return True 325 | 326 | if timeout: 327 | if call_with_timeout(symbolic_equal_process, prediction, reference): 328 | return True 329 | # try: 330 | # if call_with_timeout(symbolic_equal, prediction, reference, timeout=1): 331 | # return True 332 | # except TimeoutError: 333 | # return False 334 | else: 335 | if symbolic_equal(prediction, reference): 336 | return True 337 | 338 | return False 339 | 340 | 341 | def math_equal_process(param): 342 | return math_equal(param[-2], param[-1]) 343 | 344 | 345 | def numeric_equal(prediction: float, reference: float): 346 | # Note that relative tolerance has significant impact 347 | # on the result of the synthesized GSM-Hard dataset 348 | # if reference.is_integer(): 349 | # return isclose(reference, round(prediction), abs_tol=1e-4) 350 | # else: 351 | # prediction = round(prediction, len(str(reference).split(".")[-1])) 352 | 353 | # return isclose(reference, prediction, rel_tol=1e-4) 354 | return isclose(reference, prediction, abs_tol=1e-4) 355 | 356 | 357 | def symbolic_equal(a, b): 358 | def _parse(s): 359 | for f in [parse_latex, parse_expr, latex2sympy]: 360 | try: 361 | return f(s.replace("\\\\", "\\")) 362 | except: 363 | try: 364 | return f(s) 365 | except: 366 | pass 367 | return s 368 | 369 | a = _parse(a) 370 | b = _parse(b) 371 | 372 | # direct equal 373 | try: 374 | if str(a) == str(b) or a == b: 375 | return True 376 | except: 377 | pass 378 | 379 | # simplify equal 380 | try: 381 | if a.equals(b) or simplify(a - b) == 0: 382 | return True 383 | except: 384 | pass 385 | 386 | # equation equal 387 | try: 388 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): 389 | return True 390 | except: 391 | pass 392 | 393 | try: 394 | if numeric_equal(float(N(a)), float(N(b))): 395 | return True 396 | except: 397 | pass 398 | 399 | # matrix 400 | try: 401 | # if a and b are matrix 402 | if a.shape == b.shape: 403 | _a = a.applyfunc(lambda x: round(x, 3)) 404 | _b = b.applyfunc(lambda x: round(x, 3)) 405 | if _a.equals(_b): 406 | return True 407 | except: 408 | pass 409 | 410 | return False 411 | 412 | 413 | def symbolic_equal_process(a, b, output_queue): 414 | result = symbolic_equal(a, b) 415 | output_queue.put(result) 416 | 417 | 418 | def call_with_timeout(func, *args, timeout=3, **kwargs): 419 | output_queue = multiprocessing.Queue() 420 | process_args = args + (output_queue,) 421 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 422 | process.start() 423 | process.join(timeout) 424 | 425 | if process.is_alive(): 426 | process.terminate() 427 | process.join() 428 | return False 429 | 430 | return output_queue.get() 431 | 432 | # def call_with_timeout(func, *args, timeout=1, **kwargs): 433 | # # Register the signal function handler 434 | # signal.signal(signal.SIGALRM, timeout_handler) 435 | # # Set the alarm 436 | # signal.alarm(timeout) 437 | 438 | # try: 439 | # result = func(*args, **kwargs) 440 | # signal.alarm(0) # Disable the alarm if function completes in time 441 | # return result 442 | # except TimeoutError: 443 | # return False 444 | # finally: 445 | # # Ensure the alarm is disabled 446 | # signal.alarm(0) 447 | 448 | # def call_with_timeout(func, *args, timeout=1, **kwargs): 449 | # with ThreadPoolExecutor(max_workers=1) as executor: 450 | # future = executor.submit(func, *args, **kwargs) 451 | # try: 452 | # result = future.result(timeout=timeout) # Wait for result with a timeout 453 | # return result 454 | # except TimeoutError: 455 | # return False # Timeout occurred 456 | 457 | 458 | 459 | def check_is_correct(pred, gt, timeout=True): 460 | return math_equal(strip_string(pred), strip_string(gt), timeout=timeout) 461 | 462 | 463 | def math_equal_simple(pred, gt): 464 | pred = strip_string(pred) 465 | gt = strip_string(gt) 466 | flag = False 467 | 468 | try: 469 | pred_expr = latex2sympy(pred) 470 | except: 471 | pred_expr = pred 472 | flag = True 473 | 474 | try: 475 | gt_expr = latex2sympy(gt) 476 | except: 477 | gt_expr = gt 478 | flag = True 479 | 480 | if flag == True: 481 | return pred == gt 482 | 483 | try: 484 | if abs(N(pred_expr) - N(gt_expr)) <= 1e-5: 485 | return True 486 | except: 487 | return False 488 | 489 | return False 490 | 491 | 492 | def check_is_correct_simple(pred, gt, timeout=True): 493 | if timeout: 494 | return call_with_timeout(math_equal_simple, pred, gt, timeout=1) 495 | else: 496 | return math_equal_simple(pred, gt) 497 | 498 | def _test_math_equal(): 499 | 500 | # gt = "\\begin{pmatrix} -10 \\\\ 6 \\end{pmatrix}" 501 | # pred = "\\begin{pmatrix}-10\\\\6\\end{pmatrix}" 502 | 503 | # gt = "(6, -\\frac{3}{8})" 504 | # pred = "\left( 6, -\\frac{3}{8} \\right)" 505 | 506 | # print(math_equal(strip_string(pred), strip_string(gt), timeout=False)) 507 | 508 | s = "(A) 3" 509 | print(choice_answer_clean(s)) 510 | 511 | 512 | 513 | 514 | if __name__ == "__main__": 515 | _test_math_equal() 516 | -------------------------------------------------------------------------------- /data/amc/test.jsonl: -------------------------------------------------------------------------------- 1 | {"id":0,"problem":"Cities $A$ and $B$ are $45$ miles apart. Alicia lives in $A$ and Beth lives in $B$. Alicia bikes towards $B$ at 18 miles per hour. Leaving at the same time, Beth bikes toward $A$ at 12 miles per hour. How many miles from City $A$ will they be when they meet?","answer":27.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_1","question":"Cities $A$ and $B$ are $45$ miles apart. Alicia lives in $A$ and Beth lives in $B$. Alicia bikes towards $B$ at 18 miles per hour. Leaving at the same time, Beth bikes toward $A$ at 12 miles per hour. How many miles from City $A$ will they be when they meet?"} 2 | {"id":1,"problem":"Positive real numbers $x$ and $y$ satisfy $y^3=x^2$ and $(y-x)^2=4y^2$. What is $x+y$?","answer":36.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_10","question":"Positive real numbers $x$ and $y$ satisfy $y^3=x^2$ and $(y-x)^2=4y^2$. What is $x+y$?"} 3 | {"id":2,"problem":"What is the degree measure of the acute angle formed by lines with slopes $2$ and $\\frac{1}{3}$?","answer":45.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_11","question":"What is the degree measure of the acute angle formed by lines with slopes $2$ and $\\frac{1}{3}$?"} 4 | {"id":3,"problem":"What is the value of\n\\[2^3 - 1^3 + 4^3 - 3^3 + 6^3 - 5^3 + \\dots + 18^3 - 17^3?\\]","answer":3159.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_12","question":"What is the value of\n\\[2^3 - 1^3 + 4^3 - 3^3 + 6^3 - 5^3 + \\dots + 18^3 - 17^3?\\]"} 5 | {"id":4,"problem":"In a table tennis tournament every participant played every other participant exactly once. Although there were twice as many right-handed players as left-handed players, the number of games won by left-handed players was $40\\%$ more than the number of games won by right-handed players. (There were no ties and no ambidextrous players.) What is the total number of games played?","answer":36.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_13","question":"In a table tennis tournament every participant played every other participant exactly once. Although there were twice as many right-handed players as left-handed players, the number of games won by left-handed players was $40\\%$ more than the number of games won by right-handed players. (There were no ties and no ambidextrous players.) What is the total number of games played?"} 6 | {"id":5,"problem":"How many complex numbers satisfy the equation $z^5=\\overline{z}$, where $\\overline{z}$ is the conjugate of the complex number $z$?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_14","question":"How many complex numbers satisfy the equation $z^5=\\overline{z}$, where $\\overline{z}$ is the conjugate of the complex number $z$?"} 7 | {"id":7,"problem":"Consider the set of complex numbers $z$ satisfying $|1+z+z^{2}|=4$. The maximum value of the imaginary part of $z$ can be written in the form $\\tfrac{\\sqrt{m}}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":21.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_16","question":"Consider the set of complex numbers $z$ satisfying $|1+z+z^{2}|=4$. The maximum value of the imaginary part of $z$ can be written in the form $\\tfrac{\\sqrt{m}}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 8 | {"id":8,"problem":"Flora the frog starts at 0 on the number line and makes a sequence of jumps to the right. In any one jump, independent of previous jumps, Flora leaps a positive integer distance $m$ with probability $\\frac{1}{2^m}$.\nWhat is the probability that Flora will eventually land at 10? Write the answer as a simplified fraction $\\frac{m}{n}$, find $m+n$","answer":3.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_17","question":"Flora the frog starts at 0 on the number line and makes a sequence of jumps to the right. In any one jump, independent of previous jumps, Flora leaps a positive integer distance $m$ with probability $\\frac{1}{2^m}$.\nWhat is the probability that Flora will eventually land at 10? Write the answer as a simplified fraction $\\frac{m}{n}$, find $m+n$"} 9 | {"id":10,"problem":"What is the product of all solutions to the equation\n\\[\\log_{7x}2023\\cdot \\log_{289x}2023=\\log_{2023x}2023\\]","answer":1.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_19","question":"What is the product of all solutions to the equation\n\\[\\log_{7x}2023\\cdot \\log_{289x}2023=\\log_{2023x}2023\\]"} 10 | {"id":11,"problem":"The weight of $\\frac{1}{3}$ of a large pizza together with $3 \\frac{1}{2}$ cups of orange slices is the same as the weight of $\\frac{3}{4}$ of a large pizza together with $\\frac{1}{2}$ cup of orange slices. A cup of orange slices weighs $\\frac{1}{4}$ of a pound. What is the weight, in pounds, of a large pizza? The answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m-n$?","answer":4.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_2","question":"The weight of $\\frac{1}{3}$ of a large pizza together with $3 \\frac{1}{2}$ cups of orange slices is the same as the weight of $\\frac{3}{4}$ of a large pizza together with $\\frac{1}{2}$ cup of orange slices. A cup of orange slices weighs $\\frac{1}{4}$ of a pound. What is the weight, in pounds, of a large pizza? The answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m-n$?"} 11 | {"id":12,"problem":"Rows 1, 2, 3, 4, and 5 of a triangular array of integers are shown below.\n1\n1 1\n1 3 1\n1 5 5 1\n1 7 11 7 1\nEach row after the first row is formed by placing a 1 at each end of the row, and each interior entry is 1 greater than the sum of the two numbers diagonally above it in the previous row. What is the units digits of the sum of the 2023 numbers in the 2023rd row?","answer":5.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_20","question":"Rows 1, 2, 3, 4, and 5 of a triangular array of integers are shown below.\n1\n1 1\n1 3 1\n1 5 5 1\n1 7 11 7 1\nEach row after the first row is formed by placing a 1 at each end of the row, and each interior entry is 1 greater than the sum of the two numbers diagonally above it in the previous row. What is the units digits of the sum of the 2023 numbers in the 2023rd row?"} 12 | {"id":13,"problem":"If $A$ and $B$ are vertices of a polyhedron, define the distance $d(A,B)$ to be the minimum number of edges of the polyhedron one must traverse in order to connect $A$ and $B$. For example, if $\\overline{AB}$ is an edge of the polyhedron, then $d(A, B) = 1$, but if $\\overline{AC}$ and $\\overline{CB}$ are edges and $\\overline{AB}$ is not an edge, then $d(A, B) = 2$. Let $Q$, $R$, and $S$ be randomly chosen distinct vertices of a regular icosahedron (regular polyhedron made up of 20 equilateral triangles). Find the probability that $d(Q, R) > d(R, S)$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":29.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_21","question":"If $A$ and $B$ are vertices of a polyhedron, define the distance $d(A,B)$ to be the minimum number of edges of the polyhedron one must traverse in order to connect $A$ and $B$. For example, if $\\overline{AB}$ is an edge of the polyhedron, then $d(A, B) = 1$, but if $\\overline{AC}$ and $\\overline{CB}$ are edges and $\\overline{AB}$ is not an edge, then $d(A, B) = 2$. Let $Q$, $R$, and $S$ be randomly chosen distinct vertices of a regular icosahedron (regular polyhedron made up of 20 equilateral triangles). Find the probability that $d(Q, R) > d(R, S)$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 13 | {"id":14,"problem":"Let $f$ be the unique function defined on the positive integers such that \\[\\sum_{d\\mid n}d\\cdot f\\left(\\frac{n}{d}\\right)=1\\] for all positive integers $n$. What is $f(2023)$?","answer":96.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_22","question":"Let $f$ be the unique function defined on the positive integers such that \\[\\sum_{d\\mid n}d\\cdot f\\left(\\frac{n}{d}\\right)=1\\] for all positive integers $n$. What is $f(2023)$?"} 14 | {"id":15,"problem":"How many ordered pairs of positive real numbers $(a,b)$ satisfy the equation\n\\[(1+2a)(2+2b)(2a+b) = 32ab?\\]","answer":1.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_23","question":"How many ordered pairs of positive real numbers $(a,b)$ satisfy the equation\n\\[(1+2a)(2+2b)(2a+b) = 32ab?\\]"} 15 | {"id":16,"problem":"Let $K$ be the number of sequences $A_1$, $A_2$, $\\dots$, $A_n$ such that $n$ is a positive integer less than or equal to $10$, each $A_i$ is a subset of $\\{1, 2, 3, \\dots, 10\\}$, and $A_{i-1}$ is a subset of $A_i$ for each $i$ between $2$ and $n$, inclusive. For example, $\\{\\}$, $\\{5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 6, 7, 9\\}$ is one such sequence, with $n = 5$.What is the remainder when $K$ is divided by $10$?","answer":5.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_24","question":"Let $K$ be the number of sequences $A_1$, $A_2$, $\\dots$, $A_n$ such that $n$ is a positive integer less than or equal to $10$, each $A_i$ is a subset of $\\{1, 2, 3, \\dots, 10\\}$, and $A_{i-1}$ is a subset of $A_i$ for each $i$ between $2$ and $n$, inclusive. For example, $\\{\\}$, $\\{5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 6, 7, 9\\}$ is one such sequence, with $n = 5$.What is the remainder when $K$ is divided by $10$?"} 16 | {"id":17,"problem":"There is a unique sequence of integers $a_1, a_2, \\cdots a_{2023}$ such that\n\\[\\tan2023x = \\frac{a_1 \\tan x + a_3 \\tan^3 x + a_5 \\tan^5 x + \\cdots + a_{2023} \\tan^{2023} x}{1 + a_2 \\tan^2 x + a_4 \\tan^4 x \\cdots + a_{2022} \\tan^{2022} x}\\]whenever $\\tan 2023x$ is defined. What is $a_{2023}?$","answer":-1.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_25","question":"There is a unique sequence of integers $a_1, a_2, \\cdots a_{2023}$ such that\n\\[\\tan2023x = \\frac{a_1 \\tan x + a_3 \\tan^3 x + a_5 \\tan^5 x + \\cdots + a_{2023} \\tan^{2023} x}{1 + a_2 \\tan^2 x + a_4 \\tan^4 x \\cdots + a_{2022} \\tan^{2022} x}\\]whenever $\\tan 2023x$ is defined. What is $a_{2023}?$"} 17 | {"id":18,"problem":"How many positive perfect squares less than $2023$ are divisible by $5$?","answer":8.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_3","question":"How many positive perfect squares less than $2023$ are divisible by $5$?"} 18 | {"id":19,"problem":"How many digits are in the base-ten representation of $8^5 \\cdot 5^{10} \\cdot 15^5$?","answer":18.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_4","question":"How many digits are in the base-ten representation of $8^5 \\cdot 5^{10} \\cdot 15^5$?"} 19 | {"id":20,"problem":"Janet rolls a standard $6$-sided die $4$ times and keeps a running total of the numbers she rolls. What is the probability that at some point, her running total will equal $3$? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":265.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_5","question":"Janet rolls a standard $6$-sided die $4$ times and keeps a running total of the numbers she rolls. What is the probability that at some point, her running total will equal $3$? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 20 | {"id":21,"problem":"Points $A$ and $B$ lie on the graph of $y=\\log_{2}x$. The midpoint of $\\overline{AB}$ is $(6, 2)$. What is the positive difference between the $x$-coordinates of $A$ and $B$? The final answer can be written in the form $m \\sqrt{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":9.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_6","question":"Points $A$ and $B$ lie on the graph of $y=\\log_{2}x$. The midpoint of $\\overline{AB}$ is $(6, 2)$. What is the positive difference between the $x$-coordinates of $A$ and $B$? The final answer can be written in the form $m \\sqrt{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 21 | {"id":22,"problem":"A digital display shows the current date as an $8$-digit integer consisting of a $4$-digit year, followed by a $2$-digit month, followed by a $2$-digit date within the month. For example, Arbor Day this year is displayed as 20230428. For how many dates in $2023$ will each digit appear an even number of times in the 8-digital display for that date?","answer":9.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_7","question":"A digital display shows the current date as an $8$-digit integer consisting of a $4$-digit year, followed by a $2$-digit month, followed by a $2$-digit date within the month. For example, Arbor Day this year is displayed as 20230428. For how many dates in $2023$ will each digit appear an even number of times in the 8-digital display for that date?"} 22 | {"id":23,"problem":"Maureen is keeping track of the mean of her quiz scores this semester. If Maureen scores an $11$ on the next quiz, her mean will increase by $1$. If she scores an $11$ on each of the next three quizzes, her mean will increase by $2$. What is the mean of her quiz scores currently?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_8","question":"Maureen is keeping track of the mean of her quiz scores this semester. If Maureen scores an $11$ on the next quiz, her mean will increase by $1$. If she scores an $11$ on each of the next three quizzes, her mean will increase by $2$. What is the mean of her quiz scores currently?"} 23 | {"id":25,"problem":"Mrs. Jones is pouring orange juice into four identical glasses for her four sons. She fills the first three glasses completely but runs out of juice when the fourth glass is only $\\frac{1}{3}$ full. What fraction of a glass must Mrs. Jones pour from each of the first three glasses into the fourth glass so that all four glasses will have the same amount of juice? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_1","question":"Mrs. Jones is pouring orange juice into four identical glasses for her four sons. She fills the first three glasses completely but runs out of juice when the fourth glass is only $\\frac{1}{3}$ full. What fraction of a glass must Mrs. Jones pour from each of the first three glasses into the fourth glass so that all four glasses will have the same amount of juice? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 24 | {"id":26,"problem":"In the $xy$-plane, a circle of radius $4$ with center on the positive $x$-axis is tangent to the $y$-axis at the origin, and a circle with radius $10$ with center on the positive $y$-axis is tangent to the $x$-axis at the origin. What is the slope of the line passing through the two points at which these circles intersect? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_10","question":"In the $xy$-plane, a circle of radius $4$ with center on the positive $x$-axis is tangent to the $y$-axis at the origin, and a circle with radius $10$ with center on the positive $y$-axis is tangent to the $x$-axis at the origin. What is the slope of the line passing through the two points at which these circles intersect? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 25 | {"id":27,"problem":"Calculate the maximum area of an isosceles trapezoid that has legs of length $1$ and one base twice as long as the other. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m^2+n^2$?","answer":13.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_11","question":"Calculate the maximum area of an isosceles trapezoid that has legs of length $1$ and one base twice as long as the other. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m^2+n^2$?"} 26 | {"id":28,"problem":"For complex number $u = a+bi$ and $v = c+di$ (where $i=\\sqrt{-1}$), define the binary operation\n$u \\otimes v = ac + bdi$\nSuppose $z$ is a complex number such that $z\\otimes z = z^{2}+40$. What is $|z|^2$?","answer":50.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_12","question":"For complex number $u = a+bi$ and $v = c+di$ (where $i=\\sqrt{-1}$), define the binary operation\n$u \\otimes v = ac + bdi$\nSuppose $z$ is a complex number such that $z\\otimes z = z^{2}+40$. What is $|z|^2$?"} 27 | {"id":29,"problem":"A rectangular box $P$ has distinct edge lengths $a$, $b$, and $c$. The sum of the lengths of all $12$ edges of $P$ is $13$, the areas of all $6$ faces of $P$ is $\\frac{11}{2}$, and the volume of $P$ is $\\frac{1}{2}$. Find the length of the longest interior diagonal connecting two vertices of $P$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":13.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_13","question":"A rectangular box $P$ has distinct edge lengths $a$, $b$, and $c$. The sum of the lengths of all $12$ edges of $P$ is $13$, the areas of all $6$ faces of $P$ is $\\frac{11}{2}$, and the volume of $P$ is $\\frac{1}{2}$. Find the length of the longest interior diagonal connecting two vertices of $P$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 28 | {"id":30,"problem":"For how many ordered pairs $(a,b)$ of integers does the polynomial $x^3+ax^2+bx+6$ have $3$ distinct integer roots?","answer":5.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_14","question":"For how many ordered pairs $(a,b)$ of integers does the polynomial $x^3+ax^2+bx+6$ have $3$ distinct integer roots?"} 29 | {"id":32,"problem":"In the state of Coinland, coins have values $6,10,$ and $15$ cents. Suppose $x$ is the value in cents of the most expensive item in Coinland that cannot be purchased using these coins with exact change. What is the sum of the digits of $x?$","answer":11.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_16","question":"In the state of Coinland, coins have values $6,10,$ and $15$ cents. Suppose $x$ is the value in cents of the most expensive item in Coinland that cannot be purchased using these coins with exact change. What is the sum of the digits of $x?$"} 30 | {"id":33,"problem":"Triangle $ABC$ has side lengths in arithmetic progression, and the smallest side has length $6.$ If the triangle has an angle of $120^\\circ,$ Find the area of $ABC$. The final answer can be simplified in the form $m \\sqrt{n}$, where $m$ and $n$ are positive integers and $n$ without square factore. What is $m+n$?","answer":18.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_17","question":"Triangle $ABC$ has side lengths in arithmetic progression, and the smallest side has length $6.$ If the triangle has an angle of $120^\\circ,$ Find the area of $ABC$. The final answer can be simplified in the form $m \\sqrt{n}$, where $m$ and $n$ are positive integers and $n$ without square factore. What is $m+n$?"} 31 | {"id":36,"problem":"Carlos went to a sports store to buy running shoes. Running shoes were on sale, with prices reduced by $20\\%$ on every pair of shoes. Carlos also knew that he had to pay a $7.5\\%$ sales tax on the discounted price. He had $$43$ dollars. What is the original (before discount) price of the most expensive shoes he could afford to buy? ","answer":50.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_2","question":"Carlos went to a sports store to buy running shoes. Running shoes were on sale, with prices reduced by $20\\%$ on every pair of shoes. Carlos also knew that he had to pay a $7.5\\%$ sales tax on the discounted price. He had $$43$ dollars. What is the original (before discount) price of the most expensive shoes he could afford to buy? "} 32 | {"id":40,"problem":"When $n$ standard six-sided dice are rolled, the product of the numbers rolled can be any of $936$ possible values. What is $n$?","answer":11.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_23","question":"When $n$ standard six-sided dice are rolled, the product of the numbers rolled can be any of $936$ possible values. What is $n$?"} 33 | {"id":41,"problem":"Suppose that $a$, $b$, $c$ and $d$ are positive integers satisfying all of the following relations.\n\\[abcd=2^6\\cdot 3^9\\cdot 5^7\\]\n\\[\\text{lcm}(a,b)=2^3\\cdot 3^2\\cdot 5^3\\]\n\\[\\text{lcm}(a,c)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(a,d)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(b,c)=2^1\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(b,d)=2^2\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(c,d)=2^2\\cdot 3^3\\cdot 5^2\\]\nWhat is $\\text{gcd}(a,b,c,d)$?","answer":3.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_24","question":"Suppose that $a$, $b$, $c$ and $d$ are positive integers satisfying all of the following relations.\n\\[abcd=2^6\\cdot 3^9\\cdot 5^7\\]\n\\[\\text{lcm}(a,b)=2^3\\cdot 3^2\\cdot 5^3\\]\n\\[\\text{lcm}(a,c)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(a,d)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(b,c)=2^1\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(b,d)=2^2\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(c,d)=2^2\\cdot 3^3\\cdot 5^2\\]\nWhat is $\\text{gcd}(a,b,c,d)$?"} 34 | {"id":43,"problem":"A $3-4-5$ right triangle is inscribed in circle $A$, and a $5-12-13$ right triangle is inscribed in circle $B$. Find the ratio of the area of circle $A$ to the area of circle $B$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":194.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_3","question":"A $3-4-5$ right triangle is inscribed in circle $A$, and a $5-12-13$ right triangle is inscribed in circle $B$. Find the ratio of the area of circle $A$ to the area of circle $B$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"} 35 | {"id":44,"problem":"Jackson's paintbrush makes a narrow strip with a width of $6.5$ millimeters. Jackson has enough paint to make a strip $25$ meters long. How many square centimeters of paper could Jackson cover with paint?","answer":1625.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_4","question":"Jackson's paintbrush makes a narrow strip with a width of $6.5$ millimeters. Jackson has enough paint to make a strip $25$ meters long. How many square centimeters of paper could Jackson cover with paint?"} 36 | {"id":45,"problem":"You are playing a game. A $2 \\times 1$ rectangle covers two adjacent squares (oriented either horizontally or vertically) of a $3 \\times 3$ grid of squares, but you are not told which two squares are covered. Your goal is to find at least one square that is covered by the rectangle. A \"turn\" consists of you guessing a square, after which you are told whether that square is covered by the hidden rectangle. What is the minimum number of turns you need to ensure that at least one of your guessed squares is covered by the rectangle?","answer":4.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_5","question":"You are playing a game. A $2 \\times 1$ rectangle covers two adjacent squares (oriented either horizontally or vertically) of a $3 \\times 3$ grid of squares, but you are not told which two squares are covered. Your goal is to find at least one square that is covered by the rectangle. A \"turn\" consists of you guessing a square, after which you are told whether that square is covered by the hidden rectangle. What is the minimum number of turns you need to ensure that at least one of your guessed squares is covered by the rectangle?"} 37 | {"id":46,"problem":"When the roots of the polynomial \n\\[P(x) = (x-1)^1 (x-2)^2 (x-3)^3 \\cdot \\cdot \\cdot (x-10)^{10}\\]\nare removed from the number line, what remains is the union of $11$ disjoint open intervals. On how many of these intervals is $P(x)$ positive?","answer":6.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_6","question":"When the roots of the polynomial \n\\[P(x) = (x-1)^1 (x-2)^2 (x-3)^3 \\cdot \\cdot \\cdot (x-10)^{10}\\]\nare removed from the number line, what remains is the union of $11$ disjoint open intervals. On how many of these intervals is $P(x)$ positive?"} 38 | {"id":47,"problem":"For how many integers $n$ does the expression\\[\\sqrt{\\frac{\\log (n^2) - (\\log n)^2}{\\log n - 3}}\\]represent a real number, where log denotes the base $10$ logarithm?","answer":901.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_7","question":"For how many integers $n$ does the expression\\[\\sqrt{\\frac{\\log (n^2) - (\\log n)^2}{\\log n - 3}}\\]represent a real number, where log denotes the base $10$ logarithm?"} 39 | {"id":48,"problem":"How many nonempty subsets $B$ of ${0, 1, 2, 3, \\cdots, 12}$ have the property that the number of elements in $B$ is equal to the least element of $B$? For example, $B = {4, 6, 8, 11}$ satisfies the condition.","answer":144.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_8","question":"How many nonempty subsets $B$ of ${0, 1, 2, 3, \\cdots, 12}$ have the property that the number of elements in $B$ is equal to the least element of $B$? For example, $B = {4, 6, 8, 11}$ satisfies the condition."} 40 | {"id":49,"problem":"What is the area of the region in the coordinate plane defined by\n$| | x | - 1 | + | | y | - 1 | \\le 1$?","answer":8.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_9","question":"What is the area of the region in the coordinate plane defined by\n$| | x | - 1 | + | | y | - 1 | \\le 1$?"} 41 | -------------------------------------------------------------------------------- /vllm-deer-qwen3.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") # Ignore all warnings 3 | 4 | import os 5 | import json 6 | import time 7 | import argparse 8 | import sys 9 | import torch 10 | import torch.nn.functional as F 11 | from vllm.outputs import CompletionOutput 12 | from typing import Any, Dict, List 13 | from nltk import ngrams 14 | from collections import Counter 15 | 16 | from transformers import AutoTokenizer 17 | from tqdm import tqdm 18 | from vllm import LLM, SamplingParams 19 | import pdb 20 | 21 | import math 22 | import numpy as np 23 | import random 24 | 25 | def set_seeds(seed=42): 26 | # Set Python built-in random seed 27 | random.seed(seed) 28 | 29 | # Set NumPy random seed 30 | np.random.seed(seed) 31 | 32 | # Set PyTorch CPU random seed 33 | torch.manual_seed(seed) 34 | 35 | # If using GPU (especially CUDA) 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed(seed) # Set seed for current GPU 38 | torch.cuda.manual_seed_all(seed) # Also effective for multi-GPU 39 | 40 | # For better reproducibility, enable cudnn determinism mode 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | # Optional: Set generator (for DataLoader with multi-threading) 45 | g = torch.Generator() 46 | g.manual_seed(seed) 47 | 48 | 49 | 50 | 51 | 52 | def append_jsonl(data, file_path): 53 | """Append results in the list to a .jsonl file.""" 54 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 55 | with open(file_path, 'a', encoding='utf-8') as f: 56 | for item in data: 57 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 58 | 59 | def write_jsonl(data, file_path): 60 | """Write results in the list to a .jsonl file.""" 61 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 62 | with open(file_path, 'w', encoding='utf-8') as f: 63 | for item in data: 64 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 65 | 66 | def read_jsonl(file_path): 67 | """Read .jsonl file and return a list of dictionaries.""" 68 | data = [] 69 | if not os.path.exists(file_path): 70 | print(f"Warning: Dataset file not found at {file_path}") 71 | return data 72 | with open(file_path, 'r', encoding='utf-8') as f: 73 | for line in f: 74 | data.append(json.loads(line.strip())) 75 | return data 76 | 77 | 78 | def seq_rep_n(last_thinking, cur_thinking, rep, n=1): 79 | 80 | 81 | pred = last_thinking 82 | target = cur_thinking 83 | 84 | pred_tokens = pred.split(' ') 85 | target_token = target.split(' ') 86 | 87 | ngs_pred = [ng for ng in ngrams(pred_tokens, n)] 88 | ngs_know = [ng for ng in ngrams(target_token, n)] 89 | intersection = list(set(ngs_pred) & set(ngs_know)) 90 | overlap_num = len(intersection) 91 | 92 | 93 | if overlap_num == len(ngs_pred) and overlap_num == len(ngs_know): 94 | rep += 1 95 | 96 | 97 | return rep 98 | 99 | # Function to calculate average max probability, mimicking Transformers version logic 100 | def calculate_average_max_prob_from_logprobs(logprobs_list, policy='avg2') -> float: 101 | """ 102 | Calculate average max token probability from logprobs list in vLLM CompletionOutput. 103 | Compute from the second generated token to the second-to-last token. 104 | policy: min, avg1: arithmetic mean, avg2: geometric mean 105 | """ 106 | 107 | 108 | 109 | num_tokens = len(logprobs_list) 110 | start_index = 1 111 | end_index = num_tokens 112 | 113 | if num_tokens < 1: 114 | print("Too few tokens to calculate valid average.") 115 | return 0.0 116 | 117 | total_prob_sum = 0.0 118 | log_prob_sum = 0.0 # For geometric mean 119 | count_for_average = 0 120 | min_prob = 1.0 121 | 122 | for i in range(start_index, end_index): 123 | # Ensure index is valid and corresponding logprobs entry is not empty 124 | if i < len(logprobs_list) and logprobs_list[i]: 125 | try: 126 | logprob_obj = list(logprobs_list[i].values())[0] 127 | # Ensure object has .logprob attribute 128 | if hasattr(logprob_obj, 'logprob'): 129 | prob = torch.exp(torch.tensor(logprob_obj.logprob)).item() 130 | if prob < min_prob: 131 | min_prob = prob 132 | #print(prob) 133 | #print(list(logprobs_list[i].values())[0]) 134 | total_prob_sum += prob 135 | log_prob_sum += math.log(max(prob, 1e-10)) 136 | count_for_average += 1 137 | else: 138 | print(f"Warning: Object at logprobs_list[{i}] doesn't have '.logprob' attribute.") 139 | except (IndexError, KeyError, AttributeError) as e: 140 | print(f"Warning: Unable to process logprobs at logprobs_list[{i}]: {e}") 141 | else: 142 | print(f"Warning: logprobs_list[{i}] is empty or invalid.") 143 | # Calculate average 144 | if policy == 'min': 145 | result = min_prob 146 | elif policy == 'avg1': 147 | result = total_prob_sum / count_for_average 148 | elif policy == 'avg2': 149 | result = math.exp(log_prob_sum / count_for_average) 150 | if (list(logprobs_list[-1].values())[-1].decoded_token) == '': 151 | return result 152 | else: 153 | return 0.0 154 | 155 | 156 | def parse_args(): 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument('--model_name_or_path', type=str, default="./DeepSeek-R1-Distill-Qwen-14B/") 159 | parser.add_argument('--dataset_dir', type=str, default="./data/") 160 | parser.add_argument("--dtype", type=str, default="bfloat16") 161 | parser.add_argument("--max-model-len", "--model-context-len", type=int, default=40000, dest="model_context_len") # max-model-len for vllm, should be longer than max_generated_tokens. 162 | parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) 163 | parser.add_argument("--trust-remote-code", action="store_true") 164 | parser.add_argument("--run_time", type=int, default=1) 165 | parser.add_argument("--no_thinking", type=int, default=0) # Calculate the answer confidence at the very beginning of the reasoning process and attempt to exit early. 166 | parser.add_argument("--rep", type=int, default=0) # Exit early when repetition occurs, but it remains to be implemented. (TODO) 167 | parser.add_argument("--points", type=int, default=1) # 1: 'Wait' as thinking transition point. 0: 'Alternatively' as thinking transition point. 168 | parser.add_argument("--af", type=int, default=0) # answer forcing at end of sequence 169 | parser.add_argument("--max_judge_steps", type=int, default=10) # Limit the maximum number of answer attempts to save time cost. 170 | parser.add_argument('--policy', type=str, default="avg2") # Strategy for Calculating Answer Confidence 171 | 172 | parser.add_argument('--threshold', type=float, default=0.95) # The answer confidence threshold used to determine early exit. 173 | parser.add_argument('--max_generated_tokens', '--max-len', type=int, default=16384, dest="max_len") # total token budget 174 | parser.add_argument('--dataset', type=str, default='math') # dataset name 175 | parser.add_argument('--output_path', type=str, default='./outputs') # output path 176 | parser.add_argument('--think_ratio', type=float, default=0.9, help="Ratio of thinking phase to max generated tokens") # Ratio of thinking phase to max generated tokens 177 | parser.add_argument('--batch_size', type=int, default=2000) # vllm batch size, set it to a value above the number of samples in the dataset. 178 | parser.add_argument('--temperature', type=float, default=0.0) 179 | parser.add_argument('--top_p', type=float, default=1.0) 180 | 181 | # Hardcoded 20 182 | parser.add_argument('--prob_check_max_tokens', type=int, default=20, help="Max tokens for probability check phase") # Max tokens for answer inducing 183 | 184 | args = parser.parse_args() 185 | return args 186 | 187 | def main(): 188 | args = parse_args() 189 | args.model_context_len = args.max_len + 8000 190 | print(f"Using vLLM LLM object for direct inference (batch processing)") 191 | print(f"Model path: {args.model_name_or_path}") 192 | print(f"Dataset: {args.dataset}") 193 | print(f"Early exit probability threshold: {args.threshold}") 194 | print(f"Max total generated tokens: {args.max_len}") 195 | print(f"Thinking phase ratio: {args.think_ratio}") 196 | print(f"Batch size: {args.batch_size}") 197 | print(f"Max tokens for probability check phase: {args.prob_check_max_tokens}") 198 | 199 | print("\nInitializing vLLM LLM engine...") 200 | available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') 201 | try: 202 | llm_engine = LLM( 203 | model=args.model_name_or_path, 204 | tensor_parallel_size=len(available_gpus), 205 | dtype=args.dtype, 206 | max_model_len=args.max_len + 8000, 207 | gpu_memory_utilization=args.gpu_memory_utilization, 208 | trust_remote_code=True, 209 | ) 210 | print("vLLM LLM engine initialized successfully.") 211 | 212 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=args.trust_remote_code) 213 | print(f"Successfully loaded tokenizer: {args.model_name_or_path}") 214 | if tokenizer.pad_token is None: 215 | if tokenizer.eos_token is not None: 216 | tokenizer.pad_token = tokenizer.eos_token 217 | else: 218 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 219 | print("Warning: Model has no pad_token or eos_token. Added custom [PAD] token.") 220 | 221 | print(f"Tokenizer using pad_token_id: {tokenizer.pad_token_id}") 222 | 223 | 224 | except Exception as e: 225 | print(f"Failed to initialize vLLM LLM engine or load tokenizer: {e}") 226 | sys.exit(1) 227 | 228 | sys_prompt = ['Please reason step by step, and put your final answer within \\boxed{}.'][0] 229 | dataset_path = f'{args.dataset_dir}/{args.dataset}/test.jsonl' 230 | try: 231 | questions_json = read_jsonl(dataset_path) 232 | if not questions_json: 233 | print(f"Error: No questions loaded from {dataset_path}.") 234 | sys.exit(1) 235 | print(f"Successfully loaded dataset: {dataset_path}, total {len(questions_json)} questions") 236 | except Exception as e: 237 | print(f"Failed to load dataset: {e}") 238 | sys.exit(1) 239 | 240 | model_dir_name = os.path.basename(os.path.normpath(args.model_name_or_path)) 241 | output_dir = f'{args.output_path}/{model_dir_name}/{args.dataset}' 242 | os.makedirs(output_dir, exist_ok=True) 243 | output_file = f'{output_dir}/greedy_p{str(args.threshold)}_ratio{str(args.think_ratio)}_len{str(args.max_len)}_temperature{str(args.temperature)}_run_time{args.run_time}_no_thinking{args.no_thinking}_rep{args.rep}_points{args.points}_policy{args.policy}.jsonl' 244 | 245 | print(f"\nStarting processing, total questions: {len(questions_json)}") 246 | start_time = time.time() 247 | 248 | questions_state = {} # Dictionary to store processing state for each question 249 | last_token_strs = [""] # Strings marking end of thinking 250 | if args.points == 1: 251 | continue_str = "Wait" # String appended to sequence end to indicate continued thinking 252 | else: 253 | continue_str = "Alternatively" # String appended to sequence end to indicate continued thinking 254 | 255 | answer_prompt_str = "\n**Final Answer**\n\\boxed" # Prompt string to guide answer generation 256 | if 'gpqa' in args.dataset: 257 | answer_prompt_str = "\n**Final Answer**\nI believe the final answer, rather than the option, is \\boxed" 258 | 259 | # Get token IDs for stop conditions and strings to append 260 | last_token_ids = [] 261 | for s in last_token_strs: 262 | ids = tokenizer.encode(s, add_special_tokens=False) 263 | if ids: last_token_ids.extend(ids) 264 | last_token_ids = list(set(last_token_ids)) # Remove duplicate IDs 265 | 266 | continue_ids = tokenizer.encode(continue_str, add_special_tokens=False) 267 | if not continue_ids: 268 | print(f"Warning: Unable to tokenize continue string '{continue_str}'. This may affect logic.") 269 | 270 | # Stop tokens for thinking phase generation 271 | generation_stop_tokens = [continue_str] + last_token_strs + [tokenizer.eos_token] 272 | pred_prob_stop_tokens = [''] 273 | 274 | answer_stop_tokens = [tokenizer.eos_token] 275 | 276 | 277 | # Max token limit for thinking phase 278 | think_limit_tokens = int(args.max_len * args.think_ratio) 279 | 280 | for i, question_data in enumerate(questions_json): 281 | 282 | messages = [ 283 | {"role": "system", "content": sys_prompt}, 284 | {"role": "user", "content": question_data['problem']} 285 | ] 286 | formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 287 | if args.no_thinking == 1: 288 | questions_state[i] = { 289 | 'question_data': question_data, 290 | 'state': 'needs_prob_check', # try to exit ai begin 291 | 'formatted_prompt': formatted_prompt, 292 | 'current_full_sequence': formatted_prompt, 293 | 'generated_thinking_history': "", 294 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss", 295 | 'generated_answer_history': "", 296 | 'pred_prob': 0.0, 297 | 'too_long': 0, 298 | 'rep_end': 0, 299 | 'high_prob': 0, 300 | 'regular_end': 0, 301 | 'thinking_steps': 0, 302 | 'output_dict': {}, 303 | 'error_message': None, 304 | 'question_index': i 305 | } 306 | else: 307 | questions_state[i] = { 308 | 'question_data': question_data, 309 | 'state': 'needs_thought_chunk', 310 | 'formatted_prompt': formatted_prompt, 311 | 'current_full_sequence': formatted_prompt, 312 | 'generated_thinking_history': "", 313 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss", 314 | 'generated_answer_history': "", 315 | 'pred_prob': 0.0, 316 | 'too_long': 0, 317 | 'rep_end': 0, 318 | 'high_prob': 0, 319 | 'regular_end': 0, 320 | 'thinking_steps': 0, 321 | 'output_dict': {}, 322 | 'error_message': None, 323 | 'question_index': i 324 | } 325 | 326 | active_questions_indices = list(questions_state.keys()) # List of currently processing question indices 327 | pbar = tqdm(total=len(questions_json), desc="Processing questions") 328 | 329 | print("\nRunning a simple test generation...") 330 | try: 331 | test_outputs = llm_engine.generate(["Hello, world!"], SamplingParams(max_tokens=10, temperature=args.temperature), use_tqdm=False) 332 | if test_outputs and test_outputs[0].outputs: 333 | test_generated_text = test_outputs[0].outputs[0].text 334 | print(f"Test generation successful. Output: '{test_generated_text.strip()}'") 335 | else: 336 | print("Simple test generation failed: LLM generate returned no output.") 337 | except Exception as e: 338 | print(f"Simple test generation failed: {e}") 339 | 340 | # Main processing loop: continue while there are active questions 341 | while active_questions_indices: # indexes [0,1,2,...,n] 342 | batch_prompts = [] # Current batch prompts 343 | batch_sampling_params = [] # Current batch sampling parameters 344 | batch_request_info = [] # Store (question index, step type) for output processing 345 | 346 | current_batch_count = 0 347 | # Create copy of active_questions_indices to allow modifying original list during iteration 348 | current_active_indices_for_batching = active_questions_indices[:] 349 | 350 | # Build current batch 351 | for q_idx in current_active_indices_for_batching: 352 | if current_batch_count >= args.batch_size: 353 | break 354 | 355 | state = questions_state[q_idx] 356 | if state['state'] in ['finished', 'error']: 357 | continue 358 | 359 | prompt_for_batch = None 360 | sampling_params_for_batch = None 361 | step_type = None # 'think', 'prob_check_gen', 'answer' 362 | 363 | try: 364 | # --- Determine prompt and parameters based on state --- 365 | current_full_sequence_tokens = tokenizer.encode(state['current_full_sequence'], add_special_tokens=False) 366 | current_full_sequence_len = len(current_full_sequence_tokens) 367 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False)) 368 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False)) 369 | 370 | # Check context window limit before adding any new content 371 | remaining_context_window = args.model_context_len - current_full_sequence_len 372 | # check window, skip it 373 | if remaining_context_window <= 0: 374 | state['state'] = 'error' 375 | state['error_message'] = f"Exceeded model context window limit ({args.model_context_len}). Current sequence length: {current_full_sequence_len}" 376 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 377 | print(f"\nQuestion {q_idx}: {state['error_message']}") 378 | if q_idx in active_questions_indices: 379 | active_questions_indices.remove(q_idx) 380 | pbar.update(1) 381 | continue 382 | 383 | # Initial state 384 | if state['state'] == 'needs_thought_chunk': 385 | # Calculate max tokens for this generation chunk, considering thinking limit and context window 386 | max_new_tokens_for_thought = min( 387 | think_limit_tokens - current_generated_thinking_tokens, # thinking budget 388 | remaining_context_window 389 | ) 390 | if max_new_tokens_for_thought <= 0: 391 | state['state'] = 'needs_answer' 392 | print(f"\nQuestion {q_idx}: Reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer generation.") 393 | state['too_long'] = 1 394 | continue 395 | 396 | 397 | 398 | prompt_for_batch = state['current_full_sequence'] 399 | if state['thinking_steps'] < args.max_judge_steps: 400 | sampling_params_for_batch = SamplingParams( 401 | max_tokens=max_new_tokens_for_thought, 402 | temperature=args.temperature, 403 | top_p=args.top_p, 404 | stop=generation_stop_tokens 405 | ) 406 | else: 407 | sampling_params_for_batch = SamplingParams( 408 | max_tokens=max_new_tokens_for_thought, 409 | temperature=args.temperature, 410 | top_p=args.top_p, 411 | stop=last_token_strs 412 | ) 413 | step_type = 'think' 414 | 415 | # Check state 416 | elif state['state'] == 'needs_prob_check': 417 | # Build prompt for probability check generation: 418 | prompt_for_prob_check = state['current_full_sequence'] + answer_prompt_str 419 | prob_check_prompt_len = len(tokenizer.encode(prompt_for_prob_check, add_special_tokens=False)) 420 | required_space_for_prob_check = prob_check_prompt_len + args.prob_check_max_tokens 421 | 422 | if required_space_for_prob_check > args.model_context_len: 423 | state['state'] = 'error' 424 | state['error_message'] = f"Probability check generation prompt exceeds context window ({args.model_context_len}). Estimated length: {required_space_for_prob_check}" 425 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 426 | print(f"\nQuestion {q_idx}: {state['error_message']}") 427 | if q_idx in active_questions_indices: 428 | active_questions_indices.remove(q_idx) 429 | pbar.update(1) 430 | continue 431 | 432 | 433 | # Parameters for generating *prediction sequence* in probability check phase 434 | prompt_for_batch = prompt_for_prob_check 435 | sampling_params_for_batch = SamplingParams( 436 | max_tokens=args.prob_check_max_tokens, 437 | #temperature=args.temperature, # Greedy decoding 438 | stop=pred_prob_stop_tokens, # Only predict content inside \boxed{} 439 | logprobs=1, 440 | ) 441 | step_type = 'prob_check_gen' 442 | 443 | 444 | elif state['state'] == 'needs_answer': 445 | # Build final answer prompt 446 | 447 | if state['too_long'] == 1: 448 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n' 449 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n' 450 | 451 | else: 452 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n'#+ answer_prompt_str 453 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n' 454 | 455 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False)) 456 | total_tokens_before_answer_prompt = current_generated_thinking_tokens 457 | # Calculate remaining total budget 458 | remaining_total_budget = args.max_len - total_tokens_before_answer_prompt 459 | max_new_tokens_answer = min( 460 | remaining_total_budget, 461 | args.model_context_len - len_final_answer_prompt 462 | ) 463 | 464 | if max_new_tokens_answer <= 0: 465 | 466 | state['state'] = 'error' 467 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + "\nSkipped answer generation due to length limit."], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps']} 468 | print(f"\nQuestion {q_idx}: Skipped answer generation due to length limit.") 469 | 470 | if q_idx in active_questions_indices: 471 | active_questions_indices.remove(q_idx) 472 | pbar.update(1) 473 | continue 474 | else: 475 | prompt_for_batch = final_answer_prompt 476 | sampling_params_for_batch = SamplingParams( 477 | max_tokens=max_new_tokens_answer, 478 | temperature=args.temperature, 479 | stop=answer_stop_tokens, 480 | top_p=args.top_p, 481 | ) 482 | step_type = 'answer' 483 | 484 | 485 | 486 | 487 | 488 | elif state['state'] == 'answer_forcing': 489 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str 490 | state['generated_thinking_history'] = state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str 491 | 492 | prompt_for_batch = final_answer_prompt 493 | sampling_params_for_batch = SamplingParams( 494 | max_tokens=100, 495 | temperature=args.temperature, 496 | stop=answer_stop_tokens, 497 | top_p=args.top_p, 498 | ) 499 | step_type = 'answer_exit' 500 | 501 | # If execution reaches here and prompt_for_batch is None, there may be state logic issues 502 | # This check generally shouldn't trigger 503 | if prompt_for_batch is None: 504 | state['state'] = 'error' 505 | state['error_message'] = f"Internal error: No prompt generated for state {state['state']}." 506 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 507 | print(f"\nQuestion {q_idx}: {state['error_message']}") 508 | if q_idx in active_questions_indices: 509 | active_questions_indices.remove(q_idx) 510 | pbar.update(1) 511 | continue 512 | 513 | batch_prompts.append(prompt_for_batch) 514 | batch_sampling_params.append(sampling_params_for_batch) 515 | batch_request_info.append((q_idx, step_type)) 516 | current_batch_count += 1 517 | 518 | except Exception as e: 519 | state['state'] = 'error' 520 | state['error_message'] = f"Error preparing batch request for state '{state['state']}': {e}" 521 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 522 | print(f"\nQuestion {q_idx}: {state['error_message']}") 523 | if q_idx in active_questions_indices: 524 | active_questions_indices.remove(q_idx) 525 | pbar.update(1) 526 | 527 | if not batch_prompts: 528 | # May occur when all remaining active questions have switched to finished/error state or were skipped. 529 | all_stuck = True 530 | for q_idx in active_questions_indices: 531 | if questions_state[q_idx]['state'] not in ['finished', 'error']: 532 | all_stuck = False 533 | break 534 | if all_stuck: 535 | print("Warning: Batch generated no requests. All remaining questions are completed or in error state.") 536 | break 537 | else: 538 | print("Error: Active questions remain but no batch requests generated. Possible logic error.") 539 | for q_idx in list(active_questions_indices): 540 | state = questions_state[q_idx] 541 | if state['state'] not in ['finished', 'error']: 542 | state['state'] = 'error' 543 | state['error_message'] = "Processing aborted: Unable to generate request in batch loop." 544 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 545 | print(f"\nQuestion {q_idx}: Marked as error due to processing abort.") 546 | if q_idx in active_questions_indices: # Should be, but double-checking 547 | active_questions_indices.remove(q_idx) 548 | pbar.update(1) 549 | break 550 | 551 | batch_outputs = llm_engine.generate(batch_prompts, batch_sampling_params, use_tqdm=False) 552 | 553 | 554 | # --- Process batch outputs --- 555 | for i, output in enumerate(batch_outputs): 556 | q_idx, step_type = batch_request_info[i] 557 | state = questions_state[q_idx] 558 | 559 | if state['state'] in ['finished', 'error']: 560 | continue 561 | 562 | try: 563 | if not output.outputs: # skip, exception handling 564 | # vLLM returned empty output, possibly due to prompt issues, length limits or other internal errors 565 | error_msg = f"vLLM returned empty output for request {output.request_id} (question {q_idx}, step {step_type})." 566 | if hasattr(output, 'error') and output.error: 567 | error_msg += f" vLLM error: {output.error}" 568 | current_full_sequence_len = len(tokenizer.encode(state['current_full_sequence'], add_special_tokens=False)) 569 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False)) 570 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False)) 571 | 572 | 573 | if step_type == 'think': 574 | max_new = min(think_limit_tokens - current_generated_thinking_tokens, args.model_context_len - current_full_sequence_len) 575 | error_msg += f" State: think, attempting to generate {max_new} tokens." 576 | if max_new <= 0: error_msg += " Note: Calculated max new tokens <= 0." 577 | if args.model_context_len - current_full_sequence_len <= 0: error_msg += " Note: Already exceeded context window." 578 | if think_limit_tokens - current_generated_thinking_tokens <= 0: error_msg += " Note: Reached thinking token limit." 579 | 580 | elif step_type == 'prob_check_gen': 581 | prob_check_prompt = state['current_full_sequence'] + answer_prompt_str 582 | prob_check_prompt_len = len(tokenizer.encode(prob_check_prompt, add_special_tokens=False)) 583 | error_msg += f" State: prob_check_gen, prompt length: {prob_check_prompt_len}, attempting to generate {args.prob_check_max_tokens} tokens." 584 | if prob_check_prompt_len + args.prob_check_max_tokens > args.model_context_len: error_msg += " Note: Prompt+generation exceeds context window limit." 585 | 586 | elif step_type == 'answer': 587 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + answer_prompt_str 588 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False)) 589 | total_tokens_before_answer_prompt = initial_prompt_len + current_generated_thinking_tokens 590 | max_new_answer = min(args.max_len - total_tokens_before_answer_prompt, args.model_context_len - len_final_answer_prompt) 591 | error_msg += f" State: answer, prompt length: {len_final_answer_prompt}, attempting to generate {max_new_answer} tokens." 592 | if max_new_answer <= 0: error_msg += " Note: Calculated max new tokens <= 0." 593 | if args.model_context_len - len_final_answer_prompt <= 0: error_msg += " Note: Already exceeded context window." 594 | if args.max_len - total_tokens_before_answer_prompt <= 0: error_msg += " Note: Reached total token limit." 595 | 596 | raise ValueError(error_msg) 597 | 598 | completion_output = output.outputs[0] 599 | generated_text = completion_output.text 600 | generated_ids = completion_output.token_ids 601 | last_token_id = generated_ids[-1] 602 | 603 | rep = seq_rep_n(state['generated_thinking_last_trun'], generated_text, state['rep_end']) 604 | state['rep_end'] = rep 605 | 606 | 607 | if step_type == 'think': 608 | if state['rep_end'] >= 3 and args.rep == 1: 609 | state['state'] = 'needs_answer' 610 | state['generated_thinking_history'] += generated_text 611 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] 612 | #state['rep_end'] = 1 613 | elif last_token_id in last_token_ids: 614 | state['state'] = 'needs_answer' 615 | state['generated_thinking_history'] += generated_text 616 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] 617 | state['regular_end'] = 1 618 | else: 619 | # Append generated thinking chunk 620 | state['generated_thinking_history'] += generated_text 621 | state['generated_thinking_last_trun'] = generated_text 622 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] 623 | state['state'] = 'needs_prob_check' 624 | state['thinking_steps'] += 1 625 | 626 | elif step_type == 'prob_check_gen': 627 | # Get logprobs for probability calculation. 628 | if completion_output.logprobs: 629 | state['pred_prob'] = calculate_average_max_prob_from_logprobs(completion_output.logprobs, args.policy) 630 | 631 | else: 632 | print(f"Warning: No logprobs returned for prob_check_gen for question {q_idx}. Setting pred_prob to 0.0.") 633 | state['pred_prob'] = 0.0 634 | 635 | # Recalculate current thinking history token length before making decision 636 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False)) 637 | thinking_limit_reached = current_generated_thinking_tokens >= think_limit_tokens - 50 638 | 639 | if state['pred_prob'] > args.threshold or thinking_limit_reached: # Third condition: already generated to 640 | # Probability high enough or reached thinking limit, switch to answer phase 641 | state['state'] = 'needs_answer' 642 | if thinking_limit_reached: 643 | print(f"\nQuestion {q_idx}: Actually reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer phase.") 644 | state['too_long'] = 1 645 | else: 646 | print(f"\nQuestion {q_idx}: Reached early exit threshold ({state['pred_prob']:.4f} > {args.threshold}). Switching to answer phase.") 647 | state['high_prob'] = 1 648 | else: 649 | # Probability not high enough, need more thinking 650 | state['state'] = 'needs_thought_chunk' 651 | if not state['current_full_sequence'].strip().endswith(continue_str) and state['thinking_steps'] != 0: 652 | state['current_full_sequence'] += continue_str 653 | state['generated_thinking_history'] += continue_str 654 | print(f"\nQuestion {q_idx}: Early exit threshold not reached ({state['pred_prob']:.4f} <= {args.threshold}), thinking history length ({current_generated_thinking_tokens}/{think_limit_tokens}). Appending '{continue_str}' and continuing thinking.") 655 | 656 | elif step_type == 'answer': 657 | state['generated_answer_history'] += (generated_text) 658 | if last_token_id != tokenizer.eos_token_id and args.af == 1: 659 | state['state'] = 'answer_forcing' 660 | else: 661 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history'] 662 | state['state'] = 'finished' 663 | final_response_text = state['generated_thinking_history'] + state['generated_answer_history'] 664 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']} 665 | if q_idx in active_questions_indices: 666 | active_questions_indices.remove(q_idx) 667 | pbar.update(1) 668 | 669 | elif step_type == 'answer_exit': 670 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + generated_text 671 | state['state'] = 'finished' 672 | final_response_text = state['generated_thinking_history'] + generated_text 673 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']} 674 | if q_idx in active_questions_indices: 675 | active_questions_indices.remove(q_idx) 676 | pbar.update(1) 677 | 678 | except Exception as e: 679 | print(f"\nError processing batch results for question {q_idx} step '{step_type}': {e}") 680 | state['state'] = 'error' 681 | state['error_message'] = f"Error processing batch results for step '{step_type}': {e}" 682 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\nError: " + state['error_message']], 'gold_answer': state['question_data']['answer']} 683 | 684 | if q_idx in active_questions_indices: 685 | active_questions_indices.remove(q_idx) 686 | pbar.update(1) 687 | 688 | 689 | pbar.close() 690 | final_results = [state['output_dict'] for state in questions_state.values() if state['state'] in ['finished', 'error']] 691 | 692 | # Create a mapping from problem text to original index for sorting 693 | problem_to_index = {item['problem']: i for i, item in enumerate(questions_json)} 694 | # Use get method to handle cases where problem text might not be found (though it shouldn't happen) 695 | final_results.sort(key=lambda x: problem_to_index.get(x['question'], len(questions_json))) 696 | 697 | print("\nAll questions processed, saving results...") 698 | try: 699 | write_jsonl(final_results, output_file) 700 | except Exception as e: 701 | print(f"Error saving results to {output_file}: {e}") 702 | 703 | end_time = time.time() 704 | elapsed_time = end_time - start_time 705 | print(f"Evaluation completed! Attempted to process {len(questions_json)} questions in total, successfully recorded {len(final_results)} results, took {elapsed_time:.2f} seconds") 706 | print(f"Results saved to: {output_file}") 707 | 708 | if __name__ == "__main__": 709 | #set_seeds(42) 710 | main() -------------------------------------------------------------------------------- /vllm-deer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") # Ignore all warnings 3 | 4 | import os 5 | import json 6 | import time 7 | import argparse 8 | import sys 9 | import torch 10 | import torch.nn.functional as F 11 | from vllm.outputs import CompletionOutput 12 | from typing import Any, Dict, List 13 | from nltk import ngrams 14 | from collections import Counter 15 | 16 | from transformers import AutoTokenizer 17 | from tqdm import tqdm 18 | from vllm import LLM, SamplingParams 19 | import pdb 20 | 21 | import math 22 | import numpy as np 23 | import random 24 | 25 | def set_seeds(seed=42): 26 | # Set Python built-in random seed 27 | random.seed(seed) 28 | 29 | # Set NumPy random seed 30 | np.random.seed(seed) 31 | 32 | # Set PyTorch CPU random seed 33 | torch.manual_seed(seed) 34 | 35 | # If using GPU (especially CUDA) 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed(seed) # Set seed for current GPU 38 | torch.cuda.manual_seed_all(seed) # Also effective for multi-GPU 39 | 40 | # For better reproducibility, enable cudnn determinism mode 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | # Optional: Set generator (for DataLoader with multi-threading) 45 | g = torch.Generator() 46 | g.manual_seed(seed) 47 | 48 | 49 | 50 | 51 | 52 | def append_jsonl(data, file_path): 53 | """Append results in the list to a .jsonl file.""" 54 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 55 | with open(file_path, 'a', encoding='utf-8') as f: 56 | for item in data: 57 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 58 | 59 | def write_jsonl(data, file_path): 60 | """Write results in the list to a .jsonl file.""" 61 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 62 | with open(file_path, 'w', encoding='utf-8') as f: 63 | for item in data: 64 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 65 | 66 | def read_jsonl(file_path): 67 | """Read .jsonl file and return a list of dictionaries.""" 68 | data = [] 69 | if not os.path.exists(file_path): 70 | print(f"Warning: Dataset file not found at {file_path}") 71 | return data 72 | with open(file_path, 'r', encoding='utf-8') as f: 73 | for line in f: 74 | data.append(json.loads(line.strip())) 75 | return data 76 | 77 | 78 | def seq_rep_n(last_thinking, cur_thinking, rep, n=1): 79 | 80 | 81 | pred = last_thinking 82 | target = cur_thinking 83 | 84 | pred_tokens = pred.split(' ') 85 | target_token = target.split(' ') 86 | 87 | ngs_pred = [ng for ng in ngrams(pred_tokens, n)] 88 | ngs_know = [ng for ng in ngrams(target_token, n)] 89 | intersection = list(set(ngs_pred) & set(ngs_know)) 90 | overlap_num = len(intersection) 91 | 92 | 93 | if overlap_num == len(ngs_pred) and overlap_num == len(ngs_know): 94 | rep += 1 95 | 96 | 97 | return rep 98 | 99 | # Function to calculate average max probability, mimicking Transformers version logic 100 | def calculate_average_max_prob_from_logprobs(logprobs_list, policy='avg2') -> float: 101 | """ 102 | Calculate average max token probability from logprobs list in vLLM CompletionOutput. 103 | Compute from the second generated token to the second-to-last token. 104 | policy: min, avg1: arithmetic mean, avg2: geometric mean 105 | """ 106 | 107 | num_tokens = len(logprobs_list) 108 | start_index = 1 109 | end_index = num_tokens 110 | 111 | if num_tokens < 1: 112 | print("Too few tokens to calculate valid average.") 113 | return 0.0 114 | 115 | total_prob_sum = 0.0 116 | log_prob_sum = 0.0 # For geometric mean 117 | count_for_average = 0 118 | min_prob = 1.0 119 | 120 | for i in range(start_index, end_index): 121 | # Ensure index is valid and corresponding logprobs entry is not empty 122 | if i < len(logprobs_list) and logprobs_list[i]: 123 | try: 124 | logprob_obj = list(logprobs_list[i].values())[0] 125 | # Ensure object has .logprob attribute 126 | if hasattr(logprob_obj, 'logprob'): 127 | prob = torch.exp(torch.tensor(logprob_obj.logprob)).item() 128 | if prob < min_prob: 129 | min_prob = prob 130 | #print(prob) 131 | #print(list(logprobs_list[i].values())[0]) 132 | total_prob_sum += prob 133 | log_prob_sum += math.log(max(prob, 1e-10)) 134 | count_for_average += 1 135 | else: 136 | print(f"Warning: Object at logprobs_list[{i}] doesn't have '.logprob' attribute.") 137 | except (IndexError, KeyError, AttributeError) as e: 138 | print(f"Warning: Unable to process logprobs at logprobs_list[{i}]: {e}") 139 | else: 140 | print(f"Warning: logprobs_list[{i}] is empty or invalid.") 141 | # Calculate average 142 | if policy == 'min': 143 | result = min_prob 144 | elif policy == 'avg1': 145 | result = total_prob_sum / count_for_average 146 | elif policy == 'avg2': 147 | result = math.exp(log_prob_sum / count_for_average) 148 | 149 | return result 150 | 151 | 152 | 153 | def parse_args(): 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--model_name_or_path', type=str, default="./DeepSeek-R1-Distill-Qwen-14B/") 156 | parser.add_argument('--dataset_dir', type=str, default="./data/") 157 | parser.add_argument("--dtype", type=str, default="bfloat16") 158 | parser.add_argument("--max-model-len", "--model-context-len", type=int, default=40000, dest="model_context_len") # max-model-len for vllm, should be longer than max_generated_tokens. 159 | parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) 160 | parser.add_argument("--trust-remote-code", action="store_true") 161 | parser.add_argument("--run_time", type=int, default=1) 162 | parser.add_argument("--no_thinking", type=int, default=0) # Calculate the answer confidence at the very beginning of the reasoning process and attempt to exit early. 163 | parser.add_argument("--rep", type=int, default=0) # Exit early when repetition occurs, but it remains to be implemented. (TODO) 164 | parser.add_argument("--points", type=int, default=1) # 1: 'Wait' as thinking transition point. 0: 'Alternatively' as thinking transition point. 165 | parser.add_argument("--af", type=int, default=0) # answer forcing at end of sequence 166 | parser.add_argument("--max_judge_steps", type=int, default=10) # Limit the maximum number of answer attempts to save time cost. 167 | parser.add_argument('--policy', type=str, default="avg1") # Strategy for Calculating Answer Confidence 168 | 169 | parser.add_argument('--threshold', type=float, default=0.95) # The answer confidence threshold used to determine early exit. 170 | parser.add_argument('--max_generated_tokens', '--max-len', type=int, default=16384, dest="max_len") # total token budget 171 | parser.add_argument('--dataset', type=str, default='math') # dataset name 172 | parser.add_argument('--output_path', type=str, default='./outputs') # output path 173 | parser.add_argument('--think_ratio', type=float, default=0.9, help="Ratio of thinking phase to max generated tokens") # Ratio of thinking phase to max generated tokens 174 | parser.add_argument('--batch_size', type=int, default=2000) # vllm batch size, set it to a value above the number of samples in the dataset. 175 | parser.add_argument('--temperature', type=float, default=0.0) 176 | parser.add_argument('--top_p', type=float, default=1.0) 177 | 178 | # Hardcoded 20 179 | parser.add_argument('--prob_check_max_tokens', type=int, default=20, help="Max tokens for probability check phase") # Max tokens for answer inducing 180 | 181 | args = parser.parse_args() 182 | return args 183 | 184 | def main(): 185 | args = parse_args() 186 | args.model_context_len = args.max_len + 8000 187 | print(f"Using vLLM LLM object for direct inference (batch processing)") 188 | print(f"Model path: {args.model_name_or_path}") 189 | print(f"Dataset: {args.dataset}") 190 | print(f"Early exit probability threshold: {args.threshold}") 191 | print(f"Max total generated tokens: {args.max_len}") 192 | print(f"Thinking phase ratio: {args.think_ratio}") 193 | print(f"Batch size: {args.batch_size}") 194 | print(f"Max tokens for probability check phase: {args.prob_check_max_tokens}") 195 | 196 | print("\nInitializing vLLM LLM engine...") 197 | available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') 198 | try: 199 | llm_engine = LLM( 200 | model=args.model_name_or_path, 201 | tensor_parallel_size=len(available_gpus), 202 | dtype=args.dtype, 203 | max_model_len=args.max_len + 8000, 204 | gpu_memory_utilization=args.gpu_memory_utilization, 205 | trust_remote_code=True, 206 | ) 207 | print("vLLM LLM engine initialized successfully.") 208 | 209 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=args.trust_remote_code) 210 | print(f"Successfully loaded tokenizer: {args.model_name_or_path}") 211 | if tokenizer.pad_token is None: 212 | if tokenizer.eos_token is not None: 213 | tokenizer.pad_token = tokenizer.eos_token 214 | else: 215 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 216 | print("Warning: Model has no pad_token or eos_token. Added custom [PAD] token.") 217 | 218 | print(f"Tokenizer using pad_token_id: {tokenizer.pad_token_id}") 219 | 220 | 221 | except Exception as e: 222 | print(f"Failed to initialize vLLM LLM engine or load tokenizer: {e}") 223 | sys.exit(1) 224 | 225 | sys_prompt = ['Please reason step by step, and put your final answer within \\boxed{}.'][0] 226 | dataset_path = f'{args.dataset_dir}/{args.dataset}/test.jsonl' 227 | try: 228 | questions_json = read_jsonl(dataset_path) 229 | if not questions_json: 230 | print(f"Error: No questions loaded from {dataset_path}.") 231 | sys.exit(1) 232 | print(f"Successfully loaded dataset: {dataset_path}, total {len(questions_json)} questions") 233 | except Exception as e: 234 | print(f"Failed to load dataset: {e}") 235 | sys.exit(1) 236 | 237 | model_dir_name = os.path.basename(os.path.normpath(args.model_name_or_path)) 238 | output_dir = f'{args.output_path}/{model_dir_name}/{args.dataset}' 239 | os.makedirs(output_dir, exist_ok=True) 240 | output_file = f'{output_dir}/greedy_p{str(args.threshold)}_ratio{str(args.think_ratio)}_len{str(args.max_len)}_temperature{str(args.temperature)}_run_time{args.run_time}_no_thinking{args.no_thinking}_rep{args.rep}_points{args.points}_policy{args.policy}.jsonl' 241 | 242 | print(f"\nStarting processing, total questions: {len(questions_json)}") 243 | start_time = time.time() 244 | 245 | questions_state = {} # Dictionary to store processing state for each question 246 | last_token_strs = [""] # Strings marking end of thinking 247 | if args.points == 1: 248 | continue_str = "Wait" # String appended to sequence end to indicate continued thinking 249 | else: 250 | continue_str = "Alternatively" # String appended to sequence end to indicate continued thinking 251 | 252 | answer_prompt_str = "\n**Final Answer**\n\\boxed" # Prompt string to guide answer generation 253 | if 'gpqa' in args.dataset: 254 | answer_prompt_str = "\n**Final Answer**\nI believe the final answer, rather than the option, is \\boxed" 255 | 256 | # Get token IDs for stop conditions and strings to append 257 | last_token_ids = [] 258 | for s in last_token_strs: 259 | ids = tokenizer.encode(s, add_special_tokens=False) 260 | if ids: last_token_ids.extend(ids) 261 | last_token_ids = list(set(last_token_ids)) # Remove duplicate IDs 262 | 263 | continue_ids = tokenizer.encode(continue_str, add_special_tokens=False) 264 | if not continue_ids: 265 | print(f"Warning: Unable to tokenize continue string '{continue_str}'. This may affect logic.") 266 | 267 | # Stop tokens for thinking phase generation 268 | generation_stop_tokens = [continue_str] + last_token_strs + [tokenizer.eos_token] 269 | pred_prob_stop_tokens = [' }', '}\n', '}\n\n', '}.', '}.\n', '}\\', '}}', ')}', ')}.', ')}\n'] # where \boxed{} ends. Used to stop the model from predicting intermediate answers. 270 | 271 | answer_stop_tokens = [tokenizer.eos_token] 272 | 273 | 274 | # Max token limit for thinking phase 275 | think_limit_tokens = int(args.max_len * args.think_ratio) 276 | 277 | for i, question_data in enumerate(questions_json): 278 | 279 | messages = [ 280 | {"role": "system", "content": sys_prompt}, 281 | {"role": "user", "content": question_data['problem']} 282 | ] 283 | formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 284 | if args.no_thinking == 1: 285 | questions_state[i] = { 286 | 'question_data': question_data, 287 | 'state': 'needs_prob_check', # try to exit ai begin 288 | 'formatted_prompt': formatted_prompt, 289 | 'current_full_sequence': formatted_prompt, 290 | 'generated_thinking_history': "", 291 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss", 292 | 'generated_answer_history': "", 293 | 'pred_prob': 0.0, 294 | 'too_long': 0, 295 | 'rep_end': 0, 296 | 'high_prob': 0, 297 | 'regular_end': 0, 298 | 'thinking_steps': 0, 299 | 'output_dict': {}, 300 | 'error_message': None, 301 | 'question_index': i 302 | } 303 | else: 304 | questions_state[i] = { 305 | 'question_data': question_data, 306 | 'state': 'needs_thought_chunk', 307 | 'formatted_prompt': formatted_prompt, 308 | 'current_full_sequence': formatted_prompt, 309 | 'generated_thinking_history': "", 310 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss", 311 | 'generated_answer_history': "", 312 | 'pred_prob': 0.0, 313 | 'too_long': 0, 314 | 'rep_end': 0, 315 | 'high_prob': 0, 316 | 'regular_end': 0, 317 | 'thinking_steps': 0, 318 | 'output_dict': {}, 319 | 'error_message': None, 320 | 'question_index': i 321 | } 322 | 323 | active_questions_indices = list(questions_state.keys()) # List of currently processing question indices 324 | pbar = tqdm(total=len(questions_json), desc="Processing questions") 325 | 326 | print("\nRunning a simple test generation...") 327 | try: 328 | test_outputs = llm_engine.generate(["Hello, world!"], SamplingParams(max_tokens=10, temperature=args.temperature), use_tqdm=False) 329 | if test_outputs and test_outputs[0].outputs: 330 | test_generated_text = test_outputs[0].outputs[0].text 331 | print(f"Test generation successful. Output: '{test_generated_text.strip()}'") 332 | else: 333 | print("Simple test generation failed: LLM generate returned no output.") 334 | except Exception as e: 335 | print(f"Simple test generation failed: {e}") 336 | 337 | # Main processing loop: continue while there are active questions 338 | while active_questions_indices: # indexes [0,1,2,...,n] 339 | batch_prompts = [] # Current batch prompts 340 | batch_sampling_params = [] # Current batch sampling parameters 341 | batch_request_info = [] # Store (question index, step type) for output processing 342 | 343 | current_batch_count = 0 344 | # Create copy of active_questions_indices to allow modifying original list during iteration 345 | current_active_indices_for_batching = active_questions_indices[:] 346 | 347 | # Build current batch 348 | for q_idx in current_active_indices_for_batching: 349 | if current_batch_count >= args.batch_size: 350 | break 351 | 352 | state = questions_state[q_idx] 353 | if state['state'] in ['finished', 'error']: 354 | continue 355 | 356 | prompt_for_batch = None 357 | sampling_params_for_batch = None 358 | step_type = None # 'think', 'prob_check_gen', 'answer' 359 | 360 | try: 361 | # --- Determine prompt and parameters based on state --- 362 | current_full_sequence_tokens = tokenizer.encode(state['current_full_sequence'], add_special_tokens=False) 363 | current_full_sequence_len = len(current_full_sequence_tokens) 364 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False)) 365 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False)) 366 | 367 | # Check context window limit before adding any new content 368 | remaining_context_window = args.model_context_len - current_full_sequence_len 369 | # check window, skip it 370 | if remaining_context_window <= 0: 371 | state['state'] = 'error' 372 | state['error_message'] = f"Exceeded model context window limit ({args.model_context_len}). Current sequence length: {current_full_sequence_len}" 373 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 374 | print(f"\nQuestion {q_idx}: {state['error_message']}") 375 | if q_idx in active_questions_indices: 376 | active_questions_indices.remove(q_idx) 377 | pbar.update(1) 378 | continue 379 | 380 | # Initial state 381 | if state['state'] == 'needs_thought_chunk': 382 | # Calculate max tokens for this generation chunk, considering thinking limit and context window 383 | max_new_tokens_for_thought = min( 384 | think_limit_tokens - current_generated_thinking_tokens, # thinking budget 385 | remaining_context_window 386 | ) 387 | if max_new_tokens_for_thought <= 0: 388 | state['state'] = 'needs_answer' 389 | print(f"\nQuestion {q_idx}: Reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer generation.") 390 | state['too_long'] = 1 391 | continue 392 | 393 | 394 | 395 | prompt_for_batch = state['current_full_sequence'] 396 | if state['thinking_steps'] < args.max_judge_steps: 397 | sampling_params_for_batch = SamplingParams( 398 | max_tokens=max_new_tokens_for_thought, 399 | temperature=args.temperature, 400 | top_p=args.top_p, 401 | stop=generation_stop_tokens 402 | ) 403 | else: 404 | sampling_params_for_batch = SamplingParams( 405 | max_tokens=max_new_tokens_for_thought, 406 | temperature=args.temperature, 407 | top_p=args.top_p, 408 | stop=last_token_strs 409 | ) 410 | step_type = 'think' 411 | 412 | # Check state 413 | elif state['state'] == 'needs_prob_check': 414 | # Build prompt for probability check generation: 415 | prompt_for_prob_check = state['current_full_sequence'] + answer_prompt_str 416 | prob_check_prompt_len = len(tokenizer.encode(prompt_for_prob_check, add_special_tokens=False)) 417 | required_space_for_prob_check = prob_check_prompt_len + args.prob_check_max_tokens 418 | 419 | if required_space_for_prob_check > args.model_context_len: 420 | state['state'] = 'error' 421 | state['error_message'] = f"Probability check generation prompt exceeds context window ({args.model_context_len}). Estimated length: {required_space_for_prob_check}" 422 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 423 | print(f"\nQuestion {q_idx}: {state['error_message']}") 424 | if q_idx in active_questions_indices: 425 | active_questions_indices.remove(q_idx) 426 | pbar.update(1) 427 | continue 428 | 429 | 430 | # Parameters for generating *prediction sequence* in probability check phase 431 | prompt_for_batch = prompt_for_prob_check 432 | sampling_params_for_batch = SamplingParams( 433 | max_tokens=args.prob_check_max_tokens, 434 | #temperature=args.temperature, # Greedy decoding 435 | stop=pred_prob_stop_tokens, # Only predict content inside \boxed{} 436 | logprobs=1, 437 | ) 438 | step_type = 'prob_check_gen' 439 | 440 | 441 | elif state['state'] == 'needs_answer': 442 | # Build final answer prompt 443 | 444 | if state['too_long'] == 1: 445 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n' 446 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n' 447 | 448 | else: 449 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n'#+ answer_prompt_str 450 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n' 451 | 452 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False)) 453 | total_tokens_before_answer_prompt = current_generated_thinking_tokens 454 | # Calculate remaining total budget 455 | remaining_total_budget = args.max_len - total_tokens_before_answer_prompt 456 | max_new_tokens_answer = min( 457 | remaining_total_budget, 458 | args.model_context_len - len_final_answer_prompt 459 | ) 460 | 461 | if max_new_tokens_answer <= 0: 462 | 463 | state['state'] = 'error' 464 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + "\nSkipped answer generation due to length limit."], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps']} 465 | print(f"\nQuestion {q_idx}: Skipped answer generation due to length limit.") 466 | 467 | if q_idx in active_questions_indices: 468 | active_questions_indices.remove(q_idx) 469 | pbar.update(1) 470 | continue 471 | else: 472 | prompt_for_batch = final_answer_prompt 473 | sampling_params_for_batch = SamplingParams( 474 | max_tokens=max_new_tokens_answer, 475 | temperature=args.temperature, 476 | stop=answer_stop_tokens, 477 | top_p=args.top_p, 478 | ) 479 | step_type = 'answer' 480 | 481 | 482 | 483 | 484 | 485 | elif state['state'] == 'answer_forcing': 486 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str 487 | state['generated_thinking_history'] = state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str 488 | 489 | prompt_for_batch = final_answer_prompt 490 | sampling_params_for_batch = SamplingParams( 491 | max_tokens=100, 492 | temperature=args.temperature, 493 | stop=answer_stop_tokens, 494 | top_p=args.top_p, 495 | ) 496 | step_type = 'answer_exit' 497 | 498 | # If execution reaches here and prompt_for_batch is None, there may be state logic issues 499 | # This check generally shouldn't trigger 500 | if prompt_for_batch is None: 501 | state['state'] = 'error' 502 | state['error_message'] = f"Internal error: No prompt generated for state {state['state']}." 503 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 504 | print(f"\nQuestion {q_idx}: {state['error_message']}") 505 | if q_idx in active_questions_indices: 506 | active_questions_indices.remove(q_idx) 507 | pbar.update(1) 508 | continue 509 | 510 | batch_prompts.append(prompt_for_batch) 511 | batch_sampling_params.append(sampling_params_for_batch) 512 | batch_request_info.append((q_idx, step_type)) 513 | current_batch_count += 1 514 | 515 | except Exception as e: 516 | state['state'] = 'error' 517 | state['error_message'] = f"Error preparing batch request for state '{state['state']}': {e}" 518 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 519 | print(f"\nQuestion {q_idx}: {state['error_message']}") 520 | if q_idx in active_questions_indices: 521 | active_questions_indices.remove(q_idx) 522 | pbar.update(1) 523 | 524 | if not batch_prompts: 525 | # May occur when all remaining active questions have switched to finished/error state or were skipped. 526 | all_stuck = True 527 | for q_idx in active_questions_indices: 528 | if questions_state[q_idx]['state'] not in ['finished', 'error']: 529 | all_stuck = False 530 | break 531 | if all_stuck: 532 | print("Warning: Batch generated no requests. All remaining questions are completed or in error state.") 533 | break 534 | else: 535 | print("Error: Active questions remain but no batch requests generated. Possible logic error.") 536 | for q_idx in list(active_questions_indices): 537 | state = questions_state[q_idx] 538 | if state['state'] not in ['finished', 'error']: 539 | state['state'] = 'error' 540 | state['error_message'] = "Processing aborted: Unable to generate request in batch loop." 541 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']} 542 | print(f"\nQuestion {q_idx}: Marked as error due to processing abort.") 543 | if q_idx in active_questions_indices: # Should be, but double-checking 544 | active_questions_indices.remove(q_idx) 545 | pbar.update(1) 546 | break 547 | 548 | batch_outputs = llm_engine.generate(batch_prompts, batch_sampling_params, use_tqdm=False) 549 | 550 | 551 | # --- Process batch outputs --- 552 | for i, output in enumerate(batch_outputs): 553 | q_idx, step_type = batch_request_info[i] 554 | state = questions_state[q_idx] 555 | 556 | if state['state'] in ['finished', 'error']: 557 | continue 558 | 559 | try: 560 | if not output.outputs: # skip, exception handling 561 | # vLLM returned empty output, possibly due to prompt issues, length limits or other internal errors 562 | error_msg = f"vLLM returned empty output for request {output.request_id} (question {q_idx}, step {step_type})." 563 | if hasattr(output, 'error') and output.error: 564 | error_msg += f" vLLM error: {output.error}" 565 | current_full_sequence_len = len(tokenizer.encode(state['current_full_sequence'], add_special_tokens=False)) 566 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False)) 567 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False)) 568 | 569 | 570 | if step_type == 'think': 571 | max_new = min(think_limit_tokens - current_generated_thinking_tokens, args.model_context_len - current_full_sequence_len) 572 | error_msg += f" State: think, attempting to generate {max_new} tokens." 573 | if max_new <= 0: error_msg += " Note: Calculated max new tokens <= 0." 574 | if args.model_context_len - current_full_sequence_len <= 0: error_msg += " Note: Already exceeded context window." 575 | if think_limit_tokens - current_generated_thinking_tokens <= 0: error_msg += " Note: Reached thinking token limit." 576 | 577 | elif step_type == 'prob_check_gen': 578 | prob_check_prompt = state['current_full_sequence'] + answer_prompt_str 579 | prob_check_prompt_len = len(tokenizer.encode(prob_check_prompt, add_special_tokens=False)) 580 | error_msg += f" State: prob_check_gen, prompt length: {prob_check_prompt_len}, attempting to generate {args.prob_check_max_tokens} tokens." 581 | if prob_check_prompt_len + args.prob_check_max_tokens > args.model_context_len: error_msg += " Note: Prompt+generation exceeds context window limit." 582 | 583 | elif step_type == 'answer': 584 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + answer_prompt_str 585 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False)) 586 | total_tokens_before_answer_prompt = initial_prompt_len + current_generated_thinking_tokens 587 | max_new_answer = min(args.max_len - total_tokens_before_answer_prompt, args.model_context_len - len_final_answer_prompt) 588 | error_msg += f" State: answer, prompt length: {len_final_answer_prompt}, attempting to generate {max_new_answer} tokens." 589 | if max_new_answer <= 0: error_msg += " Note: Calculated max new tokens <= 0." 590 | if args.model_context_len - len_final_answer_prompt <= 0: error_msg += " Note: Already exceeded context window." 591 | if args.max_len - total_tokens_before_answer_prompt <= 0: error_msg += " Note: Reached total token limit." 592 | 593 | raise ValueError(error_msg) 594 | 595 | completion_output = output.outputs[0] 596 | generated_text = completion_output.text 597 | generated_ids = completion_output.token_ids 598 | last_token_id = generated_ids[-1] 599 | 600 | rep = seq_rep_n(state['generated_thinking_last_trun'], generated_text, state['rep_end']) 601 | state['rep_end'] = rep 602 | 603 | 604 | if step_type == 'think': 605 | if state['rep_end'] >= 3 and args.rep == 1: 606 | state['state'] = 'needs_answer' 607 | state['generated_thinking_history'] += generated_text 608 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] 609 | #state['rep_end'] = 1 610 | elif last_token_id in last_token_ids: 611 | state['state'] = 'needs_answer' 612 | state['generated_thinking_history'] += generated_text 613 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] 614 | state['regular_end'] = 1 615 | else: 616 | # Append generated thinking chunk 617 | state['generated_thinking_history'] += generated_text 618 | state['generated_thinking_last_trun'] = generated_text 619 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] 620 | state['state'] = 'needs_prob_check' 621 | state['thinking_steps'] += 1 622 | 623 | elif step_type == 'prob_check_gen': 624 | # Get logprobs for probability calculation. 625 | if completion_output.logprobs: 626 | state['pred_prob'] = calculate_average_max_prob_from_logprobs(completion_output.logprobs, args.policy) 627 | 628 | else: 629 | print(f"Warning: No logprobs returned for prob_check_gen for question {q_idx}. Setting pred_prob to 0.0.") 630 | state['pred_prob'] = 0.0 631 | 632 | # Recalculate current thinking history token length before making decision 633 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False)) 634 | thinking_limit_reached = current_generated_thinking_tokens >= think_limit_tokens - 50 635 | 636 | if state['pred_prob'] > args.threshold or thinking_limit_reached: # Third condition: already generated to 637 | # Probability high enough or reached thinking limit, switch to answer phase 638 | state['state'] = 'needs_answer' 639 | if thinking_limit_reached: 640 | print(f"\nQuestion {q_idx}: Actually reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer phase.") 641 | state['too_long'] = 1 642 | else: 643 | print(f"\nQuestion {q_idx}: Reached early exit threshold ({state['pred_prob']:.4f} > {args.threshold}). Switching to answer phase.") 644 | state['high_prob'] = 1 645 | else: 646 | # Probability not high enough, need more thinking 647 | state['state'] = 'needs_thought_chunk' 648 | if not state['current_full_sequence'].strip().endswith(continue_str) and state['thinking_steps'] != 0: 649 | state['current_full_sequence'] += continue_str 650 | state['generated_thinking_history'] += continue_str 651 | print(f"\nQuestion {q_idx}: Early exit threshold not reached ({state['pred_prob']:.4f} <= {args.threshold}), thinking history length ({current_generated_thinking_tokens}/{think_limit_tokens}). Appending '{continue_str}' and continuing thinking.") 652 | 653 | elif step_type == 'answer': 654 | state['generated_answer_history'] += (generated_text) 655 | if last_token_id != tokenizer.eos_token_id and args.af == 1: 656 | state['state'] = 'answer_forcing' 657 | else: 658 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history'] 659 | state['state'] = 'finished' 660 | final_response_text = state['generated_thinking_history'] + state['generated_answer_history'] 661 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']} 662 | if q_idx in active_questions_indices: 663 | active_questions_indices.remove(q_idx) 664 | pbar.update(1) 665 | 666 | elif step_type == 'answer_exit': 667 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + generated_text 668 | state['state'] = 'finished' 669 | final_response_text = state['generated_thinking_history'] + generated_text 670 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']} 671 | if q_idx in active_questions_indices: 672 | active_questions_indices.remove(q_idx) 673 | pbar.update(1) 674 | 675 | except Exception as e: 676 | print(f"\nError processing batch results for question {q_idx} step '{step_type}': {e}") 677 | state['state'] = 'error' 678 | state['error_message'] = f"Error processing batch results for step '{step_type}': {e}" 679 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\nError: " + state['error_message']], 'gold_answer': state['question_data']['answer']} 680 | 681 | if q_idx in active_questions_indices: 682 | active_questions_indices.remove(q_idx) 683 | pbar.update(1) 684 | 685 | 686 | pbar.close() 687 | final_results = [state['output_dict'] for state in questions_state.values() if state['state'] in ['finished', 'error']] 688 | 689 | # Create a mapping from problem text to original index for sorting 690 | problem_to_index = {item['problem']: i for i, item in enumerate(questions_json)} 691 | # Use get method to handle cases where problem text might not be found (though it shouldn't happen) 692 | final_results.sort(key=lambda x: problem_to_index.get(x['question'], len(questions_json))) 693 | 694 | print("\nAll questions processed, saving results...") 695 | try: 696 | write_jsonl(final_results, output_file) 697 | except Exception as e: 698 | print(f"Error saving results to {output_file}: {e}") 699 | 700 | end_time = time.time() 701 | elapsed_time = end_time - start_time 702 | print(f"Evaluation completed! Attempted to process {len(questions_json)} questions in total, successfully recorded {len(final_results)} results, took {elapsed_time:.2f} seconds") 703 | print(f"Results saved to: {output_file}") 704 | 705 | if __name__ == "__main__": 706 | #set_seeds(42) 707 | main() --------------------------------------------------------------------------------