├── .gitignore ├── LICENSE ├── README.md ├── browsecomp_eval.py ├── common.py ├── drop_eval.py ├── gpqa_eval.py ├── healthbench_eval.py ├── healthbench_eval_test.py ├── healthbench_meta_eval.py ├── healthbench_meta_eval_test.py ├── healthbench_scripts └── healthbench_analysis.ipynb ├── humaneval_eval.py ├── math_eval.py ├── mgsm_eval.py ├── mmlu_eval.py ├── multilingual_mmlu_benchmark_results.md ├── run_multilingual_mmlu.py ├── sampler ├── chat_completion_sampler.py ├── claude_sampler.py ├── o_chat_completion_sampler.py └── responses_sampler.py ├── simple_evals.py ├── simpleqa_eval.py └── types.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | This repository contains a lightweight library for evaluating language models. 3 | We are open sourcing it so we can be transparent about the accuracy numbers we're publishing alongside our latest models. 4 | 5 | ## Benchmark Results 6 | 7 | | Model | Prompt | MMLU | GPQA [^8] | MATH [^6]| HumanEval | MGSM[^5] | DROP[^5]
(F1, 3-shot) | SimpleQA 8 | |:----------------------------:|:-------------:|:------:|:------:|:--------:|:---------:|:------:|:--------------------------:|:---------:| 9 | | **o3** | | | | | | | | | | 10 | | o3-high [^10] | n/a [^7] | 93.3 | 83.4 | 98.1 | 88.4 | 92.0 | 89.8 | 48.6 | 11 | | o3 [^9] [^10] | n/a | 92.9 | 82.8 | 97.8 | 87.4 | 92.3 | 80.6 | 49.4 | 12 | | o3-low [^10] | n/a | 92.8 | 78.6 | 96.9 | 87.3 | 91.9 | 82.3 | 49.4 | 13 | | **o4-mini** | | | | | | | | | 14 | | o4-mini-high [^9] [^10] | n/a | 90.3 | 81.3 | 98.2 | 99.3 | 93.5 | 78.1 | 19.3 | 15 | | o4-mini [^9] [^10] | n/a | 90.0 | 77.6 | 97.5 | 97.3 | 93.7 | 77.7 | 20.2 | 16 | | o4-mini-low [^10] | n/a | 89.5 | 73.6 | 96.2 | 95.9 | 93.0 | 76.0 | 20.2 | 17 | | **o3-mini** | | | | | | | | | | 18 | | o3-mini-high | n/a | 86.9 | 77.2 | 97.9 | 97.6 | 92.0 | 80.6 | 13.8 | 19 | | o3-mini | n/a | 85.9 | 74.9 | 97.3 | 96.3 | 90.8 | 79.2 | 13.4 | 20 | | o3-mini-low | n/a | 84.9 | 67.6 | 95.8 | 94.5 | 89.4 | 77.6 | 13.0 | 21 | | **o1** | | | | | | | | | 22 | | o1 | n/a | 91.8 | 75.7 | 96.4 | - | 89.3 | 90.2 | 42.6 | 23 | | o1-preview | n/a | 90.8 | 73.3 | 85.5 | 92.4 | 90.8 | 74.8 | 42.4 | 24 | | o1-mini | n/a | 85.2 | 60.0 | 90.0 | 92.4 | 89.9 | 83.9 | 07.6 | 25 | | **GPT-4.1** | | | | | | | | | | 26 | | gpt-4.1-2025-04-14 | assistant [^2]| 90.2 | 66.3 | 82.1 | 94.5 | 86.9 | 79.4 | 41.6 | 27 | | gpt-4.1-mini-2025-04-14 | assistant | 87.5 | 65.0 | 81.4 | 93.8 | 88.2 | 81.0 | 16.8 | 28 | | gpt-4.1-nano-2025-04-14 | assistant | 80.1 | 50.3 | 62.3 | 87.0 | 73.0 | 82.2 | 07.6 | 29 | | **GPT-4o** | | | | | | | | | | 30 | | gpt-4o-2024-11-20 | assistant | 85.7 | 46.0 | 68.5 | 90.2 | 90.3 | 81.5 | 38.8 | 31 | | gpt-4o-2024-08-06 | assistant | 88.7 | 53.1 | 75.9 | 90.2 | 90.0 | 79.8 | 40.1 | 32 | | gpt-4o-2024-05-13 | assistant | 87.2 | 49.9 | 76.6 | 91.0 | 89.9 | 83.7 | 39.0 | 33 | | gpt-4o-mini-2024-07-18 | assistant | 82.0 | 40.2 | 70.2 | 87.2 | 87.0 | 79.7 | 09.5 | 34 | | **GPT-4.5-preview** | | | | | | | | | 35 | | gpt-4.5-preview-2025-02-27 | assistant | 90.8 | 69.5 | 87.1 | 88.6 | 86.9 | 83.4 | 62.5 | 36 | | **GPT-4 Turbo and GPT-4** | | | | | | | | | 37 | | gpt-4-turbo-2024-04-09 | assistant | 86.7 | 49.3 | 73.4 | 88.2 | 89.6 | 86.0 | 24.2 | 38 | | gpt-4-0125-preview | assistant | 85.4 | 41.4 | 64.5 | 86.6 | 85.1 | 81.5 | n/a | 39 | | gpt-4-1106-preview | assistant | 84.7 | 42.5 | 64.3 | 83.7 | 87.1 | 83.2 | n/a | 40 | | **Other Models (Reported)** | | | | | | | | 41 | | [Claude 3.5 Sonnet](https://www.anthropic.com/news/claude-3-5-sonnet) | unknown | 88.3 | 59.4 | 71.1 | 92.0 | 91.6 | 87.1 | 28.9 | 42 | | [Claude 3 Opus](https://www.anthropic.com/news/claude-3-family) | unknown | 86.8 | 50.4 | 60.1 | 84.9 | 90.7 | 83.1 | 23.5 | 43 | | [Llama 3.1 405b](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md) | unknown | 88.6 | 50.7 | 73.8 | 89.0 | 91.6 | 84.8 | n/a 44 | | [Llama 3.1 70b](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md) | unknown | 82.0 | 41.7 | 68.0 | 80.5 | 86.9 | 79.6 | n/a 45 | | [Llama 3.1 8b](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md) | unknown | 68.4 | 30.4 | 51.9 | 72.6 | 68.9 | 59.5 | n/a 46 | | [Grok 2](https://x.ai/blog/grok-2) | unknown | 87.5 | 56.0 | 76.1 | 88.4 | n/a | n/a | n/a 47 | | [Grok 2 mini](https://x.ai/blog/grok-2) | unknown | 86.2 | 51.0 | 73.0 | 85.7 | n/a | n/a | n/a 48 | | [Gemini 1.0 Ultra](https://goo.gle/GeminiV1-5) | unknown | 83.7 | n/a | 53.2 | 74.4 | 79.0 | 82.4 | n/a 49 | | [Gemini 1.5 Pro](https://goo.gle/GeminiV1-5) | unknown | 81.9 | n/a | 58.5 | 71.9 | 88.7 | 78.9 | n/a 50 | | [Gemini 1.5 Flash](https://goo.gle/GeminiV1-5) | unknown | 77.9 | 38.6 | 40.9 | 71.5 | 75.5 | 78.4 | n/a 51 | 52 | ## Background 53 | 54 | Evals are sensitive to prompting, and there's significant variation in the formulations used in recent publications and libraries. 55 | Some use few-shot prompts or role playing prompts ("You are an expert software programmer..."). 56 | These approaches are carryovers from evaluating *base models* (rather than instruction/chat-tuned models) and from models that were worse at following instructions. 57 | 58 | For this library, we are emphasizing the *zero-shot, chain-of-thought* setting, with simple instructions like "Solve the following multiple choice problem". We believe that this prompting technique is a better reflection of the models' performance in realistic usage. 59 | 60 | **We will not be actively maintaining this repository and monitoring PRs and Issues.** In particular, we're not accepting new evals. Here are the changes we might accept. 61 | - Bug fixes (hopefully not needed!) 62 | - Adding adapters for new models 63 | - Adding new rows to the table below with eval results, given new models and new system prompts. 64 | 65 | This repository is NOT intended as a replacement for https://github.com/openai/evals, which is designed to be a comprehensive collection of a large number of evals. 66 | 67 | ## Evals 68 | 69 | This repository currently contains the following evals: 70 | 71 | - MMLU: Measuring Massive Multitask Language Understanding, reference: https://arxiv.org/abs/2009.03300, https://github.com/hendrycks/test, [MIT License](https://github.com/hendrycks/test/blob/master/LICENSE) 72 | - MATH: Measuring Mathematical Problem Solving With the MATH Dataset, reference: https://arxiv.org/abs/2103.03874, https://github.com/hendrycks/math, [MIT License](https://github.com/idavidrein/gpqa/blob/main/LICENSE) 73 | - GPQA: A Graduate-Level Google-Proof Q&A Benchmark, reference: https://arxiv.org/abs/2311.12022, https://github.com/idavidrein/gpqa/, [MIT License](https://github.com/idavidrein/gpqa/blob/main/LICENSE) 74 | - DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs, reference: https://arxiv.org/abs/1903.00161, https://allenai.org/data/drop, [Apache License 2.0](https://github.com/allenai/allennlp-models/blob/main/LICENSE) 75 | - MGSM: Multilingual Grade School Math Benchmark (MGSM), Language Models are Multilingual Chain-of-Thought Reasoners, reference: https://arxiv.org/abs/2210.03057, https://github.com/google-research/url-nlp, [Creative Commons Attribution 4.0 International Public License (CC-BY)](https://github.com/google-research/url-nlp/blob/main/LICENSE) 76 | - HumanEval: Evaluating Large Language Models Trained on Code, reference https://arxiv.org/abs/2107.03374, https://github.com/openai/human-eval, [MIT License](https://github.com/openai/human-eval/blob/master/LICENSE) 77 | - SimpleQA: Measuring short-form factuality in large language models, reference: https://openai.com/index/introducing-simpleqa, [MIT License](https://github.com/openai/simple-evals/blob/main/LICENSE) 78 | - BrowseComp: A Simple Yet Challenging Benchmark for Browsing Agents, reference: https://openai.com/index/browsecomp, [MIT License](https://github.com/openai/simple-evals/blob/main/LICENSE) 79 | - HealthBench: Evaluating Large Language Models Towards Improved Human Health, reference: https://openai.com/index/healthbench, [MIT License](https://github.com/openai/simple-evals/blob/main/LICENSE) 80 | 81 | ## Samplers 82 | 83 | We have implemented sampling interfaces for the following language model APIs: 84 | 85 | - OpenAI: https://platform.openai.com/docs/overview 86 | - Claude: https://www.anthropic.com/api 87 | 88 | Make sure to set the `*_API_KEY` environment variables before using these APIs. 89 | 90 | ## Setup 91 | 92 | Due to the optional dependencies, we're not providing a unified setup mechanism. Instead, we're providing instructions for each eval and sampler. 93 | 94 | For [HumanEval](https://github.com/openai/human-eval/) (python programming) 95 | ```bash 96 | git clone https://github.com/openai/human-eval 97 | pip install -e human-eval 98 | ``` 99 | 100 | For the [OpenAI API](https://pypi.org/project/openai/): 101 | ```bash 102 | pip install openai 103 | ``` 104 | 105 | For the [Anthropic API](https://docs.anthropic.com/claude/docs/quickstart-guide): 106 | ```bash 107 | pip install anthropic 108 | ``` 109 | 110 | ## Running the evals 111 | ```bash 112 | python -m simple-evals.simple_evals --list-models 113 | ``` 114 | This will list all the models that you can evaluate. 115 | 116 | To run the evaluations, you can use the following command: 117 | ```bash 118 | python -m simple-evals.simple_evals --model --examples 119 | ``` 120 | This will launch evaluations through the OpenAI API. 121 | 122 | ## Notes 123 | 124 | [^1]:chatgpt system message: "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" 125 | [^2]:assistant system message in [OpenAI API doc](https://platform.openai.com/docs/api-reference/introduction): "You are a helpful assistant." . 126 | [^3]:claude-3 empty system message: suggested by Anthropic API doc, and we have done limited experiments due to [rate limit](https://docs.anthropic.com/claude/reference/rate-limits) issues, but we welcome PRs with alternative choices. 127 | [^4]:claude-3 lmsys system message: system message in LMSYS [Fast-chat open source code](https://github.com/lm-sys/FastChat/blob/7899355ebe32117fdae83985cf8ee476d2f4243f/fastchat/conversation.py#L894): "The assistant is Claude, created by Anthropic. The current date is {{currentDateTime}}. Claude's knowledge base was last updated ... ". We have done limited experiments due to [rate limit](https://docs.anthropic.com/claude/reference/rate-limits) issues, but we welcome PRs with alternative choices. 128 | [^5]:We believe these evals are saturated for our newer models, but are reporting them for completeness. 129 | [^6]:For newer models (anything on or after o1) we evaluate on [MATH-500](https://github.com/openai/prm800k/tree/main/prm800k/math_splits), which is a newer, IID version of MATH. 130 | [^7]:o-series models do not support using a system prompt. 131 | [^8]:Includes an answer regex tweak for GPQA benchmark. 132 | [^9]:The default reasoning level for o3-mini is "medium". 133 | [^10]:These results are with no tools enabled for o3 or o4-mini 134 | 135 | ## Legal Stuff 136 | By contributing to evals, you are agreeing to make your evaluation logic and data under the same MIT license as this repository. You must have adequate rights to upload any data used in an eval. OpenAI reserves the right to use this data in future service improvements to our product. Contributions to OpenAI evals will be subject to our usual Usage Policies: https://platform.openai.com/docs/usage-policies. 137 | -------------------------------------------------------------------------------- /browsecomp_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | BrowseComp: A Simple Yet Challenging Benchmark for Browsing Agents 3 | Authors: Jason Wei, Zhiqing Sun, Spencer Papay, Scott McKinney, Jeffrey Han, Isa Fulford, Hyung Won Chung, Alex Tachard Passos, William Fedus, Mia Glaese 4 | https://openai.com/index/browsecomp/ 5 | """ 6 | 7 | import base64 8 | import hashlib 9 | import random 10 | import re 11 | import pandas 12 | from . import common 13 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 14 | 15 | # from: https://github.com/centerforaisafety/hle/blob/7b6be5aad6f9b43af3857de7867f3b52f6e4acb3/hle_eval/run_model_predictions.py#L11 16 | QUERY_TEMPLATE = """ 17 | {Question} 18 | 19 | Your response should be in the following format: 20 | Explanation: {{your explanation for your final answer}} 21 | Exact Answer: {{your succinct, final answer}} 22 | Confidence: {{your confidence score between 0% and 100% for your answer}} 23 | """.strip() 24 | 25 | # from: https://github.com/centerforaisafety/hle/blob/7b6be5aad6f9b43af3857de7867f3b52f6e4acb3/hle_eval/run_judge_results.py#L16-L33 26 | GRADER_TEMPLATE = """ 27 | Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below. 28 | 29 | [question]: {question} 30 | 31 | [response]: {response} 32 | 33 | Your judgement must be in the format and criteria specified below: 34 | 35 | extracted_final_answer: The final exact answer extracted from the [response]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response. 36 | 37 | [correct_answer]: {correct_answer} 38 | 39 | reasoning: Explain why the extracted_final_answer is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the extracted_final_answer. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match. 40 | 41 | correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect. 42 | 43 | 44 | confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available. 45 | """.strip() 46 | 47 | CHOICE_STRINGS = ["yes", "no"] 48 | 49 | 50 | def derive_key(password: str, length: int) -> bytes: 51 | """Derive a fixed-length key from the password using SHA256.""" 52 | hasher = hashlib.sha256() 53 | hasher.update(password.encode()) 54 | key = hasher.digest() 55 | return key * (length // len(key)) + key[: length % len(key)] 56 | 57 | 58 | def decrypt(ciphertext_b64: str, password: str) -> str: 59 | """Decrypt base64-encoded ciphertext with XOR.""" 60 | encrypted = base64.b64decode(ciphertext_b64) 61 | key = derive_key(password, len(encrypted)) 62 | decrypted = bytes(a ^ b for a, b in zip(encrypted, key)) 63 | return decrypted.decode() 64 | 65 | 66 | class BrowseCompEval(Eval): 67 | def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1): 68 | df = pandas.read_csv( 69 | "https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv" 70 | ) 71 | examples = [row.to_dict() for _, row in df.iterrows()] 72 | if num_examples: 73 | assert n_repeats == 1, "n_repeats only supported when max_examples = None" 74 | rng = random.Random(0) 75 | examples = rng.sample(examples, num_examples) 76 | self.examples = examples * n_repeats 77 | self.grader_model = grader_model 78 | 79 | def grade_sample(self, question: str, correct_answer: str, response: str) -> str: 80 | grader_prompt = GRADER_TEMPLATE.format( 81 | question=question, 82 | correct_answer=correct_answer, 83 | response=response, 84 | ) 85 | 86 | prompt_messages = [ 87 | self.grader_model._pack_message(content=grader_prompt, role="user") 88 | ] 89 | sampler_response = self.grader_model(prompt_messages) 90 | grading_response = sampler_response.response_text 91 | 92 | match = re.search(r"correct: (yes|no)", grading_response) 93 | return match.group(0) if match else "no" # Default to "no" if no match 94 | 95 | def __call__(self, sampler: SamplerBase) -> EvalResult: 96 | def fn(row: dict): 97 | problem = decrypt(row.get("problem", ""), row.get("canary", "")) 98 | answer = decrypt(row.get("answer", ""), row.get("canary", "")) 99 | prompt_messages = [ 100 | sampler._pack_message(content=QUERY_TEMPLATE.format(Question=problem), role="user") 101 | ] 102 | sampler_response = sampler(prompt_messages) 103 | response_text = sampler_response.response_text 104 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 105 | grade_result = self.grade_sample(problem, answer, response_text) 106 | 107 | # Metrics based on grading response 108 | is_correct = grade_result == "yes" 109 | is_incorrect = grade_result == "no" 110 | 111 | score = is_correct 112 | 113 | # Create HTML for each sample result 114 | html = common.jinja_env.from_string(common.HTML_JINJA).render( 115 | prompt_messages=actual_queried_prompt_messages, 116 | next_message=dict(content=response_text, role="assistant"), 117 | score=score, 118 | correct_answer=row["answer"], 119 | extracted_answer=response_text, 120 | ) 121 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] 122 | return SingleEvalResult(html=html, score=score, convo=convo, metrics={ 123 | "is_correct": is_correct, 124 | "is_incorrect": is_incorrect, 125 | }) 126 | 127 | # Run evaluation and collect results 128 | results = common.map_with_progress(fn, self.examples) 129 | 130 | # Aggregate metrics 131 | aggregate_metrics = { 132 | "is_correct": sum(result.metrics["is_correct"] for result in results) / len(results), 133 | "is_incorrect": sum(result.metrics["is_incorrect"] for result in results) / len(results), 134 | } 135 | print("AGGREGATE METRICS") 136 | print(aggregate_metrics) 137 | print("##################") 138 | 139 | output_d = { 140 | "accuracy": aggregate_metrics["is_correct"], 141 | } 142 | 143 | print(f"Accuracy: {output_d['accuracy']:.3f}") 144 | 145 | return common.aggregate_results(results) 146 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from collections import defaultdict 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from multiprocessing.pool import ThreadPool 6 | from typing import Any, Callable 7 | 8 | import jinja2 9 | import numpy as np 10 | import requests 11 | from tqdm import tqdm 12 | 13 | from .types import EvalResult, Message, SamplerBase, SingleEvalResult 14 | 15 | QUERY_TEMPLATE_MULTICHOICE = """ 16 | Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. 17 | 18 | {Question} 19 | 20 | A) {A} 21 | B) {B} 22 | C) {C} 23 | D) {D} 24 | """.strip() 25 | 26 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" 27 | ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" 28 | MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( 29 | "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" 30 | ) 31 | # All the different ways "Answer" is written in different languages 32 | MULTILINGUAL_ANSWER_REGEXES = [ 33 | "Answer\s*:", 34 | "Answer\s*:​​​​​​", # Korean invisible character 35 | "উত্তর\s*:", 36 | "उत्तर\s*:", 37 | "উত্তরঃ", 38 | "উত্তর\s*:", 39 | "Antwort\s*:", 40 | "답변\s*:", 41 | "정답\s*:", 42 | "답\s*:", 43 | "答案\s*:", 44 | "答案\s*:", 45 | "答\s*:", 46 | "答\s*:", 47 | "答复\s*:", 48 | "答曰\s*:", 49 | "الإجابة:", 50 | "الجواب:", 51 | "إجابة:", 52 | "الإجابة النهائية:", 53 | "الإجابة الصحيحة:", 54 | "الإجابة الصحيحة هي:", 55 | "الإجابة هي:", 56 | "الجواب النهائي:", 57 | "Respuesta\s*:", 58 | "Risposta\s*:", 59 | "答え\s*:", 60 | "答え\s*:", 61 | "回答\s*:", 62 | "回答\s*:", 63 | "解答\s*:", 64 | "Jawaban\s*:", 65 | "Réponse\s*:", 66 | "Resposta\s*:", 67 | "Jibu\s*:", 68 | "Idahun\s*:", 69 | "Ìdáhùn\s*:", 70 | "Idáhùn\s*:", 71 | "Àmọ̀nà\s*:", 72 | "Àdáhùn\s*:", 73 | "Ànúgọ\s*:", 74 | "Àṣàyàn\s*:", 75 | ] 76 | 77 | 78 | EQUALITY_TEMPLATE = r""" 79 | Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications 80 | 81 | Examples: 82 | 83 | Expression 1: $2x+3$ 84 | Expression 2: $3+2x$ 85 | 86 | Yes 87 | 88 | Expression 1: 3/2 89 | Expression 2: 1.5 90 | 91 | Yes 92 | 93 | Expression 1: $x^2+2x+1$ 94 | Expression 2: $y^2+2y+1$ 95 | 96 | No 97 | 98 | Expression 1: $x^2+2x+1$ 99 | Expression 2: $(x+1)^2$ 100 | 101 | Yes 102 | 103 | Expression 1: 3245/5 104 | Expression 2: 649 105 | 106 | No 107 | (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) 108 | 109 | Expression 1: 2/(-3) 110 | Expression 2: -2/3 111 | 112 | Yes 113 | (trivial simplifications are allowed) 114 | 115 | Expression 1: 72 degrees 116 | Expression 2: 72 117 | 118 | Yes 119 | (give benefit of the doubt to units) 120 | 121 | Expression 1: 64 122 | Expression 2: 64 square feet 123 | 124 | Yes 125 | (give benefit of the doubt to units) 126 | 127 | --- 128 | 129 | YOUR TASK 130 | 131 | 132 | Respond with only "Yes" or "No" (without quotes). Do not include a rationale. 133 | 134 | Expression 1: %(expression1)s 135 | Expression 2: %(expression2)s 136 | """.strip() 137 | 138 | 139 | HTML_JINJA = """ 140 |

Prompt conversation

141 | {% for message in prompt_messages %} 142 | {{ message_to_html(message) | safe }} 143 | {% endfor %} 144 |

Sampled message

145 | {{ message_to_html(next_message) | safe }} 146 |

Results

147 |

Correct Answer: {{ correct_answer }}

148 |

Extracted Answer: {{ extracted_answer }}

149 |

Score: {{ score }}

150 | """ 151 | 152 | 153 | def format_multichoice_question(row): 154 | return QUERY_TEMPLATE_MULTICHOICE.format(**row) 155 | 156 | 157 | def check_equality(sampler: SamplerBase, expr1: str, expr2: str): 158 | prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} 159 | sampler_response = sampler([dict(content=prompt, role="user")]) 160 | response_text = sampler_response.response_text 161 | return response_text.lower().strip() == "yes" 162 | 163 | 164 | def _compute_stat(values: list, stat: str): 165 | if stat == "mean": 166 | return np.mean(values) 167 | elif stat == "std": 168 | return np.std(values) 169 | elif stat == "min": 170 | return np.min(values) 171 | elif stat == "max": 172 | return np.max(values) 173 | elif stat == "n_samples": 174 | return len(values) 175 | elif stat == "bootstrap_std": 176 | return np.std( 177 | [np.mean(np.random.choice(values, len(values))) for _ in range(1000)] 178 | ) 179 | else: 180 | raise ValueError(f"Unknown {stat =}") 181 | 182 | 183 | def aggregate_results( 184 | single_eval_results: list[SingleEvalResult], 185 | default_stats: tuple[str, ...] = ("mean", "std"), 186 | name2stats: dict[str, tuple[str]] | None = None, 187 | ) -> EvalResult: 188 | """ 189 | Aggregate results from multiple evaluations into a single EvalResult. 190 | """ 191 | name2stats = name2stats or {} 192 | name2values = defaultdict(list) 193 | htmls = [] 194 | convos = [] 195 | metadata = [] 196 | for single_eval_result in single_eval_results: 197 | for name, value in single_eval_result.metrics.items(): 198 | name2values[name].append(value) 199 | if single_eval_result.score is not None: 200 | name2values["score"].append(single_eval_result.score) 201 | htmls.append(single_eval_result.html) 202 | convos.append(single_eval_result.convo) 203 | metadata.append(single_eval_result.example_level_metadata) 204 | final_metrics = {} 205 | for name, values in name2values.items(): 206 | stats = name2stats.get(name, default_stats) 207 | for stat in stats: 208 | key = name if stat == "mean" else f"{name}:{stat}" 209 | final_metrics[key] = _compute_stat(values, stat) 210 | return EvalResult( 211 | score=final_metrics.pop("score", None), 212 | metrics=final_metrics, 213 | htmls=htmls, 214 | convos=convos, 215 | metadata={"example_level_metadata": metadata}, 216 | ) 217 | 218 | 219 | def map_with_progress( 220 | f: Callable, 221 | xs: list[Any], 222 | num_threads: int = os.cpu_count() or 10, 223 | pbar: bool = True, 224 | ): 225 | """ 226 | Apply f to each element of xs, using a ThreadPool, and show progress. 227 | """ 228 | pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x 229 | 230 | if os.getenv("debug"): 231 | return list(map(f, pbar_fn(xs, total=len(xs)))) 232 | else: 233 | with ThreadPool(min(num_threads, len(xs))) as pool: 234 | return list(pbar_fn(pool.imap(f, xs), total=len(xs))) 235 | 236 | 237 | jinja_env = jinja2.Environment( 238 | loader=jinja2.BaseLoader(), 239 | undefined=jinja2.StrictUndefined, 240 | autoescape=jinja2.select_autoescape(["html", "xml"]), 241 | ) 242 | _message_template = """ 243 |
244 |
245 | {{ role }} 246 | {% if variant %}({{ variant }}){% endif %} 247 |
248 |
249 |
{{ content }}
250 |
251 |
252 | """ 253 | 254 | 255 | def message_to_html(message: Message) -> str: 256 | """ 257 | Generate HTML snippet (inside a
) for a message. 258 | """ 259 | return jinja_env.from_string(_message_template).render( 260 | role=message["role"], 261 | content=message["content"], 262 | variant=message.get("variant", None), 263 | ) 264 | 265 | 266 | jinja_env.globals["message_to_html"] = message_to_html 267 | 268 | 269 | _report_template = """ 270 | 271 | 272 | 304 | 305 | 306 | {% if metrics %} 307 |

Metrics

308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | {% for name, value in metrics.items() %} 318 | 319 | 320 | 321 | 322 | {% endfor %} 323 |
MetricValue
Score{{ score | float | round(3) }}
{{ name }}{{ value }}
324 | {% endif %} 325 |

Examples

326 | {% for html in htmls %} 327 | {{ html | safe }} 328 |
329 | {% endfor %} 330 | 331 | 332 | """ 333 | 334 | 335 | def make_report(eval_result: EvalResult) -> str: 336 | """ 337 | Create a standalone HTML report from an EvalResult. 338 | """ 339 | return jinja_env.from_string(_report_template).render( 340 | score=eval_result.score, 341 | metrics=eval_result.metrics, 342 | htmls=eval_result.htmls, 343 | ) 344 | 345 | 346 | def make_report_from_example_htmls(htmls: list[str]): 347 | """ 348 | Create a standalone HTML report from a list of example htmls 349 | """ 350 | return jinja_env.from_string(_report_template).render( 351 | score=None, metrics={}, htmls=htmls 352 | ) 353 | 354 | 355 | def normalize_response(response: str) -> str: 356 | """ 357 | Normalize the response by removing markdown and LaTeX formatting that may prevent a match. 358 | """ 359 | 360 | return ( 361 | response.replace("**", "") 362 | .replace("$\\boxed{", "") 363 | .replace("}$", "") 364 | .replace("\\$", "") 365 | .replace("$\\text{", "") 366 | .replace("$", "") 367 | .replace("\\mathrm{", "") 368 | .replace("\\{", "") 369 | .replace("\\text", "") 370 | .replace("\\(", "") 371 | .replace("\\mathbf{", "") 372 | .replace("{", "") 373 | .replace("\\boxed", "") 374 | ) 375 | 376 | 377 | def normalize_extracted_answer(extracted_answer: str) -> str: 378 | return ( 379 | # In arabic these are the letters used for A-D in multiple choice questions 380 | extracted_answer.replace("أ", " A") 381 | .replace("ب", " B") 382 | .replace("ج", " C") 383 | .replace("د", " D") 384 | # In Bengali these are the letters used for A-D in multiple choice questions 385 | .replace("অ", " A") 386 | .replace("ব", " B") 387 | .replace("ড", " C") 388 | .replace("ঢ", " D") 389 | # In Japanese these are the letters sometimes used for A-D in multiple choice questions 390 | .replace("A", " A") 391 | .replace("B", " B") 392 | .replace("C", " C") 393 | .replace("D", " D") 394 | .strip() 395 | ) 396 | 397 | 398 | def url_to_fileobj(url: str, binary=False) -> Any: 399 | response = requests.get(url) 400 | response.raise_for_status() 401 | return io.BytesIO(response.content) if binary else io.StringIO(response.text) 402 | 403 | 404 | def has_only_user_assistant_messages(messages: list[Message]) -> bool: 405 | """ 406 | Check if the messages only contain user and assistant messages. 407 | """ 408 | return all(m["role"] in ("user", "assistant") for m in messages) 409 | -------------------------------------------------------------------------------- /drop_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs 3 | Dheeru Dua, Yizhong Wang, Pradeep Dasigi, Gabriel Stanovsky, Sameer Singh, Matt Gardner 4 | https://arxiv.org/abs/1903.00161 5 | """ 6 | 7 | import gzip 8 | import json 9 | import random 10 | import re 11 | import string 12 | from typing import Any, Dict, List, Optional, Set, Tuple, Union 13 | 14 | import numpy as np 15 | from scipy.optimize import linear_sum_assignment 16 | 17 | from . import common 18 | from .common import ANSWER_PATTERN, HTML_JINJA 19 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 20 | 21 | """ 22 | From here through _normalize_answer was originally copied from: 23 | https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 24 | Then cleaned up and modified a bit. 25 | 26 | The rest was originally copied from https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc 27 | /eval/drop_eval.py 28 | """ 29 | 30 | 31 | def _remove_articles(text: str) -> str: 32 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 33 | return re.sub(regex, " ", text) 34 | 35 | 36 | def _white_space_fix(text: str) -> str: 37 | return " ".join(text.split()) 38 | 39 | 40 | EXCLUDE = set(string.punctuation) 41 | 42 | 43 | def _remove_punc(text: str) -> str: 44 | if not _is_number(text): 45 | return "".join(ch for ch in text if ch not in EXCLUDE) 46 | else: 47 | return text 48 | 49 | 50 | def _lower(text: str) -> str: 51 | return text.lower() 52 | 53 | 54 | def _tokenize(text: str) -> List[str]: 55 | return re.split(" |-", text) 56 | 57 | 58 | def _normalize_answer(text: str) -> str: 59 | """Lower text and remove punctuation, articles and extra whitespace.""" 60 | 61 | parts = [ 62 | _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) 63 | for token in _tokenize(text) 64 | ] 65 | parts = [part for part in parts if part.strip()] 66 | normalized = " ".join(parts).strip() 67 | return normalized 68 | 69 | 70 | def _is_number(text: str) -> bool: 71 | try: 72 | float(text) 73 | return True 74 | except ValueError: 75 | return False 76 | 77 | 78 | def _normalize_number(text: str) -> str: 79 | if _is_number(text): 80 | return str(float(text)) 81 | else: 82 | return text 83 | 84 | 85 | def _answer_to_bags( 86 | answer: Union[str, List[str], Tuple[str, ...]] 87 | ) -> Tuple[List[str], List[Set[str]]]: 88 | if isinstance(answer, (list, tuple)): 89 | raw_spans = answer 90 | else: 91 | raw_spans = [answer] 92 | normalized_spans: List[str] = [] 93 | token_bags = [] 94 | for raw_span in raw_spans: 95 | normalized_span = _normalize_answer(raw_span) 96 | normalized_spans.append(normalized_span) 97 | token_bags.append(set(normalized_span.split())) 98 | return normalized_spans, token_bags 99 | 100 | 101 | def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: 102 | """ 103 | Takes gold and predicted answer sets and first finds the optimal 1-1 alignment 104 | between them and gets maximum metric values over all the answers. 105 | """ 106 | scores = np.zeros([len(gold), len(predicted)]) 107 | for gold_index, gold_item in enumerate(gold): 108 | for pred_index, pred_item in enumerate(predicted): 109 | if _match_numbers_if_present(gold_item, pred_item): 110 | scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item) 111 | row_ind, col_ind = linear_sum_assignment(-scores) 112 | 113 | max_scores = np.zeros([max(len(gold), len(predicted))]) 114 | for row, column in zip(row_ind, col_ind): 115 | max_scores[row] = max(max_scores[row], scores[row, column]) 116 | return max_scores 117 | 118 | 119 | def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float: 120 | intersection = len(gold_bag.intersection(predicted_bag)) 121 | if not predicted_bag: 122 | precision = 1.0 123 | else: 124 | precision = intersection / float(len(predicted_bag)) 125 | if not gold_bag: 126 | recall = 1.0 127 | else: 128 | recall = intersection / float(len(gold_bag)) 129 | f1 = ( 130 | (2 * precision * recall) / (precision + recall) 131 | if not (precision == 0.0 and recall == 0.0) 132 | else 0.0 133 | ) * 100 134 | return f1 135 | 136 | 137 | def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: 138 | gold_numbers = set() 139 | predicted_numbers = set() 140 | for word in gold_bag: 141 | if _is_number(word): 142 | gold_numbers.add(word) 143 | for word in predicted_bag: 144 | if _is_number(word): 145 | predicted_numbers.add(word) 146 | if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): 147 | return True 148 | return False 149 | 150 | 151 | def get_drop_metrics( 152 | predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]] 153 | ) -> Tuple[float, float]: 154 | """ 155 | Takes a predicted answer and a gold answer (that are both either a string or a list of 156 | strings), and returns exact match and the DROP F1 metric for the prediction. If you are 157 | writing a script for evaluating objects in memory (say, the output of predictions during 158 | validation, or while training), this is the function you want to call, after using 159 | :func:`answer_json_to_strings` when reading the gold answer from the released data file. 160 | """ 161 | predicted_bags = _answer_to_bags(predicted) 162 | gold_bags = _answer_to_bags(gold) 163 | 164 | if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): 165 | exact_match = 1.0 166 | else: 167 | exact_match = 0.0 168 | 169 | f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) 170 | f1 = np.mean(f1_per_bag) 171 | f1 = round(f1, 2) 172 | return exact_match, f1 173 | 174 | 175 | def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]: 176 | """ 177 | Takes an answer JSON blob from the DROP data release and converts it into strings used for 178 | evaluation. 179 | """ 180 | if "number" in answer and answer["number"]: 181 | return tuple([str(answer["number"])]), "number" 182 | elif "spans" in answer and answer["spans"]: 183 | return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans" 184 | elif "date" in answer: 185 | return ( 186 | tuple( 187 | [ 188 | "{0} {1} {2}".format( 189 | answer["date"]["day"], answer["date"]["month"], answer["date"]["year"] 190 | ).strip() 191 | ] 192 | ), 193 | "date", 194 | ) 195 | else: 196 | raise ValueError( 197 | f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}" 198 | ) 199 | 200 | 201 | def answer_json_to_string(answer_json): 202 | return json.dumps(answer_json_to_strings(answer_json)) 203 | 204 | 205 | def normalize(s: str) -> str: 206 | """Lower text and remove punctuation, articles and extra whitespace.""" 207 | s = s.lower() 208 | exclude = set(string.punctuation) 209 | s = "".join(char for char in s if char not in exclude) 210 | s = re.sub(r"\b(a|an|the)\b", " ", s) 211 | s = " ".join(s.split()) 212 | return s 213 | 214 | 215 | def fuzzy_match(s1: str, s2: str) -> bool: 216 | s1 = normalize(s1) 217 | s2 = normalize(s2) 218 | 219 | if s1 == "" or s2 == "": 220 | return s1 == s2 221 | 222 | return s1 in s2 or s2 in s1 223 | 224 | 225 | def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]: 226 | em_scores = [] 227 | f1_scores = [] 228 | for answer in reference: 229 | if answer.strip() != "": 230 | em, f1 = get_drop_metrics(sample, answer) 231 | em_scores.append(em) 232 | f1_scores.append(f1) 233 | return (max(em_scores), max(f1_scores)) 234 | 235 | 236 | class DropEval(Eval): 237 | def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3): 238 | self.seed = 42 239 | self._num_examples = num_examples 240 | self._train_samples_per_prompt = train_samples_per_prompt 241 | self.train_jsonl = ( 242 | "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz" 243 | ) 244 | self.test_jsonl = ( 245 | "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz" 246 | ) 247 | with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl, binary=True), mode="rb") as f: 248 | self.train_samples = list(map(json.loads, f.readlines())) 249 | with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl, binary=True), mode="rb") as f: 250 | self.test_samples = list(map(json.loads, f.readlines())) 251 | if self._num_examples: 252 | self.test_samples = random.Random(self.seed).sample( 253 | self.test_samples, self._num_examples 254 | ) 255 | 256 | def __call__(self, sampler: SamplerBase) -> EvalResult: 257 | rng = random.Random(self.seed) 258 | 259 | def fn(example: dict[str, str]): 260 | stuffing = rng.sample(self.train_samples, self._train_samples_per_prompt) 261 | 262 | # prompt = """TASK: Read the provided passage, then identify the correct answer to questions below.""" 263 | prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below.""" 264 | prompt += "\n\n# Examples" 265 | samples = stuffing + [example] 266 | for i, sample in enumerate(samples): 267 | is_test = i == len(stuffing) 268 | prompt += "\n# Your Task\n" if is_test else "" 269 | prompt += f""" 270 | --- 271 | {sample["context"]} """ 272 | 273 | a = sample["completion"] 274 | correct_answers = sample["ref_text"].split("|") 275 | 276 | if not is_test: 277 | prompt += a + "\n" 278 | else: 279 | prompt += """\n 280 | Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response. 281 | """ 282 | prompt_messages = [sampler._pack_message(content=prompt, role="user")] 283 | sampler_response = sampler(prompt_messages) 284 | response_text = sampler_response.response_text 285 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 286 | match = re.search(ANSWER_PATTERN, response_text) 287 | extracted_answer = match.group(1) if match else response_text 288 | em_score, f1_score = drop_metric(extracted_answer, correct_answers) 289 | matches = [ 290 | fuzzy_match(extracted_answer, correct_answer) 291 | for correct_answer in correct_answers 292 | ] 293 | extracted_answers = [ 294 | extracted_answer for i in range(len(correct_answers)) if matches[i] 295 | ] 296 | score = True in matches 297 | html = common.jinja_env.from_string(HTML_JINJA).render( 298 | prompt_messages=actual_queried_prompt_messages, 299 | next_message=dict(content=extracted_answer, role="assistant"), 300 | score=score, 301 | correct_answer=correct_answers, 302 | extracted_answer=extracted_answers, 303 | ) 304 | convo = actual_queried_prompt_messages + [dict(content=extracted_answer, role="assistant")] 305 | return SingleEvalResult( 306 | html=html, 307 | score=score, 308 | convo=convo, 309 | metrics={"em_score": em_score, "f1_score": f1_score}, 310 | ) 311 | 312 | results = common.map_with_progress(fn, self.test_samples) 313 | return common.aggregate_results(results) 314 | -------------------------------------------------------------------------------- /gpqa_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPQA: A Graduate-Level Google-Proof Q&A Benchmark 3 | David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman 4 | https://arxiv.org/abs/2311.12022 5 | """ 6 | 7 | import random 8 | import re 9 | 10 | import pandas 11 | 12 | from . import common 13 | from .common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question 14 | from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult 15 | 16 | 17 | class GPQAEval(Eval): 18 | def __init__( 19 | self, 20 | n_repeats: int = 4, 21 | variant: str = "diamond", 22 | num_examples: int | None = None, # restrict to a subset of the data for debugging 23 | ): 24 | df = pandas.read_csv( 25 | f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv" 26 | ) 27 | examples = [row.to_dict() for _, row in df.iterrows()] 28 | rng = random.Random(0) 29 | if num_examples: 30 | assert n_repeats == 1, "n_repeats only supported for num_examples = None" 31 | examples = rng.sample(examples, num_examples) 32 | examples = examples * n_repeats 33 | examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples] 34 | self.examples = examples 35 | self.n_repeats = n_repeats 36 | 37 | def __call__(self, sampler: SamplerBase) -> EvalResult: 38 | def fn(row: dict): 39 | choices = [ 40 | row["Correct Answer"], 41 | row["Incorrect Answer 1"], 42 | row["Incorrect Answer 2"], 43 | row["Incorrect Answer 3"], 44 | ] 45 | choices = [choices[i] for i in row["permutation"]] 46 | correct_index = choices.index(row["Correct Answer"]) 47 | correct_answer = "ABCD"[correct_index] 48 | choices_dict = dict( 49 | A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"] 50 | ) 51 | prompt_messages = [ 52 | sampler._pack_message( 53 | content=format_multichoice_question(choices_dict), role="user" 54 | ) 55 | ] 56 | sampler_response = sampler(prompt_messages) 57 | response_text = sampler_response.response_text 58 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 59 | match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) 60 | extracted_answer = match.group(1) if match else None 61 | score = 1.0 if extracted_answer == correct_answer else 0.0 62 | html = common.jinja_env.from_string(HTML_JINJA).render( 63 | prompt_messages=actual_queried_prompt_messages, 64 | next_message=dict(content=response_text, role="assistant"), 65 | score=score, 66 | correct_answer=correct_answer, 67 | extracted_answer=extracted_answer, 68 | ) 69 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] 70 | return SingleEvalResult( 71 | html=html, score=score, convo=convo, metrics={"chars": len(response_text)} 72 | ) 73 | 74 | results = common.map_with_progress(fn, self.examples) 75 | return common.aggregate_results(results) 76 | -------------------------------------------------------------------------------- /healthbench_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script evaluates the performance of a model on the HealthBench dataset. 3 | 4 | To run HealthBench, HealthBench Consensus, or HealthBench Hard, use the simple-evals script: 5 | - `python -m simple-evals.simple_evals --eval=healthbench --model=gpt-4.1` 6 | - `python -m simple-evals.simple_evals --eval=healthbench_consensus --model=gpt-4.1` 7 | - `python -m simple-evals.simple_evals --eval=healthbench_hard --model=gpt-4.1` 8 | 9 | You can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command: 10 | - To evaluate physician ideal completions: `python -m simple-evals.healthbench_eval --run_mode=physician_completions` 11 | - To evaluate reference model completions used by physicians: `python -m simple-evals.healthbench_eval --run_mode=physician_completion_references` 12 | """ 13 | 14 | import argparse 15 | import copy 16 | import hashlib 17 | import json 18 | import random 19 | import re 20 | from collections import defaultdict 21 | from datetime import datetime 22 | from pathlib import Path 23 | from typing import Literal 24 | 25 | import blobfile as bf 26 | import numpy as np 27 | import pandas as pd 28 | 29 | from . import common 30 | from .sampler.chat_completion_sampler import ( 31 | OPENAI_SYSTEM_MESSAGE_API, 32 | ChatCompletionSampler, 33 | ) 34 | from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult 35 | 36 | INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl" 37 | INPUT_PATH_HARD = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl" 38 | INPUT_PATH_CONSENSUS = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl" 39 | 40 | GRADER_TEMPLATE = """ 41 | Your job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item. 42 | 43 | # Conversation 44 | <> 45 | 46 | # Rubric item 47 | <> 48 | 49 | # Instructions 50 | Return a json object with the following fields: "explanation" and "criteria_met". 51 | - The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item. 52 | - The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met. 53 | - One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria. 54 | 55 | # Example 1 56 | For example, if the conversation is "user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better" and the rubric item is "tells the user to go to the hospital if unconscious", you should return a json like this: 57 | 58 | ```json 59 | { 60 | "explanation": "The assistant did not tell the user to go to the hospital if unconscious.", 61 | "criteria_met": false 62 | } 63 | ``` 64 | 65 | # Example 2 66 | Note that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not. 67 | For example, consider the conversation is "user: is smoking good for you? assistant: no, smoking is bad for you" and the criterion is "fails to give the user accurate information", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this: 68 | 69 | ```json 70 | { 71 | "explanation": "The assistant actually gave the user accurate information, so the criteria was not met.", 72 | "criteria_met": false 73 | } 74 | ``` 75 | 76 | As another example, if the critera says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this: 77 | 78 | ```json 79 | { 80 | "explanation": "The response is concise, so the criteria was not met.", 81 | "criteria_met": false 82 | } 83 | ``` 84 | 85 | In other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true. 86 | 87 | # Final instruction 88 | Return just the json object in markdown format. Do not include any other text in the response. 89 | """.strip() 90 | 91 | HEALTHBENCH_HTML_JINJA = ( 92 | common.HTML_JINJA.replace( 93 | "

