├── .idea
└── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── assets
├── SampleCases.png
├── image.png
├── model_result.png
└── statistic.png
├── evaluation
└── metrics_eval.py
├── model
├── cambrian1.py
├── eaglex5.py
├── gemini_infer.py
├── gpt4o_infer.py
├── internvl.py
├── llava_next.py
├── llava_ov.py
└── qwen.py
├── model_inference.sh
├── requirements.txt
└── utils
├── __init__.py
├── infer_utils.py
└── path_utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🎨 ColorBench
2 |
3 | [**📖 Paper**](https://arxiv.org/abs/2504.10514) | [**🤗 Dataset**](https://huggingface.co/datasets/umd-zhou-lab/ColorBench)
4 |
5 |
6 |
7 |
8 |
9 | This repo contains the official evaluation code and dataset for the paper ["ColorBench: Can VLMs See and Understand the Colorful World? A Comprehensive Benchmark for Color Perception, Reasoning, and Robustness"](https://arxiv.org/abs/2504.10514)
10 | Chinese Version: [[知乎]](https://zhuanlan.zhihu.com/p/1895794713593885012)
11 |
12 |
13 | ## Highlights
14 | - 🔥 **More than 5,800 image-text questions** covering diverse application scenarios and practical challenges for VLMs evaluation.
15 | - 🔥 **3 categories and 11 tasks** for various color-centric capabilities evaluation including **Perception**, **Reasoning**, and **Robustness**.
16 |
17 | ## Findings
18 | - 🔎 **The scaling law for language and vision part:** The scaling law still holds for color understanding but is much weaker and mainly depends on the language model parts.
19 | - 🔎 **Absolute performance gap between different sizes of models:** The absolute performances of different VLMs are relatively low, and the gaps between different models (open-source vs. proprietary, small vs. large) are not large.
20 | - 🔎 **Introducing reasoning steps:** Adding reasoning steps can still improve VLMs' performance on ColorBench tasks, even for color robustness.
21 | - 🔎 **The effect of color clues:** Color clues are indeed leveraged by VLMs in most of the tasks. However, in color illusion and mimicry tasks, colors might mislead VLMs to give wrong answers, and converting colorful images into grayscale can improve the accuracy.
22 |
23 |
24 | ## Dataset Introduction
25 |
26 |
27 |
28 |
29 | ColorBench is the first benchmark explicitly designed to comprehensively evaluate the color understanding capabilities of VLMs across three key dimensions: **Color Perception**, **Color Reasoning**, and **Color Robustness**.
30 | This benchmark consists of 1,448 instances and more than 5,800 image-text questions spanning 11 diverse tasks (Color Recognition, Color Extraction, Object Recognition, Color Proportion, Color Comparison, Color Counting, Object Counting, Color Illusion, Color Mimicry, Color Blindness, and Color Robustness).
31 | For the Color Perception and Color Reasoning categories, each instance contains an image, a question, and multiple-choice (3 to 6) options, with only one correct answer.
32 | For Color Robustness, each instance consists of 10 multiple-choice image-text questions including a seed image and 9 edited images with color changes.
33 |
34 |
35 |
36 |
37 |
38 | ## VLMs' Result
39 | We conduct an extensive evaluation of 32 vision-language models (VLMs) spanning a range of large language model (LLM) sizes and architectures. Our evaluation includes state-of-the-art models such as GPT-4o, Gemini-2-flash, LLaVA-OV, LLaVA-NEXT, Cambrian-1, InternVL2, Qwen2.5-VL, and Eagle. This selection covers a diverse set of architectures, including both proprietary and open-source models, enabling a comprehensive assessment of their reasoning capabilities under different computational constraints.
40 |
41 |
42 |
43 |
44 | ## Evaluation Pipeline
45 | We provide detailed instructions for evaluation as follows.
46 |
47 | ### Environment
48 | Install packages that are necessary for VLMs.
49 | ```bash
50 | conda create -n colorbench python=3.11
51 | conda activate colorbench
52 |
53 | pip3 install -r requirements.txt
54 | pip install flash-attn==2.7.3 --no-build-isolation
55 | ```
56 |
57 | ### View Dataset
58 | We release ColorBench on Huggingface, which more than 5,800 image-text pairs. You can download and view the dataset from Huggingface by the following command:
59 |
60 | ```python
61 | from datasets import load_dataset
62 |
63 | dataset = load_dataset("umd-zhou-lab/ColorBench", "test")
64 |
65 | # Evaluation samples
66 | print(dataset["test"][0])
67 | ```
68 |
69 | ### Inference with Models
70 | Inference codes for some VLMs are prepared in [model](model/). You can use script [model_inference.sh](model_inference.sh) to run inference on our benchmark.
71 |
72 | ```bash
73 | bash model_inference.sh
74 | ```
75 | Before running [model_inference.sh](model_inference.sh), modify the necessary folder paths and API keys in model_inference.sh:
76 |
77 | ```bash
78 | ROOT_DIR="PATH/TO/ROOT_DIR" # Needed only if using json for model inference
79 | RESULT_DIR="PATH/TO/RESULT_DIR" # Path to save the model inference results
80 | GEMINI_API_KEY="YOUR_API_KEY"
81 | GPT4O_API_KEY="YOUR_API_KEY"
82 | ```
83 |
84 | The cache folder for models and dataset can be modified in [path_utils.py](utils/path_utils.py). If not changed, the cache folder path defaults to the home directory.”:
85 |
86 | ```python
87 | CACHE_DIR = "YOUR_HF_CACHE_FOLDER"
88 | ```
89 |
90 | ### Evaluation Results
91 | We provide script [metrics_eval.py](evaluation/metrics_eval.py) to run evaluation metrics based on the inferenced resultsYou can run this command to get the final result:
92 |
93 | ```bash
94 | python3 evaluation/metrics_eval.py --result_dir=RESULT_DIR --save_dir=SAVE_DIR
95 | ```
96 | The final result will be saved in folder ```SAVE_DIR```.
97 |
98 | ## Citation
99 |
100 | ```bibtex
101 | @misc{liang2025colorbenchvlmsunderstandcolorful,
102 | title={ColorBench: Can VLMs See and Understand the Colorful World? A Comprehensive Benchmark for Color Perception, Reasoning, and Robustness},
103 | author={Yijun Liang and Ming Li and Chenrui Fan and Ziyue Li and Dang Nguyen and Kwesi Cobbina and Shweta Bhardwaj and Jiuhai Chen and Fuxiao Liu and Tianyi Zhou},
104 | year={2025},
105 | eprint={2504.10514},
106 | archivePrefix={arXiv},
107 | primaryClass={cs.CV},
108 | url={https://arxiv.org/abs/2504.10514},
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyi-lab/ColorBench/809101f19b4acea4a8a986c6c12d644dfdb67b16/__init__.py
--------------------------------------------------------------------------------
/assets/SampleCases.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyi-lab/ColorBench/809101f19b4acea4a8a986c6c12d644dfdb67b16/assets/SampleCases.png
--------------------------------------------------------------------------------
/assets/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyi-lab/ColorBench/809101f19b4acea4a8a986c6c12d644dfdb67b16/assets/image.png
--------------------------------------------------------------------------------
/assets/model_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyi-lab/ColorBench/809101f19b4acea4a8a986c6c12d644dfdb67b16/assets/model_result.png
--------------------------------------------------------------------------------
/assets/statistic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyi-lab/ColorBench/809101f19b4acea4a8a986c6c12d644dfdb67b16/assets/statistic.png
--------------------------------------------------------------------------------
/evaluation/metrics_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 |
5 | from utils.infer_utils import *
6 | import json
7 | import pickle
8 | import numpy as np
9 | import pandas as pd
10 | from typing import List, Tuple, Dict, Optional
11 | from argparse import ArgumentParser
12 |
13 |
14 | def process_PandR(list_models: List, result_dir: str, list_tasks_PR: List, ):
15 |
16 | dict_cate_task = dict()
17 | dict_model_cnt = dict()
18 | list_acc = []
19 | for model_name in list_models:
20 | dict_model_cnt[model_name] = dict() # save acc for each task
21 | json_path = os.path.join(result_dir, f"{model_name}.json")
22 | with open(json_path, 'r') as f:
23 | dict_res = json.load(f)
24 |
25 | for sample_idx, img_meta in dict_res.items():
26 | category = img_meta['type']
27 | if category == 'Robustness':
28 | continue
29 |
30 | task = img_meta['task']
31 | if task not in dict_cate_task:
32 | dict_cate_task[task] = category
33 |
34 | # stores correct / incorrect cnt for each task
35 | if task not in dict_model_cnt[model_name]:
36 | dict_model_cnt[model_name][task] = [0, 0] # [correct, incorrect]
37 |
38 | options = img_meta['choices']
39 | gt_ans = img_meta['answer'].replace('(', '').replace(')', '').lower()
40 | model_ans = img_meta['model_ans']
41 |
42 | if 'cot' in model_name: # for gpt / gemini cot
43 | model_ans_new = extract_letter_cot(model_ans)
44 | if_correct = (model_ans_new == gt_ans)
45 | elif 'gpt' in model_name or 'gemini' in model_name: # for gpt / gemini
46 | model_ans_new = extract_letter(model_ans)
47 | if_correct = (model_ans_new == gt_ans)
48 | else: # for open sourced model
49 | model_ans_new, if_correct, find_res = parse_res(model_ans=model_ans, options=options, gt_ans=gt_ans)
50 |
51 | if if_correct:
52 | # correct
53 | dict_model_cnt[model_name][task][0] += 1
54 | else:
55 | # incorrect
56 | dict_model_cnt[model_name][task][1] += 1
57 |
58 | # print model acc:
59 | dict_acc = {key: [sum(item), item[0], item[0]/sum(item)] for key, item in dict_model_cnt[model_name].items() if sum(item) > 0}
60 | for task, (sum_cnt, cor_cnt, acc) in dict_acc.items():
61 | list_acc.append([model_name, task, sum_cnt, cor_cnt, np.round(acc, 6)])
62 |
63 | dict_formated = {key: [0]*len(list_tasks_PR) for key in list_models}
64 | dict_formated_cor = {key: [0]*len(list_tasks_PR) for key in list_models}
65 | dict_formated_sum = {key: [0]*len(list_tasks_PR) for key in list_models}
66 |
67 | for item_meta in list_acc:
68 | model_name, task, sum_cnt, cor_cnt, acc = item_meta
69 | t_idx = list_tasks_PR.index(task)
70 | dict_formated[model_name][t_idx] = acc
71 | dict_formated_cor[model_name][t_idx] = cor_cnt
72 | dict_formated_sum[model_name][t_idx] = sum_cnt
73 |
74 | # calculate perception / reasoning / overall acc
75 | for model_name in dict_formated_sum.keys():
76 | percept_cnt = [0, 0]
77 | reasoning_cnt = [0, 0]
78 | overall_cnt = [0, 0]
79 |
80 | for task_id, task in enumerate(list_tasks_PR):
81 | category = dict_cate_task[task]
82 | cor_cnt = dict_formated_cor[model_name][task_id]
83 | all_cnt = dict_formated_sum[model_name][task_id]
84 | overall_cnt[0] += cor_cnt
85 | overall_cnt[1] += all_cnt
86 | if category.lower() == 'perception':
87 | percept_cnt[0] += cor_cnt
88 | percept_cnt[1] += all_cnt
89 | if category.lower() == 'reasoning':
90 | reasoning_cnt[0] += cor_cnt
91 | reasoning_cnt[1] += all_cnt
92 |
93 | try:
94 | perception_acc = np.round(percept_cnt[0]/percept_cnt[1], 6)
95 | reasoning_acc = np.round(reasoning_cnt[0]/reasoning_cnt[1], 6)
96 | overall_acc = np.round(overall_cnt[0]/overall_cnt[1], 6)
97 | except:
98 | perception_acc, reasoning_acc, overall_acc = 0, 0, 0
99 | dict_formated[model_name].extend([perception_acc, reasoning_acc, overall_acc])
100 | return dict_formated
101 |
102 |
103 | def process_robustness(list_models: List, result_dir: str, dict_formated: Dict, ):
104 | # Count robustness
105 | dict_id_newres = dict()
106 | dict_model_cnt = dict()
107 | list_rob = []
108 | for model_name in list_models:
109 | dict_model_cnt[model_name] = [0, 0] # save cnt for each model
110 | dict_id_newres[model_name] = dict() # save cnt for each model
111 | dict_correct = dict()
112 | json_path = os.path.join(result_dir, f"{model_name}.json")
113 | with open(json_path, 'r') as f:
114 | dict_res = json.load(f)
115 |
116 | for sample_idx, img_meta in dict_res.items():
117 | category = img_meta['type']
118 | if category != 'Robustness':
119 | continue
120 |
121 | if_ori = False
122 | img_name = img_meta["filename"].split('/')[-1].split('.')[0]
123 | if '_' not in img_name:
124 | # original image
125 | if_ori = True
126 | img_id = int(img_name)
127 | else:
128 | # recolored image
129 | img_id = int(img_name.split('_')[0])
130 |
131 | options = img_meta['choices']
132 | gt_ans = img_meta['answer'].replace('(', '').replace(')', '').lower()
133 | model_ans = img_meta['model_ans']
134 |
135 | if img_id not in dict_id_newres[model_name]:
136 | dict_id_newres[model_name][img_id] = [gt_ans, None, []] # [gt answer, result for original image, list of results for recolored image]
137 | dict_correct[img_id] = [False, []] # [correct / not for original image, list of bool]
138 |
139 | if 'cot' in model_name: # for gpt / gemini cot
140 | model_ans_new = extract_letter_cot(model_ans)
141 | if_correct = (model_ans_new == gt_ans)
142 | elif 'gpt' in model_name or 'gemini' in model_name: # for gpt / gemini
143 | model_ans_new = extract_letter(model_ans)
144 | if_correct = (model_ans_new == gt_ans)
145 | else: # for open sourced model
146 | model_ans_new, if_correct, find_res = parse_res(model_ans=model_ans, options=options, gt_ans=gt_ans)
147 |
148 | if if_ori:
149 | dict_id_newres[model_name][img_id][1] = model_ans_new
150 | dict_correct[img_id][0] = if_correct
151 | else:
152 | dict_id_newres[model_name][img_id][2].append(model_ans_new)
153 | dict_correct[img_id][1].append(if_correct)
154 |
155 | # cnt robust answers
156 | for img_id, list_res in dict_id_newres[model_name].items():
157 | gt_ans, ori_ans, list_new_ans = list_res
158 | ori_bool, list_new_bool = dict_correct[img_id]
159 | if ori_bool and False not in list_new_bool:
160 | # robust
161 | dict_model_cnt[model_name][0] += 1
162 | else:
163 | # not
164 | dict_model_cnt[model_name][1] += 1
165 |
166 | # print model robust:
167 | dict_robust = {key: [sum(item), item[0], item[0]/sum(item)] for key, item in dict_model_cnt.items() if sum(item) > 0}
168 | for model_name, (sum_cnt, rob_cnt, robustness) in dict_robust.items():
169 | list_rob.append([model_name, sum_cnt, rob_cnt, np.round(robustness, 6)])
170 |
171 | for item_meta in list_rob:
172 | model_name, sum_cnt, rob_cnt, robustness = item_meta
173 | dict_formated[model_name].append(robustness)
174 |
175 | return dict_formated
176 |
177 |
178 | if __name__ == '__main__':
179 | parser = ArgumentParser()
180 | parser.add_argument("--result_dir", type=str, default="RESULT_DIR")
181 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
182 | args = parser.parse_args()
183 |
184 | result_dir = args.result_dir
185 | save_dir = args.save_dir
186 |
187 | list_tasks_PR = ["Color Recognition", "Color Extraction", "Object Recognition", "Color Proportion", "Color Comparison", "Color Counting", "Object Counting", "Color Illusion", "Color Mimicry", "Color Blindness",]
188 |
189 | # Load model inference results
190 | list_jsons = os.listdir(result_dir)
191 | list_models = [item.split('.')[0] for item in list_jsons if 'json' in item]
192 |
193 | # Count acc for each task
194 | dict_formated = process_PandR(list_models=list_models, result_dir=result_dir, list_tasks_PR=list_tasks_PR)
195 | dict_formated = process_robustness(list_models=list_models, result_dir=result_dir, dict_formated=dict_formated, )
196 |
197 | # Save to csv
198 | df_result = pd.DataFrame(dict_formated,).T
199 | df_result.columns = list_tasks_PR + ['Perception Acc', 'Reasoning Acc', 'Overall Acc'] + ['Color Robustness']
200 | df_result = df_result.reset_index().rename(columns={'index': 'model_type',})
201 | df_result.to_csv(os.path.join(save_dir, 'inference.csv'), index=False)
202 |
--------------------------------------------------------------------------------
/model/cambrian1.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from cambrian.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
16 | from cambrian.conversation import conv_templates
17 | from cambrian.model.builder import load_pretrained_model
18 | from cambrian.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
19 |
20 | import traceback
21 | from PIL import Image
22 | from typing import Optional
23 | import torch
24 | import warnings
25 | import re
26 | import copy
27 | import json
28 | from tqdm import tqdm
29 | from argparse import ArgumentParser
30 | import numpy as np
31 |
32 | warnings.filterwarnings("ignore")
33 | device = "cuda" if torch.cuda.is_available() else "cpu"
34 | torch.manual_seed(53)
35 |
36 | base_prompt = "You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter that corresponds to the correct option. Do not repeat the entire answer. Answer with the option's letter from the given choices directly.\n"
37 | cot_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. \nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X). Do not include ( or ) in the response except for the answer.\n"
38 |
39 |
40 | def load_models(model_path: str, device: str, load_quantized: bool = False, ):
41 | model_path = os.path.expanduser(model_path)
42 | model_name = get_model_name_from_path(model_path)
43 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
44 |
45 | # find option-related tokens:
46 | vocab = tokenizer.get_vocab()
47 | tokens_with_ids, tokens_cluster = find_token_mappings(vocab)
48 |
49 | return model, image_processor, tokenizer, tokens_with_ids, tokens_cluster
50 |
51 |
52 | def load_image(datatype: str, data: Dict, image_processor, model_config):
53 | if datatype != 'json':
54 | image = data[f"image"].convert("RGB")
55 | else:
56 | image = Image.open(data[f"img_path"]).convert("RGB")
57 |
58 | image_size = [image.size]
59 | image_tensor = process_images([image], image_processor, model_config)
60 | return image_size, image_tensor
61 |
62 |
63 |
64 | def prepare_prompt(d_prompt: str, model_config, conv_mode: str, m_method: Optional[str]=None):
65 | if m_method is None:
66 | prompt = base_prompt + d_prompt
67 | else:
68 | # chain of thoughts
69 | prompt = cot_prompt + d_prompt
70 |
71 | if model_config.mm_use_im_start_end:
72 | prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + d_prompt
73 | else:
74 | prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
75 |
76 | conv = conv_templates[conv_mode].copy()
77 | conv.append_message(conv.roles[0], prompt)
78 | conv.append_message(conv.roles[1], None)
79 | prompt = conv.get_prompt()
80 | return prompt
81 |
82 |
83 | def prepare_model_input(prompt: str, tokenizer, device: str):
84 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
85 | return input_ids
86 |
87 |
88 | def process_output(generation_output, tokenizer, input_ids, update_ans_ids: bool = False):
89 | # replied answer
90 | outputs = generation_output.sequences[0].detach().cpu()
91 | model_answer = tokenizer.decode(outputs, skip_special_tokens=True)
92 | logits = generation_output.scores[0][0] # shape: |V|
93 | return model_answer, logits
94 |
95 |
96 | if __name__ == '__main__':
97 | parser = ArgumentParser()
98 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
99 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
100 | parser.add_argument("--modeltype", type=str, default="cambrian_3b")
101 | parser.add_argument("--datatype", type=str, default="dataset")
102 | parser.add_argument("--load_quantized", type=bool, default=True)
103 | args = parser.parse_args()
104 |
105 | root_dir = args.root_dir
106 | save_dir = args.save_dir
107 | load_quantized = args.load_quantized
108 | modeltype = args.modeltype
109 | datatype = args.datatype
110 | m_method = None # None for fast-thinking, 'CoT' for slow-thinking
111 |
112 | # defind model
113 | model_path = "nyu-visionx/cambrian-8b"
114 | conv_mode = "llama_3"
115 | if modeltype =='cambrian_3b':
116 | model_path = "nyu-visionx/cambrian-phi3-3b"
117 | conv_mode = "phi3"
118 | elif modeltype == 'cambrian_8b':
119 | model_path = "nyu-visionx/cambrian-8b"
120 | conv_mode = "llama_3"
121 | elif modeltype =='cambrian_13b':
122 | model_path = "nyu-visionx/cambrian-13b"
123 | conv_mode = "vicuna_v1"
124 | elif modeltype == 'cambrian_34b':
125 | model_path = "nyu-visionx/cambrian-34b"
126 | conv_mode = "chatml_direct"
127 | print(f"Evaluating model: {model_path}")
128 | os.makedirs(save_dir, exist_ok=True)
129 |
130 | #############################
131 | # load model & tokenizer
132 | model, image_processor, tokenizer, tokens_with_ids, tokens_cluster = load_models(model_path=model_path, device=device, load_quantized=load_quantized)
133 |
134 | #############################
135 | # load data
136 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
137 |
138 | #############################
139 | # Start inference
140 | dict_result = dict()
141 | for i, data in enumerate(tqdm(eval_dataset)):
142 | try:
143 | # load image
144 | image_size, image = load_image(datatype=datatype, data=data, image_processor=image_processor, model_config=model.config)
145 |
146 | # prepare prompt
147 | prompt = prepare_prompt(d_prompt=data["prompt"], model_config=model.config, conv_mode=conv_mode, m_method=m_method)
148 |
149 | # tokenize input
150 | input_ids = prepare_model_input(prompt=prompt, tokenizer=tokenizer, device=device)
151 |
152 | # inference
153 | with torch.no_grad():
154 | generation_output = model.generate(input_ids, images=image, image_sizes=image_size, num_beams=1, do_sample=False, temperature=0, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, )
155 |
156 | ####################
157 | # Process answer
158 | model_answer, logits = process_output(generation_output, tokenizer, input_ids=input_ids, update_ans_ids=True)
159 |
160 | # calculate probs within options
161 | probs, logits_options, dict_option_prob = calculate_probs(logits=logits, list_options=data['choices'], tokens_with_ids=tokens_with_ids, tokens_cluster=tokens_cluster)
162 |
163 | dict_result[i] = copy.deepcopy(data)
164 | if 'image' in dict_result[i]:
165 | del dict_result[i]['image']
166 |
167 | dict_result[i]["model_ans"] = model_answer
168 |
169 | except Exception as e:
170 | print(e)
171 | print("skipping", i)
172 | torch.cuda.empty_cache()
173 | traceback.print_exc()
174 | sys.exit(-1)
175 |
176 | # save results to json
177 | write_file = os.path.join(save_dir, f"{modeltype}.json")
178 | print(f"write to file {write_file}")
179 | with open(write_file, "w") as f:
180 | json.dump(dict_result, f, indent=4)
181 |
--------------------------------------------------------------------------------
/model/eaglex5.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from eagle import conversation as conversation_lib
16 | from eagle.constants import DEFAULT_IMAGE_TOKEN
17 | from eagle.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
18 | from eagle.conversation import conv_templates, SeparatorStyle
19 | from eagle.model.builder import load_pretrained_model
20 | from eagle.utils import disable_torch_init
21 | from eagle.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images, KeywordsStoppingCriteria
22 | from transformers import TextIteratorStreamer
23 | from threading import Thread
24 |
25 | import traceback
26 | from PIL import Image
27 | import torch
28 | import warnings
29 | import json
30 | import copy
31 | from tqdm import tqdm
32 | from argparse import ArgumentParser
33 | import numpy as np
34 |
35 | warnings.filterwarnings("ignore")
36 | device = "cuda" if torch.cuda.is_available() else "cpu"
37 | torch.manual_seed(53)
38 |
39 | base_prompt = "You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter that corresponds to the correct option. Do not repeat the entire answer. \n"
40 | cot_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. \nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X). Do not include ( or ) in the response except for the answer. \n"
41 |
42 |
43 | def load_models(model_path: str, device: str, load_quantized: bool = False, ):
44 | model_name = get_model_name_from_path(model_path)
45 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, False, False)
46 |
47 | # find option-related tokens:
48 | vocab = tokenizer.get_vocab()
49 | tokens_with_ids, tokens_cluster = find_token_mappings(vocab)
50 |
51 | return model, tokenizer, image_processor, context_len, tokens_with_ids, tokens_cluster
52 |
53 |
54 | def load_image(datatype: str, data: Dict, ):
55 | if datatype != 'json':
56 | image = data[f"image"].convert("RGB")
57 | else:
58 | image = Image.open(data[f"img_path"]).convert("RGB")
59 |
60 | return image
61 |
62 |
63 | def prepare_prompt(d_prompt: str, model, conv_mode: str, m_method: Optional[str]=None):
64 | if m_method is None:
65 | input_prompt = base_prompt + d_prompt
66 | else:
67 | # chain of thoughts
68 | input_prompt = cot_prompt + d_prompt
69 |
70 | if model.config.mm_use_im_start_end:
71 | input_prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + input_prompt
72 | else:
73 | input_prompt = DEFAULT_IMAGE_TOKEN + '\n' + input_prompt
74 |
75 | conv = conv_templates[conv_mode].copy()
76 | conv.append_message(conv.roles[0], input_prompt)
77 | conv.append_message(conv.roles[1], None)
78 | prompt = conv.get_prompt()
79 | return prompt
80 |
81 |
82 | def prepare_model_input(prompt: str, image, model, tokenizer, image_processor, device: str):
83 | image_tensor = process_images([image], image_processor, model.config)[0]
84 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
85 |
86 | input_ids = input_ids.to(device=device, non_blocking=True)
87 | image_tensor = image_tensor.to(dtype=torch.float16, device=device, non_blocking=True)
88 |
89 | return input_ids, image_tensor
90 |
91 |
92 | def process_output(generation_output, tokenizer, update_ans_ids: bool = False, m_method=None):
93 | # replied answer
94 | outputs = generation_output.sequences[0].detach().cpu()
95 | decode_res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
96 | model_answer = decode_res[0]
97 | logits = generation_output.scores[0][0] # shape: |V|
98 | if model_answer.lower() not in ('a', 'b', 'c', 'd', 'e') and len(decode_res) > 1:
99 | model_answer = ''.join(decode_res)
100 | if ' ' not in model_answer:
101 | model_answer = ' '.join(decode_res)
102 | if m_method is None:
103 | model_answer = unify_ans(model_answer)
104 | return model_answer, logits
105 |
106 |
107 | if __name__ == '__main__':
108 | parser = ArgumentParser()
109 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
110 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
111 | parser.add_argument("--modeltype", type=str, default="eaglex5_7b")
112 | parser.add_argument("--datatype", type=str, default="dataset")
113 | parser.add_argument("--load_quantized", type=bool, default=True)
114 | args = parser.parse_args()
115 |
116 | root_dir = args.root_dir
117 | save_dir = args.save_dir
118 | load_quantized = args.load_quantized
119 | modeltype = args.modeltype
120 | datatype = args.datatype
121 | m_method = None # None for fast-thinking, 'CoT' for slow-thinking
122 |
123 | model_path = "NVEagle/Eagle-X5-7B"
124 | if modeltype == 'eaglex5_7b':
125 | model_path = "NVEagle/Eagle-X5-7B"
126 | conv_mode = "vicuna_v1"
127 | elif modeltype == 'eaglex4_8b':
128 | model_path = "NVEagle/Eagle-X4-8B-Plus"
129 | conv_mode = "llama3"
130 | elif modeltype == 'eaglex4_13b':
131 | model_path = "NVEagle/Eagle-X4-13B-Plus"
132 | conv_mode = "vicuna_v1"
133 | elif modeltype == 'eaglex5_34b':
134 | model_path = "NVEagle/Eagle-X5-34B-Plus"
135 | conv_mode = "yi_34b_chatml_direct"
136 | print(f"Evaluating model: {model_path}")
137 | os.makedirs(save_dir, exist_ok=True)
138 |
139 | #############################
140 | # load model & tokenizer
141 | model, tokenizer, image_processor, context_len, tokens_with_ids, tokens_cluster = load_models(model_path=model_path, device=device, load_quantized=load_quantized)
142 |
143 | #############################
144 | # load data
145 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
146 |
147 | #############################
148 | # Start inference
149 | dict_result = dict()
150 | for i, data in enumerate(tqdm(eval_dataset)):
151 | try:
152 | # load image
153 | image = load_image(datatype=datatype, data=data, )
154 |
155 | # prepare prompt
156 | prompt = prepare_prompt(d_prompt=data["prompt"], model=model, conv_mode=conv_mode, m_method=m_method)
157 |
158 | # tokenize input
159 | input_ids, image_tensor = prepare_model_input(image=image, prompt=prompt, model=model, image_processor=image_processor, tokenizer=tokenizer, device=device)
160 |
161 | # inference
162 | with torch.no_grad():
163 | generation_output = model.generate(input_ids.unsqueeze(0), images=image_tensor.unsqueeze(0), image_sizes=[image.size], min_length=1, do_sample=False, use_cache=True, temperature=0, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, )
164 |
165 | ####################
166 | # Process answer
167 | model_answer, logits = process_output(generation_output, tokenizer, update_ans_ids=True)
168 |
169 | # calculate probs within options
170 | probs, logits_options, dict_option_prob = calculate_probs(logits=logits, list_options=data['choices'], tokens_with_ids=tokens_with_ids, tokens_cluster=tokens_cluster)
171 |
172 | dict_result[i] = copy.deepcopy(data)
173 | if 'image' in dict_result[i]:
174 | del dict_result[i]['image']
175 |
176 | dict_result[i]["model_ans"] = model_answer
177 |
178 | except Exception as e:
179 | print(e)
180 | print("skipping", i)
181 | torch.cuda.empty_cache()
182 | traceback.print_exc()
183 | sys.exit(-1)
184 |
185 | # save results to json
186 | write_file = os.path.join(save_dir, f"{modeltype}.json")
187 | print(f"write to file {write_file}")
188 | with open(write_file, "w") as f:
189 | json.dump(dict_result, f, indent=4)
190 |
--------------------------------------------------------------------------------
/model/gemini_infer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from google import genai
16 | import json
17 | import argparse
18 | import os
19 | import time
20 | import copy
21 | import re
22 | from tqdm import tqdm
23 | from PIL import Image
24 |
25 |
26 | # Function to load the image
27 | def load_image(datatype: str, data):
28 | """Loads an image file using PIL."""
29 | if datatype != 'json':
30 | image = data[f"image"].convert("RGB")
31 | else:
32 | image = Image.open(data[f"img_path"]).convert("RGB")
33 |
34 | return image
35 |
36 |
37 | def ask_gemini_about_image(api_key, datatype: str, data, use_cot=False, max_retries=3):
38 | """Sends an image and a question to Gemini for visual question answering with error handling."""
39 | image = load_image(datatype=datatype, data=data, )
40 | if image is None:
41 | return "Error: Image file missing or invalid"
42 |
43 | client = genai.Client(api_key=api_key)
44 | model = "gemini-2.0-flash" # Using the flash version for faster response
45 |
46 | # Modify the question based on whether CoT is enabled
47 | if use_cot:
48 | # Chain of Thought prompt
49 | modified_question = data['prompt'] + "\nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X)."
50 | else:
51 | # Original direct prompt
52 | modified_question = data['prompt'] + "\nAnswer with only the letter that corresponds to the correct option. Do not repeat the entire answer. Do not explain your reasoning."
53 |
54 | for attempt in range(max_retries):
55 | try:
56 | response = client.models.generate_content(
57 | model=model,
58 | contents=[modified_question, image]
59 | )
60 | return response.text.strip()
61 |
62 | except Exception as e:
63 | print(f"Error on attempt {attempt+1}: {e}")
64 | time.sleep(2) # Wait before retrying
65 |
66 | return "Error: Failed after multiple attempts"
67 |
68 |
69 | def process_json(api_key, datatype, eval_dataset, save_dir, use_cot=False):
70 | """Processes a JSON file to get Gemini answers and compute accuracy."""
71 |
72 | # Create output directory if it doesn't exist
73 | os.makedirs(save_dir, exist_ok=True)
74 | if use_cot:
75 | modeltype = 'gemini_cot'
76 | else:
77 | modeltype = 'gemini'
78 |
79 | dict_result = dict()
80 | for i, data in enumerate(tqdm(eval_dataset)):
81 |
82 | # Process with CoT or without based on the parameter
83 | if use_cot:
84 | print(f"Processing ID: {i} with CoT")
85 |
86 | gemini_answer = ask_gemini_about_image(api_key, datatype, data, use_cot=True)
87 | print(f"Gemini CoT Answer: {gemini_answer}")
88 | else:
89 | print(f"Processing ID: {i} with CoT")
90 | gemini_answer = ask_gemini_about_image(api_key, datatype, data, use_cot=False)
91 | print(f"Gemini Answer: {gemini_answer}")
92 |
93 | dict_result[i] = copy.deepcopy(data)
94 | if 'image' in dict_result[i]:
95 | del dict_result[i]['image']
96 |
97 | dict_result[i]["model_ans"] = gemini_answer
98 |
99 | # Save the updated JSON file
100 | write_file = os.path.join(save_dir, f"{modeltype}.json")
101 | print(f"write to file {write_file}")
102 | with open(write_file, "w") as file:
103 | json.dump(data, file, indent=4)
104 |
105 |
106 | if __name__ == "__main__":
107 | parser = argparse.ArgumentParser(description="Process images and questions using Gemini.")
108 |
109 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
110 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
111 | parser.add_argument("--datatype", type=str, default="dataset")
112 | parser.add_argument("--api_key", type=str, default='', help="Gemini API key")
113 | parser.add_argument("--use_cot", action="store_true", help="Use Chain of Thought reasoning")
114 |
115 | args = parser.parse_args()
116 |
117 | root_dir = args.root_dir
118 | save_dir = args.save_dir
119 | datatype = args.datatype
120 | api_key = args.api_key
121 | use_cot = args.use_cot
122 |
123 | #############################
124 | # load data
125 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
126 |
127 | #############################
128 | # Start inference
129 | process_json(api_key, datatype, eval_dataset, save_dir, use_cot)
130 |
--------------------------------------------------------------------------------
/model/gpt4o_infer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from openai import OpenAI
16 | import json
17 | import base64
18 | from io import BytesIO
19 | import argparse
20 | import os
21 | import copy
22 | import time
23 | import re
24 | from tqdm import tqdm
25 | from PIL import Image
26 |
27 |
28 | # Function to load the image
29 | def encode_image(datatype: str, data):
30 | """Loads an image file using PIL."""
31 | if datatype != 'json':
32 | image = data[f"image"].convert("RGB")
33 | buffered = BytesIO()
34 | image.save(buffered, format="JPEG")
35 | img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
36 | else:
37 | with open(data[f"img_path"], "rb") as image_file:
38 | img_str = base64.b64encode(image_file.read()).decode("utf-8")
39 |
40 | return img_str
41 |
42 |
43 | def ask_gpt4o_about_image(api_key, datatype: str, data, use_cot=False, max_retries=3):
44 | """Sends an image and a question to GPT-4o for visual question answering with error handling."""
45 | base64_image = encode_image(datatype, data)
46 | if base64_image is None:
47 | return "Error: Image file missing"
48 |
49 | client = OpenAI(api_key=api_key)
50 |
51 | # Modify the question based on whether CoT is enabled
52 | if use_cot:
53 | # Chain of Thought prompt - MODIFIED as requested
54 | modified_question = data['prompt'] + "\nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X)."
55 | else:
56 | # Original direct prompt
57 | modified_question = data['prompt'] + "\nAnswer with only the letter that corresponds to the correct option. Do not repeat the entire answer. Do not explain your reasoning."
58 |
59 | for attempt in range(max_retries):
60 | try:
61 | response = client.chat.completions.create(
62 | model="gpt-4o",
63 | messages=[
64 | {
65 | "role": "user",
66 | "content": [
67 | {"type": "text", "text": modified_question},
68 | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
69 | ],
70 | }
71 | ],
72 | )
73 | return response.choices[0].message.content.strip()
74 |
75 | except Exception as e:
76 | print(f"Error on attempt {attempt+1}: {e}")
77 | time.sleep(2) # Wait before retrying
78 |
79 | return "Error: Failed after multiple attempts"
80 |
81 |
82 | def process_json(api_key, datatype, eval_dataset, save_dir, use_cot=False):
83 | """Processes a JSON file to get GPT-4o answers and compute accuracy."""
84 |
85 | # Create output directory if it doesn't exist
86 | os.makedirs(save_dir, exist_ok=True)
87 | if use_cot:
88 | modeltype = 'gpt4o_cot'
89 | else:
90 | modeltype = 'gpt4o'
91 |
92 | dict_result = dict()
93 | for i, data in enumerate(tqdm(eval_dataset)):
94 |
95 | # Process with CoT or without based on the parameter
96 | if use_cot:
97 | print(f"Processing ID: {i} with CoT")
98 |
99 | gpt4o_answer = ask_gpt4o_about_image(api_key, datatype, data, use_cot=True)
100 | print(f"GPT-4o CoT Answer: {gpt4o_answer}")
101 | else:
102 | print(f"Processing ID: {i}")
103 |
104 | gpt4o_answer = ask_gpt4o_about_image(api_key, datatype, data, use_cot=False)
105 | print(f"GPT-4o Answer: {gpt4o_answer}")
106 |
107 | dict_result[i] = copy.deepcopy(data)
108 | if 'image' in dict_result[i]:
109 | del dict_result[i]['image']
110 |
111 | dict_result[i]["model_ans"] = gpt4o_answer
112 |
113 | # Save the updated JSON file
114 | write_file = os.path.join(save_dir, f"{modeltype}.json")
115 | print(f"write to file {write_file}")
116 | with open(write_file, "w") as file:
117 | json.dump(data, file, indent=4)
118 |
119 |
120 | if __name__ == "__main__":
121 | parser = argparse.ArgumentParser(description="Process images and questions using GPT-4o.")
122 |
123 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
124 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
125 | parser.add_argument("--datatype", type=str, default="dataset")
126 | parser.add_argument("--api_key", type=str, default='', help="OpenAI API key")
127 | parser.add_argument("--use_cot", action="store_true", help="Use Chain of Thought reasoning")
128 |
129 | args = parser.parse_args()
130 |
131 | root_dir = args.root_dir
132 | save_dir = args.save_dir
133 | datatype = args.datatype
134 | api_key = args.api_key
135 | use_cot = args.use_cot
136 |
137 | #############################
138 | # load data
139 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
140 |
141 | #############################
142 | # Start inference
143 | process_json(api_key, datatype, eval_dataset, save_dir, use_cot)
--------------------------------------------------------------------------------
/model/internvl.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from transformers import AutoTokenizer, AutoModel
16 | import traceback
17 | from PIL import Image
18 | import torch
19 | import torchvision.transforms as T
20 | from torchvision.transforms.functional import InterpolationMode
21 | import warnings
22 | import json
23 | import re
24 | import math
25 | import copy
26 | from tqdm import tqdm
27 | from argparse import ArgumentParser
28 |
29 | warnings.filterwarnings("ignore")
30 |
31 | device = "cuda" if torch.cuda.is_available() else "cpu"
32 | torch.manual_seed(53)
33 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
34 | IMAGENET_STD = (0.229, 0.224, 0.225)
35 | generation_config = dict(max_new_tokens=1024, do_sample=False)
36 | IMG_START_TOKEN = '
'
37 | IMG_END_TOKEN = ''
38 | IMG_CONTEXT_TOKEN = ''
39 | base_prompt = "Answer with only the letter that corresponds to the correct option. Do not repeat the entire answer. Do not explain your reasoning."
40 | cot_prompt = "Think step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X). Do not include ( or ) in the response except for the answer.\n"
41 |
42 |
43 | def build_transform(input_size):
44 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
45 | transform = T.Compose([T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])
46 | return transform
47 |
48 |
49 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
50 | best_ratio_diff = float('inf')
51 | best_ratio = (1, 1)
52 | area = width * height
53 | for ratio in target_ratios:
54 | target_aspect_ratio = ratio[0] / ratio[1]
55 | ratio_diff = abs(aspect_ratio - target_aspect_ratio)
56 | if ratio_diff < best_ratio_diff:
57 | best_ratio_diff = ratio_diff
58 | best_ratio = ratio
59 | elif ratio_diff == best_ratio_diff:
60 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
61 | best_ratio = ratio
62 | return best_ratio
63 |
64 |
65 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
66 | orig_width, orig_height = image.size
67 | aspect_ratio = orig_width / orig_height
68 |
69 | # calculate the existing image aspect ratio
70 | target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
71 | i * j <= max_num and i * j >= min_num)
72 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
73 |
74 | # find the closest aspect ratio to the target
75 | target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
76 |
77 | # calculate the target width and height
78 | target_width = image_size * target_aspect_ratio[0]
79 | target_height = image_size * target_aspect_ratio[1]
80 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
81 |
82 | # resize the image
83 | resized_img = image.resize((target_width, target_height))
84 | processed_images = []
85 | for i in range(blocks):
86 | box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
87 | ((i % (target_width // image_size)) + 1) * image_size,
88 | ((i // (target_width // image_size)) + 1) * image_size)
89 | # split the image
90 | split_img = resized_img.crop(box)
91 | processed_images.append(split_img)
92 | assert len(processed_images) == blocks
93 | if use_thumbnail and len(processed_images) != 1:
94 | thumbnail_img = image.resize((image_size, image_size))
95 | processed_images.append(thumbnail_img)
96 | return processed_images
97 |
98 |
99 | def load_internvl_image(image_cv2=None, image_file=None, input_size=448, max_num=12):
100 | if image_file is not None:
101 | image = Image.open(image_file).convert('RGB')
102 | if image_cv2 is not None:
103 | image = image_cv2
104 |
105 | transform = build_transform(input_size=input_size)
106 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
107 | pixel_values = [transform(image) for image in images]
108 | pixel_values = torch.stack(pixel_values)
109 | return pixel_values
110 |
111 |
112 | def load_image(datatype: str, data: Dict, ):
113 | if datatype != 'json':
114 | image = data[f"image"].convert("RGB")
115 | pixel_values = load_internvl_image(image_cv2=data[f"image"].convert("RGB"), max_num=12).to(
116 | torch.bfloat16).cuda()
117 | else:
118 | image = Image.open(data[f"img_path"]).convert("RGB")
119 | pixel_values = load_internvl_image(image_file=data[f"img_path"], max_num=12).to(torch.bfloat16).cuda()
120 | return image, pixel_values
121 |
122 |
123 | def split_model(model_name):
124 | device_map = {}
125 | world_size = torch.cuda.device_count()
126 | num_layers = {
127 | 'InternVL2-1B': 24, 'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32,
128 | 'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80,
129 | 'InternVL2_5-1B': 24, 'InternVL2_5-2B': 24, 'InternVL2_5-4B': 36, 'InternVL2_5-8B': 32,
130 | 'InternVL2_5-26B': 48, 'InternVL2_5-38B': 64, 'InternVL2_5-78B': 80}[model_name]
131 | # Since the first GPU will be used for ViT, treat it as half a GPU.
132 | num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
133 | num_layers_per_gpu = [num_layers_per_gpu] * world_size
134 | num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
135 | layer_cnt = 0
136 | for i, num_layer in enumerate(num_layers_per_gpu):
137 | for j in range(num_layer):
138 | device_map[f'language_model.model.layers.{layer_cnt}'] = i
139 | layer_cnt += 1
140 | device_map['vision_model'] = 0
141 | device_map['mlp1'] = 0
142 | device_map['language_model.model.tok_embeddings'] = 0
143 | device_map['language_model.model.embed_tokens'] = 0
144 | device_map['language_model.output'] = 0
145 | device_map['language_model.model.norm'] = 0
146 | device_map['language_model.model.rotary_emb'] = 0
147 | device_map['language_model.lm_head'] = 0
148 | device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
149 |
150 | return device_map
151 |
152 |
153 | def load_models(model_path: str, device: str, load_quantized: bool = False, ):
154 | if model_path not in ('OpenGVLab/InternVL2-Llama3-76B', "OpenGVLab/InternVL2_5-78B"):
155 | model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, load_in_8bit=load_quantized, low_cpu_mem_usage=True, use_flash_attn=True, trust_remote_code=True, cache_dir=CACHE_DIR).to(device).eval()
156 | else:
157 | device_map = split_model(model_path.split('/')[-1])
158 |
159 | model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map=device_map, load_in_8bit=True, low_cpu_mem_usage=True, use_flash_attn=True, trust_remote_code=True, cache_dir=CACHE_DIR).eval()
160 |
161 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
162 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
163 | model.img_context_token_id = img_context_token_id
164 |
165 | # find option-related tokens:
166 | vocab = tokenizer.get_vocab()
167 | tokens_with_ids, tokens_cluster = find_token_mappings(vocab)
168 |
169 | return model, tokenizer, tokens_with_ids, tokens_cluster
170 |
171 |
172 | if __name__ == '__main__':
173 | parser = ArgumentParser()
174 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
175 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
176 | parser.add_argument("--modeltype", type=str, default="internvl2_1b")
177 | parser.add_argument("--datatype", type=str, default="dataset")
178 | parser.add_argument("--load_quantized", type=bool, default=False)
179 | args = parser.parse_args()
180 |
181 | root_dir = args.root_dir
182 | save_dir = args.save_dir
183 | load_quantized = args.load_quantized
184 | modeltype = args.modeltype
185 | datatype = args.datatype
186 | m_method = None # None for fast-thinking, 'CoT' for slow-thinking
187 |
188 | model_path = "OpenGVLab/InternVL2-8B"
189 | if modeltype == 'internvl2_1b':
190 | model_path = "OpenGVLab/InternVL2-1B"
191 | elif modeltype == 'internvl2_2b':
192 | model_path = "OpenGVLab/InternVL2-2B"
193 | elif modeltype == 'internvl2_4b':
194 | model_path = "OpenGVLab/InternVL2-4B"
195 | elif modeltype == 'internvl2_8b':
196 | model_path = "OpenGVLab/InternVL2-8B"
197 | elif modeltype == 'internvl2_26b':
198 | model_path = "OpenGVLab/InternVL2-26B"
199 | elif modeltype == 'internvl2_40b':
200 | model_path = "OpenGVLab/InternVL2-40B"
201 | elif modeltype == 'internvl2_72b':
202 | model_path = "OpenGVLab/InternVL2-Llama3-76B"
203 | elif modeltype == 'internvl25_1b':
204 | model_path = "OpenGVLab/InternVL2_5-1B"
205 | elif modeltype == 'internvl25_2b':
206 | model_path = "OpenGVLab/InternVL2_5-2B"
207 | elif modeltype == 'internvl25_4b':
208 | model_path = "OpenGVLab/InternVL2_5-4B"
209 | elif modeltype == 'internvl25_8b':
210 | model_path = "OpenGVLab/InternVL2_5-8B"
211 | elif modeltype == 'internvl25_26b':
212 | model_path = "OpenGVLab/InternVL2_5-26B"
213 | elif modeltype == 'internvl25_38b':
214 | model_path = "OpenGVLab/InternVL2_5-38B"
215 | elif modeltype == 'internvl25_72b':
216 | model_path = "OpenGVLab/InternVL2_5-78B"
217 | print(f"Evaluating model: {model_path}")
218 | os.makedirs(save_dir, exist_ok=True)
219 |
220 | #############################
221 | # load model & tokenizer
222 | model, tokenizer, tokens_with_ids, tokens_cluster = load_models(model_path=model_path, device=device, load_quantized=load_quantized)
223 |
224 | #############################
225 | # load data
226 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
227 |
228 | #############################
229 | # Start inference
230 | dict_result = dict()
231 | for i, data in enumerate(tqdm(eval_dataset)):
232 | try:
233 | # load image
234 | image, pixel_values = load_image(datatype=datatype, data=data, )
235 |
236 | if m_method is None:
237 | prompt = base_prompt + data["prompt"]
238 | else:
239 | # chain of thoughts
240 | prompt = cot_prompt + data["prompt"]
241 |
242 | model_answer, _ = model.chat(tokenizer, pixel_values, prompt, generation_config, history=None, return_history=True)
243 | model_answer = unify_ans(model_answer)
244 |
245 | dict_result[i] = copy.deepcopy(data)
246 | if 'image' in dict_result[i]:
247 | del dict_result[i]['image']
248 |
249 | dict_result[i]["model_ans"] = model_answer
250 |
251 | except Exception as e:
252 | print(e)
253 | print("skipping", i)
254 | torch.cuda.empty_cache()
255 | traceback.print_exc()
256 | sys.exit(-1)
257 |
258 | # save results to json
259 | write_file = os.path.join(save_dir, f"{modeltype}.json")
260 | print(f"write to file {write_file}")
261 | with open(write_file, "w") as f:
262 | json.dump(dict_result, f, indent=4)
263 |
--------------------------------------------------------------------------------
/model/llava_next.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | import torch
16 | from PIL import Image
17 | import json
18 | import copy
19 | import traceback
20 | from tqdm import tqdm
21 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
22 | from argparse import ArgumentParser
23 | import warnings
24 |
25 | warnings.filterwarnings("ignore")
26 |
27 | device = "cuda" if torch.cuda.is_available() else "cpu"
28 | torch.manual_seed(53)
29 | base_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter that corresponds to the correct option. \n"
30 | cot_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. Only one option is correct. \nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X). Do not include ( or ) in the response except for the answer.\n"
31 |
32 |
33 | def load_models(model_path: str, device: str, load_quantized: bool = False, ):
34 | model = LlavaNextForConditionalGeneration.from_pretrained(model_path, low_cpu_mem_usage=True, device_map="auto", torch_dtype=torch.float16, ).eval()
35 |
36 | processor = LlavaNextProcessor.from_pretrained(model_path, cache_dir=CACHE_DIR)
37 | processor.tokenizer.padding_side = "left"
38 | tokenizer = processor.tokenizer
39 |
40 | # find option-related tokens:
41 | vocab = tokenizer.get_vocab()
42 | tokens_with_ids, tokens_cluster = find_token_mappings(vocab)
43 | return model, tokenizer, processor, tokens_with_ids, tokens_cluster
44 |
45 |
46 | def load_image(datatype: str, data: Dict, ):
47 | if datatype != 'json':
48 | image = data[f"image"].convert("RGB")
49 | else:
50 | image = Image.open(data[f"img_path"]).convert("RGB")
51 |
52 | return image
53 |
54 |
55 | def prepare_prompt(d_prompt: str, processor, m_method: Optional[str]=None):
56 | if m_method is None:
57 | prompt = base_prompt + d_prompt
58 | else:
59 | # chain of thoughts
60 | prompt = cot_prompt + d_prompt
61 |
62 | conversation = [
63 | {
64 | "role": "user",
65 | "content": [
66 | {"type": "text", "text": prompt},
67 | {"type": "image"},
68 | ],
69 | },
70 | ]
71 | prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
72 |
73 | return prompt
74 |
75 |
76 | def prepare_model_input(prompt: str, image, processor, device: str):
77 | inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
78 |
79 | return inputs
80 |
81 |
82 | def process_output(generation_output, tokenizer, processor, input_ids, m_method=None):
83 | # replied answer
84 | outputs = generation_output.sequences[0].detach().cpu()
85 | model_answer = processor.decode(outputs, skip_special_tokens=True) # output: str
86 | model_answer = model_answer.split("ASSISTANT: ")[-1].replace('[/INST]', '').strip()
87 | if len(model_answer) > 1:
88 | outputs_new = outputs[input_ids.shape[1]:]
89 | model_answer = processor.decode(outputs_new, skip_special_tokens=True) # output: str
90 |
91 | if model_answer.lower() not in ('a', 'b', 'c', 'd', 'e'):
92 | if m_method is None:
93 | model_answer = unify_ans(model_answer)
94 | logits = generation_output.scores[0][0] # shape: |V|
95 |
96 | return model_answer, logits
97 |
98 |
99 | if __name__ == '__main__':
100 |
101 | parser = ArgumentParser()
102 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
103 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
104 | parser.add_argument("--modeltype", type=str, default="llava_16_v_7b")
105 | parser.add_argument("--datatype", type=str, default="dataset")
106 | parser.add_argument("--load_quantized", type=bool, default=True)
107 | args = parser.parse_args()
108 |
109 | root_dir = args.root_dir
110 | save_dir = args.save_dir
111 | load_quantized = args.load_quantized
112 | modeltype = args.modeltype
113 | datatype = args.datatype
114 | m_method = None # None for fast-thinking, 'CoT' for slow-thinking
115 |
116 | # Load the model and processor
117 | model_path = 'llava-hf/llava-v1.6-vicuna-7b-hf'
118 | if modeltype == "llava_16_v_7b":
119 | model_path = 'llava-hf/llava-v1.6-vicuna-7b-hf'
120 | if modeltype == "llava_16_m_7b":
121 | model_path = 'llava-hf/llava-v1.6-mistral-7b-hf'
122 | elif modeltype == "llava_16_13b":
123 | model_path = 'llava-hf/llava-v1.6-vicuna-13b-hf'
124 | elif modeltype == "llava_16_34b":
125 | model_path = 'llava-hf/llava-v1.6-34b-hf'
126 | elif modeltype == "llava_16_72b":
127 | model_path = 'llava-hf/llava-next-72b-hf'
128 |
129 | print(f"Evaluating model: {model_path}")
130 | os.makedirs(save_dir, exist_ok=True)
131 |
132 | #############################
133 | # load model & tokenizer
134 | model, tokenizer, processor, tokens_with_ids, tokens_cluster = load_models(model_path=model_path, device=device, load_quantized=load_quantized)
135 |
136 | #############################
137 | # load data
138 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
139 |
140 | #############################
141 | # Start inference
142 | dict_result = {}
143 | for i, data in tqdm(enumerate(eval_dataset), total=len(eval_dataset)):
144 | try:
145 |
146 | # load image
147 | image = load_image(datatype=datatype, data=data, )
148 |
149 | # prepare prompt
150 | prompt = prepare_prompt(d_prompt=data["prompt"], processor=processor, m_method=m_method)
151 |
152 | # tokenize input
153 | inputs = prepare_model_input(prompt=prompt, image=image, processor=processor, device=model.device)
154 |
155 | with torch.no_grad():
156 | generation_output = model.generate(**inputs, do_sample=False, min_length=1, max_new_tokens=1024, return_dict_in_generate=True, output_scores=True, )
157 |
158 | ####################
159 | # Process answer
160 | model_answer, logits = process_output(generation_output, tokenizer, processor, input_ids=inputs['input_ids'], m_method=m_method)
161 |
162 | probs, logits_options, dict_option_prob = calculate_probs(logits=logits, list_options=data['choices'], tokens_with_ids=tokens_with_ids, tokens_cluster=tokens_cluster)
163 |
164 | dict_result[i] = copy.deepcopy(data)
165 | if 'image' in dict_result[i]:
166 | del dict_result[i]['image']
167 |
168 | dict_result[i]["model_ans"] = model_answer
169 |
170 | except Exception as e:
171 | print(e)
172 | print("skipping", i)
173 | torch.cuda.empty_cache()
174 | traceback.print_exc()
175 | sys.exit(-1)
176 |
177 | # Save results
178 | write_file = os.path.join(save_dir, f"{modeltype}.json")
179 | print(f"write to file {write_file}")
180 | with open(write_file, "w") as f:
181 | json.dump(dict_result, f, indent=4)
182 |
--------------------------------------------------------------------------------
/model/llava_ov.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from llava.model.builder import load_pretrained_model
16 | from llava.mm_utils import (process_images, tokenizer_image_token, )
17 | from llava.constants import (IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, )
18 | from llava.conversation import conv_templates
19 | from PIL import Image
20 | from typing import Optional
21 | import copy
22 | import torch
23 | import traceback
24 | import warnings
25 | import json
26 | from tqdm import tqdm
27 | from argparse import ArgumentParser
28 |
29 | warnings.filterwarnings("ignore")
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 | torch.manual_seed(53)
32 |
33 | conv_template = "qwen_1_5"
34 | base_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter that corresponds to the correct option.\n"
35 | cot_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. \nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X). Do not include ( or ) in the response except for the answer.\n"
36 |
37 |
38 | def load_models(model_path: str, device: str, load_quantized: bool = False, ):
39 |
40 | model_name = "llava_qwen"
41 | device_map = "auto"
42 | tokenizer, model, image_processor, max_length = load_pretrained_model(model_path, None, model_name, device_map=device_map, torch_dtype="bfloat16", )
43 | model.eval()
44 |
45 | # find option-related tokens:
46 | vocab = tokenizer.get_vocab()
47 | tokens_with_ids, tokens_cluster = find_token_mappings(vocab)
48 |
49 | return model, tokenizer, image_processor, tokens_with_ids, tokens_cluster
50 |
51 |
52 | def load_image(datatype: str, data: Dict, device: str, model, image_processor):
53 | if datatype != 'json':
54 | image = data[f"image"].convert("RGB")
55 | else:
56 | image = Image.open(data[f"img_path"]).convert("RGB")
57 |
58 | image_sizes = [image.size] # type: ignore
59 | image = process_images([image], image_processor, model.config)
60 | image = [img.to(dtype=torch.bfloat16, device=device) for img in image]
61 | return image, image_sizes
62 |
63 |
64 | def prepare_prompt(d_prompt: str, m_method: Optional[str]=None):
65 | if m_method is None:
66 | prompt = base_prompt + d_prompt
67 | else:
68 | # chain of thoughts
69 | prompt = cot_prompt + d_prompt
70 |
71 | prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
72 | conv = copy.deepcopy(conv_templates[conv_template])
73 | conv.append_message(conv.roles[0], prompt)
74 | conv.append_message(conv.roles[1], None)
75 | prompt = conv.get_prompt()
76 | return prompt
77 |
78 |
79 | def prepare_model_input(prompt: str, tokenizer, device: str):
80 | input_ids = (tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device))
81 | return input_ids
82 |
83 |
84 | def process_output(generation_output, tokenizer, update_ans_ids: bool = False, model_path='llava_ov_7b', m_method=None):
85 | # replied answer
86 | outputs = generation_output.sequences[0].detach().cpu()
87 | decode_res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
88 | model_answer = decode_res[0]
89 | if model_path != 'llava_ov_72b':
90 | model_answer = model_answer.split("ASSISTANT: ")[-1][0].strip().lower()
91 | else:
92 | model_answer = model_answer.split("ASSISTANT: ")[-1].strip().lower()
93 | if model_answer.lower() not in ('a', 'b', 'c', 'd', 'e'):
94 | model_answer = ''.join(decode_res)
95 | if m_method is None:
96 | model_answer = unify_ans(model_answer)
97 | logits = generation_output.scores[0][0] # shape: |V|
98 | return model_answer, logits
99 |
100 |
101 | if __name__ == '__main__':
102 | parser = ArgumentParser()
103 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
104 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
105 | parser.add_argument("--modeltype", type=str, default="llava_ov_1b")
106 | parser.add_argument("--datatype", type=str, default="dataset")
107 | parser.add_argument("--load_quantized", type=bool, default=True)
108 | args = parser.parse_args()
109 |
110 | root_dir = args.root_dir
111 | save_dir = args.save_dir
112 | load_quantized = args.load_quantized
113 | modeltype = args.modeltype
114 | datatype = args.datatype
115 | m_method = None # None for fast-thinking, 'CoT' for slow-thinking
116 |
117 | model_path = "lmms-lab/llava-onevision-qwen2-7b-ov"
118 | if modeltype == 'llava_ov_7b':
119 | model_path = "lmms-lab/llava-onevision-qwen2-7b-ov"
120 | elif modeltype == 'llava_ov_1b':
121 | model_path = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
122 | elif modeltype == 'llava_ov_72b':
123 | model_path = "lmms-lab/llava-onevision-qwen2-72b-ov-sft"
124 | print(f"Evaluating model: {model_path}")
125 | os.makedirs(save_dir, exist_ok=True)
126 |
127 | #############################
128 | # load model & tokenizer
129 | model, tokenizer, image_processor, tokens_with_ids, tokens_cluster = load_models(model_path=model_path, device=device, load_quantized=load_quantized)
130 |
131 | #############################
132 | # load data
133 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
134 |
135 | #############################
136 | # Start inference
137 | dict_result = dict()
138 | for i, data in enumerate(tqdm(eval_dataset)):
139 | try:
140 | # load image
141 | image, image_sizes = load_image(datatype=datatype, data=data, device=device, model=model, image_processor=image_processor)
142 |
143 | # prepare prompt
144 | prompt = prepare_prompt(d_prompt=data["prompt"], m_method=m_method)
145 |
146 | # tokenize input
147 | input_ids = prepare_model_input(prompt=prompt, tokenizer=tokenizer, device=device)
148 |
149 | with torch.no_grad():
150 | generation_output = model.generate(input_ids, images=image, image_sizes=image_sizes, do_sample=False, temperature=0, max_new_tokens=1024, return_dict_in_generate=True, output_scores=True, )
151 |
152 | ####################
153 | # Process answer
154 | model_answer, logits = process_output(generation_output, tokenizer, model_path=model_path, m_method=m_method)
155 |
156 | # calculate probs within options
157 | probs, logits_options, dict_option_prob = calculate_probs(logits=logits, list_options=data['choices'], tokens_with_ids=tokens_with_ids, tokens_cluster=tokens_cluster)
158 |
159 | dict_result[i] = copy.deepcopy(data)
160 | if 'image' in dict_result[i]:
161 | del dict_result[i]['image']
162 |
163 | dict_result[i]["model_ans"] = model_answer
164 |
165 | except Exception as e:
166 | print(e)
167 | print("skipping", i)
168 | torch.cuda.empty_cache()
169 | traceback.print_exc()
170 | sys.exit(-1)
171 |
172 | # save results to json
173 | write_file = os.path.join(save_dir, f"{modeltype}.json")
174 | print(f"write to file {write_file}")
175 | with open(write_file, "w") as f:
176 | json.dump(dict_result, f, indent=4)
177 |
--------------------------------------------------------------------------------
/model/qwen.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.insert(0, os.getcwd())
5 | from utils.path_utils import *
6 |
7 | CACHE_DIR = set_root_folder()
8 |
9 | os.environ["HF_HOME"] = CACHE_DIR
10 | os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
11 | os.environ["HF_MODULES_CACHE"] = CACHE_DIR
12 | os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
13 |
14 | from utils.infer_utils import *
15 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
16 | from qwen_vl_utils import process_vision_info
17 | import traceback
18 | from PIL import Image
19 | import torch
20 | import warnings
21 | import copy
22 | import json
23 | from tqdm import tqdm
24 | from argparse import ArgumentParser
25 | from io import BytesIO
26 | from typing import Optional
27 | import base64
28 | warnings.filterwarnings("ignore")
29 | device = "cuda" if torch.cuda.is_available() else "cpu"
30 | torch.manual_seed(53)
31 |
32 | base_prompt = "You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter that corresponds to the correct option. Do not repeat the entire answer. \n"
33 | cot_prompt = "USER: You'll be given an image, an instruction and some options. You have to select the correct one. \nThink step by step before answering. Then conclude with the letter that corresponds to the correct option. Make sure the option letter is in the parentheses like (X). Do not include ( or ) in the response except for the answer.\n"
34 |
35 |
36 | def load_models(model_path: str, ):
37 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", )
38 | processor = AutoProcessor.from_pretrained(model_path)
39 |
40 | # find option-related tokens:
41 | tokenizer = processor.tokenizer
42 | vocab = tokenizer.get_vocab()
43 | tokens_with_ids, tokens_cluster = find_token_mappings(vocab)
44 |
45 | return model, processor, tokens_with_ids, tokens_cluster
46 |
47 |
48 | def load_image(datatype: str, data: Dict, ):
49 | if datatype != 'json':
50 | image = data[f"image"].convert("RGB")
51 | else:
52 | image = Image.open(data[f"img_path"]).convert("RGB")
53 |
54 | buffered = BytesIO()
55 | image.save(buffered, format="JPEG")
56 | img_str = base64.b64encode(buffered.getvalue()).decode()
57 | return img_str
58 |
59 |
60 | def prepare_prompt(d_prompt: str, image, processor, m_method: Optional[str]=None):
61 | if m_method is None:
62 | prompt = base_prompt + d_prompt
63 | else:
64 | # chain of thoughts
65 | prompt = cot_prompt + d_prompt
66 | conversation = [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,"+image, }, {"type": "text", "text": prompt}, ],} ]
67 |
68 | # Preparation for inference
69 | prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
70 | return prompt, conversation
71 |
72 |
73 | def prepare_model_input(prompt: str, messages, processor, ):
74 | image_inputs, _ = process_vision_info(messages)
75 | inputs = processor(text=[prompt], images=image_inputs, padding=True, return_tensors="pt", )
76 | inputs = inputs.to("cuda")
77 | return inputs
78 |
79 |
80 | def process_output(generation_output, processor, input_ids, update_ans_ids: bool = False):
81 | # replied answer
82 | outputs = generation_output.sequences.detach().cpu()
83 | outputs_strimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, outputs)]
84 | model_answer = processor.tokenizer.batch_decode(outputs_strimmed, skip_special_tokens=True)[0]
85 | logits = generation_output.scores[0][0] # shape: |V|
86 | return model_answer, logits
87 |
88 |
89 | if __name__ == '__main__':
90 | parser = ArgumentParser()
91 | parser.add_argument("--root_dir", type=str, default="ROOT_DIR")
92 | parser.add_argument("--save_dir", type=str, default="SAVE_DIR")
93 | parser.add_argument("--modeltype", type=str, default="qwen25_3b")
94 | parser.add_argument("--datatype", type=str, default="dataset")
95 | parser.add_argument("--load_quantized", type=bool, default=True)
96 | args = parser.parse_args()
97 |
98 | root_dir = args.root_dir
99 | save_dir = args.save_dir
100 | load_quantized = args.load_quantized
101 | modeltype = args.modeltype
102 | datatype = args.datatype
103 | m_method = None # None for fast-thinking
104 |
105 | model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
106 | if modeltype == 'qwen25_7b':
107 | model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
108 | elif modeltype == 'qwen25_3b':
109 | model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
110 | elif modeltype == 'qwen25_72b':
111 | model_path = "Qwen/Qwen2.5-VL-72B-Instruct"
112 | # model_path = "Qwen/Qwen2.5-VL-72B-Instruct-AWQ"
113 | print(f"Evaluating model: {model_path}")
114 | os.makedirs(save_dir, exist_ok=True)
115 |
116 | #############################
117 | # load model & tokenizer
118 | model, processor, tokens_with_ids, tokens_cluster = load_models(model_path=model_path, )
119 |
120 | #############################
121 | # load data
122 | eval_dataset = load_data(data_type=datatype, root_dir=root_dir, )
123 |
124 | #############################
125 | # Start inference
126 | dict_result = dict()
127 | for i, data in enumerate(tqdm(eval_dataset)):
128 | try:
129 | # load image
130 | image = load_image(datatype=datatype, data=data, )
131 |
132 | # prepare prompt
133 | prompt, conversation = prepare_prompt(d_prompt=data["prompt"], image=image, processor=processor, m_method=m_method)
134 | # tokenize input
135 | inputs = prepare_model_input(prompt=prompt, messages=conversation, processor=processor, )
136 |
137 | # inference
138 | with torch.no_grad():
139 | generation_output = model.generate(**inputs, min_length=1, do_sample=False, temperature=0, max_new_tokens=2048, return_dict_in_generate=True, output_scores=True, )
140 |
141 | ####################
142 | # Process answer
143 | model_answer, logits = process_output(generation_output, processor, input_ids=inputs.input_ids, update_ans_ids=True)
144 |
145 | # calculate probs within options
146 | probs, logits_options, dict_option_prob = calculate_probs(logits=logits, list_options=data['choices'], tokens_with_ids=tokens_with_ids, tokens_cluster=tokens_cluster)
147 |
148 | dict_result[i] = copy.deepcopy(data)
149 | if 'image' in dict_result[i]:
150 | del dict_result[i]['image']
151 |
152 | dict_result[i]["model_ans"] = model_answer
153 |
154 | except Exception as e:
155 | print(e)
156 | print("skipping", i)
157 | torch.cuda.empty_cache()
158 | traceback.print_exc()
159 | sys.exit(-1)
160 |
161 | # save results to json
162 | write_file = os.path.join(save_dir, f"{modeltype}.json")
163 | print(f"write to file {write_file}")
164 | with open(write_file, "w") as f:
165 | json.dump(dict_result, f, indent=4)
166 |
--------------------------------------------------------------------------------
/model_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ROOT_DIR="PATH/TO/ROOT_DIR"
4 | RESULT_DIR="PATH/TO/RESULT_DIR"
5 | GEMINI_API_KEY="xxxx"
6 | GPT4O_API_KEY="xxxx"
7 |
8 | python3 model/llava_ov.py --modeltype="llava_ov_1b" --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset"
9 |
10 | python3 model/llava_next.py --modeltype="llava_16_v_7b" --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset"
11 |
12 | python3 model/internvl.py --modeltype="internvl2_1b" --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset"
13 |
14 | python3 model/eaglex5.py --modeltype="eaglex5_7b" --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset"
15 |
16 | python3 model/cambrian1.py --modeltype="cambrian_3b" --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset"
17 |
18 | python3 model/qwen.py --modeltype="qwen25_3b" --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset"
19 |
20 | python3 model/gpt4o_infer.py --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset" --api_key="${GEMINI_API_KEY}" --use_cot
21 |
22 | python3 model/gemini_infer.py --root_dir="${ROOT_DIR}" --save_dir="${RESULT_DIR}" --datatype="dataset" --api_key="${GPT4O_API_KEY}" --use_cot
23 |
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pillow
2 | pandas
3 | numpy
4 | opencv-python
5 | matplotlib
6 | bert-score
7 | inflect
8 | datasets
9 | transformers
10 | seaborn
11 | pymatting
12 | scikit-image
13 | accelerate>=0.26.0
14 | torch
15 | torchvision
16 | torchaudio
17 | av
18 | open-clip-torch==2.24.0
19 | xformers==0.0.29.post2
20 | bitsandbytes==0.45.1
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .gene_utils import *
2 | from .seg_utils import *
3 | from .eval_utils import *
4 | from .path_utils import *
5 |
--------------------------------------------------------------------------------
/utils/infer_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import torch
4 | import copy
5 | from typing import List, Dict, Tuple, Optional
6 | import torch.nn.functional as F
7 | from .path_utils import *
8 |
9 | from datasets import load_dataset
10 |
11 |
12 | # Define a function to check the token format
13 | def is_valid_token(token, ):
14 | # Check if the token matches the required format
15 | if len(token) == 1 and token.upper() in "ABCDEF": # Single letter like 'a' or 'A'
16 | return True, token.upper()
17 | if len(token) == 2 and token.startswith(" ") and token[1].upper() in "ABCDEF": # ' a' or ' A'
18 | return True, token[1].upper()
19 | if len(token) == 2 and token.startswith("_") and token[1].upper() in "ABCDEF": # ' a' or ' A'
20 | return True, token[1].upper()
21 | if len(token) == 2 and token.startswith("▁") and token[1].upper() in "ABCDEF": # ' a' or ' A'
22 | return True, token[1].upper()
23 | if len(token) == 2 and token.startswith("(") and token[1].upper() in "ABCDEF": # ' a' or ' A'
24 | return True, token[1].upper()
25 | if len(token) == 3 and token.startswith("(") and token.endswith(")") and token[1].upper() in "ABCDEF": # '(A)'
26 | return True, token[1].upper()
27 | return False, None
28 |
29 |
30 | def find_token_mappings(vocab_dict):
31 | # find option-related tokens:
32 | tokens_with_ids = dict()
33 | tokens_cluster = dict()
34 | for token, token_id in vocab_dict.items():
35 | res, token_format = is_valid_token(token)
36 | if res:
37 | tokens_with_ids[token] = token_id
38 | if token_format not in tokens_cluster:
39 | tokens_cluster[token_format] = []
40 | tokens_cluster[token_format].append(token)
41 |
42 | return tokens_with_ids, tokens_cluster
43 |
44 |
45 | def check_ans(pr_ans: str, gt_ans: str):
46 | if pr_ans.strip().lower() == gt_ans.strip().lower():
47 | return True
48 | else:
49 | return False
50 |
51 |
52 | def load_data(data_type: str, root_dir: str='', ):
53 | # load data
54 | if data_type not in ('json',):
55 | # hf dataset
56 | eval_dataset = load_dataset("umd-zhou-lab/ColorBench", split='test')
57 | else:
58 | # json
59 | with open(f"{root_dir}/all_data.json", 'r') as f:
60 | eval_dataset = json.load(f)
61 |
62 | # change image file path
63 | for i, data in enumerate(eval_dataset):
64 | img_path = os.path.join(root_dir, data['filename'])
65 | data['img_path'] = img_path
66 | return eval_dataset
67 |
68 |
69 | def calculate_probs(logits, list_options: List, tokens_with_ids: Dict, tokens_cluster: Dict):
70 | # calculate probs within options
71 | logits = logits.detach().cpu()
72 | options = [f"{chr(65 + opt_i)}" for opt_i, item in enumerate(list_options)]
73 |
74 | # Initialize a dictionary to store aggregated logits for each option
75 | aggregated_logits = {}
76 | dict_option_prob = {}
77 |
78 | for option in options:
79 | # Get all related formats of the option from tokens_cluster
80 | related_tokens = tokens_cluster.get(option, [])
81 |
82 | # Sum the logits of all formats of the option
83 | aggregated_logit = sum(logits[tokens_with_ids[token]] for token in related_tokens if token in tokens_with_ids)
84 | aggregated_logits[option] = aggregated_logit
85 | for token in related_tokens:
86 | dict_option_prob[token] = logits[tokens_with_ids[token]].detach().cpu().numpy().item()
87 |
88 | # Convert aggregated logits to a tensor
89 | logits_options = torch.tensor([aggregated_logits[option] for option in options])
90 | probs = F.softmax(logits_options, dim=0, ).detach().cpu().numpy()
91 |
92 | return probs.tolist(), logits_options.tolist(), dict_option_prob
93 |
94 |
95 | def unify_ans(answer: str, ):
96 |
97 | formated_answer = answer.replace('(', '').replace(')', '').lower()
98 | if formated_answer not in ('a', 'b', 'c', 'd', 'e', 'f'):
99 | # find the option letter
100 | match = re.search(r"\((a|b|c|d|e|f)\)", answer.lower())
101 | if match:
102 | formated_answer = match.group(0).replace('(', '').replace(')', '').lower()
103 |
104 | if formated_answer not in ('a', 'b', 'c', 'd', 'e', 'f'):
105 | # find the option letter
106 | match = re.search(r"([a-z])\) \d+", answer.lower())
107 | if match:
108 | formated_answer = match.group(1)
109 |
110 | return formated_answer
111 |
112 |
113 | def check_answer(model_ans, gt_ans, ):
114 | gt_ans = gt_ans.replace('(', '').replace(')', '').lower()
115 | model_ans = unify_ans(model_ans, )
116 |
117 | if model_ans == gt_ans:
118 | return True, model_ans
119 | else:
120 | return False, model_ans
121 |
122 |
123 | def extract_letter_cot(answer):
124 | """Extracts the letter choice from an answer that's in parentheses like (X)."""
125 | # Look for the last occurrence of a pattern like (A), (B), etc.
126 | matches = re.findall(r"\(([A-Za-z])\)", answer.strip())
127 | return matches[-1].lower() if matches else "" # Return the last letter found in parentheses, converted to uppercase
128 |
129 |
130 | def extract_letter(answer):
131 | """Extracts the last letter choice from an answer, ensuring it is uppercase (e.g., '(a)' → 'A', 'Selected: (c)' → 'C')."""
132 | matches = re.findall(r"[A-Za-z]", answer.strip()) # Find all letters (uppercase & lowercase)
133 | return matches[-1].lower() if matches else "" # Return the last letter found, converted to uppercase
134 |
135 |
136 | def parse_res(model_ans, options, gt_ans):
137 | str_opt = [str(item).lower() for item in options if item != '']
138 | check_res, model_ans_new = check_answer(model_ans.strip(), gt_ans)
139 | find_res = True
140 | if model_ans_new.lower() not in ('a', 'b', 'c', 'd', 'e'):
141 | if len(model_ans_new.lower().split(' ')) == 2 and model_ans_new.lower().split(' ')[0] in ('a', 'b', 'c', 'd', 'e'):
142 | model_ans_new = model_ans_new.lower().split(' ')[0]
143 | elif model_ans_new in options:
144 | ans_id = options.index(model_ans_new)
145 | model_ans_new = chr(65 + ans_id)
146 | elif len([item for item in str_opt if item in model_ans_new or item.replace(' ', '') in model_ans_new]) == 1:
147 | ans_id = [item for item in str_opt if item in model_ans_new or item.replace(' ', '') in model_ans_new][0]
148 | ans_id = str_opt.index(ans_id)
149 | model_ans_new = chr(65 + ans_id)
150 | else:
151 | find_res = False
152 | return model_ans_new, check_res, find_res
--------------------------------------------------------------------------------
/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def set_root_folder():
5 | CACHE_DIR = "YOUR_HF_CACHE_FOLDER"
6 | if not os.path.exists(CACHE_DIR):
7 | print(f"Not valid cache folder path: {CACHE_DIR}")
8 | CACHE_DIR = os.path.expanduser('~')
9 | print(f"Setting cache folder path to home directory: {CACHE_DIR}")
10 |
11 | return CACHE_DIR
12 |
--------------------------------------------------------------------------------