├── .gitignore ├── LICENSE-CODE ├── LICENSE-MODEL ├── README.md ├── cog.yaml ├── evaluation ├── README.md ├── configs │ ├── few_shot_test_configs.json │ └── zero_shot_test_configs.json ├── data_processing │ ├── answer_extraction.py │ └── process_utils.py ├── datasets │ ├── agieval │ │ ├── gaokao-mathcloze.jsonl │ │ └── gaokao-mathqa.jsonl │ ├── cmath │ │ └── test.jsonl │ ├── gsm8k │ │ └── test.jsonl │ ├── math │ │ └── test.jsonl │ ├── mgsm_zh │ │ └── mgsm_zh.jsonl │ ├── minif2f │ │ ├── test.jsonl │ │ └── validation.jsonl │ ├── mmlu_stem │ │ └── test.jsonl │ ├── ocw │ │ └── test.jsonl │ └── sat │ │ └── test.jsonl ├── environment.yml ├── eval │ ├── eval_script.py │ ├── eval_utils.py │ ├── ocwcourses_eval_utils.py │ ├── python_executor.py │ └── utils.py ├── evaluation_results.json ├── few_shot_prompts │ ├── __init__.py │ ├── cot_cmath_6_shot.py │ ├── cot_gaokao_mathcloze_5_shot.py │ ├── cot_gaokao_mathqa_5_shot.py │ ├── cot_gsm_8_shot.py │ ├── cot_math_sat_4_shot.py │ ├── cot_minerva_math_4_shot.py │ ├── cot_mmlu_stem_4_shot.py │ ├── cot_ocwcourses_4_shot.py │ ├── few_shot_prompting.py │ ├── minif2f_isabelle.py │ ├── pal_gsm_8_shot.py │ └── pal_math_4_shot.py ├── infer │ ├── run_cot_eval.py │ ├── run_pal_eval.py │ └── run_tool_integrated_eval.py ├── outputs.zip ├── run_subset_parallel.py ├── submit_eval_jobs.py ├── summarize_results.py ├── unsafe_score_minif2f_isabelle.py └── utils.py ├── images ├── badge.svg ├── base_results_1.png ├── base_results_2.png ├── base_results_3.png ├── corpus.png ├── data_pipeline.png ├── instruct_results.png ├── logo.svg ├── math.png └── qr.jpeg └── replicate ├── predict.py └── predict_instruct.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 |
7 | DeepSeek LLM 8 |
9 |
10 |
11 | 12 | 13 | Homepage 14 | 15 | 16 | Chat 17 | 18 | 19 | Hugging Face 20 | 21 | Replicate 22 |
23 | 24 |
25 | 26 | 27 | Discord 28 | 29 | 30 | Wechat 31 | 32 | 33 | Twitter Follow 34 | 35 | 36 |
37 | 38 |
39 | 40 | 41 | Code License 42 | 43 | 44 | Model License 45 | 46 |
47 | 48 | 49 |

50 | Model Download | 51 | Evaluation Results | 52 | Quick Start | 53 | License | 54 | Citation 55 |

56 | 57 |

58 | Paper Link👁️ 59 |

60 | 61 | 62 | ## 1. Introduction 63 | 64 | DeepSeekMath is initialized with [DeepSeek-Coder-v1.5 7B](https://huggingface.co/deepseek-ai/deepseek-coder-7b-base-v1.5) and continues pre-training on math-related tokens sourced from Common Crawl, together with natural language and code data for 500B tokens. DeepSeekMath 7B has achieved an impressive score of **51.7%** on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. For research purposes, we release [checkpoints](#4-model-downloads) of base, instruct, and RL models to the public. 65 | 66 |

67 | table 68 |

69 | 70 | ## 2. Evaluation Results 71 | 72 | ### DeepSeekMath-Base 7B 73 | 74 | We conduct a comprehensive assessment of the mathematical capabilities of DeepSeekMath-Base 7B, focusing on its ability to produce self-contained mathematical solutions without relying on external tools, solve math problems using tools, and conduct formal theorem proving. Beyond mathematics, we also provide a more general profile of the base model, including its performance of natural language understanding, reasoning, and programming skills. 75 | 76 | - **Mathematical problem solving with step-by-step reasoning** 77 | 78 |

79 | table 80 |

81 | 82 | - **Mathematical problem solving with tool use** 83 | 84 |

85 | table 86 |

87 | 88 | - **Natural Language Understanding, Reasoning, and Code** 89 |

90 | table 91 |

92 | 93 | The evaluation results from the tables above can be summarized as follows: 94 | - **Superior Mathematical Reasoning:** On the competition-level MATH dataset, DeepSeekMath-Base 7B outperforms existing open-source base models by more than 10% in absolute terms through few-shot chain-of-thought prompting, and also surpasses Minerva 540B. 95 | - **Strong Tool Use Ability:** Continuing pre-training with DeepSeekCoder-Base-7B-v1.5 enables DeepSeekMath-Base 7B to more effectively solve and prove mathematical problems by writing programs. 96 | - **Comparable Reasoning and Coding Performance:** DeepSeekMath-Base 7B achieves performance in reasoning and coding that is comparable to that of DeepSeekCoder-Base-7B-v1.5. 97 | 98 | ### DeepSeekMath-Instruct and -RL 7B 99 | 100 | DeepSeekMath-Instruct 7B is a mathematically instructed tuning model derived from DeepSeekMath-Base 7B, while DeepSeekMath-RL 7B is trained on the foundation of DeepSeekMath-Instruct 7B, utilizing our proposed Group Relative Policy Optimization (GRPO) algorithm. 101 | 102 | We evaluate mathematical performance both without and with tool use, on 4 quantitative reasoning benchmarks in English and Chinese. As shown in Table, DeepSeekMath-Instruct 7B demonstrates strong performance of step-by-step reasoning, and DeepSeekMath-RL 7B approaches an accuracy of 60% on MATH with tool use, surpassing all existing open-source models. 103 | 104 |

105 | table 106 |

107 | 108 | 109 | ## 3. Data Collection 110 | 111 | - Step 1: Select [OpenWebMath](https://arxiv.org/pdf/2310.06786.pdf), a collection of high-quality mathematical web texts, as our initial seed corpus for training a FastText model. 112 | - Step 2: Use the FastText model to retrieve mathematical web pages from the deduplicated Common Crawl database. 113 | - Step 3: Identify potential math-related domains through statistical analysis. 114 | - Step 4: Manually annotate URLs within these identified domains that are associated with mathematical content. 115 | - Step 5: Add web pages linked to these annotated URLs, but not yet collected, to the seed corpus. Jump to step 1 until four iterations. 116 | 117 | 118 |

119 | table 120 |