Correct Answer: {{ correct_answer }}

\n", 94 | "", 95 | ) 96 | + "

Rubrics with grades: {{ rubric_grades }}

" 97 | ) 98 | 99 | 100 | def parse_json_to_dict(json_string: str) -> dict: 101 | # Remove markdown-style ```json``` markers if present 102 | json_cleaned = re.sub(r"^```json\s*|\s*```$", "", json_string.strip()) 103 | 104 | try: 105 | return json.loads(json_cleaned) 106 | except json.JSONDecodeError as e: 107 | print(f"JSON decoding failed: {e}") 108 | return {} 109 | 110 | 111 | class RubricItem: 112 | def __init__(self, criterion: str, points: float, tags: list[str]): 113 | self.criterion = criterion 114 | self.points = points 115 | self.tags = tags 116 | 117 | def __str__(self): 118 | return f"[{self.points}] {self.criterion}" 119 | 120 | def to_dict(self): 121 | return { 122 | "criterion": self.criterion, 123 | "points": self.points, 124 | "tags": self.tags, 125 | } 126 | 127 | @classmethod 128 | def from_dict(cls, d: dict): 129 | return cls( 130 | criterion=d["criterion"], 131 | points=d["points"], 132 | tags=d["tags"], 133 | ) 134 | 135 | 136 | def calculate_score( 137 | rubric_items: list[RubricItem], grading_response_list: list[dict] 138 | ) -> float | None: 139 | total_possible_points = sum( 140 | rubric_item.points for rubric_item in rubric_items if rubric_item.points > 0 141 | ) 142 | if total_possible_points == 0: 143 | # should not happen for overall score, but may happen for tags 144 | return None 145 | 146 | achieved_points = sum( 147 | rubric_item.points 148 | for rubric_item, grading_response in zip( 149 | rubric_items, grading_response_list, strict=True 150 | ) 151 | if grading_response["criteria_met"] 152 | ) 153 | overall_score = achieved_points / total_possible_points 154 | return overall_score 155 | 156 | 157 | def get_usage_dict(response_usage) -> dict[str, int | None]: 158 | if response_usage is None: 159 | return { 160 | "input_tokens": None, 161 | "input_cached_tokens": None, 162 | "output_tokens": None, 163 | "output_reasoning_tokens": None, 164 | "total_tokens": None, 165 | } 166 | 167 | try: 168 | return { 169 | "input_tokens": response_usage.input_tokens, 170 | "input_cached_tokens": response_usage.input_tokens_details.cached_tokens 171 | if hasattr(response_usage.input_tokens_details, "cached_tokens") 172 | else response_usage.input_tokens_details["cached_tokens"], 173 | "output_tokens": response_usage.output_tokens, 174 | "output_reasoning_tokens": response_usage.output_tokens_details.reasoning_tokens 175 | if hasattr(response_usage.output_tokens_details, "reasoning_tokens") 176 | else response_usage.output_tokens_details["reasoning_tokens"], 177 | "total_tokens": response_usage.total_tokens, 178 | } 179 | except AttributeError: 180 | return { 181 | "input_tokens": response_usage.prompt_tokens, 182 | "input_cached_tokens": response_usage.prompt_tokens_details.cached_tokens 183 | if hasattr(response_usage.prompt_tokens_details, "cached_tokens") 184 | else response_usage.prompt_tokens_details["cached_tokens"], 185 | "output_tokens": response_usage.completion_tokens, 186 | "output_reasoning_tokens": response_usage.completion_tokens_details.reasoning_tokens 187 | if hasattr(response_usage.completion_tokens_details, "reasoning_tokens") 188 | else response_usage.completion_tokens_details["reasoning_tokens"], 189 | "total_tokens": response_usage.total_tokens, 190 | } 191 | 192 | 193 | PHYSICIAN_COMPLETION_MODES = { 194 | "Group 1": { 195 | "description": "No reference completions were provided to the physicians.", 196 | "short_name": "no_reference", 197 | "has_reference": False, 198 | }, 199 | "Group 2": { 200 | "description": "Reference completions were provided to the physicians from Aug / Sep 2024 models (gpt-4o-2024-08-06, o1-preview).", 201 | "short_name": "aug_2024_reference", 202 | "has_reference": True, 203 | }, 204 | "Group 3": { 205 | "description": "Reference completions were provided to the physicians from Apr 2025 models (o3, gpt-4.1).", 206 | "short_name": "apr_2025_reference", 207 | "has_reference": True, 208 | }, 209 | } 210 | 211 | 212 | def _compute_clipped_stats( 213 | values: list, 214 | stat: str, 215 | ): 216 | """Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and n_samples for final HealthBench scoring.""" 217 | if stat == "mean": 218 | return np.clip(np.mean(values), 0, 1) 219 | elif stat == "n_samples": 220 | return len(values) 221 | elif stat == "bootstrap_std": 222 | bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)] 223 | bootstrap_means = [ 224 | _compute_clipped_stats(list(s), "mean") for s in bootstrap_samples 225 | ] 226 | return np.std(bootstrap_means) 227 | else: 228 | raise ValueError(f"Unknown {stat =}") 229 | 230 | 231 | def _aggregate_get_clipped_mean( 232 | single_eval_results: list[SingleEvalResult], 233 | ) -> EvalResult: 234 | """ 235 | Aggregate multiple SingleEvalResults into a single EvalResult for HealthBench. 236 | For each metric, returns the stats in _compute_clipped_stats. 237 | """ 238 | name2values = defaultdict(list) 239 | htmls = [] 240 | convos = [] 241 | metadata = [] 242 | for single_eval_result in single_eval_results: 243 | for name, value in single_eval_result.metrics.items(): 244 | name2values[name].append(value) 245 | if single_eval_result.score is not None: 246 | name2values["score"].append(single_eval_result.score) 247 | htmls.append(single_eval_result.html) 248 | convos.append(single_eval_result.convo) 249 | metadata.append(single_eval_result.example_level_metadata) 250 | final_metrics = {} 251 | for name, values in name2values.items(): 252 | for stat in ["mean", "n_samples", "bootstrap_std"]: 253 | key = name if stat == "mean" else f"{name}:{stat}" 254 | final_metrics[key] = _compute_clipped_stats(values, stat) 255 | return EvalResult( 256 | score=final_metrics.pop("score", None), 257 | metrics=final_metrics, 258 | htmls=htmls, 259 | convos=convos, 260 | metadata={"example_level_metadata": metadata}, 261 | ) 262 | 263 | 264 | class HealthBenchEval(Eval): 265 | def __init__( 266 | self, 267 | grader_model: SamplerBase, 268 | num_examples: int | None = None, 269 | n_repeats: int = 1, 270 | # If set, evaluate human completions or reference completions instead of model completions. 271 | physician_completions_mode: str | None = None, 272 | # If True, run the grader on reference completions used by physicians, and physician_completions_mode must be set. 273 | run_reference_completions: bool = False, 274 | n_threads: int = 120, 275 | subset_name: Literal["hard", "consensus"] | None = None, 276 | ): 277 | if run_reference_completions: 278 | assert physician_completions_mode is not None, ( 279 | "physician_completions_mode must be provided if run_reference_completions is True" 280 | ) 281 | assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][ 282 | "has_reference" 283 | ], ( 284 | "physician_completions_mode must have reference completions if run_reference_completions is True" 285 | ) 286 | 287 | if subset_name == "hard": 288 | input_path = INPUT_PATH_HARD 289 | elif subset_name == "consensus": 290 | input_path = INPUT_PATH_CONSENSUS 291 | elif subset_name is None: 292 | input_path = INPUT_PATH 293 | else: 294 | assert False, f"Invalid subset name: {subset_name}" 295 | with bf.BlobFile(input_path, "rb") as f: 296 | examples = [json.loads(line) for line in f] 297 | for example in examples: 298 | example["rubrics"] = [RubricItem.from_dict(d) for d in example["rubrics"]] 299 | 300 | rng = random.Random(0) 301 | 302 | # physician completions mode 303 | self.physician_completions_mode = physician_completions_mode 304 | if self.physician_completions_mode is not None: 305 | assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, ( 306 | f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}" 307 | ) 308 | # subset to only the rows which have physician completions from that group 309 | examples_matching_mode = [ 310 | example 311 | for example in examples 312 | if example["ideal_completions_data"] is not None 313 | and example["ideal_completions_data"]["ideal_completions_group"] 314 | == self.physician_completions_mode 315 | ] 316 | print( 317 | f"Subsetting to {len(examples_matching_mode)} examples with physician completions of type {self.physician_completions_mode} ({PHYSICIAN_COMPLETION_MODES[self.physician_completions_mode]['description']})" 318 | ) 319 | 320 | examples = [] 321 | if run_reference_completions: 322 | for example in examples_matching_mode: 323 | for completion in example["ideal_completions_data"][ 324 | "ideal_completions_ref_completions" 325 | ]: 326 | new_example = copy.deepcopy(example) 327 | new_example["completion_to_trial"] = completion 328 | examples.append(new_example) 329 | assert len(examples) == len(examples_matching_mode) * 4 330 | print( 331 | f"Running four references for each example, for {len(examples)} total" 332 | ) 333 | else: 334 | for example in examples_matching_mode: 335 | example["completion_to_trial"] = example["ideal_completions_data"][ 336 | "ideal_completion" 337 | ] 338 | examples.append(example) 339 | assert len(examples) == len(examples_matching_mode) 340 | 341 | if len(examples) == 0: 342 | raise ValueError( 343 | f"No examples found matching mode {self.physician_completions_mode}" 344 | ) 345 | 346 | if num_examples is not None and num_examples < len(examples): 347 | examples = rng.sample( 348 | examples, 349 | num_examples, 350 | ) 351 | 352 | self.examples = examples * n_repeats 353 | self.n_threads = n_threads 354 | self.grader_model = grader_model 355 | 356 | def grade_sample( 357 | self, 358 | prompt: list[dict[str, str]], 359 | response_text: str, 360 | example_tags: list[str], 361 | rubric_items: list[RubricItem], 362 | ) -> tuple[dict, str, list[dict]]: 363 | # construct and grade the sample 364 | convo_with_response = prompt + [dict(content=response_text, role="assistant")] 365 | 366 | def grade_rubric_item(rubric_item: RubricItem) -> dict: 367 | convo_str = "\n\n".join( 368 | [f"{m['role']}: {m['content']}" for m in convo_with_response] 369 | ) 370 | grader_prompt = GRADER_TEMPLATE.replace( 371 | "<>", convo_str 372 | ).replace("<>", str(rubric_item)) 373 | messages: MessageList = [dict(content=grader_prompt, role="user")] 374 | while True: 375 | sampler_response = self.grader_model(messages) 376 | grading_response = sampler_response.response_text 377 | grading_response_dict = parse_json_to_dict(grading_response) 378 | if "criteria_met" in grading_response_dict: 379 | label = grading_response_dict["criteria_met"] 380 | if label is True or label is False: 381 | break 382 | print("Grading failed due to bad JSON output, retrying...") 383 | return grading_response_dict 384 | 385 | grading_response_list = common.map_with_progress( 386 | grade_rubric_item, 387 | rubric_items, 388 | pbar=False, 389 | ) 390 | 391 | # compute the overall score 392 | overall_score = calculate_score(rubric_items, grading_response_list) 393 | assert overall_score is not None 394 | metrics = { 395 | "overall_score": overall_score, 396 | } 397 | 398 | # compute scores for example-level tags) 399 | example_tag_scores = {tag: overall_score for tag in example_tags} 400 | assert len(example_tag_scores) == len(example_tags) # No duplicates. 401 | metrics.update(example_tag_scores) 402 | 403 | # compute scores for rubric-level tags 404 | rubric_tag_items_grades = defaultdict(list) 405 | for rubric_item, grading_response in zip(rubric_items, grading_response_list): 406 | curr_item_tags = set() # Ensure no duplicates in a rubric item. 407 | for tag in rubric_item.tags: 408 | rubric_tag_items_grades[tag].append((rubric_item, grading_response)) 409 | assert tag not in curr_item_tags 410 | curr_item_tags.add(tag) 411 | 412 | rubric_tag_scores = {} 413 | for tag, items_grades in rubric_tag_items_grades.items(): 414 | items, grades = zip(*items_grades) 415 | score = calculate_score(items, grades) 416 | if score is not None: # implies at least one positive criterion 417 | rubric_tag_scores[tag] = score 418 | metrics.update(rubric_tag_scores) 419 | 420 | # construct the list of explanations and grades 421 | rubric_items_with_grades = [] 422 | readable_explanation_list = [] 423 | for rubric_item, grading_response in zip(rubric_items, grading_response_list): 424 | explanation = grading_response.get("explanation", "No explanation provided") 425 | criteria_met = grading_response["criteria_met"] 426 | readable_explanation = ( 427 | f"[{criteria_met}] {rubric_item}\n\tExplanation: {explanation}" 428 | ) 429 | readable_explanation_list.append(readable_explanation) 430 | rubric_items_with_grades.append( 431 | { 432 | **rubric_item.to_dict(), 433 | "criteria_met": criteria_met, 434 | "explanation": explanation, 435 | } 436 | ) 437 | 438 | readable_explanation_list.sort( 439 | key=lambda x: x.startswith("[False]"), reverse=True 440 | ) 441 | readable_explanation_str = "\n\n".join(readable_explanation_list) 442 | readable_explanation_str = f"\n\n{readable_explanation_str}" 443 | 444 | return metrics, readable_explanation_str, rubric_items_with_grades 445 | 446 | def __call__(self, sampler: SamplerBase) -> EvalResult: 447 | def fn(row: dict): 448 | prompt_messages = row["prompt"] 449 | 450 | if self.physician_completions_mode is not None: 451 | response_text = row["completion_to_trial"] 452 | response_usage = None 453 | actual_queried_prompt_messages = prompt_messages 454 | else: 455 | sampler_response = sampler(prompt_messages) 456 | response_text = sampler_response.response_text 457 | response_dict = sampler_response.response_metadata 458 | actual_queried_prompt_messages = ( 459 | sampler_response.actual_queried_message_list 460 | ) 461 | response_usage = response_dict.get("usage", None) 462 | 463 | metrics, readable_explanation_str, rubric_items_with_grades = ( 464 | self.grade_sample( 465 | prompt=actual_queried_prompt_messages, 466 | response_text=response_text, 467 | rubric_items=row["rubrics"], 468 | example_tags=row["example_tags"], 469 | ) 470 | ) 471 | 472 | score = metrics["overall_score"] 473 | 474 | # Create HTML for each sample result 475 | html = common.jinja_env.from_string( 476 | HEALTHBENCH_HTML_JINJA.replace( 477 | "{{ rubric_grades }}", 478 | readable_explanation_str.replace("\n", "
"), 479 | ) 480 | ).render( 481 | prompt_messages=actual_queried_prompt_messages, 482 | next_message=dict(content=response_text, role="assistant"), 483 | score=metrics["overall_score"], 484 | extracted_answer=response_text, 485 | ) 486 | 487 | convo = actual_queried_prompt_messages + [ 488 | dict(content=response_text, role="assistant") 489 | ] 490 | return SingleEvalResult( 491 | html=html, 492 | score=score, 493 | convo=convo, 494 | metrics=metrics, 495 | example_level_metadata={ 496 | "score": score, 497 | "usage": get_usage_dict(response_usage), 498 | "rubric_items": rubric_items_with_grades, 499 | "prompt": actual_queried_prompt_messages, 500 | "completion": [dict(content=response_text, role="assistant")], 501 | "prompt_id": row["prompt_id"], 502 | "completion_id": hashlib.sha256( 503 | (row["prompt_id"] + response_text).encode("utf-8") 504 | ).hexdigest(), 505 | }, 506 | ) 507 | 508 | results = common.map_with_progress( 509 | fn, 510 | self.examples, 511 | num_threads=self.n_threads, 512 | pbar=True, 513 | ) 514 | final_metrics = _aggregate_get_clipped_mean(results) 515 | return final_metrics 516 | 517 | 518 | def main(): 519 | parser = argparse.ArgumentParser( 520 | description="HealthBenchEval specific run options, including e.g., running the eval on physician completions rows only." 521 | ) 522 | parser.add_argument( 523 | "--run_mode", 524 | type=str, 525 | choices=["physician_completions", "physician_completion_references"], 526 | ) 527 | parser.add_argument("--examples", type=int, help="Number of examples to run") 528 | parser.add_argument( 529 | "--n-threads", 530 | type=int, 531 | default=120, 532 | help="Number of threads to run", 533 | ) 534 | args = parser.parse_args() 535 | 536 | if args.run_mode == "physician_completions": 537 | physician_completions_main( 538 | run_reference_completions=False, 539 | num_examples=args.examples, 540 | n_threads=args.n_threads or 1, 541 | ) 542 | elif args.run_mode == "physician_completion_references": 543 | physician_completions_main( 544 | run_reference_completions=True, 545 | num_examples=args.examples, 546 | n_threads=args.n_threads or 1, 547 | ) 548 | 549 | else: 550 | raise ValueError(f"Invalid run mode: {args.run_mode}") 551 | 552 | 553 | def physician_completions_main( 554 | run_reference_completions: bool = False, 555 | num_examples: int | None = None, 556 | n_threads: int = 120, 557 | ): 558 | now = datetime.now() 559 | date_str = now.strftime("%Y%m%d_%H%M") 560 | 561 | grading_sampler = ChatCompletionSampler( 562 | model="gpt-4.1-2025-04-14", 563 | system_message=OPENAI_SYSTEM_MESSAGE_API, 564 | max_tokens=2048, 565 | ) 566 | dummy_sampler = SamplerBase() 567 | 568 | merge_metrics = [] 569 | for pc_mode in PHYSICIAN_COMPLETION_MODES.keys(): 570 | if ( 571 | run_reference_completions 572 | and not PHYSICIAN_COMPLETION_MODES[pc_mode]["has_reference"] 573 | ): 574 | continue 575 | 576 | # run 577 | eval = HealthBenchEval( 578 | grader_model=grading_sampler, 579 | physician_completions_mode=pc_mode, 580 | run_reference_completions=run_reference_completions, 581 | num_examples=num_examples, 582 | n_threads=n_threads, 583 | ) 584 | result = eval(dummy_sampler) 585 | 586 | # report 587 | parsable_mode = PHYSICIAN_COMPLETION_MODES[pc_mode]["short_name"] 588 | if run_reference_completions: 589 | file_stem = f"healthbench_{parsable_mode}_referencecompletions_{date_str}" 590 | else: 591 | file_stem = f"healthbench_{parsable_mode}_humanbaseline_{date_str}" 592 | report_filename = Path(f"/tmp/{file_stem}.html") 593 | report_filename.write_text(common.make_report(result)) 594 | print(f"Report saved to {report_filename}") 595 | 596 | # metrics 597 | assert result.metrics is not None 598 | metrics = result.metrics 599 | result_filename = Path(f"/tmp/{file_stem}.json") 600 | result_filename.write_text(json.dumps(metrics)) 601 | print(f"Results saved to {result_filename}") 602 | 603 | full_result_dict = { 604 | "score": result.score, 605 | "metrics": result.metrics, 606 | "htmls": result.htmls, 607 | "convos": result.convos, 608 | "metadata": result.metadata, 609 | } 610 | full_result_filename = Path(f"/tmp/{file_stem}_allresults.json") 611 | full_result_filename.write_text(json.dumps(full_result_dict, indent=2)) 612 | print(f"All results saved to {full_result_filename}") 613 | 614 | # metrics df 615 | merge_metrics.append( 616 | { 617 | "eval_name": "healthbench", 618 | "model_name": f"{pc_mode} ({PHYSICIAN_COMPLETION_MODES[pc_mode]['description']})", 619 | "metric": metrics.get("overall_score", None), 620 | } 621 | ) 622 | 623 | merge_metrics_df = pd.DataFrame(merge_metrics).pivot( 624 | index=["model_name"], columns="eval_name" 625 | ) 626 | print("\nAll results: ") 627 | print(merge_metrics_df.to_markdown()) 628 | return merge_metrics 629 | 630 | 631 | if __name__ == "__main__": 632 | main() 633 | -------------------------------------------------------------------------------- /healthbench_eval_test.py: -------------------------------------------------------------------------------- 1 | from .healthbench_eval import RubricItem, calculate_score 2 | 3 | 4 | def test_calculate_score(): 5 | rubric_items = [ 6 | RubricItem(criterion="test", points=7, tags=[]), 7 | RubricItem(criterion="test", points=5, tags=[]), 8 | RubricItem(criterion="test", points=10, tags=[]), 9 | RubricItem(criterion="test", points=-6, tags=[]), 10 | ] 11 | grading_response_list = [ 12 | {"criteria_met": True}, 13 | {"criteria_met": False}, 14 | {"criteria_met": True}, 15 | {"criteria_met": True}, 16 | ] 17 | total_possible = 7 + 5 + 10 18 | achieved = 7 + 0 + 10 - 6 19 | assert ( 20 | calculate_score(rubric_items, grading_response_list) 21 | == achieved / total_possible 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | test_calculate_score() 27 | -------------------------------------------------------------------------------- /healthbench_meta_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script evaluates a grader model on grading HealthBench rubrics. It effectively 3 | evaluates the evaluator against physician opinion, so we call it a meta-evaluation. 4 | 5 | To run, use the following command (working directory should contain simple-evals folder): 6 | `python -m simple-evals.simple_evals --eval=healthbench_meta --model=gpt-4.1` 7 | """ 8 | 9 | import json 10 | import random 11 | from collections import defaultdict 12 | from typing import Literal 13 | 14 | import blobfile as bf 15 | 16 | from . import common 17 | from .healthbench_eval import GRADER_TEMPLATE, parse_json_to_dict 18 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 19 | 20 | INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_meta_eval.jsonl" 21 | INDEX_STR_TEMPLATE = "pairwise_{model_or_physician}_{metric}_{pred_str}" 22 | CLUSTER_STR_TEMPLATE = "{cluster}: {index_str}" 23 | 24 | HEALTHBENCH_META_HTML_JINJA = ( 25 | common.HTML_JINJA.replace( 26 | "

