├── images
├── deer.png
├── deer-ds.png
├── deer-main.png
├── deer-res.png
├── WechatIMG.jpeg
└── deer-qwen3.png
├── utils
├── __pycache__
│ ├── grader.cpython-310.pyc
│ ├── parser.cpython-310.pyc
│ ├── utils.cpython-310.pyc
│ ├── examples.cpython-310.pyc
│ ├── data_loader.cpython-310.pyc
│ └── math_normalization.cpython-310.pyc
├── data_loader.py
├── math_normalization.py
├── utils.py
├── parser.py
└── grader.py
├── prompts
└── qwen-instruct
│ ├── __pycache__
│ ├── aime.cpython-310.pyc
│ ├── amc.cpython-310.pyc
│ ├── gpqa.cpython-310.pyc
│ ├── gsm8k.cpython-310.pyc
│ ├── math.cpython-310.pyc
│ ├── aime25.cpython-310.pyc
│ ├── minerva.cpython-310.pyc
│ └── olympiadbench.cpython-310.pyc
│ ├── aime.py
│ ├── amc.py
│ ├── gpqa.py
│ ├── gsm8k.py
│ ├── math.py
│ ├── aime25.py
│ ├── minerva.py
│ └── olympiadbench.py
├── bashes
├── bash-vanilla-deer.sh
├── bash-check-correct.sh
├── bash-vllm-deer-qwen3.sh
└── bash-vllm-deer.sh
├── requirements.txt
├── LICENSE
├── README.md
├── check.py
├── vanilla_deer.py
├── data
├── aime25
│ └── test.jsonl
└── amc
│ └── test.jsonl
├── vllm-deer-qwen3.py
└── vllm-deer.py
/images/deer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer.png
--------------------------------------------------------------------------------
/images/deer-ds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-ds.png
--------------------------------------------------------------------------------
/images/deer-main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-main.png
--------------------------------------------------------------------------------
/images/deer-res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-res.png
--------------------------------------------------------------------------------
/images/WechatIMG.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/WechatIMG.jpeg
--------------------------------------------------------------------------------
/images/deer-qwen3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/images/deer-qwen3.png
--------------------------------------------------------------------------------
/utils/__pycache__/grader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/grader.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/parser.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/parser.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/examples.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/examples.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/data_loader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/data_loader.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/math_normalization.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/utils/__pycache__/math_normalization.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/aime.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/aime.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/amc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/amc.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/gpqa.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/gpqa.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/gsm8k.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/gsm8k.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/math.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/math.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/aime25.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/aime25.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/minerva.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/minerva.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/__pycache__/olympiadbench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iie-ycx/DEER/HEAD/prompts/qwen-instruct/__pycache__/olympiadbench.cpython-310.pyc
--------------------------------------------------------------------------------
/prompts/qwen-instruct/aime.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/amc.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/gpqa.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/gsm8k.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/math.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/aime25.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/minerva.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/prompts/qwen-instruct/olympiadbench.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
4 |
5 | few_shot_prompt = ""
6 |
7 | question_format = """{question}"""
--------------------------------------------------------------------------------
/bashes/bash-vanilla-deer.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | CUDA_VISIBLE_DEVICES='1' \
4 | python ../vanilla_deer.py \
5 | --model_name_or_path ./DeepSeek-R1-Distill-Qwen-14B \
6 | --threshold 0.95 \
7 | --max_len 16384 \
8 | --dataset math \
9 |
--------------------------------------------------------------------------------
/bashes/bash-check-correct.sh:
--------------------------------------------------------------------------------
1 |
2 | python ../check.py \
3 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \
4 | --data_name "math" \
5 | --generation_path "./outputs/DeepSeek-R1-Distill-Qwen-14B/math/greedy_p0.95_ratio0.9_len16385_temperature0.0_run_time1_no_thinking0_rep0_points1.jsonl" \
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # common
2 | vllm<=0.6.1
3 | tqdm
4 | datasets
5 | torch
6 | transformers
7 | python_dateutil
8 | flash_attn
9 |
10 | # math_eval
11 | sympy==1.12
12 | antlr4-python3-runtime==4.11.1 # ! The version needs to be compatible with sympy.
13 | word2number
14 | Pebble
15 | timeout-decorator
16 | latex2sympy2==1.9.1
--------------------------------------------------------------------------------
/bashes/bash-vllm-deer-qwen3.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer-qwen3.py \
4 | --model_name_or_path "./Qwen3-4B" \
5 | --dataset_dir "./data/" \
6 | --output_path "./outputs" \
7 | --dataset "math" \
8 | --threshold 0.95 \
9 | --max_generated_tokens 16000 \
10 | --think_ratio 0.8 \
11 | --policy avg2 \
12 | --batch_size 2000 \
13 | --dtype bfloat16 \
14 | --gpu-memory-utilization 0.9 \
15 |
--------------------------------------------------------------------------------
/bashes/bash-vllm-deer.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer.py \
4 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \
5 | --dataset_dir "./data/" \
6 | --output_path "./outputs" \
7 | --dataset "math" \
8 | --threshold 0.95 \
9 | --max_generated_tokens 16000 \
10 | --policy avg1 \
11 | --think_ratio 0.6 \
12 | --batch_size 2000 \
13 | --dtype bfloat16 \
14 | --gpu-memory-utilization 0.9 \
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 chenxuYang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from datasets import load_dataset, Dataset, concatenate_datasets
5 | from utils.utils import load_jsonl, lower_keys
6 |
7 | def load_data(data_name, split, data_dir='./data'):
8 | data_file = f"{data_dir}/{data_name}/{split}.jsonl"
9 | if os.path.exists(data_file):
10 | examples = list(load_jsonl(data_file))
11 | else:
12 | if data_name == "math":
13 | dataset = load_dataset("competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp")
14 | elif data_name == "theorem-qa":
15 | dataset = load_dataset("wenhu/TheoremQA", split=split)
16 | elif data_name == "gsm8k":
17 | dataset = load_dataset(data_name, split=split)
18 | elif data_name == "gsm-hard":
19 | dataset = load_dataset("reasoning-machines/gsm-hard", split="train")
20 | elif data_name == "svamp":
21 | # evaluate on training set + test set
22 | dataset = load_dataset("ChilleD/SVAMP", split="train")
23 | dataset = concatenate_datasets([dataset, load_dataset("ChilleD/SVAMP", split="test")])
24 | elif data_name == "asdiv":
25 | dataset = load_dataset("EleutherAI/asdiv", split="validation")
26 | dataset = dataset.filter(lambda x: ";" not in x['answer']) # remove multi-answer examples
27 | elif data_name == "mawps":
28 | examples = []
29 | # four sub-tasks
30 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]:
31 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl"))
32 | for example in sub_examples:
33 | example['type'] = data_name
34 | examples.extend(sub_examples)
35 | dataset = Dataset.from_list(examples)
36 | elif data_name == "finqa":
37 | dataset = load_dataset("dreamerdeo/finqa", split=split, name="main")
38 | dataset = dataset.select(random.sample(range(len(dataset)), 1000))
39 | elif data_name == "tabmwp":
40 | examples = []
41 | with open(f"{data_dir}/tabmwp/tabmwp_{split}.json", "r") as f:
42 | data_dict = json.load(f)
43 | examples.extend(data_dict.values())
44 | dataset = Dataset.from_list(examples)
45 | dataset = dataset.select(random.sample(range(len(dataset)), 1000))
46 | elif data_name == "bbh":
47 | examples = []
48 | for data_name in ["reasoning_about_colored_objects", "penguins_in_a_table",\
49 | "date_understanding", "repeat_copy_logic", "object_counting"]:
50 | with open(f"{data_dir}/bbh/bbh/{data_name}.json", "r") as f:
51 | sub_examples = json.load(f)["examples"]
52 | for example in sub_examples:
53 | example['type'] = data_name
54 | examples.extend(sub_examples)
55 | dataset = Dataset.from_list(examples)
56 | else:
57 | raise NotImplementedError(data_name)
58 |
59 | examples = list(dataset)
60 | examples = [lower_keys(example) for example in examples]
61 | dataset = Dataset.from_list(examples)
62 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
63 | dataset.to_json(data_file)
64 |
65 | # add 'idx' in the first column
66 | if 'idx' not in examples[0]:
67 | examples = [{'idx': i, **example} for i, example in enumerate(examples)]
68 |
69 | # dedepulicate & sort
70 | examples = sorted(examples, key=lambda x: x['idx'])
71 | return examples
--------------------------------------------------------------------------------
/utils/math_normalization.py:
--------------------------------------------------------------------------------
1 | # Part of the code is modified from the code snippets provided in "Solving Quantitative Reasoning Problems with Language Models" by Lewkowycz et al.
2 | import pdb
3 | import re
4 | import sympy
5 | import threading
6 | from sympy.parsing.latex import parse_latex
7 | from .parser import strip_string
8 |
9 | SUBSTITUTIONS = [
10 | ('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), (r'\ ', ''), ('\%', '%'),
11 | (' ', ''), ('mbox', 'text'), (',\\text{and}', ','),
12 | ('\\text{and}', ','), ('\\text{m}', '\\text{}')
13 | ]
14 | REMOVED_EXPRESSIONS = [
15 | 'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft',
16 | 'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet',
17 | 'minutes', 'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds',
18 | 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples',
19 | '\\text{s}', '\\text{.}', '\\text{\ns}', '\\text{}^2',
20 | '\\text{}^3', '\\text{\n}', '\\text{}', r'\mathrm{th}',
21 | r'^\circ', r'^{\circ}', r'\;', r',\!', '{,}', '"', '\\dots'
22 | ]
23 |
24 | def is_integer(s):
25 | try:
26 | int(s)
27 | return True
28 | except ValueError:
29 | return False
30 |
31 | def normalize_final_answer(final_answer: str) -> str:
32 | """Normalize a final answer to a quantitative reasoning question."""
33 | final_answer = str(final_answer).split('=')[-1]
34 |
35 | for before, after in SUBSTITUTIONS:
36 | final_answer = final_answer.replace(before, after)
37 | for expr in REMOVED_EXPRESSIONS:
38 | final_answer = final_answer.replace(expr, '')
39 |
40 | # Extract answer that is in LaTeX math, is bold,
41 | # is surrounded by a box, etc.
42 | final_answer = re.sub(r'(.*?)(\$)(.*?)(\$)(.*)', '$\\3$', final_answer)
43 | final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer)
44 | final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
45 | final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
46 | final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
47 |
48 | # Normalize shorthand TeX:
49 | # \fracab -> \frac{a}{b}
50 | # \frac{abc}{bef} -> \frac{abc}{bef}
51 | # \fracabc -> \frac{a}{b}c
52 | # \sqrta -> \sqrt{a}
53 | # \sqrtab -> sqrt{a}b
54 | final_answer = re.sub(
55 | r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer)
56 | final_answer = re.sub(
57 | r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
58 | final_answer = final_answer.replace('$', '')
59 |
60 | # Normalize 100,000 -> 100000
61 | if final_answer.replace(',', '').isdigit():
62 | final_answer = final_answer.replace(',', '')
63 | # 3.0 -> 3
64 | if final_answer.endswith(".0") and final_answer[:-2].isdigit():
65 | final_answer = final_answer[:-2]
66 | # 3.00 -> 3
67 | if final_answer.endswith(".00") and final_answer[:-3].isdigit():
68 | final_answer = final_answer[:-3]
69 | if final_answer.endswith("%") and final_answer[:-1].isdigit():
70 | final_answer = final_answer[:-1]
71 | # A -> a
72 | if final_answer.lower() in ['a', 'b', 'c', 'd', 'e', 'f', 'g']:
73 | final_answer = final_answer.lower()
74 | return final_answer
75 |
76 | def check_sympy_equivalence(formatted_target_str, formatted_prediction_str):
77 | flag = False
78 | try:
79 | target_expr = parse_latex(formatted_target_str)
80 | except:
81 | target_expr = formatted_target_str
82 | flag = True
83 |
84 | try:
85 | prediction_expr = parse_latex(formatted_prediction_str)
86 | except:
87 | prediction_expr = formatted_prediction_str
88 | flag = True
89 |
90 | if flag == True:
91 | return formatted_target_str == formatted_prediction_str
92 |
93 | try:
94 | return sympy.simplify(target_expr - prediction_expr) == 0
95 | except:
96 | return False
97 |
98 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DEER 🦌: Dynamic Early Exit in Reasoning Models
2 | [](https://arxiv.org/abs/2504.15895)
3 | [](https://opensource.org/licenses/MIT)
4 | [](https://www.python.org/)
5 | [](https://huggingface.co/)
6 | [](https://github.com/vllm-project/vllm)
7 |
8 | This is the repository of our paper: [Dynamic Early Exit in Reasoning Models](https://arxiv.org/abs/2504.15895).
9 |
10 |
11 |
12 |
13 | **DEER** monitors model behavior at potential reasoning transition points and dynamically terminates the next reasoning chain’s generation when the model exhibits high confidence in a trial answer. It is consistently effective on 11 cutting-edge reasoning LLMs of varying series and sizes, reducing the length of CoT sequences by an average of **19.1% - 80.1%** while improving accuracy by **0.3% - 5.0%**.
14 |
15 | ---
16 |
17 | ## 🔥 **Latest Updates**
18 | - **[2025/05/20]** Released DEER code for mathematical reasoning tasks (HuggingFace & vLLM).
19 | - **[Coming Soon]** DEER for code generation tasks & Branch-Parallel Decoding Acceleration.
20 |
21 | ---
22 |
23 | ## 🎯 Key Results
24 | Results on 11 reasoning models with 16k token budgets. "Acc" denotes accuracy, "Tok" denotes token count, and "CR" denotes compression rate.
25 |
26 |
27 |
28 | Experimental results presented in bar charts.
29 |
30 |
31 |
32 |
33 |
34 | ---
35 |
36 |
37 | ## 🚀 Quick Start
38 | ### 1. Installation
39 | ```bash
40 | git clone https://github.com/yourusername/DEER.git
41 | cd DEER
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ### 2. DEER on vLLM (Recommended)
46 | Considering efficiency, we recommend reproducing the results using the code based on the **vLLM** framework.
47 |
48 | #### For Most Reasoning Models
49 | ```
50 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer.py \
51 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \
52 | --dataset_dir "./data/" \
53 | --output_path "./outputs" \
54 | --dataset "math" \
55 | --threshold 0.95 \
56 | --max_generated_tokens 16000 \
57 | --think_ratio 0.6 \
58 | --batch_size 2000 \
59 | --policy avg1 \
60 | --dtype bfloat16 \
61 | --gpu-memory-utilization 0.9 \
62 | ```
63 | or run:
64 | ```bash
65 | bash ./bashes/bash-vllm-deer.sh.
66 | ```
67 |
68 |
69 | #### For Qwen3 Models
70 |
71 | ```
72 | CUDA_VISIBLE_DEVICES=1 python ../vllm-deer-qwen3.py \
73 | --model_name_or_path "./Qwen3-4B" \
74 | --dataset_dir "./data/" \
75 | --output_path "./outputs" \
76 | --dataset "math" \
77 | --threshold 0.95 \
78 | --max_generated_tokens 16000 \
79 | --think_ratio 0.8 \
80 | --batch_size 2000 \
81 | --dtype bfloat16 \
82 | --policy avg2 \
83 | --gpu-memory-utilization 0.9 \
84 | ```
85 | or run:
86 | ```bash
87 | bash ./bashes/bash-vllm-deer-qwen3.sh.
88 | ```
89 | In our experiments, we found that Qwen3-series models tend to be over-confident in confidence prediction, so we made some modifications to its implementation.
90 | - The calculation of answer confidence was changed from arithmetic mean to geometric mean.
91 | - An additional condition must be satisfied for early exit: the model must generate <\/think> after the trial answer.
92 |
93 | ### 3. DEER on Transformers
94 |
95 | For inference using HuggingFace Transformers (without vLLM), run:
96 | ```bash
97 | bash ./bashes/bash-vanilla-deer.sh
98 | ```
99 |
100 |
101 | ## 📊 Evaluation
102 |
103 | DEER currently supports evaluation on 7 reasoning benchmarks. The rule-based evaluation for these benchmarks is based on the code implementation from the project [LIMO](https://github.com/GAIR-NLP/LIMO/tree/main).
104 |
105 |
106 | ```
107 | python ../check.py \
108 | --model_name_or_path "./DeepSeek-R1-Distill-Qwen-14B" \
109 | --data_name "math" \
110 | --generation_path "your_output.jsonl" \
111 | ```
112 | or run
113 | ```bash
114 | bash ./bashes/bash-check-correct.sh
115 | ```
116 |
117 |
118 |
119 | ## 📜 Citation
120 | If you use DEER in your research, please cite our paper:
121 | ```bibtex
122 | @article{yang2025dynamic,
123 | title={Dynamic Early Exit in Reasoning Models},
124 | author={Yang, Chenxu and Si, Qingyi and Duan, Yongjie and Zhu, Zheliang and Zhu, Chenyu and Lin, Zheng and Cao, Li and Wang, Weiping},
125 | journal={arXiv preprint arXiv:2504.15895},
126 | year={2025}
127 | }
128 | ```
129 | ## 💬 Community
130 |
131 | Join our WeChat group for discussions:
132 |
133 |
--------------------------------------------------------------------------------
/check.py:
--------------------------------------------------------------------------------
1 | import json
2 | from transformers import AutoTokenizer
3 |
4 | import re
5 | import importlib.util
6 | import os
7 | import argparse
8 |
9 | import random
10 | import time
11 | from datetime import datetime
12 | from tqdm import tqdm
13 | from utils.utils import set_seed, load_jsonl, save_jsonl, construct_prompt
14 | from utils.parser import *
15 | from utils.data_loader import load_data
16 | from utils.math_normalization import *
17 | from utils.grader import *
18 | import pickle
19 | from math import comb
20 | import pdb
21 |
22 |
23 | def parse_list(arg):
24 | return arg.split(',')
25 |
26 | def save_completions(completions, filepath):
27 | with open(filepath, 'wb') as file:
28 | pickle.dump(completions, file)
29 |
30 | def parse_args():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--model_name_or_path', type=str, default="./", help="model dir")
33 | parser.add_argument('--n_sampling', type=int, default=1, help="n for sampling")
34 | parser.add_argument("--k", type=int, default=1, help="Value of k for pass@k calculation")
35 | parser.add_argument("--data_dir", default="./data", type=str)
36 | parser.add_argument('--data_name', type=str, default="math", help='identify how to extract answer')
37 | parser.add_argument("--split", default="test", type=str)
38 | parser.add_argument("--generation_path", default="test", type=str)
39 |
40 | parser.add_argument("--prompt_type", default="qwen-base", type=str)
41 |
42 | args = parser.parse_args()
43 |
44 |
45 |
46 | return args
47 |
48 | def get_conversation_prompt_by_messages(tokenizer, messages):
49 | text = tokenizer.apply_chat_template(
50 | messages,
51 | tokenize=False,
52 | add_generation_prompt=True
53 | )
54 | return text
55 |
56 | def get_three_prompt(prompt_type, data_name):
57 | file_path = os.path.join(".", "prompts", prompt_type, f"{data_name}.py")
58 | if not os.path.exists(file_path):
59 | raise FileNotFoundError(f"File not found: {file_path}")
60 |
61 | spec = importlib.util.spec_from_file_location("dynamic_module", file_path)
62 | module = importlib.util.module_from_spec(spec)
63 | spec.loader.exec_module(module)
64 |
65 | if hasattr(module, 'system_prompt'):
66 | system_prompt = module.system_prompt
67 | else:
68 | raise AttributeError(f"'system_prompt' not found in {file_path}")
69 |
70 | if hasattr(module, 'few_shot_prompt'):
71 | few_shot_prompt = module.few_shot_prompt
72 | else:
73 | raise AttributeError(f"'few_shot_prompt' not found in {file_path}")
74 |
75 | if hasattr(module, 'question_format'):
76 | question_format = module.question_format
77 | else:
78 | raise AttributeError(f"'question_format' not found in {file_path}")
79 |
80 | return system_prompt, few_shot_prompt, question_format
81 |
82 | def read_jsonl(file_path):
83 |
84 | data = []
85 | with open(file_path, 'r', encoding='utf-8') as f:
86 | for line in f:
87 |
88 | json_obj = json.loads(line.strip())
89 | data.append(json_obj)
90 | return data
91 |
92 |
93 |
94 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
95 |
96 | def infer(args):
97 | examples = load_data(args.data_name, args.split, args.data_dir)
98 | file_outputs = read_jsonl(args.generation_path)
99 |
100 |
101 | print("llm generate done")
102 | print(len(file_outputs))
103 |
104 | pass_at_k_list = []
105 | k = args.k
106 |
107 | correct_cnt = 0
108 | for i in tqdm(range(len(file_outputs)), "check correct..."):
109 | d = examples[i]
110 | gt_cot, gt_ans = parse_ground_truth(d, args.data_name)
111 | generated_responses = file_outputs[i]['generated_responses']
112 |
113 |
114 | generated_answers = [extract_answer(generated_response, args.data_name) for generated_response in generated_responses]
115 | is_correct_list = [check_is_correct(generated_answer, gt_ans) for generated_answer in generated_answers]
116 | is_correct = any(is_correct_list)
117 | if is_correct:
118 | #print(i)
119 | correct_cnt += 1
120 | file_outputs[i]['generated_answers'] = generated_answers
121 | file_outputs[i]['gold_answer'] = gt_ans
122 | file_outputs[i]['is_correct'] = is_correct
123 | file_outputs[i]['answers_correctness'] = is_correct_list
124 |
125 | if len(is_correct_list) > 1:
126 | correct_answers = sum(is_correct_list)
127 | n = len(generated_answers)
128 | if correct_answers > 0:
129 | if n - correct_answers < k:
130 | pass_at_k = 1
131 | else:
132 | pass_at_k = 1 - (comb(n - correct_answers, k) / comb(n, k))
133 | pass_at_k_list.append(pass_at_k)
134 | else:
135 | pass_at_k_list.append(0)
136 |
137 |
138 | print(f"correct cnt / total cnt: {correct_cnt}/{len(file_outputs)}")
139 | print(f"Acc: {correct_cnt / len(file_outputs):.4f}")
140 |
141 | if pass_at_k_list:
142 | average_pass_at_k = sum(pass_at_k_list) / len(pass_at_k_list)
143 | print(f"Pass@{k}: {sum(pass_at_k_list)}/{len(pass_at_k_list)} = {average_pass_at_k:.4f}")
144 | else:
145 | print(f"Pass@1: {correct_cnt}/{len(file_outputs)} = {correct_cnt / len(file_outputs):.4f}")
146 |
147 |
148 |
149 | response_length = []
150 | token_num = []
151 | wait_num = []
152 | alt_num = []
153 |
154 | test_num = len(file_outputs)
155 | correct_num = 0
156 | for data in file_outputs:
157 | response_length.append(len(data['generated_responses'][0].split()))
158 | tokens_response_len = len(tokenizer(data['generated_responses'][0])['input_ids'])
159 | token_num.append(tokens_response_len)
160 |
161 |
162 | avg_response_length = sum(response_length) / test_num
163 | avg_token_num = sum(token_num) / test_num
164 |
165 | print("length:", avg_response_length)
166 | print('token_num:', avg_token_num)
167 |
168 |
169 | if __name__ == "__main__":
170 | args = parse_args()
171 | infer(args)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import json
5 | import os
6 | import numpy as np
7 | from pathlib import Path
8 | from typing import Iterable, Union, Any
9 |
10 | from utils.examples import get_examples
11 |
12 |
13 | def set_seed(seed: int = 42) -> None:
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | os.environ["PYTHONHASHSEED"] = str(seed)
17 | print(f"Random seed set as {seed}")
18 |
19 |
20 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
21 | with open(file, "r", encoding="utf-8") as f:
22 | for line in f:
23 | try:
24 | yield json.loads(line)
25 | except:
26 | print("Error in loading:", line)
27 | exit()
28 |
29 |
30 | def save_jsonl(samples, save_path):
31 | # ensure path
32 | folder = os.path.dirname(save_path)
33 | os.makedirs(folder, exist_ok=True)
34 |
35 | with open(save_path, "w", encoding="utf-8") as f:
36 | for sample in samples:
37 | f.write(json.dumps(sample, ensure_ascii=False) + "\n")
38 | print("Saved to", save_path)
39 |
40 |
41 | def lower_keys(example):
42 | new_example = {}
43 | for key, value in example.items():
44 | if key != key.lower():
45 | new_key = key.lower()
46 | new_example[new_key] = value
47 | else:
48 | new_example[key] = value
49 | return new_example
50 |
51 |
52 | EXAMPLES = get_examples()
53 |
54 |
55 | def load_prompt(data_name, prompt_type, num_shots):
56 | if not num_shots:
57 | return []
58 |
59 | if data_name in ["gsm_hard", "svamp", "tabmwp", "asdiv", "mawps"]:
60 | data_name = "gsm8k"
61 | if data_name in ["math_oai", "hungarian_exam", "math-oai", "aime24", "amc23"]:
62 | data_name = "math"
63 | if data_name in ["sat_math"]:
64 | data_name = "mmlu_stem"
65 | if data_name in [
66 | "gaokao2024_I",
67 | "gaokao2024_II",
68 | "gaokao_math_qa",
69 | "gaokao2024_mix",
70 | "cn_middle_school",
71 | ]:
72 | data_name = "gaokao"
73 |
74 | if prompt_type in ["tool-integrated"]:
75 | prompt_type = "tora"
76 |
77 | return EXAMPLES[data_name][:num_shots]
78 |
79 |
80 | PROMPT_TEMPLATES = {
81 | "direct": ("Question: {input}\nAnswer: ", "{output}", "\n\n"),
82 | "cot": ("Question: {input}\nAnswer: ", "{output}", "\n\n\n"),
83 | "pal": ("Question: {input}\n\n", "{output}", "\n---\n"),
84 | "tool-integrated": ("Question: {input}\n\nSolution:\n", "{output}", "\n---\n"),
85 | "self-instruct": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"),
86 | "tora": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"),
87 | "wizard_zs": (
88 | "### Instruction:\n{input}\n\n### Response: Let's think step by step.",
89 | "{output}",
90 | "\n\n\n",
91 | ),
92 | "platypus_fs": (
93 | "### Instruction:\n{input}\n\n### Response:\n",
94 | "{output}",
95 | "\n\n\n",
96 | ),
97 | "deepseek-math": (
98 | "User: {input}\nPlease reason step by step, "
99 | "and put your final answer within \\boxed{{}}.\n\nAssistant:",
100 | "{output}",
101 | "\n\n\n",
102 | ),
103 | "kpmath": (
104 | "User: Please reason step by step and put your final answer at the end "
105 | 'with "The answer is: ".\n\n{input}\n\nAssistant:',
106 | "{output}",
107 | ),
108 | "jiuzhang": (
109 | "## Question\n{input}\n\n## Solution\n",
110 | "{output}",
111 | "\n\n\n",
112 | ),
113 | "jiuzhang_tora": (
114 | "## Question\n{input}\n\n## Code Solution\n",
115 | "{output}",
116 | "\n\n\n",
117 | ),
118 | "jiuzhang_nl": (
119 | "## Question\n{input}\n\n## Natural Language Solution\n",
120 | "{output}",
121 | "\n\n\n",
122 | ),
123 | "mmiqc": (
124 | 'Please solve the following problem and put your answer at the end with "The answer is: ".\n\n{input}\n\n',
125 | "{output}",
126 | "\n\n\n",
127 | ),
128 | "abel": (
129 | "Question:\n{input}\nAnswer:\nLet's think step by step.\n",
130 | "{output}",
131 | "\n\n",
132 | ),
133 | "shepherd": ("{input}\n", "{output}", "\n\n\n"),
134 | "qwen-boxed": (
135 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
136 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
137 | "<|im_start|>assistant\n",
138 | "{output}",
139 | "\n\n",
140 | ),
141 | "qwen25-math-cot": (
142 | "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
143 | "<|im_start|>user\n{input}<|im_end|>\n"
144 | "<|im_start|>assistant\n",
145 | "{output}",
146 | "\n\n",
147 | ),
148 | "mathstral": (
149 | "{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
150 | "{output}",
151 | "\n\n",
152 | ),
153 | "internlm-math-fs": ("Question:{input}\nAnswer:", "{output}", "\n"),
154 | "internlm-math-chat": (
155 | "<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n",
156 | "{output}",
157 | "\n\n",
158 | ),
159 | "mistral": (
160 | "[INST] {input}[/INST]",
161 | "{output}",
162 | "\n\n",
163 | ),
164 | "numina": ("### Problem: {input}\n### Solution:", " {output}", "\n\n"),
165 | }
166 |
167 |
168 | def construct_prompt(example, data_name, args):
169 | if args.adapt_few_shot and data_name in [
170 | "gaokao2024_I",
171 | "gaokao2024_II",
172 | "gaokao_math_qa",
173 | "gaokao2024_mix",
174 | "cn_middle_school",
175 | ]:
176 | demos = load_prompt(data_name, args.prompt_type, 5)
177 | else:
178 | demos = load_prompt(data_name, args.prompt_type, args.num_shots)
179 | prompt_type = args.prompt_type
180 | if prompt_type == "platypus_fs":
181 | prompt_type = "cot"
182 | if prompt_type == "tool-integrated":
183 | prompt_type = "tora"
184 |
185 | prompt_temp = PROMPT_TEMPLATES[args.prompt_type]
186 |
187 | splitter = prompt_temp[2]
188 | input_template, output_template, splitter = (
189 | prompt_temp[0],
190 | prompt_temp[1],
191 | prompt_temp[2],
192 | )
193 | if args.prompt_type == "qwen25-math-cot":
194 | # Hotfix to support putting all demos into a single turn
195 | demo_prompt = splitter.join([q + "\n" + a for q, a in demos])
196 | else:
197 | demo_prompt = splitter.join(
198 | [
199 | input_template.format(input=q) + output_template.format(output=a)
200 | for q, a in demos
201 | ]
202 | )
203 | context = input_template.format(input=example["question"])
204 | if len(demo_prompt) == 0 or (
205 | args.adapt_few_shot and example["gt_ans"] not in ["A", "B", "C", "D", "E"]
206 | ):
207 | full_prompt = context
208 | else:
209 | if args.prompt_type == "qwen25-math-cot":
210 | # Hotfix to supportting put all demos into a single turn
211 | full_prompt = demo_prompt + splitter + example["question"]
212 | full_prompt = input_template.format(input=full_prompt)
213 | else:
214 | full_prompt = demo_prompt + splitter + context
215 |
216 | if args.prompt_type == "platypus_fs":
217 | full_prompt_temp = (
218 | "Below is an instruction that describes a task. "
219 | "Write a response that appropriately completes the request.\n\n"
220 | "### Instruction:\n{instruction}\n\n### Response:\n"
221 | )
222 | full_prompt = full_prompt_temp.format(instruction=full_prompt)
223 |
224 | if prompt_type == "tora":
225 | full_prompt = (
226 | """Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:
227 |
228 | - Analyze the question and write functions to solve the problem; the function should not take any arguments.
229 | - Present the final result in LaTeX using a `\boxed{}` without any units.
230 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values.
231 |
232 | Here are some examples you may refer to:
233 |
234 | ---
235 |
236 | """
237 | + full_prompt
238 | )
239 |
240 | return full_prompt.strip(" ") # important!
241 |
242 |
243 | key_map = {
244 | "gt": "Ground Truth",
245 | "pred": "Prediction",
246 | "gt_cot": "Reference CoT",
247 | "score": "Score",
248 | }
249 |
250 |
251 | def show_sample(sample, print_all_preds=False):
252 | print("==" * 20)
253 | for key in ["idx", "type", "level", "dataset"]:
254 | if key in sample:
255 | # capitalize
256 | print("{}: {}".format(key[0].upper() + key[1:], sample[key]))
257 | print("Question:", repr(sample["question"]))
258 | if "code" in sample:
259 | if print_all_preds:
260 | for code in sample["code"]:
261 | print("-" * 20)
262 | print("code:", code)
263 | print("Execution:", sample["report"])
264 | else:
265 | print("Solution:\n", sample["code"][0])
266 | print("Execution:", sample["report"][0])
267 | if "pred" in sample:
268 | print("Prediction:", repr(sample["pred"][0]))
269 | for key in ["gt", "score", "unit", "gt_cot"]:
270 | if key in sample:
271 | _key = key_map.get(key, key)
272 | print("{}: {}".format(_key, repr(sample[key])))
273 | print()
274 |
--------------------------------------------------------------------------------
/vanilla_deer.py:
--------------------------------------------------------------------------------
1 |
2 | import warnings
3 | warnings.filterwarnings("ignore") # Ignore all warnings
4 | from transformers import AutoModelForCausalLM, AutoTokenizer
5 | import pdb
6 | import torch
7 | import torch.nn.functional as F
8 | import json
9 | import re
10 | from tqdm import tqdm
11 | import argparse
12 | import os
13 | from transformers.cache_utils import DynamicCache
14 | from copy import deepcopy
15 |
16 | def append_jsonl(data, file_path):
17 | """
18 | Append results from list to a .jsonl file.
19 |
20 | Parameters:
21 | data (list): List containing Python dictionaries to append.
22 | file_path (str): Target .jsonl file path.
23 | """
24 | with open(file_path, 'a', encoding='utf-8') as f:
25 | for item in data:
26 | json_line = json.dumps(item, ensure_ascii=False)
27 | f.write(json_line + '\n')
28 |
29 | def write_jsonl(data, file_path):
30 | """
31 | Write results from list to a .jsonl file.
32 |
33 | Parameters:
34 | data (list): List containing Python dictionaries to write.
35 | file_path (str): Output .jsonl file path.
36 | """
37 | with open(file_path, 'w', encoding='utf-8') as f:
38 | for item in data:
39 | # Convert each dictionary to JSON string and write to file
40 | json_line = json.dumps(item, ensure_ascii=False)
41 | f.write(json_line + '\n')
42 |
43 | def read_jsonl(file_path):
44 | """
45 | Read .jsonl file and return a list of dictionaries.
46 |
47 | Parameters:
48 | file_path (str): Path to .jsonl file.
49 |
50 | Returns:
51 | data (list): List containing all JSON objects.
52 | """
53 | data = []
54 | with open(file_path, 'r', encoding='utf-8') as f:
55 | for line in f:
56 | # Parse JSON object from each line
57 | json_obj = json.loads(line.strip())
58 | data.append(json_obj)
59 | return data
60 |
61 | def calcu_max_probs_w_kv(model, pred_input_ids, kv_cache, tokenizer, method=1):
62 | list1 = tokenizer([' }', '}', '}.', '}.\n', '}\\', '}}', ')}', ')}.', ')}\n', ''])['input_ids']
63 | stop_ids = sum(list1, [])
64 | total_steps = 0
65 | if method == 0:
66 | total_prob_max = 1.0
67 | else:
68 | total_prob_max = 0.0
69 |
70 | pred_tokens = []
71 | last_token = -1
72 |
73 | backup_cache = deepcopy(kv_cache)
74 |
75 | with torch.no_grad():
76 | while last_token not in stop_ids:
77 |
78 | if last_token == -1:
79 | output_dicts = model(input_ids=pred_input_ids, past_key_values=backup_cache)
80 | else:
81 | output_dicts = model(input_ids=torch.tensor([last_token]).unsqueeze(0).to(pred_input_ids.device), past_key_values=backup_cache)
82 | logits = output_dicts['logits'][0][-1]
83 | past_key_values = output_dicts['past_key_values']
84 | probs = F.softmax(logits, dim=-1)
85 |
86 | max_value, max_index = torch.max(probs, dim=0)
87 |
88 |
89 | if last_token == -1:
90 | total_prob_max = total_prob_max
91 | else:
92 |
93 | if method == 0:
94 | total_prob_max *= max_value
95 | else:
96 | total_prob_max += max_value
97 |
98 | pred_tokens.append(max_index)
99 | last_token = max_index
100 | total_steps += 1
101 | if total_steps > 20:
102 | break
103 |
104 |
105 | if method != 0:
106 | total_prob_max = (total_prob_max - max_value) / (total_steps - 2)
107 |
108 | del backup_cache, past_key_values
109 | torch.cuda.empty_cache()
110 |
111 | return total_prob_max.item()
112 |
113 | def parse_args():
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument('--model_name_or_path', type=str, default="/mnt/data1/ycx/DeepSeek-R1-Distill-Qwen-32B", help="model directory")
116 | parser.add_argument('--threshold', type=float, default=0.95)
117 | parser.add_argument('--max_len', type=int, default=16384)
118 | parser.add_argument('--dataset', type=str, default='math')
119 | parser.add_argument('--output_path', type=str, default='./outputs')
120 | parser.add_argument('--log', type=bool, default=False)
121 | parser.add_argument('--think_ratio', type=float, default=0.9)
122 |
123 | args = parser.parse_args()
124 | return args
125 |
126 | args = parse_args()
127 |
128 | think_len = int(args.max_len * args.think_ratio)
129 | answer_len = args.max_len - think_len
130 |
131 | # model
132 | model_name = args.model_name_or_path
133 | model = AutoModelForCausalLM.from_pretrained(
134 | model_name,
135 | torch_dtype="bfloat16",
136 | device_map="auto"
137 | )
138 | tokenizer = AutoTokenizer.from_pretrained(model_name)
139 |
140 | # data
141 | sys_prompts = ['Please reason step by step, and put your final answer within \\boxed{}.']
142 | questions_json = read_jsonl('./data/' + args.dataset + '/test.jsonl')
143 |
144 | os.makedirs(f'{args.output_path}/{model_name}/{args.dataset}', exist_ok=True)
145 | output_list = []
146 |
147 | for i in tqdm(range(0, len(questions_json))):
148 | output_dict = {}
149 | sys_prompt = sys_prompts[0]
150 |
151 | prompt = questions_json[i]['problem'] #+ 'start ' * 30000
152 | answer = str(questions_json[i]['answer'])
153 |
154 | messages = [
155 | {"role": "system", "content": sys_prompt},
156 | {"role": "user", "content": prompt}
157 | ]
158 |
159 | text = tokenizer.apply_chat_template(
160 | messages,
161 | tokenize=False,
162 | add_generation_prompt=True
163 | )
164 |
165 | model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
166 | input_ids = model_inputs["input_ids"]
167 | input_length = len(input_ids[0])
168 |
169 | last_token_ids = tokenizer("**", add_special_tokens=False)["input_ids"] + tokenizer("", add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id]
170 | continue_ids = tokenizer("Wait", add_special_tokens=False)["input_ids"]
171 | stop_ids = continue_ids + last_token_ids
172 |
173 | answer_prompt_ids = tokenizer("\n**Final Answer**\n\nThe final answer is \\boxed", add_special_tokens=False)["input_ids"]
174 | answer_ids = tokenizer(answer, add_special_tokens=False)["input_ids"]
175 |
176 | past_key_values = DynamicCache()
177 | first_round = True
178 | too_long = False
179 | while 1:
180 | if first_round:
181 |
182 | generated_dicts = model.generate(
183 | input_ids,
184 | #max_new_tokens=200,
185 | max_new_tokens=think_len-len(input_ids[0]),
186 | do_sample=False,
187 | eos_token_id=stop_ids,
188 | return_dict_in_generate=True,
189 | output_logits=True,
190 | tokenizer=tokenizer,
191 | past_key_values=past_key_values,
192 | )
193 |
194 | else:
195 | generated_dicts = model.generate(
196 | input_ids,
197 | max_new_tokens=think_len-len(input_ids[0]),
198 | do_sample=False,
199 | return_dict_in_generate=True,
200 | output_logits=True,
201 | tokenizer=tokenizer,
202 | eos_token_id=stop_ids,
203 | past_key_values=past_key_values,
204 | )
205 |
206 | generated_ids = [
207 | output_ids[len(input_ids):-1] for input_ids, output_ids in zip(input_ids, generated_dicts['sequences'])
208 | ]
209 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
210 |
211 | logits = generated_dicts['logits'][-1]
212 | probs = F.softmax(logits, dim=-1)[0]
213 |
214 | max_value, max_index = torch.max(probs, dim=0)
215 |
216 | if max_index in last_token_ids:
217 | real_stop = 1
218 | else:
219 | real_stop = 0
220 |
221 | pred_input_ids = torch.cat((input_ids, generated_ids[0].unsqueeze(0)), dim=1)
222 |
223 | if len(pred_input_ids[0]) >= think_len - 100:
224 | too_long = True
225 |
226 | pred_prob = calcu_max_probs_w_kv(model, torch.tensor(answer_prompt_ids).to(generated_ids[0].device).unsqueeze(0), past_key_values, tokenizer, 1)
227 |
228 | torch.cuda.empty_cache()
229 |
230 | if pred_prob > args.threshold or real_stop or too_long:
231 | input_ids = torch.cat((pred_input_ids, torch.tensor(tokenizer('\n\n\n')['input_ids']).to(generated_ids[0].device).unsqueeze(0)), dim=1) # with wait
232 |
233 | generated_dicts = model.generate(
234 | input_ids,
235 | max_new_tokens=answer_len,
236 | do_sample=False,
237 | return_dict_in_generate=True,
238 | past_key_values=past_key_values,
239 | )
240 |
241 | generated_ids = [
242 | output_ids[len(input_ids):-1] for input_ids, output_ids in zip(input_ids, generated_dicts['sequences'])
243 | ]
244 | final_output_ids = torch.cat((input_ids[0], generated_ids[0]), dim=-1)
245 | response = tokenizer.batch_decode([final_output_ids[input_length:]], skip_special_tokens=True)[0]
246 |
247 | if args.log:
248 | log_file_path = args.output_path + "./outputs/log" + str(args.threshold) + ".txt"
249 | with open(log_file_path, "a") as file:
250 | file.write(response + "\n")
251 |
252 | break
253 |
254 | else:
255 | tmp = torch.cat((generated_ids[0], torch.tensor(continue_ids).to(generated_ids[0].device)), dim=0)
256 | input_ids = torch.cat((input_ids, tmp.unsqueeze(0)), dim=1) # with wait
257 | torch.cuda.empty_cache()
258 |
259 | output_dict['question'] = questions_json[i]['problem']
260 | output_dict['generated_responses'] = [response]
261 | output_dict['gold_answer'] = questions_json[i]['answer']
262 |
263 | append_jsonl([output_dict], args.output_path + model_name + '/' + args.dataset + '/greedy_p' + str(args.threshold) + '_len' + str(args.max_len) + '.jsonl')
--------------------------------------------------------------------------------
/data/aime25/test.jsonl:
--------------------------------------------------------------------------------
1 | {"problem": "Six points $ A, B, C, D, E, $ and $ F $ lie in a straight line in that order. Suppose that $ G $ is a point not on the line and that $ AC = 26 $, $ BD = 22 $, $ CE = 31 $, $ DF = 33 $, $ AF = 73 $, $ CG = 40 $, and $ DG = 30 $. Find the area of $ \\triangle BGE $.", "answer": "468"}
2 | {"problem": "Find the sum of all positive integers $ n $ such that $ n + 2 $ divides the product $ 3(n + 3)(n^2 + 9) $.", "answer": "49"}
3 | {"problem": "Four unit squares form a $2 \\times 2$ grid. Each of the 12 unit line segments forming the sides of the squares is colored either red or blue in such a way that each unit square has 2 red sides and 2 blue sides. Find the number of such colorings.", "answer": "82"}
4 | {"problem": "The product $ \\prod_{k=4}^{63} \\frac{\\log_k(5^{k^2-1})}{\\log_{k+1}(5^{k^2-4})} = \\frac{\\log_4(5^{15})}{\\log_5(5^{12})} \\cdot \\frac{\\log_5(5^{24})}{\\log_6(5^{21})} \\cdot \\frac{\\log_6(5^{35})}{\\log_7(5^{32})} \\cdots \\frac{\\log_{63}(5^{3968})}{\\log_{64}(5^{3965})} $ is equal to $ \\frac{m}{n} $, where $ m $ and $ n $ are relatively prime positive integers. Find $ m + n $.", "answer": "106"}
5 | {"problem": "Suppose $ \\triangle ABC $ has angles $ \\angle BAC = 84^\\circ $, $ \\angle ABC = 60^\\circ $, and $ \\angle ACB = 36^\\circ $. Let $ D, E, $ and $ F $ be the midpoints of sides $ \\overline{BC} $, $ \\overline{AC} $, and $ \\overline{AB} $, respectively. The circumcircle of $ \\triangle DEF $ intersects $ \\overline{BD} $, $ \\overline{AE} $, and $ \\overline{AF} $ at points $ G, H, $ and $ J $, respectively. The points $ G, D, E, H, J, $ and $ F $ divide the circumcircle of $ \\triangle DEF $ into six minor arcs, as shown. Find $ \\widehat{DE} + 2 \\cdot \\widehat{HJ} + 3 \\cdot \\widehat{FG} $, where the arcs are measured in degrees.", "answer": "336^\\circ"}
6 | {"problem": "Circle $\\omega_1$ with radius 6 centered at point $A$ is internally tangent at point $B$ to circle $\\omega_2$ with radius 15. Points $C$ and $D$ lie on $\\omega_2$ such that $\\overline{BC}$ is a diameter of $\\omega_2$ and $\\overline{BC} \\perp \\overline{AD}$. The rectangle $EFGH$ is inscribed in $\\omega_1$ such that $\\overline{EF} \\perp \\overline{BC}$, $C$ is closer to $\\overline{GH}$ than to $\\overline{EF}$, and $D$ is closer to $\\overline{FG}$ than to $\\overline{EH}$, as shown. Triangles $\\triangle DGF$ and $\\triangle CHG$ have equal areas. The area of rectangle $EFGH$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m + n$.", "answer": "293"}
7 | {"problem": "Let $ A $ be the set of positive integer divisors of 2025. Let $ B $ be a randomly selected subset of $ A $. The probability that $ B $ is a nonempty set with the property that the least common multiple of its elements is 2025 is $ \\frac{m}{n} $, where $ m $ and $ n $ are relatively prime positive integers. Find $ m + n $.", "answer": "237"}
8 | {"problem": "From an unlimited supply of 1-cent coins, 10-cent coins, and 25-cent coins, Silas wants to find a collection of coins that has a total value of $ N $ cents, where $ N $ is a positive integer. He uses the so-called **greedy algorithm**, successively choosing the coin of greatest value that does not cause the value of his collection to exceed $ N $. For example, to get 42 cents, Silas will choose a 25-cent coin, then a 10-cent coin, then 7 1-cent coins. However, this collection of 9 coins uses more coins than necessary to get a total of 42 cents; indeed, choosing 4 10-cent coins and 2 1-cent coins achieves the same total value with only 6 coins.\n\nIn general, the greedy algorithm succeeds for a given $ N $ if no other collection of 1-cent, 10-cent, and 25-cent coins gives a total value of $ N $ cents using strictly fewer coins than the collection given by the greedy algorithm. Find the number of values of $ N $ between 1 and 1000 inclusive for which the greedy algorithm succeeds.", "answer": "610"}
9 | {"problem": "There are $ n $ values of $ x $ in the interval $ 0 < x < 2\\pi $ where $ f(x) = \\sin(7\\pi \\cdot \\sin(5x)) = 0 $. For $ t $ of these $ n $ values of $ x $, the graph of $ y = f(x) $ is tangent to the $ x $-axis. Find $ n + t $.", "answer": "149"}
10 | {"problem": "Sixteen chairs are arranged in a row. Eight people each select a chair in which to sit so that no person sits next to two other people. Let $ N $ be the number of subsets of 16 chairs that could be selected. Find the remainder when $ N $ is divided by 1000.", "answer": "907"}
11 | {"problem": "Let $ S $ be the set of vertices of a regular 24-gon. Find the number of ways to draw 12 segments of equal lengths so that each vertex in $ S $ is an endpoint of exactly one of the 12 segments.", "answer": "113"}
12 | {"problem": "Let $ A_1A_2 \\ldots A_{11} $ be an 11-sided non-convex simple polygon with the following properties:\n* The area of $ A_iA_1A_{i+1} $ is 1 for each $ 2 \\leq i \\leq 10 $,\n* $ \\cos(\\angle A_iA_1A_{i+1}) = \\frac{12}{13} $ for each $ 2 \\leq i \\leq 10 $,\n* The perimeter of $ A_1A_2 \\ldots A_{11} $ is 20.\nIf $ A_1A_2 + A_1A_{11} $ can be expressed as $ \\frac{m\\sqrt{n} - p}{q} $ for positive integers $ m, n, p, q $ with $ n $ squarefree and no prime divides all of $ m, p, q$, find $ m + n + p + q $.", "answer": "19"}
13 | {"problem": "Let the sequence of rationals $ x_1, x_2, \\ldots $ be defined such that $ x_1 = \\frac{25}{11} $ and\n$ x_{k+1} = \\frac{1}{3} \\left( x_k + \\frac{1}{x_k} - 1 \\right). $\n$ x_{2025} $ can be expressed as $ \\frac{m}{n} $ for relatively prime positive integers $ m $ and $ n $. Find the remainder when $ m + n $ is divided by 1000.", "answer": "248"}
14 | {"problem": "Let $ \\triangle ABC $ be a right triangle with $ \\angle A = 90^\\circ $ and $ BC = 38 $. There exist points $ K $ and $ L $ inside the triangle such that $ AK = AL = BK = CL = KL = 14. $ The area of the quadrilateral $ BKLC $ can be expressed as $ n \\sqrt{3} $ for some positive integer $ n $. Find $ n $.", "answer": "104"}
15 | {"problem": "There are exactly three positive real numbers $ k $ such that the function\n$ f(x) = \\frac{(x - 18)(x - 72)(x - 98)(x - k)}{x} $\ndefined over the positive real numbers achieves its minimum value at exactly two positive real numbers $ x $. Find the sum of these three values of $ k $.", "answer": "240"}
16 | {"problem": "Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.", "answer": "70"}
17 | {"problem": "On $\\triangle ABC$ points $A,D,E$, and $B$ lie that order on side $\\overline{AB}$ with $AD=4, DE=16$, and $EB=8$. Points $A,F,G$, and $C$ lie in that order on side $\\overline{AC}$ with $AF=13, FG=52$, and $GC=26$. Let $M$ be the reflection of $D$ through $F$, and let $N$ be the reflection of $G$ through $E$. Quadrilateral $DEGF$ has area 288. Find the area of heptagon $AFNBCEM$.", "answer": "588"}
18 | {"problem": "The 9 members of a baseball team went to an ice cream parlor after their game. Each player had a singlescoop cone of chocolate, vanilla, or strawberry ice cream. At least one player chose each flavor, and the number of players who chose chocolate was greater than the number of players who chose vanilla, which was greater than the number of players who chose strawberry. Let $N$ be the number of different assignments of flavors to players that meet these conditions. Find the remainder when $N$ is divided by 1000.", "answer": "16"}
19 | {"problem": "Find the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers between $-100$ and $100$, inclusive, such that $12x^{2}-xy-6y^{2}=0$.", "answer": "117"}
20 | {"problem": "There are $8!=40320$ eight-digit positive integers that use each of the digits $1,2,3,4,5,6,7,8$ exactly once. Let $N$ be the number of these integers that are divisible by 22. Find the difference between $N$ and 2025.", "answer": "279"}
21 | {"problem": "An isosceles trapezoid has an inscribed circle tangent to each of its four sides. The radius of the circle is 3, and the area of the trapezoid is 72. Let the parallel sides of the trapezoid have lengths $r$ and $s$, with $r \\neq s$. Find $r^{2}+s^{2}$.", "answer": "504"}
22 | {"problem": "The twelve letters $A,B,C,D,E,F,G,H,I,J,K$, and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and those six words are listed alphabetically. For example, a possible result is $AB,CJ,DG,EK,FL,HI$. The probability that the last word listed contains $G$ is $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$.", "answer": "821"}
23 | {"problem": "Let $k$ be real numbers such that the system $|25+20i-z|=5$ and $|z-4-k|=|z-3i-k|$ has exactly one complex solution $z$. The sum of all possible values of $k$ can be written as $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. Find $m+n$. Here $i=\\sqrt{-1}$.", "answer": "77"}
24 | {"problem": "The parabola with equation $y=x^{2}-4$ is rotated $60^{\\circ}$ counterclockwise around the origin. The unique point in the fourth quadrant where the original parabola and its image intersect has $y$-coordinate $\\frac{a-\\sqrt{b}}{c}$, where $a$, $b$, and $c$ are positive integers, and $a$ and $c$ are relatively prime. Find $a+b+c$.", "answer": "62"}
25 | {"problem": "The 27 cells of a $3\\times9$ grid are filled in using the numbers 1 through 9 so that each row contains 9 different numbers, and each of the three $3\\times3$ blocks heavily outlined in the example below contains 9 different numbers, as in the first three rows of a Sudoku puzzle. \n | 4 | 2 | 8 | 9 | 6 | 3 | 1 | 7 | 5 | \n | 3 | 7 | 9 | 5 | 2 | 1 | 6 | 8 | 4 | \n | 5 | 6 | 1 | 8 | 4 | 7 | 9 | 2 | 3 | \n The number of different ways to fill such a grid can be written as $p^a\\cdot q^b\\cdot r^c\\cdot s^d$, where $p,q,r,$ and $s$ are distinct prime numbers and $a,b,c,$ and $d$ are positive integers. Find $p\\cdot a+q\\cdot b+r\\cdot c+s\\cdot d$.", "answer": "81"}
26 | {"problem": "A piecewise linear periodic function is defined by $f(x)=\\begin{cases}x&\\text{if }x\\in[-1,1)\\\\2-x&\\text{if }x\\in[1,3)\\end{cases}$ and $f(x+4)=f(x)$ for all real numbers $x$. The graph of $f(x)$ has the sawtooth pattern. The parabola $x=34y^2$ intersects the graph of $f(x)$ at finitely many points. The sum of the $y$-coordinates of these intersection points can be expressed in the form $\\frac{a+b\\sqrt{c}}{d}$, where $a,b,c,$ and $d$ are positive integers, $a,b,$ and $d$ have greatest common divisor equal to 1, and $c$ is not divisible by the square of any prime. Find $a+b+c+d$.", "answer": "259"}
27 | {"problem": "The set of points in 3-dimensional coordinate space that lie in the plane $x+y+z=75$ whose coordinates satisfy the inequalities $x-yz 1:
15 | substrs = substrs[1:]
16 | for substr in substrs:
17 | new_str += "\\frac"
18 | if len(substr) > 0 and substr[0] == "{":
19 | new_str += substr
20 | else:
21 | try:
22 | assert len(substr) >= 2
23 | except:
24 | return string
25 | a = substr[0]
26 | b = substr[1]
27 | if b != "{":
28 | if len(substr) > 2:
29 | post_substr = substr[2:]
30 | new_str += "{" + a + "}{" + b + "}" + post_substr
31 | else:
32 | new_str += "{" + a + "}{" + b + "}"
33 | else:
34 | if len(substr) > 2:
35 | post_substr = substr[2:]
36 | new_str += "{" + a + "}" + b + post_substr
37 | else:
38 | new_str += "{" + a + "}" + b
39 | string = new_str
40 | return string
41 |
42 |
43 | def _fix_a_slash_b(string):
44 | if len(string.split("/")) != 2:
45 | return string
46 | a = string.split("/")[0]
47 | b = string.split("/")[1]
48 | try:
49 | if "sqrt" not in a:
50 | a = int(a)
51 | if "sqrt" not in b:
52 | b = int(b)
53 | assert string == "{}/{}".format(a, b)
54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
55 | return new_string
56 | except:
57 | return string
58 |
59 |
60 | def _fix_sqrt(string):
61 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
62 | return _string
63 |
64 |
65 | def convert_word_number(text: str) -> str:
66 | try:
67 | text = str(w2n.word_to_num(text))
68 | except:
69 | pass
70 | return text
71 |
72 |
73 | # units mainly from MathQA
74 | unit_texts = [
75 | "east",
76 | "degree",
77 | "mph",
78 | "kmph",
79 | "ft",
80 | "m sqaure",
81 | " m east",
82 | "sq m",
83 | "deg",
84 | "mile",
85 | "q .",
86 | "monkey",
87 | "prime",
88 | "ratio",
89 | "profit of rs",
90 | "rd",
91 | "o",
92 | "gm",
93 | "p . m",
94 | "lb",
95 | "tile",
96 | "per",
97 | "dm",
98 | "lt",
99 | "gain",
100 | "ab",
101 | "way",
102 | "west",
103 | "a .",
104 | "b .",
105 | "c .",
106 | "d .",
107 | "e .",
108 | "f .",
109 | "g .",
110 | "h .",
111 | "t",
112 | "a",
113 | "h",
114 | "no change",
115 | "men",
116 | "soldier",
117 | "pie",
118 | "bc",
119 | "excess",
120 | "st",
121 | "inches",
122 | "noon",
123 | "percent",
124 | "by",
125 | "gal",
126 | "kmh",
127 | "c",
128 | "acre",
129 | "rise",
130 | "a . m",
131 | "th",
132 | "π r 2",
133 | "sq",
134 | "mark",
135 | "l",
136 | "toy",
137 | "coin",
138 | "sq . m",
139 | "gallon",
140 | "° f",
141 | "profit",
142 | "minw",
143 | "yr",
144 | "women",
145 | "feet",
146 | "am",
147 | "pm",
148 | "hr",
149 | "cu cm",
150 | "square",
151 | "v â € ™",
152 | "are",
153 | "rupee",
154 | "rounds",
155 | "cubic",
156 | "cc",
157 | "mtr",
158 | "s",
159 | "ohm",
160 | "number",
161 | "kmph",
162 | "day",
163 | "hour",
164 | "minute",
165 | "min",
166 | "second",
167 | "man",
168 | "woman",
169 | "sec",
170 | "cube",
171 | "mt",
172 | "sq inch",
173 | "mp",
174 | "∏ cm ³",
175 | "hectare",
176 | "more",
177 | "sec",
178 | "unit",
179 | "cu . m",
180 | "cm 2",
181 | "rs .",
182 | "rs",
183 | "kg",
184 | "g",
185 | "month",
186 | "km",
187 | "m",
188 | "cm",
189 | "mm",
190 | "apple",
191 | "liter",
192 | "loss",
193 | "yard",
194 | "pure",
195 | "year",
196 | "increase",
197 | "decrease",
198 | "d",
199 | "less",
200 | "Surface",
201 | "litre",
202 | "pi sq m",
203 | "s .",
204 | "metre",
205 | "meter",
206 | "inch",
207 | ]
208 |
209 | unit_texts.extend([t + "s" for t in unit_texts])
210 |
211 |
212 | def strip_string(string):
213 | string = str(string).strip()
214 | # linebreaks
215 | string = string.replace("\n", "")
216 |
217 | # right "."
218 | string = string.rstrip(".")
219 |
220 | # remove inverse spaces
221 | # replace \\ with \
222 | string = string.replace("\\!", "")
223 | # string = string.replace("\\ ", "")
224 | # string = string.replace("\\\\", "\\")
225 |
226 | # matrix
227 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
228 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
229 | string = string.replace("bmatrix", "pmatrix")
230 |
231 | # replace tfrac and dfrac with frac
232 | string = string.replace("tfrac", "frac")
233 | string = string.replace("dfrac", "frac")
234 | string = (
235 | string.replace("\\neq", "\\ne")
236 | .replace("\\leq", "\\le")
237 | .replace("\\geq", "\\ge")
238 | )
239 |
240 | # remove \left and \right
241 | string = string.replace("\\left", "")
242 | string = string.replace("\\right", "")
243 | string = string.replace("\\{", "{")
244 | string = string.replace("\\}", "}")
245 |
246 | # Remove unit: miles, dollars if after is not none
247 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
248 | if _string != "" and _string != string:
249 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
250 | string = _string
251 |
252 | for unit_text in unit_texts:
253 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character
254 | # the suffix should be either the end of the string or a non-alphanumeric character
255 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
256 | if _string != "":
257 | string = _string
258 |
259 | # Remove circ (degrees)
260 | string = string.replace("^{\\circ}", "")
261 | string = string.replace("^\\circ", "")
262 |
263 | # remove dollar signs
264 | string = string.replace("\\$", "")
265 | string = string.replace("$", "")
266 | string = string.replace("\\(", "").replace("\\)", "")
267 |
268 | # convert word number to digit
269 | string = convert_word_number(string)
270 |
271 | # replace "\\text{...}" to "..."
272 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
273 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
274 | string = string.replace(key, "")
275 | string = string.replace("\\emptyset", r"{}")
276 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
277 |
278 | # remove percentage
279 | string = string.replace("\\%", "")
280 | string = string.replace("\%", "")
281 | string = string.replace("%", "")
282 |
283 | months = r"\b(January|February|March|April|May|June|July|August|September|October|November|December)\b"
284 | string = re.sub(months, "", string, flags=re.IGNORECASE)
285 |
286 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
287 | string = string.replace(" .", " 0.")
288 | string = string.replace("{.", "{0.")
289 |
290 | # cdot
291 | # string = string.replace("\\cdot", "")
292 | if (
293 | string.startswith("{")
294 | and string.endswith("}")
295 | and string.isalnum()
296 | or string.startswith("(")
297 | and string.endswith(")")
298 | and string.isalnum()
299 | or string.startswith("[")
300 | and string.endswith("]")
301 | and string.isalnum()
302 | ):
303 | string = string[1:-1]
304 |
305 | # inf
306 | string = string.replace("infinity", "\\infty")
307 | if "\\infty" not in string:
308 | string = string.replace("inf", "\\infty")
309 | string = string.replace("+\\inity", "\\infty")
310 |
311 | # and
312 | string = string.replace("and", "")
313 | string = string.replace("\\mathbf", "")
314 |
315 | # use regex to remove \mbox{...}
316 | string = re.sub(r"\\mbox{.*?}", "", string)
317 |
318 | # quote
319 | string.replace("'", "")
320 | string.replace('"', "")
321 |
322 | # i, j
323 | if "j" in string and "i" not in string:
324 | string = string.replace("j", "i")
325 |
326 | # replace a.000b where b is not number or b is end, with ab, use regex
327 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
328 | string = re.sub(r"(\d+)\.0*$", r"\1", string)
329 |
330 | # if empty, return empty string
331 | if len(string) == 0:
332 | return string
333 | if string[0] == ".":
334 | string = "0" + string
335 |
336 | # to consider: get rid of e.g. "k = " or "q = " at beginning
337 | if len(string.split("=")) == 2:
338 | if len(string.split("=")[0]) <= 2:
339 | string = string.split("=")[1]
340 |
341 | string = _fix_sqrt(string)
342 | string = string.replace(" ", "")
343 |
344 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
345 | string = _fix_fracs(string)
346 |
347 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
348 | string = _fix_a_slash_b(string)
349 |
350 | return string
351 |
352 |
353 | def extract_multi_choice_answer(pred_str):
354 | # TODO: SFT models
355 | if "Problem:" in pred_str:
356 | pred_str = pred_str.split("Problem:", 1)[0]
357 | pred_str = pred_str.replace("choice is", "answer is")
358 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower())
359 | if patt is not None:
360 | return patt.group("ans").upper()
361 | return "placeholder"
362 |
363 |
364 | direct_answer_trigger_for_fewshot = ("choice is", "answer is")
365 |
366 |
367 | def choice_answer_clean(pred: str):
368 | pred = pred.strip("\n")
369 |
370 | # Determine if this is ICL, if so, use \n\n to split the first chunk.
371 | ICL = False
372 | for trigger in direct_answer_trigger_for_fewshot:
373 | if pred.count(trigger) > 1:
374 | ICL = True
375 | if ICL:
376 | pred = pred.split("\n\n")[0]
377 |
378 | # Split the trigger to find the answer.
379 | preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
380 | if len(preds) > 1:
381 | answer_flag = True
382 | pred = preds[-1]
383 | else:
384 | answer_flag = False
385 |
386 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
387 |
388 | # Clean the answer based on the dataset
389 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
390 | if tmp:
391 | pred = tmp
392 | else:
393 | pred = [pred.strip().strip(".")]
394 |
395 | if len(pred) == 0:
396 | pred = ""
397 | else:
398 | if answer_flag:
399 | # choose the first element in list ...
400 | pred = pred[0]
401 | else:
402 | # choose the last e
403 | pred = pred[-1]
404 |
405 | # Remove the period at the end, again!
406 | pred = pred.rstrip(".").rstrip("/")
407 |
408 | return pred
409 |
410 |
411 | def find_box(pred_str: str):
412 | ans = pred_str.split("boxed")[-1]
413 | if not ans:
414 | return ""
415 | if ans[0] == "{":
416 | stack = 1
417 | a = ""
418 | for c in ans[1:]:
419 | if c == "{":
420 | stack += 1
421 | a += c
422 | elif c == "}":
423 | stack -= 1
424 | if stack == 0:
425 | break
426 | a += c
427 | else:
428 | a += c
429 | else:
430 | a = ans.split("$")[0].strip()
431 | return a
432 |
433 |
434 | def clean_units(pred_str: str):
435 | """Clean the units in the number."""
436 |
437 | def convert_pi_to_number(code_string):
438 | code_string = code_string.replace("\\pi", "π")
439 | # Replace \pi or π not preceded by a digit or } with 3.14
440 | code_string = re.sub(r"(? "3*3.14"
442 | code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
443 | # Handle cases where π is within braces or followed by a multiplication symbol
444 | # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
445 | code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
446 | code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
447 | return code_string
448 |
449 | pred_str = convert_pi_to_number(pred_str)
450 | pred_str = pred_str.replace("%", "/100")
451 | pred_str = pred_str.replace("$", "")
452 | pred_str = pred_str.replace("¥", "")
453 | pred_str = pred_str.replace("°C", "")
454 | pred_str = pred_str.replace(" C", "")
455 | pred_str = pred_str.replace("°", "")
456 | return pred_str
457 |
458 |
459 | def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
460 | if any([option in pred.lower() for option in ["yes", "true"]]):
461 | pred = "True"
462 | elif any([option in pred.lower() for option in ["no", "false"]]):
463 | pred = "False"
464 | elif any(
465 | [
466 | option in pred.lower()
467 | for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
468 | ]
469 | ):
470 | pass
471 | else:
472 | # Some of the models somehow get used to boxed output from pre-training
473 | if "boxed" in pred:
474 | pred = find_box(pred)
475 |
476 | if answer_flag:
477 | # Extract the numbers out of the string
478 | pred = pred.split("=")[-1].strip()
479 | pred = clean_units(pred)
480 | try:
481 | tmp = str(latex2sympy(pred))
482 | pred = str(eval(tmp))
483 | except Exception:
484 | if re.match(r"-?[\d\.]+\s\D+$", pred):
485 | pred = pred.split(" ")[0]
486 | elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
487 | pred = pred.split(" ")[0]
488 | else:
489 | # desparate search over the last number
490 | preds = re.findall(r"-?\d*\.?\d+", pred)
491 | if len(preds) >= 1:
492 | pred = preds[-1]
493 | else:
494 | pred = ""
495 |
496 | return pred
497 |
498 |
499 | def extract_answer(pred_str, use_last_number=True):
500 | pred_str = pred_str.replace("\u043a\u0438", "")
501 |
502 | pred = ""
503 |
504 | if "boxed" in pred_str:
505 | ans = pred_str.split("boxed")[-1]
506 | if len(ans) == 0:
507 | return ""
508 | elif ans[0] == "{":
509 | stack = 1
510 | a = ""
511 | for c in ans[1:]:
512 | if c == "{":
513 | stack += 1
514 | a += c
515 | elif c == "}":
516 | stack -= 1
517 | if stack == 0:
518 | break
519 | a += c
520 | else:
521 | a += c
522 | else:
523 | a = ans.split("$")[0].strip()
524 | pred = a
525 |
526 | # multiple line
527 | # pred = pred.split("\n")[0]
528 | pred = re.sub(r"\n\s*", "", pred)
529 | if pred != "" and pred[0] == ":":
530 | pred = pred[1:]
531 | if pred != "" and pred[-1] == ".":
532 | pred = pred[:-1]
533 | if pred != "" and pred[-1] == "/":
534 | pred = pred[:-1]
535 | # pred = strip_string(pred)
536 | return pred
537 |
538 |
539 | STRIP_EXCEPTIONS = ["carp_en", "minerva_math"]
540 |
541 |
542 | def parse_ground_truth(example: Dict[str, Any], data_name):
543 | if "answer" in example:
544 | return None, str(example["answer"])
545 | else:
546 | return None, None
547 |
548 |
549 | def parse_question(example, data_name=None):
550 | question = ""
551 | for key in ["question", "problem", "Question", "input"]:
552 | if key in example:
553 | question = example[key]
554 | break
555 |
556 | return question.strip()
557 |
558 |
559 | def run_execute(executor, result, prompt_type, data_name, execute=False):
560 | if not result or result == "error":
561 | return None, None
562 | report = None
563 |
564 | if "program_only" in prompt_type:
565 | prediction = extract_program_output(result)
566 | elif prompt_type in ["pot", "pal"] and execute:
567 | code = extract_program(result)
568 | prediction, report = executor.apply(code)
569 | else:
570 | prediction = extract_answer(result, data_name)
571 |
572 | # prediction = strip_string(prediction, skip_unit=data_name == "carp_en")
573 | prediction = strip_string(prediction)
574 | return prediction, report
575 |
576 |
577 | def _test_extract_answer():
578 | pass
579 |
580 |
581 | if __name__ == "__main__":
582 | _test_extract_answer()
583 |
--------------------------------------------------------------------------------
/utils/grader.py:
--------------------------------------------------------------------------------
1 | """
2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
4 | - https://github.com/openai/prm800k
5 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
6 | - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
7 | """
8 |
9 | import re
10 | import regex
11 | import multiprocessing
12 | from math import isclose
13 | from typing import Union
14 | from collections import defaultdict
15 |
16 | from sympy import simplify, N
17 | from sympy.parsing.sympy_parser import parse_expr
18 | from sympy.parsing.latex import parse_latex
19 | from latex2sympy2 import latex2sympy
20 |
21 | from .parser import strip_string
22 | # from parser import choice_answer_clean, strip_string
23 |
24 | from .math_normalization import check_sympy_equivalence
25 |
26 | import signal
27 | from concurrent.futures import ThreadPoolExecutor
28 |
29 | def timeout_handler(signum, frame):
30 | raise TimeoutError("Function execution timed out")
31 |
32 |
33 | def choice_answer_clean(pred: str):
34 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
35 | # Clean the answer based on the dataset
36 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
37 | if tmp:
38 | pred = tmp
39 | else:
40 | pred = [pred.strip().strip(".")]
41 | pred = pred[-1]
42 | # Remove the period at the end, again!
43 | pred = pred.rstrip(".").rstrip("/")
44 | return pred
45 |
46 |
47 | def parse_digits(num):
48 | num = regex.sub(",", "", str(num))
49 | try:
50 | return float(num)
51 | except:
52 | if num.endswith("%"):
53 | num = num[:-1]
54 | if num.endswith("\\"):
55 | num = num[:-1]
56 | try:
57 | return float(num) / 100
58 | except:
59 | pass
60 | return None
61 |
62 |
63 | def is_digit(num):
64 | # paired with parse_digits
65 | return parse_digits(num) is not None
66 |
67 |
68 | def str_to_pmatrix(input_str):
69 | input_str = input_str.strip()
70 | matrix_str = re.findall(r"\{.*,.*\}", input_str)
71 | pmatrix_list = []
72 |
73 | for m in matrix_str:
74 | m = m.strip("{}")
75 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
76 | pmatrix_list.append(pmatrix)
77 |
78 | return ", ".join(pmatrix_list)
79 |
80 |
81 | single_choice_patterns = [
82 | r"^\(A\)", r"^\(B\)", r"^\(C\)", r"^\(D\)", r"^\(E\)", # (A) (B) (C) (D) (E)
83 | r"^A\.", r"^B\.", r"^C\.", r"^D\.", r"^E\.", # A. B. C. D. E.
84 | r"^A\)", r"^B\)", r"^C\)", r"^D\)", r"^E\)", # A) B) C) D) E)
85 | r"^\*\*A\*\*", r"^\*\*B\*\*", r"^\*\*C\*\*", r"^\*\*D\*\*", r"^\*\*E\*\*", # **A** **B** **C** **D** **E**
86 | r"^A:", r"^B:", r"^C:", r"^D:", r"^E:", # A: B: C: D: E:
87 | ]
88 |
89 |
90 | def math_equal(
91 | prediction: Union[bool, float, str],
92 | reference: Union[float, str],
93 | include_percentage: bool = True,
94 | is_close: bool = True,
95 | timeout: bool = True,
96 | depth: int = 0,
97 | max_depth: int = 5
98 | ) -> bool:
99 | """
100 | Exact match of math if and only if:
101 | 1. numerical equal: both can convert to float and are equal
102 | 2. symbolic equal: both can convert to sympy expression and are equal
103 | """
104 |
105 | if depth > max_depth:
106 | return False
107 |
108 |
109 | if prediction is None or reference is None:
110 | return False
111 | if str(prediction.strip().lower()) == str(reference.strip().lower()):
112 | return True
113 | if (
114 | reference in ["A", "B", "C", "D", "E"]
115 | and choice_answer_clean(prediction) == reference
116 | ):
117 | return True
118 |
119 | for pattern in single_choice_patterns:
120 | if regex.match(pattern, prediction):
121 | # Remove the pattern from the beginning of the prediction and strip the result
122 | prediction_cleaned = regex.sub(pattern, "", prediction, count=1).strip()
123 | # Recursively call math_equal to check if the cleaned prediction matches the reference
124 | if math_equal(prediction_cleaned, reference, include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth):
125 | return True
126 |
127 | if "," in prediction and "," in reference:
128 | # 按逗号分割并去除空格
129 | pred_parts = [part.strip() for part in prediction.split(",")]
130 | ref_parts = [part.strip() for part in reference.split(",")]
131 |
132 | if len(pred_parts) == len(ref_parts):
133 | # 对两个列表排序后逐个比较,使用 math_equal 递归判断是否相等
134 | pred_parts_sorted = sorted(pred_parts)
135 | ref_parts_sorted = sorted(ref_parts)
136 |
137 | if all(
138 | math_equal(pred_parts_sorted[i], ref_parts_sorted[i], include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth)
139 | for i in range(len(pred_parts_sorted))
140 | ):
141 | return True
142 |
143 |
144 |
145 | try: # 1. numerical equal
146 | if is_digit(prediction) and is_digit(reference):
147 | prediction = parse_digits(prediction)
148 | reference = parse_digits(reference)
149 | # number questions
150 | if include_percentage:
151 | gt_result = [reference / 100, reference, reference * 100]
152 | else:
153 | gt_result = [reference]
154 | for item in gt_result:
155 | try:
156 | if is_close:
157 | if numeric_equal(prediction, item):
158 | return True
159 | else:
160 | if item == prediction:
161 | return True
162 | except Exception:
163 | continue
164 | return False
165 | except:
166 | pass
167 |
168 | if not prediction and prediction not in [0, False]:
169 | return False
170 |
171 | # 2. symbolic equal
172 | reference = str(reference).strip()
173 | prediction = str(prediction).strip()
174 |
175 | ## pmatrix (amps)
176 | if "pmatrix" in prediction and not "pmatrix" in reference:
177 | reference = str_to_pmatrix(reference)
178 |
179 | ## deal with [], (), {}
180 | pred_str, ref_str = prediction, reference
181 | if (
182 | prediction.startswith("[")
183 | and prediction.endswith("]")
184 | and not reference.startswith("(")
185 | ) or (
186 | prediction.startswith("(")
187 | and prediction.endswith(")")
188 | and not reference.startswith("[")
189 | ):
190 | pred_str = pred_str.strip("[]()")
191 | ref_str = ref_str.strip("[]()")
192 | for s in ["{", "}", "(", ")"]:
193 | ref_str = ref_str.replace(s, "")
194 | pred_str = pred_str.replace(s, "")
195 | if pred_str.lower() == ref_str.lower():
196 | return True
197 |
198 |
199 | ## unordered [a, b] vs. [c, d]
200 | # if (
201 | # regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
202 | # and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
203 | # ):
204 | # pred_parts = prediction[1:-1].split(",")
205 | # ref_parts = reference[1:-1].split(",")
206 |
207 | # if len(pred_parts) == len(ref_parts):
208 | # # 对两个列表的每个元素进行排序后逐个比较,使用 math_equal 递归判断是否相等
209 | # pred_parts_sorted = sorted(pred_parts, key=lambda x: x.strip())
210 | # ref_parts_sorted = sorted(ref_parts, key=lambda x: x.strip())
211 |
212 | # if all(
213 | # [
214 | # math_equal(pred_parts_sorted[i], ref_parts_sorted[i], include_percentage, is_close, timeout=timeout)
215 | # for i in range(len(pred_parts_sorted))
216 | # ]
217 | # ):
218 | # return True
219 |
220 |
221 | ## [a, b] vs. [c, d], return a==c and b==d
222 | if (
223 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
224 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
225 | ):
226 | pred_parts = prediction[1:-1].split(",")
227 | ref_parts = reference[1:-1].split(",")
228 | if len(pred_parts) == len(ref_parts):
229 | if all(
230 | [
231 | math_equal(
232 | pred_parts[i], ref_parts[i], include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth
233 | )
234 | for i in range(len(pred_parts))
235 | ]
236 | ):
237 | return True
238 | if (
239 | (
240 | prediction.startswith("\\begin{pmatrix}")
241 | or prediction.startswith("\\begin{bmatrix}")
242 | )
243 | and (
244 | prediction.endswith("\\end{pmatrix}")
245 | or prediction.endswith("\\end{bmatrix}")
246 | )
247 | and (
248 | reference.startswith("\\begin{pmatrix}")
249 | or reference.startswith("\\begin{bmatrix}")
250 | )
251 | and (
252 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
253 | )
254 | ):
255 | pred_lines = [
256 | line.strip()
257 | for line in prediction[
258 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
259 | ].split("\\\\")
260 | if line.strip()
261 | ]
262 | ref_lines = [
263 | line.strip()
264 | for line in reference[
265 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
266 | ].split("\\\\")
267 | if line.strip()
268 | ]
269 | matched = True
270 | if len(pred_lines) == len(ref_lines):
271 | for pred_line, ref_line in zip(pred_lines, ref_lines):
272 | pred_parts = pred_line.split("&")
273 | ref_parts = ref_line.split("&")
274 | if len(pred_parts) == len(ref_parts):
275 | if not all(
276 | [
277 | math_equal(
278 | pred_parts[i],
279 | ref_parts[i],
280 | include_percentage,
281 | is_close,
282 | timeout=timeout,
283 | depth=depth+1,
284 | max_depth=max_depth
285 | )
286 | for i in range(len(pred_parts))
287 | ]
288 | ):
289 | matched = False
290 | break
291 | else:
292 | matched = False
293 | if not matched:
294 | break
295 | else:
296 | matched = False
297 | if matched:
298 | return True
299 |
300 | if prediction.count("=") == 1 and reference.count("=") == 1:
301 | pred = prediction.split("=")
302 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
303 | ref = reference.split("=")
304 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
305 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
306 | return True
307 | elif (
308 | prediction.count("=") == 1
309 | and len(prediction.split("=")[0].strip()) <= 2
310 | and "=" not in reference
311 | ):
312 | if math_equal(
313 | prediction.split("=")[1], reference, include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth
314 | ):
315 | return True
316 | elif (
317 | reference.count("=") == 1
318 | and len(reference.split("=")[0].strip()) <= 2
319 | and "=" not in prediction
320 | ):
321 | if math_equal(
322 | prediction, reference.split("=")[1], include_percentage, is_close, timeout=timeout, depth=depth+1, max_depth=max_depth
323 | ):
324 | return True
325 |
326 | if timeout:
327 | if call_with_timeout(symbolic_equal_process, prediction, reference):
328 | return True
329 | # try:
330 | # if call_with_timeout(symbolic_equal, prediction, reference, timeout=1):
331 | # return True
332 | # except TimeoutError:
333 | # return False
334 | else:
335 | if symbolic_equal(prediction, reference):
336 | return True
337 |
338 | return False
339 |
340 |
341 | def math_equal_process(param):
342 | return math_equal(param[-2], param[-1])
343 |
344 |
345 | def numeric_equal(prediction: float, reference: float):
346 | # Note that relative tolerance has significant impact
347 | # on the result of the synthesized GSM-Hard dataset
348 | # if reference.is_integer():
349 | # return isclose(reference, round(prediction), abs_tol=1e-4)
350 | # else:
351 | # prediction = round(prediction, len(str(reference).split(".")[-1]))
352 |
353 | # return isclose(reference, prediction, rel_tol=1e-4)
354 | return isclose(reference, prediction, abs_tol=1e-4)
355 |
356 |
357 | def symbolic_equal(a, b):
358 | def _parse(s):
359 | for f in [parse_latex, parse_expr, latex2sympy]:
360 | try:
361 | return f(s.replace("\\\\", "\\"))
362 | except:
363 | try:
364 | return f(s)
365 | except:
366 | pass
367 | return s
368 |
369 | a = _parse(a)
370 | b = _parse(b)
371 |
372 | # direct equal
373 | try:
374 | if str(a) == str(b) or a == b:
375 | return True
376 | except:
377 | pass
378 |
379 | # simplify equal
380 | try:
381 | if a.equals(b) or simplify(a - b) == 0:
382 | return True
383 | except:
384 | pass
385 |
386 | # equation equal
387 | try:
388 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
389 | return True
390 | except:
391 | pass
392 |
393 | try:
394 | if numeric_equal(float(N(a)), float(N(b))):
395 | return True
396 | except:
397 | pass
398 |
399 | # matrix
400 | try:
401 | # if a and b are matrix
402 | if a.shape == b.shape:
403 | _a = a.applyfunc(lambda x: round(x, 3))
404 | _b = b.applyfunc(lambda x: round(x, 3))
405 | if _a.equals(_b):
406 | return True
407 | except:
408 | pass
409 |
410 | return False
411 |
412 |
413 | def symbolic_equal_process(a, b, output_queue):
414 | result = symbolic_equal(a, b)
415 | output_queue.put(result)
416 |
417 |
418 | def call_with_timeout(func, *args, timeout=3, **kwargs):
419 | output_queue = multiprocessing.Queue()
420 | process_args = args + (output_queue,)
421 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
422 | process.start()
423 | process.join(timeout)
424 |
425 | if process.is_alive():
426 | process.terminate()
427 | process.join()
428 | return False
429 |
430 | return output_queue.get()
431 |
432 | # def call_with_timeout(func, *args, timeout=1, **kwargs):
433 | # # Register the signal function handler
434 | # signal.signal(signal.SIGALRM, timeout_handler)
435 | # # Set the alarm
436 | # signal.alarm(timeout)
437 |
438 | # try:
439 | # result = func(*args, **kwargs)
440 | # signal.alarm(0) # Disable the alarm if function completes in time
441 | # return result
442 | # except TimeoutError:
443 | # return False
444 | # finally:
445 | # # Ensure the alarm is disabled
446 | # signal.alarm(0)
447 |
448 | # def call_with_timeout(func, *args, timeout=1, **kwargs):
449 | # with ThreadPoolExecutor(max_workers=1) as executor:
450 | # future = executor.submit(func, *args, **kwargs)
451 | # try:
452 | # result = future.result(timeout=timeout) # Wait for result with a timeout
453 | # return result
454 | # except TimeoutError:
455 | # return False # Timeout occurred
456 |
457 |
458 |
459 | def check_is_correct(pred, gt, timeout=True):
460 | return math_equal(strip_string(pred), strip_string(gt), timeout=timeout)
461 |
462 |
463 | def math_equal_simple(pred, gt):
464 | pred = strip_string(pred)
465 | gt = strip_string(gt)
466 | flag = False
467 |
468 | try:
469 | pred_expr = latex2sympy(pred)
470 | except:
471 | pred_expr = pred
472 | flag = True
473 |
474 | try:
475 | gt_expr = latex2sympy(gt)
476 | except:
477 | gt_expr = gt
478 | flag = True
479 |
480 | if flag == True:
481 | return pred == gt
482 |
483 | try:
484 | if abs(N(pred_expr) - N(gt_expr)) <= 1e-5:
485 | return True
486 | except:
487 | return False
488 |
489 | return False
490 |
491 |
492 | def check_is_correct_simple(pred, gt, timeout=True):
493 | if timeout:
494 | return call_with_timeout(math_equal_simple, pred, gt, timeout=1)
495 | else:
496 | return math_equal_simple(pred, gt)
497 |
498 | def _test_math_equal():
499 |
500 | # gt = "\\begin{pmatrix} -10 \\\\ 6 \\end{pmatrix}"
501 | # pred = "\\begin{pmatrix}-10\\\\6\\end{pmatrix}"
502 |
503 | # gt = "(6, -\\frac{3}{8})"
504 | # pred = "\left( 6, -\\frac{3}{8} \\right)"
505 |
506 | # print(math_equal(strip_string(pred), strip_string(gt), timeout=False))
507 |
508 | s = "(A) 3"
509 | print(choice_answer_clean(s))
510 |
511 |
512 |
513 |
514 | if __name__ == "__main__":
515 | _test_math_equal()
516 |
--------------------------------------------------------------------------------
/data/amc/test.jsonl:
--------------------------------------------------------------------------------
1 | {"id":0,"problem":"Cities $A$ and $B$ are $45$ miles apart. Alicia lives in $A$ and Beth lives in $B$. Alicia bikes towards $B$ at 18 miles per hour. Leaving at the same time, Beth bikes toward $A$ at 12 miles per hour. How many miles from City $A$ will they be when they meet?","answer":27.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_1","question":"Cities $A$ and $B$ are $45$ miles apart. Alicia lives in $A$ and Beth lives in $B$. Alicia bikes towards $B$ at 18 miles per hour. Leaving at the same time, Beth bikes toward $A$ at 12 miles per hour. How many miles from City $A$ will they be when they meet?"}
2 | {"id":1,"problem":"Positive real numbers $x$ and $y$ satisfy $y^3=x^2$ and $(y-x)^2=4y^2$. What is $x+y$?","answer":36.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_10","question":"Positive real numbers $x$ and $y$ satisfy $y^3=x^2$ and $(y-x)^2=4y^2$. What is $x+y$?"}
3 | {"id":2,"problem":"What is the degree measure of the acute angle formed by lines with slopes $2$ and $\\frac{1}{3}$?","answer":45.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_11","question":"What is the degree measure of the acute angle formed by lines with slopes $2$ and $\\frac{1}{3}$?"}
4 | {"id":3,"problem":"What is the value of\n\\[2^3 - 1^3 + 4^3 - 3^3 + 6^3 - 5^3 + \\dots + 18^3 - 17^3?\\]","answer":3159.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_12","question":"What is the value of\n\\[2^3 - 1^3 + 4^3 - 3^3 + 6^3 - 5^3 + \\dots + 18^3 - 17^3?\\]"}
5 | {"id":4,"problem":"In a table tennis tournament every participant played every other participant exactly once. Although there were twice as many right-handed players as left-handed players, the number of games won by left-handed players was $40\\%$ more than the number of games won by right-handed players. (There were no ties and no ambidextrous players.) What is the total number of games played?","answer":36.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_13","question":"In a table tennis tournament every participant played every other participant exactly once. Although there were twice as many right-handed players as left-handed players, the number of games won by left-handed players was $40\\%$ more than the number of games won by right-handed players. (There were no ties and no ambidextrous players.) What is the total number of games played?"}
6 | {"id":5,"problem":"How many complex numbers satisfy the equation $z^5=\\overline{z}$, where $\\overline{z}$ is the conjugate of the complex number $z$?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_14","question":"How many complex numbers satisfy the equation $z^5=\\overline{z}$, where $\\overline{z}$ is the conjugate of the complex number $z$?"}
7 | {"id":7,"problem":"Consider the set of complex numbers $z$ satisfying $|1+z+z^{2}|=4$. The maximum value of the imaginary part of $z$ can be written in the form $\\tfrac{\\sqrt{m}}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":21.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_16","question":"Consider the set of complex numbers $z$ satisfying $|1+z+z^{2}|=4$. The maximum value of the imaginary part of $z$ can be written in the form $\\tfrac{\\sqrt{m}}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
8 | {"id":8,"problem":"Flora the frog starts at 0 on the number line and makes a sequence of jumps to the right. In any one jump, independent of previous jumps, Flora leaps a positive integer distance $m$ with probability $\\frac{1}{2^m}$.\nWhat is the probability that Flora will eventually land at 10? Write the answer as a simplified fraction $\\frac{m}{n}$, find $m+n$","answer":3.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_17","question":"Flora the frog starts at 0 on the number line and makes a sequence of jumps to the right. In any one jump, independent of previous jumps, Flora leaps a positive integer distance $m$ with probability $\\frac{1}{2^m}$.\nWhat is the probability that Flora will eventually land at 10? Write the answer as a simplified fraction $\\frac{m}{n}$, find $m+n$"}
9 | {"id":10,"problem":"What is the product of all solutions to the equation\n\\[\\log_{7x}2023\\cdot \\log_{289x}2023=\\log_{2023x}2023\\]","answer":1.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_19","question":"What is the product of all solutions to the equation\n\\[\\log_{7x}2023\\cdot \\log_{289x}2023=\\log_{2023x}2023\\]"}
10 | {"id":11,"problem":"The weight of $\\frac{1}{3}$ of a large pizza together with $3 \\frac{1}{2}$ cups of orange slices is the same as the weight of $\\frac{3}{4}$ of a large pizza together with $\\frac{1}{2}$ cup of orange slices. A cup of orange slices weighs $\\frac{1}{4}$ of a pound. What is the weight, in pounds, of a large pizza? The answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m-n$?","answer":4.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_2","question":"The weight of $\\frac{1}{3}$ of a large pizza together with $3 \\frac{1}{2}$ cups of orange slices is the same as the weight of $\\frac{3}{4}$ of a large pizza together with $\\frac{1}{2}$ cup of orange slices. A cup of orange slices weighs $\\frac{1}{4}$ of a pound. What is the weight, in pounds, of a large pizza? The answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m-n$?"}
11 | {"id":12,"problem":"Rows 1, 2, 3, 4, and 5 of a triangular array of integers are shown below.\n1\n1 1\n1 3 1\n1 5 5 1\n1 7 11 7 1\nEach row after the first row is formed by placing a 1 at each end of the row, and each interior entry is 1 greater than the sum of the two numbers diagonally above it in the previous row. What is the units digits of the sum of the 2023 numbers in the 2023rd row?","answer":5.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_20","question":"Rows 1, 2, 3, 4, and 5 of a triangular array of integers are shown below.\n1\n1 1\n1 3 1\n1 5 5 1\n1 7 11 7 1\nEach row after the first row is formed by placing a 1 at each end of the row, and each interior entry is 1 greater than the sum of the two numbers diagonally above it in the previous row. What is the units digits of the sum of the 2023 numbers in the 2023rd row?"}
12 | {"id":13,"problem":"If $A$ and $B$ are vertices of a polyhedron, define the distance $d(A,B)$ to be the minimum number of edges of the polyhedron one must traverse in order to connect $A$ and $B$. For example, if $\\overline{AB}$ is an edge of the polyhedron, then $d(A, B) = 1$, but if $\\overline{AC}$ and $\\overline{CB}$ are edges and $\\overline{AB}$ is not an edge, then $d(A, B) = 2$. Let $Q$, $R$, and $S$ be randomly chosen distinct vertices of a regular icosahedron (regular polyhedron made up of 20 equilateral triangles). Find the probability that $d(Q, R) > d(R, S)$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":29.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_21","question":"If $A$ and $B$ are vertices of a polyhedron, define the distance $d(A,B)$ to be the minimum number of edges of the polyhedron one must traverse in order to connect $A$ and $B$. For example, if $\\overline{AB}$ is an edge of the polyhedron, then $d(A, B) = 1$, but if $\\overline{AC}$ and $\\overline{CB}$ are edges and $\\overline{AB}$ is not an edge, then $d(A, B) = 2$. Let $Q$, $R$, and $S$ be randomly chosen distinct vertices of a regular icosahedron (regular polyhedron made up of 20 equilateral triangles). Find the probability that $d(Q, R) > d(R, S)$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
13 | {"id":14,"problem":"Let $f$ be the unique function defined on the positive integers such that \\[\\sum_{d\\mid n}d\\cdot f\\left(\\frac{n}{d}\\right)=1\\] for all positive integers $n$. What is $f(2023)$?","answer":96.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_22","question":"Let $f$ be the unique function defined on the positive integers such that \\[\\sum_{d\\mid n}d\\cdot f\\left(\\frac{n}{d}\\right)=1\\] for all positive integers $n$. What is $f(2023)$?"}
14 | {"id":15,"problem":"How many ordered pairs of positive real numbers $(a,b)$ satisfy the equation\n\\[(1+2a)(2+2b)(2a+b) = 32ab?\\]","answer":1.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_23","question":"How many ordered pairs of positive real numbers $(a,b)$ satisfy the equation\n\\[(1+2a)(2+2b)(2a+b) = 32ab?\\]"}
15 | {"id":16,"problem":"Let $K$ be the number of sequences $A_1$, $A_2$, $\\dots$, $A_n$ such that $n$ is a positive integer less than or equal to $10$, each $A_i$ is a subset of $\\{1, 2, 3, \\dots, 10\\}$, and $A_{i-1}$ is a subset of $A_i$ for each $i$ between $2$ and $n$, inclusive. For example, $\\{\\}$, $\\{5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 6, 7, 9\\}$ is one such sequence, with $n = 5$.What is the remainder when $K$ is divided by $10$?","answer":5.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_24","question":"Let $K$ be the number of sequences $A_1$, $A_2$, $\\dots$, $A_n$ such that $n$ is a positive integer less than or equal to $10$, each $A_i$ is a subset of $\\{1, 2, 3, \\dots, 10\\}$, and $A_{i-1}$ is a subset of $A_i$ for each $i$ between $2$ and $n$, inclusive. For example, $\\{\\}$, $\\{5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 7\\}$, $\\{2, 5, 6, 7, 9\\}$ is one such sequence, with $n = 5$.What is the remainder when $K$ is divided by $10$?"}
16 | {"id":17,"problem":"There is a unique sequence of integers $a_1, a_2, \\cdots a_{2023}$ such that\n\\[\\tan2023x = \\frac{a_1 \\tan x + a_3 \\tan^3 x + a_5 \\tan^5 x + \\cdots + a_{2023} \\tan^{2023} x}{1 + a_2 \\tan^2 x + a_4 \\tan^4 x \\cdots + a_{2022} \\tan^{2022} x}\\]whenever $\\tan 2023x$ is defined. What is $a_{2023}?$","answer":-1.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_25","question":"There is a unique sequence of integers $a_1, a_2, \\cdots a_{2023}$ such that\n\\[\\tan2023x = \\frac{a_1 \\tan x + a_3 \\tan^3 x + a_5 \\tan^5 x + \\cdots + a_{2023} \\tan^{2023} x}{1 + a_2 \\tan^2 x + a_4 \\tan^4 x \\cdots + a_{2022} \\tan^{2022} x}\\]whenever $\\tan 2023x$ is defined. What is $a_{2023}?$"}
17 | {"id":18,"problem":"How many positive perfect squares less than $2023$ are divisible by $5$?","answer":8.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_3","question":"How many positive perfect squares less than $2023$ are divisible by $5$?"}
18 | {"id":19,"problem":"How many digits are in the base-ten representation of $8^5 \\cdot 5^{10} \\cdot 15^5$?","answer":18.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_4","question":"How many digits are in the base-ten representation of $8^5 \\cdot 5^{10} \\cdot 15^5$?"}
19 | {"id":20,"problem":"Janet rolls a standard $6$-sided die $4$ times and keeps a running total of the numbers she rolls. What is the probability that at some point, her running total will equal $3$? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":265.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_5","question":"Janet rolls a standard $6$-sided die $4$ times and keeps a running total of the numbers she rolls. What is the probability that at some point, her running total will equal $3$? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
20 | {"id":21,"problem":"Points $A$ and $B$ lie on the graph of $y=\\log_{2}x$. The midpoint of $\\overline{AB}$ is $(6, 2)$. What is the positive difference between the $x$-coordinates of $A$ and $B$? The final answer can be written in the form $m \\sqrt{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":9.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_6","question":"Points $A$ and $B$ lie on the graph of $y=\\log_{2}x$. The midpoint of $\\overline{AB}$ is $(6, 2)$. What is the positive difference between the $x$-coordinates of $A$ and $B$? The final answer can be written in the form $m \\sqrt{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
21 | {"id":22,"problem":"A digital display shows the current date as an $8$-digit integer consisting of a $4$-digit year, followed by a $2$-digit month, followed by a $2$-digit date within the month. For example, Arbor Day this year is displayed as 20230428. For how many dates in $2023$ will each digit appear an even number of times in the 8-digital display for that date?","answer":9.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_7","question":"A digital display shows the current date as an $8$-digit integer consisting of a $4$-digit year, followed by a $2$-digit month, followed by a $2$-digit date within the month. For example, Arbor Day this year is displayed as 20230428. For how many dates in $2023$ will each digit appear an even number of times in the 8-digital display for that date?"}
22 | {"id":23,"problem":"Maureen is keeping track of the mean of her quiz scores this semester. If Maureen scores an $11$ on the next quiz, her mean will increase by $1$. If she scores an $11$ on each of the next three quizzes, her mean will increase by $2$. What is the mean of her quiz scores currently?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12A_Problems\/Problem_8","question":"Maureen is keeping track of the mean of her quiz scores this semester. If Maureen scores an $11$ on the next quiz, her mean will increase by $1$. If she scores an $11$ on each of the next three quizzes, her mean will increase by $2$. What is the mean of her quiz scores currently?"}
23 | {"id":25,"problem":"Mrs. Jones is pouring orange juice into four identical glasses for her four sons. She fills the first three glasses completely but runs out of juice when the fourth glass is only $\\frac{1}{3}$ full. What fraction of a glass must Mrs. Jones pour from each of the first three glasses into the fourth glass so that all four glasses will have the same amount of juice? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_1","question":"Mrs. Jones is pouring orange juice into four identical glasses for her four sons. She fills the first three glasses completely but runs out of juice when the fourth glass is only $\\frac{1}{3}$ full. What fraction of a glass must Mrs. Jones pour from each of the first three glasses into the fourth glass so that all four glasses will have the same amount of juice? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
24 | {"id":26,"problem":"In the $xy$-plane, a circle of radius $4$ with center on the positive $x$-axis is tangent to the $y$-axis at the origin, and a circle with radius $10$ with center on the positive $y$-axis is tangent to the $x$-axis at the origin. What is the slope of the line passing through the two points at which these circles intersect? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":7.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_10","question":"In the $xy$-plane, a circle of radius $4$ with center on the positive $x$-axis is tangent to the $y$-axis at the origin, and a circle with radius $10$ with center on the positive $y$-axis is tangent to the $x$-axis at the origin. What is the slope of the line passing through the two points at which these circles intersect? The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
25 | {"id":27,"problem":"Calculate the maximum area of an isosceles trapezoid that has legs of length $1$ and one base twice as long as the other. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m^2+n^2$?","answer":13.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_11","question":"Calculate the maximum area of an isosceles trapezoid that has legs of length $1$ and one base twice as long as the other. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m^2+n^2$?"}
26 | {"id":28,"problem":"For complex number $u = a+bi$ and $v = c+di$ (where $i=\\sqrt{-1}$), define the binary operation\n$u \\otimes v = ac + bdi$\nSuppose $z$ is a complex number such that $z\\otimes z = z^{2}+40$. What is $|z|^2$?","answer":50.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_12","question":"For complex number $u = a+bi$ and $v = c+di$ (where $i=\\sqrt{-1}$), define the binary operation\n$u \\otimes v = ac + bdi$\nSuppose $z$ is a complex number such that $z\\otimes z = z^{2}+40$. What is $|z|^2$?"}
27 | {"id":29,"problem":"A rectangular box $P$ has distinct edge lengths $a$, $b$, and $c$. The sum of the lengths of all $12$ edges of $P$ is $13$, the areas of all $6$ faces of $P$ is $\\frac{11}{2}$, and the volume of $P$ is $\\frac{1}{2}$. Find the length of the longest interior diagonal connecting two vertices of $P$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":13.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_13","question":"A rectangular box $P$ has distinct edge lengths $a$, $b$, and $c$. The sum of the lengths of all $12$ edges of $P$ is $13$, the areas of all $6$ faces of $P$ is $\\frac{11}{2}$, and the volume of $P$ is $\\frac{1}{2}$. Find the length of the longest interior diagonal connecting two vertices of $P$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
28 | {"id":30,"problem":"For how many ordered pairs $(a,b)$ of integers does the polynomial $x^3+ax^2+bx+6$ have $3$ distinct integer roots?","answer":5.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_14","question":"For how many ordered pairs $(a,b)$ of integers does the polynomial $x^3+ax^2+bx+6$ have $3$ distinct integer roots?"}
29 | {"id":32,"problem":"In the state of Coinland, coins have values $6,10,$ and $15$ cents. Suppose $x$ is the value in cents of the most expensive item in Coinland that cannot be purchased using these coins with exact change. What is the sum of the digits of $x?$","answer":11.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_16","question":"In the state of Coinland, coins have values $6,10,$ and $15$ cents. Suppose $x$ is the value in cents of the most expensive item in Coinland that cannot be purchased using these coins with exact change. What is the sum of the digits of $x?$"}
30 | {"id":33,"problem":"Triangle $ABC$ has side lengths in arithmetic progression, and the smallest side has length $6.$ If the triangle has an angle of $120^\\circ,$ Find the area of $ABC$. The final answer can be simplified in the form $m \\sqrt{n}$, where $m$ and $n$ are positive integers and $n$ without square factore. What is $m+n$?","answer":18.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_17","question":"Triangle $ABC$ has side lengths in arithmetic progression, and the smallest side has length $6.$ If the triangle has an angle of $120^\\circ,$ Find the area of $ABC$. The final answer can be simplified in the form $m \\sqrt{n}$, where $m$ and $n$ are positive integers and $n$ without square factore. What is $m+n$?"}
31 | {"id":36,"problem":"Carlos went to a sports store to buy running shoes. Running shoes were on sale, with prices reduced by $20\\%$ on every pair of shoes. Carlos also knew that he had to pay a $7.5\\%$ sales tax on the discounted price. He had $$43$ dollars. What is the original (before discount) price of the most expensive shoes he could afford to buy? ","answer":50.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_2","question":"Carlos went to a sports store to buy running shoes. Running shoes were on sale, with prices reduced by $20\\%$ on every pair of shoes. Carlos also knew that he had to pay a $7.5\\%$ sales tax on the discounted price. He had $$43$ dollars. What is the original (before discount) price of the most expensive shoes he could afford to buy? "}
32 | {"id":40,"problem":"When $n$ standard six-sided dice are rolled, the product of the numbers rolled can be any of $936$ possible values. What is $n$?","answer":11.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_23","question":"When $n$ standard six-sided dice are rolled, the product of the numbers rolled can be any of $936$ possible values. What is $n$?"}
33 | {"id":41,"problem":"Suppose that $a$, $b$, $c$ and $d$ are positive integers satisfying all of the following relations.\n\\[abcd=2^6\\cdot 3^9\\cdot 5^7\\]\n\\[\\text{lcm}(a,b)=2^3\\cdot 3^2\\cdot 5^3\\]\n\\[\\text{lcm}(a,c)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(a,d)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(b,c)=2^1\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(b,d)=2^2\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(c,d)=2^2\\cdot 3^3\\cdot 5^2\\]\nWhat is $\\text{gcd}(a,b,c,d)$?","answer":3.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_24","question":"Suppose that $a$, $b$, $c$ and $d$ are positive integers satisfying all of the following relations.\n\\[abcd=2^6\\cdot 3^9\\cdot 5^7\\]\n\\[\\text{lcm}(a,b)=2^3\\cdot 3^2\\cdot 5^3\\]\n\\[\\text{lcm}(a,c)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(a,d)=2^3\\cdot 3^3\\cdot 5^3\\]\n\\[\\text{lcm}(b,c)=2^1\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(b,d)=2^2\\cdot 3^3\\cdot 5^2\\]\n\\[\\text{lcm}(c,d)=2^2\\cdot 3^3\\cdot 5^2\\]\nWhat is $\\text{gcd}(a,b,c,d)$?"}
34 | {"id":43,"problem":"A $3-4-5$ right triangle is inscribed in circle $A$, and a $5-12-13$ right triangle is inscribed in circle $B$. Find the ratio of the area of circle $A$ to the area of circle $B$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?","answer":194.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_3","question":"A $3-4-5$ right triangle is inscribed in circle $A$, and a $5-12-13$ right triangle is inscribed in circle $B$. Find the ratio of the area of circle $A$ to the area of circle $B$. The final answer can be written in the form $\\frac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?"}
35 | {"id":44,"problem":"Jackson's paintbrush makes a narrow strip with a width of $6.5$ millimeters. Jackson has enough paint to make a strip $25$ meters long. How many square centimeters of paper could Jackson cover with paint?","answer":1625.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_4","question":"Jackson's paintbrush makes a narrow strip with a width of $6.5$ millimeters. Jackson has enough paint to make a strip $25$ meters long. How many square centimeters of paper could Jackson cover with paint?"}
36 | {"id":45,"problem":"You are playing a game. A $2 \\times 1$ rectangle covers two adjacent squares (oriented either horizontally or vertically) of a $3 \\times 3$ grid of squares, but you are not told which two squares are covered. Your goal is to find at least one square that is covered by the rectangle. A \"turn\" consists of you guessing a square, after which you are told whether that square is covered by the hidden rectangle. What is the minimum number of turns you need to ensure that at least one of your guessed squares is covered by the rectangle?","answer":4.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_5","question":"You are playing a game. A $2 \\times 1$ rectangle covers two adjacent squares (oriented either horizontally or vertically) of a $3 \\times 3$ grid of squares, but you are not told which two squares are covered. Your goal is to find at least one square that is covered by the rectangle. A \"turn\" consists of you guessing a square, after which you are told whether that square is covered by the hidden rectangle. What is the minimum number of turns you need to ensure that at least one of your guessed squares is covered by the rectangle?"}
37 | {"id":46,"problem":"When the roots of the polynomial \n\\[P(x) = (x-1)^1 (x-2)^2 (x-3)^3 \\cdot \\cdot \\cdot (x-10)^{10}\\]\nare removed from the number line, what remains is the union of $11$ disjoint open intervals. On how many of these intervals is $P(x)$ positive?","answer":6.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_6","question":"When the roots of the polynomial \n\\[P(x) = (x-1)^1 (x-2)^2 (x-3)^3 \\cdot \\cdot \\cdot (x-10)^{10}\\]\nare removed from the number line, what remains is the union of $11$ disjoint open intervals. On how many of these intervals is $P(x)$ positive?"}
38 | {"id":47,"problem":"For how many integers $n$ does the expression\\[\\sqrt{\\frac{\\log (n^2) - (\\log n)^2}{\\log n - 3}}\\]represent a real number, where log denotes the base $10$ logarithm?","answer":901.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_7","question":"For how many integers $n$ does the expression\\[\\sqrt{\\frac{\\log (n^2) - (\\log n)^2}{\\log n - 3}}\\]represent a real number, where log denotes the base $10$ logarithm?"}
39 | {"id":48,"problem":"How many nonempty subsets $B$ of ${0, 1, 2, 3, \\cdots, 12}$ have the property that the number of elements in $B$ is equal to the least element of $B$? For example, $B = {4, 6, 8, 11}$ satisfies the condition.","answer":144.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_8","question":"How many nonempty subsets $B$ of ${0, 1, 2, 3, \\cdots, 12}$ have the property that the number of elements in $B$ is equal to the least element of $B$? For example, $B = {4, 6, 8, 11}$ satisfies the condition."}
40 | {"id":49,"problem":"What is the area of the region in the coordinate plane defined by\n$| | x | - 1 | + | | y | - 1 | \\le 1$?","answer":8.0,"url":"https:\/\/artofproblemsolving.com\/wiki\/index.php\/2023_AMC_12B_Problems\/Problem_9","question":"What is the area of the region in the coordinate plane defined by\n$| | x | - 1 | + | | y | - 1 | \\le 1$?"}
41 |
--------------------------------------------------------------------------------
/vllm-deer-qwen3.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore") # Ignore all warnings
3 |
4 | import os
5 | import json
6 | import time
7 | import argparse
8 | import sys
9 | import torch
10 | import torch.nn.functional as F
11 | from vllm.outputs import CompletionOutput
12 | from typing import Any, Dict, List
13 | from nltk import ngrams
14 | from collections import Counter
15 |
16 | from transformers import AutoTokenizer
17 | from tqdm import tqdm
18 | from vllm import LLM, SamplingParams
19 | import pdb
20 |
21 | import math
22 | import numpy as np
23 | import random
24 |
25 | def set_seeds(seed=42):
26 | # Set Python built-in random seed
27 | random.seed(seed)
28 |
29 | # Set NumPy random seed
30 | np.random.seed(seed)
31 |
32 | # Set PyTorch CPU random seed
33 | torch.manual_seed(seed)
34 |
35 | # If using GPU (especially CUDA)
36 | if torch.cuda.is_available():
37 | torch.cuda.manual_seed(seed) # Set seed for current GPU
38 | torch.cuda.manual_seed_all(seed) # Also effective for multi-GPU
39 |
40 | # For better reproducibility, enable cudnn determinism mode
41 | torch.backends.cudnn.deterministic = True
42 | torch.backends.cudnn.benchmark = False
43 |
44 | # Optional: Set generator (for DataLoader with multi-threading)
45 | g = torch.Generator()
46 | g.manual_seed(seed)
47 |
48 |
49 |
50 |
51 |
52 | def append_jsonl(data, file_path):
53 | """Append results in the list to a .jsonl file."""
54 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
55 | with open(file_path, 'a', encoding='utf-8') as f:
56 | for item in data:
57 | f.write(json.dumps(item, ensure_ascii=False) + '\n')
58 |
59 | def write_jsonl(data, file_path):
60 | """Write results in the list to a .jsonl file."""
61 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
62 | with open(file_path, 'w', encoding='utf-8') as f:
63 | for item in data:
64 | f.write(json.dumps(item, ensure_ascii=False) + '\n')
65 |
66 | def read_jsonl(file_path):
67 | """Read .jsonl file and return a list of dictionaries."""
68 | data = []
69 | if not os.path.exists(file_path):
70 | print(f"Warning: Dataset file not found at {file_path}")
71 | return data
72 | with open(file_path, 'r', encoding='utf-8') as f:
73 | for line in f:
74 | data.append(json.loads(line.strip()))
75 | return data
76 |
77 |
78 | def seq_rep_n(last_thinking, cur_thinking, rep, n=1):
79 |
80 |
81 | pred = last_thinking
82 | target = cur_thinking
83 |
84 | pred_tokens = pred.split(' ')
85 | target_token = target.split(' ')
86 |
87 | ngs_pred = [ng for ng in ngrams(pred_tokens, n)]
88 | ngs_know = [ng for ng in ngrams(target_token, n)]
89 | intersection = list(set(ngs_pred) & set(ngs_know))
90 | overlap_num = len(intersection)
91 |
92 |
93 | if overlap_num == len(ngs_pred) and overlap_num == len(ngs_know):
94 | rep += 1
95 |
96 |
97 | return rep
98 |
99 | # Function to calculate average max probability, mimicking Transformers version logic
100 | def calculate_average_max_prob_from_logprobs(logprobs_list, policy='avg2') -> float:
101 | """
102 | Calculate average max token probability from logprobs list in vLLM CompletionOutput.
103 | Compute from the second generated token to the second-to-last token.
104 | policy: min, avg1: arithmetic mean, avg2: geometric mean
105 | """
106 |
107 |
108 |
109 | num_tokens = len(logprobs_list)
110 | start_index = 1
111 | end_index = num_tokens
112 |
113 | if num_tokens < 1:
114 | print("Too few tokens to calculate valid average.")
115 | return 0.0
116 |
117 | total_prob_sum = 0.0
118 | log_prob_sum = 0.0 # For geometric mean
119 | count_for_average = 0
120 | min_prob = 1.0
121 |
122 | for i in range(start_index, end_index):
123 | # Ensure index is valid and corresponding logprobs entry is not empty
124 | if i < len(logprobs_list) and logprobs_list[i]:
125 | try:
126 | logprob_obj = list(logprobs_list[i].values())[0]
127 | # Ensure object has .logprob attribute
128 | if hasattr(logprob_obj, 'logprob'):
129 | prob = torch.exp(torch.tensor(logprob_obj.logprob)).item()
130 | if prob < min_prob:
131 | min_prob = prob
132 | #print(prob)
133 | #print(list(logprobs_list[i].values())[0])
134 | total_prob_sum += prob
135 | log_prob_sum += math.log(max(prob, 1e-10))
136 | count_for_average += 1
137 | else:
138 | print(f"Warning: Object at logprobs_list[{i}] doesn't have '.logprob' attribute.")
139 | except (IndexError, KeyError, AttributeError) as e:
140 | print(f"Warning: Unable to process logprobs at logprobs_list[{i}]: {e}")
141 | else:
142 | print(f"Warning: logprobs_list[{i}] is empty or invalid.")
143 | # Calculate average
144 | if policy == 'min':
145 | result = min_prob
146 | elif policy == 'avg1':
147 | result = total_prob_sum / count_for_average
148 | elif policy == 'avg2':
149 | result = math.exp(log_prob_sum / count_for_average)
150 | if (list(logprobs_list[-1].values())[-1].decoded_token) == '':
151 | return result
152 | else:
153 | return 0.0
154 |
155 |
156 | def parse_args():
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument('--model_name_or_path', type=str, default="./DeepSeek-R1-Distill-Qwen-14B/")
159 | parser.add_argument('--dataset_dir', type=str, default="./data/")
160 | parser.add_argument("--dtype", type=str, default="bfloat16")
161 | parser.add_argument("--max-model-len", "--model-context-len", type=int, default=40000, dest="model_context_len") # max-model-len for vllm, should be longer than max_generated_tokens.
162 | parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
163 | parser.add_argument("--trust-remote-code", action="store_true")
164 | parser.add_argument("--run_time", type=int, default=1)
165 | parser.add_argument("--no_thinking", type=int, default=0) # Calculate the answer confidence at the very beginning of the reasoning process and attempt to exit early.
166 | parser.add_argument("--rep", type=int, default=0) # Exit early when repetition occurs, but it remains to be implemented. (TODO)
167 | parser.add_argument("--points", type=int, default=1) # 1: 'Wait' as thinking transition point. 0: 'Alternatively' as thinking transition point.
168 | parser.add_argument("--af", type=int, default=0) # answer forcing at end of sequence
169 | parser.add_argument("--max_judge_steps", type=int, default=10) # Limit the maximum number of answer attempts to save time cost.
170 | parser.add_argument('--policy', type=str, default="avg2") # Strategy for Calculating Answer Confidence
171 |
172 | parser.add_argument('--threshold', type=float, default=0.95) # The answer confidence threshold used to determine early exit.
173 | parser.add_argument('--max_generated_tokens', '--max-len', type=int, default=16384, dest="max_len") # total token budget
174 | parser.add_argument('--dataset', type=str, default='math') # dataset name
175 | parser.add_argument('--output_path', type=str, default='./outputs') # output path
176 | parser.add_argument('--think_ratio', type=float, default=0.9, help="Ratio of thinking phase to max generated tokens") # Ratio of thinking phase to max generated tokens
177 | parser.add_argument('--batch_size', type=int, default=2000) # vllm batch size, set it to a value above the number of samples in the dataset.
178 | parser.add_argument('--temperature', type=float, default=0.0)
179 | parser.add_argument('--top_p', type=float, default=1.0)
180 |
181 | # Hardcoded 20
182 | parser.add_argument('--prob_check_max_tokens', type=int, default=20, help="Max tokens for probability check phase") # Max tokens for answer inducing
183 |
184 | args = parser.parse_args()
185 | return args
186 |
187 | def main():
188 | args = parse_args()
189 | args.model_context_len = args.max_len + 8000
190 | print(f"Using vLLM LLM object for direct inference (batch processing)")
191 | print(f"Model path: {args.model_name_or_path}")
192 | print(f"Dataset: {args.dataset}")
193 | print(f"Early exit probability threshold: {args.threshold}")
194 | print(f"Max total generated tokens: {args.max_len}")
195 | print(f"Thinking phase ratio: {args.think_ratio}")
196 | print(f"Batch size: {args.batch_size}")
197 | print(f"Max tokens for probability check phase: {args.prob_check_max_tokens}")
198 |
199 | print("\nInitializing vLLM LLM engine...")
200 | available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
201 | try:
202 | llm_engine = LLM(
203 | model=args.model_name_or_path,
204 | tensor_parallel_size=len(available_gpus),
205 | dtype=args.dtype,
206 | max_model_len=args.max_len + 8000,
207 | gpu_memory_utilization=args.gpu_memory_utilization,
208 | trust_remote_code=True,
209 | )
210 | print("vLLM LLM engine initialized successfully.")
211 |
212 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=args.trust_remote_code)
213 | print(f"Successfully loaded tokenizer: {args.model_name_or_path}")
214 | if tokenizer.pad_token is None:
215 | if tokenizer.eos_token is not None:
216 | tokenizer.pad_token = tokenizer.eos_token
217 | else:
218 | tokenizer.add_special_tokens({'pad_token': '[PAD]'})
219 | print("Warning: Model has no pad_token or eos_token. Added custom [PAD] token.")
220 |
221 | print(f"Tokenizer using pad_token_id: {tokenizer.pad_token_id}")
222 |
223 |
224 | except Exception as e:
225 | print(f"Failed to initialize vLLM LLM engine or load tokenizer: {e}")
226 | sys.exit(1)
227 |
228 | sys_prompt = ['Please reason step by step, and put your final answer within \\boxed{}.'][0]
229 | dataset_path = f'{args.dataset_dir}/{args.dataset}/test.jsonl'
230 | try:
231 | questions_json = read_jsonl(dataset_path)
232 | if not questions_json:
233 | print(f"Error: No questions loaded from {dataset_path}.")
234 | sys.exit(1)
235 | print(f"Successfully loaded dataset: {dataset_path}, total {len(questions_json)} questions")
236 | except Exception as e:
237 | print(f"Failed to load dataset: {e}")
238 | sys.exit(1)
239 |
240 | model_dir_name = os.path.basename(os.path.normpath(args.model_name_or_path))
241 | output_dir = f'{args.output_path}/{model_dir_name}/{args.dataset}'
242 | os.makedirs(output_dir, exist_ok=True)
243 | output_file = f'{output_dir}/greedy_p{str(args.threshold)}_ratio{str(args.think_ratio)}_len{str(args.max_len)}_temperature{str(args.temperature)}_run_time{args.run_time}_no_thinking{args.no_thinking}_rep{args.rep}_points{args.points}_policy{args.policy}.jsonl'
244 |
245 | print(f"\nStarting processing, total questions: {len(questions_json)}")
246 | start_time = time.time()
247 |
248 | questions_state = {} # Dictionary to store processing state for each question
249 | last_token_strs = [""] # Strings marking end of thinking
250 | if args.points == 1:
251 | continue_str = "Wait" # String appended to sequence end to indicate continued thinking
252 | else:
253 | continue_str = "Alternatively" # String appended to sequence end to indicate continued thinking
254 |
255 | answer_prompt_str = "\n**Final Answer**\n\\boxed" # Prompt string to guide answer generation
256 | if 'gpqa' in args.dataset:
257 | answer_prompt_str = "\n**Final Answer**\nI believe the final answer, rather than the option, is \\boxed"
258 |
259 | # Get token IDs for stop conditions and strings to append
260 | last_token_ids = []
261 | for s in last_token_strs:
262 | ids = tokenizer.encode(s, add_special_tokens=False)
263 | if ids: last_token_ids.extend(ids)
264 | last_token_ids = list(set(last_token_ids)) # Remove duplicate IDs
265 |
266 | continue_ids = tokenizer.encode(continue_str, add_special_tokens=False)
267 | if not continue_ids:
268 | print(f"Warning: Unable to tokenize continue string '{continue_str}'. This may affect logic.")
269 |
270 | # Stop tokens for thinking phase generation
271 | generation_stop_tokens = [continue_str] + last_token_strs + [tokenizer.eos_token]
272 | pred_prob_stop_tokens = ['']
273 |
274 | answer_stop_tokens = [tokenizer.eos_token]
275 |
276 |
277 | # Max token limit for thinking phase
278 | think_limit_tokens = int(args.max_len * args.think_ratio)
279 |
280 | for i, question_data in enumerate(questions_json):
281 |
282 | messages = [
283 | {"role": "system", "content": sys_prompt},
284 | {"role": "user", "content": question_data['problem']}
285 | ]
286 | formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
287 | if args.no_thinking == 1:
288 | questions_state[i] = {
289 | 'question_data': question_data,
290 | 'state': 'needs_prob_check', # try to exit ai begin
291 | 'formatted_prompt': formatted_prompt,
292 | 'current_full_sequence': formatted_prompt,
293 | 'generated_thinking_history': "",
294 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss",
295 | 'generated_answer_history': "",
296 | 'pred_prob': 0.0,
297 | 'too_long': 0,
298 | 'rep_end': 0,
299 | 'high_prob': 0,
300 | 'regular_end': 0,
301 | 'thinking_steps': 0,
302 | 'output_dict': {},
303 | 'error_message': None,
304 | 'question_index': i
305 | }
306 | else:
307 | questions_state[i] = {
308 | 'question_data': question_data,
309 | 'state': 'needs_thought_chunk',
310 | 'formatted_prompt': formatted_prompt,
311 | 'current_full_sequence': formatted_prompt,
312 | 'generated_thinking_history': "",
313 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss",
314 | 'generated_answer_history': "",
315 | 'pred_prob': 0.0,
316 | 'too_long': 0,
317 | 'rep_end': 0,
318 | 'high_prob': 0,
319 | 'regular_end': 0,
320 | 'thinking_steps': 0,
321 | 'output_dict': {},
322 | 'error_message': None,
323 | 'question_index': i
324 | }
325 |
326 | active_questions_indices = list(questions_state.keys()) # List of currently processing question indices
327 | pbar = tqdm(total=len(questions_json), desc="Processing questions")
328 |
329 | print("\nRunning a simple test generation...")
330 | try:
331 | test_outputs = llm_engine.generate(["Hello, world!"], SamplingParams(max_tokens=10, temperature=args.temperature), use_tqdm=False)
332 | if test_outputs and test_outputs[0].outputs:
333 | test_generated_text = test_outputs[0].outputs[0].text
334 | print(f"Test generation successful. Output: '{test_generated_text.strip()}'")
335 | else:
336 | print("Simple test generation failed: LLM generate returned no output.")
337 | except Exception as e:
338 | print(f"Simple test generation failed: {e}")
339 |
340 | # Main processing loop: continue while there are active questions
341 | while active_questions_indices: # indexes [0,1,2,...,n]
342 | batch_prompts = [] # Current batch prompts
343 | batch_sampling_params = [] # Current batch sampling parameters
344 | batch_request_info = [] # Store (question index, step type) for output processing
345 |
346 | current_batch_count = 0
347 | # Create copy of active_questions_indices to allow modifying original list during iteration
348 | current_active_indices_for_batching = active_questions_indices[:]
349 |
350 | # Build current batch
351 | for q_idx in current_active_indices_for_batching:
352 | if current_batch_count >= args.batch_size:
353 | break
354 |
355 | state = questions_state[q_idx]
356 | if state['state'] in ['finished', 'error']:
357 | continue
358 |
359 | prompt_for_batch = None
360 | sampling_params_for_batch = None
361 | step_type = None # 'think', 'prob_check_gen', 'answer'
362 |
363 | try:
364 | # --- Determine prompt and parameters based on state ---
365 | current_full_sequence_tokens = tokenizer.encode(state['current_full_sequence'], add_special_tokens=False)
366 | current_full_sequence_len = len(current_full_sequence_tokens)
367 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False))
368 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False))
369 |
370 | # Check context window limit before adding any new content
371 | remaining_context_window = args.model_context_len - current_full_sequence_len
372 | # check window, skip it
373 | if remaining_context_window <= 0:
374 | state['state'] = 'error'
375 | state['error_message'] = f"Exceeded model context window limit ({args.model_context_len}). Current sequence length: {current_full_sequence_len}"
376 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
377 | print(f"\nQuestion {q_idx}: {state['error_message']}")
378 | if q_idx in active_questions_indices:
379 | active_questions_indices.remove(q_idx)
380 | pbar.update(1)
381 | continue
382 |
383 | # Initial state
384 | if state['state'] == 'needs_thought_chunk':
385 | # Calculate max tokens for this generation chunk, considering thinking limit and context window
386 | max_new_tokens_for_thought = min(
387 | think_limit_tokens - current_generated_thinking_tokens, # thinking budget
388 | remaining_context_window
389 | )
390 | if max_new_tokens_for_thought <= 0:
391 | state['state'] = 'needs_answer'
392 | print(f"\nQuestion {q_idx}: Reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer generation.")
393 | state['too_long'] = 1
394 | continue
395 |
396 |
397 |
398 | prompt_for_batch = state['current_full_sequence']
399 | if state['thinking_steps'] < args.max_judge_steps:
400 | sampling_params_for_batch = SamplingParams(
401 | max_tokens=max_new_tokens_for_thought,
402 | temperature=args.temperature,
403 | top_p=args.top_p,
404 | stop=generation_stop_tokens
405 | )
406 | else:
407 | sampling_params_for_batch = SamplingParams(
408 | max_tokens=max_new_tokens_for_thought,
409 | temperature=args.temperature,
410 | top_p=args.top_p,
411 | stop=last_token_strs
412 | )
413 | step_type = 'think'
414 |
415 | # Check state
416 | elif state['state'] == 'needs_prob_check':
417 | # Build prompt for probability check generation:
418 | prompt_for_prob_check = state['current_full_sequence'] + answer_prompt_str
419 | prob_check_prompt_len = len(tokenizer.encode(prompt_for_prob_check, add_special_tokens=False))
420 | required_space_for_prob_check = prob_check_prompt_len + args.prob_check_max_tokens
421 |
422 | if required_space_for_prob_check > args.model_context_len:
423 | state['state'] = 'error'
424 | state['error_message'] = f"Probability check generation prompt exceeds context window ({args.model_context_len}). Estimated length: {required_space_for_prob_check}"
425 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
426 | print(f"\nQuestion {q_idx}: {state['error_message']}")
427 | if q_idx in active_questions_indices:
428 | active_questions_indices.remove(q_idx)
429 | pbar.update(1)
430 | continue
431 |
432 |
433 | # Parameters for generating *prediction sequence* in probability check phase
434 | prompt_for_batch = prompt_for_prob_check
435 | sampling_params_for_batch = SamplingParams(
436 | max_tokens=args.prob_check_max_tokens,
437 | #temperature=args.temperature, # Greedy decoding
438 | stop=pred_prob_stop_tokens, # Only predict content inside \boxed{}
439 | logprobs=1,
440 | )
441 | step_type = 'prob_check_gen'
442 |
443 |
444 | elif state['state'] == 'needs_answer':
445 | # Build final answer prompt
446 |
447 | if state['too_long'] == 1:
448 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n'
449 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n'
450 |
451 | else:
452 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n'#+ answer_prompt_str
453 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n'
454 |
455 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False))
456 | total_tokens_before_answer_prompt = current_generated_thinking_tokens
457 | # Calculate remaining total budget
458 | remaining_total_budget = args.max_len - total_tokens_before_answer_prompt
459 | max_new_tokens_answer = min(
460 | remaining_total_budget,
461 | args.model_context_len - len_final_answer_prompt
462 | )
463 |
464 | if max_new_tokens_answer <= 0:
465 |
466 | state['state'] = 'error'
467 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + "\nSkipped answer generation due to length limit."], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps']}
468 | print(f"\nQuestion {q_idx}: Skipped answer generation due to length limit.")
469 |
470 | if q_idx in active_questions_indices:
471 | active_questions_indices.remove(q_idx)
472 | pbar.update(1)
473 | continue
474 | else:
475 | prompt_for_batch = final_answer_prompt
476 | sampling_params_for_batch = SamplingParams(
477 | max_tokens=max_new_tokens_answer,
478 | temperature=args.temperature,
479 | stop=answer_stop_tokens,
480 | top_p=args.top_p,
481 | )
482 | step_type = 'answer'
483 |
484 |
485 |
486 |
487 |
488 | elif state['state'] == 'answer_forcing':
489 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str
490 | state['generated_thinking_history'] = state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str
491 |
492 | prompt_for_batch = final_answer_prompt
493 | sampling_params_for_batch = SamplingParams(
494 | max_tokens=100,
495 | temperature=args.temperature,
496 | stop=answer_stop_tokens,
497 | top_p=args.top_p,
498 | )
499 | step_type = 'answer_exit'
500 |
501 | # If execution reaches here and prompt_for_batch is None, there may be state logic issues
502 | # This check generally shouldn't trigger
503 | if prompt_for_batch is None:
504 | state['state'] = 'error'
505 | state['error_message'] = f"Internal error: No prompt generated for state {state['state']}."
506 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
507 | print(f"\nQuestion {q_idx}: {state['error_message']}")
508 | if q_idx in active_questions_indices:
509 | active_questions_indices.remove(q_idx)
510 | pbar.update(1)
511 | continue
512 |
513 | batch_prompts.append(prompt_for_batch)
514 | batch_sampling_params.append(sampling_params_for_batch)
515 | batch_request_info.append((q_idx, step_type))
516 | current_batch_count += 1
517 |
518 | except Exception as e:
519 | state['state'] = 'error'
520 | state['error_message'] = f"Error preparing batch request for state '{state['state']}': {e}"
521 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
522 | print(f"\nQuestion {q_idx}: {state['error_message']}")
523 | if q_idx in active_questions_indices:
524 | active_questions_indices.remove(q_idx)
525 | pbar.update(1)
526 |
527 | if not batch_prompts:
528 | # May occur when all remaining active questions have switched to finished/error state or were skipped.
529 | all_stuck = True
530 | for q_idx in active_questions_indices:
531 | if questions_state[q_idx]['state'] not in ['finished', 'error']:
532 | all_stuck = False
533 | break
534 | if all_stuck:
535 | print("Warning: Batch generated no requests. All remaining questions are completed or in error state.")
536 | break
537 | else:
538 | print("Error: Active questions remain but no batch requests generated. Possible logic error.")
539 | for q_idx in list(active_questions_indices):
540 | state = questions_state[q_idx]
541 | if state['state'] not in ['finished', 'error']:
542 | state['state'] = 'error'
543 | state['error_message'] = "Processing aborted: Unable to generate request in batch loop."
544 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
545 | print(f"\nQuestion {q_idx}: Marked as error due to processing abort.")
546 | if q_idx in active_questions_indices: # Should be, but double-checking
547 | active_questions_indices.remove(q_idx)
548 | pbar.update(1)
549 | break
550 |
551 | batch_outputs = llm_engine.generate(batch_prompts, batch_sampling_params, use_tqdm=False)
552 |
553 |
554 | # --- Process batch outputs ---
555 | for i, output in enumerate(batch_outputs):
556 | q_idx, step_type = batch_request_info[i]
557 | state = questions_state[q_idx]
558 |
559 | if state['state'] in ['finished', 'error']:
560 | continue
561 |
562 | try:
563 | if not output.outputs: # skip, exception handling
564 | # vLLM returned empty output, possibly due to prompt issues, length limits or other internal errors
565 | error_msg = f"vLLM returned empty output for request {output.request_id} (question {q_idx}, step {step_type})."
566 | if hasattr(output, 'error') and output.error:
567 | error_msg += f" vLLM error: {output.error}"
568 | current_full_sequence_len = len(tokenizer.encode(state['current_full_sequence'], add_special_tokens=False))
569 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False))
570 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False))
571 |
572 |
573 | if step_type == 'think':
574 | max_new = min(think_limit_tokens - current_generated_thinking_tokens, args.model_context_len - current_full_sequence_len)
575 | error_msg += f" State: think, attempting to generate {max_new} tokens."
576 | if max_new <= 0: error_msg += " Note: Calculated max new tokens <= 0."
577 | if args.model_context_len - current_full_sequence_len <= 0: error_msg += " Note: Already exceeded context window."
578 | if think_limit_tokens - current_generated_thinking_tokens <= 0: error_msg += " Note: Reached thinking token limit."
579 |
580 | elif step_type == 'prob_check_gen':
581 | prob_check_prompt = state['current_full_sequence'] + answer_prompt_str
582 | prob_check_prompt_len = len(tokenizer.encode(prob_check_prompt, add_special_tokens=False))
583 | error_msg += f" State: prob_check_gen, prompt length: {prob_check_prompt_len}, attempting to generate {args.prob_check_max_tokens} tokens."
584 | if prob_check_prompt_len + args.prob_check_max_tokens > args.model_context_len: error_msg += " Note: Prompt+generation exceeds context window limit."
585 |
586 | elif step_type == 'answer':
587 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + answer_prompt_str
588 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False))
589 | total_tokens_before_answer_prompt = initial_prompt_len + current_generated_thinking_tokens
590 | max_new_answer = min(args.max_len - total_tokens_before_answer_prompt, args.model_context_len - len_final_answer_prompt)
591 | error_msg += f" State: answer, prompt length: {len_final_answer_prompt}, attempting to generate {max_new_answer} tokens."
592 | if max_new_answer <= 0: error_msg += " Note: Calculated max new tokens <= 0."
593 | if args.model_context_len - len_final_answer_prompt <= 0: error_msg += " Note: Already exceeded context window."
594 | if args.max_len - total_tokens_before_answer_prompt <= 0: error_msg += " Note: Reached total token limit."
595 |
596 | raise ValueError(error_msg)
597 |
598 | completion_output = output.outputs[0]
599 | generated_text = completion_output.text
600 | generated_ids = completion_output.token_ids
601 | last_token_id = generated_ids[-1]
602 |
603 | rep = seq_rep_n(state['generated_thinking_last_trun'], generated_text, state['rep_end'])
604 | state['rep_end'] = rep
605 |
606 |
607 | if step_type == 'think':
608 | if state['rep_end'] >= 3 and args.rep == 1:
609 | state['state'] = 'needs_answer'
610 | state['generated_thinking_history'] += generated_text
611 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history']
612 | #state['rep_end'] = 1
613 | elif last_token_id in last_token_ids:
614 | state['state'] = 'needs_answer'
615 | state['generated_thinking_history'] += generated_text
616 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history']
617 | state['regular_end'] = 1
618 | else:
619 | # Append generated thinking chunk
620 | state['generated_thinking_history'] += generated_text
621 | state['generated_thinking_last_trun'] = generated_text
622 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history']
623 | state['state'] = 'needs_prob_check'
624 | state['thinking_steps'] += 1
625 |
626 | elif step_type == 'prob_check_gen':
627 | # Get logprobs for probability calculation.
628 | if completion_output.logprobs:
629 | state['pred_prob'] = calculate_average_max_prob_from_logprobs(completion_output.logprobs, args.policy)
630 |
631 | else:
632 | print(f"Warning: No logprobs returned for prob_check_gen for question {q_idx}. Setting pred_prob to 0.0.")
633 | state['pred_prob'] = 0.0
634 |
635 | # Recalculate current thinking history token length before making decision
636 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False))
637 | thinking_limit_reached = current_generated_thinking_tokens >= think_limit_tokens - 50
638 |
639 | if state['pred_prob'] > args.threshold or thinking_limit_reached: # Third condition: already generated to
640 | # Probability high enough or reached thinking limit, switch to answer phase
641 | state['state'] = 'needs_answer'
642 | if thinking_limit_reached:
643 | print(f"\nQuestion {q_idx}: Actually reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer phase.")
644 | state['too_long'] = 1
645 | else:
646 | print(f"\nQuestion {q_idx}: Reached early exit threshold ({state['pred_prob']:.4f} > {args.threshold}). Switching to answer phase.")
647 | state['high_prob'] = 1
648 | else:
649 | # Probability not high enough, need more thinking
650 | state['state'] = 'needs_thought_chunk'
651 | if not state['current_full_sequence'].strip().endswith(continue_str) and state['thinking_steps'] != 0:
652 | state['current_full_sequence'] += continue_str
653 | state['generated_thinking_history'] += continue_str
654 | print(f"\nQuestion {q_idx}: Early exit threshold not reached ({state['pred_prob']:.4f} <= {args.threshold}), thinking history length ({current_generated_thinking_tokens}/{think_limit_tokens}). Appending '{continue_str}' and continuing thinking.")
655 |
656 | elif step_type == 'answer':
657 | state['generated_answer_history'] += (generated_text)
658 | if last_token_id != tokenizer.eos_token_id and args.af == 1:
659 | state['state'] = 'answer_forcing'
660 | else:
661 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history']
662 | state['state'] = 'finished'
663 | final_response_text = state['generated_thinking_history'] + state['generated_answer_history']
664 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']}
665 | if q_idx in active_questions_indices:
666 | active_questions_indices.remove(q_idx)
667 | pbar.update(1)
668 |
669 | elif step_type == 'answer_exit':
670 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + generated_text
671 | state['state'] = 'finished'
672 | final_response_text = state['generated_thinking_history'] + generated_text
673 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']}
674 | if q_idx in active_questions_indices:
675 | active_questions_indices.remove(q_idx)
676 | pbar.update(1)
677 |
678 | except Exception as e:
679 | print(f"\nError processing batch results for question {q_idx} step '{step_type}': {e}")
680 | state['state'] = 'error'
681 | state['error_message'] = f"Error processing batch results for step '{step_type}': {e}"
682 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\nError: " + state['error_message']], 'gold_answer': state['question_data']['answer']}
683 |
684 | if q_idx in active_questions_indices:
685 | active_questions_indices.remove(q_idx)
686 | pbar.update(1)
687 |
688 |
689 | pbar.close()
690 | final_results = [state['output_dict'] for state in questions_state.values() if state['state'] in ['finished', 'error']]
691 |
692 | # Create a mapping from problem text to original index for sorting
693 | problem_to_index = {item['problem']: i for i, item in enumerate(questions_json)}
694 | # Use get method to handle cases where problem text might not be found (though it shouldn't happen)
695 | final_results.sort(key=lambda x: problem_to_index.get(x['question'], len(questions_json)))
696 |
697 | print("\nAll questions processed, saving results...")
698 | try:
699 | write_jsonl(final_results, output_file)
700 | except Exception as e:
701 | print(f"Error saving results to {output_file}: {e}")
702 |
703 | end_time = time.time()
704 | elapsed_time = end_time - start_time
705 | print(f"Evaluation completed! Attempted to process {len(questions_json)} questions in total, successfully recorded {len(final_results)} results, took {elapsed_time:.2f} seconds")
706 | print(f"Results saved to: {output_file}")
707 |
708 | if __name__ == "__main__":
709 | #set_seeds(42)
710 | main()
--------------------------------------------------------------------------------
/vllm-deer.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore") # Ignore all warnings
3 |
4 | import os
5 | import json
6 | import time
7 | import argparse
8 | import sys
9 | import torch
10 | import torch.nn.functional as F
11 | from vllm.outputs import CompletionOutput
12 | from typing import Any, Dict, List
13 | from nltk import ngrams
14 | from collections import Counter
15 |
16 | from transformers import AutoTokenizer
17 | from tqdm import tqdm
18 | from vllm import LLM, SamplingParams
19 | import pdb
20 |
21 | import math
22 | import numpy as np
23 | import random
24 |
25 | def set_seeds(seed=42):
26 | # Set Python built-in random seed
27 | random.seed(seed)
28 |
29 | # Set NumPy random seed
30 | np.random.seed(seed)
31 |
32 | # Set PyTorch CPU random seed
33 | torch.manual_seed(seed)
34 |
35 | # If using GPU (especially CUDA)
36 | if torch.cuda.is_available():
37 | torch.cuda.manual_seed(seed) # Set seed for current GPU
38 | torch.cuda.manual_seed_all(seed) # Also effective for multi-GPU
39 |
40 | # For better reproducibility, enable cudnn determinism mode
41 | torch.backends.cudnn.deterministic = True
42 | torch.backends.cudnn.benchmark = False
43 |
44 | # Optional: Set generator (for DataLoader with multi-threading)
45 | g = torch.Generator()
46 | g.manual_seed(seed)
47 |
48 |
49 |
50 |
51 |
52 | def append_jsonl(data, file_path):
53 | """Append results in the list to a .jsonl file."""
54 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
55 | with open(file_path, 'a', encoding='utf-8') as f:
56 | for item in data:
57 | f.write(json.dumps(item, ensure_ascii=False) + '\n')
58 |
59 | def write_jsonl(data, file_path):
60 | """Write results in the list to a .jsonl file."""
61 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
62 | with open(file_path, 'w', encoding='utf-8') as f:
63 | for item in data:
64 | f.write(json.dumps(item, ensure_ascii=False) + '\n')
65 |
66 | def read_jsonl(file_path):
67 | """Read .jsonl file and return a list of dictionaries."""
68 | data = []
69 | if not os.path.exists(file_path):
70 | print(f"Warning: Dataset file not found at {file_path}")
71 | return data
72 | with open(file_path, 'r', encoding='utf-8') as f:
73 | for line in f:
74 | data.append(json.loads(line.strip()))
75 | return data
76 |
77 |
78 | def seq_rep_n(last_thinking, cur_thinking, rep, n=1):
79 |
80 |
81 | pred = last_thinking
82 | target = cur_thinking
83 |
84 | pred_tokens = pred.split(' ')
85 | target_token = target.split(' ')
86 |
87 | ngs_pred = [ng for ng in ngrams(pred_tokens, n)]
88 | ngs_know = [ng for ng in ngrams(target_token, n)]
89 | intersection = list(set(ngs_pred) & set(ngs_know))
90 | overlap_num = len(intersection)
91 |
92 |
93 | if overlap_num == len(ngs_pred) and overlap_num == len(ngs_know):
94 | rep += 1
95 |
96 |
97 | return rep
98 |
99 | # Function to calculate average max probability, mimicking Transformers version logic
100 | def calculate_average_max_prob_from_logprobs(logprobs_list, policy='avg2') -> float:
101 | """
102 | Calculate average max token probability from logprobs list in vLLM CompletionOutput.
103 | Compute from the second generated token to the second-to-last token.
104 | policy: min, avg1: arithmetic mean, avg2: geometric mean
105 | """
106 |
107 | num_tokens = len(logprobs_list)
108 | start_index = 1
109 | end_index = num_tokens
110 |
111 | if num_tokens < 1:
112 | print("Too few tokens to calculate valid average.")
113 | return 0.0
114 |
115 | total_prob_sum = 0.0
116 | log_prob_sum = 0.0 # For geometric mean
117 | count_for_average = 0
118 | min_prob = 1.0
119 |
120 | for i in range(start_index, end_index):
121 | # Ensure index is valid and corresponding logprobs entry is not empty
122 | if i < len(logprobs_list) and logprobs_list[i]:
123 | try:
124 | logprob_obj = list(logprobs_list[i].values())[0]
125 | # Ensure object has .logprob attribute
126 | if hasattr(logprob_obj, 'logprob'):
127 | prob = torch.exp(torch.tensor(logprob_obj.logprob)).item()
128 | if prob < min_prob:
129 | min_prob = prob
130 | #print(prob)
131 | #print(list(logprobs_list[i].values())[0])
132 | total_prob_sum += prob
133 | log_prob_sum += math.log(max(prob, 1e-10))
134 | count_for_average += 1
135 | else:
136 | print(f"Warning: Object at logprobs_list[{i}] doesn't have '.logprob' attribute.")
137 | except (IndexError, KeyError, AttributeError) as e:
138 | print(f"Warning: Unable to process logprobs at logprobs_list[{i}]: {e}")
139 | else:
140 | print(f"Warning: logprobs_list[{i}] is empty or invalid.")
141 | # Calculate average
142 | if policy == 'min':
143 | result = min_prob
144 | elif policy == 'avg1':
145 | result = total_prob_sum / count_for_average
146 | elif policy == 'avg2':
147 | result = math.exp(log_prob_sum / count_for_average)
148 |
149 | return result
150 |
151 |
152 |
153 | def parse_args():
154 | parser = argparse.ArgumentParser()
155 | parser.add_argument('--model_name_or_path', type=str, default="./DeepSeek-R1-Distill-Qwen-14B/")
156 | parser.add_argument('--dataset_dir', type=str, default="./data/")
157 | parser.add_argument("--dtype", type=str, default="bfloat16")
158 | parser.add_argument("--max-model-len", "--model-context-len", type=int, default=40000, dest="model_context_len") # max-model-len for vllm, should be longer than max_generated_tokens.
159 | parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
160 | parser.add_argument("--trust-remote-code", action="store_true")
161 | parser.add_argument("--run_time", type=int, default=1)
162 | parser.add_argument("--no_thinking", type=int, default=0) # Calculate the answer confidence at the very beginning of the reasoning process and attempt to exit early.
163 | parser.add_argument("--rep", type=int, default=0) # Exit early when repetition occurs, but it remains to be implemented. (TODO)
164 | parser.add_argument("--points", type=int, default=1) # 1: 'Wait' as thinking transition point. 0: 'Alternatively' as thinking transition point.
165 | parser.add_argument("--af", type=int, default=0) # answer forcing at end of sequence
166 | parser.add_argument("--max_judge_steps", type=int, default=10) # Limit the maximum number of answer attempts to save time cost.
167 | parser.add_argument('--policy', type=str, default="avg1") # Strategy for Calculating Answer Confidence
168 |
169 | parser.add_argument('--threshold', type=float, default=0.95) # The answer confidence threshold used to determine early exit.
170 | parser.add_argument('--max_generated_tokens', '--max-len', type=int, default=16384, dest="max_len") # total token budget
171 | parser.add_argument('--dataset', type=str, default='math') # dataset name
172 | parser.add_argument('--output_path', type=str, default='./outputs') # output path
173 | parser.add_argument('--think_ratio', type=float, default=0.9, help="Ratio of thinking phase to max generated tokens") # Ratio of thinking phase to max generated tokens
174 | parser.add_argument('--batch_size', type=int, default=2000) # vllm batch size, set it to a value above the number of samples in the dataset.
175 | parser.add_argument('--temperature', type=float, default=0.0)
176 | parser.add_argument('--top_p', type=float, default=1.0)
177 |
178 | # Hardcoded 20
179 | parser.add_argument('--prob_check_max_tokens', type=int, default=20, help="Max tokens for probability check phase") # Max tokens for answer inducing
180 |
181 | args = parser.parse_args()
182 | return args
183 |
184 | def main():
185 | args = parse_args()
186 | args.model_context_len = args.max_len + 8000
187 | print(f"Using vLLM LLM object for direct inference (batch processing)")
188 | print(f"Model path: {args.model_name_or_path}")
189 | print(f"Dataset: {args.dataset}")
190 | print(f"Early exit probability threshold: {args.threshold}")
191 | print(f"Max total generated tokens: {args.max_len}")
192 | print(f"Thinking phase ratio: {args.think_ratio}")
193 | print(f"Batch size: {args.batch_size}")
194 | print(f"Max tokens for probability check phase: {args.prob_check_max_tokens}")
195 |
196 | print("\nInitializing vLLM LLM engine...")
197 | available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
198 | try:
199 | llm_engine = LLM(
200 | model=args.model_name_or_path,
201 | tensor_parallel_size=len(available_gpus),
202 | dtype=args.dtype,
203 | max_model_len=args.max_len + 8000,
204 | gpu_memory_utilization=args.gpu_memory_utilization,
205 | trust_remote_code=True,
206 | )
207 | print("vLLM LLM engine initialized successfully.")
208 |
209 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=args.trust_remote_code)
210 | print(f"Successfully loaded tokenizer: {args.model_name_or_path}")
211 | if tokenizer.pad_token is None:
212 | if tokenizer.eos_token is not None:
213 | tokenizer.pad_token = tokenizer.eos_token
214 | else:
215 | tokenizer.add_special_tokens({'pad_token': '[PAD]'})
216 | print("Warning: Model has no pad_token or eos_token. Added custom [PAD] token.")
217 |
218 | print(f"Tokenizer using pad_token_id: {tokenizer.pad_token_id}")
219 |
220 |
221 | except Exception as e:
222 | print(f"Failed to initialize vLLM LLM engine or load tokenizer: {e}")
223 | sys.exit(1)
224 |
225 | sys_prompt = ['Please reason step by step, and put your final answer within \\boxed{}.'][0]
226 | dataset_path = f'{args.dataset_dir}/{args.dataset}/test.jsonl'
227 | try:
228 | questions_json = read_jsonl(dataset_path)
229 | if not questions_json:
230 | print(f"Error: No questions loaded from {dataset_path}.")
231 | sys.exit(1)
232 | print(f"Successfully loaded dataset: {dataset_path}, total {len(questions_json)} questions")
233 | except Exception as e:
234 | print(f"Failed to load dataset: {e}")
235 | sys.exit(1)
236 |
237 | model_dir_name = os.path.basename(os.path.normpath(args.model_name_or_path))
238 | output_dir = f'{args.output_path}/{model_dir_name}/{args.dataset}'
239 | os.makedirs(output_dir, exist_ok=True)
240 | output_file = f'{output_dir}/greedy_p{str(args.threshold)}_ratio{str(args.think_ratio)}_len{str(args.max_len)}_temperature{str(args.temperature)}_run_time{args.run_time}_no_thinking{args.no_thinking}_rep{args.rep}_points{args.points}_policy{args.policy}.jsonl'
241 |
242 | print(f"\nStarting processing, total questions: {len(questions_json)}")
243 | start_time = time.time()
244 |
245 | questions_state = {} # Dictionary to store processing state for each question
246 | last_token_strs = [""] # Strings marking end of thinking
247 | if args.points == 1:
248 | continue_str = "Wait" # String appended to sequence end to indicate continued thinking
249 | else:
250 | continue_str = "Alternatively" # String appended to sequence end to indicate continued thinking
251 |
252 | answer_prompt_str = "\n**Final Answer**\n\\boxed" # Prompt string to guide answer generation
253 | if 'gpqa' in args.dataset:
254 | answer_prompt_str = "\n**Final Answer**\nI believe the final answer, rather than the option, is \\boxed"
255 |
256 | # Get token IDs for stop conditions and strings to append
257 | last_token_ids = []
258 | for s in last_token_strs:
259 | ids = tokenizer.encode(s, add_special_tokens=False)
260 | if ids: last_token_ids.extend(ids)
261 | last_token_ids = list(set(last_token_ids)) # Remove duplicate IDs
262 |
263 | continue_ids = tokenizer.encode(continue_str, add_special_tokens=False)
264 | if not continue_ids:
265 | print(f"Warning: Unable to tokenize continue string '{continue_str}'. This may affect logic.")
266 |
267 | # Stop tokens for thinking phase generation
268 | generation_stop_tokens = [continue_str] + last_token_strs + [tokenizer.eos_token]
269 | pred_prob_stop_tokens = [' }', '}\n', '}\n\n', '}.', '}.\n', '}\\', '}}', ')}', ')}.', ')}\n'] # where \boxed{} ends. Used to stop the model from predicting intermediate answers.
270 |
271 | answer_stop_tokens = [tokenizer.eos_token]
272 |
273 |
274 | # Max token limit for thinking phase
275 | think_limit_tokens = int(args.max_len * args.think_ratio)
276 |
277 | for i, question_data in enumerate(questions_json):
278 |
279 | messages = [
280 | {"role": "system", "content": sys_prompt},
281 | {"role": "user", "content": question_data['problem']}
282 | ]
283 | formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
284 | if args.no_thinking == 1:
285 | questions_state[i] = {
286 | 'question_data': question_data,
287 | 'state': 'needs_prob_check', # try to exit ai begin
288 | 'formatted_prompt': formatted_prompt,
289 | 'current_full_sequence': formatted_prompt,
290 | 'generated_thinking_history': "",
291 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss",
292 | 'generated_answer_history': "",
293 | 'pred_prob': 0.0,
294 | 'too_long': 0,
295 | 'rep_end': 0,
296 | 'high_prob': 0,
297 | 'regular_end': 0,
298 | 'thinking_steps': 0,
299 | 'output_dict': {},
300 | 'error_message': None,
301 | 'question_index': i
302 | }
303 | else:
304 | questions_state[i] = {
305 | 'question_data': question_data,
306 | 'state': 'needs_thought_chunk',
307 | 'formatted_prompt': formatted_prompt,
308 | 'current_full_sequence': formatted_prompt,
309 | 'generated_thinking_history': "",
310 | 'generated_thinking_last_trun': "nsss, ssssa, wrtt, yyy, sss",
311 | 'generated_answer_history': "",
312 | 'pred_prob': 0.0,
313 | 'too_long': 0,
314 | 'rep_end': 0,
315 | 'high_prob': 0,
316 | 'regular_end': 0,
317 | 'thinking_steps': 0,
318 | 'output_dict': {},
319 | 'error_message': None,
320 | 'question_index': i
321 | }
322 |
323 | active_questions_indices = list(questions_state.keys()) # List of currently processing question indices
324 | pbar = tqdm(total=len(questions_json), desc="Processing questions")
325 |
326 | print("\nRunning a simple test generation...")
327 | try:
328 | test_outputs = llm_engine.generate(["Hello, world!"], SamplingParams(max_tokens=10, temperature=args.temperature), use_tqdm=False)
329 | if test_outputs and test_outputs[0].outputs:
330 | test_generated_text = test_outputs[0].outputs[0].text
331 | print(f"Test generation successful. Output: '{test_generated_text.strip()}'")
332 | else:
333 | print("Simple test generation failed: LLM generate returned no output.")
334 | except Exception as e:
335 | print(f"Simple test generation failed: {e}")
336 |
337 | # Main processing loop: continue while there are active questions
338 | while active_questions_indices: # indexes [0,1,2,...,n]
339 | batch_prompts = [] # Current batch prompts
340 | batch_sampling_params = [] # Current batch sampling parameters
341 | batch_request_info = [] # Store (question index, step type) for output processing
342 |
343 | current_batch_count = 0
344 | # Create copy of active_questions_indices to allow modifying original list during iteration
345 | current_active_indices_for_batching = active_questions_indices[:]
346 |
347 | # Build current batch
348 | for q_idx in current_active_indices_for_batching:
349 | if current_batch_count >= args.batch_size:
350 | break
351 |
352 | state = questions_state[q_idx]
353 | if state['state'] in ['finished', 'error']:
354 | continue
355 |
356 | prompt_for_batch = None
357 | sampling_params_for_batch = None
358 | step_type = None # 'think', 'prob_check_gen', 'answer'
359 |
360 | try:
361 | # --- Determine prompt and parameters based on state ---
362 | current_full_sequence_tokens = tokenizer.encode(state['current_full_sequence'], add_special_tokens=False)
363 | current_full_sequence_len = len(current_full_sequence_tokens)
364 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False))
365 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False))
366 |
367 | # Check context window limit before adding any new content
368 | remaining_context_window = args.model_context_len - current_full_sequence_len
369 | # check window, skip it
370 | if remaining_context_window <= 0:
371 | state['state'] = 'error'
372 | state['error_message'] = f"Exceeded model context window limit ({args.model_context_len}). Current sequence length: {current_full_sequence_len}"
373 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
374 | print(f"\nQuestion {q_idx}: {state['error_message']}")
375 | if q_idx in active_questions_indices:
376 | active_questions_indices.remove(q_idx)
377 | pbar.update(1)
378 | continue
379 |
380 | # Initial state
381 | if state['state'] == 'needs_thought_chunk':
382 | # Calculate max tokens for this generation chunk, considering thinking limit and context window
383 | max_new_tokens_for_thought = min(
384 | think_limit_tokens - current_generated_thinking_tokens, # thinking budget
385 | remaining_context_window
386 | )
387 | if max_new_tokens_for_thought <= 0:
388 | state['state'] = 'needs_answer'
389 | print(f"\nQuestion {q_idx}: Reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer generation.")
390 | state['too_long'] = 1
391 | continue
392 |
393 |
394 |
395 | prompt_for_batch = state['current_full_sequence']
396 | if state['thinking_steps'] < args.max_judge_steps:
397 | sampling_params_for_batch = SamplingParams(
398 | max_tokens=max_new_tokens_for_thought,
399 | temperature=args.temperature,
400 | top_p=args.top_p,
401 | stop=generation_stop_tokens
402 | )
403 | else:
404 | sampling_params_for_batch = SamplingParams(
405 | max_tokens=max_new_tokens_for_thought,
406 | temperature=args.temperature,
407 | top_p=args.top_p,
408 | stop=last_token_strs
409 | )
410 | step_type = 'think'
411 |
412 | # Check state
413 | elif state['state'] == 'needs_prob_check':
414 | # Build prompt for probability check generation:
415 | prompt_for_prob_check = state['current_full_sequence'] + answer_prompt_str
416 | prob_check_prompt_len = len(tokenizer.encode(prompt_for_prob_check, add_special_tokens=False))
417 | required_space_for_prob_check = prob_check_prompt_len + args.prob_check_max_tokens
418 |
419 | if required_space_for_prob_check > args.model_context_len:
420 | state['state'] = 'error'
421 | state['error_message'] = f"Probability check generation prompt exceeds context window ({args.model_context_len}). Estimated length: {required_space_for_prob_check}"
422 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
423 | print(f"\nQuestion {q_idx}: {state['error_message']}")
424 | if q_idx in active_questions_indices:
425 | active_questions_indices.remove(q_idx)
426 | pbar.update(1)
427 | continue
428 |
429 |
430 | # Parameters for generating *prediction sequence* in probability check phase
431 | prompt_for_batch = prompt_for_prob_check
432 | sampling_params_for_batch = SamplingParams(
433 | max_tokens=args.prob_check_max_tokens,
434 | #temperature=args.temperature, # Greedy decoding
435 | stop=pred_prob_stop_tokens, # Only predict content inside \boxed{}
436 | logprobs=1,
437 | )
438 | step_type = 'prob_check_gen'
439 |
440 |
441 | elif state['state'] == 'needs_answer':
442 | # Build final answer prompt
443 |
444 | if state['too_long'] == 1:
445 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n'
446 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n'
447 |
448 | else:
449 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + '\n\n\n'#+ answer_prompt_str
450 | state['generated_thinking_history'] = state['generated_thinking_history'] + '\n\n\n'
451 |
452 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False))
453 | total_tokens_before_answer_prompt = current_generated_thinking_tokens
454 | # Calculate remaining total budget
455 | remaining_total_budget = args.max_len - total_tokens_before_answer_prompt
456 | max_new_tokens_answer = min(
457 | remaining_total_budget,
458 | args.model_context_len - len_final_answer_prompt
459 | )
460 |
461 | if max_new_tokens_answer <= 0:
462 |
463 | state['state'] = 'error'
464 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + "\nSkipped answer generation due to length limit."], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps']}
465 | print(f"\nQuestion {q_idx}: Skipped answer generation due to length limit.")
466 |
467 | if q_idx in active_questions_indices:
468 | active_questions_indices.remove(q_idx)
469 | pbar.update(1)
470 | continue
471 | else:
472 | prompt_for_batch = final_answer_prompt
473 | sampling_params_for_batch = SamplingParams(
474 | max_tokens=max_new_tokens_answer,
475 | temperature=args.temperature,
476 | stop=answer_stop_tokens,
477 | top_p=args.top_p,
478 | )
479 | step_type = 'answer'
480 |
481 |
482 |
483 |
484 |
485 | elif state['state'] == 'answer_forcing':
486 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str
487 | state['generated_thinking_history'] = state['generated_thinking_history'] + state['generated_answer_history'] + answer_prompt_str
488 |
489 | prompt_for_batch = final_answer_prompt
490 | sampling_params_for_batch = SamplingParams(
491 | max_tokens=100,
492 | temperature=args.temperature,
493 | stop=answer_stop_tokens,
494 | top_p=args.top_p,
495 | )
496 | step_type = 'answer_exit'
497 |
498 | # If execution reaches here and prompt_for_batch is None, there may be state logic issues
499 | # This check generally shouldn't trigger
500 | if prompt_for_batch is None:
501 | state['state'] = 'error'
502 | state['error_message'] = f"Internal error: No prompt generated for state {state['state']}."
503 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
504 | print(f"\nQuestion {q_idx}: {state['error_message']}")
505 | if q_idx in active_questions_indices:
506 | active_questions_indices.remove(q_idx)
507 | pbar.update(1)
508 | continue
509 |
510 | batch_prompts.append(prompt_for_batch)
511 | batch_sampling_params.append(sampling_params_for_batch)
512 | batch_request_info.append((q_idx, step_type))
513 | current_batch_count += 1
514 |
515 | except Exception as e:
516 | state['state'] = 'error'
517 | state['error_message'] = f"Error preparing batch request for state '{state['state']}': {e}"
518 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
519 | print(f"\nQuestion {q_idx}: {state['error_message']}")
520 | if q_idx in active_questions_indices:
521 | active_questions_indices.remove(q_idx)
522 | pbar.update(1)
523 |
524 | if not batch_prompts:
525 | # May occur when all remaining active questions have switched to finished/error state or were skipped.
526 | all_stuck = True
527 | for q_idx in active_questions_indices:
528 | if questions_state[q_idx]['state'] not in ['finished', 'error']:
529 | all_stuck = False
530 | break
531 | if all_stuck:
532 | print("Warning: Batch generated no requests. All remaining questions are completed or in error state.")
533 | break
534 | else:
535 | print("Error: Active questions remain but no batch requests generated. Possible logic error.")
536 | for q_idx in list(active_questions_indices):
537 | state = questions_state[q_idx]
538 | if state['state'] not in ['finished', 'error']:
539 | state['state'] = 'error'
540 | state['error_message'] = "Processing aborted: Unable to generate request in batch loop."
541 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\n" + state['error_message']], 'gold_answer': state['question_data']['answer']}
542 | print(f"\nQuestion {q_idx}: Marked as error due to processing abort.")
543 | if q_idx in active_questions_indices: # Should be, but double-checking
544 | active_questions_indices.remove(q_idx)
545 | pbar.update(1)
546 | break
547 |
548 | batch_outputs = llm_engine.generate(batch_prompts, batch_sampling_params, use_tqdm=False)
549 |
550 |
551 | # --- Process batch outputs ---
552 | for i, output in enumerate(batch_outputs):
553 | q_idx, step_type = batch_request_info[i]
554 | state = questions_state[q_idx]
555 |
556 | if state['state'] in ['finished', 'error']:
557 | continue
558 |
559 | try:
560 | if not output.outputs: # skip, exception handling
561 | # vLLM returned empty output, possibly due to prompt issues, length limits or other internal errors
562 | error_msg = f"vLLM returned empty output for request {output.request_id} (question {q_idx}, step {step_type})."
563 | if hasattr(output, 'error') and output.error:
564 | error_msg += f" vLLM error: {output.error}"
565 | current_full_sequence_len = len(tokenizer.encode(state['current_full_sequence'], add_special_tokens=False))
566 | initial_prompt_len = len(tokenizer.encode(state['formatted_prompt'], add_special_tokens=False))
567 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False))
568 |
569 |
570 | if step_type == 'think':
571 | max_new = min(think_limit_tokens - current_generated_thinking_tokens, args.model_context_len - current_full_sequence_len)
572 | error_msg += f" State: think, attempting to generate {max_new} tokens."
573 | if max_new <= 0: error_msg += " Note: Calculated max new tokens <= 0."
574 | if args.model_context_len - current_full_sequence_len <= 0: error_msg += " Note: Already exceeded context window."
575 | if think_limit_tokens - current_generated_thinking_tokens <= 0: error_msg += " Note: Reached thinking token limit."
576 |
577 | elif step_type == 'prob_check_gen':
578 | prob_check_prompt = state['current_full_sequence'] + answer_prompt_str
579 | prob_check_prompt_len = len(tokenizer.encode(prob_check_prompt, add_special_tokens=False))
580 | error_msg += f" State: prob_check_gen, prompt length: {prob_check_prompt_len}, attempting to generate {args.prob_check_max_tokens} tokens."
581 | if prob_check_prompt_len + args.prob_check_max_tokens > args.model_context_len: error_msg += " Note: Prompt+generation exceeds context window limit."
582 |
583 | elif step_type == 'answer':
584 | final_answer_prompt = state['formatted_prompt'] + state['generated_thinking_history'] + answer_prompt_str
585 | len_final_answer_prompt = len(tokenizer.encode(final_answer_prompt, add_special_tokens=False))
586 | total_tokens_before_answer_prompt = initial_prompt_len + current_generated_thinking_tokens
587 | max_new_answer = min(args.max_len - total_tokens_before_answer_prompt, args.model_context_len - len_final_answer_prompt)
588 | error_msg += f" State: answer, prompt length: {len_final_answer_prompt}, attempting to generate {max_new_answer} tokens."
589 | if max_new_answer <= 0: error_msg += " Note: Calculated max new tokens <= 0."
590 | if args.model_context_len - len_final_answer_prompt <= 0: error_msg += " Note: Already exceeded context window."
591 | if args.max_len - total_tokens_before_answer_prompt <= 0: error_msg += " Note: Reached total token limit."
592 |
593 | raise ValueError(error_msg)
594 |
595 | completion_output = output.outputs[0]
596 | generated_text = completion_output.text
597 | generated_ids = completion_output.token_ids
598 | last_token_id = generated_ids[-1]
599 |
600 | rep = seq_rep_n(state['generated_thinking_last_trun'], generated_text, state['rep_end'])
601 | state['rep_end'] = rep
602 |
603 |
604 | if step_type == 'think':
605 | if state['rep_end'] >= 3 and args.rep == 1:
606 | state['state'] = 'needs_answer'
607 | state['generated_thinking_history'] += generated_text
608 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history']
609 | #state['rep_end'] = 1
610 | elif last_token_id in last_token_ids:
611 | state['state'] = 'needs_answer'
612 | state['generated_thinking_history'] += generated_text
613 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history']
614 | state['regular_end'] = 1
615 | else:
616 | # Append generated thinking chunk
617 | state['generated_thinking_history'] += generated_text
618 | state['generated_thinking_last_trun'] = generated_text
619 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history']
620 | state['state'] = 'needs_prob_check'
621 | state['thinking_steps'] += 1
622 |
623 | elif step_type == 'prob_check_gen':
624 | # Get logprobs for probability calculation.
625 | if completion_output.logprobs:
626 | state['pred_prob'] = calculate_average_max_prob_from_logprobs(completion_output.logprobs, args.policy)
627 |
628 | else:
629 | print(f"Warning: No logprobs returned for prob_check_gen for question {q_idx}. Setting pred_prob to 0.0.")
630 | state['pred_prob'] = 0.0
631 |
632 | # Recalculate current thinking history token length before making decision
633 | current_generated_thinking_tokens = len(tokenizer.encode(state['generated_thinking_history'], add_special_tokens=False))
634 | thinking_limit_reached = current_generated_thinking_tokens >= think_limit_tokens - 50
635 |
636 | if state['pred_prob'] > args.threshold or thinking_limit_reached: # Third condition: already generated to
637 | # Probability high enough or reached thinking limit, switch to answer phase
638 | state['state'] = 'needs_answer'
639 | if thinking_limit_reached:
640 | print(f"\nQuestion {q_idx}: Actually reached thinking limit ({current_generated_thinking_tokens}/{think_limit_tokens}). Switching to answer phase.")
641 | state['too_long'] = 1
642 | else:
643 | print(f"\nQuestion {q_idx}: Reached early exit threshold ({state['pred_prob']:.4f} > {args.threshold}). Switching to answer phase.")
644 | state['high_prob'] = 1
645 | else:
646 | # Probability not high enough, need more thinking
647 | state['state'] = 'needs_thought_chunk'
648 | if not state['current_full_sequence'].strip().endswith(continue_str) and state['thinking_steps'] != 0:
649 | state['current_full_sequence'] += continue_str
650 | state['generated_thinking_history'] += continue_str
651 | print(f"\nQuestion {q_idx}: Early exit threshold not reached ({state['pred_prob']:.4f} <= {args.threshold}), thinking history length ({current_generated_thinking_tokens}/{think_limit_tokens}). Appending '{continue_str}' and continuing thinking.")
652 |
653 | elif step_type == 'answer':
654 | state['generated_answer_history'] += (generated_text)
655 | if last_token_id != tokenizer.eos_token_id and args.af == 1:
656 | state['state'] = 'answer_forcing'
657 | else:
658 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + state['generated_answer_history']
659 | state['state'] = 'finished'
660 | final_response_text = state['generated_thinking_history'] + state['generated_answer_history']
661 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']}
662 | if q_idx in active_questions_indices:
663 | active_questions_indices.remove(q_idx)
664 | pbar.update(1)
665 |
666 | elif step_type == 'answer_exit':
667 | state['current_full_sequence'] = state['formatted_prompt'] + state['generated_thinking_history'] + generated_text
668 | state['state'] = 'finished'
669 | final_response_text = state['generated_thinking_history'] + generated_text
670 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [final_response_text], 'gold_answer': state['question_data']['answer'], 'too_long': state['too_long'], 'thinking_steps': state['thinking_steps'], 'rep_end': state['rep_end'], 'high_prob': state['high_prob'],'regular_end': state['regular_end']}
671 | if q_idx in active_questions_indices:
672 | active_questions_indices.remove(q_idx)
673 | pbar.update(1)
674 |
675 | except Exception as e:
676 | print(f"\nError processing batch results for question {q_idx} step '{step_type}': {e}")
677 | state['state'] = 'error'
678 | state['error_message'] = f"Error processing batch results for step '{step_type}': {e}"
679 | state['output_dict'] = {'question': state['question_data']['problem'], 'generated_responses': [state['generated_thinking_history'] + state['generated_answer_history'] + "\nError: " + state['error_message']], 'gold_answer': state['question_data']['answer']}
680 |
681 | if q_idx in active_questions_indices:
682 | active_questions_indices.remove(q_idx)
683 | pbar.update(1)
684 |
685 |
686 | pbar.close()
687 | final_results = [state['output_dict'] for state in questions_state.values() if state['state'] in ['finished', 'error']]
688 |
689 | # Create a mapping from problem text to original index for sorting
690 | problem_to_index = {item['problem']: i for i, item in enumerate(questions_json)}
691 | # Use get method to handle cases where problem text might not be found (though it shouldn't happen)
692 | final_results.sort(key=lambda x: problem_to_index.get(x['question'], len(questions_json)))
693 |
694 | print("\nAll questions processed, saving results...")
695 | try:
696 | write_jsonl(final_results, output_file)
697 | except Exception as e:
698 | print(f"Error saving results to {output_file}: {e}")
699 |
700 | end_time = time.time()
701 | elapsed_time = end_time - start_time
702 | print(f"Evaluation completed! Attempted to process {len(questions_json)} questions in total, successfully recorded {len(final_results)} results, took {elapsed_time:.2f} seconds")
703 | print(f"Results saved to: {output_file}")
704 |
705 | if __name__ == "__main__":
706 | #set_seeds(42)
707 | main()
--------------------------------------------------------------------------------