├── eval ├── __init__.py ├── tasks │ ├── __init__.py │ ├── docvqa.py │ ├── vqav2.py │ ├── chartqa.py │ ├── mathvista.py │ ├── mmmu.py │ └── mm_mt_bench.py ├── run.py ├── task.py ├── models.py └── metrics.py ├── .gitignore ├── requirements.txt └── README.md /eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==3.0.0 2 | fire==0.6.0 3 | numpy==1.26.4 4 | openai==1.45.0 5 | pillow==10.4.0 6 | tqdm==4.66.5 7 | vllm==0.6.2 8 | -------------------------------------------------------------------------------- /eval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from eval.tasks.mm_mt_bench import MultimodalMTBench 2 | from eval.tasks.vqav2 import VQAv2 3 | from eval.tasks.docvqa import DocVQA 4 | from eval.tasks.mmmu import MMMU 5 | from eval.tasks.mathvista import MathVista 6 | from eval.tasks.chartqa import ChartQA 7 | 8 | 9 | TASK_REGISTRY = { 10 | "mm_mt_bench": MultimodalMTBench, 11 | "vqav2": VQAv2, 12 | "docvqa": DocVQA, 13 | "mmmu": MMMU, 14 | "mathvista": MathVista, 15 | "chartqa": ChartQA, 16 | } 17 | 18 | 19 | def get_task(task_name): 20 | if task_name not in TASK_REGISTRY: 21 | raise ValueError(f"Did not recognize task name {task_name}") 22 | 23 | return TASK_REGISTRY[task_name]() 24 | -------------------------------------------------------------------------------- /eval/tasks/docvqa.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from datasets import load_dataset 4 | from eval.metrics import ANLS, Metric 5 | from eval.task import HuggingFaceEval, Interaction 6 | 7 | PROMPT = "Answer the question using a single word or phrase." 8 | 9 | 10 | class DocVQA(HuggingFaceEval): 11 | dataset_name = "lmms-lab/DocVQA" 12 | dataset_split = "validation" 13 | # DocVQA needs an extra config name. 14 | dataset_config = "DocVQA" 15 | 16 | @property 17 | def metric_fns(self) -> list[Metric]: 18 | return [ANLS()] 19 | 20 | def _to_interaction(self, row: Any): 21 | return Interaction( 22 | { 23 | "temperature": 0.0, 24 | "max_tokens": 10, 25 | "messages": [ 26 | { 27 | "role": "user", 28 | "content": [ 29 | {"type": "image", "image": row["image"]}, 30 | {"type": "text", "text": row["question"] + "\n" + PROMPT}, 31 | ], 32 | } 33 | ], 34 | }, 35 | reference_answer=row["answers"], 36 | ) 37 | 38 | def get_dataset(self): 39 | return load_dataset(self.dataset_name, self.dataset_config)[self.dataset_split] 40 | -------------------------------------------------------------------------------- /eval/tasks/vqav2.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from datasets import load_dataset 4 | 5 | from eval.metrics import VQAMatch, Metric 6 | from eval.task import HuggingFaceEval, Interaction 7 | 8 | PROMPT = """- Answer the question using a single word, number, or short phrase. Use as few words as possible. 9 | - If the answer is a number, report it as a number, i.e. 2, not Two, and only include the number without any unit. 10 | - If the question is Yes/No, answer with Yes/No, and nothing else (no likely, unknown, etc.). 11 | - You cannot answer that the question is unanswerable. You must answer.""" 12 | 13 | 14 | class VQAv2(HuggingFaceEval): 15 | dataset_name = "HuggingFaceM4/VQAv2" 16 | dataset_split = "validation" 17 | 18 | def _to_interaction(self, row: Any) -> Interaction: 19 | return Interaction( 20 | { 21 | "temperature": 0.0, 22 | "max_tokens": 10, 23 | "messages": [ 24 | { 25 | "role": "user", 26 | "content": [ 27 | {"type": "image", "image": row["image"]}, 28 | {"type": "text", "text": row["question"] + "\n" + PROMPT}, 29 | ], 30 | } 31 | ], 32 | }, 33 | reference_answer=row["answers"], 34 | ) 35 | 36 | @property 37 | def metric_fns(self) -> list[Metric]: 38 | return [VQAMatch()] 39 | 40 | def load_eval(self): 41 | for row in load_dataset(self.dataset_name)[self.dataset_split]: 42 | self.interactions.append(self._to_interaction(row)) 43 | -------------------------------------------------------------------------------- /eval/tasks/chartqa.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from eval.metrics import ( 4 | Metric, 5 | ExplicitPromptRelaxedCorrectness, 6 | AnywhereInAnswerRelaxedCorrectness, 7 | ) 8 | from eval.task import HuggingFaceEval, Interaction 9 | 10 | 11 | PROMPT = """Analyze the image and question carefully, using step-by-step reasoning. 12 | First, describe any image provided in detail. Then, present your reasoning. And finally your final answer in this format: 13 | Final Answer: 14 | where follows the following instructions: 15 | - should should be a single phrase or number. 16 | - should not paraphrase or reformat the text in the image. 17 | - If is a ratio, it should be a decimal value like 0.25 instead of 1:4. 18 | - If the question is a Yes/No question, should be Yes/No. 19 | - If is a number, it should not contain any units. 20 | - If is a percentage, it should include a % sign. 21 | - If is an entity, it should include the full label from the graph. 22 | IMPORTANT: Remember, to end your answer with Final Answer: .""" 23 | 24 | 25 | class ChartQA(HuggingFaceEval): 26 | dataset_name = "lmms-lab/ChartQA" 27 | dataset_split = "test" 28 | 29 | def _to_interaction(self, row: dict[str, Any]) -> Interaction: 30 | image = row["image"] 31 | question = row["question"] 32 | answer: list[str] = row["answer"] 33 | 34 | return Interaction( 35 | { 36 | "temperature": 0.0, 37 | "max_tokens": 2048, 38 | "messages": [ 39 | { 40 | "role": "user", 41 | "content": [ 42 | {"type": "image", "image": image}, 43 | {"type": "text", "text": question + "\n" + PROMPT}, 44 | ], 45 | } 46 | ], 47 | }, 48 | reference_answer=answer, 49 | ) 50 | 51 | @property 52 | def metric_fns(self) -> list[Metric]: 53 | return [ 54 | ExplicitPromptRelaxedCorrectness(), 55 | AnywhereInAnswerRelaxedCorrectness(), 56 | ] 57 | -------------------------------------------------------------------------------- /eval/tasks/mathvista.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from eval.metrics import ( 4 | Metric, 5 | ExplicitPromptRelaxedCorrectness, 6 | AnywhereInAnswerRelaxedCorrectness, 7 | ) 8 | from eval.task import HuggingFaceEval, Interaction 9 | 10 | 11 | PROMPT = """Analyze the image and question carefully, using step-by-step reasoning. 12 | First, describe any image provided in detail. Then, present your reasoning. And finally your final answer in this format: 13 | Final Answer: 14 | where is: 15 | - The single correct letter choice A, B, C, D, E, F, etc. when options are provided. Only include the letter. 16 | - Your direct answer if no options are given, as a single phrase or number. 17 | - If your answer is a number, only include the number without any unit. 18 | - If your answer is a word or phrase, do not paraphrase or reformat the text you see in the image. 19 | - You cannot answer that the question is unanswerable. You must either pick an option or provide a direct answer. 20 | IMPORTANT: Remember, to end your answer with Final Answer: .""" 21 | 22 | 23 | class MathVista(HuggingFaceEval): 24 | dataset_name = "AI4Math/MathVista" 25 | dataset_split = "testmini" 26 | 27 | def _to_interaction(self, row: dict[str, Any]) -> Interaction: 28 | image = row["decoded_image"] 29 | question = row["query"] 30 | 31 | if row["choices"]: 32 | answer_index = row["choices"].index(row["answer"]) 33 | answer = chr(ord("A") + answer_index) 34 | else: 35 | answer = row["answer"] 36 | 37 | return Interaction( 38 | { 39 | "temperature": 0.0, 40 | "max_tokens": 2048, 41 | "messages": [ 42 | { 43 | "role": "user", 44 | "content": [ 45 | {"type": "image", "image": image}, 46 | {"type": "text", "text": question + "\n" + PROMPT}, 47 | ], 48 | } 49 | ], 50 | }, 51 | reference_answer=answer, 52 | ) 53 | 54 | @property 55 | def metric_fns(self) -> list[Metric]: 56 | return [ 57 | ExplicitPromptRelaxedCorrectness(), 58 | AnywhereInAnswerRelaxedCorrectness(), 59 | ] 60 | -------------------------------------------------------------------------------- /eval/run.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import fire 5 | from eval.models import Model, VLLMModel 6 | from eval.tasks import get_task 7 | 8 | 9 | def evaluate( 10 | model: Model, 11 | eval_name: str, 12 | output_dir: str | Path, 13 | save_raw_output: bool = False, 14 | ): 15 | """ 16 | Args: 17 | model_fn: A callable that takes a chat completion request and queries a model 18 | to get a text response. 19 | eval_name: Name of an eval to run. 20 | model_name: Name of model being evaluated (need to set this for API based evals) 21 | """ 22 | eval_task = get_task(eval_name) 23 | 24 | output_dir = Path(output_dir) 25 | output_dir.mkdir(exist_ok=True, parents=True) 26 | 27 | eval_task.load_eval() 28 | eval_task.get_responses(model) 29 | eval_task.compute_metrics() 30 | 31 | metrics_output = json.dumps(eval_task.aggregate_metrics(), indent=4) 32 | with (output_dir / f"{eval_name}.json").open("w") as f: 33 | f.write(metrics_output) 34 | 35 | print("=" * 80) 36 | print("Metrics:") 37 | print(metrics_output) 38 | print("=" * 80) 39 | 40 | if save_raw_output: 41 | raw_output = json.dumps(eval_task.save_question_answer(), indent=4) 42 | with (output_dir / f"raw_output_{eval_name}.json").open("w") as f: 43 | f.write(raw_output) 44 | 45 | 46 | def eval_vllm( 47 | model_name: str, 48 | url: str, 49 | eval_name: str, 50 | output_dir: str | Path, 51 | save_raw_output: bool = False, 52 | ): 53 | model = VLLMModel(model_name, url) 54 | evaluate(model, eval_name, output_dir, save_raw_output) 55 | 56 | 57 | if __name__ == "__main__": 58 | """Usage: 59 | 60 | Step 1: Host a model using vLLM 61 | >> vllm serve mistralai/Pixtral-12B-2409 --config_format mistral --tokenizer_mode "mistral" 62 | 63 | Step 2: Evaluate hosted model. 64 | >> python -m eval.run eval_vllm \ 65 | --model_name mistralai/Pixtral-12B-2409 \ 66 | --url http://0.0.0.0:8000 \ 67 | --output_dir_str ~/tmp \ 68 | --eval_name docvqa \ 69 | --save_raw_output 70 | 71 | To evaluate your own model, you can use create a ModelClass which implements an 72 | interface for returning a generated response given a chat completion request. 73 | """ 74 | fire.Fire({"eval_vllm": eval_vllm}) 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mistral Evals 2 | 3 | This repository contains code to run evals released by Mistral AI as well as standardized prompts, parsing and metrics computation for popular academic benchmarks. 4 | 5 | ## Installation 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Evals 12 | 13 | We support the following evals in this repository: 14 | * `mm_mt_bench`: [MM-MT-Bench](https://huggingface.co/datasets/mistralai/MM-MT-Bench) is a multi-turn LLM-as-a-judge evaluation task released by Mistral AI that uses GPT-4o for judging model answers given reference answers. 15 | * `vqav2`: [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) 16 | * `docvqa`: [DocVQA](https://huggingface.co/datasets/lmms-lab/DocVQA) 17 | * `mathvista`: [MathVista](https://huggingface.co/datasets/AI4Math/MathVista) 18 | * `mmmu`: [MMMU](https://huggingface.co/datasets/lmms-lab/MMMU) 19 | * `chartqa`: [ChartQA](https://github.com/vis-nlp/ChartQA) 20 | 21 | ### Example usage: 22 | 23 | **Step 1**: Host a model using vLLM 24 | 25 | To install vLLM, follow the directions [here](https://docs.vllm.ai/en/latest/getting_started/installation.html). 26 | 27 | ``` 28 | >> vllm serve mistralai/Pixtral-12B-2409 --config_format mistral --tokenizer_mode "mistral" 29 | ``` 30 | 31 | **Step 2**: Evaluate hosted model. 32 | ``` 33 | >> python -m eval.run eval_vllm \ 34 | --model_name mistralai/Pixtral-12B-2409 \ 35 | --url http://0.0.0.0:8000 \ 36 | --output_dir ~/tmp \ 37 | --eval_name "mm_mt_bench" 38 | ``` 39 | 40 | **NOTE**: Evaluating MM-MT-Bench requires calls to GPT-4o as a judge, hence you'll need 41 | to set the `OPENAI_API_KEY` environment variable for the eval to work. 42 | 43 | For evaluating the other supported evals, see the **Evals** section. 44 | 45 | #### Evaluating a non-vLLM model 46 | 47 | To evaluate your own model, you can also create a `Model` class which implements a `__call__` method which takes as input a chat completion request and returns a string answer. Requests are provided in [vLLM API format](https://docs.vllm.ai/en/latest/models/vlm.html#openai-vision-api). 48 | 49 | ``` 50 | class CustomModel(Model): 51 | 52 | def __call__(self, request: dict[str, Any]): 53 | # Your model code 54 | ... 55 | return answer 56 | ``` 57 | 58 | ## Usage 59 | 60 | *You must not use this library or our models in a manner that infringes, misappropriates, or otherwise violates any third party’s rights, including intellectual property rights.* 61 | -------------------------------------------------------------------------------- /eval/tasks/mmmu.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import ast 4 | import re 5 | 6 | from PIL import Image 7 | 8 | from eval.metrics import ( 9 | Metric, 10 | ExplicitPromptRelaxedCorrectness, 11 | AnywhereInAnswerRelaxedCorrectness, 12 | ) 13 | from eval.task import HuggingFaceEval, Interaction 14 | 15 | 16 | PROMPT = """Analyze the image and question carefully, using step-by-step reasoning. 17 | First, describe any image provided in detail. Then, present your reasoning. And finally your final answer in this format: 18 | Final Answer: 19 | where is: 20 | - The single correct letter choice A, B, C, D, E, F, etc. when options are provided. Only include the letter. 21 | - Your direct answer if no options are given, as a single phrase or number. 22 | - If your answer is a number, only include the number without any unit. 23 | - If your answer is a word or phrase, do not paraphrase or reformat the text you see in the image. 24 | - You cannot answer that the question is unanswerable. You must either pick an option or provide a direct answer. 25 | IMPORTANT: Remember, to end your answer with Final Answer: .""" 26 | 27 | 28 | class MMMU(HuggingFaceEval): 29 | dataset_name = "lmms-lab/MMMU" 30 | dataset_split = "validation" 31 | 32 | def _to_interaction(self, row: dict[str, Any]) -> Interaction: 33 | content_chunks: list[dict[str, str | Image.Image]] = [] 34 | 35 | if row["question_type"] == "multiple-choice": 36 | choices = ast.literal_eval(row["options"]) 37 | options = [chr(ord("A") + i) for i in range(len(choices))] 38 | choices_str = "\n".join( 39 | [f"{option}. {choice}" for option, choice in zip(options, choices)] 40 | ) 41 | question = f"{row['question']}\n{choices_str}" 42 | else: 43 | question = row["question"] 44 | 45 | # pattern to split string on : 46 | split_pattern = r"()" 47 | # pattern to extract integer number to get image 48 | match_pattern = r"" 49 | text_img_chunks = re.split(pattern=split_pattern, string=question) 50 | text_img_chunks = [chunk for chunk in text_img_chunks if chunk.strip()] 51 | 52 | for chunk in text_img_chunks: 53 | # check to see if img 54 | match = re.fullmatch(match_pattern, chunk) 55 | 56 | # treating an image chunk 57 | if match: 58 | img_id = int(match.group(1)) # ignore 59 | img = row[f"image_{img_id}"] 60 | content_chunks.append({"type": "image", "image": img}) 61 | else: 62 | content_chunks.append({"type": "text", "text": chunk}) 63 | 64 | if content_chunks[-1]["type"] == "text": 65 | assert isinstance(content_chunks[-1]["text"], str) 66 | content_chunks[-1]["text"] += "\n" + PROMPT 67 | else: 68 | content_chunks.append({"type": "text", "text": PROMPT}) 69 | 70 | answer = ( 71 | ast.literal_eval(row["answer"]) if "[" in row["answer"] else [row["answer"]] 72 | ) 73 | 74 | return Interaction( 75 | { 76 | "temperature": 0.0, 77 | "max_tokens": 2048, 78 | "messages": [ 79 | { 80 | "role": "user", 81 | "content": content_chunks, 82 | } 83 | ], 84 | }, 85 | reference_answer=answer, 86 | ) 87 | 88 | @property 89 | def metric_fns(self) -> list[Metric]: 90 | return [ 91 | ExplicitPromptRelaxedCorrectness(), 92 | AnywhereInAnswerRelaxedCorrectness(), 93 | ] 94 | -------------------------------------------------------------------------------- /eval/task.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import copy 4 | import dataclasses 5 | from abc import ABC, abstractmethod 6 | from concurrent.futures import Future, ThreadPoolExecutor, as_completed 7 | import numpy as np 8 | from datasets import load_dataset 9 | from tqdm import tqdm 10 | 11 | from eval.metrics import Metric 12 | from eval.models import Model 13 | 14 | 15 | @dataclasses.dataclass 16 | class Interaction: 17 | """A single round of interaction from a model given a chat completion request.""" 18 | 19 | # vLLM compatible chat completion request 20 | request: dict[str, Any] 21 | 22 | # Reference answer(s). 23 | reference_answer: str | list[str] 24 | 25 | # Generated answer from model. 26 | model_answer: Optional[str] = None 27 | 28 | # Computed metrics (filled in after model answers are generated). 29 | metrics: dict[str, float] = dataclasses.field(default_factory=dict) 30 | 31 | # Extra metadata from dataset (e.g. category). 32 | meta: dict[str, Any] = dataclasses.field(default_factory=dict) 33 | 34 | 35 | class Eval(ABC): 36 | """Base class for an eval task.""" 37 | 38 | def __init__(self): 39 | self.interactions: list[Interaction] = [] 40 | 41 | @property 42 | def metric_fns(self) -> list[Metric]: 43 | """A list of metrics to compute for request-response pairs.""" 44 | raise NotImplementedError 45 | 46 | @abstractmethod 47 | def _to_interaction(self, row: Any): 48 | """Converts a row from eval dataset into Interaction object.""" 49 | raise NotImplementedError 50 | 51 | @abstractmethod 52 | def load_eval(self): 53 | """Loads dataset and applies transforms to get chat completion requests.""" 54 | raise NotImplementedError 55 | 56 | def get_responses(self, model: Model): 57 | """Queries model to get responses for each interaction.""" 58 | 59 | futures: dict[Future, Interaction] = {} 60 | with ThreadPoolExecutor(max_workers=8) as executor: 61 | for interaction in self.interactions: 62 | request = copy.deepcopy(interaction.request) 63 | futures[executor.submit(model, request)] = interaction 64 | 65 | interactions_w_model_ans = [] 66 | for future in tqdm( 67 | as_completed(futures), 68 | total=len(self.interactions), 69 | desc="Querying model", 70 | ): 71 | interaction = futures[future] 72 | interaction.model_answer = future.result() 73 | interactions_w_model_ans.append(interaction) 74 | self.interactions = interactions_w_model_ans 75 | 76 | def compute_metrics(self): 77 | """Computes metrics for each interaction.""" 78 | for interaction in tqdm(self.interactions): 79 | for metric in self.metric_fns: 80 | interaction.metrics[metric.name] = metric.score( 81 | interaction.model_answer, interaction.reference_answer 82 | ) 83 | 84 | def aggregate_metrics(self) -> dict[str, float]: 85 | """Aggregates metrics across all the interactions.""" 86 | overall_metrics: dict[str, float] = {} 87 | for metric in self.metric_fns: 88 | overall_metrics[metric.name] = np.mean( 89 | [interaction.metrics[metric.name] for interaction in self.interactions] 90 | ) # type: ignore 91 | return overall_metrics 92 | 93 | def save_question_answer(self) -> dict[str, dict[str, str] | str | list[str]]: 94 | """Save question and answer to a json file.""" 95 | 96 | result: dict[str, dict[str, str] | str | list[str]] = {} 97 | for i in range(len(self.interactions)): 98 | result[str(i)] = {"request": self.interactions[i].request, 99 | "model_answer": self.interactions[i].model_answer, 100 | "reference_answer": self.interactions[i].reference_answer} 101 | 102 | return result 103 | 104 | class HuggingFaceEval(Eval): 105 | """Evals hosted on hugging face for which datasets.load_dataset can be used.""" 106 | 107 | dataset_name: str 108 | dataset_split: str 109 | 110 | def get_dataset(self): 111 | return load_dataset(self.dataset_name)[self.dataset_split] 112 | 113 | def load_eval(self): 114 | """Loads dataset and applies transforms to get chat completion requests.""" 115 | for row in tqdm( 116 | self.get_dataset(), 117 | desc=f"Loading {self.dataset_name} [{self.dataset_split}]", 118 | ): 119 | self.interactions.append(self._to_interaction(row)) 120 | -------------------------------------------------------------------------------- /eval/models.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import base64 3 | import copy 4 | import json 5 | import io 6 | import re 7 | import time 8 | from typing import Any 9 | 10 | import requests 11 | 12 | 13 | class Model(ABC): 14 | @abstractmethod 15 | def __call__(self, request: dict[str, Any]) -> str: 16 | raise NotImplementedError 17 | 18 | 19 | class VLLMModel(Model): 20 | """Evaluates a model hosted using vLLM.""" 21 | 22 | def __init__(self, model_name: str, url: str): 23 | self.model_name = model_name 24 | self.url = url 25 | self._wait_till_healthy() 26 | 27 | def _wait_till_healthy(self) -> bool: 28 | base_url = self.url 29 | # wait for server to be ready 30 | assert base_url is not None 31 | match = re.match(r"^http.*:\d+$", base_url) 32 | assert match is not None, base_url 33 | 34 | health_endpoint = f"{base_url}/health" 35 | timeout = 120 36 | t0 = time.time() 37 | print(f"Waiting for VLLM server to come online at {health_endpoint} ...") 38 | print(f"Timeout is {timeout}s") 39 | while time.time() - t0 < timeout: 40 | print(f"Waiting for server ({int(time.time() - t0)}s) ...") 41 | 42 | # Query the endpoint 43 | try: 44 | req = requests.get(health_endpoint) 45 | print("Server is up!") 46 | except Exception: 47 | # Ignore exception 48 | pass 49 | else: 50 | if ( 51 | req.status_code == 200 52 | and req.content == b"" 53 | or req.json() == {"status": "OK"} 54 | ): 55 | return True 56 | 57 | # Backoff 58 | time.sleep(5) 59 | 60 | raise RuntimeError( 61 | f"Server not up in {int(timeout / 60)} minutes, something is wrong" 62 | ) 63 | 64 | def _emplace_image(self, ccr: dict[str, Any]): 65 | """Replaces image message with base64 encoded image.""" 66 | ccr = copy.deepcopy(ccr) 67 | for m in ccr["messages"]: 68 | if isinstance(m["content"], list): 69 | for c in m["content"]: 70 | if c["type"] == "image": 71 | c["type"] = "image_url" 72 | image = c.pop("image") 73 | stream = io.BytesIO() 74 | im_format = image.format or "PNG" 75 | image.save(stream, format=im_format) 76 | im_b64 = base64.b64encode(stream.getvalue()).decode("ascii") 77 | c["image_url"] = { 78 | "url": f"data:image/{im_format.lower()};base64,{im_b64}" 79 | } 80 | return ccr 81 | 82 | def __call__(self, request: dict[str, Any]) -> str: 83 | headers = { 84 | "Content-Type": "application/json", 85 | "Accept": "application/json", 86 | } 87 | 88 | # Convert images to base64 strings so they can be serialized. 89 | request_dict = self._emplace_image(request) 90 | 91 | # Retry 3 times with backoff 92 | max_retries = 3 93 | retries_left = max_retries 94 | backoff = 1.5 95 | request_dict["model"] = self.model_name 96 | while retries_left > 0: 97 | try: 98 | response = requests.post( 99 | f"{self.url}/v1/chat/completions", 100 | headers=headers, 101 | data=json.dumps(request_dict), 102 | ) 103 | 104 | if response.status_code != 200: 105 | response_json = json.dumps( 106 | json.loads(response.content.decode("utf-8")), indent=4 107 | ) 108 | raise ValueError( 109 | f"Request failed (code={response.status_code}):\n\nRESPONSE: {response_json}" 110 | ) 111 | 112 | completion_dict = response.json() 113 | assert completion_dict["choices"][0]["message"]["role"] == "assistant" 114 | return completion_dict["choices"][0]["message"]["content"] 115 | except Exception as e: 116 | print( 117 | f"Query to model failed, retrying ({max_retries - retries_left + 1} / {max_retries}): {e}", 118 | ) 119 | time.sleep(backoff) 120 | backoff *= 2 121 | retries_left -= 1 122 | # If querying server failed, raise an exception 123 | raise RuntimeError("Failed to get a response.") 124 | -------------------------------------------------------------------------------- /eval/tasks/mm_mt_bench.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import re 4 | import time 5 | from collections import defaultdict 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | from dataclasses import dataclass 8 | from typing import Any, Sequence, Optional 9 | 10 | import openai 11 | from tqdm import tqdm 12 | from datasets import load_dataset 13 | import numpy as np 14 | 15 | from eval.task import HuggingFaceEval, Interaction 16 | 17 | JUDGES = frozenset( 18 | [ 19 | "gpt-4o-2024-05-13", 20 | ] 21 | ) 22 | DEFAULT_TEMPERATURE = 0.0 23 | DEFAULT_MAX_TOKENS = 4096 24 | BRACKET_SCORE_RE = re.compile(r"\[\[(\d+\.?\d*)\]\]") 25 | 26 | 27 | @dataclass 28 | class Judgement: 29 | judgement: str 30 | grade: float 31 | 32 | 33 | class MultimodalLLMJudge: 34 | API_MAX_RETRY: int = 3 35 | JUDGE_DEFAULT_TEMPERATURE: float = 0.0 36 | JUDGE_MAX_TOKENS: int = 2048 37 | SYSTEM_PROMPT: str = 'Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the most recent question given the previous conversation as context. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant\'s answer. Begin your evaluation by comparing the assistant\'s answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example: "Rating: [[5]]".\n\n' 38 | 39 | def __init__(self, judge_name: str): 40 | self.judge_name = judge_name 41 | self.client = openai.OpenAI() 42 | 43 | def get_score(self, judgement: str) -> float: 44 | match = re.search(BRACKET_SCORE_RE, judgement) 45 | if match: 46 | rating = ast.literal_eval(match.groups()[0]) 47 | else: 48 | # Sometimes the judge fails to evaluate the generation 49 | rating = -1 50 | return rating 51 | 52 | def _add_or_append_chunk( 53 | self, prompt: list[dict[str, Any]], chunk: str | dict[str, Any] 54 | ): 55 | if isinstance(chunk, dict) and chunk["type"] == "image_url": 56 | return chunk 57 | 58 | text: str = chunk["text"] if isinstance(chunk, dict) else chunk 59 | assert isinstance(text, str) 60 | if prompt[-1]["type"] == "text": 61 | prompt[-1]["text"] += text 62 | else: 63 | prompt.append({"type": "text", "text": text}) 64 | 65 | def _replay_conversation( 66 | self, 67 | prompt: list[dict[str, Any]], 68 | questions: Sequence[str | Sequence[dict[str, Any]]], 69 | ref_answers: Sequence[str], 70 | final_answer: Optional[str] = None, 71 | ): 72 | for q, a in zip(questions, ref_answers): 73 | if isinstance(q, str): 74 | # Merge consecutive text blocks. 75 | self._add_or_append_chunk( 76 | prompt, f"### User:\n{q}\n\n### Reference answer:\n{a}\n\n" 77 | ) 78 | else: 79 | assert prompt[-1]["type"] == "text" 80 | prompt[-1]["text"] += "### User:\n" 81 | for sub_q in q: 82 | self._add_or_append_chunk(prompt, sub_q) 83 | self._add_or_append_chunk(prompt, f"\n\n### Reference answer:\n{a}\n\n") 84 | self._add_or_append_chunk( 85 | prompt, f"\n\n### Assistant's answer:\n{final_answer}\n\n" 86 | ) 87 | 88 | def _get_judge_prompt( 89 | self, 90 | questions: list[str | list[dict[str, Any]]], 91 | ref_answers: list[str], 92 | final_answer: str, 93 | ) -> list[dict[str, Any]]: 94 | # Each part of the prompt is either a string or an image. 95 | assert len(questions) == len(ref_answers) 96 | 97 | prompt: list[dict[str, Any]] = [ 98 | {"type": "text", "text": "<|The Start of Conversation with User|>\n\n"} 99 | ] 100 | self._replay_conversation(prompt, questions, ref_answers, final_answer) 101 | # Conversations always end in text answer from Assistant) 102 | assert prompt[-1]["type"] == "text" 103 | prompt[-1]["text"] += "<|The End of Conversation with User|>\n\n\n" 104 | 105 | return prompt 106 | 107 | def _query_judge(self, prompt): 108 | rating = -1.0 109 | judgement = "" 110 | n_trials = 0 111 | backoff = 1 112 | while True: 113 | try: 114 | response = self.client.chat.completions.create( 115 | model=self.judge_name, 116 | messages=[ 117 | { 118 | "role": "system", 119 | "content": self.SYSTEM_PROMPT, 120 | }, 121 | {"role": "user", "content": prompt}, 122 | ], 123 | max_tokens=self.JUDGE_MAX_TOKENS, 124 | temperature=self.JUDGE_DEFAULT_TEMPERATURE, 125 | ) 126 | judgement = response.choices[0].message.content 127 | rating = self.get_score(judgement) 128 | # If the score is -1 it means that we failed to get a score. 129 | if rating != -1.0: 130 | return Judgement(judgement, rating) 131 | except Exception as e: 132 | n_trials += 1 133 | if n_trials < self.API_MAX_RETRY: 134 | print( 135 | f"Error {e} - retrying {n_trials}/{self.API_MAX_RETRY}", 136 | ) 137 | time.sleep(backoff) 138 | backoff *= 2 139 | else: 140 | raise e 141 | 142 | def get_judgement(self, interaction: Interaction): 143 | questions = [m for m in interaction.request["messages"] if m["role"] == "user"] 144 | ref_answers = [ 145 | m for m in interaction.request["messages"] if m["role"] == "assistant" 146 | ] + [interaction.reference_answer] 147 | assert interaction.model_answer is not None 148 | prompt = self._get_judge_prompt( 149 | questions, ref_answers, interaction.model_answer 150 | ) 151 | judgement = self._query_judge(prompt) 152 | interaction.meta["judgement"] = judgement.judgement 153 | interaction.metrics["score"] = judgement.grade 154 | return interaction 155 | 156 | 157 | def run_judge(judge_name: str, interactions: list[Interaction]): 158 | judge = MultimodalLLMJudge(judge_name) 159 | futures = [] 160 | graded_interactions = [] 161 | with ThreadPoolExecutor(max_workers=8) as executor: 162 | for interaction in tqdm(interactions): 163 | futures.append(executor.submit(judge.get_judgement, interaction)) 164 | 165 | for future in tqdm( 166 | as_completed(futures), total=len(interactions), desc="Querying judge" 167 | ): 168 | graded_interactions.append(future.result()) 169 | return graded_interactions 170 | 171 | 172 | class MultimodalMTBench(HuggingFaceEval): 173 | dataset_name = "mistralai/MM-MT-Bench" 174 | dataset_split = "eval" 175 | judge = "gpt-4o-2024-05-13" 176 | 177 | def _to_interaction(self, row: Any): 178 | # Unused for this class, but we need a concrete implementation. 179 | raise NotImplementedError 180 | 181 | def load_eval(self): 182 | ds = load_dataset(self.dataset_name)[self.dataset_split] 183 | for example in tqdm(ds, f"Loading {self.dataset_name} [{self.dataset_split}]"): 184 | messages = json.loads(example["conversation"]) 185 | image = example["image"] 186 | category = example["category"] 187 | for index in range(len(messages)): 188 | if index == 0: 189 | # Image is always the first chunk of first message. 190 | assert messages[0]["content"][0]["type"] == "image" 191 | messages[0]["content"][0]["image"] = image 192 | 193 | if index % 2 == 0: 194 | assert messages[index]["role"] == "user" 195 | new_ccr = { 196 | "temperature": DEFAULT_TEMPERATURE, 197 | "max_tokens": DEFAULT_MAX_TOKENS, 198 | "messages": messages[: index + 1], 199 | } 200 | ref_answer: str = messages[index + 1]["content"] 201 | 202 | self.interactions.append( 203 | Interaction( 204 | request=new_ccr, 205 | reference_answer=ref_answer, 206 | meta={"category": category, "turn": index // 2}, 207 | ) 208 | ) 209 | 210 | def compute_metrics(self): 211 | self.interactions = run_judge(self.judge, self.interactions) 212 | 213 | def aggregate_metrics(self) -> dict[str, float]: 214 | category_scores = defaultdict(list) 215 | micro_average_score = float( 216 | np.mean([interaction.metrics["score"] for interaction in self.interactions]) 217 | ) 218 | for interaction in self.interactions: 219 | # TODO: rename to grade 220 | score = interaction.metrics["score"] 221 | category_scores[interaction.meta["category"]].append( 222 | score 223 | ) # average by question type 224 | category_scores[interaction.meta["turn"]].append(score) # average by turn 225 | category_averages = { 226 | f"{cat}_average": float(np.mean(v)) for cat, v in category_scores.items() 227 | } 228 | category_macro_average = float(np.mean(list(category_averages.values()))) 229 | return { 230 | "micro_average_score": micro_average_score, 231 | "macro_average_score": category_macro_average, 232 | **category_averages, 233 | } 234 | -------------------------------------------------------------------------------- /eval/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | 5 | def _normalize_string(s): 6 | if (s.startswith('"') and s.endswith('"')) or ( 7 | s.startswith("'") and s.endswith("'") 8 | ): 9 | return s[1:-1] 10 | return s 11 | 12 | 13 | def _remove_end_punctuation(unnormalized_string: str) -> str: 14 | while ( 15 | unnormalized_string 16 | and ( 17 | unnormalized_string[-1] in string.punctuation 18 | or unnormalized_string[-1].isspace() 19 | ) 20 | and unnormalized_string[-1] != "%" 21 | ): 22 | unnormalized_string = unnormalized_string[:-1] 23 | return unnormalized_string 24 | 25 | 26 | class Metric: 27 | """Base class for metrics.""" 28 | 29 | @property 30 | def name(self) -> str: 31 | raise NotImplementedError 32 | 33 | def score(self, model_answer: str, reference_answer: str | list[str]) -> float: 34 | raise NotImplementedError 35 | 36 | 37 | class VQAMatch(Metric): 38 | """VQA match metric which gives partial score if less than 3 answers are matched.""" 39 | 40 | @property 41 | def name(self) -> str: 42 | return "vqa_match" 43 | 44 | def score(self, model_answer: str, reference_answer: str | list[str]) -> float: 45 | if not isinstance(reference_answer, list): 46 | reference_answer = [reference_answer] 47 | normalize_response_text: str = _normalize_string(model_answer) 48 | matching_answers = [ 49 | answer 50 | for answer in reference_answer 51 | if _normalize_string(answer) == normalize_response_text 52 | ] 53 | return min(1.0, float(len(matching_answers)) / 3) 54 | 55 | 56 | class ANLS(Metric): 57 | @property 58 | def name(self) -> str: 59 | return "anls" 60 | 61 | def _edit_distance_helper(self, s1: str, s2: str) -> float: 62 | if len(s1) > len(s2): 63 | s1, s2 = s2, s1 64 | distances = list(range(len(s1) + 1)) 65 | for i2, c2 in enumerate(s2): 66 | distance_list = [i2 + 1] 67 | for i1, c1 in enumerate(s1): 68 | if c1 == c2: 69 | distance_list.append(distances[i1]) 70 | else: 71 | distance_list.append( 72 | 1 + min((distances[i1], distances[i1 + 1], distance_list[-1])) 73 | ) 74 | distances = distance_list 75 | return distances[-1] 76 | 77 | def score(self, model_answer: str, reference_answer: str | list[str]) -> float: 78 | if not isinstance(reference_answer, list): 79 | reference_answer = [reference_answer] 80 | 81 | model_answer = " ".join(model_answer.strip().lower().split()) 82 | model_answer = _remove_end_punctuation(model_answer) 83 | 84 | min_value = float("inf") 85 | for ref in reference_answer: 86 | # Post-processing: Normalize spaces and remove punctuations. 87 | ref = " ".join(ref.strip().lower().split()) 88 | ref = _remove_end_punctuation(ref) 89 | 90 | # Compute edit distance 91 | dist = self._edit_distance_helper(ref, model_answer) 92 | length = max(len(ref), len(model_answer)) 93 | value = 0.0 if length == 0 else float(dist) / float(length) 94 | if value < min_value: 95 | min_value = value 96 | 97 | anls_threshold = 0.0 98 | output = 0.0 if 1 - min_value < anls_threshold else 1 - min_value 99 | return output 100 | 101 | 102 | class RelaxedCorrectness(Metric): 103 | """Relaxed correctness metrics. 104 | 105 | The correctness tolerates certain error ratio defined by max_relative_change. 106 | See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: 107 | "Following Methani et al. (2020), we use a relaxed accuracy measure for the 108 | numeric answers to allow a minor inaccuracy that may result from the automatic 109 | data extraction process. We consider an answer to be correct if it is within 110 | 5% of the gold answer. For non-numeric answers, we still need an exact match 111 | to consider an answer to be correct." 112 | """ 113 | 114 | def _relaxed_correctness( 115 | self, prediction: str, targets: list[str], max_relative_change: float = 0.05 116 | ) -> float: 117 | def _to_float(text: str) -> tuple[float | None, bool]: 118 | text = text.strip() 119 | is_percent = text.endswith("%") 120 | try: 121 | value = float(text.rstrip("%")) 122 | return value, is_percent 123 | except ValueError: 124 | return None, False 125 | 126 | def _is_letter(text: str) -> bool: 127 | return text.isalpha() and len(text) == 1 128 | 129 | def _preprocess_text(text: str) -> str: 130 | if not any(char.isdigit() for char in text): 131 | return _normalize_string(text) 132 | else: 133 | return _remove_end_punctuation(text).replace(",", "").replace("$", "") 134 | 135 | def calculate_relative_change(prediction: float, target: float) -> float: 136 | return abs(prediction - target) / max(abs(target), 1e-10) 137 | 138 | def _compare_numeric_values( 139 | prediction: float, target: float, max_relative_change: float 140 | ) -> float: 141 | relative_change = calculate_relative_change(prediction, target) 142 | return 1.0 if relative_change <= max_relative_change else 0.0 143 | 144 | def _compare_text_values(prediction: str, target: str) -> float: 145 | return 1.0 if prediction.lower() == target.lower() else 0.0 146 | 147 | def _to_decimal(value: float, is_percent: bool) -> float: 148 | return value / 100 if is_percent else value 149 | 150 | def _compare_numeric_with_percent( 151 | prediction: float, 152 | prediction_is_percent: bool, 153 | target: float, 154 | target_is_percent: bool, 155 | max_relative_change: float, 156 | ) -> float: 157 | # Compare as-is 158 | value = _compare_numeric_values(prediction, target, max_relative_change) 159 | 160 | # If not equal and one is percent, try other comparisons 161 | if value != 1.0 and (prediction_is_percent or target_is_percent): 162 | value = max( 163 | value, 164 | _compare_numeric_values( 165 | _to_decimal(prediction, prediction_is_percent), 166 | target, 167 | max_relative_change, 168 | ), 169 | _compare_numeric_values( 170 | prediction, 171 | _to_decimal(target, target_is_percent), 172 | max_relative_change, 173 | ), 174 | ) 175 | return value 176 | 177 | prediction = _preprocess_text(prediction) 178 | prediction_float, prediction_is_percent = _to_float(prediction) 179 | 180 | value_list = [] 181 | for target in targets: 182 | target = _preprocess_text(target) 183 | target_float, target_is_percent = _to_float(target) 184 | 185 | if prediction_float is not None and target_float is not None: 186 | # Compare as numeric values 187 | value = _compare_numeric_with_percent( 188 | prediction_float, 189 | prediction_is_percent, 190 | target_float, 191 | target_is_percent, 192 | max_relative_change, 193 | ) 194 | elif _is_letter(target) and len(prediction) > 0: 195 | # Compare as multiple choice options: take first letter from prediction 196 | value = 1.0 if prediction[0].lower() == target.lower() else 0.0 197 | else: 198 | # Compare as text values 199 | value = _compare_text_values(prediction, target) 200 | 201 | value_list.append(value) 202 | 203 | return max(value_list) 204 | 205 | def score(self, model_answer: str, reference_answer: str | list[str]) -> float: 206 | reference_answer = ( 207 | reference_answer 208 | if isinstance(reference_answer, list) 209 | else [reference_answer] 210 | ) 211 | return self._relaxed_correctness(model_answer, reference_answer) 212 | 213 | 214 | class ExplicitPromptRelaxedCorrectness(RelaxedCorrectness): 215 | """Relaxed correctness for explicit prompt.""" 216 | 217 | @property 218 | def name(self) -> str: 219 | return "explicit_prompt_relaxed_correctness" 220 | 221 | def _get_final_answer(self, generation: str) -> str: 222 | def _find_last_occurrence(pattern: str, string: str): 223 | return string.rfind(pattern) 224 | 225 | # Strip extraneous markdown around the answer: 226 | generation = re.sub(r"([aA]nswer)\**:\**", "\\1:", generation) 227 | 228 | final_answer_index = _find_last_occurrence("answer:", generation.lower()) 229 | 230 | if final_answer_index != -1: 231 | # Find the start of the answer (after "final answer:") 232 | start_index = final_answer_index + len("answer:") 233 | 234 | # Split the remaining text into lines 235 | lines = generation[start_index:].split("\n") 236 | 237 | # Find the first non-empty line 238 | final_answer = next((line.strip() for line in lines if line.strip()), "") 239 | 240 | # Remove any markdown formatting 241 | final_answer = re.sub(r"[*_\[\]\(\)]", "", final_answer) 242 | 243 | return final_answer 244 | else: 245 | return "" 246 | 247 | def score(self, model_answer: str, reference_answer: str | list[str]) -> float: 248 | parsed_model_answer = self._get_final_answer(model_answer) 249 | if not parsed_model_answer: 250 | # Parsing failed. 251 | return 0.0 252 | return super().score(parsed_model_answer, reference_answer) 253 | 254 | 255 | class AnywhereInAnswerRelaxedCorrectness(ExplicitPromptRelaxedCorrectness): 256 | """Falls back to handle cases where reference answer appears anywhere in generation. 257 | 258 | NOTE: This is an overly generous metric and is likely to falsely inflate scores. 259 | """ 260 | 261 | @property 262 | def name(self) -> str: 263 | return "anywhere_in_answer_relaxed_correctness" 264 | 265 | def score(self, model_answer: str, reference_answer: str | list[str]) -> float: 266 | reference_answer = ( 267 | reference_answer 268 | if isinstance(reference_answer, list) 269 | else [reference_answer] 270 | ) 271 | parsed_model_answer = self._get_final_answer(model_answer) 272 | if parsed_model_answer: 273 | return self._relaxed_correctness(parsed_model_answer, reference_answer) 274 | 275 | # Fallback: check if reference answer appears anywhere in the model answer. 276 | for ref in reference_answer: 277 | try: 278 | # Try to parse as a float 279 | number = float(ref) 280 | 281 | # Revert to int if it is actually an int. 282 | if int(number) == number: 283 | number = int(number) 284 | # Check if the number is in the model answer with commas (e.g. 1,000) 285 | if format(number, ",") in model_answer: 286 | return 1.0 287 | # Check if the number is in the model answer without commas (e.g. 1000) 288 | elif str(number) in model_answer: 289 | return 1.0 290 | elif str(number) + "%" in model_answer: 291 | return 1.0 292 | except ValueError: 293 | # Reference answer was a text string. We search for typical patterns 294 | # in the model answer. Note that directly searching for the reference 295 | # is not a good idea for letter-option choice questions, hence we look 296 | # for common patterns. This is still heuristic, and might have false 297 | # positives as well as false negatives. 298 | candidates = [] 299 | for ref in reference_answer: 300 | candidates.extend( 301 | [ 302 | f"is {ref}", 303 | f"was {ref}", 304 | f" {ref}.", 305 | f"are {ref}", 306 | f"\n\n{ref}", 307 | ] 308 | ) 309 | if any([c.lower() in model_answer for c in candidates]): 310 | return 1.0 311 | 312 | return 0 313 | --------------------------------------------------------------------------------