├── .gitignore
├── LICENSE
├── README.md
├── assets
├── leaderboard-1.0.png
├── leaderboard-1.5-air.png
├── leaderboard-1.5-edge.png
└── pipeline.png
├── demo
├── cat.jpg
├── demo.py
└── einstein.jpg
├── requirements.txt
├── run.py
├── run.sh
├── setup.py
└── vlmeval
├── __init__.py
├── api
├── __init__.py
├── base.py
├── gemini.py
├── gpt.py
├── gpt_int.py
├── hf_chat_model.py
└── qwen_vl_api.py
├── config.py
├── evaluate
├── __init__.py
├── coco_eval.py
├── llavabench.py
├── mathvista_eval.py
├── misc.py
├── mmmu_eval
│ ├── README.md
│ ├── __init__.py
│ ├── answer_dict_val.json
│ ├── configs
│ │ └── llava1.5.yaml
│ ├── example_outputs
│ │ ├── llava1.5_13b
│ │ │ ├── Accounting
│ │ │ │ └── output.json
│ │ │ ├── Agriculture
│ │ │ │ └── output.json
│ │ │ ├── Architecture_and_Engineering
│ │ │ │ └── output.json
│ │ │ ├── Art
│ │ │ │ └── output.json
│ │ │ ├── Art_Theory
│ │ │ │ └── output.json
│ │ │ ├── Basic_Medical_Science
│ │ │ │ └── output.json
│ │ │ ├── Biology
│ │ │ │ └── output.json
│ │ │ ├── Chemistry
│ │ │ │ └── output.json
│ │ │ ├── Clinical_Medicine
│ │ │ │ └── output.json
│ │ │ ├── Computer_Science
│ │ │ │ └── output.json
│ │ │ ├── Design
│ │ │ │ └── output.json
│ │ │ ├── Diagnostics_and_Laboratory_Medicine
│ │ │ │ └── output.json
│ │ │ ├── Economics
│ │ │ │ └── output.json
│ │ │ ├── Electronics
│ │ │ │ └── output.json
│ │ │ ├── Energy_and_Power
│ │ │ │ └── output.json
│ │ │ ├── Finance
│ │ │ │ └── output.json
│ │ │ ├── Geography
│ │ │ │ └── output.json
│ │ │ ├── History
│ │ │ │ └── output.json
│ │ │ ├── Literature
│ │ │ │ └── output.json
│ │ │ ├── Manage
│ │ │ │ └── output.json
│ │ │ ├── Marketing
│ │ │ │ └── output.json
│ │ │ ├── Materials
│ │ │ │ └── output.json
│ │ │ ├── Math
│ │ │ │ └── output.json
│ │ │ ├── Mechanical_Engineering
│ │ │ │ └── output.json
│ │ │ ├── Music
│ │ │ │ └── output.json
│ │ │ ├── Pharmacy
│ │ │ │ └── output.json
│ │ │ ├── Physics
│ │ │ │ └── output.json
│ │ │ ├── Psychology
│ │ │ │ └── output.json
│ │ │ ├── Public_Health
│ │ │ │ └── output.json
│ │ │ ├── Sociology
│ │ │ │ └── output.json
│ │ │ └── total_val_output.json
│ │ ├── llava1.5_13b_val.json
│ │ └── qwen_vl
│ │ │ ├── Accounting
│ │ │ └── output.json
│ │ │ ├── Agriculture
│ │ │ └── output.json
│ │ │ ├── Architecture_and_Engineering
│ │ │ └── output.json
│ │ │ ├── Art
│ │ │ └── output.json
│ │ │ ├── Art_Theory
│ │ │ └── output.json
│ │ │ ├── Basic_Medical_Science
│ │ │ └── output.json
│ │ │ ├── Biology
│ │ │ └── output.json
│ │ │ ├── Chemistry
│ │ │ └── output.json
│ │ │ ├── Clinical_Medicine
│ │ │ └── output.json
│ │ │ ├── Computer_Science
│ │ │ └── output.json
│ │ │ ├── Design
│ │ │ └── output.json
│ │ │ ├── Diagnostics_and_Laboratory_Medicine
│ │ │ └── output.json
│ │ │ ├── Economics
│ │ │ └── output.json
│ │ │ ├── Electronics
│ │ │ └── output.json
│ │ │ ├── Energy_and_Power
│ │ │ └── output.json
│ │ │ ├── Finance
│ │ │ └── output.json
│ │ │ ├── Geography
│ │ │ └── output.json
│ │ │ ├── History
│ │ │ └── output.json
│ │ │ ├── Literature
│ │ │ └── output.json
│ │ │ ├── Manage
│ │ │ └── output.json
│ │ │ ├── Marketing
│ │ │ └── output.json
│ │ │ ├── Materials
│ │ │ └── output.json
│ │ │ ├── Math
│ │ │ └── output.json
│ │ │ ├── Mechanical_Engineering
│ │ │ └── output.json
│ │ │ ├── Music
│ │ │ └── output.json
│ │ │ ├── Pharmacy
│ │ │ └── output.json
│ │ │ ├── Physics
│ │ │ └── output.json
│ │ │ ├── Psychology
│ │ │ └── output.json
│ │ │ ├── Public_Health
│ │ │ └── output.json
│ │ │ ├── Sociology
│ │ │ └── output.json
│ │ │ └── total_val_output.json
│ ├── main_eval_only.py
│ ├── main_parse_and_eval.py
│ ├── mmmu_eval_script.py
│ ├── print_results.py
│ ├── run_llava.py
│ └── utils
│ │ ├── data_utils.py
│ │ ├── eval_utils.py
│ │ └── model_utils.py
├── mmvet_eval.py
├── multiple_choice.py
├── vqa_eval.py
└── yes_or_no.py
├── inference.py
├── smp
├── __init__.py
├── file.py
├── lb.py
├── log.py
├── misc.py
└── vlm.py
├── utils
├── __init__.py
├── custom_prompt.py
├── dataset.py
├── dataset_config.py
├── matching_util.py
├── mp_util.py
└── xtuner_util.py
└── vlm
├── __init__.py
├── hpt.py
├── hpt1_5.py
└── modeling_siglip.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
141 | # pytype static type analyzer
142 | .pytype/
143 |
144 | # Cython debug symbols
145 | cython_debug/
146 |
147 | # project-specific
148 | output/
149 | debug*/
150 | *.bak
151 | *.dir
152 | *.dat
153 | *.tsv
154 | *.gz
155 |
156 | cache/
157 |
158 | # models
159 | *.onnx
160 | *.pth
161 | *.zip
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HPT - Open Multimodal Large Language Models
2 | Hyper-Pretrained Transformers (HPT) is a novel multimodal LLM framework from HyperGAI, and has been trained for vision-language models that are capable of understanding both textual and visual inputs. HPT has achieved highly competitive results with state-of-the-art models on a variety of multimodal LLM benchmarks. This repository contains the open-source implementation of inference code to reproduce the evaluation results of HPT on different benchmarks.
3 |
4 | ## Release
5 | - [6/06] :fire: Releasing **HPT 1.5 Edge**, our latest open-source model tailored to edge devices. Despite its size (<5B), Edge demonstrates impressive capabilities while being extremely efficient. HPT 1.5 Edge is publicly available on [[HuggingFace Repository](https://huggingface.co/HyperGAI/HPT1_5-Edge)]. Please read our [[technical blog post](https://hypergai.com/blog/hpt-1-5-edge-towards-multimodal-llms-for-edge-devices)] for more details.
6 | - [5/03] **HPT 1.5 Air**, our best open-sourced 8B Multimodal LLM with [Llama 3](https://huggingface.co/blog/llama3). Built with Meta Llama 3, Our hyper capable HPT 1.5 Air packs a punch on real world understanding and complex reasoning. HPT Air 1.5 achieves the best results among <10B models across a wide range of challenging benchmarks (MMMU, POPE, SEED-I, and more). HPT 1.5 Air is publicly available on [[HuggingFace Repository](https://huggingface.co/HyperGAI/HPT1_5-Air-Llama-3-8B-Instruct-multimodal)]. Please read our [[technical blog post](https://hypergai.com/blog/hpt-1-5-air-best-open-sourced-8b-multimodal-llm-with-llama-3)] for more details.
7 | - [3/16] **HPT 1.0 Air** is out, our most efficient model as a cost-effective solution that is capable of solving a wide range of vision-and-language tasks. HPT 1.0 Air is publicly available and achieves state-of-the-art results among all the open-source multimodal LLM models of similar or smaller sizes on the challenging MMMU benchmark. Please read our [[technical blog post](https://www.hypergai.com/blog/introducing-hpt-a-family-of-leading-multimodal-llms)] and [[HuggingFace Repository](https://huggingface.co/HyperGAI/HPT)] for more details.
8 |
9 |
10 | We release HPT 1.5 Edge as our latest open-sources model tailored to edge devices. Despite its size (<5B), Edge demonstrates impressive capabilities while being extremely efficient. We release HPT 1.5 Edge publicly at Huggingface and Github under the Apache 2.0 license.
11 |
12 |
13 |
14 | ## Table of Contents
15 | - [Overview of Model Achitecture](#overview-of-model-achitecture)
16 | - [Quick Start](#quick-start)
17 | - [Installation](#installation)
18 | - [Prepare the Model](#prepare-the-model)
19 | - [Demo](#demo)
20 | - [Evaluations](#evaluations)
21 | - [Benchmarks](#benchmarks)
22 | - [Pretrained Models Used](#pretrained-models-used)
23 | - [Disclaimer and Responsible Use](#disclaimer-and-responsible-use)
24 | - [Contact Us](contact-us)
25 | - [License](#license)
26 | - [Acknowledgements](#acknowledgements)
27 |
28 |
29 |
30 | ## Overview of Model Achitecture
31 |
32 |
33 |

34 |
35 |
36 |
37 | ## Quick Start
38 |
39 | ### Installation
40 |
41 | ```
42 | pip install -r requirements.txt
43 | pip install -e .
44 | ```
45 |
46 | ### Prepare the Model
47 |
48 | You can download the model weights from HF into your [Local Path] and set the `global_model_path` as your [Local Path] in the model [config file](./vlmeval/config.py#L24):
49 | ```
50 | git lfs install
51 | git clone https://huggingface.co/HyperGAI/HPT1_5-Edge [Local Path]
52 | ```
53 |
54 | You can also set other strategies in the [config file](./vlmeval/config.py#L24) that are different from our default settings.
55 |
56 | ### Demo
57 |
58 | After setting up the config file, launch the model demo for a quick trial:
59 |
60 | ```
61 | python demo/demo.py --image_path [Image] --text [Text] --model [Config]
62 | ```
63 |
64 | Example:
65 |
66 | ```
67 | python demo/demo.py --image_path demo/einstein.jpg --text 'What is unusual about this image?' --model hpt-edge-1-5
68 | ```
69 |
70 | ## Evaluations
71 |
72 | Launch the model for evaluation:
73 |
74 | ```
75 | torchrun --nproc-per-node=8 run.py --data [Dataset] --model [Config]
76 | ```
77 |
78 | Example for HPT 1.5 Edge:
79 |
80 | ```
81 | torchrun --nproc-per-node=8 run.py --data MMMU_DEV_VAL --model hpt-edge-1-5
82 | ```
83 |
84 | ## Benchmarks
85 |
86 | **For HPT 1.5 Edge**
87 |
88 |
89 |

90 |
91 |
92 |
93 | - The majority of the results presented are taken from the models‘ original reports while the others are from Phi-3-vision evaluations, which we mark with an asterisk (*).
94 | - The benchmark result of HPT1.5 Air and HPT1.0 is in assets directory.
95 |
96 |
97 | ## Pretrained Models Used
98 |
99 | **HPT 1.5 Edge**
100 |
101 | - Pretrained LLM: [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
102 |
103 | - Pretrained Visual Encoder: [siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384)
104 |
105 | **HPT 1.5 Air**
106 |
107 | - Pretrained LLM: [Llama3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
108 |
109 | - Pretrained Visual Encoder: [siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384)
110 |
111 | **HPT 1.0 Air**
112 |
113 | - Pretrained LLM: [Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B)
114 |
115 | - Pretrained Visual Encoder: [clip-vit-large-patch14-336 ](https://huggingface.co/openai/clip-vit-large-patch14-336)
116 |
117 | ## Disclaimer and Responsible Use
118 |
119 | Note that the HPT Air is a quick open release of our models to facilitate the open, responsible AI research and community development. It does not have any moderation mechanism and provides no guarantees on their results. We hope to engage with the community to make the model finely respect guardrails to allow practical adoptions in real-world applications requiring moderated outputs.
120 |
121 | ## Contact Us
122 |
123 | - Contact: HPT@hypergai.com
124 | - Follow us on [Twitter](https://twitter.com/hypergai).
125 | - Follow us on [Linkedin](https://www.linkedin.com/company/hypergai/).
126 | - Visit our [website](https://www.hypergai.com/) to learn more about us.
127 |
128 |
129 | ## License
130 |
131 | This project is released under the [Apache 2.0 license](LICENSE). Parts of this project contain code and models from other sources, which are subject to their respective licenses and you need to apply their respective license if you want to use for commercial purposes.
132 |
133 | ## Acknowledgements
134 |
135 | The evaluation code for running this demo was extended based on the [VLMEvalKit project](https://github.com/open-compass/VLMEvalKit).
136 | We also thank [OpenAI](https://openai.com) for open-sourcing their visual encoder models, [01.AI](https://www.01.ai), [Meta](https://www.meta.com/) and [Microsoft](https://www.microsoft.com/) for open-sourcing their large language models.
137 |
--------------------------------------------------------------------------------
/assets/leaderboard-1.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/leaderboard-1.0.png
--------------------------------------------------------------------------------
/assets/leaderboard-1.5-air.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/leaderboard-1.5-air.png
--------------------------------------------------------------------------------
/assets/leaderboard-1.5-edge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/leaderboard-1.5-edge.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/pipeline.png
--------------------------------------------------------------------------------
/demo/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/demo/cat.jpg
--------------------------------------------------------------------------------
/demo/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from vlmeval.config import supported_VLM
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--image_path', type=str, nargs='+', required=True)
7 | parser.add_argument('--text', type=str, required=True)
8 | parser.add_argument("--model", type=str, required=True)
9 | parser.add_argument("--nproc", type=int, default=4, required=False)
10 | parser.add_argument("--verbose", action='store_true')
11 | args = parser.parse_args()
12 | return args
13 |
14 | def main():
15 | args = parse_args()
16 | model_name = args.model
17 | model = supported_VLM[model_name]()
18 | text = args.text
19 | image_path = args.image_path
20 | response = model.generate(prompt=text, image_path=image_path, dataset='demo')
21 |
22 | print(response)
23 |
24 |
25 | if __name__ == '__main__':
26 | main()
27 |
--------------------------------------------------------------------------------
/demo/einstein.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/demo/einstein.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.23.4
2 | openai==1.3.5
3 | requests
4 | tqdm
5 | pandas>=1.5.3
6 | gradio==4.15.0
7 | tiktoken
8 | rich
9 | portalocker
10 | timeout-decorator
11 | opencv-python==4.8.0.74
12 | pillow==10.2.0
13 | omegaconf
14 | matplotlib
15 | einops
16 | sentencepiece
17 | sty
18 | huggingface_hub
19 | visual_genome
20 | pycocoevalcap
21 | openpyxl
22 | seaborn
23 | tabulate
24 | xlsxwriter
25 | torch>=2.0.1
26 | typing_extensions==4.10.0
27 | transformers==4.37.0
28 | accelerate
29 | mmengine
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from vlmeval.smp import *
4 | from vlmeval.evaluate import COCO_eval, YOrN_eval, MMVet_eval, multiple_choice_eval, VQAEval, MathVista_eval, LLaVABench_eval
5 | from vlmeval.inference import infer_data_job, prefetch_acc
6 | from vlmeval.config import supported_VLM
7 | from vlmeval.utils import dataset_URLs, abbr2full
8 | from vlmeval.evaluate import mmmu_eval_func
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--data', type=str, nargs='+', required=True)
13 | parser.add_argument("--model", type=str, nargs='+', required=True)
14 | parser.add_argument("--work-dir", type=str, default='.', help="select the output directory")
15 | parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'])
16 | parser.add_argument("--nproc", type=int, default=4, help="Parallel API calling")
17 | parser.add_argument("--ignore", action='store_true', help="Ignore failed indices. ")
18 | parser.add_argument("--verbose", action='store_true')
19 | parser.add_argument("--prefetch", action='store_true')
20 | args = parser.parse_args()
21 | return args
22 |
23 | def main():
24 | logger = get_logger('RUN')
25 |
26 | args = parse_args()
27 | assert len(args.data), "--data should be a list of data files"
28 |
29 | rank, world_size = get_rank_and_world_size()
30 | if world_size > 1:
31 | torch.cuda.set_device(rank)
32 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400))
33 |
34 | for _, model_name in enumerate(args.model):
35 | model = None
36 |
37 | pred_root = osp.join(args.work_dir, 'output', model_name)
38 | os.makedirs(pred_root, exist_ok=True)
39 |
40 | for i, dataset_name in enumerate(args.data):
41 | if dataset_name not in dataset_URLs:
42 | dataset_name = abbr2full(dataset_name)
43 |
44 | if dataset_name not in dataset_URLs:
45 | logger.error(f'Unknown dataset: {dataset_name}. ')
46 | continue
47 |
48 | result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx'
49 |
50 | if model is None:
51 | model = model_name # which is only a name
52 |
53 | # CHECKER
54 | if dataset_name == 'CORE_MM':
55 | MULTI_IMG = getattr(supported_VLM[model_name].func, 'multi_generate', None)
56 | if MULTI_IMG is not None:
57 | logger.error(f'Model {model_name} does not support the `multi_generate` interface, which is required for testing CORE_MM, skip it. ')
58 | continue
59 | if args.mode == 'all':
60 | logger.error(f'Dataset {dataset_name} does not support `evaluation` now, will skip the evaluation. ')
61 |
62 | model = infer_data_job(model, work_dir=pred_root, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc, ignore_failed=args.ignore)
63 |
64 | if dataset_name in ['MMBench_TEST_CN', 'MMBench_TEST_EN', "MMMU_TEST"]:
65 | if not MMBenchOfficialServer():
66 | logger.error(f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation. ')
67 | continue
68 |
69 | if rank == 0:
70 | time.sleep(3)
71 | res = None
72 | if listinstr(['SEEDBench_IMG', 'MMBench', 'CCBench', 'ScienceQA', 'AI2D'], dataset_name):
73 | res = prefetch_acc(result_file)
74 | else:
75 | logger.warning(f'{dataset_name} is not handled by prefetch score calculator')
76 | if res is not None:
77 | logger.info(f'{model_name} prefetching: ')
78 | logger.info(res)
79 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx'))
80 |
81 | if rank == 0 and args.mode == 'all':
82 | import pandas as pd
83 | import json
84 | if listinstr(['MMMU'], dataset_name):
85 | res_rec = pd.read_excel(result_file)
86 | recs = {}
87 | for id_, pred_ in zip(res_rec['id'], res_rec['prediction']):
88 | if 'dev_' in id_:
89 | continue
90 | recs[id_] = str(pred_)
91 | output_file = result_file.replace('.xlsx', '.json')
92 | json.dump(recs, open(output_file, 'w'))
93 | mmmu_eval_func(output_file)
94 |
95 | elif listinstr(['MMBench', 'CCBench', 'SEEDBench_IMG', 'ScienceQA', 'AI2D'], dataset_name):
96 | multiple_choice_eval(result_file, dataset=dataset_name, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose)
97 | elif listinstr(['MME', 'Hallusion'], dataset_name):
98 | YOrN_eval(result_file, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose, dataset=dataset_name)
99 | elif dataset_name == 'MMVet':
100 | MMVet_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose)
101 | elif listinstr(['COCO'], dataset_name):
102 | COCO_eval(result_file)
103 | elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA'], dataset_name):
104 | VQAEval(result_file, dataset_name)
105 | elif listinstr(['MathVista'], dataset_name):
106 | MathVista_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose)
107 | elif listinstr(['LLaVABench'], dataset_name):
108 | LLaVABench_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose)
109 | else:
110 | logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ')
111 |
112 | if __name__ == '__main__':
113 | main()
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc-per-node=8 run.py --data MMMU_DEV_VAL --model hpt-edge-1-5
2 |
3 | torchrun --nproc-per-node=8 run.py --data MMBench_DEV_EN --model hpt-edge-1-5
4 |
5 | torchrun --nproc-per-node=8 run.py --data SEEDBench_IMG --model hpt-edge-1-5
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | from os.path import exists
4 | from setuptools import find_packages, setup
5 |
6 | def parse_requirements(fname='requirements.txt', with_version=True):
7 | """Parse the package dependencies listed in a requirements file but strips
8 | specific versioning information.
9 |
10 | Args:
11 | fname (str): path to requirements file
12 | with_version (bool, default=False): if True include version specs
13 |
14 | Returns:
15 | List[str]: list of requirements items
16 |
17 | CommandLine:
18 | python -c "import setup; print(setup.parse_requirements())"
19 | """
20 |
21 | require_fpath = fname
22 |
23 | def parse_line(line):
24 | """Parse information from a line in a requirements text file."""
25 | if line.startswith('-r '):
26 | # Allow specifying requirements in other files
27 | target = line.split(' ')[1]
28 | for info in parse_require_file(target):
29 | yield info
30 | else:
31 | info = {'line': line}
32 | if line.startswith('-e '):
33 | info['package'] = line.split('#egg=')[1]
34 | elif '@git+' in line:
35 | info['package'] = line
36 | else:
37 | # Remove versioning from the package
38 | pat = '(' + '|'.join(['>=', '==', '>']) + ')'
39 | parts = re.split(pat, line, maxsplit=1)
40 | parts = [p.strip() for p in parts]
41 |
42 | info['package'] = parts[0]
43 | if len(parts) > 1:
44 | op, rest = parts[1:]
45 | if ';' in rest:
46 | # Handle platform specific dependencies
47 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
48 | version, platform_deps = map(str.strip,
49 | rest.split(';'))
50 | info['platform_deps'] = platform_deps
51 | else:
52 | version = rest # NOQA
53 | info['version'] = (op, version)
54 | yield info
55 |
56 | def parse_require_file(fpath):
57 | with open(fpath, 'r') as f:
58 | for line in f.readlines():
59 | line = line.strip()
60 | if line and not line.startswith('#'):
61 | for info in parse_line(line):
62 | yield info
63 |
64 | def gen_packages_items():
65 | if exists(require_fpath):
66 | for info in parse_require_file(require_fpath):
67 | parts = [info['package']]
68 | if with_version and 'version' in info:
69 | parts.extend(info['version'])
70 | if not sys.version.startswith('3.4'):
71 | # apparently package_deps are broken in 3.4
72 | platform_deps = info.get('platform_deps')
73 | if platform_deps is not None:
74 | parts.append(';' + platform_deps)
75 | item = ''.join(parts)
76 | yield item
77 |
78 | packages = list(gen_packages_items())
79 | return packages
80 |
81 |
82 | with open('README.md') as f:
83 | readme = f.read()
84 |
85 |
86 | def do_setup():
87 | setup(
88 | name='vlmeval',
89 | version='0.1.0',
90 | description='OpenCompass VLM Evaluation Kit',
91 | author="Haodong Duan",
92 | author_email='dhd.efz@gmail.com',
93 | maintainer='Haodong Duan',
94 | maintainer_email='dhd.efz@gmail.com',
95 | long_description=readme,
96 | long_description_content_type='text/markdown',
97 | cmdclass={},
98 | install_requires=parse_requirements('requirements.txt'),
99 | setup_requires=[],
100 | python_requires='>=3.7.0',
101 | packages=find_packages(exclude=[
102 | 'test*',
103 | 'paper_test*',
104 | ]),
105 | keywords=['AI', 'NLP', 'in-context learning'],
106 | entry_points={
107 | "console_scripts": []
108 | },
109 | classifiers=[
110 | 'Programming Language :: Python :: 3.7',
111 | 'Programming Language :: Python :: 3.8',
112 | 'Programming Language :: Python :: 3.9',
113 | 'Programming Language :: Python :: 3.10',
114 | 'Intended Audience :: Developers',
115 | 'Intended Audience :: Education',
116 | 'Intended Audience :: Science/Research',
117 | ])
118 |
119 |
120 | if __name__ == '__main__':
121 | do_setup()
122 |
--------------------------------------------------------------------------------
/vlmeval/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | import torch
3 | except ImportError:
4 | pass
5 |
6 | from .smp import *
7 | from .api import *
8 | from .evaluate import *
9 | from .utils import *
10 | from .vlm import *
11 | from .config import *
--------------------------------------------------------------------------------
/vlmeval/api/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt import OpenAIWrapper, GPT4V
2 | from .gpt_int import OpenAIWrapperInternal, GPT4V_Internal
3 | from .hf_chat_model import HFChatModel
4 | from .gemini import GeminiWrapper, GeminiProVision
5 | from .qwen_vl_api import QwenVLWrapper, QwenVLAPI
6 |
7 | __all__ = [
8 | 'OpenAIWrapper', 'HFChatModel', 'OpenAIWrapperInternal', 'GeminiWrapper',
9 | 'GPT4V', 'GPT4V_Internal', 'GeminiProVision','QwenVLWrapper', 'QwenVLAPI'
10 | ]
--------------------------------------------------------------------------------
/vlmeval/api/base.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random as rd
3 | from abc import abstractmethod
4 | from ..smp import get_logger
5 |
6 | class BaseAPI:
7 |
8 | def __init__(self,
9 | retry=10,
10 | wait=3,
11 | system_prompt=None,
12 | verbose=True,
13 | fail_msg='Failed to obtain answer via API.',
14 | **kwargs):
15 | self.wait = wait
16 | self.retry = retry
17 | self.system_prompt = system_prompt
18 | self.kwargs = kwargs
19 | self.verbose = verbose
20 | self.fail_msg = fail_msg
21 | self.logger = get_logger('ChatAPI')
22 | if len(kwargs):
23 | self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
24 | self.logger.info(f'Will try to use them as kwargs for `generate`. ')
25 |
26 | @abstractmethod
27 | def generate_inner(self, inputs, **kwargs):
28 | self.logger.warning(f'For APIBase, generate_inner is an abstract method. ')
29 | assert 0, 'generate_inner not defined'
30 | ret_code, answer, log = None, None, None
31 | # if ret_code is 0, means succeed
32 | return ret_code, answer, log
33 |
34 | def generate(self, inputs, **kwargs):
35 | input_type = None
36 | if isinstance(inputs, str):
37 | input_type = 'str'
38 | elif isinstance(inputs, list) and isinstance(inputs[0], str):
39 | input_type = 'strlist'
40 | elif isinstance(inputs, list) and isinstance(inputs[0], dict):
41 | input_type = 'dictlist'
42 | assert input_type is not None, input_type
43 |
44 | answer = None
45 | for i in range(self.retry):
46 | T = rd.random() * self.wait * 2
47 | time.sleep(T)
48 | try:
49 | ret_code, answer, log = self.generate_inner(inputs, **kwargs)
50 | if ret_code == 0 and self.fail_msg not in answer and answer != '':
51 | if self.verbose:
52 | print(answer)
53 | return answer
54 | elif self.verbose:
55 | self.logger.info(f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}")
56 | except Exception as err:
57 | if self.verbose:
58 | self.logger.error(f'An error occured during try {i}:')
59 | self.logger.error(err)
60 |
61 | return self.fail_msg if answer in ['', None] else answer
62 |
63 |
64 |
--------------------------------------------------------------------------------
/vlmeval/api/gemini.py:
--------------------------------------------------------------------------------
1 | from vlmeval.smp import *
2 | from vlmeval.api.base import BaseAPI
3 |
4 | headers = 'Content-Type: application/json'
5 |
6 | class GeminiWrapper(BaseAPI):
7 |
8 | is_api: bool = True
9 |
10 | def __init__(self,
11 | retry: int = 5,
12 | wait: int = 5,
13 | key: str = None,
14 | verbose: bool = True,
15 | temperature: float = 0.0,
16 | system_prompt: str = None,
17 | max_tokens: int = 1024,
18 | proxy: str = None,
19 | **kwargs):
20 |
21 | self.fail_msg = 'Failed to obtain answer via API. '
22 | self.max_tokens = max_tokens
23 | self.temperature = temperature
24 | if key is None:
25 | key = os.environ.get('GOOGLE_API_KEY', None)
26 | assert key is not None
27 | self.api_key = key
28 | if proxy is not None:
29 | proxy_set(proxy)
30 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
31 |
32 | @staticmethod
33 | def build_msgs(msgs_raw, system_prompt=None):
34 | msgs = cp.deepcopy(msgs_raw)
35 | assert len(msgs) % 2 == 1
36 |
37 | if system_prompt is not None:
38 | msgs[0] = [system_prompt, msgs[0]]
39 | ret = []
40 | for i, msg in enumerate(msgs):
41 | role = 'user' if i % 2 == 0 else 'model'
42 | parts = msg if isinstance(msg, list) else [msg]
43 | ret.append(dict(role=role, parts=parts))
44 | return ret
45 |
46 | def generate_inner(self, inputs, **kwargs) -> str:
47 | import google.generativeai as genai
48 | assert isinstance(inputs, str) or isinstance(inputs, list)
49 | pure_text = True
50 | if isinstance(inputs, list):
51 | for pth in inputs:
52 | if osp.exists(pth) or pth.startswith('http'):
53 | pure_text = False
54 | genai.configure(api_key=self.api_key)
55 | model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel('gemini-pro-vision')
56 | if isinstance(inputs, str):
57 | messages = [inputs] if self.system_prompt is None else [self.system_prompt, inputs]
58 | elif pure_text:
59 | messages = self.build_msgs(inputs, self.system_prompt)
60 | else:
61 | messages = [] if self.system_prompt is None else [self.system_prompt]
62 | for s in inputs:
63 | if osp.exists(s):
64 | messages.append(Image.open(s))
65 | elif s.startswith('http'):
66 | pth = download_file(s)
67 | messages.append(Image.open(pth))
68 | shutil.remove(pth)
69 | else:
70 | messages.append(s)
71 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
72 | gen_config.update(self.kwargs)
73 | try:
74 | answer = model.generate_content(messages, generation_config=genai.types.GenerationConfig(**gen_config)).text
75 | return 0, answer, 'Succeeded! '
76 | except Exception as err:
77 | if self.verbose:
78 | self.logger.error(err)
79 | self.logger.error(f"The input messages are {inputs}.")
80 |
81 | return -1, '', ''
82 |
83 |
84 |
85 | class GeminiProVision(GeminiWrapper):
86 |
87 | def generate(self, image_path, prompt, dataset=None):
88 | return super(GeminiProVision, self).generate([image_path, prompt])
89 |
90 | def multi_generate(self, image_paths, prompt, dataset=None):
91 | return super(GeminiProVision, self).generate(image_paths + [prompt])
92 |
93 | def interleave_generate(self, ti_list, dataset=None):
94 | return super(GeminiProVision, self).generate(ti_list)
--------------------------------------------------------------------------------
/vlmeval/api/gpt.py:
--------------------------------------------------------------------------------
1 | from ..smp import *
2 | import os, sys
3 | from .base import BaseAPI
4 |
5 | APIBASES = {
6 | 'OFFICIAL': "https://api.openai.com/v1/chat/completions",
7 | }
8 |
9 |
10 | def GPT_context_window(model):
11 | length_map = {
12 | 'gpt-4-1106-preview': 128000,
13 | 'gpt-4-vision-preview': 128000,
14 | 'gpt-4': 8192,
15 | 'gpt-4-32k': 32768,
16 | 'gpt-4-0613': 8192,
17 | 'gpt-4-32k-0613': 32768,
18 | 'gpt-3.5-turbo-1106': 16385,
19 | 'gpt-3.5-turbo': 4096,
20 | 'gpt-3.5-turbo-16k': 16385,
21 | 'gpt-3.5-turbo-instruct': 4096,
22 | 'gpt-3.5-turbo-0613': 4096,
23 | 'gpt-3.5-turbo-16k-0613': 16385,
24 | }
25 | if model in length_map:
26 | return length_map[model]
27 | else:
28 | return 4096
29 |
30 | class OpenAIWrapper(BaseAPI):
31 |
32 | is_api: bool = True
33 |
34 | def __init__(self,
35 | model: str = 'gpt-3.5-turbo-0613',
36 | retry: int = 5,
37 | wait: int = 5,
38 | key: str = None,
39 | verbose: bool = True,
40 | system_prompt: str = None,
41 | temperature: float = 0,
42 | timeout: int = 60,
43 | api_base: str = 'OFFICIAL',
44 | max_tokens: int = 1024,
45 | img_size: int = 512,
46 | img_detail: str = 'low',
47 | **kwargs):
48 |
49 | self.model = model
50 | self.cur_idx = 0
51 | self.fail_msg = 'Failed to obtain answer via API. '
52 | self.max_tokens = max_tokens
53 | self.temperature = temperature
54 |
55 | openai_key = os.environ.get('OPENAI_API_KEY', None) if key is None else key
56 | self.openai_key = openai_key
57 | assert img_size > 0 or img_size == -1
58 | self.img_size = img_size
59 | assert img_detail in ['high', 'low']
60 | self.img_detail = img_detail
61 |
62 | self.vision = False
63 | if model == 'gpt-4-vision-preview':
64 | self.vision = True
65 | self.timeout = timeout
66 |
67 | assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. '
68 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
69 |
70 | if api_base in APIBASES:
71 | self.api_base = APIBASES[api_base]
72 | elif api_base.startswith('http'):
73 | self.api_base = api_base
74 | else:
75 | self.logger.error("Unknown API Base. ")
76 | sys.exit(-1)
77 |
78 | # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
79 | # content can be a string or a list of image & text
80 | def prepare_inputs(self, inputs):
81 | input_msgs = []
82 | if self.system_prompt is not None:
83 | input_msgs.append(dict(role='system', content=self.system_prompt))
84 | if isinstance(inputs, str):
85 | input_msgs.append(dict(role='user', content=inputs))
86 | return input_msgs
87 | assert isinstance(inputs, list)
88 | dict_flag = [isinstance(x, dict) for x in inputs]
89 | if np.all(dict_flag):
90 | input_msgs.extend(inputs)
91 | return input_msgs
92 | str_flag = [isinstance(x, str) for x in inputs]
93 | if np.all(str_flag):
94 | img_flag = [x.startswith('http') or osp.exists(x) for x in inputs]
95 | if np.any(img_flag):
96 | content_list = []
97 | for fl, msg in zip(img_flag, inputs):
98 | if not fl:
99 | content_list.append(dict(type='text', text=msg))
100 | elif msg.startswith('http'):
101 | content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail}))
102 | elif osp.exists(msg):
103 | from PIL import Image
104 | img = Image.open(msg)
105 | b64 = encode_image_to_base64(img, target_size=self.img_size)
106 | img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
107 | content_list.append(dict(type='image_url', image_url=img_struct))
108 | input_msgs.append(dict(role='user', content=content_list))
109 | return input_msgs
110 | else:
111 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
112 | roles = roles * len(inputs)
113 | for role, msg in zip(roles, inputs):
114 | input_msgs.append(dict(role=role, content=msg))
115 | return input_msgs
116 | raise NotImplemented("list of list prompt not implemented now. ")
117 |
118 | def generate_inner(self, inputs, **kwargs) -> str:
119 | input_msgs = self.prepare_inputs(inputs)
120 | temperature = kwargs.pop('temperature', self.temperature)
121 | max_tokens = kwargs.pop('max_tokens', self.max_tokens)
122 |
123 | context_window = GPT_context_window(self.model)
124 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
125 | if 0 < max_tokens <= 100:
126 | self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
127 | if max_tokens <= 0:
128 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
129 |
130 | headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'}
131 | payload = dict(
132 | model=self.model,
133 | messages=input_msgs,
134 | max_tokens=max_tokens,
135 | n=1,
136 | temperature=temperature,
137 | **kwargs)
138 | response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
139 | ret_code = response.status_code
140 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
141 | answer = self.fail_msg
142 | try:
143 | resp_struct = json.loads(response.text)
144 | answer = resp_struct['choices'][0]['message']['content'].strip()
145 | except:
146 | pass
147 | return ret_code, answer, response
148 |
149 | def get_token_len(self, inputs) -> int:
150 | import tiktoken
151 | enc = tiktoken.encoding_for_model(self.model)
152 | if isinstance(inputs, str):
153 | if inputs.startswith('http') or osp.exists(inputs):
154 | return 65 if self.img_detail == 'low' else 130
155 | else:
156 | return len(enc.encode(inputs))
157 | elif isinstance(inputs, dict):
158 | assert 'content' in inputs
159 | return self.get_token_len(inputs['content'])
160 | assert isinstance(inputs, list)
161 | res = 0
162 | for item in inputs:
163 | res += self.get_token_len(item)
164 | return res
165 |
166 | class GPT4V(OpenAIWrapper):
167 |
168 | def generate(self, image_path, prompt, dataset=None):
169 | assert self.model == 'gpt-4-vision-preview'
170 | return super(GPT4V, self).generate([image_path, prompt])
171 |
172 | def multi_generate(self, image_paths, prompt, dataset=None):
173 | assert self.model == 'gpt-4-vision-preview'
174 | return super(GPT4V, self).generate(image_paths + [prompt])
175 |
176 | def interleave_generate(self, ti_list, dataset=None):
177 | assert self.model == 'gpt-4-vision-preview'
178 | return super(GPT4V, self).generate(ti_list)
179 |
--------------------------------------------------------------------------------
/vlmeval/api/gpt_int.py:
--------------------------------------------------------------------------------
1 | import json
2 | import warnings
3 | import requests
4 | from ..smp import *
5 | from .gpt import GPT_context_window, OpenAIWrapper
6 |
7 |
8 | url = "http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat"
9 | headers = {
10 | "Content-Type": "application/json"
11 | }
12 |
13 | class OpenAIWrapperInternal(OpenAIWrapper):
14 |
15 | is_api: bool = True
16 |
17 | def __init__(self,
18 | model: str = 'gpt-3.5-turbo-0613',
19 | retry: int = 5,
20 | wait: int = 3,
21 | verbose: bool = True,
22 | system_prompt: str = None,
23 | temperature: float = 0,
24 | timeout: int = 60,
25 | max_tokens: int = 1024,
26 | img_size: int = 512,
27 | img_detail: str = 'low',
28 | **kwargs):
29 |
30 | self.model = model
31 | if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']):
32 | keys = load(os.environ['KEYS'])
33 | headers['alles-apin-token'] = keys.get('alles-apin-token', '')
34 | elif 'ALLES' in os.environ:
35 | headers['alles-apin-token'] = os.environ['ALLES']
36 | self.headers = headers
37 | self.temperature = temperature
38 | self.timeout = timeout
39 | self.max_tokens = max_tokens
40 |
41 | assert img_size > 0 or img_size == -1
42 | self.img_size = img_size
43 | assert img_detail in ['high', 'low']
44 | self.img_detail = img_detail
45 |
46 | self.vision = False
47 | if model == 'gpt-4-vision-preview':
48 | self.vision = True
49 |
50 | super(OpenAIWrapper, self).__init__(
51 | wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
52 |
53 | def generate_inner(self, inputs, **kwargs) -> str:
54 | input_msgs = self.prepare_inputs(inputs)
55 |
56 | temperature = kwargs.pop('temperature', self.temperature)
57 | max_tokens = kwargs.pop('max_tokens', self.max_tokens)
58 |
59 | # Held out 100 tokens as buffer
60 | context_window = GPT_context_window(self.model)
61 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
62 | if 0 < max_tokens <= 100:
63 | print('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
64 | if max_tokens <= 0:
65 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
66 |
67 | payload = dict(
68 | model=self.model,
69 | messages=input_msgs,
70 | max_tokens=max_tokens,
71 | n=1,
72 | stop=None,
73 | timeout=self.timeout,
74 | temperature=temperature,
75 | **kwargs)
76 |
77 | response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
78 | ret_code = response.status_code
79 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
80 |
81 | answer = self.fail_msg
82 | try:
83 | resp_struct = json.loads(response.text)
84 | assert resp_struct['msg'] == 'ok' and resp_struct['msgCode'] == '10000', resp_struct
85 | answer = resp_struct['data']['choices'][0]['message']['content'].strip()
86 | except:
87 | pass
88 | return ret_code, answer, response
89 |
90 |
91 | class GPT4V_Internal(OpenAIWrapperInternal):
92 |
93 | def generate(self, image_path, prompt, dataset=None):
94 | assert self.model == 'gpt-4-vision-preview'
95 | return super(GPT4V_Internal, self).generate([image_path, prompt])
96 |
97 | def multi_generate(self, image_paths, prompt, dataset=None):
98 | assert self.model == 'gpt-4-vision-preview'
99 | return super(GPT4V_Internal, self).generate(image_paths + [prompt])
100 |
101 | def interleave_generate(self, ti_list, dataset=None):
102 | assert self.model == 'gpt-4-vision-preview'
103 | return super(GPT4V_Internal, self).generate(ti_list)
--------------------------------------------------------------------------------
/vlmeval/api/hf_chat_model.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import os.path as osp
3 | import torch
4 | from ..smp import *
5 |
6 | def get_gpu_num(model_name):
7 | model_name = model_name.lower()
8 | kws = {
9 | 8: ['65b', '70b'],
10 | 4: ['30b', '33b', '35b', '40b'],
11 | 2: ['13b', '14b', '20b'],
12 | 1: ['6b', '7b', 'moss'],
13 | }
14 | for k in [8, 4, 2, 1]:
15 | for keyword in kws[k]:
16 | if keyword in model_name:
17 | return k
18 | return 8
19 |
20 | validated_llms = [
21 | 'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b',
22 | 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat',
23 | 'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k',
24 | 'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat',
25 | 'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5',
26 | 'meta-llama/Llama-2-7b-chat-hf'
27 | ]
28 | Auto_model = ['chatglm']
29 |
30 | class HFChatModel:
31 |
32 | def _get_context_length(self, model, model_path):
33 | # By default, we use model.config.seq_length
34 | model_path = model_path.lower()
35 | if 'baichuan' in model_path:
36 | context_window = model.config.model_max_length
37 | elif 'internlm' in model_path or 'llama' in model_path:
38 | context_window = model.config.max_position_embeddings
39 | elif 'vicuna' in model_path:
40 | context_window = model.generation_config.max_length
41 | else:
42 | # chatglm & qwen
43 | context_window = model.config.seq_length
44 | return context_window
45 |
46 | def _get_context_length_robust(self, model, model_path):
47 | try:
48 | context_window = self._get_context_length(model, model_path)
49 | return context_window
50 | except:
51 | self.logger.critical(
52 | "Failed to extract context_window information from config / generation_config. "
53 | "Please read the above code and check if the logic works for you model path"
54 | )
55 | raise NotImplementedError
56 |
57 | def __init__(self,
58 | model_path,
59 | system_prompt: str=None,
60 | **kwargs):
61 |
62 | self.logger = get_logger('HFChatModel')
63 | if 'vicuna' in model_path.lower():
64 | try:
65 | from fastchat.model import get_conversation_template
66 | except:
67 | self.logger.critical("Please install fastchat first to use vicuna. ")
68 | sys.exit(-1)
69 |
70 | self.explicit_device = kwargs.pop('device', None)
71 |
72 | if self.explicit_device is None:
73 | # If CUDA_VISIBLE_DEVICES is not properly set
74 | if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] in ['', '0,1,2,3,4,5,6,7']:
75 | num_gpu = get_gpu_num(model_path)
76 | gpu_offset = kwargs.pop('gpu_offset', 0)
77 | cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset+num_gpu)])
78 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
79 |
80 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
81 | from transformers.generation import GenerationConfig
82 |
83 | if model_path not in validated_llms:
84 | self.logger.warning(f"{model_path} not in validated LLMs, may have inference troubles. ")
85 |
86 | self.model_path = model_path
87 | if listinstr(Auto_model, model_path):
88 | LoadModel = AutoModel
89 | else:
90 | LoadModel = AutoModelForCausalLM
91 |
92 | assert osp.exists(model_path) or len(model_path.split('/')) == 2
93 |
94 | device = self.explicit_device if self.explicit_device else "auto"
95 |
96 | precision = {}
97 | if 'internlm-chat-7b' in model_path:
98 | precision = {'torch_dtype': torch.float16}
99 | elif 'internlm-chat-20b' in model_path:
100 | precision = {'torch_dtype': torch.bfloat16}
101 |
102 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
103 | model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision)
104 | model = model.eval()
105 |
106 | if device != 'cpu':
107 | model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda')
108 | try:
109 | model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True, device_map=device)
110 | except:
111 | pass
112 |
113 | torch.cuda.empty_cache()
114 | self.model = model
115 | self.context_length = self._get_context_length_robust(model=model, model_path=model_path)
116 | self.answer_buffer = 192
117 | self.system_prompt = system_prompt
118 | for k, v in kwargs.items():
119 | self.logger.info(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ')
120 | self.kwargs = kwargs
121 |
122 | def generate_str(self, input, **kwargs):
123 | if 'baichuan' in self.model_path.lower():
124 | messages=[]
125 | messages.append({"role": "user", "content": input})
126 | resp= self.model.chat(self.tokenizer, messages, **kwargs)
127 | elif 'vicuna' in self.model_path.lower():
128 | from fastchat.model import get_conversation_template
129 | conv = get_conversation_template('vicuna')
130 | conv.append_message(conv.roles[0], input)
131 | conv.append_message(conv.roles[1], None)
132 | prompt = conv.get_prompt()
133 | inputs = self.tokenizer([prompt], return_tensors="pt")
134 | if torch.cuda.is_available():
135 | for k in inputs:
136 | inputs[k] = inputs[k].cuda()
137 |
138 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
139 | params.update(self.kwargs)
140 | params.update(kwargs)
141 | outputs = self.model.generate(**inputs, **params)
142 | resp = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False)
143 |
144 | else:
145 | params = self.kwargs
146 | params.update(kwargs)
147 | resp, _ = self.model.chat(self.tokenizer, input, history=[], **params)
148 |
149 | return resp
150 |
151 | def length_ok(self, inputs):
152 | tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0
153 | for s in inputs:
154 | tot += len(self.tokenizer.encode(s))
155 | return tot + self.answer_buffer < self.context_length
156 |
157 | def generate_list(self, full_inputs, offset=0, **kwargs):
158 | assert isinstance(full_inputs, list)
159 |
160 | inputs = full_inputs[offset:]
161 | if not self.length_ok(inputs):
162 | return self.chat(full_inputs, offset + 1)
163 |
164 | model_path = self.model_path.lower()
165 |
166 | if sum([x in model_path for x in ['baichuan']]):
167 | input_msgs = []
168 | if self.system_prompt is not None:
169 | input_msgs.append(dict(role='user', content=self.system_prompt))
170 | if len(inputs):
171 | assert isinstance(inputs, list) and isinstance(inputs[0], str)
172 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
173 | roles = roles * len(inputs)
174 | for role, msg in zip(roles, inputs):
175 | input_msgs.append(dict(role=role, content=msg))
176 | response = self.model.chat(self.tokenizer, input_msgs)
177 | elif sum([x in model_path for x in ['vicuna']]):
178 | from fastchat.model import get_conversation_template
179 | conv = get_conversation_template('vicuna')
180 | assert isinstance(inputs, list) and isinstance(inputs[0], str)
181 | if len(inputs) % 2 == 1:
182 | if self.system_prompt is not None:
183 | conv.append_message(conv.roles[0], self.system_prompt)
184 | for i in range(len(inputs)//2):
185 | conv.append_message(conv.roles[0], inputs[2 * i])
186 | conv.append_message(conv.roles[1], inputs[2 * i + 1])
187 | else:
188 | assert self.system_prompt is not None
189 | conv.append_message(conv.roles[0], self.system_prompt)
190 | conv.append_message(conv.roles[1], inputs[0])
191 | for i in range(len(inputs) // 2 - 1):
192 | conv.append_message(conv.roles[0], inputs[2 * i + 1])
193 | conv.append_message(conv.roles[1], inputs[2 * i + 2])
194 | conv.append_message(conv.roles[0], inputs[-1])
195 | conv.append_message(conv.roles[1], None)
196 | prompt = conv.get_prompt()
197 | inputs = self.tokenizer([prompt], return_tensors="pt")
198 | if torch.cuda.is_available():
199 | for k in inputs:
200 | inputs[k] = inputs[k].cuda()
201 |
202 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
203 | params.update(self.kwargs)
204 | params.update(kwargs)
205 |
206 | outputs = self.model.generate(**inputs, **params)
207 | response = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False)
208 | response = response.lstrip('\n')
209 | else:
210 | # The default option, support internlm, chatglm, qwen
211 | history, msg = [], None
212 | if len(inputs) % 2 == 1:
213 | if self.system_prompt is not None:
214 | history = [(self.system_prompt, '')]
215 | for i in range(len(inputs)//2):
216 | history.append((inputs[2 * i], inputs[2 * i + 1]))
217 | else:
218 | assert self.system_prompt is not None
219 | history = [(self.system_prompt, inputs[0])]
220 | for i in range(len(inputs) // 2 - 1):
221 | history.append((inputs[2 * i + 1], inputs[2 * i + 2]))
222 | msg = inputs[-1]
223 |
224 | params = self.kwargs
225 | params.update(kwargs)
226 | response, _ = self.model.chat(self.tokenizer, msg, history=history, **params)
227 |
228 | return response, offset
229 |
230 | def generate(self, inputs, **kwargs):
231 | if isinstance(inputs, str):
232 | return self.generate_str(inputs, **kwargs)
233 | elif isinstance(inputs, list):
234 | return self.generate_list(inputs, **kwargs)
--------------------------------------------------------------------------------
/vlmeval/api/qwen_vl_api.py:
--------------------------------------------------------------------------------
1 | from vlmeval.smp import *
2 | from vlmeval.api.base import BaseAPI
3 |
4 | class QwenVLWrapper(BaseAPI):
5 |
6 | is_api: bool = True
7 |
8 | def __init__(self,
9 | model: str = 'qwen-vl-plus',
10 | retry: int = 5,
11 | wait: int = 5,
12 | key: str = None,
13 | verbose: bool = True,
14 | temperature: float = 0.0,
15 | system_prompt: str = None,
16 | max_tokens: int = 1024,
17 | proxy: str = None,
18 | **kwargs):
19 |
20 | assert model in ['qwen-vl-plus', 'qwen-vl-max']
21 | self.model = model
22 | import dashscope
23 | self.fail_msg = 'Failed to obtain answer via API. '
24 | self.max_tokens = max_tokens
25 | self.temperature = temperature
26 | if key is None:
27 | key = os.environ.get('DASHSCOPE_API_KEY', None)
28 | assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)"
29 | dashscope.api_key = key
30 | if proxy is not None:
31 | proxy_set(proxy)
32 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
33 |
34 | @staticmethod
35 | def build_msgs(msgs_raw, system_prompt=None):
36 | msgs = cp.deepcopy(msgs_raw)
37 | ret = []
38 | if system_prompt is not None:
39 | content = list(dict(text=system_prompt))
40 | ret.append(dict(role='system', content=content))
41 | content = []
42 | for i, msg in enumerate(msgs):
43 | if osp.exists(msg):
44 | content.append(dict(image='file://' + msg))
45 | elif msg.startswith('http'):
46 | content.append(dict(image=msg))
47 | else:
48 | content.append(dict(text=msg))
49 | ret.append(dict(role='user', content=content))
50 | return ret
51 |
52 | def generate_inner(self, inputs, **kwargs) -> str:
53 | from dashscope import MultiModalConversation
54 | assert isinstance(inputs, str) or isinstance(inputs, list)
55 | pure_text = True
56 | if isinstance(inputs, list):
57 | for pth in inputs:
58 | if osp.exists(pth) or pth.startswith('http'):
59 | pure_text = False
60 | assert not pure_text
61 | messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
62 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
63 | gen_config.update(self.kwargs)
64 | try:
65 | response = MultiModalConversation.call(model=self.model, messages=messages)
66 | if self.verbose:
67 | print(response)
68 | answer = response.output.choices[0]['message']['content'][0]['text']
69 | return 0, answer, 'Succeeded! '
70 | except Exception as err:
71 | if self.verbose:
72 | self.logger.error(err)
73 | self.logger.error(f"The input messages are {inputs}.")
74 |
75 | return -1, '', ''
76 |
77 | class QwenVLAPI(QwenVLWrapper):
78 |
79 | def generate(self, image_path, prompt, dataset=None):
80 | return super(QwenVLAPI, self).generate([image_path, prompt])
81 |
82 | def multi_generate(self, image_paths, prompt, dataset=None):
83 | return super(QwenVLAPI, self).generate(image_paths + [prompt])
84 |
85 | def interleave_generate(self, ti_list, dataset=None):
86 | return super(QwenVLAPI, self).generate(ti_list)
87 |
--------------------------------------------------------------------------------
/vlmeval/config.py:
--------------------------------------------------------------------------------
1 | from .vlm import *
2 | from .api import GPT4V, GPT4V_Internal, GeminiProVision, QwenVLAPI
3 | from functools import partial
4 |
5 | PandaGPT_ROOT = None
6 | MiniGPT4_ROOT = None
7 | TransCore_ROOT = None
8 | Yi_ROOT = None
9 |
10 | api_models = {
11 | 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10),
12 | 'GPT4V_INT': partial(GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10),
13 | 'GPT4V_SHORT': partial(
14 | GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10,
15 | system_prompt="Please responde to the following question / request in a short reply. "),
16 | 'GPT4V_SHORT_INT': partial(
17 | GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10,
18 | system_prompt="Please responde to the following question / request in a short reply. "),
19 | 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10),
20 | 'QwenVLPlus': partial(QwenVLAPI, model='qwen-vl-plus', temperature=0, retry=10),
21 | 'QwenVLMax': partial(QwenVLAPI, model='qwen-vl-max', temperature=0, retry=10),
22 | }
23 |
24 | models = {
25 | 'hpt-air-mmmu': partial(HPT),
26 | 'hpt-air-mmbench': partial(HPT, vis_scale=392, is_crop=False),
27 | 'hpt-air-seed': partial(HPT, is_crop=False),
28 | 'hpt-air-demo': partial(HPT, vis_scale=392, is_crop=False),
29 | 'hpt-air-demo-local': partial(HPT, vis_scale=392, is_crop=False, global_model_path='../HPT_AIR_HF/'),
30 | 'hpt-air-1-5': partial(HPT1_5, global_model_path='HyperGAI/HPT1_5-Air-Llama-3-8B-Instruct-multimoda', vis_scale=448, prompt_template='llama3_chat'),
31 | 'hpt-edge-1-5': partial(HPT1_5, global_model_path='HyperGAI/HPT1_5-Edge', vis_scale=490, prompt_template='phi3_chat'),
32 | }
33 |
34 | supported_VLM = {}
35 | for model_set in [models, api_models]:
36 | supported_VLM.update(model_set)
--------------------------------------------------------------------------------
/vlmeval/evaluate/__init__.py:
--------------------------------------------------------------------------------
1 | from .yes_or_no import default_rating, MME_rating, YOrN_eval
2 | from .mmvet_eval import MMVet_eval
3 | from .multiple_choice import multiple_choice_eval
4 | from .coco_eval import COCO_eval
5 | from .vqa_eval import VQAEval
6 | from .mathvista_eval import MathVista_eval
7 | from .llavabench import LLaVABench_eval
8 | from .misc import build_judge
9 | from .mmmu_eval import mmmu_eval_func
--------------------------------------------------------------------------------
/vlmeval/evaluate/coco_eval.py:
--------------------------------------------------------------------------------
1 | from vlmeval.smp import *
2 | from pycocoevalcap.bleu.bleu import Bleu
3 | from pycocoevalcap.rouge.rouge import Rouge
4 | from pycocoevalcap.cider.cider import Cider
5 |
6 |
7 | class COCO_Caption_Scorer():
8 | def __init__(self, ref, gt):
9 | self.ref = ref
10 | self.gt = gt
11 | print('setting up scorers...')
12 | self.scorers = [
13 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
14 | # (Meteor(), "METEOR"), # need java version 11.0.16+
15 | (Rouge(), "ROUGE_L"),
16 | (Cider(), "CIDEr"),
17 | # (Spice(), "SPICE"), # need java version 11.0.16+
18 | ]
19 |
20 | def compute_scores(self):
21 | total_scores = {}
22 | for scorer, method in self.scorers:
23 | print('computing %s score...' % (scorer.method()))
24 | score, scores = scorer.compute_score(self.gt, self.ref)
25 | if type(method) == list:
26 | for sc, scs, m in zip(score, scores, method):
27 | print("%s: %0.3f" % (m, sc * 100))
28 | total_scores["Bleu"] = [x * 100 for x in score]
29 | else:
30 | print("%s: %0.3f" % (method, score * 100))
31 | total_scores[method] = score * 100
32 |
33 | print('*****DONE*****')
34 | for key, value in total_scores.items():
35 | print('{}:{}'.format(key, value))
36 | return total_scores
37 |
38 | def COCO_eval(eval_file, nproc=4, verbose=False):
39 | logger = get_logger('Evaluation')
40 |
41 | data = load(eval_file)
42 |
43 | lt = len(data)
44 | lines = [data.iloc[i] for i in range(lt)]
45 | ref = {}
46 | gt = {}
47 | for i,(line) in enumerate(lines):
48 | ref[str(i)] = [str(line['prediction'])]
49 | gt[str(i)] = eval(line['answer'])
50 |
51 | scorer = COCO_Caption_Scorer(ref,gt)
52 | coco_caption_score_dict = scorer.compute_scores()
53 |
54 | score_pth = eval_file.replace('.xlsx','_score.json')
55 | dump(coco_caption_score_dict, score_pth)
56 | logger.info(f'COCO_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
57 | logger.info(f'Score: ')
58 | for key, value in coco_caption_score_dict.items():
59 | logger.info('{}:{}'.format(key, value))
60 |
61 | def parse_args():
62 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
63 | parser.add_argument("--data", type=str, help="The question set for inference, in excel / tsv / json format. ")
64 | parser.add_argument("--nproc", type=int, default=4)
65 | parser.add_argument("--verbose", action='store_true')
66 | args = parser.parse_args()
67 | return args
68 |
69 | if __name__ == '__main__':
70 | args = parse_args()
71 | COCO_eval(eval_file=args.data, nproc=args.nproc, verbose=args.verbose)
72 |
73 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/llavabench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import pandas as pd
4 | import os.path as osp
5 | from vlmeval.evaluate.misc import build_judge
6 | from vlmeval.smp import get_logger, load, dump, defaultdict
7 | from vlmeval.utils import track_progress_rich
8 |
9 | rule_dict = {
10 | "llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
11 | "llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
12 | "llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
13 | }
14 |
15 | def get_eval(judge, content):
16 | return judge.generate(content)
17 |
18 | def parse_score(review):
19 | logger = get_logger('Evaluation')
20 | try:
21 | score_pair = review.split('\n')[0]
22 | score_pair = score_pair.replace(',', ' ')
23 | sp = score_pair.split(' ')
24 | if len(sp) == 2:
25 | return [float(sp[0]), float(sp[1])]
26 | else:
27 | logger.error('error', review)
28 | return [-1, -1]
29 | except Exception as e:
30 | logger.error(e, 'error', review)
31 | return [-1, -1]
32 |
33 | def build_prompt(line):
34 | cap_str = line['caption']
35 | question = line['question']
36 | ans1 = line['gpt4_ans']
37 | ans2 = line['prediction']
38 | category = 'llava_bench_' + line['category']
39 | rule = rule_dict[category]
40 | role, prompt = rule['role'], rule['prompt']
41 |
42 | content = (f'[Context]\n{cap_str}\n\n'
43 | f'[Question]\n{question}\n\n'
44 | f'[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n'
45 | f'[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n'
46 | f'[System]\n{prompt}\n\n')
47 | return content
48 |
49 | def LLaVABench_atomeval(model, prompt):
50 | review = get_eval(model, prompt)
51 | scores = parse_score(review)
52 | return scores
53 |
54 | def LLaVABench_score(data):
55 | cates = ['overall'] + list(set(data['category']))
56 | ret = defaultdict(list)
57 |
58 |
59 | for c in cates:
60 | ret['split'].append(c)
61 | sub = data[data['category'] == c] if c != 'overall' else data
62 | ret['Relative Score (main)'].append(np.mean(sub['score']) / np.mean(sub['gpt4_score']) * 100)
63 | ret['VLM Score'].append(np.mean(sub['score']) * 10)
64 | ret['GPT4 Score'].append(np.mean(sub['gpt4_score']) * 10)
65 | return pd.DataFrame(ret)
66 |
67 | def LLaVABench_eval(eval_file, model='gpt-4-0314', nproc=4, verbose=False):
68 | suffix = '.' + eval_file.split('.')[-1]
69 | record_file = eval_file.replace(suffix, '_openai_result' + suffix)
70 | score_file = eval_file.replace(suffix, '_score.csv')
71 |
72 | if not osp.exists(record_file):
73 | data = load(eval_file)
74 | lines = [data.iloc[i] for i in range(len(data))]
75 | model = build_judge(
76 | model, temperature=0.2, retry=10, verbose=verbose,
77 | system_prompt='You are a helpful and precise assistant for checking the quality of the answer.')
78 | prompts = [build_prompt(line) for line in lines]
79 | tups = [(model, prompt) for prompt in prompts]
80 | scores = track_progress_rich(LLaVABench_atomeval, tups, nproc=nproc, chunksize=nproc)
81 | data['gpt4_score'] = [x[0] for x in scores]
82 | data['score'] = [x[1] for x in scores]
83 | dump(data, record_file)
84 |
85 | data = load(record_file)
86 | ret = LLaVABench_score(data).round(1)
87 | print(ret)
88 | dump(ret, score_file)
89 | return ret
90 |
91 | def parse_args():
92 | parser = argparse.ArgumentParser(description="LLaVABench Evaluation. ")
93 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
94 | parser.add_argument(
95 | "--model", type=str, help="The LLM (GPT) used for inference. ", default="gpt-4-turbo",
96 | choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613', 'gpt-4-0314'])
97 | parser.add_argument("--nproc", type=int, default=4)
98 | parser.add_argument("--verbose", action='store_true')
99 | args = parser.parse_args()
100 | return args
101 |
102 | if __name__ == '__main__':
103 | args = parse_args()
104 | LLaVABench_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
105 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mathvista_eval.py:
--------------------------------------------------------------------------------
1 | from vlmeval.evaluate.misc import build_judge
2 | from vlmeval.smp import *
3 | from vlmeval.utils import track_progress_rich
4 | from vlmeval.utils.matching_util import can_infer
5 |
6 | def get_gpt4_ICE():
7 | example_1 = """
8 | Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.\n
9 | Question: Which number is missing?\n
10 | Model response: The number missing in the sequence is 14.\n
11 | Extracted answer: 14
12 | """
13 |
14 | example_2 = """
15 | Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.\n
16 | Question: What is the fraction of females facing the camera?\n
17 | Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.\n
18 | Extracted answer: 0.6
19 | """
20 |
21 | example_3 = """
22 | Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.\n
23 | Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
24 | Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
25 | Extracted answer: 1.45
26 | """
27 |
28 | example_4 = """
29 | Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n
30 | Question: Between which two years does the line graph saw its maximum peak?\n
31 | Model response: The line graph saw its maximum peak between 2007 and 2008.\n
32 | Extracted answer: [2007, 2008]
33 | """
34 |
35 | example_5 = """
36 | Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
37 | Question: What fraction of the shape is blue?\n
38 | Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
39 | Model response: The correct answer is (B) 8/11.\n
40 | Extracted answer: B
41 | """
42 | return [example_1,example_2,example_3,example_4,example_5]
43 |
44 |
45 | def build_mathvista_gpt4_prompt(line):
46 | task_description = """ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.\n"""
47 | question = line['question']
48 | prediction = str(line['prediction'])
49 | prompt = task_description
50 | examples = get_gpt4_ICE()
51 | for example in examples:
52 | prompt += example + '\n'
53 | prompt += question + '\n'
54 | prompt += 'Model respone: ' + prediction
55 | prompt += 'Extracted answer:'
56 | return prompt
57 |
58 | def list_to_dict(lst):
59 | return {chr(65 + i): val for i, val in enumerate(lst)}
60 |
61 | def post_check(line, prefetch=False):
62 | res = None
63 | ans = line['answer']
64 | response = line['prediction'] if prefetch else line['res']
65 | try:
66 | if line['question_type'] == 'multi_choice':
67 | ans = line['answer_option']
68 | choices = list_to_dict(eval(line['choices']))
69 | res = can_infer(response, choices)
70 | if prefetch:
71 | return res
72 | else:
73 | if line['answer_type'] == 'integer':
74 | res = int(response)
75 | ans = int(line['answer'])
76 | elif line['answer_type'] == 'float':
77 | res = float(response)
78 | ans = float(line['answer'])
79 | else:
80 | res = str(res)
81 | ans = str(ans)
82 | except ValueError:
83 | pass
84 |
85 | if res == ans:
86 | return res
87 | else:
88 | return False
89 |
90 | def MathVista_auxeval(model, line):
91 | prompt = build_mathvista_gpt4_prompt(line)
92 | log = ''
93 | retry = 5
94 | if post_check(line, prefetch=True):
95 | res = post_check(line, prefetch=True)
96 | return dict(log='Prefetch succeed', res=res)
97 | for i in range(retry):
98 | prediction = line['prediction']
99 | res = model.generate(prompt, temperature=i * 0.5)
100 | if res is None:
101 | log += f'Try {i}: output is {prediction}, failed to parse.\n'
102 | else:
103 | log += 'Succeed'
104 | return dict(log=log, res= res)
105 | log += 'All 5 retries failed.\n'
106 | return dict(log=log, res='')
107 |
108 | def MathVista_acc(result_file):
109 | data = load(result_file)
110 | tot = defaultdict(lambda: 0)
111 | fetch = defaultdict(lambda: 0)
112 | hit = defaultdict(lambda: 0)
113 | lt = len(data)
114 | skill_list = []
115 | for i in range(lt):
116 | item = data.iloc[i]
117 | index = item['index']
118 | cate = item['task']
119 | tot['Overall'] += 1
120 | try:
121 | skills = eval(item['skills'])
122 | except SyntaxError:
123 | skills = [item['skills']]
124 | for skill in skills:
125 | if skill not in skill_list:
126 | skill_list.append(skill)
127 | tot[skill] += 1
128 | tot[cate] += 1
129 | if item['log'] == 'Prefetch succeed':
130 | fetch['Overall'] += 1
131 | fetch[cate] += 1
132 | for skill in skills:
133 | fetch[skill] += 1
134 | if post_check(item, prefetch=False):
135 | hit['Overall'] += 1
136 | hit[cate] += 1
137 | for skill in skills:
138 | hit[skill] += 1
139 |
140 | res = defaultdict(list)
141 | for k in tot.keys():
142 | res['Task&Skill'].append(k)
143 | res['tot'].append(tot[k])
144 | res['prefetch'].append(fetch[k])
145 | res['hit'].append(hit[k])
146 | res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
147 | res['acc'].append(hit[k] / tot[k] * 100)
148 | res = pd.DataFrame(res)
149 | return res
150 |
151 | def MathVista_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
152 | logger = get_logger('Evaluation')
153 |
154 | suffix = eval_file.split('.')[-1]
155 | storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
156 | tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
157 | if osp.exists(storage):
158 | logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MathVista_eval. ")
159 | else:
160 | data = load(eval_file)
161 | gpt_version = model
162 | model = build_judge(gpt_version, verbose=verbose, max_tokens=128, retry=10)
163 |
164 | lt = len(data)
165 | lines = [data.iloc[i] for i in range(lt)]
166 | tups = [(model, line) for line in lines]
167 | indices = [line['index'] for line in lines]
168 |
169 | ans = {}
170 | if osp.exists(tmp_file):
171 | ans = load(tmp_file)
172 | tups = [x for x, i in zip(tups, indices) if i not in ans]
173 | indices = [i for i in indices if i not in ans]
174 |
175 | if len(indices):
176 | new_results = track_progress_rich(
177 | MathVista_auxeval, tups, nproc=nproc, chunksize=nproc,
178 | keys=indices, save=tmp_file)
179 | ans = load(tmp_file)
180 | for k, v in zip(indices, new_results):
181 | assert k in ans
182 | assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
183 |
184 | log_map, res_map = {}, {}
185 | all_inds = [line['index'] for line in lines]
186 | for k in all_inds:
187 | log_map[k] = ans[k]['log']
188 | res_map[k] = ans[k]['res']
189 | data['res'] = [res_map[idx] for idx in data['index']]
190 | data['log'] = [log_map[idx] for idx in data['index']]
191 | dump(data, storage)
192 |
193 | score = MathVista_acc(storage)
194 | score_pth = storage.replace('.xlsx','_score.csv')
195 |
196 | dump(score,score_pth)
197 | logger.info(f'MathVista_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
198 | logger.info(f'Score: ')
199 | logger.info(score)
200 |
201 | def parse_args():
202 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
203 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
204 | parser.add_argument(
205 | "--model",
206 | type=str,
207 | help="The LLM (GPT) used for inference. ",
208 | default="gpt-4-turbo",
209 | choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613'])
210 | parser.add_argument("--nproc", type=int, default=4)
211 | parser.add_argument("--verbose", action='store_true')
212 | args = parser.parse_args()
213 | return args
214 |
215 | if __name__ == '__main__':
216 | args = parse_args()
217 | MathVista_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
218 |
219 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal
3 |
4 | INTERNAL = os.environ.get('INTERNAL', 0)
5 |
6 | def build_judge(version, **kwargs):
7 | model_map = {
8 | 'gpt-4-turbo': 'gpt-4-1106-preview',
9 | 'gpt-4-0613': 'gpt-4-0613',
10 | 'gpt-4-0314': 'gpt-4-0314',
11 | 'chatgpt-1106': 'gpt-3.5-turbo-1106',
12 | 'chatgpt-0613': 'gpt-3.5-turbo-0613'
13 | }
14 | model_version = model_map[version]
15 | if INTERNAL:
16 | model = OpenAIWrapperInternal(model_version, **kwargs)
17 | else:
18 | model = OpenAIWrapper(model_version, **kwargs)
19 | return model
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation Guidelines
2 | We provide detailed instructions for evaluation.
3 | To execute our evaluation script, please ensure that the structure of your model outputs is the same as ours.
4 |
5 | We provide two options:
6 | 1. Evaluation only: you can parse the response on your own and simply provide one file with all the final predictions.
7 | 2. Parse and evaluation: you can leave all the responses to us with the output formats shown below.
8 |
9 | ## Evaluation Only
10 | If you want to use your own parsing logic and *only provide the final answer*, you can use `main_eval_only.py`.
11 |
12 | You can provide all the outputs in *one file* in the following format:
13 |
14 | ```
15 | {
16 | "validation_Accounting_1": "D", # strictly "A", "B", "C", "D" for multi-choice question
17 | "validation_Architecture_and_Engineering_14": "0.0", # any string response for open question.
18 | ...
19 | }
20 | ```
21 | Then run eval_only with:
22 | ```
23 | python main_eval_only.py --output_path ./example_outputs/llava1.5_13b/total_val_output.json
24 | ```
25 |
26 | Please refer to [example output](https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/example_outputs/llava1.5_13b/total_val_output.json) for a detailed prediction file form.
27 |
28 |
29 | ## Parse and Evaluation
30 | You can also provide response and run the `main_parse_and_eval.py` to use our answer parsing processing and evaluation pipeline as follows:
31 |
32 | ### Output folder structure
33 |
34 | ```
35 | └── model_name
36 | ├── category_name (e.g., Accounting)
37 | │ ├── output.json
38 | └── category_name (e.g., Electronics)
39 | ├── output.json
40 | ...
41 | ```
42 |
43 | ### Output file
44 | Each `output.json`` has a list of dict containing instances for evaluation ().
45 | ```
46 | [
47 | {
48 | "id": "validation_Electronics_28",
49 | "question_type": "multiple-choice",
50 | "answer": "A", # given answer
51 | "all_choices": [ # create using `get_multi_choice_info` in
52 | "A",
53 | "B",
54 | "C",
55 | "D"
56 | ],
57 | "index2ans": { # create using `get_multi_choice_info` in
58 | "A": "75 + 13.3 cos(250t - 57.7°)V",
59 | "B": "75 + 23.3 cos(250t - 57.7°)V",
60 | "C": "45 + 3.3 cos(250t - 57.7°)V",
61 | "D": "95 + 13.3 cos(250t - 57.7°)V"
62 | },
63 | "response": "B" # model response
64 | },
65 | {
66 | "id": "validation_Electronics_29",
67 | "question_type": "short-answer",
68 | "answer": "30", # given answer
69 | "response": "36 watts" # model response
70 | },
71 | ...
72 | ]
73 | ```
74 |
75 | ### Evaluation
76 | ```
77 | python main_parse_and_eval.py --path ./example_outputs/llava1.5_13b --subject ALL # all subject
78 |
79 | # OR you can sepecify one subject for the evaluation
80 |
81 | python main_parse_and_eval.py --path ./example_outputs/llava1.5_13b --subject elec # short name for Electronics. use --help for all short names
82 |
83 | ```
84 |
85 | `main_parse_and_eval.py` will generate `parsed_output.json` and `result.json` in the subfolder under the same category with output.json, respectively.
86 |
87 | ```
88 | ├── Accounting
89 | │ ├── output.json
90 | │ ├── parsed_output.json
91 | │ └── result.json
92 | └── Electronics
93 | ├── output.json
94 | ├── parsed_output.json
95 | └── result.json
96 | ...
97 | ```
98 |
99 | ### Print Results
100 | You can print results locally if you want. (use `pip install tabulate` if you haven't)
101 | ```
102 | python print_results.py --path ./example_outputs/llava1.5_13b
103 | # Results may be slightly different due to the ramdon selection for fail response
104 | ```
105 |
106 |
107 |
108 | ##### Run Llava
109 | In case if you want to reproduce the results of some of the models, please go check run_llava.py as an example.
110 |
111 | By seeting up the env following the [llava official repo](https://github.com/haotian-liu/LLaVA) and installing `datasets` packages by huggingface, you can run llava viathe following command:
112 |
113 | ```
114 | CUDA_VISIBLE_DEVICES=0 nohup python run_llava.py \
115 | --output_path example_outputs/llava1.5_13b_val.json \
116 | --model_path liuhaotian/llava-v1.5-13b \
117 | --config_path configs/llava1.5.yaml
118 | ```
119 |
120 | Then you can evaluate the results via the very first pipeline.
121 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/__init__.py:
--------------------------------------------------------------------------------
1 | from .mmmu_eval_script import mmmu_eval_func
2 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/configs/llava1.5.yaml:
--------------------------------------------------------------------------------
1 | task_instructions:
2 | - ""
3 | multi_choice_example_format:
4 | - "{}
5 |
6 | {}
7 |
8 | Answer with the option's letter from the given choices directly."
9 |
10 | short_ans_example_format:
11 | - "{}
12 |
13 | Answer the question using a single word or phrase."
14 | temperature:
15 | - 0
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/example_outputs/llava1.5_13b/Electronics/output.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "id": "validation_Electronics_1",
4 | "question_type": "multiple-choice",
5 | "answer": "A",
6 | "all_choices": [
7 | "A",
8 | "B"
9 | ],
10 | "index2ans": {
11 | "A": "yes, saturation",
12 | "B": "no, not in saturation"
13 | },
14 | "response": "B"
15 | },
16 | {
17 | "id": "validation_Electronics_2",
18 | "question_type": "short-answer",
19 | "answer": "2.83",
20 | "response": "0.5"
21 | },
22 | {
23 | "id": "validation_Electronics_3",
24 | "question_type": "multiple-choice",
25 | "answer": "C",
26 | "all_choices": [
27 | "A",
28 | "B",
29 | "C",
30 | "D"
31 | ],
32 | "index2ans": {
33 | "A": "t + (1 / 10) (e^{-20t} - 1) V",
34 | "B": "t + (1 / 20) (e^{-10t} - 1) V",
35 | "C": "t + (1 / 10) (e^{-10t} - 1) V",
36 | "D": "t - (1 / 10) (e^{-10t} - 1) V"
37 | },
38 | "response": "C"
39 | },
40 | {
41 | "id": "validation_Electronics_4",
42 | "question_type": "short-answer",
43 | "answer": "8.4",
44 | "response": "1.5"
45 | },
46 | {
47 | "id": "validation_Electronics_5",
48 | "question_type": "short-answer",
49 | "answer": "62.6",
50 | "response": "500"
51 | },
52 | {
53 | "id": "validation_Electronics_6",
54 | "question_type": "short-answer",
55 | "answer": "71.6",
56 | "response": "5"
57 | },
58 | {
59 | "id": "validation_Electronics_7",
60 | "question_type": "multiple-choice",
61 | "answer": "C",
62 | "all_choices": [
63 | "A",
64 | "B",
65 | "C",
66 | "D"
67 | ],
68 | "index2ans": {
69 | "A": "$\\sqrt(2) cos[(1/3)t]$",
70 | "B": "$\\sqrt(3) cos[(2/3)t]$",
71 | "C": "$\\sqrt(2) cos[(2/3)t]$",
72 | "D": "$\\sqrt(3) cos[(4/3)t]$"
73 | },
74 | "response": "A"
75 | },
76 | {
77 | "id": "validation_Electronics_8",
78 | "question_type": "multiple-choice",
79 | "answer": "B",
80 | "all_choices": [
81 | "A",
82 | "B",
83 | "C",
84 | "D"
85 | ],
86 | "index2ans": {
87 | "A": "(4 / \\pi) (sin \\pit + (1 / 2) sin 3\\pit + (1 / 4) sin 5\\pit + ....).",
88 | "B": "(4 / \\pi) (sin \\pit + (1 / 3) sin 3\\pit + (1 / 5) sin 5\\pit + ....).",
89 | "C": "(4 / \\pi) (sin \\pit + (1 / 2) sin 2\\pit + (1 / 4) sin 4\\pit + ....).",
90 | "D": "(4 / \\pi) (sin \\pit + (1 / 3) sin 2\\pit + (1 / 5) sin 4\\pit + ....)."
91 | },
92 | "response": "A"
93 | },
94 | {
95 | "id": "validation_Electronics_9",
96 | "question_type": "multiple-choice",
97 | "answer": "C",
98 | "all_choices": [
99 | "A",
100 | "B",
101 | "C",
102 | "D"
103 | ],
104 | "index2ans": {
105 | "A": "-2[sin t + (1 / 2) sin 25 + (1 / 4) sin 3t + ...]",
106 | "B": "-2[sin t + (1 / 2) sin 30 + (1 / 3) sin 3t + ...]",
107 | "C": "-2[sin t + (1 / 2) sin 25 + (1 / 3) sin 3t + ...]",
108 | "D": "-2[sin t + (1 / 3) sin 25 + (1 / 3) sin 3t + ...]"
109 | },
110 | "response": "A"
111 | },
112 | {
113 | "id": "validation_Electronics_10",
114 | "question_type": "multiple-choice",
115 | "answer": "A",
116 | "all_choices": [
117 | "A",
118 | "B",
119 | "C",
120 | "D"
121 | ],
122 | "index2ans": {
123 | "A": "0.125 + j0.330",
124 | "B": "0.15 + j0.330",
125 | "C": "0.125 + j0.390",
126 | "D": "0.121 + j0.380"
127 | },
128 | "response": "A"
129 | },
130 | {
131 | "id": "validation_Electronics_11",
132 | "question_type": "short-answer",
133 | "answer": "0.3",
134 | "response": "0"
135 | },
136 | {
137 | "id": "validation_Electronics_12",
138 | "question_type": "short-answer",
139 | "answer": "10",
140 | "response": "24"
141 | },
142 | {
143 | "id": "validation_Electronics_13",
144 | "question_type": "multiple-choice",
145 | "answer": "C",
146 | "all_choices": [
147 | "A",
148 | "B",
149 | "C",
150 | "D"
151 | ],
152 | "index2ans": {
153 | "A": "[(1 - e^{-s(T/2)}) / {s(1 - e^{-sT})}]",
154 | "B": "[(2 - e^{-s(T)}) / {s(1 - e^{-sT})}]",
155 | "C": "[(1 - e^{-s(T/2)}) / {s(1 - e^{-sT})}]",
156 | "D": "[(1 - e^{-s(T/3)}) / {s(1 - e^{-sT})}]"
157 | },
158 | "response": "A"
159 | },
160 | {
161 | "id": "validation_Electronics_14",
162 | "question_type": "multiple-choice",
163 | "answer": "C",
164 | "all_choices": [
165 | "A",
166 | "B",
167 | "C",
168 | "D"
169 | ],
170 | "index2ans": {
171 | "A": "4t^2 [u(t) - u(t - 5)] + 20[u(t - 2) - u(t - 5)] + 15(t - 7)[u(t - 5) - u(t - 7)]",
172 | "B": "4t^2 [u(t) - u(t - 2)] + 20[u(t - 2) - u(t - 5)] + 15(t - 7)[u(t - 5) - u(t - 7)]",
173 | "C": "5t^2 [u(t) - u(t - 2)] + 20[u(t - 2) - u(t - 5)] + 15(t - 7)[u(t - 5) - u(t - 7)]",
174 | "D": "5t^2 [u(t) - u(t - 2)] + 20[u(t - 2) - u(t - 3)] + 15(t - 7)[u(t - 5) - u(t - 7)]"
175 | },
176 | "response": "A"
177 | },
178 | {
179 | "id": "validation_Electronics_15",
180 | "question_type": "multiple-choice",
181 | "answer": "A",
182 | "all_choices": [
183 | "A",
184 | "B",
185 | "C",
186 | "D"
187 | ],
188 | "index2ans": {
189 | "A": "$[1 + sech^2 (sin t)] cos t$",
190 | "B": "$[1 + sech^2 (sin t)] sin t$",
191 | "C": "$[1 - sech^2 (sin t)] sin t$",
192 | "D": "$[1 - sech^2 (cos t)] sin t$"
193 | },
194 | "response": "A"
195 | },
196 | {
197 | "id": "validation_Electronics_16",
198 | "question_type": "short-answer",
199 | "answer": "20",
200 | "response": "100"
201 | },
202 | {
203 | "id": "validation_Electronics_17",
204 | "question_type": "multiple-choice",
205 | "answer": "A",
206 | "all_choices": [
207 | "A",
208 | "B",
209 | "C",
210 | "D"
211 | ],
212 | "index2ans": {
213 | "A": "10e^{-0.8t} V",
214 | "B": "-5e^{-0.8t} V",
215 | "C": "-2e^{-1t} V",
216 | "D": "-6e^{-2t} V"
217 | },
218 | "response": "A"
219 | },
220 | {
221 | "id": "validation_Electronics_18",
222 | "question_type": "multiple-choice",
223 | "answer": "B",
224 | "all_choices": [
225 | "A",
226 | "B",
227 | "C",
228 | "D"
229 | ],
230 | "index2ans": {
231 | "A": "2 e^{-2t} u(t)",
232 | "B": "3 e^{-2t} u(t)",
233 | "C": "2.2 e^{-2t} u(t)",
234 | "D": "3 e^{-3t} u(t)"
235 | },
236 | "response": "A"
237 | },
238 | {
239 | "id": "validation_Electronics_19",
240 | "question_type": "multiple-choice",
241 | "answer": "A",
242 | "all_choices": [
243 | "A",
244 | "B",
245 | "C",
246 | "D"
247 | ],
248 | "index2ans": {
249 | "A": "-90 cos(t) V",
250 | "B": "-90 cos(2t) V",
251 | "C": "90 cos(2t) V",
252 | "D": "90 sin(2t) V"
253 | },
254 | "response": "B"
255 | },
256 | {
257 | "id": "validation_Electronics_20",
258 | "question_type": "multiple-choice",
259 | "answer": "C",
260 | "all_choices": [
261 | "A",
262 | "B",
263 | "C",
264 | "D"
265 | ],
266 | "index2ans": {
267 | "A": "$3 \\pi x 10^-5 A$",
268 | "B": "$\\pi x 10^-5 A$",
269 | "C": "$2 \\pi x 10^-5 A$",
270 | "D": "$\\pi x 10^-4 A$"
271 | },
272 | "response": "B"
273 | },
274 | {
275 | "id": "validation_Electronics_21",
276 | "question_type": "short-answer",
277 | "answer": "0.9965",
278 | "response": "0"
279 | },
280 | {
281 | "id": "validation_Electronics_22",
282 | "question_type": "multiple-choice",
283 | "answer": "C",
284 | "all_choices": [
285 | "A",
286 | "B",
287 | "C",
288 | "D"
289 | ],
290 | "index2ans": {
291 | "A": "[70.7 cos (20t - 60^{\\circ})] u(t) V",
292 | "B": "[70.7 cos (10t - 45^{\\circ})] u(t) V",
293 | "C": "[70.7 cos (20t - 45^{\\circ})] u(t) V",
294 | "D": "[70.7 cos (20t - 90^{\\circ})] u(t) V"
295 | },
296 | "response": "C"
297 | },
298 | {
299 | "id": "validation_Electronics_23",
300 | "question_type": "multiple-choice",
301 | "answer": "A",
302 | "all_choices": [
303 | "A",
304 | "B",
305 | "C",
306 | "D"
307 | ],
308 | "index2ans": {
309 | "A": "$C_0 sin wt cos t + 2w C_0 cos wt (1 + 0.5 sin t)$",
310 | "B": "$C_0 cos wt cos t + 2w C_0 sin wt (1 + 0.5 sin t)$",
311 | "C": "$C_0 sin wt cos t + 4w C_0 cos wt (1 + sin t)$",
312 | "D": "$C_0 sin wt cos t + 2w C_0 cos wt (1 - sin t)$"
313 | },
314 | "response": "A"
315 | },
316 | {
317 | "id": "validation_Electronics_24",
318 | "question_type": "short-answer",
319 | "answer": "551",
320 | "response": "100"
321 | },
322 | {
323 | "id": "validation_Electronics_25",
324 | "question_type": "short-answer",
325 | "answer": "50",
326 | "response": "100"
327 | },
328 | {
329 | "id": "validation_Electronics_26",
330 | "question_type": "short-answer",
331 | "answer": "6.333",
332 | "response": "100"
333 | },
334 | {
335 | "id": "validation_Electronics_27",
336 | "question_type": "short-answer",
337 | "answer": "-120",
338 | "response": "0"
339 | },
340 | {
341 | "id": "validation_Electronics_28",
342 | "question_type": "multiple-choice",
343 | "answer": "A",
344 | "all_choices": [
345 | "A",
346 | "B",
347 | "C",
348 | "D"
349 | ],
350 | "index2ans": {
351 | "A": "75 + 13.3 cos(250t - 57.7\u00b0)V",
352 | "B": "75 + 23.3 cos(250t - 57.7\u00b0)V",
353 | "C": "45 + 3.3 cos(250t - 57.7\u00b0)V",
354 | "D": "95 + 13.3 cos(250t - 57.7\u00b0)V"
355 | },
356 | "response": "B"
357 | },
358 | {
359 | "id": "validation_Electronics_29",
360 | "question_type": "short-answer",
361 | "answer": "30",
362 | "response": "36 watts"
363 | },
364 | {
365 | "id": "validation_Electronics_30",
366 | "question_type": "short-answer",
367 | "answer": "-141",
368 | "response": "0"
369 | }
370 | ]
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/main_eval_only.py:
--------------------------------------------------------------------------------
1 | """Parse and Evalate"""
2 | import os
3 | import json
4 |
5 | import pdb
6 | from argparse import ArgumentParser
7 |
8 | from utils.data_utils import save_json, CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT
9 | from utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response, calculate_ins_level_acc
10 |
11 |
12 | if __name__ == '__main__':
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument('--output_path', type=str, default="./example_outputs/qwen_vl/total_val_output.json", help="The path to model output file.")
16 | parser.add_argument('--answer_path', type=str, default="./answer_dict_val.json", help="Answer file path.")
17 | args = parser.parse_args()
18 |
19 | output_dict = json.load(open(args.output_path))
20 | answer_dict = json.load(open(args.answer_path))
21 |
22 | # group by category
23 | output_dict_w_cat = {}
24 | for data_id, parsed_pred in output_dict.items():
25 | category = "_".join(data_id.split("_")[1:-1])
26 | if category not in output_dict_w_cat:
27 | output_dict_w_cat.update({category: {}})
28 | output_dict_w_cat[category].update({data_id: parsed_pred})
29 |
30 | # group by category
31 | answer_dict_w_cat = {}
32 | for data_id, parsed_pred in answer_dict.items():
33 | category = "_".join(data_id.split("_")[1:-1])
34 | if category not in answer_dict_w_cat:
35 | answer_dict_w_cat.update({category: {}})
36 | answer_dict_w_cat[category].update({data_id: parsed_pred})
37 |
38 | evaluation_result = {}
39 |
40 | for category in CAT_SHORT2LONG.values():
41 | print("Evaluating: {}".format(category))
42 | # get cat_outputs and cat_answers
43 | try:
44 | cat_outputs = output_dict_w_cat[category]
45 | cat_answers = answer_dict_w_cat[category]
46 | except KeyError:
47 | print("Skipping {} for not found".format(category))
48 | continue
49 |
50 | exampels_to_eval = []
51 | for data_id, parsed_pred in cat_outputs.items():
52 | question_type = cat_answers[data_id]['question_type']
53 | if question_type != 'multiple-choice':
54 | parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
55 | else:
56 | parsed_pred = parsed_pred
57 |
58 | exampels_to_eval.append({
59 | "id": data_id,
60 | "question_type": question_type,
61 | "answer": cat_answers[data_id]['ground_truth'],
62 | "parsed_pred": parsed_pred
63 | })
64 |
65 | judge_dict, metric_dict = evaluate(exampels_to_eval)
66 | metric_dict.update({"num_example": len(exampels_to_eval)})
67 |
68 | evaluation_result[category] = metric_dict
69 |
70 | printable_results = {}
71 | # pdb.set_trace()
72 | # add domain Subject
73 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
74 | in_domain_cat_results = {}
75 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
76 | if cat_name in evaluation_result.keys():
77 | in_domain_cat_results[cat_name] = evaluation_result[cat_name]
78 | else:
79 | pass
80 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
81 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
82 | printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
83 | "acc": round(in_domain_ins_acc, 3)
84 | }
85 | # add sub category
86 | for cat_name, cat_results in in_domain_cat_results.items():
87 | printable_results[cat_name] = {"num": int(cat_results['num_example']),
88 | "acc": round(cat_results['acc'], 3)
89 | }
90 |
91 | # table.append(["-----------------------------", "-----", "----"])
92 | all_ins_acc = calculate_ins_level_acc(evaluation_result)
93 | printable_results['Overall'] = {"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
94 | "acc": round(all_ins_acc, 3)
95 | }
96 |
97 | print(printable_results)
98 |
99 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/main_parse_and_eval.py:
--------------------------------------------------------------------------------
1 | """Parse and Evalate"""
2 | import os
3 | import json
4 | from argparse import ArgumentParser
5 |
6 | from utils.data_utils import save_json, CAT_SHORT2LONG
7 | from utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response
8 |
9 |
10 | if __name__ == '__main__':
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('--path', type=str, default="./example_outputs/llava1.5_13b", help="The path to model output directory.")
14 | parser.add_argument('--subject', nargs='+',
15 | help=f'The name of the mmmu sub-category. Availble: {CAT_SHORT2LONG.keys()} or ALL')
16 |
17 | args = parser.parse_args()
18 | if args.subject[0] == 'ALL':
19 | args.subject = CAT_SHORT2LONG.keys()
20 |
21 | ex_output_path = os.path.join(args.path)
22 |
23 | all_results = {}
24 | for cat_short in args.subject:
25 | category = CAT_SHORT2LONG[cat_short]
26 | print("Evaluating: {}".format(category))
27 | if category not in os.listdir(ex_output_path):
28 | print("Skipping {} for not found".format(category))
29 | else:
30 | cat_folder_path = os.path.join(ex_output_path, category)
31 | cat_outputs = json.load(open(os.path.join(cat_folder_path, 'output.json')))
32 | # Evaluation
33 | eval_samples = []
34 | for cat_output in cat_outputs:
35 | response = cat_output['response']
36 | if cat_output['question_type'] == 'multiple-choice':
37 | all_choices = cat_output['all_choices']
38 | index2ans = cat_output['index2ans']
39 | parsed_pred = parse_multi_choice_response(response, all_choices, index2ans)
40 | eval_samples.append(
41 | {
42 | 'id': cat_output['id'],
43 | 'question_type': cat_output['question_type'],
44 | 'answer': cat_output['answer'], # the content in option, not answer index.
45 | 'response': response,
46 | 'parsed_pred': parsed_pred,
47 | 'index2ans': index2ans,
48 | }
49 | )
50 | else: # open
51 | parsed_pred = parse_open_response(response)
52 | eval_samples.append(
53 | {
54 | 'id': cat_output['id'],
55 | 'question_type': cat_output['question_type'],
56 | 'answer': cat_output['answer'],
57 | 'response': response,
58 | 'parsed_pred': parsed_pred,
59 | }
60 | )
61 |
62 | print("Num of valid samples: {}, Expected Num: {}".format(len(eval_samples), len(cat_outputs)))
63 |
64 | judge_dict, metric_dict = evaluate(eval_samples)
65 | metric_dict.update({"num_example": len(eval_samples)})
66 | for eval_sample in eval_samples:
67 | eval_sample.update({"judge": judge_dict[eval_sample['id']]})
68 |
69 | save_json(os.path.join(cat_folder_path, 'parsed_output.json'), eval_samples)
70 | save_json(os.path.join(cat_folder_path, 'result.json'), metric_dict)
71 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/mmmu_eval_script.py:
--------------------------------------------------------------------------------
1 | """Parse and Evalate"""
2 | import os
3 | import json
4 |
5 | import pdb
6 |
7 | from .utils.data_utils import save_json, CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT
8 | from .utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response, calculate_ins_level_acc
9 |
10 |
11 | def mmmu_eval_func(output_path):
12 | #
13 | answer_path = './vlmeval/evaluate/mmmu_eval/answer_dict_val.json'
14 | output_dict = json.load(open(output_path))
15 | answer_dict = json.load(open(answer_path))
16 |
17 | # group by category
18 | output_dict_w_cat = {}
19 | for data_id, parsed_pred in output_dict.items():
20 | category = "_".join(data_id.split("_")[1:-1])
21 | if category not in output_dict_w_cat:
22 | output_dict_w_cat.update({category: {}})
23 | output_dict_w_cat[category].update({data_id: parsed_pred})
24 |
25 | # group by category
26 | answer_dict_w_cat = {}
27 | for data_id, parsed_pred in answer_dict.items():
28 | category = "_".join(data_id.split("_")[1:-1])
29 | if category not in answer_dict_w_cat:
30 | answer_dict_w_cat.update({category: {}})
31 | answer_dict_w_cat[category].update({data_id: parsed_pred})
32 |
33 | evaluation_result = {}
34 |
35 | for category in CAT_SHORT2LONG.values():
36 | print("Evaluating: {}".format(category))
37 | # get cat_outputs and cat_answers
38 | try:
39 | cat_outputs = output_dict_w_cat[category]
40 | cat_answers = answer_dict_w_cat[category]
41 | except KeyError:
42 | print("Skipping {} for not found".format(category))
43 | continue
44 |
45 | exampels_to_eval = []
46 | for data_id, parsed_pred in cat_outputs.items():
47 | question_type = cat_answers[data_id]['question_type']
48 | if question_type != 'multiple-choice':
49 | parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
50 | else:
51 | parsed_pred = parsed_pred
52 |
53 | exampels_to_eval.append({
54 | "id": data_id,
55 | "question_type": question_type,
56 | "answer": cat_answers[data_id]['ground_truth'],
57 | "parsed_pred": parsed_pred
58 | })
59 |
60 | judge_dict, metric_dict = evaluate(exampels_to_eval)
61 | metric_dict.update({"num_example": len(exampels_to_eval)})
62 |
63 | evaluation_result[category] = metric_dict
64 |
65 | printable_results = {}
66 | # pdb.set_trace()
67 | # add domain Subject
68 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
69 | in_domain_cat_results = {}
70 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
71 | if cat_name in evaluation_result.keys():
72 | in_domain_cat_results[cat_name] = evaluation_result[cat_name]
73 | else:
74 | pass
75 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
76 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
77 | printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
78 | "acc": round(in_domain_ins_acc, 3)
79 | }
80 | # add sub category
81 | for cat_name, cat_results in in_domain_cat_results.items():
82 | printable_results[cat_name] = {"num": int(cat_results['num_example']),
83 | "acc": round(cat_results['acc'], 3)
84 | }
85 |
86 | # table.append(["-----------------------------", "-----", "----"])
87 | all_ins_acc = calculate_ins_level_acc(evaluation_result)
88 | printable_results['Overall'] = {"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
89 | "acc": round(all_ins_acc, 3)
90 | }
91 |
92 | print(printable_results)
93 |
94 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/print_results.py:
--------------------------------------------------------------------------------
1 | # Beautiful table to print results of all categories
2 |
3 | import os
4 | from typing import Dict
5 | import json
6 | import numpy as np
7 | from tabulate import tabulate
8 |
9 | from argparse import ArgumentParser
10 |
11 | from utils.data_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT
12 |
13 | from utils.eval_utils import calculate_ins_level_acc
14 |
15 | def main():
16 | parser = ArgumentParser()
17 | parser.add_argument('--path', type=str, default="./example_outputs/blip2_flant5xxl", help="The path to output directory.")
18 | args = parser.parse_args()
19 |
20 | # load all results
21 | all_results = {}
22 | for cat_folder_name in os.listdir(args.path):
23 | if cat_folder_name in CAT_SHORT2LONG.values():
24 | cat_folder_path = os.path.join(args.path, cat_folder_name)
25 | result_path = os.path.join(cat_folder_path, 'result.json')
26 | if os.path.exists(result_path):
27 | cat_results = json.load(open(result_path))
28 | all_results[cat_folder_name] = cat_results
29 |
30 | # print results
31 | headers = ['Subject', 'Data Num', 'Acc']
32 | table = []
33 |
34 | # add domain Subject
35 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
36 | in_domain_cat_results = {}
37 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
38 | if cat_name in all_results.keys():
39 | in_domain_cat_results[cat_name] = all_results[cat_name]
40 | else:
41 | pass
42 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
43 | in_domain_data_num = np.sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
44 | table.append(['Overall-' + domain, int(in_domain_data_num), round(in_domain_ins_acc, 3)])
45 | # add sub category
46 | for cat_name, cat_results in in_domain_cat_results.items():
47 | table.append([cat_name, int(cat_results['num_example']), round(cat_results['acc'], 3)])
48 | # table.append(["-----------------------------", "-----", "----"])
49 |
50 | # table.append(["-----------------------------", "-----", "----"])
51 | all_ins_acc = calculate_ins_level_acc(all_results)
52 | table.append(['Overall', np.sum([cat_results['num_example'] for cat_results in all_results.values()]), round(all_ins_acc, 3)])
53 |
54 | print(tabulate(table, headers=headers, tablefmt='orgtbl'))
55 |
56 |
57 | if __name__ == '__main__':
58 | main()
59 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from datasets import load_dataset, concatenate_datasets
9 | from llava.model.builder import load_pretrained_model
10 | from llava.mm_utils import get_model_name_from_path
11 |
12 | from argparse import ArgumentParser
13 |
14 | from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
15 | from utils.model_utils import call_llava_engine_df, llava_image_processor
16 | from utils.eval_utils import parse_multi_choice_response, parse_open_response
17 |
18 |
19 | def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
20 | out_samples = dict()
21 | with torch.no_grad():
22 | for sample in tqdm(samples):
23 | response = call_model_engine_fn(args, sample, model, tokenizer, processor)
24 |
25 | if sample['question_type'] == 'multiple-choice':
26 | pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans'])
27 | else: # open question
28 | pred_ans = response
29 | out_samples[sample['id']] = pred_ans
30 | return out_samples
31 |
32 | def set_seed(seed_value):
33 | """
34 | Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
35 |
36 | :param seed_value: An integer value to be used as the seed.
37 | """
38 | torch.manual_seed(seed_value)
39 | if torch.cuda.is_available():
40 | torch.cuda.manual_seed(seed_value)
41 | torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
42 | random.seed(seed_value)
43 | np.random.seed(seed_value)
44 | torch.backends.cudnn.deterministic = True
45 | torch.backends.cudnn.benchmark = False
46 |
47 | def main():
48 | parser = ArgumentParser()
49 | parser.add_argument('--output_path', type=str, default='llava1.5_13b_val.json',
50 | help='name of saved json')
51 | parser.add_argument('--config_path', type=str, default="configs/llava1.5.yaml")
52 | parser.add_argument('--data_path', type=str, default="MMMU/MMMU") # hf dataset path.
53 | parser.add_argument('--model_path', type=str, default="liuhaotian/llava-v1.5-13b")
54 | parser.add_argument('--split', type=str, default='validation')
55 | parser.add_argument('--seed', type=int, default=42)
56 |
57 | args = parser.parse_args()
58 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
59 | set_seed(args.seed)
60 |
61 | print('llava_initializing...')
62 | processor = None
63 | call_model_engine = call_llava_engine_df
64 | vis_process_func = llava_image_processor
65 |
66 | # load config and process to one value
67 | args.config = load_yaml(args.config_path)
68 | for key, value in args.config.items():
69 | if key != 'eval_params' and type(value) == list:
70 | assert len(value) == 1, 'key {} has more than one value'.format(key)
71 | args.config[key] = value[0]
72 |
73 | # run for each subject
74 | sub_dataset_list = []
75 | for subject in CAT_SHORT2LONG.values():
76 | sub_dataset = load_dataset(args.data_path, subject, split=args.split)
77 | sub_dataset_list.append(sub_dataset)
78 |
79 | # merge all dataset
80 | dataset = concatenate_datasets(sub_dataset_list)
81 |
82 |
83 | # load model
84 | model_name = get_model_name_from_path(args.model_path)
85 | tokenizer, model, vis_processors, _ = load_pretrained_model(args.model_path, None,
86 | model_name)
87 |
88 | samples = []
89 | for sample in dataset:
90 | sample = process_single_sample(sample)
91 |
92 | sample = construct_prompt(sample, args.config)
93 | if sample['image']:
94 | sample['image'] = vis_process_func(sample['image'], vis_processors).to(device)
95 | samples.append(sample)
96 |
97 | # run ex
98 | out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
99 |
100 | save_json(args.output_path, out_samples)
101 | # metric_dict.update({"num_example": len(out_samples)})
102 | # save_json(save_result_path, metric_dict)
103 |
104 |
105 | if __name__ == '__main__':
106 | main()
107 |
108 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for data load, save, and process (e.g., prompt construction)"""
2 |
3 | import os
4 | import json
5 | import yaml
6 | import re
7 |
8 |
9 | DOMAIN_CAT2SUB_CAT = {
10 | 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
11 | 'Business': ['Accounting', 'Economics', 'Finance', 'Manage','Marketing'],
12 | 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics',],
13 | 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 'Pharmacy', 'Public_Health'],
14 | 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
15 | 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
16 | }
17 |
18 |
19 | CAT_SHORT2LONG = {
20 | 'acc': 'Accounting',
21 | 'agri': 'Agriculture',
22 | 'arch': 'Architecture_and_Engineering',
23 | 'art': 'Art',
24 | 'art_theory': 'Art_Theory',
25 | 'bas_med': 'Basic_Medical_Science',
26 | 'bio': 'Biology',
27 | 'chem': 'Chemistry',
28 | 'cli_med': 'Clinical_Medicine',
29 | 'cs': 'Computer_Science',
30 | 'design': 'Design',
31 | 'diag_med': 'Diagnostics_and_Laboratory_Medicine',
32 | 'econ': 'Economics',
33 | 'elec': 'Electronics',
34 | 'ep': 'Energy_and_Power',
35 | 'fin': 'Finance',
36 | 'geo': 'Geography',
37 | 'his': 'History',
38 | 'liter': 'Literature',
39 | 'manage': 'Manage',
40 | 'mark': 'Marketing',
41 | 'mate': 'Materials',
42 | 'math': 'Math',
43 | 'mech': 'Mechanical_Engineering',
44 | 'music': 'Music',
45 | 'phar': 'Pharmacy',
46 | 'phys': 'Physics',
47 | 'psy': 'Psychology',
48 | 'pub_health': 'Public_Health',
49 | 'socio': 'Sociology'
50 | }
51 |
52 | # DATA SAVING
53 | def save_json(filename, ds):
54 | with open(filename, 'w') as f:
55 | json.dump(ds, f, indent=4)
56 |
57 |
58 | def get_multi_choice_info(options):
59 | """
60 | Given the list of options for multiple choice question
61 | Return the index2ans and all_choices
62 | """
63 |
64 | start_chr = 'A'
65 | all_choices = []
66 | index2ans = {}
67 | for i, option in enumerate(options):
68 | index2ans[chr(ord(start_chr) + i)] = option
69 | all_choices.append(chr(ord(start_chr) + i))
70 |
71 | return index2ans, all_choices
72 |
73 | def load_yaml(file_path):
74 | with open(file_path, 'r') as stream:
75 | try:
76 | yaml_dict = yaml.safe_load(stream)
77 | except yaml.YAMLError as exc:
78 | print(exc)
79 |
80 | return yaml_dict
81 |
82 |
83 | def parse_img_path(text):
84 | matches = re.findall("
", text)
85 | return matches
86 |
87 | def process_single_sample(data):
88 | question = data['question']
89 | o_imgs_paths = []
90 | for option in data['options']:
91 | current_o_imgs_paths = parse_img_path(option)
92 | for img_path in current_o_imgs_paths:
93 | o_imgs_paths.append(img_path)
94 |
95 | if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
96 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
97 | 'image': None, 'question_type': data['question_type']}
98 | else:
99 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
100 | 'image': data['image_1'], 'question_type': data['question_type']}
101 |
102 |
103 | # DATA SAVING
104 | def save_json(filename, ds):
105 | with open(filename, 'w') as f:
106 | json.dump(ds, f, indent=4)
107 |
108 | def save_jsonl(filename, data):
109 | """
110 | Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
111 |
112 | Args:
113 | filename (str): The path to the file where the data should be saved.
114 | data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
115 | """
116 | with open(filename, 'w', encoding='utf-8') as f:
117 | for img_path, caption in data.items():
118 | # Extract the base filename without the extension
119 | base_filename = os.path.basename(img_path)
120 | # Create a JSON object with the filename as the key and caption as the value
121 | json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
122 | # Write the JSON object to the file, one per line
123 | f.write(json_record + '\n')
124 |
125 | def save_args(args, path_dir):
126 | argsDict = args.__dict__
127 | with open(path_dir + 'setting.txt', 'w') as f:
128 | f.writelines('------------------ start ------------------' + '\n')
129 | for eachArg, value in argsDict.items():
130 | f.writelines(eachArg + ' : ' + str(value) + '\n')
131 | f.writelines('------------------- end -------------------')
132 |
133 |
134 |
135 | # DATA PROCESSING
136 | def construct_prompt(sample, config):
137 | question = sample['question']
138 | options = eval(sample['options'])
139 | example = ""
140 | if sample['question_type'] == 'multiple-choice':
141 | start_chr = 'A'
142 | prediction_range = []
143 | index2ans = {}
144 | for option in options:
145 | prediction_range.append(start_chr)
146 | example += f"({start_chr}) {option}\n"
147 | index2ans[start_chr] = option
148 | start_chr = chr(ord(start_chr) + 1)
149 | empty_prompt_sample_structure = config['multi_choice_example_format']
150 | empty_prompt = empty_prompt_sample_structure.format(question, example)
151 | res_dict = {}
152 | res_dict['index2ans'] = index2ans
153 | res_dict['correct_choice'] = sample['answer']
154 | res_dict['all_choices'] = prediction_range
155 | res_dict['empty_prompt'] = empty_prompt
156 | if config['task_instructions']:
157 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
158 | else:
159 | res_dict['final_input_prompt'] = empty_prompt
160 |
161 | res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
162 | else:
163 | empty_prompt_sample_structure = config['short_ans_example_format']
164 | empty_prompt = empty_prompt_sample_structure.format(question)
165 | res_dict = {}
166 | res_dict['empty_prompt'] = empty_prompt
167 | if config['task_instructions']:
168 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
169 | else:
170 | res_dict['final_input_prompt'] = empty_prompt
171 | res_dict['gt_content'] = sample['answer']
172 |
173 | res_dict.update(sample)
174 | return res_dict
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | """Response Parsing and Evaluation for various models"""
2 | from typing import Dict
3 |
4 | import re
5 | import random
6 | random.seed(42)
7 | import numpy as np
8 |
9 | # ----------- Process Multi-choice -------------
10 | def parse_multi_choice_response(response, all_choices, index2ans):
11 | """
12 | Parse the prediction from the generated response.
13 | Return the predicted index e.g., A, B, C, D.
14 | """
15 | for char in [',', '.', '!', '?', ';', ':', "'"]:
16 | response = response.strip(char)
17 | response = " " + response + " " # add space to avoid partial match
18 |
19 | index_ans = True
20 | ans_with_brack = False
21 | candidates = []
22 | for choice in all_choices: # e.g., (A) (B) (C) (D)
23 | if f'({choice})' in response:
24 | candidates.append(choice)
25 | ans_with_brack = True
26 |
27 | if len(candidates) == 0:
28 | for choice in all_choices: # e.g., A B C D
29 | if f' {choice} ' in response:
30 | candidates.append(choice)
31 |
32 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
33 | if len(candidates) == 0 and len(response.split()) > 5:
34 | for index, ans in index2ans.items():
35 | if ans.lower() in response.lower():
36 | candidates.append(index)
37 | index_ans = False # it's content ans.
38 |
39 | if len(candidates) == 0: # still not get answer, randomly choose one.
40 | pred_index = random.choice(all_choices)
41 | elif len(candidates) > 1:
42 | start_indexes = []
43 | if index_ans:
44 | if ans_with_brack:
45 | for can in candidates:
46 | index = response.rfind(f'({can})')
47 | start_indexes.append(index) # -1 will be ignored anyway
48 | # start_indexes = [generated_response.index(f'({can})') for can in candidates]
49 | else:
50 | for can in candidates:
51 | index = response.rfind(f" {can} ")
52 | start_indexes.append(index)
53 | else:
54 | for can in candidates:
55 | index = response.lower().rfind(index2ans[can].lower())
56 | start_indexes.append(index)
57 | # get the last one
58 | pred_index = candidates[np.argmax(start_indexes)]
59 | else: # if only one candidate, use it.
60 | pred_index = candidates[0]
61 |
62 | return pred_index
63 |
64 | # ----------- Process Open -------------
65 | def check_is_number(string):
66 | """
67 | Check if the given string a number.
68 | """
69 | try:
70 | float(string.replace(',', ''))
71 | return True
72 | except ValueError:
73 | # check if there's comma inside
74 | return False
75 |
76 | def normalize_str(string):
77 | """
78 | Normalize the str to lower case and make them float numbers if possible.
79 | """
80 | # check if characters in the string
81 |
82 | # if number, numerize it.
83 | string = string.strip()
84 |
85 | is_number = check_is_number(string)
86 |
87 | if is_number:
88 | string = string.replace(',', '')
89 | string = float(string)
90 | # leave 2 decimal
91 | string = round(string, 2)
92 | return [string]
93 | else: # it's likely to be a string
94 | # lower it
95 | string = string.lower()
96 | if len(string) == 1:
97 | return [" " + string, string + " "] # avoid trivial matches
98 | return [string]
99 |
100 | def extract_numbers(string):
101 | """
102 | Exact all forms of numbers from a string with regex.
103 | """
104 | # Pattern for numbers with commas
105 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
106 | # Pattern for scientific notation
107 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
108 | # Pattern for simple numbers without commas
109 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
110 |
111 | # Extract numbers with commas
112 | numbers_with_commas = re.findall(pattern_commas, string)
113 | # Extract numbers in scientific notation
114 | numbers_scientific = re.findall(pattern_scientific, string)
115 | # Extract simple numbers without commas
116 | numbers_simple = re.findall(pattern_simple, string)
117 |
118 | # Combine all extracted numbers
119 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
120 | return all_numbers
121 |
122 | def parse_open_response(response):
123 | """
124 | Parse the prediction from the generated response.
125 | Return a list of predicted strings or numbers.
126 | """
127 | # content = content.strip("\n").strip(".").strip(" ")
128 | def get_key_subresponses(response):
129 | key_responses = []
130 | response = response.strip().strip(".").lower()
131 | sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
132 | indicators_of_keys = ['could be ', 'so ', 'is ',
133 | 'thus ', 'therefore ', 'final ', 'answer ', 'result ']
134 | key_responses = []
135 | for index, resp in enumerate(sub_responses):
136 | # if last one, accept it's an equation (the entire response can be just one sentence with equation)
137 | if index == len(sub_responses) - 1:
138 | indicators_of_keys.extend(['='])
139 | shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
140 | for indicator in indicators_of_keys:
141 | if indicator in resp:
142 | if not shortest_key_response:
143 | shortest_key_response = resp.split(indicator)[-1].strip()
144 | else:
145 | if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
146 | shortest_key_response = resp.split(indicator)[-1].strip()
147 | # key_responses.append(resp.split(indicator)[1].strip())
148 |
149 | if shortest_key_response:
150 | # and it's not trivial
151 | if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
152 | key_responses.append(shortest_key_response)
153 | if len(key_responses) == 0: # did not found any
154 | return [response]
155 | return key_responses
156 | # pdb.set_trace()
157 | key_responses = get_key_subresponses(response)
158 |
159 | pred_list = key_responses.copy() # keep the original string response
160 | for resp in key_responses:
161 | pred_list.extend(extract_numbers(resp))
162 |
163 | tmp_pred_list = []
164 | for i in range(len(pred_list)):
165 | tmp_pred_list.extend(normalize_str(pred_list[i]))
166 | pred_list = tmp_pred_list
167 |
168 | # remove duplicates
169 | pred_list = list(set(pred_list))
170 |
171 | return pred_list
172 |
173 | # ----------- Evaluation -------------
174 |
175 | def eval_multi_choice(gold_i, pred_i):
176 | """
177 | Evaluate a multiple choice instance.
178 | """
179 | correct = False
180 | # only they are exactly the same, we consider it as correct
181 | if isinstance(gold_i, list):
182 | for answer in gold_i:
183 | if answer == pred_i:
184 | correct = True
185 | break
186 | else: # gold_i is a string
187 | if gold_i == pred_i:
188 | correct = True
189 | return correct
190 |
191 | def eval_open(gold_i, pred_i):
192 | """
193 | Evaluate an open question instance
194 | """
195 | correct = False
196 | if isinstance(gold_i, list):
197 | # use float to avoid trivial matches
198 | norm_answers = []
199 | for answer in gold_i:
200 | norm_answers.extend(normalize_str(answer))
201 | else:
202 | norm_answers = normalize_str(gold_i)
203 | for pred in pred_i: # pred is already normalized in parse response phase
204 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
205 | for norm_ans in norm_answers:
206 | # only see if the string answer in the string pred
207 | if isinstance(norm_ans, str) and norm_ans in pred:
208 | if not correct:
209 | correct = True
210 | break
211 | else: # it's a float number
212 | if pred in norm_answers:
213 | if not correct:
214 | correct = True
215 | break
216 | return correct
217 |
218 | # ----------- Batch Evaluation -------------
219 | def evaluate(samples):
220 | """
221 | Batch evaluation for multiple choice and open questions.
222 | """
223 | pred_correct = 0
224 | judge_dict = dict()
225 | for sample in samples:
226 | gold_i = sample['answer']
227 | pred_i = sample['parsed_pred']
228 | if sample['question_type'] == 'multiple-choice':
229 | correct = eval_multi_choice(gold_i, pred_i)
230 | else: # open question
231 | correct = eval_open(gold_i, pred_i)
232 |
233 | if correct:
234 | judge_dict[sample['id']] = 'Correct'
235 | pred_correct += 1
236 | else:
237 | judge_dict[sample['id']] = 'Wrong'
238 |
239 | if len(samples) == 0:
240 | return {'acc': 0}
241 | return judge_dict, {'acc': pred_correct / len(samples)}
242 |
243 |
244 |
245 | # ----------- Calculate Accuracy -------------
246 | def calculate_ins_level_acc(results: Dict):
247 | """Calculate the instruction level accuracy for given Subject results"""
248 | acc = 0
249 | ins_num = 0
250 | for cat_results in results.values():
251 | acc += cat_results['acc'] * cat_results['num_example']
252 | ins_num += cat_results['num_example']
253 | if ins_num == 0:
254 | return 0
255 | return acc / ins_num
256 |
257 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmmu_eval/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | from random import random
2 | import torch
3 |
4 | def call_llava_engine_df(args, sample, model, tokenizer=None, processor=None):
5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
6 | from llava.conversation import conv_templates, SeparatorStyle
7 |
8 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
9 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
10 |
11 | def insert_separator(X, sep):
12 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
13 |
14 | input_ids = []
15 | offset = 0
16 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
17 | offset = 1
18 | input_ids.append(prompt_chunks[0][0])
19 |
20 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
21 | input_ids.extend(x[offset:])
22 |
23 | if return_tensors is not None:
24 | if return_tensors == 'pt':
25 | return torch.tensor(input_ids, dtype=torch.long)
26 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
27 | return input_ids
28 |
29 | def deal_with_prompt(input_text, mm_use_im_start_end):
30 | qs = input_text
31 | if mm_use_im_start_end:
32 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
33 | else:
34 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
35 | return qs
36 |
37 | prompt = sample['final_input_prompt']
38 | prompt = deal_with_prompt(prompt, model.config.mm_use_im_start_end)
39 | conv = conv_templates['vicuna_v1'].copy()
40 | conv.append_message(conv.roles[0], prompt)
41 | conv.append_message(conv.roles[1], None)
42 | prompt = conv.get_prompt()
43 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
44 | image = sample['image']
45 | if image is not None:
46 | output_ids = model.generate(
47 | input_ids,
48 | images=image.unsqueeze(0).half().cuda(),
49 | do_sample=True,
50 | temperature=1,
51 | top_p=None,
52 | num_beams=5,
53 | max_new_tokens=128,
54 | use_cache=True)
55 |
56 | input_token_len = input_ids.shape[1]
57 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
58 | if n_diff_input_output > 0:
59 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
60 | response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
61 | else: # multiple images actually
62 | if sample['question_type'] == 'multiple-choice':
63 | all_choices = sample['all_choices']
64 | response = random.choice(all_choices)
65 | else:
66 | response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
67 |
68 | return response
69 |
70 |
71 | def llava_image_processor(raw_image, vis_processors=None):
72 | image_tensor = vis_processors.preprocess(raw_image, return_tensors='pt')['pixel_values'][0]
73 | return image_tensor
74 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/mmvet_eval.py:
--------------------------------------------------------------------------------
1 | from vlmeval.evaluate.misc import build_judge
2 | from vlmeval.smp import *
3 | from vlmeval.utils import track_progress_rich
4 |
5 | def build_mmvet_gpt4_prompt(line):
6 | question = line['question']
7 | gt = str(line['answer'])
8 | prediction = str(line['prediction'])
9 | prompt = """Compare the ground truth and prediction from AI models, to give a correctness score for the prediction. in the ground truth means it is totally right only when all elements in the ground truth are present in the prediction, and means it is totally right when any one element in the ground truth is present in the prediction. The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Just complete the last space of the correctness score.
10 |
11 | Question | Ground truth | Prediction | Correctness
12 | --- | --- | --- | ---
13 | What is x in the equation? | -1 -5 | x = 3 | 0.0
14 | What is x in the equation? | -1 -5 | x = -1 | 0.5
15 | What is x in the equation? | -1 -5 | x = -5 | 0.5
16 | What is x in the equation? | -1 -5 | x = -5 or 5 | 0.5
17 | What is x in the equation? | -1 -5 | x = -1 or x = -5 | 1.0
18 | Can you explain this meme? | This meme is poking fun at the fact that the names of the countries Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes, while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues because the names of these countries do not accurately represent their landscapes. | The meme talks about Iceland and Greenland. It's pointing out that despite their names, Iceland is not very icy and Greenland isn't very green. | 0.4
19 | Can you explain this meme? | This meme is poking fun at the fact that the names of the countries Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes, while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues because the names of these countries do not accurately represent their landscapes. | The meme is using humor to point out the misleading nature of Iceland's and Greenland's names. Iceland, despite its name, has lush green landscapes while Greenland is mostly covered in ice and snow. The text 'This is why I have trust issues' is a playful way to suggest that these contradictions can lead to distrust or confusion. The humor in this meme is derived from the unexpected contrast between the names of the countries and their actual physical characteristics. | 1.0
20 | """
21 | gpt4_prompt = prompt + '\n' + ' | '.join([question, gt.replace("", " ").replace("", " "), prediction, ""])
22 | return gpt4_prompt
23 |
24 | def MMVet_auxeval(model, line):
25 | def float_cvt(s):
26 | try:
27 | return float(s)
28 | except ValueError:
29 | return None
30 |
31 | prompt = build_mmvet_gpt4_prompt(line)
32 | log = ''
33 | retry = 5
34 | for i in range(retry):
35 | output = model.generate(prompt, temperature=i * 0.5)
36 | score = float_cvt(output)
37 | if score is None:
38 | log += f'Try {i}: output is {output}, failed to parse.\n'
39 | elif score < 0 or score > 1:
40 | log += f'Try {i}: output is {output}, invalid score: {score}.\n'
41 | else:
42 | log += 'Succeed'
43 | return dict(log=log, score=score)
44 | log += 'All 5 retries failed.\n'
45 | return dict(log=log, score=0.0)
46 |
47 | def MMVet_acc(result_file):
48 | data = load(result_file)
49 | tot = defaultdict(lambda: 0)
50 | score = defaultdict(lambda: 0)
51 | lt = len(data)
52 | cate2_list = []
53 | for i in range(lt):
54 | item = data.iloc[i]
55 | cate = item['category']
56 | cate2 = cate.replace(',','_')
57 | if cate2 not in cate2_list:
58 | cate2_list.append(cate2)
59 | grade = float(item['score'])
60 | cate_list = ['rec','ocr','know','gen','spat','math']
61 | for capa in cate_list:
62 | if capa in cate:
63 | tot[capa] += 1
64 | score[capa] += grade
65 | tot['Overall'] += 1
66 | tot[cate2] += 1
67 | score['Overall'] += grade
68 | score[cate2] += grade
69 |
70 | res = defaultdict(list)
71 | res2 = defaultdict(list)
72 | cate_list.append('Overall')
73 | cate2_list.append('Overall')
74 | for k in cate_list:
75 | res['Category'].append(k)
76 | res['tot'].append(tot[k])
77 | res['acc'].append(score[k] / tot[k] * 100)
78 | for v in cate2_list:
79 | res2['Category'].append(v)
80 | res2['tot'].append(tot[v])
81 | res2['acc'].append(score[v] / tot[v] * 100)
82 | res = pd.DataFrame(res)
83 | res2 = pd.DataFrame(res2)
84 | return res, res2
85 |
86 | def MMVet_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
87 | logger = get_logger('Evaluation')
88 |
89 | suffix = eval_file.split('.')[-1]
90 | storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
91 | tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
92 | if osp.exists(storage):
93 | logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MMVet_eval. ")
94 | else:
95 | data = load(eval_file)
96 | gpt_version = model
97 | model = build_judge(gpt_version, verbose=verbose, max_tokens=3, retry=10)
98 |
99 | lt = len(data)
100 | lines = [data.iloc[i] for i in range(lt)]
101 | tups = [(model, line) for line in lines]
102 | indices = [line['index'] for line in lines]
103 |
104 | ans = {}
105 | if osp.exists(tmp_file):
106 | ans = load(tmp_file)
107 | tups = [x for x, i in zip(tups, indices) if i not in ans]
108 | indices = [i for i in indices if i not in ans]
109 |
110 | if len(indices):
111 | new_results = track_progress_rich(
112 | MMVet_auxeval, tups, nproc=nproc, chunksize=nproc,
113 | keys=indices, save=tmp_file)
114 | ans = load(tmp_file)
115 | for k, v in zip(indices, new_results):
116 | assert k in ans
117 | assert ans[k]['log'] == v['log'] and ans[k]['score'] == v['score']
118 |
119 | log_map, score_map = {}, {}
120 | all_inds = [line['index'] for line in lines]
121 | for k in all_inds:
122 | log_map[k] = ans[k]['log']
123 | score_map[k] = ans[k]['score']
124 | data['score'] = [score_map[idx] for idx in data['index']]
125 | data['log'] = [log_map[idx] for idx in data['index']]
126 | dump(data, storage)
127 |
128 | score, score_fine = MMVet_acc(storage)
129 | score_pth = storage.replace('.xlsx', '_score.csv')
130 | score_fine_pth = storage.replace('.xlsx', '_score_fine.csv')
131 |
132 | dump(score, score_pth)
133 | dump(score_fine, score_fine_pth)
134 | logger.info(f'MMVet_eval successfully finished evaluating {eval_file}, results saved in {score_pth} and {score_fine_pth}')
135 | logger.info(f'Score: ')
136 | logger.info(score)
137 |
138 | def parse_args():
139 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
140 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
141 | parser.add_argument(
142 | "--model",
143 | type=str,
144 | help="The LLM (GPT) used for inference. ",
145 | default="gpt-4-turbo",
146 | choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613'])
147 | parser.add_argument("--nproc", type=int, default=4)
148 | parser.add_argument("--verbose", action='store_true')
149 | args = parser.parse_args()
150 | return args
151 |
152 | if __name__ == '__main__':
153 | args = parse_args()
154 | MMVet_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
155 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/yes_or_no.py:
--------------------------------------------------------------------------------
1 | from vlmeval.evaluate.misc import build_judge
2 | from vlmeval.smp import *
3 | from vlmeval.utils import track_progress_rich
4 |
5 | INTERNAL = os.environ.get('INTERNAL', 0)
6 |
7 | def MME_rating(data_file):
8 | data = load(data_file)
9 | stats = defaultdict(dict)
10 | lt = len(data)
11 | for i in range(lt):
12 | item = data.iloc[i]
13 | category = item['category']
14 | image_path = item['image_path']
15 | score = item['score']
16 | if image_path not in stats[category]:
17 | stats[category][image_path] = []
18 | stats[category][image_path].append(score)
19 |
20 | def acc(key, mode='normal'):
21 | res = stats[key]
22 | values = []
23 | for val in res.values():
24 | if mode == 'normal':
25 | values.extend(val)
26 | elif mode == 'plus':
27 | values.append(val[0] * val[1])
28 | return np.mean(values) * 100
29 |
30 | scores = {}
31 | for k in stats:
32 | scores[k] = acc(k) + acc(k, 'plus')
33 |
34 | super_cates = dict(
35 | perception=['OCR', 'artwork', 'celebrity', 'color', 'count', 'existence', 'landmark', 'position', 'posters', 'scene'],
36 | reasoning=['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation']
37 | )
38 |
39 | ret = {}
40 | for sc, cate_list in super_cates.items():
41 | base = 0
42 | for c in cate_list:
43 | base += scores[c]
44 | ret[sc] = base
45 | ret.update(scores)
46 | ret = d2df(ret)
47 | return ret
48 |
49 | def Hallusion_rating(data_file):
50 | def calc_fAcc(data):
51 | res = defaultdict(list)
52 | lt = len(data)
53 | for i in range(lt):
54 | line = data.iloc[i]
55 | res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score'])
56 | return np.mean([np.all(x) for x in res.values()]) * 100
57 |
58 | def calc_qAcc(data):
59 | res = defaultdict(list)
60 | lt = len(data)
61 | for i in range(lt):
62 | line = data.iloc[i]
63 | res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score'])
64 | return np.mean([np.all(x) for x in res.values()]) * 100
65 |
66 | def calc_aAcc(data):
67 | return np.mean(data['score']) * 100
68 |
69 | data = load(data_file)
70 | data['set_id'] = [x.split('_')[3] for x in data['index']]
71 | data['figure_id'] = [x.split('_')[4] for x in data['index']]
72 | data['question_id'] = [x.split('_')[5] for x in data['index']]
73 |
74 | res = dict(split=[], aAcc=[], fAcc=[], qAcc=[])
75 | res['split'].append('Overall')
76 | res['aAcc'].append(calc_aAcc(data))
77 | res['fAcc'].append(calc_fAcc(data))
78 | res['qAcc'].append(calc_qAcc(data))
79 |
80 | if 'category' in data:
81 | cates = list(set(data['category']))
82 | for c in cates:
83 | sub = data[data['category'] == c]
84 | res['split'].append(c)
85 | res['aAcc'].append(calc_aAcc(sub))
86 | res['fAcc'].append(calc_fAcc(sub))
87 | res['qAcc'].append(calc_qAcc(sub))
88 |
89 | if 'l2-category' in data:
90 | cates = list(set(data['l2-category']))
91 | for c in cates:
92 | sub = data[data['l2-category'] == c]
93 | res['split'].append(c)
94 | res['aAcc'].append(calc_aAcc(sub))
95 | res['fAcc'].append(calc_fAcc(sub))
96 | res['qAcc'].append(calc_qAcc(sub))
97 | ret = pd.DataFrame(res)
98 | return ret
99 |
100 | def default_rating(data_file):
101 | data = load(data_file)
102 | res = {}
103 | res['Overall'] = np.mean(data['score']) * 100
104 | if 'category' in data:
105 | cates = list(set(data['category']))
106 | cates = [c for c in cates if not pd.isna(c)]
107 | cates.sort()
108 | for c in cates:
109 | sub = data[data['category'] == c]
110 | res[c] = np.mean(sub['score']) * 100
111 | if 'l2-category' in data:
112 | cates = list(set(data['l2-category']))
113 | cates = [c for c in cates if not pd.isna(c)]
114 | cates.sort()
115 | for c in cates:
116 | sub = data[data['l2-category'] == c]
117 | res[c] = np.mean(sub['score']) * 100
118 | ret = d2df(res)
119 | return ret
120 |
121 | def YOrN_match_prompt(line):
122 | tmpl = (
123 | "You are an AI assistant who will help me to match an answer with two options of a question. "
124 | "The options are only Yes / No. "
125 | "You are provided with a question and an answer, and you need to find which option (Yes / No) is most similar to the answer. "
126 | "If the meaning of all options are significantly different from the answer, output Unknown. "\
127 | "Your should output a single word among the following 3 choices: Yes, No, Unknown.\n"
128 | "Example 1: \n"
129 | "Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is 'Hello'.\nYour output: Yes\n"
130 | "Example 2: \n"
131 | "Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is not 'Hello'.\nYour output: No\n"
132 | "Example 3: \n"
133 | "Question: {}?\nAnswer: {}\nYour output: "
134 | )
135 | return tmpl.format(line['question'], line['prediction'])
136 |
137 | def YOrN_Extraction(output):
138 | s = output.lower()
139 | words = process_punctuation(s).split()
140 | if 'yes' in words and 'no' not in words:
141 | return 'Yes'
142 | if 'yes' not in words and 'no' in words:
143 | return 'No'
144 | return 'Unknown'
145 |
146 | def YOrN_auxeval(model, line):
147 | prompt = YOrN_match_prompt(line)
148 | retry = 5
149 | for i in range(retry):
150 | output = model.generate(prompt, temperature=0.5 * i)
151 | ans = YOrN_Extraction(output)
152 | if ans != 'Unknown':
153 | return ans
154 | return 'Unknown'
155 |
156 | def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=None):
157 | logger = get_logger('Evaluation')
158 | data = load(eval_file)
159 | data['prediction'] = [str(x) for x in data['prediction']]
160 | storage = eval_file.replace('.xlsx', '_auxmatch.xlsx')
161 | tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
162 |
163 | if not osp.exists(storage):
164 | ans_map = {k: YOrN_Extraction(v) for k, v in zip(data['index'], data['prediction'])}
165 | if osp.exists(tmp_file):
166 | tmp = load(tmp_file)
167 | for k in tmp:
168 | if ans_map[k] == 'Unknown' and tmp[k] != 'Unknown':
169 | ans_map[k] = tmp[k]
170 |
171 | data['extracted'] = [ans_map[x] for x in data['index']]
172 | unknown = data[data['extracted'] == 'Unknown']
173 |
174 | model_name = 'chatgpt-0613'
175 |
176 | if INTERNAL or gpt_key_set():
177 | model = build_judge(model_name, verbose=verbose, retry=10)
178 | else:
179 | logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
180 | model = None
181 |
182 | if model is not None:
183 | lt = len(unknown)
184 | lines = [unknown.iloc[i] for i in range(lt)]
185 | tups = [(model, line) for line in lines]
186 | indices = list(unknown['index'])
187 | if len(tups):
188 | res = track_progress_rich(YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
189 | for k, v in zip(indices, res):
190 | ans_map[k] = v
191 |
192 | data['extracted'] = [ans_map[x] for x in data['index']]
193 | dump(data, storage)
194 | else:
195 | logger.warning(f"GPT matching file {storage} already exists, will reuse it in YOrN_eval. ")
196 |
197 | data = load(storage)
198 | data["score"] = (data["answer"] == data["extracted"])
199 | dump(data, storage)
200 |
201 | if dataset is not None and listinstr(['MME'], dataset):
202 | score = MME_rating(storage)
203 | elif dataset is not None and listinstr(['Hallusion'], dataset):
204 | score = Hallusion_rating(storage)
205 | else:
206 | score = default_rating(storage)
207 |
208 | score_tgt = eval_file.replace('.xlsx', '_score.csv')
209 | dump(score, score_tgt)
210 |
211 | logger.info(f'YOrN_eval successfully finished evaluating {eval_file}, results saved in {score_tgt}')
212 | logger.info('Score: ')
213 | logger.info(score)
214 | return score
215 |
216 | def parse_args():
217 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
218 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
219 | parser.add_argument("--model", type=str, help="The LLM (GPT) used for inference. ", default="chatgpt-0613", choices=['chatgpt-0613'])
220 | parser.add_argument("--nproc", type=int, default=4)
221 | parser.add_argument("--dataset", type=str, default=None)
222 | parser.add_argument("--verbose", action='store_true')
223 | args = parser.parse_args()
224 | return args
225 |
226 | if __name__ == '__main__':
227 | args = parse_args()
228 | acc = YOrN_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose, dataset=args.dataset)
229 |
--------------------------------------------------------------------------------
/vlmeval/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import datetime
4 | from vlmeval.config import supported_VLM
5 | from vlmeval.utils import TSVDataset, track_progress_rich, split_MMMU
6 | from vlmeval.smp import *
7 |
8 | FAIL_MSG = 'Failed to obtain answer via API.'
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--data', type=str, nargs='+', required=True)
13 | parser.add_argument("--model", type=str, nargs='+', required=True)
14 | parser.add_argument("--nproc", type=int, default=4, required=True)
15 | parser.add_argument("--verbose", action='store_true')
16 | args = parser.parse_args()
17 | return args
18 |
19 | # Only API model is accepted
20 | def infer_data_api(work_dir, model_name, dataset_name, index_set, api_nproc=4):
21 | rank, world_size = get_rank_and_world_size()
22 | assert rank == 0 and world_size == 1
23 | dataset = TSVDataset(dataset_name)
24 | data = dataset.data
25 | data = data[data['index'].isin(index_set)]
26 |
27 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
28 | is_api = getattr(model, 'is_api', False)
29 | assert is_api
30 |
31 | lt, indices = len(data), list(data['index'])
32 | structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
33 |
34 | out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
35 | res = {}
36 | if osp.exists(out_file):
37 | res = load(out_file)
38 | res = {k: v for k, v in res.items() if FAIL_MSG not in v}
39 |
40 | structs = [s for i, s in zip(indices, structs) if i not in res]
41 | indices = [i for i in indices if i not in res]
42 |
43 | gen_func = None
44 | if listinstr(['MMMU'], dataset_name):
45 | assert hasattr(model, 'interleave_generate')
46 | gen_func = model.interleave_generate
47 | structs = [dict(ti_list=split_MMMU(struct), dataset=dataset_name) for struct in structs]
48 | elif listinstr(['CORE_MM'], dataset_name):
49 | assert hasattr(model, 'multi_generate')
50 | gen_func = model.multi_generate
51 | structs = [dict(image_paths=struct['image'], prompt=struct['text'], dataset=dataset_name) for struct in structs]
52 | else:
53 | gen_func = model.generate
54 | structs = [dict(image_path=struct['image'], prompt=struct['text'], dataset=dataset_name) for struct in structs]
55 |
56 | inference_results = track_progress_rich(
57 | gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
58 |
59 | res = load(out_file)
60 | for idx, text in zip(indices, inference_results):
61 | assert (res[idx] == text if idx in res else True)
62 | res[idx] = text
63 | return res
64 |
65 | def infer_data(model_name, dataset_name, out_file, verbose=False, api_nproc=4):
66 | res = {}
67 | if osp.exists(out_file):
68 | res = load(out_file)
69 |
70 | rank, world_size = get_rank_and_world_size()
71 | if rank == 0:
72 | dataset = TSVDataset(dataset_name)
73 | if world_size > 1:
74 | dist.barrier()
75 | dataset = TSVDataset(dataset_name)
76 |
77 | indices = list(range(rank, len(dataset), world_size))
78 | lt = len(indices)
79 | data = dataset.data.iloc[indices]
80 |
81 | # If finished, will exit without building the model
82 | all_finished = True
83 | for i in range(lt):
84 | idx = data.iloc[i]['index']
85 | if idx not in res:
86 | all_finished = False
87 | if all_finished:
88 | return
89 | data = data[~data['index'].isin(res)]
90 | lt = len(data)
91 |
92 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
93 |
94 | is_api = getattr(model, 'is_api', False)
95 | if is_api:
96 | assert world_size == 1
97 | lt, indices = len(data), list(data['index'])
98 | supp = infer_data_api(model_name=model_name, dataset_name=dataset_name, index_set=set(indices), api_nproc=api_nproc)
99 | for idx in indices:
100 | assert idx in supp
101 | res.update(supp)
102 | dump(res, out_file)
103 | return model_name
104 |
105 | for i in tqdm(range(lt)):
106 | idx = data.iloc[i]['index']
107 | if idx in res:
108 | continue
109 |
110 | if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
111 | struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
112 | else:
113 | struct = dataset.build_prompt(data.iloc[i])
114 |
115 | if dataset_name in ['CORE_MM']:
116 | assert hasattr(model, 'multi_generate')
117 | response = model.multi_generate(prompt=struct['text'], image_paths=struct['image'], dataset=dataset_name)
118 | elif listinstr(['MMMU'], dataset_name):
119 | if hasattr(model, 'interleave_generate'):
120 | response = model.interleave_generate(ti_list=split_MMMU(struct), dataset=dataset_name)
121 | elif len(struct['image']) >= 1:
122 | response = model.generate(prompt=struct['text'], image_path=struct['image'], dataset=dataset_name, qtype=struct['qtype'])
123 | else:
124 | response = model.generate(prompt=struct['text'], image_path=struct['image'], dataset=dataset_name)
125 | torch.cuda.empty_cache()
126 |
127 | if verbose:
128 | print(response, flush=True)
129 |
130 | res[idx] = response
131 | if (i + 1) % 20 == 0:
132 | dump(res, out_file)
133 |
134 | dump(res, out_file)
135 | return model
136 |
137 | def prefetch_acc(result_file):
138 | data = load(result_file)
139 | from vlmeval.evaluate.multiple_choice import build_choices, can_infer
140 | tot = defaultdict(lambda: 0)
141 | match = defaultdict(lambda: 0)
142 | hit = defaultdict(lambda: 0)
143 | lt = len(data)
144 | for i in range(lt):
145 | item = data.iloc[i]
146 | cate = item['category']
147 | tot['Overall'] += 1
148 | tot[cate] += 1
149 | choices = build_choices(item)
150 | matched = can_infer(item['prediction'], choices)
151 | if matched:
152 | match['Overall'] += 1
153 | match[cate] += 1
154 | if matched == item['answer']:
155 | hit['Overall'] += 1
156 | hit[cate] += 1
157 | res = defaultdict(list)
158 | for k in tot.keys():
159 | res['Category'].append(k)
160 | res['tot'].append(tot[k])
161 | res['match'].append(match[k])
162 | res['hit'].append(hit[k])
163 | res['match_rate'].append(match[k] / tot[k] * 100)
164 | if match[k] == 0:
165 | res['acc'].append(0)
166 | else:
167 | res['acc'].append(hit[k] / tot[k] * 100)
168 | res = pd.DataFrame(res)
169 | return res
170 |
171 | def infer_data_job(model, work_dir, model_name, dataset_name, verbose=False, api_nproc=4, ignore_failed=False):
172 | result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx')
173 | rank, world_size = get_rank_and_world_size()
174 | tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
175 | out_file = tmpl.format(rank)
176 |
177 | if True: #not osp.exists(result_file):
178 | model = infer_data(model, dataset_name=dataset_name, out_file=out_file, verbose=verbose)
179 | if world_size > 1:
180 | dist.barrier()
181 |
182 | if rank == 0:
183 | data_all = {}
184 | for i in range(world_size):
185 | data_all.update(load(tmpl.format(i)))
186 |
187 | data = TSVDataset(dataset_name).data
188 | assert len(data_all) == len(data)
189 | data['prediction'] = [str(data_all[x]) for x in data['index']]
190 | data.pop('image')
191 |
192 | dump(data, result_file)
193 | for i in range(world_size):
194 | os.remove(tmpl.format(i))
195 | return model
196 |
--------------------------------------------------------------------------------
/vlmeval/smp/__init__.py:
--------------------------------------------------------------------------------
1 | from .file import *
2 | from .vlm import *
3 | from .misc import *
4 | from .log import *
5 | from .lb import *
--------------------------------------------------------------------------------
/vlmeval/smp/file.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import pandas as pd
4 | import os
5 | import csv
6 | import hashlib
7 | import os.path as osp
8 | import time
9 | import numpy as np
10 |
11 | class NumpyEncoder(json.JSONEncoder):
12 | def default(self, obj):
13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
14 | np.int16, np.int32, np.int64, np.uint8,
15 | np.uint16, np.uint32, np.uint64)):
16 | return int(obj)
17 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
18 | return float(obj)
19 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
20 | return {'real': obj.real, 'imag': obj.imag}
21 | elif isinstance(obj, (np.ndarray,)):
22 | return obj.tolist()
23 | elif isinstance(obj, (np.bool_)):
24 | return bool(obj)
25 | elif isinstance(obj, (np.void)):
26 | return None
27 | return json.JSONEncoder.default(self, obj)
28 |
29 | # LOAD & DUMP
30 | def dump(data, f, **kwargs):
31 | def dump_pkl(data, pth, **kwargs):
32 | pickle.dump(data, open(pth, 'wb'))
33 |
34 | def dump_json(data, pth, **kwargs):
35 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
36 |
37 | def dump_jsonl(data, f, **kwargs):
38 | lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
39 | with open(f, 'w', encoding='utf8') as fout:
40 | fout.write('\n'.join(lines))
41 |
42 | def dump_xlsx(data, f, **kwargs):
43 | data.to_excel(f, index=False, engine='xlsxwriter')
44 |
45 | def dump_csv(data, f, quoting=csv.QUOTE_ALL):
46 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
47 |
48 | def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
49 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
50 |
51 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
52 | suffix = f.split('.')[-1]
53 | return handlers[suffix](data, f, **kwargs)
54 |
55 | def load(f):
56 | def load_pkl(pth):
57 | return pickle.load(open(pth, 'rb'))
58 |
59 | def load_json(pth):
60 | return json.load(open(pth, 'r', encoding='utf-8'))
61 |
62 | def load_jsonl(f):
63 | lines = open(f, encoding='utf-8').readlines()
64 | lines = [x.strip() for x in lines]
65 | if lines[-1] == '':
66 | lines = lines[:-1]
67 | data = [json.loads(x) for x in lines]
68 | return data
69 |
70 | def load_xlsx(f):
71 | return pd.read_excel(f)
72 |
73 | def load_csv(f):
74 | return pd.read_csv(f)
75 |
76 | def load_tsv(f):
77 | return pd.read_csv(f, sep='\t')
78 |
79 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
80 | suffix = f.split('.')[-1]
81 | return handlers[suffix](f)
82 |
83 | def download_file(url, filename=None):
84 | import urllib.request
85 | from tqdm import tqdm
86 |
87 | class DownloadProgressBar(tqdm):
88 | def update_to(self, b=1, bsize=1, tsize=None):
89 | if tsize is not None:
90 | self.total = tsize
91 | self.update(b * bsize - self.n)
92 |
93 | if filename is None:
94 | filename = url.split('/')[-1]
95 |
96 | with DownloadProgressBar(unit='B', unit_scale=True,
97 | miniters=1, desc=url.split('/')[-1]) as t:
98 | urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
99 | return filename
100 |
101 | def ls(dirname='.', match='', mode='all', level=1):
102 | if dirname == '.':
103 | ans = os.listdir(dirname)
104 | else:
105 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
106 | assert mode in ['all', 'dir', 'file']
107 | assert level >= 1 and isinstance(level, int)
108 | if level == 1:
109 | ans = [x for x in ans if match in x]
110 | if mode == 'dir':
111 | ans = [x for x in ans if osp.isdir(x)]
112 | elif mode == 'file':
113 | ans = [x for x in ans if not osp.isdir(x)]
114 | else:
115 | ans = [x for x in ans if osp.isdir(x)]
116 | res = []
117 | for d in ans:
118 | res.extend(ls(d, match=match, mode=mode, level=level-1))
119 | ans = res
120 | return ans
121 |
122 | def mrlines(fname, sp='\n'):
123 | f = open(fname).read().split(sp)
124 | while f != [] and f[-1] == '':
125 | f = f[:-1]
126 | return f
127 |
128 | def mwlines(lines, fname):
129 | with open(fname, 'w') as fout:
130 | fout.write('\n'.join(lines))
131 |
132 | def md5(file_pth):
133 | with open(file_pth, 'rb') as f:
134 | hash = hashlib.new('md5')
135 | for chunk in iter(lambda: f.read(2**20), b''):
136 | hash.update(chunk)
137 | return str(hash.hexdigest())
138 |
139 | def last_modified(pth):
140 | stamp = osp.getmtime(pth)
141 | m_ti = time.ctime(stamp)
142 | t_obj = time.strptime(m_ti)
143 | t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
144 | return t
145 |
--------------------------------------------------------------------------------
/vlmeval/smp/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger_initialized = {}
4 |
5 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
6 | logger = logging.getLogger(name)
7 | if name in logger_initialized:
8 | return logger
9 |
10 | for logger_name in logger_initialized:
11 | if name.startswith(logger_name):
12 | return logger
13 |
14 | stream_handler = logging.StreamHandler()
15 | handlers = [stream_handler]
16 |
17 | try:
18 | import torch.distributed as dist
19 | if dist.is_available() and dist.is_initialized():
20 | rank = dist.get_rank()
21 | else:
22 | rank = 0
23 | except ImportError:
24 | rank = 0
25 |
26 | if rank == 0 and log_file is not None:
27 | file_handler = logging.FileHandler(log_file, file_mode)
28 | handlers.append(file_handler)
29 |
30 | formatter = logging.Formatter(
31 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32 | for handler in handlers:
33 | handler.setFormatter(formatter)
34 | handler.setLevel(log_level)
35 | logger.addHandler(handler)
36 |
37 | if rank == 0:
38 | logger.setLevel(log_level)
39 | else:
40 | logger.setLevel(logging.ERROR)
41 |
42 | logger_initialized[name] = True
43 | return logger
--------------------------------------------------------------------------------
/vlmeval/smp/misc.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: F401, F403
2 | import abc
3 | import argparse
4 | import csv
5 | import multiprocessing as mp
6 | import os
7 | import os.path as osp
8 | import copy as cp
9 | import random as rd
10 | import requests
11 | import shutil
12 | import subprocess
13 | import warnings
14 | import pandas as pd
15 | from collections import OrderedDict, defaultdict
16 | from multiprocessing import Pool, current_process
17 | from tqdm import tqdm
18 | import datetime
19 | import matplotlib.pyplot as plt
20 | import seaborn as sns
21 | from tabulate import tabulate_formats, tabulate
22 | from huggingface_hub import scan_cache_dir
23 | from sty import fg, bg, ef, rs
24 |
25 | def process_punctuation(inText):
26 | import re
27 | outText = inText
28 | punct = [
29 | ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
30 | '>', '<', '@', '`', ',', '?', '!'
31 | ]
32 | commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
33 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
34 | for p in punct:
35 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(
36 | commaStrip, inText) is not None):
37 | outText = outText.replace(p, '')
38 | else:
39 | outText = outText.replace(p, ' ')
40 | outText = periodStrip.sub('', outText, re.UNICODE)
41 | return outText
42 |
43 | def h2r(value):
44 | if value[0] == '#':
45 | value = value[1:]
46 | assert len(value) == 6
47 | return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
48 |
49 | def r2h(rgb):
50 | return '#%02x%02x%02x' % rgb
51 |
52 | def colored(s, color):
53 | if isinstance(color, str):
54 | color = h2r(color)
55 | return fg(*color) + s + fg.rs
56 |
57 | def istype(s, type):
58 | if isinstance(s, type):
59 | return True
60 | try:
61 | return isinstance(eval(s), type)
62 | except Exception as _:
63 | return False
64 |
65 | def bincount(lst):
66 | bins = defaultdict(lambda: 0)
67 | for item in lst:
68 | bins[item] += 1
69 | return bins
70 |
71 | def get_cache_path(repo_id):
72 | hf_cache_info = scan_cache_dir()
73 | repos = list(hf_cache_info.repos)
74 | repo = None
75 | for r in repos:
76 | if r.repo_id == repo_id:
77 | repo = r
78 | break
79 | if repo is None:
80 | return None
81 | revs = list(repo.revisions)
82 | rev2keep, last_modified = None, 0
83 | for rev in revs:
84 | if rev.last_modified > last_modified:
85 | rev2keep, last_modified = rev, rev.last_modified
86 | if rev2keep is None:
87 | return None
88 | return str(rev2keep.snapshot_path)
89 |
90 | def proxy_set(s):
91 | import os
92 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
93 | os.environ[key] = s
94 |
95 | def get_rank_and_world_size():
96 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
97 | world_size = int(os.environ.get("WORLD_SIZE", 1))
98 | return local_rank, world_size
99 |
100 | def splitlen(s, sym='/'):
101 | return len(s.split(sym))
102 |
103 | def listinstr(lst, s):
104 | assert isinstance(lst, list)
105 | for item in lst:
106 | if item in s:
107 | return True
108 | return False
109 |
110 | def d2df(D):
111 | return pd.DataFrame({x: [D[x]] for x in D})
112 |
113 | def cn_string(s):
114 | import re
115 | if re.search(u'[\u4e00-\u9fff]', s):
116 | return True
117 | return False
118 |
119 | try:
120 | import decord
121 | except ImportError:
122 | pass
123 |
124 | def timestr(second=True, minute=False):
125 | s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
126 | if second:
127 | return s
128 | elif minute:
129 | return s[:-2]
130 | else:
131 | return s[:-4]
132 |
133 | def dict_merge(dct, merge_dct):
134 | for k, _ in merge_dct.items():
135 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
136 | dict_merge(dct[k], merge_dct[k])
137 | else:
138 | dct[k] = merge_dct[k]
139 |
140 | def youtube_dl(idx):
141 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
142 | os.system(cmd)
143 |
144 | def run_command(cmd):
145 | if isinstance(cmd, str):
146 | cmd = cmd.split()
147 | return subprocess.check_output(cmd)
148 |
--------------------------------------------------------------------------------
/vlmeval/smp/vlm.py:
--------------------------------------------------------------------------------
1 | import os, io
2 | import pandas as pd
3 | import numpy as np
4 | import string
5 | from uuid import uuid4
6 | import os.path as osp
7 | import base64
8 | from PIL import Image
9 |
10 | def mmqa_display(question):
11 | question = {k.lower(): v for k, v in question.items()}
12 | keys = list(question.keys())
13 | keys = [k for k in keys if k not in ['index', 'image']]
14 |
15 | images = question['image']
16 | if isinstance(images, str):
17 | images = [images]
18 |
19 | idx = question.pop('index', 'XXX')
20 | print(f'INDEX: {idx}')
21 |
22 | for im in images:
23 | image = decode_base64_to_image(im, target_size=512)
24 | display(image)
25 |
26 | for k in keys:
27 | try:
28 | if not pd.isna(question[k]):
29 | print(f'{k.upper()}. {question[k]}')
30 | except ValueError:
31 | if False in pd.isna(question[k]):
32 | print(f'{k.upper()}. {question[k]}')
33 |
34 | def encode_image_to_base64(img, target_size=-1):
35 | # if target_size == -1, will not do resizing
36 | # else, will set the max_size ot (target_size, target_size)
37 | if img.mode in ("RGBA", "P"):
38 | img = img.convert("RGB")
39 | tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
40 | if target_size > 0:
41 | img.thumbnail((target_size, target_size))
42 | img.save(tmp)
43 | with open(tmp, 'rb') as image_file:
44 | image_data = image_file.read()
45 | ret = base64.b64encode(image_data).decode('utf-8')
46 | os.remove(tmp)
47 | return ret
48 |
49 | def encode_image_file_to_base64(image_path, target_size=-1):
50 | image = Image.open(image_path)
51 | return encode_image_to_base64(image, target_size=target_size)
52 |
53 | def decode_base64_to_image(base64_string, target_size=-1):
54 | image_data = base64.b64decode(base64_string)
55 | image = Image.open(io.BytesIO(image_data))
56 | if image.mode in ('RGBA', 'P'):
57 | image = image.convert('RGB')
58 | if target_size > 0:
59 | image.thumbnail((target_size, target_size))
60 | return image
61 |
62 | def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
63 | image = decode_base64_to_image(base64_string, target_size=target_size)
64 | image.save(image_path)
65 |
66 | def LMUDataRoot():
67 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
68 | return os.environ['LMUData']
69 | home = osp.expanduser('~')
70 | root = osp.join(home, 'LMUData')
71 | os.makedirs(root, exist_ok=True)
72 | return root
73 |
74 | def build_option_str(option_dict):
75 | s = 'There are several options: \n'
76 | for c, content in option_dict.items():
77 | if not pd.isna(content):
78 | s += f'{c}. {content}\n'
79 | return s
80 |
81 | def isimg(s):
82 | return osp.exists(s) or s.startswith('http')
83 |
84 | def read_ok(img_path):
85 | if not osp.exists(img_path):
86 | return False
87 | try:
88 | im = Image.open(img_path)
89 | assert im.size[0] > 0 and im.size[1] > 0
90 | return True
91 | except:
92 | return False
93 |
94 | def gpt_key_set():
95 | openai_key = os.environ.get('OPENAI_API_KEY', None)
96 | return isinstance(openai_key, str) and openai_key.startswith('sk-')
97 |
98 | def apiok(wrapper):
99 | s = wrapper.generate("Hello!")
100 | return wrapper.fail_msg not in s
101 |
102 | def circular_pred(df, extract_func=None):
103 | if extract_func is None:
104 | extract_func = lambda x: x
105 | df = df.sort_values('index')
106 | from vlmeval.utils import can_infer_option
107 | shift = int(1e6)
108 |
109 | choices = [extract_func(x) for x in df['prediction']]
110 | pred_map = {i: c for i, c in zip(df['index'], choices)}
111 | flag_map = {i: True for i in pred_map if i < 1e6}
112 | valid_map = {i: True for i in pred_map if i < 1e6}
113 | for i in df['index']:
114 | if i >= shift and pred_map[i] and pred_map[i - shift]:
115 | if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase):
116 | valid_map[i % shift] = False
117 | continue
118 | if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
119 | continue
120 | else:
121 | flag_map[i % shift] = False
122 | flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
123 | flags = list(flag_map.values())
124 | return np.mean(flags)
125 |
126 | def MMBenchOfficialServer():
127 | root = LMUDataRoot()
128 | for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']:
129 | if osp.exists(f'{root}/{dataset}.tsv'):
130 | return True
131 | return False
--------------------------------------------------------------------------------
/vlmeval/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .matching_util import can_infer, can_infer_option, can_infer_text
2 | from .mp_util import track_progress_rich
3 | from .custom_prompt import CustomPrompt
4 | from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
5 | from .dataset import TSVDataset, split_MMMU
6 | from .xtuner_util import PROMPT_TEMPLATE, StopWordStoppingCriteria, expand2square, prepare_inputs_labels_for_multimodal, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
7 |
8 |
9 | __all__ = [
10 | 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
11 | 'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
12 | 'split_MMMU', 'abbr2full', 'expand2square', 'prepare_inputs_labels_for_multimodal',
13 | 'DEFAULT_IMAGE_TOKEN', 'IMAGE_TOKEN_INDEX', 'PROMPT_TEMPLATE', 'StopWordStoppingCriteria'
14 | ]
15 |
--------------------------------------------------------------------------------
/vlmeval/utils/custom_prompt.py:
--------------------------------------------------------------------------------
1 | from ..smp import *
2 | from .dataset_config import img_root_map
3 | from abc import abstractmethod
4 |
5 | class CustomPrompt:
6 |
7 | @abstractmethod
8 | def use_custom_prompt(self, dataset):
9 | raise NotImplementedError
10 |
11 | @abstractmethod
12 | def build_prompt(self, line, dataset):
13 | raise NotImplementedError
14 |
15 | def dump_image(self, line, dataset):
16 | ROOT = LMUDataRoot()
17 | assert isinstance(dataset, str)
18 | img_root = osp.join(ROOT, 'images', img_root_map[dataset])
19 | os.makedirs(img_root, exist_ok=True)
20 | if isinstance(line['image'], list):
21 | tgt_path = []
22 | assert 'image_path' in line
23 | for img, im_name in zip(line['image'], line['image_path']):
24 | path = osp.join(img_root, im_name)
25 | if not read_ok(path):
26 | decode_base64_to_image_file(img, path)
27 | tgt_path.append(path)
28 | else:
29 | tgt_path = osp.join(img_root, f"{line['index']}.jpg")
30 | if not read_ok(tgt_path):
31 | decode_base64_to_image_file(line['image'], tgt_path)
32 | return tgt_path
--------------------------------------------------------------------------------
/vlmeval/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import hashlib
3 | from ..smp import *
4 | from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
5 | from .custom_prompt import CustomPrompt
6 |
7 | def isliststr(s):
8 | return (s[0] == '[') and (s[-1] == ']')
9 |
10 | def check_md5(data_path, dataset):
11 | try:
12 | with open(data_path, 'rb') as f:
13 | hash = hashlib.new('md5')
14 | for chunk in iter(lambda: f.read(2**20), b''):
15 | hash.update(chunk)
16 | if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
17 | return True
18 | else:
19 | warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
20 | return False
21 | except:
22 | return False
23 |
24 | def split_MMMU(struct):
25 | assert 'image' in struct and 'text' in struct
26 | text, images = struct['text'], struct['image']
27 | text_segs = text.split(''
33 | image_idx = int(seg[0]) - 1
34 | segs.append(images[image_idx])
35 | segs.append(seg[2:])
36 | return segs
37 |
38 | class TSVDataset(CustomPrompt):
39 |
40 | def __init__(self, dataset='MMBench', skip_noimg=True):
41 |
42 | self.data_root = LMUDataRoot()
43 | assert osp.exists(self.data_root)
44 |
45 | self.dataset = dataset
46 | self.dataset_type = DATASET_TYPE(dataset)
47 |
48 | url = dataset_URLs[dataset]
49 | file_name = url.split('/')[-1]
50 | data_path = osp.join(self.data_root, file_name)
51 |
52 | #print(md5(data_path))
53 | #print(dataset_md5_dict[dataset])
54 | if osp.exists(data_path) and md5(data_path) == dataset_md5_dict[dataset]:
55 | pass
56 | else:
57 | warnings.warn("The dataset tsv is not downloaded")
58 | download_file(url, data_path)
59 |
60 | data = load(data_path)
61 | self.skip_noimg = skip_noimg
62 | if skip_noimg:
63 | data = data[~pd.isna(data['image'])]
64 |
65 | # Prompt for Captioning
66 | if listinstr(['COCO'], dataset):
67 | data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data)
68 |
69 | data['index'] = [str(x) for x in data['index']]
70 | data['image'] = [str(x) for x in data['image']]
71 |
72 | image_map = {x: y for x, y in zip(data['index'], data['image'])}
73 | for k in image_map:
74 | if len(image_map[k]) <= 64:
75 | idx = image_map[k]
76 | assert idx in image_map and len(image_map[idx]) > 64
77 | image_map[k] = image_map[idx]
78 |
79 | data['image'] = [
80 | eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
81 | for k in data['index']
82 | ]
83 | if 'image_path' in data:
84 | data['image_path'] = [
85 | eval(pths) if isliststr(pths) else pths for pths in data['image_path']
86 | ]
87 | if np.all([istype(x, int) for x in data['index']]):
88 | data['index'] = [int(x) for x in data['index']]
89 |
90 | self.data = data
91 |
92 | def __len__(self):
93 | return len(self.data)
94 |
95 | def build_prompt(self, line, dataset=None):
96 | if dataset is None:
97 | dataset = self.dataset
98 |
99 | if isinstance(line, int):
100 | line = self.data.iloc[line]
101 |
102 | tgt_path = self.dump_image(line, dataset)
103 |
104 | prompt = line['question']
105 | if DATASET_TYPE(dataset) == 'multi-choice':
106 | question = line['question']
107 | options = {
108 | cand: line[cand]
109 | for cand in string.ascii_uppercase
110 | if cand in line and not pd.isna(line[cand])
111 | }
112 | options_prompt = 'Options:\n'
113 | for key, item in options.items():
114 | options_prompt += f'{key}. {item}\n'
115 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
116 | prompt = ''
117 | if hint is not None:
118 | prompt += f'Hint: {hint}\n'
119 | prompt += f'Question: {question}\n'
120 | if len(options):
121 | prompt += options_prompt
122 | prompt += 'Please select the correct answer from the options above. \n'
123 |
124 | return dict(image=tgt_path, text=prompt)
125 |
126 | def display(self, line):
127 | if isinstance(line, int):
128 | line = self.data.iloc[line]
129 | mmqa_display(line)
130 |
131 |
--------------------------------------------------------------------------------
/vlmeval/utils/dataset_config.py:
--------------------------------------------------------------------------------
1 | from ..smp import listinstr
2 |
3 | dataset_URLs = {
4 | 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv",
5 | 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv",
6 | 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv",
7 | 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv",
8 | "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only
9 | "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only
10 | 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv",
11 | 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv",
12 | 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv",
13 | "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv",
14 | "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv",
15 | "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv",
16 | "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv",
17 | "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv",
18 | 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv",
19 | "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv",
20 | "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv",
21 | "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv",
22 | 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv",
23 | 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv",
24 | 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv",
25 | 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv",
26 | "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv",
27 | 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv",
28 | "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv",
29 | }
30 |
31 | dataset_md5_dict = {
32 | 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8",
33 | 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528",
34 | 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd",
35 | 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e",
36 | "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only
37 | "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only
38 | 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07",
39 | 'MME': "b36b43c3f09801f5d368627fb92187c3",
40 | 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0",
41 | "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d",
42 | "MMVet": "f400d7f513a585a0f218cbd6882e0671",
43 | 'COCO_VAL': "72a5079dead060269ac222c5aa5128af",
44 | 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
45 | 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
46 | 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
47 | 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d",
48 | 'MMMU_TEST' : "c19875d11a2d348d07e5eb4bdf33166d",
49 | 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
50 | 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc',
51 | 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
52 | 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
53 | 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
54 | "DocVQA_VAL": 'c911fdc5f4974513c112cc83a25c99d9',
55 | "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975",
56 | "LLaVABench": "d382a093f749a697820d3dadd61c8428"
57 | }
58 |
59 | img_root_map = {k: k for k in dataset_URLs}
60 | img_root_map.update({
61 | 'MMBench_DEV_EN': "MMBench",
62 | 'MMBench_TEST_EN': "MMBench",
63 | 'MMBench_DEV_CN': "MMBench",
64 | 'MMBench_TEST_CN': "MMBench",
65 | "MMBench_CN": "MMBench", # Link Invalid, Internal Only
66 | 'COCO_VAL':'COCO',
67 | 'OCRVQA_TEST': 'OCRVQA',
68 | 'OCRVQA_TESTCORE': 'OCRVQA',
69 | 'TextVQA_VAL': 'TextVQA',
70 | 'MMMU_DEV_VAL': 'MMMU',
71 | "MMMU_TEST": "MMMU",
72 | 'MathVista_MINI': 'MathVista',
73 | 'ChartQA_VALTEST_HUMAN': 'ChartQA',
74 | 'HallusionBench': 'Hallusion',
75 | 'DocVQA_VAL': 'DocVQA',
76 | })
77 |
78 | assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict)
79 |
80 | def DATASET_TYPE(dataset):
81 | if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset.lower()):
82 | return 'multi-choice'
83 | elif 'MME' in dataset:
84 | return 'Y/N'
85 | elif 'COCO' in dataset:
86 | return 'Caption'
87 | elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa'], dataset.lower()):
88 | return 'VQA'
89 | else:
90 | return 'QA'
91 |
92 | def abbr2full(s):
93 | datasets = [x for x in img_root_map]
94 | ins = [s in d for d in datasets]
95 | if sum(ins) == 1:
96 | for d in datasets:
97 | if s in d:
98 | return d
99 | else:
100 | return None
101 |
--------------------------------------------------------------------------------
/vlmeval/utils/matching_util.py:
--------------------------------------------------------------------------------
1 | import string
2 | import copy as cp
3 | import os
4 | from ..smp import *
5 |
6 | def can_infer_option(answer, choices):
7 | verbose = os.environ.get('VERBOSE', 0)
8 | # Choices is a dictionary
9 | if 'Failed to obtain answer via API' in answer:
10 | return False
11 |
12 | bard_err = [
13 | "Sorry, I can't help with images of people yet.",
14 | "I can't process this file."
15 | ]
16 | for err in bard_err:
17 | if err in answer:
18 | return 'Z'
19 |
20 | def count_choice(splits, choices, prefix='', suffix=''):
21 | cnt = 0
22 | for c in choices:
23 | if prefix + c + suffix in splits:
24 | cnt += 1
25 | return cnt
26 |
27 | answer_mod = cp.copy(answer)
28 | chars = '.()[],:;!*#{}'
29 | for c in chars:
30 | answer_mod = answer_mod.replace(c, ' ')
31 |
32 | splits = [x.strip() for x in answer_mod.split()]
33 | count = count_choice(splits, choices)
34 |
35 | if count == 1:
36 | for ch in choices:
37 | if 'A' in splits and len(splits) > 3 and verbose:
38 | logger = get_logger('Evaluation')
39 | logger.info(f'A might be a quantifier in the string: {answer}.')
40 | return False
41 | if ch in splits:
42 | return ch
43 | elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
44 | return 'Z'
45 | return False
46 |
47 | def can_infer_text(answer, choices):
48 | answer = answer.lower()
49 | assert isinstance(choices, dict)
50 | for k in choices:
51 | assert k in string.ascii_uppercase
52 | choices[k] = str(choices[k]).lower()
53 | cands = []
54 | for k in choices:
55 | if choices[k] in answer:
56 | cands.append(k)
57 | if len(cands) == 1:
58 | return cands[0]
59 | return False
60 |
61 | def can_infer(answer, choices):
62 | answer = str(answer)
63 | copt = can_infer_option(answer, choices)
64 | return copt if copt else can_infer_text(answer, choices)
--------------------------------------------------------------------------------
/vlmeval/utils/mp_util.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import os
3 | from typing import Callable, Iterable, Sized
4 |
5 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
6 | TaskProgressColumn, TextColumn, TimeRemainingColumn)
7 | from rich.text import Text
8 | import os.path as osp
9 | import portalocker
10 | from ..smp import load, dump
11 |
12 |
13 | class _Worker:
14 | """Function wrapper for ``track_progress_rich``"""
15 |
16 | def __init__(self, func) -> None:
17 | self.func = func
18 |
19 | def __call__(self, inputs):
20 | inputs, idx = inputs
21 | if not isinstance(inputs, (tuple, list, dict)):
22 | inputs = (inputs, )
23 |
24 | if isinstance(inputs, dict):
25 | return self.func(**inputs), idx
26 | else:
27 | return self.func(*inputs), idx
28 |
29 |
30 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
31 | """Skip calculating remaining time for the first few times.
32 |
33 | Args:
34 | skip_times (int): The number of times to skip. Defaults to 0.
35 | """
36 |
37 | def __init__(self, *args, skip_times=0, **kwargs):
38 | super().__init__(*args, **kwargs)
39 | self.skip_times = skip_times
40 |
41 | def render(self, task: Task) -> Text:
42 | """Show time remaining."""
43 | if task.completed <= self.skip_times:
44 | return Text('-:--:--', style='progress.remaining')
45 | return super().render(task)
46 |
47 |
48 | def _tasks_with_index(tasks):
49 | """Add index to tasks."""
50 | for idx, task in enumerate(tasks):
51 | yield task, idx
52 |
53 | def track_progress_rich(func: Callable,
54 | tasks: Iterable = tuple(),
55 | task_num: int = None,
56 | nproc: int = 1,
57 | chunksize: int = 1,
58 | description: str = 'Processing',
59 | save=None, keys=None,
60 | color: str = 'blue') -> list:
61 | """Track the progress of parallel task execution with a progress bar. The
62 | built-in :mod:`multiprocessing` module is used for process pools and tasks
63 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
64 |
65 | Args:
66 | func (callable): The function to be applied to each task.
67 | tasks (Iterable or Sized): A tuple of tasks. There are several cases
68 | for different format tasks:
69 | - When ``func`` accepts no arguments: tasks should be an empty
70 | tuple, and ``task_num`` must be specified.
71 | - When ``func`` accepts only one argument: tasks should be a tuple
72 | containing the argument.
73 | - When ``func`` accepts multiple arguments: tasks should be a
74 | tuple, with each element representing a set of arguments.
75 | If an element is a ``dict``, it will be parsed as a set of
76 | keyword-only arguments.
77 | Defaults to an empty tuple.
78 | task_num (int, optional): If ``tasks`` is an iterator which does not
79 | have length, the number of tasks can be provided by ``task_num``.
80 | Defaults to None.
81 | nproc (int): Process (worker) number, if nuproc is 1,
82 | use single process. Defaults to 1.
83 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
84 | Defaults to 1.
85 | description (str): The description of progress bar.
86 | Defaults to "Process".
87 | color (str): The color of progress bar. Defaults to "blue".
88 |
89 | Examples:
90 | >>> import time
91 |
92 | >>> def func(x):
93 | ... time.sleep(1)
94 | ... return x**2
95 | >>> track_progress_rich(func, range(10), nproc=2)
96 |
97 | Returns:
98 | list: The task results.
99 | """
100 | if save is not None:
101 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
102 | if not osp.exists(save):
103 | dump({}, save)
104 | if keys is not None:
105 | assert len(keys) == len(tasks)
106 |
107 | if not callable(func):
108 | raise TypeError('func must be a callable object')
109 | if not isinstance(tasks, Iterable):
110 | raise TypeError(
111 | f'tasks must be an iterable object, but got {type(tasks)}')
112 | if isinstance(tasks, Sized):
113 | if len(tasks) == 0:
114 | if task_num is None:
115 | raise ValueError('If tasks is an empty iterable, '
116 | 'task_num must be set')
117 | else:
118 | tasks = tuple(tuple() for _ in range(task_num))
119 | else:
120 | if task_num is not None and task_num != len(tasks):
121 | raise ValueError('task_num does not match the length of tasks')
122 | task_num = len(tasks)
123 |
124 | if nproc <= 0:
125 | raise ValueError('nproc must be a positive number')
126 |
127 | skip_times = nproc * chunksize if nproc > 1 else 0
128 | prog_bar = Progress(
129 | TextColumn('{task.description}'),
130 | BarColumn(),
131 | _SkipFirstTimeRemainingColumn(skip_times=skip_times),
132 | MofNCompleteColumn(),
133 | TaskProgressColumn(show_speed=True),
134 | )
135 |
136 | worker = _Worker(func)
137 | task_id = prog_bar.add_task(
138 | total=task_num, color=color, description=description)
139 | tasks = _tasks_with_index(tasks)
140 |
141 | # Use single process when nproc is 1, else use multiprocess.
142 | with prog_bar:
143 | if nproc == 1:
144 | results = []
145 | for task in tasks:
146 | result, idx = worker(task)
147 | results.append(worker(task)[0])
148 | if save is not None:
149 | with portalocker.Lock(save, timeout=5) as fh:
150 | ans = load(save)
151 | ans[keys[idx]] = result
152 |
153 | if os.environ.get('VERBOSE', True):
154 | print(keys[idx], result, flush=True)
155 |
156 | dump(ans, save)
157 | fh.flush()
158 | os.fsync(fh.fileno())
159 |
160 | prog_bar.update(task_id, advance=1, refresh=True)
161 | else:
162 | with Pool(nproc) as pool:
163 | results = []
164 | unordered_results = []
165 | gen = pool.imap_unordered(worker, tasks, chunksize)
166 | try:
167 | for result in gen:
168 | result, idx = result
169 | unordered_results.append((result, idx))
170 |
171 | if save is not None:
172 | with portalocker.Lock(save, timeout=5) as fh:
173 | ans = load(save)
174 | ans[keys[idx]] = result
175 |
176 | if os.environ.get('VERBOSE', False):
177 | print(keys[idx], result, flush=True)
178 |
179 | dump(ans, save)
180 | fh.flush()
181 | os.fsync(fh.fileno())
182 |
183 | results.append(None)
184 | prog_bar.update(task_id, advance=1, refresh=True)
185 | except Exception as e:
186 | prog_bar.stop()
187 | raise e
188 | for result, idx in unordered_results:
189 | results[idx] = result
190 | return results
--------------------------------------------------------------------------------
/vlmeval/vlm/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_grad_enabled(False)
3 | torch.manual_seed(1234)
4 |
5 | from .hpt import HPT
6 | from .hpt1_5 import HPT1_5
7 |
--------------------------------------------------------------------------------
/vlmeval/vlm/hpt1_5.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import string
4 | import sys
5 | import warnings
6 | import re
7 | import torch.nn as nn
8 |
9 | import pandas as pd
10 | import torch
11 | from huggingface_hub import snapshot_download
12 | from PIL import Image
13 | from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
14 | SiglipImageProcessor,
15 | GenerationConfig, StoppingCriteriaList)
16 |
17 | from ..smp import cn_string, get_cache_path
18 | from ..utils import DATASET_TYPE, CustomPrompt
19 | from .modeling_siglip import SiglipVisionModel
20 |
21 | def interpolate_pos_embed_siglip(model, new_size):
22 | pos_emb = model.vision_model.embeddings.position_embedding.weight.float()
23 | ori_size = int((pos_emb.shape[0])**0.5)
24 | dim = pos_emb.shape[1]
25 | print("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size))
26 | pos_tokens = pos_emb
27 | pos_tokens = pos_tokens.reshape(-1, ori_size, ori_size, dim).permute(0, 3, 1, 2)
28 | pos_tokens = torch.nn.functional.interpolate(
29 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
30 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2).squeeze(0)
31 | new_pos_embed = pos_tokens #torch.cat((extra_tokens, pos_tokens), dim=0)
32 | new_pos_embed = new_pos_embed.to(torch.float16)
33 | return torch.nn.Parameter(new_pos_embed)
34 |
35 | class HPT1_5(CustomPrompt):
36 |
37 | INSTALL_REQ = True
38 |
39 | def __init__(self,
40 | global_model_path='HyperGAI/HPT1_5-Air-Llama-3-8B-Instruct-multimodal',
41 | vis_scale=448,
42 | visual_select_layer=-2,
43 | prompt_template='llama3_chat',
44 | stop_words=[],
45 | torch_dtype=torch.float16):
46 | from vlmeval.utils import PROMPT_TEMPLATE, StopWordStoppingCriteria
47 |
48 | llm = AutoModelForCausalLM.from_pretrained(global_model_path,
49 | subfolder='llm',
50 | trust_remote_code=True,
51 | torch_dtype=torch_dtype,
52 | device_map='cpu')
53 |
54 | tokenizer = AutoTokenizer.from_pretrained(global_model_path,
55 | subfolder='llm',
56 | trust_remote_code=True,
57 | encode_special_tokens=True)
58 | print(f'Load LLM')
59 |
60 | # build visual_encoder
61 | image_size = vis_scale
62 | self.image_size = vis_scale
63 |
64 | image_processor = SiglipImageProcessor.from_pretrained(global_model_path,
65 | subfolder='visual_encoder',
66 | size={"height": image_size, "width": image_size})
67 |
68 | visual_encoder = SiglipVisionModel.from_pretrained(global_model_path,
69 | subfolder='visual_encoder',
70 | torch_dtype=torch_dtype, device_map='cpu')
71 |
72 | patch_size = visual_encoder.vision_model.embeddings.patch_size
73 | num_positions = (image_size//patch_size)**2
74 | new_size = image_size//patch_size
75 | visual_encoder.vision_model.embeddings.num_patches = (image_size//patch_size)**2
76 | visual_encoder.vision_model.embeddings.num_positions = num_positions
77 | visual_encoder.vision_model.embeddings.position_ids = torch.arange(num_positions).expand((1, -1))
78 | visual_encoder.vision_model.embeddings.position_embedding.weight = interpolate_pos_embed_siglip(visual_encoder, new_size)
79 | visual_encoder.config.image_size = image_size
80 |
81 | print(f'Load visual_encoder')
82 |
83 |
84 | projector = AutoModel.from_pretrained(global_model_path,
85 | subfolder='projector',
86 | trust_remote_code=True,
87 | torch_dtype=torch_dtype,
88 | )
89 | print(f'Load projector')
90 |
91 | llm.eval()
92 | visual_encoder.eval()
93 | projector.eval()
94 |
95 | self.llm = llm.cuda()
96 | self.tokenizer = tokenizer
97 | self.visual_encoder = visual_encoder.cuda()
98 | self.image_processor = image_processor
99 | self.projector = projector.cuda()
100 | self.visual_select_layer = visual_select_layer
101 | if prompt_template is not None:
102 | self.prompt_template = PROMPT_TEMPLATE[prompt_template]
103 | stop_words += self.prompt_template.get('STOP_WORDS', [])
104 | else:
105 | self.prompt_template = None
106 |
107 | self.stop_criteria = StoppingCriteriaList()
108 | for word in stop_words:
109 | self.stop_criteria.append(
110 | StopWordStoppingCriteria(self.tokenizer, word))
111 |
112 | def build_gen_config(self, dataset, qtype):
113 | gen_kwargs = dict(max_new_tokens=1024,
114 | do_sample=True,
115 | temperature=0.6,
116 | num_beams=5,
117 | top_p=0.9,
118 | eos_token_id=self.tokenizer.eos_token_id,
119 | pad_token_id=self.tokenizer.pad_token_id
120 | if self.tokenizer.pad_token_id is not None else
121 | self.tokenizer.eos_token_id)
122 | # For single word generation
123 | if (dataset is not None and DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']):
124 | if qtype == '':
125 | gen_kwargs.update(
126 | dict(max_new_tokens=5, do_sample=False, num_beams=1))
127 | elif qtype == 'open':
128 | gen_kwargs.update(
129 | dict(max_new_tokens=1024, do_sample=False, num_beams=1))
130 |
131 | return GenerationConfig(**gen_kwargs)
132 |
133 | def use_custom_prompt(self, dataset):
134 | assert dataset is not None
135 | if DATASET_TYPE(dataset) == 'multi-choice':
136 | return True
137 | return False
138 |
139 | def build_prompt(self, line, dataset=None):
140 | assert self.use_custom_prompt(dataset)
141 | assert dataset is None or isinstance(dataset, str)
142 | tgt_path = self.dump_image(line, dataset)
143 |
144 | question = line['question']
145 | hint = line['hint'] if ('hint' in line
146 | and not pd.isna(line['hint'])) else None
147 |
148 | if hint is not None:
149 | try:
150 | question = hint + '\n' + question
151 | except:
152 | print(hint)
153 |
154 | options = {
155 | cand: line[cand]
156 | for cand in string.ascii_uppercase
157 | if cand in line and not pd.isna(line[cand])
158 | }
159 | for key, item in options.items():
160 | question += f'\n{key}. {item}'
161 | qtype = ''
162 | if not cn_string(question):
163 | if 'question_type' in line.keys() and 'choice' not in line['question_type']:
164 | prompt = question + '\n' + ("Answer with succinct phrase or a single word, incorporating professional terms when necessary, to address inquiries regarding scene description, key elements identification, and potential activities in the provided image.") #("Answer the question using a single word or phrase.")
165 | qtype = 'open'
166 | else:
167 | # prompt = question + '\n' + ("Answer with the option's letter from the given choices directly.")
168 | prompt = question + '\n' "Answer the question with the correct option's letter from the given choices."
169 | else:
170 | prompt = question + '\n' + '请直接回答选项字母。'
171 |
172 | return {'image': tgt_path, 'text': prompt, 'qtype': qtype}
173 |
174 | def generate(self, image_path, prompt, dataset=None, qtype=''):
175 | from vlmeval.utils import expand2square, prepare_inputs_labels_for_multimodal, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
176 |
177 | if isinstance(image_path, list) and len(image_path) > 1 and prompt.count('') > 0 and prompt.count('') > 0:
178 | image_list = [Image.open(image_path_cur).convert('RGB') for image_path_cur in image_path]
179 | image1 = image_list[0]
180 | w1, h1 = image1.size
181 |
182 | image2 = image_list[1]
183 | w2, h2 = image2.size
184 |
185 | w_sum, h_sum = w1 + w2, h1 + h2
186 | if w_sum > h_sum:
187 | h2_new = int(h2 * w2 / w1)
188 | image2 =image2.resize((w1, h2_new))
189 | image = Image.new('RGB', (w1, h1 + h2_new))
190 | image.paste(image1, (0, 0))
191 | image.paste(image2, (0, h1))
192 |
193 | prompt = prompt.replace('', '')
194 | prompt = prompt.replace('', '')
195 | else:
196 | w2_new = int(w2 * h2 / h1)
197 | image2 = image2.resize((w2_new, h2))
198 | image = Image.new('RGB', (w1 + w2_new, h2))
199 | image.paste(image1, (0, 0))
200 | image.paste(image2, (w1, 0))
201 |
202 | prompt = prompt.replace('', '')
203 | prompt = prompt.replace('', '')
204 |
205 | image_size = 448
206 | image = image.resize((image_size, image_size), 3)
207 | elif isinstance(image_path, list):
208 | image = Image.open(image_path[0]).convert('RGB')
209 | else:
210 | image = Image.open(image_path).convert('RGB')
211 |
212 | image = self.image_processor.preprocess(
213 | image, return_tensors='pt')['pixel_values'][0]
214 | image = image.cuda().unsqueeze(0)
215 | visual_outputs = self.visual_encoder(image, output_hidden_states=True)
216 | # pixel_values = self.projector(visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
217 | pixel_values = self.projector(visual_outputs.hidden_states[self.visual_select_layer])
218 |
219 | inputs = DEFAULT_IMAGE_TOKEN + '\n' + prompt
220 |
221 | if self.prompt_template:
222 | inputs = self.prompt_template['INSTRUCTION'].format(input=inputs)
223 |
224 | chunk_encode = []
225 | for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
226 | if idx == 0:
227 | cur_encode = self.tokenizer(chunk)
228 | else:
229 | cur_encode = self.tokenizer(chunk, add_special_tokens=False)
230 | chunk_encode.append(cur_encode)
231 | if len(chunk_encode) != 2:
232 | print(prompt)
233 |
234 | assert len(chunk_encode) == 2
235 | ids = []
236 | for idx, cur_chunk_encode in enumerate(chunk_encode):
237 | ids.extend(cur_chunk_encode['input_ids'])
238 | if idx != len(chunk_encode) - 1:
239 | ids.append(IMAGE_TOKEN_INDEX)
240 | ids = torch.tensor(ids).cuda().unsqueeze(0)
241 | mm_inputs = prepare_inputs_labels_for_multimodal(
242 | llm=self.llm, input_ids=ids, pixel_values=pixel_values)
243 |
244 | gen_config = self.build_gen_config(dataset, qtype)
245 |
246 | generate_output = self.llm.generate(
247 | **mm_inputs,
248 | generation_config=gen_config,
249 | streamer=None,
250 | bos_token_id=self.tokenizer.bos_token_id,
251 | stopping_criteria=self.stop_criteria)
252 | predict = self.tokenizer.decode(generate_output[0],
253 | skip_special_tokens=True).strip()
254 | return predict
255 |
--------------------------------------------------------------------------------