Correct Answer: {{ correct_answer }}

\n", 27 | "", 28 | ) 29 | + "

Explanation for grader's label: {{ explanation }}

" 30 | ) 31 | 32 | 33 | class HealthBenchMetaEval(Eval): 34 | def __init__( 35 | self, 36 | grader_model: SamplerBase, 37 | num_examples: int | None = None, 38 | n_threads: int = 120, 39 | n_repeats: int = 1, 40 | ): 41 | with bf.BlobFile(INPUT_PATH, "rb") as f: 42 | examples = [json.loads(line) for line in f] 43 | print(f"Loaded {len(examples)} examples from {INPUT_PATH}") 44 | 45 | rng = random.Random(0) 46 | 47 | if num_examples is not None and len(examples) > num_examples: 48 | examples = rng.sample(examples, num_examples) 49 | 50 | self.examples = examples * n_repeats 51 | self.grader_model = grader_model 52 | self.n_threads = n_threads 53 | 54 | def grade_sample( 55 | self, 56 | grading_response_dict: dict, 57 | physician_labels: list[bool], 58 | category: str, 59 | ) -> tuple[dict, bool | None, str]: 60 | metrics = { 61 | "num_physician_labels": len(physician_labels), 62 | "percent_physician_pos": sum(physician_labels) / len(physician_labels), 63 | } 64 | 65 | grader_label = grading_response_dict["criteria_met"] 66 | assert grader_label is True or grader_label is False 67 | metrics["model_predicted_positive"] = grader_label 68 | explanation = grading_response_dict.get( 69 | "explanation", "No explanation provided" 70 | ) 71 | 72 | category_metrics = {f"{category}: {k}": v for k, v in metrics.items()} 73 | metrics = {**metrics, **category_metrics} 74 | return metrics, grader_label, explanation 75 | 76 | def __call__(self, sampler: SamplerBase) -> EvalResult: 77 | def fn(row: dict) -> tuple[SingleEvalResult, bool | None]: 78 | convo_with_response = row["prompt"] + [ 79 | dict(content=row["completion"], role="assistant") 80 | ] 81 | prompt_str = "\n\n".join( 82 | [f"{m['role']}: {m['content']}" for m in convo_with_response] 83 | ) 84 | grader_prompt = GRADER_TEMPLATE.replace("<>", prompt_str) 85 | grader_prompt = grader_prompt.replace("<>", row["rubric"]) 86 | grader_convo = [dict(content=grader_prompt, role="user")] 87 | 88 | while True: 89 | sampler_response = sampler(grader_convo) 90 | response_text = sampler_response.response_text 91 | actual_queried_grader_convo = ( 92 | sampler_response.actual_queried_message_list 93 | ) 94 | grading_response_dict = parse_json_to_dict(response_text) 95 | if "criteria_met" in grading_response_dict: 96 | label = grading_response_dict["criteria_met"] 97 | if label is True or label is False: 98 | break 99 | print("Grading failed due to bad JSON output, retrying...") 100 | 101 | metrics, grader_label, explanation = self.grade_sample( 102 | grading_response_dict=grading_response_dict, 103 | physician_labels=row["binary_labels"], 104 | category=row["category"], 105 | ) 106 | score = metrics["model_predicted_positive"] 107 | 108 | # Create HTML for each sample result 109 | html = common.jinja_env.from_string(HEALTHBENCH_META_HTML_JINJA).render( 110 | prompt_messages=actual_queried_grader_convo, 111 | next_message=dict(content=response_text, role="assistant"), 112 | score=metrics["model_predicted_positive"], 113 | extracted_answer=response_text, 114 | explanation=explanation, 115 | ) 116 | convo = actual_queried_grader_convo + [ 117 | dict(content=response_text, role="assistant") 118 | ] 119 | return ( 120 | SingleEvalResult(html=html, score=score, convo=convo, metrics=metrics), 121 | grader_label, 122 | ) 123 | 124 | # Run evaluation and collect results 125 | all_outputs = common.map_with_progress(fn, self.examples, self.n_threads) 126 | results: list[SingleEvalResult] 127 | grader_labels: list[bool] 128 | results, grader_labels = zip(*all_outputs) 129 | 130 | # model pairwise agreement metrics 131 | model_agreement_metrics = compute_metrics_for_rater_by_class( 132 | self_pred_list=grader_labels, 133 | other_preds_list=[x["binary_labels"] for x in self.examples], 134 | cluster_list=[x["category"] for x in self.examples], 135 | model_or_physician="model", 136 | ) 137 | 138 | # physicians: 139 | physician_rating_lists = defaultdict(lambda: ([], [], [])) 140 | for example in self.examples: 141 | for i in range(len(example["binary_labels"])): 142 | physician_id = example["anonymized_physician_ids"][i] 143 | self_pred = example["binary_labels"][i] 144 | other_preds = ( 145 | example["binary_labels"][:i] + example["binary_labels"][i + 1 :] 146 | ) 147 | cluster = example["category"] 148 | physician_rating_lists[physician_id][0].append(self_pred) 149 | physician_rating_lists[physician_id][1].append(other_preds) 150 | physician_rating_lists[physician_id][2].append(cluster) 151 | 152 | physician_agreement_metric_lists = defaultdict(dict) 153 | for physician_id, ( 154 | physician_rating_list, 155 | other_preds_list, 156 | cluster_list, 157 | ) in physician_rating_lists.items(): 158 | physician_agreement_metrics = compute_metrics_for_rater_by_class( 159 | self_pred_list=physician_rating_list, 160 | other_preds_list=other_preds_list, 161 | cluster_list=cluster_list, 162 | model_or_physician="physician", 163 | ) 164 | for k, v in physician_agreement_metrics.items(): 165 | physician_agreement_metric_lists[k][physician_id] = v 166 | 167 | # consolidate final metrics and add agreement metrics 168 | final_metrics = common.aggregate_results( 169 | results, default_stats=("mean", "n_samples", "bootstrap_std") 170 | ) 171 | model_agreement_metrics_condensed: dict[str, float] = { 172 | k: v["value"] 173 | for k, v in model_agreement_metrics.items() 174 | if v["value"] is not None 175 | } 176 | assert final_metrics.metrics is not None 177 | final_metrics.metrics.update(model_agreement_metrics_condensed) 178 | final_metrics.score = final_metrics.metrics["pairwise_model_f1_balanced"] 179 | 180 | final_metrics.metadata = { 181 | "model_agreement_metrics": model_agreement_metrics, 182 | "physician_agreement_metric_lists": physician_agreement_metric_lists, 183 | } 184 | return final_metrics 185 | 186 | 187 | def compute_metrics_for_rater_by_class( 188 | self_pred_list: list[bool], 189 | other_preds_list: list[list[bool]], 190 | cluster_list: list[str], 191 | model_or_physician: Literal["model", "physician"], 192 | ) -> dict[str, dict[str, float | None]]: 193 | # get all the metrics for each cluster 194 | metric_lists = defaultdict(list) 195 | for self_pred, other_preds, cluster in zip( 196 | self_pred_list, other_preds_list, cluster_list, strict=True 197 | ): 198 | self_pred_str = "pos" if self_pred else "neg" 199 | for other_pred in other_preds: 200 | # precision. based on the grader's labels - 201 | # i.e., calculated as TP / (TP + FP) 202 | # so a prediction should be recorded whenever self_pred is True 203 | precision_index_str = INDEX_STR_TEMPLATE.format( 204 | model_or_physician=model_or_physician, 205 | metric="precision", 206 | pred_str=self_pred_str, 207 | ) 208 | metric_lists[precision_index_str].append(self_pred == other_pred) 209 | precision_cluster_str = CLUSTER_STR_TEMPLATE.format( 210 | cluster=cluster, index_str=precision_index_str 211 | ) 212 | metric_lists[precision_cluster_str].append(self_pred == other_pred) 213 | 214 | # recall. based on the ground truth labels - 215 | # i.e., calculated as TP / (TP + FN) 216 | # so a prediction should be recorded whenever other_pred is True 217 | other_pred_str = "pos" if other_pred else "neg" 218 | recall_index_str = INDEX_STR_TEMPLATE.format( 219 | model_or_physician=model_or_physician, 220 | metric="recall", 221 | pred_str=other_pred_str, 222 | ) 223 | metric_lists[recall_index_str].append(self_pred == other_pred) 224 | recall_cluster_str = CLUSTER_STR_TEMPLATE.format( 225 | cluster=cluster, index_str=recall_index_str 226 | ) 227 | metric_lists[recall_cluster_str].append(self_pred == other_pred) 228 | 229 | metrics: dict[str, dict[str, float | None]] = {} 230 | for index_str, metric_list in metric_lists.items(): 231 | n = len(metric_list) 232 | metric = sum(metric_list) / n if n > 0 else None 233 | metrics[index_str] = { 234 | "n": n, 235 | "value": metric, 236 | } 237 | 238 | f1_metrics = get_f1_metrics(metrics) 239 | metrics.update(f1_metrics) 240 | 241 | balanced_metrics = get_balanced_metrics(metrics) 242 | metrics.update(balanced_metrics) 243 | 244 | return metrics 245 | 246 | 247 | def get_f1_metrics( 248 | metrics: dict[str, dict[str, float | None]], 249 | ) -> dict[str, dict[str, float | None]]: 250 | f1_metrics: dict[str, dict[str, float | None]] = {} 251 | for precision_key_name in metrics: 252 | if "precision" in precision_key_name: 253 | recall_key_name = precision_key_name.replace("precision", "recall") 254 | if recall_key_name not in metrics: 255 | continue 256 | f1_key_name = precision_key_name.replace("precision", "f1") 257 | assert f1_key_name not in metrics 258 | f1_metrics[f1_key_name] = compute_f1_metric( 259 | precision=metrics[precision_key_name], 260 | recall=metrics[recall_key_name], 261 | ) 262 | 263 | return f1_metrics 264 | 265 | 266 | def compute_f1_metric( 267 | precision: dict[str, float | None], 268 | recall: dict[str, float | None], 269 | ) -> dict[str, float | None]: 270 | precision_n = precision["n"] 271 | recall_n = recall["n"] 272 | assert precision_n is not None and recall_n is not None, "n_pos or n_neg is None" 273 | 274 | precision_metric = precision["value"] 275 | recall_metric = recall["value"] 276 | if precision_metric is None or recall_metric is None: 277 | f1_metric = None 278 | n_f1 = ( 279 | precision_n + recall_n 280 | ) # precision_metric is None iff precision_n = 0 and recall_metric is None iff recall_n = 0, so if either is zero this gives TP + FN + FP without double counting 281 | elif precision_metric == 0 and recall_metric == 0: 282 | f1_metric = 0.0 283 | tp = precision_metric * precision_n # because precision = TP / (TP+FP) 284 | n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP 285 | else: 286 | f1_metric = ( 287 | 2 * (precision_metric * recall_metric) / (precision_metric + recall_metric) 288 | ) 289 | tp = precision_metric * precision_n # because precision = TP / (TP+FP) 290 | n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP 291 | 292 | return { 293 | "n": n_f1, 294 | "value": f1_metric, 295 | } 296 | 297 | 298 | def get_balanced_metrics( 299 | metrics: dict[str, dict[str, float | None]], 300 | ) -> dict[str, dict[str, float | None]]: 301 | balanced_metrics: dict[str, dict[str, float | None]] = {} 302 | for pos_key_name in metrics: 303 | if "pos" in pos_key_name: 304 | neg_key_name = pos_key_name.replace("pos", "neg") 305 | if neg_key_name not in metrics: 306 | continue 307 | balanced_key_name = pos_key_name.replace("pos", "balanced") 308 | assert balanced_key_name not in metrics 309 | balanced_metrics[balanced_key_name] = compute_balanced_metric( 310 | metric_pos=metrics[pos_key_name], 311 | metric_neg=metrics[neg_key_name], 312 | ) 313 | 314 | return balanced_metrics 315 | 316 | 317 | def compute_balanced_metric( 318 | metric_pos: dict[str, float | None], 319 | metric_neg: dict[str, float | None], 320 | ) -> dict[str, float | None]: 321 | n_pos = metric_pos["n"] 322 | n_neg = metric_neg["n"] 323 | assert n_pos is not None and n_neg is not None, "n_pos or n_neg is None" 324 | 325 | pos_metric = metric_pos["value"] 326 | neg_metric = metric_neg["value"] 327 | if pos_metric is None or neg_metric is None: 328 | metric = None 329 | else: 330 | metric = (pos_metric + neg_metric) / 2 331 | 332 | return { 333 | "n": n_pos + n_neg, 334 | # note: this overcounts samples going towards the balanced F1 335 | "value": metric, 336 | } 337 | -------------------------------------------------------------------------------- /healthbench_meta_eval_test.py: -------------------------------------------------------------------------------- 1 | from . import healthbench_meta_eval 2 | 3 | 4 | def test_compute_agreement_for_rater_by_class(): 5 | self_pred_list = [True, False, True] 6 | other_preds_list = [[True, True, False], [True, False], [False]] 7 | cluster_list = ["a", "a", "b"] 8 | model_or_physician = "model" 9 | metrics = healthbench_meta_eval.compute_metrics_for_rater_by_class( 10 | self_pred_list, other_preds_list, cluster_list, model_or_physician 11 | ) 12 | 13 | # precision overall 14 | index_str_pos_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 15 | model_or_physician=model_or_physician, metric="precision", pred_str="pos" 16 | ) 17 | index_str_neg_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 18 | model_or_physician=model_or_physician, metric="precision", pred_str="neg" 19 | ) 20 | overall_pos_precision = metrics[index_str_pos_precision] 21 | overall_neg_precision = metrics[index_str_neg_precision] 22 | expected_overall_pos_precision = (2 + 0 + 0) / (3 + 0 + 1) 23 | expected_overall_neg_precision = (0 + 1 + 0) / (0 + 2 + 0) 24 | assert overall_pos_precision["value"] == expected_overall_pos_precision 25 | assert overall_neg_precision["value"] == expected_overall_neg_precision 26 | assert overall_pos_precision["n"] == 4 27 | assert overall_neg_precision["n"] == 2 28 | 29 | # recall overall 30 | index_str_pos_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 31 | model_or_physician=model_or_physician, metric="recall", pred_str="pos" 32 | ) 33 | index_str_neg_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 34 | model_or_physician=model_or_physician, metric="recall", pred_str="neg" 35 | ) 36 | overall_pos_recall = metrics[index_str_pos_recall] 37 | overall_neg_recall = metrics[index_str_neg_recall] 38 | expected_overall_pos_recall = (2 + 0 + 0) / (2 + 1 + 0) 39 | expected_overall_neg_recall = (0 + 1 + 0) / (1 + 1 + 1) 40 | assert overall_pos_recall["value"] == expected_overall_pos_recall 41 | assert overall_neg_recall["value"] == expected_overall_neg_recall 42 | assert overall_pos_recall["n"] == 3 43 | assert overall_neg_recall["n"] == 3 44 | 45 | # f1 overall 46 | index_str_pos_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 47 | model_or_physician=model_or_physician, metric="f1", pred_str="pos" 48 | ) 49 | index_str_neg_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 50 | model_or_physician=model_or_physician, metric="f1", pred_str="neg" 51 | ) 52 | overall_pos_f1 = metrics[index_str_pos_f1] 53 | overall_neg_f1 = metrics[index_str_neg_f1] 54 | expected_overall_pos_f1 = ( 55 | 2 56 | * expected_overall_pos_precision 57 | * expected_overall_pos_recall 58 | / (expected_overall_pos_precision + expected_overall_pos_recall) 59 | ) 60 | expected_overall_neg_f1 = ( 61 | 2 62 | * expected_overall_neg_precision 63 | * expected_overall_neg_recall 64 | / (expected_overall_neg_precision + expected_overall_neg_recall) 65 | ) 66 | assert overall_pos_f1["value"] == expected_overall_pos_f1 67 | assert overall_neg_f1["value"] == expected_overall_neg_f1 68 | 69 | # balanced f1 70 | index_str_balanced_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format( 71 | model_or_physician=model_or_physician, metric="f1", pred_str="balanced" 72 | ) 73 | balanced_f1 = metrics[index_str_balanced_f1] 74 | expected_balanced_f1 = (expected_overall_pos_f1 + expected_overall_neg_f1) / 2 75 | assert balanced_f1["value"] == expected_balanced_f1 76 | 77 | # by cluster 78 | # precision 79 | cluster_a_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 80 | cluster="a", index_str=index_str_pos_precision 81 | ) 82 | cluster_a_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 83 | cluster="a", index_str=index_str_neg_precision 84 | ) 85 | cluster_a_pos_precision = metrics[cluster_a_str_pos_precision] 86 | cluster_a_neg_precision = metrics[cluster_a_str_neg_precision] 87 | assert cluster_a_pos_precision["value"] == ( 88 | # example 1, 2 in order 89 | (2 + 0) / (3 + 0) 90 | ) 91 | assert cluster_a_neg_precision["value"] == ( 92 | # example 1, 2 in order 93 | (0 + 1) / (0 + 2) 94 | ) 95 | assert cluster_a_pos_precision["n"] == 3 96 | assert cluster_a_neg_precision["n"] == 2 97 | 98 | # recall 99 | cluster_a_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 100 | cluster="a", index_str=index_str_pos_recall 101 | ) 102 | cluster_a_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 103 | cluster="a", index_str=index_str_neg_recall 104 | ) 105 | cluster_a_pos_recall = metrics[cluster_a_str_pos_recall] 106 | cluster_a_neg_recall = metrics[cluster_a_str_neg_recall] 107 | assert cluster_a_pos_recall["value"] == ( 108 | # example 1, 2 in order 109 | (2 + 0) / (2 + 1) 110 | ) 111 | assert cluster_a_neg_recall["value"] == ( 112 | # example 1, 2 in order 113 | (0 + 1) / (1 + 1) 114 | ) 115 | assert cluster_a_pos_recall["n"] == 3 116 | assert cluster_a_neg_recall["n"] == 2 117 | 118 | # cluster B 119 | # precision 120 | cluster_b_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 121 | cluster="b", index_str=index_str_pos_precision 122 | ) 123 | cluster_b_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 124 | cluster="b", index_str=index_str_neg_precision 125 | ) 126 | cluster_b_str_pos_precision = metrics[cluster_b_str_pos_precision] 127 | assert cluster_b_str_neg_precision not in metrics 128 | assert cluster_b_str_pos_precision["value"] == ( 129 | # example 3 only 130 | 0 / 1 131 | ) 132 | assert cluster_b_str_pos_precision["n"] == 1 133 | 134 | # recall 135 | cluster_b_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 136 | cluster="b", index_str=index_str_pos_recall 137 | ) 138 | cluster_b_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 139 | cluster="b", index_str=index_str_neg_recall 140 | ) 141 | assert cluster_b_str_pos_recall not in metrics 142 | cluster_b_neg_recall = metrics[cluster_b_str_neg_recall] 143 | assert cluster_b_neg_recall["value"] == ( 144 | # example 3 only 145 | 0 / 1 146 | ) 147 | assert cluster_b_neg_recall["n"] == 1 148 | 149 | # f1 150 | index_str_pos_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 151 | cluster="b", index_str=index_str_pos_f1 152 | ) 153 | index_str_neg_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 154 | cluster="b", index_str=index_str_neg_f1 155 | ) 156 | index_str_balanced_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format( 157 | cluster="b", index_str=index_str_balanced_f1 158 | ) 159 | assert index_str_pos_f1 not in metrics 160 | assert index_str_neg_f1 not in metrics 161 | assert index_str_balanced_f1 not in metrics 162 | 163 | 164 | if __name__ == "__main__": 165 | test_compute_agreement_for_rater_by_class() 166 | -------------------------------------------------------------------------------- /humaneval_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | HumanEval: Evaluating Large Language Models Trained on Code 3 | Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba 4 | https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/ 5 | """ 6 | 7 | import random 8 | import re 9 | from concurrent.futures import ThreadPoolExecutor, as_completed 10 | 11 | from human_eval.data import read_problems 12 | from human_eval.evaluation import estimate_pass_at_k 13 | from human_eval.execution import check_correctness # , unsafe_execute 14 | 15 | from . import common 16 | from .common import HTML_JINJA 17 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 18 | 19 | 20 | def evaluate_functional_correctness( 21 | sample: dict[str, str], 22 | completions: list[str], 23 | n_workers: int = 4, 24 | timeout: float = 3.0, 25 | ): 26 | """ 27 | Evaluates the functional correctness of generated samples, and writes 28 | results to f"{sample_file}_results.jsonl.gz" 29 | """ 30 | 31 | # Check the generated samples against test suites. 32 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 33 | futures = [] 34 | for i, completion in enumerate(completions): 35 | args = (sample, completion, timeout, i) 36 | future = executor.submit(check_correctness, *args) 37 | futures.append(future) 38 | results = [] 39 | for future in as_completed(futures): 40 | result = future.result() 41 | results.append(result) 42 | passed = [int(r["passed"]) for r in results] 43 | return passed 44 | 45 | 46 | class HumanEval(Eval): 47 | def __init__( 48 | self, 49 | num_examples: int = 250, # restrict to a subset of the data for debugging 50 | num_samples_per_task: int = 5, 51 | ks_passes: list[int] = [1, 2, 5], 52 | timeout: int = 120, 53 | ): 54 | self.seed = 0 55 | self.examples = read_problems() 56 | self.examples = list(self.examples.values()) 57 | 58 | self._num_examples = num_examples 59 | if self._num_examples: 60 | self.examples = random.Random(self.seed).sample(self.examples, num_examples) 61 | self._num_samples_per_task = num_samples_per_task 62 | self._ks_passes = ks_passes 63 | self._timeout = timeout 64 | 65 | def __call__(self, sampler: SamplerBase) -> EvalResult: 66 | instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n" 67 | 68 | def find_code(completion): 69 | pattern = re.compile(r"```python\n(.*?)```", re.DOTALL) 70 | matches = pattern.findall(completion) 71 | extracted_answer = matches[0] if len(matches) >= 1 else completion 72 | extracted_answer = extracted_answer[ 73 | extracted_answer.find(":\n ") + 2 : 74 | ] # remove signature 75 | return extracted_answer 76 | 77 | def fn(sample: dict[str, str]): 78 | prompt_messages = [ 79 | sampler._pack_message( 80 | role="user", content=instruction + sample["prompt"] 81 | ) 82 | ] 83 | completions = [ 84 | find_code(sampler(prompt_messages).response_text) 85 | for _ in range(self._num_samples_per_task) 86 | ] 87 | results = evaluate_functional_correctness(sample, completions) 88 | total = len(results) 89 | correct = sum(results) 90 | score = sum(results) / len(results) 91 | html = common.jinja_env.from_string(HTML_JINJA).render( 92 | prompt_messages=prompt_messages, 93 | next_message=dict(content=completions[0], role="assistant"), 94 | score=score, 95 | correct_answer=[1] * len(results), 96 | extracted_answer=results, 97 | ) 98 | convo = prompt_messages + [ 99 | dict(content=completion, role="assistant") for completion in completions 100 | ] 101 | return SingleEvalResult( 102 | html=html, 103 | score=score, 104 | convo=convo, 105 | metrics={ 106 | f"pass@{k}": estimate_pass_at_k([total], [correct], k) 107 | # this will be aggrated so no need of .mean() 108 | for k in self._ks_passes 109 | if total >= k 110 | }, 111 | ) 112 | 113 | results = common.map_with_progress(fn, self.examples, num_threads=3) 114 | return common.aggregate_results(results) 115 | -------------------------------------------------------------------------------- /math_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Measuring Mathematical Problem Solving With the MATH Dataset 3 | Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt 4 | https://arxiv.org/abs/2103.03874 5 | """ 6 | 7 | import random 8 | import re 9 | from typing import Literal 10 | 11 | import pandas 12 | 13 | from . import common 14 | from .common import ANSWER_PATTERN, HTML_JINJA, check_equality 15 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 16 | 17 | QUERY_TEMPLATE = """ 18 | Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. 19 | 20 | {Question} 21 | 22 | Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. 23 | """.strip() 24 | 25 | 26 | class MathEval(Eval): 27 | def __init__( 28 | self, 29 | equality_checker: SamplerBase, 30 | num_examples: int | None = None, 31 | n_repeats: int = 16, 32 | split: Literal["math_test", "math_500_test"] = "math_test", 33 | ): 34 | df = pandas.read_csv( 35 | f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv" 36 | ) 37 | examples = [row.to_dict() for _, row in df.iterrows()] 38 | if num_examples: 39 | assert n_repeats == 1, "n_repeats only supported for num_examples = None" 40 | rng = random.Random(0) 41 | examples = rng.sample(examples, num_examples) 42 | self.examples = examples * n_repeats 43 | self.equality_checker = equality_checker 44 | 45 | def __call__(self, sampler: SamplerBase) -> EvalResult: 46 | def fn(row: dict): 47 | prompt_messages = [ 48 | sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user") 49 | ] 50 | sampler_response = sampler(prompt_messages) 51 | response_text = sampler_response.response_text 52 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 53 | match = re.search(ANSWER_PATTERN, response_text) 54 | extracted_answer = match.group(1) if match else None 55 | score = float(check_equality(self.equality_checker, row["Answer"], extracted_answer)) 56 | html = common.jinja_env.from_string(HTML_JINJA).render( 57 | prompt_messages=actual_queried_prompt_messages, 58 | next_message=dict(content=response_text, role="assistant"), 59 | score=score, 60 | correct_answer=row["Answer"], 61 | extracted_answer=extracted_answer, 62 | ) 63 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] 64 | return SingleEvalResult(html=html, score=score, convo=convo) 65 | 66 | results = common.map_with_progress(fn, self.examples) 67 | return common.aggregate_results(results) 68 | -------------------------------------------------------------------------------- /mgsm_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems. 3 | Language Models are Multilingual Chain-of-Thought Reasoners 4 | Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei 5 | https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp 6 | """ 7 | 8 | import re 9 | from typing import Optional 10 | 11 | from . import common 12 | from .mmlu_eval import HTML_JINJA 13 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 14 | 15 | ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"] 16 | LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"] 17 | NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"] 18 | 19 | LANG_TO_FPATH = { 20 | "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv", 21 | "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv", 22 | "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv", 23 | "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv", 24 | "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv", 25 | "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv", 26 | "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv", 27 | "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv", 28 | "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv", 29 | "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv", 30 | "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv", 31 | } 32 | LANG_TO_INSTRUCTIONS = { 33 | "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:". 34 | 35 | {input}""", 36 | "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।. 37 | 38 | {input}""", 39 | "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu. 40 | 41 | {input}""", 42 | "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:". 43 | 44 | {input}""", 45 | "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:". 46 | 47 | {input}""", 48 | "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。 49 | 50 | {input}""", 51 | "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:". 52 | 53 | {input}""", 54 | "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:". 55 | 56 | {input}""", 57 | "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు. 58 | 59 | {input}""", 60 | "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:" 61 | 62 | {input}""", 63 | "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。 64 | 65 | {input}""", 66 | } 67 | 68 | LANG_TO_ANSWER_PREFIX = { 69 | "en": "Answer", 70 | "bn": "উত্তর", 71 | "de": "Antwort", 72 | "es": "Respuesta", 73 | "fr": "Réponse", 74 | "ja": "答え", 75 | "ru": "Ответ", 76 | "sw": "Jibu", 77 | "te": "సమాధానం", 78 | "th": "คำตอบ", 79 | "zh": "答案", 80 | } 81 | 82 | 83 | def parse_answer(answer: str, answer_prefix: str) -> str: 84 | if answer_prefix not in answer: 85 | return "" 86 | 87 | answer_text = answer.split(answer_prefix)[-1].strip() 88 | 89 | # find all the numbers (including decimals) in the string 90 | numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", "")) 91 | 92 | # return the first number (removing trailing decimal point if present), 93 | # or an empty string if there were no numbers 94 | return numbers[-1].rstrip(".") if numbers else "" 95 | 96 | 97 | def score_mgsm(target: str, prediction: str) -> bool: 98 | if "." in prediction: 99 | prediction = prediction.rstrip("0").rstrip(".") 100 | 101 | target = target.replace(",", "") 102 | prediction = prediction.replace(",", "") 103 | 104 | return target == prediction 105 | 106 | 107 | def get_lang_examples(lang: str) -> list[dict[str, str]]: 108 | fpath = LANG_TO_FPATH[lang] 109 | examples = [] 110 | with common.url_to_fileobj(fpath, binary=True) as f: 111 | for raw_line in f: 112 | line = raw_line.decode("utf-8").strip() 113 | inputs, targets = line.split("\t") 114 | if "." in targets: 115 | raise ValueError(f"targets {targets} contains a decimal point.") 116 | # targets = int(targets.replace(",", "")) 117 | examples.append({"inputs": inputs, "targets": targets, "lang": lang}) 118 | return examples 119 | 120 | 121 | def get_all_examples() -> list[dict[str, str]]: 122 | examples = [] 123 | for lang in ALL_LANGUAGES: 124 | if lang != "en": 125 | continue 126 | examples += get_lang_examples(lang) 127 | return examples 128 | 129 | 130 | class MGSMEval(Eval): 131 | def __init__( 132 | self, 133 | num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging 134 | languages: Optional[list[str]] = ALL_LANGUAGES, 135 | ): 136 | if languages is None: 137 | languages = ALL_LANGUAGES 138 | else: 139 | for language in languages: 140 | if language not in ALL_LANGUAGES: 141 | raise ValueError( 142 | f"language {language} is not a valid language. " 143 | f"It should be one in {ALL_LANGUAGES}" 144 | ) 145 | self._languages = languages 146 | self._num_examples_per_lang = num_examples_per_lang 147 | 148 | examples = [] 149 | for lang in self._languages: 150 | lang_examples = get_lang_examples(lang) 151 | examples.extend(lang_examples[: self._num_examples_per_lang]) 152 | self.examples = examples 153 | 154 | def __call__(self, sampler: SamplerBase) -> EvalResult: 155 | def fn(example: dict[str, str]): 156 | language = example["lang"] 157 | latin_language = "group_latin" if language in LATIN_LANGUAGES else "group_non_latin" 158 | correct_answer = example["targets"] 159 | instruction = LANG_TO_INSTRUCTIONS[language] 160 | prompt_messages = [ 161 | sampler._pack_message( 162 | content=instruction.format(input=example["inputs"]), role="user" 163 | ) 164 | ] 165 | try: 166 | sampler_response = sampler(prompt_messages) 167 | response_text = sampler_response.response_text 168 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 169 | except Exception as e: 170 | response_text = "" 171 | 172 | answer_prefix = LANG_TO_ANSWER_PREFIX[language] 173 | extracted_answer = parse_answer(response_text, answer_prefix) 174 | 175 | score = score_mgsm(correct_answer, extracted_answer) 176 | html = common.jinja_env.from_string(HTML_JINJA).render( 177 | prompt_messages=actual_queried_prompt_messages, 178 | next_message=dict(content=response_text, role="assistant"), 179 | score=score, 180 | correct_answer=correct_answer, 181 | extracted_answer=extracted_answer or None, 182 | ) 183 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] 184 | return SingleEvalResult( 185 | html=html, 186 | score=score, 187 | convo=convo, 188 | metrics={language: score, latin_language: score}, 189 | ) 190 | 191 | results = common.map_with_progress(fn, self.examples) 192 | return common.aggregate_results(results, default_stats=("mean", "std")) 193 | -------------------------------------------------------------------------------- /mmlu_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Measuring Massive Multitask Language Understanding 3 | Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt 4 | https://arxiv.org/abs/2009.03300 5 | """ 6 | 7 | import random 8 | import re 9 | 10 | import pandas 11 | 12 | from . import common 13 | from .common import ( 14 | HTML_JINJA, 15 | MULTILINGUAL_ANSWER_PATTERN_TEMPLATE, 16 | MULTILINGUAL_ANSWER_REGEXES, 17 | format_multichoice_question, 18 | normalize_extracted_answer, 19 | normalize_response, 20 | ) 21 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 22 | 23 | subject2category = { 24 | "abstract_algebra": "stem", 25 | "anatomy": "other", 26 | "astronomy": "stem", 27 | "business_ethics": "other", 28 | "clinical_knowledge": "other", 29 | "college_biology": "stem", 30 | "college_chemistry": "stem", 31 | "college_computer_science": "stem", 32 | "college_mathematics": "stem", 33 | "college_medicine": "other", 34 | "college_physics": "stem", 35 | "computer_security": "stem", 36 | "conceptual_physics": "stem", 37 | "econometrics": "social_sciences", 38 | "electrical_engineering": "stem", 39 | "elementary_mathematics": "stem", 40 | "formal_logic": "humanities", 41 | "global_facts": "other", 42 | "high_school_biology": "stem", 43 | "high_school_chemistry": "stem", 44 | "high_school_computer_science": "stem", 45 | "high_school_european_history": "humanities", 46 | "high_school_geography": "social_sciences", 47 | "high_school_government_and_politics": "social_sciences", 48 | "high_school_macroeconomics": "social_sciences", 49 | "high_school_mathematics": "stem", 50 | "high_school_microeconomics": "social_sciences", 51 | "high_school_physics": "stem", 52 | "high_school_psychology": "social_sciences", 53 | "high_school_statistics": "stem", 54 | "high_school_us_history": "humanities", 55 | "high_school_world_history": "humanities", 56 | "human_aging": "other", 57 | "human_sexuality": "social_sciences", 58 | "international_law": "humanities", 59 | "jurisprudence": "humanities", 60 | "logical_fallacies": "humanities", 61 | "machine_learning": "stem", 62 | "management": "other", 63 | "marketing": "other", 64 | "medical_genetics": "other", 65 | "miscellaneous": "other", 66 | "moral_disputes": "humanities", 67 | "moral_scenarios": "humanities", 68 | "nutrition": "other", 69 | "philosophy": "humanities", 70 | "prehistory": "humanities", 71 | "professional_accounting": "other", 72 | "professional_law": "humanities", 73 | "professional_medicine": "other", 74 | "professional_psychology": "social_sciences", 75 | "public_relations": "social_sciences", 76 | "security_studies": "social_sciences", 77 | "sociology": "social_sciences", 78 | "us_foreign_policy": "social_sciences", 79 | "virology": "other", 80 | "world_religions": "humanities", 81 | } 82 | 83 | 84 | class MMLUEval(Eval): 85 | def __init__(self, num_examples: int | None = None, language: str = "EN-US"): 86 | if language != "EN-US": 87 | url = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv" 88 | else: 89 | url = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv" 90 | df = pandas.read_csv(url) 91 | examples = [row.to_dict() for _, row in df.iterrows()] 92 | if num_examples: 93 | examples = random.Random(0).sample(examples, num_examples) 94 | self.examples = examples 95 | 96 | def __call__(self, sampler: SamplerBase) -> EvalResult: 97 | def fn(row: dict): 98 | prompt_messages = [ 99 | sampler._pack_message( 100 | content=format_multichoice_question(row), role="user" 101 | ) 102 | ] 103 | sampler_response = sampler(prompt_messages) 104 | response_text = sampler_response.response_text 105 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 106 | response_text = normalize_response(response_text) 107 | extracted_answer = None 108 | for answer_regex in MULTILINGUAL_ANSWER_REGEXES: 109 | regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) 110 | match = re.search(regex, response_text) 111 | if match: 112 | extracted_answer = normalize_extracted_answer(match.group(1)) 113 | break 114 | score = 1.0 if extracted_answer == row["Answer"] else 0.0 115 | html = common.jinja_env.from_string(HTML_JINJA).render( 116 | prompt_messages=actual_queried_prompt_messages, 117 | next_message=dict(content=response_text, role="assistant"), 118 | score=score, 119 | correct_answer=row["Answer"], 120 | extracted_answer=extracted_answer, 121 | ) 122 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] 123 | category = subject2category.get(row["Subject"], "other") 124 | return SingleEvalResult( 125 | html=html, score=score, metrics={category: score}, convo=convo 126 | ) 127 | 128 | results = common.map_with_progress(fn, self.examples) 129 | return common.aggregate_results(results) 130 | -------------------------------------------------------------------------------- /multilingual_mmlu_benchmark_results.md: -------------------------------------------------------------------------------- 1 | # Multilingual MMLU Benchmark Results 2 | 3 | To evaluate multilingual performance, we translated MMLU’s test set into 14 languages using professional human translators. Relying on human translators for this evaluation increases confidence in the accuracy of the translations, especially for low-resource languages like Yoruba. 4 | 5 | ## Results 6 | 7 | 8 | | Language | o3-high | o1 | o4-mini-high | o3-mini-high | gpt-4.5-preview-2025-02-27 | gpt-4.1-2025-04-14 | gpt-4o-2024-11-20 | gpt-4.1-mini-2025-04-14 | gpt-4o-mini-2024-07-18 | gpt-4.1-nano-2025-04-14 | 9 | | :------------------: | :---------: | :---: | :----------: | :----------: | :------------------------: | :----------------: | :---------------: | :---------------------: | :--------------------: | :---------------------: | 10 | | Arabic | **0.904** | 0.890 | 0.861 | 0.819 | 0.860 | 0.844 | 0.831 | 0.795 | 0.709 | 0.659 | 11 | | Bengali | **0.878** | 0.873 | 0.840 | 0.801 | 0.848 | 0.827 | 0.801 | 0.749 | 0.658 | 0.583 | 12 | | Chinese (Simplified) | **0.893** | 0.889 | 0.869 | 0.836 | 0.870 | 0.861 | 0.842 | 0.817 | 0.731 | 0.710 | 13 | | French | **0.906** | 0.893 | 0.874 | 0.837 | 0.878 | 0.870 | 0.846 | 0.835 | 0.766 | 0.739 | 14 | | German | **0.905** | 0.890 | 0.867 | 0.808 | 0.853 | 0.855 | 0.836 | 0.823 | 0.743 | 0.722 | 15 | | Hindi | **0.898** | 0.883 | 0.859 | 0.811 | 0.858 | 0.842 | 0.819 | 0.780 | 0.692 | 0.629 | 16 | | Indonesian | **0.898** | 0.886 | 0.869 | 0.828 | 0.872 | 0.859 | 0.840 | 0.816 | 0.745 | 0.714 | 17 | | Italian | **0.912** | 0.897 | 0.877 | 0.838 | 0.878 | 0.869 | 0.845 | 0.835 | 0.764 | 0.734 | 18 | | Japanese | **0.890** | 0.889 | 0.869 | 0.831 | 0.869 | 0.856 | 0.835 | 0.810 | 0.726 | 0.690 | 19 | | Korean | **0.893** | 0.882 | 0.867 | 0.826 | 0.860 | 0.849 | 0.829 | 0.801 | 0.720 | 0.679 | 20 | | Portuguese (Brazil) | **0.910** | 0.895 | 0.878 | 0.841 | 0.879 | 0.870 | 0.836 | 0.839 | 0.768 | 0.741 | 21 | | Spanish | **0.911** | 0.899 | 0.880 | 0.840 | 0.884 | 0.876 | 0.843 | 0.839 | 0.774 | 0.748 | 22 | | Swahili | **0.860** | 0.854 | 0.813 | 0.738 | 0.820 | 0.795 | 0.779 | 0.679 | 0.619 | 0.566 | 23 | | Yoruba | **0.780** | 0.754 | 0.708 | 0.637 | 0.682 | 0.647 | 0.621 | 0.566 | 0.458 | 0.455 | 24 | | Average | **0.888** | 0.877 | 0.852 | 0.807 | 0.851 | 0.837 | 0.814 | 0.785 | 0.705 | 0.669 | 25 | 26 | These results can be reproduced by running 27 | 28 | ```bash 29 | python -m simple-evals.run_multilingual_mmlu 30 | ``` 31 | -------------------------------------------------------------------------------- /run_multilingual_mmlu.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pandas as pd 4 | 5 | from . import common 6 | from .mmlu_eval import MMLUEval 7 | from .sampler.chat_completion_sampler import ( 8 | OPENAI_SYSTEM_MESSAGE_API, 9 | OPENAI_SYSTEM_MESSAGE_CHATGPT, 10 | ChatCompletionSampler, 11 | ) 12 | from .sampler.o_chat_completion_sampler import OChatCompletionSampler 13 | 14 | 15 | def main(): 16 | debug = True 17 | samplers = { 18 | "gpt-4o_chatgpt": ChatCompletionSampler( 19 | model="gpt-4o", 20 | system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT, 21 | max_tokens=2048, 22 | ), 23 | "gpt-4o-mini-2024-07-18": ChatCompletionSampler( 24 | model="gpt-4o-mini-2024-07-18", 25 | system_message=OPENAI_SYSTEM_MESSAGE_API, 26 | max_tokens=2048, 27 | ), 28 | "o1-preview": OChatCompletionSampler( 29 | model="o1-preview", 30 | ), 31 | "o1-mini": OChatCompletionSampler( 32 | model="o1-mini", 33 | ), 34 | # Default == Medium 35 | "o3-mini": OChatCompletionSampler( 36 | model="o3-mini", 37 | ), 38 | "o3-mini_high": OChatCompletionSampler( 39 | model="o3-mini", 40 | reasoning_effort="high", 41 | ), 42 | "o3-mini_low": OChatCompletionSampler( 43 | model="o3-mini", 44 | reasoning_effort="low", 45 | ), 46 | } 47 | 48 | def get_evals(eval_name): 49 | match eval_name: 50 | case "mmlu_EN-US": 51 | return MMLUEval(num_examples=10 if debug else None, language="EN-US") 52 | case "mmlu_AR-XY": 53 | return MMLUEval(num_examples=10 if debug else None, language="AR-XY") 54 | case "mmlu_BN-BD": 55 | return MMLUEval(num_examples=10 if debug else None, language="BN-BD") 56 | case "mmlu_DE-DE": 57 | return MMLUEval(num_examples=10 if debug else None, language="DE-DE") 58 | case "mmlu_ES-LA": 59 | return MMLUEval(num_examples=10 if debug else None, language="ES-LA") 60 | case "mmlu_FR-FR": 61 | return MMLUEval(num_examples=10 if debug else None, language="FR-FR") 62 | case "mmlu_HI-IN": 63 | return MMLUEval(num_examples=10 if debug else None, language="HI-IN") 64 | case "mmlu_ID-ID": 65 | return MMLUEval(num_examples=10 if debug else None, language="ID-ID") 66 | case "mmlu_IT-IT": 67 | return MMLUEval(num_examples=10 if debug else None, language="IT-IT") 68 | case "mmlu_JA-JP": 69 | return MMLUEval(num_examples=10 if debug else None, language="JA-JP") 70 | case "mmlu_KO-KR": 71 | return MMLUEval(num_examples=10 if debug else None, language="KO-KR") 72 | case "mmlu_PT-BR": 73 | return MMLUEval(num_examples=10 if debug else None, language="PT-BR") 74 | case "mmlu_ZH-CN": 75 | return MMLUEval(num_examples=10 if debug else None, language="ZH-CN") 76 | case "mmlu_SW-KE": 77 | return MMLUEval(num_examples=10 if debug else None, language="SW-KE") 78 | case "mmlu_YO-NG": 79 | return MMLUEval(num_examples=10 if debug else None, language="YO-NG") 80 | case _: 81 | raise Exception(f"Unrecoginized eval type: {eval_name}") 82 | 83 | evals = { 84 | eval_name: get_evals(eval_name) 85 | for eval_name in [ 86 | "mmlu_AR-XY", 87 | "mmlu_BN-BD", 88 | "mmlu_DE-DE", 89 | "mmlu_EN-US", 90 | "mmlu_ES-LA", 91 | "mmlu_FR-FR", 92 | "mmlu_HI-IN", 93 | "mmlu_ID-ID", 94 | "mmlu_IT-IT", 95 | "mmlu_JA-JP", 96 | "mmlu_KO-KR", 97 | "mmlu_PT-BR", 98 | "mmlu_ZH-CN", 99 | "mmlu_SW-KE", 100 | "mmlu_YO-NG", 101 | ] 102 | } 103 | print(evals) 104 | debug_suffix = "_DEBUG" if debug else "" 105 | mergekey2resultpath = {} 106 | for sampler_name, sampler in samplers.items(): 107 | for eval_name, eval_obj in evals.items(): 108 | result = eval_obj(sampler) 109 | # ^^^ how to use a sampler 110 | file_stem = f"{eval_name}_{sampler_name}" 111 | report_filename = f"/tmp/{file_stem}{debug_suffix}.html" 112 | print(f"Writing report to {report_filename}") 113 | with open(report_filename, "w") as fh: 114 | fh.write(common.make_report(result)) 115 | metrics = result.metrics | {"score": result.score} 116 | print(metrics) 117 | result_filename = f"/tmp/{file_stem}{debug_suffix}.json" 118 | with open(result_filename, "w") as f: 119 | f.write(json.dumps(metrics, indent=2)) 120 | print(f"Writing results to {result_filename}") 121 | mergekey2resultpath[f"{file_stem}"] = result_filename 122 | merge_metrics = [] 123 | for eval_sampler_name, result_filename in mergekey2resultpath.items(): 124 | try: 125 | result = json.load(open(result_filename, "r+")) 126 | except Exception as e: 127 | print(e, result_filename) 128 | continue 129 | result = result.get("f1_score", result.get("score", None)) 130 | eval_name = eval_sampler_name[: eval_sampler_name.find("_")] 131 | sampler_name = eval_sampler_name[eval_sampler_name.find("_") + 1 :] 132 | merge_metrics.append( 133 | {"eval_name": eval_name, "sampler_name": sampler_name, "metric": result} 134 | ) 135 | merge_metrics_df = pd.DataFrame(merge_metrics).pivot( 136 | index=["sampler_name"], columns="eval_name" 137 | ) 138 | print("\nAll results: ") 139 | print(merge_metrics_df.to_markdown()) 140 | return merge_metrics 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /sampler/chat_completion_sampler.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any 3 | 4 | import openai 5 | from openai import OpenAI 6 | 7 | from ..types import MessageList, SamplerBase, SamplerResponse 8 | 9 | OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." 10 | OPENAI_SYSTEM_MESSAGE_CHATGPT = ( 11 | "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." 12 | + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" 13 | ) 14 | 15 | 16 | class ChatCompletionSampler(SamplerBase): 17 | """ 18 | Sample from OpenAI's chat completion API 19 | """ 20 | 21 | def __init__( 22 | self, 23 | model: str = "gpt-3.5-turbo", 24 | system_message: str | None = None, 25 | temperature: float = 0.5, 26 | max_tokens: int = 1024, 27 | ): 28 | self.api_key_name = "OPENAI_API_KEY" 29 | self.client = OpenAI() 30 | # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY 31 | self.model = model 32 | self.system_message = system_message 33 | self.temperature = temperature 34 | self.max_tokens = max_tokens 35 | self.image_format = "url" 36 | 37 | def _handle_image( 38 | self, 39 | image: str, 40 | encoding: str = "base64", 41 | format: str = "png", 42 | fovea: int = 768, 43 | ): 44 | new_image = { 45 | "type": "image_url", 46 | "image_url": { 47 | "url": f"data:image/{format};{encoding},{image}", 48 | }, 49 | } 50 | return new_image 51 | 52 | def _handle_text(self, text: str): 53 | return {"type": "text", "text": text} 54 | 55 | def _pack_message(self, role: str, content: Any): 56 | return {"role": str(role), "content": content} 57 | 58 | def __call__(self, message_list: MessageList) -> SamplerResponse: 59 | if self.system_message: 60 | message_list = [ 61 | self._pack_message("system", self.system_message) 62 | ] + message_list 63 | trial = 0 64 | while True: 65 | try: 66 | response = self.client.chat.completions.create( 67 | model=self.model, 68 | messages=message_list, 69 | temperature=self.temperature, 70 | max_tokens=self.max_tokens, 71 | ) 72 | content = response.choices[0].message.content 73 | if content is None: 74 | raise ValueError("OpenAI API returned empty response; retrying") 75 | return SamplerResponse( 76 | response_text=content, 77 | response_metadata={"usage": response.usage}, 78 | actual_queried_message_list=message_list, 79 | ) 80 | # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU 81 | except openai.BadRequestError as e: 82 | print("Bad Request Error", e) 83 | return SamplerResponse( 84 | response_text="No response (bad request).", 85 | response_metadata={"usage": None}, 86 | actual_queried_message_list=message_list, 87 | ) 88 | except Exception as e: 89 | exception_backoff = 2**trial # expontial back off 90 | print( 91 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", 92 | e, 93 | ) 94 | time.sleep(exception_backoff) 95 | trial += 1 96 | # unknown error shall throw exception 97 | -------------------------------------------------------------------------------- /sampler/claude_sampler.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | 4 | import anthropic 5 | 6 | from ..types import MessageList, SamplerBase, SamplerResponse 7 | from .. import common 8 | 9 | CLAUDE_SYSTEM_MESSAGE_LMSYS = ( 10 | "The assistant is Claude, created by Anthropic. The current date is " 11 | "{currentDateTime}. Claude's knowledge base was last updated in " 12 | "August 2023 and it answers user questions about events before " 13 | "August 2023 and after August 2023 the same way a highly informed " 14 | "individual from August 2023 would if they were talking to someone " 15 | "from {currentDateTime}. It should give concise responses to very " 16 | "simple questions, but provide thorough responses to more complex " 17 | "and open-ended questions. It is happy to help with writing, " 18 | "analysis, question answering, math, coding, and all sorts of other " 19 | "tasks. It uses markdown for coding. It does not mention this " 20 | "information about itself unless the information is directly " 21 | "pertinent to the human's query." 22 | ).format(currentDateTime="2024-04-01") 23 | # reference: https://github.com/lm-sys/FastChat/blob/7899355ebe32117fdae83985cf8ee476d2f4243f/fastchat/conversation.py#L894 24 | 25 | 26 | class ClaudeCompletionSampler(SamplerBase): 27 | 28 | def __init__( 29 | self, 30 | model: str, 31 | system_message: str | None = None, 32 | temperature: float = 0.0, # default in Anthropic example 33 | max_tokens: int = 4096, 34 | ): 35 | self.client = anthropic.Anthropic() 36 | self.api_key = os.environ.get("ANTHROPIC_API_KEY") # please set your API_KEY 37 | self.model = model 38 | self.system_message = system_message 39 | self.temperature = temperature 40 | self.max_tokens = max_tokens 41 | self.image_format = "base64" 42 | 43 | def _handle_image( 44 | self, 45 | image: str, 46 | encoding: str = "base64", 47 | format: str = "png", 48 | fovea: int = 768, 49 | ): 50 | new_image = { 51 | "type": "image", 52 | "source": { 53 | "type": encoding, 54 | "media_type": f"image/{format}", 55 | "data": image, 56 | }, 57 | } 58 | return new_image 59 | 60 | def _handle_text(self, text): 61 | return {"type": "text", "text": text} 62 | 63 | def _pack_message(self, role, content): 64 | return {"role": str(role), "content": content} 65 | 66 | def __call__(self, message_list: MessageList) -> SamplerResponse: 67 | trial = 0 68 | while True: 69 | try: 70 | if not common.has_only_user_assistant_messages(message_list): 71 | raise ValueError(f"Claude sampler only supports user and assistant messages, got {message_list}") 72 | if self.system_message: 73 | response_message = self.client.messages.create( 74 | model=self.model, 75 | system=self.system_message, 76 | max_tokens=self.max_tokens, 77 | temperature=self.temperature, 78 | messages=message_list, 79 | ) 80 | claude_input_messages: MessageList = [{"role": "system", "content": self.system_message}] + message_list 81 | else: 82 | response_message = self.client.messages.create( 83 | model=self.model, 84 | max_tokens=self.max_tokens, 85 | temperature=self.temperature, 86 | messages=message_list, 87 | ) 88 | claude_input_messages = message_list 89 | response_text = response_message.content[0].text 90 | return SamplerResponse( 91 | response_text=response_text, 92 | response_metadata={}, 93 | actual_queried_message_list=claude_input_messages, 94 | ) 95 | except anthropic.RateLimitError as e: 96 | exception_backoff = 2**trial # expontial back off 97 | print( 98 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", 99 | e, 100 | ) 101 | time.sleep(exception_backoff) 102 | trial += 1 103 | # unknown error shall throw exception 104 | -------------------------------------------------------------------------------- /sampler/o_chat_completion_sampler.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any 3 | 4 | import openai 5 | from openai import OpenAI 6 | 7 | from ..types import MessageList, SamplerBase, SamplerResponse 8 | 9 | 10 | class OChatCompletionSampler(SamplerBase): 11 | """ 12 | Sample from OpenAI's chat completion API for o series models 13 | """ 14 | 15 | def __init__( 16 | self, 17 | *, 18 | reasoning_effort: str | None = None, 19 | model: str = "o1-mini", 20 | ): 21 | self.api_key_name = "OPENAI_API_KEY" 22 | self.client = OpenAI() 23 | # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY 24 | self.model = model 25 | self.image_format = "url" 26 | self.reasoning_effort = reasoning_effort 27 | 28 | def _handle_image( 29 | self, 30 | image: str, 31 | encoding: str = "base64", 32 | format: str = "png", 33 | fovea: int = 768, 34 | ): 35 | new_image = { 36 | "type": "image_url", 37 | "image_url": { 38 | "url": f"data:image/{format};{encoding},{image}", 39 | }, 40 | } 41 | return new_image 42 | 43 | def _handle_text(self, text: str): 44 | return {"type": "text", "text": text} 45 | 46 | def _pack_message(self, role: str, content: Any): 47 | return {"role": str(role), "content": content} 48 | 49 | def __call__(self, message_list: MessageList) -> SamplerResponse: 50 | trial = 0 51 | while True: 52 | try: 53 | response = self.client.chat.completions.create( 54 | model=self.model, 55 | messages=message_list, 56 | reasoning_effort=self.reasoning_effort, 57 | ) 58 | content = response.choices[0].message.content 59 | return SamplerResponse( 60 | response_text=content, 61 | response_metadata={"usage": response.usage}, 62 | actual_queried_message_list=message_list, 63 | ) 64 | # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU 65 | except openai.BadRequestError as e: 66 | print("Bad Request Error", e) 67 | return SamplerResponse( 68 | response_text="", 69 | response_metadata={"usage": None}, 70 | actual_queried_message_list=message_list, 71 | ) 72 | except Exception as e: 73 | exception_backoff = 2**trial # expontial back off 74 | print( 75 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", 76 | e, 77 | ) 78 | time.sleep(exception_backoff) 79 | trial += 1 80 | # unknown error shall throw exception 81 | -------------------------------------------------------------------------------- /sampler/responses_sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Any 4 | 5 | import openai 6 | from openai import OpenAI 7 | 8 | from ..types import MessageList, SamplerBase, SamplerResponse 9 | 10 | 11 | class ResponsesSampler(SamplerBase): 12 | """ 13 | Sample from OpenAI's responses API 14 | """ 15 | 16 | def __init__( 17 | self, 18 | model: str = "gpt-4.1", 19 | system_message: str | None = None, 20 | temperature: float = 0.5, 21 | max_tokens: int = 1024, 22 | reasoning_model: bool = False, 23 | reasoning_effort: str | None = None, 24 | ): 25 | self.api_key_name = "OPENAI_API_KEY" 26 | assert os.environ.get("OPENAI_API_KEY"), "Please set OPENAI_API_KEY" 27 | self.client = OpenAI() 28 | self.model = model 29 | self.system_message = system_message 30 | self.temperature = temperature 31 | self.max_tokens = max_tokens 32 | self.image_format = "url" 33 | self.reasoning_model = reasoning_model 34 | self.reasoning_effort = reasoning_effort 35 | 36 | def _handle_image( 37 | self, 38 | image: str, 39 | encoding: str = "base64", 40 | format: str = "png", 41 | fovea: int = 768, 42 | ) -> dict[str, Any]: 43 | new_image = { 44 | "type": "input_image", 45 | "image_url": f"data:image/{format};{encoding},{image}", 46 | } 47 | return new_image 48 | 49 | def _handle_text(self, text: str) -> dict[str, Any]: 50 | return {"type": "input_text", "text": text} 51 | 52 | def _pack_message(self, role: str, content: Any) -> dict[str, Any]: 53 | return {"role": role, "content": content} 54 | 55 | def __call__(self, message_list: MessageList) -> SamplerResponse: 56 | if self.system_message: 57 | message_list = [ 58 | self._pack_message("developer", self.system_message) 59 | ] + message_list 60 | trial = 0 61 | while True: 62 | try: 63 | if self.reasoning_model: 64 | reasoning = ( 65 | {"effort": self.reasoning_effort} 66 | if self.reasoning_effort 67 | else None 68 | ) 69 | response = self.client.responses.create( 70 | model=self.model, 71 | input=message_list, 72 | reasoning=reasoning, 73 | ) 74 | else: 75 | response = self.client.responses.create( 76 | model=self.model, 77 | input=message_list, 78 | temperature=self.temperature, 79 | max_output_tokens=self.max_tokens, 80 | ) 81 | return SamplerResponse( 82 | response_text=response.output_text, 83 | response_metadata={"usage": response.usage}, 84 | actual_queried_message_list=message_list, 85 | ) 86 | except openai.BadRequestError as e: 87 | print("Bad Request Error", e) 88 | return SamplerResponse( 89 | response_text="", 90 | response_metadata={"usage": None}, 91 | actual_queried_message_list=message_list, 92 | ) 93 | except Exception as e: 94 | exception_backoff = 2**trial # expontial back off 95 | print( 96 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", 97 | e, 98 | ) 99 | time.sleep(exception_backoff) 100 | trial += 1 101 | # unknown error shall throw exception 102 | -------------------------------------------------------------------------------- /simple_evals.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import subprocess 4 | from datetime import datetime 5 | 6 | import pandas as pd 7 | 8 | from . import common 9 | from .browsecomp_eval import BrowseCompEval 10 | from .drop_eval import DropEval 11 | from .gpqa_eval import GPQAEval 12 | from .healthbench_eval import HealthBenchEval 13 | from .healthbench_meta_eval import HealthBenchMetaEval 14 | from .math_eval import MathEval 15 | from .mgsm_eval import MGSMEval 16 | from .mmlu_eval import MMLUEval 17 | from .humaneval_eval import HumanEval 18 | from .sampler.chat_completion_sampler import ( 19 | OPENAI_SYSTEM_MESSAGE_API, 20 | OPENAI_SYSTEM_MESSAGE_CHATGPT, 21 | ChatCompletionSampler, 22 | ) 23 | from .sampler.claude_sampler import ClaudeCompletionSampler, CLAUDE_SYSTEM_MESSAGE_LMSYS 24 | from .sampler.o_chat_completion_sampler import OChatCompletionSampler 25 | from .sampler.responses_sampler import ResponsesSampler 26 | from .simpleqa_eval import SimpleQAEval 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser( 31 | description="Run sampling and evaluations using different samplers and evaluations." 32 | ) 33 | parser.add_argument( 34 | "--list-models", action="store_true", help="List available models" 35 | ) 36 | parser.add_argument( 37 | "--model", 38 | type=str, 39 | help="Select a model by name. Also accepts a comma-separated list of models.", 40 | ) 41 | parser.add_argument( 42 | "--eval", 43 | type=str, 44 | help="Select an eval by name. Also accepts a comma-separated list of evals.", 45 | ) 46 | parser.add_argument( 47 | "--n-repeats", 48 | type=int, 49 | default=None, 50 | help="Number of repeats to run. Only supported for certain evals.", 51 | ) 52 | parser.add_argument( 53 | "--n-threads", 54 | type=int, 55 | default=120, 56 | help="Number of threads to run. Only supported for HealthBench and HealthBenchMeta.", 57 | ) 58 | parser.add_argument("--debug", action="store_true", help="Run in debug mode") 59 | parser.add_argument( 60 | "--examples", type=int, help="Number of examples to use (overrides default)" 61 | ) 62 | 63 | args = parser.parse_args() 64 | 65 | models = { 66 | # Reasoning Models 67 | "o3": ResponsesSampler( 68 | model="o3-2025-04-16", 69 | reasoning_model=True, 70 | ), 71 | "o3-temp-1": ResponsesSampler( 72 | model="o3-2025-04-16", 73 | reasoning_model=True, 74 | temperature=1.0, 75 | ), 76 | "o3_high": ResponsesSampler( 77 | model="o3-2025-04-16", 78 | reasoning_model=True, 79 | reasoning_effort="high", 80 | ), 81 | "o3_low": ResponsesSampler( 82 | model="o3-2025-04-16", 83 | reasoning_model=True, 84 | reasoning_effort="low", 85 | ), 86 | # Default == Medium 87 | "o4-mini": ResponsesSampler( 88 | model="o4-mini-2025-04-16", 89 | reasoning_model=True, 90 | ), 91 | "o4-mini_high": ResponsesSampler( 92 | model="o4-mini-2025-04-16", 93 | reasoning_model=True, 94 | reasoning_effort="high", 95 | ), 96 | "o4-mini_low": ResponsesSampler( 97 | model="o4-mini-2025-04-16", 98 | reasoning_model=True, 99 | reasoning_effort="low", 100 | ), 101 | "o1-pro": ResponsesSampler( 102 | model="o1-pro", 103 | reasoning_model=True, 104 | ), 105 | "o1": OChatCompletionSampler( 106 | model="o1", 107 | ), 108 | "o1_high": OChatCompletionSampler( 109 | model="o1", 110 | reasoning_effort="high", 111 | ), 112 | "o1_low": OChatCompletionSampler( 113 | model="o1", 114 | reasoning_effort="low", 115 | ), 116 | "o1-preview": OChatCompletionSampler( 117 | model="o1-preview", 118 | ), 119 | "o1-mini": OChatCompletionSampler( 120 | model="o1-mini", 121 | ), 122 | # Default == Medium 123 | "o3-mini": OChatCompletionSampler( 124 | model="o3-mini", 125 | ), 126 | "o3-mini_high": OChatCompletionSampler( 127 | model="o3-mini", 128 | reasoning_effort="high", 129 | ), 130 | "o3-mini_low": OChatCompletionSampler( 131 | model="o3-mini", 132 | reasoning_effort="low", 133 | ), 134 | # GPT-4.1 models 135 | "gpt-4.1": ChatCompletionSampler( 136 | model="gpt-4.1-2025-04-14", 137 | system_message=OPENAI_SYSTEM_MESSAGE_API, 138 | max_tokens=2048, 139 | ), 140 | "gpt-4.1-temp-1": ChatCompletionSampler( 141 | model="gpt-4.1-2025-04-14", 142 | system_message=OPENAI_SYSTEM_MESSAGE_API, 143 | max_tokens=2048, 144 | temperature=1.0, 145 | ), 146 | "gpt-4.1-mini": ChatCompletionSampler( 147 | model="gpt-4.1-mini-2025-04-14", 148 | system_message=OPENAI_SYSTEM_MESSAGE_API, 149 | max_tokens=2048, 150 | ), 151 | "gpt-4.1-nano": ChatCompletionSampler( 152 | model="gpt-4.1-nano-2025-04-14", 153 | system_message=OPENAI_SYSTEM_MESSAGE_API, 154 | max_tokens=2048, 155 | ), 156 | # GPT-4o models 157 | "gpt-4o": ChatCompletionSampler( 158 | model="gpt-4o", 159 | system_message=OPENAI_SYSTEM_MESSAGE_API, 160 | max_tokens=2048, 161 | ), 162 | "gpt-4o-2024-11-20": ChatCompletionSampler( 163 | model="gpt-4o-2024-11-20", 164 | system_message=OPENAI_SYSTEM_MESSAGE_API, 165 | max_tokens=2048, 166 | ), 167 | "gpt-4o-2024-08-06": ChatCompletionSampler( 168 | model="gpt-4o-2024-08-06", 169 | system_message=OPENAI_SYSTEM_MESSAGE_API, 170 | max_tokens=2048, 171 | ), 172 | "gpt-4o-2024-08-06-temp-1": ChatCompletionSampler( 173 | model="gpt-4o-2024-08-06", 174 | system_message=OPENAI_SYSTEM_MESSAGE_API, 175 | max_tokens=2048, 176 | temperature=1.0, 177 | ), 178 | "gpt-4o-2024-05-13": ChatCompletionSampler( 179 | model="gpt-4o-2024-05-13", 180 | system_message=OPENAI_SYSTEM_MESSAGE_API, 181 | max_tokens=2048, 182 | ), 183 | "gpt-4o-mini": ChatCompletionSampler( 184 | model="gpt-4o-mini-2024-07-18", 185 | system_message=OPENAI_SYSTEM_MESSAGE_API, 186 | max_tokens=2048, 187 | ), 188 | # GPT-4.5 model 189 | "gpt-4.5-preview": ChatCompletionSampler( 190 | model="gpt-4.5-preview-2025-02-27", 191 | system_message=OPENAI_SYSTEM_MESSAGE_API, 192 | max_tokens=2048, 193 | ), 194 | # GPT-4-turbo model 195 | "gpt-4-turbo-2024-04-09": ChatCompletionSampler( 196 | model="gpt-4-turbo-2024-04-09", 197 | system_message=OPENAI_SYSTEM_MESSAGE_API, 198 | ), 199 | # GPT-4 model 200 | "gpt-4-0613": ChatCompletionSampler( 201 | model="gpt-4-0613", 202 | system_message=OPENAI_SYSTEM_MESSAGE_API, 203 | ), 204 | # GPT-3.5 Turbo model 205 | "gpt-3.5-turbo-0125": ChatCompletionSampler( 206 | model="gpt-3.5-turbo-0125", 207 | system_message=OPENAI_SYSTEM_MESSAGE_API, 208 | ), 209 | "gpt-3.5-turbo-0125-temp-1": ChatCompletionSampler( 210 | model="gpt-3.5-turbo-0125", 211 | system_message=OPENAI_SYSTEM_MESSAGE_API, 212 | temperature=1.0, 213 | ), 214 | # Chatgpt models: 215 | "chatgpt-4o-latest": ChatCompletionSampler( 216 | model="chatgpt-4o-latest", 217 | system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT, 218 | max_tokens=2048, 219 | ), 220 | "gpt-4-turbo-2024-04-09_chatgpt": ChatCompletionSampler( 221 | model="gpt-4-turbo-2024-04-09", 222 | system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT, 223 | ), 224 | # Claude models: 225 | "claude-3-opus-20240229_empty": ClaudeCompletionSampler( 226 | model="claude-3-opus-20240229", 227 | system_message=CLAUDE_SYSTEM_MESSAGE_LMSYS, 228 | ), 229 | "claude-3-7-sonnet-20250219": ClaudeCompletionSampler( 230 | model="claude-3-7-sonnet-20250219", 231 | system_message=CLAUDE_SYSTEM_MESSAGE_LMSYS, 232 | ), 233 | "claude-3-haiku-20240307": ClaudeCompletionSampler( 234 | model="claude-3-haiku-20240307", 235 | ), 236 | } 237 | 238 | if args.list_models: 239 | print("Available models:") 240 | for model_name in models.keys(): 241 | print(f" - {model_name}") 242 | return 243 | 244 | if args.model: 245 | models_chosen = args.model.split(",") 246 | for model_name in models_chosen: 247 | if model_name not in models: 248 | print(f"Error: Model '{model_name}' not found.") 249 | return 250 | models = {model_name: models[model_name] for model_name in models_chosen} 251 | 252 | print(f"Running with args {args}") 253 | 254 | grading_sampler = ChatCompletionSampler( 255 | model="gpt-4.1-2025-04-14", 256 | system_message=OPENAI_SYSTEM_MESSAGE_API, 257 | max_tokens=2048, 258 | ) 259 | equality_checker = ChatCompletionSampler(model="gpt-4-turbo-preview") 260 | # ^^^ used for fuzzy matching, just for math 261 | 262 | def get_evals(eval_name, debug_mode): 263 | num_examples = ( 264 | args.examples if args.examples is not None else (5 if debug_mode else None) 265 | ) 266 | # Set num_examples = None to reproduce full evals 267 | match eval_name: 268 | case "mmlu": 269 | return MMLUEval(num_examples=1 if debug_mode else num_examples) 270 | case "math": 271 | return MathEval( 272 | equality_checker=equality_checker, 273 | num_examples=num_examples, 274 | n_repeats=1 if debug_mode else args.n_repeats or 10, 275 | ) 276 | case "gpqa": 277 | return GPQAEval( 278 | n_repeats=1 if debug_mode else args.n_repeats or 10, 279 | num_examples=num_examples, 280 | ) 281 | case "mgsm": 282 | return MGSMEval( 283 | num_examples_per_lang=10 if debug_mode else num_examples or 250 284 | ) 285 | case "drop": 286 | return DropEval( 287 | num_examples=10 if debug_mode else num_examples, 288 | train_samples_per_prompt=3, 289 | ) 290 | case "humaneval": 291 | return HumanEval(num_examples=10 if debug_mode else num_examples) 292 | case "simpleqa": 293 | return SimpleQAEval( 294 | grader_model=grading_sampler, 295 | num_examples=10 if debug_mode else num_examples, 296 | ) 297 | case "browsecomp": 298 | return BrowseCompEval( 299 | grader_model=grading_sampler, 300 | num_examples=10 if debug_mode else num_examples, 301 | ) 302 | case "healthbench": 303 | return HealthBenchEval( 304 | grader_model=grading_sampler, 305 | num_examples=10 if debug_mode else num_examples, 306 | n_repeats=args.n_repeats or 1, 307 | n_threads=args.n_threads or 1, 308 | subset_name=None, 309 | ) 310 | case "healthbench_hard": 311 | return HealthBenchEval( 312 | grader_model=grading_sampler, 313 | num_examples=10 if debug_mode else num_examples, 314 | n_repeats=args.n_repeats or 1, 315 | n_threads=args.n_threads or 1, 316 | subset_name="hard", 317 | ) 318 | case "healthbench_consensus": 319 | return HealthBenchEval( 320 | grader_model=grading_sampler, 321 | num_examples=10 if debug_mode else num_examples, 322 | n_repeats=args.n_repeats or 1, 323 | n_threads=args.n_threads or 1, 324 | subset_name="consensus", 325 | ) 326 | case "healthbench_meta": 327 | return HealthBenchMetaEval( 328 | grader_model=grading_sampler, 329 | num_examples=10 if debug_mode else num_examples, 330 | n_repeats=args.n_repeats or 1, 331 | n_threads=args.n_threads or 1, 332 | ) 333 | case _: 334 | raise Exception(f"Unrecognized eval type: {eval_name}") 335 | 336 | if args.eval: 337 | evals_list = args.eval.split(",") 338 | evals = {} 339 | for eval_name in evals_list: 340 | try: 341 | evals[eval_name] = get_evals(eval_name, args.debug) 342 | except Exception: 343 | print(f"Error: eval '{eval_name}' not found.") 344 | return 345 | else: 346 | evals = { 347 | eval_name: get_evals(eval_name, args.debug) 348 | for eval_name in [ 349 | "mmlu", 350 | "math", 351 | "gpqa", 352 | "mgsm", 353 | "drop", 354 | "humaneval", 355 | "simpleqa", 356 | "browsecomp", 357 | "healthbench", 358 | "healthbench_hard", 359 | "healthbench_consensus", 360 | "healthbench_meta", 361 | ] 362 | } 363 | 364 | print(evals) 365 | debug_suffix = "_DEBUG" if args.debug else "" 366 | print(debug_suffix) 367 | mergekey2resultpath = {} 368 | print(f"Running the following evals: {list(evals.keys())}") 369 | print(f"Running evals for the following models: {list(models.keys())}") 370 | 371 | now = datetime.now() 372 | date_str = now.strftime("%Y%m%d_%H%M%S") 373 | for model_name, sampler in models.items(): 374 | for eval_name, eval_obj in evals.items(): 375 | result = eval_obj(sampler) 376 | # ^^^ how to use a sampler 377 | file_stem = f"{eval_name}_{model_name}" 378 | # file stem should also include the year, month, day, and time in hours and minutes 379 | file_stem += f"_{date_str}" 380 | report_filename = f"/tmp/{file_stem}{debug_suffix}.html" 381 | print(f"Writing report to {report_filename}") 382 | with open(report_filename, "w") as fh: 383 | fh.write(common.make_report(result)) 384 | assert result.metrics is not None 385 | metrics = result.metrics | {"score": result.score} 386 | # Sort metrics by key 387 | metrics = dict(sorted(metrics.items())) 388 | print(metrics) 389 | result_filename = f"/tmp/{file_stem}{debug_suffix}.json" 390 | with open(result_filename, "w") as f: 391 | f.write(json.dumps(metrics, indent=2)) 392 | print(f"Writing results to {result_filename}") 393 | 394 | full_result_filename = f"/tmp/{file_stem}{debug_suffix}_allresults.json" 395 | with open(full_result_filename, "w") as f: 396 | result_dict = { 397 | "score": result.score, 398 | "metrics": result.metrics, 399 | "htmls": result.htmls, 400 | "convos": result.convos, 401 | "metadata": result.metadata, 402 | } 403 | f.write(json.dumps(result_dict, indent=2)) 404 | print(f"Writing all results to {full_result_filename}") 405 | 406 | mergekey2resultpath[f"{file_stem}"] = result_filename 407 | merge_metrics = [] 408 | for eval_model_name, result_filename in mergekey2resultpath.items(): 409 | try: 410 | result = json.load(open(result_filename, "r+")) 411 | except Exception as e: 412 | print(e, result_filename) 413 | continue 414 | result = result.get("f1_score", result.get("score", None)) 415 | eval_name = eval_model_name[: eval_model_name.find("_")] 416 | model_name = eval_model_name[eval_model_name.find("_") + 1 :] 417 | merge_metrics.append( 418 | {"eval_name": eval_name, "model_name": model_name, "metric": result} 419 | ) 420 | merge_metrics_df = pd.DataFrame(merge_metrics).pivot( 421 | index=["model_name"], columns="eval_name" 422 | ) 423 | print("\nAll results: ") 424 | print(merge_metrics_df.to_markdown()) 425 | return merge_metrics 426 | 427 | 428 | if __name__ == "__main__": 429 | main() 430 | -------------------------------------------------------------------------------- /simpleqa_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | SimpleQA: Measuring short-form factuality in large language models 3 | Authors: Jason Wei, Nguyen Karina, Hyung Won Chung, Yunxin Joy Jiao, Spencer Papay, Amelia Glaese, John Schulman, William Fedus 4 | https://cdn.openai.com/papers/simpleqa.pdf 5 | """ 6 | 7 | import random 8 | import re 9 | import pandas 10 | from . import common 11 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult 12 | 13 | GRADER_TEMPLATE = """ 14 | Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. 15 | First, I will give examples of each grade, and then you will grade a new example. 16 | 17 | 18 | The following are examples of CORRECT predicted answers. 19 | ``` 20 | Question: What are the names of Barack Obama's children? 21 | Gold target: Malia Obama and Sasha Obama 22 | Predicted answer 1: sasha and malia obama 23 | Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check 24 | Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001. 25 | ``` 26 | These predicted answers are all CORRECT because: 27 | - They fully contain the important information in the gold target. 28 | - They do not contain any information that contradicts the gold target. 29 | - Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter. 30 | - Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions. 31 | 32 | 33 | The following are examples of INCORRECT predicted answers. 34 | ``` 35 | Question: What are the names of Barack Obama's children? 36 | Gold target: Malia and Sasha 37 | Predicted answer 1: Malia. 38 | Predicted answer 2: Malia, Sasha, and Susan. 39 | Predicted answer 3: Barack Obama does not have any children. 40 | Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia. 41 | Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children. 42 | Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer? 43 | Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information. 44 | ``` 45 | These predicted answers are all INCORRECT because: 46 | - A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect. 47 | 48 | 49 | The following are examples of NOT_ATTEMPTED predicted answers. 50 | ``` 51 | Question: What are the names of Barack Obama's children? 52 | Gold target: Malia and Sasha 53 | Predicted answer 1: I don't know. 54 | Predicted answer 2: I need more context about which Obama you are talking about. 55 | Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children. 56 | Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one. 57 | ``` 58 | These predicted answers are all NOT_ATTEMPTED because: 59 | - The important information in the gold target is not included in the answer. 60 | - No statements in the answer contradict the gold target. 61 | 62 | 63 | Also note the following things: 64 | - For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k". 65 | - Predicted answers "120k", "124k", and 115k" are all CORRECT. 66 | - Predicted answers "100k" and "113k" are INCORRECT. 67 | - Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target. 68 | - The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question. 69 | - For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer. 70 | - Do not punish predicted answers if they omit information that would be clearly inferred from the question. 71 | - For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California". 72 | - Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question. 73 | - For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question. 74 | - For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed. 75 | - Do not punish for typos in people's name if it's clearly the same name. 76 | - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung". 77 | 78 | 79 | Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. 80 | ``` 81 | Question: {question} 82 | Gold target: {target} 83 | Predicted answer: {predicted_answer} 84 | ``` 85 | 86 | Grade the predicted answer of this new question as one of: 87 | A: CORRECT 88 | B: INCORRECT 89 | C: NOT_ATTEMPTED 90 | 91 | Just return the letters "A", "B", or "C", with no text around it. 92 | """.strip() 93 | 94 | 95 | CHOICE_LETTERS = ["A", "B", "C"] 96 | CHOICE_STRINGS = ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"] 97 | CHOICE_LETTER_TO_STRING = dict(zip(CHOICE_LETTERS, CHOICE_STRINGS)) 98 | 99 | class SimpleQAEval(Eval): 100 | def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1): 101 | df = pandas.read_csv( 102 | "https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv" 103 | ) 104 | examples = [row.to_dict() for _, row in df.iterrows()] 105 | if num_examples: 106 | assert n_repeats == 1, "n_repeats only supported when max_examples = None" 107 | rng = random.Random(0) 108 | examples = rng.sample(examples, num_examples) 109 | self.examples = examples * n_repeats 110 | self.grader_model = grader_model 111 | 112 | def grade_sample(self, question: str, target: str, predicted_answer: str) -> str: 113 | grader_prompt = GRADER_TEMPLATE.format( 114 | question=question, 115 | target=target, 116 | predicted_answer=predicted_answer, 117 | ) 118 | 119 | prompt_messages = [ 120 | self.grader_model._pack_message(content=grader_prompt, role="user") 121 | ] 122 | sampler_response = self.grader_model(prompt_messages) 123 | grading_response = sampler_response.response_text 124 | 125 | match = re.search(r"(A|B|C)", grading_response) 126 | return match.group(0) if match else "C" # Default to "NOT_ATTEMPTED" if no match 127 | 128 | def __call__(self, sampler: SamplerBase) -> EvalResult: 129 | def fn(row: dict): 130 | prompt_messages = [ 131 | sampler._pack_message(content=row.get("problem", ""), role="user") 132 | ] 133 | sampler_response = sampler(prompt_messages) 134 | response_text = sampler_response.response_text 135 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list 136 | grade_letter = self.grade_sample(row.get("problem", ""), row.get("answer", ""), response_text) 137 | 138 | # Metrics based on grading response 139 | is_correct = grade_letter == "A" 140 | is_incorrect = grade_letter == "B" 141 | is_not_attempted = grade_letter == "C" 142 | 143 | score = is_correct 144 | 145 | # Create HTML for each sample result 146 | html = common.jinja_env.from_string(common.HTML_JINJA).render( 147 | prompt_messages=actual_queried_prompt_messages, 148 | next_message=dict(content=response_text, role="assistant"), 149 | score=score, 150 | correct_answer=row["answer"], 151 | extracted_answer=response_text, 152 | ) 153 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")] 154 | return SingleEvalResult(html=html, score=score, convo=convo, metrics={ 155 | "is_correct": is_correct, 156 | "is_incorrect": is_incorrect, 157 | "is_not_attempted": is_not_attempted 158 | }) 159 | 160 | # Run evaluation and collect results 161 | results = common.map_with_progress(fn, self.examples) 162 | 163 | # Aggregate metrics 164 | aggregate_metrics = { 165 | "is_correct": sum(result.metrics["is_correct"] for result in results) / len(results), 166 | "is_incorrect": sum(result.metrics["is_incorrect"] for result in results) / len(results), 167 | "is_not_attempted": sum(result.metrics["is_not_attempted"] for result in results) / len(results), 168 | } 169 | aggregate_metrics["is_given_attempted"] = aggregate_metrics["is_correct"] + aggregate_metrics["is_incorrect"] 170 | # Calculate accuracy_given_attempted 171 | aggregate_metrics["accuracy_given_attempted"] = ( 172 | aggregate_metrics["is_correct"] 173 | / aggregate_metrics["is_given_attempted"] 174 | if aggregate_metrics["is_given_attempted"] > 0 175 | else 0 176 | ) 177 | print("AGGREGATE METRICS") 178 | print(aggregate_metrics) 179 | print("##################") 180 | 181 | output_d = { 182 | "accuracy_given_attempted": aggregate_metrics["accuracy_given_attempted"], 183 | "f1": ( 184 | 2 * aggregate_metrics["accuracy_given_attempted"] * aggregate_metrics["is_correct"] 185 | / (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"]) 186 | if (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"]) > 0 187 | else 0 188 | ) 189 | } 190 | 191 | print(f"Accuracy Given Attempted: {output_d['accuracy_given_attempted']:.3f}") 192 | print(f"F1 Score: {output_d['f1']:.3f}") 193 | 194 | return common.aggregate_results(results) 195 | 196 | 197 | -------------------------------------------------------------------------------- /types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Literal, overload 3 | 4 | Message = dict[str, Any] # keys role, content 5 | MessageList = list[Message] 6 | 7 | 8 | 9 | @dataclass 10 | class SamplerResponse: 11 | """ 12 | Response from a sampler. 13 | """ 14 | response_text: str 15 | actual_queried_message_list: MessageList 16 | response_metadata: dict[str, Any] 17 | 18 | class SamplerBase: 19 | """ 20 | Base class for defining a sampling model, which can be evaluated, 21 | or used as part of the grading process. 22 | """ 23 | 24 | def __call__( 25 | self, 26 | message_list: MessageList, 27 | ) -> SamplerResponse: 28 | raise NotImplementedError 29 | 30 | 31 | @dataclass 32 | class EvalResult: 33 | """ 34 | Result of running an evaluation (usually consisting of many samples) 35 | """ 36 | 37 | score: float | None # top-line metric 38 | metrics: dict[str, float] | None # other metrics 39 | htmls: list[str] # strings of valid HTML 40 | convos: list[MessageList] # sampled conversations 41 | metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen 42 | 43 | 44 | @dataclass 45 | class SingleEvalResult: 46 | """ 47 | Result of evaluating a single sample 48 | """ 49 | 50 | score: float | None 51 | metrics: dict[str, float] = field(default_factory=dict) 52 | html: str | None = None 53 | convo: MessageList | None = None # sampled conversation 54 | example_level_metadata: dict[str, Any] | None = ( 55 | None # Extra data such as rubric scores or sollen 56 | ) 57 | 58 | 59 | class Eval: 60 | """ 61 | Base class for defining an evaluation. 62 | """ 63 | 64 | def __call__(self, sampler: SamplerBase) -> EvalResult: 65 | raise NotImplementedError 66 | 67 | --------------------------------------------------------------------------------