├── .gitignore ├── LICENSE ├── README.md ├── assets ├── leaderboard-1.0.png ├── leaderboard-1.5-air.png ├── leaderboard-1.5-edge.png └── pipeline.png ├── demo ├── cat.jpg ├── demo.py └── einstein.jpg ├── requirements.txt ├── run.py ├── run.sh ├── setup.py └── vlmeval ├── __init__.py ├── api ├── __init__.py ├── base.py ├── gemini.py ├── gpt.py ├── gpt_int.py ├── hf_chat_model.py └── qwen_vl_api.py ├── config.py ├── evaluate ├── __init__.py ├── coco_eval.py ├── llavabench.py ├── mathvista_eval.py ├── misc.py ├── mmmu_eval │ ├── README.md │ ├── __init__.py │ ├── answer_dict_val.json │ ├── configs │ │ └── llava1.5.yaml │ ├── example_outputs │ │ ├── llava1.5_13b │ │ │ ├── Accounting │ │ │ │ └── output.json │ │ │ ├── Agriculture │ │ │ │ └── output.json │ │ │ ├── Architecture_and_Engineering │ │ │ │ └── output.json │ │ │ ├── Art │ │ │ │ └── output.json │ │ │ ├── Art_Theory │ │ │ │ └── output.json │ │ │ ├── Basic_Medical_Science │ │ │ │ └── output.json │ │ │ ├── Biology │ │ │ │ └── output.json │ │ │ ├── Chemistry │ │ │ │ └── output.json │ │ │ ├── Clinical_Medicine │ │ │ │ └── output.json │ │ │ ├── Computer_Science │ │ │ │ └── output.json │ │ │ ├── Design │ │ │ │ └── output.json │ │ │ ├── Diagnostics_and_Laboratory_Medicine │ │ │ │ └── output.json │ │ │ ├── Economics │ │ │ │ └── output.json │ │ │ ├── Electronics │ │ │ │ └── output.json │ │ │ ├── Energy_and_Power │ │ │ │ └── output.json │ │ │ ├── Finance │ │ │ │ └── output.json │ │ │ ├── Geography │ │ │ │ └── output.json │ │ │ ├── History │ │ │ │ └── output.json │ │ │ ├── Literature │ │ │ │ └── output.json │ │ │ ├── Manage │ │ │ │ └── output.json │ │ │ ├── Marketing │ │ │ │ └── output.json │ │ │ ├── Materials │ │ │ │ └── output.json │ │ │ ├── Math │ │ │ │ └── output.json │ │ │ ├── Mechanical_Engineering │ │ │ │ └── output.json │ │ │ ├── Music │ │ │ │ └── output.json │ │ │ ├── Pharmacy │ │ │ │ └── output.json │ │ │ ├── Physics │ │ │ │ └── output.json │ │ │ ├── Psychology │ │ │ │ └── output.json │ │ │ ├── Public_Health │ │ │ │ └── output.json │ │ │ ├── Sociology │ │ │ │ └── output.json │ │ │ └── total_val_output.json │ │ ├── llava1.5_13b_val.json │ │ └── qwen_vl │ │ │ ├── Accounting │ │ │ └── output.json │ │ │ ├── Agriculture │ │ │ └── output.json │ │ │ ├── Architecture_and_Engineering │ │ │ └── output.json │ │ │ ├── Art │ │ │ └── output.json │ │ │ ├── Art_Theory │ │ │ └── output.json │ │ │ ├── Basic_Medical_Science │ │ │ └── output.json │ │ │ ├── Biology │ │ │ └── output.json │ │ │ ├── Chemistry │ │ │ └── output.json │ │ │ ├── Clinical_Medicine │ │ │ └── output.json │ │ │ ├── Computer_Science │ │ │ └── output.json │ │ │ ├── Design │ │ │ └── output.json │ │ │ ├── Diagnostics_and_Laboratory_Medicine │ │ │ └── output.json │ │ │ ├── Economics │ │ │ └── output.json │ │ │ ├── Electronics │ │ │ └── output.json │ │ │ ├── Energy_and_Power │ │ │ └── output.json │ │ │ ├── Finance │ │ │ └── output.json │ │ │ ├── Geography │ │ │ └── output.json │ │ │ ├── History │ │ │ └── output.json │ │ │ ├── Literature │ │ │ └── output.json │ │ │ ├── Manage │ │ │ └── output.json │ │ │ ├── Marketing │ │ │ └── output.json │ │ │ ├── Materials │ │ │ └── output.json │ │ │ ├── Math │ │ │ └── output.json │ │ │ ├── Mechanical_Engineering │ │ │ └── output.json │ │ │ ├── Music │ │ │ └── output.json │ │ │ ├── Pharmacy │ │ │ └── output.json │ │ │ ├── Physics │ │ │ └── output.json │ │ │ ├── Psychology │ │ │ └── output.json │ │ │ ├── Public_Health │ │ │ └── output.json │ │ │ ├── Sociology │ │ │ └── output.json │ │ │ └── total_val_output.json │ ├── main_eval_only.py │ ├── main_parse_and_eval.py │ ├── mmmu_eval_script.py │ ├── print_results.py │ ├── run_llava.py │ └── utils │ │ ├── data_utils.py │ │ ├── eval_utils.py │ │ └── model_utils.py ├── mmvet_eval.py ├── multiple_choice.py ├── vqa_eval.py └── yes_or_no.py ├── inference.py ├── smp ├── __init__.py ├── file.py ├── lb.py ├── log.py ├── misc.py └── vlm.py ├── utils ├── __init__.py ├── custom_prompt.py ├── dataset.py ├── dataset_config.py ├── matching_util.py ├── mp_util.py └── xtuner_util.py └── vlm ├── __init__.py ├── hpt.py ├── hpt1_5.py └── modeling_siglip.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # project-specific 148 | output/ 149 | debug*/ 150 | *.bak 151 | *.dir 152 | *.dat 153 | *.tsv 154 | *.gz 155 | 156 | cache/ 157 | 158 | # models 159 | *.onnx 160 | *.pth 161 | *.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HPT - Open Multimodal Large Language Models 2 | Hyper-Pretrained Transformers (HPT) is a novel multimodal LLM framework from HyperGAI, and has been trained for vision-language models that are capable of understanding both textual and visual inputs. HPT has achieved highly competitive results with state-of-the-art models on a variety of multimodal LLM benchmarks. This repository contains the open-source implementation of inference code to reproduce the evaluation results of HPT on different benchmarks. 3 | 4 | ## Release 5 | - [6/06] :fire: Releasing **HPT 1.5 Edge**, our latest open-source model tailored to edge devices. Despite its size (<5B), Edge demonstrates impressive capabilities while being extremely efficient. HPT 1.5 Edge is publicly available on [[HuggingFace Repository](https://huggingface.co/HyperGAI/HPT1_5-Edge)]. Please read our [[technical blog post](https://hypergai.com/blog/hpt-1-5-edge-towards-multimodal-llms-for-edge-devices)] for more details. 6 | - [5/03] **HPT 1.5 Air**, our best open-sourced 8B Multimodal LLM with [Llama 3](https://huggingface.co/blog/llama3). Built with Meta Llama 3, Our hyper capable HPT 1.5 Air packs a punch on real world understanding and complex reasoning. HPT Air 1.5 achieves the best results among <10B models across a wide range of challenging benchmarks (MMMU, POPE, SEED-I, and more). HPT 1.5 Air is publicly available on [[HuggingFace Repository](https://huggingface.co/HyperGAI/HPT1_5-Air-Llama-3-8B-Instruct-multimodal)]. Please read our [[technical blog post](https://hypergai.com/blog/hpt-1-5-air-best-open-sourced-8b-multimodal-llm-with-llama-3)] for more details. 7 | - [3/16] **HPT 1.0 Air** is out, our most efficient model as a cost-effective solution that is capable of solving a wide range of vision-and-language tasks. HPT 1.0 Air is publicly available and achieves state-of-the-art results among all the open-source multimodal LLM models of similar or smaller sizes on the challenging MMMU benchmark. Please read our [[technical blog post](https://www.hypergai.com/blog/introducing-hpt-a-family-of-leading-multimodal-llms)] and [[HuggingFace Repository](https://huggingface.co/HyperGAI/HPT)] for more details. 8 | 9 | 10 | We release HPT 1.5 Edge as our latest open-sources model tailored to edge devices. Despite its size (<5B), Edge demonstrates impressive capabilities while being extremely efficient. We release HPT 1.5 Edge publicly at Huggingface and Github under the Apache 2.0 license. 11 | 12 | 13 | 14 | ## Table of Contents 15 | - [Overview of Model Achitecture](#overview-of-model-achitecture) 16 | - [Quick Start](#quick-start) 17 | - [Installation](#installation) 18 | - [Prepare the Model](#prepare-the-model) 19 | - [Demo](#demo) 20 | - [Evaluations](#evaluations) 21 | - [Benchmarks](#benchmarks) 22 | - [Pretrained Models Used](#pretrained-models-used) 23 | - [Disclaimer and Responsible Use](#disclaimer-and-responsible-use) 24 | - [Contact Us](contact-us) 25 | - [License](#license) 26 | - [Acknowledgements](#acknowledgements) 27 | 28 | 29 | 30 | ## Overview of Model Achitecture 31 | 32 |
33 | 34 |
35 |
36 | 37 | ## Quick Start 38 | 39 | ### Installation 40 | 41 | ``` 42 | pip install -r requirements.txt 43 | pip install -e . 44 | ``` 45 | 46 | ### Prepare the Model 47 | 48 | You can download the model weights from HF into your [Local Path] and set the `global_model_path` as your [Local Path] in the model [config file](./vlmeval/config.py#L24): 49 | ``` 50 | git lfs install 51 | git clone https://huggingface.co/HyperGAI/HPT1_5-Edge [Local Path] 52 | ``` 53 | 54 | You can also set other strategies in the [config file](./vlmeval/config.py#L24) that are different from our default settings. 55 | 56 | ### Demo 57 | 58 | After setting up the config file, launch the model demo for a quick trial: 59 | 60 | ``` 61 | python demo/demo.py --image_path [Image] --text [Text] --model [Config] 62 | ``` 63 | 64 | Example: 65 | 66 | ``` 67 | python demo/demo.py --image_path demo/einstein.jpg --text 'What is unusual about this image?' --model hpt-edge-1-5 68 | ``` 69 | 70 | ## Evaluations 71 | 72 | Launch the model for evaluation: 73 | 74 | ``` 75 | torchrun --nproc-per-node=8 run.py --data [Dataset] --model [Config] 76 | ``` 77 | 78 | Example for HPT 1.5 Edge: 79 | 80 | ``` 81 | torchrun --nproc-per-node=8 run.py --data MMMU_DEV_VAL --model hpt-edge-1-5 82 | ``` 83 | 84 | ## Benchmarks 85 | 86 | **For HPT 1.5 Edge** 87 | 88 |
89 | 90 |
91 |
92 | 93 | - The majority of the results presented are taken from the models‘ original reports while the others are from Phi-3-vision evaluations, which we mark with an asterisk (*). 94 | - The benchmark result of HPT1.5 Air and HPT1.0 is in assets directory. 95 | 96 | 97 | ## Pretrained Models Used 98 | 99 | **HPT 1.5 Edge** 100 | 101 | - Pretrained LLM: [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) 102 | 103 | - Pretrained Visual Encoder: [siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) 104 | 105 | **HPT 1.5 Air** 106 | 107 | - Pretrained LLM: [Llama3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) 108 | 109 | - Pretrained Visual Encoder: [siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) 110 | 111 | **HPT 1.0 Air** 112 | 113 | - Pretrained LLM: [Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B) 114 | 115 | - Pretrained Visual Encoder: [clip-vit-large-patch14-336 ](https://huggingface.co/openai/clip-vit-large-patch14-336) 116 | 117 | ## Disclaimer and Responsible Use 118 | 119 | Note that the HPT Air is a quick open release of our models to facilitate the open, responsible AI research and community development. It does not have any moderation mechanism and provides no guarantees on their results. We hope to engage with the community to make the model finely respect guardrails to allow practical adoptions in real-world applications requiring moderated outputs. 120 | 121 | ## Contact Us 122 | 123 | - Contact: HPT@hypergai.com 124 | - Follow us on [Twitter](https://twitter.com/hypergai). 125 | - Follow us on [Linkedin](https://www.linkedin.com/company/hypergai/). 126 | - Visit our [website](https://www.hypergai.com/) to learn more about us. 127 | 128 | 129 | ## License 130 | 131 | This project is released under the [Apache 2.0 license](LICENSE). Parts of this project contain code and models from other sources, which are subject to their respective licenses and you need to apply their respective license if you want to use for commercial purposes. 132 | 133 | ## Acknowledgements 134 | 135 | The evaluation code for running this demo was extended based on the [VLMEvalKit project](https://github.com/open-compass/VLMEvalKit). 136 | We also thank [OpenAI](https://openai.com) for open-sourcing their visual encoder models, [01.AI](https://www.01.ai), [Meta](https://www.meta.com/) and [Microsoft](https://www.microsoft.com/) for open-sourcing their large language models. 137 | -------------------------------------------------------------------------------- /assets/leaderboard-1.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/leaderboard-1.0.png -------------------------------------------------------------------------------- /assets/leaderboard-1.5-air.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/leaderboard-1.5-air.png -------------------------------------------------------------------------------- /assets/leaderboard-1.5-edge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/leaderboard-1.5-edge.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/assets/pipeline.png -------------------------------------------------------------------------------- /demo/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/demo/cat.jpg -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from vlmeval.config import supported_VLM 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--image_path', type=str, nargs='+', required=True) 7 | parser.add_argument('--text', type=str, required=True) 8 | parser.add_argument("--model", type=str, required=True) 9 | parser.add_argument("--nproc", type=int, default=4, required=False) 10 | parser.add_argument("--verbose", action='store_true') 11 | args = parser.parse_args() 12 | return args 13 | 14 | def main(): 15 | args = parse_args() 16 | model_name = args.model 17 | model = supported_VLM[model_name]() 18 | text = args.text 19 | image_path = args.image_path 20 | response = model.generate(prompt=text, image_path=image_path, dataset='demo') 21 | 22 | print(response) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /demo/einstein.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyperGAI/HPT/b65b9a07660d163a3f426b1d79f289f89ab22a91/demo/einstein.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.23.4 2 | openai==1.3.5 3 | requests 4 | tqdm 5 | pandas>=1.5.3 6 | gradio==4.15.0 7 | tiktoken 8 | rich 9 | portalocker 10 | timeout-decorator 11 | opencv-python==4.8.0.74 12 | pillow==10.2.0 13 | omegaconf 14 | matplotlib 15 | einops 16 | sentencepiece 17 | sty 18 | huggingface_hub 19 | visual_genome 20 | pycocoevalcap 21 | openpyxl 22 | seaborn 23 | tabulate 24 | xlsxwriter 25 | torch>=2.0.1 26 | typing_extensions==4.10.0 27 | transformers==4.37.0 28 | accelerate 29 | mmengine -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from vlmeval.smp import * 4 | from vlmeval.evaluate import COCO_eval, YOrN_eval, MMVet_eval, multiple_choice_eval, VQAEval, MathVista_eval, LLaVABench_eval 5 | from vlmeval.inference import infer_data_job, prefetch_acc 6 | from vlmeval.config import supported_VLM 7 | from vlmeval.utils import dataset_URLs, abbr2full 8 | from vlmeval.evaluate import mmmu_eval_func 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data', type=str, nargs='+', required=True) 13 | parser.add_argument("--model", type=str, nargs='+', required=True) 14 | parser.add_argument("--work-dir", type=str, default='.', help="select the output directory") 15 | parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer']) 16 | parser.add_argument("--nproc", type=int, default=4, help="Parallel API calling") 17 | parser.add_argument("--ignore", action='store_true', help="Ignore failed indices. ") 18 | parser.add_argument("--verbose", action='store_true') 19 | parser.add_argument("--prefetch", action='store_true') 20 | args = parser.parse_args() 21 | return args 22 | 23 | def main(): 24 | logger = get_logger('RUN') 25 | 26 | args = parse_args() 27 | assert len(args.data), "--data should be a list of data files" 28 | 29 | rank, world_size = get_rank_and_world_size() 30 | if world_size > 1: 31 | torch.cuda.set_device(rank) 32 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400)) 33 | 34 | for _, model_name in enumerate(args.model): 35 | model = None 36 | 37 | pred_root = osp.join(args.work_dir, 'output', model_name) 38 | os.makedirs(pred_root, exist_ok=True) 39 | 40 | for i, dataset_name in enumerate(args.data): 41 | if dataset_name not in dataset_URLs: 42 | dataset_name = abbr2full(dataset_name) 43 | 44 | if dataset_name not in dataset_URLs: 45 | logger.error(f'Unknown dataset: {dataset_name}. ') 46 | continue 47 | 48 | result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx' 49 | 50 | if model is None: 51 | model = model_name # which is only a name 52 | 53 | # CHECKER 54 | if dataset_name == 'CORE_MM': 55 | MULTI_IMG = getattr(supported_VLM[model_name].func, 'multi_generate', None) 56 | if MULTI_IMG is not None: 57 | logger.error(f'Model {model_name} does not support the `multi_generate` interface, which is required for testing CORE_MM, skip it. ') 58 | continue 59 | if args.mode == 'all': 60 | logger.error(f'Dataset {dataset_name} does not support `evaluation` now, will skip the evaluation. ') 61 | 62 | model = infer_data_job(model, work_dir=pred_root, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc, ignore_failed=args.ignore) 63 | 64 | if dataset_name in ['MMBench_TEST_CN', 'MMBench_TEST_EN', "MMMU_TEST"]: 65 | if not MMBenchOfficialServer(): 66 | logger.error(f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation. ') 67 | continue 68 | 69 | if rank == 0: 70 | time.sleep(3) 71 | res = None 72 | if listinstr(['SEEDBench_IMG', 'MMBench', 'CCBench', 'ScienceQA', 'AI2D'], dataset_name): 73 | res = prefetch_acc(result_file) 74 | else: 75 | logger.warning(f'{dataset_name} is not handled by prefetch score calculator') 76 | if res is not None: 77 | logger.info(f'{model_name} prefetching: ') 78 | logger.info(res) 79 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx')) 80 | 81 | if rank == 0 and args.mode == 'all': 82 | import pandas as pd 83 | import json 84 | if listinstr(['MMMU'], dataset_name): 85 | res_rec = pd.read_excel(result_file) 86 | recs = {} 87 | for id_, pred_ in zip(res_rec['id'], res_rec['prediction']): 88 | if 'dev_' in id_: 89 | continue 90 | recs[id_] = str(pred_) 91 | output_file = result_file.replace('.xlsx', '.json') 92 | json.dump(recs, open(output_file, 'w')) 93 | mmmu_eval_func(output_file) 94 | 95 | elif listinstr(['MMBench', 'CCBench', 'SEEDBench_IMG', 'ScienceQA', 'AI2D'], dataset_name): 96 | multiple_choice_eval(result_file, dataset=dataset_name, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose) 97 | elif listinstr(['MME', 'Hallusion'], dataset_name): 98 | YOrN_eval(result_file, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose, dataset=dataset_name) 99 | elif dataset_name == 'MMVet': 100 | MMVet_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose) 101 | elif listinstr(['COCO'], dataset_name): 102 | COCO_eval(result_file) 103 | elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA'], dataset_name): 104 | VQAEval(result_file, dataset_name) 105 | elif listinstr(['MathVista'], dataset_name): 106 | MathVista_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose) 107 | elif listinstr(['LLaVABench'], dataset_name): 108 | LLaVABench_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose) 109 | else: 110 | logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ') 111 | 112 | if __name__ == '__main__': 113 | main() -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc-per-node=8 run.py --data MMMU_DEV_VAL --model hpt-edge-1-5 2 | 3 | torchrun --nproc-per-node=8 run.py --data MMBench_DEV_EN --model hpt-edge-1-5 4 | 5 | torchrun --nproc-per-node=8 run.py --data SEEDBench_IMG --model hpt-edge-1-5 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from os.path import exists 4 | from setuptools import find_packages, setup 5 | 6 | def parse_requirements(fname='requirements.txt', with_version=True): 7 | """Parse the package dependencies listed in a requirements file but strips 8 | specific versioning information. 9 | 10 | Args: 11 | fname (str): path to requirements file 12 | with_version (bool, default=False): if True include version specs 13 | 14 | Returns: 15 | List[str]: list of requirements items 16 | 17 | CommandLine: 18 | python -c "import setup; print(setup.parse_requirements())" 19 | """ 20 | 21 | require_fpath = fname 22 | 23 | def parse_line(line): 24 | """Parse information from a line in a requirements text file.""" 25 | if line.startswith('-r '): 26 | # Allow specifying requirements in other files 27 | target = line.split(' ')[1] 28 | for info in parse_require_file(target): 29 | yield info 30 | else: 31 | info = {'line': line} 32 | if line.startswith('-e '): 33 | info['package'] = line.split('#egg=')[1] 34 | elif '@git+' in line: 35 | info['package'] = line 36 | else: 37 | # Remove versioning from the package 38 | pat = '(' + '|'.join(['>=', '==', '>']) + ')' 39 | parts = re.split(pat, line, maxsplit=1) 40 | parts = [p.strip() for p in parts] 41 | 42 | info['package'] = parts[0] 43 | if len(parts) > 1: 44 | op, rest = parts[1:] 45 | if ';' in rest: 46 | # Handle platform specific dependencies 47 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 48 | version, platform_deps = map(str.strip, 49 | rest.split(';')) 50 | info['platform_deps'] = platform_deps 51 | else: 52 | version = rest # NOQA 53 | info['version'] = (op, version) 54 | yield info 55 | 56 | def parse_require_file(fpath): 57 | with open(fpath, 'r') as f: 58 | for line in f.readlines(): 59 | line = line.strip() 60 | if line and not line.startswith('#'): 61 | for info in parse_line(line): 62 | yield info 63 | 64 | def gen_packages_items(): 65 | if exists(require_fpath): 66 | for info in parse_require_file(require_fpath): 67 | parts = [info['package']] 68 | if with_version and 'version' in info: 69 | parts.extend(info['version']) 70 | if not sys.version.startswith('3.4'): 71 | # apparently package_deps are broken in 3.4 72 | platform_deps = info.get('platform_deps') 73 | if platform_deps is not None: 74 | parts.append(';' + platform_deps) 75 | item = ''.join(parts) 76 | yield item 77 | 78 | packages = list(gen_packages_items()) 79 | return packages 80 | 81 | 82 | with open('README.md') as f: 83 | readme = f.read() 84 | 85 | 86 | def do_setup(): 87 | setup( 88 | name='vlmeval', 89 | version='0.1.0', 90 | description='OpenCompass VLM Evaluation Kit', 91 | author="Haodong Duan", 92 | author_email='dhd.efz@gmail.com', 93 | maintainer='Haodong Duan', 94 | maintainer_email='dhd.efz@gmail.com', 95 | long_description=readme, 96 | long_description_content_type='text/markdown', 97 | cmdclass={}, 98 | install_requires=parse_requirements('requirements.txt'), 99 | setup_requires=[], 100 | python_requires='>=3.7.0', 101 | packages=find_packages(exclude=[ 102 | 'test*', 103 | 'paper_test*', 104 | ]), 105 | keywords=['AI', 'NLP', 'in-context learning'], 106 | entry_points={ 107 | "console_scripts": [] 108 | }, 109 | classifiers=[ 110 | 'Programming Language :: Python :: 3.7', 111 | 'Programming Language :: Python :: 3.8', 112 | 'Programming Language :: Python :: 3.9', 113 | 'Programming Language :: Python :: 3.10', 114 | 'Intended Audience :: Developers', 115 | 'Intended Audience :: Education', 116 | 'Intended Audience :: Science/Research', 117 | ]) 118 | 119 | 120 | if __name__ == '__main__': 121 | do_setup() 122 | -------------------------------------------------------------------------------- /vlmeval/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | except ImportError: 4 | pass 5 | 6 | from .smp import * 7 | from .api import * 8 | from .evaluate import * 9 | from .utils import * 10 | from .vlm import * 11 | from .config import * -------------------------------------------------------------------------------- /vlmeval/api/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import OpenAIWrapper, GPT4V 2 | from .gpt_int import OpenAIWrapperInternal, GPT4V_Internal 3 | from .hf_chat_model import HFChatModel 4 | from .gemini import GeminiWrapper, GeminiProVision 5 | from .qwen_vl_api import QwenVLWrapper, QwenVLAPI 6 | 7 | __all__ = [ 8 | 'OpenAIWrapper', 'HFChatModel', 'OpenAIWrapperInternal', 'GeminiWrapper', 9 | 'GPT4V', 'GPT4V_Internal', 'GeminiProVision','QwenVLWrapper', 'QwenVLAPI' 10 | ] -------------------------------------------------------------------------------- /vlmeval/api/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random as rd 3 | from abc import abstractmethod 4 | from ..smp import get_logger 5 | 6 | class BaseAPI: 7 | 8 | def __init__(self, 9 | retry=10, 10 | wait=3, 11 | system_prompt=None, 12 | verbose=True, 13 | fail_msg='Failed to obtain answer via API.', 14 | **kwargs): 15 | self.wait = wait 16 | self.retry = retry 17 | self.system_prompt = system_prompt 18 | self.kwargs = kwargs 19 | self.verbose = verbose 20 | self.fail_msg = fail_msg 21 | self.logger = get_logger('ChatAPI') 22 | if len(kwargs): 23 | self.logger.info(f'BaseAPI received the following kwargs: {kwargs}') 24 | self.logger.info(f'Will try to use them as kwargs for `generate`. ') 25 | 26 | @abstractmethod 27 | def generate_inner(self, inputs, **kwargs): 28 | self.logger.warning(f'For APIBase, generate_inner is an abstract method. ') 29 | assert 0, 'generate_inner not defined' 30 | ret_code, answer, log = None, None, None 31 | # if ret_code is 0, means succeed 32 | return ret_code, answer, log 33 | 34 | def generate(self, inputs, **kwargs): 35 | input_type = None 36 | if isinstance(inputs, str): 37 | input_type = 'str' 38 | elif isinstance(inputs, list) and isinstance(inputs[0], str): 39 | input_type = 'strlist' 40 | elif isinstance(inputs, list) and isinstance(inputs[0], dict): 41 | input_type = 'dictlist' 42 | assert input_type is not None, input_type 43 | 44 | answer = None 45 | for i in range(self.retry): 46 | T = rd.random() * self.wait * 2 47 | time.sleep(T) 48 | try: 49 | ret_code, answer, log = self.generate_inner(inputs, **kwargs) 50 | if ret_code == 0 and self.fail_msg not in answer and answer != '': 51 | if self.verbose: 52 | print(answer) 53 | return answer 54 | elif self.verbose: 55 | self.logger.info(f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}") 56 | except Exception as err: 57 | if self.verbose: 58 | self.logger.error(f'An error occured during try {i}:') 59 | self.logger.error(err) 60 | 61 | return self.fail_msg if answer in ['', None] else answer 62 | 63 | 64 | -------------------------------------------------------------------------------- /vlmeval/api/gemini.py: -------------------------------------------------------------------------------- 1 | from vlmeval.smp import * 2 | from vlmeval.api.base import BaseAPI 3 | 4 | headers = 'Content-Type: application/json' 5 | 6 | class GeminiWrapper(BaseAPI): 7 | 8 | is_api: bool = True 9 | 10 | def __init__(self, 11 | retry: int = 5, 12 | wait: int = 5, 13 | key: str = None, 14 | verbose: bool = True, 15 | temperature: float = 0.0, 16 | system_prompt: str = None, 17 | max_tokens: int = 1024, 18 | proxy: str = None, 19 | **kwargs): 20 | 21 | self.fail_msg = 'Failed to obtain answer via API. ' 22 | self.max_tokens = max_tokens 23 | self.temperature = temperature 24 | if key is None: 25 | key = os.environ.get('GOOGLE_API_KEY', None) 26 | assert key is not None 27 | self.api_key = key 28 | if proxy is not None: 29 | proxy_set(proxy) 30 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 31 | 32 | @staticmethod 33 | def build_msgs(msgs_raw, system_prompt=None): 34 | msgs = cp.deepcopy(msgs_raw) 35 | assert len(msgs) % 2 == 1 36 | 37 | if system_prompt is not None: 38 | msgs[0] = [system_prompt, msgs[0]] 39 | ret = [] 40 | for i, msg in enumerate(msgs): 41 | role = 'user' if i % 2 == 0 else 'model' 42 | parts = msg if isinstance(msg, list) else [msg] 43 | ret.append(dict(role=role, parts=parts)) 44 | return ret 45 | 46 | def generate_inner(self, inputs, **kwargs) -> str: 47 | import google.generativeai as genai 48 | assert isinstance(inputs, str) or isinstance(inputs, list) 49 | pure_text = True 50 | if isinstance(inputs, list): 51 | for pth in inputs: 52 | if osp.exists(pth) or pth.startswith('http'): 53 | pure_text = False 54 | genai.configure(api_key=self.api_key) 55 | model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel('gemini-pro-vision') 56 | if isinstance(inputs, str): 57 | messages = [inputs] if self.system_prompt is None else [self.system_prompt, inputs] 58 | elif pure_text: 59 | messages = self.build_msgs(inputs, self.system_prompt) 60 | else: 61 | messages = [] if self.system_prompt is None else [self.system_prompt] 62 | for s in inputs: 63 | if osp.exists(s): 64 | messages.append(Image.open(s)) 65 | elif s.startswith('http'): 66 | pth = download_file(s) 67 | messages.append(Image.open(pth)) 68 | shutil.remove(pth) 69 | else: 70 | messages.append(s) 71 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature) 72 | gen_config.update(self.kwargs) 73 | try: 74 | answer = model.generate_content(messages, generation_config=genai.types.GenerationConfig(**gen_config)).text 75 | return 0, answer, 'Succeeded! ' 76 | except Exception as err: 77 | if self.verbose: 78 | self.logger.error(err) 79 | self.logger.error(f"The input messages are {inputs}.") 80 | 81 | return -1, '', '' 82 | 83 | 84 | 85 | class GeminiProVision(GeminiWrapper): 86 | 87 | def generate(self, image_path, prompt, dataset=None): 88 | return super(GeminiProVision, self).generate([image_path, prompt]) 89 | 90 | def multi_generate(self, image_paths, prompt, dataset=None): 91 | return super(GeminiProVision, self).generate(image_paths + [prompt]) 92 | 93 | def interleave_generate(self, ti_list, dataset=None): 94 | return super(GeminiProVision, self).generate(ti_list) -------------------------------------------------------------------------------- /vlmeval/api/gpt.py: -------------------------------------------------------------------------------- 1 | from ..smp import * 2 | import os, sys 3 | from .base import BaseAPI 4 | 5 | APIBASES = { 6 | 'OFFICIAL': "https://api.openai.com/v1/chat/completions", 7 | } 8 | 9 | 10 | def GPT_context_window(model): 11 | length_map = { 12 | 'gpt-4-1106-preview': 128000, 13 | 'gpt-4-vision-preview': 128000, 14 | 'gpt-4': 8192, 15 | 'gpt-4-32k': 32768, 16 | 'gpt-4-0613': 8192, 17 | 'gpt-4-32k-0613': 32768, 18 | 'gpt-3.5-turbo-1106': 16385, 19 | 'gpt-3.5-turbo': 4096, 20 | 'gpt-3.5-turbo-16k': 16385, 21 | 'gpt-3.5-turbo-instruct': 4096, 22 | 'gpt-3.5-turbo-0613': 4096, 23 | 'gpt-3.5-turbo-16k-0613': 16385, 24 | } 25 | if model in length_map: 26 | return length_map[model] 27 | else: 28 | return 4096 29 | 30 | class OpenAIWrapper(BaseAPI): 31 | 32 | is_api: bool = True 33 | 34 | def __init__(self, 35 | model: str = 'gpt-3.5-turbo-0613', 36 | retry: int = 5, 37 | wait: int = 5, 38 | key: str = None, 39 | verbose: bool = True, 40 | system_prompt: str = None, 41 | temperature: float = 0, 42 | timeout: int = 60, 43 | api_base: str = 'OFFICIAL', 44 | max_tokens: int = 1024, 45 | img_size: int = 512, 46 | img_detail: str = 'low', 47 | **kwargs): 48 | 49 | self.model = model 50 | self.cur_idx = 0 51 | self.fail_msg = 'Failed to obtain answer via API. ' 52 | self.max_tokens = max_tokens 53 | self.temperature = temperature 54 | 55 | openai_key = os.environ.get('OPENAI_API_KEY', None) if key is None else key 56 | self.openai_key = openai_key 57 | assert img_size > 0 or img_size == -1 58 | self.img_size = img_size 59 | assert img_detail in ['high', 'low'] 60 | self.img_detail = img_detail 61 | 62 | self.vision = False 63 | if model == 'gpt-4-vision-preview': 64 | self.vision = True 65 | self.timeout = timeout 66 | 67 | assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. ' 68 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 69 | 70 | if api_base in APIBASES: 71 | self.api_base = APIBASES[api_base] 72 | elif api_base.startswith('http'): 73 | self.api_base = api_base 74 | else: 75 | self.logger.error("Unknown API Base. ") 76 | sys.exit(-1) 77 | 78 | # inputs can be a lvl-2 nested list: [content1, content2, content3, ...] 79 | # content can be a string or a list of image & text 80 | def prepare_inputs(self, inputs): 81 | input_msgs = [] 82 | if self.system_prompt is not None: 83 | input_msgs.append(dict(role='system', content=self.system_prompt)) 84 | if isinstance(inputs, str): 85 | input_msgs.append(dict(role='user', content=inputs)) 86 | return input_msgs 87 | assert isinstance(inputs, list) 88 | dict_flag = [isinstance(x, dict) for x in inputs] 89 | if np.all(dict_flag): 90 | input_msgs.extend(inputs) 91 | return input_msgs 92 | str_flag = [isinstance(x, str) for x in inputs] 93 | if np.all(str_flag): 94 | img_flag = [x.startswith('http') or osp.exists(x) for x in inputs] 95 | if np.any(img_flag): 96 | content_list = [] 97 | for fl, msg in zip(img_flag, inputs): 98 | if not fl: 99 | content_list.append(dict(type='text', text=msg)) 100 | elif msg.startswith('http'): 101 | content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail})) 102 | elif osp.exists(msg): 103 | from PIL import Image 104 | img = Image.open(msg) 105 | b64 = encode_image_to_base64(img, target_size=self.img_size) 106 | img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail) 107 | content_list.append(dict(type='image_url', image_url=img_struct)) 108 | input_msgs.append(dict(role='user', content=content_list)) 109 | return input_msgs 110 | else: 111 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] 112 | roles = roles * len(inputs) 113 | for role, msg in zip(roles, inputs): 114 | input_msgs.append(dict(role=role, content=msg)) 115 | return input_msgs 116 | raise NotImplemented("list of list prompt not implemented now. ") 117 | 118 | def generate_inner(self, inputs, **kwargs) -> str: 119 | input_msgs = self.prepare_inputs(inputs) 120 | temperature = kwargs.pop('temperature', self.temperature) 121 | max_tokens = kwargs.pop('max_tokens', self.max_tokens) 122 | 123 | context_window = GPT_context_window(self.model) 124 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) 125 | if 0 < max_tokens <= 100: 126 | self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ') 127 | if max_tokens <= 0: 128 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. ' 129 | 130 | headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'} 131 | payload = dict( 132 | model=self.model, 133 | messages=input_msgs, 134 | max_tokens=max_tokens, 135 | n=1, 136 | temperature=temperature, 137 | **kwargs) 138 | response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1) 139 | ret_code = response.status_code 140 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code 141 | answer = self.fail_msg 142 | try: 143 | resp_struct = json.loads(response.text) 144 | answer = resp_struct['choices'][0]['message']['content'].strip() 145 | except: 146 | pass 147 | return ret_code, answer, response 148 | 149 | def get_token_len(self, inputs) -> int: 150 | import tiktoken 151 | enc = tiktoken.encoding_for_model(self.model) 152 | if isinstance(inputs, str): 153 | if inputs.startswith('http') or osp.exists(inputs): 154 | return 65 if self.img_detail == 'low' else 130 155 | else: 156 | return len(enc.encode(inputs)) 157 | elif isinstance(inputs, dict): 158 | assert 'content' in inputs 159 | return self.get_token_len(inputs['content']) 160 | assert isinstance(inputs, list) 161 | res = 0 162 | for item in inputs: 163 | res += self.get_token_len(item) 164 | return res 165 | 166 | class GPT4V(OpenAIWrapper): 167 | 168 | def generate(self, image_path, prompt, dataset=None): 169 | assert self.model == 'gpt-4-vision-preview' 170 | return super(GPT4V, self).generate([image_path, prompt]) 171 | 172 | def multi_generate(self, image_paths, prompt, dataset=None): 173 | assert self.model == 'gpt-4-vision-preview' 174 | return super(GPT4V, self).generate(image_paths + [prompt]) 175 | 176 | def interleave_generate(self, ti_list, dataset=None): 177 | assert self.model == 'gpt-4-vision-preview' 178 | return super(GPT4V, self).generate(ti_list) 179 | -------------------------------------------------------------------------------- /vlmeval/api/gpt_int.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | import requests 4 | from ..smp import * 5 | from .gpt import GPT_context_window, OpenAIWrapper 6 | 7 | 8 | url = "http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat" 9 | headers = { 10 | "Content-Type": "application/json" 11 | } 12 | 13 | class OpenAIWrapperInternal(OpenAIWrapper): 14 | 15 | is_api: bool = True 16 | 17 | def __init__(self, 18 | model: str = 'gpt-3.5-turbo-0613', 19 | retry: int = 5, 20 | wait: int = 3, 21 | verbose: bool = True, 22 | system_prompt: str = None, 23 | temperature: float = 0, 24 | timeout: int = 60, 25 | max_tokens: int = 1024, 26 | img_size: int = 512, 27 | img_detail: str = 'low', 28 | **kwargs): 29 | 30 | self.model = model 31 | if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']): 32 | keys = load(os.environ['KEYS']) 33 | headers['alles-apin-token'] = keys.get('alles-apin-token', '') 34 | elif 'ALLES' in os.environ: 35 | headers['alles-apin-token'] = os.environ['ALLES'] 36 | self.headers = headers 37 | self.temperature = temperature 38 | self.timeout = timeout 39 | self.max_tokens = max_tokens 40 | 41 | assert img_size > 0 or img_size == -1 42 | self.img_size = img_size 43 | assert img_detail in ['high', 'low'] 44 | self.img_detail = img_detail 45 | 46 | self.vision = False 47 | if model == 'gpt-4-vision-preview': 48 | self.vision = True 49 | 50 | super(OpenAIWrapper, self).__init__( 51 | wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 52 | 53 | def generate_inner(self, inputs, **kwargs) -> str: 54 | input_msgs = self.prepare_inputs(inputs) 55 | 56 | temperature = kwargs.pop('temperature', self.temperature) 57 | max_tokens = kwargs.pop('max_tokens', self.max_tokens) 58 | 59 | # Held out 100 tokens as buffer 60 | context_window = GPT_context_window(self.model) 61 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) 62 | if 0 < max_tokens <= 100: 63 | print('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ') 64 | if max_tokens <= 0: 65 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. ' 66 | 67 | payload = dict( 68 | model=self.model, 69 | messages=input_msgs, 70 | max_tokens=max_tokens, 71 | n=1, 72 | stop=None, 73 | timeout=self.timeout, 74 | temperature=temperature, 75 | **kwargs) 76 | 77 | response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1) 78 | ret_code = response.status_code 79 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code 80 | 81 | answer = self.fail_msg 82 | try: 83 | resp_struct = json.loads(response.text) 84 | assert resp_struct['msg'] == 'ok' and resp_struct['msgCode'] == '10000', resp_struct 85 | answer = resp_struct['data']['choices'][0]['message']['content'].strip() 86 | except: 87 | pass 88 | return ret_code, answer, response 89 | 90 | 91 | class GPT4V_Internal(OpenAIWrapperInternal): 92 | 93 | def generate(self, image_path, prompt, dataset=None): 94 | assert self.model == 'gpt-4-vision-preview' 95 | return super(GPT4V_Internal, self).generate([image_path, prompt]) 96 | 97 | def multi_generate(self, image_paths, prompt, dataset=None): 98 | assert self.model == 'gpt-4-vision-preview' 99 | return super(GPT4V_Internal, self).generate(image_paths + [prompt]) 100 | 101 | def interleave_generate(self, ti_list, dataset=None): 102 | assert self.model == 'gpt-4-vision-preview' 103 | return super(GPT4V_Internal, self).generate(ti_list) -------------------------------------------------------------------------------- /vlmeval/api/hf_chat_model.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | import torch 4 | from ..smp import * 5 | 6 | def get_gpu_num(model_name): 7 | model_name = model_name.lower() 8 | kws = { 9 | 8: ['65b', '70b'], 10 | 4: ['30b', '33b', '35b', '40b'], 11 | 2: ['13b', '14b', '20b'], 12 | 1: ['6b', '7b', 'moss'], 13 | } 14 | for k in [8, 4, 2, 1]: 15 | for keyword in kws[k]: 16 | if keyword in model_name: 17 | return k 18 | return 8 19 | 20 | validated_llms = [ 21 | 'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b', 22 | 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', 23 | 'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k', 24 | 'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat', 25 | 'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5', 26 | 'meta-llama/Llama-2-7b-chat-hf' 27 | ] 28 | Auto_model = ['chatglm'] 29 | 30 | class HFChatModel: 31 | 32 | def _get_context_length(self, model, model_path): 33 | # By default, we use model.config.seq_length 34 | model_path = model_path.lower() 35 | if 'baichuan' in model_path: 36 | context_window = model.config.model_max_length 37 | elif 'internlm' in model_path or 'llama' in model_path: 38 | context_window = model.config.max_position_embeddings 39 | elif 'vicuna' in model_path: 40 | context_window = model.generation_config.max_length 41 | else: 42 | # chatglm & qwen 43 | context_window = model.config.seq_length 44 | return context_window 45 | 46 | def _get_context_length_robust(self, model, model_path): 47 | try: 48 | context_window = self._get_context_length(model, model_path) 49 | return context_window 50 | except: 51 | self.logger.critical( 52 | "Failed to extract context_window information from config / generation_config. " 53 | "Please read the above code and check if the logic works for you model path" 54 | ) 55 | raise NotImplementedError 56 | 57 | def __init__(self, 58 | model_path, 59 | system_prompt: str=None, 60 | **kwargs): 61 | 62 | self.logger = get_logger('HFChatModel') 63 | if 'vicuna' in model_path.lower(): 64 | try: 65 | from fastchat.model import get_conversation_template 66 | except: 67 | self.logger.critical("Please install fastchat first to use vicuna. ") 68 | sys.exit(-1) 69 | 70 | self.explicit_device = kwargs.pop('device', None) 71 | 72 | if self.explicit_device is None: 73 | # If CUDA_VISIBLE_DEVICES is not properly set 74 | if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] in ['', '0,1,2,3,4,5,6,7']: 75 | num_gpu = get_gpu_num(model_path) 76 | gpu_offset = kwargs.pop('gpu_offset', 0) 77 | cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset+num_gpu)]) 78 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices 79 | 80 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel 81 | from transformers.generation import GenerationConfig 82 | 83 | if model_path not in validated_llms: 84 | self.logger.warning(f"{model_path} not in validated LLMs, may have inference troubles. ") 85 | 86 | self.model_path = model_path 87 | if listinstr(Auto_model, model_path): 88 | LoadModel = AutoModel 89 | else: 90 | LoadModel = AutoModelForCausalLM 91 | 92 | assert osp.exists(model_path) or len(model_path.split('/')) == 2 93 | 94 | device = self.explicit_device if self.explicit_device else "auto" 95 | 96 | precision = {} 97 | if 'internlm-chat-7b' in model_path: 98 | precision = {'torch_dtype': torch.float16} 99 | elif 'internlm-chat-20b' in model_path: 100 | precision = {'torch_dtype': torch.bfloat16} 101 | 102 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 103 | model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision) 104 | model = model.eval() 105 | 106 | if device != 'cpu': 107 | model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda') 108 | try: 109 | model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True, device_map=device) 110 | except: 111 | pass 112 | 113 | torch.cuda.empty_cache() 114 | self.model = model 115 | self.context_length = self._get_context_length_robust(model=model, model_path=model_path) 116 | self.answer_buffer = 192 117 | self.system_prompt = system_prompt 118 | for k, v in kwargs.items(): 119 | self.logger.info(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ') 120 | self.kwargs = kwargs 121 | 122 | def generate_str(self, input, **kwargs): 123 | if 'baichuan' in self.model_path.lower(): 124 | messages=[] 125 | messages.append({"role": "user", "content": input}) 126 | resp= self.model.chat(self.tokenizer, messages, **kwargs) 127 | elif 'vicuna' in self.model_path.lower(): 128 | from fastchat.model import get_conversation_template 129 | conv = get_conversation_template('vicuna') 130 | conv.append_message(conv.roles[0], input) 131 | conv.append_message(conv.roles[1], None) 132 | prompt = conv.get_prompt() 133 | inputs = self.tokenizer([prompt], return_tensors="pt") 134 | if torch.cuda.is_available(): 135 | for k in inputs: 136 | inputs[k] = inputs[k].cuda() 137 | 138 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512) 139 | params.update(self.kwargs) 140 | params.update(kwargs) 141 | outputs = self.model.generate(**inputs, **params) 142 | resp = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False) 143 | 144 | else: 145 | params = self.kwargs 146 | params.update(kwargs) 147 | resp, _ = self.model.chat(self.tokenizer, input, history=[], **params) 148 | 149 | return resp 150 | 151 | def length_ok(self, inputs): 152 | tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0 153 | for s in inputs: 154 | tot += len(self.tokenizer.encode(s)) 155 | return tot + self.answer_buffer < self.context_length 156 | 157 | def generate_list(self, full_inputs, offset=0, **kwargs): 158 | assert isinstance(full_inputs, list) 159 | 160 | inputs = full_inputs[offset:] 161 | if not self.length_ok(inputs): 162 | return self.chat(full_inputs, offset + 1) 163 | 164 | model_path = self.model_path.lower() 165 | 166 | if sum([x in model_path for x in ['baichuan']]): 167 | input_msgs = [] 168 | if self.system_prompt is not None: 169 | input_msgs.append(dict(role='user', content=self.system_prompt)) 170 | if len(inputs): 171 | assert isinstance(inputs, list) and isinstance(inputs[0], str) 172 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] 173 | roles = roles * len(inputs) 174 | for role, msg in zip(roles, inputs): 175 | input_msgs.append(dict(role=role, content=msg)) 176 | response = self.model.chat(self.tokenizer, input_msgs) 177 | elif sum([x in model_path for x in ['vicuna']]): 178 | from fastchat.model import get_conversation_template 179 | conv = get_conversation_template('vicuna') 180 | assert isinstance(inputs, list) and isinstance(inputs[0], str) 181 | if len(inputs) % 2 == 1: 182 | if self.system_prompt is not None: 183 | conv.append_message(conv.roles[0], self.system_prompt) 184 | for i in range(len(inputs)//2): 185 | conv.append_message(conv.roles[0], inputs[2 * i]) 186 | conv.append_message(conv.roles[1], inputs[2 * i + 1]) 187 | else: 188 | assert self.system_prompt is not None 189 | conv.append_message(conv.roles[0], self.system_prompt) 190 | conv.append_message(conv.roles[1], inputs[0]) 191 | for i in range(len(inputs) // 2 - 1): 192 | conv.append_message(conv.roles[0], inputs[2 * i + 1]) 193 | conv.append_message(conv.roles[1], inputs[2 * i + 2]) 194 | conv.append_message(conv.roles[0], inputs[-1]) 195 | conv.append_message(conv.roles[1], None) 196 | prompt = conv.get_prompt() 197 | inputs = self.tokenizer([prompt], return_tensors="pt") 198 | if torch.cuda.is_available(): 199 | for k in inputs: 200 | inputs[k] = inputs[k].cuda() 201 | 202 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512) 203 | params.update(self.kwargs) 204 | params.update(kwargs) 205 | 206 | outputs = self.model.generate(**inputs, **params) 207 | response = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False) 208 | response = response.lstrip('\n') 209 | else: 210 | # The default option, support internlm, chatglm, qwen 211 | history, msg = [], None 212 | if len(inputs) % 2 == 1: 213 | if self.system_prompt is not None: 214 | history = [(self.system_prompt, '')] 215 | for i in range(len(inputs)//2): 216 | history.append((inputs[2 * i], inputs[2 * i + 1])) 217 | else: 218 | assert self.system_prompt is not None 219 | history = [(self.system_prompt, inputs[0])] 220 | for i in range(len(inputs) // 2 - 1): 221 | history.append((inputs[2 * i + 1], inputs[2 * i + 2])) 222 | msg = inputs[-1] 223 | 224 | params = self.kwargs 225 | params.update(kwargs) 226 | response, _ = self.model.chat(self.tokenizer, msg, history=history, **params) 227 | 228 | return response, offset 229 | 230 | def generate(self, inputs, **kwargs): 231 | if isinstance(inputs, str): 232 | return self.generate_str(inputs, **kwargs) 233 | elif isinstance(inputs, list): 234 | return self.generate_list(inputs, **kwargs) -------------------------------------------------------------------------------- /vlmeval/api/qwen_vl_api.py: -------------------------------------------------------------------------------- 1 | from vlmeval.smp import * 2 | from vlmeval.api.base import BaseAPI 3 | 4 | class QwenVLWrapper(BaseAPI): 5 | 6 | is_api: bool = True 7 | 8 | def __init__(self, 9 | model: str = 'qwen-vl-plus', 10 | retry: int = 5, 11 | wait: int = 5, 12 | key: str = None, 13 | verbose: bool = True, 14 | temperature: float = 0.0, 15 | system_prompt: str = None, 16 | max_tokens: int = 1024, 17 | proxy: str = None, 18 | **kwargs): 19 | 20 | assert model in ['qwen-vl-plus', 'qwen-vl-max'] 21 | self.model = model 22 | import dashscope 23 | self.fail_msg = 'Failed to obtain answer via API. ' 24 | self.max_tokens = max_tokens 25 | self.temperature = temperature 26 | if key is None: 27 | key = os.environ.get('DASHSCOPE_API_KEY', None) 28 | assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)" 29 | dashscope.api_key = key 30 | if proxy is not None: 31 | proxy_set(proxy) 32 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 33 | 34 | @staticmethod 35 | def build_msgs(msgs_raw, system_prompt=None): 36 | msgs = cp.deepcopy(msgs_raw) 37 | ret = [] 38 | if system_prompt is not None: 39 | content = list(dict(text=system_prompt)) 40 | ret.append(dict(role='system', content=content)) 41 | content = [] 42 | for i, msg in enumerate(msgs): 43 | if osp.exists(msg): 44 | content.append(dict(image='file://' + msg)) 45 | elif msg.startswith('http'): 46 | content.append(dict(image=msg)) 47 | else: 48 | content.append(dict(text=msg)) 49 | ret.append(dict(role='user', content=content)) 50 | return ret 51 | 52 | def generate_inner(self, inputs, **kwargs) -> str: 53 | from dashscope import MultiModalConversation 54 | assert isinstance(inputs, str) or isinstance(inputs, list) 55 | pure_text = True 56 | if isinstance(inputs, list): 57 | for pth in inputs: 58 | if osp.exists(pth) or pth.startswith('http'): 59 | pure_text = False 60 | assert not pure_text 61 | messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt) 62 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature) 63 | gen_config.update(self.kwargs) 64 | try: 65 | response = MultiModalConversation.call(model=self.model, messages=messages) 66 | if self.verbose: 67 | print(response) 68 | answer = response.output.choices[0]['message']['content'][0]['text'] 69 | return 0, answer, 'Succeeded! ' 70 | except Exception as err: 71 | if self.verbose: 72 | self.logger.error(err) 73 | self.logger.error(f"The input messages are {inputs}.") 74 | 75 | return -1, '', '' 76 | 77 | class QwenVLAPI(QwenVLWrapper): 78 | 79 | def generate(self, image_path, prompt, dataset=None): 80 | return super(QwenVLAPI, self).generate([image_path, prompt]) 81 | 82 | def multi_generate(self, image_paths, prompt, dataset=None): 83 | return super(QwenVLAPI, self).generate(image_paths + [prompt]) 84 | 85 | def interleave_generate(self, ti_list, dataset=None): 86 | return super(QwenVLAPI, self).generate(ti_list) 87 | -------------------------------------------------------------------------------- /vlmeval/config.py: -------------------------------------------------------------------------------- 1 | from .vlm import * 2 | from .api import GPT4V, GPT4V_Internal, GeminiProVision, QwenVLAPI 3 | from functools import partial 4 | 5 | PandaGPT_ROOT = None 6 | MiniGPT4_ROOT = None 7 | TransCore_ROOT = None 8 | Yi_ROOT = None 9 | 10 | api_models = { 11 | 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10), 12 | 'GPT4V_INT': partial(GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10), 13 | 'GPT4V_SHORT': partial( 14 | GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10, 15 | system_prompt="Please responde to the following question / request in a short reply. "), 16 | 'GPT4V_SHORT_INT': partial( 17 | GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10, 18 | system_prompt="Please responde to the following question / request in a short reply. "), 19 | 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10), 20 | 'QwenVLPlus': partial(QwenVLAPI, model='qwen-vl-plus', temperature=0, retry=10), 21 | 'QwenVLMax': partial(QwenVLAPI, model='qwen-vl-max', temperature=0, retry=10), 22 | } 23 | 24 | models = { 25 | 'hpt-air-mmmu': partial(HPT), 26 | 'hpt-air-mmbench': partial(HPT, vis_scale=392, is_crop=False), 27 | 'hpt-air-seed': partial(HPT, is_crop=False), 28 | 'hpt-air-demo': partial(HPT, vis_scale=392, is_crop=False), 29 | 'hpt-air-demo-local': partial(HPT, vis_scale=392, is_crop=False, global_model_path='../HPT_AIR_HF/'), 30 | 'hpt-air-1-5': partial(HPT1_5, global_model_path='HyperGAI/HPT1_5-Air-Llama-3-8B-Instruct-multimoda', vis_scale=448, prompt_template='llama3_chat'), 31 | 'hpt-edge-1-5': partial(HPT1_5, global_model_path='HyperGAI/HPT1_5-Edge', vis_scale=490, prompt_template='phi3_chat'), 32 | } 33 | 34 | supported_VLM = {} 35 | for model_set in [models, api_models]: 36 | supported_VLM.update(model_set) -------------------------------------------------------------------------------- /vlmeval/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .yes_or_no import default_rating, MME_rating, YOrN_eval 2 | from .mmvet_eval import MMVet_eval 3 | from .multiple_choice import multiple_choice_eval 4 | from .coco_eval import COCO_eval 5 | from .vqa_eval import VQAEval 6 | from .mathvista_eval import MathVista_eval 7 | from .llavabench import LLaVABench_eval 8 | from .misc import build_judge 9 | from .mmmu_eval import mmmu_eval_func -------------------------------------------------------------------------------- /vlmeval/evaluate/coco_eval.py: -------------------------------------------------------------------------------- 1 | from vlmeval.smp import * 2 | from pycocoevalcap.bleu.bleu import Bleu 3 | from pycocoevalcap.rouge.rouge import Rouge 4 | from pycocoevalcap.cider.cider import Cider 5 | 6 | 7 | class COCO_Caption_Scorer(): 8 | def __init__(self, ref, gt): 9 | self.ref = ref 10 | self.gt = gt 11 | print('setting up scorers...') 12 | self.scorers = [ 13 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 14 | # (Meteor(), "METEOR"), # need java version 11.0.16+ 15 | (Rouge(), "ROUGE_L"), 16 | (Cider(), "CIDEr"), 17 | # (Spice(), "SPICE"), # need java version 11.0.16+ 18 | ] 19 | 20 | def compute_scores(self): 21 | total_scores = {} 22 | for scorer, method in self.scorers: 23 | print('computing %s score...' % (scorer.method())) 24 | score, scores = scorer.compute_score(self.gt, self.ref) 25 | if type(method) == list: 26 | for sc, scs, m in zip(score, scores, method): 27 | print("%s: %0.3f" % (m, sc * 100)) 28 | total_scores["Bleu"] = [x * 100 for x in score] 29 | else: 30 | print("%s: %0.3f" % (method, score * 100)) 31 | total_scores[method] = score * 100 32 | 33 | print('*****DONE*****') 34 | for key, value in total_scores.items(): 35 | print('{}:{}'.format(key, value)) 36 | return total_scores 37 | 38 | def COCO_eval(eval_file, nproc=4, verbose=False): 39 | logger = get_logger('Evaluation') 40 | 41 | data = load(eval_file) 42 | 43 | lt = len(data) 44 | lines = [data.iloc[i] for i in range(lt)] 45 | ref = {} 46 | gt = {} 47 | for i,(line) in enumerate(lines): 48 | ref[str(i)] = [str(line['prediction'])] 49 | gt[str(i)] = eval(line['answer']) 50 | 51 | scorer = COCO_Caption_Scorer(ref,gt) 52 | coco_caption_score_dict = scorer.compute_scores() 53 | 54 | score_pth = eval_file.replace('.xlsx','_score.json') 55 | dump(coco_caption_score_dict, score_pth) 56 | logger.info(f'COCO_eval successfully finished evaluating {eval_file}, results saved in {score_pth}') 57 | logger.info(f'Score: ') 58 | for key, value in coco_caption_score_dict.items(): 59 | logger.info('{}:{}'.format(key, value)) 60 | 61 | def parse_args(): 62 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ") 63 | parser.add_argument("--data", type=str, help="The question set for inference, in excel / tsv / json format. ") 64 | parser.add_argument("--nproc", type=int, default=4) 65 | parser.add_argument("--verbose", action='store_true') 66 | args = parser.parse_args() 67 | return args 68 | 69 | if __name__ == '__main__': 70 | args = parse_args() 71 | COCO_eval(eval_file=args.data, nproc=args.nproc, verbose=args.verbose) 72 | 73 | -------------------------------------------------------------------------------- /vlmeval/evaluate/llavabench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import os.path as osp 5 | from vlmeval.evaluate.misc import build_judge 6 | from vlmeval.smp import get_logger, load, dump, defaultdict 7 | from vlmeval.utils import track_progress_rich 8 | 9 | rule_dict = { 10 | "llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 11 | "llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 12 | "llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."} 13 | } 14 | 15 | def get_eval(judge, content): 16 | return judge.generate(content) 17 | 18 | def parse_score(review): 19 | logger = get_logger('Evaluation') 20 | try: 21 | score_pair = review.split('\n')[0] 22 | score_pair = score_pair.replace(',', ' ') 23 | sp = score_pair.split(' ') 24 | if len(sp) == 2: 25 | return [float(sp[0]), float(sp[1])] 26 | else: 27 | logger.error('error', review) 28 | return [-1, -1] 29 | except Exception as e: 30 | logger.error(e, 'error', review) 31 | return [-1, -1] 32 | 33 | def build_prompt(line): 34 | cap_str = line['caption'] 35 | question = line['question'] 36 | ans1 = line['gpt4_ans'] 37 | ans2 = line['prediction'] 38 | category = 'llava_bench_' + line['category'] 39 | rule = rule_dict[category] 40 | role, prompt = rule['role'], rule['prompt'] 41 | 42 | content = (f'[Context]\n{cap_str}\n\n' 43 | f'[Question]\n{question}\n\n' 44 | f'[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n' 45 | f'[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n' 46 | f'[System]\n{prompt}\n\n') 47 | return content 48 | 49 | def LLaVABench_atomeval(model, prompt): 50 | review = get_eval(model, prompt) 51 | scores = parse_score(review) 52 | return scores 53 | 54 | def LLaVABench_score(data): 55 | cates = ['overall'] + list(set(data['category'])) 56 | ret = defaultdict(list) 57 | 58 | 59 | for c in cates: 60 | ret['split'].append(c) 61 | sub = data[data['category'] == c] if c != 'overall' else data 62 | ret['Relative Score (main)'].append(np.mean(sub['score']) / np.mean(sub['gpt4_score']) * 100) 63 | ret['VLM Score'].append(np.mean(sub['score']) * 10) 64 | ret['GPT4 Score'].append(np.mean(sub['gpt4_score']) * 10) 65 | return pd.DataFrame(ret) 66 | 67 | def LLaVABench_eval(eval_file, model='gpt-4-0314', nproc=4, verbose=False): 68 | suffix = '.' + eval_file.split('.')[-1] 69 | record_file = eval_file.replace(suffix, '_openai_result' + suffix) 70 | score_file = eval_file.replace(suffix, '_score.csv') 71 | 72 | if not osp.exists(record_file): 73 | data = load(eval_file) 74 | lines = [data.iloc[i] for i in range(len(data))] 75 | model = build_judge( 76 | model, temperature=0.2, retry=10, verbose=verbose, 77 | system_prompt='You are a helpful and precise assistant for checking the quality of the answer.') 78 | prompts = [build_prompt(line) for line in lines] 79 | tups = [(model, prompt) for prompt in prompts] 80 | scores = track_progress_rich(LLaVABench_atomeval, tups, nproc=nproc, chunksize=nproc) 81 | data['gpt4_score'] = [x[0] for x in scores] 82 | data['score'] = [x[1] for x in scores] 83 | dump(data, record_file) 84 | 85 | data = load(record_file) 86 | ret = LLaVABench_score(data).round(1) 87 | print(ret) 88 | dump(ret, score_file) 89 | return ret 90 | 91 | def parse_args(): 92 | parser = argparse.ArgumentParser(description="LLaVABench Evaluation. ") 93 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ") 94 | parser.add_argument( 95 | "--model", type=str, help="The LLM (GPT) used for inference. ", default="gpt-4-turbo", 96 | choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613', 'gpt-4-0314']) 97 | parser.add_argument("--nproc", type=int, default=4) 98 | parser.add_argument("--verbose", action='store_true') 99 | args = parser.parse_args() 100 | return args 101 | 102 | if __name__ == '__main__': 103 | args = parse_args() 104 | LLaVABench_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose) 105 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mathvista_eval.py: -------------------------------------------------------------------------------- 1 | from vlmeval.evaluate.misc import build_judge 2 | from vlmeval.smp import * 3 | from vlmeval.utils import track_progress_rich 4 | from vlmeval.utils.matching_util import can_infer 5 | 6 | def get_gpt4_ICE(): 7 | example_1 = """ 8 | Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.\n 9 | Question: Which number is missing?\n 10 | Model response: The number missing in the sequence is 14.\n 11 | Extracted answer: 14 12 | """ 13 | 14 | example_2 = """ 15 | Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.\n 16 | Question: What is the fraction of females facing the camera?\n 17 | Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.\n 18 | Extracted answer: 0.6 19 | """ 20 | 21 | example_3 = """ 22 | Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.\n 23 | Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n 24 | Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n 25 | Extracted answer: 1.45 26 | """ 27 | 28 | example_4 = """ 29 | Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n 30 | Question: Between which two years does the line graph saw its maximum peak?\n 31 | Model response: The line graph saw its maximum peak between 2007 and 2008.\n 32 | Extracted answer: [2007, 2008] 33 | """ 34 | 35 | example_5 = """ 36 | Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n 37 | Question: What fraction of the shape is blue?\n 38 | Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n 39 | Model response: The correct answer is (B) 8/11.\n 40 | Extracted answer: B 41 | """ 42 | return [example_1,example_2,example_3,example_4,example_5] 43 | 44 | 45 | def build_mathvista_gpt4_prompt(line): 46 | task_description = """ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.\n""" 47 | question = line['question'] 48 | prediction = str(line['prediction']) 49 | prompt = task_description 50 | examples = get_gpt4_ICE() 51 | for example in examples: 52 | prompt += example + '\n' 53 | prompt += question + '\n' 54 | prompt += 'Model respone: ' + prediction 55 | prompt += 'Extracted answer:' 56 | return prompt 57 | 58 | def list_to_dict(lst): 59 | return {chr(65 + i): val for i, val in enumerate(lst)} 60 | 61 | def post_check(line, prefetch=False): 62 | res = None 63 | ans = line['answer'] 64 | response = line['prediction'] if prefetch else line['res'] 65 | try: 66 | if line['question_type'] == 'multi_choice': 67 | ans = line['answer_option'] 68 | choices = list_to_dict(eval(line['choices'])) 69 | res = can_infer(response, choices) 70 | if prefetch: 71 | return res 72 | else: 73 | if line['answer_type'] == 'integer': 74 | res = int(response) 75 | ans = int(line['answer']) 76 | elif line['answer_type'] == 'float': 77 | res = float(response) 78 | ans = float(line['answer']) 79 | else: 80 | res = str(res) 81 | ans = str(ans) 82 | except ValueError: 83 | pass 84 | 85 | if res == ans: 86 | return res 87 | else: 88 | return False 89 | 90 | def MathVista_auxeval(model, line): 91 | prompt = build_mathvista_gpt4_prompt(line) 92 | log = '' 93 | retry = 5 94 | if post_check(line, prefetch=True): 95 | res = post_check(line, prefetch=True) 96 | return dict(log='Prefetch succeed', res=res) 97 | for i in range(retry): 98 | prediction = line['prediction'] 99 | res = model.generate(prompt, temperature=i * 0.5) 100 | if res is None: 101 | log += f'Try {i}: output is {prediction}, failed to parse.\n' 102 | else: 103 | log += 'Succeed' 104 | return dict(log=log, res= res) 105 | log += 'All 5 retries failed.\n' 106 | return dict(log=log, res='') 107 | 108 | def MathVista_acc(result_file): 109 | data = load(result_file) 110 | tot = defaultdict(lambda: 0) 111 | fetch = defaultdict(lambda: 0) 112 | hit = defaultdict(lambda: 0) 113 | lt = len(data) 114 | skill_list = [] 115 | for i in range(lt): 116 | item = data.iloc[i] 117 | index = item['index'] 118 | cate = item['task'] 119 | tot['Overall'] += 1 120 | try: 121 | skills = eval(item['skills']) 122 | except SyntaxError: 123 | skills = [item['skills']] 124 | for skill in skills: 125 | if skill not in skill_list: 126 | skill_list.append(skill) 127 | tot[skill] += 1 128 | tot[cate] += 1 129 | if item['log'] == 'Prefetch succeed': 130 | fetch['Overall'] += 1 131 | fetch[cate] += 1 132 | for skill in skills: 133 | fetch[skill] += 1 134 | if post_check(item, prefetch=False): 135 | hit['Overall'] += 1 136 | hit[cate] += 1 137 | for skill in skills: 138 | hit[skill] += 1 139 | 140 | res = defaultdict(list) 141 | for k in tot.keys(): 142 | res['Task&Skill'].append(k) 143 | res['tot'].append(tot[k]) 144 | res['prefetch'].append(fetch[k]) 145 | res['hit'].append(hit[k]) 146 | res['prefetch_rate'].append(fetch[k] / tot[k] * 100) 147 | res['acc'].append(hit[k] / tot[k] * 100) 148 | res = pd.DataFrame(res) 149 | return res 150 | 151 | def MathVista_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False): 152 | logger = get_logger('Evaluation') 153 | 154 | suffix = eval_file.split('.')[-1] 155 | storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx') 156 | tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl') 157 | if osp.exists(storage): 158 | logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MathVista_eval. ") 159 | else: 160 | data = load(eval_file) 161 | gpt_version = model 162 | model = build_judge(gpt_version, verbose=verbose, max_tokens=128, retry=10) 163 | 164 | lt = len(data) 165 | lines = [data.iloc[i] for i in range(lt)] 166 | tups = [(model, line) for line in lines] 167 | indices = [line['index'] for line in lines] 168 | 169 | ans = {} 170 | if osp.exists(tmp_file): 171 | ans = load(tmp_file) 172 | tups = [x for x, i in zip(tups, indices) if i not in ans] 173 | indices = [i for i in indices if i not in ans] 174 | 175 | if len(indices): 176 | new_results = track_progress_rich( 177 | MathVista_auxeval, tups, nproc=nproc, chunksize=nproc, 178 | keys=indices, save=tmp_file) 179 | ans = load(tmp_file) 180 | for k, v in zip(indices, new_results): 181 | assert k in ans 182 | assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res'] 183 | 184 | log_map, res_map = {}, {} 185 | all_inds = [line['index'] for line in lines] 186 | for k in all_inds: 187 | log_map[k] = ans[k]['log'] 188 | res_map[k] = ans[k]['res'] 189 | data['res'] = [res_map[idx] for idx in data['index']] 190 | data['log'] = [log_map[idx] for idx in data['index']] 191 | dump(data, storage) 192 | 193 | score = MathVista_acc(storage) 194 | score_pth = storage.replace('.xlsx','_score.csv') 195 | 196 | dump(score,score_pth) 197 | logger.info(f'MathVista_eval successfully finished evaluating {eval_file}, results saved in {score_pth}') 198 | logger.info(f'Score: ') 199 | logger.info(score) 200 | 201 | def parse_args(): 202 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ") 203 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ") 204 | parser.add_argument( 205 | "--model", 206 | type=str, 207 | help="The LLM (GPT) used for inference. ", 208 | default="gpt-4-turbo", 209 | choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613']) 210 | parser.add_argument("--nproc", type=int, default=4) 211 | parser.add_argument("--verbose", action='store_true') 212 | args = parser.parse_args() 213 | return args 214 | 215 | if __name__ == '__main__': 216 | args = parse_args() 217 | MathVista_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose) 218 | 219 | -------------------------------------------------------------------------------- /vlmeval/evaluate/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal 3 | 4 | INTERNAL = os.environ.get('INTERNAL', 0) 5 | 6 | def build_judge(version, **kwargs): 7 | model_map = { 8 | 'gpt-4-turbo': 'gpt-4-1106-preview', 9 | 'gpt-4-0613': 'gpt-4-0613', 10 | 'gpt-4-0314': 'gpt-4-0314', 11 | 'chatgpt-1106': 'gpt-3.5-turbo-1106', 12 | 'chatgpt-0613': 'gpt-3.5-turbo-0613' 13 | } 14 | model_version = model_map[version] 15 | if INTERNAL: 16 | model = OpenAIWrapperInternal(model_version, **kwargs) 17 | else: 18 | model = OpenAIWrapper(model_version, **kwargs) 19 | return model -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation Guidelines 2 | We provide detailed instructions for evaluation. 3 | To execute our evaluation script, please ensure that the structure of your model outputs is the same as ours. 4 | 5 | We provide two options: 6 | 1. Evaluation only: you can parse the response on your own and simply provide one file with all the final predictions. 7 | 2. Parse and evaluation: you can leave all the responses to us with the output formats shown below. 8 | 9 | ## Evaluation Only 10 | If you want to use your own parsing logic and *only provide the final answer*, you can use `main_eval_only.py`. 11 | 12 | You can provide all the outputs in *one file* in the following format: 13 | 14 | ``` 15 | { 16 | "validation_Accounting_1": "D", # strictly "A", "B", "C", "D" for multi-choice question 17 | "validation_Architecture_and_Engineering_14": "0.0", # any string response for open question. 18 | ... 19 | } 20 | ``` 21 | Then run eval_only with: 22 | ``` 23 | python main_eval_only.py --output_path ./example_outputs/llava1.5_13b/total_val_output.json 24 | ``` 25 | 26 | Please refer to [example output](https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/example_outputs/llava1.5_13b/total_val_output.json) for a detailed prediction file form. 27 | 28 | 29 | ## Parse and Evaluation 30 | You can also provide response and run the `main_parse_and_eval.py` to use our answer parsing processing and evaluation pipeline as follows: 31 | 32 | ### Output folder structure 33 | 34 | ``` 35 | └── model_name 36 | ├── category_name (e.g., Accounting) 37 | │ ├── output.json 38 | └── category_name (e.g., Electronics) 39 | ├── output.json 40 | ... 41 | ``` 42 | 43 | ### Output file 44 | Each `output.json`` has a list of dict containing instances for evaluation (). 45 | ``` 46 | [ 47 | { 48 | "id": "validation_Electronics_28", 49 | "question_type": "multiple-choice", 50 | "answer": "A", # given answer 51 | "all_choices": [ # create using `get_multi_choice_info` in 52 | "A", 53 | "B", 54 | "C", 55 | "D" 56 | ], 57 | "index2ans": { # create using `get_multi_choice_info` in 58 | "A": "75 + 13.3 cos(250t - 57.7°)V", 59 | "B": "75 + 23.3 cos(250t - 57.7°)V", 60 | "C": "45 + 3.3 cos(250t - 57.7°)V", 61 | "D": "95 + 13.3 cos(250t - 57.7°)V" 62 | }, 63 | "response": "B" # model response 64 | }, 65 | { 66 | "id": "validation_Electronics_29", 67 | "question_type": "short-answer", 68 | "answer": "30", # given answer 69 | "response": "36 watts" # model response 70 | }, 71 | ... 72 | ] 73 | ``` 74 | 75 | ### Evaluation 76 | ``` 77 | python main_parse_and_eval.py --path ./example_outputs/llava1.5_13b --subject ALL # all subject 78 | 79 | # OR you can sepecify one subject for the evaluation 80 | 81 | python main_parse_and_eval.py --path ./example_outputs/llava1.5_13b --subject elec # short name for Electronics. use --help for all short names 82 | 83 | ``` 84 | 85 | `main_parse_and_eval.py` will generate `parsed_output.json` and `result.json` in the subfolder under the same category with output.json, respectively. 86 | 87 | ``` 88 | ├── Accounting 89 | │ ├── output.json 90 | │ ├── parsed_output.json 91 | │ └── result.json 92 | └── Electronics 93 | ├── output.json 94 | ├── parsed_output.json 95 | └── result.json 96 | ... 97 | ``` 98 | 99 | ### Print Results 100 | You can print results locally if you want. (use `pip install tabulate` if you haven't) 101 | ``` 102 | python print_results.py --path ./example_outputs/llava1.5_13b 103 | # Results may be slightly different due to the ramdon selection for fail response 104 | ``` 105 | 106 | 107 | 108 | ##### Run Llava 109 | In case if you want to reproduce the results of some of the models, please go check run_llava.py as an example. 110 | 111 | By seeting up the env following the [llava official repo](https://github.com/haotian-liu/LLaVA) and installing `datasets` packages by huggingface, you can run llava viathe following command: 112 | 113 | ``` 114 | CUDA_VISIBLE_DEVICES=0 nohup python run_llava.py \ 115 | --output_path example_outputs/llava1.5_13b_val.json \ 116 | --model_path liuhaotian/llava-v1.5-13b \ 117 | --config_path configs/llava1.5.yaml 118 | ``` 119 | 120 | Then you can evaluate the results via the very first pipeline. 121 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmmu_eval_script import mmmu_eval_func 2 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/configs/llava1.5.yaml: -------------------------------------------------------------------------------- 1 | task_instructions: 2 | - "" 3 | multi_choice_example_format: 4 | - "{} 5 | 6 | {} 7 | 8 | Answer with the option's letter from the given choices directly." 9 | 10 | short_ans_example_format: 11 | - "{} 12 | 13 | Answer the question using a single word or phrase." 14 | temperature: 15 | - 0 -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/example_outputs/llava1.5_13b/Electronics/output.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "validation_Electronics_1", 4 | "question_type": "multiple-choice", 5 | "answer": "A", 6 | "all_choices": [ 7 | "A", 8 | "B" 9 | ], 10 | "index2ans": { 11 | "A": "yes, saturation", 12 | "B": "no, not in saturation" 13 | }, 14 | "response": "B" 15 | }, 16 | { 17 | "id": "validation_Electronics_2", 18 | "question_type": "short-answer", 19 | "answer": "2.83", 20 | "response": "0.5" 21 | }, 22 | { 23 | "id": "validation_Electronics_3", 24 | "question_type": "multiple-choice", 25 | "answer": "C", 26 | "all_choices": [ 27 | "A", 28 | "B", 29 | "C", 30 | "D" 31 | ], 32 | "index2ans": { 33 | "A": "t + (1 / 10) (e^{-20t} - 1) V", 34 | "B": "t + (1 / 20) (e^{-10t} - 1) V", 35 | "C": "t + (1 / 10) (e^{-10t} - 1) V", 36 | "D": "t - (1 / 10) (e^{-10t} - 1) V" 37 | }, 38 | "response": "C" 39 | }, 40 | { 41 | "id": "validation_Electronics_4", 42 | "question_type": "short-answer", 43 | "answer": "8.4", 44 | "response": "1.5" 45 | }, 46 | { 47 | "id": "validation_Electronics_5", 48 | "question_type": "short-answer", 49 | "answer": "62.6", 50 | "response": "500" 51 | }, 52 | { 53 | "id": "validation_Electronics_6", 54 | "question_type": "short-answer", 55 | "answer": "71.6", 56 | "response": "5" 57 | }, 58 | { 59 | "id": "validation_Electronics_7", 60 | "question_type": "multiple-choice", 61 | "answer": "C", 62 | "all_choices": [ 63 | "A", 64 | "B", 65 | "C", 66 | "D" 67 | ], 68 | "index2ans": { 69 | "A": "$\\sqrt(2) cos[(1/3)t]$", 70 | "B": "$\\sqrt(3) cos[(2/3)t]$", 71 | "C": "$\\sqrt(2) cos[(2/3)t]$", 72 | "D": "$\\sqrt(3) cos[(4/3)t]$" 73 | }, 74 | "response": "A" 75 | }, 76 | { 77 | "id": "validation_Electronics_8", 78 | "question_type": "multiple-choice", 79 | "answer": "B", 80 | "all_choices": [ 81 | "A", 82 | "B", 83 | "C", 84 | "D" 85 | ], 86 | "index2ans": { 87 | "A": "(4 / \\pi) (sin \\pit + (1 / 2) sin 3\\pit + (1 / 4) sin 5\\pit + ....).", 88 | "B": "(4 / \\pi) (sin \\pit + (1 / 3) sin 3\\pit + (1 / 5) sin 5\\pit + ....).", 89 | "C": "(4 / \\pi) (sin \\pit + (1 / 2) sin 2\\pit + (1 / 4) sin 4\\pit + ....).", 90 | "D": "(4 / \\pi) (sin \\pit + (1 / 3) sin 2\\pit + (1 / 5) sin 4\\pit + ....)." 91 | }, 92 | "response": "A" 93 | }, 94 | { 95 | "id": "validation_Electronics_9", 96 | "question_type": "multiple-choice", 97 | "answer": "C", 98 | "all_choices": [ 99 | "A", 100 | "B", 101 | "C", 102 | "D" 103 | ], 104 | "index2ans": { 105 | "A": "-2[sin t + (1 / 2) sin 25 + (1 / 4) sin 3t + ...]", 106 | "B": "-2[sin t + (1 / 2) sin 30 + (1 / 3) sin 3t + ...]", 107 | "C": "-2[sin t + (1 / 2) sin 25 + (1 / 3) sin 3t + ...]", 108 | "D": "-2[sin t + (1 / 3) sin 25 + (1 / 3) sin 3t + ...]" 109 | }, 110 | "response": "A" 111 | }, 112 | { 113 | "id": "validation_Electronics_10", 114 | "question_type": "multiple-choice", 115 | "answer": "A", 116 | "all_choices": [ 117 | "A", 118 | "B", 119 | "C", 120 | "D" 121 | ], 122 | "index2ans": { 123 | "A": "0.125 + j0.330", 124 | "B": "0.15 + j0.330", 125 | "C": "0.125 + j0.390", 126 | "D": "0.121 + j0.380" 127 | }, 128 | "response": "A" 129 | }, 130 | { 131 | "id": "validation_Electronics_11", 132 | "question_type": "short-answer", 133 | "answer": "0.3", 134 | "response": "0" 135 | }, 136 | { 137 | "id": "validation_Electronics_12", 138 | "question_type": "short-answer", 139 | "answer": "10", 140 | "response": "24" 141 | }, 142 | { 143 | "id": "validation_Electronics_13", 144 | "question_type": "multiple-choice", 145 | "answer": "C", 146 | "all_choices": [ 147 | "A", 148 | "B", 149 | "C", 150 | "D" 151 | ], 152 | "index2ans": { 153 | "A": "[(1 - e^{-s(T/2)}) / {s(1 - e^{-sT})}]", 154 | "B": "[(2 - e^{-s(T)}) / {s(1 - e^{-sT})}]", 155 | "C": "[(1 - e^{-s(T/2)}) / {s(1 - e^{-sT})}]", 156 | "D": "[(1 - e^{-s(T/3)}) / {s(1 - e^{-sT})}]" 157 | }, 158 | "response": "A" 159 | }, 160 | { 161 | "id": "validation_Electronics_14", 162 | "question_type": "multiple-choice", 163 | "answer": "C", 164 | "all_choices": [ 165 | "A", 166 | "B", 167 | "C", 168 | "D" 169 | ], 170 | "index2ans": { 171 | "A": "4t^2 [u(t) - u(t - 5)] + 20[u(t - 2) - u(t - 5)] + 15(t - 7)[u(t - 5) - u(t - 7)]", 172 | "B": "4t^2 [u(t) - u(t - 2)] + 20[u(t - 2) - u(t - 5)] + 15(t - 7)[u(t - 5) - u(t - 7)]", 173 | "C": "5t^2 [u(t) - u(t - 2)] + 20[u(t - 2) - u(t - 5)] + 15(t - 7)[u(t - 5) - u(t - 7)]", 174 | "D": "5t^2 [u(t) - u(t - 2)] + 20[u(t - 2) - u(t - 3)] + 15(t - 7)[u(t - 5) - u(t - 7)]" 175 | }, 176 | "response": "A" 177 | }, 178 | { 179 | "id": "validation_Electronics_15", 180 | "question_type": "multiple-choice", 181 | "answer": "A", 182 | "all_choices": [ 183 | "A", 184 | "B", 185 | "C", 186 | "D" 187 | ], 188 | "index2ans": { 189 | "A": "$[1 + sech^2 (sin t)] cos t$", 190 | "B": "$[1 + sech^2 (sin t)] sin t$", 191 | "C": "$[1 - sech^2 (sin t)] sin t$", 192 | "D": "$[1 - sech^2 (cos t)] sin t$" 193 | }, 194 | "response": "A" 195 | }, 196 | { 197 | "id": "validation_Electronics_16", 198 | "question_type": "short-answer", 199 | "answer": "20", 200 | "response": "100" 201 | }, 202 | { 203 | "id": "validation_Electronics_17", 204 | "question_type": "multiple-choice", 205 | "answer": "A", 206 | "all_choices": [ 207 | "A", 208 | "B", 209 | "C", 210 | "D" 211 | ], 212 | "index2ans": { 213 | "A": "10e^{-0.8t} V", 214 | "B": "-5e^{-0.8t} V", 215 | "C": "-2e^{-1t} V", 216 | "D": "-6e^{-2t} V" 217 | }, 218 | "response": "A" 219 | }, 220 | { 221 | "id": "validation_Electronics_18", 222 | "question_type": "multiple-choice", 223 | "answer": "B", 224 | "all_choices": [ 225 | "A", 226 | "B", 227 | "C", 228 | "D" 229 | ], 230 | "index2ans": { 231 | "A": "2 e^{-2t} u(t)", 232 | "B": "3 e^{-2t} u(t)", 233 | "C": "2.2 e^{-2t} u(t)", 234 | "D": "3 e^{-3t} u(t)" 235 | }, 236 | "response": "A" 237 | }, 238 | { 239 | "id": "validation_Electronics_19", 240 | "question_type": "multiple-choice", 241 | "answer": "A", 242 | "all_choices": [ 243 | "A", 244 | "B", 245 | "C", 246 | "D" 247 | ], 248 | "index2ans": { 249 | "A": "-90 cos(t) V", 250 | "B": "-90 cos(2t) V", 251 | "C": "90 cos(2t) V", 252 | "D": "90 sin(2t) V" 253 | }, 254 | "response": "B" 255 | }, 256 | { 257 | "id": "validation_Electronics_20", 258 | "question_type": "multiple-choice", 259 | "answer": "C", 260 | "all_choices": [ 261 | "A", 262 | "B", 263 | "C", 264 | "D" 265 | ], 266 | "index2ans": { 267 | "A": "$3 \\pi x 10^-5 A$", 268 | "B": "$\\pi x 10^-5 A$", 269 | "C": "$2 \\pi x 10^-5 A$", 270 | "D": "$\\pi x 10^-4 A$" 271 | }, 272 | "response": "B" 273 | }, 274 | { 275 | "id": "validation_Electronics_21", 276 | "question_type": "short-answer", 277 | "answer": "0.9965", 278 | "response": "0" 279 | }, 280 | { 281 | "id": "validation_Electronics_22", 282 | "question_type": "multiple-choice", 283 | "answer": "C", 284 | "all_choices": [ 285 | "A", 286 | "B", 287 | "C", 288 | "D" 289 | ], 290 | "index2ans": { 291 | "A": "[70.7 cos (20t - 60^{\\circ})] u(t) V", 292 | "B": "[70.7 cos (10t - 45^{\\circ})] u(t) V", 293 | "C": "[70.7 cos (20t - 45^{\\circ})] u(t) V", 294 | "D": "[70.7 cos (20t - 90^{\\circ})] u(t) V" 295 | }, 296 | "response": "C" 297 | }, 298 | { 299 | "id": "validation_Electronics_23", 300 | "question_type": "multiple-choice", 301 | "answer": "A", 302 | "all_choices": [ 303 | "A", 304 | "B", 305 | "C", 306 | "D" 307 | ], 308 | "index2ans": { 309 | "A": "$C_0 sin wt cos t + 2w C_0 cos wt (1 + 0.5 sin t)$", 310 | "B": "$C_0 cos wt cos t + 2w C_0 sin wt (1 + 0.5 sin t)$", 311 | "C": "$C_0 sin wt cos t + 4w C_0 cos wt (1 + sin t)$", 312 | "D": "$C_0 sin wt cos t + 2w C_0 cos wt (1 - sin t)$" 313 | }, 314 | "response": "A" 315 | }, 316 | { 317 | "id": "validation_Electronics_24", 318 | "question_type": "short-answer", 319 | "answer": "551", 320 | "response": "100" 321 | }, 322 | { 323 | "id": "validation_Electronics_25", 324 | "question_type": "short-answer", 325 | "answer": "50", 326 | "response": "100" 327 | }, 328 | { 329 | "id": "validation_Electronics_26", 330 | "question_type": "short-answer", 331 | "answer": "6.333", 332 | "response": "100" 333 | }, 334 | { 335 | "id": "validation_Electronics_27", 336 | "question_type": "short-answer", 337 | "answer": "-120", 338 | "response": "0" 339 | }, 340 | { 341 | "id": "validation_Electronics_28", 342 | "question_type": "multiple-choice", 343 | "answer": "A", 344 | "all_choices": [ 345 | "A", 346 | "B", 347 | "C", 348 | "D" 349 | ], 350 | "index2ans": { 351 | "A": "75 + 13.3 cos(250t - 57.7\u00b0)V", 352 | "B": "75 + 23.3 cos(250t - 57.7\u00b0)V", 353 | "C": "45 + 3.3 cos(250t - 57.7\u00b0)V", 354 | "D": "95 + 13.3 cos(250t - 57.7\u00b0)V" 355 | }, 356 | "response": "B" 357 | }, 358 | { 359 | "id": "validation_Electronics_29", 360 | "question_type": "short-answer", 361 | "answer": "30", 362 | "response": "36 watts" 363 | }, 364 | { 365 | "id": "validation_Electronics_30", 366 | "question_type": "short-answer", 367 | "answer": "-141", 368 | "response": "0" 369 | } 370 | ] -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/main_eval_only.py: -------------------------------------------------------------------------------- 1 | """Parse and Evalate""" 2 | import os 3 | import json 4 | 5 | import pdb 6 | from argparse import ArgumentParser 7 | 8 | from utils.data_utils import save_json, CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT 9 | from utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response, calculate_ins_level_acc 10 | 11 | 12 | if __name__ == '__main__': 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('--output_path', type=str, default="./example_outputs/qwen_vl/total_val_output.json", help="The path to model output file.") 16 | parser.add_argument('--answer_path', type=str, default="./answer_dict_val.json", help="Answer file path.") 17 | args = parser.parse_args() 18 | 19 | output_dict = json.load(open(args.output_path)) 20 | answer_dict = json.load(open(args.answer_path)) 21 | 22 | # group by category 23 | output_dict_w_cat = {} 24 | for data_id, parsed_pred in output_dict.items(): 25 | category = "_".join(data_id.split("_")[1:-1]) 26 | if category not in output_dict_w_cat: 27 | output_dict_w_cat.update({category: {}}) 28 | output_dict_w_cat[category].update({data_id: parsed_pred}) 29 | 30 | # group by category 31 | answer_dict_w_cat = {} 32 | for data_id, parsed_pred in answer_dict.items(): 33 | category = "_".join(data_id.split("_")[1:-1]) 34 | if category not in answer_dict_w_cat: 35 | answer_dict_w_cat.update({category: {}}) 36 | answer_dict_w_cat[category].update({data_id: parsed_pred}) 37 | 38 | evaluation_result = {} 39 | 40 | for category in CAT_SHORT2LONG.values(): 41 | print("Evaluating: {}".format(category)) 42 | # get cat_outputs and cat_answers 43 | try: 44 | cat_outputs = output_dict_w_cat[category] 45 | cat_answers = answer_dict_w_cat[category] 46 | except KeyError: 47 | print("Skipping {} for not found".format(category)) 48 | continue 49 | 50 | exampels_to_eval = [] 51 | for data_id, parsed_pred in cat_outputs.items(): 52 | question_type = cat_answers[data_id]['question_type'] 53 | if question_type != 'multiple-choice': 54 | parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.) 55 | else: 56 | parsed_pred = parsed_pred 57 | 58 | exampels_to_eval.append({ 59 | "id": data_id, 60 | "question_type": question_type, 61 | "answer": cat_answers[data_id]['ground_truth'], 62 | "parsed_pred": parsed_pred 63 | }) 64 | 65 | judge_dict, metric_dict = evaluate(exampels_to_eval) 66 | metric_dict.update({"num_example": len(exampels_to_eval)}) 67 | 68 | evaluation_result[category] = metric_dict 69 | 70 | printable_results = {} 71 | # pdb.set_trace() 72 | # add domain Subject 73 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): 74 | in_domain_cat_results = {} 75 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT 76 | if cat_name in evaluation_result.keys(): 77 | in_domain_cat_results[cat_name] = evaluation_result[cat_name] 78 | else: 79 | pass 80 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) 81 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()]) 82 | printable_results['Overall-' + domain] = {"num": int(in_domain_data_num), 83 | "acc": round(in_domain_ins_acc, 3) 84 | } 85 | # add sub category 86 | for cat_name, cat_results in in_domain_cat_results.items(): 87 | printable_results[cat_name] = {"num": int(cat_results['num_example']), 88 | "acc": round(cat_results['acc'], 3) 89 | } 90 | 91 | # table.append(["-----------------------------", "-----", "----"]) 92 | all_ins_acc = calculate_ins_level_acc(evaluation_result) 93 | printable_results['Overall'] = {"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]), 94 | "acc": round(all_ins_acc, 3) 95 | } 96 | 97 | print(printable_results) 98 | 99 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/main_parse_and_eval.py: -------------------------------------------------------------------------------- 1 | """Parse and Evalate""" 2 | import os 3 | import json 4 | from argparse import ArgumentParser 5 | 6 | from utils.data_utils import save_json, CAT_SHORT2LONG 7 | from utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response 8 | 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument('--path', type=str, default="./example_outputs/llava1.5_13b", help="The path to model output directory.") 14 | parser.add_argument('--subject', nargs='+', 15 | help=f'The name of the mmmu sub-category. Availble: {CAT_SHORT2LONG.keys()} or ALL') 16 | 17 | args = parser.parse_args() 18 | if args.subject[0] == 'ALL': 19 | args.subject = CAT_SHORT2LONG.keys() 20 | 21 | ex_output_path = os.path.join(args.path) 22 | 23 | all_results = {} 24 | for cat_short in args.subject: 25 | category = CAT_SHORT2LONG[cat_short] 26 | print("Evaluating: {}".format(category)) 27 | if category not in os.listdir(ex_output_path): 28 | print("Skipping {} for not found".format(category)) 29 | else: 30 | cat_folder_path = os.path.join(ex_output_path, category) 31 | cat_outputs = json.load(open(os.path.join(cat_folder_path, 'output.json'))) 32 | # Evaluation 33 | eval_samples = [] 34 | for cat_output in cat_outputs: 35 | response = cat_output['response'] 36 | if cat_output['question_type'] == 'multiple-choice': 37 | all_choices = cat_output['all_choices'] 38 | index2ans = cat_output['index2ans'] 39 | parsed_pred = parse_multi_choice_response(response, all_choices, index2ans) 40 | eval_samples.append( 41 | { 42 | 'id': cat_output['id'], 43 | 'question_type': cat_output['question_type'], 44 | 'answer': cat_output['answer'], # the content in option, not answer index. 45 | 'response': response, 46 | 'parsed_pred': parsed_pred, 47 | 'index2ans': index2ans, 48 | } 49 | ) 50 | else: # open 51 | parsed_pred = parse_open_response(response) 52 | eval_samples.append( 53 | { 54 | 'id': cat_output['id'], 55 | 'question_type': cat_output['question_type'], 56 | 'answer': cat_output['answer'], 57 | 'response': response, 58 | 'parsed_pred': parsed_pred, 59 | } 60 | ) 61 | 62 | print("Num of valid samples: {}, Expected Num: {}".format(len(eval_samples), len(cat_outputs))) 63 | 64 | judge_dict, metric_dict = evaluate(eval_samples) 65 | metric_dict.update({"num_example": len(eval_samples)}) 66 | for eval_sample in eval_samples: 67 | eval_sample.update({"judge": judge_dict[eval_sample['id']]}) 68 | 69 | save_json(os.path.join(cat_folder_path, 'parsed_output.json'), eval_samples) 70 | save_json(os.path.join(cat_folder_path, 'result.json'), metric_dict) 71 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/mmmu_eval_script.py: -------------------------------------------------------------------------------- 1 | """Parse and Evalate""" 2 | import os 3 | import json 4 | 5 | import pdb 6 | 7 | from .utils.data_utils import save_json, CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT 8 | from .utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response, calculate_ins_level_acc 9 | 10 | 11 | def mmmu_eval_func(output_path): 12 | # 13 | answer_path = './vlmeval/evaluate/mmmu_eval/answer_dict_val.json' 14 | output_dict = json.load(open(output_path)) 15 | answer_dict = json.load(open(answer_path)) 16 | 17 | # group by category 18 | output_dict_w_cat = {} 19 | for data_id, parsed_pred in output_dict.items(): 20 | category = "_".join(data_id.split("_")[1:-1]) 21 | if category not in output_dict_w_cat: 22 | output_dict_w_cat.update({category: {}}) 23 | output_dict_w_cat[category].update({data_id: parsed_pred}) 24 | 25 | # group by category 26 | answer_dict_w_cat = {} 27 | for data_id, parsed_pred in answer_dict.items(): 28 | category = "_".join(data_id.split("_")[1:-1]) 29 | if category not in answer_dict_w_cat: 30 | answer_dict_w_cat.update({category: {}}) 31 | answer_dict_w_cat[category].update({data_id: parsed_pred}) 32 | 33 | evaluation_result = {} 34 | 35 | for category in CAT_SHORT2LONG.values(): 36 | print("Evaluating: {}".format(category)) 37 | # get cat_outputs and cat_answers 38 | try: 39 | cat_outputs = output_dict_w_cat[category] 40 | cat_answers = answer_dict_w_cat[category] 41 | except KeyError: 42 | print("Skipping {} for not found".format(category)) 43 | continue 44 | 45 | exampels_to_eval = [] 46 | for data_id, parsed_pred in cat_outputs.items(): 47 | question_type = cat_answers[data_id]['question_type'] 48 | if question_type != 'multiple-choice': 49 | parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.) 50 | else: 51 | parsed_pred = parsed_pred 52 | 53 | exampels_to_eval.append({ 54 | "id": data_id, 55 | "question_type": question_type, 56 | "answer": cat_answers[data_id]['ground_truth'], 57 | "parsed_pred": parsed_pred 58 | }) 59 | 60 | judge_dict, metric_dict = evaluate(exampels_to_eval) 61 | metric_dict.update({"num_example": len(exampels_to_eval)}) 62 | 63 | evaluation_result[category] = metric_dict 64 | 65 | printable_results = {} 66 | # pdb.set_trace() 67 | # add domain Subject 68 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): 69 | in_domain_cat_results = {} 70 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT 71 | if cat_name in evaluation_result.keys(): 72 | in_domain_cat_results[cat_name] = evaluation_result[cat_name] 73 | else: 74 | pass 75 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) 76 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()]) 77 | printable_results['Overall-' + domain] = {"num": int(in_domain_data_num), 78 | "acc": round(in_domain_ins_acc, 3) 79 | } 80 | # add sub category 81 | for cat_name, cat_results in in_domain_cat_results.items(): 82 | printable_results[cat_name] = {"num": int(cat_results['num_example']), 83 | "acc": round(cat_results['acc'], 3) 84 | } 85 | 86 | # table.append(["-----------------------------", "-----", "----"]) 87 | all_ins_acc = calculate_ins_level_acc(evaluation_result) 88 | printable_results['Overall'] = {"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]), 89 | "acc": round(all_ins_acc, 3) 90 | } 91 | 92 | print(printable_results) 93 | 94 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/print_results.py: -------------------------------------------------------------------------------- 1 | # Beautiful table to print results of all categories 2 | 3 | import os 4 | from typing import Dict 5 | import json 6 | import numpy as np 7 | from tabulate import tabulate 8 | 9 | from argparse import ArgumentParser 10 | 11 | from utils.data_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT 12 | 13 | from utils.eval_utils import calculate_ins_level_acc 14 | 15 | def main(): 16 | parser = ArgumentParser() 17 | parser.add_argument('--path', type=str, default="./example_outputs/blip2_flant5xxl", help="The path to output directory.") 18 | args = parser.parse_args() 19 | 20 | # load all results 21 | all_results = {} 22 | for cat_folder_name in os.listdir(args.path): 23 | if cat_folder_name in CAT_SHORT2LONG.values(): 24 | cat_folder_path = os.path.join(args.path, cat_folder_name) 25 | result_path = os.path.join(cat_folder_path, 'result.json') 26 | if os.path.exists(result_path): 27 | cat_results = json.load(open(result_path)) 28 | all_results[cat_folder_name] = cat_results 29 | 30 | # print results 31 | headers = ['Subject', 'Data Num', 'Acc'] 32 | table = [] 33 | 34 | # add domain Subject 35 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): 36 | in_domain_cat_results = {} 37 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT 38 | if cat_name in all_results.keys(): 39 | in_domain_cat_results[cat_name] = all_results[cat_name] 40 | else: 41 | pass 42 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) 43 | in_domain_data_num = np.sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()]) 44 | table.append(['Overall-' + domain, int(in_domain_data_num), round(in_domain_ins_acc, 3)]) 45 | # add sub category 46 | for cat_name, cat_results in in_domain_cat_results.items(): 47 | table.append([cat_name, int(cat_results['num_example']), round(cat_results['acc'], 3)]) 48 | # table.append(["-----------------------------", "-----", "----"]) 49 | 50 | # table.append(["-----------------------------", "-----", "----"]) 51 | all_ins_acc = calculate_ins_level_acc(all_results) 52 | table.append(['Overall', np.sum([cat_results['num_example'] for cat_results in all_results.values()]), round(all_ins_acc, 3)]) 53 | 54 | print(tabulate(table, headers=headers, tablefmt='orgtbl')) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from datasets import load_dataset, concatenate_datasets 9 | from llava.model.builder import load_pretrained_model 10 | from llava.mm_utils import get_model_name_from_path 11 | 12 | from argparse import ArgumentParser 13 | 14 | from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG 15 | from utils.model_utils import call_llava_engine_df, llava_image_processor 16 | from utils.eval_utils import parse_multi_choice_response, parse_open_response 17 | 18 | 19 | def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None): 20 | out_samples = dict() 21 | with torch.no_grad(): 22 | for sample in tqdm(samples): 23 | response = call_model_engine_fn(args, sample, model, tokenizer, processor) 24 | 25 | if sample['question_type'] == 'multiple-choice': 26 | pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans']) 27 | else: # open question 28 | pred_ans = response 29 | out_samples[sample['id']] = pred_ans 30 | return out_samples 31 | 32 | def set_seed(seed_value): 33 | """ 34 | Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results. 35 | 36 | :param seed_value: An integer value to be used as the seed. 37 | """ 38 | torch.manual_seed(seed_value) 39 | if torch.cuda.is_available(): 40 | torch.cuda.manual_seed(seed_value) 41 | torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups 42 | random.seed(seed_value) 43 | np.random.seed(seed_value) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | 47 | def main(): 48 | parser = ArgumentParser() 49 | parser.add_argument('--output_path', type=str, default='llava1.5_13b_val.json', 50 | help='name of saved json') 51 | parser.add_argument('--config_path', type=str, default="configs/llava1.5.yaml") 52 | parser.add_argument('--data_path', type=str, default="MMMU/MMMU") # hf dataset path. 53 | parser.add_argument('--model_path', type=str, default="liuhaotian/llava-v1.5-13b") 54 | parser.add_argument('--split', type=str, default='validation') 55 | parser.add_argument('--seed', type=int, default=42) 56 | 57 | args = parser.parse_args() 58 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 59 | set_seed(args.seed) 60 | 61 | print('llava_initializing...') 62 | processor = None 63 | call_model_engine = call_llava_engine_df 64 | vis_process_func = llava_image_processor 65 | 66 | # load config and process to one value 67 | args.config = load_yaml(args.config_path) 68 | for key, value in args.config.items(): 69 | if key != 'eval_params' and type(value) == list: 70 | assert len(value) == 1, 'key {} has more than one value'.format(key) 71 | args.config[key] = value[0] 72 | 73 | # run for each subject 74 | sub_dataset_list = [] 75 | for subject in CAT_SHORT2LONG.values(): 76 | sub_dataset = load_dataset(args.data_path, subject, split=args.split) 77 | sub_dataset_list.append(sub_dataset) 78 | 79 | # merge all dataset 80 | dataset = concatenate_datasets(sub_dataset_list) 81 | 82 | 83 | # load model 84 | model_name = get_model_name_from_path(args.model_path) 85 | tokenizer, model, vis_processors, _ = load_pretrained_model(args.model_path, None, 86 | model_name) 87 | 88 | samples = [] 89 | for sample in dataset: 90 | sample = process_single_sample(sample) 91 | 92 | sample = construct_prompt(sample, args.config) 93 | if sample['image']: 94 | sample['image'] = vis_process_func(sample['image'], vis_processors).to(device) 95 | samples.append(sample) 96 | 97 | # run ex 98 | out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor) 99 | 100 | save_json(args.output_path, out_samples) 101 | # metric_dict.update({"num_example": len(out_samples)}) 102 | # save_json(save_result_path, metric_dict) 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | 108 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for data load, save, and process (e.g., prompt construction)""" 2 | 3 | import os 4 | import json 5 | import yaml 6 | import re 7 | 8 | 9 | DOMAIN_CAT2SUB_CAT = { 10 | 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'], 11 | 'Business': ['Accounting', 'Economics', 'Finance', 'Manage','Marketing'], 12 | 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics',], 13 | 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 'Pharmacy', 'Public_Health'], 14 | 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'], 15 | 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'], 16 | } 17 | 18 | 19 | CAT_SHORT2LONG = { 20 | 'acc': 'Accounting', 21 | 'agri': 'Agriculture', 22 | 'arch': 'Architecture_and_Engineering', 23 | 'art': 'Art', 24 | 'art_theory': 'Art_Theory', 25 | 'bas_med': 'Basic_Medical_Science', 26 | 'bio': 'Biology', 27 | 'chem': 'Chemistry', 28 | 'cli_med': 'Clinical_Medicine', 29 | 'cs': 'Computer_Science', 30 | 'design': 'Design', 31 | 'diag_med': 'Diagnostics_and_Laboratory_Medicine', 32 | 'econ': 'Economics', 33 | 'elec': 'Electronics', 34 | 'ep': 'Energy_and_Power', 35 | 'fin': 'Finance', 36 | 'geo': 'Geography', 37 | 'his': 'History', 38 | 'liter': 'Literature', 39 | 'manage': 'Manage', 40 | 'mark': 'Marketing', 41 | 'mate': 'Materials', 42 | 'math': 'Math', 43 | 'mech': 'Mechanical_Engineering', 44 | 'music': 'Music', 45 | 'phar': 'Pharmacy', 46 | 'phys': 'Physics', 47 | 'psy': 'Psychology', 48 | 'pub_health': 'Public_Health', 49 | 'socio': 'Sociology' 50 | } 51 | 52 | # DATA SAVING 53 | def save_json(filename, ds): 54 | with open(filename, 'w') as f: 55 | json.dump(ds, f, indent=4) 56 | 57 | 58 | def get_multi_choice_info(options): 59 | """ 60 | Given the list of options for multiple choice question 61 | Return the index2ans and all_choices 62 | """ 63 | 64 | start_chr = 'A' 65 | all_choices = [] 66 | index2ans = {} 67 | for i, option in enumerate(options): 68 | index2ans[chr(ord(start_chr) + i)] = option 69 | all_choices.append(chr(ord(start_chr) + i)) 70 | 71 | return index2ans, all_choices 72 | 73 | def load_yaml(file_path): 74 | with open(file_path, 'r') as stream: 75 | try: 76 | yaml_dict = yaml.safe_load(stream) 77 | except yaml.YAMLError as exc: 78 | print(exc) 79 | 80 | return yaml_dict 81 | 82 | 83 | def parse_img_path(text): 84 | matches = re.findall("", text) 85 | return matches 86 | 87 | def process_single_sample(data): 88 | question = data['question'] 89 | o_imgs_paths = [] 90 | for option in data['options']: 91 | current_o_imgs_paths = parse_img_path(option) 92 | for img_path in current_o_imgs_paths: 93 | o_imgs_paths.append(img_path) 94 | 95 | if len(o_imgs_paths) > 1: # multiple images in options, used for random selection 96 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'], 97 | 'image': None, 'question_type': data['question_type']} 98 | else: 99 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'], 100 | 'image': data['image_1'], 'question_type': data['question_type']} 101 | 102 | 103 | # DATA SAVING 104 | def save_json(filename, ds): 105 | with open(filename, 'w') as f: 106 | json.dump(ds, f, indent=4) 107 | 108 | def save_jsonl(filename, data): 109 | """ 110 | Save a dictionary of data to a JSON Lines file with the filename as key and caption as value. 111 | 112 | Args: 113 | filename (str): The path to the file where the data should be saved. 114 | data (dict): The dictionary containing the data to save where key is the image path and value is the caption. 115 | """ 116 | with open(filename, 'w', encoding='utf-8') as f: 117 | for img_path, caption in data.items(): 118 | # Extract the base filename without the extension 119 | base_filename = os.path.basename(img_path) 120 | # Create a JSON object with the filename as the key and caption as the value 121 | json_record = json.dumps({base_filename: caption}, ensure_ascii=False) 122 | # Write the JSON object to the file, one per line 123 | f.write(json_record + '\n') 124 | 125 | def save_args(args, path_dir): 126 | argsDict = args.__dict__ 127 | with open(path_dir + 'setting.txt', 'w') as f: 128 | f.writelines('------------------ start ------------------' + '\n') 129 | for eachArg, value in argsDict.items(): 130 | f.writelines(eachArg + ' : ' + str(value) + '\n') 131 | f.writelines('------------------- end -------------------') 132 | 133 | 134 | 135 | # DATA PROCESSING 136 | def construct_prompt(sample, config): 137 | question = sample['question'] 138 | options = eval(sample['options']) 139 | example = "" 140 | if sample['question_type'] == 'multiple-choice': 141 | start_chr = 'A' 142 | prediction_range = [] 143 | index2ans = {} 144 | for option in options: 145 | prediction_range.append(start_chr) 146 | example += f"({start_chr}) {option}\n" 147 | index2ans[start_chr] = option 148 | start_chr = chr(ord(start_chr) + 1) 149 | empty_prompt_sample_structure = config['multi_choice_example_format'] 150 | empty_prompt = empty_prompt_sample_structure.format(question, example) 151 | res_dict = {} 152 | res_dict['index2ans'] = index2ans 153 | res_dict['correct_choice'] = sample['answer'] 154 | res_dict['all_choices'] = prediction_range 155 | res_dict['empty_prompt'] = empty_prompt 156 | if config['task_instructions']: 157 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt 158 | else: 159 | res_dict['final_input_prompt'] = empty_prompt 160 | 161 | res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] 162 | else: 163 | empty_prompt_sample_structure = config['short_ans_example_format'] 164 | empty_prompt = empty_prompt_sample_structure.format(question) 165 | res_dict = {} 166 | res_dict['empty_prompt'] = empty_prompt 167 | if config['task_instructions']: 168 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt 169 | else: 170 | res_dict['final_input_prompt'] = empty_prompt 171 | res_dict['gt_content'] = sample['answer'] 172 | 173 | res_dict.update(sample) 174 | return res_dict -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | """Response Parsing and Evaluation for various models""" 2 | from typing import Dict 3 | 4 | import re 5 | import random 6 | random.seed(42) 7 | import numpy as np 8 | 9 | # ----------- Process Multi-choice ------------- 10 | def parse_multi_choice_response(response, all_choices, index2ans): 11 | """ 12 | Parse the prediction from the generated response. 13 | Return the predicted index e.g., A, B, C, D. 14 | """ 15 | for char in [',', '.', '!', '?', ';', ':', "'"]: 16 | response = response.strip(char) 17 | response = " " + response + " " # add space to avoid partial match 18 | 19 | index_ans = True 20 | ans_with_brack = False 21 | candidates = [] 22 | for choice in all_choices: # e.g., (A) (B) (C) (D) 23 | if f'({choice})' in response: 24 | candidates.append(choice) 25 | ans_with_brack = True 26 | 27 | if len(candidates) == 0: 28 | for choice in all_choices: # e.g., A B C D 29 | if f' {choice} ' in response: 30 | candidates.append(choice) 31 | 32 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example 33 | if len(candidates) == 0 and len(response.split()) > 5: 34 | for index, ans in index2ans.items(): 35 | if ans.lower() in response.lower(): 36 | candidates.append(index) 37 | index_ans = False # it's content ans. 38 | 39 | if len(candidates) == 0: # still not get answer, randomly choose one. 40 | pred_index = random.choice(all_choices) 41 | elif len(candidates) > 1: 42 | start_indexes = [] 43 | if index_ans: 44 | if ans_with_brack: 45 | for can in candidates: 46 | index = response.rfind(f'({can})') 47 | start_indexes.append(index) # -1 will be ignored anyway 48 | # start_indexes = [generated_response.index(f'({can})') for can in candidates] 49 | else: 50 | for can in candidates: 51 | index = response.rfind(f" {can} ") 52 | start_indexes.append(index) 53 | else: 54 | for can in candidates: 55 | index = response.lower().rfind(index2ans[can].lower()) 56 | start_indexes.append(index) 57 | # get the last one 58 | pred_index = candidates[np.argmax(start_indexes)] 59 | else: # if only one candidate, use it. 60 | pred_index = candidates[0] 61 | 62 | return pred_index 63 | 64 | # ----------- Process Open ------------- 65 | def check_is_number(string): 66 | """ 67 | Check if the given string a number. 68 | """ 69 | try: 70 | float(string.replace(',', '')) 71 | return True 72 | except ValueError: 73 | # check if there's comma inside 74 | return False 75 | 76 | def normalize_str(string): 77 | """ 78 | Normalize the str to lower case and make them float numbers if possible. 79 | """ 80 | # check if characters in the string 81 | 82 | # if number, numerize it. 83 | string = string.strip() 84 | 85 | is_number = check_is_number(string) 86 | 87 | if is_number: 88 | string = string.replace(',', '') 89 | string = float(string) 90 | # leave 2 decimal 91 | string = round(string, 2) 92 | return [string] 93 | else: # it's likely to be a string 94 | # lower it 95 | string = string.lower() 96 | if len(string) == 1: 97 | return [" " + string, string + " "] # avoid trivial matches 98 | return [string] 99 | 100 | def extract_numbers(string): 101 | """ 102 | Exact all forms of numbers from a string with regex. 103 | """ 104 | # Pattern for numbers with commas 105 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b' 106 | # Pattern for scientific notation 107 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' 108 | # Pattern for simple numbers without commas 109 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])' 110 | 111 | # Extract numbers with commas 112 | numbers_with_commas = re.findall(pattern_commas, string) 113 | # Extract numbers in scientific notation 114 | numbers_scientific = re.findall(pattern_scientific, string) 115 | # Extract simple numbers without commas 116 | numbers_simple = re.findall(pattern_simple, string) 117 | 118 | # Combine all extracted numbers 119 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple 120 | return all_numbers 121 | 122 | def parse_open_response(response): 123 | """ 124 | Parse the prediction from the generated response. 125 | Return a list of predicted strings or numbers. 126 | """ 127 | # content = content.strip("\n").strip(".").strip(" ") 128 | def get_key_subresponses(response): 129 | key_responses = [] 130 | response = response.strip().strip(".").lower() 131 | sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response) 132 | indicators_of_keys = ['could be ', 'so ', 'is ', 133 | 'thus ', 'therefore ', 'final ', 'answer ', 'result '] 134 | key_responses = [] 135 | for index, resp in enumerate(sub_responses): 136 | # if last one, accept it's an equation (the entire response can be just one sentence with equation) 137 | if index == len(sub_responses) - 1: 138 | indicators_of_keys.extend(['=']) 139 | shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) 140 | for indicator in indicators_of_keys: 141 | if indicator in resp: 142 | if not shortest_key_response: 143 | shortest_key_response = resp.split(indicator)[-1].strip() 144 | else: 145 | if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): 146 | shortest_key_response = resp.split(indicator)[-1].strip() 147 | # key_responses.append(resp.split(indicator)[1].strip()) 148 | 149 | if shortest_key_response: 150 | # and it's not trivial 151 | if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]: 152 | key_responses.append(shortest_key_response) 153 | if len(key_responses) == 0: # did not found any 154 | return [response] 155 | return key_responses 156 | # pdb.set_trace() 157 | key_responses = get_key_subresponses(response) 158 | 159 | pred_list = key_responses.copy() # keep the original string response 160 | for resp in key_responses: 161 | pred_list.extend(extract_numbers(resp)) 162 | 163 | tmp_pred_list = [] 164 | for i in range(len(pred_list)): 165 | tmp_pred_list.extend(normalize_str(pred_list[i])) 166 | pred_list = tmp_pred_list 167 | 168 | # remove duplicates 169 | pred_list = list(set(pred_list)) 170 | 171 | return pred_list 172 | 173 | # ----------- Evaluation ------------- 174 | 175 | def eval_multi_choice(gold_i, pred_i): 176 | """ 177 | Evaluate a multiple choice instance. 178 | """ 179 | correct = False 180 | # only they are exactly the same, we consider it as correct 181 | if isinstance(gold_i, list): 182 | for answer in gold_i: 183 | if answer == pred_i: 184 | correct = True 185 | break 186 | else: # gold_i is a string 187 | if gold_i == pred_i: 188 | correct = True 189 | return correct 190 | 191 | def eval_open(gold_i, pred_i): 192 | """ 193 | Evaluate an open question instance 194 | """ 195 | correct = False 196 | if isinstance(gold_i, list): 197 | # use float to avoid trivial matches 198 | norm_answers = [] 199 | for answer in gold_i: 200 | norm_answers.extend(normalize_str(answer)) 201 | else: 202 | norm_answers = normalize_str(gold_i) 203 | for pred in pred_i: # pred is already normalized in parse response phase 204 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i 205 | for norm_ans in norm_answers: 206 | # only see if the string answer in the string pred 207 | if isinstance(norm_ans, str) and norm_ans in pred: 208 | if not correct: 209 | correct = True 210 | break 211 | else: # it's a float number 212 | if pred in norm_answers: 213 | if not correct: 214 | correct = True 215 | break 216 | return correct 217 | 218 | # ----------- Batch Evaluation ------------- 219 | def evaluate(samples): 220 | """ 221 | Batch evaluation for multiple choice and open questions. 222 | """ 223 | pred_correct = 0 224 | judge_dict = dict() 225 | for sample in samples: 226 | gold_i = sample['answer'] 227 | pred_i = sample['parsed_pred'] 228 | if sample['question_type'] == 'multiple-choice': 229 | correct = eval_multi_choice(gold_i, pred_i) 230 | else: # open question 231 | correct = eval_open(gold_i, pred_i) 232 | 233 | if correct: 234 | judge_dict[sample['id']] = 'Correct' 235 | pred_correct += 1 236 | else: 237 | judge_dict[sample['id']] = 'Wrong' 238 | 239 | if len(samples) == 0: 240 | return {'acc': 0} 241 | return judge_dict, {'acc': pred_correct / len(samples)} 242 | 243 | 244 | 245 | # ----------- Calculate Accuracy ------------- 246 | def calculate_ins_level_acc(results: Dict): 247 | """Calculate the instruction level accuracy for given Subject results""" 248 | acc = 0 249 | ins_num = 0 250 | for cat_results in results.values(): 251 | acc += cat_results['acc'] * cat_results['num_example'] 252 | ins_num += cat_results['num_example'] 253 | if ins_num == 0: 254 | return 0 255 | return acc / ins_num 256 | 257 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmmu_eval/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | import torch 3 | 4 | def call_llava_engine_df(args, sample, model, tokenizer=None, processor=None): 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 6 | from llava.conversation import conv_templates, SeparatorStyle 7 | 8 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 9 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 10 | 11 | def insert_separator(X, sep): 12 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 13 | 14 | input_ids = [] 15 | offset = 0 16 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 17 | offset = 1 18 | input_ids.append(prompt_chunks[0][0]) 19 | 20 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 21 | input_ids.extend(x[offset:]) 22 | 23 | if return_tensors is not None: 24 | if return_tensors == 'pt': 25 | return torch.tensor(input_ids, dtype=torch.long) 26 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 27 | return input_ids 28 | 29 | def deal_with_prompt(input_text, mm_use_im_start_end): 30 | qs = input_text 31 | if mm_use_im_start_end: 32 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 33 | else: 34 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 35 | return qs 36 | 37 | prompt = sample['final_input_prompt'] 38 | prompt = deal_with_prompt(prompt, model.config.mm_use_im_start_end) 39 | conv = conv_templates['vicuna_v1'].copy() 40 | conv.append_message(conv.roles[0], prompt) 41 | conv.append_message(conv.roles[1], None) 42 | prompt = conv.get_prompt() 43 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 44 | image = sample['image'] 45 | if image is not None: 46 | output_ids = model.generate( 47 | input_ids, 48 | images=image.unsqueeze(0).half().cuda(), 49 | do_sample=True, 50 | temperature=1, 51 | top_p=None, 52 | num_beams=5, 53 | max_new_tokens=128, 54 | use_cache=True) 55 | 56 | input_token_len = input_ids.shape[1] 57 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 58 | if n_diff_input_output > 0: 59 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 60 | response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 61 | else: # multiple images actually 62 | if sample['question_type'] == 'multiple-choice': 63 | all_choices = sample['all_choices'] 64 | response = random.choice(all_choices) 65 | else: 66 | response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS' 67 | 68 | return response 69 | 70 | 71 | def llava_image_processor(raw_image, vis_processors=None): 72 | image_tensor = vis_processors.preprocess(raw_image, return_tensors='pt')['pixel_values'][0] 73 | return image_tensor 74 | -------------------------------------------------------------------------------- /vlmeval/evaluate/mmvet_eval.py: -------------------------------------------------------------------------------- 1 | from vlmeval.evaluate.misc import build_judge 2 | from vlmeval.smp import * 3 | from vlmeval.utils import track_progress_rich 4 | 5 | def build_mmvet_gpt4_prompt(line): 6 | question = line['question'] 7 | gt = str(line['answer']) 8 | prediction = str(line['prediction']) 9 | prompt = """Compare the ground truth and prediction from AI models, to give a correctness score for the prediction. in the ground truth means it is totally right only when all elements in the ground truth are present in the prediction, and means it is totally right when any one element in the ground truth is present in the prediction. The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Just complete the last space of the correctness score. 10 | 11 | Question | Ground truth | Prediction | Correctness 12 | --- | --- | --- | --- 13 | What is x in the equation? | -1 -5 | x = 3 | 0.0 14 | What is x in the equation? | -1 -5 | x = -1 | 0.5 15 | What is x in the equation? | -1 -5 | x = -5 | 0.5 16 | What is x in the equation? | -1 -5 | x = -5 or 5 | 0.5 17 | What is x in the equation? | -1 -5 | x = -1 or x = -5 | 1.0 18 | Can you explain this meme? | This meme is poking fun at the fact that the names of the countries Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes, while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues because the names of these countries do not accurately represent their landscapes. | The meme talks about Iceland and Greenland. It's pointing out that despite their names, Iceland is not very icy and Greenland isn't very green. | 0.4 19 | Can you explain this meme? | This meme is poking fun at the fact that the names of the countries Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes, while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues because the names of these countries do not accurately represent their landscapes. | The meme is using humor to point out the misleading nature of Iceland's and Greenland's names. Iceland, despite its name, has lush green landscapes while Greenland is mostly covered in ice and snow. The text 'This is why I have trust issues' is a playful way to suggest that these contradictions can lead to distrust or confusion. The humor in this meme is derived from the unexpected contrast between the names of the countries and their actual physical characteristics. | 1.0 20 | """ 21 | gpt4_prompt = prompt + '\n' + ' | '.join([question, gt.replace("", " ").replace("", " "), prediction, ""]) 22 | return gpt4_prompt 23 | 24 | def MMVet_auxeval(model, line): 25 | def float_cvt(s): 26 | try: 27 | return float(s) 28 | except ValueError: 29 | return None 30 | 31 | prompt = build_mmvet_gpt4_prompt(line) 32 | log = '' 33 | retry = 5 34 | for i in range(retry): 35 | output = model.generate(prompt, temperature=i * 0.5) 36 | score = float_cvt(output) 37 | if score is None: 38 | log += f'Try {i}: output is {output}, failed to parse.\n' 39 | elif score < 0 or score > 1: 40 | log += f'Try {i}: output is {output}, invalid score: {score}.\n' 41 | else: 42 | log += 'Succeed' 43 | return dict(log=log, score=score) 44 | log += 'All 5 retries failed.\n' 45 | return dict(log=log, score=0.0) 46 | 47 | def MMVet_acc(result_file): 48 | data = load(result_file) 49 | tot = defaultdict(lambda: 0) 50 | score = defaultdict(lambda: 0) 51 | lt = len(data) 52 | cate2_list = [] 53 | for i in range(lt): 54 | item = data.iloc[i] 55 | cate = item['category'] 56 | cate2 = cate.replace(',','_') 57 | if cate2 not in cate2_list: 58 | cate2_list.append(cate2) 59 | grade = float(item['score']) 60 | cate_list = ['rec','ocr','know','gen','spat','math'] 61 | for capa in cate_list: 62 | if capa in cate: 63 | tot[capa] += 1 64 | score[capa] += grade 65 | tot['Overall'] += 1 66 | tot[cate2] += 1 67 | score['Overall'] += grade 68 | score[cate2] += grade 69 | 70 | res = defaultdict(list) 71 | res2 = defaultdict(list) 72 | cate_list.append('Overall') 73 | cate2_list.append('Overall') 74 | for k in cate_list: 75 | res['Category'].append(k) 76 | res['tot'].append(tot[k]) 77 | res['acc'].append(score[k] / tot[k] * 100) 78 | for v in cate2_list: 79 | res2['Category'].append(v) 80 | res2['tot'].append(tot[v]) 81 | res2['acc'].append(score[v] / tot[v] * 100) 82 | res = pd.DataFrame(res) 83 | res2 = pd.DataFrame(res2) 84 | return res, res2 85 | 86 | def MMVet_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False): 87 | logger = get_logger('Evaluation') 88 | 89 | suffix = eval_file.split('.')[-1] 90 | storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx') 91 | tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl') 92 | if osp.exists(storage): 93 | logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MMVet_eval. ") 94 | else: 95 | data = load(eval_file) 96 | gpt_version = model 97 | model = build_judge(gpt_version, verbose=verbose, max_tokens=3, retry=10) 98 | 99 | lt = len(data) 100 | lines = [data.iloc[i] for i in range(lt)] 101 | tups = [(model, line) for line in lines] 102 | indices = [line['index'] for line in lines] 103 | 104 | ans = {} 105 | if osp.exists(tmp_file): 106 | ans = load(tmp_file) 107 | tups = [x for x, i in zip(tups, indices) if i not in ans] 108 | indices = [i for i in indices if i not in ans] 109 | 110 | if len(indices): 111 | new_results = track_progress_rich( 112 | MMVet_auxeval, tups, nproc=nproc, chunksize=nproc, 113 | keys=indices, save=tmp_file) 114 | ans = load(tmp_file) 115 | for k, v in zip(indices, new_results): 116 | assert k in ans 117 | assert ans[k]['log'] == v['log'] and ans[k]['score'] == v['score'] 118 | 119 | log_map, score_map = {}, {} 120 | all_inds = [line['index'] for line in lines] 121 | for k in all_inds: 122 | log_map[k] = ans[k]['log'] 123 | score_map[k] = ans[k]['score'] 124 | data['score'] = [score_map[idx] for idx in data['index']] 125 | data['log'] = [log_map[idx] for idx in data['index']] 126 | dump(data, storage) 127 | 128 | score, score_fine = MMVet_acc(storage) 129 | score_pth = storage.replace('.xlsx', '_score.csv') 130 | score_fine_pth = storage.replace('.xlsx', '_score_fine.csv') 131 | 132 | dump(score, score_pth) 133 | dump(score_fine, score_fine_pth) 134 | logger.info(f'MMVet_eval successfully finished evaluating {eval_file}, results saved in {score_pth} and {score_fine_pth}') 135 | logger.info(f'Score: ') 136 | logger.info(score) 137 | 138 | def parse_args(): 139 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ") 140 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ") 141 | parser.add_argument( 142 | "--model", 143 | type=str, 144 | help="The LLM (GPT) used for inference. ", 145 | default="gpt-4-turbo", 146 | choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613']) 147 | parser.add_argument("--nproc", type=int, default=4) 148 | parser.add_argument("--verbose", action='store_true') 149 | args = parser.parse_args() 150 | return args 151 | 152 | if __name__ == '__main__': 153 | args = parse_args() 154 | MMVet_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose) 155 | -------------------------------------------------------------------------------- /vlmeval/evaluate/yes_or_no.py: -------------------------------------------------------------------------------- 1 | from vlmeval.evaluate.misc import build_judge 2 | from vlmeval.smp import * 3 | from vlmeval.utils import track_progress_rich 4 | 5 | INTERNAL = os.environ.get('INTERNAL', 0) 6 | 7 | def MME_rating(data_file): 8 | data = load(data_file) 9 | stats = defaultdict(dict) 10 | lt = len(data) 11 | for i in range(lt): 12 | item = data.iloc[i] 13 | category = item['category'] 14 | image_path = item['image_path'] 15 | score = item['score'] 16 | if image_path not in stats[category]: 17 | stats[category][image_path] = [] 18 | stats[category][image_path].append(score) 19 | 20 | def acc(key, mode='normal'): 21 | res = stats[key] 22 | values = [] 23 | for val in res.values(): 24 | if mode == 'normal': 25 | values.extend(val) 26 | elif mode == 'plus': 27 | values.append(val[0] * val[1]) 28 | return np.mean(values) * 100 29 | 30 | scores = {} 31 | for k in stats: 32 | scores[k] = acc(k) + acc(k, 'plus') 33 | 34 | super_cates = dict( 35 | perception=['OCR', 'artwork', 'celebrity', 'color', 'count', 'existence', 'landmark', 'position', 'posters', 'scene'], 36 | reasoning=['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation'] 37 | ) 38 | 39 | ret = {} 40 | for sc, cate_list in super_cates.items(): 41 | base = 0 42 | for c in cate_list: 43 | base += scores[c] 44 | ret[sc] = base 45 | ret.update(scores) 46 | ret = d2df(ret) 47 | return ret 48 | 49 | def Hallusion_rating(data_file): 50 | def calc_fAcc(data): 51 | res = defaultdict(list) 52 | lt = len(data) 53 | for i in range(lt): 54 | line = data.iloc[i] 55 | res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score']) 56 | return np.mean([np.all(x) for x in res.values()]) * 100 57 | 58 | def calc_qAcc(data): 59 | res = defaultdict(list) 60 | lt = len(data) 61 | for i in range(lt): 62 | line = data.iloc[i] 63 | res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score']) 64 | return np.mean([np.all(x) for x in res.values()]) * 100 65 | 66 | def calc_aAcc(data): 67 | return np.mean(data['score']) * 100 68 | 69 | data = load(data_file) 70 | data['set_id'] = [x.split('_')[3] for x in data['index']] 71 | data['figure_id'] = [x.split('_')[4] for x in data['index']] 72 | data['question_id'] = [x.split('_')[5] for x in data['index']] 73 | 74 | res = dict(split=[], aAcc=[], fAcc=[], qAcc=[]) 75 | res['split'].append('Overall') 76 | res['aAcc'].append(calc_aAcc(data)) 77 | res['fAcc'].append(calc_fAcc(data)) 78 | res['qAcc'].append(calc_qAcc(data)) 79 | 80 | if 'category' in data: 81 | cates = list(set(data['category'])) 82 | for c in cates: 83 | sub = data[data['category'] == c] 84 | res['split'].append(c) 85 | res['aAcc'].append(calc_aAcc(sub)) 86 | res['fAcc'].append(calc_fAcc(sub)) 87 | res['qAcc'].append(calc_qAcc(sub)) 88 | 89 | if 'l2-category' in data: 90 | cates = list(set(data['l2-category'])) 91 | for c in cates: 92 | sub = data[data['l2-category'] == c] 93 | res['split'].append(c) 94 | res['aAcc'].append(calc_aAcc(sub)) 95 | res['fAcc'].append(calc_fAcc(sub)) 96 | res['qAcc'].append(calc_qAcc(sub)) 97 | ret = pd.DataFrame(res) 98 | return ret 99 | 100 | def default_rating(data_file): 101 | data = load(data_file) 102 | res = {} 103 | res['Overall'] = np.mean(data['score']) * 100 104 | if 'category' in data: 105 | cates = list(set(data['category'])) 106 | cates = [c for c in cates if not pd.isna(c)] 107 | cates.sort() 108 | for c in cates: 109 | sub = data[data['category'] == c] 110 | res[c] = np.mean(sub['score']) * 100 111 | if 'l2-category' in data: 112 | cates = list(set(data['l2-category'])) 113 | cates = [c for c in cates if not pd.isna(c)] 114 | cates.sort() 115 | for c in cates: 116 | sub = data[data['l2-category'] == c] 117 | res[c] = np.mean(sub['score']) * 100 118 | ret = d2df(res) 119 | return ret 120 | 121 | def YOrN_match_prompt(line): 122 | tmpl = ( 123 | "You are an AI assistant who will help me to match an answer with two options of a question. " 124 | "The options are only Yes / No. " 125 | "You are provided with a question and an answer, and you need to find which option (Yes / No) is most similar to the answer. " 126 | "If the meaning of all options are significantly different from the answer, output Unknown. "\ 127 | "Your should output a single word among the following 3 choices: Yes, No, Unknown.\n" 128 | "Example 1: \n" 129 | "Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is 'Hello'.\nYour output: Yes\n" 130 | "Example 2: \n" 131 | "Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is not 'Hello'.\nYour output: No\n" 132 | "Example 3: \n" 133 | "Question: {}?\nAnswer: {}\nYour output: " 134 | ) 135 | return tmpl.format(line['question'], line['prediction']) 136 | 137 | def YOrN_Extraction(output): 138 | s = output.lower() 139 | words = process_punctuation(s).split() 140 | if 'yes' in words and 'no' not in words: 141 | return 'Yes' 142 | if 'yes' not in words and 'no' in words: 143 | return 'No' 144 | return 'Unknown' 145 | 146 | def YOrN_auxeval(model, line): 147 | prompt = YOrN_match_prompt(line) 148 | retry = 5 149 | for i in range(retry): 150 | output = model.generate(prompt, temperature=0.5 * i) 151 | ans = YOrN_Extraction(output) 152 | if ans != 'Unknown': 153 | return ans 154 | return 'Unknown' 155 | 156 | def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=None): 157 | logger = get_logger('Evaluation') 158 | data = load(eval_file) 159 | data['prediction'] = [str(x) for x in data['prediction']] 160 | storage = eval_file.replace('.xlsx', '_auxmatch.xlsx') 161 | tmp_file = eval_file.replace('.xlsx', '_tmp.pkl') 162 | 163 | if not osp.exists(storage): 164 | ans_map = {k: YOrN_Extraction(v) for k, v in zip(data['index'], data['prediction'])} 165 | if osp.exists(tmp_file): 166 | tmp = load(tmp_file) 167 | for k in tmp: 168 | if ans_map[k] == 'Unknown' and tmp[k] != 'Unknown': 169 | ans_map[k] = tmp[k] 170 | 171 | data['extracted'] = [ans_map[x] for x in data['index']] 172 | unknown = data[data['extracted'] == 'Unknown'] 173 | 174 | model_name = 'chatgpt-0613' 175 | 176 | if INTERNAL or gpt_key_set(): 177 | model = build_judge(model_name, verbose=verbose, retry=10) 178 | else: 179 | logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation') 180 | model = None 181 | 182 | if model is not None: 183 | lt = len(unknown) 184 | lines = [unknown.iloc[i] for i in range(lt)] 185 | tups = [(model, line) for line in lines] 186 | indices = list(unknown['index']) 187 | if len(tups): 188 | res = track_progress_rich(YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file) 189 | for k, v in zip(indices, res): 190 | ans_map[k] = v 191 | 192 | data['extracted'] = [ans_map[x] for x in data['index']] 193 | dump(data, storage) 194 | else: 195 | logger.warning(f"GPT matching file {storage} already exists, will reuse it in YOrN_eval. ") 196 | 197 | data = load(storage) 198 | data["score"] = (data["answer"] == data["extracted"]) 199 | dump(data, storage) 200 | 201 | if dataset is not None and listinstr(['MME'], dataset): 202 | score = MME_rating(storage) 203 | elif dataset is not None and listinstr(['Hallusion'], dataset): 204 | score = Hallusion_rating(storage) 205 | else: 206 | score = default_rating(storage) 207 | 208 | score_tgt = eval_file.replace('.xlsx', '_score.csv') 209 | dump(score, score_tgt) 210 | 211 | logger.info(f'YOrN_eval successfully finished evaluating {eval_file}, results saved in {score_tgt}') 212 | logger.info('Score: ') 213 | logger.info(score) 214 | return score 215 | 216 | def parse_args(): 217 | parser = argparse.ArgumentParser(description="Inference LLM Answers. ") 218 | parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ") 219 | parser.add_argument("--model", type=str, help="The LLM (GPT) used for inference. ", default="chatgpt-0613", choices=['chatgpt-0613']) 220 | parser.add_argument("--nproc", type=int, default=4) 221 | parser.add_argument("--dataset", type=str, default=None) 222 | parser.add_argument("--verbose", action='store_true') 223 | args = parser.parse_args() 224 | return args 225 | 226 | if __name__ == '__main__': 227 | args = parse_args() 228 | acc = YOrN_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose, dataset=args.dataset) 229 | -------------------------------------------------------------------------------- /vlmeval/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import datetime 4 | from vlmeval.config import supported_VLM 5 | from vlmeval.utils import TSVDataset, track_progress_rich, split_MMMU 6 | from vlmeval.smp import * 7 | 8 | FAIL_MSG = 'Failed to obtain answer via API.' 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data', type=str, nargs='+', required=True) 13 | parser.add_argument("--model", type=str, nargs='+', required=True) 14 | parser.add_argument("--nproc", type=int, default=4, required=True) 15 | parser.add_argument("--verbose", action='store_true') 16 | args = parser.parse_args() 17 | return args 18 | 19 | # Only API model is accepted 20 | def infer_data_api(work_dir, model_name, dataset_name, index_set, api_nproc=4): 21 | rank, world_size = get_rank_and_world_size() 22 | assert rank == 0 and world_size == 1 23 | dataset = TSVDataset(dataset_name) 24 | data = dataset.data 25 | data = data[data['index'].isin(index_set)] 26 | 27 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name 28 | is_api = getattr(model, 'is_api', False) 29 | assert is_api 30 | 31 | lt, indices = len(data), list(data['index']) 32 | structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)] 33 | 34 | out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl' 35 | res = {} 36 | if osp.exists(out_file): 37 | res = load(out_file) 38 | res = {k: v for k, v in res.items() if FAIL_MSG not in v} 39 | 40 | structs = [s for i, s in zip(indices, structs) if i not in res] 41 | indices = [i for i in indices if i not in res] 42 | 43 | gen_func = None 44 | if listinstr(['MMMU'], dataset_name): 45 | assert hasattr(model, 'interleave_generate') 46 | gen_func = model.interleave_generate 47 | structs = [dict(ti_list=split_MMMU(struct), dataset=dataset_name) for struct in structs] 48 | elif listinstr(['CORE_MM'], dataset_name): 49 | assert hasattr(model, 'multi_generate') 50 | gen_func = model.multi_generate 51 | structs = [dict(image_paths=struct['image'], prompt=struct['text'], dataset=dataset_name) for struct in structs] 52 | else: 53 | gen_func = model.generate 54 | structs = [dict(image_path=struct['image'], prompt=struct['text'], dataset=dataset_name) for struct in structs] 55 | 56 | inference_results = track_progress_rich( 57 | gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices) 58 | 59 | res = load(out_file) 60 | for idx, text in zip(indices, inference_results): 61 | assert (res[idx] == text if idx in res else True) 62 | res[idx] = text 63 | return res 64 | 65 | def infer_data(model_name, dataset_name, out_file, verbose=False, api_nproc=4): 66 | res = {} 67 | if osp.exists(out_file): 68 | res = load(out_file) 69 | 70 | rank, world_size = get_rank_and_world_size() 71 | if rank == 0: 72 | dataset = TSVDataset(dataset_name) 73 | if world_size > 1: 74 | dist.barrier() 75 | dataset = TSVDataset(dataset_name) 76 | 77 | indices = list(range(rank, len(dataset), world_size)) 78 | lt = len(indices) 79 | data = dataset.data.iloc[indices] 80 | 81 | # If finished, will exit without building the model 82 | all_finished = True 83 | for i in range(lt): 84 | idx = data.iloc[i]['index'] 85 | if idx not in res: 86 | all_finished = False 87 | if all_finished: 88 | return 89 | data = data[~data['index'].isin(res)] 90 | lt = len(data) 91 | 92 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name 93 | 94 | is_api = getattr(model, 'is_api', False) 95 | if is_api: 96 | assert world_size == 1 97 | lt, indices = len(data), list(data['index']) 98 | supp = infer_data_api(model_name=model_name, dataset_name=dataset_name, index_set=set(indices), api_nproc=api_nproc) 99 | for idx in indices: 100 | assert idx in supp 101 | res.update(supp) 102 | dump(res, out_file) 103 | return model_name 104 | 105 | for i in tqdm(range(lt)): 106 | idx = data.iloc[i]['index'] 107 | if idx in res: 108 | continue 109 | 110 | if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name): 111 | struct = model.build_prompt(data.iloc[i], dataset=dataset_name) 112 | else: 113 | struct = dataset.build_prompt(data.iloc[i]) 114 | 115 | if dataset_name in ['CORE_MM']: 116 | assert hasattr(model, 'multi_generate') 117 | response = model.multi_generate(prompt=struct['text'], image_paths=struct['image'], dataset=dataset_name) 118 | elif listinstr(['MMMU'], dataset_name): 119 | if hasattr(model, 'interleave_generate'): 120 | response = model.interleave_generate(ti_list=split_MMMU(struct), dataset=dataset_name) 121 | elif len(struct['image']) >= 1: 122 | response = model.generate(prompt=struct['text'], image_path=struct['image'], dataset=dataset_name, qtype=struct['qtype']) 123 | else: 124 | response = model.generate(prompt=struct['text'], image_path=struct['image'], dataset=dataset_name) 125 | torch.cuda.empty_cache() 126 | 127 | if verbose: 128 | print(response, flush=True) 129 | 130 | res[idx] = response 131 | if (i + 1) % 20 == 0: 132 | dump(res, out_file) 133 | 134 | dump(res, out_file) 135 | return model 136 | 137 | def prefetch_acc(result_file): 138 | data = load(result_file) 139 | from vlmeval.evaluate.multiple_choice import build_choices, can_infer 140 | tot = defaultdict(lambda: 0) 141 | match = defaultdict(lambda: 0) 142 | hit = defaultdict(lambda: 0) 143 | lt = len(data) 144 | for i in range(lt): 145 | item = data.iloc[i] 146 | cate = item['category'] 147 | tot['Overall'] += 1 148 | tot[cate] += 1 149 | choices = build_choices(item) 150 | matched = can_infer(item['prediction'], choices) 151 | if matched: 152 | match['Overall'] += 1 153 | match[cate] += 1 154 | if matched == item['answer']: 155 | hit['Overall'] += 1 156 | hit[cate] += 1 157 | res = defaultdict(list) 158 | for k in tot.keys(): 159 | res['Category'].append(k) 160 | res['tot'].append(tot[k]) 161 | res['match'].append(match[k]) 162 | res['hit'].append(hit[k]) 163 | res['match_rate'].append(match[k] / tot[k] * 100) 164 | if match[k] == 0: 165 | res['acc'].append(0) 166 | else: 167 | res['acc'].append(hit[k] / tot[k] * 100) 168 | res = pd.DataFrame(res) 169 | return res 170 | 171 | def infer_data_job(model, work_dir, model_name, dataset_name, verbose=False, api_nproc=4, ignore_failed=False): 172 | result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx') 173 | rank, world_size = get_rank_and_world_size() 174 | tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl') 175 | out_file = tmpl.format(rank) 176 | 177 | if True: #not osp.exists(result_file): 178 | model = infer_data(model, dataset_name=dataset_name, out_file=out_file, verbose=verbose) 179 | if world_size > 1: 180 | dist.barrier() 181 | 182 | if rank == 0: 183 | data_all = {} 184 | for i in range(world_size): 185 | data_all.update(load(tmpl.format(i))) 186 | 187 | data = TSVDataset(dataset_name).data 188 | assert len(data_all) == len(data) 189 | data['prediction'] = [str(data_all[x]) for x in data['index']] 190 | data.pop('image') 191 | 192 | dump(data, result_file) 193 | for i in range(world_size): 194 | os.remove(tmpl.format(i)) 195 | return model 196 | -------------------------------------------------------------------------------- /vlmeval/smp/__init__.py: -------------------------------------------------------------------------------- 1 | from .file import * 2 | from .vlm import * 3 | from .misc import * 4 | from .log import * 5 | from .lb import * -------------------------------------------------------------------------------- /vlmeval/smp/file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import pandas as pd 4 | import os 5 | import csv 6 | import hashlib 7 | import os.path as osp 8 | import time 9 | import numpy as np 10 | 11 | class NumpyEncoder(json.JSONEncoder): 12 | def default(self, obj): 13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, 14 | np.int16, np.int32, np.int64, np.uint8, 15 | np.uint16, np.uint32, np.uint64)): 16 | return int(obj) 17 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 18 | return float(obj) 19 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): 20 | return {'real': obj.real, 'imag': obj.imag} 21 | elif isinstance(obj, (np.ndarray,)): 22 | return obj.tolist() 23 | elif isinstance(obj, (np.bool_)): 24 | return bool(obj) 25 | elif isinstance(obj, (np.void)): 26 | return None 27 | return json.JSONEncoder.default(self, obj) 28 | 29 | # LOAD & DUMP 30 | def dump(data, f, **kwargs): 31 | def dump_pkl(data, pth, **kwargs): 32 | pickle.dump(data, open(pth, 'wb')) 33 | 34 | def dump_json(data, pth, **kwargs): 35 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder) 36 | 37 | def dump_jsonl(data, f, **kwargs): 38 | lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data] 39 | with open(f, 'w', encoding='utf8') as fout: 40 | fout.write('\n'.join(lines)) 41 | 42 | def dump_xlsx(data, f, **kwargs): 43 | data.to_excel(f, index=False, engine='xlsxwriter') 44 | 45 | def dump_csv(data, f, quoting=csv.QUOTE_ALL): 46 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting) 47 | 48 | def dump_tsv(data, f, quoting=csv.QUOTE_ALL): 49 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting) 50 | 51 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv) 52 | suffix = f.split('.')[-1] 53 | return handlers[suffix](data, f, **kwargs) 54 | 55 | def load(f): 56 | def load_pkl(pth): 57 | return pickle.load(open(pth, 'rb')) 58 | 59 | def load_json(pth): 60 | return json.load(open(pth, 'r', encoding='utf-8')) 61 | 62 | def load_jsonl(f): 63 | lines = open(f, encoding='utf-8').readlines() 64 | lines = [x.strip() for x in lines] 65 | if lines[-1] == '': 66 | lines = lines[:-1] 67 | data = [json.loads(x) for x in lines] 68 | return data 69 | 70 | def load_xlsx(f): 71 | return pd.read_excel(f) 72 | 73 | def load_csv(f): 74 | return pd.read_csv(f) 75 | 76 | def load_tsv(f): 77 | return pd.read_csv(f, sep='\t') 78 | 79 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv) 80 | suffix = f.split('.')[-1] 81 | return handlers[suffix](f) 82 | 83 | def download_file(url, filename=None): 84 | import urllib.request 85 | from tqdm import tqdm 86 | 87 | class DownloadProgressBar(tqdm): 88 | def update_to(self, b=1, bsize=1, tsize=None): 89 | if tsize is not None: 90 | self.total = tsize 91 | self.update(b * bsize - self.n) 92 | 93 | if filename is None: 94 | filename = url.split('/')[-1] 95 | 96 | with DownloadProgressBar(unit='B', unit_scale=True, 97 | miniters=1, desc=url.split('/')[-1]) as t: 98 | urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to) 99 | return filename 100 | 101 | def ls(dirname='.', match='', mode='all', level=1): 102 | if dirname == '.': 103 | ans = os.listdir(dirname) 104 | else: 105 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)] 106 | assert mode in ['all', 'dir', 'file'] 107 | assert level >= 1 and isinstance(level, int) 108 | if level == 1: 109 | ans = [x for x in ans if match in x] 110 | if mode == 'dir': 111 | ans = [x for x in ans if osp.isdir(x)] 112 | elif mode == 'file': 113 | ans = [x for x in ans if not osp.isdir(x)] 114 | else: 115 | ans = [x for x in ans if osp.isdir(x)] 116 | res = [] 117 | for d in ans: 118 | res.extend(ls(d, match=match, mode=mode, level=level-1)) 119 | ans = res 120 | return ans 121 | 122 | def mrlines(fname, sp='\n'): 123 | f = open(fname).read().split(sp) 124 | while f != [] and f[-1] == '': 125 | f = f[:-1] 126 | return f 127 | 128 | def mwlines(lines, fname): 129 | with open(fname, 'w') as fout: 130 | fout.write('\n'.join(lines)) 131 | 132 | def md5(file_pth): 133 | with open(file_pth, 'rb') as f: 134 | hash = hashlib.new('md5') 135 | for chunk in iter(lambda: f.read(2**20), b''): 136 | hash.update(chunk) 137 | return str(hash.hexdigest()) 138 | 139 | def last_modified(pth): 140 | stamp = osp.getmtime(pth) 141 | m_ti = time.ctime(stamp) 142 | t_obj = time.strptime(m_ti) 143 | t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:] 144 | return t 145 | -------------------------------------------------------------------------------- /vlmeval/smp/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger_initialized = {} 4 | 5 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 6 | logger = logging.getLogger(name) 7 | if name in logger_initialized: 8 | return logger 9 | 10 | for logger_name in logger_initialized: 11 | if name.startswith(logger_name): 12 | return logger 13 | 14 | stream_handler = logging.StreamHandler() 15 | handlers = [stream_handler] 16 | 17 | try: 18 | import torch.distributed as dist 19 | if dist.is_available() and dist.is_initialized(): 20 | rank = dist.get_rank() 21 | else: 22 | rank = 0 23 | except ImportError: 24 | rank = 0 25 | 26 | if rank == 0 and log_file is not None: 27 | file_handler = logging.FileHandler(log_file, file_mode) 28 | handlers.append(file_handler) 29 | 30 | formatter = logging.Formatter( 31 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 32 | for handler in handlers: 33 | handler.setFormatter(formatter) 34 | handler.setLevel(log_level) 35 | logger.addHandler(handler) 36 | 37 | if rank == 0: 38 | logger.setLevel(log_level) 39 | else: 40 | logger.setLevel(logging.ERROR) 41 | 42 | logger_initialized[name] = True 43 | return logger -------------------------------------------------------------------------------- /vlmeval/smp/misc.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401, F403 2 | import abc 3 | import argparse 4 | import csv 5 | import multiprocessing as mp 6 | import os 7 | import os.path as osp 8 | import copy as cp 9 | import random as rd 10 | import requests 11 | import shutil 12 | import subprocess 13 | import warnings 14 | import pandas as pd 15 | from collections import OrderedDict, defaultdict 16 | from multiprocessing import Pool, current_process 17 | from tqdm import tqdm 18 | import datetime 19 | import matplotlib.pyplot as plt 20 | import seaborn as sns 21 | from tabulate import tabulate_formats, tabulate 22 | from huggingface_hub import scan_cache_dir 23 | from sty import fg, bg, ef, rs 24 | 25 | def process_punctuation(inText): 26 | import re 27 | outText = inText 28 | punct = [ 29 | ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', 30 | '>', '<', '@', '`', ',', '?', '!' 31 | ] 32 | commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 33 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 34 | for p in punct: 35 | if (p + ' ' in inText or ' ' + p in inText) or (re.search( 36 | commaStrip, inText) is not None): 37 | outText = outText.replace(p, '') 38 | else: 39 | outText = outText.replace(p, ' ') 40 | outText = periodStrip.sub('', outText, re.UNICODE) 41 | return outText 42 | 43 | def h2r(value): 44 | if value[0] == '#': 45 | value = value[1:] 46 | assert len(value) == 6 47 | return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2)) 48 | 49 | def r2h(rgb): 50 | return '#%02x%02x%02x' % rgb 51 | 52 | def colored(s, color): 53 | if isinstance(color, str): 54 | color = h2r(color) 55 | return fg(*color) + s + fg.rs 56 | 57 | def istype(s, type): 58 | if isinstance(s, type): 59 | return True 60 | try: 61 | return isinstance(eval(s), type) 62 | except Exception as _: 63 | return False 64 | 65 | def bincount(lst): 66 | bins = defaultdict(lambda: 0) 67 | for item in lst: 68 | bins[item] += 1 69 | return bins 70 | 71 | def get_cache_path(repo_id): 72 | hf_cache_info = scan_cache_dir() 73 | repos = list(hf_cache_info.repos) 74 | repo = None 75 | for r in repos: 76 | if r.repo_id == repo_id: 77 | repo = r 78 | break 79 | if repo is None: 80 | return None 81 | revs = list(repo.revisions) 82 | rev2keep, last_modified = None, 0 83 | for rev in revs: 84 | if rev.last_modified > last_modified: 85 | rev2keep, last_modified = rev, rev.last_modified 86 | if rev2keep is None: 87 | return None 88 | return str(rev2keep.snapshot_path) 89 | 90 | def proxy_set(s): 91 | import os 92 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']: 93 | os.environ[key] = s 94 | 95 | def get_rank_and_world_size(): 96 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 97 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 98 | return local_rank, world_size 99 | 100 | def splitlen(s, sym='/'): 101 | return len(s.split(sym)) 102 | 103 | def listinstr(lst, s): 104 | assert isinstance(lst, list) 105 | for item in lst: 106 | if item in s: 107 | return True 108 | return False 109 | 110 | def d2df(D): 111 | return pd.DataFrame({x: [D[x]] for x in D}) 112 | 113 | def cn_string(s): 114 | import re 115 | if re.search(u'[\u4e00-\u9fff]', s): 116 | return True 117 | return False 118 | 119 | try: 120 | import decord 121 | except ImportError: 122 | pass 123 | 124 | def timestr(second=True, minute=False): 125 | s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:] 126 | if second: 127 | return s 128 | elif minute: 129 | return s[:-2] 130 | else: 131 | return s[:-4] 132 | 133 | def dict_merge(dct, merge_dct): 134 | for k, _ in merge_dct.items(): 135 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa 136 | dict_merge(dct[k], merge_dct[k]) 137 | else: 138 | dct[k] = merge_dct[k] 139 | 140 | def youtube_dl(idx): 141 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4' 142 | os.system(cmd) 143 | 144 | def run_command(cmd): 145 | if isinstance(cmd, str): 146 | cmd = cmd.split() 147 | return subprocess.check_output(cmd) 148 | -------------------------------------------------------------------------------- /vlmeval/smp/vlm.py: -------------------------------------------------------------------------------- 1 | import os, io 2 | import pandas as pd 3 | import numpy as np 4 | import string 5 | from uuid import uuid4 6 | import os.path as osp 7 | import base64 8 | from PIL import Image 9 | 10 | def mmqa_display(question): 11 | question = {k.lower(): v for k, v in question.items()} 12 | keys = list(question.keys()) 13 | keys = [k for k in keys if k not in ['index', 'image']] 14 | 15 | images = question['image'] 16 | if isinstance(images, str): 17 | images = [images] 18 | 19 | idx = question.pop('index', 'XXX') 20 | print(f'INDEX: {idx}') 21 | 22 | for im in images: 23 | image = decode_base64_to_image(im, target_size=512) 24 | display(image) 25 | 26 | for k in keys: 27 | try: 28 | if not pd.isna(question[k]): 29 | print(f'{k.upper()}. {question[k]}') 30 | except ValueError: 31 | if False in pd.isna(question[k]): 32 | print(f'{k.upper()}. {question[k]}') 33 | 34 | def encode_image_to_base64(img, target_size=-1): 35 | # if target_size == -1, will not do resizing 36 | # else, will set the max_size ot (target_size, target_size) 37 | if img.mode in ("RGBA", "P"): 38 | img = img.convert("RGB") 39 | tmp = osp.join('/tmp', str(uuid4()) + '.jpg') 40 | if target_size > 0: 41 | img.thumbnail((target_size, target_size)) 42 | img.save(tmp) 43 | with open(tmp, 'rb') as image_file: 44 | image_data = image_file.read() 45 | ret = base64.b64encode(image_data).decode('utf-8') 46 | os.remove(tmp) 47 | return ret 48 | 49 | def encode_image_file_to_base64(image_path, target_size=-1): 50 | image = Image.open(image_path) 51 | return encode_image_to_base64(image, target_size=target_size) 52 | 53 | def decode_base64_to_image(base64_string, target_size=-1): 54 | image_data = base64.b64decode(base64_string) 55 | image = Image.open(io.BytesIO(image_data)) 56 | if image.mode in ('RGBA', 'P'): 57 | image = image.convert('RGB') 58 | if target_size > 0: 59 | image.thumbnail((target_size, target_size)) 60 | return image 61 | 62 | def decode_base64_to_image_file(base64_string, image_path, target_size=-1): 63 | image = decode_base64_to_image(base64_string, target_size=target_size) 64 | image.save(image_path) 65 | 66 | def LMUDataRoot(): 67 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']): 68 | return os.environ['LMUData'] 69 | home = osp.expanduser('~') 70 | root = osp.join(home, 'LMUData') 71 | os.makedirs(root, exist_ok=True) 72 | return root 73 | 74 | def build_option_str(option_dict): 75 | s = 'There are several options: \n' 76 | for c, content in option_dict.items(): 77 | if not pd.isna(content): 78 | s += f'{c}. {content}\n' 79 | return s 80 | 81 | def isimg(s): 82 | return osp.exists(s) or s.startswith('http') 83 | 84 | def read_ok(img_path): 85 | if not osp.exists(img_path): 86 | return False 87 | try: 88 | im = Image.open(img_path) 89 | assert im.size[0] > 0 and im.size[1] > 0 90 | return True 91 | except: 92 | return False 93 | 94 | def gpt_key_set(): 95 | openai_key = os.environ.get('OPENAI_API_KEY', None) 96 | return isinstance(openai_key, str) and openai_key.startswith('sk-') 97 | 98 | def apiok(wrapper): 99 | s = wrapper.generate("Hello!") 100 | return wrapper.fail_msg not in s 101 | 102 | def circular_pred(df, extract_func=None): 103 | if extract_func is None: 104 | extract_func = lambda x: x 105 | df = df.sort_values('index') 106 | from vlmeval.utils import can_infer_option 107 | shift = int(1e6) 108 | 109 | choices = [extract_func(x) for x in df['prediction']] 110 | pred_map = {i: c for i, c in zip(df['index'], choices)} 111 | flag_map = {i: True for i in pred_map if i < 1e6} 112 | valid_map = {i: True for i in pred_map if i < 1e6} 113 | for i in df['index']: 114 | if i >= shift and pred_map[i] and pred_map[i - shift]: 115 | if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase): 116 | valid_map[i % shift] = False 117 | continue 118 | if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1: 119 | continue 120 | else: 121 | flag_map[i % shift] = False 122 | flag_map = {k: v for k, v in flag_map.items() if valid_map[k]} 123 | flags = list(flag_map.values()) 124 | return np.mean(flags) 125 | 126 | def MMBenchOfficialServer(): 127 | root = LMUDataRoot() 128 | for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']: 129 | if osp.exists(f'{root}/{dataset}.tsv'): 130 | return True 131 | return False -------------------------------------------------------------------------------- /vlmeval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .matching_util import can_infer, can_infer_option, can_infer_text 2 | from .mp_util import track_progress_rich 3 | from .custom_prompt import CustomPrompt 4 | from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full 5 | from .dataset import TSVDataset, split_MMMU 6 | from .xtuner_util import PROMPT_TEMPLATE, StopWordStoppingCriteria, expand2square, prepare_inputs_labels_for_multimodal, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX 7 | 8 | 9 | __all__ = [ 10 | 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich', 11 | 'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt', 12 | 'split_MMMU', 'abbr2full', 'expand2square', 'prepare_inputs_labels_for_multimodal', 13 | 'DEFAULT_IMAGE_TOKEN', 'IMAGE_TOKEN_INDEX', 'PROMPT_TEMPLATE', 'StopWordStoppingCriteria' 14 | ] 15 | -------------------------------------------------------------------------------- /vlmeval/utils/custom_prompt.py: -------------------------------------------------------------------------------- 1 | from ..smp import * 2 | from .dataset_config import img_root_map 3 | from abc import abstractmethod 4 | 5 | class CustomPrompt: 6 | 7 | @abstractmethod 8 | def use_custom_prompt(self, dataset): 9 | raise NotImplementedError 10 | 11 | @abstractmethod 12 | def build_prompt(self, line, dataset): 13 | raise NotImplementedError 14 | 15 | def dump_image(self, line, dataset): 16 | ROOT = LMUDataRoot() 17 | assert isinstance(dataset, str) 18 | img_root = osp.join(ROOT, 'images', img_root_map[dataset]) 19 | os.makedirs(img_root, exist_ok=True) 20 | if isinstance(line['image'], list): 21 | tgt_path = [] 22 | assert 'image_path' in line 23 | for img, im_name in zip(line['image'], line['image_path']): 24 | path = osp.join(img_root, im_name) 25 | if not read_ok(path): 26 | decode_base64_to_image_file(img, path) 27 | tgt_path.append(path) 28 | else: 29 | tgt_path = osp.join(img_root, f"{line['index']}.jpg") 30 | if not read_ok(tgt_path): 31 | decode_base64_to_image_file(line['image'], tgt_path) 32 | return tgt_path -------------------------------------------------------------------------------- /vlmeval/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import hashlib 3 | from ..smp import * 4 | from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE 5 | from .custom_prompt import CustomPrompt 6 | 7 | def isliststr(s): 8 | return (s[0] == '[') and (s[-1] == ']') 9 | 10 | def check_md5(data_path, dataset): 11 | try: 12 | with open(data_path, 'rb') as f: 13 | hash = hashlib.new('md5') 14 | for chunk in iter(lambda: f.read(2**20), b''): 15 | hash.update(chunk) 16 | if str(hash.hexdigest()) == dataset_md5_dict[dataset]: 17 | return True 18 | else: 19 | warnings.warn('this data file is incomplete, so it needs to be downloaded again.') 20 | return False 21 | except: 22 | return False 23 | 24 | def split_MMMU(struct): 25 | assert 'image' in struct and 'text' in struct 26 | text, images = struct['text'], struct['image'] 27 | text_segs = text.split('' 33 | image_idx = int(seg[0]) - 1 34 | segs.append(images[image_idx]) 35 | segs.append(seg[2:]) 36 | return segs 37 | 38 | class TSVDataset(CustomPrompt): 39 | 40 | def __init__(self, dataset='MMBench', skip_noimg=True): 41 | 42 | self.data_root = LMUDataRoot() 43 | assert osp.exists(self.data_root) 44 | 45 | self.dataset = dataset 46 | self.dataset_type = DATASET_TYPE(dataset) 47 | 48 | url = dataset_URLs[dataset] 49 | file_name = url.split('/')[-1] 50 | data_path = osp.join(self.data_root, file_name) 51 | 52 | #print(md5(data_path)) 53 | #print(dataset_md5_dict[dataset]) 54 | if osp.exists(data_path) and md5(data_path) == dataset_md5_dict[dataset]: 55 | pass 56 | else: 57 | warnings.warn("The dataset tsv is not downloaded") 58 | download_file(url, data_path) 59 | 60 | data = load(data_path) 61 | self.skip_noimg = skip_noimg 62 | if skip_noimg: 63 | data = data[~pd.isna(data['image'])] 64 | 65 | # Prompt for Captioning 66 | if listinstr(['COCO'], dataset): 67 | data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data) 68 | 69 | data['index'] = [str(x) for x in data['index']] 70 | data['image'] = [str(x) for x in data['image']] 71 | 72 | image_map = {x: y for x, y in zip(data['index'], data['image'])} 73 | for k in image_map: 74 | if len(image_map[k]) <= 64: 75 | idx = image_map[k] 76 | assert idx in image_map and len(image_map[idx]) > 64 77 | image_map[k] = image_map[idx] 78 | 79 | data['image'] = [ 80 | eval(image_map[k]) if isliststr(image_map[k]) else image_map[k] 81 | for k in data['index'] 82 | ] 83 | if 'image_path' in data: 84 | data['image_path'] = [ 85 | eval(pths) if isliststr(pths) else pths for pths in data['image_path'] 86 | ] 87 | if np.all([istype(x, int) for x in data['index']]): 88 | data['index'] = [int(x) for x in data['index']] 89 | 90 | self.data = data 91 | 92 | def __len__(self): 93 | return len(self.data) 94 | 95 | def build_prompt(self, line, dataset=None): 96 | if dataset is None: 97 | dataset = self.dataset 98 | 99 | if isinstance(line, int): 100 | line = self.data.iloc[line] 101 | 102 | tgt_path = self.dump_image(line, dataset) 103 | 104 | prompt = line['question'] 105 | if DATASET_TYPE(dataset) == 'multi-choice': 106 | question = line['question'] 107 | options = { 108 | cand: line[cand] 109 | for cand in string.ascii_uppercase 110 | if cand in line and not pd.isna(line[cand]) 111 | } 112 | options_prompt = 'Options:\n' 113 | for key, item in options.items(): 114 | options_prompt += f'{key}. {item}\n' 115 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None 116 | prompt = '' 117 | if hint is not None: 118 | prompt += f'Hint: {hint}\n' 119 | prompt += f'Question: {question}\n' 120 | if len(options): 121 | prompt += options_prompt 122 | prompt += 'Please select the correct answer from the options above. \n' 123 | 124 | return dict(image=tgt_path, text=prompt) 125 | 126 | def display(self, line): 127 | if isinstance(line, int): 128 | line = self.data.iloc[line] 129 | mmqa_display(line) 130 | 131 | -------------------------------------------------------------------------------- /vlmeval/utils/dataset_config.py: -------------------------------------------------------------------------------- 1 | from ..smp import listinstr 2 | 3 | dataset_URLs = { 4 | 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv", 5 | 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv", 6 | 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv", 7 | 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv", 8 | "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only 9 | "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only 10 | 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv", 11 | 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv", 12 | 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv", 13 | "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv", 14 | "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv", 15 | "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv", 16 | "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv", 17 | "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv", 18 | 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv", 19 | "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv", 20 | "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv", 21 | "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv", 22 | 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv", 23 | 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv", 24 | 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv", 25 | 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv", 26 | "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv", 27 | 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv", 28 | "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv", 29 | } 30 | 31 | dataset_md5_dict = { 32 | 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8", 33 | 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528", 34 | 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd", 35 | 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e", 36 | "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only 37 | "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only 38 | 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07", 39 | 'MME': "b36b43c3f09801f5d368627fb92187c3", 40 | 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0", 41 | "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d", 42 | "MMVet": "f400d7f513a585a0f218cbd6882e0671", 43 | 'COCO_VAL': "72a5079dead060269ac222c5aa5128af", 44 | 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9', 45 | 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97', 46 | 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd', 47 | 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d", 48 | 'MMMU_TEST' : "c19875d11a2d348d07e5eb4bdf33166d", 49 | 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464', 50 | 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc', 51 | 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3', 52 | 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f', 53 | 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c', 54 | "DocVQA_VAL": 'c911fdc5f4974513c112cc83a25c99d9', 55 | "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975", 56 | "LLaVABench": "d382a093f749a697820d3dadd61c8428" 57 | } 58 | 59 | img_root_map = {k: k for k in dataset_URLs} 60 | img_root_map.update({ 61 | 'MMBench_DEV_EN': "MMBench", 62 | 'MMBench_TEST_EN': "MMBench", 63 | 'MMBench_DEV_CN': "MMBench", 64 | 'MMBench_TEST_CN': "MMBench", 65 | "MMBench_CN": "MMBench", # Link Invalid, Internal Only 66 | 'COCO_VAL':'COCO', 67 | 'OCRVQA_TEST': 'OCRVQA', 68 | 'OCRVQA_TESTCORE': 'OCRVQA', 69 | 'TextVQA_VAL': 'TextVQA', 70 | 'MMMU_DEV_VAL': 'MMMU', 71 | "MMMU_TEST": "MMMU", 72 | 'MathVista_MINI': 'MathVista', 73 | 'ChartQA_VALTEST_HUMAN': 'ChartQA', 74 | 'HallusionBench': 'Hallusion', 75 | 'DocVQA_VAL': 'DocVQA', 76 | }) 77 | 78 | assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict) 79 | 80 | def DATASET_TYPE(dataset): 81 | if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset.lower()): 82 | return 'multi-choice' 83 | elif 'MME' in dataset: 84 | return 'Y/N' 85 | elif 'COCO' in dataset: 86 | return 'Caption' 87 | elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa'], dataset.lower()): 88 | return 'VQA' 89 | else: 90 | return 'QA' 91 | 92 | def abbr2full(s): 93 | datasets = [x for x in img_root_map] 94 | ins = [s in d for d in datasets] 95 | if sum(ins) == 1: 96 | for d in datasets: 97 | if s in d: 98 | return d 99 | else: 100 | return None 101 | -------------------------------------------------------------------------------- /vlmeval/utils/matching_util.py: -------------------------------------------------------------------------------- 1 | import string 2 | import copy as cp 3 | import os 4 | from ..smp import * 5 | 6 | def can_infer_option(answer, choices): 7 | verbose = os.environ.get('VERBOSE', 0) 8 | # Choices is a dictionary 9 | if 'Failed to obtain answer via API' in answer: 10 | return False 11 | 12 | bard_err = [ 13 | "Sorry, I can't help with images of people yet.", 14 | "I can't process this file." 15 | ] 16 | for err in bard_err: 17 | if err in answer: 18 | return 'Z' 19 | 20 | def count_choice(splits, choices, prefix='', suffix=''): 21 | cnt = 0 22 | for c in choices: 23 | if prefix + c + suffix in splits: 24 | cnt += 1 25 | return cnt 26 | 27 | answer_mod = cp.copy(answer) 28 | chars = '.()[],:;!*#{}' 29 | for c in chars: 30 | answer_mod = answer_mod.replace(c, ' ') 31 | 32 | splits = [x.strip() for x in answer_mod.split()] 33 | count = count_choice(splits, choices) 34 | 35 | if count == 1: 36 | for ch in choices: 37 | if 'A' in splits and len(splits) > 3 and verbose: 38 | logger = get_logger('Evaluation') 39 | logger.info(f'A might be a quantifier in the string: {answer}.') 40 | return False 41 | if ch in splits: 42 | return ch 43 | elif count == 0 and count_choice(splits, {'Z', ''}) == 1: 44 | return 'Z' 45 | return False 46 | 47 | def can_infer_text(answer, choices): 48 | answer = answer.lower() 49 | assert isinstance(choices, dict) 50 | for k in choices: 51 | assert k in string.ascii_uppercase 52 | choices[k] = str(choices[k]).lower() 53 | cands = [] 54 | for k in choices: 55 | if choices[k] in answer: 56 | cands.append(k) 57 | if len(cands) == 1: 58 | return cands[0] 59 | return False 60 | 61 | def can_infer(answer, choices): 62 | answer = str(answer) 63 | copt = can_infer_option(answer, choices) 64 | return copt if copt else can_infer_text(answer, choices) -------------------------------------------------------------------------------- /vlmeval/utils/mp_util.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | from typing import Callable, Iterable, Sized 4 | 5 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, 6 | TaskProgressColumn, TextColumn, TimeRemainingColumn) 7 | from rich.text import Text 8 | import os.path as osp 9 | import portalocker 10 | from ..smp import load, dump 11 | 12 | 13 | class _Worker: 14 | """Function wrapper for ``track_progress_rich``""" 15 | 16 | def __init__(self, func) -> None: 17 | self.func = func 18 | 19 | def __call__(self, inputs): 20 | inputs, idx = inputs 21 | if not isinstance(inputs, (tuple, list, dict)): 22 | inputs = (inputs, ) 23 | 24 | if isinstance(inputs, dict): 25 | return self.func(**inputs), idx 26 | else: 27 | return self.func(*inputs), idx 28 | 29 | 30 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): 31 | """Skip calculating remaining time for the first few times. 32 | 33 | Args: 34 | skip_times (int): The number of times to skip. Defaults to 0. 35 | """ 36 | 37 | def __init__(self, *args, skip_times=0, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self.skip_times = skip_times 40 | 41 | def render(self, task: Task) -> Text: 42 | """Show time remaining.""" 43 | if task.completed <= self.skip_times: 44 | return Text('-:--:--', style='progress.remaining') 45 | return super().render(task) 46 | 47 | 48 | def _tasks_with_index(tasks): 49 | """Add index to tasks.""" 50 | for idx, task in enumerate(tasks): 51 | yield task, idx 52 | 53 | def track_progress_rich(func: Callable, 54 | tasks: Iterable = tuple(), 55 | task_num: int = None, 56 | nproc: int = 1, 57 | chunksize: int = 1, 58 | description: str = 'Processing', 59 | save=None, keys=None, 60 | color: str = 'blue') -> list: 61 | """Track the progress of parallel task execution with a progress bar. The 62 | built-in :mod:`multiprocessing` module is used for process pools and tasks 63 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. 64 | 65 | Args: 66 | func (callable): The function to be applied to each task. 67 | tasks (Iterable or Sized): A tuple of tasks. There are several cases 68 | for different format tasks: 69 | - When ``func`` accepts no arguments: tasks should be an empty 70 | tuple, and ``task_num`` must be specified. 71 | - When ``func`` accepts only one argument: tasks should be a tuple 72 | containing the argument. 73 | - When ``func`` accepts multiple arguments: tasks should be a 74 | tuple, with each element representing a set of arguments. 75 | If an element is a ``dict``, it will be parsed as a set of 76 | keyword-only arguments. 77 | Defaults to an empty tuple. 78 | task_num (int, optional): If ``tasks`` is an iterator which does not 79 | have length, the number of tasks can be provided by ``task_num``. 80 | Defaults to None. 81 | nproc (int): Process (worker) number, if nuproc is 1, 82 | use single process. Defaults to 1. 83 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details. 84 | Defaults to 1. 85 | description (str): The description of progress bar. 86 | Defaults to "Process". 87 | color (str): The color of progress bar. Defaults to "blue". 88 | 89 | Examples: 90 | >>> import time 91 | 92 | >>> def func(x): 93 | ... time.sleep(1) 94 | ... return x**2 95 | >>> track_progress_rich(func, range(10), nproc=2) 96 | 97 | Returns: 98 | list: The task results. 99 | """ 100 | if save is not None: 101 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == '' 102 | if not osp.exists(save): 103 | dump({}, save) 104 | if keys is not None: 105 | assert len(keys) == len(tasks) 106 | 107 | if not callable(func): 108 | raise TypeError('func must be a callable object') 109 | if not isinstance(tasks, Iterable): 110 | raise TypeError( 111 | f'tasks must be an iterable object, but got {type(tasks)}') 112 | if isinstance(tasks, Sized): 113 | if len(tasks) == 0: 114 | if task_num is None: 115 | raise ValueError('If tasks is an empty iterable, ' 116 | 'task_num must be set') 117 | else: 118 | tasks = tuple(tuple() for _ in range(task_num)) 119 | else: 120 | if task_num is not None and task_num != len(tasks): 121 | raise ValueError('task_num does not match the length of tasks') 122 | task_num = len(tasks) 123 | 124 | if nproc <= 0: 125 | raise ValueError('nproc must be a positive number') 126 | 127 | skip_times = nproc * chunksize if nproc > 1 else 0 128 | prog_bar = Progress( 129 | TextColumn('{task.description}'), 130 | BarColumn(), 131 | _SkipFirstTimeRemainingColumn(skip_times=skip_times), 132 | MofNCompleteColumn(), 133 | TaskProgressColumn(show_speed=True), 134 | ) 135 | 136 | worker = _Worker(func) 137 | task_id = prog_bar.add_task( 138 | total=task_num, color=color, description=description) 139 | tasks = _tasks_with_index(tasks) 140 | 141 | # Use single process when nproc is 1, else use multiprocess. 142 | with prog_bar: 143 | if nproc == 1: 144 | results = [] 145 | for task in tasks: 146 | result, idx = worker(task) 147 | results.append(worker(task)[0]) 148 | if save is not None: 149 | with portalocker.Lock(save, timeout=5) as fh: 150 | ans = load(save) 151 | ans[keys[idx]] = result 152 | 153 | if os.environ.get('VERBOSE', True): 154 | print(keys[idx], result, flush=True) 155 | 156 | dump(ans, save) 157 | fh.flush() 158 | os.fsync(fh.fileno()) 159 | 160 | prog_bar.update(task_id, advance=1, refresh=True) 161 | else: 162 | with Pool(nproc) as pool: 163 | results = [] 164 | unordered_results = [] 165 | gen = pool.imap_unordered(worker, tasks, chunksize) 166 | try: 167 | for result in gen: 168 | result, idx = result 169 | unordered_results.append((result, idx)) 170 | 171 | if save is not None: 172 | with portalocker.Lock(save, timeout=5) as fh: 173 | ans = load(save) 174 | ans[keys[idx]] = result 175 | 176 | if os.environ.get('VERBOSE', False): 177 | print(keys[idx], result, flush=True) 178 | 179 | dump(ans, save) 180 | fh.flush() 181 | os.fsync(fh.fileno()) 182 | 183 | results.append(None) 184 | prog_bar.update(task_id, advance=1, refresh=True) 185 | except Exception as e: 186 | prog_bar.stop() 187 | raise e 188 | for result, idx in unordered_results: 189 | results[idx] = result 190 | return results -------------------------------------------------------------------------------- /vlmeval/vlm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_grad_enabled(False) 3 | torch.manual_seed(1234) 4 | 5 | from .hpt import HPT 6 | from .hpt1_5 import HPT1_5 7 | -------------------------------------------------------------------------------- /vlmeval/vlm/hpt1_5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import string 4 | import sys 5 | import warnings 6 | import re 7 | import torch.nn as nn 8 | 9 | import pandas as pd 10 | import torch 11 | from huggingface_hub import snapshot_download 12 | from PIL import Image 13 | from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, 14 | SiglipImageProcessor, 15 | GenerationConfig, StoppingCriteriaList) 16 | 17 | from ..smp import cn_string, get_cache_path 18 | from ..utils import DATASET_TYPE, CustomPrompt 19 | from .modeling_siglip import SiglipVisionModel 20 | 21 | def interpolate_pos_embed_siglip(model, new_size): 22 | pos_emb = model.vision_model.embeddings.position_embedding.weight.float() 23 | ori_size = int((pos_emb.shape[0])**0.5) 24 | dim = pos_emb.shape[1] 25 | print("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size)) 26 | pos_tokens = pos_emb 27 | pos_tokens = pos_tokens.reshape(-1, ori_size, ori_size, dim).permute(0, 3, 1, 2) 28 | pos_tokens = torch.nn.functional.interpolate( 29 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 30 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2).squeeze(0) 31 | new_pos_embed = pos_tokens #torch.cat((extra_tokens, pos_tokens), dim=0) 32 | new_pos_embed = new_pos_embed.to(torch.float16) 33 | return torch.nn.Parameter(new_pos_embed) 34 | 35 | class HPT1_5(CustomPrompt): 36 | 37 | INSTALL_REQ = True 38 | 39 | def __init__(self, 40 | global_model_path='HyperGAI/HPT1_5-Air-Llama-3-8B-Instruct-multimodal', 41 | vis_scale=448, 42 | visual_select_layer=-2, 43 | prompt_template='llama3_chat', 44 | stop_words=[], 45 | torch_dtype=torch.float16): 46 | from vlmeval.utils import PROMPT_TEMPLATE, StopWordStoppingCriteria 47 | 48 | llm = AutoModelForCausalLM.from_pretrained(global_model_path, 49 | subfolder='llm', 50 | trust_remote_code=True, 51 | torch_dtype=torch_dtype, 52 | device_map='cpu') 53 | 54 | tokenizer = AutoTokenizer.from_pretrained(global_model_path, 55 | subfolder='llm', 56 | trust_remote_code=True, 57 | encode_special_tokens=True) 58 | print(f'Load LLM') 59 | 60 | # build visual_encoder 61 | image_size = vis_scale 62 | self.image_size = vis_scale 63 | 64 | image_processor = SiglipImageProcessor.from_pretrained(global_model_path, 65 | subfolder='visual_encoder', 66 | size={"height": image_size, "width": image_size}) 67 | 68 | visual_encoder = SiglipVisionModel.from_pretrained(global_model_path, 69 | subfolder='visual_encoder', 70 | torch_dtype=torch_dtype, device_map='cpu') 71 | 72 | patch_size = visual_encoder.vision_model.embeddings.patch_size 73 | num_positions = (image_size//patch_size)**2 74 | new_size = image_size//patch_size 75 | visual_encoder.vision_model.embeddings.num_patches = (image_size//patch_size)**2 76 | visual_encoder.vision_model.embeddings.num_positions = num_positions 77 | visual_encoder.vision_model.embeddings.position_ids = torch.arange(num_positions).expand((1, -1)) 78 | visual_encoder.vision_model.embeddings.position_embedding.weight = interpolate_pos_embed_siglip(visual_encoder, new_size) 79 | visual_encoder.config.image_size = image_size 80 | 81 | print(f'Load visual_encoder') 82 | 83 | 84 | projector = AutoModel.from_pretrained(global_model_path, 85 | subfolder='projector', 86 | trust_remote_code=True, 87 | torch_dtype=torch_dtype, 88 | ) 89 | print(f'Load projector') 90 | 91 | llm.eval() 92 | visual_encoder.eval() 93 | projector.eval() 94 | 95 | self.llm = llm.cuda() 96 | self.tokenizer = tokenizer 97 | self.visual_encoder = visual_encoder.cuda() 98 | self.image_processor = image_processor 99 | self.projector = projector.cuda() 100 | self.visual_select_layer = visual_select_layer 101 | if prompt_template is not None: 102 | self.prompt_template = PROMPT_TEMPLATE[prompt_template] 103 | stop_words += self.prompt_template.get('STOP_WORDS', []) 104 | else: 105 | self.prompt_template = None 106 | 107 | self.stop_criteria = StoppingCriteriaList() 108 | for word in stop_words: 109 | self.stop_criteria.append( 110 | StopWordStoppingCriteria(self.tokenizer, word)) 111 | 112 | def build_gen_config(self, dataset, qtype): 113 | gen_kwargs = dict(max_new_tokens=1024, 114 | do_sample=True, 115 | temperature=0.6, 116 | num_beams=5, 117 | top_p=0.9, 118 | eos_token_id=self.tokenizer.eos_token_id, 119 | pad_token_id=self.tokenizer.pad_token_id 120 | if self.tokenizer.pad_token_id is not None else 121 | self.tokenizer.eos_token_id) 122 | # For single word generation 123 | if (dataset is not None and DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']): 124 | if qtype == '': 125 | gen_kwargs.update( 126 | dict(max_new_tokens=5, do_sample=False, num_beams=1)) 127 | elif qtype == 'open': 128 | gen_kwargs.update( 129 | dict(max_new_tokens=1024, do_sample=False, num_beams=1)) 130 | 131 | return GenerationConfig(**gen_kwargs) 132 | 133 | def use_custom_prompt(self, dataset): 134 | assert dataset is not None 135 | if DATASET_TYPE(dataset) == 'multi-choice': 136 | return True 137 | return False 138 | 139 | def build_prompt(self, line, dataset=None): 140 | assert self.use_custom_prompt(dataset) 141 | assert dataset is None or isinstance(dataset, str) 142 | tgt_path = self.dump_image(line, dataset) 143 | 144 | question = line['question'] 145 | hint = line['hint'] if ('hint' in line 146 | and not pd.isna(line['hint'])) else None 147 | 148 | if hint is not None: 149 | try: 150 | question = hint + '\n' + question 151 | except: 152 | print(hint) 153 | 154 | options = { 155 | cand: line[cand] 156 | for cand in string.ascii_uppercase 157 | if cand in line and not pd.isna(line[cand]) 158 | } 159 | for key, item in options.items(): 160 | question += f'\n{key}. {item}' 161 | qtype = '' 162 | if not cn_string(question): 163 | if 'question_type' in line.keys() and 'choice' not in line['question_type']: 164 | prompt = question + '\n' + ("Answer with succinct phrase or a single word, incorporating professional terms when necessary, to address inquiries regarding scene description, key elements identification, and potential activities in the provided image.") #("Answer the question using a single word or phrase.") 165 | qtype = 'open' 166 | else: 167 | # prompt = question + '\n' + ("Answer with the option's letter from the given choices directly.") 168 | prompt = question + '\n' "Answer the question with the correct option's letter from the given choices." 169 | else: 170 | prompt = question + '\n' + '请直接回答选项字母。' 171 | 172 | return {'image': tgt_path, 'text': prompt, 'qtype': qtype} 173 | 174 | def generate(self, image_path, prompt, dataset=None, qtype=''): 175 | from vlmeval.utils import expand2square, prepare_inputs_labels_for_multimodal, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX 176 | 177 | if isinstance(image_path, list) and len(image_path) > 1 and prompt.count('') > 0 and prompt.count('') > 0: 178 | image_list = [Image.open(image_path_cur).convert('RGB') for image_path_cur in image_path] 179 | image1 = image_list[0] 180 | w1, h1 = image1.size 181 | 182 | image2 = image_list[1] 183 | w2, h2 = image2.size 184 | 185 | w_sum, h_sum = w1 + w2, h1 + h2 186 | if w_sum > h_sum: 187 | h2_new = int(h2 * w2 / w1) 188 | image2 =image2.resize((w1, h2_new)) 189 | image = Image.new('RGB', (w1, h1 + h2_new)) 190 | image.paste(image1, (0, 0)) 191 | image.paste(image2, (0, h1)) 192 | 193 | prompt = prompt.replace('', '') 194 | prompt = prompt.replace('', '') 195 | else: 196 | w2_new = int(w2 * h2 / h1) 197 | image2 = image2.resize((w2_new, h2)) 198 | image = Image.new('RGB', (w1 + w2_new, h2)) 199 | image.paste(image1, (0, 0)) 200 | image.paste(image2, (w1, 0)) 201 | 202 | prompt = prompt.replace('', '') 203 | prompt = prompt.replace('', '') 204 | 205 | image_size = 448 206 | image = image.resize((image_size, image_size), 3) 207 | elif isinstance(image_path, list): 208 | image = Image.open(image_path[0]).convert('RGB') 209 | else: 210 | image = Image.open(image_path).convert('RGB') 211 | 212 | image = self.image_processor.preprocess( 213 | image, return_tensors='pt')['pixel_values'][0] 214 | image = image.cuda().unsqueeze(0) 215 | visual_outputs = self.visual_encoder(image, output_hidden_states=True) 216 | # pixel_values = self.projector(visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) 217 | pixel_values = self.projector(visual_outputs.hidden_states[self.visual_select_layer]) 218 | 219 | inputs = DEFAULT_IMAGE_TOKEN + '\n' + prompt 220 | 221 | if self.prompt_template: 222 | inputs = self.prompt_template['INSTRUCTION'].format(input=inputs) 223 | 224 | chunk_encode = [] 225 | for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): 226 | if idx == 0: 227 | cur_encode = self.tokenizer(chunk) 228 | else: 229 | cur_encode = self.tokenizer(chunk, add_special_tokens=False) 230 | chunk_encode.append(cur_encode) 231 | if len(chunk_encode) != 2: 232 | print(prompt) 233 | 234 | assert len(chunk_encode) == 2 235 | ids = [] 236 | for idx, cur_chunk_encode in enumerate(chunk_encode): 237 | ids.extend(cur_chunk_encode['input_ids']) 238 | if idx != len(chunk_encode) - 1: 239 | ids.append(IMAGE_TOKEN_INDEX) 240 | ids = torch.tensor(ids).cuda().unsqueeze(0) 241 | mm_inputs = prepare_inputs_labels_for_multimodal( 242 | llm=self.llm, input_ids=ids, pixel_values=pixel_values) 243 | 244 | gen_config = self.build_gen_config(dataset, qtype) 245 | 246 | generate_output = self.llm.generate( 247 | **mm_inputs, 248 | generation_config=gen_config, 249 | streamer=None, 250 | bos_token_id=self.tokenizer.bos_token_id, 251 | stopping_criteria=self.stop_criteria) 252 | predict = self.tokenizer.decode(generate_output[0], 253 | skip_special_tokens=True).strip() 254 | return predict 255 | --------------------------------------------------------------------------------