121 | 122 | After four iterations of data collection, we end up with **35.5M** mathematical web pages, totaling **120B** tokens. 123 | 124 | ## 4. Model Downloads 125 | 126 | We release the DeepSeekMath 7B, including base, instruct and RL models, to the public. To support a broader and more diverse range of research within both academic and commercial communities. Please **note** that the use of this model is subject to the terms outlined in [License section](#6-license). Commercial usage is permitted under these terms. 127 | 128 | ### Huggingface 129 | 130 | | Model | Sequence Length | Download | 131 | | :----------------------- | :-------------: | :----------------------------------------------------------: | 132 | | DeepSeekMath-Base 7B | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-math-7b-base) | 133 | | DeepSeekMath-Instruct 7B | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-math-7b-instruct) | 134 | | DeepSeekMath-RL 7B | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-math-7b-rl) | 135 | 136 | ## 5. Quick Start 137 | 138 | You can directly employ [Huggingface's Transformers](https://github.com/huggingface/transformers) for model inference. 139 | 140 | **Text Completion** 141 | 142 | ```python 143 | import torch 144 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 145 | 146 | model_name = "deepseek-ai/deepseek-math-7b-base" 147 | tokenizer = AutoTokenizer.from_pretrained(model_name) 148 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") 149 | model.generation_config = GenerationConfig.from_pretrained(model_name) 150 | model.generation_config.pad_token_id = model.generation_config.eos_token_id 151 | 152 | text = "The integral of x^2 from 0 to 2 is" 153 | inputs = tokenizer(text, return_tensors="pt") 154 | outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) 155 | 156 | result = tokenizer.decode(outputs[0], skip_special_tokens=True) 157 | print(result) 158 | ``` 159 | 160 | **Chat Completion** 161 | 162 | ```python 163 | import torch 164 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 165 | 166 | model_name = "deepseek-ai/deepseek-math-7b-instruct" 167 | tokenizer = AutoTokenizer.from_pretrained(model_name) 168 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") 169 | model.generation_config = GenerationConfig.from_pretrained(model_name) 170 | model.generation_config.pad_token_id = model.generation_config.eos_token_id 171 | 172 | messages = [ 173 | {"role": "user", "content": "what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \boxed{}."} 174 | ] 175 | input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") 176 | outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100) 177 | 178 | result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True) 179 | print(result) 180 | ``` 181 | 182 | Avoiding the use of the provided function `apply_chat_template`, you can also interact with our model following the sample template. Note that `messages` should be replaced by your input. 183 | 184 | ``` 185 | User: {messages[0]['content']} 186 | 187 | Assistant: {messages[1]['content']}<|end▁of▁sentence|>User: {messages[2]['content']} 188 | 189 | Assistant: 190 | ``` 191 | 192 | **Note:** By default (`add_special_tokens=True`), our tokenizer automatically adds a `bos_token` (`<|begin▁of▁sentence|>`) before the input text. Additionally, since the system prompt is not compatible with this version of our models, we DO NOT RECOMMEND including the system prompt in your input. 193 | 194 | ❗❗❗ **Please use chain-of-thought prompt to test DeepSeekMath-Instruct and DeepSeekMath-RL:** 195 | 196 | - English questions: **{question}\nPlease reason step by step, and put your final answer within \\boxed{}.** 197 | 198 | - Chinese questions: **{question}\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。** 199 | 200 | 201 | ## 6. License 202 | This code repository is licensed under the MIT License. The use of DeepSeekMath models is subject to the Model License. DeepSeekMath supports commercial use. 203 | 204 | See the [LICENSE-CODE](LICENSE-CODE) and [LICENSE-MODEL](LICENSE-MODEL) for more details. 205 | 206 | ## 7. Citation 207 | 208 | ``` 209 | @misc{deepseek-math, 210 | author = {Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y.K. Li, Y. Wu, Daya Guo}, 211 | title = {DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}, 212 | journal = {CoRR}, 213 | volume = {abs/2402.03300}, 214 | year = {2024}, 215 | url = {https://arxiv.org/abs/2402.03300}, 216 | } 217 | ``` 218 | 219 | 220 | ## 8. Contact 221 | 222 | If you have any questions, please raise an issue or contact us at [service@deepseek.com](mailto:service@deepseek.com). 223 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | python_version: "3.11" 7 | python_packages: 8 | - torch==2.0.1 9 | - torchvision==0.15.2 10 | - transformers==4.37.2 11 | - accelerate==0.27.0 12 | - hf_transfer 13 | 14 | # predict.py defines how predictions are run on your model 15 | predict: "replicate/predict.py:Predictor" 16 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | ## 1. Introduction 2 | 3 | We provide a test script for both zero-shot and few-shot evaluation on mathematical reasoning benchmarks used in our paper. 4 | 5 | ## 2. Setup 6 | 7 | First configure the `prefix` in `environment.yml` and then run the following command 8 | ``` 9 | conda env create -f environment.yml 10 | ``` 11 | 12 | ## 3. Evaluation 13 | 14 | For chain-of-thought evaluation of DeepSeekMath-Instruct and DeepSeekMath-RL, our script (see `def markup_question()` in `run_subset_parallel.py`) processes each question as follows: 15 | * English questions: `{question}\nPlease reason step by step, and put your final answer within \\boxed{}.` 16 | * Chinese questions: `{question}\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。` 17 | 18 | For tool-integrated reasoning, we process each question as follows: 19 | * English questions: `{question}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{}.` 20 | * Chinese questions: `{question}\n请结合自然语言和Python程序语言来解答问题,并把最终答案放置于\\boxed{}中。` 21 | 22 | We provide an example of testing the DeepSeekMath-Base 7B using 8 GPUs. 23 | 24 | If you wish to use a different model or dataset, you can modify the configs in `submit_eval_jobs.py` and `configs/*test_configs.json` 25 | 26 | ``` 27 | python submit_eval_jobs.py --n-gpus 8 28 | ``` 29 | 30 | Wait for all processes to finish, and then run the following command to aggregate results from all processes 31 | 32 | ``` 33 | python summarize_results.py [--eval-atp] 34 | ``` 35 | where the option `--eval-atp` will invoke `unsafe_score_minif2f_isabelle.py` to evaluate the informal-to-formal proving results. Please make sure you have set up the [PISA](https://github.com/wellecks/lm-evaluation-harness/blob/minif2f-isabelle/docs/isabelle_setup.md) server before using this option. 36 | 37 | A summary of all evaluation results will be saved as `evaluation_results.json` 38 | 39 | ## 4. Model Outputs 40 | 41 | We provide all model outputs in `outputs.zip`. 42 | -------------------------------------------------------------------------------- /evaluation/configs/few_shot_test_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "gsm8k-cot-test": { 3 | "test_path": "datasets/gsm8k/test.jsonl", 4 | "language": "en", 5 | "tasks": ["cot"], 6 | "process_fn": "process_gsm8k_test", 7 | "answer_extraction_fn": "extract_gsm_few_shot_cot_answer", 8 | "eval_fn": "eval_last_single_answer", 9 | "few_shot_prompt": "CoTGSMPrompt" 10 | }, 11 | "gsm8k-pal-test": { 12 | "test_path": "datasets/gsm8k/test.jsonl", 13 | "language": "en", 14 | "tasks": ["pal"], 15 | "process_fn": "process_gsm8k_test", 16 | "answer_extraction_fn": "placeholder", 17 | "eval_fn": "eval_last_single_answer", 18 | "few_shot_prompt": "PALGSMPrompt" 19 | }, 20 | "math-cot-test": { 21 | "test_path": "datasets/math/test.jsonl", 22 | "language": "en", 23 | "tasks": ["cot"], 24 | "process_fn": "process_math_test", 25 | "answer_extraction_fn": "extract_math_few_shot_cot_answer", 26 | "eval_fn": "eval_math", 27 | "few_shot_prompt": "MinervaMathPrompt" 28 | }, 29 | "math-pal-test": { 30 | "test_path": "datasets/math/test.jsonl", 31 | "language": "en", 32 | "tasks": ["pal"], 33 | "process_fn": "process_math_test", 34 | "answer_extraction_fn": "placeholder", 35 | "eval_fn": "eval_math", 36 | "few_shot_prompt": "PALMathPrompt" 37 | }, 38 | "math_sat": { 39 | "test_path": "datasets/sat/test.jsonl", 40 | "language": "en", 41 | "tasks": ["cot"], 42 | "process_fn": "process_math_sat", 43 | "answer_extraction_fn": "extract_sat_few_shot_answer", 44 | "eval_fn": "eval_math_sat", 45 | "few_shot_prompt": "CoTSATPrompt" 46 | }, 47 | "OCWCourses": { 48 | "test_path": "datasets/ocw/test.jsonl", 49 | "language": "en", 50 | "tasks": ["cot"], 51 | "process_fn": "process_ocwcourses", 52 | "answer_extraction_fn": "extract_ocwcourses_few_shot_answer", 53 | "eval_fn": "eval_ocwcourses", 54 | "few_shot_prompt": "OCWCoursesPrompt" 55 | }, 56 | "MMLU-STEM-test": { 57 | "test_path": "datasets/mmlu_stem/test.jsonl", 58 | "language": "en", 59 | "tasks": ["cot"], 60 | "process_fn": "process_mmlu_stem", 61 | "answer_extraction_fn": "extract_mmlu_stem", 62 | "eval_fn": "eval_mmlu_stem", 63 | "few_shot_prompt": "MMLUSTEMPrompt" 64 | }, 65 | "miniF2F-Isabelle-valid": { 66 | "test_path": "datasets/minif2f/validation.jsonl", 67 | "language": "en", 68 | "tasks": ["cot"], 69 | "process_fn": "process_minif2f_isabelle", 70 | "answer_extraction_fn": "extract_minif2f_isabelle", 71 | "eval_fn": "eval_minif2f_isabelle", 72 | "few_shot_prompt": "MiniF2FIsabellePrompt" 73 | }, 74 | "miniF2F-Isabelle-test": { 75 | "test_path": "datasets/minif2f/test.jsonl", 76 | "language": "en", 77 | "tasks": ["cot"], 78 | "process_fn": "process_minif2f_isabelle", 79 | "answer_extraction_fn": "extract_minif2f_isabelle", 80 | "eval_fn": "eval_minif2f_isabelle", 81 | "few_shot_prompt": "MiniF2FIsabellePrompt" 82 | }, 83 | "cmath-cot-test": { 84 | "test_path": "datasets/cmath/test.jsonl", 85 | "language": "zh", 86 | "tasks": ["cot"], 87 | "process_fn": "process_cmath", 88 | "answer_extraction_fn": "extract_cmath_few_shot_test", 89 | "eval_fn": "eval_last_single_answer", 90 | "few_shot_prompt": "CoTCMATHPrompt" 91 | }, 92 | "agieval-gaokao-mathcloze-cot-test": { 93 | "test_path": "datasets/agieval/gaokao-mathcloze.jsonl", 94 | "language": "zh", 95 | "tasks": ["cot"], 96 | "process_fn": "process_agieval_gaokao_math_cloze", 97 | "answer_extraction_fn": "extract_agieval_gaokao_mathcloze_few_shot_cot_test", 98 | "eval_fn": "eval_agieval_gaokao_math_cloze", 99 | "few_shot_prompt": "CoTGaoKaoMathClozePrompt" 100 | }, 101 | "agieval-gaokao-mathqa-cot-test": { 102 | "test_path": "datasets/agieval/gaokao-mathqa.jsonl", 103 | "language": "zh", 104 | "tasks": ["cot"], 105 | "process_fn": "process_agieval_gaokao_mathqa_few_shot_cot_test", 106 | "answer_extraction_fn": "extract_agieval_gaokao_mathqa_few_shot_cot_test", 107 | "eval_fn": "eval_agieval_gaokao_mathqa", 108 | "few_shot_prompt": "CoTGaoKaoMathQAPrompt" 109 | } 110 | } -------------------------------------------------------------------------------- /evaluation/configs/zero_shot_test_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "gsm8k-test": { 3 | "test_path": "datasets/gsm8k/test.jsonl", 4 | "language": "en", 5 | "tasks": ["tool", "cot"], 6 | "process_fn": "process_gsm8k_test", 7 | "answer_extraction_fn": "extract_last_single_answer", 8 | "eval_fn": "eval_last_single_answer" 9 | }, 10 | "math-test": { 11 | "test_path": "datasets/math/test.jsonl", 12 | "language": "en", 13 | "tasks": ["tool", "cot"], 14 | "process_fn": "process_math_test", 15 | "answer_extraction_fn": "extract_math_answer", 16 | "eval_fn": "eval_math" 17 | }, 18 | "mgsm-zh": { 19 | "test_path": "datasets/mgsm_zh/mgsm_zh.jsonl", 20 | "language": "zh", 21 | "tasks": ["tool", "cot"], 22 | "process_fn": "process_mgsm_zh", 23 | "answer_extraction_fn": "extract_last_single_answer", 24 | "eval_fn": "eval_last_single_answer" 25 | }, 26 | "cmath": { 27 | "test_path": "datasets/cmath/test.jsonl", 28 | "language": "zh", 29 | "tasks": ["tool", "cot"], 30 | "process_fn": "process_cmath", 31 | "answer_extraction_fn": "extract_last_single_answer", 32 | "eval_fn": "eval_last_single_answer" 33 | } 34 | } -------------------------------------------------------------------------------- /evaluation/data_processing/answer_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | import regex 3 | 4 | def _fix_fracs(string): 5 | substrs = string.split("\\frac") 6 | new_str = substrs[0] 7 | if len(substrs) > 1: 8 | substrs = substrs[1:] 9 | for substr in substrs: 10 | new_str += "\\frac" 11 | if len(substr) > 0 and substr[0] == "{": 12 | new_str += substr 13 | else: 14 | try: 15 | assert len(substr) >= 2 16 | except: 17 | return string 18 | a = substr[0] 19 | b = substr[1] 20 | if b != "{": 21 | if len(substr) > 2: 22 | post_substr = substr[2:] 23 | new_str += "{" + a + "}{" + b + "}" + post_substr 24 | else: 25 | new_str += "{" + a + "}{" + b + "}" 26 | else: 27 | if len(substr) > 2: 28 | post_substr = substr[2:] 29 | new_str += "{" + a + "}" + b + post_substr 30 | else: 31 | new_str += "{" + a + "}" + b 32 | string = new_str 33 | return string 34 | 35 | 36 | def _fix_a_slash_b(string): 37 | if len(string.split("/")) != 2: 38 | return string 39 | a = string.split("/")[0] 40 | b = string.split("/")[1] 41 | try: 42 | if "sqrt" not in a: 43 | a = int(a) 44 | if "sqrt" not in b: 45 | b = int(b) 46 | assert string == "{}/{}".format(a, b) 47 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 48 | return new_string 49 | except: 50 | return string 51 | 52 | 53 | def _fix_sqrt(string): 54 | _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) 55 | _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) 56 | return _string 57 | 58 | 59 | def _fix_tan(string): 60 | _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) 61 | _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) 62 | return _string 63 | 64 | 65 | def strip_string(string): 66 | string = str(string).strip() 67 | # linebreaks 68 | string = string.replace("\n", "") 69 | 70 | # right "." 71 | string = string.rstrip(".") 72 | 73 | # remove inverse spaces 74 | string = string.replace("\\!", "") 75 | # string = string.replace("\\ ", "") 76 | 77 | # replace \\ with \ 78 | # string = string.replace("\\\\", "\\") 79 | # string = string.replace("\\\\", "\\") 80 | 81 | if string.startswith("\\text{") and string.endswith("}"): 82 | string = string.split("{", 1)[1][:-1] 83 | 84 | # replace tfrac and dfrac with frac 85 | string = string.replace("tfrac", "frac") 86 | string = string.replace("dfrac", "frac") 87 | string = string.replace("cfrac", "frac") 88 | 89 | # remove \left and \right 90 | string = string.replace("\\left", "") 91 | string = string.replace("\\right", "") 92 | 93 | # Remove unit: miles, dollars if after is not none 94 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 95 | if _string != "" and _string != string: 96 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 97 | string = _string 98 | 99 | # Remove circ (degrees) 100 | string = string.replace("^{\\circ}", "").strip() 101 | string = string.replace("^\\circ", "").strip() 102 | 103 | string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() 104 | string = regex.sub(r"p\.m\.$", "", string).strip() 105 | string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() 106 | 107 | # remove dollar signs 108 | string = string.replace("\\$", "") 109 | string = string.replace("$", "") 110 | 111 | # string = string.replace("\\text", "") 112 | string = string.replace("x\\in", "") 113 | 114 | # remove percentage 115 | string = string.replace("\\%", "%") 116 | string = string.replace("\%", "%") 117 | # string = string.replace("%", "") 118 | 119 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 120 | string = string.replace(" .", " 0.") 121 | string = string.replace("{.", "{0.") 122 | 123 | # cdot 124 | string = string.replace("\\cdot", "") 125 | 126 | # inf 127 | string = string.replace("infinity", "\\infty") 128 | if "\\infty" not in string: 129 | string = string.replace("inf", "\\infty") 130 | string = string.replace("+\\inity", "\\infty") 131 | 132 | # and 133 | # string = string.replace("and", "") 134 | string = string.replace("\\mathbf", "") 135 | string = string.replace("\\mathrm", "") 136 | 137 | # use regex to remove \mbox{...} 138 | string = re.sub(r"\\mbox{.*?}", "", string) 139 | 140 | # quote 141 | string.replace("'", "") 142 | string.replace("\"", "") 143 | 144 | # i, j 145 | if "j" in string and "i" not in string: 146 | string = string.replace("j", "i") 147 | 148 | # replace a.000b where b is not number or b is end, with ab, use regex 149 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) 150 | string = re.sub(r"(\d+)\.0+$", r"\1", string) 151 | 152 | # if empty, return empty string 153 | if len(string) == 0: 154 | return string 155 | if string[0] == ".": 156 | string = "0" + string 157 | 158 | # to consider: get rid of e.g. "k = " or "q = " at beginning 159 | # if len(string.split("=")) == 2: 160 | # if len(string.split("=")[0]) <= 2: 161 | # string = string.split("=")[1] 162 | 163 | string = _fix_sqrt(string) 164 | string = _fix_tan(string) 165 | string = string.replace(" ", "") 166 | 167 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 168 | string = _fix_fracs(string) 169 | 170 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 171 | string = _fix_a_slash_b(string) 172 | 173 | string = regex.sub(r"(\\|,|\.)+$", "", string) 174 | 175 | return string 176 | 177 | def extract_boxed_answers(text): 178 | answers = [] 179 | for piece in text.split('boxed{')[1:]: 180 | n = 0 181 | for i in range(len(piece)): 182 | if piece[i] == '{': 183 | n += 1 184 | elif piece[i] == '}': 185 | n -= 1 186 | if n < 0: 187 | if i + 1 < len(piece) and piece[i + 1] == '%': 188 | answers.append(piece[: i + 1]) 189 | else: 190 | answers.append(piece[:i]) 191 | break 192 | return answers 193 | 194 | def extract_program_output(pred_str): 195 | """ 196 | extract output between the last ```output\n...\n``` 197 | """ 198 | if "```output" not in pred_str: 199 | return "" 200 | if '```output' in pred_str: 201 | pred_str = pred_str.split('```output')[-1] 202 | if '```' in pred_str: 203 | pred_str = pred_str.split('```')[0] 204 | output = pred_str.strip() 205 | return output 206 | 207 | def extract_answer(pred_str, exhaust=False): 208 | pred = [] 209 | if 'final answer is $' in pred_str and '$. I hope' in pred_str: 210 | tmp = pred_str.split('final answer is $', 1)[1] 211 | pred = [tmp.split('$. I hope', 1)[0].strip()] 212 | elif 'boxed' in pred_str: 213 | pred = extract_boxed_answers(pred_str) 214 | elif ('he answer is' in pred_str): 215 | pred = [pred_str.split('he answer is')[-1].strip()] 216 | else: 217 | program_output = extract_program_output(pred_str) 218 | if program_output != "": 219 | # fall back to program 220 | pred.append(program_output) 221 | else: # use the last number 222 | pattern = '-?\d*\.?\d+' 223 | ans = re.findall(pattern, pred_str.replace(",", "")) 224 | if(len(ans) >= 1): 225 | ans = ans[-1] 226 | else: 227 | ans = '' 228 | if ans: 229 | pred.append(ans) 230 | 231 | # multiple line 232 | _pred = [] 233 | for ans in pred: 234 | ans = ans.strip().split("\n")[0] 235 | ans = ans.lstrip(":") 236 | ans = ans.rstrip(".") 237 | ans = ans.rstrip("/") 238 | ans = strip_string(ans) 239 | _pred.append(ans) 240 | if exhaust: 241 | return _pred 242 | else: 243 | return _pred[-1] if _pred else "" 244 | 245 | def extract_math_answer(question, reasoning, task): 246 | answer = [] 247 | for ans in extract_answer(reasoning, exhaust=True): 248 | if 'separated by commas' in question and all(ch not in ans for ch in '()[]'): 249 | answer.extend([a.strip() for a in ans.split(",")]) 250 | elif regex.search(r"\\text\{\s*and\s*\}", ans): 251 | answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")]) 252 | else: 253 | answer.append(ans.strip()) 254 | return answer 255 | 256 | def extract_math_few_shot_cot_answer(question, reasoning, task): 257 | if 'Problem:' in reasoning: 258 | reasoning = reasoning.split("Problem:", 1)[0] 259 | return extract_math_answer(question, reasoning, task) 260 | 261 | def extract_last_single_answer(question, reasoning, task): 262 | return extract_answer(reasoning, exhaust=False) 263 | 264 | def extract_gsm_few_shot_cot_answer(question, reasoning, task): 265 | if 'Q: ' in reasoning: 266 | reasoning = reasoning.split("Q: ", 1)[0] 267 | pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)] 268 | if pred: 269 | return pred[-1] 270 | else: 271 | return "[invalid]" 272 | 273 | def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task): 274 | if '问题 ' in reasoning: 275 | reasoning = reasoning.split("问题 ", 1)[0] 276 | if '答案是' in reasoning: 277 | ans = reasoning.split('答案是', 1)[1].strip() 278 | ans = ans.split("\n")[0].strip() 279 | ans = [ans.strip("$")] 280 | else: 281 | ans = ['placeholder'] 282 | return ans 283 | 284 | def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task): 285 | if '问题 ' in reasoning: 286 | reasoning = reasoning.split("问题 ", 1)[0] 287 | if '答案是' in reasoning: 288 | ans = reasoning.split('答案是', 1)[1].strip() 289 | ans = ans.split("\n")[0].strip() 290 | else: 291 | ans = 'placeholder' 292 | return ans 293 | 294 | def extract_sat_few_shot_answer(question, reasoning, task): 295 | if 'Problem:' in reasoning: 296 | reasoning = reasoning.split("Problem:", 1)[0] 297 | patt = regex.search(r"the final answer is \(?(?P[abcd])\)?", reasoning.lower()) 298 | if patt is not None: 299 | return patt.group('ans').upper() 300 | return 'placeholder' 301 | 302 | def extract_ocwcourses_few_shot_answer(question, reasoning, task): 303 | if 'Problem:' in reasoning: 304 | reasoning = reasoning.split("Problem:", 1)[0] 305 | patt = regex.search(r"final answer is (?P.*)\. I hope it is correct.", reasoning) 306 | if patt is None: 307 | pred = "[invalid]" 308 | print(f"DEBUG >>>\n{reasoning}", flush=True) 309 | else: 310 | pred = patt.group('ans') 311 | return pred 312 | 313 | def extract_mmlu_stem(question, reasoning, task): 314 | if 'Problem:' in reasoning: 315 | reasoning = reasoning.split("Problem:", 1)[0] 316 | return extract_sat_few_shot_answer(question, reasoning, task) 317 | 318 | def extract_minif2f_isabelle(question, reasoning, task): 319 | if 'Informal:' in reasoning: 320 | reasoning = reasoning.split("Informal:", 1)[0] 321 | return reasoning.strip() 322 | 323 | def extract_cmath_few_shot_test(question, reasoning, task): 324 | if '问题:' in reasoning: 325 | reasoning = reasoning.split("问题:", 1)[0] 326 | if '答案是' in reasoning: 327 | ans = reasoning.split('答案是', 1)[1].strip() 328 | ans = ans.split("\n")[0] 329 | ans = ans.strip(":") 330 | ans = ans.strip("。") 331 | try: 332 | ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1] 333 | except: 334 | print(f"DEBUG CMATH: {reasoning}", flush=True) 335 | ans = "[invalid]" 336 | else: 337 | ans = extract_last_single_answer(question, reasoning, task) 338 | return ans 339 | -------------------------------------------------------------------------------- /evaluation/data_processing/process_utils.py: -------------------------------------------------------------------------------- 1 | import regex 2 | 3 | from data_processing.answer_extraction import extract_math_answer, strip_string 4 | 5 | def process_gsm8k_test(item): 6 | sample = { 7 | 'dataset': 'gsm8k-cot', 8 | 'id': item['id'], 9 | 'messages': [ 10 | {'role': 'user', 'content': item['question']}, 11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."} 12 | ], 13 | 'answer': item['answer'].replace(',', '') 14 | } 15 | yield sample 16 | 17 | def process_math_test(item): 18 | question = item["problem"] 19 | try: 20 | answer = extract_math_answer(question, item['solution'], task="cot") 21 | except: 22 | return 23 | sample = { 24 | "dataset": "math-cot", 25 | "id": item['id'], 26 | "level": item["level"], 27 | "type": item["type"], 28 | "category": item["category"], 29 | "messages": [ 30 | {"role": "user", "content": question}, 31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))} 32 | ], 33 | "answer": answer 34 | } 35 | yield sample 36 | 37 | def process_math_sat(item): 38 | options = item['options'].strip() 39 | assert 'A' == options[0] 40 | options = '(' + options 41 | for ch in 'BCDEFG': 42 | if f' {ch}) ' in options: 43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options) 44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 45 | messages = [ 46 | {'role': 'user', 'content': question}, 47 | {'role': 'assistant', 'content': item['Answer']} 48 | ] 49 | item = { 50 | 'dataset': 'math_sat', 51 | 'id': item['id'], 52 | 'language': 'en', 53 | 'messages': messages, 54 | 'answer': item['Answer'], 55 | } 56 | yield item 57 | 58 | def process_ocwcourses(item): 59 | messages = [ 60 | {'role': 'user', 'content': item['problem'].strip()}, 61 | {'role': 'assistant', 'content': item['solution'].strip()} 62 | ] 63 | item = { 64 | "dataset": "OCWCourses", 65 | "id": item['id'], 66 | "language": "en", 67 | "messages": messages, 68 | "answer": item['answer'] 69 | } 70 | yield item 71 | 72 | def process_mmlu_stem(item): 73 | options = item['options'] 74 | for i, (label, option) in enumerate(zip('ABCD', options)): 75 | options[i] = f"({label}) {str(option).strip()}" 76 | options = ", ".join(options) 77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" 78 | messages = [ 79 | {'role': 'user', 'content': question}, 80 | {'role': 'assistant', 'content': item['answer']} 81 | ] 82 | item = { 83 | "dataset": "MMLU-STEM", 84 | "id": item['id'], 85 | "language": "en", 86 | "messages": messages, 87 | "answer": item['answer'] 88 | } 89 | yield item 90 | 91 | def process_mgsm_zh(item): 92 | item['answer'] = item['answer'].replace(',', '') 93 | yield item 94 | 95 | def process_cmath(item): 96 | item = { 97 | 'dataset': 'cmath', 98 | 'id': item['id'], 99 | 'grade': item['grade'], 100 | 'reasoning_step': item['reasoning_step'], 101 | 'messages': [ 102 | {'role': 'user', 'content': item['question'].strip()}, 103 | {'role': 'assistant', 'content': ''} 104 | ], 105 | 'answer': item['golden'].strip().replace(",", "") 106 | } 107 | yield item 108 | 109 | def process_agieval_gaokao_math_cloze(item): 110 | item = { 111 | 'dataset': 'agieval-gaokao-math-cloze', 112 | 'id': item['id'], 113 | 'messages': [ 114 | {'role': 'user', 'content': item['question'].strip()}, 115 | {'role': 'assistant', 'content': ''} 116 | ], 117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")] 118 | } 119 | yield item 120 | 121 | def process_agieval_gaokao_mathqa(item): 122 | question = item['question'].strip() 123 | options = [] 124 | for option in item['options']: 125 | option = option.strip() 126 | assert option[0] == '(' 127 | assert option[2] == ')' 128 | assert option[1] in 'ABCD' 129 | option = f"{option[1]}: {option[3:].strip()}" 130 | options.append(option.strip()) 131 | question = f"{question}\n{options}" 132 | item = { 133 | 'dataset': 'agieval-gaokao-mathqa', 134 | 'id': item['id'], 135 | 'messages': [ 136 | {'role': 'user', 'content': question}, 137 | {'role': 'assistant', 'content': ''} 138 | ], 139 | "answer": item['label'] 140 | } 141 | yield item 142 | 143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item): 144 | question = item['question'].strip().rstrip('\\') 145 | options = " ".join([opt.strip() for opt in item['options']]) 146 | question = f"{question}\n从以下选项中选择: {options}" 147 | item = { 148 | 'dataset': 'agieval-gaokao-mathqa', 149 | 'id': item['id'], 150 | 'messages': [ 151 | {'role': 'user', 'content': question}, 152 | {'role': 'assistant', 'content': ''} 153 | ], 154 | "answer": item['label'] 155 | } 156 | yield item 157 | 158 | def process_minif2f_isabelle(item): 159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}" 160 | item = { 161 | 'dataset': 'minif2f-isabelle', 162 | 'id': item['id'], 163 | 'messages': [ 164 | {'role': 'user', 'content': question}, 165 | {'role': 'assistant', 'content': ''} 166 | ], 167 | "answer": "placeholder" 168 | } 169 | yield item 170 | -------------------------------------------------------------------------------- /evaluation/environment.yml: -------------------------------------------------------------------------------- 1 | name: vllm020 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py310h6a678d5_7 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2023.08.22=h06a4308_0 13 | - certifi=2023.7.22=py310h06a4308_0 14 | - cffi=1.15.1=py310h5eee18b_3 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=41.0.3=py310hdda0065_0 17 | - cuda-cudart=11.7.99=0 18 | - cuda-cupti=11.7.101=0 19 | - cuda-libraries=11.7.1=0 20 | - cuda-nvrtc=11.7.99=0 21 | - cuda-nvtx=11.7.91=0 22 | - cuda-runtime=11.7.1=0 23 | - ffmpeg=4.3=hf484d3e_0 24 | - filelock=3.9.0=py310h06a4308_0 25 | - freetype=2.12.1=h4a9f257_0 26 | - giflib=5.2.1=h5eee18b_3 27 | - gmp=6.2.1=h295c915_3 28 | - gmpy2=2.1.2=py310heeb90bb_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - idna=3.4=py310h06a4308_0 31 | - intel-openmp=2023.1.0=hdb19cb5_46306 32 | - jinja2=3.1.2=py310h06a4308_0 33 | - jpeg=9e=h5eee18b_1 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.38=h1181459_1 37 | - lerc=3.0=h295c915_0 38 | - libcublas=11.10.3.66=0 39 | - libcufft=10.7.2.124=h4fbf590_0 40 | - libcufile=1.8.1.2=0 41 | - libcurand=10.3.4.101=0 42 | - libcusolver=11.4.0.1=0 43 | - libcusparse=11.7.4.91=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_0 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libiconv=1.16=h7f8727e_2 49 | - libidn2=2.3.4=h5eee18b_0 50 | - libnpp=11.7.4.75=0 51 | - libnvjpeg=11.8.0.2=0 52 | - libpng=1.6.39=h5eee18b_0 53 | - libstdcxx-ng=11.2.0=h1234567_1 54 | - libtasn1=4.19.0=h5eee18b_0 55 | - libtiff=4.5.1=h6a678d5_0 56 | - libunistring=0.9.10=h27cfd23_0 57 | - libuuid=1.41.5=h5eee18b_0 58 | - libwebp=1.3.2=h11a3e52_0 59 | - libwebp-base=1.3.2=h5eee18b_0 60 | - lz4-c=1.9.4=h6a678d5_0 61 | - markupsafe=2.1.1=py310h7f8727e_0 62 | - mkl=2023.1.0=h213fc3f_46344 63 | - mkl-service=2.4.0=py310h5eee18b_1 64 | - mkl_fft=1.3.8=py310h5eee18b_0 65 | - mkl_random=1.2.4=py310hdb19cb5_0 66 | - mpc=1.1.0=h10f8cd9_1 67 | - mpfr=4.0.2=hb69a4c5_1 68 | - mpmath=1.3.0=py310h06a4308_0 69 | - ncurses=6.4=h6a678d5_0 70 | - nettle=3.7.3=hbbd107a_1 71 | - networkx=3.1=py310h06a4308_0 72 | - numpy=1.26.0=py310h5f9d8c6_0 73 | - numpy-base=1.26.0=py310hb5e798b_0 74 | - openh264=2.1.1=h4ff587b_0 75 | - openjpeg=2.4.0=h3ad879b_0 76 | - openssl=3.0.12=h7f8727e_0 77 | - pillow=10.0.1=py310ha6cbd5a_0 78 | - pip=23.3=py310h06a4308_0 79 | - pycparser=2.21=pyhd3eb1b0_0 80 | - pyopenssl=23.2.0=py310h06a4308_0 81 | - pysocks=1.7.1=py310h06a4308_0 82 | - python=3.10.13=h955ad1f_0 83 | - pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0 84 | - pytorch-cuda=11.7=h778d358_5 85 | - pytorch-mutex=1.0=cuda 86 | - readline=8.2=h5eee18b_0 87 | - requests=2.31.0=py310h06a4308_0 88 | - setuptools=68.0.0=py310h06a4308_0 89 | - sqlite=3.41.2=h5eee18b_0 90 | - tbb=2021.8.0=hdb19cb5_0 91 | - tk=8.6.12=h1ccaba5_0 92 | - torchaudio=2.0.2=py310_cu117 93 | - torchtriton=2.0.0=py310 94 | - torchvision=0.15.2=py310_cu117 95 | - urllib3=1.26.18=py310h06a4308_0 96 | - wheel=0.41.2=py310h06a4308_0 97 | - xz=5.4.2=h5eee18b_0 98 | - zlib=1.2.13=h5eee18b_0 99 | - zstd=1.5.5=hc292b87_0 100 | - pip: 101 | - absl-py==2.0.0 102 | - accelerate==0.21.0 103 | - aiofiles==23.2.1 104 | - aiohttp==3.8.6 105 | - aiosignal==1.3.1 106 | - altair==5.1.2 107 | - antlr4-python3-runtime==4.11.0 108 | - anyio==3.7.1 109 | - appdirs==1.4.4 110 | - asttokens==2.4.1 111 | - async-timeout==4.0.3 112 | - attrs==23.1.0 113 | - bitarray==2.8.3 114 | - bitsandbytes==0.41.2.post2 115 | - blessed==1.20.0 116 | - blinker==1.7.0 117 | - cachetools==5.3.2 118 | - clarabel==0.6.0 119 | - click==8.1.7 120 | - contourpy==1.2.0 121 | - cvxpy==1.4.1 122 | - cycler==0.12.1 123 | - datasets==2.14.5 124 | - decorator==5.1.1 125 | - deepspeed==0.11.0 126 | - dill==0.3.7 127 | - distro==1.8.0 128 | - docker-pycreds==0.4.0 129 | - ecos==2.0.12 130 | - einops==0.7.0 131 | - evaluate==0.4.1 132 | - exceptiongroup==1.1.3 133 | - executing==2.0.1 134 | - fastapi==0.104.1 135 | - ffmpy==0.3.1 136 | - fire==0.5.0 137 | - flash-attn==2.0.1 138 | - flask==3.0.0 139 | - fonttools==4.44.3 140 | - frozenlist==1.4.0 141 | - fsspec==2023.6.0 142 | - func-timeout==4.3.5 143 | - gitdb==4.0.11 144 | - gitpython==3.1.40 145 | - google-auth==2.25.2 146 | - google-auth-oauthlib==1.2.0 147 | - gpustat==1.1.1 148 | - gradio==3.50.2 149 | - gradio-client==0.6.1 150 | - grpcio==1.59.2 151 | - h11==0.14.0 152 | - hjson==3.1.0 153 | - httpcore==1.0.2 154 | - httptools==0.6.1 155 | - httpx==0.25.1 156 | - huggingface-hub==0.19.3 157 | - importlib-resources==6.1.1 158 | - ipython==8.17.2 159 | - itsdangerous==2.1.2 160 | - jedi==0.19.1 161 | - joblib==1.3.2 162 | - jsonlines==4.0.0 163 | - jsonschema==4.19.2 164 | - jsonschema-specifications==2023.11.1 165 | - kiwisolver==1.4.5 166 | - markdown==3.5.1 167 | - markdown-it-py==3.0.0 168 | - matplotlib==3.8.1 169 | - matplotlib-inline==0.1.6 170 | - mdurl==0.1.2 171 | - mecab-python3==1.0.8 172 | - msgpack==1.0.7 173 | - multidict==6.0.4 174 | - multiprocess==0.70.15 175 | - ninja==1.11.1.1 176 | - nltk==3.8.1 177 | - nvidia-ml-py==12.535.133 178 | - oauthlib==3.2.2 179 | - openai==1.3.0 180 | - orjson==3.9.10 181 | - osqp==0.6.3 182 | - packaging==23.2 183 | - pandas==2.1.3 184 | - parso==0.8.3 185 | - pebble==5.0.3 186 | - peft==0.6.2 187 | - pexpect==4.8.0 188 | - prompt-toolkit==3.0.41 189 | - protobuf==4.23.4 190 | - psutil==5.9.6 191 | - ptyprocess==0.7.0 192 | - pure-eval==0.2.2 193 | - py-cpuinfo==9.0.0 194 | - pyarrow==14.0.1 195 | - pyasn1==0.5.1 196 | - pyasn1-modules==0.3.0 197 | - pybind11==2.11.1 198 | - pydantic==1.10.13 199 | - pydub==0.25.1 200 | - pygments==2.16.1 201 | - pylatexenc==2.10 202 | - pyparsing==3.1.1 203 | - python-dateutil==2.8.2 204 | - python-dotenv==1.0.0 205 | - python-multipart==0.0.6 206 | - pytz==2023.3.post1 207 | - pyyaml==6.0.1 208 | - qdldl==0.1.7.post0 209 | - ray==2.6.3 210 | - referencing==0.31.0 211 | - regex==2023.10.3 212 | - requests-oauthlib==1.3.1 213 | - responses==0.18.0 214 | - rich==13.7.0 215 | - rouge-score==0.1.2 216 | - rpds-py==0.12.0 217 | - rsa==4.9 218 | - safetensors==0.4.0 219 | - scipy==1.11.3 220 | - scs==3.2.4 221 | - semantic-version==2.10.0 222 | - sentencepiece==0.1.99 223 | - sentry-sdk==1.35.0 224 | - setproctitle==1.3.3 225 | - six==1.16.0 226 | - smmap==5.0.1 227 | - sniffio==1.3.0 228 | - stack-data==0.6.3 229 | - starlette==0.27.0 230 | - sympy==1.12 231 | - tensorboard==2.15.1 232 | - tensorboard-data-server==0.7.2 233 | - termcolor==2.3.0 234 | - tiktoken==0.5.1 235 | - timeout-decorator==0.5.0 236 | - tokenizers==0.15.0 237 | - toolz==0.12.0 238 | - tqdm==4.66.1 239 | - traitlets==5.13.0 240 | - transformers==4.35.2 241 | - typing-extensions==4.8.0 242 | - tzdata==2023.3 243 | - unidic-lite==1.0.8 244 | - uvicorn==0.24.0.post1 245 | - uvloop==0.19.0 246 | - vllm==0.2.0 247 | - wandb==0.16.0 248 | - watchfiles==0.21.0 249 | - wcwidth==0.2.10 250 | - websockets==11.0.3 251 | - werkzeug==3.0.1 252 | - xformers==0.0.22 253 | - xxhash==3.4.1 254 | - yarl==1.9.2 255 | - zstandard==0.22.0 256 | prefix: 257 | -------------------------------------------------------------------------------- /evaluation/eval/eval_script.py: -------------------------------------------------------------------------------- 1 | import regex 2 | from copy import deepcopy 3 | from eval.eval_utils import math_equal 4 | from eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin 5 | 6 | def is_correct(item, pred_key='prediction', prec=1e-3): 7 | pred = item[pred_key] 8 | ans = item['answer'] 9 | if isinstance(pred, list) and isinstance(ans, list): 10 | pred_matched = set() 11 | ans_matched = set() 12 | for i in range(len(pred)): 13 | for j in range(len(ans)): 14 | item_cpy = deepcopy(item) 15 | item_cpy.update({ 16 | pred_key: pred[i], 17 | 'answer': ans[j] 18 | }) 19 | if is_correct(item_cpy, pred_key=pred_key, prec=prec): 20 | pred_matched.add(i) 21 | ans_matched.add(j) 22 | if item_cpy[pred_key] == '2,3,4': 23 | print(item, flush=True) 24 | print("wtf", flush=True) 25 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) 26 | elif isinstance(pred, str) and isinstance(ans, str): 27 | if '\\cup' in pred and '\\cup' in ans: 28 | item = deepcopy(item) 29 | item.update({ 30 | pred_key: pred.split('\\cup'), 31 | 'answer': ans.split('\\cup'), 32 | }) 33 | return is_correct(item, pred_key=pred_key, prec=prec) 34 | else: 35 | label = False 36 | try: 37 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec 38 | except: 39 | pass 40 | label = label or (ans and pred == ans) or math_equal(pred, ans) 41 | return label 42 | else: 43 | print(item, flush=True) 44 | raise NotImplementedError() 45 | 46 | def eval_math(item, pred_key='prediction', prec=1e-3): 47 | pred = item[pred_key] 48 | if pred_key == 'program_output' and isinstance(pred, str): 49 | pred = [pred] 50 | ans = item['answer'] 51 | if isinstance(pred, list) and isinstance(ans, list): 52 | # for some questions in MATH, `reference` repeats answers 53 | _ans = [] 54 | for a in ans: 55 | if a not in _ans: 56 | _ans.append(a) 57 | ans = _ans 58 | # some predictions for MATH questions also repeats answers 59 | _pred = [] 60 | for a in pred: 61 | if a not in _pred: 62 | _pred.append(a) 63 | # some predictions mistakenly box non-answer strings 64 | pred = _pred[-len(ans):] 65 | 66 | item.update({ 67 | pred_key: pred, 68 | 'answer': ans 69 | }) 70 | return is_correct(item, pred_key=pred_key, prec=prec) 71 | 72 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3): 73 | for key in [pred_key, 'answer']: 74 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 75 | return is_correct(item, pred_key=pred_key, prec=prec) 76 | 77 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3): 78 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 79 | item[pred_key] = [item[pred_key]] 80 | for key in [pred_key, 'answer']: 81 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" 82 | pred = item[pred_key] 83 | ans = item['answer'] 84 | _pred = [] 85 | for p in pred: 86 | p = p + ";" 87 | while p: 88 | left_brackets = 0 89 | for i in range(len(p)): 90 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0): 91 | _p, p = p[:i].strip(), p[i + 1:].strip() 92 | if _p not in _pred: 93 | _pred.append(_p) 94 | break 95 | elif p[i] in '([{': 96 | left_brackets += 1 97 | elif p[i] in ')]}': 98 | left_brackets -= 1 99 | pred = _pred[-len(ans):] 100 | if len(pred) == len(ans): 101 | for p, a in zip(pred, ans): 102 | item.update({ 103 | pred_key: p, 104 | 'answer': a, 105 | }) 106 | if not is_correct(item, pred_key=pred_key, prec=prec): 107 | return False 108 | return True 109 | else: 110 | return False 111 | 112 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3): 113 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 114 | item[pred_key] = [item[pred_key]] 115 | pred_str = " ".join(item[pred_key]) 116 | ans = item['answer'] 117 | tag = None 118 | idx = -1 119 | for t in 'ABCD': 120 | if t in pred_str and pred_str.index(t) > idx: 121 | tag = t 122 | idx = pred_str.index(t) 123 | return tag == ans 124 | 125 | def eval_math_sat(item, pred_key='prediction', prec=1e-3): 126 | for key in [pred_key, 'answer']: 127 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 128 | return item[pred_key].lower() == item['answer'].lower() 129 | 130 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3): 131 | return eval_math_sat(item, pred_key=pred_key, prec=prec) 132 | 133 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): 134 | INVALID_ANSWER = "[invalidanswer]" 135 | for key in [pred_key, 'answer']: 136 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 137 | pred = item[pred_key] 138 | ans = item['answer'] 139 | 140 | try: 141 | float(ans) 142 | normalize_fn = normalize_numeric 143 | is_equiv = numeric_equality 144 | answer_type = "numeric" 145 | except ValueError: 146 | if "=" in ans: 147 | normalize_fn = normalize_symbolic_equation 148 | is_equiv = lambda x, y: x==y 149 | answer_type = "equation" 150 | else: 151 | normalize_fn = SymbolicMathMixin().normalize_tex 152 | is_equiv = SymbolicMathMixin().is_tex_equiv 153 | answer_type = "expression" 154 | 155 | correct_answer = normalize_fn(ans) 156 | 157 | unnormalized_answer = pred if pred else INVALID_ANSWER 158 | model_answer = normalize_fn(unnormalized_answer) 159 | 160 | if unnormalized_answer == INVALID_ANSWER: 161 | acc = 0 162 | elif model_answer == INVALID_ANSWER: 163 | acc = 0 164 | elif is_equiv(model_answer, correct_answer): 165 | acc = 1 166 | else: 167 | acc = 0 168 | 169 | return acc 170 | 171 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3): 172 | return True 173 | -------------------------------------------------------------------------------- /evaluation/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from math import isclose 3 | import numpy as np 4 | from typing import Union, Any, Dict 5 | 6 | from sympy import simplify, N 7 | from sympy.parsing.sympy_parser import parse_expr 8 | from sympy.parsing.latex import parse_latex 9 | import re 10 | import regex 11 | 12 | from data_processing.answer_extraction import extract_answer, extract_program_output, strip_string 13 | 14 | def extract_program(result: str, last_only=True): 15 | """ 16 | extract the program after "```python", and before "```" 17 | """ 18 | program = "" 19 | start = False 20 | for line in result.split("\n"): 21 | if line.startswith("```python"): 22 | if last_only: 23 | program = "" # only extract the last program 24 | else: 25 | program += "\n# ========\n" 26 | start = True 27 | elif line.startswith("```"): 28 | start = False 29 | elif start: 30 | program += line + "\n" 31 | return program 32 | 33 | 34 | def parse_ground_truth(example: Dict[str, Any], data_name): 35 | if 'gt_cot' in example: 36 | return example['gt_cot'], strip_string(example['gt']) 37 | 38 | # parse ground truth 39 | if data_name in ["math", 'ocw']: 40 | gt_cot = example['solution'] 41 | gt_ans = extract_answer(gt_cot) 42 | elif data_name == "gsm8k": 43 | gt_cot, gt_ans = example['answer'].split("####") 44 | elif data_name == "gsm-hard": 45 | gt_cot, gt_ans = example['code'], example['target'] 46 | elif data_name == "svamp": 47 | gt_cot, gt_ans = example['Equation'], example['Answer'] 48 | elif data_name == "asdiv": 49 | gt_cot = example['formula'] 50 | gt_ans = re.sub(r"\(.*?\)", "", example['answer']) 51 | elif data_name == "mawps": 52 | gt_cot, gt_ans = None, example['target'] 53 | elif data_name == "tabmwp": 54 | gt_cot = example['solution'] 55 | gt_ans = example['answer'] 56 | if example['ans_type'] in ['integer_number', 'decimal_number']: 57 | if '/' in gt_ans: 58 | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) 59 | elif ',' in gt_ans: 60 | gt_ans = float(gt_ans.replace(',', '')) 61 | elif '%' in gt_ans: 62 | gt_ans = float(gt_ans.split('%')[0]) / 100 63 | else: 64 | gt_ans = float(gt_ans) 65 | elif data_name == "bbh": 66 | gt_cot, gt_ans = None, example['target'] 67 | else: 68 | raise NotImplementedError(data_name) 69 | # post process 70 | gt_cot = str(gt_cot).strip() 71 | gt_ans = strip_string(gt_ans) 72 | return gt_cot, gt_ans 73 | 74 | 75 | def parse_question(example, data_name): 76 | question = "" 77 | if data_name == "asdiv": 78 | question = f"{example['body'].strip()} {example['question'].strip()}" 79 | elif data_name == "svamp": 80 | body = example["Body"].strip() 81 | if not body.endswith("."): 82 | body = body + "." 83 | question = f'{body} {example["Question"].strip()}' 84 | elif data_name == "tabmwp": 85 | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" 86 | question = f'Read the following table {title_str}and answer a question:\n' 87 | question += f'{example["table"]}\n{example["question"]}' 88 | if example['choices']: 89 | question += f' Please select from the following options: {example["choices"]}' 90 | else: 91 | for key in ['question', 'problem', 'Question', 'input']: 92 | if key in example: 93 | question = example[key] 94 | break 95 | assert question != "" 96 | return question.strip() 97 | 98 | 99 | def run_execute(executor, result, prompt_type, execute=False): 100 | if not result or result == 'error': 101 | return None, None 102 | report = None 103 | 104 | if "program_only" in prompt_type: 105 | prediction = extract_program_output(result) 106 | elif prompt_type in ["pot", "pal"] and execute: 107 | code = extract_program(result) 108 | prediction, report = executor.apply(code) 109 | else: 110 | prediction = extract_answer(result) 111 | 112 | prediction = strip_string(prediction) 113 | return prediction, report 114 | 115 | 116 | def parse_digits(num): 117 | # format: 234.23 || 23% 118 | num = regex.sub(',', '', str(num)) 119 | try: 120 | return float(num) 121 | except: 122 | if num.endswith('%'): 123 | num = num[:-1] 124 | if num.endswith('\\'): 125 | num = num[:-1] 126 | try: 127 | return float(num) / 100 128 | except: 129 | pass 130 | return None 131 | 132 | def is_digit(num): 133 | # paired with parse_digits 134 | return parse_digits(num) is not None 135 | 136 | 137 | def normalize_prediction(prediction): 138 | try: # 1. numerical equal 139 | if is_digit(prediction): 140 | prediction = np.round(float(str(prediction).replace(",", "")), 6) 141 | return str(prediction) 142 | except: 143 | pass 144 | 145 | # 2. symbolic equal 146 | prediction = str(prediction).strip() 147 | 148 | ## deal with [], (), {} 149 | brackets = [] 150 | while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")): 151 | bracket = prediction[0] 152 | prediction = prediction[1:-1] 153 | if brackets and ',' in prediction: 154 | pred_parts = [normalize_prediction(part) for part in prediction.split(",")] 155 | prediction = ",".join(pred_parts) 156 | 157 | if brackets: 158 | for b in reversed(brackets): 159 | if b == '[': 160 | prediction = '[' + prediction + ']' 161 | else: 162 | assert b == '(' 163 | prediction = '(' + prediction + ')' 164 | 165 | def _parse(s): 166 | for f in [parse_latex, parse_expr]: 167 | try: 168 | return f(s) 169 | except: 170 | pass 171 | return s 172 | 173 | prediction = _parse(prediction) 174 | 175 | for s in ['{', "}", "(", ")"]: 176 | prediction = prediction.replace(s, "") 177 | 178 | return prediction 179 | 180 | 181 | def math_equal(prediction: Union[bool, float, str], 182 | reference: Union[float, str], 183 | include_percentage: bool = True, 184 | is_close: bool = True, 185 | timeout: bool = False, 186 | ) -> bool: 187 | """ 188 | Exact match of math if and only if: 189 | 1. numerical equal: both can convert to float and are equal 190 | 2. symbolic equal: both can convert to sympy expression and are equal 191 | """ 192 | if str(prediction) == str(reference): 193 | return True 194 | 195 | try: # 1. numerical equal 196 | if is_digit(prediction) and is_digit(reference): 197 | prediction = parse_digits(prediction) 198 | reference = parse_digits(reference) 199 | # number questions 200 | if include_percentage: 201 | gt_result = [reference / 100, reference, reference * 100] 202 | else: 203 | gt_result = [reference] 204 | for item in gt_result: 205 | try: 206 | if is_close: 207 | if isclose(item, prediction, abs_tol=1e-3): 208 | return True 209 | else: 210 | if item == prediction: 211 | return True 212 | except Exception: 213 | continue 214 | return False 215 | except: 216 | pass 217 | 218 | if not prediction and prediction not in [0, False]: 219 | return False 220 | 221 | # 2. symbolic equal 222 | reference = str(reference).strip() 223 | prediction = str(prediction).strip() 224 | 225 | if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: 226 | pred_parts = prediction[1:-1].split(",") 227 | ref_parts = reference[1:-1].split(",") 228 | if len(pred_parts) == len(ref_parts): 229 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 230 | return True 231 | 232 | if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ 233 | (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): 234 | pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] 235 | ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] 236 | matched = True 237 | if len(pred_lines) == len(ref_lines): 238 | for pred_line, ref_line in zip(pred_lines, ref_lines): 239 | pred_parts = pred_line.split("&") 240 | ref_parts = ref_line.split("&") 241 | if len(pred_parts) == len(ref_parts): 242 | if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 243 | matched = False 244 | break 245 | else: 246 | matched = False 247 | if not matched: 248 | break 249 | else: 250 | matched = False 251 | if matched: 252 | return True 253 | 254 | if prediction.count('=') == 1 and reference.count('=') == 1: 255 | pred = prediction.split('=') 256 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 257 | ref = reference.split('=') 258 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 259 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 260 | return True 261 | elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: 262 | if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): 263 | return True 264 | elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: 265 | if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): 266 | return True 267 | 268 | # symbolic equal with sympy 269 | if timeout: 270 | if call_with_timeout(symbolic_equal_process, prediction, reference): 271 | return True 272 | else: 273 | if symbolic_equal(prediction, reference): 274 | return True 275 | 276 | return False 277 | 278 | 279 | def math_equal_process(param): 280 | return math_equal(param[-2], param[-1]) 281 | 282 | 283 | def symbolic_equal(a, b): 284 | def _parse(s): 285 | for f in [parse_latex, parse_expr]: 286 | try: 287 | return f(s) 288 | except: 289 | pass 290 | return s 291 | a = _parse(a) 292 | b = _parse(b) 293 | 294 | try: 295 | if simplify(a-b) == 0: 296 | return True 297 | except: 298 | pass 299 | 300 | try: 301 | if isclose(N(a), N(b), abs_tol=1e-3): 302 | return True 303 | except: 304 | pass 305 | return False 306 | 307 | 308 | def symbolic_equal_process(a, b, output_queue): 309 | result = symbolic_equal(a, b) 310 | output_queue.put(result) 311 | 312 | 313 | def call_with_timeout(func, *args, timeout=1, **kwargs): 314 | output_queue = multiprocessing.Queue() 315 | process_args = args + (output_queue,) 316 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 317 | process.start() 318 | process.join(timeout) 319 | 320 | if process.is_alive(): 321 | process.terminate() 322 | process.join() 323 | return False 324 | 325 | return output_queue.get() 326 | -------------------------------------------------------------------------------- /evaluation/eval/ocwcourses_eval_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sympy 4 | from sympy.core.sympify import SympifyError 5 | from sympy.parsing.latex import parse_latex 6 | 7 | import signal 8 | 9 | INVALID_ANSWER = "[invalidanswer]" 10 | 11 | class timeout: 12 | def __init__(self, seconds=1, error_message="Timeout"): 13 | self.seconds = seconds 14 | self.error_message = error_message 15 | 16 | def handle_timeout(self, signum, frame): 17 | raise TimeoutError(self.error_message) 18 | 19 | def __enter__(self): 20 | signal.signal(signal.SIGALRM, self.handle_timeout) 21 | signal.alarm(self.seconds) 22 | 23 | def __exit__(self, type, value, traceback): 24 | signal.alarm(0) 25 | 26 | def normalize_numeric(s): 27 | if s is None: 28 | return None 29 | for unit in [ 30 | "eV", 31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}", 32 | " kg m/s", 33 | "kg*m/s", 34 | "kg", 35 | "m/s", 36 | "m / s", 37 | "m s^{-1}", 38 | "\\text{ m/s}", 39 | " \\mathrm{m/s}", 40 | " \\text{ m/s}", 41 | "g/mole", 42 | "g/mol", 43 | "\\mathrm{~g}", 44 | "\\mathrm{~g} / \\mathrm{mol}", 45 | "W", 46 | "erg/s", 47 | "years", 48 | "year", 49 | "cm", 50 | ]: 51 | s = s.replace(unit, "") 52 | s = s.strip() 53 | for maybe_unit in ["m", "s", "cm"]: 54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "") 55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "") 56 | s = s.strip() 57 | s = s.strip("$") 58 | try: 59 | return float(eval(s)) 60 | except: 61 | try: 62 | expr = parse_latex(s) 63 | if expr.is_number: 64 | return float(expr) 65 | return INVALID_ANSWER 66 | except: 67 | return INVALID_ANSWER 68 | 69 | def numeric_equality(n1, n2, threshold=0.01): 70 | if n1 is None or n2 is None: 71 | return False 72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0): 73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2 74 | else: 75 | return np.isclose(n1, n2) 76 | 77 | def normalize_symbolic_equation(s): 78 | if not isinstance(s, str): 79 | return INVALID_ANSWER 80 | if s.startswith("\\["): 81 | s = s[2:] 82 | if s.endswith("\\]"): 83 | s = s[:-2] 84 | s = s.replace("\\left(", "(") 85 | s = s.replace("\\right)", ")") 86 | s = s.replace("\\\\", "\\") 87 | if s.startswith("$") or s.endswith("$"): 88 | s = s.strip("$") 89 | try: 90 | maybe_expression = parse_latex(s) 91 | if not isinstance(maybe_expression, sympy.core.relational.Equality): 92 | # we have equation, not expression 93 | return INVALID_ANSWER 94 | else: 95 | return maybe_expression 96 | except: 97 | return INVALID_ANSWER 98 | 99 | class SymbolicMathMixin: 100 | """ 101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions. 102 | """ 103 | 104 | SUBSTITUTIONS = [ # used for text normalize 105 | ("an ", ""), 106 | ("a ", ""), 107 | (".$", "$"), 108 | ("\\$", ""), 109 | (r"\ ", ""), 110 | (" ", ""), 111 | ("mbox", "text"), 112 | (",\\text{and}", ","), 113 | ("\\text{and}", ","), 114 | ("\\text{m}", "\\text{}"), 115 | ] 116 | REMOVED_EXPRESSIONS = [ # used for text normalizer 117 | "square", 118 | "ways", 119 | "integers", 120 | "dollars", 121 | "mph", 122 | "inches", 123 | "ft", 124 | "hours", 125 | "km", 126 | "units", 127 | "\\ldots", 128 | "sue", 129 | "points", 130 | "feet", 131 | "minutes", 132 | "digits", 133 | "cents", 134 | "degrees", 135 | "cm", 136 | "gm", 137 | "pounds", 138 | "meters", 139 | "meals", 140 | "edges", 141 | "students", 142 | "childrentickets", 143 | "multiples", 144 | "\\text{s}", 145 | "\\text{.}", 146 | "\\text{\ns}", 147 | "\\text{}^2", 148 | "\\text{}^3", 149 | "\\text{\n}", 150 | "\\text{}", 151 | r"\mathrm{th}", 152 | r"^\circ", 153 | r"^{\circ}", 154 | r"\;", 155 | r",\!", 156 | "{,}", 157 | '"', 158 | "\\dots", 159 | ] 160 | 161 | def normalize_tex(self, final_answer: str) -> str: 162 | """ 163 | Normalizes a string representing a mathematical expression. 164 | Used as a preprocessing step before parsing methods. 165 | 166 | Copied character for character from appendix D of Lewkowycz et al. (2022) 167 | """ 168 | final_answer = final_answer.split("=")[-1] 169 | 170 | for before, after in self.SUBSTITUTIONS: 171 | final_answer = final_answer.replace(before, after) 172 | for expr in self.REMOVED_EXPRESSIONS: 173 | final_answer = final_answer.replace(expr, "") 174 | 175 | # Extract answer that is in LaTeX math, is bold, 176 | # is surrounded by a box, etc. 177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 182 | 183 | # Normalize shorthand TeX: 184 | # \fracab -> \frac{a}{b} 185 | # \frac{abc}{bef} -> \frac{abc}{bef} 186 | # \fracabc -> \frac{a}{b}c 187 | # \sqrta -> \sqrt{a} 188 | # \sqrtab -> sqrt{a}b 189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 191 | final_answer = final_answer.replace("$", "") 192 | 193 | # Normalize 100,000 -> 100000 194 | if final_answer.replace(",", "").isdigit(): 195 | final_answer = final_answer.replace(",", "") 196 | 197 | return final_answer 198 | 199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic: 200 | """ 201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression. 202 | Typically, you want to apply `normalize_text` as a preprocessing step. 203 | """ 204 | try: 205 | with timeout(seconds=time_limit): 206 | parsed = parse_latex(text) 207 | except ( 208 | # general error handling: there is a long tail of possible sympy/other 209 | # errors we would like to catch 210 | Exception 211 | ) as e: 212 | print(f"failed to parse {text} with exception {e}") 213 | return None 214 | 215 | return parsed 216 | 217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool: 218 | """ 219 | Determines whether two sympy expressions are equal. 220 | """ 221 | try: 222 | with timeout(seconds=time_limit): 223 | try: 224 | diff = x1 - x2 225 | except (SympifyError, ValueError, TypeError) as e: 226 | print( 227 | f"Couldn't subtract {x1} and {x2} with exception {e}" 228 | ) 229 | return False 230 | 231 | try: 232 | if sympy.simplify(diff) == 0: 233 | return True 234 | else: 235 | return False 236 | except (SympifyError, ValueError, TypeError) as e: 237 | print(f"Failed to simplify {x1}-{x2} with {e}") 238 | return False 239 | except TimeoutError as e: 240 | print(f"Timed out comparing {x1} and {x2}") 241 | return False 242 | except Exception as e: 243 | print(f"failed on unrecognized exception {e}") 244 | return False 245 | 246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool: 247 | """ 248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal. 249 | 250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence, 251 | following the (Lewkowycz et al. 2022) methodology. 252 | """ 253 | if x1 == x2: 254 | # don't resort to sympy if we have full string match, post-normalization 255 | return True 256 | else: 257 | return False 258 | parsed_x2 = self.parse_tex(x2) 259 | if not parsed_x2: 260 | # if our reference fails to parse into a Sympy object, 261 | # we forgo parsing + checking our generated answer. 262 | return False 263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit) 264 | -------------------------------------------------------------------------------- /evaluation/eval/python_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from contextlib import redirect_stdout 4 | import pickle 5 | import regex 6 | import copy 7 | from typing import Any, Dict, Optional 8 | import multiprocess 9 | from pebble import ProcessPool 10 | from concurrent.futures import TimeoutError 11 | from functools import partial 12 | import traceback 13 | from timeout_decorator import timeout 14 | 15 | class GenericRuntime: 16 | GLOBAL_DICT = {} 17 | LOCAL_DICT = None 18 | HEADERS = [] 19 | def __init__(self): 20 | self._global_vars = copy.copy(self.GLOBAL_DICT) 21 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None 22 | 23 | for c in self.HEADERS: 24 | self.exec_code(c) 25 | 26 | def exec_code(self, code_piece: str) -> None: 27 | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece): 28 | raise RuntimeError() 29 | exec(code_piece, self._global_vars) 30 | 31 | def eval_code(self, expr: str) -> Any: 32 | return eval(expr, self._global_vars) 33 | 34 | def inject(self, var_dict: Dict[str, Any]) -> None: 35 | for k, v in var_dict.items(): 36 | self._global_vars[k] = v 37 | 38 | @property 39 | def answer(self): 40 | return self._global_vars['answer'] 41 | 42 | class PythonExecutor: 43 | def __init__( 44 | self, 45 | runtime: Optional[Any] = None, 46 | get_answer_symbol: Optional[str] = None, 47 | get_answer_expr: Optional[str] = None, 48 | get_answer_from_stdout: bool = False, 49 | ) -> None: 50 | self.runtime = runtime if runtime else GenericRuntime() 51 | self.answer_symbol = get_answer_symbol 52 | self.answer_expr = get_answer_expr 53 | self.get_answer_from_stdout = get_answer_from_stdout 54 | 55 | def process_generation_to_code(self, gens: str): 56 | batch_code = [] 57 | for g in gens: 58 | multiline_comments = False 59 | code = [] 60 | for line in g.split('\n'): 61 | strip_line = line.strip() 62 | if strip_line.startswith("#"): 63 | line = line.split("#", 1)[0] + "# comments" 64 | elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6: 65 | line = line.split('"""', 1)[0] + '"""comments"""' 66 | elif not multiline_comments and strip_line.startswith('"""'): 67 | multiline_comments = True 68 | elif multiline_comments and strip_line.endswith('"""'): 69 | multiline_comments = False 70 | line = "" 71 | if not multiline_comments: 72 | code.append(line) 73 | batch_code.append(code) 74 | return batch_code 75 | 76 | @staticmethod 77 | def execute( 78 | code, 79 | get_answer_from_stdout = None, 80 | runtime = None, 81 | answer_symbol = None, 82 | answer_expr = None, 83 | timeout_length = 10, 84 | ): 85 | try: 86 | if get_answer_from_stdout: 87 | program_io = io.StringIO() 88 | with redirect_stdout(program_io): 89 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 90 | program_io.seek(0) 91 | result = "".join(program_io.readlines()) # [-1] 92 | elif answer_symbol: 93 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 94 | result = runtime._global_vars[answer_symbol] 95 | elif answer_expr: 96 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 97 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) 98 | else: 99 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) 100 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 101 | concise_exec_info = "" 102 | exec_info = "" 103 | str(result) 104 | pickle.dumps(result) # serialization check 105 | except: 106 | # traceback.print_exc() 107 | result = '' 108 | concise_exec_info = traceback.format_exc().split('\n')[-2] 109 | exec_info = traceback.format_exc() 110 | if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info: 111 | exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip() 112 | msg = [] 113 | for line in exec_info.split("\n"): 114 | patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line) 115 | if patt is not None: 116 | if '' in patt.group('end'): 117 | continue 118 | fname = patt.group("file") 119 | if "site-packages" in fname: 120 | fname = f"site-packages{fname.split('site-packages', 1)[1]}" 121 | line = f'{patt.group("start")}File "{fname}", {patt.group("end")}' 122 | else: 123 | line = f'{patt.group("start")}{patt.group("end")}' 124 | else: 125 | patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line) 126 | if patt is not None: 127 | line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}' 128 | msg.append(line) 129 | exec_info = "\n".join(msg) 130 | return result, concise_exec_info, exec_info 131 | 132 | def apply(self, code): 133 | return self.batch_apply([code])[0] 134 | 135 | def batch_apply(self, batch_code): 136 | all_code_snippets = self.process_generation_to_code(batch_code) 137 | all_exec_results = [] 138 | executor = partial( 139 | self.execute, 140 | get_answer_from_stdout=self.get_answer_from_stdout, 141 | runtime=self.runtime, 142 | answer_symbol=self.answer_symbol, 143 | answer_expr=self.answer_expr, 144 | timeout_length=10, 145 | ) 146 | with ProcessPool(max_workers=multiprocess.cpu_count()) as pool: 147 | iterator = pool.map(executor, all_code_snippets, timeout=10).result() 148 | 149 | while True: 150 | try: 151 | result = next(iterator) 152 | all_exec_results.append(result) 153 | except StopIteration: 154 | break 155 | except TimeoutError as error: 156 | all_exec_results.append(("", "Timeout Error", "Timeout Error")) 157 | except Exception as error: 158 | print(error) 159 | exit() 160 | 161 | batch_results = [] 162 | for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results): 163 | metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info} 164 | batch_results.append((result, metadata)) 165 | return batch_results 166 | -------------------------------------------------------------------------------- /evaluation/eval/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from transformers import StoppingCriteria, GenerationConfig 4 | 5 | class KeyWordsCriteria(StoppingCriteria): 6 | def __init__(self, stop_id_sequences, tokenizer, prompt_length): 7 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 8 | self.tokenizer = tokenizer 9 | self.stop_id_sequences = stop_id_sequences 10 | self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences] 11 | print(f"stop sequences: {self.stop_sequences}", flush=True) 12 | self.prompt_length = prompt_length 13 | 14 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 15 | sequences_should_be_stopped = [] 16 | for i in range(input_ids.shape[0]): 17 | ids = input_ids[i][self.prompt_length:].tolist() 18 | should_be_stopped = False 19 | for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences): 20 | _ids = ids 21 | for j in range(len(_ids), 0, -1): 22 | s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j]) 23 | if s.endswith(stop_sequence): 24 | should_be_stopped = True 25 | break 26 | if should_be_stopped: 27 | break 28 | sequences_should_be_stopped.append(should_be_stopped) 29 | return all(sequences_should_be_stopped) 30 | 31 | @torch.no_grad() 32 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs): 33 | generations = [] 34 | finish_completion = [] 35 | if not disable_tqdm: 36 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions") 37 | 38 | if stop_id_sequences is not None: 39 | stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences] 40 | 41 | if end_of_generation_id_sequence is not None: 42 | end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence) 43 | 44 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1) 45 | generation_kwargs['use_cache'] = True 46 | for i in range(0, len(prompts), batch_size): 47 | batch_prompts = prompts[i:i+batch_size] 48 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__)) 49 | batch_input_ids = tokenized_prompts.input_ids 50 | attention_mask = tokenized_prompts.attention_mask 51 | 52 | if model.device.type == "cuda": 53 | batch_input_ids = batch_input_ids.cuda() 54 | attention_mask = attention_mask.cuda() 55 | 56 | batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences 57 | try: 58 | batch_outputs = model.generate( 59 | input_ids=batch_input_ids, 60 | attention_mask=attention_mask, 61 | stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None, 62 | **generation_kwargs 63 | ) 64 | 65 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate. 66 | # so some outputs still have the stop sequence, which we need to remove. 67 | if stop_id_sequences: 68 | for output_idx in range(batch_outputs.shape[0]): 69 | finish = False 70 | for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]): 71 | if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences): 72 | if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence): 73 | batch_finish_completion[output_idx] = True 74 | batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id 75 | break 76 | 77 | # remove the prompt from the output 78 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs. 79 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token. 80 | # space is important for some tasks (e.g., code completion). 81 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True) 82 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True) 83 | # duplicate the prompts to match the number of return sequences 84 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)] 85 | batch_generations = [ 86 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs) 87 | ] 88 | except Exception as e: 89 | print("Error when generating completions for batch:") 90 | print(batch_prompts) 91 | print("Error message:") 92 | print(e) 93 | print("Use empty string as the completion.") 94 | batch_generations = [""] * len(batch_prompts) * num_return_sequences 95 | 96 | generations += batch_generations 97 | finish_completion += batch_finish_completion 98 | 99 | if not disable_tqdm: 100 | progress.update(len(batch_prompts)//num_return_sequences) 101 | 102 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences" 103 | return generations, finish_completion 104 | 105 | 106 | @torch.no_grad() 107 | def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False): 108 | predictions, probs = [], [] 109 | if not disable_tqdm: 110 | progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions") 111 | 112 | for i in range(0, len(prompts), batch_size): 113 | batch_prompts = prompts[i: i+batch_size] 114 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False) 115 | batch_input_ids = tokenized_prompts.input_ids 116 | attention_mask = tokenized_prompts.attention_mask 117 | 118 | if model.device.type == "cuda": 119 | batch_input_ids = batch_input_ids.cuda() 120 | attention_mask = attention_mask.cuda() 121 | 122 | batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :] 123 | if candidate_token_ids is not None: 124 | batch_logits = batch_logits[:, candidate_token_ids] 125 | batch_probs = torch.softmax(batch_logits, dim=-1) 126 | batch_prediction_indices = torch.argmax(batch_probs, dim=-1) 127 | if return_token_predictions: 128 | if candidate_token_ids is not None: 129 | candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids) 130 | batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices] 131 | else: 132 | batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices) 133 | predictions += batch_predictions 134 | else: 135 | predictions += batch_prediction_indices.tolist() 136 | probs += batch_probs.tolist() 137 | 138 | if not disable_tqdm: 139 | progress.update(len(batch_prompts)) 140 | 141 | assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts" 142 | return predictions, probs 143 | 144 | 145 | @torch.no_grad() 146 | def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False): 147 | ''' 148 | Each scoring example is a dict, which contains the following keys: 149 | - prompt: the prompt to score 150 | - completions: a list of completions to score 151 | ''' 152 | 153 | if not disable_tqdm: 154 | progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions") 155 | 156 | # unroll the scoring examples 157 | unrolled_examples = [] 158 | for scoring_example in scoring_examples: 159 | prompt = scoring_example["prompt"] 160 | for completion in scoring_example["completions"]: 161 | unrolled_examples.append({ 162 | "prompt": prompt, 163 | "completion": completion 164 | }) 165 | 166 | scores = [] 167 | # currently we don't support batching, because we want to directly use the loss returned by the model to score each completion. 168 | for unrolled_example in unrolled_examples: 169 | encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None) 170 | # unsqueeze the batch dimension 171 | for key, value in encoded_example.items(): 172 | encoded_example[key] = value.unsqueeze(0) 173 | if model.device.type == "cuda": 174 | encoded_example = { 175 | key: value.cuda() for key, value in encoded_example.items() 176 | } 177 | outputs = model(**encoded_example) 178 | loss = outputs.loss 179 | scores.append(-loss.item()) 180 | if not disable_tqdm: 181 | progress.update(1) 182 | 183 | # roll up the scores 184 | rolled_up_scores = {} 185 | for unrolled_example, score in zip(unrolled_examples, scores): 186 | prompt = unrolled_example["prompt"] 187 | completion = unrolled_example["completion"] 188 | if prompt not in rolled_up_scores: 189 | rolled_up_scores[prompt] = {} 190 | rolled_up_scores[prompt][completion] = score 191 | 192 | return rolled_up_scores 193 | 194 | 195 | 196 | def load_hf_lm_and_tokenizer( 197 | model_name_or_path, 198 | tokenizer_name_or_path=None, 199 | device_map="auto", 200 | load_in_8bit=False, 201 | load_in_half=False, 202 | gptq_model=False, 203 | use_fast_tokenizer=True, 204 | padding_side="left", 205 | ): 206 | 207 | from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer 208 | 209 | if not tokenizer_name_or_path: 210 | tokenizer_name_or_path = model_name_or_path 211 | 212 | is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path 213 | is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path 214 | 215 | if is_chatglm2 or is_qwen: 216 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True) 217 | if is_qwen: 218 | tokenizer.eos_token = '<|endoftext|>' 219 | tokenizer.eos_token_id = 151643 220 | tokenizer.pad_token = tokenizer.eos_token 221 | tokenizer.pad_token_id = tokenizer.eos_token_id 222 | else: 223 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer) 224 | # set padding side to left for batch generation 225 | tokenizer.padding_side = padding_side 226 | # set pad token to eos token if pad token is not set (as is the case for llama models) 227 | if tokenizer.pad_token is None: 228 | tokenizer.pad_token = tokenizer.eos_token 229 | tokenizer.pad_token_id = tokenizer.eos_token_id 230 | 231 | if gptq_model: 232 | from auto_gptq import AutoGPTQForCausalLM 233 | model_wrapper = AutoGPTQForCausalLM.from_quantized( 234 | model_name_or_path, device="cuda:0", use_triton=True 235 | ) 236 | model = model_wrapper.model 237 | elif load_in_8bit: 238 | model = AutoModelForCausalLM.from_pretrained( 239 | model_name_or_path, 240 | device_map=device_map, 241 | load_in_8bit=True 242 | ) 243 | else: 244 | kwargs = {} 245 | model_class = AutoModelForCausalLM 246 | if is_chatglm2: 247 | kwargs = {'trust_remote_code': True} 248 | model_class = AutoModel 249 | elif is_qwen: 250 | kwargs = {'trust_remote_code': True} 251 | if device_map: 252 | model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs) 253 | else: 254 | model = model_class.from_pretrained(model_name_or_path, **kwargs) 255 | if torch.cuda.is_available(): 256 | model = model.cuda() 257 | if is_qwen: 258 | model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True) 259 | model.generation_config.do_sample = False 260 | if not is_chatglm2 and not is_qwen and load_in_half: 261 | model = model.half() 262 | model.eval() 263 | return model, tokenizer 264 | -------------------------------------------------------------------------------- /evaluation/evaluation_results.json: -------------------------------------------------------------------------------- 1 | { 2 | "DeepSeekMath-Base": { 3 | "OCWCourses": { 4 | "cot": { 5 | "accuracy": 0.15441176470588236, 6 | "n_samples": 272 7 | }, 8 | "tool": { 9 | "n_samples": 0 10 | } 11 | }, 12 | "cmath-cot-test": { 13 | "cot": { 14 | "accuracy": 0.7167577413479053, 15 | "n_samples": 1098 16 | }, 17 | "tool": { 18 | "n_samples": 0 19 | } 20 | }, 21 | "miniF2F-Isabelle-test": { 22 | "cot": { 23 | "accuracy": 1.0, 24 | "n_samples": 244 25 | }, 26 | "tool": { 27 | "n_samples": 0 28 | } 29 | }, 30 | "gsm8k-cot-test": { 31 | "cot": { 32 | "accuracy": 0.6421531463229719, 33 | "n_samples": 1319 34 | }, 35 | "tool": { 36 | "n_samples": 0 37 | } 38 | }, 39 | "MMLU-STEM-test": { 40 | "cot": { 41 | "accuracy": 0.5646123260437376, 42 | "n_samples": 3018 43 | }, 44 | "tool": { 45 | "n_samples": 0 46 | } 47 | }, 48 | "agieval-gaokao-mathqa-cot-test": { 49 | "cot": { 50 | "accuracy": 0.35327635327635326, 51 | "n_samples": 351 52 | }, 53 | "tool": { 54 | "n_samples": 0 55 | } 56 | }, 57 | "agieval-gaokao-mathcloze-cot-test": { 58 | "cot": { 59 | "accuracy": 0.2033898305084746, 60 | "n_samples": 118 61 | }, 62 | "tool": { 63 | "n_samples": 0 64 | } 65 | }, 66 | "gsm8k-pal-test": { 67 | "cot": { 68 | "n_samples": 0 69 | }, 70 | "tool": { 71 | "accuracy": 0.66868840030326, 72 | "n_samples": 1319 73 | } 74 | }, 75 | "math_sat": { 76 | "cot": { 77 | "accuracy": 0.84375, 78 | "n_samples": 32 79 | }, 80 | "tool": { 81 | "n_samples": 0 82 | } 83 | }, 84 | "miniF2F-Isabelle-valid": { 85 | "cot": { 86 | "accuracy": 1.0, 87 | "n_samples": 244 88 | }, 89 | "tool": { 90 | "n_samples": 0 91 | } 92 | }, 93 | "math-pal-test": { 94 | "cot": { 95 | "n_samples": 0 96 | }, 97 | "tool": { 98 | "accuracy": 0.3142, 99 | "n_samples": 5000 100 | } 101 | }, 102 | "math-cot-test": { 103 | "cot": { 104 | "accuracy": 0.3618, 105 | "n_samples": 5000 106 | }, 107 | "tool": { 108 | "n_samples": 0 109 | } 110 | } 111 | }, 112 | "DeepSeekMath-RL": { 113 | "mgsm-zh": { 114 | "cot": { 115 | "accuracy": 0.796, 116 | "n_samples": 250 117 | }, 118 | "tool": { 119 | "accuracy": 0.784, 120 | "program_accuracy": 0.776, 121 | "n_samples": 250 122 | } 123 | }, 124 | "cmath": { 125 | "cot": { 126 | "accuracy": 0.8879781420765027, 127 | "n_samples": 1098 128 | }, 129 | "tool": { 130 | "accuracy": 0.8761384335154827, 131 | "program_accuracy": 0.8570127504553734, 132 | "n_samples": 1098 133 | } 134 | }, 135 | "math-test": { 136 | "cot": { 137 | "accuracy": 0.517, 138 | "n_samples": 5000 139 | }, 140 | "tool": { 141 | "accuracy": 0.5878, 142 | "program_accuracy": 0.509, 143 | "n_samples": 5000 144 | } 145 | }, 146 | "gsm8k-test": { 147 | "cot": { 148 | "accuracy": 0.8824867323730099, 149 | "n_samples": 1319 150 | }, 151 | "tool": { 152 | "accuracy": 0.866565579984837, 153 | "program_accuracy": 0.868081880212282, 154 | "n_samples": 1319 155 | } 156 | } 157 | }, 158 | "DeepSeekMath-Instruct": { 159 | "gsm8k-test": { 160 | "cot": { 161 | "accuracy": 0.8286580742987112, 162 | "n_samples": 1319 163 | }, 164 | "tool": { 165 | "accuracy": 0.8369977255496588, 166 | "program_accuracy": 0.8332069749810462, 167 | "n_samples": 1319 168 | } 169 | }, 170 | "math-test": { 171 | "cot": { 172 | "accuracy": 0.4682, 173 | "n_samples": 5000 174 | }, 175 | "tool": { 176 | "accuracy": 0.575, 177 | "program_accuracy": 0.4664, 178 | "n_samples": 5000 179 | } 180 | }, 181 | "cmath": { 182 | "cot": { 183 | "accuracy": 0.8460837887067395, 184 | "n_samples": 1098 185 | }, 186 | "tool": { 187 | "accuracy": 0.843351548269581, 188 | "program_accuracy": 0.8214936247723132, 189 | "n_samples": 1098 190 | } 191 | }, 192 | "mgsm-zh": { 193 | "cot": { 194 | "accuracy": 0.732, 195 | "n_samples": 250 196 | }, 197 | "tool": { 198 | "accuracy": 0.72, 199 | "program_accuracy": 0.716, 200 | "n_samples": 250 201 | } 202 | } 203 | } 204 | } -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .cot_minerva_math_4_shot import MinervaMathPrompt 2 | from .cot_gsm_8_shot import CoTGSMPrompt 3 | from .cot_math_sat_4_shot import CoTSATPrompt 4 | from .cot_mmlu_stem_4_shot import MMLUSTEMPrompt 5 | from .cot_ocwcourses_4_shot import OCWCoursesPrompt 6 | from .pal_gsm_8_shot import PALGSMPrompt 7 | from .pal_math_4_shot import PALMathPrompt 8 | from .minif2f_isabelle import MiniF2FIsabellePrompt 9 | from .cot_cmath_6_shot import CoTCMATHPrompt 10 | from .cot_gaokao_mathcloze_5_shot import CoTGaoKaoMathClozePrompt 11 | from .cot_gaokao_mathqa_5_shot import CoTGaoKaoMathQAPrompt -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_cmath_6_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """ 4 | 问题:芳芳买了一本书有99页,看了90页,她还剩多少页没有看? 5 | 答案:还剩的没有看的页数=书的总页数-芳芳看了的页数,99-90=9。所以答案是:9。 6 | 7 | 8 | 问题:张师傅上午修了18把椅子,下午修了29把椅子,一天共修了多少把椅子? 9 | 答案:一天共修的椅子数量=上午修的椅子数量+下午修的椅子数量,18+29=47。所以答案是:47。 10 | 11 | 12 | 问题:小猴摘了84个桃子,平均分给6只猴子,每只猴子能吃到几个桃子? 13 | 答案:每只猴子能吃到的桃子数=总桃子数/猴子的数量,84/6=14。所以答案是:14。 14 | 15 | 16 | 问题:用面包机烤面包时,第一面烤2分钟,第二面只要烤1分钟,即烤一片面包需要3分钟,小勤的面包机一次只能放2片,他每天早上吃3片面包,至少需要烤多少分钟? 17 | 答案:可以现将两片面包放入面包机烤2分钟,再将其中一片拿出来,将第三片面包放进去,烤1分钟,这样第一片面包就烤好了,将第一片面包拿出来将第二片面包放进去,继续烤1分钟,于是第二片面包也烤好了将其拿出来,第三片面包再烤1分钟也就烤好了,一共是2+1+1=5。所以答案是:5。 18 | 19 | 20 | 问题:一组学生植树,每人栽6棵还剩4棵;如果其中3人各栽5棵,其余每人各栽7棵,正好栽完。这一组学生有多少人? 21 | 答案:假设学生的数量是x,每人栽6棵还剩4棵,也就是说树苗的数量=6x+4,又知道如果其中3人各栽5棵,其余每人各栽7棵,正好栽完,即6x+4=3*5+(x-3)*7,化简方程得到:x=10。所以答案是:10。 22 | 23 | 24 | 问题:某小学在“献爱心--为汶川地震区捐款”活动中,六年级五个班共捐款8000元,其中一班捐款1500元,二班比一班多捐款200元,三班捐款1600元,四班与五班捐款数之比是3:5.四班捐款多少元? 25 | 答案:一班捐款1500元,而二班比一班多捐200元,所以二班捐款1500+200=1700元,又知道六年级五个班一共捐款8000元,所以四班和五班捐款之和 = 一共捐款 - 一班和二班和三班捐款之和,即8000-1500-1700-1600=3200元,而题目说四班与五班捐款数之比是3:5,则四班捐款了3200/(3+5)*3=1200元。所以答案是:1200。 26 | """.strip() 27 | 28 | class CoTCMATHPrompt(FewShotPrompting): 29 | def __init__(self): 30 | super().__init__() 31 | 32 | def format_prompt(self, task_input, task_output): 33 | prompt = f"{few_shot_prompt}\n\n\n问题:{task_input}\n答案:{task_output}" 34 | return prompt.rstrip() 35 | 36 | def stop_words(self): 37 | return ["\n问题:"] 38 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_gaokao_mathcloze_5_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """ 4 | 问题 1. 设数列 $\\left\\{a_{n}\\right\\}$ 的前 $n$ 项和为 $S_{n}$, 且 $a_{1}=-1, a_{n+1}=S_{n+1} S_{n}$, 则 $S_{n}=(\\quad)$. 5 | 问题 1的解析: 让我们写出这个数列的前n项和: 6 | $S_n = a_1 + a_2 + ... + a_n$ 7 | $S_n = -1 + (S_2 S_1) + (S_3 S_2) + ... + (S_{n+1} S_n)$ 8 | $S_n = -1 + (S_2 S_1) + (S_3 S_2) + ... + (S_n S_{n-1}) + (S_{n+1} S_n)$ 9 | $S_n = -1 + S_n (S_{n+1} - S_1)$ 10 | $S_n - S_n S_{n+1} = -1 - S_n$ 11 | $S_n (1 - S_{n+1}) = -1 - S_n$ 12 | $S_n = -\\frac{1}{1 - S_{n+1}}$ 13 | 因为这个数列后面的所有项都是0,我们可以看到对于所有$n\\geq 1$,$S_{n+1} = 0$。因此,我们有: 14 | $S_n = -\\frac{1}{1 - S_{n+1}} = -\\frac{1}{1 - 0} = -1$ 15 | 这个数列前n项和的公式是$S_n = -\\frac{1}{n}$。 16 | 答案是 $-\\frac{1}{n}$ 17 | 18 | 19 | 问题 2. 若 $\\left(x+\\frac{1}{x}\\right)^{n}$ 的展开式中第 3 项与第 7 项的二项式系数相等, 则该展 开式中 $\\frac{1}{x^{2}}$ 的系数为 $(\\quad)$. 20 | 问题 2的解析: 由题意可得, $c_{n}^{2}=c_{n}^{6}$ 21 | $\\therefore n=8$ 22 | 展开式的通项 T_{r+1}=C_8^r x^{8-r}\\left(\\frac{1}{x}\\right)^r=C_8^r x^{8-2 r}$ 23 | 令 $8-2 r=-2$ 可得 $r=5$ 24 | 此时系数为 $c_{8}^{5}=56$ 25 | 答案是 56 26 | 27 | 28 | 问题 3. 函数 $\\mathrm{f}(\\mathrm{x})=\\sin (\\mathrm{x}+2 \\phi)-2 \\sin \\phi \\cos (\\mathrm{x}+\\phi)$ 的最大值为 $(\\quad)$. 29 | 问题 3的解析: 函数 $f(x)=\\sin (x+2 \\phi)-2 \\sin \\phi \\cos (x+\\phi)=\\sin [(x+\\phi)+\\phi]-$ $2 \\sin \\phi \\cos (x+\\phi)$ 30 | $=\\sin (x+\\phi) \\cos \\phi+\\cos (x+\\phi) \\sin \\phi-2 \\sin \\phi \\cos (x+\\phi)=\\sin (x+\\phi) \\cos \\phi-\\cos$ $(x+\\phi) \\sin \\phi$ $=\\sin [(x+\\phi)-\\phi]=\\sin x$ 31 | 故函数 $f(x)$ 的最大值为 1 32 | 答案是 1 33 | 34 | 35 | 问题 4. 已知向量 $\\vec{a}=(3,1), \\vec{b}=(1,0), \\vec{c}=\\vec{a}+k \\vec{b}$. 若 $\\vec{a} \\perp \\vec{c}$, 则 $k=(\\quad)$ 36 | 问题 4的解析: \\because \\vec{a}=(3,1), \\vec{b}=(1,0), \\therefore \\vec{c}=\\vec{a}+k \\vec{b}=(3+k, 1)$ , 37 | $\\because \\vec{a} \\perp \\vec{c}, \\therefore \\vec{a} \\square \\vec{c}=3(3+k)+1 \\times 1=0$, 解得 $k=-\\frac{10}{3}$ 38 | 答案是 $-\\frac{10}{3}$ 39 | 40 | 41 | 问题 5. 设向量 $\\vec{a}, \\vec{b}$ 不平行, 向量 $\\lambda \\vec{a}+\\vec{b}$ 与 $\\vec{a}+2 \\vec{b}$ 平行, 则实数 $\\lambda=(\\quad)$. 42 | 问题 5的解析: $\\because$ 向量 $\\vec{a}, \\vec{b}$ 不平行, 向量 $\\lambda \\vec{a}+\\vec{b}$ 与 $\\vec{a}+2 \\vec{b}$ 平行, 43 | $\\therefore \\lambda \\vec{a}+\\vec{b}=t(\\vec{a}+2 \\vec{b})=t \\vec{a}+2 t \\vec{b}$ 44 | $\\therefore\\left\\{\\begin{array}{c}\\lambda=\\mathrm{t} \\\\ 1=2 \\mathrm{t},\\end{array}\\right.$ 解得实数 $\\lambda=\\frac{1}{2}$. 45 | 答案是 $\\frac{1}{2}$ 46 | """.strip() 47 | 48 | class CoTGaoKaoMathClozePrompt(FewShotPrompting): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | def format_prompt(self, task_input, task_output): 53 | prompt = f"{few_shot_prompt}\n\n\n问题 6. {task_input}\n问题 6的解析: {task_output}" 54 | return prompt.rstrip() 55 | 56 | def stop_words(self): 57 | return ["\n问题 "] 58 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_gaokao_mathqa_5_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """ 4 | 问题 1. 已知 $\\alpha, \\beta, \\gamma$ 是互不相同的锐角, 则在 $\\sin \\alpha \\cos \\beta, \\sin \\beta \\cos \\gamma, \\sin \\gamma \\cos \\alpha$ 三个值中, 大于 $\\frac{1}{2}$ 的个数的最大值是 ($\\quad$) 5 | 从以下选项中选择: (A)0 (B)1 (C)2 (D)3 6 | 问题 1的解析: 1. 如果 $\\alpha, \\beta, \\gamma$ 均小于 $60^\\circ$,那么他们的正弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。 7 | 2. 如果有一个角大于 $60^\\circ$,假设为 $\\alpha$,那么对应的正弦值大于 $\\frac{1}{2}$。此时,由于三角形内角和为 $180^\\circ$,所以 $\\beta + \\gamma < 120^\\circ$。这意味着 $\\beta, \\gamma$ 的余弦值均大于 $\\frac{1}{2}$,所以此时 $\\sin \\alpha \\cos \\beta > \\frac{1}{2}, \\sin \\beta \\cos \\gamma > \\frac{1}{2}$。 8 | 3. 如果有两个角大于 $60^\\circ$,例如 $\\alpha$ 和 $\\beta$,那么由于三角形内角和为 $180^\\circ$,我们可以得到 $\\gamma < 60^\\circ$,此时 $\\sin \\gamma < \\frac{1}{2}$。由于 $\\alpha$ 和 $\\beta$ 9 | 的余弦值都小于 $\\frac{1}{2}$,因此三个值中不可能有大于 $\\frac{1}{2}$ 的值。 10 | 4. 如果三个角都大于 $60^\\circ$,显然不符合题意。 11 | 综上所述,当有一个角大于 $60^\\circ$ 时,大于 $\\frac{1}{2}$ 的个数的最大值是 2。 12 | 答案是 C 13 | 14 | 15 | 问题 2. 正方体 $A B C D-A_{1} B_{1} C_{1} D_{1}$ 中, $B B_{1}$ 与平面 $A C D_{1}$ 所成角的余弦值为 ($\\qquad$) 16 | 从以下选项中选择: (A)$\\frac{\\sqrt{2}}{3}$ (B)$\\frac{\\sqrt{3}}{3}$ (C)$\\frac{2}{3}$ (D)$\\frac{\\sqrt{6}}{3}$ 17 | 问题 2的解析: 设上下底面的中心分别为 $\\mathrm{O}_{1}, \\mathrm{O}$, 设正方体的棱长等于 1 , 则 $O_{1} O$ 与平面 $A C D_{1}$ 所成角就是 $B B_{1}$ 与平面 $A C D_{1}$ 所成角, 即 $\\angle O_{1} O D_{1}$, 18 | 直角三角形 $\\mathrm{OO}_{1} \\mathrm{D}_{1}$ 中, $\\cos \\angle \\mathrm{O}_{1} \\mathrm{OD}_{1}=\\frac{\\mathrm{O}_{1} \\mathrm{O}}{\\mathrm{OD}_{1}}=\\frac{\\frac{1}{\\sqrt{6}}}{2}=\\frac{\\sqrt{6}}{3}$. 19 | 答案是 C 20 | 21 | 22 | 问题 3. 设函数 $f(x)=\\left\\{\\begin{array}{ll}1+\\log _{2}(2-x), & x<1 \\ 2^{x-1}, & x \\geqslant 1,\\end{array}\\right.$ 则 $f(-2)+f\\left(\\log _{2} 12\\right)=$ ($\\qquad$) 23 | 从以下选项中选择: (A)3 (B)6 (C)9 (D)12 24 | 问题 3的解析: 首先,我们可以根据定义计算 $f(-2)$ 和 $f(\\log_2 12)$: 25 | $f(-2)=1+\\log_2(2-(-2))=1+\\log_2 4=3$ 26 | $f(\\log_2 12)=2^{\\log_2 12-1}=6$ 27 | 因此,$f(-2)+f(\\log_2 12)=3+6=9$。 28 | 答案是 C 29 | 30 | 31 | 问题 4. 已知函数 $f(x)=a x^{3}-3 x^{2}+1$, 若 $f(x)$ 存在唯一的零点 $x_{0}$, 且 $x_{0}>$ 0 , 则实数 $\\mathrm{a}$ 的取值范围是 ($\\qquad$) 32 | 从以下选项中选择: (A)$(1,+\\infty)$ (B)$(2,+\\infty)$ (C)$(-\\infty,-1)$ (D)$(-\\infty,-2)$ 33 | 问题 4的解析: 首先,我们可以通过求出函数的导函数 $f'(x)$ 来判断函数在 $x>0$ 区间内的单调性。在这里,我们求出导函数 $f'(x)$ 为 $f'(x)=3ax^2-6x$。 34 | 然后,我们需要求出导函数 $f'(x)$ 的零点,以确定函数 $f(x)$ 在 $x>0$ 区间内的单调性。导函数 $f'(x)$ 的零点为 $x=0$ 和 $x=\\frac{2}{\\sqrt{a}}$。注意到 $x>0$,所以我们得到 $a<0$。此外,由于函数 $f(x)$ 在 $x=0$ 处的函数值为 $1$,因此不能有 $a=\\frac{4}{3}$。 35 | 综上所述,当 $a$ 的取值范围为 $a<-\\frac{4}{3}$ 时,函数 $f(x)$ 在 $x>0$ 区间内是单调递减的,此时存在唯一的零点 $x_0$。因此,答案为 $(-\\infty,-2)$。 36 | 答案是 D 37 | 38 | 39 | 问题 5. 设 $\\left\\{a_{n}\\right\\}$ 是公差不为 0 的无穷等差数列, 则“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的 ($\\quad$) 40 | 从以下选项中选择: (A)充分而不必要条件 (B)必要而不充分条件 (C)充分必要条件 (D)既不充分也不必要条件 41 | 问题 5的解析: 首先,我们可以通过举例来判断该条件是充分还是必要条件。如果一个数列递增,那么它的公差一定大于 0,也就是存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$。因此,“ $\\left\\{a_{n}\\right\\}$ 为递增数列”是“存在正整数 $N_{0}$, 当 $n>N_{0}$ 时, $a_{n}>0$ ”的必要条件。 42 | 接下来,我们需要判断是否充分。也就是说,如果存在正整数 $N_{0}$,当 $n>N_{0}$ 时,$a_{n}>0$,那么能否得出“ $\\left\\{a_{n}\\right\\}$ 为递增数列”这一结论。 43 | 答案是肯定的。因为如果 $a_{n}>0$,那么 $a_{n+1}-a_{n}>0$,即公差大于 0,因此该数列是递增的。因此,该条件是充分条件。 44 | 综上所述,选项为 (C) 充分必要条件。 45 | 答案是 C 46 | """.strip() 47 | 48 | class CoTGaoKaoMathQAPrompt(FewShotPrompting): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | def format_prompt(self, task_input, task_output): 53 | prompt = f"{few_shot_prompt}\n\n\n问题 6. {task_input}\n问题 6的解析: {task_output}" 54 | return prompt.rstrip() 55 | 56 | def stop_words(self): 57 | return ["\n问题 "] 58 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_gsm_8_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """ 4 | Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 5 | A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6. 6 | 7 | 8 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 9 | A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5. 10 | 11 | 12 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 13 | A: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39. 14 | 15 | 16 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 17 | A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8. 18 | 19 | 20 | Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 21 | A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9. 22 | 23 | 24 | Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 25 | A: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29. 26 | 27 | 28 | Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 29 | A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33. 30 | 31 | 32 | Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 33 | A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8. 34 | """.strip() 35 | 36 | class CoTGSMPrompt(FewShotPrompting): 37 | def __init__(self): 38 | super().__init__() 39 | 40 | def format_prompt(self, task_input, task_output): 41 | prompt = f"{few_shot_prompt}\n\n\nQ: {task_input}\nA: {task_output}" 42 | return prompt.rstrip() 43 | 44 | def stop_words(self): 45 | return ["\nQ:"] 46 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_math_sat_4_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """ 4 | Problem: 5 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$. 6 | What of the following is the right choice? Explain your answer. 7 | (A) [-5,-2), (B) [2,5), (C) [-2,-5), (D) [5,2) 8 | Solution: 9 | The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. 10 | Therefore, the domain of the expression is $\\boxed{[2,5)}$. 11 | Final Answer: The final answer is (B). I hope it is correct. 12 | 13 | Problem: 14 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$ 15 | What of the following is the right choice? Explain your answer. 16 | (A) 14, (B) 4, (C) 2, (D) 24 17 | Solution: 18 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$ 19 | Final Answer: The final answer is (D). I hope it is correct. 20 | 21 | Problem: 22 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight? 23 | What of the following is the right choice? Explain your answer. 24 | (A) 12, (B) 20, (C) 16, (D) 15 25 | Solution: 26 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 27 | 30n&=480\\\\ 28 | \\Rightarrow\\qquad n&=480/30=\\boxed{16} 29 | \\end{align*} 30 | Final Answer: The final answer is (C). I hope it is correct. 31 | 32 | Problem: 33 | If the system of equations 34 | 35 | \\begin{align*} 36 | 6x-4y&=a,\\\\ 37 | 6y-9x &=b. 38 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is 39 | nonzero. 40 | What of the following is the right choice? Explain your answer. 41 | (A) $-\\frac{2}{3}$, (B) $\\frac{2}{3}$, (C) $\\frac{1}{3}$, (D) $\\frac{4}{9}$ 42 | Solution: 43 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain 44 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have 45 | 46 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$ 47 | Final Answer: The final answer is (A). I hope it is correct. 48 | """.strip() 49 | 50 | class CoTSATPrompt(FewShotPrompting): 51 | def __init__(self): 52 | super().__init__() 53 | 54 | def format_prompt(self, task_input, task_output): 55 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\nSolution:\n{task_output}" 56 | return prompt.rstrip() 57 | 58 | def stop_words(self): 59 | return ["\nProblem:"] 60 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_minerva_math_4_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """Problem: 4 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.} 5 | 6 | Solution: 7 | The expressions inside each square root must be non-negative. 8 | Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. 9 | Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. 10 | Therefore, the domain of the expression is $\\boxed{[2,5)}$. 11 | Final Answer: The final answer is $[2,5)$. I hope it is correct. 12 | 13 | Problem: 14 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$ 15 | 16 | Solution: 17 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$ 18 | Final Answer: The final answer is $24$. I hope it is correct. 19 | 20 | Problem: 21 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight? 22 | 23 | Solution: 24 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 25 | 30n&=480\\\\ 26 | \\Rightarrow\\qquad n&=480/30=\\boxed{16} 27 | \\end{align*} 28 | Final Answer: The final answer is $16$. I hope it is correct. 29 | 30 | Problem: 31 | If the system of equations 32 | 33 | \\begin{align*} 34 | 6x-4y&=a,\\\\ 35 | 6y-9x &=b. 36 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero. 37 | 38 | Solution: 39 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain 40 | 41 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have 42 | 43 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$ 44 | Final Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.""" 45 | 46 | class MinervaMathPrompt(FewShotPrompting): 47 | def __init__(self): 48 | super().__init__() 49 | 50 | def format_prompt(self, task_input, task_output): 51 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\n\nSolution:\n{task_output}" 52 | return prompt.rstrip() 53 | 54 | def stop_words(self): 55 | return ["\nProblem:"] 56 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_mmlu_stem_4_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """Problem: 4 | Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$. 5 | What of the following is the right choice? Explain your answer. 6 | (A) [-5,-2), (B) [2,5), (C) [-2,-5), (D) [5,2) 7 | Solution: 8 | The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. 9 | Therefore, the domain of the expression is $\\boxed{[2,5)}$. 10 | Final Answer: The final answer is (B). I hope it is correct. 11 | 12 | Problem: 13 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$ 14 | What of the following is the right choice? Explain your answer. 15 | (A) 14, (B) 4, (C) 2, (D) 24 16 | Solution: 17 | We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$ 18 | Final Answer: The final answer is (D). I hope it is correct. 19 | 20 | Problem: 21 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight? 22 | What of the following is the right choice? Explain your answer. 23 | (A) 12, (B) 20, (C) 16, (D) 15 24 | Solution: 25 | If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 26 | 30n&=480\\\\ 27 | \\Rightarrow\\qquad n&=480/30=\\boxed{16} 28 | \\end{align*} 29 | Final Answer: The final answer is (C). I hope it is correct. 30 | 31 | Problem: 32 | If the system of equations 33 | 34 | \\begin{align*} 35 | 6x-4y&=a,\\\\ 36 | 6y-9x &=b. 37 | \\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is 38 | nonzero. 39 | What of the following is the right choice? Explain your answer. 40 | (A) $-\\frac{2}{3}$, (B) $\\frac{2}{3}$, (C) $\\frac{1}{3}$, (D) $\\frac{4}{9}$ 41 | Solution: 42 | If we multiply the first equation by $-\\frac{3}{2}$, we obtain 43 | $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have 44 | 45 | $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$ 46 | Final Answer: The final answer is (A). I hope it is correct.""" 47 | 48 | class MMLUSTEMPrompt(FewShotPrompting): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | def format_prompt(self, task_input, task_output): 53 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\nSolution:\n{task_output}" 54 | return prompt.rstrip() 55 | 56 | def stop_words(self): 57 | return ["\nProblem:"] 58 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/cot_ocwcourses_4_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """ 4 | Problem: 5 | Subproblem 0: What is the net charge of arginine in a solution of $\\mathrm{pH} 1.0$? 6 | Please format your answer as +n or -n. 7 | Solution: 8 | The answer is +2. 9 | Final answer: The final answer is $\\boxed{+2}$. I hope it is correct. 10 | 11 | Problem: 12 | Subproblem 0: Let $z = 1 + \\sqrt{3} i$. Find $a, b$ that satisfy the equation 13 | $z^4 = a + bi$. Express your answer as the ordered pair $(a,b)$. 14 | Solution: 15 | $z^{4}$ has argument $4 \\pi / 3$ and radius 16 , so it's equal to $-8-8 \\sqrt{3} i$. 16 | Thus $a = -8, b = -8\\sqrt 3$, and our answer is $(-8, -8\\sqrt{3})$. 17 | Final answer: The final answer is $\\boxed{(-8, -8\\sqrt{3})}$. I hope it is correct. 18 | 19 | Problem: 20 | Preamble: For each Laplace Transform \\(Y(s)\\), find the function \\(y(t)\\): 21 | Subproblem 0: 22 | \\[Y(s)=\\frac{1}{(s+a)(s+b)}\\] 23 | Solution: 24 | We can simplify with partial fractions: 25 | \\[Y(s)=\\frac{1}{(s+a)(s+b)}=\\frac{C}{s+a}+\\frac{D}{s+b}\\] 26 | find the constants 27 | \\(C\\) and \\(D\\) by setting \\(s=-a\\) and \\(s=-b\\) 28 | \\[ 29 | \\begin{aligned} 30 | \\frac{1}{(s+a)(s+b)} &=\\frac{C}{s+a}+\\frac{D}{s+b} \\\\ 31 | 1 &=C(s+b)+D(s+a) \\\\ 32 | C &=\\frac{1}{b-a} \\\\ 33 | D &=\\frac{1}{a-b} 34 | \\end{aligned} 35 | \\] 36 | therefore 37 | \\[ 38 | Y(s)=\\frac{1}{b-a} \\frac{1}{s+a}-\\frac{1}{b-a} \\frac{1}{s+b} 39 | \\] 40 | By looking up the inverse Laplace Transform of \\(\\frac{1}{s+b}\\), we find the total 41 | solution \\(y(t)\\) 42 | \\[ 43 | y(t)=\\frac{e^{-a t}-e^{-b t}}{b-a} 44 | \\]. 45 | Final answer: The final answer is $\\boxed{\\frac{e^{-a t}-e^{-b t}}{b-a}}$. I hope it is correct. 46 | 47 | Problem: 48 | Preamble: The following subproblems refer to the differential equation 49 | $\\ddot{x}+b \\dot{x}+x=0$. 50 | Subproblem 0: What is the characteristic polynomial $p(s)$ of 51 | $\\ddot{x}+b \\dot{x}+x=0$? 52 | Solution: 53 | The characteristic polynomial is $p(s)=s^{2}+b s+1$. 54 | Final answer: The final answer is $\\boxed{s^{2}+b s+1}$. I hope it is correct. 55 | """.strip() 56 | 57 | few_shot_prompt = """ 58 | Problem: 59 | Subproblem 0: What is the net charge of arginine in a solution of $\\mathrm{pH} 1.0$? 60 | Please format your answer as +n or -n. 61 | Solution: 62 | The answer is +2. 63 | Final answer: The final answer is +2. I hope it is correct. 64 | 65 | Problem: 66 | Subproblem 0: Let $z = 1 + \\sqrt{3} i$. Find $a, b$ that satisfy the equation 67 | $z^4 = a + bi$. Express your answer as the ordered pair $(a,b)$. 68 | Solution: 69 | $z^{4}$ has argument $4 \\pi / 3$ and radius 16 , so it's equal to $-8-8 \\sqrt{3} i$. 70 | Thus $a = -8, b = -8\\sqrt 3$, and our answer is $\\boxed{(-8, -8\\sqrt{3})}$. 71 | Final answer: The final answer is (-8, -8\\sqrt{3}). I hope it is correct. 72 | 73 | Problem: 74 | Preamble: For each Laplace Transform \\(Y(s)\\), find the function \\(y(t)\\): 75 | Subproblem 0: 76 | \\[Y(s)=\\boxed{\\frac{1}{(s+a)(s+b)}}\\] 77 | Solution: 78 | We can simplify with partial fractions: 79 | \\[Y(s)=\\frac{1}{(s+a)(s+b)}=\\frac{C}{s+a}+\\frac{D}{s+b}\\] 80 | find the constants 81 | \\(C\\) and \\(D\\) by setting \\(s=-a\\) and \\(s=-b\\) 82 | \\[ 83 | \\begin{aligned} 84 | \\frac{1}{(s+a)(s+b)} &=\\frac{C}{s+a}+\\frac{D}{s+b} \\\\ 85 | 1 &=C(s+b)+D(s+a) \\\\ 86 | C &=\\frac{1}{b-a} \\\\ 87 | D &=\\frac{1}{a-b} 88 | \\end{aligned} 89 | \\] 90 | therefore 91 | \\[ 92 | Y(s)=\\frac{1}{b-a} \\frac{1}{s+a}-\\frac{1}{b-a} \\frac{1}{s+b} 93 | \\] 94 | By looking up the inverse Laplace Transform of \\(\\frac{1}{s+b}\\), we find the total 95 | solution \\(y(t)\\) 96 | \\[ 97 | y(t)=\\boxed{\\frac{1}{b-a}\\left(e^{-a t}-e^{-b t}\\right)} 98 | \\]. 99 | Final answer: The final answer is \\[\\frac{1}{b-a}\\left(e^{-a t}-e^{-b t}\\right)\\]. I hope it is correct. 100 | 101 | Problem: 102 | Preamble: The following subproblems refer to the differential equation 103 | $\\ddot{x}+b \\dot{x}+x=0$. 104 | Subproblem 0: What is the characteristic polynomial $p(s)$ of 105 | $\\ddot{x}+b \\dot{x}+x=0$? 106 | Solution: 107 | The characteristic polynomial is $p(s)=\\boxed{s^{2}+b s+1}$. 108 | Final answer: The final answer is $s^{2}+b s+1$. I hope it is correct. 109 | """.strip() 110 | 111 | class OCWCoursesPrompt(FewShotPrompting): 112 | def __init__(self): 113 | super().__init__() 114 | 115 | def format_prompt(self, task_input, task_output): 116 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\nSolution:\n{task_output}" 117 | return prompt.rstrip() 118 | 119 | def stop_words(self): 120 | return ["\nProblem:"] 121 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/few_shot_prompting.py: -------------------------------------------------------------------------------- 1 | class FewShotPrompting: 2 | def __init__(self): 3 | pass 4 | 5 | def format_prompt(self, task_input, task_output): 6 | pass 7 | 8 | def stop_words(self): 9 | pass 10 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/minif2f_isabelle.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = { 4 | 'numbertheory': """Informal: 5 | (*### Problem 6 | 7 | Find the minimum value of $\\frac{9x^2\\sin^2 x + 4}{x\\sin x}$ for $0 < x < \\pi$. Show that it is 12. 8 | 9 | ### Solution 10 | 11 | Let $y = x \\sin x$. It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}. 12 | It is trivial to see that $y > 0$. 13 | Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$. 14 | This can be done by the sum of squares method.*) 15 | 16 | Formal: 17 | theorem aime_1983_p9: 18 | fixes x::real 19 | assumes "0 ((9 * (x^2 * (sin x)^2)) + 4) / (x * sin x)" 21 | proof - 22 | (* Let $y = x \\sin x$. *) 23 | define y where "y=x * sin x" 24 | (* It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}. *) 25 | have "12 \\ (9 * y^2 + 4) / y" 26 | proof - 27 | (* It is trivial to see that $y > 0$. *) 28 | have c0: "y > 0" 29 | sledgehammer 30 | (* Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$. *) 31 | have "(9 * y^2 + 4) \\ 12 * y" 32 | sledgehammer 33 | then show ?thesis 34 | sledgehammer 35 | qed 36 | then show ?thesis 37 | sledgehammer 38 | qed 39 | 40 | 41 | 42 | Informal: 43 | (*### Problem 44 | 45 | Find the greatest common factor of 180 and 168. Show that it is 12. 46 | 47 | ### Solution 48 | 49 | This is true by simple evaluation.*) 50 | 51 | Formal: 52 | theorem mathd_numbertheory_188: 53 | "gcd 180 168 = (12::nat)" 54 | sledgehammer 55 | 56 | 57 | 58 | Informal: 59 | (*### Problem 60 | 61 | Show that for positive integer n, 2 divides $4^n$. 62 | 63 | ### Solution 64 | 65 | Since n is positive, we can find a natural number m where $m+1=n$. 66 | Then we can show that 2 divides $4^{m+1}$. The conclusion thus follows.*) 67 | 68 | Formal: 69 | theorem numbertheory_2dvd4expn: 70 | fixes n :: nat 71 | assumes h0 : "n \\ 0" 72 | shows "(2::nat) dvd 4^n" 73 | proof - 74 | obtain m::nat where c0: "m+1=n" 75 | sledgehammer 76 | have "(2::nat) dvd 4^(m+1)" sledgehammer 77 | then show ?thesis unfolding c0 sledgehammer 78 | qed 79 | 80 | 81 | 82 | Informal: 83 | (*### Problem 84 | 85 | What is the remainder when $1 + 2 + 3 + 4 + \\dots + 9 + 10$ is divided by 9? Show that it is 1. 86 | 87 | ### Solution 88 | 89 | This is true by simple evaluation.*) 90 | 91 | Formal: 92 | theorem mathd_numbertheory_466: 93 | "(\\ k< 11. k) mod 9 = (1::nat)" 94 | sledgehammer 95 | 96 | 97 | 98 | Informal: 99 | (*### Problem 100 | 101 | If $321_{b}$ is equal to the base 10 integer 57, find $b$ given that $b>0$. Show that it is 4. 102 | 103 | ### Solution 104 | 105 | Converting $321_{b}$ to base 10 and setting it equal to 57, we find that \\begin{align*} 3(b^2)+2(b^1)+1(b^0)&=57 106 | \\\\ 3b^2+2b+1&=57 107 | \\\\\\Rightarrow\\qquad 3b^2+2b-56&=0 108 | \\\\\\Rightarrow\\qquad (3b+14)(b-4)&=0 109 | \\end{align*}This tells us that $b$ is either $-\\frac{14}{3}$ or $4$. We know that $b>0$, so $b=4$.*) 110 | 111 | Formal: 112 | theorem mathd_numbertheory_48: 113 | fixes b :: real 114 | assumes h0 : "00$, so $b=4$. *) 129 | then show ?thesis using h0 sledgehammer 130 | qed 131 | 132 | 133 | 134 | Informal: 135 | (*### Problem 136 | 137 | When Rachel divides her favorite number by 7, she gets a remainder of 5. What will the remainder be if she multiplies her favorite number by 5 and then divides by 7? Show that it is 4. 138 | 139 | ### Solution 140 | 141 | Let $n$ be Rachel's favorite number. 142 | Then $n \\equiv 5 \\pmod{7}$, so $5n \\equiv 5 \\cdot 5 \\equiv 25 \\equiv 4 \\pmod{7}$. 143 | *) 144 | 145 | Formal: 146 | theorem mathd_numbertheory_335: 147 | fixes n :: nat 148 | assumes h0 : "n mod 7 = 5" 149 | shows "(5 * n) mod 7 = 4" 150 | proof - 151 | (* Then $n \\equiv 5 \\pmod{7}$, so $5n \\equiv 5 \\cdot 5 \\equiv 25 \\equiv 4 \\pmod{7}$. *) 152 | have c0:"(5 * n) mod 7 = (5 * 5) mod 7" using h0 153 | sledgehammer 154 | then have "\\ = 4" sledgehammer 155 | then have "(5 * n) mod 7 = 4" using c0 sledgehammer 156 | then show ?thesis sledgehammer 157 | qed 158 | 159 | 160 | 161 | Informal: 162 | (*### Problem 163 | 164 | What positive two-digit integer is exactly twice the sum of its digits? Show that it is 18. 165 | 166 | ### Solution 167 | 168 | We simplify $10a + b = 2(a+b)$ to get $8a = b$. 169 | Since $a$ is at least 1, $b$ is at least 8. 170 | We know $b$ is 8 since $8a = b$ and $a$ is a natural number. 171 | Hence $a$ is 1. 172 | The two-digit integer is hence $18$. 173 | *) 174 | 175 | Formal: 176 | theorem mathd_numbertheory_284: 177 | fixes a b :: nat 178 | assumes h0 : "1\\a \\ a \\9 \\ b \\9" 179 | and h1 : "10 * a + b = 2 * (a+b)" 180 | shows "10 * a + b = 18" 181 | proof - 182 | (* We simplify $10a + b = 2(a+b)$ to get $8a = b$. *) 183 | have c0: "8 * a = b" using h1 sledgehammer 184 | (* Since $a$ is at least 1, $b$ is at least 8. *) 185 | hence "b \\ 8" using h0 sledgehammer 186 | (* We know $b$ is 8 since $8a = b$ and $a$ is a natural number. *) 187 | hence c1:"b = 8" using h0 c0 188 | sledgehammer 189 | (* Hence $a$ is 1. *) 190 | hence "a = 1" using c0 sledgehammer 191 | (* The two-digit integer is hence $18$. *) 192 | then show ?thesis using c1 sledgehammer 193 | qed 194 | 195 | 196 | 197 | """.strip(), 198 | "other": """Informal: 199 | (*### Problem 200 | 201 | Find the minimum value of $\\frac{9x^2\\sin^2 x + 4}{x\\sin x}$ for $0 < x < \\pi$. Show that it is 12. 202 | 203 | ### Solution 204 | 205 | Let $y = x \\sin x$. It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}. 206 | It is trivial to see that $y > 0$. 207 | Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$. 208 | This can be done by the sum of squares method.*) 209 | 210 | Formal: 211 | theorem aime_1983_p9: 212 | fixes x::real 213 | assumes "0 ((9 * (x^2 * (sin x)^2)) + 4) / (x * sin x)" 215 | proof - 216 | (* Let $y = x \\sin x$. *) 217 | define y where "y=x * sin x" 218 | (* It suffices to show that $12 \\leq \\frac{9y^2 + 4}{y}. *) 219 | have "12 \\ (9 * y^2 + 4) / y" 220 | proof - 221 | (* It is trivial to see that $y > 0$. *) 222 | have c0: "y > 0" 223 | sledgehammer 224 | (* Then one can multiply both sides by $y$ and it suffices to show $12y \\leq 9y^2 + 4$. *) 225 | have "(9 * y^2 + 4) \\ 12 * y" 226 | sledgehammer 227 | then show ?thesis 228 | sledgehammer 229 | qed 230 | then show ?thesis 231 | sledgehammer 232 | qed 233 | 234 | 235 | 236 | Informal: 237 | (*### Problem 238 | 239 | Show that for any four complex numbers a, b, c, and d, $(a-d)(a-c)(a-b) = -(((a^2 - a(b+c)) + bc) * d) + (a^2 - a(b+c) + bc) * a$. 240 | 241 | ### Solution 242 | 243 | We first see that $a^2 = a * a$ trivially. 244 | Unfolding this, the main equation holds true when terms are rearranged.*) 245 | 246 | Formal: 247 | theorem algebra_3rootspoly_amdtamctambeqnasqmbpctapcbtdpasqmbpctapcbta: 248 | fixes a b c d :: complex 249 | shows "(a-d) * (a-c) * (a-b) = -(((a^2 - (b+c) * a) + c * b) * d) + (a^2 - (b+c) * a + c * b) * a" 250 | proof - 251 | (* We first see that $a^2 = a * a$ trivially. *) 252 | have t0: "a^2 = a * a" 253 | using power2_eq_square 254 | sledgehammer 255 | (* Unfolding this, the main equation holds true when terms are rearranged. *) 256 | show ?thesis unfolding t0 257 | sledgehammer 258 | qed 259 | 260 | 261 | 262 | Informal: 263 | (*### Problem 264 | 265 | Find the greatest common factor of 180 and 168. Show that it is 12. 266 | 267 | ### Solution 268 | 269 | This is true by simple evaluation.*) 270 | 271 | Formal: 272 | theorem mathd_numbertheory_188: 273 | "gcd 180 168 = (12::nat)" 274 | sledgehammer 275 | 276 | 277 | 278 | Informal: 279 | (*### Problem 280 | 281 | For how many positive integers $n$ is $n^2 - 3n + 2$ a [[prime]] number? 282 | 283 | $\\mathrm{(A)}\\ \\text{none} 284 | \\qquad\\mathrm{(B)}\\ \\text{one} 285 | \\qquad\\mathrm{(C)}\\ \\text{two} 286 | \\qquad\\mathrm{(D)}\\ \\text{more\\ than\\ two,\\ but\\ finitely\\ many} 287 | \\qquad\\mathrm{(E)}\\ \\text{infinitely\\ many}$ Show that it is \\mathrm{(B)}\\ \\text{one}. 288 | 289 | ### Solution 290 | 291 | Factoring, we get $n^2 - 3n + 2 = (n-2)(n-1)$. 292 | Either $n-1$ or $n-2$ is odd, and the other is even. 293 | Their product must yield an even number. 294 | The only prime that is even is $2$, which is when $n$ is $3$ or $0$. 295 | Since $0$ is not a positive number, the answer is $\\mathrm{(B)}\\ \\text{one}$.*) 296 | 297 | Formal: 298 | theorem amc12b_2002_p3: 299 | fixes n ::nat 300 | assumes "n>0" 301 | and prime:"prime (n^2+2-3*n)" 302 | shows "n=3" 303 | proof - 304 | have "n>2" 305 | proof (rule ccontr) 306 | assume "\\ 2 < n" 307 | then have "n=1 \\ n=2" using \\n>0\\ sledgehammer 308 | then show False using prime[THEN prime_gt_1_nat] 309 | sledgehammer 310 | qed 311 | (* Factoring, we get $n^2 - 3n + 2 = (n-2)(n-1)$. *) 312 | then have "n^2+2-3*n = (n-1) * (n-2)" 313 | unfolding power2_eq_square 314 | sledgehammer 315 | (* Either $n-1$ or $n-2$ is odd, and the other is even. 316 | Their product must yield an even number. 317 | The only prime that is even is $2$, which is when $n$ is $3$ or $0$. 318 | Since $0$ is not a positive number, the answer is $\\mathrm{(B)}\\ \\text{one}$.*) 319 | then have "prime ((n-1) * (n-2))" 320 | using prime sledgehammer 321 | then have "n-1=1 \\ n-2 = 1" 322 | using prime_product sledgehammer 323 | with \\n>2\\ 324 | show "n=3" sledgehammer 325 | qed 326 | 327 | 328 | 329 | Informal: 330 | (*### Problem 331 | 332 | For a positive real number a, show that $10a\\leq 28a^2+1$. 333 | 334 | ### Solution 335 | 336 | It suffices to show $0\\leq 28a^2 - 10a + 1$. 337 | First, consider completing the square for $28a^2 - 10a$ and observe that $(a - \\frac{5}{28})^2 = a^2 - \\frac{10}{28}a + (5/28)^2$. 338 | Since $0\\leq (a - \\frac{5}{28})^2$, we have $0\\leq a^2 - \\frac{10}{28}a + (5/28)^2$. 339 | Multiplying by 28 and simplifying terms gives $0\\leq 28*a^2 - 10*a + (25/28)$. 340 | Since $25/28 < 1$, the result follows.*) 341 | 342 | Formal: 343 | theorem algebra_binomnegdiscrineq_10alt28asqp1: 344 | fixes a :: real 345 | shows "10 * a \\ 28 * a^2 + 1" 346 | proof - 347 | (* it suffices to show $0\\leq 28a^2 - 10a + 1$ *) 348 | have c0: "0 \\ 28*a^2 - 10*a + 1" 349 | proof - 350 | (* observe that $(a - \\frac{5}{28})^2 = a^2 - \\frac{10}{28}a + (5/28)^2$ *) 351 | have c1: "(a - (5/28))^2 = a^2 - 10/28*a + (5/28)^2" 352 | sledgehammer 353 | (* we have $0\\leq a^2 - \\frac{10}{28}a + (5/28)^2$ *) 354 | then have c2: "0 \\ a^2 - 10/28*a + (5/28)^2" using c1 355 | sledgehammer 356 | (* Multiplying by 28 and simplifying terms gives $0\\leq 28*a^2 - 10*a + (25/28)$ *) 357 | then have c3: "0 \\ 28*a^2 - 10*a + 28*((5/28)^2)" using c2 358 | sledgehammer 359 | then have c4: "0 \\ 28*a^2 - 10*a + 28*((5/28)*(5/28))" using c3 360 | sledgehammer 361 | then have c5: "0 \\ 28*a^2 - 10*a + (25/28)" using c4 362 | sledgehammer 363 | (* Since $25/28 < 1$, the result follows. *) 364 | then show ?thesis using c5 365 | sledgehammer 366 | qed 367 | then show ?thesis 368 | sledgehammer 369 | qed 370 | 371 | 372 | 373 | Informal: 374 | (*### Problem 375 | 376 | Show that for any complex number a, $(a-10)(a+11) = a^2 + a - 110$. 377 | 378 | ### Solution 379 | 380 | We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$. 381 | This equals $a^2 + a - 10*11 = a^2 + a - 110$.*) 382 | 383 | Formal: 384 | theorem algebra_2rootsintpoly_am10tap11eqasqpam110: 385 | fixes a :: complex 386 | shows "(a-10) * (a+11) = a^2 + a -110" 387 | proof - 388 | (* We first expand all terms of the left hand side to get $a^2 - 10a + 11a - 10*11$. *) 389 | have "(a-10) * (a+11) = a^2 - 10*a + 11*a - 10 *11" 390 | sledgehammer 391 | (* This equals $a^2 + a - 10*11 = a^2 + a - 110$. *) 392 | also have "\\ = a^2 + a - 10 * 11" 393 | sledgehammer 394 | also have "\\ = a^2 + a - 110" 395 | sledgehammer 396 | finally show ?thesis 397 | sledgehammer 398 | qed 399 | 400 | 401 | 402 | """.strip() 403 | } 404 | 405 | class MiniF2FIsabellePrompt(FewShotPrompting): 406 | def __init__(self): 407 | super().__init__() 408 | 409 | def format_prompt(self, task_input, task_output): 410 | if 'numbertheory' in task_input.split("Formal:", 1)[1]: 411 | tag = 'numbertheory' 412 | else: 413 | tag = 'other' 414 | prompt = f"{few_shot_prompt[tag].strip()}\n\n\n\nInformal:\n{task_input.strip()}\n{task_output.strip()}" 415 | return prompt.rstrip() 416 | 417 | def stop_words(self): 418 | return ["\nInformal:"] 419 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/pal_gsm_8_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = ''' 4 | Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 5 | 6 | # solution in Python: 7 | 8 | 9 | def solution(): 10 | """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?""" 11 | money_initial = 23 12 | bagels = 5 13 | bagel_cost = 3 14 | money_spent = bagels * bagel_cost 15 | money_left = money_initial - money_spent 16 | result = money_left 17 | return result 18 | 19 | 20 | 21 | 22 | 23 | Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 24 | 25 | # solution in Python: 26 | 27 | 28 | def solution(): 29 | """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?""" 30 | golf_balls_initial = 58 31 | golf_balls_lost_tuesday = 23 32 | golf_balls_lost_wednesday = 2 33 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday 34 | result = golf_balls_left 35 | return result 36 | 37 | 38 | 39 | 40 | 41 | Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 42 | 43 | # solution in Python: 44 | 45 | 46 | def solution(): 47 | """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?""" 48 | computers_initial = 9 49 | computers_per_day = 5 50 | num_days = 4 # 4 days between monday and thursday 51 | computers_added = computers_per_day * num_days 52 | computers_total = computers_initial + computers_added 53 | result = computers_total 54 | return result 55 | 56 | 57 | 58 | 59 | 60 | Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 61 | 62 | # solution in Python: 63 | 64 | 65 | def solution(): 66 | """Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?""" 67 | toys_initial = 5 68 | mom_toys = 2 69 | dad_toys = 2 70 | total_received = mom_toys + dad_toys 71 | total_toys = toys_initial + total_received 72 | result = total_toys 73 | return result 74 | 75 | 76 | 77 | 78 | 79 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 80 | 81 | # solution in Python: 82 | 83 | 84 | def solution(): 85 | """Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?""" 86 | jason_lollipops_initial = 20 87 | jason_lollipops_after = 12 88 | denny_lollipops = jason_lollipops_initial - jason_lollipops_after 89 | result = denny_lollipops 90 | return result 91 | 92 | 93 | 94 | 95 | 96 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 97 | 98 | # solution in Python: 99 | 100 | 101 | def solution(): 102 | """Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?""" 103 | leah_chocolates = 32 104 | sister_chocolates = 42 105 | total_chocolates = leah_chocolates + sister_chocolates 106 | chocolates_eaten = 35 107 | chocolates_left = total_chocolates - chocolates_eaten 108 | result = chocolates_left 109 | return result 110 | 111 | 112 | 113 | 114 | 115 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 116 | 117 | # solution in Python: 118 | 119 | 120 | def solution(): 121 | """If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?""" 122 | cars_initial = 3 123 | cars_arrived = 2 124 | total_cars = cars_initial + cars_arrived 125 | result = total_cars 126 | return result 127 | 128 | 129 | 130 | 131 | 132 | Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 133 | 134 | # solution in Python: 135 | 136 | 137 | def solution(): 138 | """There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?""" 139 | trees_initial = 15 140 | trees_after = 21 141 | trees_added = trees_after - trees_initial 142 | result = trees_added 143 | return result 144 | '''.strip() 145 | 146 | class PALGSMPrompt(FewShotPrompting): 147 | def __init__(self): 148 | super().__init__() 149 | 150 | def format_prompt(self, task_input, task_output): 151 | prompt = f"{few_shot_prompt}\n\n\n\n\n\nQ: {task_input}\n\n# solution in Python:" 152 | return prompt.rstrip() 153 | 154 | def stop_words(self): 155 | return ["\nQ:", "Q: "] 156 | -------------------------------------------------------------------------------- /evaluation/few_shot_prompts/pal_math_4_shot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | few_shot_prompt = """Problem: 4 | Find the value of $x$ that satisfies $\\frac{\\sqrt{3x+5}}{\\sqrt{6x+5}}=\\frac{\\sqrt{5}}{3}$. Express your answer as a common fraction. 5 | 6 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object. 7 | ``` 8 | # Initialize x 9 | x = symbols('x') 10 | 11 | # Define the equation 12 | equation = Eq(sqrt(3*x + 5)/sqrt(6*x + 5), sqrt(5)/3) 13 | 14 | # Solve for x 15 | answer = solve(equation, x) 16 | ``` 17 | The imports required for this program are 18 | ``` 19 | from sympy import symbols, Eq, solve, sqrt 20 | ``` 21 | I hope my solution is correct. 22 | 23 | Problem: 24 | If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$ 25 | 26 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object. 27 | ``` 28 | # Given det(A) = 2 and det(B) = 12 29 | det_A = 2 30 | det_B = 12 31 | 32 | # Use the property det(AB) = det(A)*det(B) 33 | det_AB = det_A * det_B 34 | 35 | answer = det_AB 36 | ``` 37 | The imports required for this program are 38 | ``` 39 | 40 | ``` 41 | I hope my solution is correct. 42 | 43 | Problem: 44 | Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight? 45 | 46 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object. 47 | ``` 48 | # Calculate the total weight lifted initially, which is 2*20*12 pounds 49 | total_weight = 2 * 20 * 12 50 | 51 | # Since Terrell lifts two 15-pound weights, divide the total weight by 2 * 15 52 | repetitions = total_weight / (2*15) 53 | 54 | answer = n_value 55 | ``` 56 | The imports required for this program are 57 | ``` 58 | 59 | ``` 60 | I hope my solution is correct. 61 | 62 | Problem: 63 | If Anna flips 8 coins, what is the probability that she gets more heads than tails? 64 | 65 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object. 66 | ``` 67 | # There are 2**8 possible outcomes 68 | n = 8 69 | total_outcomes = 2 ** n 70 | 71 | # There are binom(n, k) ways to get k heads 72 | favorable_outcomes = 0 73 | for k in range((n // 2) + 1, n + 1): 74 | favorable_outcomes += math.comb(n, k) 75 | 76 | probability = favorable_outcomes / total_outcomes 77 | 78 | answer = probability 79 | ``` 80 | The imports required for this program are 81 | ``` 82 | import math 83 | ``` 84 | I hope my solution is correct. 85 | 86 | Problem: 87 | Evaluate $\\left\\lceil3\\left(6-\\frac12\\right)\\right\\rceil$. 88 | 89 | You are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object. 90 | ``` 91 | # Calculate 3 * (6 - 1/2) 92 | result = 3 * (6 - 0.5) 93 | 94 | # Apply the ceiling function 95 | ceiling_result = math.ceil(result) 96 | 97 | answer = ceiling_result 98 | ``` 99 | The imports required for this program are 100 | ``` 101 | import math 102 | ``` 103 | I hope my solution is correct.""" 104 | 105 | class PALMathPrompt(FewShotPrompting): 106 | def __init__(self): 107 | super().__init__() 108 | 109 | def format_prompt(self, task_input, task_output): 110 | prompt = f"{few_shot_prompt}\n\nProblem:\n{task_input}\n\nYou are an expert programmer. Solve the above mathematical problem by writing a Python program. Express your answer as a numeric type or a SymPy object.\n{task_output}" 111 | return prompt.rstrip() 112 | 113 | def stop_words(self): 114 | return ["\nProblem:", "Problem:"] 115 | -------------------------------------------------------------------------------- /evaluation/infer/run_cot_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import_path = os.path.abspath(__file__) 5 | for _ in range(2): 6 | import_path = os.path.dirname(import_path) 7 | sys.path.append(import_path) 8 | 9 | from tqdm import tqdm 10 | import json 11 | from copy import deepcopy 12 | from vllm import LLM, SamplingParams 13 | from pebble import ProcessPool 14 | from concurrent.futures import TimeoutError 15 | import random 16 | from eval.utils import generate_completions, load_hf_lm_and_tokenizer 17 | 18 | from transformers import AutoTokenizer 19 | 20 | from data_processing.answer_extraction import * 21 | from eval.eval_script import * 22 | from few_shot_prompts import * 23 | 24 | def evaluate(eval_fn, tasks, _timeout=15): 25 | with ProcessPool() as pool: 26 | timeout_cnt = 0 27 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result() 28 | labels = [] 29 | while True: 30 | try: 31 | labels.append(int(next(iterator))) 32 | except StopIteration: 33 | break 34 | except TimeoutError as error: 35 | labels.append(0) 36 | timeout_cnt += 1 37 | except Exception as error: 38 | print(error.traceback, flush=True) 39 | exit() 40 | return labels, timeout_cnt 41 | 42 | def infer(args, test_data): 43 | global tokenizer 44 | if tokenizer is None: 45 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True) 46 | 47 | if args.prompt_format == 'few_shot': 48 | assert args.few_shot_prompt is not None 49 | prompting = eval(args.few_shot_prompt)() 50 | 51 | prompts = [] 52 | for example in test_data: 53 | prompt = "" 54 | if args.prompt_format == 'few_shot': 55 | prompt = prompting.format_prompt(example['messages'][-2]['content'], example['messages'][-1]['content']) 56 | else: 57 | for mess in example['messages']: 58 | if args.prompt_format == 'sft': 59 | if mess['role'] == 'user': 60 | prompt += f"{tokenizer.eos_token}User: {mess['content'].strip()}\n\nAssistant:" 61 | elif mess['role'] == 'assistant': 62 | prompt += mess['content'].rstrip() 63 | else: 64 | raise NotImplementedError() 65 | prompt = prompt.lstrip() 66 | if args.prompt_format == 'sft' and prompt.startswith(tokenizer.eos_token): 67 | prompt = prompt[len(tokenizer.eos_token):].lstrip() 68 | example['prompt'] = prompt 69 | prompts.append(prompt.lstrip()) 70 | 71 | global model 72 | print("Loading model and tokenizer...") 73 | if args.use_vllm: 74 | if model is None: 75 | model = LLM(model=args.model_name_or_path, tokenizer=args.tokenizer_name_or_path, trust_remote_code=True, tensor_parallel_size=len(os.environ['CUDA_VISIBLE_DEVICES'].split(","))) 76 | eos_token = tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else '' 77 | stop_words = [eos_token] 78 | if args.prompt_format == 'few_shot': 79 | stop_words.extend(prompting.stop_words()) 80 | outputs = model.generate(prompts, SamplingParams(temperature=args.temperature, top_p=1.0, max_tokens=1024, n=1, stop=stop_words)) 81 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id 82 | outputs = [output.outputs[0].text for output in outputs] 83 | else: 84 | model, tokenizer = load_hf_lm_and_tokenizer( 85 | model_name_or_path=args.model_name_or_path, 86 | tokenizer_name_or_path=args.tokenizer_name_or_path, 87 | load_in_8bit=args.load_in_8bit, 88 | load_in_half=args.load_in_half, 89 | gptq_model=args.gptq 90 | ) 91 | 92 | stop_id_sequences = [] 93 | if tokenizer.eos_token_id is not None: 94 | stop_id_sequences = [[tokenizer.eos_token_id]] 95 | if args.prompt_format == 'few_shot': 96 | stop_id_sequences.extend([tokenizer.encode(word) for word in prompting.stop_words()]) 97 | outputs, finish_completion = generate_completions( 98 | model=model, 99 | tokenizer=tokenizer, 100 | prompts=prompts, 101 | max_new_tokens=512, 102 | batch_size=args.eval_batch_size, 103 | stop_id_sequences=stop_id_sequences if stop_id_sequences else None, 104 | end_of_generation_id_sequence=[tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else None 105 | ) 106 | 107 | if args.complete_partial_output: 108 | model_outputs = [example['messages'][-1]['content'] + output for example, output in zip(test_data, outputs)] 109 | else: 110 | model_outputs = outputs 111 | 112 | predictions = [eval(args.answer_extraction_fn)(item['messages'][-2]['content'], output, task='cot') for item, output in tqdm(zip(test_data, model_outputs), desc="extract answer", total=len(model_outputs))] 113 | assert len(model_outputs) > 0, f"{len(model_outputs)}" 114 | 115 | results = [] 116 | for example, output, pred in zip(test_data, model_outputs, predictions): 117 | item = deepcopy(example) 118 | item.update({ 119 | 'model_output': output, 120 | 'prediction': pred, 121 | }) 122 | results.append(item) 123 | return results 124 | 125 | def main(args): 126 | random.seed(42) 127 | 128 | print("Loading data...") 129 | test_data = [] 130 | with open(os.path.join(args.data_dir, f"train.jsonl" if args.infer_train_set else f"test.jsonl")) as fin: 131 | for line in fin: 132 | example = json.loads(line) 133 | messages = example['messages'] 134 | assert messages[-1]['role'] == 'assistant' 135 | if not args.complete_partial_output: 136 | example['reference'] = example.get('reference', '') or [mess['content'] for mess in messages if mess['role'] == 'assistant'] 137 | for mess in messages: 138 | if mess['role'] == 'assistant': 139 | mess['content'] = '' 140 | example['messages'] = messages 141 | test_data.append(example) 142 | 143 | if args.max_num_examples and len(test_data) > args.max_num_examples: 144 | test_data = random.sample(test_data, args.max_num_examples) 145 | 146 | if args.n_subsets > 1: 147 | assert args.subset_id >= 0 and args.subset_id < args.n_subsets 148 | test_data = [item for i, item in enumerate(test_data) if i % args.n_subsets == args.subset_id] 149 | 150 | if not test_data: 151 | return 152 | 153 | if not os.path.exists(args.save_dir): 154 | os.makedirs(args.save_dir, exist_ok=True) 155 | 156 | results = infer(args, test_data) 157 | 158 | labels, eval_timeout_cnt = evaluate(eval(args.eval_fn), results) 159 | for item, label in zip(results, labels): 160 | item['accuracy'] = label 161 | 162 | print("Calculating accuracy...") 163 | acc = 0 164 | for item in results: 165 | acc += item['accuracy'] 166 | print("output acc = {:.5f}".format(acc / len(results) * 100), flush=True) 167 | 168 | print(f"Timeout count >>> output eval = {eval_timeout_cnt}", flush=True) 169 | 170 | pred_fname = "predictions.json" 171 | if args.n_subsets > 1: 172 | pred_fname = f"predictions.{args.subset_id}.json" 173 | with open(os.path.join(args.save_dir, pred_fname), "w") as fout: 174 | json.dump(results, fout, ensure_ascii=True) 175 | 176 | metric_fname = "metrics.json" 177 | if args.n_subsets > 1: 178 | metric_fname = f"metrics.{args.subset_id}.json" 179 | with open(os.path.join(args.save_dir, metric_fname), "w") as fout: 180 | json.dump({ 181 | "n_samples": len(results), 182 | "accuracy": sum(item['accuracy'] for item in results) / len(results), 183 | }, fout, indent=4) 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument("--data_dir", type=str, default="data/mgsm") 188 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.") 189 | parser.add_argument("--save_dir", type=str, default="results/mgsm") 190 | parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.") 191 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.") 192 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.") 193 | parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.") 194 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 195 | parser.add_argument("--use_vllm", action="store_true") 196 | parser.add_argument("--load_in_half", action='store_true') 197 | parser.add_argument("--infer_train_set", action="store_true") 198 | parser.add_argument("--n_subsets", type=int, default=1) 199 | parser.add_argument("--subset_id", type=int, default=0) 200 | parser.add_argument("--temperature", type=float, default=0.0) 201 | parser.add_argument("--repeat_id_start", type=int, default=0) 202 | parser.add_argument("--n_repeat_sampling", type=int, default=1) 203 | parser.add_argument("--complete_partial_output", action='store_true') 204 | parser.add_argument("--prompt_format", type=str, choices=['sft', 'few_shot'], default='sft') 205 | parser.add_argument("--few_shot_prompt", type=str, default=None) 206 | parser.add_argument("--answer_extraction_fn", type=str, required=True) 207 | parser.add_argument("--eval_fn", type=str, required=True) 208 | parser.add_argument("--gpus", type=str, default=None) 209 | args, unparsed_args = parser.parse_known_args() 210 | if args.gpus is not None: 211 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 212 | 213 | print(unparsed_args, flush=True) 214 | 215 | if 'math6' in args.data_dir: 216 | args.multi_turn = True 217 | 218 | # model_name_or_path cannot be both None or both not None. 219 | model = None 220 | tokenizer = None 221 | pool = None 222 | if args.n_repeat_sampling > 1 or args.repeat_id_start != 0: 223 | assert args.temperature > 0 224 | save_dir = args.save_dir 225 | for i in range(args.repeat_id_start, args.repeat_id_start + args.n_repeat_sampling): 226 | print(f"working on the {i} trials ...", flush=True) 227 | args.save_dir = os.path.join(save_dir, str(i)) 228 | os.makedirs(args.save_dir, exist_ok=True) 229 | main(args) 230 | else: 231 | main(args) 232 | 233 | if pool is not None: 234 | pool.close() 235 | -------------------------------------------------------------------------------- /evaluation/infer/run_pal_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import_path = os.path.abspath(__file__) 5 | for _ in range(2): 6 | import_path = os.path.dirname(import_path) 7 | sys.path.append(import_path) 8 | 9 | import json 10 | from copy import deepcopy 11 | from functools import partial 12 | from vllm import LLM, SamplingParams 13 | from pebble import ProcessPool 14 | from concurrent.futures import TimeoutError 15 | import random 16 | from eval.utils import generate_completions, load_hf_lm_and_tokenizer 17 | from eval.python_executor import PythonExecutor 18 | 19 | from transformers import AutoTokenizer 20 | 21 | from data_processing.answer_extraction import * 22 | from eval.eval_script import * 23 | from few_shot_prompts import * 24 | 25 | def evaluate(eval_fn, tasks, _timeout=15): 26 | with ProcessPool() as pool: 27 | timeout_cnt = 0 28 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result() 29 | labels = [] 30 | while True: 31 | try: 32 | labels.append(int(next(iterator))) 33 | except StopIteration: 34 | break 35 | except TimeoutError as error: 36 | labels.append(0) 37 | timeout_cnt += 1 38 | except Exception as error: 39 | print(error.traceback, flush=True) 40 | exit() 41 | return labels, timeout_cnt 42 | 43 | def main(args): 44 | random.seed(42) 45 | 46 | print("Loading data...") 47 | test_data = [] 48 | with open(os.path.join(args.data_dir, f"train.jsonl" if args.infer_train_set else f"test.jsonl")) as fin: 49 | for line in fin: 50 | example = json.loads(line) 51 | messages = example['messages'] 52 | assert len(messages) in [2, 3] 53 | assert messages[-1]['role'] == 'assistant' 54 | if not args.complete_partial_output: 55 | example['reference'] = example.get('reference', '') or messages[-1]['content'] 56 | messages[-1]['content'] = '' 57 | example['messages'] = messages 58 | test_data.append(example) 59 | 60 | if args.max_num_examples and len(test_data) > args.max_num_examples: 61 | test_data = random.sample(test_data, args.max_num_examples) 62 | 63 | if args.n_subsets > 1: 64 | assert args.subset_id >= 0 and args.subset_id < args.n_subsets 65 | test_data = [item for i, item in enumerate(test_data) if i % args.n_subsets == args.subset_id] 66 | 67 | if not test_data: 68 | return 69 | 70 | if not os.path.exists(args.save_dir): 71 | os.makedirs(args.save_dir, exist_ok=True) 72 | 73 | if args.prompt_format == 'few_shot': 74 | assert args.few_shot_prompt is not None 75 | prompting = eval(args.few_shot_prompt)() 76 | 77 | prompts = [] 78 | for example in test_data: 79 | prompt = "" 80 | if args.prompt_format == 'few_shot': 81 | prompt = prompting.format_prompt(example['messages'][-2]['content'], example['messages'][-1]['content']) 82 | else: 83 | for mess in example['messages']: 84 | if args.prompt_format == 'sft': 85 | if mess['role'] == 'user': 86 | prompt += f"User: {mess['content'].strip()}\n\nAssistant:" 87 | elif mess['role'] == 'assistant': 88 | prompt += mess['content'].strip() 89 | else: 90 | raise NotImplementedError() 91 | prompt = prompt.lstrip() 92 | example['prompt'] = prompt 93 | prompts.append(prompt.lstrip()) 94 | 95 | global model, tokenizer 96 | if tokenizer is None: 97 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path, trust_remote_code=True) 98 | print("Loading model and tokenizer...") 99 | if args.use_vllm: 100 | if model is None: 101 | model = LLM(model=args.model_name_or_path, tokenizer=args.tokenizer_name_or_path, trust_remote_code=True, tensor_parallel_size=len(os.environ['CUDA_VISIBLE_DEVICES'].split(","))) 102 | eos_token = tokenizer.eos_token if tokenizer is not None and tokenizer.eos_token is not None else '' 103 | stop_words = [eos_token] 104 | if args.prompt_format == 'few_shot': 105 | stop_words.extend(prompting.stop_words()) 106 | outputs = model.generate(prompts, SamplingParams(temperature=args.temperature, top_p=1.0, max_tokens=1024, n=1, stop=stop_words)) 107 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) # sort outputs by request_id 108 | outputs = [output.outputs[0].text for output in outputs] 109 | else: 110 | model, tokenizer = load_hf_lm_and_tokenizer( 111 | model_name_or_path=args.model_name_or_path, 112 | tokenizer_name_or_path=args.tokenizer_name_or_path, 113 | load_in_8bit=args.load_in_8bit, 114 | load_in_half=args.load_in_half, 115 | gptq_model=args.gptq 116 | ) 117 | 118 | stop_id_sequences = [] 119 | if tokenizer.eos_token_id is not None: 120 | stop_id_sequences = [[tokenizer.eos_token_id]] 121 | if args.prompt_format == 'few_shot': 122 | stop_id_sequences.extend([tokenizer.encode(word) for word in prompting.stop_words()]) 123 | outputs, finish_completion = generate_completions( 124 | model=model, 125 | tokenizer=tokenizer, 126 | prompts=prompts, 127 | max_new_tokens=512, 128 | batch_size=args.eval_batch_size, 129 | stop_id_sequences=stop_id_sequences if stop_id_sequences else None, 130 | end_of_generation_id_sequence=[tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else None 131 | ) 132 | 133 | if args.complete_partial_output: 134 | model_outputs = [example['messages'][-1]['content'] + output for example, output in zip(test_data, outputs)] 135 | else: 136 | model_outputs = outputs 137 | 138 | if 'PALGSMPrompt' in args.few_shot_prompt: 139 | executor = PythonExecutor(get_answer_expr='solution()') 140 | codes = model_outputs 141 | elif 'PALMathPrompt' in args.few_shot_prompt: 142 | executor = PythonExecutor(get_answer_symbol='answer') 143 | codes = [] 144 | for text in model_outputs: 145 | if text.count("```") == 4: 146 | segments = text.split("```") 147 | assert len(segments) == 5 148 | code = f"{segments[3]}\n\n{segments[1]}" 149 | else: 150 | code = "answer = '[invalid]'" 151 | codes.append(code) 152 | else: 153 | raise NotImplementedError() 154 | 155 | predictions = [] 156 | runtime_errors = [] 157 | for pred, err in executor.batch_apply(codes): 158 | predictions.append(str(pred)) 159 | runtime_errors.append(str(err['exec_info']).strip()) 160 | 161 | assert len(model_outputs) > 0, f"{len(model_outputs)}" 162 | 163 | results = [] 164 | for example, output, pred in zip(test_data, model_outputs, predictions): 165 | item = deepcopy(example) 166 | item.update({ 167 | 'model_output': output, 168 | 'program_output': pred, 169 | }) 170 | results.append(item) 171 | 172 | labels, eval_timeout_cnt = evaluate(partial(eval(args.eval_fn), pred_key='program_output'), results) 173 | for item, label in zip(results, labels): 174 | item['accuracy'] = label 175 | 176 | print("Calculating accuracy...") 177 | acc = 0 178 | for item in results: 179 | acc += item['accuracy'] 180 | print("output acc = {:.5f}".format(acc / len(results) * 100), flush=True) 181 | 182 | print(f"Timeout count >>> output eval = {eval_timeout_cnt}", flush=True) 183 | 184 | pred_fname = "predictions.json" 185 | if args.n_subsets > 1: 186 | pred_fname = f"predictions.{args.subset_id}.json" 187 | with open(os.path.join(args.save_dir, pred_fname), "w") as fout: 188 | json.dump(results, fout, ensure_ascii=True) 189 | 190 | metric_fname = "metrics.json" 191 | if args.n_subsets > 1: 192 | metric_fname = f"metrics.{args.subset_id}.json" 193 | with open(os.path.join(args.save_dir, metric_fname), "w") as fout: 194 | json.dump({ 195 | "n_samples": len(results), 196 | "accuracy": sum(item['accuracy'] for item in results) / len(results), 197 | }, fout, indent=4) 198 | 199 | if __name__ == "__main__": 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument("--data_dir", type=str, default="data/mgsm") 202 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.") 203 | parser.add_argument("--save_dir", type=str, default="results/mgsm") 204 | parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.") 205 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.") 206 | parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.") 207 | parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.") 208 | parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.") 209 | parser.add_argument("--use_vllm", action="store_true") 210 | parser.add_argument("--load_in_half", action='store_true') 211 | parser.add_argument("--infer_train_set", action="store_true") 212 | parser.add_argument("--n_subsets", type=int, default=1) 213 | parser.add_argument("--subset_id", type=int, default=0) 214 | parser.add_argument("--temperature", type=float, default=0.0) 215 | parser.add_argument("--repeat_id_start", type=int, default=0) 216 | parser.add_argument("--n_repeat_sampling", type=int, default=1) 217 | parser.add_argument("--complete_partial_output", action='store_true') 218 | parser.add_argument("--prompt_format", type=str, choices=['sft', 'few_shot'], default='sft') 219 | parser.add_argument("--few_shot_prompt", type=str, default=None) 220 | parser.add_argument("--answer_extraction_fn", type=str, default=None) 221 | parser.add_argument("--eval_fn", type=str, required=True) 222 | parser.add_argument("--gpus", type=str, default=None) 223 | args, unparsed_args = parser.parse_known_args() 224 | if args.gpus is not None: 225 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 226 | 227 | print(unparsed_args, flush=True) 228 | 229 | model = None 230 | tokenizer = None 231 | pool = None 232 | if args.n_repeat_sampling > 1 or args.repeat_id_start != 0: 233 | assert args.temperature > 0 234 | save_dir = args.save_dir 235 | for i in range(args.repeat_id_start, args.repeat_id_start + args.n_repeat_sampling): 236 | print(f"working on the {i} trials ...", flush=True) 237 | args.save_dir = os.path.join(save_dir, str(i)) 238 | os.makedirs(args.save_dir, exist_ok=True) 239 | main(args) 240 | else: 241 | main(args) 242 | 243 | if pool is not None: 244 | pool.close() 245 | -------------------------------------------------------------------------------- /evaluation/outputs.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/evaluation/outputs.zip -------------------------------------------------------------------------------- /evaluation/run_subset_parallel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from glob import glob 5 | import time 6 | import json 7 | import subprocess 8 | 9 | from utils import read_data 10 | from data_processing.process_utils import * 11 | 12 | _worker_num = int(os.environ.get('WORLD_SIZE', 1)) 13 | _worker_id = int(os.environ.get('RANK', 0)) 14 | 15 | def markup_question(args, item, language, src, task): 16 | for i in range(len(item['messages']) - 2, -1, -2): 17 | if language == 'zh': 18 | if task == 'cot': 19 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请通过逐步推理来解答问题,并把最终答案放置于" + "\\boxed{}中。" 20 | elif task == 'tool': 21 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\n请结合自然语言和Python程序语言来解答问题,并把最终答案放置于" + "\\boxed{}中。" 22 | else: 23 | pass 24 | elif language == 'en': 25 | if task == 'cot': 26 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease reason step by step, and put your final answer within " + "\\boxed{}." 27 | elif task == 'tool': 28 | item['messages'][i]['content'] = f"{item['messages'][i]['content']}\nPlease integrate natural language reasoning with programs to solve the problem above, and put your final answer within " + "\\boxed{}." 29 | else: 30 | pass 31 | return item 32 | 33 | def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, output_dir, log_dir): 34 | if task == 'pal': 35 | code_fname = "run_pal_eval" 36 | elif task == 'cot': 37 | code_fname = "run_cot_eval" 38 | elif task == 'tool': 39 | code_fname = "run_tool_integrated_eval" 40 | else: 41 | raise NotImplementedError() 42 | 43 | n_procs = args.ngpus // args.ngpus_per_model 44 | 45 | gpus = [str(i) for i in range(args.ngpus)] 46 | gpu_groups = [] 47 | for i in range(n_procs): 48 | gpu_groups.append(gpus[i * args.ngpus_per_model: (i + 1) * args.ngpus_per_model]) 49 | 50 | global_n_procs = n_procs * _worker_num 51 | 52 | procs = [] 53 | for pid, gpus in enumerate(gpu_groups): 54 | global_pid = n_procs * (args.rank or _worker_id) + pid 55 | logpath = os.path.join(log_dir, f"{global_pid}.log") 56 | f = open(logpath, "w") 57 | cmd = f"python infer/{code_fname}.py " \ 58 | f"--data_dir {input_dir} " \ 59 | f"--max_num_examples 100000000000000 " \ 60 | f"--save_dir {output_dir} " \ 61 | f"--model {args.model_path} " \ 62 | f"--tokenizer {args.tokenizer_path or args.model_path} " \ 63 | f"--eval_batch_size 1 " \ 64 | f"--temperature {args.temperature} " \ 65 | f"--repeat_id_start 0 " \ 66 | f"--n_repeat_sampling {args.n_repeats} " \ 67 | f"--n_subsets {global_n_procs} " \ 68 | f"--prompt_format {args.prompt_format} " \ 69 | f"--few_shot_prompt {args.few_shot_prompt} " \ 70 | f"--answer_extraction_fn {answer_extraction_fn} " \ 71 | f"--eval_fn {eval_fn} " \ 72 | f"--subset_id {global_pid} " \ 73 | f"--gpus {','.join(gpus)} " 74 | if args.use_vllm: 75 | cmd += " --use_vllm " 76 | if args.load_in_half: 77 | cmd += " --load_in_half " 78 | local_metric_path = os.path.join(output_dir, f"metrics.{global_pid}.json") 79 | if not args.overwrite and os.path.exists(local_metric_path) and read_data(local_metric_path)['n_samples'] > 0: 80 | continue 81 | procs.append((global_pid, subprocess.Popen(cmd.split(), stdout=f, stderr=f), f)) 82 | for (global_pid, proc, f) in procs: 83 | print(f"Waiting for the {global_pid}th process to finish ...", flush=True) 84 | proc.wait() 85 | for (global_pid, proc, f) in procs: 86 | print(f"Closing the {global_pid}th process ...", flush=True) 87 | f.close() 88 | 89 | time.sleep(1) 90 | 91 | local_pids = [global_pid for (global_pid, _, _) in procs] 92 | 93 | agg_preds = [] 94 | for fname in glob(os.path.join(output_dir, "predictions.*.json")): 95 | if any(str(pid) in fname for pid in local_pids): 96 | agg_preds.extend(read_data(fname)) 97 | 98 | metrics = {} 99 | n_samples = 0 100 | for fname in glob(os.path.join(output_dir, "metrics.*.json")): 101 | if not any(str(pid) in fname for pid in local_pids): 102 | continue 103 | _metrics = read_data(fname) 104 | n_samples += _metrics['n_samples'] 105 | for key, val in _metrics.items(): 106 | if key != 'n_samples': 107 | metrics[key] = metrics.get(key, 0) + val * _metrics['n_samples'] 108 | for key, val in metrics.items(): 109 | metrics[key] = val / max(n_samples, 1) 110 | 111 | result_msg = f"n samples = {n_samples}" 112 | for key, val in metrics.items(): 113 | result_msg += f"\n{key} = {val * 100}" 114 | 115 | metrics['n_samples'] = n_samples 116 | 117 | return metrics, agg_preds, result_msg 118 | 119 | def main(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--output-dir", type=str, required=True, help="default to `model_path`_predictions") 122 | parser.add_argument("--model-path", type=str, required=True) 123 | parser.add_argument("--tokenizer-path", type=str, default=None) 124 | parser.add_argument("--model-size", type=str, choices=['1b', '7b', '13b', '33b', '34b', '70b'], default="7b") 125 | 126 | parser.add_argument("--test-conf", type=str, default="configs/zero_shot_test_configs.json", help="path to testing data config file that maps from a source to its info") 127 | parser.add_argument("--ngpus", type=int, default=8) 128 | parser.add_argument("--overwrite", action='store_true') 129 | parser.add_argument("--temperature", type=float, default=0) 130 | parser.add_argument("--n-repeats", type=int, default=1) 131 | parser.add_argument("--use-vllm", action='store_true') 132 | parser.add_argument("--load_in_half", action='store_true') 133 | 134 | parser.add_argument("--prompt_format", type=str, default="sft") 135 | parser.add_argument("--few_shot_prompt", type=str, default=None) 136 | 137 | parser.add_argument("--no-markup-question", action='store_true') 138 | 139 | parser.add_argument("--rank", type=int, default=None) 140 | parser.add_argument("--seed", type=int, default=42) 141 | args, _ = parser.parse_known_args() 142 | 143 | print(f"Evaluating {args.model_path}", flush=True) 144 | 145 | if args.output_dir is None: 146 | args.output_dir = f"{args.model_path.rstrip('/')}_predictions" 147 | 148 | args.ngpus_per_model = 4 if args.model_size in ['70b', '33b', '34b'] else 1 149 | assert args.ngpus % args.ngpus_per_model == 0 150 | 151 | default_few_shot_prompt = args.few_shot_prompt 152 | 153 | test_conf = read_data(args.test_conf) 154 | 155 | for src, info in test_conf.items(): 156 | if args.n_repeats > 1: 157 | _src = f"{src}/sample_logs" 158 | else: 159 | _src = f"{src}/infer_logs" 160 | if _worker_num > 1: 161 | _src = f"{_src}/{args.rank or _worker_id}" 162 | if args.prompt_format == 'few_shot': 163 | args.few_shot_prompt = info.get('few_shot_prompt', None) or default_few_shot_prompt 164 | for task in info['tasks']: 165 | fname = os.path.join(args.output_dir, _src, task, "test_data", "test.jsonl") 166 | input_dir = os.path.dirname(fname) 167 | os.makedirs(input_dir, exist_ok=True) 168 | metric_path = os.path.join(args.output_dir, _src, task, "samples", "metrics.json") 169 | if not args.overwrite and os.path.exists(metric_path) and read_data(metric_path)['n_samples'] > 0: 170 | continue 171 | with open(fname, "w") as file: 172 | data = read_data(info['test_path']) 173 | for i, sample in enumerate(tqdm(data, desc=f'processing {src}')): 174 | fn = eval(info['process_fn']) 175 | sample['id'] = sample.get('id', f"{src}-{i}") 176 | for j, item in enumerate(fn(sample)): 177 | item['dataset'] = src 178 | item['id'] = f"{src}-test-{i}-{j}" 179 | assert 'answer' in item 180 | if not args.no_markup_question: 181 | item = markup_question(args, item, info['language'], src, task) 182 | print(json.dumps(item), file=file, flush=True) 183 | 184 | output_dir = os.path.join(args.output_dir, _src, task, "samples") 185 | log_dir = os.path.join(args.output_dir, _src, task, "logs") 186 | os.makedirs(output_dir, exist_ok=True) 187 | os.makedirs(log_dir, exist_ok=True) 188 | metrics, agg_preds, result_msg = do_parallel_sampling(args, task, info['answer_extraction_fn'], info['eval_fn'], input_dir, output_dir, log_dir) 189 | 190 | os.makedirs(os.path.dirname(metric_path), exist_ok=True) 191 | json.dump(metrics, open(metric_path, "w"), indent=4) 192 | data_path = os.path.join(args.output_dir, _src, task, "samples", "predictions.json") 193 | os.makedirs(os.path.dirname(data_path), exist_ok=True) 194 | with open(data_path, "w") as file: 195 | json.dump(agg_preds, file, ensure_ascii=False) 196 | print(f"src = {src} | task = {task} >>>\n{result_msg}\n\n", flush=True) 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /evaluation/submit_eval_jobs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | configs = [ 5 | { 6 | 'output-dir': "outputs/DeepSeekMath-Base", 7 | 'model-path': "deepseek-ai/deepseek-math-7b-base", 8 | 'tokenizer-path': "deepseek-ai/deepseek-math-7b-base", 9 | 'model-size': "7b", 10 | 'overwrite': False, 11 | 'use-vllm': True, 12 | 'no-markup-question': True, 13 | 'test-conf': "configs/few_shot_test_configs.json", 14 | 'prompt_format': 'few_shot', 15 | 'expname': 'eval-deepseek-math-7b-base' 16 | }, 17 | { 18 | 'output-dir': "outputs/DeepSeekMath-Instruct", 19 | 'model-path': "deepseek-ai/deepseek-math-7b-instruct", 20 | 'tokenizer-path': "deepseek-ai/deepseek-math-7b-instruct", 21 | 'model-size': "7b", 22 | 'overwrite': False, 23 | 'use-vllm': True, 24 | 'test-conf': "configs/zero_shot_test_configs.json", 25 | 'expname': 'eval-deepseek-math-7b-instruct' 26 | }, 27 | { 28 | 'output-dir': "outputs/DeepSeekMath-RL", 29 | 'model-path': "deepseek-ai/deepseek-math-7b-rl", 30 | 'tokenizer-path': "deepseek-ai/deepseek-math-7b-rl", 31 | 'model-size': "7b", 32 | 'overwrite': False, 33 | 'use-vllm': True, 34 | 'test-conf': "configs/zero_shot_test_configs.json", 35 | 'expname': 'eval-deepseek-math-7b-rl' 36 | } 37 | ] 38 | 39 | base_conf, instruct_conf, rl_conf = configs 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--n-repeats", type=int ,default=1) 44 | parser.add_argument("--temperature", type=float, default=0) 45 | parser.add_argument("--n-gpus", type=int, default=8) 46 | args = parser.parse_args() 47 | 48 | conf = base_conf # TODO: your conf here 49 | cmd = "python run_subset_parallel.py" 50 | for key, val in conf.items(): 51 | if key == 'expname': 52 | continue 53 | if isinstance(val, str): 54 | cmd += f" --{key} {val}" 55 | elif val: 56 | cmd += f" --{key}" 57 | cmd += f" --test-conf {conf['test-conf']}" 58 | cmd += f" --n-repeats {args.n_repeats}" 59 | cmd += f" --temperature {args.temperature}" 60 | cmd += f" --ngpus {args.n_gpus}" 61 | cmd += f" --rank {0} &" 62 | print(cmd, flush=True) 63 | os.system(cmd) 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /evaluation/summarize_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | from glob import glob 6 | from copy import deepcopy 7 | 8 | def seek_metrics(path): 9 | if os.path.isdir(path): 10 | for subpath in glob(os.path.join(path, "*")): 11 | yield from seek_metrics(subpath) 12 | else: 13 | if "metrics.json" in path: 14 | yield path 15 | 16 | def seek_predictions(path): 17 | if os.path.isdir(path): 18 | for subpath in glob(os.path.join(path, "*")): 19 | yield from seek_predictions(subpath) 20 | else: 21 | if "predictions.json" in path: 22 | yield path 23 | 24 | def aggregate_metrics(paths): 25 | result = {} 26 | total = 0 27 | for path in paths: 28 | metric = json.load(open(path, "r")) 29 | n_samples = metric['n_samples'] 30 | total += n_samples 31 | for key, val in metric.items(): 32 | if key != 'n_samples': 33 | result[key] = result.get(key, 0) + val * n_samples 34 | for key, val in result.items(): 35 | result[key] = val / total 36 | result['n_samples'] = total 37 | return result 38 | 39 | def aggregate_predictions(paths): 40 | data = [] 41 | for path in paths: 42 | try: 43 | data.extend(json.load(open(path, "r"))) 44 | except: 45 | print(path, flush=True) 46 | continue 47 | return data 48 | 49 | def main(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--dirname", type=str, default="outputs") 52 | parser.add_argument("--eval-atp", action='store_true') 53 | parser.add_argument("--isa-path", type=str, default="") 54 | parser.add_argument("--theory-file", type=str, default="") 55 | args = parser.parse_args() 56 | 57 | model2dataset2task2metric = {} 58 | for model in os.listdir(args.dirname): 59 | model2dataset2task2metric[model] = {} 60 | subdir = os.path.join(args.dirname, model) 61 | for dataset in os.listdir(subdir): 62 | log_dir = os.path.join(subdir, dataset, "infer_logs") 63 | agg_dirname = os.path.join(subdir, dataset, "results") 64 | if not os.path.exists(log_dir): 65 | os.makedirs(log_dir, exist_ok=True) 66 | os.system(f"mv {subdir}/{dataset}/* {log_dir}") 67 | metric_paths = list(seek_metrics(log_dir)) 68 | pred_paths = list(seek_predictions(log_dir)) 69 | task2metric_paths = {'cot': [], 'tool': []} 70 | task2pred_paths = {'cot': [], 'tool': []} 71 | for path in metric_paths: 72 | if 'cot' in path: 73 | task2metric_paths['cot'].append(path) 74 | else: 75 | task2metric_paths['tool'].append(path) 76 | for path in pred_paths: 77 | if 'cot' in path: 78 | task2pred_paths['cot'].append(path) 79 | else: 80 | task2pred_paths['tool'].append(path) 81 | task2metric = {task: aggregate_metrics(paths) for task, paths in task2metric_paths.items()} 82 | task2pred = {task: aggregate_predictions(paths) for task, paths in task2pred_paths.items()} 83 | model2dataset2task2metric[model][dataset] = task2metric 84 | 85 | for task in task2metric: 86 | task_dirname = os.path.join(agg_dirname, task) 87 | os.makedirs(task_dirname, exist_ok=True) 88 | metric_path = os.path.join(task_dirname, "metrics.json") 89 | pred_path = os.path.join(task_dirname, "predictions.json") 90 | json.dump(task2metric[task], open(metric_path, "w"), indent=4) 91 | json.dump(task2pred[task], open(pred_path, "w"), indent=4) 92 | if 'minif2f' in dataset.lower() and 'isabelle' in dataset.lower() and task2pred[task] and args.eval_atp: 93 | eval_path = metric_path + ".eval" 94 | if os.path.exists(eval_path) and json.load(open(eval_path, "r")).get('n_samples', 0): 95 | model2dataset2task2metric[model][dataset][task] = json.load(open(eval_path, "r")) 96 | continue 97 | print(f"Running minif2f-isabelle evaluation on {dataset} ...", flush=True) 98 | print(f"Predictions >>> {pred_path}", flush=True) 99 | cmd = f"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python unsafe_score_minif2f_isabelle.py " \ 100 | f"--isa-path {args.isa_path} " \ 101 | f"--theory-file {args.theory_file} " \ 102 | f"--working-dir {args.working_dir} " \ 103 | f"--port 9000 " \ 104 | f"--output {pred_path} " 105 | os.system(cmd) 106 | 107 | json.dump(model2dataset2task2metric, open("evaluation_results.json", "w"), indent=4, ensure_ascii=False) 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /evaluation/unsafe_score_minif2f_isabelle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import json 4 | import sys 5 | import os 6 | import time 7 | from tqdm import tqdm 8 | import traceback 9 | 10 | class Checker(object): 11 | """A modified version of the Draft, Sketch, Prove proof-checking client. 12 | (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py) 13 | 14 | This checker supports Isabelle2022 via PISA 15 | (https://albertqjiang.github.io/Portal-to-ISAbelle/). 16 | 17 | It supports checking a miniF2F-style proof via `check`. 18 | 19 | Finally, it replaces `sledgehammer` with a call to `normalhammer`. 20 | """ 21 | def __init__(self, working_dir, isa_path, theory_file, port=9000): 22 | sys.path.append(os.environ['PISA_PATH']) 23 | try: 24 | from pisa_client import initialise_env 25 | self.initialise_env = initialise_env 26 | except Exception as e: 27 | traceback.print_exc() 28 | print(e) 29 | print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python") 30 | 31 | self.working_dir = working_dir 32 | self.isa_path = isa_path 33 | self.theory_file = theory_file 34 | self.port = port 35 | 36 | def _initialize(self): 37 | env = self.initialise_env( 38 | self.port, 39 | isa_path=self.isa_path, 40 | theory_file_path=self.theory_file, 41 | working_directory=self.working_dir 42 | ) 43 | return env 44 | 45 | def _exit(self, env): 46 | try: 47 | env.post('exit') 48 | except: 49 | print("env.post('exit') timed out") 50 | pass 51 | os.system("ps aux | grep Isabelle2022/contrib | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1") 52 | os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1") 53 | 54 | def _parse_output(self, obs): 55 | """Parse the sledgehammer output, otherwise return an empty string""" 56 | if '' in obs: 57 | output = obs.split('')[0] 58 | else: 59 | output = '' 60 | return output 61 | 62 | def _run_step(self, step, i, tls_name, env): 63 | obs, reward, done, metadata = env.step_to_top_level_state( 64 | action=step, 65 | tls_name=tls_name, 66 | new_name='default_%d' % i 67 | ) 68 | error = None 69 | if 'error:' in obs or 'Step error' in obs or 'Unknown error' in obs: 70 | error = obs 71 | return obs, reward, done, metadata, error 72 | 73 | def _run_sledgehammer(self, step, i, tls_name, env): 74 | # First try heuristics 75 | for heuristic in [ 76 | 'by auto', 'by simp', 'by blast', 'by fastforce', 77 | 'by force', 'by eval', 'by presburger', 'by sos', 78 | 'by arith', 'by linarith', 'by (auto simp: field_simps)' 79 | ]: 80 | step_ = step.replace('normalhammer', heuristic) 81 | obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env) 82 | if error is None: 83 | obs = '%s %s' % (heuristic, obs) 84 | return obs, reward, done, metadata, error 85 | # Try sledgehammer 86 | out = self._run_step(step, i, tls_name, env) 87 | return out 88 | 89 | def check(self, statement_and_proof): 90 | # Initialize environment 91 | env = self._initialize() 92 | env.initialise() 93 | 94 | # Wrap and parse theorem 95 | theory = Checker.wrap_theorem(statement_and_proof) 96 | steps = Checker.get_parsed(env, theory) 97 | 98 | result = self._check(env, steps) 99 | return result 100 | 101 | def _check(self, env, steps): 102 | done = False 103 | reason = '' 104 | success = False 105 | step_results = [] 106 | tls_name = 'default' 107 | for i, step in enumerate(steps): 108 | try: 109 | time0 = time.time() 110 | if 'normalhammer' in step: 111 | obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env) 112 | else: 113 | obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env) 114 | step_time = time.time() - time0 115 | step_results.append(dict( 116 | index=i, step=step, output=self._parse_output(obs), step_time=step_time 117 | )) 118 | if error is not None: 119 | reason = error 120 | success = False 121 | done = False 122 | break 123 | except: 124 | # Timeout - end the proof attempt 125 | success = False 126 | done = False 127 | reason = 'timeout (%d)' % len(step_results) 128 | step_results.append(dict(index=i, step=step, output='')) 129 | break 130 | 131 | # Change when successful 132 | tls_name = 'default_%d' % i 133 | 134 | if done and reward == 1.0: 135 | success = True 136 | 137 | result = { 138 | 'success': success, 139 | 'reason': reason, 140 | 'num_steps': len(steps), 141 | 'last_step': len(step_results), 142 | 'step_results': step_results 143 | } 144 | # Exit environment 145 | self._exit(env) 146 | return result 147 | 148 | @staticmethod 149 | def wrap_theorem(theorem): 150 | return 'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem 151 | 152 | @staticmethod 153 | def get_parsed(env, theory, tls_name='default'): 154 | # The parsing doesn't work well with `normalhammer`, so we replace 155 | # all hammer calls with sorry, then replace sorry to normalhammer after parsing. 156 | theory = theory.replace('sledgehammer', 'sorry') 157 | theory = theory.replace('normalhammer', 'sorry') 158 | 159 | steps = env.post(f" ${theory}") 160 | steps = steps.split('') 161 | steps = [s for s in steps if s.strip() != ''] 162 | # remove '$' step and whitespace steps 163 | steps = [s for s in steps if s != '$' and s.strip() != ''] 164 | steps = [s.replace('sorry', 'normalhammer') for s in steps] 165 | return steps 166 | 167 | 168 | def check_proof(formal_statement, proof, working_dir, isa_path, theory_file, port): 169 | checker = Checker( 170 | working_dir=working_dir, 171 | isa_path=isa_path, 172 | theory_file=theory_file, 173 | port=port 174 | ) 175 | theorem_with_proof = f"{formal_statement}\n{proof}" 176 | result = checker.check(theorem_with_proof) 177 | return result 178 | 179 | 180 | def main(args): 181 | with open(args.output) as f: 182 | docs = json.load(f) 183 | 184 | if args.limit: 185 | limit = args.limit 186 | else: 187 | limit = len(docs) 188 | 189 | pass_at_1s = [] 190 | pass_at_anys = [] 191 | for i, doc in enumerate(tqdm(docs[:limit])): 192 | formal_statement = doc['messages'][-2]['content'].split("Formal:", 1)[1].strip() 193 | proofs = [doc['prediction'].strip()] 194 | 195 | pass_at_1 = 0 196 | pass_at_any = 0 197 | checked_proofs = [] 198 | for j, proof in enumerate(proofs): 199 | result = check_proof( 200 | formal_statement=formal_statement, 201 | proof=proof, 202 | working_dir=args.working_dir, 203 | isa_path=args.isa_path, 204 | theory_file=args.theory_file, 205 | port=args.port 206 | ) 207 | 208 | if result['success']: 209 | pass_at_any = 1 210 | if j == 0: 211 | pass_at_1 = 1 212 | 213 | checked_proofs.append({ 214 | 'proof': proof, 215 | 'result': result 216 | }) 217 | 218 | pass_at_1s.append(pass_at_1) 219 | pass_at_anys.append(pass_at_any) 220 | 221 | print(f"acc: {sum(pass_at_1s)} / {len(pass_at_1s)} = {sum(pass_at_1s) / max(len(pass_at_1s), 1)}", flush=True) 222 | 223 | doc['eval'] = { 224 | 'checked_proofs': checked_proofs, 225 | 'pass_at_1': pass_at_1, 226 | 'pass_at_any': pass_at_any 227 | } 228 | 229 | metrics = { 230 | "pass_at_1": sum(pass_at_1s) / len(pass_at_1s), 231 | "pass_at_any": sum(pass_at_anys) / len(pass_at_anys), 232 | "n_samples": len(pass_at_1s) 233 | } 234 | 235 | output_path = args.output + ".eval" 236 | metrics_path = os.path.join(os.path.dirname(args.output), "metrics.json.eval") 237 | json.dump(docs, open(output_path, "w"), indent=4) 238 | json.dump(metrics, open(metrics_path, "w"), indent=4) 239 | 240 | 241 | if __name__ == "__main__": 242 | logging.basicConfig(level=logging.INFO) 243 | logging.critical( 244 | "THIS PROGRAM EXECUTES UNTRUSTED MODEL GENERATED CODE." 245 | "THERE HAS BEEN NO EFFORT TO AVOID OS AND NETWORK SIDE EFFECTS." 246 | "USE WITH CAUTION." 247 | ) 248 | 249 | parser = argparse.ArgumentParser("Unsafe script for scoring the minif2f_isabelle tasks") 250 | 251 | parser.add_argument( 252 | "--isa-path", 253 | type=str, 254 | help="path to Isabelle installation (see setup documentation), e.g. " 255 | "/path/to/Isabelle2022" 256 | ) 257 | parser.add_argument( 258 | "--theory-file", 259 | type=str, 260 | help="path to Interactive.thy (see setup documentation), e.g. " 261 | "/path/to/Isabelle2022/src/HOL/Examples/Interactive.thy" 262 | ) 263 | parser.add_argument( 264 | "--working-dir", 265 | type=str, 266 | help="path to Isabelle working directory (see setup documentation), e.g. " 267 | "/path/to/Isabelle2022/src/HOL/Examples" 268 | ) 269 | parser.add_argument( 270 | "--port", 271 | type=int, 272 | default=9000, 273 | help="PISA server port (see setup documentation)" 274 | ) 275 | parser.add_argument( 276 | "--output", 277 | type=str, 278 | help="path to output file from running miniF2F Isabelle tasks" 279 | ) 280 | parser.add_argument( 281 | "--limit", 282 | type=int, 283 | default=None, 284 | help="for debugging purposes, max examples per task to process" 285 | ) 286 | 287 | args = parser.parse_args() 288 | main(args) -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import numpy as np 4 | 5 | def set_seed(seed): 6 | if seed > 0: 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | 10 | def shuffle(data, seed): 11 | if seed < 0: 12 | return data 13 | set_seed(seed) 14 | indices = list(range(len(data))) 15 | np.random.shuffle(indices) 16 | data = [data[i] for i in indices] 17 | return data 18 | 19 | def read_data(path): 20 | if path.endswith("json"): 21 | data = json.load(open(path, "r")) 22 | elif path.endswith("jsonl"): 23 | data = [] 24 | with open(path, "r") as file: 25 | for line in file: 26 | line = json.loads(line) 27 | data.append(line) 28 | else: 29 | raise NotImplementedError() 30 | return data 31 | -------------------------------------------------------------------------------- /images/badge.svg: -------------------------------------------------------------------------------- 1 | DeepSeek: HomepageDeepSeekHomepage 2 | -------------------------------------------------------------------------------- /images/base_results_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/base_results_1.png -------------------------------------------------------------------------------- /images/base_results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/base_results_2.png -------------------------------------------------------------------------------- /images/base_results_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/base_results_3.png -------------------------------------------------------------------------------- /images/corpus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/corpus.png -------------------------------------------------------------------------------- /images/data_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/data_pipeline.png -------------------------------------------------------------------------------- /images/instruct_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/instruct_results.png -------------------------------------------------------------------------------- /images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | Created with Pixso. 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /images/math.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/math.png -------------------------------------------------------------------------------- /images/qr.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/b8b0f8ce093d80bf8e9a641e44142f06d092c305/images/qr.jpeg -------------------------------------------------------------------------------- /replicate/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | import os 5 | import time 6 | from threading import Thread 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 9 | from transformers.generation.streamers import TextIteratorStreamer 10 | from cog import BasePredictor, Input, ConcatenateIterator 11 | 12 | # Enable faster download speed 13 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 14 | CACHE_DIR = "model_cache" 15 | 16 | 17 | class Predictor(BasePredictor): 18 | def setup(self) -> None: 19 | """Load the model into memory to make running multiple predictions efficient""" 20 | 21 | model_name = "deepseek-ai/deepseek-math-7b-base" 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR) 23 | self.model = AutoModelForCausalLM.from_pretrained( 24 | model_name, 25 | torch_dtype=torch.bfloat16, 26 | device_map="auto", 27 | cache_dir=CACHE_DIR, 28 | ) 29 | self.model.generation_config = GenerationConfig.from_pretrained( 30 | model_name, cache_dir=CACHE_DIR 31 | ) 32 | self.model.generation_config.pad_token_id = ( 33 | self.model.generation_config.eos_token_id 34 | ) 35 | 36 | def predict( 37 | self, 38 | text: str = Input( 39 | description="Input text.", 40 | default="The integral of x^2 from 0 to 2 is", 41 | ), 42 | max_new_tokens: int = Input( 43 | description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.", 44 | default=100, 45 | ), 46 | temperature: float = Input( 47 | description="The value used to modulate the next token probabilities.", 48 | default=1, 49 | ), 50 | top_k: int = Input( 51 | description="The number of highest probability vocabulary tokens to keep for top-k-filtering.", 52 | default=50, 53 | ), 54 | top_p: float = Input( 55 | description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.", 56 | default=0.9, 57 | ), 58 | ) -> ConcatenateIterator[str]: 59 | """Run a single prediction on the model""" 60 | 61 | inputs = self.tokenizer(text, return_tensors="pt") 62 | streamer = TextIteratorStreamer( 63 | self.tokenizer, skip_prompt=True, skip_special_tokens=True 64 | ) 65 | with torch.inference_mode(): 66 | thread = Thread( 67 | target=self.model.generate, 68 | kwargs=dict( 69 | **inputs.to(self.model.device), 70 | do_sample=True, 71 | temperature=temperature, 72 | top_p=top_p, 73 | top_k=top_k, 74 | max_new_tokens=max_new_tokens, 75 | streamer=streamer, 76 | use_cache=True 77 | ), 78 | ) 79 | thread.start() 80 | for new_token in streamer: 81 | yield new_token 82 | thread.join() 83 | -------------------------------------------------------------------------------- /replicate/predict_instruct.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | import os 5 | import time 6 | from threading import Thread 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 9 | from transformers.generation.streamers import TextIteratorStreamer 10 | from cog import BasePredictor, Input, ConcatenateIterator 11 | 12 | # Enable faster download speed 13 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 14 | CACHE_DIR = "model_cache" 15 | 16 | 17 | class Predictor(BasePredictor): 18 | def setup(self) -> None: 19 | """Load the model into memory to make running multiple predictions efficient""" 20 | 21 | model_name = "deepseek-ai/deepseek-math-7b-instruct" 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR) 23 | self.model = AutoModelForCausalLM.from_pretrained( 24 | model_name, 25 | torch_dtype=torch.bfloat16, 26 | device_map="auto", 27 | cache_dir=CACHE_DIR, 28 | ) 29 | self.model.generation_config = GenerationConfig.from_pretrained( 30 | model_name, cache_dir=CACHE_DIR 31 | ) 32 | self.model.generation_config.pad_token_id = ( 33 | self.model.generation_config.eos_token_id 34 | ) 35 | 36 | def predict( 37 | self, 38 | text: str = Input( 39 | description="Input text.", 40 | default="what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \boxed{}.", 41 | ), 42 | max_new_tokens: int = Input( 43 | description="The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.", 44 | default=100, 45 | ), 46 | temperature: float = Input( 47 | description="The value used to modulate the next token probabilities.", 48 | default=1, 49 | ), 50 | top_k: int = Input( 51 | description="The number of highest probability vocabulary tokens to keep for top-k-filtering.", 52 | default=50, 53 | ), 54 | top_p: float = Input( 55 | description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.", 56 | default=0.9, 57 | ), 58 | ) -> ConcatenateIterator[str]: 59 | """Run a single prediction on the model""" 60 | 61 | messages = [{"role": "user", "content": text}] 62 | input_tensor = self.tokenizer.apply_chat_template( 63 | messages, add_generation_prompt=True, return_tensors="pt" 64 | ) 65 | streamer = TextIteratorStreamer( 66 | self.tokenizer, skip_prompt=True, skip_special_tokens=True 67 | ) 68 | 69 | with torch.inference_mode(): 70 | thread = Thread( 71 | target=self.model.generate, 72 | kwargs=dict( 73 | input_ids=input_tensor.to(self.model.device), 74 | do_sample=True, 75 | temperature=temperature, 76 | top_p=top_p, 77 | top_k=top_k, 78 | max_new_tokens=max_new_tokens, 79 | streamer=streamer, 80 | use_cache=True, 81 | ), 82 | ) 83 | thread.start() 84 | for new_token in streamer: 85 | yield new_token 86 | thread.join() 87 | --------------------------------------------------------------------------------