├── .gitignore ├── LICENSE ├── README.md ├── browse-data.ipynb ├── evaluate.py ├── gpt_evaluate.py ├── infer_correction.py ├── infer_critique.py ├── infer_critique_lookback.py ├── prompts ├── correction.txt ├── critique.txt ├── gpt_evaluate.txt ├── lookback_synthesize.txt └── lookback_visual-query.txt ├── requirements.txt ├── src_evaluation ├── CLEVR_evaluation.py ├── EmbSpatial_evaluation.py ├── FigureQA_evaluation.py ├── GQA_evaluation.py ├── HallusionBench_evaluation.py ├── MMMU_evaluation.py ├── MMVet_evaluation.py ├── MathVision_Evaluation.py ├── MathVista_evaluation.py ├── POPE_evaluation.py ├── PlotQA_evaluation.py ├── SceMQA_evaluation.py ├── ScienceQA_evaluation.py ├── TallyQA_evaluation.py ├── VQA_evaluation.py ├── VSR_evaluation.py ├── WeMathEvaluation.py └── evaluate.py ├── static ├── examples.1.jpg └── teaser.jpg └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 PlusLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VISCO 2 | 3 | **Benchmarking Fine-Grained Critique and Correction Towards Self-Improvement in Visual Reasoning** 4 | 5 | [🌐 Project](https://visco-benchmark.github.io/) | [🤗 Dataset](https://huggingface.co/datasets/uclanlp/VISCO) | [📖 Paper](https://arxiv.org/abs/2412.02172) 6 | 7 |

🎉Accepted to CVPR 2025!🎉

8 | 9 |
10 | 11 | Outline: 12 | * [Introduction](#introduction) 13 | * [Data](#data) 14 | * [Environment](#environment) 15 | * [Run models](#run-models) 16 | * [Run evaluation](#run-evaluation) 17 | 18 | ## Introduction 19 | 20 | VISCO is a benchmark for evaluating the critique and correction capabilities of LVLMs. VISCO contains: 21 | 1. **1645 pairs of questions and LVLM-generated answers**. Each answer includes a chain-of-thought with multiple reasoning steps. 22 | 2. **5604 step-wise annotations of critique**, showing whether each step is correct or incorrect, along with a natural language explanation if a step is incorrect. 23 | 24 | ![Overview](static/teaser.jpg) 25 | 26 | ## Data 27 | 28 | ![Example](static/examples.1.jpg) 29 | 30 | Please download our data from [here](https://huggingface.co/datasets/uclanlp/VISCO). Due to the variable data type, the huggingface dataset viewer does not work properly. Please use [browse-data.ipynb](browse-data.ipynb) to browse the data. 31 | 32 | Each line in `test.jsonl` corresponds to a data point, with a question, an LVLM-generated answer, and critique annotations for each step. For the data example shown above, its json representation is: 33 | ```json 34 | { 35 | "id": "test1||Reasoning||Math||MathVista||5", 36 | "question": "What is the center of symmetry of this function?\nChoices:\n(A) (0, 0)\n(B) (-1, 0)\n(C) (2, 0)", 37 | "image": IMAGE_BASE64_REPR, 38 | "label": "(0, 0)", 39 | "response": { 40 | "reasoning": [ 41 | "A function with a center of symmetry at a point (a, b) implies the graph of the function is symmetric when rotated 180 degrees about this point.", 42 | "Observing the graph, it appears that this symmetry exists around the point (2, 0).", 43 | "Points on opposite sides of (2, 0) have values equidistant from this center, confirming symmetry around this point." 44 | ], 45 | "answer": "(2, 0)" 46 | }, 47 | "answer_correctness": false, 48 | "reasoning_correctness": [true, false, false], 49 | "reasoning_critic": [ 50 | ["", "", ""], 51 | ["The symmetry exists around (0, 0), not (2, 0).", SECOND_REFERENCE, THIRD_REFERENCE], 52 | ["Points on opposite sides of (2, 0) do not have values equidistant from this center.", SECOND_REFERENCE, THIRD_REFERENCE], 53 | ], 54 | "meta_data": { 55 | "critic_superskill": "Reasoning", 56 | "critic_skill": "Math", 57 | "src_dataset": "MathVista", 58 | "src_model": "GPT-4o", 59 | ...META_DATA_FOR_ORIGINAL_DATASET 60 | } 61 | } 62 | ``` 63 | 64 | Notes: 65 | * The field `response` is the answer generated by LVLMs. It includes a chain-of-thought (field `reasoning`) and the final answer (field `answer`). 66 | * Annotations for critique include three parts: the binary critique for final answer (`answer_correctness`), the binary critique for each step (`reasoning_correctness`), and the natural language critique for each step (`reasoning_critic`). 67 | * Note that for each step, we have three different references produced by three different annotators. All references are considered when doing the final evaluation. 68 | * Also note that we only provide natural language critiques for incorrect steps. 69 | 70 | ## Environment 71 | 72 | To install the minimal requirements: 73 | ```bash 74 | pip install -r requirements.txt 75 | ``` 76 | 77 | However, note that **this requirement does not include requirements for fast serving frameworks** such as vllm, lmdeploy and sglang. To install these packages, please first [install pytorch](https://pytorch.org/get-started/locally/), and then follow their documents to install their latest versions respectively. If you want to use multiple fast serving frameworks, it is recommended to maintain multiple environments, one for each fast serving framework, because they may have conflicts in dependencies. 78 | 79 | ## Run models 80 | 81 | Download the data from [huggingface](https://huggingface.co/datasets/uclanlp/VISCO) and put `test.jsonl` under this directory. Then, use the following scripts: 82 | 83 | ### Critique 84 | 85 | Run `python infer_critique.py --input test.jsonl --output OUTPUT_FILE` 86 | 87 | * If you're using proprietary LVLMs such as OpenAI, Anthropic and Gemini models, use `--model XXX` to specify the model and use `--api_key` to provide your API key. The proprietary models we test include `gpt-4o-2024-08-06`, `claude-3-5-sonnet-20240620` and `gemini-1.5-pro` 88 | * If you're using open LVLMs, you can locally launch an OpenAI compatible server, and then use the same script. Then, you should specify the following arguments: set `--base_url` as your server URL, `--model` as your model name or `auto`, and set `--api_key` to your API key. An example of how to launch OpenAI compatible server with vllm is [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). 89 | * Alternatively, you can specify `--model XXX` and add `--launch_locally BACKEND` argument to the script, so the script will automatically launch a server and make requests to the launched server. Supported backend include `lmdeploy`, `vllm` and `sglang`. Note that **this requires you to properly install the backend packages first**. The framework we use for evaluating each model is as follows: 90 | 91 | | Framework | Model(s) | 92 | |---|---| 93 | | vllm | Qwen2-VL, Molmo, Llama-3.2, NVLM | 94 | | lmdeploy | InternVL2, DeepSeek-VL, LLaVA-v1.5, LLaVA-v1.6, Qwen-VL, Prometheus-Vision | 95 | | sglang | LLaVA-OV, LLaVA-Critic | 96 | 97 | * If you want to use your custom inference code, please rewrite `def infer` in `utils.py`. 98 | 99 | ### Critique with LookBack 100 | 101 | Run `python infer_critique_lookback.py --input test.jsonl --output OUTPUT_FILE`. The other arguments are the same as `infer_critique.py`. Note that this script will be slower and takes more API calls, so remember to monitor your API usage. 102 | 103 | ### Correction 104 | 105 | Run `python infer_correction.py --input test.jsonl --output OUTPUT_FILE` 106 | * For correction with human critique, use argument `--critique human`. The script will use the critique annotations in `test.jsonl`. 107 | * For correction with model-generated critique, use argument `--critique CRITIQUE_FILE`, where `CRITIQUE_FILE` is the output file generated by `infer_critique.py`. 108 | * By default, the correction script use the full critique, including answer-level critique, step-level critique and explanation-level critique. If you only want to use more coarse-grained critique, set `--critique_setting A` to only use answer-level binary critique, and set `--critique_setting AS` to only use answer-level and step-level binary critique. 109 | 110 | The other arguments are the same as `infer_critique.py`. 111 | 112 | ## Run evaluation 113 | 114 | ### Critique 115 | 116 | First, run LLM-assisted evaluation of explanation-level F1: 117 | ```bash 118 | python gpt_evaluate.py YOUR_OUTPUT_FILE --input test.jsonl 119 | ``` 120 | Remember to set environment variable `OPENAI_API_KEY` so the script can have OpenAI access. The evaluation results will be saved to a cache file `YOUR_OUTPUT_FILE.gpt_evaluate_cache`. 121 | 122 | Then, run `evaluate.py` to calculate all the full metrics including VISCore. 123 | ```bash 124 | python evaluate.py YOUR_OUTPUT_FILE --input test.jsonl --task critique 125 | ``` 126 | 127 | ### Correction 128 | ```bash 129 | python evaluate.py YOUR_OUTPUT_FILE --input test.jsonl --task correction 130 | ``` 131 | 132 | ## Citation 133 | Please cite our paper if this repository inspires your work! 134 | 135 | ``` 136 | @inproceedings{wu2025visco, 137 | title={Visco: Benchmarking fine-grained critique and correction towards self-improvement in visual reasoning}, 138 | author={Wu, Xueqing and Ding, Yuheng and Li, Bingxuan and Lu, Pan and Yin, Da and Chang, Kai-Wei and Peng, Nanyun}, 139 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 140 | pages={9527--9537}, 141 | year={2025} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /browse-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "This is example code that reads and displays VISCO dataset.\n", 21 | "\n", 22 | "[🌐 Project](https://visco-benchmark.github.io/) | [🤗 Dataset](https://huggingface.co/datasets/uclanlp/VISCO) | [📖 Paper](https://arxiv.org/abs/2412.02172)\n", 23 | "\n", 24 | "
\n", 25 | "\n", 26 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PlusLabNLP/VISCO/blob/main/browse-data.ipynb)" 27 | ], 28 | "metadata": { 29 | "id": "m-V84haHNoUh" 30 | } 31 | }, 32 | { 33 | "cell_type": "code", 34 | "source": [ 35 | "from huggingface_hub import hf_hub_download\n", 36 | "\n", 37 | "fname = hf_hub_download(repo_id=\"uclanlp/VISCO\", filename=\"test.jsonl\", repo_type='dataset')\n", 38 | "with open(fname, 'r') as f:\n", 39 | " lines = f.readlines()\n", 40 | "print(\"Read %d lines\" % len(lines))" 41 | ], 42 | "metadata": { 43 | "colab": { 44 | "base_uri": "https://localhost:8080/" 45 | }, 46 | "id": "RTgvWMnfN59G", 47 | "outputId": "f0be7b8b-72c1-465f-bdaa-81e048c910a7" 48 | }, 49 | "execution_count": 1, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stderr", 54 | "text": [ 55 | "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", 56 | "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", 57 | "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", 58 | "You will be able to reuse this secret in all of your notebooks.\n", 59 | "Please note that authentication is recommended but still optional to access public models or datasets.\n", 60 | " warnings.warn(\n" 61 | ] 62 | }, 63 | { 64 | "output_type": "stream", 65 | "name": "stdout", 66 | "text": [ 67 | "Read 1645 lines\n" 68 | ] 69 | } 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "source": [ 75 | "import json\n", 76 | "import random\n", 77 | "from PIL import Image\n", 78 | "from io import BytesIO\n", 79 | "import base64\n", 80 | "import matplotlib.pyplot as plt\n", 81 | "\n", 82 | "index = 5\n", 83 | "print(\"Display data %d\\n\" % index)\n", 84 | "line = json.loads(lines[index])\n", 85 | "\n", 86 | "print(\"--- Question:\")\n", 87 | "print(line['question'])\n", 88 | "print()\n", 89 | "\n", 90 | "print(\"--- Label:\", line['label'])\n", 91 | "print()\n", 92 | "\n", 93 | "print(\"--- Model CoT:\")\n", 94 | "for i in range(len(line['response']['reasoning'])):\n", 95 | " print(\"{:d}. {:s}\".format(i + 1, line['response']['reasoning'][i]))\n", 96 | "print(\"--- Model answer:\", line['response']['answer'])\n", 97 | "print()\n", 98 | "\n", 99 | "print(\"--- Critique:\")\n", 100 | "for i in range(len(line['response']['reasoning'])):\n", 101 | " print(\"{:d}. {:s}\".format(i + 1, \"Correct\" if line['reasoning_correctness'][i] else \"Incorrect\"))\n", 102 | " for j in range(3): # three references for explanation\n", 103 | " print(\" - Explanation {:d}:\".format(j + 1), line['reasoning_critic'][i][j])\n", 104 | "print(\"Answer:\", \"Correct\" if line['answer_correctness']else \"Incorrect\")\n", 105 | "\n", 106 | "plt.imshow(Image.open(BytesIO(base64.b64decode(line['image']))))\n", 107 | "plt.show()" 108 | ], 109 | "metadata": { 110 | "colab": { 111 | "base_uri": "https://localhost:8080/", 112 | "height": 973 113 | }, 114 | "id": "ZA8ksb1SPazG", 115 | "outputId": "b166665f-ba99-4664-cc25-70c5ce4e3001" 116 | }, 117 | "execution_count": 2, 118 | "outputs": [ 119 | { 120 | "output_type": "stream", 121 | "name": "stdout", 122 | "text": [ 123 | "Display data 5\n", 124 | "\n", 125 | "--- Question:\n", 126 | "What is the center of symmetry of this function?\n", 127 | "Choices:\n", 128 | "(A) (0, 0)\n", 129 | "(B) (-1, 0)\n", 130 | "(C) (2, 0)\n", 131 | "\n", 132 | "--- Label: (0, 0)\n", 133 | "\n", 134 | "--- Model CoT:\n", 135 | "1. A function with a center of symmetry at a point (a, b) implies the graph of the function is symmetric when rotated 180 degrees about this point.\n", 136 | "2. Observing the graph, it appears that this symmetry exists around the point (2, 0).\n", 137 | "3. Points on opposite sides of (2, 0) have values equidistant from this center, confirming symmetry around this point.\n", 138 | "--- Model answer: (2, 0)\n", 139 | "\n", 140 | "--- Critique:\n", 141 | "1. Correct\n", 142 | " - Explanation 1: \n", 143 | " - Explanation 2: \n", 144 | " - Explanation 3: \n", 145 | "2. Incorrect\n", 146 | " - Explanation 1: The symmetry exists around (0, 0), not (2, 0).\n", 147 | " - Explanation 2: It's around (0,0)\n", 148 | " - Explanation 3: It is not symmetric around point (2, 0). Instead it obtains symmetry with point (0, 0).\n", 149 | "3. Incorrect\n", 150 | " - Explanation 1: Points on opposite sides of (2, 0) do not have values equidistant from this center.\n", 151 | " - Explanation 2: Values that are located on the left side of (2,0) do not already have values equidistant from (2,0) to the values on the right side of (2,0).\n", 152 | " - Explanation 3: Points on opposite sides of (2, 0) do not have equal distance to point (2, 0).\n", 153 | "Answer: Incorrect\n" 154 | ] 155 | }, 156 | { 157 | "output_type": "display_data", 158 | "data": { 159 | "text/plain": [ 160 | "
" 161 | ], 162 | "image/png": "\n" 163 | }, 164 | "metadata": {} 165 | } 166 | ] 167 | } 168 | ] 169 | } 170 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import os 5 | 6 | import numpy as np 7 | import tabulate 8 | 9 | from gpt_evaluate import _calc_gpt_metrics 10 | from src_evaluation.evaluate import evaluate 11 | 12 | 13 | def _f1(gt, pred): 14 | assert len(gt) == len(pred) 15 | tp = sum(a is False and b is False for a, b in zip(gt, pred)) 16 | gt_pos = gt.count(False) 17 | pred_pos = pred.count(False) 18 | if tp == 0: 19 | return 0 20 | else: 21 | p = tp / pred_pos 22 | r = tp / gt_pos 23 | return 2 / (1 / p + 1 / r) 24 | 25 | 26 | def _evaluate_critic(data, responses, gpt_responses): 27 | gt_ans = [] 28 | pred_ans = [] 29 | gt_th = [] 30 | pred_th = [] 31 | refs_ex = [] 32 | sys_ex = [] 33 | assert len(data) == len(responses) 34 | for i in range(len(data)): 35 | assert len(data[i]['reasoning_critic']) == len(responses[i]['formatted']['reasoning_critic']) 36 | 37 | gt_ans.append(data[i]['answer_correctness']) 38 | pred_ans.append(responses[i]['formatted']['answer_correctness']) 39 | gt_th += data[i]['reasoning_correctness'] 40 | pred_th += responses[i]['formatted']['reasoning_correctness'] 41 | refs_ex += data[i]['reasoning_critic'] 42 | sys_ex += responses[i]['formatted']['reasoning_critic'] 43 | 44 | ret = { 45 | 'Ans. F1': _f1(gt_ans, pred_ans) * 100, 46 | 'Th. F1': _f1(gt_th, pred_th) * 100, 47 | } 48 | if gpt_responses is not None: 49 | ex_f1 = _calc_gpt_metrics(data, responses, gpt_responses) 50 | viscore = pow(ret['Ans. F1'] * ret['Th. F1'] * ex_f1, 1 / 3) 51 | ret['Ex. F1'] = ex_f1 52 | ret['VISCore'] = viscore 53 | return ret 54 | 55 | 56 | def evaluate_critique(data, responses, gpt_responses=None, do_print=True): 57 | if do_print: 58 | print("Format error: {:d} / {:d}\n".format(sum(r['format_error'] for r in responses), len(responses))) 59 | 60 | # Remove critic for steps predicted as correct: not necessary 61 | responses = copy.deepcopy(responses) 62 | for r in responses: 63 | for i in range(len(r['formatted']['reasoning_correctness'])): 64 | if r['formatted']['reasoning_correctness'][i]: 65 | r['formatted']['reasoning_critic'][i] = '' 66 | 67 | metrics = { 68 | 'Total': _evaluate_critic(data, responses, gpt_responses), 69 | 'Reasoning': _evaluate_critic( 70 | [data[i] for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Reasoning'], 71 | [responses[i] for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Reasoning'], 72 | [gpt_responses[i] for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Reasoning'] 73 | if gpt_responses is not None else None, 74 | ), 75 | 'Perception': _evaluate_critic( 76 | [data[i] for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Perception'], 77 | [responses[i] for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Perception'], 78 | [gpt_responses[i] for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Perception'] 79 | if gpt_responses is not None else None, 80 | ), 81 | } 82 | 83 | if do_print: 84 | KEYS = ['Ans. F1', 'Th. F1', ] 85 | if gpt_responses is not None: 86 | KEYS += ['Ex. F1', 'VISCore', ] 87 | print(tabulate.tabulate([[category, ] + [metrics[category][k] for k in KEYS] for category in metrics], 88 | headers=KEYS, floatfmt=[None, ] + ['.2f', ] * len(KEYS))) 89 | 90 | return metrics 91 | 92 | 93 | def evaluate_correction(data, responses, do_print=True): 94 | accuracy_pre = [] 95 | accuracy_post = [] 96 | for i in range(len(data)): 97 | accuracy_pre.append(not data[i]['id'].startswith("test1")) 98 | accuracy_post.append(evaluate(responses[i]['answer'], data[i]['label'], data[i]['meta_data'])) 99 | accuracy_pre = np.array(accuracy_pre) 100 | accuracy_post = np.array(accuracy_post) 101 | correction_score = accuracy_post[~accuracy_pre].mean() - (1 - accuracy_post)[accuracy_pre].mean() 102 | accuracy_post = np.mean(accuracy_post) 103 | if do_print: 104 | print("Accuracy = {:.2f}".format(accuracy_post * 100)) 105 | print("Correction score = {:.2f}".format(correction_score * 100)) 106 | return {'accuracy': accuracy_post, 'correction_score': correction_score} 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('output') 112 | parser.add_argument('--task', default='critique', choices=['critique', 'correction', ]) 113 | parser.add_argument('--input', default='test.jsonl') 114 | args = parser.parse_args() 115 | 116 | with open(args.input) as f: 117 | data = [json.loads(line) for line in f] 118 | with open(args.output) as f: 119 | responses = [json.loads(line) for line in f] 120 | 121 | if args.task == 'critique': 122 | gpt_responses = None 123 | if os.path.exists(args.output + '.gpt_evaluate_cache'): 124 | with open(args.output + '.gpt_evaluate_cache') as f: 125 | gpt_responses = [json.loads(line) for line in f] 126 | evaluate_critique(data, responses, gpt_responses) 127 | else: 128 | assert args.task == 'correction' 129 | evaluate_correction(data, responses) 130 | -------------------------------------------------------------------------------- /gpt_evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import time 6 | 7 | import numpy as np 8 | import tqdm 9 | from openai import OpenAI 10 | 11 | from utils import get_pool 12 | 13 | 14 | def func(obj): 15 | i, j, k, image, query = obj 16 | 17 | client = OpenAI() 18 | messages = [ 19 | { 20 | "role": "user", 21 | "content": [ 22 | {"type": "text", "text": query}, 23 | ], 24 | }, 25 | ] 26 | if image is not None: 27 | messages[0]['content'].append({ 28 | "type": "image_url", 29 | "image_url": { 30 | "url": f"data:image/jpeg;base64,{image}", 31 | }, 32 | }) 33 | 34 | model = 'gpt-4o-2024-08-06' 35 | try: 36 | completion = client.chat.completions.create( 37 | model=model, 38 | messages=messages, 39 | max_tokens=512, temperature=0.0, 40 | ) 41 | except: 42 | time.sleep(1) 43 | try: 44 | completion = client.chat.completions.create( 45 | model=model, 46 | messages=messages, 47 | max_tokens=512, temperature=0.0, 48 | ) 49 | except: 50 | completion = None 51 | 52 | if completion is None: 53 | print("Warning! gpt infer does not work") 54 | ret = "TODO" 55 | else: 56 | ret = completion.choices[0].message.content 57 | return i, j, k, ret 58 | 59 | 60 | def gpt_evaluate(data, responses): 61 | with open(os.path.join(os.path.dirname(__file__), 'prompts/gpt_evaluate.txt')) as f: 62 | PROMPT = f.read() 63 | 64 | def format_prompt(i, j, k): 65 | which_step = {1: "first", 2: "second", 3: "third", 4: "fourth", 5: "fifth"}[j + 1] 66 | question = data[i]['question'] 67 | cot = [] 68 | for j_ in range(len(data[i]['response']['reasoning'])): 69 | cot.append("{:d}. {:s}".format(j_ + 1, data[i]['response']['reasoning'][j_])) 70 | if not data[i]['reasoning_correctness'][j_]: 71 | cot.append(" - Ground truth critique: incorrect. {}".format(data[i]['reasoning_critic'][j_][k])) 72 | if j_ == j: 73 | cot.append(" - Critique to be evaluated: incorrect. {}".format( 74 | responses[i]['formatted']['reasoning_critic'][j_] 75 | )) 76 | break 77 | cot = '\n'.join(cot) 78 | return PROMPT.replace("{{{WHICH_STEP}}}", which_step).replace("{{{QUESTION}}}", question) \ 79 | .replace("{{{COT}}}", cot) 80 | 81 | queries = [] 82 | gpt_responses = [] 83 | assert len(data) == len(responses) 84 | for i in range(len(data)): 85 | assert len(data[i]['reasoning_critic']) == len(responses[i]['formatted']['reasoning_critic']) 86 | gpt_responses.append([[None, None, None] for _ in range(len(responses[i]['formatted']['reasoning_critic']))]) 87 | for j in range(len(responses[i]['formatted']['reasoning_critic'])): 88 | if data[i]['reasoning_correctness'][j] is False and \ 89 | responses[i]['formatted']['reasoning_correctness'][j] is False: 90 | for k in range(3): 91 | queries.append((i, j, k, None, format_prompt(i, j, k))) 92 | 93 | def parse_response(response): 94 | if response.lower().endswith(' incorrect') or response.lower().endswith(' incorrect.'): 95 | correct = False 96 | else: 97 | correct = True 98 | return {'response': response, 'correct': correct} 99 | 100 | random.seed(42) 101 | random.shuffle(queries) 102 | 103 | count = 0 104 | with get_pool(args.n_proc) as p: 105 | for i, j, k, response in tqdm.tqdm(p.imap(func, queries), total=len(queries)): 106 | gpt_responses[i][j][k] = parse_response(response) 107 | count += 1 108 | if count <= 5: 109 | print() 110 | print() 111 | print("\n--- Example prompt:", count) 112 | print(queries[count - 1][-1]) 113 | print("\n--- Example output:", count) 114 | print(response) 115 | print("\n--- Parsed correctness:", gpt_responses[i][j][k]['correct']) 116 | 117 | return gpt_responses 118 | 119 | 120 | def _calc_gpt_metrics(data, responses, gpt_responses): 121 | tp = 0 122 | tp_binary = 0 123 | gt_pos = 0 124 | pred_pos = 0 125 | assert len(data) == len(responses) 126 | for i in range(len(data)): 127 | assert len(data[i]['reasoning_critic']) == len(responses[i]['formatted']['reasoning_critic']) 128 | for j in range(len(responses[i]['formatted']['reasoning_critic'])): 129 | if data[i]['reasoning_correctness'][j] is False and \ 130 | responses[i]['formatted']['reasoning_correctness'][j] is False: 131 | tp += np.mean([int(x['correct']) for x in gpt_responses[i][j]]) 132 | tp_binary += 1 133 | gt_pos += data[i]['reasoning_correctness'].count(False) 134 | pred_pos += responses[i]['formatted']['reasoning_correctness'].count(False) 135 | p = tp / pred_pos 136 | r = tp / gt_pos 137 | f1 = 2 / (1 / p + 1 / r) 138 | return f1 * 100 139 | 140 | 141 | def calc_gpt_metrics(data, responses, gpt_responses): 142 | reasoning_ids = [i for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Reasoning'] 143 | perception_ids = [i for i in range(len(data)) if data[i]['meta_data']['critic_superskill'] == 'Perception'] 144 | return { 145 | 'Total': _calc_gpt_metrics(data, responses, gpt_responses), 146 | 'Reasoning': _calc_gpt_metrics( 147 | [data[i] for i in reasoning_ids], 148 | [responses[i] for i in reasoning_ids], 149 | [gpt_responses[i] for i in reasoning_ids] 150 | ), 151 | 'Perception': _calc_gpt_metrics( 152 | [data[i] for i in perception_ids], 153 | [responses[i] for i in perception_ids], 154 | [gpt_responses[i] for i in perception_ids] 155 | ), 156 | } 157 | 158 | 159 | def main(args): 160 | with open(args.input) as f: 161 | data = [json.loads(line) for line in f] 162 | 163 | with open(args.output) as f: 164 | responses = [json.loads(line) for line in f] 165 | 166 | if not os.path.exists(args.output + '.gpt_evaluate_cache'): 167 | gpt_eval_responses = gpt_evaluate(data, responses) 168 | assert not os.path.exists(args.output + '.gpt_evaluate_cache') 169 | with open(args.output + '.gpt_evaluate_cache', 'w') as f: 170 | for line in gpt_eval_responses: 171 | f.write(json.dumps(line) + '\n') 172 | else: 173 | with open(args.output + '.gpt_evaluate_cache') as f: 174 | gpt_eval_responses = [json.loads(line) for line in f] 175 | 176 | metrics = calc_gpt_metrics(data, responses, gpt_eval_responses) 177 | print(json.dumps(metrics, indent=2)) 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('output') 183 | parser.add_argument('--input', default='test.jsonl') 184 | parser.add_argument('--n_proc', default=16, type=int) 185 | args = parser.parse_args() 186 | 187 | main(args) 188 | -------------------------------------------------------------------------------- /infer_correction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import os 5 | import random 6 | 7 | from evaluate import evaluate_correction 8 | from utils import launch_locally, infer, get_answer_format 9 | 10 | 11 | def format_prompt(prompt_base, item, critique, critique_setting): 12 | answer_format = get_answer_format(item) 13 | 14 | C_map = {True: 'correct.', False: 'incorrect.', None: 'unknown.'} 15 | 16 | reasoning = [] 17 | for i, thought in enumerate(item['response']['reasoning']): 18 | reasoning.append("{:d}. {:s}".format(i + 1, thought)) 19 | if critique_setting.startswith("AS"): 20 | c = C_map[critique['reasoning_correctness'][i]] 21 | if critique['reasoning_correctness'][i] is False: 22 | c = 'incorrect.' 23 | if critique_setting == 'ASE': 24 | c += ' Explanation: ' + ('unknown.' if critique['reasoning_critic'][i] is None 25 | else critique['reasoning_critic'][i]) 26 | reasoning.append(" - Critique: " + c) 27 | reasoning.append("{:d}. The final answer is: {}".format( 28 | len(item['response']['reasoning']) + 1, item['response']['answer'])) 29 | reasoning.append(" - Critique: " + C_map[critique['answer_correctness']]) 30 | reasoning = '\n'.join(reasoning) 31 | 32 | if critique_setting.startswith("AS"): 33 | critique_setting_text = "each reasoning step" 34 | else: 35 | critique_setting_text = "the answer" 36 | 37 | prompt = prompt_base.replace("{{{QUESTION}}}", item['question']).replace("{{{ANSWER_FORMAT}}}", answer_format). \ 38 | replace("{{{REASONING}}}", reasoning).replace("{{{CRITIQUE_SETTING}}}", critique_setting_text) 39 | if critique_setting == 'A': 40 | prompt = prompt.replace("{{{REASONING}}}", "the answer") 41 | else: 42 | prompt = prompt.replace("{{{REASONING}}}", "each reasoning step as well as the final answer") 43 | return prompt 44 | 45 | 46 | def format_response(response): 47 | response_orig = response 48 | response = response.replace('\_', '_').replace('\\', '\\\\') 49 | 50 | success = False 51 | try: 52 | response = json.loads(response.split('```json')[-1].split('```')[0]) 53 | assert isinstance(response, dict) 54 | success = True 55 | except: 56 | pass 57 | 58 | if not success: 59 | try: 60 | response = json.loads(response.split('``` json')[-1].split('```')[0]) 61 | assert isinstance(response, dict) 62 | success = True 63 | except: 64 | pass 65 | 66 | if not success: 67 | try: 68 | response = json.loads('{' + response.split('{')[-1].split('}')[0] + '}') 69 | assert isinstance(response, dict) 70 | except: 71 | response = {} 72 | 73 | response = {k.lower().strip(): v for k, v in response.items()} 74 | answer = str(response.get('answer', '')) 75 | return {'response': response_orig, 'answer': answer} 76 | 77 | 78 | def main(args): 79 | with open(args.input) as f: 80 | data = [json.loads(line) for line in f] 81 | if args.critique == 'human': 82 | critique = copy.deepcopy(data) 83 | for x in critique: 84 | for i in range(len(x['reasoning_critic'])): 85 | assert len(x['reasoning_critic'][i]) == 3 86 | x['reasoning_critic'][i] = random.choice(x['reasoning_critic'][i]) 87 | else: 88 | with open(args.critique) as f: 89 | critique = [json.loads(line)['formatted'] for line in f] 90 | 91 | prompt_fname = os.path.join(os.path.dirname(__file__), 'prompts/correction.txt') 92 | with open(prompt_fname) as f: 93 | PROMPT = f.read() 94 | 95 | prompts = [format_prompt(PROMPT, item, c, args.critique_setting) for item, c in zip(data, critique)] 96 | print("\n--- Example prompt") 97 | print(prompts[0]) 98 | images = [item['image'] for item in data] 99 | responses = infer(prompts, images, args) 100 | responses = [format_response(response) for response in responses] 101 | 102 | if args.output is not None: 103 | print("Save outputs to", args.output) 104 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 105 | with open(args.output, 'w') as f: 106 | for r in responses: 107 | f.write(json.dumps(r) + '\n') 108 | 109 | evaluate_correction(data, responses) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | # model and inference parameters 115 | parser.add_argument('--model', default="gpt-4o-2024-08-06") # auto if we're using a locally served model 116 | 117 | # openai api-based 118 | parser.add_argument('--api_key', default='YOUR_API_KEY') 119 | parser.add_argument('--base_url', default=None) 120 | parser.add_argument('--n_proc', default=16, type=int) 121 | parser.add_argument('--launch_locally', default=None, choices=['lmdeploy', 'vllm', 'sglang']) 122 | 123 | # input output 124 | parser.add_argument('--critique', default='human') 125 | parser.add_argument('--critique_setting', default='ASE', choices=['A', 'AS', 'ASE', ]) 126 | parser.add_argument('--input', default='test.jsonl') 127 | parser.add_argument('--output', default=None) 128 | args = parser.parse_args() 129 | 130 | if args.launch_locally: 131 | process, port = launch_locally(args.launch_locally, args.model) 132 | args.model = 'auto' 133 | args.base_url = f'http://0.0.0.0:{port}/v1' 134 | 135 | try: 136 | main(args) 137 | finally: 138 | if args.launch_locally: 139 | process.kill() 140 | -------------------------------------------------------------------------------- /infer_critique.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from evaluate import evaluate_critique 6 | from utils import infer, launch_locally 7 | 8 | 9 | def format_prompt(prompt_base, item): 10 | reasoning = "\n".join(["{:d}. {:s}".format(i + 1, x) for i, x in enumerate(item['response']['reasoning'])]) 11 | prompt = prompt_base.replace("{{{QUESTION}}}", item['question']). \ 12 | replace("{{{ANSWER}}}", str(item['response']['answer'])).replace("{{{REASONING}}}", reasoning) 13 | 14 | prompt_lines = prompt.splitlines() 15 | final_prompt_lines = [] 16 | for line in prompt_lines: 17 | if '{{{REPEAT_BY_N_STEP}}}' in line: 18 | for i in range(len(item['response']['reasoning'])): 19 | final_prompt_lines.append(line.replace('{{{REPEAT_BY_N_STEP}}}', str(i + 1))) 20 | else: 21 | final_prompt_lines.append(line) 22 | prompt = "\n".join(final_prompt_lines) 23 | 24 | return prompt 25 | 26 | 27 | def format_response(response, n_steps): 28 | if isinstance(response, list) or isinstance(response, tuple): 29 | response_history = response[:-1] 30 | response_orig = response = response[-1] 31 | else: 32 | response_history = None 33 | response_orig = response 34 | format_error = False 35 | 36 | if not isinstance(response, str): 37 | ret = {'response': response, 'formatted': None, 'format_error': True} 38 | if response_history is not None: 39 | ret['response_history'] = response_history 40 | return ret 41 | 42 | response = response.replace('\_', '_').replace('\\', '\\\\') 43 | try: 44 | response = response.split('```json')[-1].split('```')[0] 45 | response = json.loads(response) 46 | assert isinstance(response, dict) 47 | except: 48 | try: 49 | response = '{' + "{".join(response.split('{')[1:]) 50 | response = "}".join(response.split('}')[:-1]) + '}' 51 | response = json.loads(response) 52 | assert isinstance(response, dict) 53 | except: 54 | response = {} 55 | format_error = True 56 | 57 | def to_true_or_false(x): 58 | if isinstance(x, bool): 59 | return x 60 | elif isinstance(x, str): 61 | x = x.lower().strip() 62 | if x == 'correct' or x == 'yes' or x == 'true': 63 | return True 64 | elif x == 'incorrect' or x == 'no' or x == 'false': 65 | return False 66 | return None 67 | 68 | def process_dict_key(x): 69 | if isinstance(x, dict): 70 | return {k.strip().lower(): process_dict_key(v) for k, v in x.items()} 71 | return x 72 | 73 | response = process_dict_key(response) 74 | formatted = {} 75 | 76 | formatted['answer_correctness'] = None 77 | if 'answer_correctness' in response: 78 | if isinstance(response['answer_correctness'], dict): 79 | if 'correctness' in response['answer_correctness']: 80 | formatted['answer_correctness'] = to_true_or_false(response['answer_correctness']['correctness']) 81 | else: 82 | formatted['answer_correctness'] = to_true_or_false(response['answer_correctness']) 83 | if formatted['answer_correctness'] is None: 84 | format_error = True 85 | 86 | formatted['reasoning_correctness'] = [None for _ in range(n_steps)] 87 | formatted['reasoning_critic'] = [None for _ in range(n_steps)] 88 | for i in range(n_steps): 89 | if 'step_{:d}'.format(i + 1) in response: 90 | step_response = response['step_{:d}'.format(i + 1)] 91 | if isinstance(step_response, dict) and 'correctness' in step_response: 92 | formatted['reasoning_correctness'][i] = to_true_or_false(step_response['correctness']) 93 | if 'explanation' in step_response: 94 | formatted['reasoning_critic'][i] = str(step_response['explanation']) 95 | if formatted['reasoning_correctness'][i] is None or formatted['reasoning_critic'][i] is None: 96 | format_error = True 97 | 98 | ret = {'response': response_orig, 'formatted': formatted, 'format_error': format_error} 99 | if response_history is not None: 100 | ret['response_history'] = response_history 101 | return ret 102 | 103 | 104 | def main(args): 105 | with open(args.input) as f: 106 | data = [json.loads(line) for line in f] 107 | 108 | prompt_fname = os.path.join(os.path.dirname(__file__), 'prompts/critique.txt') 109 | with open(prompt_fname) as f: 110 | PROMPT = f.read() 111 | 112 | prompts = [format_prompt(PROMPT, item) for item in data] 113 | print("\n--- Example prompt") 114 | print(prompts[0]) 115 | images = [item['image'] for item in data] 116 | responses = infer(prompts, images, args) 117 | 118 | responses = [format_response(response, len(item['response']['reasoning'])) 119 | for response, item in zip(responses, data)] 120 | for i in range(5): 121 | print("\n--- Example parse:\n") 122 | print(responses[i]['response']) 123 | print("\n--->\n") 124 | print(json.dumps(responses[i]['formatted'], indent=2)) 125 | 126 | if args.output is not None: 127 | print("Save outputs to", args.output) 128 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 129 | with open(args.output, 'w') as f: 130 | for r in responses: 131 | f.write(json.dumps(r) + '\n') 132 | 133 | evaluate_critique(data, responses) 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | # model and inference parameters 139 | parser.add_argument('--model', default="gpt-4o-2024-08-06") # auto if we're using a locally served model 140 | 141 | # openai api-based 142 | parser.add_argument('--api_key', default='YOUR_API_KEY') 143 | parser.add_argument('--base_url', default=None) 144 | parser.add_argument('--n_proc', default=16, type=int) 145 | parser.add_argument('--launch_locally', default=None, choices=['lmdeploy', 'vllm', 'sglang']) 146 | 147 | # input output 148 | parser.add_argument('--input', default='test.jsonl') 149 | parser.add_argument('--output', default=None) 150 | args = parser.parse_args() 151 | 152 | if args.launch_locally: 153 | process, port = launch_locally(args.launch_locally, args.model) 154 | args.model = 'auto' 155 | args.base_url = f'http://0.0.0.0:{port}/v1' 156 | 157 | try: 158 | main(args) 159 | finally: 160 | if args.launch_locally: 161 | process.kill() 162 | -------------------------------------------------------------------------------- /infer_critique_lookback.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import google.generativeai as genai 6 | import tqdm 7 | 8 | import utils 9 | from evaluate import evaluate_critique 10 | from infer_critique import format_response 11 | from utils import launch_locally, func, get_pool 12 | 13 | prompt_dir = os.path.join(os.path.dirname(__file__), 'prompts/') 14 | PROMPT_PROBLEM_SOLVER = "{{{QUESTION}}}\nThink step by step, and then provide your final answer." 15 | with open(os.path.join(prompt_dir, 'lookback_visual-query.txt')) as f: 16 | PROMPT_SCHEDULE_VISUAL_QUERY = f.read() 17 | with open(os.path.join(prompt_dir, 'lookback_synthesize.txt')) as f: 18 | PROMPT_SYNTHESIZE = f.read() 19 | 20 | 21 | def func_agent(obj): 22 | item, image = obj 23 | 24 | # problem solver 25 | ref_answer = func((image, item['question'])) 26 | 27 | def format_prompt(prompt, item): 28 | prompts = [] 29 | for i in range(len(item['response']['reasoning'])): 30 | reasoning = "\n".join(["{:d}. {:s}".format(j + 1, x) 31 | for j, x in enumerate(item['response']['reasoning'][:i + 1])]) 32 | which_step = {1: "first", 2: "second", 3: "third", 4: "fourth", 5: "fifth"}[i + 1] 33 | prompts.append( 34 | prompt.replace("{{{QUESTION}}}", item['question']). \ 35 | replace("{{{ANSWER}}}", str(item['response']['answer'])).replace("{{{REASONING}}}", reasoning). 36 | replace("{{{WHICH_STEP}}}", which_step) 37 | ) 38 | return prompts 39 | 40 | def extract_verify_questions(text): 41 | if 'N/A' in text.strip(): 42 | return [] 43 | 44 | text = "\n" + text.strip() 45 | text = "\n1.".join(text.split("\n1.")[1:]) 46 | lines = text.splitlines() 47 | for i, line in enumerate(lines): 48 | if line.startswith("{:d}.".format(i + 1)): 49 | lines[i] = ".".join(line.split('.')[1:]) 50 | lines[i] = lines[i].strip() 51 | return lines 52 | 53 | # visual verification 54 | prompt = format_prompt(PROMPT_SCHEDULE_VISUAL_QUERY, item) 55 | visual_questions = [func((image, p)) for p in prompt] 56 | visual_questions = [extract_verify_questions(q) for q in visual_questions] 57 | visual_answers = [[func((image, pp + ' Answer briefly.')) for pp in p] for p in visual_questions] 58 | 59 | def format_prompt_synthesize(prompt_base): 60 | reasoning = [] 61 | for i, (r, q, a) in enumerate(zip(item['response']['reasoning'], visual_questions, visual_answers)): 62 | reasoning.append("{:d}. {:s}".format(i + 1, r)) 63 | reasoning = "\n".join(reasoning) 64 | 65 | visual_info = [] 66 | for q, a in zip(visual_questions, visual_answers): 67 | for qq, aa in zip(q, a): 68 | visual_info.append("* {} - {}".format(qq, aa)) 69 | visual_info = "\n".join(visual_info) 70 | if visual_info.strip() == '': 71 | visual_info = "N/A" 72 | 73 | prompt = prompt_base.replace("{{{QUESTION}}}", item['question']). \ 74 | replace("{{{ANSWER}}}", str(item['response']['answer'])).replace("{{{REASONING}}}", reasoning). \ 75 | replace("{{{REFERNCE_ANSWER}}}", ref_answer).replace("{{{VISUAL_INFO}}}", visual_info) 76 | 77 | prompt_lines = prompt.splitlines() 78 | final_prompt_lines = [] 79 | for line in prompt_lines: 80 | if '{{{REPEAT_BY_N_STEP}}}' in line: 81 | for i in range(len(item['response']['reasoning'])): 82 | final_prompt_lines.append(line.replace('{{{REPEAT_BY_N_STEP}}}', str(i + 1))) 83 | else: 84 | final_prompt_lines.append(line) 85 | prompt = "\n".join(final_prompt_lines) 86 | return prompt 87 | 88 | # synthesize 89 | prompt = format_prompt_synthesize(PROMPT_SYNTHESIZE) 90 | ret = func((image, prompt)) 91 | return ret, { 92 | 'ref_answer': ref_answer, 'visual_questions': visual_questions, 'visual_answers': visual_answers, 93 | } 94 | 95 | 96 | def infer(data, images): 97 | utils.args = args 98 | if args.model == "gemini-1.5-pro": 99 | genai.configure(api_key=args.api_key) 100 | responses = [] 101 | assert len(data) == len(images) 102 | with get_pool(args.n_proc) as p: 103 | for response, additional_info in tqdm.tqdm(p.imap(func_agent, zip(data, images)), total=len(images)): 104 | responses.append((response, additional_info)) 105 | if len(responses) <= 5: 106 | print("\n\n------------------------- Example output:", len(responses)) 107 | print(responses[-1][0]) 108 | print("\n--- Additional info:") 109 | print(json.dumps(additional_info, indent=2)) 110 | return responses 111 | 112 | 113 | data = [] 114 | 115 | 116 | def main(args): 117 | images = [item['image'] for item in data] 118 | responses_raw = infer(data, images) 119 | 120 | responses = [] 121 | for (response, additional_info), item in zip(responses_raw, data): 122 | response = format_response(response, len(item['response']['reasoning'])) 123 | response['additional_info'] = additional_info 124 | responses.append(response) 125 | 126 | if args.output is not None: 127 | print("Save outputs to", args.output) 128 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 129 | with open(args.output, 'w') as f: 130 | for r in responses: 131 | f.write(json.dumps(r) + '\n') 132 | 133 | evaluate_critique(data, responses) 134 | 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | # model and inference parameters 139 | parser.add_argument('--model', default="gpt-4o-2024-08-06") # auto if we're using a locally served model 140 | 141 | # openai api-based 142 | parser.add_argument('--api_key', default='YOUR_API_KEY') 143 | parser.add_argument('--base_url', default=None) 144 | parser.add_argument('--n_proc', default=16, type=int) 145 | parser.add_argument('--launch_locally', default=None, choices=['lmdeploy', 'vllm', 'sglang']) 146 | 147 | # input output 148 | parser.add_argument('--input', default='test.jsonl') 149 | parser.add_argument('--output', default=None) 150 | args = parser.parse_args() 151 | 152 | if args.launch_locally: 153 | process, port = launch_locally(args.launch_locally, args.model) 154 | args.model = 'auto' 155 | args.base_url = f'http://0.0.0.0:{port}/v1' 156 | 157 | with open(args.input) as f: 158 | data = [json.loads(line) for line in f] 159 | 160 | try: 161 | main(args) 162 | finally: 163 | if args.launch_locally: 164 | process.kill() 165 | -------------------------------------------------------------------------------- /prompts/correction.txt: -------------------------------------------------------------------------------- 1 | You are given an image, a question about the image, a multi-step reasoning process leading to an answer, and the critique for {{{CRITIQUE_SETTING}}}. Based on this information, think step by step and provide the correct answer. Your response should end with a json dictionary as follows: 2 | ```json 3 | {"answer": ANSWER} 4 | ``` 5 | ANSWER should be {{{ANSWER_FORMAT}}}. 6 | 7 | # Question: {{{QUESTION}}} 8 | 9 | # Reasoning: 10 | {{{REASONING}}} -------------------------------------------------------------------------------- /prompts/critique.txt: -------------------------------------------------------------------------------- 1 | You are given an image, a question about the image, a reasoning process involving multiple steps, and a final answer. Evaluate the accuracy of each reasoning step and the final answer. For each reasoning step, evaluate whether it is correct or incorrect. If it is incorrect, briefly explain why. Then, evaluate the final answer as correct or incorrect. 2 | 3 | Your response should be a json in the following format: 4 | ```json 5 | { 6 | "step_{{{REPEAT_BY_N_STEP}}}": {"correctness": true or false, "explanation": "Write your explanation here"}, 7 | "answer_correctness": true or false 8 | } 9 | ``` 10 | 11 | --- 12 | 13 | Instructions for evaluating reasoning steps: 14 | * For each reasoning step, evaluate whether it is correct or incorrect based on the accuracy of the factual information and logical calculations it contains. 15 | * Evaluate each step in isolation. 16 | * You do not need to evaluate the importance of the step in achieving the correct final answer; focus solely on the correctness within that step itself. 17 | 18 | --- 19 | 20 | # Question: {{{QUESTION}}} 21 | 22 | # Reasoning: 23 | {{{REASONING}}} 24 | 25 | # Answer: {{{ANSWER}}} -------------------------------------------------------------------------------- /prompts/gpt_evaluate.txt: -------------------------------------------------------------------------------- 1 | You task is to evaluate a **critique** of reasoning. 2 | 3 | You are given a question about an image, an incorrect chain-of-thought trying to answer the question, and a **critique** that explains why the {{{WHICH_STEP}}} step of the chain-of-thought is incorrect. You are required to focus on the {{{WHICH_STEP}}} step, and analyzes whether the critique correctly identifies the source of error. 4 | 5 | For reference, you will be provided with the ground truth critique for each individual step. Evaluate the critique by comparing against the ground truth. Focus on whether the critique accurately identifies the **core mistake**. If the critique addresses the core error and does not contain factual or logical error, minor deviations or omissions in reasoning or explanation should be considered as correct. 6 | 7 | Think step by step, then provide your judgment. Your response should end with either: 8 | 9 | # Judgment: the critique is correct. 10 | 11 | or 12 | 13 | # Judgment: the critique is incorrect. 14 | 15 | --- Example 1 16 | 17 | # Question: Question: 26 18 | Data values represented by the bar labeled "10" in the histogram below fall into which range? 19 | 20 | A. 7.5 up to 12.5 21 | B. 7.25 up to 12.75 22 | C. 8.5 up to 11.5 23 | D. 8.75 up to 11.75 24 | E. 8.75 up to 11.25 25 | 26 | # Chain-of-thought: 27 | 1. The bar labeled '10' in the histogram corresponds to data values between 7.5 and 12.5. 28 | - Ground truth critique: incorrect. 7.5 and 12.5 are not boundaries of the interval. The boundaries are midway between the bar labels: left boundary: (7.5 + 10)/2 = 8.75; right boundary: (10 + 12.5)/2 = 11.25. 29 | - Critique to be evaluated: incorrect. The bar labeled '10' in the histogram actually corresponds to data values between 10.0 and 12.5. The question incorrectly stated that the range is from 7.5 to 12.5. 30 | 31 | # Evaluation: 32 | The core error is that step 2 incorrectly identifies the span as 7.5 - 12.5. Although the critique correctly identifies the core error, the critique contains a factual error that the span should be 10.0 - 12.5, while according to the ground truth critique, the span should be 8.75 - 11.25. Therefore, the critique is incorrect. 33 | 34 | # Judgment: the critique is incorrect. 35 | 36 | --- Example 2 37 | 38 | # Question: The small metallic object that is on the left side of the tiny rubber cylinder on the right side of the big shiny cylinder is what shape? 39 | 40 | # Chain-of-thought: 41 | 1. First, identify the large shiny cylinder, which is the gold object in the foreground. 42 | - Ground truth critique: correct. 43 | 2. To its right is a tiny rubber cylinder, which is turquoise. 44 | - Ground truth critique: correct. 45 | 3. To the left of this turquoise cylinder is a small metallic object, which appears to be a sphere. 46 | - Ground truth critique: incorrect. The small metallic object is a cube, not sphere. 47 | 4. Therefore, the shape of the object is a sphere. 48 | - Ground truth critique: incorrect. The pervious incorrect sentence leads to this incorrect sentence. 49 | - Critique to be evaluated: incorrect. The shape of the object is a cube, not a sphere. 50 | 51 | # Evaluation: 52 | Based on the ground truth, the error of step 4 originates from step 3's error. Specifically, the error is that the small metallic object is a cube, not sphere. The critique correctly identifies the error comes from the object being a cube, which aligns with the ground truth critique. 53 | 54 | # Judgment: the critique is correct. 55 | 56 | --- 57 | 58 | # Question: {{{QUESTION}}} 59 | 60 | # Chain-of-thought: 61 | {{{COT}}} -------------------------------------------------------------------------------- /prompts/lookback_synthesize.txt: -------------------------------------------------------------------------------- 1 | You are given an image, a question about the image, a reasoning process involving multiple steps, and a final answer. Evaluate the accuracy of each reasoning step and the final answer. For each reasoning step, evaluate whether it is correct or incorrect. If it is incorrect, briefly explain why. Then, evaluate the final answer as correct or incorrect. 2 | 3 | Your response should be a json in the following format: 4 | ```json 5 | { 6 | "step_{{{REPEAT_BY_N_STEP}}}": {"correctness": true or false, "explanation": "Write your explanation here"}, 7 | "answer_correctness": true or false 8 | } 9 | ``` 10 | 11 | --- 12 | 13 | Instructions for evaluating reasoning steps: 14 | * For each reasoning step, evaluate whether it is correct or incorrect based on the accuracy of the factual information and logical calculations it contains. 15 | * Evaluate each step in isolation. 16 | * You do not need to evaluate the importance of the step in achieving the correct final answer; focus solely on the correctness within that step itself. 17 | 18 | To help your evaluation, we provide the following additional information: 19 | 1. Question-answer pairs to verify the visual information. 20 | 2. A candidate answer, which MAY OR MAY NOT be correct. 21 | 22 | --- To be evaluated: 23 | 24 | # Question: {{{QUESTION}}} 25 | 26 | # Reasoning: 27 | {{{REASONING}}} 28 | 29 | # Answer: {{{ANSWER}}} 30 | 31 | --- Visual information 32 | 33 | {{{VISUAL_INFO}}} 34 | 35 | --- Reference answer 36 | 37 | {{{REFERNCE_ANSWER}}} -------------------------------------------------------------------------------- /prompts/lookback_visual-query.txt: -------------------------------------------------------------------------------- 1 | You are given an image and a reasoning process around this image. To evaluate the accuracy of the last step, you need to identify information from the image. List all questions necessary to verify against the image. 2 | 3 | Detailed instructions: 4 | * Focus only on the last reasoning step. No need to verify visual information from previous steps. 5 | * Each question should focus on verifying visual information from the image, without involving any reasoning. 6 | * Keep questions simple. Break down complex questions into smaller, independent ones. 7 | * Ensure each question can be answered in isolation, without needing context from the reasoning process. 8 | * If the last step does not involve any information from the image, you can respond with N/A. 9 | 10 | Your response should be a numbered list as follows: 11 | 1. Question 1 12 | 2. Question 2 13 | ... 14 | 15 | --- Example input: 16 | 17 | 1. The cat is sitting on a cushion placed on the toilet seat. 18 | 2. The cat's body is oriented towards the camera, and its head is also facing the camera. 19 | 3. The toilet is directly behind the cat, and the cat is not showing any signs of turning away from it. 20 | 21 | --- Example output: 22 | 23 | 1. Is the toilet directly behind the cat? 24 | 2. Is the cat turning away from the toilet that is behind the cat? 25 | 26 | --- Input: 27 | 28 | {{{REASONING}}} 29 | 30 | --- Output: -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anthropic>=0.36.1 2 | google-generativeai>=0.8.3 3 | openai>=1.41.0 4 | Levenshtein>=0.26.0 5 | word2number>=1.1 6 | tabulate 7 | -------------------------------------------------------------------------------- /src_evaluation/CLEVR_evaluation.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from word2number import w2n 4 | 5 | 6 | def safe_equal(prediction, answer): 7 | try: 8 | if prediction == answer: 9 | return True 10 | return False 11 | except Exception as e: 12 | print(e) 13 | return False 14 | 15 | 16 | def CLEVREvaluate(answer, label, meta_data): 17 | is_integer = False 18 | 19 | try: 20 | label = int(label) 21 | is_integer = True 22 | except: 23 | pass 24 | 25 | if is_integer: 26 | try: 27 | answer = int(answer) 28 | except: 29 | try: 30 | answer = w2n.word_to_num(''.join([a for a in answer if a not in string.punctuation])) 31 | answer = str(int(answer)) 32 | except: 33 | answer = None 34 | return safe_equal(answer, label) 35 | 36 | else: 37 | try: 38 | translator = str.maketrans('', '', string.punctuation) 39 | answer = answer.translate(translator) 40 | 41 | answer = answer.split(" ")[0] 42 | 43 | answer = answer.lower() 44 | return safe_equal(answer, label) 45 | except: 46 | return False 47 | -------------------------------------------------------------------------------- /src_evaluation/EmbSpatial_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from .MathVista_evaluation import get_most_similar 4 | 5 | 6 | def safe_equal(prediction, answer): 7 | try: 8 | if prediction == answer: 9 | return True 10 | return False 11 | except Exception as e: 12 | print(e) 13 | return False 14 | 15 | 16 | def EmbSpatial_evaluation(extraction, label, meta_data): 17 | choices = meta_data['answer_options'] 18 | 19 | # extract "A" from "(A) text" 20 | letter = re.findall(r'\(([a-zA-Z])\)', extraction) 21 | if len(letter) > 0: 22 | extraction = letter[0].upper() 23 | 24 | # also try to extract \"A\" from '"A"' 25 | letter = re.search(r'\"[a-zA-Z]\"', extraction) 26 | if letter: 27 | extraction = letter.group() 28 | 29 | options = [chr(ord('A') + i) for i in range(len(choices))] 30 | assert label in options 31 | 32 | if extraction not in options: 33 | # select the most similar option 34 | choice = get_most_similar(extraction, choices) 35 | extraction = options[choices.index(choice)] 36 | assert extraction in options 37 | 38 | return safe_equal(extraction, label) 39 | -------------------------------------------------------------------------------- /src_evaluation/FigureQA_evaluation.py: -------------------------------------------------------------------------------- 1 | from .VSR_evaluation import VSREvaluate 2 | 3 | 4 | def FigureQAEvaluate(answer, label, meta_data): 5 | if label == 0: 6 | label = 'False' 7 | else: 8 | assert label == 1 9 | label = 'True' 10 | return VSREvaluate(answer, label, meta_data) 11 | -------------------------------------------------------------------------------- /src_evaluation/GQA_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def general_postprocessing(prediction): 5 | prediction = str(prediction) 6 | 7 | prediction = prediction.replace('\n', ' ') 8 | prediction = prediction.replace('\t', ' ') 9 | prediction = prediction.strip() 10 | prediction = prediction.lower() 11 | 12 | if prediction == 'true': 13 | prediction = 'yes' 14 | elif prediction == 'false': 15 | prediction = 'no' 16 | return prediction 17 | 18 | 19 | # For evaluation 20 | contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", 21 | "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", 22 | "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", 23 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", 24 | "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", 25 | "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", 26 | "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", 27 | "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", 28 | "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", 30 | "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", 31 | "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", 32 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", 33 | "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", 34 | "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", 35 | "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 36 | "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", 37 | "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", 38 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", 39 | "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", 40 | "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", 41 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", 42 | "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", 43 | "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", 44 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", 45 | "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", 46 | "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", 47 | "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", 48 | "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 49 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", 50 | "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", 51 | "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", 52 | "youre": "you're", "youve": "you've"} 53 | 54 | manualMap = {'none': '0', 55 | 'zero': '0', 56 | 'one': '1', 57 | 'two': '2', 58 | 'three': '3', 59 | 'four': '4', 60 | 'five': '5', 61 | 'six': '6', 62 | 'seven': '7', 63 | 'eight': '8', 64 | 'nine': '9', 65 | 'ten': '10' 66 | } 67 | articles = ['a', 68 | 'an', 69 | 'the' 70 | ] 71 | 72 | periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 73 | commaStrip = re.compile("(\d)(\,)(\d)") 74 | punct = [';', r"/", '[', ']', '"', '{', '}', 75 | '(', ')', '=', '+', '\\', '_', '-', 76 | '>', '<', '@', '`', ',', '?', '!'] 77 | 78 | max_words = 50 79 | 80 | 81 | def processPunctuation(inText): 82 | outText = inText 83 | for p in punct: 84 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) != None): 85 | outText = outText.replace(p, '') 86 | else: 87 | outText = outText.replace(p, ' ') 88 | outText = periodStrip.sub("", outText, re.UNICODE) 89 | return outText 90 | 91 | 92 | def processDigitArticle(inText): 93 | outText = [] 94 | tempText = inText.lower().split() 95 | for word in tempText: 96 | word = manualMap.setdefault(word, word) 97 | if word not in articles: 98 | outText.append(word) 99 | else: 100 | pass 101 | for wordId, word in enumerate(outText): 102 | if word in contractions: 103 | outText[wordId] = contractions[word] 104 | outText = ' '.join(outText) 105 | return outText 106 | 107 | 108 | def post_process(prediction, stem=True): 109 | """ 110 | Code from https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py, 111 | as indicated here https://okvqa.allenai.org/leaderboard.html 112 | :return: 113 | """ 114 | prediction = general_postprocessing(prediction) 115 | 116 | prediction = prediction.replace('\n', ' ') 117 | prediction = prediction.replace('\t', ' ') 118 | prediction = prediction.strip() 119 | prediction = processPunctuation(prediction) 120 | prediction = processDigitArticle(prediction) 121 | return prediction 122 | 123 | 124 | def GQAEvaluate(answer, label, meta_data): 125 | try: 126 | processed = post_process(answer) 127 | if processed == post_process(label): 128 | return True 129 | except: 130 | return False 131 | return False 132 | -------------------------------------------------------------------------------- /src_evaluation/HallusionBench_evaluation.py: -------------------------------------------------------------------------------- 1 | def HallusionBenchEvaluate(text, label, meta_data): 2 | # Only keep the first sentence 3 | if text.find('.') != -1: 4 | text = text.split('.')[0] 5 | 6 | text = text.replace(',', '') 7 | words = text.split(' ') 8 | if 'No' in words or 'not' in words or 'no' in words: 9 | answer = 'No' 10 | else: 11 | answer = 'Yes' 12 | 13 | if answer == label: 14 | return True 15 | else: 16 | return False 17 | -------------------------------------------------------------------------------- /src_evaluation/MMMU_evaluation.py: -------------------------------------------------------------------------------- 1 | """Response Parsing and Evaluation for various models""" 2 | 3 | import random 4 | import re 5 | 6 | random.seed(42) 7 | import numpy as np 8 | 9 | 10 | # ----------- Process Multi-choice ------------- 11 | def parse_multi_choice_response(response, all_choices, index2ans): 12 | """ 13 | Parse the prediction from the generated response. 14 | Return the predicted index e.g., A, B, C, D. 15 | """ 16 | for char in [',', '.', '!', '?', ';', ':', "'"]: 17 | response = response.strip(char) 18 | response = " " + response + " " # add space to avoid partial match 19 | 20 | index_ans = True 21 | ans_with_brack = False 22 | candidates = [] 23 | for choice in all_choices: # e.g., (A) (B) (C) (D) 24 | if f'({choice})' in response: 25 | candidates.append(choice) 26 | ans_with_brack = True 27 | 28 | if len(candidates) == 0: 29 | for choice in all_choices: # e.g., A B C D 30 | if f' {choice} ' in response: 31 | candidates.append(choice) 32 | 33 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example 34 | if len(candidates) == 0 and len(response.split()) > 5: 35 | for index, ans in index2ans.items(): 36 | if ans.lower() in response.lower(): 37 | candidates.append(index) 38 | index_ans = False # it's content ans. 39 | 40 | if len(candidates) == 0: # still not get answer, randomly choose one. 41 | pred_index = random.choice(all_choices) 42 | elif len(candidates) > 1: 43 | start_indexes = [] 44 | if index_ans: 45 | if ans_with_brack: 46 | for can in candidates: 47 | index = response.rfind(f'({can})') 48 | start_indexes.append(index) # -1 will be ignored anyway 49 | # start_indexes = [generated_response.index(f'({can})') for can in candidates] 50 | else: 51 | for can in candidates: 52 | index = response.rfind(f" {can} ") 53 | start_indexes.append(index) 54 | else: 55 | for can in candidates: 56 | index = response.lower().rfind(index2ans[can].lower()) 57 | start_indexes.append(index) 58 | # get the last one 59 | pred_index = candidates[np.argmax(start_indexes)] 60 | else: # if only one candidate, use it. 61 | pred_index = candidates[0] 62 | 63 | return pred_index 64 | 65 | 66 | # ----------- Process Open ------------- 67 | def check_is_number(string): 68 | """ 69 | Check if the given string a number. 70 | """ 71 | try: 72 | float(string.replace(',', '')) 73 | return True 74 | except ValueError: 75 | # check if there's comma inside 76 | return False 77 | 78 | 79 | def normalize_str(string): 80 | """ 81 | Normalize the str to lower case and make them float numbers if possible. 82 | """ 83 | # check if characters in the string 84 | 85 | # if number, numerize it. 86 | string = string.strip() 87 | 88 | is_number = check_is_number(string) 89 | 90 | if is_number: 91 | string = string.replace(',', '') 92 | string = float(string) 93 | # leave 2 decimal 94 | string = round(string, 2) 95 | return [string] 96 | else: # it's likely to be a string 97 | # lower it 98 | string = string.lower() 99 | if len(string) == 1: 100 | return [" " + string, string + " "] # avoid trivial matches 101 | return [string] 102 | 103 | 104 | def extract_numbers(string): 105 | """ 106 | Exact all forms of numbers from a string with regex. 107 | """ 108 | # Pattern for numbers with commas 109 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b' 110 | # Pattern for scientific notation 111 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' 112 | # Pattern for simple numbers without commas 113 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])' 114 | 115 | # Extract numbers with commas 116 | numbers_with_commas = re.findall(pattern_commas, string) 117 | # Extract numbers in scientific notation 118 | numbers_scientific = re.findall(pattern_scientific, string) 119 | # Extract simple numbers without commas 120 | numbers_simple = re.findall(pattern_simple, string) 121 | 122 | # Combine all extracted numbers 123 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple 124 | return all_numbers 125 | 126 | 127 | def parse_open_response(response): 128 | """ 129 | Parse the prediction from the generated response. 130 | Return a list of predicted strings or numbers. 131 | """ 132 | # content = content.strip("\n").strip(".").strip(" ") 133 | 134 | key_responses = [response] 135 | 136 | pred_list = key_responses.copy() # keep the original string response 137 | for resp in key_responses: 138 | pred_list.extend(extract_numbers(resp)) 139 | 140 | tmp_pred_list = [] 141 | for i in range(len(pred_list)): 142 | tmp_pred_list.extend(normalize_str(pred_list[i])) 143 | pred_list = tmp_pred_list 144 | 145 | # remove duplicates 146 | pred_list = list(set(pred_list)) 147 | 148 | return pred_list 149 | 150 | 151 | # ----------- Evaluation ------------- 152 | 153 | def eval_multi_choice(gold_i, pred_i): 154 | """ 155 | Evaluate a multiple choice instance. 156 | """ 157 | correct = False 158 | # only they are exactly the same, we consider it as correct 159 | if isinstance(gold_i, list): 160 | for answer in gold_i: 161 | if answer == pred_i: 162 | correct = True 163 | break 164 | else: # gold_i is a string 165 | if gold_i == pred_i: 166 | correct = True 167 | return correct 168 | 169 | 170 | def eval_open(gold_i, pred_i): 171 | """ 172 | Evaluate an open question instance 173 | """ 174 | correct = False 175 | if isinstance(gold_i, list): 176 | # use float to avoid trivial matches 177 | norm_answers = [] 178 | for answer in gold_i: 179 | norm_answers.extend(normalize_str(answer)) 180 | else: 181 | norm_answers = normalize_str(gold_i) 182 | for pred in pred_i: # pred is already normalized in parse response phase 183 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i 184 | for norm_ans in norm_answers: 185 | # only see if the string answer in the string pred 186 | if isinstance(norm_ans, str) and norm_ans in pred: 187 | if not correct: 188 | correct = True 189 | break 190 | else: # it's a float number 191 | if pred in norm_answers: 192 | if not correct: 193 | correct = True 194 | break 195 | return correct 196 | 197 | 198 | def MMMU_evaluate(answer, label, meta_data): 199 | """ 200 | Evaluation for multiple choice and open questions. 201 | """ 202 | question_type = meta_data["question_type"] 203 | 204 | if question_type == 'multiple-choice': 205 | extraction = None 206 | 207 | letter = re.findall(r'\(([a-zA-Z])\)', answer) 208 | if len(letter) > 0: 209 | extraction = letter[0].upper() 210 | 211 | if extraction is None: 212 | letter = re.search(r'\"[a-zA-Z]\"', answer) 213 | if letter: 214 | extraction = letter.group() 215 | 216 | if extraction is None: # we don't have options, we can't match options, so we just extract first letter anyway 217 | letter = re.search(r'[a-zA-Z]', answer) 218 | if letter: 219 | extraction = letter.group() 220 | 221 | if extraction is None: 222 | extraction = answer 223 | return eval_multi_choice(label, extraction) 224 | 225 | else: # open question 226 | pred_i = parse_open_response(str(answer)) 227 | return eval_open(label, pred_i) 228 | -------------------------------------------------------------------------------- /src_evaluation/MMVet_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def safe_equal(prediction, answer): 5 | try: 6 | if prediction == answer: 7 | return True 8 | return False 9 | except Exception as e: 10 | print(e) 11 | return False 12 | 13 | 14 | def has_numbers(inputString): 15 | return bool(re.search(r'\d', inputString)) 16 | 17 | 18 | def convert_string(input_string): 19 | # Define the regular expression pattern 20 | pattern = r'[^0-9.,/]' 21 | 22 | # Use re.sub() to replace unwanted characters 23 | result = re.sub(pattern, '', input_string) 24 | 25 | return result 26 | 27 | 28 | def MMVetEvaluate(answer, label, meta_data): 29 | labels = label.split("") 30 | 31 | for label in labels: 32 | if has_numbers(label): 33 | label = convert_string(label) 34 | if "," in label or "/" in label: 35 | continue 36 | 37 | try: 38 | if "." in label: 39 | precision = len(label.split(".")[1]) 40 | answer = str(round(float(convert_string(str(answer))), precision)) 41 | else: 42 | answer = str(round(float(convert_string(str(answer))))) 43 | except: 44 | continue 45 | 46 | if safe_equal(answer, label): 47 | return True 48 | 49 | else: 50 | if safe_equal(answer, label): 51 | return True 52 | 53 | return False -------------------------------------------------------------------------------- /src_evaluation/MathVision_Evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def safe_equal(prediction, answer): 5 | try: 6 | if prediction == answer: 7 | return True 8 | return False 9 | except Exception as e: 10 | print(e) 11 | return False 12 | 13 | 14 | def MathVisionEvaluate(answer, label, meta_data): 15 | if label[0] == "(": 16 | extracted_label = label[1:-1].split(",") 17 | 18 | try: 19 | if answer[0] == "(" and answer[-1] == ")": 20 | extracted_answer = answer[1:-1].split(",") 21 | 22 | extracted_answer = [str(round(float(e))) for e in extracted_answer] 23 | 24 | for i in range(len(extracted_answer)): 25 | if safe_equal(extracted_answer[i], extracted_label[i]) == False: 26 | return False 27 | 28 | return True 29 | except: 30 | return False 31 | 32 | try: 33 | float(answer) 34 | is_number = True 35 | except: 36 | is_number = False 37 | 38 | if is_number == True: 39 | if "." in label: 40 | extracted_answer = str(round(float(answer), 1)) 41 | else: 42 | extracted_answer = str(round(float(answer))) 43 | 44 | return safe_equal(extracted_answer, label) 45 | 46 | else: 47 | letter = re.search(r'[a-zA-Z]', answer) 48 | if letter: 49 | answer = letter.group() 50 | else: 51 | answer = None 52 | 53 | return safe_equal(answer, label) 54 | -------------------------------------------------------------------------------- /src_evaluation/MathVista_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from Levenshtein import distance 4 | 5 | 6 | def get_most_similar(prediction, choices): 7 | """ 8 | Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction 9 | """ 10 | distances = [distance(prediction, choice) for choice in choices] 11 | ind = distances.index(min(distances)) 12 | return choices[ind] 13 | # return min(choices, key=lambda choice: distance(prediction, choice)) 14 | 15 | 16 | def normalize_extracted_answer(extraction, choices, question_type, answer_type, precision): 17 | """ 18 | Normalize the extracted answer to match the answer type 19 | """ 20 | if question_type == 'multi_choice': 21 | # make sure the extraction is a string 22 | if isinstance(extraction, str): 23 | extraction = extraction.strip() 24 | else: 25 | try: 26 | extraction = str(extraction) 27 | except: 28 | extraction = "" 29 | 30 | # extract "A" from "(A) text" 31 | letter = re.findall(r'\(([a-zA-Z])\)', extraction) 32 | if len(letter) > 0: 33 | extraction = letter[0].upper() 34 | 35 | # also try to extract \"A\" from '"A"' 36 | letter = re.search(r'\"[a-zA-Z]\"', extraction) 37 | if letter: 38 | extraction = letter.group() 39 | 40 | options = [chr(ord('A') + i) for i in range(len(choices))] 41 | 42 | if extraction in options: 43 | # convert option letter to text, e.g. "A" -> "text" 44 | ind = options.index(extraction) 45 | extraction = choices[ind] 46 | else: 47 | # select the most similar option 48 | extraction = get_most_similar(extraction, choices) 49 | assert extraction in choices 50 | 51 | elif answer_type == 'integer': 52 | try: 53 | extraction = str(int(float(extraction))) 54 | except: 55 | extraction = None 56 | 57 | elif answer_type == 'float': 58 | try: 59 | extraction = str(round(float(extraction), int(precision))) 60 | except: 61 | extraction = None 62 | 63 | elif answer_type == 'list': 64 | try: 65 | extraction = str(extraction) 66 | except: 67 | extraction = None 68 | 69 | return extraction 70 | 71 | 72 | def safe_equal(prediction, answer): 73 | """ 74 | Check if the prediction is equal to the answer, even if they are of different types 75 | """ 76 | try: 77 | if prediction == answer: 78 | return True 79 | return False 80 | except Exception as e: 81 | print(e) 82 | return False 83 | 84 | 85 | def MathVistaEvaluate(pred, label, meta_data): 86 | normalized_answer = normalize_extracted_answer( 87 | pred, meta_data["choices"], meta_data["question_type"], meta_data["answer_type"], meta_data["precision"] 88 | ) 89 | correct = safe_equal(normalized_answer, label) 90 | return correct 91 | -------------------------------------------------------------------------------- /src_evaluation/POPE_evaluation.py: -------------------------------------------------------------------------------- 1 | def POPEEvaluate(text, label, meta_data): 2 | # Only keep the first sentence 3 | if text.find('.') != -1: 4 | text = text.split('.')[0] 5 | 6 | text = text.replace(',', '') 7 | words = text.split(' ') 8 | if 'No' in words or 'not' in words or 'no' in words: 9 | answer = 'no' 10 | else: 11 | answer = 'yes' 12 | 13 | if answer == label: 14 | return True 15 | else: 16 | return False 17 | -------------------------------------------------------------------------------- /src_evaluation/PlotQA_evaluation.py: -------------------------------------------------------------------------------- 1 | from .VSR_evaluation import VSREvaluate 2 | 3 | 4 | def PlotQAEvaluate(answer, label, meta_data): 5 | if label in ['Yes', 'No', ]: 6 | label = {'Yes': 'True', 'No': 'False'}[label] 7 | return VSREvaluate(answer, label, meta_data) 8 | else: 9 | try: 10 | label = float(label) 11 | answer = float(answer) 12 | return label == answer 13 | except: 14 | return str(label).lower() == str(answer).lower() 15 | -------------------------------------------------------------------------------- /src_evaluation/SceMQA_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def safe_equal(prediction, answer): 5 | try: 6 | if prediction == answer: 7 | return True 8 | return False 9 | except Exception as e: 10 | print(e) 11 | return False 12 | 13 | 14 | def SceMQA_evaluate(answer, label, meta_data): 15 | extraction = None 16 | 17 | letter = re.findall(r'\(([a-zA-Z])\)', answer) 18 | if len(letter) > 0: 19 | extraction = letter[0].upper() 20 | 21 | if extraction is None: 22 | letter = re.search(r'\"[a-zA-Z]\"', answer) 23 | if letter: 24 | extraction = letter.group() 25 | 26 | if extraction is None: # we don't have options, we can't match options, so we just extract first letter anyway 27 | letter = re.search(r'[a-zA-Z]', answer) 28 | if letter: 29 | extraction = letter.group() 30 | 31 | return safe_equal(extraction, label) 32 | -------------------------------------------------------------------------------- /src_evaluation/ScienceQA_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from .MathVista_evaluation import get_most_similar 4 | 5 | 6 | def safe_equal(prediction, answer): 7 | try: 8 | if prediction == answer: 9 | return True 10 | return False 11 | except Exception as e: 12 | print(e) 13 | return False 14 | 15 | 16 | def ScienceQA_evaluate(extraction, label, meta_data): 17 | choices = meta_data['choices'] 18 | 19 | # extract "A" from "(A) text" 20 | letter = re.findall(r'\(([a-zA-Z])\)', extraction) 21 | if len(letter) > 0: 22 | extraction = letter[0].upper() 23 | 24 | # also try to extract \"A\" from '"A"' 25 | letter = re.search(r'\"[a-zA-Z]\"', extraction) 26 | if letter: 27 | extraction = letter.group() 28 | 29 | options = [chr(ord('A') + i) for i in range(len(choices))] 30 | assert label in options 31 | 32 | if extraction not in options: 33 | # select the most similar option 34 | choice = get_most_similar(extraction, choices) 35 | extraction = options[choices.index(choice)] 36 | assert extraction in options 37 | 38 | return safe_equal(extraction, label) 39 | -------------------------------------------------------------------------------- /src_evaluation/TallyQA_evaluation.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | from word2number import w2n 4 | 5 | 6 | def safe_equal(prediction, answer): 7 | """ 8 | Check if the prediction is equal to the answer, even if they are of different types 9 | """ 10 | try: 11 | if prediction == answer: 12 | return True 13 | return False 14 | except Exception as e: 15 | print(e) 16 | return False 17 | 18 | 19 | def TallyQAEvaluate(answer, label, meta_data): 20 | try: 21 | answer = str(int(float(answer))) 22 | except: 23 | try: 24 | answer = w2n.word_to_num(''.join([a for a in answer if a not in string.punctuation])) 25 | answer = str(int(answer)) 26 | except: 27 | answer = None 28 | 29 | label = str(int(float(label))) 30 | return safe_equal(answer, label) 31 | -------------------------------------------------------------------------------- /src_evaluation/VQA_evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def general_postprocessing(prediction): 5 | prediction = str(prediction) 6 | 7 | prediction = prediction.replace('\n', ' ') 8 | prediction = prediction.replace('\t', ' ') 9 | prediction = prediction.strip() 10 | prediction = prediction.lower() 11 | 12 | if prediction == 'true': 13 | prediction = 'yes' 14 | elif prediction == 'false': 15 | prediction = 'no' 16 | return prediction 17 | 18 | 19 | # For evaluation 20 | contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", 21 | "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", 22 | "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", 23 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", 24 | "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", 25 | "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", 26 | "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", 27 | "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", 28 | "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 29 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", 30 | "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", 31 | "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", 32 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", 33 | "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", 34 | "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", 35 | "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 36 | "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", 37 | "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", 38 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", 39 | "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", 40 | "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", 41 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", 42 | "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", 43 | "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", 44 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", 45 | "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", 46 | "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", 47 | "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", 48 | "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 49 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", 50 | "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", 51 | "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", 52 | "youre": "you're", "youve": "you've"} 53 | 54 | manualMap = {'none': '0', 55 | 'zero': '0', 56 | 'one': '1', 57 | 'two': '2', 58 | 'three': '3', 59 | 'four': '4', 60 | 'five': '5', 61 | 'six': '6', 62 | 'seven': '7', 63 | 'eight': '8', 64 | 'nine': '9', 65 | 'ten': '10' 66 | } 67 | articles = ['a', 68 | 'an', 69 | 'the' 70 | ] 71 | 72 | periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 73 | commaStrip = re.compile("(\d)(\,)(\d)") 74 | punct = [';', r"/", '[', ']', '"', '{', '}', 75 | '(', ')', '=', '+', '\\', '_', '-', 76 | '>', '<', '@', '`', ',', '?', '!'] 77 | 78 | max_words = 50 79 | 80 | 81 | def processPunctuation(inText): 82 | outText = inText 83 | for p in punct: 84 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(commaStrip, inText) != None): 85 | outText = outText.replace(p, '') 86 | else: 87 | outText = outText.replace(p, ' ') 88 | outText = periodStrip.sub("", outText, re.UNICODE) 89 | return outText 90 | 91 | 92 | def processDigitArticle(inText): 93 | outText = [] 94 | tempText = inText.lower().split() 95 | for word in tempText: 96 | word = manualMap.setdefault(word, word) 97 | if word not in articles: 98 | outText.append(word) 99 | else: 100 | pass 101 | for wordId, word in enumerate(outText): 102 | if word in contractions: 103 | outText[wordId] = contractions[word] 104 | outText = ' '.join(outText) 105 | return outText 106 | 107 | 108 | def post_process(prediction, stem=True): 109 | """ 110 | Code from https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py, 111 | as indicated here https://okvqa.allenai.org/leaderboard.html 112 | :return: 113 | """ 114 | prediction = general_postprocessing(prediction) 115 | 116 | prediction = prediction.replace('\n', ' ') 117 | prediction = prediction.replace('\t', ' ') 118 | prediction = prediction.strip() 119 | prediction = processPunctuation(prediction) 120 | prediction = processDigitArticle(prediction) 121 | return prediction 122 | 123 | 124 | def VQAEvaluate(answer, label, meta_data): 125 | if meta_data["src_dataset"] == "TextVQA": 126 | try: 127 | processed_answer = post_process(answer) 128 | correct_count = 0 129 | for l in label: 130 | processed_label = post_process(l) 131 | if processed_answer == processed_label: 132 | correct_count += 1 133 | if correct_count >= 3: 134 | return True 135 | except: 136 | return False 137 | else: 138 | try: 139 | processed_answer = post_process(answer) 140 | for l in label: 141 | processed_label = post_process(l) 142 | if processed_answer == processed_label: 143 | return True 144 | except: 145 | return False 146 | 147 | return False 148 | -------------------------------------------------------------------------------- /src_evaluation/VSR_evaluation.py: -------------------------------------------------------------------------------- 1 | def VSREvaluate(text, label, meta_data): 2 | # Only keep the first sentence 3 | if text.find('.') != -1: 4 | text = text.split('.')[0] 5 | 6 | text = text.replace(',', '') 7 | words = text.split(' ') 8 | if 'False' in words or 'false' in words or 'No' in words or 'not' in words or 'no' in words: 9 | answer = 'False' 10 | else: 11 | answer = 'True' 12 | 13 | if answer == label: 14 | return True 15 | else: 16 | return False 17 | -------------------------------------------------------------------------------- /src_evaluation/WeMathEvaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def safe_equal(prediction, answer): 5 | """ 6 | Check if the prediction is equal to the answer, even if they are of different types 7 | """ 8 | try: 9 | if prediction == answer: 10 | return True 11 | return False 12 | except Exception as e: 13 | print(e) 14 | return False 15 | 16 | 17 | def WeMathEvaluate(answer, label, meta_data): 18 | extraction = None 19 | 20 | letter = re.findall(r'\(([a-zA-Z])\)', answer) 21 | if len(letter) > 0: 22 | extraction = letter[0].upper() 23 | 24 | if extraction is None: 25 | letter = re.search(r'\"[a-zA-Z]\"', answer) 26 | if letter: 27 | extraction = letter.group() 28 | 29 | if extraction is None: # we don't have options, we can't match options, so we just extract first letter anyway 30 | letter = re.search(r'[a-zA-Z]', answer) 31 | if letter: 32 | extraction = letter.group() 33 | 34 | return safe_equal(extraction, label) 35 | -------------------------------------------------------------------------------- /src_evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from src_evaluation.CLEVR_evaluation import CLEVREvaluate 4 | from src_evaluation.EmbSpatial_evaluation import EmbSpatial_evaluation 5 | from src_evaluation.FigureQA_evaluation import FigureQAEvaluate 6 | from src_evaluation.GQA_evaluation import GQAEvaluate 7 | from src_evaluation.HallusionBench_evaluation import HallusionBenchEvaluate 8 | from src_evaluation.MMMU_evaluation import MMMU_evaluate 9 | from src_evaluation.MMVet_evaluation import MMVetEvaluate 10 | from src_evaluation.MathVision_Evaluation import MathVisionEvaluate 11 | from src_evaluation.MathVista_evaluation import MathVistaEvaluate 12 | from src_evaluation.POPE_evaluation import POPEEvaluate 13 | from src_evaluation.PlotQA_evaluation import PlotQAEvaluate 14 | from src_evaluation.SceMQA_evaluation import SceMQA_evaluate 15 | from src_evaluation.ScienceQA_evaluation import ScienceQA_evaluate 16 | from src_evaluation.TallyQA_evaluation import TallyQAEvaluate 17 | from src_evaluation.VQA_evaluation import VQAEvaluate 18 | from src_evaluation.VSR_evaluation import VSREvaluate 19 | from src_evaluation.WeMathEvaluation import WeMathEvaluate 20 | 21 | 22 | def get_evaluation_method(dataset): 23 | if dataset == "MathVista": 24 | return MathVistaEvaluate 25 | elif dataset == "POPE": 26 | return POPEEvaluate 27 | elif dataset == "HallusionBench": 28 | return HallusionBenchEvaluate 29 | elif dataset == "WeMath": 30 | return WeMathEvaluate 31 | elif dataset == "MathVision": 32 | return MathVisionEvaluate 33 | elif dataset == "MMVet": 34 | return MMVetEvaluate 35 | elif dataset == "MMMU": 36 | return MMMU_evaluate 37 | elif dataset == "ScienceQA": 38 | return ScienceQA_evaluate 39 | elif dataset == "SceMQA": 40 | return SceMQA_evaluate 41 | elif dataset == "EmbSpatial": 42 | return EmbSpatial_evaluation 43 | elif dataset == "TallyQA": 44 | return TallyQAEvaluate 45 | elif dataset == "VSR": 46 | return VSREvaluate 47 | elif dataset == "TextVQA": 48 | return VQAEvaluate 49 | elif dataset == "DocVQA": 50 | return VQAEvaluate 51 | elif dataset == "GQA": 52 | return GQAEvaluate 53 | elif dataset == "CLEVR": 54 | return CLEVREvaluate 55 | elif dataset == "ChartQA": 56 | return VQAEvaluate 57 | elif dataset == "FigureQA": 58 | return FigureQAEvaluate 59 | elif dataset == "PlotQA": 60 | return PlotQAEvaluate 61 | else: 62 | raise ValueError(f"Dataset not found: {dataset}") 63 | 64 | 65 | def evaluate(pred, label, meta_data): 66 | evaluate_func = get_evaluation_method(meta_data['src_dataset']) 67 | return evaluate_func(pred, label, meta_data) -------------------------------------------------------------------------------- /static/examples.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/VISCO/12e64a9072bb5b8dbb28c63467a7058916378c1b/static/examples.1.jpg -------------------------------------------------------------------------------- /static/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlusLabNLP/VISCO/12e64a9072bb5b8dbb28c63467a7058916378c1b/static/teaser.jpg -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import multiprocessing 3 | import socket 4 | import subprocess 5 | import time 6 | from io import BytesIO 7 | 8 | import anthropic 9 | import google.generativeai as genai 10 | import torch 11 | import tqdm 12 | from PIL import Image 13 | from openai import OpenAI 14 | 15 | args = None 16 | 17 | 18 | def gemini_infer(image, query): 19 | try: 20 | model = genai.GenerativeModel(model_name=args.model) 21 | image_ = Image.open(BytesIO(base64.b64decode(image))) 22 | response = model.generate_content([query, image_]) 23 | return response.text 24 | except: 25 | time.sleep(1) 26 | try: 27 | model = genai.GenerativeModel(model_name=args.model) 28 | image_ = Image.open(BytesIO(base64.b64decode(image))) 29 | response = model.generate_content([query, image_]) 30 | return response.text 31 | except: 32 | print("Warning! gemini infer does not work") 33 | return "TODO" 34 | 35 | 36 | def claude_func(image, query): 37 | client = anthropic.Anthropic(api_key=args.api_key) 38 | 39 | image_data = base64.b64decode(image) 40 | with BytesIO(image_data) as img_buffer: 41 | img = Image.open(img_buffer).convert("RGB") 42 | with BytesIO() as output_buffer: 43 | img.save(output_buffer, format='JPEG') 44 | image_str = base64.b64encode(output_buffer.getvalue()).decode('utf8') 45 | 46 | messages = [ 47 | { 48 | "role": "user", 49 | "content": [ 50 | { 51 | "type": "image", 52 | "source": { 53 | "type": "base64", 54 | "media_type": "image/jpeg", 55 | "data": image_str, 56 | }, 57 | }, 58 | { 59 | "type": "text", 60 | "text": query, 61 | } 62 | ], 63 | } 64 | ] 65 | 66 | try: 67 | completion = client.messages.create( 68 | model=args.model, 69 | max_tokens=512, 70 | messages=messages, 71 | ) 72 | except Exception as e: 73 | print("Error") 74 | print(e) 75 | time.sleep(60) 76 | completion = client.messages.create( 77 | model=args.model, 78 | max_tokens=512, 79 | messages=messages, 80 | ) 81 | 82 | return completion.content[0].text 83 | 84 | 85 | def func(obj): 86 | if len(obj) == 2: 87 | image, query = obj 88 | response2 = query2 = None 89 | else: 90 | assert len(obj) == 4 91 | image, query, response2, query2 = obj 92 | 93 | if args.model.startswith("gemini"): 94 | return gemini_infer(image, query) 95 | elif args.model.startswith("claude"): 96 | return claude_func(image, query) 97 | 98 | client = OpenAI(api_key=args.api_key, base_url=args.base_url) 99 | messages = [ 100 | { 101 | "role": "user", 102 | "content": [ 103 | {"type": "text", "text": query}, 104 | ], 105 | }, 106 | ] 107 | if image is not None: 108 | messages[0]['content'].append({ 109 | "type": "image_url", 110 | "image_url": { 111 | "url": f"data:image/jpeg;base64,{image}", 112 | }, 113 | }) 114 | 115 | if response2 is not None: 116 | assert query2 is not None 117 | messages += [{ 118 | "role": "assistant", 119 | "content": [ 120 | {"type": "text", "text": response2}, 121 | ], 122 | }, { 123 | "role": "user", 124 | "content": [ 125 | {"type": "text", "text": query2}, 126 | ], 127 | }] 128 | else: 129 | assert query2 is None 130 | 131 | if args.model == 'auto': 132 | model = client.models.list().data[0].id 133 | else: 134 | model = args.model 135 | 136 | try: 137 | completion = client.chat.completions.create( 138 | model=model, 139 | messages=messages, 140 | temperature=0.7, 141 | ) 142 | except: 143 | time.sleep(1) 144 | try: 145 | completion = client.chat.completions.create( 146 | model=model, 147 | messages=messages, 148 | temperature=0.7, 149 | ) 150 | except Exception as e: 151 | print("Warning! infer does not work") 152 | print("Error:") 153 | print(e) 154 | return "TODO" 155 | 156 | return completion.choices[0].message.content 157 | 158 | 159 | def get_pool(n_proc): 160 | class DummyPool: 161 | imap = map 162 | 163 | def __enter__(self): 164 | return self 165 | 166 | def __exit__(self, type, value, traceback): 167 | pass 168 | 169 | if n_proc == 0: 170 | return DummyPool() 171 | else: 172 | return multiprocessing.Pool(n_proc) 173 | 174 | 175 | def infer(queries, images, given_args): 176 | global args 177 | args = given_args 178 | 179 | if args.model.startswith("gemini"): 180 | genai.configure(api_key=args.api_key) 181 | 182 | responses = [] 183 | assert len(images) == len(queries) 184 | with get_pool(args.n_proc) as p: 185 | for response in tqdm.tqdm(p.imap(func, zip(images, queries)), total=len(images)): 186 | responses.append(response) 187 | if len(responses) <= 5: 188 | print("\n--- Example output:", len(responses)) 189 | print(responses[-1]) 190 | 191 | return responses 192 | 193 | 194 | def find_available_port(): 195 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 196 | s.bind(('0.0.0.0', 0)) 197 | return s.getsockname()[1] 198 | 199 | 200 | def launch_locally(backend, model): 201 | port = find_available_port() 202 | 203 | if backend == 'lmdeploy': 204 | cmd = ['lmdeploy', 'serve', 'api_server', model, '--server-port', str(port), 205 | '--tp', str(torch.cuda.device_count()), ] 206 | if 'prometheus' in model: 207 | cmd += ['--chat-template', 'llava-v1', ] 208 | elif backend == 'vllm': 209 | cmd = ['vllm', 'serve', model, '--port', str(port), '--dtype', 'auto', '--api-key', 'YOUR_API_KEY', 210 | '--trust-remote-code', '--tensor-parallel-size', str(torch.cuda.device_count()), ] 211 | if '3.2' in model or 'nvlm' in model.lower(): 212 | cmd += ['--enforce-eager', '--max-num-seqs', '32', '--max_model_len', '40000', ] 213 | else: 214 | assert backend == 'sglang' 215 | cmd = ['python', '-m', 'sglang.launch_server', '--model-path', model, '--port', str(port), 216 | '--chat-template=chatml-llava', ] 217 | if '7b' in model.lower(): 218 | tp = 1 219 | if 'llava-critic' in model: 220 | cmd += ['--tokenizer-path', 'lmms-lab/llava-onevision-qwen2-7b-ov', ] 221 | elif '11b' in model.lower() or '13b' in model.lower(): 222 | tp = 2 223 | else: 224 | assert '72b' in model.lower() 225 | tp = 4 226 | if 'llava-critic' in model: 227 | cmd += ['--tokenizer-path', 'lmms-lab/llava-onevision-qwen2-72b-ov-sft', ] 228 | assert torch.cuda.device_count() % tp == 0 229 | dp = torch.cuda.device_count() // tp 230 | cmd += ['--tp', str(tp), '--dp', str(dp), ] 231 | 232 | process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL) 233 | 234 | while True: 235 | try: 236 | _ = OpenAI(api_key='YOUR_API_KEY', base_url=f'http://0.0.0.0:{port}/v1').models.list() 237 | print("> launched. Proceed") 238 | return process, port 239 | except: 240 | pass 241 | time.sleep(5) 242 | 243 | 244 | def get_answer_format(item): 245 | DEFAULT_ANSWER = "a single number, word or phrase" 246 | BINARY_ANSWER = "either \"Yes\" or \"No\"" 247 | MULTI_CHOICE_ANSWER = 'in letter form of the choice selected, e.g., "A", "B", "C", "D"' 248 | INTEGER_ANSWER = "an integer number, e.g. 1, 2, 3" 249 | dataset = item['meta_data']['src_dataset'] 250 | 251 | if dataset in ["VSR", "FigureQA", "POPE", "HallusionBench", ]: 252 | # a few datasets with only yes or no 253 | return BINARY_ANSWER 254 | 255 | elif dataset in ["WeMath", "ScienceQA", "SceMQA", "EmbSpatial", ]: 256 | # a few datasets completely multi-choice 257 | return MULTI_CHOICE_ANSWER 258 | 259 | elif dataset == "PlotQA": 260 | if item['label'] in ['Yes', 'No', ]: 261 | return BINARY_ANSWER 262 | else: 263 | return DEFAULT_ANSWER 264 | 265 | elif dataset == "MathVista": 266 | if item['meta_data']["question_type"] == "multi_choice": 267 | return MULTI_CHOICE_ANSWER 268 | elif item['meta_data']["answer_type"] == "integer": 269 | return INTEGER_ANSWER 270 | elif item['meta_data']["answer_type"] == "float": 271 | return "a decimal number, e.g., 1.23, 1.34, 1.45" 272 | else: 273 | assert item['meta_data']["answer_type"] == "list" 274 | return "a list, e.g., [1, 2, 3], [1.2, 1.3, 1.4]" 275 | 276 | elif dataset == "MMVet": 277 | return DEFAULT_ANSWER 278 | 279 | elif dataset == "MathVision": 280 | return "an integer number, e.g. -5, a decimal number, e.g. 3.5, or a coordinate, e.g. (1, 2)" 281 | 282 | elif dataset == "MMMU": 283 | if item["meta_data"]["question_type"] == "multiple-choice": 284 | return MULTI_CHOICE_ANSWER 285 | else: 286 | return DEFAULT_ANSWER 287 | 288 | elif dataset == "TallyQA": 289 | return INTEGER_ANSWER 290 | 291 | elif dataset in ["TextVQA", "DocVQA", ]: 292 | return "a word, a phrase, or a short concise sentence" 293 | 294 | elif dataset == "GQA": 295 | return DEFAULT_ANSWER 296 | 297 | elif dataset == "CLEVR": 298 | try: 299 | _ = int(item['label']) 300 | return INTEGER_ANSWER 301 | except: 302 | if item['label'] in ['yes', 'no', ]: 303 | return BINARY_ANSWER 304 | else: 305 | return DEFAULT_ANSWER 306 | 307 | elif dataset == "ChartQA": 308 | return DEFAULT_ANSWER 309 | 310 | else: 311 | raise ValueError(f"Dataset not found: {dataset}") 312 | --------------------------------------------------------------------------------