60 |
61 |
62 | ## 1. Introduction
63 |
64 | DeepSeekMath is initialized with [DeepSeek-Coder-v1.5 7B](https://huggingface.co/deepseek-ai/deepseek-coder-7b-base-v1.5) and continues pre-training on math-related tokens sourced from Common Crawl, together with natural language and code data for 500B tokens. DeepSeekMath 7B has achieved an impressive score of **51.7%** on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. For research purposes, we release [checkpoints](#4-model-downloads) of base, instruct, and RL models to the public.
65 |
66 |
67 |
68 |
69 |
70 | ## 2. Evaluation Results
71 |
72 | ### DeepSeekMath-Base 7B
73 |
74 | We conduct a comprehensive assessment of the mathematical capabilities of DeepSeekMath-Base 7B, focusing on its ability to produce self-contained mathematical solutions without relying on external tools, solve math problems using tools, and conduct formal theorem proving. Beyond mathematics, we also provide a more general profile of the base model, including its performance of natural language understanding, reasoning, and programming skills.
75 |
76 | - **Mathematical problem solving with step-by-step reasoning**
77 |
78 |
79 |
80 |
81 |
82 | - **Mathematical problem solving with tool use**
83 |
84 |
85 |
86 |
87 |
88 | - **Natural Language Understanding, Reasoning, and Code**
89 |
90 |
91 |
92 |
93 | The evaluation results from the tables above can be summarized as follows:
94 | - **Superior Mathematical Reasoning:** On the competition-level MATH dataset, DeepSeekMath-Base 7B outperforms existing open-source base models by more than 10% in absolute terms through few-shot chain-of-thought prompting, and also surpasses Minerva 540B.
95 | - **Strong Tool Use Ability:** Continuing pre-training with DeepSeekCoder-Base-7B-v1.5 enables DeepSeekMath-Base 7B to more effectively solve and prove mathematical problems by writing programs.
96 | - **Comparable Reasoning and Coding Performance:** DeepSeekMath-Base 7B achieves performance in reasoning and coding that is comparable to that of DeepSeekCoder-Base-7B-v1.5.
97 |
98 | ### DeepSeekMath-Instruct and -RL 7B
99 |
100 | DeepSeekMath-Instruct 7B is a mathematically instructed tuning model derived from DeepSeekMath-Base 7B, while DeepSeekMath-RL 7B is trained on the foundation of DeepSeekMath-Instruct 7B, utilizing our proposed Group Relative Policy Optimization (GRPO) algorithm.
101 |
102 | We evaluate mathematical performance both without and with tool use, on 4 quantitative reasoning benchmarks in English and Chinese. As shown in Table, DeepSeekMath-Instruct 7B demonstrates strong performance of step-by-step reasoning, and DeepSeekMath-RL 7B approaches an accuracy of 60% on MATH with tool use, surpassing all existing open-source models.
103 |
104 |
105 |
106 |
107 |
108 |
109 | ## 3. Data Collection
110 |
111 | - Step 1: Select [OpenWebMath](https://arxiv.org/pdf/2310.06786.pdf), a collection of high-quality mathematical web texts, as our initial seed corpus for training a FastText model.
112 | - Step 2: Use the FastText model to retrieve mathematical web pages from the deduplicated Common Crawl database.
113 | - Step 3: Identify potential math-related domains through statistical analysis.
114 | - Step 4: Manually annotate URLs within these identified domains that are associated with mathematical content.
115 | - Step 5: Add web pages linked to these annotated URLs, but not yet collected, to the seed corpus. Jump to step 1 until four iterations.
116 |
117 |
118 |
119 |
120 |
121 |
122 | After four iterations of data collection, we end up with **35.5M** mathematical web pages, totaling **120B** tokens.
123 |
124 | ## 4. Model Downloads
125 |
126 | We release the DeepSeekMath 7B, including base, instruct and RL models, to the public. To support a broader and more diverse range of research within both academic and commercial communities. Please **note** that the use of this model is subject to the terms outlined in [License section](#6-license). Commercial usage is permitted under these terms.
127 |
128 | ### Huggingface
129 |
130 | | Model | Sequence Length | Download |
131 | | :----------------------- | :-------------: | :----------------------------------------------------------: |
132 | | DeepSeekMath-Base 7B | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-math-7b-base) |
133 | | DeepSeekMath-Instruct 7B | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-math-7b-instruct) |
134 | | DeepSeekMath-RL 7B | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-math-7b-rl) |
135 |
136 | ## 5. Quick Start
137 |
138 | You can directly employ [Huggingface's Transformers](https://github.com/huggingface/transformers) for model inference.
139 |
140 | **Text Completion**
141 |
142 | ```python
143 | import torch
144 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
145 |
146 | model_name = "deepseek-ai/deepseek-math-7b-base"
147 | tokenizer = AutoTokenizer.from_pretrained(model_name)
148 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
149 | model.generation_config = GenerationConfig.from_pretrained(model_name)
150 | model.generation_config.pad_token_id = model.generation_config.eos_token_id
151 |
152 | text = "The integral of x^2 from 0 to 2 is"
153 | inputs = tokenizer(text, return_tensors="pt")
154 | outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
155 |
156 | result = tokenizer.decode(outputs[0], skip_special_tokens=True)
157 | print(result)
158 | ```
159 |
160 | **Chat Completion**
161 |
162 | ```python
163 | import torch
164 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
165 |
166 | model_name = "deepseek-ai/deepseek-math-7b-instruct"
167 | tokenizer = AutoTokenizer.from_pretrained(model_name)
168 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
169 | model.generation_config = GenerationConfig.from_pretrained(model_name)
170 | model.generation_config.pad_token_id = model.generation_config.eos_token_id
171 |
172 | messages = [
173 | {"role": "user", "content": "what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \boxed{}."}
174 | ]
175 | input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
176 | outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
177 |
178 | result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
179 | print(result)
180 | ```
181 |
182 | Avoiding the use of the provided function `apply_chat_template`, you can also interact with our model following the sample template. Note that `messages` should be replaced by your input.
183 |
184 | ```
185 | User: {messages[0]['content']}
186 |
187 | Assistant: {messages[1]['content']}<|end▁of▁sentence|>User: {messages[2]['content']}
188 |
189 | Assistant:
190 | ```
191 |
192 | **Note:** By default (`add_special_tokens=True`), our tokenizer automatically adds a `bos_token` (`<|begin▁of▁sentence|>`) before the input text. Additionally, since the system prompt is not compatible with this version of our models, we DO NOT RECOMMEND including the system prompt in your input.
193 |
194 | ❗❗❗ **Please use chain-of-thought prompt to test DeepSeekMath-Instruct and DeepSeekMath-RL:**
195 |
196 | - English questions: **{question}\nPlease reason step by step, and put your final answer within \\boxed{}.**
197 |
198 | - Chinese questions: **{question}\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。**
199 |
200 |
201 | ## 6. License
202 | This code repository is licensed under the MIT License. The use of DeepSeekMath models is subject to the Model License. DeepSeekMath supports commercial use.
203 |
204 | See the [LICENSE-CODE](LICENSE-CODE) and [LICENSE-MODEL](LICENSE-MODEL) for more details.
205 |
206 | ## 7. Citation
207 |
208 | ```
209 | @misc{deepseek-math,
210 | author = {Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y.K. Li, Y. Wu, Daya Guo},
211 | title = {DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models},
212 | journal = {CoRR},
213 | volume = {abs/2402.03300},
214 | year = {2024},
215 | url = {https://arxiv.org/abs/2402.03300},
216 | }
217 | ```
218 |
219 |
220 | ## 8. Contact
221 |
222 | If you have any questions, please raise an issue or contact us at [service@deepseek.com](mailto:service@deepseek.com).
223 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 | python_version: "3.11"
7 | python_packages:
8 | - torch==2.0.1
9 | - torchvision==0.15.2
10 | - transformers==4.37.2
11 | - accelerate==0.27.0
12 | - hf_transfer
13 |
14 | # predict.py defines how predictions are run on your model
15 | predict: "replicate/predict.py:Predictor"
16 |
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 | ## 1. Introduction
2 |
3 | We provide a test script for both zero-shot and few-shot evaluation on mathematical reasoning benchmarks used in our paper.
4 |
5 | ## 2. Setup
6 |
7 | First configure the `prefix` in `environment.yml` and then run the following command
8 | ```
9 | conda env create -f environment.yml
10 | ```
11 |
12 | ## 3. Evaluation
13 |
14 | For chain-of-thought evaluation of DeepSeekMath-Instruct and DeepSeekMath-RL, our script (see `def markup_question()` in `run_subset_parallel.py`) processes each question as follows:
15 | * English questions: `{question}\nPlease reason step by step, and put your final answer within \\boxed{}.`
16 | * Chinese questions: `{question}\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。`
17 |
18 | For tool-integrated reasoning, we process each question as follows:
19 | * English questions: `{question}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{}.`
20 | * Chinese questions: `{question}\n请结合自然语言和Python程序语言来解答问题,并把最终答案放置于\\boxed{}中。`
21 |
22 | We provide an example of testing the DeepSeekMath-Base 7B using 8 GPUs.
23 |
24 | If you wish to use a different model or dataset, you can modify the configs in `submit_eval_jobs.py` and `configs/*test_configs.json`
25 |
26 | ```
27 | python submit_eval_jobs.py --n-gpus 8
28 | ```
29 |
30 | Wait for all processes to finish, and then run the following command to aggregate results from all processes
31 |
32 | ```
33 | python summarize_results.py [--eval-atp]
34 | ```
35 | where the option `--eval-atp` will invoke `unsafe_score_minif2f_isabelle.py` to evaluate the informal-to-formal proving results. Please make sure you have set up the [PISA](https://github.com/wellecks/lm-evaluation-harness/blob/minif2f-isabelle/docs/isabelle_setup.md) server before using this option.
36 |
37 | A summary of all evaluation results will be saved as `evaluation_results.json`
38 |
39 | ## 4. Model Outputs
40 |
41 | We provide all model outputs in `outputs.zip`.
42 |
--------------------------------------------------------------------------------
/evaluation/configs/few_shot_test_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "gsm8k-cot-test": {
3 | "test_path": "datasets/gsm8k/test.jsonl",
4 | "language": "en",
5 | "tasks": ["cot"],
6 | "process_fn": "process_gsm8k_test",
7 | "answer_extraction_fn": "extract_gsm_few_shot_cot_answer",
8 | "eval_fn": "eval_last_single_answer",
9 | "few_shot_prompt": "CoTGSMPrompt"
10 | },
11 | "gsm8k-pal-test": {
12 | "test_path": "datasets/gsm8k/test.jsonl",
13 | "language": "en",
14 | "tasks": ["pal"],
15 | "process_fn": "process_gsm8k_test",
16 | "answer_extraction_fn": "placeholder",
17 | "eval_fn": "eval_last_single_answer",
18 | "few_shot_prompt": "PALGSMPrompt"
19 | },
20 | "math-cot-test": {
21 | "test_path": "datasets/math/test.jsonl",
22 | "language": "en",
23 | "tasks": ["cot"],
24 | "process_fn": "process_math_test",
25 | "answer_extraction_fn": "extract_math_few_shot_cot_answer",
26 | "eval_fn": "eval_math",
27 | "few_shot_prompt": "MinervaMathPrompt"
28 | },
29 | "math-pal-test": {
30 | "test_path": "datasets/math/test.jsonl",
31 | "language": "en",
32 | "tasks": ["pal"],
33 | "process_fn": "process_math_test",
34 | "answer_extraction_fn": "placeholder",
35 | "eval_fn": "eval_math",
36 | "few_shot_prompt": "PALMathPrompt"
37 | },
38 | "math_sat": {
39 | "test_path": "datasets/sat/test.jsonl",
40 | "language": "en",
41 | "tasks": ["cot"],
42 | "process_fn": "process_math_sat",
43 | "answer_extraction_fn": "extract_sat_few_shot_answer",
44 | "eval_fn": "eval_math_sat",
45 | "few_shot_prompt": "CoTSATPrompt"
46 | },
47 | "OCWCourses": {
48 | "test_path": "datasets/ocw/test.jsonl",
49 | "language": "en",
50 | "tasks": ["cot"],
51 | "process_fn": "process_ocwcourses",
52 | "answer_extraction_fn": "extract_ocwcourses_few_shot_answer",
53 | "eval_fn": "eval_ocwcourses",
54 | "few_shot_prompt": "OCWCoursesPrompt"
55 | },
56 | "MMLU-STEM-test": {
57 | "test_path": "datasets/mmlu_stem/test.jsonl",
58 | "language": "en",
59 | "tasks": ["cot"],
60 | "process_fn": "process_mmlu_stem",
61 | "answer_extraction_fn": "extract_mmlu_stem",
62 | "eval_fn": "eval_mmlu_stem",
63 | "few_shot_prompt": "MMLUSTEMPrompt"
64 | },
65 | "miniF2F-Isabelle-valid": {
66 | "test_path": "datasets/minif2f/validation.jsonl",
67 | "language": "en",
68 | "tasks": ["cot"],
69 | "process_fn": "process_minif2f_isabelle",
70 | "answer_extraction_fn": "extract_minif2f_isabelle",
71 | "eval_fn": "eval_minif2f_isabelle",
72 | "few_shot_prompt": "MiniF2FIsabellePrompt"
73 | },
74 | "miniF2F-Isabelle-test": {
75 | "test_path": "datasets/minif2f/test.jsonl",
76 | "language": "en",
77 | "tasks": ["cot"],
78 | "process_fn": "process_minif2f_isabelle",
79 | "answer_extraction_fn": "extract_minif2f_isabelle",
80 | "eval_fn": "eval_minif2f_isabelle",
81 | "few_shot_prompt": "MiniF2FIsabellePrompt"
82 | },
83 | "cmath-cot-test": {
84 | "test_path": "datasets/cmath/test.jsonl",
85 | "language": "zh",
86 | "tasks": ["cot"],
87 | "process_fn": "process_cmath",
88 | "answer_extraction_fn": "extract_cmath_few_shot_test",
89 | "eval_fn": "eval_last_single_answer",
90 | "few_shot_prompt": "CoTCMATHPrompt"
91 | },
92 | "agieval-gaokao-mathcloze-cot-test": {
93 | "test_path": "datasets/agieval/gaokao-mathcloze.jsonl",
94 | "language": "zh",
95 | "tasks": ["cot"],
96 | "process_fn": "process_agieval_gaokao_math_cloze",
97 | "answer_extraction_fn": "extract_agieval_gaokao_mathcloze_few_shot_cot_test",
98 | "eval_fn": "eval_agieval_gaokao_math_cloze",
99 | "few_shot_prompt": "CoTGaoKaoMathClozePrompt"
100 | },
101 | "agieval-gaokao-mathqa-cot-test": {
102 | "test_path": "datasets/agieval/gaokao-mathqa.jsonl",
103 | "language": "zh",
104 | "tasks": ["cot"],
105 | "process_fn": "process_agieval_gaokao_mathqa_few_shot_cot_test",
106 | "answer_extraction_fn": "extract_agieval_gaokao_mathqa_few_shot_cot_test",
107 | "eval_fn": "eval_agieval_gaokao_mathqa",
108 | "few_shot_prompt": "CoTGaoKaoMathQAPrompt"
109 | }
110 | }
--------------------------------------------------------------------------------
/evaluation/configs/zero_shot_test_configs.json:
--------------------------------------------------------------------------------
1 | {
2 | "gsm8k-test": {
3 | "test_path": "datasets/gsm8k/test.jsonl",
4 | "language": "en",
5 | "tasks": ["tool", "cot"],
6 | "process_fn": "process_gsm8k_test",
7 | "answer_extraction_fn": "extract_last_single_answer",
8 | "eval_fn": "eval_last_single_answer"
9 | },
10 | "math-test": {
11 | "test_path": "datasets/math/test.jsonl",
12 | "language": "en",
13 | "tasks": ["tool", "cot"],
14 | "process_fn": "process_math_test",
15 | "answer_extraction_fn": "extract_math_answer",
16 | "eval_fn": "eval_math"
17 | },
18 | "mgsm-zh": {
19 | "test_path": "datasets/mgsm_zh/mgsm_zh.jsonl",
20 | "language": "zh",
21 | "tasks": ["tool", "cot"],
22 | "process_fn": "process_mgsm_zh",
23 | "answer_extraction_fn": "extract_last_single_answer",
24 | "eval_fn": "eval_last_single_answer"
25 | },
26 | "cmath": {
27 | "test_path": "datasets/cmath/test.jsonl",
28 | "language": "zh",
29 | "tasks": ["tool", "cot"],
30 | "process_fn": "process_cmath",
31 | "answer_extraction_fn": "extract_last_single_answer",
32 | "eval_fn": "eval_last_single_answer"
33 | }
34 | }
--------------------------------------------------------------------------------
/evaluation/data_processing/answer_extraction.py:
--------------------------------------------------------------------------------
1 | import re
2 | import regex
3 |
4 | def _fix_fracs(string):
5 | substrs = string.split("\\frac")
6 | new_str = substrs[0]
7 | if len(substrs) > 1:
8 | substrs = substrs[1:]
9 | for substr in substrs:
10 | new_str += "\\frac"
11 | if len(substr) > 0 and substr[0] == "{":
12 | new_str += substr
13 | else:
14 | try:
15 | assert len(substr) >= 2
16 | except:
17 | return string
18 | a = substr[0]
19 | b = substr[1]
20 | if b != "{":
21 | if len(substr) > 2:
22 | post_substr = substr[2:]
23 | new_str += "{" + a + "}{" + b + "}" + post_substr
24 | else:
25 | new_str += "{" + a + "}{" + b + "}"
26 | else:
27 | if len(substr) > 2:
28 | post_substr = substr[2:]
29 | new_str += "{" + a + "}" + b + post_substr
30 | else:
31 | new_str += "{" + a + "}" + b
32 | string = new_str
33 | return string
34 |
35 |
36 | def _fix_a_slash_b(string):
37 | if len(string.split("/")) != 2:
38 | return string
39 | a = string.split("/")[0]
40 | b = string.split("/")[1]
41 | try:
42 | if "sqrt" not in a:
43 | a = int(a)
44 | if "sqrt" not in b:
45 | b = int(b)
46 | assert string == "{}/{}".format(a, b)
47 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
48 | return new_string
49 | except:
50 | return string
51 |
52 |
53 | def _fix_sqrt(string):
54 | _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string)
55 | _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string)
56 | return _string
57 |
58 |
59 | def _fix_tan(string):
60 | _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string)
61 | _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string)
62 | return _string
63 |
64 |
65 | def strip_string(string):
66 | string = str(string).strip()
67 | # linebreaks
68 | string = string.replace("\n", "")
69 |
70 | # right "."
71 | string = string.rstrip(".")
72 |
73 | # remove inverse spaces
74 | string = string.replace("\\!", "")
75 | # string = string.replace("\\ ", "")
76 |
77 | # replace \\ with \
78 | # string = string.replace("\\\\", "\\")
79 | # string = string.replace("\\\\", "\\")
80 |
81 | if string.startswith("\\text{") and string.endswith("}"):
82 | string = string.split("{", 1)[1][:-1]
83 |
84 | # replace tfrac and dfrac with frac
85 | string = string.replace("tfrac", "frac")
86 | string = string.replace("dfrac", "frac")
87 | string = string.replace("cfrac", "frac")
88 |
89 | # remove \left and \right
90 | string = string.replace("\\left", "")
91 | string = string.replace("\\right", "")
92 |
93 | # Remove unit: miles, dollars if after is not none
94 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
95 | if _string != "" and _string != string:
96 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
97 | string = _string
98 |
99 | # Remove circ (degrees)
100 | string = string.replace("^{\\circ}", "").strip()
101 | string = string.replace("^\\circ", "").strip()
102 |
103 | string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip()
104 | string = regex.sub(r"p\.m\.$", "", string).strip()
105 | string = regex.sub(r"(\d)\s*t$", r"\1", string).strip()
106 |
107 | # remove dollar signs
108 | string = string.replace("\\$", "")
109 | string = string.replace("$", "")
110 |
111 | # string = string.replace("\\text", "")
112 | string = string.replace("x\\in", "")
113 |
114 | # remove percentage
115 | string = string.replace("\\%", "%")
116 | string = string.replace("\%", "%")
117 | # string = string.replace("%", "")
118 |
119 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
120 | string = string.replace(" .", " 0.")
121 | string = string.replace("{.", "{0.")
122 |
123 | # cdot
124 | string = string.replace("\\cdot", "")
125 |
126 | # inf
127 | string = string.replace("infinity", "\\infty")
128 | if "\\infty" not in string:
129 | string = string.replace("inf", "\\infty")
130 | string = string.replace("+\\inity", "\\infty")
131 |
132 | # and
133 | # string = string.replace("and", "")
134 | string = string.replace("\\mathbf", "")
135 | string = string.replace("\\mathrm", "")
136 |
137 | # use regex to remove \mbox{...}
138 | string = re.sub(r"\\mbox{.*?}", "", string)
139 |
140 | # quote
141 | string.replace("'", "")
142 | string.replace("\"", "")
143 |
144 | # i, j
145 | if "j" in string and "i" not in string:
146 | string = string.replace("j", "i")
147 |
148 | # replace a.000b where b is not number or b is end, with ab, use regex
149 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
150 | string = re.sub(r"(\d+)\.0+$", r"\1", string)
151 |
152 | # if empty, return empty string
153 | if len(string) == 0:
154 | return string
155 | if string[0] == ".":
156 | string = "0" + string
157 |
158 | # to consider: get rid of e.g. "k = " or "q = " at beginning
159 | # if len(string.split("=")) == 2:
160 | # if len(string.split("=")[0]) <= 2:
161 | # string = string.split("=")[1]
162 |
163 | string = _fix_sqrt(string)
164 | string = _fix_tan(string)
165 | string = string.replace(" ", "")
166 |
167 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
168 | string = _fix_fracs(string)
169 |
170 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
171 | string = _fix_a_slash_b(string)
172 |
173 | string = regex.sub(r"(\\|,|\.)+$", "", string)
174 |
175 | return string
176 |
177 | def extract_boxed_answers(text):
178 | answers = []
179 | for piece in text.split('boxed{')[1:]:
180 | n = 0
181 | for i in range(len(piece)):
182 | if piece[i] == '{':
183 | n += 1
184 | elif piece[i] == '}':
185 | n -= 1
186 | if n < 0:
187 | if i + 1 < len(piece) and piece[i + 1] == '%':
188 | answers.append(piece[: i + 1])
189 | else:
190 | answers.append(piece[:i])
191 | break
192 | return answers
193 |
194 | def extract_program_output(pred_str):
195 | """
196 | extract output between the last ```output\n...\n```
197 | """
198 | if "```output" not in pred_str:
199 | return ""
200 | if '```output' in pred_str:
201 | pred_str = pred_str.split('```output')[-1]
202 | if '```' in pred_str:
203 | pred_str = pred_str.split('```')[0]
204 | output = pred_str.strip()
205 | return output
206 |
207 | def extract_answer(pred_str, exhaust=False):
208 | pred = []
209 | if 'final answer is $' in pred_str and '$. I hope' in pred_str:
210 | tmp = pred_str.split('final answer is $', 1)[1]
211 | pred = [tmp.split('$. I hope', 1)[0].strip()]
212 | elif 'boxed' in pred_str:
213 | pred = extract_boxed_answers(pred_str)
214 | elif ('he answer is' in pred_str):
215 | pred = [pred_str.split('he answer is')[-1].strip()]
216 | else:
217 | program_output = extract_program_output(pred_str)
218 | if program_output != "":
219 | # fall back to program
220 | pred.append(program_output)
221 | else: # use the last number
222 | pattern = '-?\d*\.?\d+'
223 | ans = re.findall(pattern, pred_str.replace(",", ""))
224 | if(len(ans) >= 1):
225 | ans = ans[-1]
226 | else:
227 | ans = ''
228 | if ans:
229 | pred.append(ans)
230 |
231 | # multiple line
232 | _pred = []
233 | for ans in pred:
234 | ans = ans.strip().split("\n")[0]
235 | ans = ans.lstrip(":")
236 | ans = ans.rstrip(".")
237 | ans = ans.rstrip("/")
238 | ans = strip_string(ans)
239 | _pred.append(ans)
240 | if exhaust:
241 | return _pred
242 | else:
243 | return _pred[-1] if _pred else ""
244 |
245 | def extract_math_answer(question, reasoning, task):
246 | answer = []
247 | for ans in extract_answer(reasoning, exhaust=True):
248 | if 'separated by commas' in question and all(ch not in ans for ch in '()[]'):
249 | answer.extend([a.strip() for a in ans.split(",")])
250 | elif regex.search(r"\\text\{\s*and\s*\}", ans):
251 | answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")])
252 | else:
253 | answer.append(ans.strip())
254 | return answer
255 |
256 | def extract_math_few_shot_cot_answer(question, reasoning, task):
257 | if 'Problem:' in reasoning:
258 | reasoning = reasoning.split("Problem:", 1)[0]
259 | return extract_math_answer(question, reasoning, task)
260 |
261 | def extract_last_single_answer(question, reasoning, task):
262 | return extract_answer(reasoning, exhaust=False)
263 |
264 | def extract_gsm_few_shot_cot_answer(question, reasoning, task):
265 | if 'Q: ' in reasoning:
266 | reasoning = reasoning.split("Q: ", 1)[0]
267 | pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)]
268 | if pred:
269 | return pred[-1]
270 | else:
271 | return "[invalid]"
272 |
273 | def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task):
274 | if '问题 ' in reasoning:
275 | reasoning = reasoning.split("问题 ", 1)[0]
276 | if '答案是' in reasoning:
277 | ans = reasoning.split('答案是', 1)[1].strip()
278 | ans = ans.split("\n")[0].strip()
279 | ans = [ans.strip("$")]
280 | else:
281 | ans = ['placeholder']
282 | return ans
283 |
284 | def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task):
285 | if '问题 ' in reasoning:
286 | reasoning = reasoning.split("问题 ", 1)[0]
287 | if '答案是' in reasoning:
288 | ans = reasoning.split('答案是', 1)[1].strip()
289 | ans = ans.split("\n")[0].strip()
290 | else:
291 | ans = 'placeholder'
292 | return ans
293 |
294 | def extract_sat_few_shot_answer(question, reasoning, task):
295 | if 'Problem:' in reasoning:
296 | reasoning = reasoning.split("Problem:", 1)[0]
297 | patt = regex.search(r"the final answer is \(?(?P[abcd])\)?", reasoning.lower())
298 | if patt is not None:
299 | return patt.group('ans').upper()
300 | return 'placeholder'
301 |
302 | def extract_ocwcourses_few_shot_answer(question, reasoning, task):
303 | if 'Problem:' in reasoning:
304 | reasoning = reasoning.split("Problem:", 1)[0]
305 | patt = regex.search(r"final answer is (?P.*)\. I hope it is correct.", reasoning)
306 | if patt is None:
307 | pred = "[invalid]"
308 | print(f"DEBUG >>>\n{reasoning}", flush=True)
309 | else:
310 | pred = patt.group('ans')
311 | return pred
312 |
313 | def extract_mmlu_stem(question, reasoning, task):
314 | if 'Problem:' in reasoning:
315 | reasoning = reasoning.split("Problem:", 1)[0]
316 | return extract_sat_few_shot_answer(question, reasoning, task)
317 |
318 | def extract_minif2f_isabelle(question, reasoning, task):
319 | if 'Informal:' in reasoning:
320 | reasoning = reasoning.split("Informal:", 1)[0]
321 | return reasoning.strip()
322 |
323 | def extract_cmath_few_shot_test(question, reasoning, task):
324 | if '问题:' in reasoning:
325 | reasoning = reasoning.split("问题:", 1)[0]
326 | if '答案是' in reasoning:
327 | ans = reasoning.split('答案是', 1)[1].strip()
328 | ans = ans.split("\n")[0]
329 | ans = ans.strip(":")
330 | ans = ans.strip("。")
331 | try:
332 | ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1]
333 | except:
334 | print(f"DEBUG CMATH: {reasoning}", flush=True)
335 | ans = "[invalid]"
336 | else:
337 | ans = extract_last_single_answer(question, reasoning, task)
338 | return ans
339 |
--------------------------------------------------------------------------------
/evaluation/data_processing/process_utils.py:
--------------------------------------------------------------------------------
1 | import regex
2 |
3 | from data_processing.answer_extraction import extract_math_answer, strip_string
4 |
5 | def process_gsm8k_test(item):
6 | sample = {
7 | 'dataset': 'gsm8k-cot',
8 | 'id': item['id'],
9 | 'messages': [
10 | {'role': 'user', 'content': item['question']},
11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."}
12 | ],
13 | 'answer': item['answer'].replace(',', '')
14 | }
15 | yield sample
16 |
17 | def process_math_test(item):
18 | question = item["problem"]
19 | try:
20 | answer = extract_math_answer(question, item['solution'], task="cot")
21 | except:
22 | return
23 | sample = {
24 | "dataset": "math-cot",
25 | "id": item['id'],
26 | "level": item["level"],
27 | "type": item["type"],
28 | "category": item["category"],
29 | "messages": [
30 | {"role": "user", "content": question},
31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))}
32 | ],
33 | "answer": answer
34 | }
35 | yield sample
36 |
37 | def process_math_sat(item):
38 | options = item['options'].strip()
39 | assert 'A' == options[0]
40 | options = '(' + options
41 | for ch in 'BCDEFG':
42 | if f' {ch}) ' in options:
43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options)
44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
45 | messages = [
46 | {'role': 'user', 'content': question},
47 | {'role': 'assistant', 'content': item['Answer']}
48 | ]
49 | item = {
50 | 'dataset': 'math_sat',
51 | 'id': item['id'],
52 | 'language': 'en',
53 | 'messages': messages,
54 | 'answer': item['Answer'],
55 | }
56 | yield item
57 |
58 | def process_ocwcourses(item):
59 | messages = [
60 | {'role': 'user', 'content': item['problem'].strip()},
61 | {'role': 'assistant', 'content': item['solution'].strip()}
62 | ]
63 | item = {
64 | "dataset": "OCWCourses",
65 | "id": item['id'],
66 | "language": "en",
67 | "messages": messages,
68 | "answer": item['answer']
69 | }
70 | yield item
71 |
72 | def process_mmlu_stem(item):
73 | options = item['options']
74 | for i, (label, option) in enumerate(zip('ABCD', options)):
75 | options[i] = f"({label}) {str(option).strip()}"
76 | options = ", ".join(options)
77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
78 | messages = [
79 | {'role': 'user', 'content': question},
80 | {'role': 'assistant', 'content': item['answer']}
81 | ]
82 | item = {
83 | "dataset": "MMLU-STEM",
84 | "id": item['id'],
85 | "language": "en",
86 | "messages": messages,
87 | "answer": item['answer']
88 | }
89 | yield item
90 |
91 | def process_mgsm_zh(item):
92 | item['answer'] = item['answer'].replace(',', '')
93 | yield item
94 |
95 | def process_cmath(item):
96 | item = {
97 | 'dataset': 'cmath',
98 | 'id': item['id'],
99 | 'grade': item['grade'],
100 | 'reasoning_step': item['reasoning_step'],
101 | 'messages': [
102 | {'role': 'user', 'content': item['question'].strip()},
103 | {'role': 'assistant', 'content': ''}
104 | ],
105 | 'answer': item['golden'].strip().replace(",", "")
106 | }
107 | yield item
108 |
109 | def process_agieval_gaokao_math_cloze(item):
110 | item = {
111 | 'dataset': 'agieval-gaokao-math-cloze',
112 | 'id': item['id'],
113 | 'messages': [
114 | {'role': 'user', 'content': item['question'].strip()},
115 | {'role': 'assistant', 'content': ''}
116 | ],
117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")]
118 | }
119 | yield item
120 |
121 | def process_agieval_gaokao_mathqa(item):
122 | question = item['question'].strip()
123 | options = []
124 | for option in item['options']:
125 | option = option.strip()
126 | assert option[0] == '('
127 | assert option[2] == ')'
128 | assert option[1] in 'ABCD'
129 | option = f"{option[1]}: {option[3:].strip()}"
130 | options.append(option.strip())
131 | question = f"{question}\n{options}"
132 | item = {
133 | 'dataset': 'agieval-gaokao-mathqa',
134 | 'id': item['id'],
135 | 'messages': [
136 | {'role': 'user', 'content': question},
137 | {'role': 'assistant', 'content': ''}
138 | ],
139 | "answer": item['label']
140 | }
141 | yield item
142 |
143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item):
144 | question = item['question'].strip().rstrip('\\')
145 | options = " ".join([opt.strip() for opt in item['options']])
146 | question = f"{question}\n从以下选项中选择: {options}"
147 | item = {
148 | 'dataset': 'agieval-gaokao-mathqa',
149 | 'id': item['id'],
150 | 'messages': [
151 | {'role': 'user', 'content': question},
152 | {'role': 'assistant', 'content': ''}
153 | ],
154 | "answer": item['label']
155 | }
156 | yield item
157 |
158 | def process_minif2f_isabelle(item):
159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}"
160 | item = {
161 | 'dataset': 'minif2f-isabelle',
162 | 'id': item['id'],
163 | 'messages': [
164 | {'role': 'user', 'content': question},
165 | {'role': 'assistant', 'content': ''}
166 | ],
167 | "answer": "placeholder"
168 | }
169 | yield item
170 |
--------------------------------------------------------------------------------
/evaluation/environment.yml:
--------------------------------------------------------------------------------
1 | name: vllm020
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - blas=1.0=mkl
10 | - brotli-python=1.0.9=py310h6a678d5_7
11 | - bzip2=1.0.8=h7b6447c_0
12 | - ca-certificates=2023.08.22=h06a4308_0
13 | - certifi=2023.7.22=py310h06a4308_0
14 | - cffi=1.15.1=py310h5eee18b_3
15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
16 | - cryptography=41.0.3=py310hdda0065_0
17 | - cuda-cudart=11.7.99=0
18 | - cuda-cupti=11.7.101=0
19 | - cuda-libraries=11.7.1=0
20 | - cuda-nvrtc=11.7.99=0
21 | - cuda-nvtx=11.7.91=0
22 | - cuda-runtime=11.7.1=0
23 | - ffmpeg=4.3=hf484d3e_0
24 | - filelock=3.9.0=py310h06a4308_0
25 | - freetype=2.12.1=h4a9f257_0
26 | - giflib=5.2.1=h5eee18b_3
27 | - gmp=6.2.1=h295c915_3
28 | - gmpy2=2.1.2=py310heeb90bb_0
29 | - gnutls=3.6.15=he1e5248_0
30 | - idna=3.4=py310h06a4308_0
31 | - intel-openmp=2023.1.0=hdb19cb5_46306
32 | - jinja2=3.1.2=py310h06a4308_0
33 | - jpeg=9e=h5eee18b_1
34 | - lame=3.100=h7b6447c_0
35 | - lcms2=2.12=h3be6417_0
36 | - ld_impl_linux-64=2.38=h1181459_1
37 | - lerc=3.0=h295c915_0
38 | - libcublas=11.10.3.66=0
39 | - libcufft=10.7.2.124=h4fbf590_0
40 | - libcufile=1.8.1.2=0
41 | - libcurand=10.3.4.101=0
42 | - libcusolver=11.4.0.1=0
43 | - libcusparse=11.7.4.91=0
44 | - libdeflate=1.17=h5eee18b_1
45 | - libffi=3.4.4=h6a678d5_0
46 | - libgcc-ng=11.2.0=h1234567_1
47 | - libgomp=11.2.0=h1234567_1
48 | - libiconv=1.16=h7f8727e_2
49 | - libidn2=2.3.4=h5eee18b_0
50 | - libnpp=11.7.4.75=0
51 | - libnvjpeg=11.8.0.2=0
52 | - libpng=1.6.39=h5eee18b_0
53 | - libstdcxx-ng=11.2.0=h1234567_1
54 | - libtasn1=4.19.0=h5eee18b_0
55 | - libtiff=4.5.1=h6a678d5_0
56 | - libunistring=0.9.10=h27cfd23_0
57 | - libuuid=1.41.5=h5eee18b_0
58 | - libwebp=1.3.2=h11a3e52_0
59 | - libwebp-base=1.3.2=h5eee18b_0
60 | - lz4-c=1.9.4=h6a678d5_0
61 | - markupsafe=2.1.1=py310h7f8727e_0
62 | - mkl=2023.1.0=h213fc3f_46344
63 | - mkl-service=2.4.0=py310h5eee18b_1
64 | - mkl_fft=1.3.8=py310h5eee18b_0
65 | - mkl_random=1.2.4=py310hdb19cb5_0
66 | - mpc=1.1.0=h10f8cd9_1
67 | - mpfr=4.0.2=hb69a4c5_1
68 | - mpmath=1.3.0=py310h06a4308_0
69 | - ncurses=6.4=h6a678d5_0
70 | - nettle=3.7.3=hbbd107a_1
71 | - networkx=3.1=py310h06a4308_0
72 | - numpy=1.26.0=py310h5f9d8c6_0
73 | - numpy-base=1.26.0=py310hb5e798b_0
74 | - openh264=2.1.1=h4ff587b_0
75 | - openjpeg=2.4.0=h3ad879b_0
76 | - openssl=3.0.12=h7f8727e_0
77 | - pillow=10.0.1=py310ha6cbd5a_0
78 | - pip=23.3=py310h06a4308_0
79 | - pycparser=2.21=pyhd3eb1b0_0
80 | - pyopenssl=23.2.0=py310h06a4308_0
81 | - pysocks=1.7.1=py310h06a4308_0
82 | - python=3.10.13=h955ad1f_0
83 | - pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0
84 | - pytorch-cuda=11.7=h778d358_5
85 | - pytorch-mutex=1.0=cuda
86 | - readline=8.2=h5eee18b_0
87 | - requests=2.31.0=py310h06a4308_0
88 | - setuptools=68.0.0=py310h06a4308_0
89 | - sqlite=3.41.2=h5eee18b_0
90 | - tbb=2021.8.0=hdb19cb5_0
91 | - tk=8.6.12=h1ccaba5_0
92 | - torchaudio=2.0.2=py310_cu117
93 | - torchtriton=2.0.0=py310
94 | - torchvision=0.15.2=py310_cu117
95 | - urllib3=1.26.18=py310h06a4308_0
96 | - wheel=0.41.2=py310h06a4308_0
97 | - xz=5.4.2=h5eee18b_0
98 | - zlib=1.2.13=h5eee18b_0
99 | - zstd=1.5.5=hc292b87_0
100 | - pip:
101 | - absl-py==2.0.0
102 | - accelerate==0.21.0
103 | - aiofiles==23.2.1
104 | - aiohttp==3.8.6
105 | - aiosignal==1.3.1
106 | - altair==5.1.2
107 | - antlr4-python3-runtime==4.11.0
108 | - anyio==3.7.1
109 | - appdirs==1.4.4
110 | - asttokens==2.4.1
111 | - async-timeout==4.0.3
112 | - attrs==23.1.0
113 | - bitarray==2.8.3
114 | - bitsandbytes==0.41.2.post2
115 | - blessed==1.20.0
116 | - blinker==1.7.0
117 | - cachetools==5.3.2
118 | - clarabel==0.6.0
119 | - click==8.1.7
120 | - contourpy==1.2.0
121 | - cvxpy==1.4.1
122 | - cycler==0.12.1
123 | - datasets==2.14.5
124 | - decorator==5.1.1
125 | - deepspeed==0.11.0
126 | - dill==0.3.7
127 | - distro==1.8.0
128 | - docker-pycreds==0.4.0
129 | - ecos==2.0.12
130 | - einops==0.7.0
131 | - evaluate==0.4.1
132 | - exceptiongroup==1.1.3
133 | - executing==2.0.1
134 | - fastapi==0.104.1
135 | - ffmpy==0.3.1
136 | - fire==0.5.0
137 | - flash-attn==2.0.1
138 | - flask==3.0.0
139 | - fonttools==4.44.3
140 | - frozenlist==1.4.0
141 | - fsspec==2023.6.0
142 | - func-timeout==4.3.5
143 | - gitdb==4.0.11
144 | - gitpython==3.1.40
145 | - google-auth==2.25.2
146 | - google-auth-oauthlib==1.2.0
147 | - gpustat==1.1.1
148 | - gradio==3.50.2
149 | - gradio-client==0.6.1
150 | - grpcio==1.59.2
151 | - h11==0.14.0
152 | - hjson==3.1.0
153 | - httpcore==1.0.2
154 | - httptools==0.6.1
155 | - httpx==0.25.1
156 | - huggingface-hub==0.19.3
157 | - importlib-resources==6.1.1
158 | - ipython==8.17.2
159 | - itsdangerous==2.1.2
160 | - jedi==0.19.1
161 | - joblib==1.3.2
162 | - jsonlines==4.0.0
163 | - jsonschema==4.19.2
164 | - jsonschema-specifications==2023.11.1
165 | - kiwisolver==1.4.5
166 | - markdown==3.5.1
167 | - markdown-it-py==3.0.0
168 | - matplotlib==3.8.1
169 | - matplotlib-inline==0.1.6
170 | - mdurl==0.1.2
171 | - mecab-python3==1.0.8
172 | - msgpack==1.0.7
173 | - multidict==6.0.4
174 | - multiprocess==0.70.15
175 | - ninja==1.11.1.1
176 | - nltk==3.8.1
177 | - nvidia-ml-py==12.535.133
178 | - oauthlib==3.2.2
179 | - openai==1.3.0
180 | - orjson==3.9.10
181 | - osqp==0.6.3
182 | - packaging==23.2
183 | - pandas==2.1.3
184 | - parso==0.8.3
185 | - pebble==5.0.3
186 | - peft==0.6.2
187 | - pexpect==4.8.0
188 | - prompt-toolkit==3.0.41
189 | - protobuf==4.23.4
190 | - psutil==5.9.6
191 | - ptyprocess==0.7.0
192 | - pure-eval==0.2.2
193 | - py-cpuinfo==9.0.0
194 | - pyarrow==14.0.1
195 | - pyasn1==0.5.1
196 | - pyasn1-modules==0.3.0
197 | - pybind11==2.11.1
198 | - pydantic==1.10.13
199 | - pydub==0.25.1
200 | - pygments==2.16.1
201 | - pylatexenc==2.10
202 | - pyparsing==3.1.1
203 | - python-dateutil==2.8.2
204 | - python-dotenv==1.0.0
205 | - python-multipart==0.0.6
206 | - pytz==2023.3.post1
207 | - pyyaml==6.0.1
208 | - qdldl==0.1.7.post0
209 | - ray==2.6.3
210 | - referencing==0.31.0
211 | - regex==2023.10.3
212 | - requests-oauthlib==1.3.1
213 | - responses==0.18.0
214 | - rich==13.7.0
215 | - rouge-score==0.1.2
216 | - rpds-py==0.12.0
217 | - rsa==4.9
218 | - safetensors==0.4.0
219 | - scipy==1.11.3
220 | - scs==3.2.4
221 | - semantic-version==2.10.0
222 | - sentencepiece==0.1.99
223 | - sentry-sdk==1.35.0
224 | - setproctitle==1.3.3
225 | - six==1.16.0
226 | - smmap==5.0.1
227 | - sniffio==1.3.0
228 | - stack-data==0.6.3
229 | - starlette==0.27.0
230 | - sympy==1.12
231 | - tensorboard==2.15.1
232 | - tensorboard-data-server==0.7.2
233 | - termcolor==2.3.0
234 | - tiktoken==0.5.1
235 | - timeout-decorator==0.5.0
236 | - tokenizers==0.15.0
237 | - toolz==0.12.0
238 | - tqdm==4.66.1
239 | - traitlets==5.13.0
240 | - transformers==4.35.2
241 | - typing-extensions==4.8.0
242 | - tzdata==2023.3
243 | - unidic-lite==1.0.8
244 | - uvicorn==0.24.0.post1
245 | - uvloop==0.19.0
246 | - vllm==0.2.0
247 | - wandb==0.16.0
248 | - watchfiles==0.21.0
249 | - wcwidth==0.2.10
250 | - websockets==11.0.3
251 | - werkzeug==3.0.1
252 | - xformers==0.0.22
253 | - xxhash==3.4.1
254 | - yarl==1.9.2
255 | - zstandard==0.22.0
256 | prefix:
257 |
--------------------------------------------------------------------------------
/evaluation/eval/eval_script.py:
--------------------------------------------------------------------------------
1 | import regex
2 | from copy import deepcopy
3 | from eval.eval_utils import math_equal
4 | from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
5 |
6 | def is_correct(item, pred_key='prediction', prec=1e-3):
7 | pred = item[pred_key]
8 | ans = item['answer']
9 | if isinstance(pred, list) and isinstance(ans, list):
10 | pred_matched = set()
11 | ans_matched = set()
12 | for i in range(len(pred)):
13 | for j in range(len(ans)):
14 | item_cpy = deepcopy(item)
15 | item_cpy.update({
16 | pred_key: pred[i],
17 | 'answer': ans[j]
18 | })
19 | if is_correct(item_cpy, pred_key=pred_key, prec=prec):
20 | pred_matched.add(i)
21 | ans_matched.add(j)
22 | if item_cpy[pred_key] == '2,3,4':
23 | print(item, flush=True)
24 | print("wtf", flush=True)
25 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
26 | elif isinstance(pred, str) and isinstance(ans, str):
27 | if '\\cup' in pred and '\\cup' in ans:
28 | item = deepcopy(item)
29 | item.update({
30 | pred_key: pred.split('\\cup'),
31 | 'answer': ans.split('\\cup'),
32 | })
33 | return is_correct(item, pred_key=pred_key, prec=prec)
34 | else:
35 | label = False
36 | try:
37 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
38 | except:
39 | pass
40 | label = label or (ans and pred == ans) or math_equal(pred, ans)
41 | return label
42 | else:
43 | print(item, flush=True)
44 | raise NotImplementedError()
45 |
46 | def eval_math(item, pred_key='prediction', prec=1e-3):
47 | pred = item[pred_key]
48 | if pred_key == 'program_output' and isinstance(pred, str):
49 | pred = [pred]
50 | ans = item['answer']
51 | if isinstance(pred, list) and isinstance(ans, list):
52 | # for some questions in MATH, `reference` repeats answers
53 | _ans = []
54 | for a in ans:
55 | if a not in _ans:
56 | _ans.append(a)
57 | ans = _ans
58 | # some predictions for MATH questions also repeats answers
59 | _pred = []
60 | for a in pred:
61 | if a not in _pred:
62 | _pred.append(a)
63 | # some predictions mistakenly box non-answer strings
64 | pred = _pred[-len(ans):]
65 |
66 | item.update({
67 | pred_key: pred,
68 | 'answer': ans
69 | })
70 | return is_correct(item, pred_key=pred_key, prec=prec)
71 |
72 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
73 | for key in [pred_key, 'answer']:
74 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
75 | return is_correct(item, pred_key=pred_key, prec=prec)
76 |
77 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
78 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
79 | item[pred_key] = [item[pred_key]]
80 | for key in [pred_key, 'answer']:
81 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
82 | pred = item[pred_key]
83 | ans = item['answer']
84 | _pred = []
85 | for p in pred:
86 | p = p + ";"
87 | while p:
88 | left_brackets = 0
89 | for i in range(len(p)):
90 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
91 | _p, p = p[:i].strip(), p[i + 1:].strip()
92 | if _p not in _pred:
93 | _pred.append(_p)
94 | break
95 | elif p[i] in '([{':
96 | left_brackets += 1
97 | elif p[i] in ')]}':
98 | left_brackets -= 1
99 | pred = _pred[-len(ans):]
100 | if len(pred) == len(ans):
101 | for p, a in zip(pred, ans):
102 | item.update({
103 | pred_key: p,
104 | 'answer': a,
105 | })
106 | if not is_correct(item, pred_key=pred_key, prec=prec):
107 | return False
108 | return True
109 | else:
110 | return False
111 |
112 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
113 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
114 | item[pred_key] = [item[pred_key]]
115 | pred_str = " ".join(item[pred_key])
116 | ans = item['answer']
117 | tag = None
118 | idx = -1
119 | for t in 'ABCD':
120 | if t in pred_str and pred_str.index(t) > idx:
121 | tag = t
122 | idx = pred_str.index(t)
123 | return tag == ans
124 |
125 | def eval_math_sat(item, pred_key='prediction', prec=1e-3):
126 | for key in [pred_key, 'answer']:
127 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
128 | return item[pred_key].lower() == item['answer'].lower()
129 |
130 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
131 | return eval_math_sat(item, pred_key=pred_key, prec=prec)
132 |
133 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
134 | INVALID_ANSWER = "[invalidanswer]"
135 | for key in [pred_key, 'answer']:
136 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
137 | pred = item[pred_key]
138 | ans = item['answer']
139 |
140 | try:
141 | float(ans)
142 | normalize_fn = normalize_numeric
143 | is_equiv = numeric_equality
144 | answer_type = "numeric"
145 | except ValueError:
146 | if "=" in ans:
147 | normalize_fn = normalize_symbolic_equation
148 | is_equiv = lambda x, y: x==y
149 | answer_type = "equation"
150 | else:
151 | normalize_fn = SymbolicMathMixin().normalize_tex
152 | is_equiv = SymbolicMathMixin().is_tex_equiv
153 | answer_type = "expression"
154 |
155 | correct_answer = normalize_fn(ans)
156 |
157 | unnormalized_answer = pred if pred else INVALID_ANSWER
158 | model_answer = normalize_fn(unnormalized_answer)
159 |
160 | if unnormalized_answer == INVALID_ANSWER:
161 | acc = 0
162 | elif model_answer == INVALID_ANSWER:
163 | acc = 0
164 | elif is_equiv(model_answer, correct_answer):
165 | acc = 1
166 | else:
167 | acc = 0
168 |
169 | return acc
170 |
171 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
172 | return True
173 |
--------------------------------------------------------------------------------
/evaluation/eval/eval_utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from math import isclose
3 | import numpy as np
4 | from typing import Union, Any, Dict
5 |
6 | from sympy import simplify, N
7 | from sympy.parsing.sympy_parser import parse_expr
8 | from sympy.parsing.latex import parse_latex
9 | import re
10 | import regex
11 |
12 | from data_processing.answer_extraction import extract_answer, extract_program_output, strip_string
13 |
14 | def extract_program(result: str, last_only=True):
15 | """
16 | extract the program after "```python", and before "```"
17 | """
18 | program = ""
19 | start = False
20 | for line in result.split("\n"):
21 | if line.startswith("```python"):
22 | if last_only:
23 | program = "" # only extract the last program
24 | else:
25 | program += "\n# ========\n"
26 | start = True
27 | elif line.startswith("```"):
28 | start = False
29 | elif start:
30 | program += line + "\n"
31 | return program
32 |
33 |
34 | def parse_ground_truth(example: Dict[str, Any], data_name):
35 | if 'gt_cot' in example:
36 | return example['gt_cot'], strip_string(example['gt'])
37 |
38 | # parse ground truth
39 | if data_name in ["math", 'ocw']:
40 | gt_cot = example['solution']
41 | gt_ans = extract_answer(gt_cot)
42 | elif data_name == "gsm8k":
43 | gt_cot, gt_ans = example['answer'].split("####")
44 | elif data_name == "gsm-hard":
45 | gt_cot, gt_ans = example['code'], example['target']
46 | elif data_name == "svamp":
47 | gt_cot, gt_ans = example['Equation'], example['Answer']
48 | elif data_name == "asdiv":
49 | gt_cot = example['formula']
50 | gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
51 | elif data_name == "mawps":
52 | gt_cot, gt_ans = None, example['target']
53 | elif data_name == "tabmwp":
54 | gt_cot = example['solution']
55 | gt_ans = example['answer']
56 | if example['ans_type'] in ['integer_number', 'decimal_number']:
57 | if '/' in gt_ans:
58 | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
59 | elif ',' in gt_ans:
60 | gt_ans = float(gt_ans.replace(',', ''))
61 | elif '%' in gt_ans:
62 | gt_ans = float(gt_ans.split('%')[0]) / 100
63 | else:
64 | gt_ans = float(gt_ans)
65 | elif data_name == "bbh":
66 | gt_cot, gt_ans = None, example['target']
67 | else:
68 | raise NotImplementedError(data_name)
69 | # post process
70 | gt_cot = str(gt_cot).strip()
71 | gt_ans = strip_string(gt_ans)
72 | return gt_cot, gt_ans
73 |
74 |
75 | def parse_question(example, data_name):
76 | question = ""
77 | if data_name == "asdiv":
78 | question = f"{example['body'].strip()} {example['question'].strip()}"
79 | elif data_name == "svamp":
80 | body = example["Body"].strip()
81 | if not body.endswith("."):
82 | body = body + "."
83 | question = f'{body} {example["Question"].strip()}'
84 | elif data_name == "tabmwp":
85 | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
86 | question = f'Read the following table {title_str}and answer a question:\n'
87 | question += f'{example["table"]}\n{example["question"]}'
88 | if example['choices']:
89 | question += f' Please select from the following options: {example["choices"]}'
90 | else:
91 | for key in ['question', 'problem', 'Question', 'input']:
92 | if key in example:
93 | question = example[key]
94 | break
95 | assert question != ""
96 | return question.strip()
97 |
98 |
99 | def run_execute(executor, result, prompt_type, execute=False):
100 | if not result or result == 'error':
101 | return None, None
102 | report = None
103 |
104 | if "program_only" in prompt_type:
105 | prediction = extract_program_output(result)
106 | elif prompt_type in ["pot", "pal"] and execute:
107 | code = extract_program(result)
108 | prediction, report = executor.apply(code)
109 | else:
110 | prediction = extract_answer(result)
111 |
112 | prediction = strip_string(prediction)
113 | return prediction, report
114 |
115 |
116 | def parse_digits(num):
117 | # format: 234.23 || 23%
118 | num = regex.sub(',', '', str(num))
119 | try:
120 | return float(num)
121 | except:
122 | if num.endswith('%'):
123 | num = num[:-1]
124 | if num.endswith('\\'):
125 | num = num[:-1]
126 | try:
127 | return float(num) / 100
128 | except:
129 | pass
130 | return None
131 |
132 | def is_digit(num):
133 | # paired with parse_digits
134 | return parse_digits(num) is not None
135 |
136 |
137 | def normalize_prediction(prediction):
138 | try: # 1. numerical equal
139 | if is_digit(prediction):
140 | prediction = np.round(float(str(prediction).replace(",", "")), 6)
141 | return str(prediction)
142 | except:
143 | pass
144 |
145 | # 2. symbolic equal
146 | prediction = str(prediction).strip()
147 |
148 | ## deal with [], (), {}
149 | brackets = []
150 | while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")):
151 | bracket = prediction[0]
152 | prediction = prediction[1:-1]
153 | if brackets and ',' in prediction:
154 | pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
155 | prediction = ",".join(pred_parts)
156 |
157 | if brackets:
158 | for b in reversed(brackets):
159 | if b == '[':
160 | prediction = '[' + prediction + ']'
161 | else:
162 | assert b == '('
163 | prediction = '(' + prediction + ')'
164 |
165 | def _parse(s):
166 | for f in [parse_latex, parse_expr]:
167 | try:
168 | return f(s)
169 | except:
170 | pass
171 | return s
172 |
173 | prediction = _parse(prediction)
174 |
175 | for s in ['{', "}", "(", ")"]:
176 | prediction = prediction.replace(s, "")
177 |
178 | return prediction
179 |
180 |
181 | def math_equal(prediction: Union[bool, float, str],
182 | reference: Union[float, str],
183 | include_percentage: bool = True,
184 | is_close: bool = True,
185 | timeout: bool = False,
186 | ) -> bool:
187 | """
188 | Exact match of math if and only if:
189 | 1. numerical equal: both can convert to float and are equal
190 | 2. symbolic equal: both can convert to sympy expression and are equal
191 | """
192 | if str(prediction) == str(reference):
193 | return True
194 |
195 | try: # 1. numerical equal
196 | if is_digit(prediction) and is_digit(reference):
197 | prediction = parse_digits(prediction)
198 | reference = parse_digits(reference)
199 | # number questions
200 | if include_percentage:
201 | gt_result = [reference / 100, reference, reference * 100]
202 | else:
203 | gt_result = [reference]
204 | for item in gt_result:
205 | try:
206 | if is_close:
207 | if isclose(item, prediction, abs_tol=1e-3):
208 | return True
209 | else:
210 | if item == prediction:
211 | return True
212 | except Exception:
213 | continue
214 | return False
215 | except:
216 | pass
217 |
218 | if not prediction and prediction not in [0, False]:
219 | return False
220 |
221 | # 2. symbolic equal
222 | reference = str(reference).strip()
223 | prediction = str(prediction).strip()
224 |
225 | if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
226 | pred_parts = prediction[1:-1].split(",")
227 | ref_parts = reference[1:-1].split(",")
228 | if len(pred_parts) == len(ref_parts):
229 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
230 | return True
231 |
232 | if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
233 | (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
234 | pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
235 | ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
236 | matched = True
237 | if len(pred_lines) == len(ref_lines):
238 | for pred_line, ref_line in zip(pred_lines, ref_lines):
239 | pred_parts = pred_line.split("&")
240 | ref_parts = ref_line.split("&")
241 | if len(pred_parts) == len(ref_parts):
242 | if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
243 | matched = False
244 | break
245 | else:
246 | matched = False
247 | if not matched:
248 | break
249 | else:
250 | matched = False
251 | if matched:
252 | return True
253 |
254 | if prediction.count('=') == 1 and reference.count('=') == 1:
255 | pred = prediction.split('=')
256 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
257 | ref = reference.split('=')
258 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
259 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
260 | return True
261 | elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
262 | if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
263 | return True
264 | elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
265 | if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
266 | return True
267 |
268 | # symbolic equal with sympy
269 | if timeout:
270 | if call_with_timeout(symbolic_equal_process, prediction, reference):
271 | return True
272 | else:
273 | if symbolic_equal(prediction, reference):
274 | return True
275 |
276 | return False
277 |
278 |
279 | def math_equal_process(param):
280 | return math_equal(param[-2], param[-1])
281 |
282 |
283 | def symbolic_equal(a, b):
284 | def _parse(s):
285 | for f in [parse_latex, parse_expr]:
286 | try:
287 | return f(s)
288 | except:
289 | pass
290 | return s
291 | a = _parse(a)
292 | b = _parse(b)
293 |
294 | try:
295 | if simplify(a-b) == 0:
296 | return True
297 | except:
298 | pass
299 |
300 | try:
301 | if isclose(N(a), N(b), abs_tol=1e-3):
302 | return True
303 | except:
304 | pass
305 | return False
306 |
307 |
308 | def symbolic_equal_process(a, b, output_queue):
309 | result = symbolic_equal(a, b)
310 | output_queue.put(result)
311 |
312 |
313 | def call_with_timeout(func, *args, timeout=1, **kwargs):
314 | output_queue = multiprocessing.Queue()
315 | process_args = args + (output_queue,)
316 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
317 | process.start()
318 | process.join(timeout)
319 |
320 | if process.is_alive():
321 | process.terminate()
322 | process.join()
323 | return False
324 |
325 | return output_queue.get()
326 |
--------------------------------------------------------------------------------
/evaluation/eval/ocwcourses_eval_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sympy
4 | from sympy.core.sympify import SympifyError
5 | from sympy.parsing.latex import parse_latex
6 |
7 | import signal
8 |
9 | INVALID_ANSWER = "[invalidanswer]"
10 |
11 | class timeout:
12 | def __init__(self, seconds=1, error_message="Timeout"):
13 | self.seconds = seconds
14 | self.error_message = error_message
15 |
16 | def handle_timeout(self, signum, frame):
17 | raise TimeoutError(self.error_message)
18 |
19 | def __enter__(self):
20 | signal.signal(signal.SIGALRM, self.handle_timeout)
21 | signal.alarm(self.seconds)
22 |
23 | def __exit__(self, type, value, traceback):
24 | signal.alarm(0)
25 |
26 | def normalize_numeric(s):
27 | if s is None:
28 | return None
29 | for unit in [
30 | "eV",
31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
32 | " kg m/s",
33 | "kg*m/s",
34 | "kg",
35 | "m/s",
36 | "m / s",
37 | "m s^{-1}",
38 | "\\text{ m/s}",
39 | " \\mathrm{m/s}",
40 | " \\text{ m/s}",
41 | "g/mole",
42 | "g/mol",
43 | "\\mathrm{~g}",
44 | "\\mathrm{~g} / \\mathrm{mol}",
45 | "W",
46 | "erg/s",
47 | "years",
48 | "year",
49 | "cm",
50 | ]:
51 | s = s.replace(unit, "")
52 | s = s.strip()
53 | for maybe_unit in ["m", "s", "cm"]:
54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "")
55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
56 | s = s.strip()
57 | s = s.strip("$")
58 | try:
59 | return float(eval(s))
60 | except:
61 | try:
62 | expr = parse_latex(s)
63 | if expr.is_number:
64 | return float(expr)
65 | return INVALID_ANSWER
66 | except:
67 | return INVALID_ANSWER
68 |
69 | def numeric_equality(n1, n2, threshold=0.01):
70 | if n1 is None or n2 is None:
71 | return False
72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
74 | else:
75 | return np.isclose(n1, n2)
76 |
77 | def normalize_symbolic_equation(s):
78 | if not isinstance(s, str):
79 | return INVALID_ANSWER
80 | if s.startswith("\\["):
81 | s = s[2:]
82 | if s.endswith("\\]"):
83 | s = s[:-2]
84 | s = s.replace("\\left(", "(")
85 | s = s.replace("\\right)", ")")
86 | s = s.replace("\\\\", "\\")
87 | if s.startswith("$") or s.endswith("$"):
88 | s = s.strip("$")
89 | try:
90 | maybe_expression = parse_latex(s)
91 | if not isinstance(maybe_expression, sympy.core.relational.Equality):
92 | # we have equation, not expression
93 | return INVALID_ANSWER
94 | else:
95 | return maybe_expression
96 | except:
97 | return INVALID_ANSWER
98 |
99 | class SymbolicMathMixin:
100 | """
101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
102 | """
103 |
104 | SUBSTITUTIONS = [ # used for text normalize
105 | ("an ", ""),
106 | ("a ", ""),
107 | (".$", "$"),
108 | ("\\$", ""),
109 | (r"\ ", ""),
110 | (" ", ""),
111 | ("mbox", "text"),
112 | (",\\text{and}", ","),
113 | ("\\text{and}", ","),
114 | ("\\text{m}", "\\text{}"),
115 | ]
116 | REMOVED_EXPRESSIONS = [ # used for text normalizer
117 | "square",
118 | "ways",
119 | "integers",
120 | "dollars",
121 | "mph",
122 | "inches",
123 | "ft",
124 | "hours",
125 | "km",
126 | "units",
127 | "\\ldots",
128 | "sue",
129 | "points",
130 | "feet",
131 | "minutes",
132 | "digits",
133 | "cents",
134 | "degrees",
135 | "cm",
136 | "gm",
137 | "pounds",
138 | "meters",
139 | "meals",
140 | "edges",
141 | "students",
142 | "childrentickets",
143 | "multiples",
144 | "\\text{s}",
145 | "\\text{.}",
146 | "\\text{\ns}",
147 | "\\text{}^2",
148 | "\\text{}^3",
149 | "\\text{\n}",
150 | "\\text{}",
151 | r"\mathrm{th}",
152 | r"^\circ",
153 | r"^{\circ}",
154 | r"\;",
155 | r",\!",
156 | "{,}",
157 | '"',
158 | "\\dots",
159 | ]
160 |
161 | def normalize_tex(self, final_answer: str) -> str:
162 | """
163 | Normalizes a string representing a mathematical expression.
164 | Used as a preprocessing step before parsing methods.
165 |
166 | Copied character for character from appendix D of Lewkowycz et al. (2022)
167 | """
168 | final_answer = final_answer.split("=")[-1]
169 |
170 | for before, after in self.SUBSTITUTIONS:
171 | final_answer = final_answer.replace(before, after)
172 | for expr in self.REMOVED_EXPRESSIONS:
173 | final_answer = final_answer.replace(expr, "")
174 |
175 | # Extract answer that is in LaTeX math, is bold,
176 | # is surrounded by a box, etc.
177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
182 |
183 | # Normalize shorthand TeX:
184 | # \fracab -> \frac{a}{b}
185 | # \frac{abc}{bef} -> \frac{abc}{bef}
186 | # \fracabc -> \frac{a}{b}c
187 | # \sqrta -> \sqrt{a}
188 | # \sqrtab -> sqrt{a}b
189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
191 | final_answer = final_answer.replace("$", "")
192 |
193 | # Normalize 100,000 -> 100000
194 | if final_answer.replace(",", "").isdigit():
195 | final_answer = final_answer.replace(",", "")
196 |
197 | return final_answer
198 |
199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
200 | """
201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression.
202 | Typically, you want to apply `normalize_text` as a preprocessing step.
203 | """
204 | try:
205 | with timeout(seconds=time_limit):
206 | parsed = parse_latex(text)
207 | except (
208 | # general error handling: there is a long tail of possible sympy/other
209 | # errors we would like to catch
210 | Exception
211 | ) as e:
212 | print(f"failed to parse {text} with exception {e}")
213 | return None
214 |
215 | return parsed
216 |
217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
218 | """
219 | Determines whether two sympy expressions are equal.
220 | """
221 | try:
222 | with timeout(seconds=time_limit):
223 | try:
224 | diff = x1 - x2
225 | except (SympifyError, ValueError, TypeError) as e:
226 | print(
227 | f"Couldn't subtract {x1} and {x2} with exception {e}"
228 | )
229 | return False
230 |
231 | try:
232 | if sympy.simplify(diff) == 0:
233 | return True
234 | else:
235 | return False
236 | except (SympifyError, ValueError, TypeError) as e:
237 | print(f"Failed to simplify {x1}-{x2} with {e}")
238 | return False
239 | except TimeoutError as e:
240 | print(f"Timed out comparing {x1} and {x2}")
241 | return False
242 | except Exception as e:
243 | print(f"failed on unrecognized exception {e}")
244 | return False
245 |
246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
247 | """
248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
249 |
250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence,
251 | following the (Lewkowycz et al. 2022) methodology.
252 | """
253 | if x1 == x2:
254 | # don't resort to sympy if we have full string match, post-normalization
255 | return True
256 | else:
257 | return False
258 | parsed_x2 = self.parse_tex(x2)
259 | if not parsed_x2:
260 | # if our reference fails to parse into a Sympy object,
261 | # we forgo parsing + checking our generated answer.
262 | return False
263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
264 |
--------------------------------------------------------------------------------
/evaluation/eval/python_executor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | from contextlib import redirect_stdout
4 | import pickle
5 | import regex
6 | import copy
7 | from typing import Any, Dict, Optional
8 | import multiprocess
9 | from pebble import ProcessPool
10 | from concurrent.futures import TimeoutError
11 | from functools import partial
12 | import traceback
13 | from timeout_decorator import timeout
14 |
15 | class GenericRuntime:
16 | GLOBAL_DICT = {}
17 | LOCAL_DICT = None
18 | HEADERS = []
19 | def __init__(self):
20 | self._global_vars = copy.copy(self.GLOBAL_DICT)
21 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
22 |
23 | for c in self.HEADERS:
24 | self.exec_code(c)
25 |
26 | def exec_code(self, code_piece: str) -> None:
27 | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
28 | raise RuntimeError()
29 | exec(code_piece, self._global_vars)
30 |
31 | def eval_code(self, expr: str) -> Any:
32 | return eval(expr, self._global_vars)
33 |
34 | def inject(self, var_dict: Dict[str, Any]) -> None:
35 | for k, v in var_dict.items():
36 | self._global_vars[k] = v
37 |
38 | @property
39 | def answer(self):
40 | return self._global_vars['answer']
41 |
42 | class PythonExecutor:
43 | def __init__(
44 | self,
45 | runtime: Optional[Any] = None,
46 | get_answer_symbol: Optional[str] = None,
47 | get_answer_expr: Optional[str] = None,
48 | get_answer_from_stdout: bool = False,
49 | ) -> None:
50 | self.runtime = runtime if runtime else GenericRuntime()
51 | self.answer_symbol = get_answer_symbol
52 | self.answer_expr = get_answer_expr
53 | self.get_answer_from_stdout = get_answer_from_stdout
54 |
55 | def process_generation_to_code(self, gens: str):
56 | batch_code = []
57 | for g in gens:
58 | multiline_comments = False
59 | code = []
60 | for line in g.split('\n'):
61 | strip_line = line.strip()
62 | if strip_line.startswith("#"):
63 | line = line.split("#", 1)[0] + "# comments"
64 | elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6:
65 | line = line.split('"""', 1)[0] + '"""comments"""'
66 | elif not multiline_comments and strip_line.startswith('"""'):
67 | multiline_comments = True
68 | elif multiline_comments and strip_line.endswith('"""'):
69 | multiline_comments = False
70 | line = ""
71 | if not multiline_comments:
72 | code.append(line)
73 | batch_code.append(code)
74 | return batch_code
75 |
76 | @staticmethod
77 | def execute(
78 | code,
79 | get_answer_from_stdout = None,
80 | runtime = None,
81 | answer_symbol = None,
82 | answer_expr = None,
83 | timeout_length = 10,
84 | ):
85 | try:
86 | if get_answer_from_stdout:
87 | program_io = io.StringIO()
88 | with redirect_stdout(program_io):
89 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
90 | program_io.seek(0)
91 | result = "".join(program_io.readlines()) # [-1]
92 | elif answer_symbol:
93 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
94 | result = runtime._global_vars[answer_symbol]
95 | elif answer_expr:
96 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
97 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
98 | else:
99 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
100 | result = timeout(timeout_length)(runtime.eval_code)(code[-1])
101 | concise_exec_info = ""
102 | exec_info = ""
103 | str(result)
104 | pickle.dumps(result) # serialization check
105 | except:
106 | # traceback.print_exc()
107 | result = ''
108 | concise_exec_info = traceback.format_exc().split('\n')[-2]
109 | exec_info = traceback.format_exc()
110 | if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info:
111 | exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip()
112 | msg = []
113 | for line in exec_info.split("\n"):
114 | patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line)
115 | if patt is not None:
116 | if '' in patt.group('end'):
117 | continue
118 | fname = patt.group("file")
119 | if "site-packages" in fname:
120 | fname = f"site-packages{fname.split('site-packages', 1)[1]}"
121 | line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
122 | else:
123 | line = f'{patt.group("start")}{patt.group("end")}'
124 | else:
125 | patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line)
126 | if patt is not None:
127 | line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
128 | msg.append(line)
129 | exec_info = "\n".join(msg)
130 | return result, concise_exec_info, exec_info
131 |
132 | def apply(self, code):
133 | return self.batch_apply([code])[0]
134 |
135 | def batch_apply(self, batch_code):
136 | all_code_snippets = self.process_generation_to_code(batch_code)
137 | all_exec_results = []
138 | executor = partial(
139 | self.execute,
140 | get_answer_from_stdout=self.get_answer_from_stdout,
141 | runtime=self.runtime,
142 | answer_symbol=self.answer_symbol,
143 | answer_expr=self.answer_expr,
144 | timeout_length=10,
145 | )
146 | with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
147 | iterator = pool.map(executor, all_code_snippets, timeout=10).result()
148 |
149 | while True:
150 | try:
151 | result = next(iterator)
152 | all_exec_results.append(result)
153 | except StopIteration:
154 | break
155 | except TimeoutError as error:
156 | all_exec_results.append(("", "Timeout Error", "Timeout Error"))
157 | except Exception as error:
158 | print(error)
159 | exit()
160 |
161 | batch_results = []
162 | for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results):
163 | metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info}
164 | batch_results.append((result, metadata))
165 | return batch_results
166 |
--------------------------------------------------------------------------------
/evaluation/eval/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from transformers import StoppingCriteria, GenerationConfig
4 |
5 | class KeyWordsCriteria(StoppingCriteria):
6 | def __init__(self, stop_id_sequences, tokenizer, prompt_length):
7 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
8 | self.tokenizer = tokenizer
9 | self.stop_id_sequences = stop_id_sequences
10 | self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences]
11 | print(f"stop sequences: {self.stop_sequences}", flush=True)
12 | self.prompt_length = prompt_length
13 |
14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
15 | sequences_should_be_stopped = []
16 | for i in range(input_ids.shape[0]):
17 | ids = input_ids[i][self.prompt_length:].tolist()
18 | should_be_stopped = False
19 | for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences):
20 | _ids = ids
21 | for j in range(len(_ids), 0, -1):
22 | s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j])
23 | if s.endswith(stop_sequence):
24 | should_be_stopped = True
25 | break
26 | if should_be_stopped:
27 | break
28 | sequences_should_be_stopped.append(should_be_stopped)
29 | return all(sequences_should_be_stopped)
30 |
31 | @torch.no_grad()
32 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs):
33 | generations = []
34 | finish_completion = []
35 | if not disable_tqdm:
36 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
37 |
38 | if stop_id_sequences is not None:
39 | stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences]
40 |
41 | if end_of_generation_id_sequence is not None:
42 | end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence)
43 |
44 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
45 | generation_kwargs['use_cache'] = True
46 | for i in range(0, len(prompts), batch_size):
47 | batch_prompts = prompts[i:i+batch_size]
48 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__))
49 | batch_input_ids = tokenized_prompts.input_ids
50 | attention_mask = tokenized_prompts.attention_mask
51 |
52 | if model.device.type == "cuda":
53 | batch_input_ids = batch_input_ids.cuda()
54 | attention_mask = attention_mask.cuda()
55 |
56 | batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences
57 | try:
58 | batch_outputs = model.generate(
59 | input_ids=batch_input_ids,
60 | attention_mask=attention_mask,
61 | stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None,
62 | **generation_kwargs
63 | )
64 |
65 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
66 | # so some outputs still have the stop sequence, which we need to remove.
67 | if stop_id_sequences:
68 | for output_idx in range(batch_outputs.shape[0]):
69 | finish = False
70 | for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
71 | if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences):
72 | if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence):
73 | batch_finish_completion[output_idx] = True
74 | batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
75 | break
76 |
77 | # remove the prompt from the output
78 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
79 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
80 | # space is important for some tasks (e.g., code completion).
81 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
82 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
83 | # duplicate the prompts to match the number of return sequences
84 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
85 | batch_generations = [
86 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
87 | ]
88 | except Exception as e:
89 | print("Error when generating completions for batch:")
90 | print(batch_prompts)
91 | print("Error message:")
92 | print(e)
93 | print("Use empty string as the completion.")
94 | batch_generations = [""] * len(batch_prompts) * num_return_sequences
95 |
96 | generations += batch_generations
97 | finish_completion += batch_finish_completion
98 |
99 | if not disable_tqdm:
100 | progress.update(len(batch_prompts)//num_return_sequences)
101 |
102 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
103 | return generations, finish_completion
104 |
105 |
106 | @torch.no_grad()
107 | def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False):
108 | predictions, probs = [], []
109 | if not disable_tqdm:
110 | progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")
111 |
112 | for i in range(0, len(prompts), batch_size):
113 | batch_prompts = prompts[i: i+batch_size]
114 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
115 | batch_input_ids = tokenized_prompts.input_ids
116 | attention_mask = tokenized_prompts.attention_mask
117 |
118 | if model.device.type == "cuda":
119 | batch_input_ids = batch_input_ids.cuda()
120 | attention_mask = attention_mask.cuda()
121 |
122 | batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
123 | if candidate_token_ids is not None:
124 | batch_logits = batch_logits[:, candidate_token_ids]
125 | batch_probs = torch.softmax(batch_logits, dim=-1)
126 | batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
127 | if return_token_predictions:
128 | if candidate_token_ids is not None:
129 | candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
130 | batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
131 | else:
132 | batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
133 | predictions += batch_predictions
134 | else:
135 | predictions += batch_prediction_indices.tolist()
136 | probs += batch_probs.tolist()
137 |
138 | if not disable_tqdm:
139 | progress.update(len(batch_prompts))
140 |
141 | assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
142 | return predictions, probs
143 |
144 |
145 | @torch.no_grad()
146 | def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False):
147 | '''
148 | Each scoring example is a dict, which contains the following keys:
149 | - prompt: the prompt to score
150 | - completions: a list of completions to score
151 | '''
152 |
153 | if not disable_tqdm:
154 | progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions")
155 |
156 | # unroll the scoring examples
157 | unrolled_examples = []
158 | for scoring_example in scoring_examples:
159 | prompt = scoring_example["prompt"]
160 | for completion in scoring_example["completions"]:
161 | unrolled_examples.append({
162 | "prompt": prompt,
163 | "completion": completion
164 | })
165 |
166 | scores = []
167 | # currently we don't support batching, because we want to directly use the loss returned by the model to score each completion.
168 | for unrolled_example in unrolled_examples:
169 | encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None)
170 | # unsqueeze the batch dimension
171 | for key, value in encoded_example.items():
172 | encoded_example[key] = value.unsqueeze(0)
173 | if model.device.type == "cuda":
174 | encoded_example = {
175 | key: value.cuda() for key, value in encoded_example.items()
176 | }
177 | outputs = model(**encoded_example)
178 | loss = outputs.loss
179 | scores.append(-loss.item())
180 | if not disable_tqdm:
181 | progress.update(1)
182 |
183 | # roll up the scores
184 | rolled_up_scores = {}
185 | for unrolled_example, score in zip(unrolled_examples, scores):
186 | prompt = unrolled_example["prompt"]
187 | completion = unrolled_example["completion"]
188 | if prompt not in rolled_up_scores:
189 | rolled_up_scores[prompt] = {}
190 | rolled_up_scores[prompt][completion] = score
191 |
192 | return rolled_up_scores
193 |
194 |
195 |
196 | def load_hf_lm_and_tokenizer(
197 | model_name_or_path,
198 | tokenizer_name_or_path=None,
199 | device_map="auto",
200 | load_in_8bit=False,
201 | load_in_half=False,
202 | gptq_model=False,
203 | use_fast_tokenizer=True,
204 | padding_side="left",
205 | ):
206 |
207 | from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
208 |
209 | if not tokenizer_name_or_path:
210 | tokenizer_name_or_path = model_name_or_path
211 |
212 | is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path
213 | is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path
214 |
215 | if is_chatglm2 or is_qwen:
216 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
217 | if is_qwen:
218 | tokenizer.eos_token = '<|endoftext|>'
219 | tokenizer.eos_token_id = 151643
220 | tokenizer.pad_token = tokenizer.eos_token
221 | tokenizer.pad_token_id = tokenizer.eos_token_id
222 | else:
223 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer)
224 | # set padding side to left for batch generation
225 | tokenizer.padding_side = padding_side
226 | # set pad token to eos token if pad token is not set (as is the case for llama models)
227 | if tokenizer.pad_token is None:
228 | tokenizer.pad_token = tokenizer.eos_token
229 | tokenizer.pad_token_id = tokenizer.eos_token_id
230 |
231 | if gptq_model:
232 | from auto_gptq import AutoGPTQForCausalLM
233 | model_wrapper = AutoGPTQForCausalLM.from_quantized(
234 | model_name_or_path, device="cuda:0", use_triton=True
235 | )
236 | model = model_wrapper.model
237 | elif load_in_8bit:
238 | model = AutoModelForCausalLM.from_pretrained(
239 | model_name_or_path,
240 | device_map=device_map,
241 | load_in_8bit=True
242 | )
243 | else:
244 | kwargs = {}
245 | model_class = AutoModelForCausalLM
246 | if is_chatglm2:
247 | kwargs = {'trust_remote_code': True}
248 | model_class = AutoModel
249 | elif is_qwen:
250 | kwargs = {'trust_remote_code': True}
251 | if device_map:
252 | model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs)
253 | else:
254 | model = model_class.from_pretrained(model_name_or_path, **kwargs)
255 | if torch.cuda.is_available():
256 | model = model.cuda()
257 | if is_qwen:
258 | model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
259 | model.generation_config.do_sample = False
260 | if not is_chatglm2 and not is_qwen and load_in_half:
261 | model = model.half()
262 | model.eval()
263 | return model, tokenizer
264 |
--------------------------------------------------------------------------------
/evaluation/evaluation_results.json:
--------------------------------------------------------------------------------
1 | {
2 | "DeepSeekMath-Base": {
3 | "OCWCourses": {
4 | "cot": {
5 | "accuracy": 0.15441176470588236,
6 | "n_samples": 272
7 | },
8 | "tool": {
9 | "n_samples": 0
10 | }
11 | },
12 | "cmath-cot-test": {
13 | "cot": {
14 | "accuracy": 0.7167577413479053,
15 | "n_samples": 1098
16 | },
17 | "tool": {
18 | "n_samples": 0
19 | }
20 | },
21 | "miniF2F-Isabelle-test": {
22 | "cot": {
23 | "accuracy": 1.0,
24 | "n_samples": 244
25 | },
26 | "tool": {
27 | "n_samples": 0
28 | }
29 | },
30 | "gsm8k-cot-test": {
31 | "cot": {
32 | "accuracy": 0.6421531463229719,
33 | "n_samples": 1319
34 | },
35 | "tool": {
36 | "n_samples": 0
37 | }
38 | },
39 | "MMLU-STEM-test": {
40 | "cot": {
41 | "accuracy": 0.5646123260437376,
42 | "n_samples": 3018
43 | },
44 | "tool": {
45 | "n_samples": 0
46 | }
47 | },
48 | "agieval-gaokao-mathqa-cot-test": {
49 | "cot": {
50 | "accuracy": 0.35327635327635326,
51 | "n_samples": 351
52 | },
53 | "tool": {
54 | "n_samples": 0
55 | }
56 | },
57 | "agieval-gaokao-mathcloze-cot-test": {
58 | "cot": {
59 | "accuracy": 0.2033898305084746,
60 | "n_samples": 118
61 | },
62 | "tool": {
63 | "n_samples": 0
64 | }
65 | },
66 | "gsm8k-pal-test": {
67 | "cot": {
68 | "n_samples": 0
69 | },
70 | "tool": {
71 | "accuracy": 0.66868840030326,
72 | "n_samples": 1319
73 | }
74 | },
75 | "math_sat": {
76 | "cot": {
77 | "accuracy": 0.84375,
78 | "n_samples": 32
79 | },
80 | "tool": {
81 | "n_samples": 0
82 | }
83 | },
84 | "miniF2F-Isabelle-valid": {
85 | "cot": {
86 | "accuracy": 1.0,
87 | "n_samples": 244
88 | },
89 | "tool": {
90 | "n_samples": 0
91 | }
92 | },
93 | "math-pal-test": {
94 | "cot": {
95 | "n_samples": 0
96 | },
97 | "tool": {
98 | "accuracy": 0.3142,
99 | "n_samples": 5000
100 | }
101 | },
102 | "math-cot-test": {
103 | "cot": {
104 | "accuracy": 0.3618,
105 | "n_samples": 5000
106 | },
107 | "tool": {
108 | "n_samples": 0
109 | }
110 | }
111 | },
112 | "DeepSeekMath-RL": {
113 | "mgsm-zh": {
114 | "cot": {
115 | "accuracy": 0.796,
116 | "n_samples": 250
117 | },
118 | "tool": {
119 | "accuracy": 0.784,
120 | "program_accuracy": 0.776,
121 | "n_samples": 250
122 | }
123 | },
124 | "cmath": {
125 | "cot": {
126 | "accuracy": 0.8879781420765027,
127 | "n_samples": 1098
128 | },
129 | "tool": {
130 | "accuracy": 0.8761384335154827,
131 | "program_accuracy": 0.8570127504553734,
132 | "n_samples": 1098
133 | }
134 | },
135 | "math-test": {
136 | "cot": {
137 | "accuracy": 0.517,
138 | "n_samples": 5000
139 | },
140 | "tool": {
141 | "accuracy": 0.5878,
142 | "program_accuracy": 0.509,
143 | "n_samples": 5000
144 | }
145 | },
146 | "gsm8k-test": {
147 | "cot": {
148 | "accuracy": 0.8824867323730099,
149 | "n_samples": 1319
150 | },
151 | "tool": {
152 | "accuracy": 0.866565579984837,
153 | "program_accuracy": 0.868081880212282,
154 | "n_samples": 1319
155 | }
156 | }
157 | },
158 | "DeepSeekMath-Instruct": {
159 | "gsm8k-test": {
160 | "cot": {
161 | "accuracy": 0.8286580742987112,
162 | "n_samples": 1319
163 | },
164 | "tool": {
165 | "accuracy": 0.8369977255496588,
166 | "program_accuracy": 0.8332069749810462,
167 | "n_samples": 1319
168 | }
169 | },
170 | "math-test": {
171 | "cot": {
172 | "accuracy": 0.4682,
173 | "n_samples": 5000
174 | },
175 | "tool": {
176 | "accuracy": 0.575,
177 | "program_accuracy": 0.4664,
178 | "n_samples": 5000
179 | }
180 | },
181 | "cmath": {
182 | "cot": {
183 | "accuracy": 0.8460837887067395,
184 | "n_samples": 1098
185 | },
186 | "tool": {
187 | "accuracy": 0.843351548269581,
188 | "program_accuracy": 0.8214936247723132,
189 | "n_samples": 1098
190 | }
191 | },
192 | "mgsm-zh": {
193 | "cot": {
194 | "accuracy": 0.732,
195 | "n_samples": 250
196 | },
197 | "tool": {
198 | "accuracy": 0.72,
199 | "program_accuracy": 0.716,
200 | "n_samples": 250
201 | }
202 | }
203 | }
204 | }
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/__init__.py:
--------------------------------------------------------------------------------
1 | from .cot_minerva_math_4_shot import MinervaMathPrompt
2 | from .cot_gsm_8_shot import CoTGSMPrompt
3 | from .cot_math_sat_4_shot import CoTSATPrompt
4 | from .cot_mmlu_stem_4_shot import MMLUSTEMPrompt
5 | from .cot_ocwcourses_4_shot import OCWCoursesPrompt
6 | from .pal_gsm_8_shot import PALGSMPrompt
7 | from .pal_math_4_shot import PALMathPrompt
8 | from .minif2f_isabelle import MiniF2FIsabellePrompt
9 | from .cot_cmath_6_shot import CoTCMATHPrompt
10 | from .cot_gaokao_mathcloze_5_shot import CoTGaoKaoMathClozePrompt
11 | from .cot_gaokao_mathqa_5_shot import CoTGaoKaoMathQAPrompt
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_cmath_6_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """
4 | 问题:芳芳买了一本书有99页,看了90页,她还剩多少页没有看?
5 | 答案:还剩的没有看的页数=书的总页数-芳芳看了的页数,99-90=9。所以答案是:9。
6 |
7 |
8 | 问题:张师傅上午修了18把椅子,下午修了29把椅子,一天共修了多少把椅子?
9 | 答案:一天共修的椅子数量=上午修的椅子数量+下午修的椅子数量,18+29=47。所以答案是:47。
10 |
11 |
12 | 问题:小猴摘了84个桃子,平均分给6只猴子,每只猴子能吃到几个桃子?
13 | 答案:每只猴子能吃到的桃子数=总桃子数/猴子的数量,84/6=14。所以答案是:14。
14 |
15 |
16 | 问题:用面包机烤面包时,第一面烤2分钟,第二面只要烤1分钟,即烤一片面包需要3分钟,小勤的面包机一次只能放2片,他每天早上吃3片面包,至少需要烤多少分钟?
17 | 答案:可以现将两片面包放入面包机烤2分钟,再将其中一片拿出来,将第三片面包放进去,烤1分钟,这样第一片面包就烤好了,将第一片面包拿出来将第二片面包放进去,继续烤1分钟,于是第二片面包也烤好了将其拿出来,第三片面包再烤1分钟也就烤好了,一共是2+1+1=5。所以答案是:5。
18 |
19 |
20 | 问题:一组学生植树,每人栽6棵还剩4棵;如果其中3人各栽5棵,其余每人各栽7棵,正好栽完。这一组学生有多少人?
21 | 答案:假设学生的数量是x,每人栽6棵还剩4棵,也就是说树苗的数量=6x+4,又知道如果其中3人各栽5棵,其余每人各栽7棵,正好栽完,即6x+4=3*5+(x-3)*7,化简方程得到:x=10。所以答案是:10。
22 |
23 |
24 | 问题:某小学在“献爱心--为汶川地震区捐款”活动中,六年级五个班共捐款8000元,其中一班捐款1500元,二班比一班多捐款200元,三班捐款1600元,四班与五班捐款数之比是3:5.四班捐款多少元?
25 | 答案:一班捐款1500元,而二班比一班多捐200元,所以二班捐款1500+200=1700元,又知道六年级五个班一共捐款8000元,所以四班和五班捐款之和 = 一共捐款 - 一班和二班和三班捐款之和,即8000-1500-1700-1600=3200元,而题目说四班与五班捐款数之比是3:5,则四班捐款了3200/(3+5)*3=1200元。所以答案是:1200。
26 | """.strip()
27 |
28 | class CoTCMATHPrompt(FewShotPrompting):
29 | def __init__(self):
30 | super().__init__()
31 |
32 | def format_prompt(self, task_input, task_output):
33 | prompt = f"{few_shot_prompt}\n\n\n问题:{task_input}\n答案:{task_output}"
34 | return prompt.rstrip()
35 |
36 | def stop_words(self):
37 | return ["\n问题:"]
38 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_gaokao_mathcloze_5_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """
4 | 问题 1. 设数列 $\\left\\{a_{n}\\right\\}$ 的前 $n$ 项和为 $S_{n}$, 且 $a_{1}=-1, a_{n+1}=S_{n+1} S_{n}$, 则 $S_{n}=(\\quad)$.
5 | 问题 1的解析: 让我们写出这个数列的前n项和:
6 | $S_n = a_1 + a_2 + ... + a_n$
7 | $S_n = -1 + (S_2 S_1) + (S_3 S_2) + ... + (S_{n+1} S_n)$
8 | $S_n = -1 + (S_2 S_1) + (S_3 S_2) + ... + (S_n S_{n-1}) + (S_{n+1} S_n)$
9 | $S_n = -1 + S_n (S_{n+1} - S_1)$
10 | $S_n - S_n S_{n+1} = -1 - S_n$
11 | $S_n (1 - S_{n+1}) = -1 - S_n$
12 | $S_n = -\\frac{1}{1 - S_{n+1}}$
13 | 因为这个数列后面的所有项都是0,我们可以看到对于所有$n\\geq 1$,$S_{n+1} = 0$。因此,我们有:
14 | $S_n = -\\frac{1}{1 - S_{n+1}} = -\\frac{1}{1 - 0} = -1$
15 | 这个数列前n项和的公式是$S_n = -\\frac{1}{n}$。
16 | 答案是 $-\\frac{1}{n}$
17 |
18 |
19 | 问题 2. 若 $\\left(x+\\frac{1}{x}\\right)^{n}$ 的展开式中第 3 项与第 7 项的二项式系数相等, 则该展 开式中 $\\frac{1}{x^{2}}$ 的系数为 $(\\quad)$.
20 | 问题 2的解析: 由题意可得, $c_{n}^{2}=c_{n}^{6}$
21 | $\\therefore n=8$
22 | 展开式的通项 T_{r+1}=C_8^r x^{8-r}\\left(\\frac{1}{x}\\right)^r=C_8^r x^{8-2 r}$
23 | 令 $8-2 r=-2$ 可得 $r=5$
24 | 此时系数为 $c_{8}^{5}=56$
25 | 答案是 56
26 |
27 |
28 | 问题 3. 函数 $\\mathrm{f}(\\mathrm{x})=\\sin (\\mathrm{x}+2 \\phi)-2 \\sin \\phi \\cos (\\mathrm{x}+\\phi)$ 的最大值为 $(\\quad)$.
29 | 问题 3的解析: 函数 $f(x)=\\sin (x+2 \\phi)-2 \\sin \\phi \\cos (x+\\phi)=\\sin [(x+\\phi)+\\phi]-$ $2 \\sin \\phi \\cos (x+\\phi)$
30 | $=\\sin (x+\\phi) \\cos \\phi+\\cos (x+\\phi) \\sin \\phi-2 \\sin \\phi \\cos (x+\\phi)=\\sin (x+\\phi) \\cos \\phi-\\cos$ $(x+\\phi) \\sin \\phi$ $=\\sin [(x+\\phi)-\\phi]=\\sin x$
31 | 故函数 $f(x)$ 的最大值为 1
32 | 答案是 1
33 |
34 |
35 | 问题 4. 已知向量 $\\vec{a}=(3,1), \\vec{b}=(1,0), \\vec{c}=\\vec{a}+k \\vec{b}$. 若 $\\vec{a} \\perp \\vec{c}$, 则 $k=(\\quad)$
36 | 问题 4的解析: \\because \\vec{a}=(3,1), \\vec{b}=(1,0), \\therefore \\vec{c}=\\vec{a}+k \\vec{b}=(3+k, 1)$ ,
37 | $\\because \\vec{a} \\perp \\vec{c}, \\therefore \\vec{a} \\square \\vec{c}=3(3+k)+1 \\times 1=0$, 解得 $k=-\\frac{10}{3}$
38 | 答案是 $-\\frac{10}{3}$
39 |
40 |
41 | 问题 5. 设向量 $\\vec{a}, \\vec{b}$ 不平行, 向量 $\\lambda \\vec{a}+\\vec{b}$ 与 $\\vec{a}+2 \\vec{b}$ 平行, 则实数 $\\lambda=(\\quad)$.
42 | 问题 5的解析: $\\because$ 向量 $\\vec{a}, \\vec{b}$ 不平行, 向量 $\\lambda \\vec{a}+\\vec{b}$ 与 $\\vec{a}+2 \\vec{b}$ 平行,
43 | $\\therefore \\lambda \\vec{a}+\\vec{b}=t(\\vec{a}+2 \\vec{b})=t \\vec{a}+2 t \\vec{b}$
44 | $\\therefore\\left\\{\\begin{array}{c}\\lambda=\\mathrm{t} \\\\ 1=2 \\mathrm{t},\\end{array}\\right.$ 解得实数 $\\lambda=\\frac{1}{2}$.
45 | 答案是 $\\frac{1}{2}$
46 | """.strip()
47 |
48 | class CoTGaoKaoMathClozePrompt(FewShotPrompting):
49 | def __init__(self):
50 | super().__init__()
51 |
52 | def format_prompt(self, task_input, task_output):
53 | prompt = f"{few_shot_prompt}\n\n\n问题 6. {task_input}\n问题 6的解析: {task_output}"
54 | return prompt.rstrip()
55 |
56 | def stop_words(self):
57 | return ["\n问题 "]
58 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_gaokao_mathqa_5_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """
4 | 问题 1. 已知 $\\alpha, \\beta, \\gamma$ 是互不相同的锐角, 则在 $\\sin \\alpha \\cos \\beta, \\sin \\beta \\cos \\gamma, \\sin \\gamma \\cos \\alpha$ 三个值中, 大于 $\\frac{1}{2}$ 的个数的最大值是 ($\\quad$)
5 | 从以下选项中选择: (A)0 (B)1 (C)2 (D)3
6 | 问题 1的解析: 1. 如果 $\\alpha, \\beta, \\gamma$ 均小于 $60^\\circ$,那么他们的正弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。
7 | 2. 如果有一个角大于 $60^\\circ$,假设为 $\\alpha$,那么对应的正弦值大于 $\\frac{1}{2}$。此时,由于三角形内角和为 $180^\\circ$,所以 $\\beta + \\gamma < 120^\\circ$。这意味着 $\\beta, \\gamma$ 的余弦值均大于 $\\frac{1}{2}$,所以此时 $\\sin \\alpha \\cos \\beta > \\frac{1}{2}, \\sin \\beta \\cos \\gamma > \\frac{1}{2}$。
8 | 3. 如果有两个角大于 $60^\\circ$,例如 $\\alpha$ 和 $\\beta$,那么由于三角形内角和为 $180^\\circ$,我们可以得到 $\\gamma < 60^\\circ$,此时 $\\sin \\gamma < \\frac{1}{2}$。由于 $\\alpha$ 和 $\\beta$
9 | 的余弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。
10 | 4. 如果三个角都大于 $60^\\circ$,显然不符合题意。
11 | 综上所述,当有一个角大于 $60^\\circ$ 时,大于 $\\frac{1}{2}$ 的个数的最大值是 2。
12 | 答案是 C
13 |
14 |
15 | 问题 2. 正方体 $A B C D-A_{1} B_{1} C_{1} D_{1}$ 中, $B B_{1}$ 与平面 $A C D_{1}$ 所成角的余弦值为 ($\\qquad$)
16 | 从以下选项中选择: (A)$\\frac{\\sqrt{2}}{3}$ (B)$\\frac{\\sqrt{3}}{3}$ (C)$\\frac{2}{3}$ (D)$\\frac{\\sqrt{6}}{3}$
17 | 问题 2的解析: 设上下底面的中心分别为 $\\mathrm{O}_{1}, \\mathrm{O}$, 设正方体的棱长等于 1 , 则 $O_{1} O$ 与平面 $A C D_{1}$ 所成角就是 $B B_{1}$ 与平面 $A C D_{1}$ 所成角, 即 $\\angle O_{1} O D_{1}$,
18 | 直角三角形 $\\mathrm{OO}_{1} \\mathrm{D}_{1}$ 中, $\\cos \\angle \\mathrm{O}_{1} \\mathrm{OD}_{1}=\\frac{\\mathrm{O}_{1} \\mathrm{O}}{\\mathrm{OD}_{1}}=\\frac{\\frac{1}{\\sqrt{6}}}{2}=\\frac{\\sqrt{6}}{3}$.
19 | 答案是 C
20 |
21 |
22 | 问题 3. 设函数 $f(x)=\\left\\{\\begin{array}{ll}1+\\log _{2}(2-x), & x<1 \\ 2^{x-1}, & x \\geqslant 1,\\end{array}\\right.$ 则 $f(-2)+f\\left(\\log _{2} 12\\right)=$ ($\\qquad$)
23 | 从以下选项中选择: (A)3 (B)6 (C)9 (D)12
24 | 问题 3的解析: 首先,我们可以根据定义计算 $f(-2)$ 和 $f(\\log_2 12)$:
25 | $f(-2)=1+\\log_2(2-(-2))=1+\\log_2 4=3$
26 | $f(\\log_2 12)=2^{\\log_2 12-1}=6$
27 | 因此,$f(-2)+f(\\log_2 12)=3+6=9$。
28 | 答案是 C
29 |
30 |
31 | 问题 4. 已知函数 $f(x)=a x^{3}-3 x^{2}+1$, 若 $f(x)$ 存在唯一的零点 $x_{0}$, 且 $x_{0}>$ 0 , 则实数 $\\mathrm{a}$ 的取值范围是 ($\\qquad$)
32 | 从以下选项中选择: (A)$(1,+\\infty)$ (B)$(2,+\\infty)$ (C)$(-\\infty,-1)$ (D)$(-\\infty,-2)$
33 | 问题 4的解析: 首先,我们可以通过求出函数的导函数 $f'(x)$ 来判断函数在 $x>0$ 区间内的单调性。在这里,我们求出导函数 $f'(x)$ 为 $f'(x)=3ax^2-6x$。
34 | 然后,我们需要求出导函数 $f'(x)$ 的零点,以确定函数 $f(x)$ 在 $x>0$ 区间内的单调性。导函数 $f'(x)$ 的零点为 $x=0$ 和 $x=\\frac{2}{\\sqrt{a}}$。注意到 $x>0$,所以我们得到 $a<0$。此外,由于函数 $f(x)$ 在 $x=0$ 处的函数值为 $1$,因此不能有 $a=\\frac{4}{3}$。
35 | 综上所述,当 $a$ 的取值范围为 $a<-\\frac{4}{3}$ 时,函数 $f(x)$ 在 $x>0$ 区间内是单调递减的,此时存在唯一的零点 $x_0$。因此,答案为 $(-\\infty,-2)$。
36 | 答案是 D
37 |
38 |
39 | 问题 5. 设 $\\left\\{a_{n}\\right\\}$ 是公差不为 0 的无穷等差数列, 则“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的 ($\\quad$)
40 | 从以下选项中选择: (A)充分而不必要条件 (B)必要而不充分条件 (C)充分必要条件 (D)既不充分也不必要条件
41 | 问题 5的解析: 首先,我们可以通过举例来判断该条件是充分还是必要条件。如果一个数列递增,那么它的公差一定大于 0,也就是存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$。因此,“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的必要条件。
42 | 接下来,我们需要判断是否充分。也就是说,如果存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$,那么能否得出“ $\\left\\{a_{n}\\right\\}$ 为递增数列”这一结论。
43 | 答案是肯定的。因为如果 $a_{n}>0$,那么 $a_{n+1}-a_{n}>0$,即公差大于 0,因此该数列是递增的。因此,该条件是充分条件。
44 | 综上所述,选项为 (C) 充分必要条件。
45 | 答案是 C
46 | """.strip()
47 |
48 | class CoTGaoKaoMathQAPrompt(FewShotPrompting):
49 | def __init__(self):
50 | super().__init__()
51 |
52 | def format_prompt(self, task_input, task_output):
53 | prompt = f"{few_shot_prompt}\n\n\n问题 6. {task_input}\n问题 6的解析: {task_output}"
54 | return prompt.rstrip()
55 |
56 | def stop_words(self):
57 | return ["\n问题 "]
58 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_gsm_8_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """
4 | Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
5 | A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.
6 |
7 |
8 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
9 | A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.
10 |
11 |
12 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
13 | A: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.
14 |
15 |
16 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
17 | A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.
18 |
19 |
20 | Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
21 | A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.
22 |
23 |
24 | Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
25 | A: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.
26 |
27 |
28 | Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
29 | A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.
30 |
31 |
32 | Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
33 | A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.
34 | """.strip()
35 |
36 | class CoTGSMPrompt(FewShotPrompting):
37 | def __init__(self):
38 | super().__init__()
39 |
40 | def format_prompt(self, task_input, task_output):
41 | prompt = f"{few_shot_prompt}\n\n\nQ: {task_input}\nA: {task_output}"
42 | return prompt.rstrip()
43 |
44 | def stop_words(self):
45 | return ["\nQ:"]
46 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_math_sat_4_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """
4 | Problem:
5 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.
6 | What of the following is the right choice? Explain your answer.
7 | (A) [-5,-2), (B) [2,5), (C) [-2,-5), (D) [5,2)
8 | Solution:
9 | The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.
10 | Therefore, the domain of the expression is $\\boxed{[2,5)}$.
11 | Final Answer: The final answer is (B). I hope it is correct.
12 |
13 | Problem:
14 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$
15 | What of the following is the right choice? Explain your answer.
16 | (A) 14, (B) 4, (C) 2, (D) 24
17 | Solution:
18 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$
19 | Final Answer: The final answer is (D). I hope it is correct.
20 |
21 | Problem:
22 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?
23 | What of the following is the right choice? Explain your answer.
24 | (A) 12, (B) 20, (C) 16, (D) 15
25 | Solution:
26 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}
27 | 30n&=480\\\\
28 | \\Rightarrow\\qquad n&=480/30=\\boxed{16}
29 | \\end{align*}
30 | Final Answer: The final answer is (C). I hope it is correct.
31 |
32 | Problem:
33 | If the system of equations
34 |
35 | \\begin{align*}
36 | 6x-4y&=a,\\\\
37 | 6y-9x &=b.
38 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is
39 | nonzero.
40 | What of the following is the right choice? Explain your answer.
41 | (A) $-\\frac{2}{3}$, (B) $\\frac{2}{3}$, (C) $\\frac{1}{3}$, (D) $\\frac{4}{9}$
42 | Solution:
43 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain
44 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have
45 |
46 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$
47 | Final Answer: The final answer is (A). I hope it is correct.
48 | """.strip()
49 |
50 | class CoTSATPrompt(FewShotPrompting):
51 | def __init__(self):
52 | super().__init__()
53 |
54 | def format_prompt(self, task_input, task_output):
55 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\nSolution:\n{task_output}"
56 | return prompt.rstrip()
57 |
58 | def stop_words(self):
59 | return ["\nProblem:"]
60 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_minerva_math_4_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """Problem:
4 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}
5 |
6 | Solution:
7 | The expressions inside each square root must be non-negative.
8 | Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.
9 | Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.
10 | Therefore, the domain of the expression is $\\boxed{[2,5)}$.
11 | Final Answer: The final answer is $[2,5)$. I hope it is correct.
12 |
13 | Problem:
14 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$
15 |
16 | Solution:
17 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$
18 | Final Answer: The final answer is $24$. I hope it is correct.
19 |
20 | Problem:
21 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?
22 |
23 | Solution:
24 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}
25 | 30n&=480\\\\
26 | \\Rightarrow\\qquad n&=480/30=\\boxed{16}
27 | \\end{align*}
28 | Final Answer: The final answer is $16$. I hope it is correct.
29 |
30 | Problem:
31 | If the system of equations
32 |
33 | \\begin{align*}
34 | 6x-4y&=a,\\\\
35 | 6y-9x &=b.
36 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.
37 |
38 | Solution:
39 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain
40 |
41 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have
42 |
43 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$
44 | Final Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct."""
45 |
46 | class MinervaMathPrompt(FewShotPrompting):
47 | def __init__(self):
48 | super().__init__()
49 |
50 | def format_prompt(self, task_input, task_output):
51 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\n\nSolution:\n{task_output}"
52 | return prompt.rstrip()
53 |
54 | def stop_words(self):
55 | return ["\nProblem:"]
56 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_mmlu_stem_4_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """Problem:
4 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.
5 | What of the following is the right choice? Explain your answer.
6 | (A) [-5,-2), (B) [2,5), (C) [-2,-5), (D) [5,2)
7 | Solution:
8 | The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.
9 | Therefore, the domain of the expression is $\\boxed{[2,5)}$.
10 | Final Answer: The final answer is (B). I hope it is correct.
11 |
12 | Problem:
13 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$
14 | What of the following is the right choice? Explain your answer.
15 | (A) 14, (B) 4, (C) 2, (D) 24
16 | Solution:
17 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$
18 | Final Answer: The final answer is (D). I hope it is correct.
19 |
20 | Problem:
21 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?
22 | What of the following is the right choice? Explain your answer.
23 | (A) 12, (B) 20, (C) 16, (D) 15
24 | Solution:
25 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}
26 | 30n&=480\\\\
27 | \\Rightarrow\\qquad n&=480/30=\\boxed{16}
28 | \\end{align*}
29 | Final Answer: The final answer is (C). I hope it is correct.
30 |
31 | Problem:
32 | If the system of equations
33 |
34 | \\begin{align*}
35 | 6x-4y&=a,\\\\
36 | 6y-9x &=b.
37 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is
38 | nonzero.
39 | What of the following is the right choice? Explain your answer.
40 | (A) $-\\frac{2}{3}$, (B) $\\frac{2}{3}$, (C) $\\frac{1}{3}$, (D) $\\frac{4}{9}$
41 | Solution:
42 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain
43 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have
44 |
45 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$
46 | Final Answer: The final answer is (A). I hope it is correct."""
47 |
48 | class MMLUSTEMPrompt(FewShotPrompting):
49 | def __init__(self):
50 | super().__init__()
51 |
52 | def format_prompt(self, task_input, task_output):
53 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\nSolution:\n{task_output}"
54 | return prompt.rstrip()
55 |
56 | def stop_words(self):
57 | return ["\nProblem:"]
58 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/cot_ocwcourses_4_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """
4 | Problem:
5 | Subproblem 0: What is the net charge of arginine in a solution of $\\mathrm{pH} 1.0$?
6 | Please format your answer as +n or -n.
7 | Solution:
8 | The answer is +2.
9 | Final answer: The final answer is $\\boxed{+2}$. I hope it is correct.
10 |
11 | Problem:
12 | Subproblem 0: Let $z = 1 + \\sqrt{3} i$. Find $a, b$ that satisfy the equation
13 | $z^4 = a + bi$. Express your answer as the ordered pair $(a,b)$.
14 | Solution:
15 | $z^{4}$ has argument $4 \\pi / 3$ and radius 16 , so it's equal to $-8-8 \\sqrt{3} i$.
16 | Thus $a = -8, b = -8\\sqrt 3$, and our answer is $(-8, -8\\sqrt{3})$.
17 | Final answer: The final answer is $\\boxed{(-8, -8\\sqrt{3})}$. I hope it is correct.
18 |
19 | Problem:
20 | Preamble: For each Laplace Transform \\(Y(s)\\), find the function \\(y(t)\\):
21 | Subproblem 0:
22 | \\[Y(s)=\\frac{1}{(s+a)(s+b)}\\]
23 | Solution:
24 | We can simplify with partial fractions:
25 | \\[Y(s)=\\frac{1}{(s+a)(s+b)}=\\frac{C}{s+a}+\\frac{D}{s+b}\\]
26 | find the constants
27 | \\(C\\) and \\(D\\) by setting \\(s=-a\\) and \\(s=-b\\)
28 | \\[
29 | \\begin{aligned}
30 | \\frac{1}{(s+a)(s+b)} &=\\frac{C}{s+a}+\\frac{D}{s+b} \\\\
31 | 1 &=C(s+b)+D(s+a) \\\\
32 | C &=\\frac{1}{b-a} \\\\
33 | D &=\\frac{1}{a-b}
34 | \\end{aligned}
35 | \\]
36 | therefore
37 | \\[
38 | Y(s)=\\frac{1}{b-a} \\frac{1}{s+a}-\\frac{1}{b-a} \\frac{1}{s+b}
39 | \\]
40 | By looking up the inverse Laplace Transform of \\(\\frac{1}{s+b}\\), we find the total
41 | solution \\(y(t)\\)
42 | \\[
43 | y(t)=\\frac{e^{-a t}-e^{-b t}}{b-a}
44 | \\].
45 | Final answer: The final answer is $\\boxed{\\frac{e^{-a t}-e^{-b t}}{b-a}}$. I hope it is correct.
46 |
47 | Problem:
48 | Preamble: The following subproblems refer to the differential equation
49 | $\\ddot{x}+b \\dot{x}+x=0$.
50 | Subproblem 0: What is the characteristic polynomial $p(s)$ of
51 | $\\ddot{x}+b \\dot{x}+x=0$?
52 | Solution:
53 | The characteristic polynomial is $p(s)=s^{2}+b s+1$.
54 | Final answer: The final answer is $\\boxed{s^{2}+b s+1}$. I hope it is correct.
55 | """.strip()
56 |
57 | few_shot_prompt = """
58 | Problem:
59 | Subproblem 0: What is the net charge of arginine in a solution of $\\mathrm{pH} 1.0$?
60 | Please format your answer as +n or -n.
61 | Solution:
62 | The answer is +2.
63 | Final answer: The final answer is +2. I hope it is correct.
64 |
65 | Problem:
66 | Subproblem 0: Let $z = 1 + \\sqrt{3} i$. Find $a, b$ that satisfy the equation
67 | $z^4 = a + bi$. Express your answer as the ordered pair $(a,b)$.
68 | Solution:
69 | $z^{4}$ has argument $4 \\pi / 3$ and radius 16 , so it's equal to $-8-8 \\sqrt{3} i$.
70 | Thus $a = -8, b = -8\\sqrt 3$, and our answer is $\\boxed{(-8, -8\\sqrt{3})}$.
71 | Final answer: The final answer is (-8, -8\\sqrt{3}). I hope it is correct.
72 |
73 | Problem:
74 | Preamble: For each Laplace Transform \\(Y(s)\\), find the function \\(y(t)\\):
75 | Subproblem 0:
76 | \\[Y(s)=\\boxed{\\frac{1}{(s+a)(s+b)}}\\]
77 | Solution:
78 | We can simplify with partial fractions:
79 | \\[Y(s)=\\frac{1}{(s+a)(s+b)}=\\frac{C}{s+a}+\\frac{D}{s+b}\\]
80 | find the constants
81 | \\(C\\) and \\(D\\) by setting \\(s=-a\\) and \\(s=-b\\)
82 | \\[
83 | \\begin{aligned}
84 | \\frac{1}{(s+a)(s+b)} &=\\frac{C}{s+a}+\\frac{D}{s+b} \\\\
85 | 1 &=C(s+b)+D(s+a) \\\\
86 | C &=\\frac{1}{b-a} \\\\
87 | D &=\\frac{1}{a-b}
88 | \\end{aligned}
89 | \\]
90 | therefore
91 | \\[
92 | Y(s)=\\frac{1}{b-a} \\frac{1}{s+a}-\\frac{1}{b-a} \\frac{1}{s+b}
93 | \\]
94 | By looking up the inverse Laplace Transform of \\(\\frac{1}{s+b}\\), we find the total
95 | solution \\(y(t)\\)
96 | \\[
97 | y(t)=\\boxed{\\frac{1}{b-a}\\left(e^{-a t}-e^{-b t}\\right)}
98 | \\].
99 | Final answer: The final answer is \\[\\frac{1}{b-a}\\left(e^{-a t}-e^{-b t}\\right)\\]. I hope it is correct.
100 |
101 | Problem:
102 | Preamble: The following subproblems refer to the differential equation
103 | $\\ddot{x}+b \\dot{x}+x=0$.
104 | Subproblem 0: What is the characteristic polynomial $p(s)$ of
105 | $\\ddot{x}+b \\dot{x}+x=0$?
106 | Solution:
107 | The characteristic polynomial is $p(s)=\\boxed{s^{2}+b s+1}$.
108 | Final answer: The final answer is $s^{2}+b s+1$. I hope it is correct.
109 | """.strip()
110 |
111 | class OCWCoursesPrompt(FewShotPrompting):
112 | def __init__(self):
113 | super().__init__()
114 |
115 | def format_prompt(self, task_input, task_output):
116 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\nSolution:\n{task_output}"
117 | return prompt.rstrip()
118 |
119 | def stop_words(self):
120 | return ["\nProblem:"]
121 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/few_shot_prompting.py:
--------------------------------------------------------------------------------
1 | class FewShotPrompting:
2 | def __init__(self):
3 | pass
4 |
5 | def format_prompt(self, task_input, task_output):
6 | pass
7 |
8 | def stop_words(self):
9 | pass
10 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/minif2f_isabelle.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = {
4 | 'numbertheory': """Informal:
5 | (*### Problem
6 |
7 | Find the minimum value of $\\frac{9x^2\\sin^2 x + 4}{x\\sin x}$ for $0 < x < \\pi$. Show that it is 12.
8 |
9 | ### Solution
10 |
11 | Let $y = x \\sin x$. It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}.
12 | It is trivial to see that $y > 0$.
13 | Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$.
14 | This can be done by the sum of squares method.*)
15 |
16 | Formal:
17 | theorem aime_1983_p9:
18 | fixes x::real
19 | assumes "0 ((9 * (x^2 * (sin x)^2)) + 4) / (x * sin x)"
21 | proof -
22 | (* Let $y = x \\sin x$. *)
23 | define y where "y=x * sin x"
24 | (* It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}. *)
25 | have "12 \\ (9 * y^2 + 4) / y"
26 | proof -
27 | (* It is trivial to see that $y > 0$. *)
28 | have c0: "y > 0"
29 | sledgehammer
30 | (* Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$. *)
31 | have "(9 * y^2 + 4) \\ 12 * y"
32 | sledgehammer
33 | then show ?thesis
34 | sledgehammer
35 | qed
36 | then show ?thesis
37 | sledgehammer
38 | qed
39 |
40 |
41 |
42 | Informal:
43 | (*### Problem
44 |
45 | Find the greatest common factor of 180 and 168. Show that it is 12.
46 |
47 | ### Solution
48 |
49 | This is true by simple evaluation.*)
50 |
51 | Formal:
52 | theorem mathd_numbertheory_188:
53 | "gcd 180 168 = (12::nat)"
54 | sledgehammer
55 |
56 |
57 |
58 | Informal:
59 | (*### Problem
60 |
61 | Show that for positive integer n, 2 divides $4^n$.
62 |
63 | ### Solution
64 |
65 | Since n is positive, we can find a natural number m where $m+1=n$.
66 | Then we can show that 2 divides $4^{m+1}$. The conclusion thus follows.*)
67 |
68 | Formal:
69 | theorem numbertheory_2dvd4expn:
70 | fixes n :: nat
71 | assumes h0 : "n \\ 0"
72 | shows "(2::nat) dvd 4^n"
73 | proof -
74 | obtain m::nat where c0: "m+1=n"
75 | sledgehammer
76 | have "(2::nat) dvd 4^(m+1)" sledgehammer
77 | then show ?thesis unfolding c0 sledgehammer
78 | qed
79 |
80 |
81 |
82 | Informal:
83 | (*### Problem
84 |
85 | What is the remainder when $1 + 2 + 3 + 4 + \\dots + 9 + 10$ is divided by 9? Show that it is 1.
86 |
87 | ### Solution
88 |
89 | This is true by simple evaluation.*)
90 |
91 | Formal:
92 | theorem mathd_numbertheory_466:
93 | "(\\ k< 11. k) mod 9 = (1::nat)"
94 | sledgehammer
95 |
96 |
97 |
98 | Informal:
99 | (*### Problem
100 |
101 | If $321_{b}$ is equal to the base 10 integer 57, find $b$ given that $b>0$. Show that it is 4.
102 |
103 | ### Solution
104 |
105 | Converting $321_{b}$ to base 10 and setting it equal to 57, we find that \\begin{align*} 3(b^2)+2(b^1)+1(b^0)&=57
106 | \\\\ 3b^2+2b+1&=57
107 | \\\\\\Rightarrow\\qquad 3b^2+2b-56&=0
108 | \\\\\\Rightarrow\\qquad (3b+14)(b-4)&=0
109 | \\end{align*}This tells us that $b$ is either $-\\frac{14}{3}$ or $4$. We know that $b>0$, so $b=4$.*)
110 |
111 | Formal:
112 | theorem mathd_numbertheory_48:
113 | fixes b :: real
114 | assumes h0 : "00$, so $b=4$. *)
129 | then show ?thesis using h0 sledgehammer
130 | qed
131 |
132 |
133 |
134 | Informal:
135 | (*### Problem
136 |
137 | When Rachel divides her favorite number by 7, she gets a remainder of 5. What will the remainder be if she multiplies her favorite number by 5 and then divides by 7? Show that it is 4.
138 |
139 | ### Solution
140 |
141 | Let $n$ be Rachel's favorite number.
142 | Then $n \\equiv 5 \\pmod{7}$, so $5n \\equiv 5 \\cdot 5 \\equiv 25 \\equiv 4 \\pmod{7}$.
143 | *)
144 |
145 | Formal:
146 | theorem mathd_numbertheory_335:
147 | fixes n :: nat
148 | assumes h0 : "n mod 7 = 5"
149 | shows "(5 * n) mod 7 = 4"
150 | proof -
151 | (* Then $n \\equiv 5 \\pmod{7}$, so $5n \\equiv 5 \\cdot 5 \\equiv 25 \\equiv 4 \\pmod{7}$. *)
152 | have c0:"(5 * n) mod 7 = (5 * 5) mod 7" using h0
153 | sledgehammer
154 | then have "\\ = 4" sledgehammer
155 | then have "(5 * n) mod 7 = 4" using c0 sledgehammer
156 | then show ?thesis sledgehammer
157 | qed
158 |
159 |
160 |
161 | Informal:
162 | (*### Problem
163 |
164 | What positive two-digit integer is exactly twice the sum of its digits? Show that it is 18.
165 |
166 | ### Solution
167 |
168 | We simplify $10a + b = 2(a+b)$ to get $8a = b$.
169 | Since $a$ is at least 1, $b$ is at least 8.
170 | We know $b$ is 8 since $8a = b$ and $a$ is a natural number.
171 | Hence $a$ is 1.
172 | The two-digit integer is hence $18$.
173 | *)
174 |
175 | Formal:
176 | theorem mathd_numbertheory_284:
177 | fixes a b :: nat
178 | assumes h0 : "1\\a \\ a \\9 \\ b \\9"
179 | and h1 : "10 * a + b = 2 * (a+b)"
180 | shows "10 * a + b = 18"
181 | proof -
182 | (* We simplify $10a + b = 2(a+b)$ to get $8a = b$. *)
183 | have c0: "8 * a = b" using h1 sledgehammer
184 | (* Since $a$ is at least 1, $b$ is at least 8. *)
185 | hence "b \\ 8" using h0 sledgehammer
186 | (* We know $b$ is 8 since $8a = b$ and $a$ is a natural number. *)
187 | hence c1:"b = 8" using h0 c0
188 | sledgehammer
189 | (* Hence $a$ is 1. *)
190 | hence "a = 1" using c0 sledgehammer
191 | (* The two-digit integer is hence $18$. *)
192 | then show ?thesis using c1 sledgehammer
193 | qed
194 |
195 |
196 |
197 | """.strip(),
198 | "other": """Informal:
199 | (*### Problem
200 |
201 | Find the minimum value of $\\frac{9x^2\\sin^2 x + 4}{x\\sin x}$ for $0 < x < \\pi$. Show that it is 12.
202 |
203 | ### Solution
204 |
205 | Let $y = x \\sin x$. It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}.
206 | It is trivial to see that $y > 0$.
207 | Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$.
208 | This can be done by the sum of squares method.*)
209 |
210 | Formal:
211 | theorem aime_1983_p9:
212 | fixes x::real
213 | assumes "0 ((9 * (x^2 * (sin x)^2)) + 4) / (x * sin x)"
215 | proof -
216 | (* Let $y = x \\sin x$. *)
217 | define y where "y=x * sin x"
218 | (* It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}. *)
219 | have "12 \\ (9 * y^2 + 4) / y"
220 | proof -
221 | (* It is trivial to see that $y > 0$. *)
222 | have c0: "y > 0"
223 | sledgehammer
224 | (* Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$. *)
225 | have "(9 * y^2 + 4) \\ 12 * y"
226 | sledgehammer
227 | then show ?thesis
228 | sledgehammer
229 | qed
230 | then show ?thesis
231 | sledgehammer
232 | qed
233 |
234 |
235 |
236 | Informal:
237 | (*### Problem
238 |
239 | Show that for any four complex numbers a, b, c, and d, $(a-d)(a-c)(a-b) = -(((a^2 - a(b+c)) + bc) * d) + (a^2 - a(b+c) + bc) * a$.
240 |
241 | ### Solution
242 |
243 | We first see that $a^2 = a * a$ trivially.
244 | Unfolding this, the main equation holds true when terms are rearranged.*)
245 |
246 | Formal:
247 | theorem algebra_3rootspoly_amdtamctambeqnasqmbpctapcbtdpasqmbpctapcbta:
248 | fixes a b c d :: complex
249 | shows "(a-d) * (a-c) * (a-b) = -(((a^2 - (b+c) * a) + c * b) * d) + (a^2 - (b+c) * a + c * b) * a"
250 | proof -
251 | (* We first see that $a^2 = a * a$ trivially. *)
252 | have t0: "a^2 = a * a"
253 | using power2_eq_square
254 | sledgehammer
255 | (* Unfolding this, the main equation holds true when terms are rearranged. *)
256 | show ?thesis unfolding t0
257 | sledgehammer
258 | qed
259 |
260 |
261 |
262 | Informal:
263 | (*### Problem
264 |
265 | Find the greatest common factor of 180 and 168. Show that it is 12.
266 |
267 | ### Solution
268 |
269 | This is true by simple evaluation.*)
270 |
271 | Formal:
272 | theorem mathd_numbertheory_188:
273 | "gcd 180 168 = (12::nat)"
274 | sledgehammer
275 |
276 |
277 |
278 | Informal:
279 | (*### Problem
280 |
281 | For how many positive integers $n$ is $n^2 - 3n + 2$ a [[prime]] number?
282 |
283 | $\\mathrm{(A)}\\ \\text{none}
284 | \\qquad\\mathrm{(B)}\\ \\text{one}
285 | \\qquad\\mathrm{(C)}\\ \\text{two}
286 | \\qquad\\mathrm{(D)}\\ \\text{more\\ than\\ two,\\ but\\ finitely\\ many}
287 | \\qquad\\mathrm{(E)}\\ \\text{infinitely\\ many}$ Show that it is \\mathrm{(B)}\\ \\text{one}.
288 |
289 | ### Solution
290 |
291 | Factoring, we get $n^2 - 3n + 2 = (n-2)(n-1)$.
292 | Either $n-1$ or $n-2$ is odd, and the other is even.
293 | Their product must yield an even number.
294 | The only prime that is even is $2$, which is when $n$ is $3$ or $0$.
295 | Since $0$ is not a positive number, the answer is $\\mathrm{(B)}\\ \\text{one}$.*)
296 |
297 | Formal:
298 | theorem amc12b_2002_p3:
299 | fixes n ::nat
300 | assumes "n>0"
301 | and prime:"prime (n^2+2-3*n)"
302 | shows "n=3"
303 | proof -
304 | have "n>2"
305 | proof (rule ccontr)
306 | assume "\\ 2 < n"
307 | then have "n=1 \\ n=2" using \\n>0\\ sledgehammer
308 | then show False using prime[THEN prime_gt_1_nat]
309 | sledgehammer
310 | qed
311 | (* Factoring, we get $n^2 - 3n + 2 = (n-2)(n-1)$. *)
312 | then have "n^2+2-3*n = (n-1) * (n-2)"
313 | unfolding power2_eq_square
314 | sledgehammer
315 | (* Either $n-1$ or $n-2$ is odd, and the other is even.
316 | Their product must yield an even number.
317 | The only prime that is even is $2$, which is when $n$ is $3$ or $0$.
318 | Since $0$ is not a positive number, the answer is $\\mathrm{(B)}\\ \\text{one}$.*)
319 | then have "prime ((n-1) * (n-2))"
320 | using prime sledgehammer
321 | then have "n-1=1 \\ n-2 = 1"
322 | using prime_product sledgehammer
323 | with \\n>2\\
324 | show "n=3" sledgehammer
325 | qed
326 |
327 |
328 |
329 | Informal:
330 | (*### Problem
331 |
332 | For a positive real number a, show that $10a\\leq 28a^2+1$.
333 |
334 | ### Solution
335 |
336 | It suffices to show $0\\leq 28a^2 - 10a + 1$.
337 | First, consider completing the square for $28a^2 - 10a$ and observe that $(a - \\frac{5}{28})^2 = a^2 - \\frac{10}{28}a + (5/28)^2$.
338 | Since $0\\leq (a - \\frac{5}{28})^2$, we have $0\\leq a^2 - \\frac{10}{28}a + (5/28)^2$.
339 | Multiplying by 28 and simplifying terms gives $0\\leq 28*a^2 - 10*a + (25/28)$.
340 | Since $25/28 < 1$, the result follows.*)
341 |
342 | Formal:
343 | theorem algebra_binomnegdiscrineq_10alt28asqp1:
344 | fixes a :: real
345 | shows "10 * a \\ 28 * a^2 + 1"
346 | proof -
347 | (* it suffices to show $0\\leq 28a^2 - 10a + 1$ *)
348 | have c0: "0 \\ 28*a^2 - 10*a + 1"
349 | proof -
350 | (* observe that $(a - \\frac{5}{28})^2 = a^2 - \\frac{10}{28}a + (5/28)^2$ *)
351 | have c1: "(a - (5/28))^2 = a^2 - 10/28*a + (5/28)^2"
352 | sledgehammer
353 | (* we have $0\\leq a^2 - \\frac{10}{28}a + (5/28)^2$ *)
354 | then have c2: "0 \\ a^2 - 10/28*a + (5/28)^2" using c1
355 | sledgehammer
356 | (* Multiplying by 28 and simplifying terms gives $0\\leq 28*a^2 - 10*a + (25/28)$ *)
357 | then have c3: "0 \\ 28*a^2 - 10*a + 28*((5/28)^2)" using c2
358 | sledgehammer
359 | then have c4: "0 \\ 28*a^2 - 10*a + 28*((5/28)*(5/28))" using c3
360 | sledgehammer
361 | then have c5: "0 \\ 28*a^2 - 10*a + (25/28)" using c4
362 | sledgehammer
363 | (* Since $25/28 < 1$, the result follows. *)
364 | then show ?thesis using c5
365 | sledgehammer
366 | qed
367 | then show ?thesis
368 | sledgehammer
369 | qed
370 |
371 |
372 |
373 | Informal:
374 | (*### Problem
375 |
376 | Show that for any complex number a, $(a-10)(a+11) = a^2 + a - 110$.
377 |
378 | ### Solution
379 |
380 | We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$.
381 | This equals $a^2 + a - 10*11 = a^2 + a - 110$.*)
382 |
383 | Formal:
384 | theorem algebra_2rootsintpoly_am10tap11eqasqpam110:
385 | fixes a :: complex
386 | shows "(a-10) * (a+11) = a^2 + a -110"
387 | proof -
388 | (* We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$. *)
389 | have "(a-10) * (a+11) = a^2 - 10*a + 11*a - 10 *11"
390 | sledgehammer
391 | (* This equals $a^2 + a - 10*11 = a^2 + a - 110$. *)
392 | also have "\\ = a^2 + a - 10 * 11"
393 | sledgehammer
394 | also have "\\ = a^2 + a - 110"
395 | sledgehammer
396 | finally show ?thesis
397 | sledgehammer
398 | qed
399 |
400 |
401 |
402 | """.strip()
403 | }
404 |
405 | class MiniF2FIsabellePrompt(FewShotPrompting):
406 | def __init__(self):
407 | super().__init__()
408 |
409 | def format_prompt(self, task_input, task_output):
410 | if 'numbertheory' in task_input.split("Formal:", 1)[1]:
411 | tag = 'numbertheory'
412 | else:
413 | tag = 'other'
414 | prompt = f"{few_shot_prompt[tag].strip()}\n\n\n\nInformal:\n{task_input.strip()}\n{task_output.strip()}"
415 | return prompt.rstrip()
416 |
417 | def stop_words(self):
418 | return ["\nInformal:"]
419 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/pal_gsm_8_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = '''
4 | Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
5 |
6 | # solution in Python:
7 |
8 |
9 | def solution():
10 | """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"""
11 | money_initial = 23
12 | bagels = 5
13 | bagel_cost = 3
14 | money_spent = bagels * bagel_cost
15 | money_left = money_initial - money_spent
16 | result = money_left
17 | return result
18 |
19 |
20 |
21 |
22 |
23 | Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
24 |
25 | # solution in Python:
26 |
27 |
28 | def solution():
29 | """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?"""
30 | golf_balls_initial = 58
31 | golf_balls_lost_tuesday = 23
32 | golf_balls_lost_wednesday = 2
33 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday
34 | result = golf_balls_left
35 | return result
36 |
37 |
38 |
39 |
40 |
41 | Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
42 |
43 | # solution in Python:
44 |
45 |
46 | def solution():
47 | """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"""
48 | computers_initial = 9
49 | computers_per_day = 5
50 | num_days = 4 # 4 days between monday and thursday
51 | computers_added = computers_per_day * num_days
52 | computers_total = computers_initial + computers_added
53 | result = computers_total
54 | return result
55 |
56 |
57 |
58 |
59 |
60 | Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
61 |
62 | # solution in Python:
63 |
64 |
65 | def solution():
66 | """Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?"""
67 | toys_initial = 5
68 | mom_toys = 2
69 | dad_toys = 2
70 | total_received = mom_toys + dad_toys
71 | total_toys = toys_initial + total_received
72 | result = total_toys
73 | return result
74 |
75 |
76 |
77 |
78 |
79 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
80 |
81 | # solution in Python:
82 |
83 |
84 | def solution():
85 | """Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?"""
86 | jason_lollipops_initial = 20
87 | jason_lollipops_after = 12
88 | denny_lollipops = jason_lollipops_initial - jason_lollipops_after
89 | result = denny_lollipops
90 | return result
91 |
92 |
93 |
94 |
95 |
96 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
97 |
98 | # solution in Python:
99 |
100 |
101 | def solution():
102 | """Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?"""
103 | leah_chocolates = 32
104 | sister_chocolates = 42
105 | total_chocolates = leah_chocolates + sister_chocolates
106 | chocolates_eaten = 35
107 | chocolates_left = total_chocolates - chocolates_eaten
108 | result = chocolates_left
109 | return result
110 |
111 |
112 |
113 |
114 |
115 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
116 |
117 | # solution in Python:
118 |
119 |
120 | def solution():
121 | """If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?"""
122 | cars_initial = 3
123 | cars_arrived = 2
124 | total_cars = cars_initial + cars_arrived
125 | result = total_cars
126 | return result
127 |
128 |
129 |
130 |
131 |
132 | Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
133 |
134 | # solution in Python:
135 |
136 |
137 | def solution():
138 | """There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"""
139 | trees_initial = 15
140 | trees_after = 21
141 | trees_added = trees_after - trees_initial
142 | result = trees_added
143 | return result
144 | '''.strip()
145 |
146 | class PALGSMPrompt(FewShotPrompting):
147 | def __init__(self):
148 | super().__init__()
149 |
150 | def format_prompt(self, task_input, task_output):
151 | prompt = f"{few_shot_prompt}\n\n\n\n\n\nQ: {task_input}\n\n# solution in Python:"
152 | return prompt.rstrip()
153 |
154 | def stop_words(self):
155 | return ["\nQ:", "Q: "]
156 |
--------------------------------------------------------------------------------
/evaluation/few_shot_prompts/pal_math_4_shot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | few_shot_prompt = """Problem:
4 | Find the value of $x$ that satisfies $\\frac{\\sqrt{3x+5}}{\\sqrt{6x+5}}=\\frac{\\sqrt{5}}{3}$. Express your answer as a common fraction.
5 |
6 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.
7 | ```
8 | # Initialize x
9 | x = symbols('x')
10 |
11 | # Define the equation
12 | equation = Eq(sqrt(3*x + 5)/sqrt(6*x + 5), sqrt(5)/3)
13 |
14 | # Solve for x
15 | answer = solve(equation, x)
16 | ```
17 | The imports required for this program are
18 | ```
19 | from sympy import symbols, Eq, solve, sqrt
20 | ```
21 | I hope my solution is correct.
22 |
23 | Problem:
24 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$
25 |
26 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.
27 | ```
28 | # Given det(A) = 2 and det(B) = 12
29 | det_A = 2
30 | det_B = 12
31 |
32 | # Use the property det(AB) = det(A)*det(B)
33 | det_AB = det_A * det_B
34 |
35 | answer = det_AB
36 | ```
37 | The imports required for this program are
38 | ```
39 |
40 | ```
41 | I hope my solution is correct.
42 |
43 | Problem:
44 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?
45 |
46 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.
47 | ```
48 | # Calculate the total weight lifted initially, which is 2*20*12 pounds
49 | total_weight = 2 * 20 * 12
50 |
51 | # Since Terrell lifts two 15-pound weights, divide the total weight by 2 * 15
52 | repetitions = total_weight / (2*15)
53 |
54 | answer = n_value
55 | ```
56 | The imports required for this program are
57 | ```
58 |
59 | ```
60 | I hope my solution is correct.
61 |
62 | Problem:
63 | If Anna flips 8 coins, what is the probability that she gets more heads than tails?
64 |
65 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.
66 | ```
67 | # There are 2**8 possible outcomes
68 | n = 8
69 | total_outcomes = 2 ** n
70 |
71 | # There are binom(n, k) ways to get k heads
72 | favorable_outcomes = 0
73 | for k in range((n // 2) + 1, n + 1):
74 | favorable_outcomes += math.comb(n, k)
75 |
76 | probability = favorable_outcomes / total_outcomes
77 |
78 | answer = probability
79 | ```
80 | The imports required for this program are
81 | ```
82 | import math
83 | ```
84 | I hope my solution is correct.
85 |
86 | Problem:
87 | Evaluate $\\left\\lceil3\\left(6-\\frac12\\right)\\right\\rceil$.
88 |
89 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.
90 | ```
91 | # Calculate 3 * (6 - 1/2)
92 | result = 3 * (6 - 0.5)
93 |
94 | # Apply the ceiling function
95 | ceiling_result = math.ceil(result)
96 |
97 | answer = ceiling_result
98 | ```
99 | The imports required for this program are
100 | ```
101 | import math
102 | ```
103 | I hope my solution is correct."""
104 |
105 | class PALMathPrompt(FewShotPrompting):
106 | def __init__(self):
107 | super().__init__()
108 |
109 | def format_prompt(self, task_input, task_output):
110 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\n\nYou are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.\n{task_output}"
111 | return prompt.rstrip()
112 |
113 | def stop_words(self):
114 | return ["\nProblem:", "Problem:"]
115 |
--------------------------------------------------------------------------------
/evaluation/infer/run_cot_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import_path = os.path.abspath(__file__)
5 | for _ in range(2):
6 | import_path = os.path.dirname(import_path)
7 | sys.path.append(import_path)
8 |
9 | from tqdm import tqdm
10 | import json
11 | from copy import deepcopy
12 | from vllm import LLM, SamplingParams
13 | from pebble import ProcessPool
14 | from concurrent.futures import TimeoutError
15 | import random
16 | from eval.utils import generate_completions, load_hf_lm_and_tokenizer
17 |
18 | from transformers import AutoTokenizer
19 |
20 | from data_processing.answer_extraction import *
21 | from eval.eval_script import *
22 | from few_shot_prompts import *
23 |
24 | def evaluate(eval_fn, tasks, _timeout=15):
25 | with ProcessPool() as pool:
26 | timeout_cnt = 0
27 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result()
28 | labels = []
29 | while True:
30 | try:
31 | labels.append(int(next(iterator)))
32 | except StopIteration:
33 | break
34 | except TimeoutError as error:
35 | labels.append(0)
36 | timeout_cnt += 1
37 | except Exception as error:
38 | print(error.traceback, flush=True)
39 | exit()
40 | return labels, timeout_cnt
41 |
42 | def infer(args, test_data):
43 | global tokenizer
44 | if tokenizer is None:
45 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True)
46 |
47 | if args.prompt_format == 'few_shot':
48 | assert args.few_shot_prompt is not None
49 | prompting = eval(args.few_shot_prompt)()
50 |
51 | prompts = []
52 | for example in test_data:
53 | prompt = ""
54 | if args.prompt_format == 'few_shot':
55 | prompt = prompting.format_prompt(example['messages'][-2]['content'], example['messages'][-1]['content'])
56 | else:
57 | for mess in example['messages']:
58 | if args.prompt_format == 'sft':
59 | if mess['role'] == 'user':
60 | prompt += f"{tokenizer.eos_token}User: {mess['content'].strip()}\n\nAssistant:"
61 | elif mess['role'] == 'assistant':
62 | prompt += mess['content'].rstrip()
63 | else:
64 | raise NotImplementedError()
65 | prompt = prompt.lstrip()
66 | if args.prompt_format == 'sft' and prompt.startswith(tokenizer.eos_token):
67 | prompt = prompt[len(tokenizer.eos_token):].lstrip()
68 | example['prompt'] = prompt
69 | prompts.append(prompt.lstrip())
70 |
71 | global model
72 | print("Loading model and tokenizer...")
73 | if args.use_vllm:
74 | if model is None:
75 | model = LLM(model=args.model_name_or_path, tokenizer=args.tokenizer_name_or_path, trust_remote_code=True, tensor_parallel_size=len(os.environ['CUDA_VISIBLE_DEVICES'].split(",")))
76 | eos_token = tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else ''
77 | stop_words = [eos_token]
78 | if args.prompt_format == 'few_shot':
79 | stop_words.extend(prompting.stop_words())
80 | outputs = model.generate(prompts, SamplingParams(temperature=args.temperature, top_p=1.0, max_tokens=1024, n=1, stop=stop_words))
81 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id
82 | outputs = [output.outputs[0].text for output in outputs]
83 | else:
84 | model, tokenizer = load_hf_lm_and_tokenizer(
85 | model_name_or_path=args.model_name_or_path,
86 | tokenizer_name_or_path=args.tokenizer_name_or_path,
87 | load_in_8bit=args.load_in_8bit,
88 | load_in_half=args.load_in_half,
89 | gptq_model=args.gptq
90 | )
91 |
92 | stop_id_sequences = []
93 | if tokenizer.eos_token_id is not None:
94 | stop_id_sequences = [[tokenizer.eos_token_id]]
95 | if args.prompt_format == 'few_shot':
96 | stop_id_sequences.extend([tokenizer.encode(word) for word in prompting.stop_words()])
97 | outputs, finish_completion = generate_completions(
98 | model=model,
99 | tokenizer=tokenizer,
100 | prompts=prompts,
101 | max_new_tokens=512,
102 | batch_size=args.eval_batch_size,
103 | stop_id_sequences=stop_id_sequences if stop_id_sequences else None,
104 | end_of_generation_id_sequence=[tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else None
105 | )
106 |
107 | if args.complete_partial_output:
108 | model_outputs = [example['messages'][-1]['content'] + output for example, output in zip(test_data, outputs)]
109 | else:
110 | model_outputs = outputs
111 |
112 | predictions = [eval(args.answer_extraction_fn)(item['messages'][-2]['content'], output, task='cot') for item, output in tqdm(zip(test_data, model_outputs), desc="extract answer", total=len(model_outputs))]
113 | assert len(model_outputs) > 0, f"{len(model_outputs)}"
114 |
115 | results = []
116 | for example, output, pred in zip(test_data, model_outputs, predictions):
117 | item = deepcopy(example)
118 | item.update({
119 | 'model_output': output,
120 | 'prediction': pred,
121 | })
122 | results.append(item)
123 | return results
124 |
125 | def main(args):
126 | random.seed(42)
127 |
128 | print("Loading data...")
129 | test_data = []
130 | with open(os.path.join(args.data_dir, f"train.jsonl" if args.infer_train_set else f"test.jsonl")) as fin:
131 | for line in fin:
132 | example = json.loads(line)
133 | messages = example['messages']
134 | assert messages[-1]['role'] == 'assistant'
135 | if not args.complete_partial_output:
136 | example['reference'] = example.get('reference', '') or [mess['content'] for mess in messages if mess['role'] == 'assistant']
137 | for mess in messages:
138 | if mess['role'] == 'assistant':
139 | mess['content'] = ''
140 | example['messages'] = messages
141 | test_data.append(example)
142 |
143 | if args.max_num_examples and len(test_data) > args.max_num_examples:
144 | test_data = random.sample(test_data, args.max_num_examples)
145 |
146 | if args.n_subsets > 1:
147 | assert args.subset_id >= 0 and args.subset_id < args.n_subsets
148 | test_data = [item for i, item in enumerate(test_data) if i % args.n_subsets == args.subset_id]
149 |
150 | if not test_data:
151 | return
152 |
153 | if not os.path.exists(args.save_dir):
154 | os.makedirs(args.save_dir, exist_ok=True)
155 |
156 | results = infer(args, test_data)
157 |
158 | labels, eval_timeout_cnt = evaluate(eval(args.eval_fn), results)
159 | for item, label in zip(results, labels):
160 | item['accuracy'] = label
161 |
162 | print("Calculating accuracy...")
163 | acc = 0
164 | for item in results:
165 | acc += item['accuracy']
166 | print("output acc = {:.5f}".format(acc / len(results) * 100), flush=True)
167 |
168 | print(f"Timeout count >>> output eval = {eval_timeout_cnt}", flush=True)
169 |
170 | pred_fname = "predictions.json"
171 | if args.n_subsets > 1:
172 | pred_fname = f"predictions.{args.subset_id}.json"
173 | with open(os.path.join(args.save_dir, pred_fname), "w") as fout:
174 | json.dump(results, fout, ensure_ascii=True)
175 |
176 | metric_fname = "metrics.json"
177 | if args.n_subsets > 1:
178 | metric_fname = f"metrics.{args.subset_id}.json"
179 | with open(os.path.join(args.save_dir, metric_fname), "w") as fout:
180 | json.dump({
181 | "n_samples": len(results),
182 | "accuracy": sum(item['accuracy'] for item in results) / len(results),
183 | }, fout, indent=4)
184 |
185 | if __name__ == "__main__":
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument("--data_dir", type=str, default="data/mgsm")
188 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.")
189 | parser.add_argument("--save_dir", type=str, default="results/mgsm")
190 | parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
191 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
192 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
193 | parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
194 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
195 | parser.add_argument("--use_vllm", action="store_true")
196 | parser.add_argument("--load_in_half", action='store_true')
197 | parser.add_argument("--infer_train_set", action="store_true")
198 | parser.add_argument("--n_subsets", type=int, default=1)
199 | parser.add_argument("--subset_id", type=int, default=0)
200 | parser.add_argument("--temperature", type=float, default=0.0)
201 | parser.add_argument("--repeat_id_start", type=int, default=0)
202 | parser.add_argument("--n_repeat_sampling", type=int, default=1)
203 | parser.add_argument("--complete_partial_output", action='store_true')
204 | parser.add_argument("--prompt_format", type=str, choices=['sft', 'few_shot'], default='sft')
205 | parser.add_argument("--few_shot_prompt", type=str, default=None)
206 | parser.add_argument("--answer_extraction_fn", type=str, required=True)
207 | parser.add_argument("--eval_fn", type=str, required=True)
208 | parser.add_argument("--gpus", type=str, default=None)
209 | args, unparsed_args = parser.parse_known_args()
210 | if args.gpus is not None:
211 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
212 |
213 | print(unparsed_args, flush=True)
214 |
215 | if 'math6' in args.data_dir:
216 | args.multi_turn = True
217 |
218 | # model_name_or_path cannot be both None or both not None.
219 | model = None
220 | tokenizer = None
221 | pool = None
222 | if args.n_repeat_sampling > 1 or args.repeat_id_start != 0:
223 | assert args.temperature > 0
224 | save_dir = args.save_dir
225 | for i in range(args.repeat_id_start, args.repeat_id_start + args.n_repeat_sampling):
226 | print(f"working on the {i} trials ...", flush=True)
227 | args.save_dir = os.path.join(save_dir, str(i))
228 | os.makedirs(args.save_dir, exist_ok=True)
229 | main(args)
230 | else:
231 | main(args)
232 |
233 | if pool is not None:
234 | pool.close()
235 |
--------------------------------------------------------------------------------
/evaluation/infer/run_pal_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import_path = os.path.abspath(__file__)
5 | for _ in range(2):
6 | import_path = os.path.dirname(import_path)
7 | sys.path.append(import_path)
8 |
9 | import json
10 | from copy import deepcopy
11 | from functools import partial
12 | from vllm import LLM, SamplingParams
13 | from pebble import ProcessPool
14 | from concurrent.futures import TimeoutError
15 | import random
16 | from eval.utils import generate_completions, load_hf_lm_and_tokenizer
17 | from eval.python_executor import PythonExecutor
18 |
19 | from transformers import AutoTokenizer
20 |
21 | from data_processing.answer_extraction import *
22 | from eval.eval_script import *
23 | from few_shot_prompts import *
24 |
25 | def evaluate(eval_fn, tasks, _timeout=15):
26 | with ProcessPool() as pool:
27 | timeout_cnt = 0
28 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result()
29 | labels = []
30 | while True:
31 | try:
32 | labels.append(int(next(iterator)))
33 | except StopIteration:
34 | break
35 | except TimeoutError as error:
36 | labels.append(0)
37 | timeout_cnt += 1
38 | except Exception as error:
39 | print(error.traceback, flush=True)
40 | exit()
41 | return labels, timeout_cnt
42 |
43 | def main(args):
44 | random.seed(42)
45 |
46 | print("Loading data...")
47 | test_data = []
48 | with open(os.path.join(args.data_dir, f"train.jsonl" if args.infer_train_set else f"test.jsonl")) as fin:
49 | for line in fin:
50 | example = json.loads(line)
51 | messages = example['messages']
52 | assert len(messages) in [2, 3]
53 | assert messages[-1]['role'] == 'assistant'
54 | if not args.complete_partial_output:
55 | example['reference'] = example.get('reference', '') or messages[-1]['content']
56 | messages[-1]['content'] = ''
57 | example['messages'] = messages
58 | test_data.append(example)
59 |
60 | if args.max_num_examples and len(test_data) > args.max_num_examples:
61 | test_data = random.sample(test_data, args.max_num_examples)
62 |
63 | if args.n_subsets > 1:
64 | assert args.subset_id >= 0 and args.subset_id < args.n_subsets
65 | test_data = [item for i, item in enumerate(test_data) if i % args.n_subsets == args.subset_id]
66 |
67 | if not test_data:
68 | return
69 |
70 | if not os.path.exists(args.save_dir):
71 | os.makedirs(args.save_dir, exist_ok=True)
72 |
73 | if args.prompt_format == 'few_shot':
74 | assert args.few_shot_prompt is not None
75 | prompting = eval(args.few_shot_prompt)()
76 |
77 | prompts = []
78 | for example in test_data:
79 | prompt = ""
80 | if args.prompt_format == 'few_shot':
81 | prompt = prompting.format_prompt(example['messages'][-2]['content'], example['messages'][-1]['content'])
82 | else:
83 | for mess in example['messages']:
84 | if args.prompt_format == 'sft':
85 | if mess['role'] == 'user':
86 | prompt += f"User: {mess['content'].strip()}\n\nAssistant:"
87 | elif mess['role'] == 'assistant':
88 | prompt += mess['content'].strip()
89 | else:
90 | raise NotImplementedError()
91 | prompt = prompt.lstrip()
92 | example['prompt'] = prompt
93 | prompts.append(prompt.lstrip())
94 |
95 | global model, tokenizer
96 | if tokenizer is None:
97 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True)
98 | print("Loading model and tokenizer...")
99 | if args.use_vllm:
100 | if model is None:
101 | model = LLM(model=args.model_name_or_path, tokenizer=args.tokenizer_name_or_path, trust_remote_code=True, tensor_parallel_size=len(os.environ['CUDA_VISIBLE_DEVICES'].split(",")))
102 | eos_token = tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else ''
103 | stop_words = [eos_token]
104 | if args.prompt_format == 'few_shot':
105 | stop_words.extend(prompting.stop_words())
106 | outputs = model.generate(prompts, SamplingParams(temperature=args.temperature, top_p=1.0, max_tokens=1024, n=1, stop=stop_words))
107 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id
108 | outputs = [output.outputs[0].text for output in outputs]
109 | else:
110 | model, tokenizer = load_hf_lm_and_tokenizer(
111 | model_name_or_path=args.model_name_or_path,
112 | tokenizer_name_or_path=args.tokenizer_name_or_path,
113 | load_in_8bit=args.load_in_8bit,
114 | load_in_half=args.load_in_half,
115 | gptq_model=args.gptq
116 | )
117 |
118 | stop_id_sequences = []
119 | if tokenizer.eos_token_id is not None:
120 | stop_id_sequences = [[tokenizer.eos_token_id]]
121 | if args.prompt_format == 'few_shot':
122 | stop_id_sequences.extend([tokenizer.encode(word) for word in prompting.stop_words()])
123 | outputs, finish_completion = generate_completions(
124 | model=model,
125 | tokenizer=tokenizer,
126 | prompts=prompts,
127 | max_new_tokens=512,
128 | batch_size=args.eval_batch_size,
129 | stop_id_sequences=stop_id_sequences if stop_id_sequences else None,
130 | end_of_generation_id_sequence=[tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else None
131 | )
132 |
133 | if args.complete_partial_output:
134 | model_outputs = [example['messages'][-1]['content'] + output for example, output in zip(test_data, outputs)]
135 | else:
136 | model_outputs = outputs
137 |
138 | if 'PALGSMPrompt' in args.few_shot_prompt:
139 | executor = PythonExecutor(get_answer_expr='solution()')
140 | codes = model_outputs
141 | elif 'PALMathPrompt' in args.few_shot_prompt:
142 | executor = PythonExecutor(get_answer_symbol='answer')
143 | codes = []
144 | for text in model_outputs:
145 | if text.count("```") == 4:
146 | segments = text.split("```")
147 | assert len(segments) == 5
148 | code = f"{segments[3]}\n\n{segments[1]}"
149 | else:
150 | code = "answer = '[invalid]'"
151 | codes.append(code)
152 | else:
153 | raise NotImplementedError()
154 |
155 | predictions = []
156 | runtime_errors = []
157 | for pred, err in executor.batch_apply(codes):
158 | predictions.append(str(pred))
159 | runtime_errors.append(str(err['exec_info']).strip())
160 |
161 | assert len(model_outputs) > 0, f"{len(model_outputs)}"
162 |
163 | results = []
164 | for example, output, pred in zip(test_data, model_outputs, predictions):
165 | item = deepcopy(example)
166 | item.update({
167 | 'model_output': output,
168 | 'program_output': pred,
169 | })
170 | results.append(item)
171 |
172 | labels, eval_timeout_cnt = evaluate(partial(eval(args.eval_fn), pred_key='program_output'), results)
173 | for item, label in zip(results, labels):
174 | item['accuracy'] = label
175 |
176 | print("Calculating accuracy...")
177 | acc = 0
178 | for item in results:
179 | acc += item['accuracy']
180 | print("output acc = {:.5f}".format(acc / len(results) * 100), flush=True)
181 |
182 | print(f"Timeout count >>> output eval = {eval_timeout_cnt}", flush=True)
183 |
184 | pred_fname = "predictions.json"
185 | if args.n_subsets > 1:
186 | pred_fname = f"predictions.{args.subset_id}.json"
187 | with open(os.path.join(args.save_dir, pred_fname), "w") as fout:
188 | json.dump(results, fout, ensure_ascii=True)
189 |
190 | metric_fname = "metrics.json"
191 | if args.n_subsets > 1:
192 | metric_fname = f"metrics.{args.subset_id}.json"
193 | with open(os.path.join(args.save_dir, metric_fname), "w") as fout:
194 | json.dump({
195 | "n_samples": len(results),
196 | "accuracy": sum(item['accuracy'] for item in results) / len(results),
197 | }, fout, indent=4)
198 |
199 | if __name__ == "__main__":
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument("--data_dir", type=str, default="data/mgsm")
202 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.")
203 | parser.add_argument("--save_dir", type=str, default="results/mgsm")
204 | parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
205 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
206 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
207 | parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
208 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
209 | parser.add_argument("--use_vllm", action="store_true")
210 | parser.add_argument("--load_in_half", action='store_true')
211 | parser.add_argument("--infer_train_set", action="store_true")
212 | parser.add_argument("--n_subsets", type=int, default=1)
213 | parser.add_argument("--subset_id", type=int, default=0)
214 | parser.add_argument("--temperature", type=float, default=0.0)
215 | parser.add_argument("--repeat_id_start", type=int, default=0)
216 | parser.add_argument("--n_repeat_sampling", type=int, default=1)
217 | parser.add_argument("--complete_partial_output", action='store_true')
218 | parser.add_argument("--prompt_format", type=str, choices=['sft', 'few_shot'], default='sft')
219 | parser.add_argument("--few_shot_prompt", type=str, default=None)
220 | parser.add_argument("--answer_extraction_fn", type=str, default=None)
221 | parser.add_argument("--eval_fn", type=str, required=True)
222 | parser.add_argument("--gpus", type=str, default=None)
223 | args, unparsed_args = parser.parse_known_args()
224 | if args.gpus is not None:
225 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
226 |
227 | print(unparsed_args, flush=True)
228 |
229 | model = None
230 | tokenizer = None
231 | pool = None
232 | if args.n_repeat_sampling > 1 or args.repeat_id_start != 0:
233 | assert args.temperature > 0
234 | save_dir = args.save_dir
235 | for i in range(args.repeat_id_start, args.repeat_id_start + args.n_repeat_sampling):
236 | print(f"working on the {i} trials ...", flush=True)
237 | args.save_dir = os.path.join(save_dir, str(i))
238 | os.makedirs(args.save_dir, exist_ok=True)
239 | main(args)
240 | else:
241 | main(args)
242 |
243 | if pool is not None:
244 | pool.close()
245 |
--------------------------------------------------------------------------------
/evaluation/outputs.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/outputs.zip
--------------------------------------------------------------------------------
/evaluation/run_subset_parallel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from glob import glob
5 | import time
6 | import json
7 | import subprocess
8 |
9 | from utils import read_data
10 | from data_processing.process_utils import *
11 |
12 | _worker_num = int(os.environ.get('WORLD_SIZE', 1))
13 | _worker_id = int(os.environ.get('RANK', 0))
14 |
15 | def markup_question(args, item, language, src, task):
16 | for i in range(len(item['messages']) - 2, -1, -2):
17 | if language == 'zh':
18 | if task == 'cot':
19 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请通过逐步推理来解答问题,并把最终答案放置于" + "\\boxed{}中。"
20 | elif task == 'tool':
21 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请结合自然语言和Python程序语言来解答问题,并把最终答案放置于" + "\\boxed{}中。"
22 | else:
23 | pass
24 | elif language == 'en':
25 | if task == 'cot':
26 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease reason step by step, and put your final answer within " + "\\boxed{}."
27 | elif task == 'tool':
28 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within " + "\\boxed{}."
29 | else:
30 | pass
31 | return item
32 |
33 | def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, output_dir, log_dir):
34 | if task == 'pal':
35 | code_fname = "run_pal_eval"
36 | elif task == 'cot':
37 | code_fname = "run_cot_eval"
38 | elif task == 'tool':
39 | code_fname = "run_tool_integrated_eval"
40 | else:
41 | raise NotImplementedError()
42 |
43 | n_procs = args.ngpus // args.ngpus_per_model
44 |
45 | gpus = [str(i) for i in range(args.ngpus)]
46 | gpu_groups = []
47 | for i in range(n_procs):
48 | gpu_groups.append(gpus[i * args.ngpus_per_model: (i + 1) * args.ngpus_per_model])
49 |
50 | global_n_procs = n_procs * _worker_num
51 |
52 | procs = []
53 | for pid, gpus in enumerate(gpu_groups):
54 | global_pid = n_procs * (args.rank or _worker_id) + pid
55 | logpath = os.path.join(log_dir, f"{global_pid}.log")
56 | f = open(logpath, "w")
57 | cmd = f"python infer/{code_fname}.py " \
58 | f"--data_dir {input_dir} " \
59 | f"--max_num_examples 100000000000000 " \
60 | f"--save_dir {output_dir} " \
61 | f"--model {args.model_path} " \
62 | f"--tokenizer {args.tokenizer_path or args.model_path} " \
63 | f"--eval_batch_size 1 " \
64 | f"--temperature {args.temperature} " \
65 | f"--repeat_id_start 0 " \
66 | f"--n_repeat_sampling {args.n_repeats} " \
67 | f"--n_subsets {global_n_procs} " \
68 | f"--prompt_format {args.prompt_format} " \
69 | f"--few_shot_prompt {args.few_shot_prompt} " \
70 | f"--answer_extraction_fn {answer_extraction_fn} " \
71 | f"--eval_fn {eval_fn} " \
72 | f"--subset_id {global_pid} " \
73 | f"--gpus {','.join(gpus)} "
74 | if args.use_vllm:
75 | cmd += " --use_vllm "
76 | if args.load_in_half:
77 | cmd += " --load_in_half "
78 | local_metric_path = os.path.join(output_dir, f"metrics.{global_pid}.json")
79 | if not args.overwrite and os.path.exists(local_metric_path) and read_data(local_metric_path)['n_samples'] > 0:
80 | continue
81 | procs.append((global_pid, subprocess.Popen(cmd.split(), stdout=f, stderr=f), f))
82 | for (global_pid, proc, f) in procs:
83 | print(f"Waiting for the {global_pid}th process to finish ...", flush=True)
84 | proc.wait()
85 | for (global_pid, proc, f) in procs:
86 | print(f"Closing the {global_pid}th process ...", flush=True)
87 | f.close()
88 |
89 | time.sleep(1)
90 |
91 | local_pids = [global_pid for (global_pid, _, _) in procs]
92 |
93 | agg_preds = []
94 | for fname in glob(os.path.join(output_dir, "predictions.*.json")):
95 | if any(str(pid) in fname for pid in local_pids):
96 | agg_preds.extend(read_data(fname))
97 |
98 | metrics = {}
99 | n_samples = 0
100 | for fname in glob(os.path.join(output_dir, "metrics.*.json")):
101 | if not any(str(pid) in fname for pid in local_pids):
102 | continue
103 | _metrics = read_data(fname)
104 | n_samples += _metrics['n_samples']
105 | for key, val in _metrics.items():
106 | if key != 'n_samples':
107 | metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples']
108 | for key, val in metrics.items():
109 | metrics[key] = val / max(n_samples, 1)
110 |
111 | result_msg = f"n samples = {n_samples}"
112 | for key, val in metrics.items():
113 | result_msg += f"\n{key} = {val * 100}"
114 |
115 | metrics['n_samples'] = n_samples
116 |
117 | return metrics, agg_preds, result_msg
118 |
119 | def main():
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--output-dir", type=str, required=True, help="default to `model_path`_predictions")
122 | parser.add_argument("--model-path", type=str, required=True)
123 | parser.add_argument("--tokenizer-path", type=str, default=None)
124 | parser.add_argument("--model-size", type=str, choices=['1b', '7b', '13b', '33b', '34b', '70b'], default="7b")
125 |
126 | parser.add_argument("--test-conf", type=str, default="configs/zero_shot_test_configs.json", help="path to testing data config file that maps from a source to its info")
127 | parser.add_argument("--ngpus", type=int, default=8)
128 | parser.add_argument("--overwrite", action='store_true')
129 | parser.add_argument("--temperature", type=float, default=0)
130 | parser.add_argument("--n-repeats", type=int, default=1)
131 | parser.add_argument("--use-vllm", action='store_true')
132 | parser.add_argument("--load_in_half", action='store_true')
133 |
134 | parser.add_argument("--prompt_format", type=str, default="sft")
135 | parser.add_argument("--few_shot_prompt", type=str, default=None)
136 |
137 | parser.add_argument("--no-markup-question", action='store_true')
138 |
139 | parser.add_argument("--rank", type=int, default=None)
140 | parser.add_argument("--seed", type=int, default=42)
141 | args, _ = parser.parse_known_args()
142 |
143 | print(f"Evaluating {args.model_path}", flush=True)
144 |
145 | if args.output_dir is None:
146 | args.output_dir = f"{args.model_path.rstrip('/')}_predictions"
147 |
148 | args.ngpus_per_model = 4 if args.model_size in ['70b', '33b', '34b'] else 1
149 | assert args.ngpus % args.ngpus_per_model == 0
150 |
151 | default_few_shot_prompt = args.few_shot_prompt
152 |
153 | test_conf = read_data(args.test_conf)
154 |
155 | for src, info in test_conf.items():
156 | if args.n_repeats > 1:
157 | _src = f"{src}/sample_logs"
158 | else:
159 | _src = f"{src}/infer_logs"
160 | if _worker_num > 1:
161 | _src = f"{_src}/{args.rank or _worker_id}"
162 | if args.prompt_format == 'few_shot':
163 | args.few_shot_prompt = info.get('few_shot_prompt', None) or default_few_shot_prompt
164 | for task in info['tasks']:
165 | fname = os.path.join(args.output_dir, _src, task, "test_data", "test.jsonl")
166 | input_dir = os.path.dirname(fname)
167 | os.makedirs(input_dir, exist_ok=True)
168 | metric_path = os.path.join(args.output_dir, _src, task, "samples", "metrics.json")
169 | if not args.overwrite and os.path.exists(metric_path) and read_data(metric_path)['n_samples'] > 0:
170 | continue
171 | with open(fname, "w") as file:
172 | data = read_data(info['test_path'])
173 | for i, sample in enumerate(tqdm(data, desc=f'processing {src}')):
174 | fn = eval(info['process_fn'])
175 | sample['id'] = sample.get('id', f"{src}-{i}")
176 | for j, item in enumerate(fn(sample)):
177 | item['dataset'] = src
178 | item['id'] = f"{src}-test-{i}-{j}"
179 | assert 'answer' in item
180 | if not args.no_markup_question:
181 | item = markup_question(args, item, info['language'], src, task)
182 | print(json.dumps(item), file=file, flush=True)
183 |
184 | output_dir = os.path.join(args.output_dir, _src, task, "samples")
185 | log_dir = os.path.join(args.output_dir, _src, task, "logs")
186 | os.makedirs(output_dir, exist_ok=True)
187 | os.makedirs(log_dir, exist_ok=True)
188 | metrics, agg_preds, result_msg = do_parallel_sampling(args, task, info['answer_extraction_fn'], info['eval_fn'], input_dir, output_dir, log_dir)
189 |
190 | os.makedirs(os.path.dirname(metric_path), exist_ok=True)
191 | json.dump(metrics, open(metric_path, "w"), indent=4)
192 | data_path = os.path.join(args.output_dir, _src, task, "samples", "predictions.json")
193 | os.makedirs(os.path.dirname(data_path), exist_ok=True)
194 | with open(data_path, "w") as file:
195 | json.dump(agg_preds, file, ensure_ascii=False)
196 | print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True)
197 |
198 | if __name__ == '__main__':
199 | main()
200 |
--------------------------------------------------------------------------------
/evaluation/submit_eval_jobs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | configs = [
5 | {
6 | 'output-dir': "outputs/DeepSeekMath-Base",
7 | 'model-path': "deepseek-ai/deepseek-math-7b-base",
8 | 'tokenizer-path': "deepseek-ai/deepseek-math-7b-base",
9 | 'model-size': "7b",
10 | 'overwrite': False,
11 | 'use-vllm': True,
12 | 'no-markup-question': True,
13 | 'test-conf': "configs/few_shot_test_configs.json",
14 | 'prompt_format': 'few_shot',
15 | 'expname': 'eval-deepseek-math-7b-base'
16 | },
17 | {
18 | 'output-dir': "outputs/DeepSeekMath-Instruct",
19 | 'model-path': "deepseek-ai/deepseek-math-7b-instruct",
20 | 'tokenizer-path': "deepseek-ai/deepseek-math-7b-instruct",
21 | 'model-size': "7b",
22 | 'overwrite': False,
23 | 'use-vllm': True,
24 | 'test-conf': "configs/zero_shot_test_configs.json",
25 | 'expname': 'eval-deepseek-math-7b-instruct'
26 | },
27 | {
28 | 'output-dir': "outputs/DeepSeekMath-RL",
29 | 'model-path': "deepseek-ai/deepseek-math-7b-rl",
30 | 'tokenizer-path': "deepseek-ai/deepseek-math-7b-rl",
31 | 'model-size': "7b",
32 | 'overwrite': False,
33 | 'use-vllm': True,
34 | 'test-conf': "configs/zero_shot_test_configs.json",
35 | 'expname': 'eval-deepseek-math-7b-rl'
36 | }
37 | ]
38 |
39 | base_conf, instruct_conf, rl_conf = configs
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--n-repeats", type=int ,default=1)
44 | parser.add_argument("--temperature", type=float, default=0)
45 | parser.add_argument("--n-gpus", type=int, default=8)
46 | args = parser.parse_args()
47 |
48 | conf = base_conf # TODO: your conf here
49 | cmd = "python run_subset_parallel.py"
50 | for key, val in conf.items():
51 | if key == 'expname':
52 | continue
53 | if isinstance(val, str):
54 | cmd += f" --{key} {val}"
55 | elif val:
56 | cmd += f" --{key}"
57 | cmd += f" --test-conf {conf['test-conf']}"
58 | cmd += f" --n-repeats {args.n_repeats}"
59 | cmd += f" --temperature {args.temperature}"
60 | cmd += f" --ngpus {args.n_gpus}"
61 | cmd += f" --rank {0} &"
62 | print(cmd, flush=True)
63 | os.system(cmd)
64 |
65 | if __name__ == '__main__':
66 | main()
67 |
--------------------------------------------------------------------------------
/evaluation/summarize_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | from glob import glob
6 | from copy import deepcopy
7 |
8 | def seek_metrics(path):
9 | if os.path.isdir(path):
10 | for subpath in glob(os.path.join(path, "*")):
11 | yield from seek_metrics(subpath)
12 | else:
13 | if "metrics.json" in path:
14 | yield path
15 |
16 | def seek_predictions(path):
17 | if os.path.isdir(path):
18 | for subpath in glob(os.path.join(path, "*")):
19 | yield from seek_predictions(subpath)
20 | else:
21 | if "predictions.json" in path:
22 | yield path
23 |
24 | def aggregate_metrics(paths):
25 | result = {}
26 | total = 0
27 | for path in paths:
28 | metric = json.load(open(path, "r"))
29 | n_samples = metric['n_samples']
30 | total += n_samples
31 | for key, val in metric.items():
32 | if key != 'n_samples':
33 | result[key] = result.get(key, 0) + val * n_samples
34 | for key, val in result.items():
35 | result[key] = val / total
36 | result['n_samples'] = total
37 | return result
38 |
39 | def aggregate_predictions(paths):
40 | data = []
41 | for path in paths:
42 | try:
43 | data.extend(json.load(open(path, "r")))
44 | except:
45 | print(path, flush=True)
46 | continue
47 | return data
48 |
49 | def main():
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument("--dirname", type=str, default="outputs")
52 | parser.add_argument("--eval-atp", action='store_true')
53 | parser.add_argument("--isa-path", type=str, default="")
54 | parser.add_argument("--theory-file", type=str, default="")
55 | args = parser.parse_args()
56 |
57 | model2dataset2task2metric = {}
58 | for model in os.listdir(args.dirname):
59 | model2dataset2task2metric[model] = {}
60 | subdir = os.path.join(args.dirname, model)
61 | for dataset in os.listdir(subdir):
62 | log_dir = os.path.join(subdir, dataset, "infer_logs")
63 | agg_dirname = os.path.join(subdir, dataset, "results")
64 | if not os.path.exists(log_dir):
65 | os.makedirs(log_dir, exist_ok=True)
66 | os.system(f"mv {subdir}/{dataset}/* {log_dir}")
67 | metric_paths = list(seek_metrics(log_dir))
68 | pred_paths = list(seek_predictions(log_dir))
69 | task2metric_paths = {'cot': [], 'tool': []}
70 | task2pred_paths = {'cot': [], 'tool': []}
71 | for path in metric_paths:
72 | if 'cot' in path:
73 | task2metric_paths['cot'].append(path)
74 | else:
75 | task2metric_paths['tool'].append(path)
76 | for path in pred_paths:
77 | if 'cot' in path:
78 | task2pred_paths['cot'].append(path)
79 | else:
80 | task2pred_paths['tool'].append(path)
81 | task2metric = {task: aggregate_metrics(paths) for task, paths in task2metric_paths.items()}
82 | task2pred = {task: aggregate_predictions(paths) for task, paths in task2pred_paths.items()}
83 | model2dataset2task2metric[model][dataset] = task2metric
84 |
85 | for task in task2metric:
86 | task_dirname = os.path.join(agg_dirname, task)
87 | os.makedirs(task_dirname, exist_ok=True)
88 | metric_path = os.path.join(task_dirname, "metrics.json")
89 | pred_path = os.path.join(task_dirname, "predictions.json")
90 | json.dump(task2metric[task], open(metric_path, "w"), indent=4)
91 | json.dump(task2pred[task], open(pred_path, "w"), indent=4)
92 | if 'minif2f' in dataset.lower() and 'isabelle' in dataset.lower() and task2pred[task] and args.eval_atp:
93 | eval_path = metric_path + ".eval"
94 | if os.path.exists(eval_path) and json.load(open(eval_path, "r")).get('n_samples', 0):
95 | model2dataset2task2metric[model][dataset][task] = json.load(open(eval_path, "r"))
96 | continue
97 | print(f"Running minif2f-isabelle evaluation on {dataset} ...", flush=True)
98 | print(f"Predictions >>> {pred_path}", flush=True)
99 | cmd = f"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python unsafe_score_minif2f_isabelle.py " \
100 | f"--isa-path {args.isa_path} " \
101 | f"--theory-file {args.theory_file} " \
102 | f"--working-dir {args.working_dir} " \
103 | f"--port 9000 " \
104 | f"--output {pred_path} "
105 | os.system(cmd)
106 |
107 | json.dump(model2dataset2task2metric, open("evaluation_results.json", "w"), indent=4, ensure_ascii=False)
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
--------------------------------------------------------------------------------
/evaluation/unsafe_score_minif2f_isabelle.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import json
4 | import sys
5 | import os
6 | import time
7 | from tqdm import tqdm
8 | import traceback
9 |
10 | class Checker(object):
11 | """A modified version of the Draft, Sketch, Prove proof-checking client.
12 | (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)
13 |
14 | This checker supports Isabelle2022 via PISA
15 | (https://albertqjiang.github.io/Portal-to-ISAbelle/).
16 |
17 | It supports checking a miniF2F-style proof via `check`.
18 |
19 | Finally, it replaces `sledgehammer` with a call to `normalhammer`.
20 | """
21 | def __init__(self, working_dir, isa_path, theory_file, port=9000):
22 | sys.path.append(os.environ['PISA_PATH'])
23 | try:
24 | from pisa_client import initialise_env
25 | self.initialise_env = initialise_env
26 | except Exception as e:
27 | traceback.print_exc()
28 | print(e)
29 | print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")
30 |
31 | self.working_dir = working_dir
32 | self.isa_path = isa_path
33 | self.theory_file = theory_file
34 | self.port = port
35 |
36 | def _initialize(self):
37 | env = self.initialise_env(
38 | self.port,
39 | isa_path=self.isa_path,
40 | theory_file_path=self.theory_file,
41 | working_directory=self.working_dir
42 | )
43 | return env
44 |
45 | def _exit(self, env):
46 | try:
47 | env.post('exit')
48 | except:
49 | print("env.post('exit') timed out")
50 | pass
51 | os.system("ps aux | grep Isabelle2022/contrib | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
52 | os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
53 |
54 | def _parse_output(self, obs):
55 | """Parse the sledgehammer output, otherwise return an empty string"""
56 | if '' in obs:
57 | output = obs.split('')[0]
58 | else:
59 | output = ''
60 | return output
61 |
62 | def _run_step(self, step, i, tls_name, env):
63 | obs, reward, done, metadata = env.step_to_top_level_state(
64 | action=step,
65 | tls_name=tls_name,
66 | new_name='default_%d' % i
67 | )
68 | error = None
69 | if 'error:' in obs or 'Step error' in obs or 'Unknown error' in obs:
70 | error = obs
71 | return obs, reward, done, metadata, error
72 |
73 | def _run_sledgehammer(self, step, i, tls_name, env):
74 | # First try heuristics
75 | for heuristic in [
76 | 'by auto', 'by simp', 'by blast', 'by fastforce',
77 | 'by force', 'by eval', 'by presburger', 'by sos',
78 | 'by arith', 'by linarith', 'by (auto simp: field_simps)'
79 | ]:
80 | step_ = step.replace('normalhammer', heuristic)
81 | obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
82 | if error is None:
83 | obs = '%s %s' % (heuristic, obs)
84 | return obs, reward, done, metadata, error
85 | # Try sledgehammer
86 | out = self._run_step(step, i, tls_name, env)
87 | return out
88 |
89 | def check(self, statement_and_proof):
90 | # Initialize environment
91 | env = self._initialize()
92 | env.initialise()
93 |
94 | # Wrap and parse theorem
95 | theory = Checker.wrap_theorem(statement_and_proof)
96 | steps = Checker.get_parsed(env, theory)
97 |
98 | result = self._check(env, steps)
99 | return result
100 |
101 | def _check(self, env, steps):
102 | done = False
103 | reason = ''
104 | success = False
105 | step_results = []
106 | tls_name = 'default'
107 | for i, step in enumerate(steps):
108 | try:
109 | time0 = time.time()
110 | if 'normalhammer' in step:
111 | obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env)
112 | else:
113 | obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env)
114 | step_time = time.time() - time0
115 | step_results.append(dict(
116 | index=i, step=step, output=self._parse_output(obs), step_time=step_time
117 | ))
118 | if error is not None:
119 | reason = error
120 | success = False
121 | done = False
122 | break
123 | except:
124 | # Timeout - end the proof attempt
125 | success = False
126 | done = False
127 | reason = 'timeout (%d)' % len(step_results)
128 | step_results.append(dict(index=i, step=step, output=''))
129 | break
130 |
131 | # Change when successful
132 | tls_name = 'default_%d' % i
133 |
134 | if done and reward == 1.0:
135 | success = True
136 |
137 | result = {
138 | 'success': success,
139 | 'reason': reason,
140 | 'num_steps': len(steps),
141 | 'last_step': len(step_results),
142 | 'step_results': step_results
143 | }
144 | # Exit environment
145 | self._exit(env)
146 | return result
147 |
148 | @staticmethod
149 | def wrap_theorem(theorem):
150 | return 'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem
151 |
152 | @staticmethod
153 | def get_parsed(env, theory, tls_name='default'):
154 | # The parsing doesn't work well with `normalhammer`, so we replace
155 | # all hammer calls with sorry, then replace sorry to normalhammer after parsing.
156 | theory = theory.replace('sledgehammer', 'sorry')
157 | theory = theory.replace('normalhammer', 'sorry')
158 |
159 | steps = env.post(f" ${theory}")
160 | steps = steps.split('')
161 | steps = [s for s in steps if s.strip() != '']
162 | # remove '$' step and whitespace steps
163 | steps = [s for s in steps if s != '$' and s.strip() != '']
164 | steps = [s.replace('sorry', 'normalhammer') for s in steps]
165 | return steps
166 |
167 |
168 | def check_proof(formal_statement, proof, working_dir, isa_path, theory_file, port):
169 | checker = Checker(
170 | working_dir=working_dir,
171 | isa_path=isa_path,
172 | theory_file=theory_file,
173 | port=port
174 | )
175 | theorem_with_proof = f"{formal_statement}\n{proof}"
176 | result = checker.check(theorem_with_proof)
177 | return result
178 |
179 |
180 | def main(args):
181 | with open(args.output) as f:
182 | docs = json.load(f)
183 |
184 | if args.limit:
185 | limit = args.limit
186 | else:
187 | limit = len(docs)
188 |
189 | pass_at_1s = []
190 | pass_at_anys = []
191 | for i, doc in enumerate(tqdm(docs[:limit])):
192 | formal_statement = doc['messages'][-2]['content'].split("Formal:", 1)[1].strip()
193 | proofs = [doc['prediction'].strip()]
194 |
195 | pass_at_1 = 0
196 | pass_at_any = 0
197 | checked_proofs = []
198 | for j, proof in enumerate(proofs):
199 | result = check_proof(
200 | formal_statement=formal_statement,
201 | proof=proof,
202 | working_dir=args.working_dir,
203 | isa_path=args.isa_path,
204 | theory_file=args.theory_file,
205 | port=args.port
206 | )
207 |
208 | if result['success']:
209 | pass_at_any = 1
210 | if j == 0:
211 | pass_at_1 = 1
212 |
213 | checked_proofs.append({
214 | 'proof': proof,
215 | 'result': result
216 | })
217 |
218 | pass_at_1s.append(pass_at_1)
219 | pass_at_anys.append(pass_at_any)
220 |
221 | print(f"acc: {sum(pass_at_1s)} / {len(pass_at_1s)} = {sum(pass_at_1s) / max(len(pass_at_1s), 1)}", flush=True)
222 |
223 | doc['eval'] = {
224 | 'checked_proofs': checked_proofs,
225 | 'pass_at_1': pass_at_1,
226 | 'pass_at_any': pass_at_any
227 | }
228 |
229 | metrics = {
230 | "pass_at_1": sum(pass_at_1s) / len(pass_at_1s),
231 | "pass_at_any": sum(pass_at_anys) / len(pass_at_anys),
232 | "n_samples": len(pass_at_1s)
233 | }
234 |
235 | output_path = args.output + ".eval"
236 | metrics_path = os.path.join(os.path.dirname(args.output), "metrics.json.eval")
237 | json.dump(docs, open(output_path, "w"), indent=4)
238 | json.dump(metrics, open(metrics_path, "w"), indent=4)
239 |
240 |
241 | if __name__ == "__main__":
242 | logging.basicConfig(level=logging.INFO)
243 | logging.critical(
244 | "THIS PROGRAM EXECUTES UNTRUSTED MODEL GENERATED CODE."
245 | "THERE HAS BEEN NO EFFORT TO AVOID OS AND NETWORK SIDE EFFECTS."
246 | "USE WITH CAUTION."
247 | )
248 |
249 | parser = argparse.ArgumentParser("Unsafe script for scoring the minif2f_isabelle tasks")
250 |
251 | parser.add_argument(
252 | "--isa-path",
253 | type=str,
254 | help="path to Isabelle installation (see setup documentation), e.g. "
255 | "/path/to/Isabelle2022"
256 | )
257 | parser.add_argument(
258 | "--theory-file",
259 | type=str,
260 | help="path to Interactive.thy (see setup documentation), e.g. "
261 | "/path/to/Isabelle2022/src/HOL/Examples/Interactive.thy"
262 | )
263 | parser.add_argument(
264 | "--working-dir",
265 | type=str,
266 | help="path to Isabelle working directory (see setup documentation), e.g. "
267 | "/path/to/Isabelle2022/src/HOL/Examples"
268 | )
269 | parser.add_argument(
270 | "--port",
271 | type=int,
272 | default=9000,
273 | help="PISA server port (see setup documentation)"
274 | )
275 | parser.add_argument(
276 | "--output",
277 | type=str,
278 | help="path to output file from running miniF2F Isabelle tasks"
279 | )
280 | parser.add_argument(
281 | "--limit",
282 | type=int,
283 | default=None,
284 | help="for debugging purposes, max examples per task to process"
285 | )
286 |
287 | args = parser.parse_args()
288 | main(args)
--------------------------------------------------------------------------------
/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import numpy as np
4 |
5 | def set_seed(seed):
6 | if seed > 0:
7 | random.seed(seed)
8 | np.random.seed(seed)
9 |
10 | def shuffle(data, seed):
11 | if seed < 0:
12 | return data
13 | set_seed(seed)
14 | indices = list(range(len(data)))
15 | np.random.shuffle(indices)
16 | data = [data[i] for i in indices]
17 | return data
18 |
19 | def read_data(path):
20 | if path.endswith("json"):
21 | data = json.load(open(path, "r"))
22 | elif path.endswith("jsonl"):
23 | data = []
24 | with open(path, "r") as file:
25 | for line in file:
26 | line = json.loads(line)
27 | data.append(line)
28 | else:
29 | raise NotImplementedError()
30 | return data
31 |
--------------------------------------------------------------------------------
/images/badge.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/base_results_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/base_results_1.png
--------------------------------------------------------------------------------
/images/base_results_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/base_results_2.png
--------------------------------------------------------------------------------
/images/base_results_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/base_results_3.png
--------------------------------------------------------------------------------
/images/corpus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/corpus.png
--------------------------------------------------------------------------------
/images/data_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/data_pipeline.png
--------------------------------------------------------------------------------
/images/instruct_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/instruct_results.png
--------------------------------------------------------------------------------
/images/logo.svg:
--------------------------------------------------------------------------------
1 |
23 |
--------------------------------------------------------------------------------
/images/math.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/math.png
--------------------------------------------------------------------------------
/images/qr.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/qr.jpeg
--------------------------------------------------------------------------------
/replicate/predict.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://github.com/replicate/cog/blob/main/docs/python.md
3 |
4 | import os
5 | import time
6 | from threading import Thread
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
9 | from transformers.generation.streamers import TextIteratorStreamer
10 | from cog import BasePredictor, Input, ConcatenateIterator
11 |
12 | # Enable faster download speed
13 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14 | CACHE_DIR = "model_cache"
15 |
16 |
17 | class Predictor(BasePredictor):
18 | def setup(self) -> None:
19 | """Load the model into memory to make running multiple predictions efficient"""
20 |
21 | model_name = "deepseek-ai/deepseek-math-7b-base"
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
23 | self.model = AutoModelForCausalLM.from_pretrained(
24 | model_name,
25 | torch_dtype=torch.bfloat16,
26 | device_map="auto",
27 | cache_dir=CACHE_DIR,
28 | )
29 | self.model.generation_config = GenerationConfig.from_pretrained(
30 | model_name, cache_dir=CACHE_DIR
31 | )
32 | self.model.generation_config.pad_token_id = (
33 | self.model.generation_config.eos_token_id
34 | )
35 |
36 | def predict(
37 | self,
38 | text: str = Input(
39 | description="Input text.",
40 | default="The integral of x^2 from 0 to 2 is",
41 | ),
42 | max_new_tokens: int = Input(
43 | description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
44 | default=100,
45 | ),
46 | temperature: float = Input(
47 | description="The value used to modulate the next token probabilities.",
48 | default=1,
49 | ),
50 | top_k: int = Input(
51 | description="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
52 | default=50,
53 | ),
54 | top_p: float = Input(
55 | description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
56 | default=0.9,
57 | ),
58 | ) -> ConcatenateIterator[str]:
59 | """Run a single prediction on the model"""
60 |
61 | inputs = self.tokenizer(text, return_tensors="pt")
62 | streamer = TextIteratorStreamer(
63 | self.tokenizer, skip_prompt=True, skip_special_tokens=True
64 | )
65 | with torch.inference_mode():
66 | thread = Thread(
67 | target=self.model.generate,
68 | kwargs=dict(
69 | **inputs.to(self.model.device),
70 | do_sample=True,
71 | temperature=temperature,
72 | top_p=top_p,
73 | top_k=top_k,
74 | max_new_tokens=max_new_tokens,
75 | streamer=streamer,
76 | use_cache=True
77 | ),
78 | )
79 | thread.start()
80 | for new_token in streamer:
81 | yield new_token
82 | thread.join()
83 |
--------------------------------------------------------------------------------
/replicate/predict_instruct.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://github.com/replicate/cog/blob/main/docs/python.md
3 |
4 | import os
5 | import time
6 | from threading import Thread
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
9 | from transformers.generation.streamers import TextIteratorStreamer
10 | from cog import BasePredictor, Input, ConcatenateIterator
11 |
12 | # Enable faster download speed
13 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14 | CACHE_DIR = "model_cache"
15 |
16 |
17 | class Predictor(BasePredictor):
18 | def setup(self) -> None:
19 | """Load the model into memory to make running multiple predictions efficient"""
20 |
21 | model_name = "deepseek-ai/deepseek-math-7b-instruct"
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
23 | self.model = AutoModelForCausalLM.from_pretrained(
24 | model_name,
25 | torch_dtype=torch.bfloat16,
26 | device_map="auto",
27 | cache_dir=CACHE_DIR,
28 | )
29 | self.model.generation_config = GenerationConfig.from_pretrained(
30 | model_name, cache_dir=CACHE_DIR
31 | )
32 | self.model.generation_config.pad_token_id = (
33 | self.model.generation_config.eos_token_id
34 | )
35 |
36 | def predict(
37 | self,
38 | text: str = Input(
39 | description="Input text.",
40 | default="what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \boxed{}.",
41 | ),
42 | max_new_tokens: int = Input(
43 | description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.",
44 | default=100,
45 | ),
46 | temperature: float = Input(
47 | description="The value used to modulate the next token probabilities.",
48 | default=1,
49 | ),
50 | top_k: int = Input(
51 | description="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
52 | default=50,
53 | ),
54 | top_p: float = Input(
55 | description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
56 | default=0.9,
57 | ),
58 | ) -> ConcatenateIterator[str]:
59 | """Run a single prediction on the model"""
60 |
61 | messages = [{"role": "user", "content": text}]
62 | input_tensor = self.tokenizer.apply_chat_template(
63 | messages, add_generation_prompt=True, return_tensors="pt"
64 | )
65 | streamer = TextIteratorStreamer(
66 | self.tokenizer, skip_prompt=True, skip_special_tokens=True
67 | )
68 |
69 | with torch.inference_mode():
70 | thread = Thread(
71 | target=self.model.generate,
72 | kwargs=dict(
73 | input_ids=input_tensor.to(self.model.device),
74 | do_sample=True,
75 | temperature=temperature,
76 | top_p=top_p,
77 | top_k=top_k,
78 | max_new_tokens=max_new_tokens,
79 | streamer=streamer,
80 | use_cache=True,
81 | ),
82 | )
83 | thread.start()
84 | for new_token in streamer:
85 | yield new_token
86 | thread.join()
87 |
--------------------------------------------------------------------------------