) for a message.
258 | """
259 | return jinja_env.from_string(_message_template).render(
260 | role=message["role"],
261 | content=message["content"],
262 | variant=message.get("variant", None),
263 | )
264 |
265 |
266 | jinja_env.globals["message_to_html"] = message_to_html
267 |
268 |
269 | _report_template = """
270 |
271 |
272 |
304 |
305 |
306 | {% if metrics %}
307 |
Metrics
308 |
309 |
310 | Metric |
311 | Value |
312 |
313 |
314 | Score |
315 | {{ score | float | round(3) }} |
316 |
317 | {% for name, value in metrics.items() %}
318 |
319 | {{ name }} |
320 | {{ value }} |
321 |
322 | {% endfor %}
323 |
324 | {% endif %}
325 |
Examples
326 | {% for html in htmls %}
327 | {{ html | safe }}
328 |
329 | {% endfor %}
330 |
331 |
332 | """
333 |
334 |
335 | def make_report(eval_result: EvalResult) -> str:
336 | """
337 | Create a standalone HTML report from an EvalResult.
338 | """
339 | return jinja_env.from_string(_report_template).render(
340 | score=eval_result.score,
341 | metrics=eval_result.metrics,
342 | htmls=eval_result.htmls,
343 | )
344 |
345 |
346 | def make_report_from_example_htmls(htmls: list[str]):
347 | """
348 | Create a standalone HTML report from a list of example htmls
349 | """
350 | return jinja_env.from_string(_report_template).render(
351 | score=None, metrics={}, htmls=htmls
352 | )
353 |
354 |
355 | def normalize_response(response: str) -> str:
356 | """
357 | Normalize the response by removing markdown and LaTeX formatting that may prevent a match.
358 | """
359 |
360 | return (
361 | response.replace("**", "")
362 | .replace("$\\boxed{", "")
363 | .replace("}$", "")
364 | .replace("\\$", "")
365 | .replace("$\\text{", "")
366 | .replace("$", "")
367 | .replace("\\mathrm{", "")
368 | .replace("\\{", "")
369 | .replace("\\text", "")
370 | .replace("\\(", "")
371 | .replace("\\mathbf{", "")
372 | .replace("{", "")
373 | .replace("\\boxed", "")
374 | )
375 |
376 |
377 | def normalize_extracted_answer(extracted_answer: str) -> str:
378 | return (
379 | # In arabic these are the letters used for A-D in multiple choice questions
380 | extracted_answer.replace("أ", " A")
381 | .replace("ب", " B")
382 | .replace("ج", " C")
383 | .replace("د", " D")
384 | # In Bengali these are the letters used for A-D in multiple choice questions
385 | .replace("অ", " A")
386 | .replace("ব", " B")
387 | .replace("ড", " C")
388 | .replace("ঢ", " D")
389 | # In Japanese these are the letters sometimes used for A-D in multiple choice questions
390 | .replace("A", " A")
391 | .replace("B", " B")
392 | .replace("C", " C")
393 | .replace("D", " D")
394 | .strip()
395 | )
396 |
397 |
398 | def url_to_fileobj(url: str, binary=False) -> Any:
399 | response = requests.get(url)
400 | response.raise_for_status()
401 | return io.BytesIO(response.content) if binary else io.StringIO(response.text)
402 |
403 |
404 | def has_only_user_assistant_messages(messages: list[Message]) -> bool:
405 | """
406 | Check if the messages only contain user and assistant messages.
407 | """
408 | return all(m["role"] in ("user", "assistant") for m in messages)
409 |
--------------------------------------------------------------------------------
/drop_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
3 | Dheeru Dua, Yizhong Wang, Pradeep Dasigi, Gabriel Stanovsky, Sameer Singh, Matt Gardner
4 | https://arxiv.org/abs/1903.00161
5 | """
6 |
7 | import gzip
8 | import json
9 | import random
10 | import re
11 | import string
12 | from typing import Any, Dict, List, Optional, Set, Tuple, Union
13 |
14 | import numpy as np
15 | from scipy.optimize import linear_sum_assignment
16 |
17 | from . import common
18 | from .common import ANSWER_PATTERN, HTML_JINJA
19 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
20 |
21 | """
22 | From here through _normalize_answer was originally copied from:
23 | https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
24 | Then cleaned up and modified a bit.
25 |
26 | The rest was originally copied from https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc
27 | /eval/drop_eval.py
28 | """
29 |
30 |
31 | def _remove_articles(text: str) -> str:
32 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
33 | return re.sub(regex, " ", text)
34 |
35 |
36 | def _white_space_fix(text: str) -> str:
37 | return " ".join(text.split())
38 |
39 |
40 | EXCLUDE = set(string.punctuation)
41 |
42 |
43 | def _remove_punc(text: str) -> str:
44 | if not _is_number(text):
45 | return "".join(ch for ch in text if ch not in EXCLUDE)
46 | else:
47 | return text
48 |
49 |
50 | def _lower(text: str) -> str:
51 | return text.lower()
52 |
53 |
54 | def _tokenize(text: str) -> List[str]:
55 | return re.split(" |-", text)
56 |
57 |
58 | def _normalize_answer(text: str) -> str:
59 | """Lower text and remove punctuation, articles and extra whitespace."""
60 |
61 | parts = [
62 | _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
63 | for token in _tokenize(text)
64 | ]
65 | parts = [part for part in parts if part.strip()]
66 | normalized = " ".join(parts).strip()
67 | return normalized
68 |
69 |
70 | def _is_number(text: str) -> bool:
71 | try:
72 | float(text)
73 | return True
74 | except ValueError:
75 | return False
76 |
77 |
78 | def _normalize_number(text: str) -> str:
79 | if _is_number(text):
80 | return str(float(text))
81 | else:
82 | return text
83 |
84 |
85 | def _answer_to_bags(
86 | answer: Union[str, List[str], Tuple[str, ...]]
87 | ) -> Tuple[List[str], List[Set[str]]]:
88 | if isinstance(answer, (list, tuple)):
89 | raw_spans = answer
90 | else:
91 | raw_spans = [answer]
92 | normalized_spans: List[str] = []
93 | token_bags = []
94 | for raw_span in raw_spans:
95 | normalized_span = _normalize_answer(raw_span)
96 | normalized_spans.append(normalized_span)
97 | token_bags.append(set(normalized_span.split()))
98 | return normalized_spans, token_bags
99 |
100 |
101 | def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
102 | """
103 | Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
104 | between them and gets maximum metric values over all the answers.
105 | """
106 | scores = np.zeros([len(gold), len(predicted)])
107 | for gold_index, gold_item in enumerate(gold):
108 | for pred_index, pred_item in enumerate(predicted):
109 | if _match_numbers_if_present(gold_item, pred_item):
110 | scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
111 | row_ind, col_ind = linear_sum_assignment(-scores)
112 |
113 | max_scores = np.zeros([max(len(gold), len(predicted))])
114 | for row, column in zip(row_ind, col_ind):
115 | max_scores[row] = max(max_scores[row], scores[row, column])
116 | return max_scores
117 |
118 |
119 | def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
120 | intersection = len(gold_bag.intersection(predicted_bag))
121 | if not predicted_bag:
122 | precision = 1.0
123 | else:
124 | precision = intersection / float(len(predicted_bag))
125 | if not gold_bag:
126 | recall = 1.0
127 | else:
128 | recall = intersection / float(len(gold_bag))
129 | f1 = (
130 | (2 * precision * recall) / (precision + recall)
131 | if not (precision == 0.0 and recall == 0.0)
132 | else 0.0
133 | ) * 100
134 | return f1
135 |
136 |
137 | def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
138 | gold_numbers = set()
139 | predicted_numbers = set()
140 | for word in gold_bag:
141 | if _is_number(word):
142 | gold_numbers.add(word)
143 | for word in predicted_bag:
144 | if _is_number(word):
145 | predicted_numbers.add(word)
146 | if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
147 | return True
148 | return False
149 |
150 |
151 | def get_drop_metrics(
152 | predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]]
153 | ) -> Tuple[float, float]:
154 | """
155 | Takes a predicted answer and a gold answer (that are both either a string or a list of
156 | strings), and returns exact match and the DROP F1 metric for the prediction. If you are
157 | writing a script for evaluating objects in memory (say, the output of predictions during
158 | validation, or while training), this is the function you want to call, after using
159 | :func:`answer_json_to_strings` when reading the gold answer from the released data file.
160 | """
161 | predicted_bags = _answer_to_bags(predicted)
162 | gold_bags = _answer_to_bags(gold)
163 |
164 | if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
165 | exact_match = 1.0
166 | else:
167 | exact_match = 0.0
168 |
169 | f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
170 | f1 = np.mean(f1_per_bag)
171 | f1 = round(f1, 2)
172 | return exact_match, f1
173 |
174 |
175 | def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
176 | """
177 | Takes an answer JSON blob from the DROP data release and converts it into strings used for
178 | evaluation.
179 | """
180 | if "number" in answer and answer["number"]:
181 | return tuple([str(answer["number"])]), "number"
182 | elif "spans" in answer and answer["spans"]:
183 | return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans"
184 | elif "date" in answer:
185 | return (
186 | tuple(
187 | [
188 | "{0} {1} {2}".format(
189 | answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
190 | ).strip()
191 | ]
192 | ),
193 | "date",
194 | )
195 | else:
196 | raise ValueError(
197 | f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}"
198 | )
199 |
200 |
201 | def answer_json_to_string(answer_json):
202 | return json.dumps(answer_json_to_strings(answer_json))
203 |
204 |
205 | def normalize(s: str) -> str:
206 | """Lower text and remove punctuation, articles and extra whitespace."""
207 | s = s.lower()
208 | exclude = set(string.punctuation)
209 | s = "".join(char for char in s if char not in exclude)
210 | s = re.sub(r"\b(a|an|the)\b", " ", s)
211 | s = " ".join(s.split())
212 | return s
213 |
214 |
215 | def fuzzy_match(s1: str, s2: str) -> bool:
216 | s1 = normalize(s1)
217 | s2 = normalize(s2)
218 |
219 | if s1 == "" or s2 == "":
220 | return s1 == s2
221 |
222 | return s1 in s2 or s2 in s1
223 |
224 |
225 | def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]:
226 | em_scores = []
227 | f1_scores = []
228 | for answer in reference:
229 | if answer.strip() != "":
230 | em, f1 = get_drop_metrics(sample, answer)
231 | em_scores.append(em)
232 | f1_scores.append(f1)
233 | return (max(em_scores), max(f1_scores))
234 |
235 |
236 | class DropEval(Eval):
237 | def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3):
238 | self.seed = 42
239 | self._num_examples = num_examples
240 | self._train_samples_per_prompt = train_samples_per_prompt
241 | self.train_jsonl = (
242 | "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz"
243 | )
244 | self.test_jsonl = (
245 | "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
246 | )
247 | with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl, binary=True), mode="rb") as f:
248 | self.train_samples = list(map(json.loads, f.readlines()))
249 | with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl, binary=True), mode="rb") as f:
250 | self.test_samples = list(map(json.loads, f.readlines()))
251 | if self._num_examples:
252 | self.test_samples = random.Random(self.seed).sample(
253 | self.test_samples, self._num_examples
254 | )
255 |
256 | def __call__(self, sampler: SamplerBase) -> EvalResult:
257 | rng = random.Random(self.seed)
258 |
259 | def fn(example: dict[str, str]):
260 | stuffing = rng.sample(self.train_samples, self._train_samples_per_prompt)
261 |
262 | # prompt = """TASK: Read the provided passage, then identify the correct answer to questions below."""
263 | prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below."""
264 | prompt += "\n\n# Examples"
265 | samples = stuffing + [example]
266 | for i, sample in enumerate(samples):
267 | is_test = i == len(stuffing)
268 | prompt += "\n# Your Task\n" if is_test else ""
269 | prompt += f"""
270 | ---
271 | {sample["context"]} """
272 |
273 | a = sample["completion"]
274 | correct_answers = sample["ref_text"].split("|")
275 |
276 | if not is_test:
277 | prompt += a + "\n"
278 | else:
279 | prompt += """\n
280 | Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response.
281 | """
282 | prompt_messages = [sampler._pack_message(content=prompt, role="user")]
283 | sampler_response = sampler(prompt_messages)
284 | response_text = sampler_response.response_text
285 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list
286 | match = re.search(ANSWER_PATTERN, response_text)
287 | extracted_answer = match.group(1) if match else response_text
288 | em_score, f1_score = drop_metric(extracted_answer, correct_answers)
289 | matches = [
290 | fuzzy_match(extracted_answer, correct_answer)
291 | for correct_answer in correct_answers
292 | ]
293 | extracted_answers = [
294 | extracted_answer for i in range(len(correct_answers)) if matches[i]
295 | ]
296 | score = True in matches
297 | html = common.jinja_env.from_string(HTML_JINJA).render(
298 | prompt_messages=actual_queried_prompt_messages,
299 | next_message=dict(content=extracted_answer, role="assistant"),
300 | score=score,
301 | correct_answer=correct_answers,
302 | extracted_answer=extracted_answers,
303 | )
304 | convo = actual_queried_prompt_messages + [dict(content=extracted_answer, role="assistant")]
305 | return SingleEvalResult(
306 | html=html,
307 | score=score,
308 | convo=convo,
309 | metrics={"em_score": em_score, "f1_score": f1_score},
310 | )
311 |
312 | results = common.map_with_progress(fn, self.test_samples)
313 | return common.aggregate_results(results)
314 |
--------------------------------------------------------------------------------
/gpqa_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | GPQA: A Graduate-Level Google-Proof Q&A Benchmark
3 | David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
4 | https://arxiv.org/abs/2311.12022
5 | """
6 |
7 | import random
8 | import re
9 |
10 | import pandas
11 |
12 | from . import common
13 | from .common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question
14 | from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
15 |
16 |
17 | class GPQAEval(Eval):
18 | def __init__(
19 | self,
20 | n_repeats: int = 4,
21 | variant: str = "diamond",
22 | num_examples: int | None = None, # restrict to a subset of the data for debugging
23 | ):
24 | df = pandas.read_csv(
25 | f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv"
26 | )
27 | examples = [row.to_dict() for _, row in df.iterrows()]
28 | rng = random.Random(0)
29 | if num_examples:
30 | assert n_repeats == 1, "n_repeats only supported for num_examples = None"
31 | examples = rng.sample(examples, num_examples)
32 | examples = examples * n_repeats
33 | examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
34 | self.examples = examples
35 | self.n_repeats = n_repeats
36 |
37 | def __call__(self, sampler: SamplerBase) -> EvalResult:
38 | def fn(row: dict):
39 | choices = [
40 | row["Correct Answer"],
41 | row["Incorrect Answer 1"],
42 | row["Incorrect Answer 2"],
43 | row["Incorrect Answer 3"],
44 | ]
45 | choices = [choices[i] for i in row["permutation"]]
46 | correct_index = choices.index(row["Correct Answer"])
47 | correct_answer = "ABCD"[correct_index]
48 | choices_dict = dict(
49 | A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"]
50 | )
51 | prompt_messages = [
52 | sampler._pack_message(
53 | content=format_multichoice_question(choices_dict), role="user"
54 | )
55 | ]
56 | sampler_response = sampler(prompt_messages)
57 | response_text = sampler_response.response_text
58 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list
59 | match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
60 | extracted_answer = match.group(1) if match else None
61 | score = 1.0 if extracted_answer == correct_answer else 0.0
62 | html = common.jinja_env.from_string(HTML_JINJA).render(
63 | prompt_messages=actual_queried_prompt_messages,
64 | next_message=dict(content=response_text, role="assistant"),
65 | score=score,
66 | correct_answer=correct_answer,
67 | extracted_answer=extracted_answer,
68 | )
69 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
70 | return SingleEvalResult(
71 | html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
72 | )
73 |
74 | results = common.map_with_progress(fn, self.examples)
75 | return common.aggregate_results(results)
76 |
--------------------------------------------------------------------------------
/healthbench_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | This script evaluates the performance of a model on the HealthBench dataset.
3 |
4 | To run HealthBench, HealthBench Consensus, or HealthBench Hard, use the simple-evals script:
5 | - `python -m simple-evals.simple_evals --eval=healthbench --model=gpt-4.1`
6 | - `python -m simple-evals.simple_evals --eval=healthbench_consensus --model=gpt-4.1`
7 | - `python -m simple-evals.simple_evals --eval=healthbench_hard --model=gpt-4.1`
8 |
9 | You can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command:
10 | - To evaluate physician ideal completions: `python -m simple-evals.healthbench_eval --run_mode=physician_completions`
11 | - To evaluate reference model completions used by physicians: `python -m simple-evals.healthbench_eval --run_mode=physician_completion_references`
12 | """
13 |
14 | import argparse
15 | import copy
16 | import hashlib
17 | import json
18 | import random
19 | import re
20 | from collections import defaultdict
21 | from datetime import datetime
22 | from pathlib import Path
23 | from typing import Literal
24 |
25 | import blobfile as bf
26 | import numpy as np
27 | import pandas as pd
28 |
29 | from . import common
30 | from .sampler.chat_completion_sampler import (
31 | OPENAI_SYSTEM_MESSAGE_API,
32 | ChatCompletionSampler,
33 | )
34 | from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
35 |
36 | INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl"
37 | INPUT_PATH_HARD = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl"
38 | INPUT_PATH_CONSENSUS = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl"
39 |
40 | GRADER_TEMPLATE = """
41 | Your job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item.
42 |
43 | # Conversation
44 | <
>
45 |
46 | # Rubric item
47 | <>
48 |
49 | # Instructions
50 | Return a json object with the following fields: "explanation" and "criteria_met".
51 | - The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item.
52 | - The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met.
53 | - One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.
54 |
55 | # Example 1
56 | For example, if the conversation is "user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better" and the rubric item is "tells the user to go to the hospital if unconscious", you should return a json like this:
57 |
58 | ```json
59 | {
60 | "explanation": "The assistant did not tell the user to go to the hospital if unconscious.",
61 | "criteria_met": false
62 | }
63 | ```
64 |
65 | # Example 2
66 | Note that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not.
67 | For example, consider the conversation is "user: is smoking good for you? assistant: no, smoking is bad for you" and the criterion is "fails to give the user accurate information", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this:
68 |
69 | ```json
70 | {
71 | "explanation": "The assistant actually gave the user accurate information, so the criteria was not met.",
72 | "criteria_met": false
73 | }
74 | ```
75 |
76 | As another example, if the critera says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:
77 |
78 | ```json
79 | {
80 | "explanation": "The response is concise, so the criteria was not met.",
81 | "criteria_met": false
82 | }
83 | ```
84 |
85 | In other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true.
86 |
87 | # Final instruction
88 | Return just the json object in markdown format. Do not include any other text in the response.
89 | """.strip()
90 |
91 | HEALTHBENCH_HTML_JINJA = (
92 | common.HTML_JINJA.replace(
93 | "Correct Answer: {{ correct_answer }}
\n",
94 | "",
95 | )
96 | + "Rubrics with grades: {{ rubric_grades }}
"
97 | )
98 |
99 |
100 | def parse_json_to_dict(json_string: str) -> dict:
101 | # Remove markdown-style ```json``` markers if present
102 | json_cleaned = re.sub(r"^```json\s*|\s*```$", "", json_string.strip())
103 |
104 | try:
105 | return json.loads(json_cleaned)
106 | except json.JSONDecodeError as e:
107 | print(f"JSON decoding failed: {e}")
108 | return {}
109 |
110 |
111 | class RubricItem:
112 | def __init__(self, criterion: str, points: float, tags: list[str]):
113 | self.criterion = criterion
114 | self.points = points
115 | self.tags = tags
116 |
117 | def __str__(self):
118 | return f"[{self.points}] {self.criterion}"
119 |
120 | def to_dict(self):
121 | return {
122 | "criterion": self.criterion,
123 | "points": self.points,
124 | "tags": self.tags,
125 | }
126 |
127 | @classmethod
128 | def from_dict(cls, d: dict):
129 | return cls(
130 | criterion=d["criterion"],
131 | points=d["points"],
132 | tags=d["tags"],
133 | )
134 |
135 |
136 | def calculate_score(
137 | rubric_items: list[RubricItem], grading_response_list: list[dict]
138 | ) -> float | None:
139 | total_possible_points = sum(
140 | rubric_item.points for rubric_item in rubric_items if rubric_item.points > 0
141 | )
142 | if total_possible_points == 0:
143 | # should not happen for overall score, but may happen for tags
144 | return None
145 |
146 | achieved_points = sum(
147 | rubric_item.points
148 | for rubric_item, grading_response in zip(
149 | rubric_items, grading_response_list, strict=True
150 | )
151 | if grading_response["criteria_met"]
152 | )
153 | overall_score = achieved_points / total_possible_points
154 | return overall_score
155 |
156 |
157 | def get_usage_dict(response_usage) -> dict[str, int | None]:
158 | if response_usage is None:
159 | return {
160 | "input_tokens": None,
161 | "input_cached_tokens": None,
162 | "output_tokens": None,
163 | "output_reasoning_tokens": None,
164 | "total_tokens": None,
165 | }
166 |
167 | try:
168 | return {
169 | "input_tokens": response_usage.input_tokens,
170 | "input_cached_tokens": response_usage.input_tokens_details.cached_tokens
171 | if hasattr(response_usage.input_tokens_details, "cached_tokens")
172 | else response_usage.input_tokens_details["cached_tokens"],
173 | "output_tokens": response_usage.output_tokens,
174 | "output_reasoning_tokens": response_usage.output_tokens_details.reasoning_tokens
175 | if hasattr(response_usage.output_tokens_details, "reasoning_tokens")
176 | else response_usage.output_tokens_details["reasoning_tokens"],
177 | "total_tokens": response_usage.total_tokens,
178 | }
179 | except AttributeError:
180 | return {
181 | "input_tokens": response_usage.prompt_tokens,
182 | "input_cached_tokens": response_usage.prompt_tokens_details.cached_tokens
183 | if hasattr(response_usage.prompt_tokens_details, "cached_tokens")
184 | else response_usage.prompt_tokens_details["cached_tokens"],
185 | "output_tokens": response_usage.completion_tokens,
186 | "output_reasoning_tokens": response_usage.completion_tokens_details.reasoning_tokens
187 | if hasattr(response_usage.completion_tokens_details, "reasoning_tokens")
188 | else response_usage.completion_tokens_details["reasoning_tokens"],
189 | "total_tokens": response_usage.total_tokens,
190 | }
191 |
192 |
193 | PHYSICIAN_COMPLETION_MODES = {
194 | "Group 1": {
195 | "description": "No reference completions were provided to the physicians.",
196 | "short_name": "no_reference",
197 | "has_reference": False,
198 | },
199 | "Group 2": {
200 | "description": "Reference completions were provided to the physicians from Aug / Sep 2024 models (gpt-4o-2024-08-06, o1-preview).",
201 | "short_name": "aug_2024_reference",
202 | "has_reference": True,
203 | },
204 | "Group 3": {
205 | "description": "Reference completions were provided to the physicians from Apr 2025 models (o3, gpt-4.1).",
206 | "short_name": "apr_2025_reference",
207 | "has_reference": True,
208 | },
209 | }
210 |
211 |
212 | def _compute_clipped_stats(
213 | values: list,
214 | stat: str,
215 | ):
216 | """Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and n_samples for final HealthBench scoring."""
217 | if stat == "mean":
218 | return np.clip(np.mean(values), 0, 1)
219 | elif stat == "n_samples":
220 | return len(values)
221 | elif stat == "bootstrap_std":
222 | bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]
223 | bootstrap_means = [
224 | _compute_clipped_stats(list(s), "mean") for s in bootstrap_samples
225 | ]
226 | return np.std(bootstrap_means)
227 | else:
228 | raise ValueError(f"Unknown {stat =}")
229 |
230 |
231 | def _aggregate_get_clipped_mean(
232 | single_eval_results: list[SingleEvalResult],
233 | ) -> EvalResult:
234 | """
235 | Aggregate multiple SingleEvalResults into a single EvalResult for HealthBench.
236 | For each metric, returns the stats in _compute_clipped_stats.
237 | """
238 | name2values = defaultdict(list)
239 | htmls = []
240 | convos = []
241 | metadata = []
242 | for single_eval_result in single_eval_results:
243 | for name, value in single_eval_result.metrics.items():
244 | name2values[name].append(value)
245 | if single_eval_result.score is not None:
246 | name2values["score"].append(single_eval_result.score)
247 | htmls.append(single_eval_result.html)
248 | convos.append(single_eval_result.convo)
249 | metadata.append(single_eval_result.example_level_metadata)
250 | final_metrics = {}
251 | for name, values in name2values.items():
252 | for stat in ["mean", "n_samples", "bootstrap_std"]:
253 | key = name if stat == "mean" else f"{name}:{stat}"
254 | final_metrics[key] = _compute_clipped_stats(values, stat)
255 | return EvalResult(
256 | score=final_metrics.pop("score", None),
257 | metrics=final_metrics,
258 | htmls=htmls,
259 | convos=convos,
260 | metadata={"example_level_metadata": metadata},
261 | )
262 |
263 |
264 | class HealthBenchEval(Eval):
265 | def __init__(
266 | self,
267 | grader_model: SamplerBase,
268 | num_examples: int | None = None,
269 | n_repeats: int = 1,
270 | # If set, evaluate human completions or reference completions instead of model completions.
271 | physician_completions_mode: str | None = None,
272 | # If True, run the grader on reference completions used by physicians, and physician_completions_mode must be set.
273 | run_reference_completions: bool = False,
274 | n_threads: int = 120,
275 | subset_name: Literal["hard", "consensus"] | None = None,
276 | ):
277 | if run_reference_completions:
278 | assert physician_completions_mode is not None, (
279 | "physician_completions_mode must be provided if run_reference_completions is True"
280 | )
281 | assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][
282 | "has_reference"
283 | ], (
284 | "physician_completions_mode must have reference completions if run_reference_completions is True"
285 | )
286 |
287 | if subset_name == "hard":
288 | input_path = INPUT_PATH_HARD
289 | elif subset_name == "consensus":
290 | input_path = INPUT_PATH_CONSENSUS
291 | elif subset_name is None:
292 | input_path = INPUT_PATH
293 | else:
294 | assert False, f"Invalid subset name: {subset_name}"
295 | with bf.BlobFile(input_path, "rb") as f:
296 | examples = [json.loads(line) for line in f]
297 | for example in examples:
298 | example["rubrics"] = [RubricItem.from_dict(d) for d in example["rubrics"]]
299 |
300 | rng = random.Random(0)
301 |
302 | # physician completions mode
303 | self.physician_completions_mode = physician_completions_mode
304 | if self.physician_completions_mode is not None:
305 | assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, (
306 | f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}"
307 | )
308 | # subset to only the rows which have physician completions from that group
309 | examples_matching_mode = [
310 | example
311 | for example in examples
312 | if example["ideal_completions_data"] is not None
313 | and example["ideal_completions_data"]["ideal_completions_group"]
314 | == self.physician_completions_mode
315 | ]
316 | print(
317 | f"Subsetting to {len(examples_matching_mode)} examples with physician completions of type {self.physician_completions_mode} ({PHYSICIAN_COMPLETION_MODES[self.physician_completions_mode]['description']})"
318 | )
319 |
320 | examples = []
321 | if run_reference_completions:
322 | for example in examples_matching_mode:
323 | for completion in example["ideal_completions_data"][
324 | "ideal_completions_ref_completions"
325 | ]:
326 | new_example = copy.deepcopy(example)
327 | new_example["completion_to_trial"] = completion
328 | examples.append(new_example)
329 | assert len(examples) == len(examples_matching_mode) * 4
330 | print(
331 | f"Running four references for each example, for {len(examples)} total"
332 | )
333 | else:
334 | for example in examples_matching_mode:
335 | example["completion_to_trial"] = example["ideal_completions_data"][
336 | "ideal_completion"
337 | ]
338 | examples.append(example)
339 | assert len(examples) == len(examples_matching_mode)
340 |
341 | if len(examples) == 0:
342 | raise ValueError(
343 | f"No examples found matching mode {self.physician_completions_mode}"
344 | )
345 |
346 | if num_examples is not None and num_examples < len(examples):
347 | examples = rng.sample(
348 | examples,
349 | num_examples,
350 | )
351 |
352 | self.examples = examples * n_repeats
353 | self.n_threads = n_threads
354 | self.grader_model = grader_model
355 |
356 | def grade_sample(
357 | self,
358 | prompt: list[dict[str, str]],
359 | response_text: str,
360 | example_tags: list[str],
361 | rubric_items: list[RubricItem],
362 | ) -> tuple[dict, str, list[dict]]:
363 | # construct and grade the sample
364 | convo_with_response = prompt + [dict(content=response_text, role="assistant")]
365 |
366 | def grade_rubric_item(rubric_item: RubricItem) -> dict:
367 | convo_str = "\n\n".join(
368 | [f"{m['role']}: {m['content']}" for m in convo_with_response]
369 | )
370 | grader_prompt = GRADER_TEMPLATE.replace(
371 | "<>", convo_str
372 | ).replace("<>", str(rubric_item))
373 | messages: MessageList = [dict(content=grader_prompt, role="user")]
374 | while True:
375 | sampler_response = self.grader_model(messages)
376 | grading_response = sampler_response.response_text
377 | grading_response_dict = parse_json_to_dict(grading_response)
378 | if "criteria_met" in grading_response_dict:
379 | label = grading_response_dict["criteria_met"]
380 | if label is True or label is False:
381 | break
382 | print("Grading failed due to bad JSON output, retrying...")
383 | return grading_response_dict
384 |
385 | grading_response_list = common.map_with_progress(
386 | grade_rubric_item,
387 | rubric_items,
388 | pbar=False,
389 | )
390 |
391 | # compute the overall score
392 | overall_score = calculate_score(rubric_items, grading_response_list)
393 | assert overall_score is not None
394 | metrics = {
395 | "overall_score": overall_score,
396 | }
397 |
398 | # compute scores for example-level tags)
399 | example_tag_scores = {tag: overall_score for tag in example_tags}
400 | assert len(example_tag_scores) == len(example_tags) # No duplicates.
401 | metrics.update(example_tag_scores)
402 |
403 | # compute scores for rubric-level tags
404 | rubric_tag_items_grades = defaultdict(list)
405 | for rubric_item, grading_response in zip(rubric_items, grading_response_list):
406 | curr_item_tags = set() # Ensure no duplicates in a rubric item.
407 | for tag in rubric_item.tags:
408 | rubric_tag_items_grades[tag].append((rubric_item, grading_response))
409 | assert tag not in curr_item_tags
410 | curr_item_tags.add(tag)
411 |
412 | rubric_tag_scores = {}
413 | for tag, items_grades in rubric_tag_items_grades.items():
414 | items, grades = zip(*items_grades)
415 | score = calculate_score(items, grades)
416 | if score is not None: # implies at least one positive criterion
417 | rubric_tag_scores[tag] = score
418 | metrics.update(rubric_tag_scores)
419 |
420 | # construct the list of explanations and grades
421 | rubric_items_with_grades = []
422 | readable_explanation_list = []
423 | for rubric_item, grading_response in zip(rubric_items, grading_response_list):
424 | explanation = grading_response.get("explanation", "No explanation provided")
425 | criteria_met = grading_response["criteria_met"]
426 | readable_explanation = (
427 | f"[{criteria_met}] {rubric_item}\n\tExplanation: {explanation}"
428 | )
429 | readable_explanation_list.append(readable_explanation)
430 | rubric_items_with_grades.append(
431 | {
432 | **rubric_item.to_dict(),
433 | "criteria_met": criteria_met,
434 | "explanation": explanation,
435 | }
436 | )
437 |
438 | readable_explanation_list.sort(
439 | key=lambda x: x.startswith("[False]"), reverse=True
440 | )
441 | readable_explanation_str = "\n\n".join(readable_explanation_list)
442 | readable_explanation_str = f"\n\n{readable_explanation_str}"
443 |
444 | return metrics, readable_explanation_str, rubric_items_with_grades
445 |
446 | def __call__(self, sampler: SamplerBase) -> EvalResult:
447 | def fn(row: dict):
448 | prompt_messages = row["prompt"]
449 |
450 | if self.physician_completions_mode is not None:
451 | response_text = row["completion_to_trial"]
452 | response_usage = None
453 | actual_queried_prompt_messages = prompt_messages
454 | else:
455 | sampler_response = sampler(prompt_messages)
456 | response_text = sampler_response.response_text
457 | response_dict = sampler_response.response_metadata
458 | actual_queried_prompt_messages = (
459 | sampler_response.actual_queried_message_list
460 | )
461 | response_usage = response_dict.get("usage", None)
462 |
463 | metrics, readable_explanation_str, rubric_items_with_grades = (
464 | self.grade_sample(
465 | prompt=actual_queried_prompt_messages,
466 | response_text=response_text,
467 | rubric_items=row["rubrics"],
468 | example_tags=row["example_tags"],
469 | )
470 | )
471 |
472 | score = metrics["overall_score"]
473 |
474 | # Create HTML for each sample result
475 | html = common.jinja_env.from_string(
476 | HEALTHBENCH_HTML_JINJA.replace(
477 | "{{ rubric_grades }}",
478 | readable_explanation_str.replace("\n", "
"),
479 | )
480 | ).render(
481 | prompt_messages=actual_queried_prompt_messages,
482 | next_message=dict(content=response_text, role="assistant"),
483 | score=metrics["overall_score"],
484 | extracted_answer=response_text,
485 | )
486 |
487 | convo = actual_queried_prompt_messages + [
488 | dict(content=response_text, role="assistant")
489 | ]
490 | return SingleEvalResult(
491 | html=html,
492 | score=score,
493 | convo=convo,
494 | metrics=metrics,
495 | example_level_metadata={
496 | "score": score,
497 | "usage": get_usage_dict(response_usage),
498 | "rubric_items": rubric_items_with_grades,
499 | "prompt": actual_queried_prompt_messages,
500 | "completion": [dict(content=response_text, role="assistant")],
501 | "prompt_id": row["prompt_id"],
502 | "completion_id": hashlib.sha256(
503 | (row["prompt_id"] + response_text).encode("utf-8")
504 | ).hexdigest(),
505 | },
506 | )
507 |
508 | results = common.map_with_progress(
509 | fn,
510 | self.examples,
511 | num_threads=self.n_threads,
512 | pbar=True,
513 | )
514 | final_metrics = _aggregate_get_clipped_mean(results)
515 | return final_metrics
516 |
517 |
518 | def main():
519 | parser = argparse.ArgumentParser(
520 | description="HealthBenchEval specific run options, including e.g., running the eval on physician completions rows only."
521 | )
522 | parser.add_argument(
523 | "--run_mode",
524 | type=str,
525 | choices=["physician_completions", "physician_completion_references"],
526 | )
527 | parser.add_argument("--examples", type=int, help="Number of examples to run")
528 | parser.add_argument(
529 | "--n-threads",
530 | type=int,
531 | default=120,
532 | help="Number of threads to run",
533 | )
534 | args = parser.parse_args()
535 |
536 | if args.run_mode == "physician_completions":
537 | physician_completions_main(
538 | run_reference_completions=False,
539 | num_examples=args.examples,
540 | n_threads=args.n_threads or 1,
541 | )
542 | elif args.run_mode == "physician_completion_references":
543 | physician_completions_main(
544 | run_reference_completions=True,
545 | num_examples=args.examples,
546 | n_threads=args.n_threads or 1,
547 | )
548 |
549 | else:
550 | raise ValueError(f"Invalid run mode: {args.run_mode}")
551 |
552 |
553 | def physician_completions_main(
554 | run_reference_completions: bool = False,
555 | num_examples: int | None = None,
556 | n_threads: int = 120,
557 | ):
558 | now = datetime.now()
559 | date_str = now.strftime("%Y%m%d_%H%M")
560 |
561 | grading_sampler = ChatCompletionSampler(
562 | model="gpt-4.1-2025-04-14",
563 | system_message=OPENAI_SYSTEM_MESSAGE_API,
564 | max_tokens=2048,
565 | )
566 | dummy_sampler = SamplerBase()
567 |
568 | merge_metrics = []
569 | for pc_mode in PHYSICIAN_COMPLETION_MODES.keys():
570 | if (
571 | run_reference_completions
572 | and not PHYSICIAN_COMPLETION_MODES[pc_mode]["has_reference"]
573 | ):
574 | continue
575 |
576 | # run
577 | eval = HealthBenchEval(
578 | grader_model=grading_sampler,
579 | physician_completions_mode=pc_mode,
580 | run_reference_completions=run_reference_completions,
581 | num_examples=num_examples,
582 | n_threads=n_threads,
583 | )
584 | result = eval(dummy_sampler)
585 |
586 | # report
587 | parsable_mode = PHYSICIAN_COMPLETION_MODES[pc_mode]["short_name"]
588 | if run_reference_completions:
589 | file_stem = f"healthbench_{parsable_mode}_referencecompletions_{date_str}"
590 | else:
591 | file_stem = f"healthbench_{parsable_mode}_humanbaseline_{date_str}"
592 | report_filename = Path(f"/tmp/{file_stem}.html")
593 | report_filename.write_text(common.make_report(result))
594 | print(f"Report saved to {report_filename}")
595 |
596 | # metrics
597 | assert result.metrics is not None
598 | metrics = result.metrics
599 | result_filename = Path(f"/tmp/{file_stem}.json")
600 | result_filename.write_text(json.dumps(metrics))
601 | print(f"Results saved to {result_filename}")
602 |
603 | full_result_dict = {
604 | "score": result.score,
605 | "metrics": result.metrics,
606 | "htmls": result.htmls,
607 | "convos": result.convos,
608 | "metadata": result.metadata,
609 | }
610 | full_result_filename = Path(f"/tmp/{file_stem}_allresults.json")
611 | full_result_filename.write_text(json.dumps(full_result_dict, indent=2))
612 | print(f"All results saved to {full_result_filename}")
613 |
614 | # metrics df
615 | merge_metrics.append(
616 | {
617 | "eval_name": "healthbench",
618 | "model_name": f"{pc_mode} ({PHYSICIAN_COMPLETION_MODES[pc_mode]['description']})",
619 | "metric": metrics.get("overall_score", None),
620 | }
621 | )
622 |
623 | merge_metrics_df = pd.DataFrame(merge_metrics).pivot(
624 | index=["model_name"], columns="eval_name"
625 | )
626 | print("\nAll results: ")
627 | print(merge_metrics_df.to_markdown())
628 | return merge_metrics
629 |
630 |
631 | if __name__ == "__main__":
632 | main()
633 |
--------------------------------------------------------------------------------
/healthbench_eval_test.py:
--------------------------------------------------------------------------------
1 | from .healthbench_eval import RubricItem, calculate_score
2 |
3 |
4 | def test_calculate_score():
5 | rubric_items = [
6 | RubricItem(criterion="test", points=7, tags=[]),
7 | RubricItem(criterion="test", points=5, tags=[]),
8 | RubricItem(criterion="test", points=10, tags=[]),
9 | RubricItem(criterion="test", points=-6, tags=[]),
10 | ]
11 | grading_response_list = [
12 | {"criteria_met": True},
13 | {"criteria_met": False},
14 | {"criteria_met": True},
15 | {"criteria_met": True},
16 | ]
17 | total_possible = 7 + 5 + 10
18 | achieved = 7 + 0 + 10 - 6
19 | assert (
20 | calculate_score(rubric_items, grading_response_list)
21 | == achieved / total_possible
22 | )
23 |
24 |
25 | if __name__ == "__main__":
26 | test_calculate_score()
27 |
--------------------------------------------------------------------------------
/healthbench_meta_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | This script evaluates a grader model on grading HealthBench rubrics. It effectively
3 | evaluates the evaluator against physician opinion, so we call it a meta-evaluation.
4 |
5 | To run, use the following command (working directory should contain simple-evals folder):
6 | `python -m simple-evals.simple_evals --eval=healthbench_meta --model=gpt-4.1`
7 | """
8 |
9 | import json
10 | import random
11 | from collections import defaultdict
12 | from typing import Literal
13 |
14 | import blobfile as bf
15 |
16 | from . import common
17 | from .healthbench_eval import GRADER_TEMPLATE, parse_json_to_dict
18 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
19 |
20 | INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_meta_eval.jsonl"
21 | INDEX_STR_TEMPLATE = "pairwise_{model_or_physician}_{metric}_{pred_str}"
22 | CLUSTER_STR_TEMPLATE = "{cluster}: {index_str}"
23 |
24 | HEALTHBENCH_META_HTML_JINJA = (
25 | common.HTML_JINJA.replace(
26 | "Correct Answer: {{ correct_answer }}
\n",
27 | "",
28 | )
29 | + "Explanation for grader's label: {{ explanation }}
"
30 | )
31 |
32 |
33 | class HealthBenchMetaEval(Eval):
34 | def __init__(
35 | self,
36 | grader_model: SamplerBase,
37 | num_examples: int | None = None,
38 | n_threads: int = 120,
39 | n_repeats: int = 1,
40 | ):
41 | with bf.BlobFile(INPUT_PATH, "rb") as f:
42 | examples = [json.loads(line) for line in f]
43 | print(f"Loaded {len(examples)} examples from {INPUT_PATH}")
44 |
45 | rng = random.Random(0)
46 |
47 | if num_examples is not None and len(examples) > num_examples:
48 | examples = rng.sample(examples, num_examples)
49 |
50 | self.examples = examples * n_repeats
51 | self.grader_model = grader_model
52 | self.n_threads = n_threads
53 |
54 | def grade_sample(
55 | self,
56 | grading_response_dict: dict,
57 | physician_labels: list[bool],
58 | category: str,
59 | ) -> tuple[dict, bool | None, str]:
60 | metrics = {
61 | "num_physician_labels": len(physician_labels),
62 | "percent_physician_pos": sum(physician_labels) / len(physician_labels),
63 | }
64 |
65 | grader_label = grading_response_dict["criteria_met"]
66 | assert grader_label is True or grader_label is False
67 | metrics["model_predicted_positive"] = grader_label
68 | explanation = grading_response_dict.get(
69 | "explanation", "No explanation provided"
70 | )
71 |
72 | category_metrics = {f"{category}: {k}": v for k, v in metrics.items()}
73 | metrics = {**metrics, **category_metrics}
74 | return metrics, grader_label, explanation
75 |
76 | def __call__(self, sampler: SamplerBase) -> EvalResult:
77 | def fn(row: dict) -> tuple[SingleEvalResult, bool | None]:
78 | convo_with_response = row["prompt"] + [
79 | dict(content=row["completion"], role="assistant")
80 | ]
81 | prompt_str = "\n\n".join(
82 | [f"{m['role']}: {m['content']}" for m in convo_with_response]
83 | )
84 | grader_prompt = GRADER_TEMPLATE.replace("<>", prompt_str)
85 | grader_prompt = grader_prompt.replace("<>", row["rubric"])
86 | grader_convo = [dict(content=grader_prompt, role="user")]
87 |
88 | while True:
89 | sampler_response = sampler(grader_convo)
90 | response_text = sampler_response.response_text
91 | actual_queried_grader_convo = (
92 | sampler_response.actual_queried_message_list
93 | )
94 | grading_response_dict = parse_json_to_dict(response_text)
95 | if "criteria_met" in grading_response_dict:
96 | label = grading_response_dict["criteria_met"]
97 | if label is True or label is False:
98 | break
99 | print("Grading failed due to bad JSON output, retrying...")
100 |
101 | metrics, grader_label, explanation = self.grade_sample(
102 | grading_response_dict=grading_response_dict,
103 | physician_labels=row["binary_labels"],
104 | category=row["category"],
105 | )
106 | score = metrics["model_predicted_positive"]
107 |
108 | # Create HTML for each sample result
109 | html = common.jinja_env.from_string(HEALTHBENCH_META_HTML_JINJA).render(
110 | prompt_messages=actual_queried_grader_convo,
111 | next_message=dict(content=response_text, role="assistant"),
112 | score=metrics["model_predicted_positive"],
113 | extracted_answer=response_text,
114 | explanation=explanation,
115 | )
116 | convo = actual_queried_grader_convo + [
117 | dict(content=response_text, role="assistant")
118 | ]
119 | return (
120 | SingleEvalResult(html=html, score=score, convo=convo, metrics=metrics),
121 | grader_label,
122 | )
123 |
124 | # Run evaluation and collect results
125 | all_outputs = common.map_with_progress(fn, self.examples, self.n_threads)
126 | results: list[SingleEvalResult]
127 | grader_labels: list[bool]
128 | results, grader_labels = zip(*all_outputs)
129 |
130 | # model pairwise agreement metrics
131 | model_agreement_metrics = compute_metrics_for_rater_by_class(
132 | self_pred_list=grader_labels,
133 | other_preds_list=[x["binary_labels"] for x in self.examples],
134 | cluster_list=[x["category"] for x in self.examples],
135 | model_or_physician="model",
136 | )
137 |
138 | # physicians:
139 | physician_rating_lists = defaultdict(lambda: ([], [], []))
140 | for example in self.examples:
141 | for i in range(len(example["binary_labels"])):
142 | physician_id = example["anonymized_physician_ids"][i]
143 | self_pred = example["binary_labels"][i]
144 | other_preds = (
145 | example["binary_labels"][:i] + example["binary_labels"][i + 1 :]
146 | )
147 | cluster = example["category"]
148 | physician_rating_lists[physician_id][0].append(self_pred)
149 | physician_rating_lists[physician_id][1].append(other_preds)
150 | physician_rating_lists[physician_id][2].append(cluster)
151 |
152 | physician_agreement_metric_lists = defaultdict(dict)
153 | for physician_id, (
154 | physician_rating_list,
155 | other_preds_list,
156 | cluster_list,
157 | ) in physician_rating_lists.items():
158 | physician_agreement_metrics = compute_metrics_for_rater_by_class(
159 | self_pred_list=physician_rating_list,
160 | other_preds_list=other_preds_list,
161 | cluster_list=cluster_list,
162 | model_or_physician="physician",
163 | )
164 | for k, v in physician_agreement_metrics.items():
165 | physician_agreement_metric_lists[k][physician_id] = v
166 |
167 | # consolidate final metrics and add agreement metrics
168 | final_metrics = common.aggregate_results(
169 | results, default_stats=("mean", "n_samples", "bootstrap_std")
170 | )
171 | model_agreement_metrics_condensed: dict[str, float] = {
172 | k: v["value"]
173 | for k, v in model_agreement_metrics.items()
174 | if v["value"] is not None
175 | }
176 | assert final_metrics.metrics is not None
177 | final_metrics.metrics.update(model_agreement_metrics_condensed)
178 | final_metrics.score = final_metrics.metrics["pairwise_model_f1_balanced"]
179 |
180 | final_metrics.metadata = {
181 | "model_agreement_metrics": model_agreement_metrics,
182 | "physician_agreement_metric_lists": physician_agreement_metric_lists,
183 | }
184 | return final_metrics
185 |
186 |
187 | def compute_metrics_for_rater_by_class(
188 | self_pred_list: list[bool],
189 | other_preds_list: list[list[bool]],
190 | cluster_list: list[str],
191 | model_or_physician: Literal["model", "physician"],
192 | ) -> dict[str, dict[str, float | None]]:
193 | # get all the metrics for each cluster
194 | metric_lists = defaultdict(list)
195 | for self_pred, other_preds, cluster in zip(
196 | self_pred_list, other_preds_list, cluster_list, strict=True
197 | ):
198 | self_pred_str = "pos" if self_pred else "neg"
199 | for other_pred in other_preds:
200 | # precision. based on the grader's labels -
201 | # i.e., calculated as TP / (TP + FP)
202 | # so a prediction should be recorded whenever self_pred is True
203 | precision_index_str = INDEX_STR_TEMPLATE.format(
204 | model_or_physician=model_or_physician,
205 | metric="precision",
206 | pred_str=self_pred_str,
207 | )
208 | metric_lists[precision_index_str].append(self_pred == other_pred)
209 | precision_cluster_str = CLUSTER_STR_TEMPLATE.format(
210 | cluster=cluster, index_str=precision_index_str
211 | )
212 | metric_lists[precision_cluster_str].append(self_pred == other_pred)
213 |
214 | # recall. based on the ground truth labels -
215 | # i.e., calculated as TP / (TP + FN)
216 | # so a prediction should be recorded whenever other_pred is True
217 | other_pred_str = "pos" if other_pred else "neg"
218 | recall_index_str = INDEX_STR_TEMPLATE.format(
219 | model_or_physician=model_or_physician,
220 | metric="recall",
221 | pred_str=other_pred_str,
222 | )
223 | metric_lists[recall_index_str].append(self_pred == other_pred)
224 | recall_cluster_str = CLUSTER_STR_TEMPLATE.format(
225 | cluster=cluster, index_str=recall_index_str
226 | )
227 | metric_lists[recall_cluster_str].append(self_pred == other_pred)
228 |
229 | metrics: dict[str, dict[str, float | None]] = {}
230 | for index_str, metric_list in metric_lists.items():
231 | n = len(metric_list)
232 | metric = sum(metric_list) / n if n > 0 else None
233 | metrics[index_str] = {
234 | "n": n,
235 | "value": metric,
236 | }
237 |
238 | f1_metrics = get_f1_metrics(metrics)
239 | metrics.update(f1_metrics)
240 |
241 | balanced_metrics = get_balanced_metrics(metrics)
242 | metrics.update(balanced_metrics)
243 |
244 | return metrics
245 |
246 |
247 | def get_f1_metrics(
248 | metrics: dict[str, dict[str, float | None]],
249 | ) -> dict[str, dict[str, float | None]]:
250 | f1_metrics: dict[str, dict[str, float | None]] = {}
251 | for precision_key_name in metrics:
252 | if "precision" in precision_key_name:
253 | recall_key_name = precision_key_name.replace("precision", "recall")
254 | if recall_key_name not in metrics:
255 | continue
256 | f1_key_name = precision_key_name.replace("precision", "f1")
257 | assert f1_key_name not in metrics
258 | f1_metrics[f1_key_name] = compute_f1_metric(
259 | precision=metrics[precision_key_name],
260 | recall=metrics[recall_key_name],
261 | )
262 |
263 | return f1_metrics
264 |
265 |
266 | def compute_f1_metric(
267 | precision: dict[str, float | None],
268 | recall: dict[str, float | None],
269 | ) -> dict[str, float | None]:
270 | precision_n = precision["n"]
271 | recall_n = recall["n"]
272 | assert precision_n is not None and recall_n is not None, "n_pos or n_neg is None"
273 |
274 | precision_metric = precision["value"]
275 | recall_metric = recall["value"]
276 | if precision_metric is None or recall_metric is None:
277 | f1_metric = None
278 | n_f1 = (
279 | precision_n + recall_n
280 | ) # precision_metric is None iff precision_n = 0 and recall_metric is None iff recall_n = 0, so if either is zero this gives TP + FN + FP without double counting
281 | elif precision_metric == 0 and recall_metric == 0:
282 | f1_metric = 0.0
283 | tp = precision_metric * precision_n # because precision = TP / (TP+FP)
284 | n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
285 | else:
286 | f1_metric = (
287 | 2 * (precision_metric * recall_metric) / (precision_metric + recall_metric)
288 | )
289 | tp = precision_metric * precision_n # because precision = TP / (TP+FP)
290 | n_f1 = precision_n + recall_n - tp # TP+FP + TP+FN − TP
291 |
292 | return {
293 | "n": n_f1,
294 | "value": f1_metric,
295 | }
296 |
297 |
298 | def get_balanced_metrics(
299 | metrics: dict[str, dict[str, float | None]],
300 | ) -> dict[str, dict[str, float | None]]:
301 | balanced_metrics: dict[str, dict[str, float | None]] = {}
302 | for pos_key_name in metrics:
303 | if "pos" in pos_key_name:
304 | neg_key_name = pos_key_name.replace("pos", "neg")
305 | if neg_key_name not in metrics:
306 | continue
307 | balanced_key_name = pos_key_name.replace("pos", "balanced")
308 | assert balanced_key_name not in metrics
309 | balanced_metrics[balanced_key_name] = compute_balanced_metric(
310 | metric_pos=metrics[pos_key_name],
311 | metric_neg=metrics[neg_key_name],
312 | )
313 |
314 | return balanced_metrics
315 |
316 |
317 | def compute_balanced_metric(
318 | metric_pos: dict[str, float | None],
319 | metric_neg: dict[str, float | None],
320 | ) -> dict[str, float | None]:
321 | n_pos = metric_pos["n"]
322 | n_neg = metric_neg["n"]
323 | assert n_pos is not None and n_neg is not None, "n_pos or n_neg is None"
324 |
325 | pos_metric = metric_pos["value"]
326 | neg_metric = metric_neg["value"]
327 | if pos_metric is None or neg_metric is None:
328 | metric = None
329 | else:
330 | metric = (pos_metric + neg_metric) / 2
331 |
332 | return {
333 | "n": n_pos + n_neg,
334 | # note: this overcounts samples going towards the balanced F1
335 | "value": metric,
336 | }
337 |
--------------------------------------------------------------------------------
/healthbench_meta_eval_test.py:
--------------------------------------------------------------------------------
1 | from . import healthbench_meta_eval
2 |
3 |
4 | def test_compute_agreement_for_rater_by_class():
5 | self_pred_list = [True, False, True]
6 | other_preds_list = [[True, True, False], [True, False], [False]]
7 | cluster_list = ["a", "a", "b"]
8 | model_or_physician = "model"
9 | metrics = healthbench_meta_eval.compute_metrics_for_rater_by_class(
10 | self_pred_list, other_preds_list, cluster_list, model_or_physician
11 | )
12 |
13 | # precision overall
14 | index_str_pos_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
15 | model_or_physician=model_or_physician, metric="precision", pred_str="pos"
16 | )
17 | index_str_neg_precision = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
18 | model_or_physician=model_or_physician, metric="precision", pred_str="neg"
19 | )
20 | overall_pos_precision = metrics[index_str_pos_precision]
21 | overall_neg_precision = metrics[index_str_neg_precision]
22 | expected_overall_pos_precision = (2 + 0 + 0) / (3 + 0 + 1)
23 | expected_overall_neg_precision = (0 + 1 + 0) / (0 + 2 + 0)
24 | assert overall_pos_precision["value"] == expected_overall_pos_precision
25 | assert overall_neg_precision["value"] == expected_overall_neg_precision
26 | assert overall_pos_precision["n"] == 4
27 | assert overall_neg_precision["n"] == 2
28 |
29 | # recall overall
30 | index_str_pos_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
31 | model_or_physician=model_or_physician, metric="recall", pred_str="pos"
32 | )
33 | index_str_neg_recall = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
34 | model_or_physician=model_or_physician, metric="recall", pred_str="neg"
35 | )
36 | overall_pos_recall = metrics[index_str_pos_recall]
37 | overall_neg_recall = metrics[index_str_neg_recall]
38 | expected_overall_pos_recall = (2 + 0 + 0) / (2 + 1 + 0)
39 | expected_overall_neg_recall = (0 + 1 + 0) / (1 + 1 + 1)
40 | assert overall_pos_recall["value"] == expected_overall_pos_recall
41 | assert overall_neg_recall["value"] == expected_overall_neg_recall
42 | assert overall_pos_recall["n"] == 3
43 | assert overall_neg_recall["n"] == 3
44 |
45 | # f1 overall
46 | index_str_pos_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
47 | model_or_physician=model_or_physician, metric="f1", pred_str="pos"
48 | )
49 | index_str_neg_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
50 | model_or_physician=model_or_physician, metric="f1", pred_str="neg"
51 | )
52 | overall_pos_f1 = metrics[index_str_pos_f1]
53 | overall_neg_f1 = metrics[index_str_neg_f1]
54 | expected_overall_pos_f1 = (
55 | 2
56 | * expected_overall_pos_precision
57 | * expected_overall_pos_recall
58 | / (expected_overall_pos_precision + expected_overall_pos_recall)
59 | )
60 | expected_overall_neg_f1 = (
61 | 2
62 | * expected_overall_neg_precision
63 | * expected_overall_neg_recall
64 | / (expected_overall_neg_precision + expected_overall_neg_recall)
65 | )
66 | assert overall_pos_f1["value"] == expected_overall_pos_f1
67 | assert overall_neg_f1["value"] == expected_overall_neg_f1
68 |
69 | # balanced f1
70 | index_str_balanced_f1 = healthbench_meta_eval.INDEX_STR_TEMPLATE.format(
71 | model_or_physician=model_or_physician, metric="f1", pred_str="balanced"
72 | )
73 | balanced_f1 = metrics[index_str_balanced_f1]
74 | expected_balanced_f1 = (expected_overall_pos_f1 + expected_overall_neg_f1) / 2
75 | assert balanced_f1["value"] == expected_balanced_f1
76 |
77 | # by cluster
78 | # precision
79 | cluster_a_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
80 | cluster="a", index_str=index_str_pos_precision
81 | )
82 | cluster_a_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
83 | cluster="a", index_str=index_str_neg_precision
84 | )
85 | cluster_a_pos_precision = metrics[cluster_a_str_pos_precision]
86 | cluster_a_neg_precision = metrics[cluster_a_str_neg_precision]
87 | assert cluster_a_pos_precision["value"] == (
88 | # example 1, 2 in order
89 | (2 + 0) / (3 + 0)
90 | )
91 | assert cluster_a_neg_precision["value"] == (
92 | # example 1, 2 in order
93 | (0 + 1) / (0 + 2)
94 | )
95 | assert cluster_a_pos_precision["n"] == 3
96 | assert cluster_a_neg_precision["n"] == 2
97 |
98 | # recall
99 | cluster_a_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
100 | cluster="a", index_str=index_str_pos_recall
101 | )
102 | cluster_a_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
103 | cluster="a", index_str=index_str_neg_recall
104 | )
105 | cluster_a_pos_recall = metrics[cluster_a_str_pos_recall]
106 | cluster_a_neg_recall = metrics[cluster_a_str_neg_recall]
107 | assert cluster_a_pos_recall["value"] == (
108 | # example 1, 2 in order
109 | (2 + 0) / (2 + 1)
110 | )
111 | assert cluster_a_neg_recall["value"] == (
112 | # example 1, 2 in order
113 | (0 + 1) / (1 + 1)
114 | )
115 | assert cluster_a_pos_recall["n"] == 3
116 | assert cluster_a_neg_recall["n"] == 2
117 |
118 | # cluster B
119 | # precision
120 | cluster_b_str_pos_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
121 | cluster="b", index_str=index_str_pos_precision
122 | )
123 | cluster_b_str_neg_precision = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
124 | cluster="b", index_str=index_str_neg_precision
125 | )
126 | cluster_b_str_pos_precision = metrics[cluster_b_str_pos_precision]
127 | assert cluster_b_str_neg_precision not in metrics
128 | assert cluster_b_str_pos_precision["value"] == (
129 | # example 3 only
130 | 0 / 1
131 | )
132 | assert cluster_b_str_pos_precision["n"] == 1
133 |
134 | # recall
135 | cluster_b_str_pos_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
136 | cluster="b", index_str=index_str_pos_recall
137 | )
138 | cluster_b_str_neg_recall = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
139 | cluster="b", index_str=index_str_neg_recall
140 | )
141 | assert cluster_b_str_pos_recall not in metrics
142 | cluster_b_neg_recall = metrics[cluster_b_str_neg_recall]
143 | assert cluster_b_neg_recall["value"] == (
144 | # example 3 only
145 | 0 / 1
146 | )
147 | assert cluster_b_neg_recall["n"] == 1
148 |
149 | # f1
150 | index_str_pos_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
151 | cluster="b", index_str=index_str_pos_f1
152 | )
153 | index_str_neg_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
154 | cluster="b", index_str=index_str_neg_f1
155 | )
156 | index_str_balanced_f1 = healthbench_meta_eval.CLUSTER_STR_TEMPLATE.format(
157 | cluster="b", index_str=index_str_balanced_f1
158 | )
159 | assert index_str_pos_f1 not in metrics
160 | assert index_str_neg_f1 not in metrics
161 | assert index_str_balanced_f1 not in metrics
162 |
163 |
164 | if __name__ == "__main__":
165 | test_compute_agreement_for_rater_by_class()
166 |
--------------------------------------------------------------------------------
/humaneval_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | HumanEval: Evaluating Large Language Models Trained on Code
3 | Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
4 | https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
5 | """
6 |
7 | import random
8 | import re
9 | from concurrent.futures import ThreadPoolExecutor, as_completed
10 |
11 | from human_eval.data import read_problems
12 | from human_eval.evaluation import estimate_pass_at_k
13 | from human_eval.execution import check_correctness # , unsafe_execute
14 |
15 | from . import common
16 | from .common import HTML_JINJA
17 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
18 |
19 |
20 | def evaluate_functional_correctness(
21 | sample: dict[str, str],
22 | completions: list[str],
23 | n_workers: int = 4,
24 | timeout: float = 3.0,
25 | ):
26 | """
27 | Evaluates the functional correctness of generated samples, and writes
28 | results to f"{sample_file}_results.jsonl.gz"
29 | """
30 |
31 | # Check the generated samples against test suites.
32 | with ThreadPoolExecutor(max_workers=n_workers) as executor:
33 | futures = []
34 | for i, completion in enumerate(completions):
35 | args = (sample, completion, timeout, i)
36 | future = executor.submit(check_correctness, *args)
37 | futures.append(future)
38 | results = []
39 | for future in as_completed(futures):
40 | result = future.result()
41 | results.append(result)
42 | passed = [int(r["passed"]) for r in results]
43 | return passed
44 |
45 |
46 | class HumanEval(Eval):
47 | def __init__(
48 | self,
49 | num_examples: int = 250, # restrict to a subset of the data for debugging
50 | num_samples_per_task: int = 5,
51 | ks_passes: list[int] = [1, 2, 5],
52 | timeout: int = 120,
53 | ):
54 | self.seed = 0
55 | self.examples = read_problems()
56 | self.examples = list(self.examples.values())
57 |
58 | self._num_examples = num_examples
59 | if self._num_examples:
60 | self.examples = random.Random(self.seed).sample(self.examples, num_examples)
61 | self._num_samples_per_task = num_samples_per_task
62 | self._ks_passes = ks_passes
63 | self._timeout = timeout
64 |
65 | def __call__(self, sampler: SamplerBase) -> EvalResult:
66 | instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
67 |
68 | def find_code(completion):
69 | pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
70 | matches = pattern.findall(completion)
71 | extracted_answer = matches[0] if len(matches) >= 1 else completion
72 | extracted_answer = extracted_answer[
73 | extracted_answer.find(":\n ") + 2 :
74 | ] # remove signature
75 | return extracted_answer
76 |
77 | def fn(sample: dict[str, str]):
78 | prompt_messages = [
79 | sampler._pack_message(
80 | role="user", content=instruction + sample["prompt"]
81 | )
82 | ]
83 | completions = [
84 | find_code(sampler(prompt_messages).response_text)
85 | for _ in range(self._num_samples_per_task)
86 | ]
87 | results = evaluate_functional_correctness(sample, completions)
88 | total = len(results)
89 | correct = sum(results)
90 | score = sum(results) / len(results)
91 | html = common.jinja_env.from_string(HTML_JINJA).render(
92 | prompt_messages=prompt_messages,
93 | next_message=dict(content=completions[0], role="assistant"),
94 | score=score,
95 | correct_answer=[1] * len(results),
96 | extracted_answer=results,
97 | )
98 | convo = prompt_messages + [
99 | dict(content=completion, role="assistant") for completion in completions
100 | ]
101 | return SingleEvalResult(
102 | html=html,
103 | score=score,
104 | convo=convo,
105 | metrics={
106 | f"pass@{k}": estimate_pass_at_k([total], [correct], k)
107 | # this will be aggrated so no need of .mean()
108 | for k in self._ks_passes
109 | if total >= k
110 | },
111 | )
112 |
113 | results = common.map_with_progress(fn, self.examples, num_threads=3)
114 | return common.aggregate_results(results)
115 |
--------------------------------------------------------------------------------
/math_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Measuring Mathematical Problem Solving With the MATH Dataset
3 | Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
4 | https://arxiv.org/abs/2103.03874
5 | """
6 |
7 | import random
8 | import re
9 | from typing import Literal
10 |
11 | import pandas
12 |
13 | from . import common
14 | from .common import ANSWER_PATTERN, HTML_JINJA, check_equality
15 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
16 |
17 | QUERY_TEMPLATE = """
18 | Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
19 |
20 | {Question}
21 |
22 | Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
23 | """.strip()
24 |
25 |
26 | class MathEval(Eval):
27 | def __init__(
28 | self,
29 | equality_checker: SamplerBase,
30 | num_examples: int | None = None,
31 | n_repeats: int = 16,
32 | split: Literal["math_test", "math_500_test"] = "math_test",
33 | ):
34 | df = pandas.read_csv(
35 | f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv"
36 | )
37 | examples = [row.to_dict() for _, row in df.iterrows()]
38 | if num_examples:
39 | assert n_repeats == 1, "n_repeats only supported for num_examples = None"
40 | rng = random.Random(0)
41 | examples = rng.sample(examples, num_examples)
42 | self.examples = examples * n_repeats
43 | self.equality_checker = equality_checker
44 |
45 | def __call__(self, sampler: SamplerBase) -> EvalResult:
46 | def fn(row: dict):
47 | prompt_messages = [
48 | sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
49 | ]
50 | sampler_response = sampler(prompt_messages)
51 | response_text = sampler_response.response_text
52 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list
53 | match = re.search(ANSWER_PATTERN, response_text)
54 | extracted_answer = match.group(1) if match else None
55 | score = float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
56 | html = common.jinja_env.from_string(HTML_JINJA).render(
57 | prompt_messages=actual_queried_prompt_messages,
58 | next_message=dict(content=response_text, role="assistant"),
59 | score=score,
60 | correct_answer=row["Answer"],
61 | extracted_answer=extracted_answer,
62 | )
63 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
64 | return SingleEvalResult(html=html, score=score, convo=convo)
65 |
66 | results = common.map_with_progress(fn, self.examples)
67 | return common.aggregate_results(results)
68 |
--------------------------------------------------------------------------------
/mgsm_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
3 | Language Models are Multilingual Chain-of-Thought Reasoners
4 | Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei
5 | https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
6 | """
7 |
8 | import re
9 | from typing import Optional
10 |
11 | from . import common
12 | from .mmlu_eval import HTML_JINJA
13 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
14 |
15 | ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
16 | LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
17 | NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"]
18 |
19 | LANG_TO_FPATH = {
20 | "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv",
21 | "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv",
22 | "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv",
23 | "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv",
24 | "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv",
25 | "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv",
26 | "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv",
27 | "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv",
28 | "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv",
29 | "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv",
30 | "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv",
31 | }
32 | LANG_TO_INSTRUCTIONS = {
33 | "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".
34 |
35 | {input}""",
36 | "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.
37 |
38 | {input}""",
39 | "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.
40 |
41 | {input}""",
42 | "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".
43 |
44 | {input}""",
45 | "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".
46 |
47 | {input}""",
48 | "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。
49 |
50 | {input}""",
51 | "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".
52 |
53 | {input}""",
54 | "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".
55 |
56 | {input}""",
57 | "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.
58 |
59 | {input}""",
60 | "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:"
61 |
62 | {input}""",
63 | "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。
64 |
65 | {input}""",
66 | }
67 |
68 | LANG_TO_ANSWER_PREFIX = {
69 | "en": "Answer",
70 | "bn": "উত্তর",
71 | "de": "Antwort",
72 | "es": "Respuesta",
73 | "fr": "Réponse",
74 | "ja": "答え",
75 | "ru": "Ответ",
76 | "sw": "Jibu",
77 | "te": "సమాధానం",
78 | "th": "คำตอบ",
79 | "zh": "答案",
80 | }
81 |
82 |
83 | def parse_answer(answer: str, answer_prefix: str) -> str:
84 | if answer_prefix not in answer:
85 | return ""
86 |
87 | answer_text = answer.split(answer_prefix)[-1].strip()
88 |
89 | # find all the numbers (including decimals) in the string
90 | numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
91 |
92 | # return the first number (removing trailing decimal point if present),
93 | # or an empty string if there were no numbers
94 | return numbers[-1].rstrip(".") if numbers else ""
95 |
96 |
97 | def score_mgsm(target: str, prediction: str) -> bool:
98 | if "." in prediction:
99 | prediction = prediction.rstrip("0").rstrip(".")
100 |
101 | target = target.replace(",", "")
102 | prediction = prediction.replace(",", "")
103 |
104 | return target == prediction
105 |
106 |
107 | def get_lang_examples(lang: str) -> list[dict[str, str]]:
108 | fpath = LANG_TO_FPATH[lang]
109 | examples = []
110 | with common.url_to_fileobj(fpath, binary=True) as f:
111 | for raw_line in f:
112 | line = raw_line.decode("utf-8").strip()
113 | inputs, targets = line.split("\t")
114 | if "." in targets:
115 | raise ValueError(f"targets {targets} contains a decimal point.")
116 | # targets = int(targets.replace(",", ""))
117 | examples.append({"inputs": inputs, "targets": targets, "lang": lang})
118 | return examples
119 |
120 |
121 | def get_all_examples() -> list[dict[str, str]]:
122 | examples = []
123 | for lang in ALL_LANGUAGES:
124 | if lang != "en":
125 | continue
126 | examples += get_lang_examples(lang)
127 | return examples
128 |
129 |
130 | class MGSMEval(Eval):
131 | def __init__(
132 | self,
133 | num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging
134 | languages: Optional[list[str]] = ALL_LANGUAGES,
135 | ):
136 | if languages is None:
137 | languages = ALL_LANGUAGES
138 | else:
139 | for language in languages:
140 | if language not in ALL_LANGUAGES:
141 | raise ValueError(
142 | f"language {language} is not a valid language. "
143 | f"It should be one in {ALL_LANGUAGES}"
144 | )
145 | self._languages = languages
146 | self._num_examples_per_lang = num_examples_per_lang
147 |
148 | examples = []
149 | for lang in self._languages:
150 | lang_examples = get_lang_examples(lang)
151 | examples.extend(lang_examples[: self._num_examples_per_lang])
152 | self.examples = examples
153 |
154 | def __call__(self, sampler: SamplerBase) -> EvalResult:
155 | def fn(example: dict[str, str]):
156 | language = example["lang"]
157 | latin_language = "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
158 | correct_answer = example["targets"]
159 | instruction = LANG_TO_INSTRUCTIONS[language]
160 | prompt_messages = [
161 | sampler._pack_message(
162 | content=instruction.format(input=example["inputs"]), role="user"
163 | )
164 | ]
165 | try:
166 | sampler_response = sampler(prompt_messages)
167 | response_text = sampler_response.response_text
168 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list
169 | except Exception as e:
170 | response_text = ""
171 |
172 | answer_prefix = LANG_TO_ANSWER_PREFIX[language]
173 | extracted_answer = parse_answer(response_text, answer_prefix)
174 |
175 | score = score_mgsm(correct_answer, extracted_answer)
176 | html = common.jinja_env.from_string(HTML_JINJA).render(
177 | prompt_messages=actual_queried_prompt_messages,
178 | next_message=dict(content=response_text, role="assistant"),
179 | score=score,
180 | correct_answer=correct_answer,
181 | extracted_answer=extracted_answer or None,
182 | )
183 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
184 | return SingleEvalResult(
185 | html=html,
186 | score=score,
187 | convo=convo,
188 | metrics={language: score, latin_language: score},
189 | )
190 |
191 | results = common.map_with_progress(fn, self.examples)
192 | return common.aggregate_results(results, default_stats=("mean", "std"))
193 |
--------------------------------------------------------------------------------
/mmlu_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Measuring Massive Multitask Language Understanding
3 | Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
4 | https://arxiv.org/abs/2009.03300
5 | """
6 |
7 | import random
8 | import re
9 |
10 | import pandas
11 |
12 | from . import common
13 | from .common import (
14 | HTML_JINJA,
15 | MULTILINGUAL_ANSWER_PATTERN_TEMPLATE,
16 | MULTILINGUAL_ANSWER_REGEXES,
17 | format_multichoice_question,
18 | normalize_extracted_answer,
19 | normalize_response,
20 | )
21 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
22 |
23 | subject2category = {
24 | "abstract_algebra": "stem",
25 | "anatomy": "other",
26 | "astronomy": "stem",
27 | "business_ethics": "other",
28 | "clinical_knowledge": "other",
29 | "college_biology": "stem",
30 | "college_chemistry": "stem",
31 | "college_computer_science": "stem",
32 | "college_mathematics": "stem",
33 | "college_medicine": "other",
34 | "college_physics": "stem",
35 | "computer_security": "stem",
36 | "conceptual_physics": "stem",
37 | "econometrics": "social_sciences",
38 | "electrical_engineering": "stem",
39 | "elementary_mathematics": "stem",
40 | "formal_logic": "humanities",
41 | "global_facts": "other",
42 | "high_school_biology": "stem",
43 | "high_school_chemistry": "stem",
44 | "high_school_computer_science": "stem",
45 | "high_school_european_history": "humanities",
46 | "high_school_geography": "social_sciences",
47 | "high_school_government_and_politics": "social_sciences",
48 | "high_school_macroeconomics": "social_sciences",
49 | "high_school_mathematics": "stem",
50 | "high_school_microeconomics": "social_sciences",
51 | "high_school_physics": "stem",
52 | "high_school_psychology": "social_sciences",
53 | "high_school_statistics": "stem",
54 | "high_school_us_history": "humanities",
55 | "high_school_world_history": "humanities",
56 | "human_aging": "other",
57 | "human_sexuality": "social_sciences",
58 | "international_law": "humanities",
59 | "jurisprudence": "humanities",
60 | "logical_fallacies": "humanities",
61 | "machine_learning": "stem",
62 | "management": "other",
63 | "marketing": "other",
64 | "medical_genetics": "other",
65 | "miscellaneous": "other",
66 | "moral_disputes": "humanities",
67 | "moral_scenarios": "humanities",
68 | "nutrition": "other",
69 | "philosophy": "humanities",
70 | "prehistory": "humanities",
71 | "professional_accounting": "other",
72 | "professional_law": "humanities",
73 | "professional_medicine": "other",
74 | "professional_psychology": "social_sciences",
75 | "public_relations": "social_sciences",
76 | "security_studies": "social_sciences",
77 | "sociology": "social_sciences",
78 | "us_foreign_policy": "social_sciences",
79 | "virology": "other",
80 | "world_religions": "humanities",
81 | }
82 |
83 |
84 | class MMLUEval(Eval):
85 | def __init__(self, num_examples: int | None = None, language: str = "EN-US"):
86 | if language != "EN-US":
87 | url = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv"
88 | else:
89 | url = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
90 | df = pandas.read_csv(url)
91 | examples = [row.to_dict() for _, row in df.iterrows()]
92 | if num_examples:
93 | examples = random.Random(0).sample(examples, num_examples)
94 | self.examples = examples
95 |
96 | def __call__(self, sampler: SamplerBase) -> EvalResult:
97 | def fn(row: dict):
98 | prompt_messages = [
99 | sampler._pack_message(
100 | content=format_multichoice_question(row), role="user"
101 | )
102 | ]
103 | sampler_response = sampler(prompt_messages)
104 | response_text = sampler_response.response_text
105 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list
106 | response_text = normalize_response(response_text)
107 | extracted_answer = None
108 | for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
109 | regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
110 | match = re.search(regex, response_text)
111 | if match:
112 | extracted_answer = normalize_extracted_answer(match.group(1))
113 | break
114 | score = 1.0 if extracted_answer == row["Answer"] else 0.0
115 | html = common.jinja_env.from_string(HTML_JINJA).render(
116 | prompt_messages=actual_queried_prompt_messages,
117 | next_message=dict(content=response_text, role="assistant"),
118 | score=score,
119 | correct_answer=row["Answer"],
120 | extracted_answer=extracted_answer,
121 | )
122 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
123 | category = subject2category.get(row["Subject"], "other")
124 | return SingleEvalResult(
125 | html=html, score=score, metrics={category: score}, convo=convo
126 | )
127 |
128 | results = common.map_with_progress(fn, self.examples)
129 | return common.aggregate_results(results)
130 |
--------------------------------------------------------------------------------
/multilingual_mmlu_benchmark_results.md:
--------------------------------------------------------------------------------
1 | # Multilingual MMLU Benchmark Results
2 |
3 | To evaluate multilingual performance, we translated MMLU’s test set into 14 languages using professional human translators. Relying on human translators for this evaluation increases confidence in the accuracy of the translations, especially for low-resource languages like Yoruba.
4 |
5 | ## Results
6 |
7 |
8 | | Language | o3-high | o1 | o4-mini-high | o3-mini-high | gpt-4.5-preview-2025-02-27 | gpt-4.1-2025-04-14 | gpt-4o-2024-11-20 | gpt-4.1-mini-2025-04-14 | gpt-4o-mini-2024-07-18 | gpt-4.1-nano-2025-04-14 |
9 | | :------------------: | :---------: | :---: | :----------: | :----------: | :------------------------: | :----------------: | :---------------: | :---------------------: | :--------------------: | :---------------------: |
10 | | Arabic | **0.904** | 0.890 | 0.861 | 0.819 | 0.860 | 0.844 | 0.831 | 0.795 | 0.709 | 0.659 |
11 | | Bengali | **0.878** | 0.873 | 0.840 | 0.801 | 0.848 | 0.827 | 0.801 | 0.749 | 0.658 | 0.583 |
12 | | Chinese (Simplified) | **0.893** | 0.889 | 0.869 | 0.836 | 0.870 | 0.861 | 0.842 | 0.817 | 0.731 | 0.710 |
13 | | French | **0.906** | 0.893 | 0.874 | 0.837 | 0.878 | 0.870 | 0.846 | 0.835 | 0.766 | 0.739 |
14 | | German | **0.905** | 0.890 | 0.867 | 0.808 | 0.853 | 0.855 | 0.836 | 0.823 | 0.743 | 0.722 |
15 | | Hindi | **0.898** | 0.883 | 0.859 | 0.811 | 0.858 | 0.842 | 0.819 | 0.780 | 0.692 | 0.629 |
16 | | Indonesian | **0.898** | 0.886 | 0.869 | 0.828 | 0.872 | 0.859 | 0.840 | 0.816 | 0.745 | 0.714 |
17 | | Italian | **0.912** | 0.897 | 0.877 | 0.838 | 0.878 | 0.869 | 0.845 | 0.835 | 0.764 | 0.734 |
18 | | Japanese | **0.890** | 0.889 | 0.869 | 0.831 | 0.869 | 0.856 | 0.835 | 0.810 | 0.726 | 0.690 |
19 | | Korean | **0.893** | 0.882 | 0.867 | 0.826 | 0.860 | 0.849 | 0.829 | 0.801 | 0.720 | 0.679 |
20 | | Portuguese (Brazil) | **0.910** | 0.895 | 0.878 | 0.841 | 0.879 | 0.870 | 0.836 | 0.839 | 0.768 | 0.741 |
21 | | Spanish | **0.911** | 0.899 | 0.880 | 0.840 | 0.884 | 0.876 | 0.843 | 0.839 | 0.774 | 0.748 |
22 | | Swahili | **0.860** | 0.854 | 0.813 | 0.738 | 0.820 | 0.795 | 0.779 | 0.679 | 0.619 | 0.566 |
23 | | Yoruba | **0.780** | 0.754 | 0.708 | 0.637 | 0.682 | 0.647 | 0.621 | 0.566 | 0.458 | 0.455 |
24 | | Average | **0.888** | 0.877 | 0.852 | 0.807 | 0.851 | 0.837 | 0.814 | 0.785 | 0.705 | 0.669 |
25 |
26 | These results can be reproduced by running
27 |
28 | ```bash
29 | python -m simple-evals.run_multilingual_mmlu
30 | ```
31 |
--------------------------------------------------------------------------------
/run_multilingual_mmlu.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pandas as pd
4 |
5 | from . import common
6 | from .mmlu_eval import MMLUEval
7 | from .sampler.chat_completion_sampler import (
8 | OPENAI_SYSTEM_MESSAGE_API,
9 | OPENAI_SYSTEM_MESSAGE_CHATGPT,
10 | ChatCompletionSampler,
11 | )
12 | from .sampler.o_chat_completion_sampler import OChatCompletionSampler
13 |
14 |
15 | def main():
16 | debug = True
17 | samplers = {
18 | "gpt-4o_chatgpt": ChatCompletionSampler(
19 | model="gpt-4o",
20 | system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
21 | max_tokens=2048,
22 | ),
23 | "gpt-4o-mini-2024-07-18": ChatCompletionSampler(
24 | model="gpt-4o-mini-2024-07-18",
25 | system_message=OPENAI_SYSTEM_MESSAGE_API,
26 | max_tokens=2048,
27 | ),
28 | "o1-preview": OChatCompletionSampler(
29 | model="o1-preview",
30 | ),
31 | "o1-mini": OChatCompletionSampler(
32 | model="o1-mini",
33 | ),
34 | # Default == Medium
35 | "o3-mini": OChatCompletionSampler(
36 | model="o3-mini",
37 | ),
38 | "o3-mini_high": OChatCompletionSampler(
39 | model="o3-mini",
40 | reasoning_effort="high",
41 | ),
42 | "o3-mini_low": OChatCompletionSampler(
43 | model="o3-mini",
44 | reasoning_effort="low",
45 | ),
46 | }
47 |
48 | def get_evals(eval_name):
49 | match eval_name:
50 | case "mmlu_EN-US":
51 | return MMLUEval(num_examples=10 if debug else None, language="EN-US")
52 | case "mmlu_AR-XY":
53 | return MMLUEval(num_examples=10 if debug else None, language="AR-XY")
54 | case "mmlu_BN-BD":
55 | return MMLUEval(num_examples=10 if debug else None, language="BN-BD")
56 | case "mmlu_DE-DE":
57 | return MMLUEval(num_examples=10 if debug else None, language="DE-DE")
58 | case "mmlu_ES-LA":
59 | return MMLUEval(num_examples=10 if debug else None, language="ES-LA")
60 | case "mmlu_FR-FR":
61 | return MMLUEval(num_examples=10 if debug else None, language="FR-FR")
62 | case "mmlu_HI-IN":
63 | return MMLUEval(num_examples=10 if debug else None, language="HI-IN")
64 | case "mmlu_ID-ID":
65 | return MMLUEval(num_examples=10 if debug else None, language="ID-ID")
66 | case "mmlu_IT-IT":
67 | return MMLUEval(num_examples=10 if debug else None, language="IT-IT")
68 | case "mmlu_JA-JP":
69 | return MMLUEval(num_examples=10 if debug else None, language="JA-JP")
70 | case "mmlu_KO-KR":
71 | return MMLUEval(num_examples=10 if debug else None, language="KO-KR")
72 | case "mmlu_PT-BR":
73 | return MMLUEval(num_examples=10 if debug else None, language="PT-BR")
74 | case "mmlu_ZH-CN":
75 | return MMLUEval(num_examples=10 if debug else None, language="ZH-CN")
76 | case "mmlu_SW-KE":
77 | return MMLUEval(num_examples=10 if debug else None, language="SW-KE")
78 | case "mmlu_YO-NG":
79 | return MMLUEval(num_examples=10 if debug else None, language="YO-NG")
80 | case _:
81 | raise Exception(f"Unrecoginized eval type: {eval_name}")
82 |
83 | evals = {
84 | eval_name: get_evals(eval_name)
85 | for eval_name in [
86 | "mmlu_AR-XY",
87 | "mmlu_BN-BD",
88 | "mmlu_DE-DE",
89 | "mmlu_EN-US",
90 | "mmlu_ES-LA",
91 | "mmlu_FR-FR",
92 | "mmlu_HI-IN",
93 | "mmlu_ID-ID",
94 | "mmlu_IT-IT",
95 | "mmlu_JA-JP",
96 | "mmlu_KO-KR",
97 | "mmlu_PT-BR",
98 | "mmlu_ZH-CN",
99 | "mmlu_SW-KE",
100 | "mmlu_YO-NG",
101 | ]
102 | }
103 | print(evals)
104 | debug_suffix = "_DEBUG" if debug else ""
105 | mergekey2resultpath = {}
106 | for sampler_name, sampler in samplers.items():
107 | for eval_name, eval_obj in evals.items():
108 | result = eval_obj(sampler)
109 | # ^^^ how to use a sampler
110 | file_stem = f"{eval_name}_{sampler_name}"
111 | report_filename = f"/tmp/{file_stem}{debug_suffix}.html"
112 | print(f"Writing report to {report_filename}")
113 | with open(report_filename, "w") as fh:
114 | fh.write(common.make_report(result))
115 | metrics = result.metrics | {"score": result.score}
116 | print(metrics)
117 | result_filename = f"/tmp/{file_stem}{debug_suffix}.json"
118 | with open(result_filename, "w") as f:
119 | f.write(json.dumps(metrics, indent=2))
120 | print(f"Writing results to {result_filename}")
121 | mergekey2resultpath[f"{file_stem}"] = result_filename
122 | merge_metrics = []
123 | for eval_sampler_name, result_filename in mergekey2resultpath.items():
124 | try:
125 | result = json.load(open(result_filename, "r+"))
126 | except Exception as e:
127 | print(e, result_filename)
128 | continue
129 | result = result.get("f1_score", result.get("score", None))
130 | eval_name = eval_sampler_name[: eval_sampler_name.find("_")]
131 | sampler_name = eval_sampler_name[eval_sampler_name.find("_") + 1 :]
132 | merge_metrics.append(
133 | {"eval_name": eval_name, "sampler_name": sampler_name, "metric": result}
134 | )
135 | merge_metrics_df = pd.DataFrame(merge_metrics).pivot(
136 | index=["sampler_name"], columns="eval_name"
137 | )
138 | print("\nAll results: ")
139 | print(merge_metrics_df.to_markdown())
140 | return merge_metrics
141 |
142 |
143 | if __name__ == "__main__":
144 | main()
145 |
--------------------------------------------------------------------------------
/sampler/chat_completion_sampler.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any
3 |
4 | import openai
5 | from openai import OpenAI
6 |
7 | from ..types import MessageList, SamplerBase, SamplerResponse
8 |
9 | OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
10 | OPENAI_SYSTEM_MESSAGE_CHATGPT = (
11 | "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
12 | + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
13 | )
14 |
15 |
16 | class ChatCompletionSampler(SamplerBase):
17 | """
18 | Sample from OpenAI's chat completion API
19 | """
20 |
21 | def __init__(
22 | self,
23 | model: str = "gpt-3.5-turbo",
24 | system_message: str | None = None,
25 | temperature: float = 0.5,
26 | max_tokens: int = 1024,
27 | ):
28 | self.api_key_name = "OPENAI_API_KEY"
29 | self.client = OpenAI()
30 | # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
31 | self.model = model
32 | self.system_message = system_message
33 | self.temperature = temperature
34 | self.max_tokens = max_tokens
35 | self.image_format = "url"
36 |
37 | def _handle_image(
38 | self,
39 | image: str,
40 | encoding: str = "base64",
41 | format: str = "png",
42 | fovea: int = 768,
43 | ):
44 | new_image = {
45 | "type": "image_url",
46 | "image_url": {
47 | "url": f"data:image/{format};{encoding},{image}",
48 | },
49 | }
50 | return new_image
51 |
52 | def _handle_text(self, text: str):
53 | return {"type": "text", "text": text}
54 |
55 | def _pack_message(self, role: str, content: Any):
56 | return {"role": str(role), "content": content}
57 |
58 | def __call__(self, message_list: MessageList) -> SamplerResponse:
59 | if self.system_message:
60 | message_list = [
61 | self._pack_message("system", self.system_message)
62 | ] + message_list
63 | trial = 0
64 | while True:
65 | try:
66 | response = self.client.chat.completions.create(
67 | model=self.model,
68 | messages=message_list,
69 | temperature=self.temperature,
70 | max_tokens=self.max_tokens,
71 | )
72 | content = response.choices[0].message.content
73 | if content is None:
74 | raise ValueError("OpenAI API returned empty response; retrying")
75 | return SamplerResponse(
76 | response_text=content,
77 | response_metadata={"usage": response.usage},
78 | actual_queried_message_list=message_list,
79 | )
80 | # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
81 | except openai.BadRequestError as e:
82 | print("Bad Request Error", e)
83 | return SamplerResponse(
84 | response_text="No response (bad request).",
85 | response_metadata={"usage": None},
86 | actual_queried_message_list=message_list,
87 | )
88 | except Exception as e:
89 | exception_backoff = 2**trial # expontial back off
90 | print(
91 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
92 | e,
93 | )
94 | time.sleep(exception_backoff)
95 | trial += 1
96 | # unknown error shall throw exception
97 |
--------------------------------------------------------------------------------
/sampler/claude_sampler.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 |
4 | import anthropic
5 |
6 | from ..types import MessageList, SamplerBase, SamplerResponse
7 | from .. import common
8 |
9 | CLAUDE_SYSTEM_MESSAGE_LMSYS = (
10 | "The assistant is Claude, created by Anthropic. The current date is "
11 | "{currentDateTime}. Claude's knowledge base was last updated in "
12 | "August 2023 and it answers user questions about events before "
13 | "August 2023 and after August 2023 the same way a highly informed "
14 | "individual from August 2023 would if they were talking to someone "
15 | "from {currentDateTime}. It should give concise responses to very "
16 | "simple questions, but provide thorough responses to more complex "
17 | "and open-ended questions. It is happy to help with writing, "
18 | "analysis, question answering, math, coding, and all sorts of other "
19 | "tasks. It uses markdown for coding. It does not mention this "
20 | "information about itself unless the information is directly "
21 | "pertinent to the human's query."
22 | ).format(currentDateTime="2024-04-01")
23 | # reference: https://github.com/lm-sys/FastChat/blob/7899355ebe32117fdae83985cf8ee476d2f4243f/fastchat/conversation.py#L894
24 |
25 |
26 | class ClaudeCompletionSampler(SamplerBase):
27 |
28 | def __init__(
29 | self,
30 | model: str,
31 | system_message: str | None = None,
32 | temperature: float = 0.0, # default in Anthropic example
33 | max_tokens: int = 4096,
34 | ):
35 | self.client = anthropic.Anthropic()
36 | self.api_key = os.environ.get("ANTHROPIC_API_KEY") # please set your API_KEY
37 | self.model = model
38 | self.system_message = system_message
39 | self.temperature = temperature
40 | self.max_tokens = max_tokens
41 | self.image_format = "base64"
42 |
43 | def _handle_image(
44 | self,
45 | image: str,
46 | encoding: str = "base64",
47 | format: str = "png",
48 | fovea: int = 768,
49 | ):
50 | new_image = {
51 | "type": "image",
52 | "source": {
53 | "type": encoding,
54 | "media_type": f"image/{format}",
55 | "data": image,
56 | },
57 | }
58 | return new_image
59 |
60 | def _handle_text(self, text):
61 | return {"type": "text", "text": text}
62 |
63 | def _pack_message(self, role, content):
64 | return {"role": str(role), "content": content}
65 |
66 | def __call__(self, message_list: MessageList) -> SamplerResponse:
67 | trial = 0
68 | while True:
69 | try:
70 | if not common.has_only_user_assistant_messages(message_list):
71 | raise ValueError(f"Claude sampler only supports user and assistant messages, got {message_list}")
72 | if self.system_message:
73 | response_message = self.client.messages.create(
74 | model=self.model,
75 | system=self.system_message,
76 | max_tokens=self.max_tokens,
77 | temperature=self.temperature,
78 | messages=message_list,
79 | )
80 | claude_input_messages: MessageList = [{"role": "system", "content": self.system_message}] + message_list
81 | else:
82 | response_message = self.client.messages.create(
83 | model=self.model,
84 | max_tokens=self.max_tokens,
85 | temperature=self.temperature,
86 | messages=message_list,
87 | )
88 | claude_input_messages = message_list
89 | response_text = response_message.content[0].text
90 | return SamplerResponse(
91 | response_text=response_text,
92 | response_metadata={},
93 | actual_queried_message_list=claude_input_messages,
94 | )
95 | except anthropic.RateLimitError as e:
96 | exception_backoff = 2**trial # expontial back off
97 | print(
98 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
99 | e,
100 | )
101 | time.sleep(exception_backoff)
102 | trial += 1
103 | # unknown error shall throw exception
104 |
--------------------------------------------------------------------------------
/sampler/o_chat_completion_sampler.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any
3 |
4 | import openai
5 | from openai import OpenAI
6 |
7 | from ..types import MessageList, SamplerBase, SamplerResponse
8 |
9 |
10 | class OChatCompletionSampler(SamplerBase):
11 | """
12 | Sample from OpenAI's chat completion API for o series models
13 | """
14 |
15 | def __init__(
16 | self,
17 | *,
18 | reasoning_effort: str | None = None,
19 | model: str = "o1-mini",
20 | ):
21 | self.api_key_name = "OPENAI_API_KEY"
22 | self.client = OpenAI()
23 | # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
24 | self.model = model
25 | self.image_format = "url"
26 | self.reasoning_effort = reasoning_effort
27 |
28 | def _handle_image(
29 | self,
30 | image: str,
31 | encoding: str = "base64",
32 | format: str = "png",
33 | fovea: int = 768,
34 | ):
35 | new_image = {
36 | "type": "image_url",
37 | "image_url": {
38 | "url": f"data:image/{format};{encoding},{image}",
39 | },
40 | }
41 | return new_image
42 |
43 | def _handle_text(self, text: str):
44 | return {"type": "text", "text": text}
45 |
46 | def _pack_message(self, role: str, content: Any):
47 | return {"role": str(role), "content": content}
48 |
49 | def __call__(self, message_list: MessageList) -> SamplerResponse:
50 | trial = 0
51 | while True:
52 | try:
53 | response = self.client.chat.completions.create(
54 | model=self.model,
55 | messages=message_list,
56 | reasoning_effort=self.reasoning_effort,
57 | )
58 | content = response.choices[0].message.content
59 | return SamplerResponse(
60 | response_text=content,
61 | response_metadata={"usage": response.usage},
62 | actual_queried_message_list=message_list,
63 | )
64 | # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
65 | except openai.BadRequestError as e:
66 | print("Bad Request Error", e)
67 | return SamplerResponse(
68 | response_text="",
69 | response_metadata={"usage": None},
70 | actual_queried_message_list=message_list,
71 | )
72 | except Exception as e:
73 | exception_backoff = 2**trial # expontial back off
74 | print(
75 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
76 | e,
77 | )
78 | time.sleep(exception_backoff)
79 | trial += 1
80 | # unknown error shall throw exception
81 |
--------------------------------------------------------------------------------
/sampler/responses_sampler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Any
4 |
5 | import openai
6 | from openai import OpenAI
7 |
8 | from ..types import MessageList, SamplerBase, SamplerResponse
9 |
10 |
11 | class ResponsesSampler(SamplerBase):
12 | """
13 | Sample from OpenAI's responses API
14 | """
15 |
16 | def __init__(
17 | self,
18 | model: str = "gpt-4.1",
19 | system_message: str | None = None,
20 | temperature: float = 0.5,
21 | max_tokens: int = 1024,
22 | reasoning_model: bool = False,
23 | reasoning_effort: str | None = None,
24 | ):
25 | self.api_key_name = "OPENAI_API_KEY"
26 | assert os.environ.get("OPENAI_API_KEY"), "Please set OPENAI_API_KEY"
27 | self.client = OpenAI()
28 | self.model = model
29 | self.system_message = system_message
30 | self.temperature = temperature
31 | self.max_tokens = max_tokens
32 | self.image_format = "url"
33 | self.reasoning_model = reasoning_model
34 | self.reasoning_effort = reasoning_effort
35 |
36 | def _handle_image(
37 | self,
38 | image: str,
39 | encoding: str = "base64",
40 | format: str = "png",
41 | fovea: int = 768,
42 | ) -> dict[str, Any]:
43 | new_image = {
44 | "type": "input_image",
45 | "image_url": f"data:image/{format};{encoding},{image}",
46 | }
47 | return new_image
48 |
49 | def _handle_text(self, text: str) -> dict[str, Any]:
50 | return {"type": "input_text", "text": text}
51 |
52 | def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
53 | return {"role": role, "content": content}
54 |
55 | def __call__(self, message_list: MessageList) -> SamplerResponse:
56 | if self.system_message:
57 | message_list = [
58 | self._pack_message("developer", self.system_message)
59 | ] + message_list
60 | trial = 0
61 | while True:
62 | try:
63 | if self.reasoning_model:
64 | reasoning = (
65 | {"effort": self.reasoning_effort}
66 | if self.reasoning_effort
67 | else None
68 | )
69 | response = self.client.responses.create(
70 | model=self.model,
71 | input=message_list,
72 | reasoning=reasoning,
73 | )
74 | else:
75 | response = self.client.responses.create(
76 | model=self.model,
77 | input=message_list,
78 | temperature=self.temperature,
79 | max_output_tokens=self.max_tokens,
80 | )
81 | return SamplerResponse(
82 | response_text=response.output_text,
83 | response_metadata={"usage": response.usage},
84 | actual_queried_message_list=message_list,
85 | )
86 | except openai.BadRequestError as e:
87 | print("Bad Request Error", e)
88 | return SamplerResponse(
89 | response_text="",
90 | response_metadata={"usage": None},
91 | actual_queried_message_list=message_list,
92 | )
93 | except Exception as e:
94 | exception_backoff = 2**trial # expontial back off
95 | print(
96 | f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
97 | e,
98 | )
99 | time.sleep(exception_backoff)
100 | trial += 1
101 | # unknown error shall throw exception
102 |
--------------------------------------------------------------------------------
/simple_evals.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import subprocess
4 | from datetime import datetime
5 |
6 | import pandas as pd
7 |
8 | from . import common
9 | from .browsecomp_eval import BrowseCompEval
10 | from .drop_eval import DropEval
11 | from .gpqa_eval import GPQAEval
12 | from .healthbench_eval import HealthBenchEval
13 | from .healthbench_meta_eval import HealthBenchMetaEval
14 | from .math_eval import MathEval
15 | from .mgsm_eval import MGSMEval
16 | from .mmlu_eval import MMLUEval
17 | from .humaneval_eval import HumanEval
18 | from .sampler.chat_completion_sampler import (
19 | OPENAI_SYSTEM_MESSAGE_API,
20 | OPENAI_SYSTEM_MESSAGE_CHATGPT,
21 | ChatCompletionSampler,
22 | )
23 | from .sampler.claude_sampler import ClaudeCompletionSampler, CLAUDE_SYSTEM_MESSAGE_LMSYS
24 | from .sampler.o_chat_completion_sampler import OChatCompletionSampler
25 | from .sampler.responses_sampler import ResponsesSampler
26 | from .simpleqa_eval import SimpleQAEval
27 |
28 |
29 | def main():
30 | parser = argparse.ArgumentParser(
31 | description="Run sampling and evaluations using different samplers and evaluations."
32 | )
33 | parser.add_argument(
34 | "--list-models", action="store_true", help="List available models"
35 | )
36 | parser.add_argument(
37 | "--model",
38 | type=str,
39 | help="Select a model by name. Also accepts a comma-separated list of models.",
40 | )
41 | parser.add_argument(
42 | "--eval",
43 | type=str,
44 | help="Select an eval by name. Also accepts a comma-separated list of evals.",
45 | )
46 | parser.add_argument(
47 | "--n-repeats",
48 | type=int,
49 | default=None,
50 | help="Number of repeats to run. Only supported for certain evals.",
51 | )
52 | parser.add_argument(
53 | "--n-threads",
54 | type=int,
55 | default=120,
56 | help="Number of threads to run. Only supported for HealthBench and HealthBenchMeta.",
57 | )
58 | parser.add_argument("--debug", action="store_true", help="Run in debug mode")
59 | parser.add_argument(
60 | "--examples", type=int, help="Number of examples to use (overrides default)"
61 | )
62 |
63 | args = parser.parse_args()
64 |
65 | models = {
66 | # Reasoning Models
67 | "o3": ResponsesSampler(
68 | model="o3-2025-04-16",
69 | reasoning_model=True,
70 | ),
71 | "o3-temp-1": ResponsesSampler(
72 | model="o3-2025-04-16",
73 | reasoning_model=True,
74 | temperature=1.0,
75 | ),
76 | "o3_high": ResponsesSampler(
77 | model="o3-2025-04-16",
78 | reasoning_model=True,
79 | reasoning_effort="high",
80 | ),
81 | "o3_low": ResponsesSampler(
82 | model="o3-2025-04-16",
83 | reasoning_model=True,
84 | reasoning_effort="low",
85 | ),
86 | # Default == Medium
87 | "o4-mini": ResponsesSampler(
88 | model="o4-mini-2025-04-16",
89 | reasoning_model=True,
90 | ),
91 | "o4-mini_high": ResponsesSampler(
92 | model="o4-mini-2025-04-16",
93 | reasoning_model=True,
94 | reasoning_effort="high",
95 | ),
96 | "o4-mini_low": ResponsesSampler(
97 | model="o4-mini-2025-04-16",
98 | reasoning_model=True,
99 | reasoning_effort="low",
100 | ),
101 | "o1-pro": ResponsesSampler(
102 | model="o1-pro",
103 | reasoning_model=True,
104 | ),
105 | "o1": OChatCompletionSampler(
106 | model="o1",
107 | ),
108 | "o1_high": OChatCompletionSampler(
109 | model="o1",
110 | reasoning_effort="high",
111 | ),
112 | "o1_low": OChatCompletionSampler(
113 | model="o1",
114 | reasoning_effort="low",
115 | ),
116 | "o1-preview": OChatCompletionSampler(
117 | model="o1-preview",
118 | ),
119 | "o1-mini": OChatCompletionSampler(
120 | model="o1-mini",
121 | ),
122 | # Default == Medium
123 | "o3-mini": OChatCompletionSampler(
124 | model="o3-mini",
125 | ),
126 | "o3-mini_high": OChatCompletionSampler(
127 | model="o3-mini",
128 | reasoning_effort="high",
129 | ),
130 | "o3-mini_low": OChatCompletionSampler(
131 | model="o3-mini",
132 | reasoning_effort="low",
133 | ),
134 | # GPT-4.1 models
135 | "gpt-4.1": ChatCompletionSampler(
136 | model="gpt-4.1-2025-04-14",
137 | system_message=OPENAI_SYSTEM_MESSAGE_API,
138 | max_tokens=2048,
139 | ),
140 | "gpt-4.1-temp-1": ChatCompletionSampler(
141 | model="gpt-4.1-2025-04-14",
142 | system_message=OPENAI_SYSTEM_MESSAGE_API,
143 | max_tokens=2048,
144 | temperature=1.0,
145 | ),
146 | "gpt-4.1-mini": ChatCompletionSampler(
147 | model="gpt-4.1-mini-2025-04-14",
148 | system_message=OPENAI_SYSTEM_MESSAGE_API,
149 | max_tokens=2048,
150 | ),
151 | "gpt-4.1-nano": ChatCompletionSampler(
152 | model="gpt-4.1-nano-2025-04-14",
153 | system_message=OPENAI_SYSTEM_MESSAGE_API,
154 | max_tokens=2048,
155 | ),
156 | # GPT-4o models
157 | "gpt-4o": ChatCompletionSampler(
158 | model="gpt-4o",
159 | system_message=OPENAI_SYSTEM_MESSAGE_API,
160 | max_tokens=2048,
161 | ),
162 | "gpt-4o-2024-11-20": ChatCompletionSampler(
163 | model="gpt-4o-2024-11-20",
164 | system_message=OPENAI_SYSTEM_MESSAGE_API,
165 | max_tokens=2048,
166 | ),
167 | "gpt-4o-2024-08-06": ChatCompletionSampler(
168 | model="gpt-4o-2024-08-06",
169 | system_message=OPENAI_SYSTEM_MESSAGE_API,
170 | max_tokens=2048,
171 | ),
172 | "gpt-4o-2024-08-06-temp-1": ChatCompletionSampler(
173 | model="gpt-4o-2024-08-06",
174 | system_message=OPENAI_SYSTEM_MESSAGE_API,
175 | max_tokens=2048,
176 | temperature=1.0,
177 | ),
178 | "gpt-4o-2024-05-13": ChatCompletionSampler(
179 | model="gpt-4o-2024-05-13",
180 | system_message=OPENAI_SYSTEM_MESSAGE_API,
181 | max_tokens=2048,
182 | ),
183 | "gpt-4o-mini": ChatCompletionSampler(
184 | model="gpt-4o-mini-2024-07-18",
185 | system_message=OPENAI_SYSTEM_MESSAGE_API,
186 | max_tokens=2048,
187 | ),
188 | # GPT-4.5 model
189 | "gpt-4.5-preview": ChatCompletionSampler(
190 | model="gpt-4.5-preview-2025-02-27",
191 | system_message=OPENAI_SYSTEM_MESSAGE_API,
192 | max_tokens=2048,
193 | ),
194 | # GPT-4-turbo model
195 | "gpt-4-turbo-2024-04-09": ChatCompletionSampler(
196 | model="gpt-4-turbo-2024-04-09",
197 | system_message=OPENAI_SYSTEM_MESSAGE_API,
198 | ),
199 | # GPT-4 model
200 | "gpt-4-0613": ChatCompletionSampler(
201 | model="gpt-4-0613",
202 | system_message=OPENAI_SYSTEM_MESSAGE_API,
203 | ),
204 | # GPT-3.5 Turbo model
205 | "gpt-3.5-turbo-0125": ChatCompletionSampler(
206 | model="gpt-3.5-turbo-0125",
207 | system_message=OPENAI_SYSTEM_MESSAGE_API,
208 | ),
209 | "gpt-3.5-turbo-0125-temp-1": ChatCompletionSampler(
210 | model="gpt-3.5-turbo-0125",
211 | system_message=OPENAI_SYSTEM_MESSAGE_API,
212 | temperature=1.0,
213 | ),
214 | # Chatgpt models:
215 | "chatgpt-4o-latest": ChatCompletionSampler(
216 | model="chatgpt-4o-latest",
217 | system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
218 | max_tokens=2048,
219 | ),
220 | "gpt-4-turbo-2024-04-09_chatgpt": ChatCompletionSampler(
221 | model="gpt-4-turbo-2024-04-09",
222 | system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
223 | ),
224 | # Claude models:
225 | "claude-3-opus-20240229_empty": ClaudeCompletionSampler(
226 | model="claude-3-opus-20240229",
227 | system_message=CLAUDE_SYSTEM_MESSAGE_LMSYS,
228 | ),
229 | "claude-3-7-sonnet-20250219": ClaudeCompletionSampler(
230 | model="claude-3-7-sonnet-20250219",
231 | system_message=CLAUDE_SYSTEM_MESSAGE_LMSYS,
232 | ),
233 | "claude-3-haiku-20240307": ClaudeCompletionSampler(
234 | model="claude-3-haiku-20240307",
235 | ),
236 | }
237 |
238 | if args.list_models:
239 | print("Available models:")
240 | for model_name in models.keys():
241 | print(f" - {model_name}")
242 | return
243 |
244 | if args.model:
245 | models_chosen = args.model.split(",")
246 | for model_name in models_chosen:
247 | if model_name not in models:
248 | print(f"Error: Model '{model_name}' not found.")
249 | return
250 | models = {model_name: models[model_name] for model_name in models_chosen}
251 |
252 | print(f"Running with args {args}")
253 |
254 | grading_sampler = ChatCompletionSampler(
255 | model="gpt-4.1-2025-04-14",
256 | system_message=OPENAI_SYSTEM_MESSAGE_API,
257 | max_tokens=2048,
258 | )
259 | equality_checker = ChatCompletionSampler(model="gpt-4-turbo-preview")
260 | # ^^^ used for fuzzy matching, just for math
261 |
262 | def get_evals(eval_name, debug_mode):
263 | num_examples = (
264 | args.examples if args.examples is not None else (5 if debug_mode else None)
265 | )
266 | # Set num_examples = None to reproduce full evals
267 | match eval_name:
268 | case "mmlu":
269 | return MMLUEval(num_examples=1 if debug_mode else num_examples)
270 | case "math":
271 | return MathEval(
272 | equality_checker=equality_checker,
273 | num_examples=num_examples,
274 | n_repeats=1 if debug_mode else args.n_repeats or 10,
275 | )
276 | case "gpqa":
277 | return GPQAEval(
278 | n_repeats=1 if debug_mode else args.n_repeats or 10,
279 | num_examples=num_examples,
280 | )
281 | case "mgsm":
282 | return MGSMEval(
283 | num_examples_per_lang=10 if debug_mode else num_examples or 250
284 | )
285 | case "drop":
286 | return DropEval(
287 | num_examples=10 if debug_mode else num_examples,
288 | train_samples_per_prompt=3,
289 | )
290 | case "humaneval":
291 | return HumanEval(num_examples=10 if debug_mode else num_examples)
292 | case "simpleqa":
293 | return SimpleQAEval(
294 | grader_model=grading_sampler,
295 | num_examples=10 if debug_mode else num_examples,
296 | )
297 | case "browsecomp":
298 | return BrowseCompEval(
299 | grader_model=grading_sampler,
300 | num_examples=10 if debug_mode else num_examples,
301 | )
302 | case "healthbench":
303 | return HealthBenchEval(
304 | grader_model=grading_sampler,
305 | num_examples=10 if debug_mode else num_examples,
306 | n_repeats=args.n_repeats or 1,
307 | n_threads=args.n_threads or 1,
308 | subset_name=None,
309 | )
310 | case "healthbench_hard":
311 | return HealthBenchEval(
312 | grader_model=grading_sampler,
313 | num_examples=10 if debug_mode else num_examples,
314 | n_repeats=args.n_repeats or 1,
315 | n_threads=args.n_threads or 1,
316 | subset_name="hard",
317 | )
318 | case "healthbench_consensus":
319 | return HealthBenchEval(
320 | grader_model=grading_sampler,
321 | num_examples=10 if debug_mode else num_examples,
322 | n_repeats=args.n_repeats or 1,
323 | n_threads=args.n_threads or 1,
324 | subset_name="consensus",
325 | )
326 | case "healthbench_meta":
327 | return HealthBenchMetaEval(
328 | grader_model=grading_sampler,
329 | num_examples=10 if debug_mode else num_examples,
330 | n_repeats=args.n_repeats or 1,
331 | n_threads=args.n_threads or 1,
332 | )
333 | case _:
334 | raise Exception(f"Unrecognized eval type: {eval_name}")
335 |
336 | if args.eval:
337 | evals_list = args.eval.split(",")
338 | evals = {}
339 | for eval_name in evals_list:
340 | try:
341 | evals[eval_name] = get_evals(eval_name, args.debug)
342 | except Exception:
343 | print(f"Error: eval '{eval_name}' not found.")
344 | return
345 | else:
346 | evals = {
347 | eval_name: get_evals(eval_name, args.debug)
348 | for eval_name in [
349 | "mmlu",
350 | "math",
351 | "gpqa",
352 | "mgsm",
353 | "drop",
354 | "humaneval",
355 | "simpleqa",
356 | "browsecomp",
357 | "healthbench",
358 | "healthbench_hard",
359 | "healthbench_consensus",
360 | "healthbench_meta",
361 | ]
362 | }
363 |
364 | print(evals)
365 | debug_suffix = "_DEBUG" if args.debug else ""
366 | print(debug_suffix)
367 | mergekey2resultpath = {}
368 | print(f"Running the following evals: {list(evals.keys())}")
369 | print(f"Running evals for the following models: {list(models.keys())}")
370 |
371 | now = datetime.now()
372 | date_str = now.strftime("%Y%m%d_%H%M%S")
373 | for model_name, sampler in models.items():
374 | for eval_name, eval_obj in evals.items():
375 | result = eval_obj(sampler)
376 | # ^^^ how to use a sampler
377 | file_stem = f"{eval_name}_{model_name}"
378 | # file stem should also include the year, month, day, and time in hours and minutes
379 | file_stem += f"_{date_str}"
380 | report_filename = f"/tmp/{file_stem}{debug_suffix}.html"
381 | print(f"Writing report to {report_filename}")
382 | with open(report_filename, "w") as fh:
383 | fh.write(common.make_report(result))
384 | assert result.metrics is not None
385 | metrics = result.metrics | {"score": result.score}
386 | # Sort metrics by key
387 | metrics = dict(sorted(metrics.items()))
388 | print(metrics)
389 | result_filename = f"/tmp/{file_stem}{debug_suffix}.json"
390 | with open(result_filename, "w") as f:
391 | f.write(json.dumps(metrics, indent=2))
392 | print(f"Writing results to {result_filename}")
393 |
394 | full_result_filename = f"/tmp/{file_stem}{debug_suffix}_allresults.json"
395 | with open(full_result_filename, "w") as f:
396 | result_dict = {
397 | "score": result.score,
398 | "metrics": result.metrics,
399 | "htmls": result.htmls,
400 | "convos": result.convos,
401 | "metadata": result.metadata,
402 | }
403 | f.write(json.dumps(result_dict, indent=2))
404 | print(f"Writing all results to {full_result_filename}")
405 |
406 | mergekey2resultpath[f"{file_stem}"] = result_filename
407 | merge_metrics = []
408 | for eval_model_name, result_filename in mergekey2resultpath.items():
409 | try:
410 | result = json.load(open(result_filename, "r+"))
411 | except Exception as e:
412 | print(e, result_filename)
413 | continue
414 | result = result.get("f1_score", result.get("score", None))
415 | eval_name = eval_model_name[: eval_model_name.find("_")]
416 | model_name = eval_model_name[eval_model_name.find("_") + 1 :]
417 | merge_metrics.append(
418 | {"eval_name": eval_name, "model_name": model_name, "metric": result}
419 | )
420 | merge_metrics_df = pd.DataFrame(merge_metrics).pivot(
421 | index=["model_name"], columns="eval_name"
422 | )
423 | print("\nAll results: ")
424 | print(merge_metrics_df.to_markdown())
425 | return merge_metrics
426 |
427 |
428 | if __name__ == "__main__":
429 | main()
430 |
--------------------------------------------------------------------------------
/simpleqa_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | SimpleQA: Measuring short-form factuality in large language models
3 | Authors: Jason Wei, Nguyen Karina, Hyung Won Chung, Yunxin Joy Jiao, Spencer Papay, Amelia Glaese, John Schulman, William Fedus
4 | https://cdn.openai.com/papers/simpleqa.pdf
5 | """
6 |
7 | import random
8 | import re
9 | import pandas
10 | from . import common
11 | from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
12 |
13 | GRADER_TEMPLATE = """
14 | Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
15 | First, I will give examples of each grade, and then you will grade a new example.
16 |
17 |
18 | The following are examples of CORRECT predicted answers.
19 | ```
20 | Question: What are the names of Barack Obama's children?
21 | Gold target: Malia Obama and Sasha Obama
22 | Predicted answer 1: sasha and malia obama
23 | Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check
24 | Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001.
25 | ```
26 | These predicted answers are all CORRECT because:
27 | - They fully contain the important information in the gold target.
28 | - They do not contain any information that contradicts the gold target.
29 | - Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter.
30 | - Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions.
31 |
32 |
33 | The following are examples of INCORRECT predicted answers.
34 | ```
35 | Question: What are the names of Barack Obama's children?
36 | Gold target: Malia and Sasha
37 | Predicted answer 1: Malia.
38 | Predicted answer 2: Malia, Sasha, and Susan.
39 | Predicted answer 3: Barack Obama does not have any children.
40 | Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia.
41 | Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children.
42 | Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer?
43 | Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information.
44 | ```
45 | These predicted answers are all INCORRECT because:
46 | - A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect.
47 |
48 |
49 | The following are examples of NOT_ATTEMPTED predicted answers.
50 | ```
51 | Question: What are the names of Barack Obama's children?
52 | Gold target: Malia and Sasha
53 | Predicted answer 1: I don't know.
54 | Predicted answer 2: I need more context about which Obama you are talking about.
55 | Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children.
56 | Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one.
57 | ```
58 | These predicted answers are all NOT_ATTEMPTED because:
59 | - The important information in the gold target is not included in the answer.
60 | - No statements in the answer contradict the gold target.
61 |
62 |
63 | Also note the following things:
64 | - For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k".
65 | - Predicted answers "120k", "124k", and 115k" are all CORRECT.
66 | - Predicted answers "100k" and "113k" are INCORRECT.
67 | - Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target.
68 | - The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question.
69 | - For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer.
70 | - Do not punish predicted answers if they omit information that would be clearly inferred from the question.
71 | - For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California".
72 | - Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question.
73 | - For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question.
74 | - For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed.
75 | - Do not punish for typos in people's name if it's clearly the same name.
76 | - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung".
77 |
78 |
79 | Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
80 | ```
81 | Question: {question}
82 | Gold target: {target}
83 | Predicted answer: {predicted_answer}
84 | ```
85 |
86 | Grade the predicted answer of this new question as one of:
87 | A: CORRECT
88 | B: INCORRECT
89 | C: NOT_ATTEMPTED
90 |
91 | Just return the letters "A", "B", or "C", with no text around it.
92 | """.strip()
93 |
94 |
95 | CHOICE_LETTERS = ["A", "B", "C"]
96 | CHOICE_STRINGS = ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]
97 | CHOICE_LETTER_TO_STRING = dict(zip(CHOICE_LETTERS, CHOICE_STRINGS))
98 |
99 | class SimpleQAEval(Eval):
100 | def __init__(self, grader_model: SamplerBase, num_examples: int | None = None, n_repeats: int = 1):
101 | df = pandas.read_csv(
102 | "https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv"
103 | )
104 | examples = [row.to_dict() for _, row in df.iterrows()]
105 | if num_examples:
106 | assert n_repeats == 1, "n_repeats only supported when max_examples = None"
107 | rng = random.Random(0)
108 | examples = rng.sample(examples, num_examples)
109 | self.examples = examples * n_repeats
110 | self.grader_model = grader_model
111 |
112 | def grade_sample(self, question: str, target: str, predicted_answer: str) -> str:
113 | grader_prompt = GRADER_TEMPLATE.format(
114 | question=question,
115 | target=target,
116 | predicted_answer=predicted_answer,
117 | )
118 |
119 | prompt_messages = [
120 | self.grader_model._pack_message(content=grader_prompt, role="user")
121 | ]
122 | sampler_response = self.grader_model(prompt_messages)
123 | grading_response = sampler_response.response_text
124 |
125 | match = re.search(r"(A|B|C)", grading_response)
126 | return match.group(0) if match else "C" # Default to "NOT_ATTEMPTED" if no match
127 |
128 | def __call__(self, sampler: SamplerBase) -> EvalResult:
129 | def fn(row: dict):
130 | prompt_messages = [
131 | sampler._pack_message(content=row.get("problem", ""), role="user")
132 | ]
133 | sampler_response = sampler(prompt_messages)
134 | response_text = sampler_response.response_text
135 | actual_queried_prompt_messages = sampler_response.actual_queried_message_list
136 | grade_letter = self.grade_sample(row.get("problem", ""), row.get("answer", ""), response_text)
137 |
138 | # Metrics based on grading response
139 | is_correct = grade_letter == "A"
140 | is_incorrect = grade_letter == "B"
141 | is_not_attempted = grade_letter == "C"
142 |
143 | score = is_correct
144 |
145 | # Create HTML for each sample result
146 | html = common.jinja_env.from_string(common.HTML_JINJA).render(
147 | prompt_messages=actual_queried_prompt_messages,
148 | next_message=dict(content=response_text, role="assistant"),
149 | score=score,
150 | correct_answer=row["answer"],
151 | extracted_answer=response_text,
152 | )
153 | convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
154 | return SingleEvalResult(html=html, score=score, convo=convo, metrics={
155 | "is_correct": is_correct,
156 | "is_incorrect": is_incorrect,
157 | "is_not_attempted": is_not_attempted
158 | })
159 |
160 | # Run evaluation and collect results
161 | results = common.map_with_progress(fn, self.examples)
162 |
163 | # Aggregate metrics
164 | aggregate_metrics = {
165 | "is_correct": sum(result.metrics["is_correct"] for result in results) / len(results),
166 | "is_incorrect": sum(result.metrics["is_incorrect"] for result in results) / len(results),
167 | "is_not_attempted": sum(result.metrics["is_not_attempted"] for result in results) / len(results),
168 | }
169 | aggregate_metrics["is_given_attempted"] = aggregate_metrics["is_correct"] + aggregate_metrics["is_incorrect"]
170 | # Calculate accuracy_given_attempted
171 | aggregate_metrics["accuracy_given_attempted"] = (
172 | aggregate_metrics["is_correct"]
173 | / aggregate_metrics["is_given_attempted"]
174 | if aggregate_metrics["is_given_attempted"] > 0
175 | else 0
176 | )
177 | print("AGGREGATE METRICS")
178 | print(aggregate_metrics)
179 | print("##################")
180 |
181 | output_d = {
182 | "accuracy_given_attempted": aggregate_metrics["accuracy_given_attempted"],
183 | "f1": (
184 | 2 * aggregate_metrics["accuracy_given_attempted"] * aggregate_metrics["is_correct"]
185 | / (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"])
186 | if (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"]) > 0
187 | else 0
188 | )
189 | }
190 |
191 | print(f"Accuracy Given Attempted: {output_d['accuracy_given_attempted']:.3f}")
192 | print(f"F1 Score: {output_d['f1']:.3f}")
193 |
194 | return common.aggregate_results(results)
195 |
196 |
197 |
--------------------------------------------------------------------------------
/types.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Any, Literal, overload
3 |
4 | Message = dict[str, Any] # keys role, content
5 | MessageList = list[Message]
6 |
7 |
8 |
9 | @dataclass
10 | class SamplerResponse:
11 | """
12 | Response from a sampler.
13 | """
14 | response_text: str
15 | actual_queried_message_list: MessageList
16 | response_metadata: dict[str, Any]
17 |
18 | class SamplerBase:
19 | """
20 | Base class for defining a sampling model, which can be evaluated,
21 | or used as part of the grading process.
22 | """
23 |
24 | def __call__(
25 | self,
26 | message_list: MessageList,
27 | ) -> SamplerResponse:
28 | raise NotImplementedError
29 |
30 |
31 | @dataclass
32 | class EvalResult:
33 | """
34 | Result of running an evaluation (usually consisting of many samples)
35 | """
36 |
37 | score: float | None # top-line metric
38 | metrics: dict[str, float] | None # other metrics
39 | htmls: list[str] # strings of valid HTML
40 | convos: list[MessageList] # sampled conversations
41 | metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen
42 |
43 |
44 | @dataclass
45 | class SingleEvalResult:
46 | """
47 | Result of evaluating a single sample
48 | """
49 |
50 | score: float | None
51 | metrics: dict[str, float] = field(default_factory=dict)
52 | html: str | None = None
53 | convo: MessageList | None = None # sampled conversation
54 | example_level_metadata: dict[str, Any] | None = (
55 | None # Extra data such as rubric scores or sollen
56 | )
57 |
58 |
59 | class Eval:
60 | """
61 | Base class for defining an evaluation.
62 | """
63 |
64 | def __call__(self, sampler: SamplerBase) -> EvalResult:
65 | raise NotImplementedError
66 |
67 |
--------------------------------------------------------------------------------