├── data ├── hungarian_exam │ ├── README.md │ └── test.jsonl ├── ocw │ └── README.md └── sat_math │ └── test.jsonl ├── requirements.txt ├── LICENSE ├── scripts ├── summarize_results.py ├── run_eval.sh └── run_eval_multi_gpus.py ├── prompts ├── pal │ ├── gsm8k.md │ └── math.md ├── cot │ ├── minerva_math.md │ ├── gsm8k.md │ ├── mmlu_stem.md │ ├── math.md │ ├── mathqa.md │ └── math_8shot.md └── tora │ ├── gsm8k.md │ └── math.md ├── .gitignore ├── data_loader.py ├── evaluate.py ├── utils.py ├── trajectory.py ├── python_executor.py ├── README.md ├── grader.py ├── model_utils.py ├── math_eval.py └── parser.py /data/hungarian_exam/README.md: -------------------------------------------------------------------------------- 1 | https://huggingface.co/datasets/keirp/hungarian_national_hs_finals_exam -------------------------------------------------------------------------------- /data/ocw/README.md: -------------------------------------------------------------------------------- 1 | MIT OpenCourseWare: 2 | - Solving Quantitative Reasoning Problems with Language Models. https://openreview.net/forum?id=IFXTZERXdM7 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | # common 3 | vllm 4 | tqdm 5 | datasets 6 | torch 7 | transformers 8 | python_dateutil 9 | 10 | # math_eval 11 | sympy==1.12 12 | antlr4-python3-runtime==4.11.1 # ! The version needs to be compatible with sympy. 13 | word2number 14 | Pebble 15 | timeout-decorator 16 | git+https://github.com/ZubinGou/latex2sympy.git -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhibin Gou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/summarize_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import argparse 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--result_dir", type=str) 9 | parser.add_argument("--data_names", type=str, default="gsm8k,minerva_math,svamp,asdiv,mawps") 10 | parser.add_argument("--split", type=str, default="test") 11 | args = parser.parse_args() 12 | summarize_results(args.result_dir, args.data_names, args.split) 13 | 14 | 15 | def summarize_results(result_dir, data_names, split): 16 | data_list = data_names.split(',') 17 | 18 | # read the result 19 | results = [] 20 | for data_name in data_list: 21 | files = glob.glob(f"{result_dir}/{data_name}/{split}*metrics.json") 22 | assert len(files) == 1, f"Found {len(files)} files for {data_name}" 23 | with open(files[0], 'r') as f: 24 | metrics = json.load(f) 25 | results.append(metrics) 26 | 27 | data_list.append("avg") 28 | results.append({ 29 | "acc": sum([result["acc"] for result in results]) / len(results), 30 | }) 31 | 32 | # print all results 33 | pad = max([len(data_name) for data_name in data_list]) 34 | print("\t".join(data_name.ljust(pad, " ") for data_name in data_list)) 35 | print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results])) 36 | print(" & ".join([f"{result['acc']:.1f}" for result in results])) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /prompts/pal/gsm8k.md: -------------------------------------------------------------------------------- 1 | Let's use python to solve math problems. 2 | 3 | Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 4 | 5 | ```python 6 | def solution(): 7 | """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?""" 8 | money_initial = 23 9 | bagels = 5 10 | bagel_cost = 3 11 | money_spent = bagels * bagel_cost 12 | money_left = money_initial - money_spent 13 | result = money_left 14 | return result 15 | ``` 16 | 17 | --- 18 | 19 | Question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 20 | 21 | ```python 22 | def solution(): 23 | """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?""" 24 | golf_balls_initial = 58 25 | golf_balls_lost_tuesday = 23 26 | golf_balls_lost_wednesday = 2 27 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday 28 | result = golf_balls_left 29 | return result 30 | ``` 31 | 32 | --- 33 | 34 | Question: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 35 | 36 | ```python 37 | def solution(): 38 | """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?""" 39 | computers_initial = 9 40 | computers_per_day = 5 41 | num_days = 4 # 4 days between monday and thursday 42 | computers_added = computers_per_day * num_days 43 | computers_total = computers_initial + computers_added 44 | result = computers_total 45 | return result 46 | ``` 47 | 48 | --- -------------------------------------------------------------------------------- /scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | PROMPT_TYPE=$1 4 | MODEL_NAME_OR_PATH=$2 5 | 6 | # ======= Base Models ======= 7 | # PROMPT_TYPE="cot" # direct / cot / pal / tool-integrated 8 | # MODEL_NAME_OR_PATH=${HF_MODEL_DIR}/mistral/Mistral-7B-v0.1 9 | # MODEL_NAME_OR_PATH=${HF_MODEL_DIR}/llemma/llemma_7b 10 | # MODEL_NAME_OR_PATH=${HF_MODEL_DIR}/internlm/internlm2-math-base-7b 11 | # MODEL_NAME_OR_PATH=${HF_MODEL_DIR}/deepseek/deepseek-math-7b-base 12 | 13 | 14 | # ======= SFT Models ======= 15 | # PROMPT_TYPE="deepseek-math" # self-instruct / tora / wizard_zs / deepseek-math / kpmath 16 | # MODEL_NAME_OR_PATH=${HF_MODEL_DIR}/deepseek/deepseek-math-7b-rl 17 | # MODEL_NAME_OR_PATH=${HF_MODEL_DIR}/deepseek/deepseek-math-7b-instruct 18 | 19 | 20 | OUTPUT_DIR=${MODEL_NAME_OR_PATH}/math_eval 21 | DATA_NAMES="gsm8k,minerva_math" 22 | # DATA_NAMES="gsm8k,minerva_math,svamp,asdiv,mawps,tabmwp,mathqa,mmlu_stem,sat_math" 23 | SPLIT="test" 24 | NUM_TEST_SAMPLE=-1 25 | 26 | 27 | # single-gpu 28 | CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \ 29 | python3 -u math_eval.py \ 30 | --model_name_or_path ${MODEL_NAME_OR_PATH} \ 31 | --output_dir ${OUTPUT_DIR} \ 32 | --data_names ${DATA_NAMES} \ 33 | --split ${SPLIT} \ 34 | --prompt_type ${PROMPT_TYPE} \ 35 | --num_test_sample ${NUM_TEST_SAMPLE} \ 36 | --seed 0 \ 37 | --temperature 0 \ 38 | --n_sampling 1 \ 39 | --top_p 1 \ 40 | --start 0 \ 41 | --end -1 \ 42 | --use_vllm \ 43 | --save_outputs \ 44 | # --overwrite \ 45 | 46 | 47 | # multi-gpu 48 | # python3 scripts/run_eval_multi_gpus.py \ 49 | # --model_name_or_path $MODEL_NAME_OR_PATH \ 50 | # --output_dir $OUTPUT_DIR \ 51 | # --data_names ${DATA_NAMES} \ 52 | # --prompt_type "cot" \ 53 | # --temperature 0 \ 54 | # --use_vllm \ 55 | # --save_outputs \ 56 | # --available_gpus 0,1,2,3,4,5,6,7 \ 57 | # --gpus_per_model 1 \ 58 | # --overwrite 59 | -------------------------------------------------------------------------------- /prompts/cot/minerva_math.md: -------------------------------------------------------------------------------- 1 | Problem: 2 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.} 3 | Solution: 4 | The expressions inside each square root must be non-negative. 5 | Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. 6 | Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. 7 | Therefore, the domain of the expression is $\\boxed{[2,5)}$. 8 | Final Answer: The final answer is $[2,5)$. I hope it is correct. 9 | 10 | 11 | Problem: 12 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$ 13 | Solution: 14 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$ 15 | Final Answer: The final answer is $24$. I hope it is correct. 16 | 17 | 18 | Problem: 19 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight? 20 | Solution: 21 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 22 | 30n&=480\\\\ 23 | \\Rightarrow\\qquad n&=480/30=\\boxed{16} 24 | \\end{align*} 25 | Final Answer: The final answer is $16$. I hope it is correct. 26 | 27 | 28 | Problem: 29 | If the system of equations 30 | 31 | \\begin{align*} 32 | 6x-4y&=a,\\\\ 33 | 6y-9x &=b. 34 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero. 35 | Solution: 36 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain 37 | 38 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have 39 | 40 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$ 41 | Final Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct. -------------------------------------------------------------------------------- /prompts/cot/gsm8k.md: -------------------------------------------------------------------------------- 1 | Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 2 | Answer: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6. 3 | 4 | 5 | Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 6 | Answer: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5. 7 | 8 | 9 | Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 10 | Answer: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39. 11 | 12 | 13 | Question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 14 | Answer: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8. 15 | 16 | 17 | Question: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 18 | Answer: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9. 19 | 20 | 21 | Question: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 22 | Answer: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29. 23 | 24 | 25 | Question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 26 | Answer: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33. 27 | 28 | 29 | Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 30 | Answer: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8. -------------------------------------------------------------------------------- /prompts/cot/mmlu_stem.md: -------------------------------------------------------------------------------- 1 | Problem: 2 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$. 3 | What of the following is the right choice? Explain your answer. 4 | (A) [-5,-2), (B) [2,5), (C) [-2,-5), (D) [5,2) 5 | Solution: 6 | The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. 7 | Therefore, the domain of the expression is $\\boxed{[2,5)}$. 8 | Final Answer: The final answer is (B). 9 | 10 | 11 | Problem: 12 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$ 13 | What of the following is the right choice? Explain your answer. 14 | (A) 14, (B) 4, (C) 2, (D) 24 15 | Solution: 16 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$ 17 | Final Answer: The final answer is (D). 18 | 19 | 20 | Problem: 21 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight? 22 | What of the following is the right choice? Explain your answer. 23 | (A) 12, (B) 20, (C) 16, (D) 15 24 | Solution: 25 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 26 | 30n&=480\\\\ 27 | \\Rightarrow\\qquad n&=480/30=\\boxed{16} 28 | \\end{align*} 29 | Final Answer: The final answer is (C). 30 | 31 | 32 | Problem: 33 | If the system of equations 34 | 35 | \\begin{align*} 36 | 6x-4y&=a,\\\\ 37 | 6y-9x &=b. 38 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is 39 | nonzero. 40 | What of the following is the right choice? Explain your answer. 41 | (A) $-\\frac{2}{3}$, (B) $\\frac{2}{3}$, (C) $\\frac{1}{3}$, (D) $\\frac{4}{9}$ 42 | Solution: 43 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain 44 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have 45 | 46 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$ 47 | Final Answer: The final answer is (A). -------------------------------------------------------------------------------- /prompts/tora/gsm8k.md: -------------------------------------------------------------------------------- 1 | Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines: 2 | 3 | - Analyze the question and write functions to solve the problem; the function should not take any arguments. 4 | - Present the final result in LaTeX using a `\boxed{}` without any units. 5 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values. 6 | 7 | Here are some examples you may refer to: 8 | 9 | --- 10 | 11 | Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 12 | 13 | Solution: 14 | ```python 15 | def money_left(): 16 | money_initial = 23 17 | bagels = 5 18 | bagel_cost = 3 19 | money_spent = bagels * bagel_cost 20 | remaining_money = money_initial - money_spent 21 | return remaining_money 22 | 23 | remaining_money = money_left() 24 | print(remaining_money) 25 | ``` 26 | ```output 27 | 8 28 | ``` 29 | Olivia has $\boxed{8}$ dollars left. 30 | 31 | --- 32 | 33 | Question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 34 | 35 | Solution: 36 | ```python 37 | def remaining_golf_balls(): 38 | golf_balls_initial = 58 39 | golf_balls_lost_tuesday = 23 40 | golf_balls_lost_wednesday = 2 41 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday 42 | remaining_golf_balls = golf_balls_left 43 | return remaining_golf_balls 44 | 45 | answer = remaining_golf_balls() 46 | print(answer) 47 | ``` 48 | ```output 49 | 33 50 | ``` 51 | Michael had $\boxed{33}$ golf balls at the end of Wednesday. 52 | 53 | --- 54 | 55 | Question: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 56 | Solution: 57 | ```python 58 | def total_computers(): 59 | computers_initial = 9 60 | computers_per_day = 5 61 | num_days = 4 # 4 days between monday and thursday 62 | computers_added = computers_per_day * num_days 63 | computers_total = computers_initial + computers_added 64 | return computers_total 65 | 66 | total_computers = total_computers() 67 | print(total_computers) 68 | ``` 69 | ```output 70 | 29 71 | ``` 72 | There're $\boxed{29}$ computers in the server room. 73 | 74 | --- -------------------------------------------------------------------------------- /prompts/pal/math.md: -------------------------------------------------------------------------------- 1 | Let's use python to solve math problems. Display the final result in LaTeX. 2 | 3 | Question: Find the coefficient of $x^3$ when $3(x^2 - x^3+x) +3(x +2x^3- 3x^2 + 3x^5+x^3) -5(1+x-4x^3 - x^2)$ is simplifie. 4 | 5 | ```python 6 | from sympy import symbols, simplify 7 | 8 | def solution(): 9 | x = symbols('x') 10 | expr = 3*(x**2 - x**3 + x) + 3*(x + 2*x**3 - 3*x**2 + 3*x**5 + x**3) - 5*(1 + x - 4*x**3 - x**2) 11 | simplified_expr = simplify(expr) 12 | 13 | x3_coefficient = simplified_expr.as_coefficients_dict()[x**3] 14 | result = x3_coefficient 15 | return result 16 | ``` 17 | 18 | --- 19 | 20 | Question: The surface area of a sphere with radius $r$ is $4\pi r^2$. Including the area of its circular base, what is the total surface area of a hemisphere with radius 6 cm? Express your answer in terms of $\pi$. 21 | 22 | ```python 23 | import math 24 | 25 | def solution(): 26 | radius = 6 27 | 28 | # Surface area of the hemisphere 29 | hemisphere_area = 2 * math.pi * radius**2 30 | 31 | # Area of the circular base 32 | base_area = math.pi * radius**2 33 | 34 | # Total surface area 35 | total_surface_area = hemisphere_area + base_area 36 | 37 | # Formatting the result in LaTeX 38 | result = r'{}\\pi'.format(total_surface_area / math.pi) 39 | return result 40 | ``` 41 | 42 | --- 43 | 44 | Question: Monica tosses a fair 6-sided die. If the roll is a prime number, then she wins that amount of dollars (so that, for example, if she rolls 3, then she wins 3 dollars). If the roll is composite, she wins nothing. Otherwise, she loses 3 dollars. What is the expected value of her winnings on one die toss? Express your answer as a dollar value to the nearest cent. 45 | 46 | ```python 47 | def solution(): 48 | # Probabilities of each outcome 49 | prime_prob = 1 / 6 50 | composite_prob = 1 / 3 51 | otherwise_prob = 1 / 6 52 | 53 | # Expected value of each outcome 54 | prime_expected_value = (2 * prime_prob) + (3 * prime_prob) + (5 * prime_prob) 55 | composite_expected_value = 0 * composite_prob 56 | otherwise_expected_value = -3 * otherwise_prob 57 | 58 | # Total expected value 59 | total_expected_value = prime_expected_value + composite_expected_value + otherwise_expected_value 60 | 61 | # Dollar value to the nearest cent 62 | result = "{:.2f}".format(total_expected_value) 63 | return result 64 | ``` 65 | 66 | --- 67 | 68 | Question: Given $\mathbf{a} = \begin{pmatrix} -7 \\ 0 \\ 1 \end{pmatrix}$ and $\mathbf{b} = \begin{pmatrix} 4 \\ 2 \\ -1 \end{pmatrix},$ find $\mathbf{a} - 3 \mathbf{b}.$ 69 | 70 | Solution: 71 | ```python 72 | import numpy as np 73 | 74 | def solution() 75 | a = np.array([-7, 0, 1]) 76 | b = np.array([4, 2, -1]) 77 | 78 | result = a - 3 * b 79 | 80 | result = r'\begin{{pmatrix}} {} \\ {} \\ {} \end{{pmatrix}}'.format(result[0], result[1], result[2]) 81 | return result 82 | ``` 83 | 84 | --- -------------------------------------------------------------------------------- /prompts/cot/math.md: -------------------------------------------------------------------------------- 1 | Problem: 2 | Kevin Kangaroo begins hopping on a number line at 0. He wants to get to 1, but he can hop only $\frac{1}{3}$ of the distance. Each hop tires him out so that he continues to hop $\frac{1}{3}$ of the remaining distance. How far has he hopped after five hops? Express your answer as a common fraction. 3 | Solution: 4 | Let's think step by step 5 | Kevin hops $1/3$ of the remaining distance with every hop. 6 | His first hop takes $1/3$ closer. 7 | For his second hop, he has $2/3$ left to travel, so he hops forward $(2/3)(1/3)$. 8 | For his third hop, he has $(2/3)^2$ left to travel, so he hops forward $(2/3)^2(1/3)$. 9 | In general, Kevin hops forward $(2/3)^{k-1}(1/3)$ on his $k$th hop. 10 | We want to find how far he has hopped after five hops. 11 | This is a finite geometric series with first term $1/3$, common ratio $2/3$, and five terms. 12 | Thus, Kevin has hopped $\frac{\frac{1}{3}\left(1-\left(\frac{2}{3}\right)^5\right)}{1-\frac{2}{3}} = \frac{211}{243}$. 13 | So the final answer is $\boxed{\frac{211}{243}}$. 14 | 15 | 16 | Problem: 17 | What is the area of the region defined by the equation $x^2+y^2 - 7 = 4y-14x+3$? 18 | Solution: 19 | Let's think step by step 20 | We rewrite the equation as $x^2 + 14x + y^2 - 4y = 10$ and then complete the square, 21 | resulting in $(x+7)^2-49 + (y-2)^2-4=10$, 22 | or $(x+7)^2+(y-2)^2=63$. 23 | This is the equation of a circle with center $(-7, 2)$ and radius $\sqrt{63},$ 24 | so the area of this region is $\pi r^2 = 63\pi$. 25 | So the final answer is $\boxed{63\pi}$. 26 | 27 | 28 | Problem: 29 | If $x^2+y^2=1$, what is the largest possible value of $|x|+|y|$? 30 | Solution: 31 | Let's think step by step 32 | If $(x,y)$ lies on the circle, 33 | so does $(x,-y),$ $(-x,-y),$ and $(-x,-y),$ (which all give the same value of $|x| + |y|$), 34 | so we can assume that $x \ge 0$ and $y \ge 0.$ 35 | Then $|x| + |y| = x + y.$ Squaring, we get 36 | \[(x + y)^2 = x^2 + 2xy + y^2 = 1 + 2xy.\] 37 | Note that $(x - y)^2 \ge 0.$ 38 | Expanding, we get $x^2 - 2xy + y^2 \ge 0,$ so $2xy \le x^2 + y^2 = 1.$ 39 | Hence,\[1 + 2xy \le 2,\]which means $x + y \le \sqrt{2}.$ 40 | Equality occurs when $x = y = \frac{1}{\sqrt{2}},$ 41 | so the maximum value of $|x| + |y|$ is $\sqrt{2}.$ 42 | So the final answer is $\boxed{\sqrt{2}}$. 43 | 44 | 45 | Problem: 46 | If $f(x)=\frac{ax+b}{cx+d}, abcd\not=0$ and $f(f(x))=x$ for all $x$ in the domain of $f$, what is the value of $a+d$? 47 | Solution: 48 | Let's think step by step 49 | The condition $f(f(x))$ means that $f$ is the inverse of itself, 50 | so its graph is symmetrical about the line $y = x$. 51 | With a rational function of this form, we will have two asymptotes: 52 | a vertical one at $x=-d/c$ if $cx+d$ does not divide $ax+b$, 53 | and a horizontal one at $y=a/c$, 54 | if we take the limit of $f(x)$ as $x$ goes to $\pm\infty$. 55 | In order for $f$ to be its own inverse, the intersection of the asymptotes must lie on the line $y=x$ 56 | so that it and its asymptotes reflect onto themselves. 57 | This means that $-d/c=a/c$, 58 | and therefore $-d=a$ and $a+d=0$. 59 | So the final answer is $\boxed{0}$. 60 | 61 | 62 | Problem: 63 | Expand $(2z^2 + 5z - 6)(3z^3 - 2z + 1)$. 64 | Solution: 65 | Let's think step by step 66 | $$\begin{array}{crrrrrrr} 67 | & & & 3z^3 & & -2z & + 1 & \\ 68 | \times & & & & 2z^2 & +5z & -6 \\ 69 | \cline{1-7}\rule{0pt}{0.17in} 70 | & & & -18z^3 & & +12z & -6 & \\ 71 | & & +15z^4 & & -10z^2 & +5z & & \\ 72 | + & 6z^5 & & -4z^3 & +2z^2 & & & \\ 73 | \cline{1-7}\rule{0pt}{0.17in} 74 | & 6z^5 & +15z^4 & -22z^3 & - 8z^2 &+17z & -6 & 75 | \end{array}$$ 76 | So the final answer is $\boxed{6z^5+15z^4-22z^3-8z^2+17z-6}$. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | private/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /prompts/cot/mathqa.md: -------------------------------------------------------------------------------- 1 | Problem: 2 | An old man distributed all the gold coins he had to his two sons into two different numbers such that the difference between the squares of the two numbers is 16 times the difference between the two numbers . how many coins did the old man have ? 3 | What of the following is the right choice? Explain your answer. 4 | (A) 24, (B) 26, (C) 30, (D) 16, (E) 40 5 | Solution: 6 | Let's denote the number of coins received by the first son as $x$ and by the second son as $y$. 7 | According to the problem, the difference between the squares of the two numbers is 16 times the difference between the two numbers. This can be written as: 8 | 9 | $$x2 - y2 = 16(x - y)$$ 10 | 11 | Factor the left side as a difference of squares: 12 | 13 | $$(x + y)(x - y) = 16(x - y)$$ 14 | 15 | Since $x \neq y$, we can divide both sides by $(x - y)$: 16 | 17 | $$x + y = 16$$ 18 | 19 | The total number of coins is $x + y$. Therefore, the old man had 16 coins. 20 | Final Answer: The final answer is (D). 21 | 22 | 23 | Problem: 24 | The speed of a car increases by 2 kms after every one hour . if the distance travelling in the first one hour was 30 kms . what was the total distance traveled in 12 hours ? 25 | What of the following is the right choice? Explain your answer. 26 | (A) 252 kms, (B) 152 kms, (C) 492 kms, (D) 752 kms, (E) 152 kms 27 | Solution: 28 | The car's speed increases by 2 km/h each hour, starting at 30 km/h. The distance for each hour forms an arithmetic sequence with the first term (a_1) being 30 km and the common difference (d) being 2 km. 29 | 30 | The total distance (S_n) after n terms is given by the formula: 31 | [ S_n = \frac{n}{2} \times (2a_1 + (n - 1)d) ] 32 | 33 | For 12 hours (n=12), the total distance (S_12) is: 34 | [ S_{12} = \frac{12}{2} \times (2 \times 30 + (12 - 1) \times 2) ] 35 | [ S_{12} = 6 \times (60 + 22) ] 36 | [ S_{12} = 6 \times 82 ] 37 | [ S_{12} = 492 \text{ kms} ] 38 | 39 | Therefore, the total distance traveled in 12 hours is 492 kms. 40 | Final Answer: The final answer is (C). 41 | 42 | 43 | Problem: 44 | A reduction of 25 % in the price of oil enables a house wife to obtain 5 kgs more for rs . 600 , what is the reduced price for kg ? 45 | What of the following is the right choice? Explain your answer (a little concise). 46 | (A) 30, (B) 60, (C) 70, (D) 80, (E) 20 47 | Solution: 48 | Let's denote the original price per kg of oil as \( P \) rupees. After a 25% reduction, the new price per kg becomes \( 0.75P \). 49 | With the original price, the housewife could buy \( \frac{600}{P} \) kgs of oil. With the reduced price, she can now buy \( \frac{600}{0.75P} \) kgs of oil. 50 | According to the problem, the difference in the quantity of oil she can buy with the reduced price is 5 kgs more than what she could buy with the original price. So we can set up the following equation: 51 | \[ \frac{600}{0.75P} - \frac{600}{P} = 5 \] 52 | To solve for \( P \), we find a common denominator and subtract the fractions: 53 | \[ \frac{800}{P} - \frac{600}{P} = 5 \] 54 | \[ \frac{200}{P} = 5 \] 55 | \[ P = \frac{200}{5} \] 56 | \[ P = 40 \] 57 | Now that we have the original price \( P \), we can find the reduced price by calculating 75% of \( P \): 58 | \[ 0.75P = 0.75 \times 40 = 30 \] 59 | So the reduced price per kg of oil is 30 rupees. 60 | Final Answer: The final answer is (A). 61 | 62 | 63 | Problem: 64 | A lady builds 9 cm length , 12 cm width , and 3 cm height box using 3 cubic cm cubes . what is the minimum number of cubes required to build the box ? 65 | What of the following is the right choice? Explain your answer. 66 | (A) 107, (B) 108, (C) 109, (D) 110, (E) 111 67 | Solution: 68 | To find the minimum number of 3 cubic cm cubes required to build the box, we need to calculate the volume of the box and then divide it by the volume of one cube. 69 | The volume of the box is given by the formula for the volume of a rectangular prism, which is length × width × height. So, the volume of the box is: 70 | Volume of the box = 9 cm × 12 cm × 3 cm = 324 cubic cm. 71 | Each cube has a volume of 3 cubic cm (since it's a cube with each side being 1 cm, and the volume of a cube is side^3). 72 | To find the number of cubes needed, we divide the total volume of the box by the volume of one cube: 73 | Number of cubes = Volume of the box / Volume of one cube = 324 cubic cm / 3 cubic cm = 108. 74 | Therefore, the minimum number of cubes required to build the box is 108. 75 | Final Answer: The final answer is (B). -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from datasets import load_dataset, Dataset, concatenate_datasets 5 | from utils import load_jsonl, lower_keys 6 | 7 | 8 | def load_data(data_name, split, data_dir='./data'): 9 | if data_name in ['minerva_math']: 10 | data_name = 'math_oai' 11 | data_file = f"{data_dir}/{data_name}/{split}.jsonl" 12 | if os.path.exists(data_file): 13 | examples = list(load_jsonl(data_file)) 14 | else: 15 | if data_name == "math": 16 | dataset = load_dataset("competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp") 17 | elif data_name == "theorem_qa": 18 | dataset = load_dataset("wenhu/TheoremQA", split=split) 19 | elif data_name == "gsm8k": 20 | dataset = load_dataset(data_name, split=split) 21 | elif data_name == "gsm_hard": 22 | dataset = load_dataset("reasoning-machines/gsm_hard", split="train") 23 | elif data_name == "svamp": 24 | # evaluate on training set + test set 25 | dataset = load_dataset("ChilleD/SVAMP", split="train") 26 | dataset = concatenate_datasets([dataset, load_dataset("ChilleD/SVAMP", split="test")]) 27 | elif data_name == "asdiv": 28 | dataset = load_dataset("EleutherAI/asdiv", split="validation") 29 | dataset = dataset.filter(lambda x: ";" not in x['answer']) # remove multi-answer examples 30 | elif data_name == "mawps": 31 | examples = [] 32 | # four sub-tasks 33 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: 34 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) 35 | for example in sub_examples: 36 | example['type'] = data_name 37 | examples.extend(sub_examples) 38 | dataset = Dataset.from_list(examples) 39 | elif data_name == "finqa": 40 | dataset = load_dataset("dreamerdeo/finqa", split=split, name="main") 41 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 42 | elif data_name == "tabmwp": 43 | examples = [] 44 | with open(f"{data_dir}/tabmwp/tabmwp_{split}.json", "r") as f: 45 | data_dict = json.load(f) 46 | examples.extend(data_dict.values()) 47 | dataset = Dataset.from_list(examples) 48 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 49 | elif data_name == "mathqa": 50 | dataset = load_dataset("math_qa", split=split) 51 | dataset = dataset.rename_column("category", "type") 52 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 53 | elif data_name == "mmlu_stem": 54 | dataset = load_dataset("hails/mmlu_no_train", 'all', split='test') 55 | # only keep stem subjects 56 | stem_subjects = ['abstract_algebra', 'astronomy', 'college_biology', 'college_chemistry', 57 | 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 58 | 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 59 | 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 60 | 'high_school_physics', 'high_school_statistics', 'machine_learning'] 61 | dataset = dataset.rename_column("subject", "type") 62 | dataset = dataset.filter(lambda x: x['type'] in stem_subjects) 63 | elif data_name == "bbh": 64 | examples = [] 65 | for data_name in ["reasoning_about_colored_objects", "penguins_in_a_table",\ 66 | "date_understanding", "repeat_copy_logic", "object_counting"]: 67 | with open(f"{data_dir}/bbh/bbh/{data_name}.json", "r") as f: 68 | sub_examples = json.load(f)["examples"] 69 | for example in sub_examples: 70 | example['type'] = data_name 71 | examples.extend(sub_examples) 72 | dataset = Dataset.from_list(examples) 73 | elif data_name == "hungarian_exam": 74 | dataset = load_dataset("json", data_files=f"{data_dir}/hungarian_exam/{split}.jsonl") 75 | else: 76 | raise NotImplementedError(data_name) 77 | 78 | examples = list(dataset) 79 | examples = [lower_keys(example) for example in examples] 80 | dataset = Dataset.from_list(examples) 81 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) 82 | dataset.to_json(data_file) 83 | 84 | # add 'idx' in the first column 85 | if 'idx' not in examples[0]: 86 | examples = [{'idx': i, **example} for i, example in enumerate(examples)] 87 | 88 | # dedepulicate & sort 89 | examples = sorted(examples, key=lambda x: x['idx']) 90 | return examples 91 | 92 | 93 | if __name__ == "__main__": 94 | examples = load_data("mmlu_stem", "test") 95 | -------------------------------------------------------------------------------- /scripts/run_eval_multi_gpus.py: -------------------------------------------------------------------------------- 1 | """ 2 | This scripts is deprecated. 3 | """ 4 | import os 5 | import time 6 | import torch 7 | import subprocess 8 | import argparse 9 | from multiprocessing import Pool 10 | 11 | from summarize_results import summarize_results 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str) 16 | parser.add_argument("--data_names", default="gsm8k,minerva_math,math,gsm_hard,svamp,tabmwp,asdiv,mawps", type=str) 17 | parser.add_argument("--output_dir", default="/mnt/project/tora/outputs", type=str) 18 | parser.add_argument("--prompt_type", default="tool-integrated", type=str, choices=["direct", "cot", "pal", "tool-integrated", "self-instruct", "self-instruct-boxed", "tora", "pal", "cot", "wizard_zs", "platypus_fs", "deepseek-math", "kpmath"]) 19 | parser.add_argument("--split", default="test", type=str) 20 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data_name 21 | parser.add_argument("--seed", default=0, type=int) 22 | parser.add_argument("--start", default=0, type=int) 23 | parser.add_argument("--end", default=-1, type=int) 24 | parser.add_argument("--temperature", default=0, type=float) 25 | parser.add_argument("--n_sampling", default=1, type=int) 26 | parser.add_argument("--top_p", default=1, type=float) 27 | parser.add_argument("--gpus_per_model", default=1, type=int) # 2 for 70b 28 | parser.add_argument("--available_gpus", default=None, type=str) 29 | parser.add_argument("--split_data_over_gpus", action="store_true") 30 | parser.add_argument("--use_vllm", action="store_true") 31 | parser.add_argument("--save_outputs", action="store_true") 32 | parser.add_argument("--use_safetensors", action="store_true") 33 | parser.add_argument("--overwrite", action="store_true") 34 | args = parser.parse_args() 35 | return args 36 | 37 | args = parse_args() 38 | 39 | if not os.path.exists(args.model_name_or_path): 40 | raise FileNotFoundError(args.model_name_or_path) 41 | 42 | data_list = args.data_names.split(',') 43 | 44 | if args.available_gpus: 45 | available_gpus = args.available_gpus.split(',') 46 | else: 47 | available_gpus = [str(i) for i in range(torch.cuda.device_count())] 48 | 49 | start_end_list = [(args.start, args.end) for _ in range(len(data_list))] 50 | 51 | if args.split_data_over_gpus: 52 | assert len(data_list) == 1 53 | assert args.num_test_sample != -1 54 | num_gpus = len(available_gpus) 55 | data_list = [data_list[0] for _ in range(num_gpus)] 56 | num_test_sample_per_gpu = args.num_test_sample // num_gpus 57 | start_end_list = [(i * num_test_sample_per_gpu, (i+1) * num_test_sample_per_gpu if i != (num_gpus - 1 ) else args.num_test_sample) for i in range(num_gpus)] 58 | 59 | 60 | gpu_idx = 0 61 | scripts = [] 62 | for i, data_name in enumerate(data_list): 63 | if gpu_idx + args.gpus_per_model > len(available_gpus): 64 | print("No enough GPUs!") 65 | break 66 | 67 | start, end = start_end_list[i] 68 | cmd = f"sleep {gpu_idx * 3} && " \ 69 | f"CUDA_VISIBLE_DEVICES={','.join(available_gpus[gpu_idx:gpu_idx+args.gpus_per_model])} TOKENIZERS_PARALLELISM=false "\ 70 | "python3 -u math_eval.py " \ 71 | f"--model_name_or_path {args.model_name_or_path} " \ 72 | f"--data_name {data_name} " \ 73 | f"--output_dir {args.output_dir} " \ 74 | f"--split {args.split} " \ 75 | f"--prompt_type {args.prompt_type} " \ 76 | f"--num_test_sample {args.num_test_sample} " \ 77 | f"--seed {args.seed} " \ 78 | f"--temperature {args.temperature} " \ 79 | f"--n_sampling {args.n_sampling} " \ 80 | f"--top_p {args.top_p} " \ 81 | f"--start {start} " \ 82 | f"--end {end} " \ 83 | 84 | if args.use_vllm: 85 | cmd += "--use_vllm " 86 | if args.save_outputs: 87 | cmd += "--save_outputs " 88 | if args.use_safetensors: 89 | cmd += "--use_safetensors " 90 | if args.overwrite: 91 | cmd += "--overwrite" 92 | 93 | # cmd += " & " 94 | # print(cmd) 95 | # os.system(cmd) 96 | 97 | scripts.append(cmd) 98 | gpu_idx += args.gpus_per_model 99 | 100 | 101 | def run_process(cmd): 102 | os.system(cmd) 103 | 104 | 105 | if __name__ == "__main__": 106 | 107 | for i, script in enumerate(scripts): 108 | print(f"Script {i}: {script}") 109 | 110 | pool = Pool() 111 | pool.map(run_process, scripts) 112 | 113 | pool.close() 114 | 115 | summarize_results(args.output_dir, args.data_names, args.split) 116 | 117 | 118 | # Usage: 119 | # model_name_or_path=./mistral/Mistral-7B-v0.1 120 | # python3 scripts/run_eval_multi_gpus.py \ 121 | # --model_name_or_path $model_name_or_path \ 122 | # --prompt_type "cot" \ 123 | # --save_outputs \ 124 | # --available_gpus 0,1,2,3 \ 125 | # --data_names gsm8k,minerva_math,svamp,asdiv \ 126 | # --use_vllm \ 127 | # --gpus_per_model 1 \ 128 | # --overwrite 129 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from tqdm import tqdm 4 | from pebble import ProcessPool 5 | from concurrent.futures import TimeoutError 6 | 7 | from grader import * 8 | 9 | from parser import * 10 | from utils import load_jsonl 11 | from python_executor import PythonExecutor 12 | 13 | 14 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False): 15 | assert samples or file_path, "samples or file_path must be provided" 16 | if not samples: 17 | samples = list(load_jsonl(file_path)) 18 | # dedup by idx 19 | if 'idx' in samples[0]: 20 | samples = {sample['idx']: sample for sample in samples}.values() 21 | samples = sorted(samples, key=lambda x: x['idx']) 22 | else: 23 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)] 24 | 25 | if max_num_samples: 26 | print(f"max_num_samples: {max_num_samples} / {len(samples)}") 27 | samples = samples[:max_num_samples] 28 | 29 | # parse gt 30 | for sample in samples: 31 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name) 32 | 33 | # execute 34 | if ('pred' not in samples[0]) or execute: 35 | if "pal" in prompt_type: 36 | executor = PythonExecutor(get_answer_expr="solution()") 37 | else: 38 | executor = PythonExecutor(get_answer_from_stdout=True) 39 | 40 | for sample in tqdm(samples, desc="Execute"): 41 | sample['pred'] = [] 42 | sample['report'] = [] 43 | for code in sample['code']: 44 | pred, report = run_execute(executor, code, prompt_type, data_name, execute=True) 45 | sample['pred'].append(pred) 46 | sample['report'].append(report) 47 | 48 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']] 49 | 50 | scores = [] 51 | timeout_cnt = 0 52 | 53 | with ProcessPool() as pool: 54 | future = pool.map(math_equal_process, params, timeout=3) 55 | iterator = future.result() 56 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar: 57 | while True: 58 | try: 59 | result = next(iterator) 60 | scores.append(result) 61 | except StopIteration: 62 | break 63 | except TimeoutError as error: 64 | print(error) 65 | scores.append(False) 66 | timeout_cnt += 1 67 | except Exception as error: 68 | print(error.traceback) 69 | exit() 70 | progress_bar.update(1) 71 | 72 | idx = 0 73 | score_mat = [] 74 | for sample in samples: 75 | sample['score'] = scores[idx: idx+len(sample['pred'])] 76 | assert len(sample['score']) == len(sample['pred']) 77 | score_mat.append(sample['score']) 78 | idx += len(sample['pred']) 79 | 80 | max_len = max([len(s) for s in score_mat]) 81 | 82 | for i, s in enumerate(score_mat): 83 | if len(s) < max_len: 84 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad 85 | 86 | # output mean of each column of scores 87 | col_means= np.array(score_mat).mean(axis=0) 88 | mean_score = list(np.round(col_means * 100, decimals=1)) 89 | 90 | result_json = { 91 | "num_samples": len(samples), 92 | "num_scores": len(scores), 93 | "timeout_samples": timeout_cnt, 94 | "empty_samples": len([s for s in samples if not s['pred'][-1]]), 95 | "acc": mean_score[0] 96 | } 97 | 98 | # each type score 99 | if "type" in samples[0]: 100 | type_scores = {} 101 | for sample in samples: 102 | if sample['type'] not in type_scores: 103 | type_scores[sample['type']] = [] 104 | type_scores[sample['type']].append(sample['score'][-1]) 105 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()} 106 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])} 107 | result_json['type_acc'] = type_scores 108 | 109 | print(result_json) 110 | return samples, result_json 111 | 112 | 113 | def parse_args(): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--data_name", type=str, default="math") 116 | parser.add_argument("--prompt_type", type=str, default="tool-integrated") 117 | parser.add_argument("--file_path", type=str, default=None, required=True) 118 | parser.add_argument("--max_num_samples", type=int, default=None) 119 | parser.add_argument("--execute", action="store_true") 120 | args = parser.parse_args() 121 | return args 122 | 123 | if __name__ == "__main__": 124 | args = parse_args() 125 | evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path, 126 | max_num_samples=args.max_num_samples, execute=args.execute) 127 | -------------------------------------------------------------------------------- /prompts/tora/math.md: -------------------------------------------------------------------------------- 1 | Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines: 2 | 3 | - Analyze the question and write functions to solve the problem; the function should not take any arguments. 4 | - Present the final result in LaTeX using a `\boxed{}` without any units. 5 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values. 6 | 7 | Here are some examples you may refer to: 8 | 9 | --- 10 | 11 | Question: Convert the point $(0, -3 \sqrt{3}, 3)$ in rectangular coordinates to spherical coordinates. Enter your answer in the form $(\rho,\theta,\phi),$ where $\rho > 0,$ $0 \le \theta < 2 \pi,$ and $0 \le \phi \le \pi.$ 12 | 13 | Solution: 14 | To convert the point $(x, y, z)$ in rectangular coordinates to spherical coordinates $(\rho, \theta, \phi)$, we use the following formulas: 15 | 1. $\rho = \sqrt{x^2 + y^2 + z^2}$ 16 | 2. $\theta = \arctan{\frac{y}{x}}$ 17 | 3. $\phi = \arccos{\frac{z}{\rho}}$ 18 | 19 | Let's write a function: 20 | 21 | ```python 22 | from sympy import sqrt, atan2, acos, pi 23 | 24 | def rectangular_to_spherical(): 25 | x, y, z = 0, -3*sqrt(3), 3 26 | rho = sqrt(x**2 + y**2 + z**2) 27 | theta = atan2(y, x) 28 | phi = acos(z/rho) 29 | return rho, theta, phi 30 | 31 | spherical_coordinates = rectangular_to_spherical() 32 | print(spherical_coordinates) 33 | ``` 34 | ```output 35 | (6, -pi/2, pi/3) 36 | ``` 37 | 38 | The required form of the answer is $(\rho, \theta, \phi)$, where $\rho > 0$, $0 \le \theta < 2\pi$, and $0 \le \phi \le \pi$. Since our initial output has $\theta = -\pi/2$, which doesn't meet the condition for $\theta$, we add $2\pi$ to adjust its value. 39 | 40 | Hence, the spherical coordinates for the point $(0, -3\sqrt{3}, 3)$ are $\boxed{(6, \frac{3\pi}{2}, \frac{\pi}{3})}$. 41 | 42 | --- 43 | 44 | Question: Compute $1011_2 + 101_2 - 1100_2 + 1101_2$. Express your answer in base 2. 45 | 46 | Solution: 47 | Let's write a single python function for binary computation: 48 | 49 | ```python 50 | def binary_sum_diff(): 51 | # Convert binary to decimal 52 | num1 = int("1011", 2) 53 | num2 = int("101", 2) 54 | num3 = int("1100", 2) 55 | num4 = int("1101", 2) 56 | 57 | # Perform the operations 58 | result = num1 + num2 - num3 + num4 59 | 60 | # Convert the result back to binary 61 | result_binary = format(result, "b") 62 | 63 | return result_binary 64 | 65 | result = binary_sum_diff() 66 | print(result) 67 | ``` 68 | ```output 69 | 10001 70 | ``` 71 | The answer expressed in base 2 is $\boxed{10001_2}$. 72 | 73 | --- 74 | 75 | Question: Find the matrix that corresponds to projecting onto the vector $\begin{pmatrix} 2 \\ -3 \end{pmatrix}.$ 76 | 77 | Solution: 78 | ```python 79 | from sympy import Matrix, simplify 80 | 81 | def projection_matrix(): 82 | v = Matrix([2, -3]) 83 | 84 | # Calculate the projection matrix 85 | P = (v * v.T) / (v.T * v)[0] 86 | 87 | # Simplify the elements 88 | P = P.applyfunc(simplify) 89 | 90 | return P 91 | 92 | result = projection_matrix() 93 | print(result) 94 | ``` 95 | ```output 96 | Matrix([[4/13, -6/13], [-6/13, 9/13]]) 97 | ``` 98 | 99 | The matrix is $\boxed{\begin{pmatrix} \frac{4}{13} & -\frac{6}{13} \ -\frac{6}{13} & \frac{9}{13} \end{pmatrix}}.$ 100 | 101 | --- 102 | 103 | Question: Solve the inequality 104 | \[\frac{x^2 - 25}{x + 5} < 0.\] 105 | 106 | Solution: 107 | Let's use `sympy` to solve the inequality: 108 | 109 | ```python 110 | from sympy import symbols, simplify 111 | from sympy.solvers.inequalities import solve_univariate_inequality 112 | from sympy.core.relational import StrictLessThan 113 | 114 | def solve_inequality(): 115 | x = symbols('x') 116 | expression = (x**2 - 25) / (x + 5) 117 | inequality = StrictLessThan(expression, 0) 118 | solution = solve_univariate_inequality(inequality, x, relational=False) 119 | simplified_solution = simplify(solution) 120 | 121 | return simplified_solution 122 | 123 | result = solve_inequality() 124 | print(result) 125 | ``` 126 | ```output 127 | Union(Interval.open(-oo, -5), Interval.Lopen(-5, 5)) 128 | ``` 129 | The answer is $\boxed{(-\infty,-5)\cup(-5,5)}$ 130 | 131 | --- 132 | 133 | Question: In the figure, triangles $ABC$ and $BCD$ are equilateral triangles. What is the value of $AD \div BC$ when expressed in simplest radical form? 134 | 135 | [asy] 136 | draw((0,0)--(5,8.7)--(10,0)--cycle); 137 | draw((10,0)--(15,8.7)--(5,8.7)); 138 | label("$A$",(0,0),SW); 139 | label("$B$",(5,8.7),N); 140 | label("$C$",(10,0),SE); 141 | label("$D$",(15,8.7),NE); 142 | [/asy] 143 | 144 | Solution: 145 | ```python 146 | from sympy import Rational, sqrt, simplify 147 | 148 | def ad_divided_by_bc(): 149 | 150 | x = Rational(1, 1) # Side length of equilateral triangles 151 | 152 | ad_squared = 2 * x**2 * (1 + Rational(1, 2)) # Using the law of cosines with cos(2*pi/3) = -1/2 153 | ad = sqrt(ad_squared) 154 | 155 | bc = x # BC is the side length of the equilateral triangles 156 | 157 | simplified_ratio = simplify(ad / bc) 158 | 159 | return simplified_ratio 160 | 161 | result = ad_divided_by_bc() 162 | print(result) 163 | ``` 164 | ```output 165 | sqrt(3) 166 | ``` 167 | The value of $AD \div BC$ is $\boxed{\sqrt{3}}$. 168 | 169 | --- 170 | -------------------------------------------------------------------------------- /prompts/cot/math_8shot.md: -------------------------------------------------------------------------------- 1 | Question: Kevin Kangaroo begins hopping on a number line at 0. He wants to get to 1, but he can hop only $\frac{1}{3}$ of the distance. Each hop tires him out so that he continues to hop $\frac{1}{3}$ of the remaining distance. How far has he hopped after five hops? Express your answer as a common fraction. 2 | Answer: Let's think step by step 3 | Kevin hops $1/3$ of the remaining distance with every hop. 4 | His first hop takes $1/3$ closer. 5 | For his second hop, he has $2/3$ left to travel, so he hops forward $(2/3)(1/3)$. 6 | For his third hop, he has $(2/3)^2$ left to travel, so he hops forward $(2/3)^2(1/3)$. 7 | In general, Kevin hops forward $(2/3)^{k-1}(1/3)$ on his $k$th hop. 8 | We want to find how far he has hopped after five hops. 9 | This is a finite geometric series with first term $1/3$, common ratio $2/3$, and five terms. 10 | Thus, Kevin has hopped $\frac{\frac{1}{3}\left(1-\left(\frac{2}{3}\right)^5\right)}{1-\frac{2}{3}} = \boxed{\frac{211}{243}}$. 11 | The answer is \frac{211}{243}} 12 | 13 | 14 | Question: What is the area of the region defined by the equation $x^2+y^2 - 7 = 4y-14x+3$? 15 | Answer: Let's think step by step 16 | We rewrite the equation as $x^2 + 14x + y^2 - 4y = 10$ and then complete the square, 17 | resulting in $(x+7)^2-49 + (y-2)^2-4=10$, 18 | or $(x+7)^2+(y-2)^2=63$. 19 | This is the equation of a circle with center $(-7, 2)$ and radius $\sqrt{63},$ 20 | so the area of this region is $\pi r^2 = \boxed{63\pi}$. 21 | The answer is 63\pi 22 | 23 | 24 | Question: If $x^2+y^2=1$, what is the largest possible value of $|x|+|y|$? 25 | Answer: Let's think step by step 26 | If $(x,y)$ lies on the circle, 27 | so does $(x,-y),$ $(-x,-y),$ and $(-x,-y),$ (which all give the same value of $|x| + |y|$), 28 | so we can assume that $x \ge 0$ and $y \ge 0.$ 29 | Then $|x| + |y| = x + y.$ Squaring, we get 30 | \[(x + y)^2 = x^2 + 2xy + y^2 = 1 + 2xy.\] 31 | Note that $(x - y)^2 \ge 0.$ 32 | Expanding, we get $x^2 - 2xy + y^2 \ge 0,$ so $2xy \le x^2 + y^2 = 1.$ 33 | Hence,\[1 + 2xy \le 2,\]which means $x + y \le \sqrt{2}.$ 34 | Equality occurs when $x = y = \frac{1}{\sqrt{2}},$ 35 | so the maximum value of $|x| + |y|$ is $\boxed{\sqrt{2}}.$ 36 | The answer is \sqrt{2} 37 | 38 | 39 | Question: If $f(x)=\frac{ax+b}{cx+d}, abcd\not=0$ and $f(f(x))=x$ for all $x$ in the domain of $f$, what is the value of $a+d$? 40 | Answer: Let's think step by step 41 | The condition $f(f(x))$ means that $f$ is the inverse of itself, 42 | so its graph is symmetrical about the line $y = x$. 43 | With a rational function of this form, we will have two asymptotes: 44 | a vertical one at $x=-d/c$ if $cx+d$ does not divide $ax+b$, 45 | and a horizontal one at $y=a/c$, 46 | if we take the limit of $f(x)$ as $x$ goes to $\pm\infty$. 47 | In order for $f$ to be its own inverse, the intersection of the asymptotes must lie on the line $y=x$ 48 | so that it and its asymptotes reflect onto themselves. 49 | This means that $-d/c=a/c$, 50 | and therefore $-d=a$ and $a+d=\boxed{0}$. 51 | The answer is 0 52 | 53 | 54 | Question: A math teacher requires Noelle to do one homework assignment for each of the first five homework points she wants to earn; for each of the next five homework points, she needs to do two homework assignments; and so on, so that to earn the $n^{\text{th}}$ homework point, she has to do $n\div5$ (rounded up) homework assignments. For example, when she has 11 points, it will take $12\div5=2.4\rightarrow3$ homework assignments to earn her $12^{\text{th}}$ point. What is the smallest number of homework assignments necessary to earn a total of 25 homework points? 55 | Answer: Let's think step by step 56 | Noelle only has to do 1 homework assignment to earn her first point, 57 | and the same is true for each of her first five points. 58 | She must then do 2 homework assignments to earn her sixth point, seventh point, and so on, up to her tenth point. 59 | Continuing, we see that Noelle must do a total of \[1+1+1+1+1+2+2+2+2+2+\dots+5+5+5+5+5\] homework assignments to earn 25 points. 60 | This sum may be rewritten as $5(1+2+3+4+5)=5(15)=\boxed{75}$. 61 | The answer is 75 62 | 63 | 64 | Question: The quadratic equation $x^2+mx+n=0$ has roots that are twice those of $x^2+px+m=0,$ and none of $m,$ $n,$ and $p$ is zero. What is the value of $n/p?$ 65 | Answer: Let's think step by step 66 | Let $r_1$ and $r_2$ be the roots of $x^2+px+m=0.$ 67 | Since the roots of $x^2+mx+n=0$ are $2r_1$ and $2r_2,$ we have the following relationships: \[ 68 | m=r_1 r_2,\quad n=4r_1 r_2,\quad p=-(r_1+r_2), \quad\text{and}\quad 69 | m=-2(r_1+r_2). 70 | \] So \[ 71 | n = 4m, \quad p = \frac{1}{2}m, 72 | \quad\text{and}\quad 73 | \frac{n}{p}=\frac{4m}{\frac{1}{2}m}=\boxed{8}. 74 | \] 75 | Alternatively, the roots of \[ 76 | \left(\frac{x}{2}\right)^2 + p\left(\frac{x}{2}\right) + m = 0 77 | \] are twice those of $x^2 + px + m = 0.$ 78 | Since the first equation is equivalent to $x^2 + 2px + 4m = 0,$ 79 | we have \[m = 2p \quad\text{and}\quad n = 4m, \quad\text{so}\quad \frac{n}{p} = \boxed{8}.\] 80 | The answer is 8 81 | 82 | 83 | Question: Expand $(2z^2 + 5z - 6)(3z^3 - 2z + 1)$. 84 | Answer: Let's think step by step 85 | $$\begin{array}{crrrrrrr} 86 | & & & 3z^3 & & -2z & + 1 & \\ 87 | \times & & & & 2z^2 & +5z & -6 \\ 88 | \cline{1-7}\rule{0pt}{0.17in} 89 | & & & -18z^3 & & +12z & -6 & \\ 90 | & & +15z^4 & & -10z^2 & +5z & & \\ 91 | + & 6z^5 & & -4z^3 & +2z^2 & & & \\ 92 | \cline{1-7}\rule{0pt}{0.17in} 93 | & 6z^5 & +15z^4 & -22z^3 & - 8z^2 &+17z & -6 & 94 | \end{array}$$ 95 | The answer is 6z^5+15z^4-22z^3-8z^2+17z-6 96 | 97 | 98 | Question: Find the mean of all solutions for $x$ when $x^3 + 3x^2 - 10x = 0$. 99 | Answer: Let's think step by step 100 | First, we factor the equation as $x(x^2 +3x - 10) = 0$. 101 | So, one solution is $x=0$ and the other two solutions are the solutions to $x^2 + 3x-10=0$. 102 | We could either factor the quadratic, or note that the sum of the solutions to this quadratic is $-(3/1)=-3$, 103 | so the mean of the three solutions to the original equation is $-3/3=\boxed{-1}$. 104 | The answer is -1 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import json 5 | import os 6 | import numpy as np 7 | from pathlib import Path 8 | from typing import Iterable, Union, Any 9 | 10 | 11 | def set_seed(seed: int = 42) -> None: 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | os.environ["PYTHONHASHSEED"] = str(seed) 15 | print(f"Random seed set as {seed}") 16 | 17 | 18 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]: 19 | with open(file, "r", encoding="utf-8") as f: 20 | for line in f: 21 | try: 22 | yield json.loads(line) 23 | except: 24 | print("Error in loading:", line) 25 | exit() 26 | 27 | 28 | def save_jsonl(samples, save_path): 29 | # ensure path 30 | folder = os.path.dirname(save_path) 31 | os.makedirs(folder, exist_ok=True) 32 | 33 | with open(save_path, "w", encoding="utf-8") as f: 34 | for sample in samples: 35 | f.write(json.dumps(sample) + "\n") 36 | print("Saved to", save_path) 37 | 38 | 39 | def lower_keys(example): 40 | new_example = {} 41 | for key, value in example.items(): 42 | if key != key.lower(): 43 | new_key = key.lower() 44 | new_example[new_key] = value 45 | else: 46 | new_example[key] = value 47 | return new_example 48 | 49 | 50 | def load_prompt(data_name, prompt_type): 51 | if data_name in ['gsm_hard', 'svamp', 'tabmwp', 'asdiv', 'mawps']: 52 | data_name = "gsm8k" 53 | if data_name in ['math_oai', "hungarian_exam"]: 54 | data_name = "math" 55 | if data_name in ['sat_math']: 56 | data_name = "mmlu_stem" 57 | if prompt_type in ['platypus_fs']: 58 | prompt_type = "cot" 59 | if prompt_type in ['tool-integrated']: 60 | prompt_type = "tora" 61 | 62 | if prompt_type in ['cot', 'pal', 'tora']: 63 | prompt_path = "./prompts/{}/{}.md".format(prompt_type, data_name) 64 | if not os.path.exists(prompt_path): 65 | prompt_path = "./prompts/{}.md".format(prompt_type) 66 | if os.path.exists(prompt_path): 67 | with open(prompt_path, 'r', encoding='utf-8') as fp: 68 | prompt = fp.read().strip() + "\n\n\n" 69 | else: 70 | print(f"Error: prompt file {prompt_path} not found") 71 | prompt = "" 72 | else: 73 | prompt = "" 74 | return prompt 75 | 76 | def construct_prompt(example, data_name, args): 77 | # Base models 78 | if args.prompt_type in ["direct", "cot", "pal", "tool-integrated"]: 79 | demo_prompt = load_prompt(data_name, args.prompt_type) 80 | if args.prompt_type in ["direct", "cot"]: 81 | if data_name in ["minerva_math", "math", "math_oai", "mmlu_stem", "sat_math", "mathqa", "hungarian_exam"]: 82 | context = f"Problem:\n{example['question']}\nSolution:" 83 | else: 84 | context = f"Question: {example['question']}\nAnswer:" 85 | full_prompt = demo_prompt + context 86 | elif args.prompt_type == "pal": 87 | context = f"Question: {example['question']}" 88 | full_prompt = demo_prompt + context 89 | elif args.prompt_type in ['tool-integreted']: 90 | context = f"Question: {example['question']}\n\nSolution:" 91 | full_prompt = demo_prompt + context 92 | 93 | # SFT models 94 | elif args.prompt_type in ['self-instruct', 'tora']: 95 | full_prompt = f"<|user|>\n{example['question']}\n<|assistant|>\n" 96 | elif args.prompt_type in ['self-instruct-boxed']: 97 | full_prompt = f"<|user|>\n{example['question']}\nEnclose the final answer using \\boxed{{}}.\n<|assistant|>\n" 98 | elif args.prompt_type == "wizard_zs": 99 | full_prompt = ( 100 | "Below is an instruction that describes a task. " 101 | "Write a response that appropriately completes the request.\n\n" 102 | "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 103 | ) 104 | full_prompt = full_prompt.format(instruction=example['question']) 105 | elif args.prompt_type == "deepseek-math": 106 | full_prompt = ( 107 | "User: {instruction}\nPlease reason step by step, " 108 | "and put your final answer within \\boxed{{}}.\n\nAssistant:" 109 | ) 110 | full_prompt = full_prompt.format(instruction=example['question']) 111 | elif args.prompt_type == "kpmath": 112 | full_prompt = ( 113 | 'User: Please reason step by step and put your final answer at the end ' 114 | 'with "The answer is: ".\n\n{instruction}\n\nAssistant:' 115 | ) 116 | full_prompt = full_prompt.format(instruction=example['question']) 117 | else: 118 | raise NotImplementedError(args.prompt_type) 119 | return full_prompt 120 | 121 | key_map = { 122 | "gt": "Ground Truth", 123 | "pred": "Prediction", 124 | "gt_cot": "Reference CoT", 125 | "score": "Score", 126 | } 127 | 128 | def show_sample(sample, print_all_preds=False): 129 | print("=="*20) 130 | for key in ["idx", "type", "level", "dataset"]: 131 | if key in sample: 132 | # capitalize 133 | print("{}: {}".format(key[0].upper() + key[1:], sample[key])) 134 | print("Question:", repr(sample['question'])) 135 | if 'code' in sample: 136 | if print_all_preds: 137 | for code in sample['code']: 138 | print('-'*20) 139 | print("code:", code) 140 | print("Execution:", sample['report']) 141 | else: 142 | print("Solution:\n", sample['code'][0]) 143 | print("Execution:", sample['report'][0]) 144 | if 'pred' in sample: 145 | print("Prediction:", repr(sample['pred'][0])) 146 | for key in ["gt", "score", "unit", "gt_cot"]: 147 | if key in sample: 148 | _key = key_map.get(key, key) 149 | print("{}: {}".format(_key, repr(sample[key]))) 150 | print() 151 | -------------------------------------------------------------------------------- /trajectory.py: -------------------------------------------------------------------------------- 1 | import re 2 | """ 3 | trajcectory: 4 | [ 5 | {"role": "rationale", "content": "..."}, 6 | {"role": "program", "content": "..."}, 7 | {"role": "output", "content": "..."}, 8 | {"role": "rationale", "content": "..."}, 9 | ... 10 | ] 11 | """ 12 | 13 | def text_to_trajectory(traj_str: str) -> None: 14 | """ 15 | """ 16 | # parse the above interleaved string of raionale, program, output, raionale, program, output, ... 17 | # output a list of dict 18 | trajectory = [] 19 | cur_role = "rationale" 20 | cur_content = "" 21 | 22 | # print(traj_str) 23 | for i, line in enumerate(traj_str.split("\n")): 24 | if line == "```python": # program begin 25 | assert cur_role == "rationale" 26 | if cur_content: 27 | trajectory.append({"role": cur_role, "content": cur_content}) 28 | cur_content = "" 29 | cur_role = "program" 30 | elif cur_role == "program" and line == "```": # program end 31 | assert cur_content 32 | trajectory.append({"role": cur_role, "content": cur_content}) 33 | cur_content = "" 34 | cur_role = "output" 35 | elif cur_role == "output" and line.startswith("```output"): # output begin 36 | assert cur_content == "" 37 | elif cur_role == "output" and line == "```": # output end 38 | trajectory.append({"role": cur_role, "content": cur_content}) 39 | cur_content = "" 40 | cur_role = "rationale" 41 | else: # content 42 | cur_content += line 43 | if i < len(traj_str.split("\n")) - 1: 44 | cur_content += "\n" 45 | # the last content 46 | if cur_content: 47 | trajectory.append({"role": cur_role, "content": cur_content}) 48 | return trajectory 49 | 50 | 51 | def trajectory_to_text(trajectory: list) -> str: 52 | text = "" 53 | for item in trajectory: 54 | content = item["content"] 55 | if item["role"] == "program": 56 | content = f"```python\n{content}```\n" 57 | elif item["role"] == "output": 58 | content = f"```output\n{content}```\n" 59 | text += content 60 | return text 61 | 62 | 63 | def is_execution_success(output): 64 | error_key_words = ["error", "exception", "no algorithms", "no algorithms", "cannot", "nan", "..."] 65 | success = all([k not in output.lower() for k in error_key_words]) 66 | return success 67 | 68 | 69 | def extract_program(text:str=None, trajectory:list=None, last_only=False) -> str: 70 | assert text is not None or trajectory is not None, "Either text or trajectory should be provided." 71 | if trajectory is None: 72 | try: 73 | trajectory = text_to_trajectory(text) 74 | except: 75 | return "raise ValueError('Invalid trajectory')" 76 | 77 | program_list = [] 78 | import_lines = [] 79 | for i, item in enumerate(trajectory): 80 | if item["role"] == "program": 81 | cur_program = item["content"] 82 | if i < len(trajectory) - 1: 83 | assert trajectory[i+1]["role"] == "output" 84 | output = trajectory[i+1]["content"].strip() 85 | if is_execution_success(output): 86 | program_list.append(cur_program) 87 | else: 88 | # extract import lines only 89 | for line in cur_program.split("\n"): 90 | if line.startswith("import") or line.startswith("from"): 91 | import_lines.append(line) 92 | else: 93 | program_list.append(cur_program) 94 | # add import lines to the first program 95 | if len(program_list) == 0: 96 | program_list.append("") 97 | if len(import_lines) > 0: 98 | program_list[0] = "\n".join(import_lines) + "\n" + program_list[0] 99 | for i, program in enumerate(program_list[:-1]): 100 | program_list[i] = "\n".join([line for line in program.split("\n") if not line.strip().startswith("print(")]) 101 | 102 | if last_only: 103 | program = program_list[-1] 104 | else: 105 | program = "\n".join(program_list) 106 | return program 107 | 108 | 109 | def extract_program_output(pred_str, last_only=True): 110 | """ 111 | extract output between ```output\n...\n```, use regex, there might be multiple outputs, each output may have multiple lines 112 | """ 113 | outputs = re.findall(r"```output\n(.*?)\n```", pred_str, re.DOTALL) 114 | if last_only: 115 | return outputs[-1] if len(outputs) > 0 else "" 116 | else: 117 | return outputs 118 | 119 | 120 | def _test_str_to_trajectory(): 121 | traj_text = """We can use the Chinese Remainder Theorem to solve this system of congruences. The Chinese Remainder Theorem states that if $n_1, n_2, \ldots, n_k$ are pairwise coprime (i.e., the greatest common divisor of any two of them is 1), and if $a_1, a_2, \ldots, a_k$ are any integers, then there is a number $x$ that satisfies the system of congruences: 122 | 123 | $x \equiv a_1 \pmod{n_1}$ 124 | $x \equiv a_2 \pmod{n_2}$ 125 | $\vdots$ 126 | $x \equiv a_k \pmod{n_k}$ 127 | 128 | Moreover, any two solutions of this system are congruent modulo the product $n_1n_2 \cdots n_k$. 129 | 130 | Let's write a function to solve this: 131 | 132 | ```python 133 | from sympy import symbols, solve 134 | 135 | def smallest_x(): 136 | x = symbols('x') 137 | # Solve the system of congruences 138 | solution = solve([x - 7, x - 17], x, domain='Z') 139 | # Return the smallest positive solution 140 | return min([sol for sol in solution if sol >= 0]) 141 | 142 | result = smallest_x() 143 | print(result) 144 | ``` 145 | ```output 146 | ValueError: min() arg is an empty sequence 147 | ``` 148 | The error in the last example is due to the incorrect usage of the `solve` function. The `solve` function is used to solve equations, not congruences. To solve a system of congruences, we can use the `solve_congruence` function from the `sympy.ntheory.modular` module. 149 | 150 | Here is the corrected code: 151 | 152 | ```python 153 | from sympy.ntheory.modular import solve_congruence 154 | 155 | def smallest_x(): 156 | # Solve the system of congruences 157 | solution = solve_congruence(('x', 7, 12), ('x', 17, 11)) 158 | # Return the smallest positive solution 159 | return solution[0][0] 160 | 161 | result = smallest_x() 162 | print(result) 163 | ``` 164 | ```output 165 | good 166 | ``` 167 | ValueError: too many values to unpack (expected 2) 168 | The error in the last example is due to the incorrect usage of the `solve_congruence` function. The `solve_congruence` function returns a tuple of two elements: the solution and the modulus. We need to unpack these two values correctly. 169 | 170 | Here is the corrected code: 171 | 172 | ```python 173 | from sympy.ntheory.modular import solve_congruence 174 | 175 | def smallest_x(): 176 | # Solve the system of congruences 177 | solution, modulus = solve_congruence(('x', 7, 12), ('x', 17, 11)) 178 | # Return the smallest positive solution 179 | return solution 180 | 181 | result = smallest_x() 182 | print(result) 183 | ```""" 184 | 185 | import pprint 186 | trajectory = text_to_trajectory(traj_text) 187 | pprint.pprint(trajectory) 188 | 189 | text = trajectory_to_text(trajectory) 190 | assert text == traj_text 191 | 192 | # print(extract_program(traj_text)) 193 | 194 | 195 | if __name__ == "__main__": 196 | _test_str_to_trajectory() -------------------------------------------------------------------------------- /python_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import regex 4 | import pickle 5 | import traceback 6 | import copy 7 | import datetime 8 | import dateutil.relativedelta 9 | import multiprocess 10 | from multiprocess import Pool 11 | from typing import Any, Dict, Optional 12 | from pebble import ProcessPool 13 | from tqdm import tqdm 14 | from concurrent.futures import TimeoutError 15 | from functools import partial 16 | from timeout_decorator import timeout 17 | from contextlib import redirect_stdout 18 | 19 | 20 | class GenericRuntime: 21 | GLOBAL_DICT = {} 22 | LOCAL_DICT = None 23 | HEADERS = [] 24 | def __init__(self): 25 | self._global_vars = copy.copy(self.GLOBAL_DICT) 26 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None 27 | 28 | for c in self.HEADERS: 29 | self.exec_code(c) 30 | 31 | def exec_code(self, code_piece: str) -> None: 32 | if regex.search(r'(\s|^)?input\(', code_piece): 33 | # regex.search(r'(\s|^)?os.', code_piece): 34 | raise RuntimeError() 35 | exec(code_piece, self._global_vars) 36 | 37 | # TODO: use: https://github.com/shroominic/codebox-api 38 | # @high safe exec in sandbox 39 | # byte_code = compile_restricted( 40 | # code_piece, 41 | # filename='', 42 | # mode='exec' 43 | # ) 44 | # print("global vars:", self._global_vars) 45 | # _print_ = PrintCollector 46 | # exec(byte_code, {'__builtins__': utility_builtins}, None) 47 | 48 | def eval_code(self, expr: str) -> Any: 49 | return eval(expr, self._global_vars) 50 | 51 | def inject(self, var_dict: Dict[str, Any]) -> None: 52 | for k, v in var_dict.items(): 53 | self._global_vars[k] = v 54 | 55 | @property 56 | def answer(self): 57 | return self._global_vars['answer'] 58 | 59 | class DateRuntime(GenericRuntime): 60 | GLOBAL_DICT = { 61 | 'datetime': datetime.datetime, 62 | 'timedelta': dateutil.relativedelta.relativedelta, 63 | 'relativedelta': dateutil.relativedelta.relativedelta 64 | } 65 | 66 | 67 | class CustomDict(dict): 68 | def __iter__(self): 69 | return list(super().__iter__()).__iter__() 70 | 71 | class ColorObjectRuntime(GenericRuntime): 72 | GLOBAL_DICT = {'dict': CustomDict} 73 | 74 | 75 | class PythonExecutor: 76 | def __init__( 77 | self, 78 | runtime: Optional[Any] = None, 79 | get_answer_symbol: Optional[str] = None, 80 | get_answer_expr: Optional[str] = None, 81 | get_answer_from_stdout: bool = False, 82 | timeout_length: int = 5, 83 | ) -> None: 84 | self.runtime = runtime if runtime else GenericRuntime() 85 | self.answer_symbol = get_answer_symbol 86 | self.answer_expr = get_answer_expr 87 | self.get_answer_from_stdout = get_answer_from_stdout 88 | self.pool = Pool(multiprocess.cpu_count()) 89 | self.timeout_length = timeout_length 90 | 91 | def process_generation_to_code(self, gens: str): 92 | return [g.split('\n') for g in gens] 93 | 94 | @staticmethod 95 | def execute( 96 | code, 97 | get_answer_from_stdout = None, 98 | runtime = None, 99 | answer_symbol = None, 100 | answer_expr = None, 101 | timeout_length = 10, 102 | ): 103 | try: 104 | if get_answer_from_stdout: 105 | program_io = io.StringIO() 106 | with redirect_stdout(program_io): 107 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 108 | program_io.seek(0) 109 | result = program_io.read() 110 | elif answer_symbol: 111 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 112 | result = runtime._global_vars[answer_symbol] 113 | elif answer_expr: 114 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 115 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) 116 | else: 117 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) 118 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 119 | report = "Done" 120 | str(result) 121 | pickle.dumps(result) # serialization check 122 | except: 123 | result = '' 124 | report = traceback.format_exc().split('\n')[-2] 125 | return result, report 126 | 127 | def apply(self, code): 128 | return self.batch_apply([code])[0] 129 | 130 | @staticmethod 131 | def truncate(s, max_length=400): 132 | half = max_length // 2 133 | if len(s) > max_length: 134 | s = s[:half] + "..." + s[-half:] 135 | return s 136 | 137 | def batch_apply(self, batch_code): 138 | all_code_snippets = self.process_generation_to_code(batch_code) 139 | 140 | timeout_cnt = 0 141 | all_exec_results = [] 142 | with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool: 143 | executor = partial( 144 | self.execute, 145 | get_answer_from_stdout=self.get_answer_from_stdout, 146 | runtime=self.runtime, 147 | answer_symbol=self.answer_symbol, 148 | answer_expr=self.answer_expr, 149 | timeout_length=self.timeout_length, # this timeout not work 150 | ) 151 | future = pool.map(executor, all_code_snippets, timeout=self.timeout_length) 152 | iterator = future.result() 153 | 154 | if len(all_code_snippets) > 100: 155 | progress_bar = tqdm(total=len(all_code_snippets), desc="Execute") 156 | else: 157 | progress_bar = None 158 | 159 | while True: 160 | try: 161 | result = next(iterator) 162 | all_exec_results.append(result) 163 | except StopIteration: 164 | break 165 | except TimeoutError as error: 166 | print(error) 167 | all_exec_results.append(("", "Timeout Error")) 168 | timeout_cnt += 1 169 | except Exception as error: 170 | print(error) 171 | exit() 172 | if progress_bar is not None: 173 | progress_bar.update(1) 174 | 175 | if progress_bar is not None: 176 | progress_bar.close() 177 | 178 | batch_results = [] 179 | for code, (res, report) in zip(all_code_snippets, all_exec_results): 180 | # post processing 181 | res, report = str(res).strip(), str(report).strip() 182 | res, report = self.truncate(res), self.truncate(report) 183 | batch_results.append((res, report)) 184 | return batch_results 185 | 186 | 187 | def _test(): 188 | batch_code = [ 189 | """ 190 | from sympy import Matrix 191 | 192 | def null_space_basis(): 193 | # Define the matrix 194 | A = Matrix([[3, 3, -1, -6], [9, -1, -8, -1], [7, 4, -2, -9]]) 195 | 196 | # Compute the basis for the null space 197 | basis = A.nullspace() 198 | 199 | # Round the elements of the basis vectors to three decimal places 200 | basis_rounded = [v.evalf(3) for v in basis] 201 | 202 | return basis_rounded 203 | 204 | result = null_space_basis() 205 | print(result) 206 | """ 207 | ] 208 | 209 | executor = PythonExecutor(get_answer_from_stdout=True) 210 | predictions = executor.apply(batch_code[0]) 211 | print(predictions) 212 | 213 | 214 | if __name__ == '__main__': 215 | _test() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM Math Evaluation Harness 2 | 3 | A unified, precise, and extensible toolkit to benchmark LLMs on various mathematical tasks 🧮✨. 4 | 5 | > 🔴🚀 **Important Notice**: We've identified variances above 5% in results from diverse math evaluation frameworks. To ensure fair and standardized comparisons across research, our toolkit strives to harmonize evaluation methods, promoting consistent and reliable math evaluation. 6 | 7 | > 🌟 **In Practice**: Esteemed projects like [ToRA](https://github.com/microsoft/ToRA) (ICLR'24) and [DeepSeek-Coder](https://github.com/deepseek-ai/DeepSeek-Coder/tree/main/Evaluation/PAL-Math) have leveraged this suite! 8 | 9 | ### Features: 10 | 11 | - **Models**: Seamless compatibility with models from Hugging Face 🤗 and [vLLM](https://github.com/vllm-project/vllm). 12 | 13 | - **Datasets**: An extensive array of datasets including `minerva_math`, `math`, `math_oai`, `gsm8k`, `gsm_hard`, `svamp`, `asdiv`, `mawps`, `tabmwp`, `finqa`, `theorem_qa`, `bbh`, `mmlu_stem`, `sat_math`, `mathqa`, `hungarian_exam`. 14 | 15 | - **Prompts**: Diverse prompting paradigms, from Direct to Chain-of-Thought (CoT), Program-of-Thought (PoT/PAL), and [Tool-Integrated Reasoning (ToRA)](https://github.com/microsoft/ToRA). 16 | 17 | 18 | ## 🚀 Getting Started 19 | 20 | ### ⚙️ Environment Setup 21 | 22 | #### Option 1: Conda 23 | 24 | ``` 25 | conda create -n math_eval python=3.10 26 | conda activate math_eval 27 | ``` 28 | 29 | #### Option 2: Docker 30 | 31 | We suggest using vLLM docker directly: 32 | 33 | ``` 34 | docker run --network host --cap-add=SYS_ADMIN --privileged -d \ 35 | --entrypoint '' --name vllm \ 36 | --runtime nvidia --gpus all \ 37 | --security-opt apparmor:unconfined \ 38 | --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ 39 | -v /mnt:/mnt \ 40 | -p 8000:8000 \ 41 | vllm/vllm-openai:latest \ 42 | sleep infinity 43 | ``` 44 | 45 | ### Install 46 | 47 | ``` 48 | git clone https://github.com/ZubinGou/math-evaluation-harness.git 49 | cd math-evaluation-harness 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | ### ⚖️ Evaluation 54 | 55 | 1. Configure model and data settings in `scripts/run_math_eval.sh`, and set the `PROMPT_TYPE` variable accordingly: 56 | - For base models, choose from: `direct`, `cot`, `pal`, or `tool-integrated`. 57 | - For SFT models, your options include: `tora`, `wizard_zs`, `deepseek-math`, etc. 58 | - To add new models, update the `construct_prompt` function in `utils.py` to include your new prompt template. 59 | 60 | 2. Run the script: 61 | 62 | ```bash 63 | bash scripts/run_eval.sh $PROMPT_TYPE $MODEL_NAME_OR_PATH 64 | ``` 65 | 66 | 67 | ## 📊 Results 68 | 69 | ### Base Models (CoT) 70 | 71 | > PROMPT_TYPE=cot 72 | 73 | | Model | Size | Data | Uniq. Token | Train Token | GSM8K | MATH[^1] | SVAMP | ASDiv | MAWPS | TAB[^2] | MQA | MMLU STEM | SAT | AVG | 74 | |---------------------------------------------------------------|--------------------------|--------|--------------|------------|-------|----------------|-------|-------|-------|-------|------|-----------|----------------|------| 75 | | **1-2B Base Models** | | | | | | | | | | | | | | | 76 | | [Tinyllama](https://huggingface.co/Tinyllama/Tinyllama-1.1B-intermediate-step-1431k-3T) | 1.1B | - | - | - | 2.9 | 3.2 | 11.0 | 18.1 | 20.4 | 12.5 | 14.6 | 16.1 | 21.9 | 13.4 | 77 | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | - | - | - | 32.4 | 4.2 | 43.4 | 53.1 | 66.2 | 24.4 | 14.3 | 21.8 | 18.8 | 31.0 | 78 | | [Qwen1.5](https://huggingface.co/Qwen/Qwen1.5-1.8B) | 1.8B | - | - | - | 36.1 | 6.8 | 48.5 | 63.6 | 79.0 | 29.2 | 25.1 | 31.3 | 40.6 | 40.0 | 79 | | [Gemma](https://huggingface.co/google/gemma-2b) | 2.0B | - | - | - | 18.8 | 11.4 | 38.0 | 56.6 | 72.5 | **36.9** | 26.8 | **34.4** | 50.0 | 38.4 | 80 | | DeepSeekLLM | 1.3B | OWM | 14B | 150B | 11.5 | 8.9 | - | - | - | - | - | 29.6 | 31.3 | - | 81 | | DeepSeekMath | 1.3B | - | 120B | 150B | 23.8 | 13.6 | - | - | - | - | - | 33.1 | **56.3** | - | 82 | | [Rho-Math](https://huggingface.co/microsoft/rho-math-1b-v0.1) | 1.1B | OWM | 14B | **30B** | **36.2** | **15.6** | **52.1** | **67.0** | **83.9** | 29.0 | **32.5** | 23.3 | 28.1 | **40.9** | 83 | | **>= 7B Base Models** | | | | | | | | | | | | | | | 84 | | [LLaMA-2](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 7B | | - | - | 14.0 | 3.6 | 39.5 | 51.7 | 63.5 | 30.9 | 12.4 | 32.7 | 34.4 | 31.4 | 85 | | [Mistral](https://huggingface.co/mistralai/Mistral-7B-v0.1) | 7B | | - | - | 41.2 | 11.6 | 64.7 | 68.5 | 87.5 | 52.9 | 33.0 | 49.5 | 59.4 | 52.0 | 86 | | Minerva | 8B | - | 39B | 164B | 16.2 | 14.1 | - | - | - | - | - | 35.6 | - | - | 87 | | Minerva | 62B | - | 39B | 109B | 52.4 | 27.6 | - | - | - | - | - | 53.9 | - | - | 88 | | Minerva | 540B | - | 39B | 26B | 58.8 | 33.6 | - | - | - | - | - | **63.9** | - | - | 89 | | [LLemma](https://huggingface.co/EleutherAI/llemma_7b) | 7B | PPile | 55B | 200B | 38.8 | 17.2 | 56.1 | 69.1 | 82.4 | 48.7 | 41.0 | 45.4 | 59.4 | 50.9 | 90 | | [LLemma](https://huggingface.co/EleutherAI/llemma_34b) | 34B | PPile | 55B | 50B | 54.2 | 23.0 | 67.9 | 75.7 | 90.1 | 57.0 | 49.8 | 54.7 | 68.8 | 60.1 | 91 | | [Intern-Math](https://huggingface.co/internlm/internlm2-math-base-7b) | 7B | - | 31B | 125B | 41.8 | 14.4 | 61.6 | 66.8 | 83.7 | 50.0 | 57.3 | 24.8 | 37.5 | 48.7 | 92 | | [Intern-Math](https://huggingface.co/internlm/internlm2-math-base-20b) | 20B | - | 31B | 125B | 65.4 | 30.0 | 75.7 | 79.3 | **94.0** | 50.9 | 38.5 | 53.1 | 71.9 | 62.1 | 93 | | [DeepSeekMath](https://huggingface.co/deepseek-ai/deepseek-math-7b-base) | 7B | - | 120B | 500B | 64.1 | **34.2** | 74.0 | **83.9** | 92.4 | **63.4** | **62.4** | 56.4 | **84.4** | **68.4** | 94 | | [Rho-Math](https://huggingface.co/microsoft/rho-math-7b-v0.1) | 7B | OWM | 14B | **10.5B** | **66.9** | 31.0 | **77.8** | 79.0 | 93.9 | 49.9 | 58.7 | 54.6 | **84.4** | 66.2 | 95 | 96 | 97 | [^1]: We suggest utilizing the [OpenAI test subset](https://github.com/openai/prm800k) for evaluating MATH performance, since the original `MATH` test set has already been included in public training sets such as PRM800k. We use [minerva_math](/prompts/cot/minerva_math.md) prompt. 98 | [^2]: abbreviations: TAB=tabmwp, MQA = mathqa, SAT = sat_math 99 | 100 | 101 | ### SFT Models (Code Interpreter) 102 | 103 | > PROMPT_TYPE=tora 104 | 105 | | Model | Size | SFT Data | GSM8k | MATH | SVAMP | ASDiv | MAWPS | TAB | GSM-Hard | AVG | 106 | |------------------|------|----------|-------|------|-------|-------|-------|-----|----------|------| 107 | | GPT4-early (PAL) | - | - | 94.2 | 51.8 | 94.8 | 92.6 | 97.7 | 95.9| 77.6 | 86.4 | 108 | | MAmmoTH | 70B | MI-260k | 76.9 | 41.8 | 82.4 | - | - | - | - | - | 109 | | ToRA | 7B | ToRA-69k | 68.8 | 40.1 | 68.2 | 73.9 | 88.8 | 42.4 | 54.6 | 62.4 | 110 | | ToRA | 70B | ToRA-69k | 84.3 | 49.7 | 82.7 | 86.8 | 93.8 | 74.0 | 67.2 | 76.9 | 111 | | DeepSeekMath | 7B | ToRA-69k | 79.8 | 52.0 | 80.1 | 87.1 | 93.8 | 85.8 | 63.1 | 77.4 | 112 | | Rho-Math | 1B | ToRA-69k | 59.4 | 40.6 | 60.7 | 74.2 | 88.6 | 26.7 | 48.1 | 56.9 | 113 | | Rho-Math | 7B | ToRA-69k | 81.3 | 51.8 | 80.8 | 85.5 | 94.5 | 70.1 | 63.1 | 75.3 | 114 | 115 | 116 | ### SFT Models (CoT) 117 | 118 | > PROMPT_TYPE=deepseek-math 119 | 120 | | Size | Model | GSM8k | MATH | SWAMP | ASDiv | MAWPS | AVG | 121 | |----------|------------------------|:-----:|:--------:|:-----:|:-----:|:-----:|:-----:| 122 | | **7B** | DeepSeek-Math-Instruct | 82.4 | 45.8 | 83.5 | 90.1 | 95.7 | 79.5 | 123 | | | DeepSeek-Math-RL | 88.3 | 50.0 | 87.2 | 92.0 | 95.5 | 82.6 | 124 | 125 | 126 | ## 🍀 Contributing 127 | 128 | This project is still under active development. We welcome any contributions, including bug reports, feature requests, and pull requests. 129 | 130 | 131 | ## ☕️ References 132 | 133 | - https://github.com/microsoft/ToRA 134 | - https://github.com/openai/prm800k 135 | - https://github.com/wellecks/lm-evaluation-harness 136 | - https://github.com/deepseek-ai/DeepSeek-Math 137 | -------------------------------------------------------------------------------- /data/hungarian_exam/test.jsonl: -------------------------------------------------------------------------------- 1 | {"problem": "Given are two sets: $A=\\{a ; b ; e ; g\\}$ and $B=\\{a ; b ; c ; d ; f\\}$.\n\nBy listing its elements, give the set $B \\backslash A$.", "idx": 0, "solution": ""} 2 | {"problem": "Bori, Krist\u00f3f and Marci are playing a role-playing card game. At the beginning of the game they each select one out of 10 role cards, without replacement. In how many different arrangements of the roles can the game begin?", "idx": 1, "solution": ""} 3 | {"problem": "Zita's salary has been raised from $275000$ Ft to $308000$ Ft. By what percentage has Zita's salary been raised?", "idx": 2, "solution": ""} 4 | {"problem": "In triangle $A B C \\overrightarrow{A B}=\\mathbf{b}, \\overrightarrow{A C}=\\mathbf{c}$. The midpoint of side $A B$ is point $F$, the midpoint of side $A C$ is $G$. Express vector $\\overrightarrow{F G}$ in terms of vectors $\\mathbf{b}$ and $\\mathbf{c}$. Explain your answer.", "idx": 3, "solution": ""} 5 | {"problem": "Give five positive numbers such that their median is 3 and their range is 7 .", "idx": 4, "solution": ""} 6 | {"problem": "Determine the decimal (base 10) value of the binary (base 2) number 101011.", "idx": 5, "solution": ""} 7 | {"problem": "It is known that $\\log_{2}(x)=5$. Give the value of $\\log_{2}(2x)$. Explain your answer.", "idx": 6, "solution": ""} 8 | {"problem": "List all integer values of $x$ for which both of the inequalities $-6 \\leq x \\leq 2$ and $-4 bool: 59 | """ 60 | Exact match of math if and only if: 61 | 1. numerical equal: both can convert to float and are equal 62 | 2. symbolic equal: both can convert to sympy expression and are equal 63 | """ 64 | # print("Judge:", prediction, reference) 65 | if str(prediction) == str(reference): 66 | return True 67 | 68 | try: # 1. numerical equal 69 | if is_digit(prediction) and is_digit(reference): 70 | prediction = parse_digits(prediction) 71 | reference = parse_digits(reference) 72 | # number questions 73 | if include_percentage: 74 | gt_result = [reference / 100, reference, reference * 100] 75 | else: 76 | gt_result = [reference] 77 | for item in gt_result: 78 | try: 79 | if is_close: 80 | if numeric_equal(prediction, item): 81 | return True 82 | else: 83 | if item == prediction: 84 | return True 85 | except Exception: 86 | continue 87 | return False 88 | except: 89 | pass 90 | 91 | if not prediction and prediction not in [0, False]: 92 | return False 93 | # print("try math_eval") 94 | 95 | # 2. symbolic equal 96 | reference = str(reference).strip() 97 | prediction = str(prediction).strip() 98 | 99 | ## pmatrix (amps) 100 | if "pmatrix" in prediction and not 'pmatrix' in reference: 101 | reference = str_to_pmatrix(reference) 102 | 103 | ## deal with [], (), {} 104 | pred_str, ref_str = prediction, reference 105 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ 106 | (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): 107 | pred_str = pred_str.strip("[]()") 108 | ref_str = ref_str.strip("[]()") 109 | for s in ['{', "}", "(", ")"]: 110 | ref_str = ref_str.replace(s, "") 111 | pred_str = pred_str.replace(s, "") 112 | if pred_str.lower() == ref_str.lower(): 113 | return True 114 | 115 | ## [a, b] vs. [c, d], return a==c and b==d 116 | if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: 117 | pred_parts = prediction[1:-1].split(",") 118 | ref_parts = reference[1:-1].split(",") 119 | if len(pred_parts) == len(ref_parts): 120 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 121 | return True 122 | if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ 123 | (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): 124 | pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] 125 | ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] 126 | matched = True 127 | if len(pred_lines) == len(ref_lines): 128 | for pred_line, ref_line in zip(pred_lines, ref_lines): 129 | pred_parts = pred_line.split("&") 130 | ref_parts = ref_line.split("&") 131 | if len(pred_parts) == len(ref_parts): 132 | if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 133 | matched = False 134 | break 135 | else: 136 | matched = False 137 | if not matched: 138 | break 139 | else: 140 | matched = False 141 | if matched: 142 | return True 143 | 144 | if prediction.count('=') == 1 and reference.count('=') == 1: 145 | pred = prediction.split('=') 146 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 147 | ref = reference.split('=') 148 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 149 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 150 | return True 151 | elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: 152 | if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): 153 | return True 154 | elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: 155 | if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): 156 | return True 157 | 158 | # print("try final") 159 | # symbolic equal with sympy 160 | if timeout: 161 | if call_with_timeout(symbolic_equal_process, prediction, reference): 162 | return True 163 | else: 164 | if symbolic_equal(prediction, reference): 165 | return True 166 | 167 | return False 168 | 169 | 170 | def math_equal_process(param): 171 | return math_equal(param[-2], param[-1]) 172 | 173 | 174 | def numeric_equal(prediction: float, reference: float): 175 | # Note that relative tolerance has significant impact 176 | # on the result of the synthesized gsm_hard dataset 177 | # if reference.is_integer(): 178 | # return isclose(reference, round(prediction), abs_tol=1e-4) 179 | # else: 180 | # prediction = round(prediction, len(str(reference).split(".")[-1])) 181 | return isclose(reference, prediction, rel_tol=1e-4) 182 | 183 | 184 | def symbolic_equal(a, b): 185 | def _parse(s): 186 | for f in [parse_latex, parse_expr, latex2sympy]: 187 | try: 188 | return f(s.replace("\\\\", "\\")) 189 | except: 190 | try: 191 | return f(s) 192 | except: 193 | pass 194 | return s 195 | a = _parse(a) 196 | b = _parse(b) 197 | 198 | # direct equal 199 | try: 200 | if str(a) == str(b) or a == b: 201 | return True 202 | except: 203 | pass 204 | 205 | # print("try simplify") 206 | # simplify equal 207 | try: 208 | if a.equals(b) or simplify(a-b) == 0: 209 | return True 210 | except: 211 | pass 212 | 213 | # print("try equation") 214 | # equation equal 215 | try: 216 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): 217 | return True 218 | except: 219 | pass 220 | 221 | try: 222 | if numeric_equal(float(N(a)), float(N(b))): 223 | return True 224 | except: 225 | pass 226 | 227 | # matrix 228 | try: 229 | # if a and b are matrix 230 | if a.shape == b.shape: 231 | _a = a.applyfunc(lambda x: round(x, 3)) 232 | _b = b.applyfunc(lambda x: round(x, 3)) 233 | if _a.equals(_b): 234 | return True 235 | except: 236 | pass 237 | 238 | return False 239 | 240 | 241 | def symbolic_equal_process(a, b, output_queue): 242 | result = symbolic_equal(a, b) 243 | output_queue.put(result) 244 | 245 | 246 | def call_with_timeout(func, *args, timeout=1, **kwargs): 247 | output_queue = multiprocessing.Queue() 248 | process_args = args + (output_queue,) 249 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 250 | process.start() 251 | process.join(timeout) 252 | 253 | if process.is_alive(): 254 | process.terminate() 255 | process.join() 256 | return False 257 | 258 | return output_queue.get() 259 | 260 | 261 | def _test_math_equal(): 262 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) 263 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) 264 | # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) 265 | # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True)) 266 | # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True)) 267 | 268 | # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}' 269 | # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})' 270 | 271 | # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}' 272 | # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}' 273 | 274 | # pred = '-34x-45y+20z-100=0' 275 | # gt = '34x+45y-20z+100=0' 276 | 277 | # pred = '\\frac{100}{3}' 278 | # gt = '33.3' 279 | 280 | # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}' 281 | # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})' 282 | 283 | # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}' 284 | # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}' 285 | 286 | # pred = '(+5)(b+2)' 287 | # gt = '(a+5)(b+2)' 288 | 289 | # pred = '\\frac{1+\\sqrt{5}}{2}' 290 | # gt = '2' 291 | 292 | # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4' 293 | # pred = '1', gt = '1\\\\sqrt{19}' 294 | 295 | pred = '(0.6,2.6667]' 296 | gt = '(\\frac{3}{5},\\frac{8}{3}]' 297 | 298 | print(math_equal(pred, gt, timeout=True)) 299 | 300 | 301 | if __name__ == "__main__": 302 | _test_math_equal() 303 | 304 | 305 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/allenai/open-instruct 3 | """ 4 | import torch 5 | import tqdm 6 | from transformers import StoppingCriteria, StoppingCriteriaList 7 | 8 | 9 | class KeywordsStoppingCriteria(StoppingCriteria): 10 | def __init__(self, keywords_str, tokenizer): 11 | StoppingCriteria.__init__(self) 12 | self.current_context = [] 13 | self.tokenizer = tokenizer 14 | self.keywords_str = keywords_str 15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 16 | if len(self.current_context) == 0: 17 | self.current_context = [[] for _ in range(input_ids.shape[0])] 18 | 19 | # self.current_context.append(input_ids[0][-1].item()) 20 | sequences_should_be_stopped = [] 21 | for i in range(input_ids.shape[0]): 22 | _id = input_ids[i][-1].item() 23 | self.current_context[i].append(_id) 24 | current_context = self.tokenizer.decode(self.current_context[i]) 25 | should_be_stopped = False 26 | for word in self.keywords_str: 27 | if word in current_context: 28 | should_be_stopped = True 29 | break 30 | sequences_should_be_stopped.append(should_be_stopped) 31 | return all(sequences_should_be_stopped) 32 | 33 | 34 | class KeyWordsCriteriaTrunc(StoppingCriteria): 35 | def __init__(self, stop_id_sequences, prompt_length): 36 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 37 | self.stop_sequences = stop_id_sequences 38 | self.prompt_length = prompt_length 39 | 40 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 41 | sequences_should_be_stopped = [] 42 | for i in range(input_ids.shape[0]): 43 | ids = input_ids[i][self.prompt_length:].tolist() 44 | should_be_stopped = False 45 | for stop_sequence in self.stop_sequences: 46 | if input_ids.shape[0] == 1: 47 | _ids = ids[-len(stop_sequence):] 48 | else: 49 | _ids = ids 50 | for j in range(len(_ids), 0, -len(stop_sequence)): 51 | if _ids[max(j - len(stop_sequence), 0): j] == stop_sequence: 52 | should_be_stopped = True 53 | break 54 | if should_be_stopped: 55 | break 56 | sequences_should_be_stopped.append(should_be_stopped) 57 | return all(sequences_should_be_stopped) 58 | 59 | 60 | class KeyWordsCriteria(StoppingCriteria): 61 | def __init__(self, stop_id_sequences): 62 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 63 | self.stop_sequences = stop_id_sequences 64 | 65 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 66 | sequences_should_be_stopped = [] 67 | for i in range(input_ids.shape[0]): 68 | sequence_should_be_stopped = False 69 | for stop_sequence in self.stop_sequences: 70 | if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence: 71 | sequence_should_be_stopped = True 72 | break 73 | sequences_should_be_stopped.append(sequence_should_be_stopped) 74 | return all(sequences_should_be_stopped) 75 | 76 | 77 | @torch.no_grad() 78 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs): 79 | generations = [] 80 | if not disable_tqdm: 81 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions") 82 | 83 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1) 84 | for i in range(0, len(prompts), batch_size): 85 | batch_prompts = prompts[i:i+batch_size] 86 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens) 87 | batch_input_ids = tokenized_prompts.input_ids 88 | attention_mask = tokenized_prompts.attention_mask 89 | 90 | if model.device.type == "cuda": 91 | batch_input_ids = batch_input_ids.cuda() 92 | attention_mask = attention_mask.cuda() 93 | 94 | # try: 95 | stop_criteria = KeywordsStoppingCriteria(stop_id_sequences, tokenizer) 96 | batch_outputs = model.generate( 97 | input_ids=batch_input_ids, 98 | attention_mask=attention_mask, 99 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 100 | # stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None, 101 | # stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None, 102 | **generation_kwargs 103 | ) 104 | 105 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate. 106 | # so some outputs still have the stop sequence, which we need to remove. 107 | # if stop_id_sequences: 108 | # for output_idx in range(batch_outputs.shape[0]): 109 | # for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]): 110 | # if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences): 111 | # batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id 112 | # break 113 | 114 | # remove the prompt from the output 115 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs. 116 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token. 117 | # space is important for some tasks (e.g., code completion). 118 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True) 119 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True) 120 | # duplicate the prompts to match the number of return sequences 121 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)] 122 | batch_generations = [ 123 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs) 124 | ] 125 | 126 | # remove the remain stop sequence from the output. 127 | for idx, prediction in enumerate(batch_generations): 128 | for stop_sequence in stop_id_sequences: 129 | batch_generations[idx] = prediction.split(stop_sequence)[0] 130 | 131 | generations += batch_generations 132 | 133 | if not disable_tqdm: 134 | progress.update(len(batch_prompts)//num_return_sequences) 135 | 136 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences" 137 | return generations 138 | 139 | 140 | def load_hf_lm_and_tokenizer( 141 | model_name_or_path, 142 | tokenizer_name_or_path=None, 143 | device_map="auto", 144 | load_in_8bit=False, 145 | load_in_half=True, 146 | gptq_model=False, 147 | use_fast_tokenizer=False, 148 | padding_side="left", 149 | use_safetensors=False, 150 | ): 151 | import torch 152 | from transformers import AutoModelForCausalLM, AutoTokenizer 153 | 154 | if not tokenizer_name_or_path: 155 | tokenizer_name_or_path = model_name_or_path 156 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True) 157 | # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=False, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True) 158 | 159 | # set pad token to eos token if pad token is not set 160 | if tokenizer.pad_token is None: 161 | if tokenizer.unk_token: 162 | tokenizer.pad_token = tokenizer.unk_token 163 | tokenizer.pad_token_id = tokenizer.unk_token_id 164 | elif tokenizer.eos_token: 165 | tokenizer.pad_token = tokenizer.eos_token 166 | tokenizer.pad_token_id = tokenizer.eos_token_id 167 | else: 168 | raise ValueError("You are using a new tokenizer without a pad token." 169 | "This is not supported by this script.") 170 | 171 | # if tokenizer.pad_token is None: 172 | # tokenizer.pad_token = tokenizer.unk_token 173 | # tokenizer.pad_token_id = tokenizer.unk_token_id 174 | 175 | if gptq_model: 176 | from auto_gptq import AutoGPTQForCausalLM 177 | model_wrapper = AutoGPTQForCausalLM.from_quantized( 178 | model_name_or_path, device="cuda:0", use_triton=True 179 | ) 180 | model = model_wrapper.model 181 | elif load_in_8bit: 182 | model = AutoModelForCausalLM.from_pretrained( 183 | model_name_or_path, 184 | device_map=device_map, 185 | load_in_8bit=True 186 | ) 187 | else: 188 | # return "", tokenizer 189 | # defaul load in float16 190 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 191 | torch_dtype=torch.float16, 192 | device_map=device_map, 193 | trust_remote_code=True, 194 | use_safetensors=use_safetensors) 195 | if torch.cuda.is_available(): 196 | model = model.cuda() 197 | if load_in_half: 198 | model = model.half() 199 | model.eval() 200 | return model, tokenizer 201 | 202 | 203 | def _test_generate_completions(): 204 | model_name_or_path = "../models/codellama_7b/v1-16k" 205 | llm, tokenizer = load_hf_lm_and_tokenizer( 206 | model_name_or_path=model_name_or_path, 207 | load_in_half=True, 208 | use_fast_tokenizer=True, 209 | use_safetensors=True, 210 | ) 211 | # some math word problems 212 | prompts = [ 213 | "---\n1+1=2\n---2+2=4\n---3+3=6\n---4+4=8\n---5+5=10\n---6+6=", 214 | "---\n1+1=2\n---12+12=24\n---3+3=6\n---12345+12345=", 215 | # "A train leaves Chicago at 7am and travels at 60mph. Another train leaves Chicago at 9am and travels at 80mph. When will the second train overtake the first?", 216 | # "The sum of two numbers is 10. The difference of the same two numbers is 4. What are the two numbers?", 217 | ] 218 | 219 | stop_sequences = ["\n\n\n", "---"] 220 | # Because many tokenizers will treat the word after space differently from the original word alone, 221 | # to be consistent, we add a space before tokenization and remove it after tokenization. 222 | # stop_id_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences] 223 | outputs = generate_completions( 224 | model=llm, 225 | tokenizer=tokenizer, 226 | prompts=prompts, 227 | max_new_tokens=128, 228 | batch_size=16, 229 | # stop_id_sequences=stop_id_sequences, 230 | stop_id_sequences=stop_sequences, 231 | ) 232 | print(outputs) 233 | 234 | if __name__ == "__main__": 235 | _test_generate_completions() -------------------------------------------------------------------------------- /math_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | import time 5 | from vllm import LLM, SamplingParams 6 | from datetime import datetime 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | 12 | from evaluate import evaluate 13 | from utils import set_seed, load_jsonl, save_jsonl, construct_prompt 14 | from parser import * 15 | from trajectory import * 16 | from data_loader import load_data 17 | from python_executor import PythonExecutor 18 | from model_utils import load_hf_lm_and_tokenizer, generate_completions 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--data_names", default="gsm8k,math", type=str) 24 | parser.add_argument("--data_dir", default="./data", type=str) 25 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str) 26 | parser.add_argument("--output_dir", default="./output", type=str) 27 | parser.add_argument("--prompt_type", default="tool-integrated", type=str) 28 | parser.add_argument("--split", default="test", type=str) 29 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data 30 | parser.add_argument("--seed", default=0, type=int) 31 | parser.add_argument("--start", default=0, type=int) 32 | parser.add_argument("--end", default=-1, type=int) 33 | parser.add_argument("--temperature", default=0, type=float) 34 | parser.add_argument("--n_sampling", default=1, type=int) 35 | parser.add_argument("--top_p", default=1, type=float) 36 | parser.add_argument("--max_tokens_per_call", default=1024, type=int) 37 | parser.add_argument("--shuffle", action="store_true") 38 | parser.add_argument("--use_vllm", action="store_true") 39 | parser.add_argument("--save_outputs", action="store_true") 40 | parser.add_argument("--overwrite", action="store_true") 41 | parser.add_argument("--use_safetensors", action="store_true") 42 | args = parser.parse_args() 43 | args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy sampling (vllm) 44 | return args 45 | 46 | 47 | def prepare_data(data_name, args): 48 | examples = load_data(data_name, args.split, args.data_dir) 49 | 50 | # sample `num_test_sample` from dataset 51 | if args.num_test_sample > 0: 52 | examples = random.sample(examples, args.num_test_sample) 53 | 54 | # shuffle 55 | if args.shuffle: 56 | random.shuffle(examples, seed=datetime.now().timestamp()) 57 | 58 | # select start and end 59 | examples = examples[args.start:len(examples) if args.end == -1 else args.end] 60 | 61 | # get out_file name 62 | dt_string = datetime.now().strftime("%m-%d_%H-%M") 63 | model_name = "/".join(args.model_name_or_path.split("/")[-2:]) 64 | out_file_prefix = f'{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}' 65 | # out_file = f'{args.output_dir}/{model_name}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl' 66 | out_file = f'{args.output_dir}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}.jsonl' 67 | os.makedirs(f'{args.output_dir}/{data_name}', exist_ok=True) 68 | 69 | # load all processed samples 70 | processed_samples = [] 71 | if not args.overwrite: 72 | processed_files = [f for f in os.listdir(f"{args.output_dir}/{data_name}/") if f.endswith(".jsonl") and f.startswith(out_file_prefix)] 73 | for f in processed_files: 74 | processed_samples.extend(list(load_jsonl(f"{args.output_dir}/{data_name}/{f}"))) 75 | 76 | # dedepulicate 77 | processed_samples = {sample['idx']: sample for sample in processed_samples} 78 | processed_idxs = list(processed_samples.keys()) 79 | processed_samples = list(processed_samples.values()) 80 | total_examples = len(examples) 81 | examples = [example for example in examples if example['idx'] not in processed_idxs] 82 | # print(f"Idx {args.start} - {args.end}: Remain {len(examples)}/{total_examples} samples.") 83 | return examples, processed_samples, out_file 84 | 85 | 86 | def setup(args): 87 | # load model 88 | available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') 89 | if args.use_vllm: 90 | llm = LLM(model=args.model_name_or_path, tensor_parallel_size=len(available_gpus), trust_remote_code=True) 91 | tokenizer = None 92 | else: 93 | llm, tokenizer = load_hf_lm_and_tokenizer( 94 | model_name_or_path=args.model_name_or_path, 95 | load_in_half=True, 96 | use_fast_tokenizer=True, 97 | use_safetensors=args.use_safetensors, 98 | ) 99 | 100 | # infer & eval 101 | data_list = args.data_names.split(',') 102 | results = [] 103 | for data_name in data_list: 104 | results.append(main(llm, tokenizer, data_name, args)) 105 | 106 | # add "avg" result to data_list and results 107 | data_list.append("avg") 108 | results.append({ 109 | "acc": sum([result["acc"] for result in results]) / len(results), 110 | }) 111 | 112 | # print all results 113 | pad = max([len(data_name) for data_name in data_list]) 114 | print("\t".join(data_name.ljust(pad, " ") for data_name in data_list)) 115 | print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results])) 116 | 117 | 118 | def main(llm, tokenizer, data_name, args): 119 | examples, processed_samples, out_file = prepare_data(data_name, args) 120 | print("=" * 50) 121 | print("data:", data_name, " ,remain samples:", len(examples)) 122 | if len(examples) > 0: 123 | print(examples[0]) 124 | 125 | # init python executor 126 | if "pal" in args.prompt_type: 127 | executor = PythonExecutor(get_answer_expr='solution()') 128 | else: 129 | executor = PythonExecutor(get_answer_from_stdout=True) 130 | 131 | samples = [] 132 | for example in tqdm(examples, total=len(examples)): 133 | idx = example['idx'] 134 | 135 | # parse question and answer 136 | example['question'] = parse_question(example, data_name) 137 | gt_cot, gt_ans = parse_ground_truth(example, data_name) 138 | full_prompt = construct_prompt(example, data_name, args) 139 | 140 | if idx == args.start: 141 | print(full_prompt) 142 | 143 | sample = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, 'prompt': full_prompt} 144 | 145 | # add remain fields 146 | for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', \ 147 | 'ans_type', 'answer_type', 'dataset', 'subfield', 'filed', 'theorem', 'answer']: 148 | if key in example: 149 | sample[key] = example[key] 150 | samples.append(sample) 151 | 152 | 153 | # repeat n times 154 | input_prompts = [sample['prompt'] for sample in samples for _ in range(args.n_sampling)] 155 | remain_prompts = input_prompts 156 | remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)] 157 | end_prompts = [] 158 | 159 | max_func_call = 1 if args.prompt_type in ['cot', 'pal'] else 4 160 | 161 | # stop words TODO: make it more general 162 | stop_words = [""] 163 | 164 | if args.prompt_type in ['cot']: 165 | stop_words.extend(["\n\nQuestion:", "\n\nProblem:"]) 166 | if args.prompt_type in ['pal', 'tool-integrated', 'tora']: 167 | stop_words.extend(["\n\n---", "```output"]) 168 | elif args.prompt_type in ['wizard_zs', 'platypus_fs']: 169 | stop_words.extend(["Instruction", "Response"]) 170 | print("Stop words:", stop_words) 171 | 172 | # start inference 173 | # measure time use 174 | start_time = time.time() 175 | for epoch in range(max_func_call): 176 | print("-" * 20, "Epoch", epoch) 177 | current_prompts = remain_prompts 178 | if len(current_prompts) == 0: 179 | break 180 | 181 | # get all outputs 182 | prompts = [item[1] for item in current_prompts] 183 | if args.use_vllm: 184 | outputs = llm.generate(prompts, SamplingParams( 185 | temperature=args.temperature, 186 | top_p=args.top_p, 187 | max_tokens=args.max_tokens_per_call, 188 | n=1, 189 | stop=stop_words, 190 | )) 191 | 192 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id 193 | outputs = [output.outputs[0].text for output in outputs] 194 | else: 195 | outputs = generate_completions( 196 | model=llm, 197 | tokenizer=tokenizer, 198 | prompts=prompts, 199 | max_new_tokens=args.max_tokens_per_call, 200 | batch_size=16, 201 | stop_id_sequences=stop_words, 202 | ) 203 | 204 | assert len(outputs) == len(current_prompts) 205 | 206 | # process all outputs 207 | remain_prompts = [] 208 | remain_codes = [] 209 | for (i, query), output in zip(current_prompts, outputs): 210 | output = output.rstrip() 211 | query += output 212 | if args.prompt_type == "pal": 213 | remain_prompts.append((i, query)) 214 | if "```python" in output: 215 | output = extract_program(query) 216 | remain_codes.append(output) 217 | elif args.prompt_type == "cot": 218 | end_prompts.append((i, query)) 219 | elif ("boxed" not in output and output.endswith("```")): 220 | program = extract_program(query) 221 | remain_prompts.append((i, query)) 222 | remain_codes.append(program) 223 | else: 224 | end_prompts.append((i, query)) 225 | 226 | # execute the remain prompts 227 | remain_results = executor.batch_apply(remain_codes) 228 | for k in range(len(remain_prompts)): 229 | i, query = remain_prompts[k] 230 | res, report = remain_results[k] 231 | exec_result = res if res else report 232 | if "pal" in args.prompt_type: 233 | exec_result = "\\boxed{" + exec_result + "}" 234 | exec_result = f"\n```output\n{exec_result}\n```\n" 235 | query += exec_result 236 | # not end 237 | if epoch == max_func_call - 1: 238 | query += "\nReach max function call limit." 239 | remain_prompts[k] = (i, query) 240 | 241 | # unsolved samples 242 | print("Unsolved samples:", len(remain_prompts)) 243 | end_prompts.extend(remain_prompts) 244 | # sort by idx 245 | end_prompts = sorted(end_prompts, key=lambda x: x[0]) 246 | 247 | # remove input_prompt from end_prompt 248 | codes = [] 249 | assert len(input_prompts) == len(end_prompts) 250 | for i in range(len(input_prompts)): 251 | _, end_prompt = end_prompts[i] 252 | code = end_prompt.split(input_prompts[i])[-1].strip() 253 | codes.append(code) 254 | 255 | # extract preds 256 | results = [run_execute(executor, code, args.prompt_type, data_name) for code in codes] 257 | time_use = time.time() - start_time 258 | 259 | # put results back to examples 260 | all_samples = [] 261 | for i, sample in enumerate(samples): 262 | code = codes[i*args.n_sampling: (i+1)*args.n_sampling] 263 | result = results[i*args.n_sampling: (i+1)*args.n_sampling] 264 | preds = [item[0] for item in result] 265 | reports = [item[1] for item in result] 266 | 267 | sample.pop('prompt') 268 | sample.update({'code': code, 'pred': preds, 'report': reports}) 269 | all_samples.append(sample) 270 | 271 | # add processed samples 272 | all_samples.extend(processed_samples) 273 | all_samples, result_json = evaluate(samples=all_samples, data_name=data_name, prompt_type=args.prompt_type, execute=True) 274 | 275 | # save outputs 276 | if len(processed_samples) < len(all_samples) and args.save_outputs: 277 | save_jsonl(all_samples, out_file) 278 | 279 | result_json['time_use_in_second'] = time_use 280 | result_json['time_use_in_minite'] = f"{int(time_use // 60)}:{int(time_use % 60):02d}" 281 | 282 | with open(out_file.replace(".jsonl", f"_{args.prompt_type}_metrics.json"), "w") as f: 283 | json.dump(result_json, f, indent=4) 284 | return result_json 285 | 286 | if __name__ == "__main__": 287 | args = parse_args() 288 | set_seed(args.seed) 289 | setup(args) 290 | -------------------------------------------------------------------------------- /data/sat_math/test.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "0", "question": "The graph of the polynomial function $f$, where $y=f(x)$, has $x$-intercepts of $(-6,0)$ and $(6,0)$. Which of the following must be true?", "options": "A) $f(-6)=0$ B) $f(6)=-6$ C) $f(-6)=6$ D) $f(0)=-6$", "Answer": "A"} 2 | {"id": "1", "question": "$$\\begin{gathered} y=4 x+6 \\\\-5 x-y=21\\end{gathered}$$ What is the solution $(x, y)$ to the given system of equations?", "options": "A) $(-3,-6)$ B) $\\left(-\\frac{5}{3},-\\frac{2}{3}\\right)$ C) $(3,18)$ D) $(15,66)$", "Answer": "A"} 3 | {"id": "2", "question": "$\\lvert x-10 \\rvert = 0$ What are all the possible solutions to the given equation?", "options": "A) -10 B) 0 C) 10 D) -10 and 10", "Answer": "C"} 4 | {"id": "3", "question": "$$q=s(r-1)^2$$ The given equation relates the positive numbers $q, r$, and $s$. Which equation gives $r$ in terms of $q$ and $s$, when $r>1$?", "options": "A) $r=1+\\sqrt{\\frac{q}{s}}$ B) $r=1+\\frac{\\sqrt{q}}{s}$ C) $r=-1-\\sqrt{\\frac{q}{s}}$ D) $r=-1-\\frac{\\sqrt{q}}{s}$", "Answer": "A"} 5 | {"id": "4", "question": "In the relationship between variables $x$ and $y$, each increase of $1$ in the value of $x$ decreases the value of $y$ by 2. When $x=0$, $y=5$. Which equation represents this relationship?", "options": "A) $y=-\\frac{1}{2}x+5$ B) $y=-\\frac{1}{2}x-5$ C) $y=-2x-5$ D) $y=-2x+5$", "Answer": "D"} 6 | {"id": "5", "question": "An isosceles right triangle has a hypotenuse of length 4 inches. What is the perimeter, in inches, of this triangle?", "options": "A) $2\\sqrt{2}$ B) $4\\sqrt{2}$ C) $4+4\\sqrt{2}$ D) $4+8\\sqrt{2}$", "Answer": "C"} 7 | {"id": "6", "question": "How many solutions does the equation $4(x-2) = -2(x+4)$ have?", "options": "A) Zero B) Exactly one C) Exactly two D) Infinitely many", "Answer": "B"} 8 | {"id": "7", "question": "$R(t) = 1,830 - 790(2.71)^{-.18t}$ The function $R$ gives the predicted average rating, expressed as a number of points, in the German chess federation database for a player based on the number of years, $t$, the player has participated in professional chess tournaments. Which of the following represents the predicted average rating of a player who has just entered their first professional chess tournament?", "options": "A) $R(-0.18)$ B) $R(0)$ C) $R(790)$ D) $R(1,830)$", "Answer": "B"} 9 | {"id": "8", "question": "Alice took 60 minutes to complete a task on her first trial. The time it took Alice to complete the task decreased by 10% of the previous time for each additional trial. Approximately how many minutes will it take Alice to complete the task on her fifth trial?", "options": "A) 50 B) 42 C) 39 D) 35", "Answer": "C"} 10 | {"id": "9", "question": "$$ \\begin{aligned} & y<\\frac{2}{5} x+3 \\\\& y>\\frac{1}{2} x-6\\end{aligned}$$ In which of the following tables are all the values of $x$ and their corresponding values of $y$ solutions to the system of inequalities shown?", "options": "A) \\begin{tabular}{|r|r|} \\hline$x$ & $y$ \\\\\\hline-2 & -8 \\\\\\hline 0 & -4 \\\\\\hline 4 & 4 \\\\\\hline\\end{tabular} B) \\begin{tabular}{|c|c|}\\hline$x$ & $y$ \\\\\\hline-2 & -8 \\\\\\hline 0 & 4 \\\\\\hline 4 & 4 \\\\\\hline\\end{tabular} C) \\begin{tabular}{|r|r|}\\hline$x$ & $y$ \\\\\\hline-2 & 3 \\\\\\hline 0 & 2 \\\\\\hline 4 & -3 \\\\\\hline\\end{tabular} D) \\begin{tabular}{|r|r|}\\hline$x$ & $y$ \\\\\\hline-2 & 2 \\\\\\hline 0 & -3 \\\\\\hline 4 & 3 \\\\\\hline\\end{tabular}", "Answer": "D"} 11 | {"id": "10", "question": "Which of the following is equivalent to $(\\sqrt{32})(\\sqrt[5]{64})$?", "options": "A) $6\\left(\\sqrt[7]{2^5}\\right)$ B) $6\\left(\\sqrt[10]{2^7}\\right)$ C) $8\\left(\\sqrt[7]{2^5}\\right)$ D) $8\\left(\\sqrt[10]{2^7}\\right)$", "Answer": "D"} 12 | {"id": "11", "question": "An object has a mass of 3,300 milligrams. What is the mass of the object in grams? (1 gram = 1,000 milligrams)", "options": "A) 0.33 B) 3.30 C) 33.00 D) 330.00", "Answer": "B"} 13 | {"id": "12", "question": "On average, one square inch of human skin contains 650 sweat glands. A certain area of skin contains 1,170 sweat glands. Based on this information, which of the following is closest to the size of this area, in square inches?", "options": "A) 0.44 B) 0.56 C) 0.80 D) 1.80", "Answer": "D"} 14 | {"id": "13", "question": "The table give the heights, in feet, of 5 peaks in the Rocky Mountains and 5 peaks in the Appalachian Mountains. \\begin{tabular}{|l|l|l|l|l|} \\hline $\\begin{array}{l}\\text { Rocky } \\\\\\text { Mountain } \\\\\\text { Peak }\\end{array}$ & $\\begin{array}{l}\\text { Height } \\\\\\text { (in feet) }\\end{array}$ & $\\begin{array}{l}\\text { Appalachian } \\\\\\text { Mountain } \\\\\\text { Peak }\\end{array}$ & $\\begin{array}{l}\\text { Height } \\\\\\text { (in feet) }\\end{array}$ \\\\\\hline $\\begin{array}{l}\\text { Mount } \\\\\\text { Elbert }\\end{array}$ & 14,439 & $\\begin{array}{l}\\text { Mount } \\\\\\text { Mitchell }\\end{array}$ & 6,684 \\\\\\hline $\\begin{array}{l}\\text { Mount } \\\\\\text { Massive }\\end{array}$ & 14,429 & Mount Craig & 6,647 \\\\\\hline $\\begin{array}{l}\\text { Mount } \\\\\\text { Harvard }\\end{array}$ & 14,419 & $\\begin{array}{l}\\text { Clingman's } \\\\\\text { Dome }\\end{array}$ & 6,643 \\\\\\hline $\\begin{array}{l}\\text { Blanca } \\\\\\text { Peak }\\end{array}$ & 14,350 & $\\begin{array}{l}\\text { Mount } \\\\\\text { Guyot }\\end{array}$ & 6,621 \\\\\\hline $\\begin{array}{l}\\text { La Plata } \\\\\\text { Peak }\\end{array}$ & 14,343 & $\\begin{array}{l}\\text { Balsam } \\\\\\text { Cone }\\end{array}$ & 6,611 \\\\\\hline\\end{tabular} What is the height, in meters, of Blanca Peak? (Use 1 meter $=3.28$ feet)", "options": "A) 437.5 B) 4,375 C) 47,045 D) 47,068", "Answer": "B"} 15 | {"id": "14", "question": "The table give the heights, in feet, of 5 peaks in the Rocky Mountains and 5 peaks in the Appalachian Mountains. \\begin{tabular}{|l|l|l|l|l|} \\hline $\\begin{array}{l}\\text { Rocky } \\\\\\text { Mountain } \\\\\\text { Peak }\\end{array}$ & $\\begin{array}{l}\\text { Height } \\\\\\text { (in feet) }\\end{array}$ & $\\begin{array}{l}\\text { Appalachian } \\\\\\text { Mountain } \\\\\\text { Peak }\\end{array}$ & $\\begin{array}{l}\\text { Height } \\\\\\text { (in feet) }\\end{array}$ \\\\\\hline $\\begin{array}{l}\\text { Mount } \\\\\\text { Elbert }\\end{array}$ & 14,439 & $\\begin{array}{l}\\text { Mount } \\\\\\text { Mitchell }\\end{array}$ & 6,684 \\\\\\hline $\\begin{array}{l}\\text { Mount } \\\\\\text { Massive }\\end{array}$ & 14,429 & Mount Craig & 6,647 \\\\\\hline $\\begin{array}{l}\\text { Mount } \\\\\\text { Harvard }\\end{array}$ & 14,419 & $\\begin{array}{l}\\text { Clingman's } \\\\\\text { Dome }\\end{array}$ & 6,643 \\\\\\hline $\\begin{array}{l}\\text { Blanca } \\\\\\text { Peak }\\end{array}$ & 14,350 & $\\begin{array}{l}\\text { Mount } \\\\\\text { Guyot }\\end{array}$ & 6,621 \\\\\\hline $\\begin{array}{l}\\text { La Plata } \\\\\\text { Peak }\\end{array}$ & 14,343 & $\\begin{array}{l}\\text { Balsam } \\\\\\text { Cone }\\end{array}$ & 6,611 \\\\\\hline\\end{tabular} For the given Appalachian Mountain peaks, the height of the highest peak is approximately what percent greater than the height of the lowest peak?", "options": "A) $1.1 \\%$ B) $9.9 \\%$ C) $73.0 \\%$ D) $101.1 \\%$", "Answer": "A"} 16 | {"id": "15", "question": "Data set $A: 2,4,6,6,8,12$ Data set B: $2,4,6,6,8,12,26$ Two data sets are shown. Which statement best compares the medians of the data sets?", "options": "A) The median of data set A is greater than the median of data set $B$ B) The median of data set A is less than the median of data set B C) The medians of data sets A and B are equal D) There is not enough information to compare the medians", "Answer": "C"} 17 | {"id": "16", "question": "$$0.79 x+1.0 y=100$$ The mass of a solution of isopropanol and water is 100 grams. The given equation represents this situation, where $x$ is the volume of isopropanol, in cubic centimeters, and $y$ is the volume of water, in cubic centimeters. If the volume of isopropanol is 70 cubic centimeters, what is the approximate volume of water, in cubic centimeters?", "options": "A) 45 B) 55 C) 70 D) 79", "Answer": "A"} 18 | {"id": "17", "question": "There are 435 voting members of the US House of Representatives. If $b$ voting members are in favor of a certain bill, which expression represents the percentage of the voting members in favor of the bill?", "options": "A. $100\\left(\\frac{b}{435}\\right)$ B. $100\\left(\\frac{435}{b}\\right)$ C. $435\\left(\\frac{b}{100}\\right)$ D. $435(100 b)$", "Answer": "A"} 19 | {"id": "18", "question": "$$10(x+120)=120$$ Which of the following equations has the same solution as the given equation?", "options": "A) $x+120=12$ B) $x+120=130$ C) $x+12=12$ D) $x+12=120$", "Answer": "A"} 20 | {"id": "19", "question": "The given function $C$ models the annual soybean use in China, in millions of metric tons, between 1995 and 2014, where $x$ is the number of years after 1995. $$C(x)=4.3 x+19$$ According to the model, what is the best interpretation of 4.3 in this context?", "options": "A) Each year between 1995 and 2014, China used 4.3 million metric tons of soybeans B) Each year between 1995 and 2014, China's annual use of soybeans increased by 4.3 million metric tons C) China used 4.3 million metric tons of soybeans in 1995 D) China used a total of 4.3 million metric tons of soybeans between 1995 and 2014", "Answer": "B"} 21 | {"id": "20", "question": "$$ \\begin{gathered} C(x)=50,000+0.75 x \\\\ R(x)=4.75 x \\end{gathered}$$ The given function $C$ models the total cost (sum of fixed cost and variable cost), in dollars, of growing and harvesting $x$ bales of hay on a certain farm. The given function $R$ models the revenue, in dollars, earned from selling $x$ bales of hay. According to the function $R$, how many bales of hay would have to be sold to earn a revenue of $\\$1,425$?", "options": "A) 100 B) 300 C) 500 D) 1,000", "Answer": "B"} 22 | {"id": "21", "question": "$$ \\begin{gathered} C(x)=50,000+0.75 x \\\\ R(x)=4.75 x \\end{gathered}$$ The given function $C$ models the total cost (sum of fixed cost and variable cost), in dollars, of growing and harvesting $x$ bales of hay on a certain farm. The given function $R$ models the revenue, in dollars, earned from selling $x$ bales of hay. Which of the following inequalities models the number of bales of hay that must be sold to earn a profit of $\\$ 10,000$ or more? (profit $=$ revenue - cost)", "options": "A) $10,000 \\leq 4 x-50,000$ B) $10,000 \\geq 4 x-50,000$ C) $10,000 \\leq 4 x+50,000$ D) $10,000 \\geq 4 x+50,000$", "Answer": "A"} 23 | {"id": "22", "question": "Which expression is equivalent to $\\left(x^2+4\\right)^2+(x-2)(x+2) ?$", "options": "A) $x^4+x^2+20$ B) $x^4+5 x^2+16$ C) $x^4+9 x^2$ D) $x^4+9 x^2+12$", "Answer": "D"} 24 | {"id": "23", "question": "$$ \\begin{aligned} & y=4 x+1 \\\\ & y=4 x+3 \\end{aligned}$$ How many solutions does the given system of equations have?", "options": "A) Zero B) Exactly one C) Exactly two D) Infinitely many", "Answer": "A"} 25 | {"id": "24", "question": "$$ h(x)=3 x+3 $$ Which inequality represents all values of $x$ for which the graph of $y=h(x)$ in the $x y$-plane is above the $x$-axis?", "options": "A) $x<3$ B) $x<-1$ C) $x>-1$ D) $x>3$", "Answer": "C"} 26 | {"id": "25", "question": "Which quadratic equation has no real solutions?", "options": "A) $3 x^2-3=0$ B) $3 x^2+3 x=0$ C) $3 x^2+3 x+3=0$ D) $3 x^2-6 x+3=0$", "Answer": "C"} 27 | {"id": "26", "question": "In 1976, there were approximately 1,000 gray wolves in northern Minnesota. The number of gray wolves in northern Minnesota in 2008 was 190% greater than in 1976. Approximately how many gray wolves were in northern Minnesota in 2008?", "options": "A. 1,190 B. 1,900 C. 2,900 D. 19,000", "Answer": "C"} 28 | {"id": "27", "question": "When the quadratic function $f$ is graphed in the $x y$-plane, where $y=f(x)$, its vertex is $(-2,5)$. One of the $x$-intercepts of this graph is $\\left(-\\frac{7}{3}, 0\\right)$. What is the other $x$-intercept of the graph?", "options": "A. $\\left(-\\frac{13}{3}, 0\\right)$ B. $\\left(-\\frac{5}{3}, 0\\right)$ C. $\\left(\\frac{1}{3}, 0\\right)$ D. $\\left(\\frac{7}{3}, 0\\right)$", "Answer": "B"} 29 | {"id": "28", "question": "For an exponential function $g$, the value of $g(x)$ decreases by $20 \\%$ for each 1-unit increase in the value of $x$. If $g(2)=16$, which equation could define $g$ ?", "options": "A) $g(x)=16(0.8)^{x-2}$ B) $g(x)=16(0.8)^{x+2}$ C) $g(x)=16(0.2)^{x-2}$ D) $g(x)=16(0.2)^{x+2}$", "Answer": "A"} 30 | {"id": "29", "question": "Micha and Rana each selected a random sample of students at their school and asked how many soft drink servings each student had consumed the previous week. Micha estimated that the mean number of soft drink servings was 7.1, with an associated margin of error of 1.2. Rana estimated that the mean number of soft drink servings was 8.3, with an associated margin of error of 0.8. Assuming the margins of error were calculated in the same way, which of the following best explains why Rana obtained a smaller margin of error than Micha?", "options": "A. Rana's sample contained more students than Micha's sample contained. B. Rana's sample contained more students who drank soft drinks than Micha's sample contained. C. Rana's sample contained more students who drank exactly seven soft drink servings than Micha's sample contained. D. Rana's sample contained more students who drank exactly eight soft drink servings than Micha's sample contained.", "Answer": "A"} 31 | {"id": "30", "question": "A circle in the $x y$-plane has its center at $(-3,4)$ and the point $(-2,1)$ lies on the circle. Which equation represents this circle?", "options": "A) $(x-3)^2+(y+4)^2=\\sqrt{10}$ B) $(x+3)^2+(y-4)^2=\\sqrt{10}$ C) $(x-3)^2+(y+4)^2=10$ D) $(x+3)^2+(y-4)^2=10$", "Answer": "D"} 32 | {"id": "31", "question": "\\begin{tabular}{|c|c|} \\hline$x$ & $h(x)$ \\\\\\hline 2 & 0 \\\\\\hline 4 & 0 \\\\\\hline 6 & 8 \\\\\\hline \\end{tabular} For the quadratic function $h$, the table gives three values of $x$ and their corresponding values of $h(x)$. At what value of $x$ does $h$ reach its minimum?", "options": "A) -1 B) 0 C) 3 D) 4", "Answer": "C"} 33 | -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | import regex 3 | import sympy 4 | from typing import TypeVar, Iterable, List, Union, Any, Dict 5 | from word2number import w2n 6 | from utils import * 7 | 8 | 9 | def _fix_fracs(string): 10 | substrs = string.split("\\frac") 11 | new_str = substrs[0] 12 | if len(substrs) > 1: 13 | substrs = substrs[1:] 14 | for substr in substrs: 15 | new_str += "\\frac" 16 | if len(substr) > 0 and substr[0] == "{": 17 | new_str += substr 18 | else: 19 | try: 20 | assert len(substr) >= 2 21 | except: 22 | return string 23 | a = substr[0] 24 | b = substr[1] 25 | if b != "{": 26 | if len(substr) > 2: 27 | post_substr = substr[2:] 28 | new_str += "{" + a + "}{" + b + "}" + post_substr 29 | else: 30 | new_str += "{" + a + "}{" + b + "}" 31 | else: 32 | if len(substr) > 2: 33 | post_substr = substr[2:] 34 | new_str += "{" + a + "}" + b + post_substr 35 | else: 36 | new_str += "{" + a + "}" + b 37 | string = new_str 38 | return string 39 | 40 | 41 | def _fix_a_slash_b(string): 42 | if len(string.split("/")) != 2: 43 | return string 44 | a = string.split("/")[0] 45 | b = string.split("/")[1] 46 | try: 47 | if "sqrt" not in a: 48 | a = int(a) 49 | if "sqrt" not in b: 50 | b = int(b) 51 | assert string == "{}/{}".format(a, b) 52 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 53 | return new_string 54 | except: 55 | return string 56 | 57 | 58 | def _fix_sqrt(string): 59 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 60 | return _string 61 | 62 | 63 | def convert_word_number(text:str) -> str: 64 | try: 65 | text = str(w2n.word_to_num(text)) 66 | except: 67 | pass 68 | return text 69 | 70 | # units mainly from MathQA 71 | unit_texts = [ 72 | "east", "degree", "mph", "kmph", "ft", "m sqaure", " m east", "sq m", "deg", "mile", 73 | "q .", "monkey", "prime", "ratio", "profit of rs", "rd", "o", "gm", 74 | "p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west", 75 | "a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h", 76 | "no change", "men", "soldier", "pie", "bc", "excess", "st", 77 | "inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise", 78 | "a . m", "th", "π r 2", "sq", "mark", "l", "toy", "coin", 79 | "sq . m", "gallon", "° f", "profit", "minw", "yr", "women", 80 | "feet", "am", "pm", "hr", "cu cm", "square", "v â € ™", "are", 81 | "rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number", 82 | "kmph", "day", "hour", "minute", "min", "second", "man", "woman", 83 | "sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare", "more", 84 | "sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month", 85 | "km", "m", "cm", "mm", "apple", "liter", "loss", "yard", 86 | "pure", "year", "increase", "decrease", "d", "less", "Surface", 87 | "litre", "pi sq m", "s .", "metre", "meter", "inch", 88 | ] 89 | 90 | unit_texts.extend([t + "s" for t in unit_texts]) 91 | 92 | def strip_string(string): 93 | string = str(string).strip() 94 | # linebreaks 95 | string = string.replace("\n", "") 96 | 97 | # right "." 98 | string = string.rstrip(".") 99 | 100 | # remove inverse spaces 101 | # replace \\ with \ 102 | string = string.replace("\\!", "") 103 | # string = string.replace("\\ ", "") 104 | # string = string.replace("\\\\", "\\") 105 | 106 | # matrix 107 | string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string) 108 | string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string) 109 | string = string.replace("bmatrix", "pmatrix") 110 | 111 | 112 | # replace tfrac and dfrac with frac 113 | string = string.replace("tfrac", "frac") 114 | string = string.replace("dfrac", "frac") 115 | 116 | # remove \left and \right 117 | string = string.replace("\\left", "") 118 | string = string.replace("\\right", "") 119 | string = string.replace("\\{", "{") 120 | string = string.replace("\\}", "}") 121 | 122 | # Remove unit: miles, dollars if after is not none 123 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 124 | if _string != "" and _string != string: 125 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 126 | string = _string 127 | 128 | # Remove unit: texts 129 | for _ in range(2): 130 | for unit_text in unit_texts: 131 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character 132 | # the suffix should be either the end of the string or a non-alphanumeric character 133 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) 134 | if _string != "": 135 | string = _string 136 | 137 | # Remove circ (degrees) 138 | string = string.replace("^{\\circ}", "") 139 | string = string.replace("^\\circ", "") 140 | 141 | # remove dollar signs 142 | string = string.replace("\\$", "") 143 | string = string.replace("$", "") 144 | 145 | # convert word number to digit 146 | string = convert_word_number(string) 147 | 148 | # replace "\\text{...}" to "..." 149 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string) 150 | for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']: 151 | string = string.replace(key, "") 152 | string = string.replace("\\emptyset", r"{}") 153 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") 154 | 155 | # remove percentage 156 | string = string.replace("\\%", "") 157 | string = string.replace("\%", "") 158 | string = string.replace("%", "") 159 | 160 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 161 | string = string.replace(" .", " 0.") 162 | string = string.replace("{.", "{0.") 163 | 164 | # cdot 165 | # string = string.replace("\\cdot", "") 166 | if string.startswith("{") and string.endswith("}") and string.isalnum() or \ 167 | string.startswith("(") and string.endswith(")") and string.isalnum() or \ 168 | string.startswith("[") and string.endswith("]") and string.isalnum(): 169 | string = string[1:-1] 170 | 171 | # inf 172 | string = string.replace("infinity", "\\infty") 173 | if "\\infty" not in string: 174 | string = string.replace("inf", "\\infty") 175 | string = string.replace("+\\inity", "\\infty") 176 | 177 | # and 178 | string = string.replace("and", "") 179 | string = string.replace("\\mathbf", "") 180 | 181 | # use regex to remove \mbox{...} 182 | string = re.sub(r"\\mbox{.*?}", "", string) 183 | 184 | # quote 185 | string.replace("'", "") 186 | string.replace("\"", "") 187 | 188 | # i, j 189 | if "j" in string and "i" not in string: 190 | string = string.replace("j", "i") 191 | 192 | # replace a.000b where b is not number or b is end, with ab, use regex 193 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) 194 | string = re.sub(r"(\d+)\.0*$", r"\1", string) 195 | 196 | # if empty, return empty string 197 | if len(string) == 0: 198 | return string 199 | if string[0] == ".": 200 | string = "0" + string 201 | 202 | # to consider: get rid of e.g. "k = " or "q = " at beginning 203 | if len(string.split("=")) == 2: 204 | if len(string.split("=")[0]) <= 2: 205 | string = string.split("=")[1] 206 | 207 | string = _fix_sqrt(string) 208 | string = string.replace(" ", "") 209 | 210 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 211 | string = _fix_fracs(string) 212 | 213 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 214 | string = _fix_a_slash_b(string) 215 | 216 | return string 217 | 218 | 219 | def extract_multi_choice_answer(pred_str): 220 | # TODO: SFT models 221 | if 'Problem:' in pred_str: 222 | pred_str = pred_str.split("Problem:", 1)[0] 223 | pred_str = pred_str.replace("choice is", "answer is") 224 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower()) 225 | if patt is not None: 226 | return patt.group('ans').upper() 227 | return 'placeholder' 228 | 229 | 230 | def extract_answer(pred_str, data_name): 231 | if data_name in ["mmlu_stem", "sat_math", "mathqa"]: 232 | return extract_multi_choice_answer(pred_str) 233 | 234 | if 'final answer is $' in pred_str and '$. I hope' in pred_str: 235 | # minerva_math 236 | tmp = pred_str.split('final answer is $', 1)[1] 237 | pred = tmp.split('$. I hope', 1)[0].strip() 238 | elif 'boxed' in pred_str: 239 | ans = pred_str.split('boxed')[-1] 240 | if len(ans) == 0: 241 | return "" 242 | elif ans[0] == '{': 243 | stack = 1 244 | a = '' 245 | for c in ans[1:]: 246 | if (c == '{'): 247 | stack += 1 248 | a += c 249 | elif (c == '}'): 250 | stack -= 1 251 | if (stack == 0): break 252 | a += c 253 | else: 254 | a += c 255 | else: 256 | a = ans.split('$')[0].strip() 257 | pred = a 258 | elif ('he answer is' in pred_str): 259 | pred = pred_str.split('he answer is')[-1].strip() 260 | elif ('final answer is' in pred_str): 261 | pred = pred_str.split('final answer is')[-1].strip() 262 | # elif extract_program_output(pred_str) != "": 263 | # fall back to program 264 | # pred = extract_program_output(pred_str) 265 | else: # use the last number 266 | pattern = '-?\d*\.?\d+' 267 | pred = re.findall(pattern, pred_str.replace(",", "")) 268 | if(len(pred) >= 1): 269 | pred = pred[-1] 270 | else: pred = '' 271 | 272 | # multiple line 273 | # pred = pred.split("\n")[0] 274 | pred = re.sub(r"\n\s*", "", pred) 275 | if pred != "" and pred[0] == ":": 276 | pred = pred[1:] 277 | if pred != "" and pred[-1] == ".": 278 | pred = pred[:-1] 279 | if pred != "" and pred[-1] == "/": 280 | pred = pred[:-1] 281 | pred = strip_string(pred) 282 | return pred 283 | 284 | 285 | def parse_ground_truth(example: Dict[str, Any], data_name): 286 | if 'gt_cot' in example and 'gt' in example: 287 | if data_name in ["math", "math_oai", "ocw", "amps", "hungarian_exam"]: 288 | gt_ans = extract_answer(example['gt_cot'], data_name) 289 | else: 290 | gt_ans = strip_string(example['gt']) 291 | return example['gt_cot'], gt_ans 292 | 293 | # parse ground truth 294 | if data_name in ["math", "math_oai", "minerva_math", "ocw", "amps", "hungarian_exam"]: 295 | gt_cot = example['solution'] 296 | gt_ans = extract_answer(gt_cot, data_name) 297 | elif data_name in ['mathqa']: 298 | gt_cot = example['rationale'] 299 | gt_ans = example['correct'].upper() 300 | assert gt_ans in ['A', 'B', 'C', 'D', 'E'] 301 | elif data_name == "gsm8k": 302 | gt_cot, gt_ans = example['answer'].split("####") 303 | elif data_name == "gsm_hard": 304 | gt_cot, gt_ans = example['code'], example['target'] 305 | elif data_name == "svamp": 306 | gt_cot, gt_ans = example['Equation'], example['Answer'] 307 | elif data_name == "asdiv": 308 | gt_cot = example['formula'] 309 | gt_ans = re.sub(r"\(.*?\)", "", example['answer']) 310 | elif data_name == "mawps": 311 | gt_cot, gt_ans = None, example['target'] 312 | elif data_name == "tabmwp": 313 | gt_cot = example['solution'] 314 | gt_ans = example['answer'] 315 | if example['ans_type'] in ['integer_number', 'decimal_number']: 316 | if '/' in gt_ans: 317 | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) 318 | elif ',' in gt_ans: 319 | gt_ans = float(gt_ans.replace(',', '')) 320 | elif '%' in gt_ans: 321 | gt_ans = float(gt_ans.split('%')[0]) / 100 322 | else: 323 | gt_ans = float(gt_ans) 324 | elif data_name == "bbh": 325 | gt_cot, gt_ans = None, example['target'] 326 | elif data_name == "theorem_qa": 327 | gt_cot, gt_ans = None, example['answer'] 328 | elif data_name == "mmlu_stem": 329 | abcd = 'ABCD' 330 | gt_cot, gt_ans = None, abcd[example['answer']] 331 | elif data_name == "sat_math": 332 | gt_cot, gt_ans = None, example['Answer'] 333 | else: 334 | raise NotImplementedError(f"`{data_name}`") 335 | # post process 336 | gt_cot = str(gt_cot).strip() 337 | gt_ans = strip_string(gt_ans) 338 | return gt_cot, gt_ans 339 | 340 | 341 | def parse_question(example, data_name): 342 | question = "" 343 | if data_name == "asdiv": 344 | question = f"{example['body'].strip()} {example['question'].strip()}" 345 | elif data_name == "svamp": 346 | body = example["Body"].strip() 347 | if not body.endswith("."): 348 | body = body + "." 349 | question = f'{body} {example["Question"].strip()}' 350 | elif data_name == "tabmwp": 351 | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" 352 | question = f'Read the following table {title_str}and answer a question:\n' 353 | question += f'{example["table"]}\n{example["question"]}' 354 | if example['choices']: 355 | question += f' Please select from the following options: {example["choices"]}' 356 | elif data_name == "theorem_qa": 357 | question = f"{example['question'].strip()}\nTheorem: {example['theorem_def'].strip()}" 358 | elif data_name == "mmlu_stem": 359 | options = example['choices'] 360 | assert len(options) == 4 361 | for i, (label, option) in enumerate(zip('ABCD', options)): 362 | options[i] = f"({label}) {str(option).strip()}" 363 | options = ", ".join(options) 364 | question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" 365 | elif data_name == "sat_math": 366 | options = example['options'].strip() 367 | assert 'A' == options[0] 368 | options = '(' + options 369 | for ch in 'BCD': 370 | if f' {ch}) ' in options: 371 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options) 372 | question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 373 | elif data_name == "mathqa": 374 | example['problem'] = example['problem'][0].upper() + example['problem'][1:] 375 | options = example['options'].strip() 376 | if options[0] == '[': 377 | options = eval(options) 378 | options = ", ".join(options) 379 | assert 'a' == options[0], options 380 | for ch in 'abcde': 381 | if f'{ch} ) ' in options: 382 | options = regex.sub(f'{ch} \) {ch} \) ', f'{ch} ) ', options) 383 | options = regex.sub(f'{ch} \) ', f"({ch.upper()}) ", options) 384 | options = options.replace(' , ', ', ') 385 | question = f"{example['problem'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 386 | else: 387 | for key in ['question', 'problem', 'Question', 'input']: 388 | if key in example: 389 | question = example[key] 390 | break 391 | assert question != "" 392 | # Yes or No question 393 | _, gt_ans = parse_ground_truth(example, data_name) 394 | gt_lower = gt_ans.lower() 395 | if gt_lower in ["true", "false"]: 396 | question += " (True or False)" 397 | if gt_lower in ["yes", "no"]: 398 | question += " (Yes or No)" 399 | return question.strip() 400 | 401 | 402 | def run_execute(executor, result, prompt_type, data_name, execute=False): 403 | if not result or result == 'error': 404 | return None, None 405 | report = None 406 | 407 | if "program_only" in prompt_type: 408 | prediction = extract_program_output(result) 409 | elif prompt_type in ["pot", "pal"] and execute: 410 | code = extract_program(result) 411 | prediction, report = executor.apply(code) 412 | else: 413 | prediction = extract_answer(result, data_name) 414 | 415 | prediction = strip_string(prediction) 416 | return prediction, report 417 | 418 | 419 | def _test_extract_answer(): 420 | text= """ 421 | The answer is $\\boxed{\left( 422 | \\begin{array}{ccc} 423 | -13 & 4 & -2 \\\\ 424 | 7 & 8 & -3 \\\\ 425 | 0 & 18 & -7 \\\\ 426 | 6 & 12 & 5 \\\\ 427 | \\end{array} 428 | \\right)}$. 429 | """ 430 | print(extract_answer(text, "math")) 431 | # should output a dict 432 | 433 | 434 | if __name__ == "__main__": 435 | _test_extract_answer() --------------------------------------------------------------------------------