├── .gitignore ├── LICENSE ├── README.md ├── assets ├── example_image1.jpg ├── example_image2.jpg ├── example_video.mp4 └── overview.png ├── eval ├── mmbench │ └── evaluate_mmbench.py ├── mme │ ├── README.md │ ├── calculation.py │ └── eval.py ├── mmmu │ ├── answer_dict_val.json │ ├── data_utils.py │ ├── eval_utils.py │ ├── evaluate_mmmu.py │ ├── evaluate_mmmu_cot.py │ └── main_eval_only.py ├── mvbench │ └── evaluate_mvbench.py ├── scienceqa │ └── evaluate_scienceqa.py ├── seed │ ├── calculation.py │ └── evaluate_seed.py └── vqa │ ├── convert_gqa_for_eval.py │ ├── evaluate_vqa.py │ ├── infographicsvqa_eval.py │ └── textvqa_eval.py ├── evaluate.sh ├── evaluate_launch.sh ├── requirements.txt └── utils └── preprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .idea/ 163 | .vscode/ 164 | 165 | data 166 | results/ 167 | eval/ 168 | ckpts/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OpenGVLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Progressive Visual Token Compression (PVC) 2 | 3 | ![Static Badge](https://img.shields.io/badge/CVPR-2025-red) 4 | [![Static Badge](https://img.shields.io/badge/arXiv-2412.09613-green)](https://arxiv.org/abs/2412.09613) 5 | [![Static Badge](https://img.shields.io/badge/🤗 HuggingFace-checkpoint-blue)](https://huggingface.co/OpenGVLab/PVC-InternVL2-8B) 6 | 7 | **[CVPR 2025]** [**PVC: Progressive Visual Token Compression for Unified Image and Video Processing in Large Vision-Language Models**](https://arxiv.org/abs/2412.09613) 8 | 9 | We introduce the **Progressive Visual Token Compression (PVC)** in large vision-language models (VLMs), which unifies the visual inputs as videos and progressively compresses vision tokens across video frames. Our PVC achieves: 10 | 11 | * Preserve spatial details and temporal dynamics for both images and videos. 12 | * Effectively reduce the tokens used for each video frame and image tile. 13 | * SoTA performance on various video benchmarks, including long and fine-grained short video tasks. 14 | * No performance loss on image benchmarks, especially on detail-sensitive tasks. 15 | 16 |
17 | 18 |
19 | 20 | ## 📈 Results 21 | 22 | Our implementation is based on the [InternVL2](https://github.com/OpenGVLab/InternVL) model, referred to as **PVCInternVL2** 23 | 24 | ### Video Understanding Benckmarks 25 | 26 | | Model | LLaVA-OneVision-7B | Qwen2-VL-7B | InternVL2-8B | PVCInternVL2-8B
🤗 [link](https://huggingface.co/OpenGVLab/PVC-InternVL2-8B) | 27 | | :--------------: | :--: | :--: | :--: | :--: | 28 | | \# token/frame | 196 | - | 256 | 64 | 29 | | | | | | | 30 | | MVbench | 56.7 | 67.0 | 66.4 | 73.8 | 31 | | VideoMME w/o-sub | 58.2 | 63.3 | 54.0 | 64.1 | 32 | | VideoMME w-sub | 61.5 | 69.0 | 56.9 | 69.7 | 33 | | MLVU | 64.7 | - | 52.0 | 72.4 | 34 | | LongVideoBench | 56.5 | - | - | 59.2 | 35 | | NextQA | 79.4 | - | - | 82.0 | 36 | | Egoschema | 60.1 | 66.7 | 55.0 | 59.6 | 37 | | PercepTest | 57.1 | 62.3 | 52.0 | 68.4 | 38 | | AcNet-QA | 56.6 | - | - | 57.1 | 39 | 40 | ### Image Understanding Benckmarks 41 | 42 | | Model | LLaVA-OneVision-7B | Qwen2-VL-7B | InternVL2-8B | PVCInternVL2-8B
🤗 [link](https://huggingface.co/OpenGVLab/PVC-InternVL2-8B) | 43 | | :--------------------: | :--: | :--: | :--: | :--: | 44 | | \# token/image tile | 729 | - | 256 | 64 | 45 | | | | | | | 46 | | AI2Dtest | 81.4 | 83.0 | 83.8 | 83.8 | 47 | | ChartQAtest | 80.0 | 83.0 | 83.3 | 84.1 | 48 | | DocVQAtest | 87.5 | 94.5 | 91.6 | 92.5 | 49 | | InfoVQAtest | 68.8 | 76.5 | 74.8 | 75.0 | 50 | | SQAtest | 96.0 | - | 97.1 | 97.7 | 51 | | TextVQAval | - | 84.3 | 77.4 | 80.0 | 52 | | MMBen-test | - | 83.0 | 81.7 | 83.9 | 53 | | MMEsum | 1998 | 2327 | 2210 | 2282 | 54 | | MMMUval | 48.8 | 54.1 | 49.3 | 50.9 | 55 | | SEEDI | 75.4 | - | 76.2 | 77.2 | 56 | | OCRBench | - | 866 | 794 | 807 | 57 | 58 | ## 🛠️ Usage 59 | 60 | You can use `pip install -r requirements.txt` to set up the environment. Please use `transformers>=4.37.2` to ensure the model works normally. 61 | 62 | ```python 63 | import torch 64 | from transformers import AutoTokenizer, AutoModel 65 | from utils.preprocess import load_image, load_video 66 | 67 | path = 'OpenGVLab/PVC-InternVL2-8B' 68 | model = AutoModel.from_pretrained( 69 | path, 70 | torch_dtype=torch.bfloat16, 71 | low_cpu_mem_usage=True, 72 | trust_remote_code=True).eval().cuda() 73 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) 74 | generation_config = dict(max_new_tokens=1024, do_sample=True) 75 | 76 | # single-image conversation 77 | pixel_values = load_image('./assets/example_image1.jpg', max_num=12).to(torch.bfloat16).cuda() 78 | data_flag = torch.tensor([1], dtype=torch.long).cuda() 79 | 80 | question = '\nWhat is in the image?' 81 | response = model.chat(tokenizer, pixel_values, question, generation_config, data_flag=data_flag) 82 | print(f'User: {question}\nAssistant: {response}') 83 | 84 | # multi-image conversation 85 | pixel_values1 = load_image('./assets/example_image1.jpg', max_num=12).to(torch.bfloat16).cuda() 86 | pixel_values2 = load_image('./assets/example_image2.jpg', max_num=12).to(torch.bfloat16).cuda() 87 | pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0) 88 | data_flag = torch.tensor([2], dtype=torch.long).cuda() 89 | num_patches_list = [pixel_values1.shape[0], pixel_values2.shape[0]] 90 | 91 | question = 'Image-1: \nImage-2: \nWhat are the similarities and differences between these two images.' 92 | response = model.chat(tokenizer, pixel_values, question, generation_config, data_flag=data_flag, num_patches_list=num_patches_list) 93 | print(f'User: {question}\nAssistant: {response}') 94 | 95 | # video conversation 96 | pixel_values, num_patches_list = load_video('./assets/example_video.mp4', num_segments=64, max_num=1) 97 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 98 | video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) 99 | # Frame1: \nFrame2: \n...\nFrameN: \n{question} 100 | data_flag = torch.tensor([3], dtype=torch.long).cuda() 101 | 102 | question = video_prefix + 'Describe this video in detail.' 103 | response = model.chat(tokenizer, pixel_values, question, generation_config, data_flag=data_flag, num_patches_list=num_patches_list) 104 | print(f'User: {question}\nAssistant: {response}') 105 | ``` 106 | 107 | ## 📊 Evaluation 108 | 109 | ### Image Benchmarks & MVBench 110 | 111 | **Prepare data:** please follow [here](https://internvl.readthedocs.io/en/latest/get_started/eval_data_preparation.html) to prepare the data for evaluation. 112 | 113 | **Run evaluation:** use the following command to start the evaluation: 114 | 115 | ```bash 116 | bash evaluate_launch.sh 117 | ``` 118 | 119 | Currently supported tasks: `vqa-ai2d-test`, `vqa-chartqa-test`, `vqa-docvqa-val`, `vqa-docvqa-test`, `vqa-infovqa-val`, `vqa-infovqa-test`, `scienceqa`, `mme`, `mmbench-dev-en`, `mmbench-test-en`, `mmmu-val`, `seed`, `mvbench`. 120 | 121 | For image benchmarks and MVBench, we use the evaluation codebase of InternVL2. Refer to [here](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html#) for more details. 122 | 123 | ## 📅 TODO List 124 | 125 | * [X] release model and checkpoint 126 | * [ ] release evaluation code 127 | * [ ] release training code 128 | 129 | ## 🖊️ Citation 130 | 131 | If you find this work helpful in your research, please consider citing: 132 | 133 | ```bibtex 134 | @article{yang2024pvc, 135 | title={PVC: Progressive Visual Token Compression for Unified Image and Video Processing in Large Vision-Language Models}, 136 | author={Yang, Chenyu and Dong, Xuan and Zhu, Xizhou and Su, Weijie and Wang, Jiahao and Tian, Hao and Chen, Zhe and Wang, Wenhai and Lu, Lewei and and Dai, Jifeng}, 137 | journal={arXiv preprint arXiv:2412.09613}, 138 | year={2024} 139 | } 140 | ``` 141 | 142 | ## 📃 License 143 | 144 | This project is released under the [MIT license](LICENSE). Parts of this project contain code and models from other sources, which are subject to their respective licenses. 145 | -------------------------------------------------------------------------------- /assets/example_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/example_image1.jpg -------------------------------------------------------------------------------- /assets/example_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/example_image2.jpg -------------------------------------------------------------------------------- /assets/example_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/example_video.mp4 -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/overview.png -------------------------------------------------------------------------------- /eval/mmbench/evaluate_mmbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import itertools 4 | import json 5 | import os 6 | import random 7 | import time 8 | from functools import partial 9 | from io import BytesIO 10 | 11 | import pandas as pd 12 | import torch 13 | from utils.preprocess import build_transform, dynamic_preprocess 14 | from PIL import Image 15 | from torch.utils.data import Dataset 16 | from tqdm import tqdm 17 | from transformers import AutoTokenizer, AutoModel 18 | 19 | ds_collections = { 20 | 'mmbench_dev_20230712': { 21 | 'root': 'data/mmbench/mmbench_dev_20230712.tsv', 22 | 'max_new_tokens': 100, 23 | 'min_new_tokens': 1, 24 | 'type': 'dev', 25 | 'language': 'en' 26 | }, 27 | 'mmbench_dev_cn_20231003': { 28 | 'root': 'data/mmbench/mmbench_dev_cn_20231003.tsv', 29 | 'max_new_tokens': 100, 30 | 'min_new_tokens': 1, 31 | 'type': 'dev', 32 | 'language': 'cn' 33 | }, 34 | 'mmbench_dev_en_20231003': { 35 | 'root': 'data/mmbench/mmbench_dev_en_20231003.tsv', 36 | 'max_new_tokens': 100, 37 | 'min_new_tokens': 1, 38 | 'type': 'dev', 39 | 'language': 'en' 40 | }, 41 | 'mmbench_test_cn_20231003': { 42 | 'root': 'data/mmbench/mmbench_test_cn_20231003.tsv', 43 | 'max_new_tokens': 100, 44 | 'min_new_tokens': 1, 45 | 'type': 'test', 46 | 'language': 'cn' 47 | }, 48 | 'mmbench_test_en_20231003': { 49 | 'root': 'data/mmbench/mmbench_test_en_20231003.tsv', 50 | 'max_new_tokens': 100, 51 | 'min_new_tokens': 1, 52 | 'type': 'test', 53 | 'language': 'en' 54 | }, 55 | 'ccbench_dev_cn': { 56 | 'root': 'data/mmbench/CCBench_legacy.tsv', 57 | 'max_new_tokens': 100, 58 | 'min_new_tokens': 1, 59 | 'type': 'dev', 60 | 'language': 'cn' 61 | } 62 | } 63 | 64 | 65 | def collate_fn(batches, tokenizer): 66 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 67 | split_sizes = torch.cat([_['split_sizes'] for _ in batches]) 68 | questions = [_['question'] for _ in batches] 69 | answers = [_['answer'] for _ in batches] 70 | indexes = [_['index'] for _ in batches] 71 | options = [_['option'] for _ in batches] 72 | return pixel_values, split_sizes, questions, answers, indexes, options 73 | 74 | 75 | class MMBenchDataset(torch.utils.data.Dataset): 76 | 77 | def __init__(self, root, prompt, language, input_size=224, dynamic_image_size=False, 78 | use_thumbnail=False, max_num=6): 79 | self.df = pd.read_csv(root, sep='\t') 80 | self.prompt = prompt 81 | self.language = language 82 | self.input_size = input_size 83 | self.dynamic_image_size = dynamic_image_size 84 | self.use_thumbnail = use_thumbnail 85 | self.max_num = max_num 86 | self.transform = build_transform(input_size=input_size) 87 | 88 | def __len__(self): 89 | return len(self.df) 90 | 91 | def __getitem__(self, idx): 92 | index = self.df.iloc[idx]['index'] 93 | image = self.df.iloc[idx]['image'] 94 | question = self.df.iloc[idx]['question'] 95 | answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None 96 | # catetory = self.df.iloc[idx]['category'] 97 | # l2_catetory = self.df.iloc[idx]['l2-category'] 98 | 99 | image = Image.open(BytesIO(base64.b64decode(image))).convert('RGB') 100 | if self.dynamic_image_size: 101 | images = dynamic_preprocess(image, image_size=self.input_size, 102 | use_thumbnail=self.use_thumbnail, 103 | max_num=self.max_num) 104 | else: 105 | images = [image] 106 | pixel_values = [self.transform(image) for image in images] 107 | pixel_values = torch.stack(pixel_values) 108 | 109 | option_candidate = ['A', 'B', 'C', 'D', 'E'] 110 | options = { 111 | cand: self.load_from_df(idx, cand) 112 | for cand in option_candidate 113 | if self.load_from_df(idx, cand) is not None 114 | } 115 | 116 | hint = self.load_from_df(idx, 'hint') 117 | if hint is not None: 118 | question = hint + '\n' + question 119 | for key, item in options.items(): 120 | question += f'\n{key}. {item}' 121 | if self.language == 'cn': 122 | question = question + '\n' + self.prompt['cn'] 123 | else: 124 | question = question + '\n' + self.prompt['en'] 125 | 126 | return { 127 | 'question': question, 128 | 'pixel_values': pixel_values, 129 | 'split_sizes': torch.LongTensor((pixel_values.shape[0], )), 130 | 'answer': answer, 131 | 'index': index, 132 | 'option': options 133 | } 134 | 135 | def load_from_df(self, idx, key): 136 | if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]): 137 | return self.df.iloc[idx][key] 138 | else: 139 | return None 140 | 141 | 142 | class InferenceSampler(torch.utils.data.sampler.Sampler): 143 | 144 | def __init__(self, size): 145 | self._size = int(size) 146 | assert size > 0 147 | self._rank = torch.distributed.get_rank() 148 | self._world_size = torch.distributed.get_world_size() 149 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 150 | 151 | @staticmethod 152 | def _get_local_indices(total_size, world_size, rank): 153 | shard_size = total_size // world_size 154 | left = total_size % world_size 155 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 156 | 157 | begin = sum(shard_sizes[:rank]) 158 | end = min(sum(shard_sizes[:rank + 1]), total_size) 159 | return range(begin, end) 160 | 161 | def __iter__(self): 162 | yield from self._local_indices 163 | 164 | def __len__(self): 165 | return len(self._local_indices) 166 | 167 | 168 | def post_process(pred, option): 169 | pred = pred.strip() 170 | option_candidate = list(option.keys()) 171 | if len(pred) == 1: 172 | return pred 173 | elif len(pred) != 1 and pred[0] in option_candidate: 174 | return pred[0] 175 | elif len(pred) != 1 and pred[0] not in option_candidate: 176 | for k, v in option.items(): 177 | if v in pred: 178 | return k 179 | 180 | return pred 181 | 182 | 183 | def evaluate_chat_model(): 184 | random.seed(args.seed) 185 | 186 | for ds_name in args.datasets: 187 | dataset = MMBenchDataset( 188 | root=ds_collections[ds_name]['root'], 189 | prompt=prompt, 190 | language=ds_collections[ds_name]['language'], 191 | input_size=image_size, 192 | dynamic_image_size=args.dynamic, 193 | use_thumbnail=use_thumbnail, 194 | max_num=args.max_num 195 | ) 196 | dataloader = torch.utils.data.DataLoader( 197 | dataset=dataset, 198 | sampler=InferenceSampler(len(dataset)), 199 | batch_size=args.batch_size, 200 | num_workers=args.num_workers, 201 | pin_memory=True, 202 | drop_last=False, 203 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 204 | ) 205 | 206 | outputs = [] 207 | for _, (pixel_values, split_sizes, questions, answers, indexes, options) in tqdm(enumerate(dataloader)): 208 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 209 | generation_config = dict( 210 | num_beams=args.num_beams, 211 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 212 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'], 213 | do_sample=True if args.temperature > 0 else False, 214 | temperature=args.temperature, 215 | ) 216 | pred = model.chat( 217 | tokenizer=tokenizer, 218 | pixel_values=pixel_values, 219 | split_sizes=split_sizes, 220 | question=questions[0], 221 | generation_config=generation_config 222 | ) 223 | preds = [post_process(pred, options[0])] 224 | 225 | for question, pred, answer, index in zip(questions, preds, answers, indexes): 226 | outputs.append({ 227 | 'question': question, 228 | 'answer': pred, 229 | 'gt_answers': answer, 230 | 'index': int(index) 231 | }) 232 | 233 | torch.distributed.barrier() 234 | 235 | world_size = torch.distributed.get_world_size() 236 | merged_outputs = [None for _ in range(world_size)] 237 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 238 | 239 | merged_outputs = [json.loads(_) for _ in merged_outputs] 240 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 241 | 242 | if torch.distributed.get_rank() == 0: 243 | 244 | print(f'Evaluating {ds_name} ...') 245 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 246 | results_file = f'{ds_name}_{time_prefix}.xlsx' 247 | output_path = os.path.join(args.out_dir, results_file) 248 | df = pd.read_table(ds_collections[ds_name]['root']) 249 | cur_df = df.copy() 250 | if 'mmbench' in ds_name: 251 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) 252 | cur_df.insert(6, 'prediction', None) 253 | else: 254 | cur_df = cur_df.drop(columns=['category', 'image']) 255 | cur_df.insert(8, 'prediction', None) 256 | for item in merged_outputs: 257 | cur_df.loc[df['index'] == item['index'], 'prediction'] = item['answer'] 258 | 259 | cur_df.to_excel(output_path, index=False, engine='openpyxl') 260 | print('Results saved to {}'.format(output_path)) 261 | 262 | 263 | if __name__ == '__main__': 264 | parser = argparse.ArgumentParser() 265 | parser.add_argument('--checkpoint', type=str, default='') 266 | parser.add_argument('--datasets', type=str, default='mmbench_dev_20230712') 267 | parser.add_argument('--batch-size', type=int, default=1) 268 | parser.add_argument('--num-workers', type=int, default=1) 269 | parser.add_argument('--num-beams', type=int, default=5) 270 | parser.add_argument('--temperature', type=float, default=0.0) 271 | parser.add_argument('--out-dir', type=str, default='results') 272 | parser.add_argument('--seed', type=int, default=0) 273 | parser.add_argument('--dynamic', action='store_true') 274 | parser.add_argument('--max-num', type=int, default=6) 275 | parser.add_argument('--load-in-8bit', action='store_true') 276 | parser.add_argument('--auto', action='store_true') 277 | args = parser.parse_args() 278 | 279 | if not os.path.exists(args.out_dir): 280 | os.makedirs(args.out_dir, exist_ok=True) 281 | 282 | args.datasets = args.datasets.split(',') 283 | print('datasets:', args.datasets) 284 | assert args.batch_size == 1, 'Only batch size 1 is supported' 285 | 286 | torch.distributed.init_process_group( 287 | backend='nccl', 288 | world_size=int(os.getenv('WORLD_SIZE', '1')), 289 | rank=int(os.getenv('RANK', '0')), 290 | ) 291 | 292 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 293 | 294 | if args.auto: 295 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 296 | kwargs = {'device_map': 'auto'} if args.auto else {} 297 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 298 | model = AutoModel.from_pretrained( 299 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 300 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 301 | if not args.load_in_8bit and not args.auto: 302 | model = model.cuda() 303 | image_size = model.config.force_image_size or model.config.vision_config.image_size 304 | use_thumbnail = model.config.use_thumbnail 305 | 306 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 307 | if total_params > 20 or args.dynamic: 308 | args.num_beams = 1 309 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 310 | else: 311 | print(f'[test] total_params: {total_params}B') 312 | print(f'[test] image_size: {image_size}') 313 | print(f'[test] template: {model.config.template}') 314 | print(f'[test] dynamic_image_size: {args.dynamic}') 315 | print(f'[test] use_thumbnail: {use_thumbnail}') 316 | print(f'[test] max_num: {args.max_num}') 317 | 318 | prompt = { 319 | 'en': "Answer with the option's letter from the given choices directly.", 320 | 'cn': '请直接回答选项字母。' 321 | } 322 | evaluate_chat_model() 323 | -------------------------------------------------------------------------------- /eval/mme/README.md: -------------------------------------------------------------------------------- 1 | # This is an automated calculation script for the acc, acc+, and score. 2 | 3 | # You can directly run "python3 calculation.py" to get the evaluation results of LaVIN. 4 | 5 | # In order to get the statistical results of your model: 6 | 7 | (1) Fill all the files in "Your_Results", adding your model's responses: 8 | Each file in "Your_Results" consists of: 9 | Image_Name + "\\t" + Question + "\\t" + Ground_Truth_Answer + "\\n" 10 | 11 | You need to add the responses of your model as: 12 | Image_Name + "\\t" + Question + "\\t" + Ground_Truth_Answer + "\\t" + Your_Response + "\\n" 13 | 14 | Note: if your responses contain "\\n", please delet it. For each question, your response can only be in one line, not across lines! 15 | 16 | (2) run "python3 calculation.py --results_dir ./Your_Results" 17 | -------------------------------------------------------------------------------- /eval/mme/calculation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from sklearn.metrics import (accuracy_score, confusion_matrix, precision_score, 5 | recall_score) 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--results_dir', default='./LaVIN', type=str) 9 | 10 | eval_type_dict = { 11 | 'Perception': ['existence', 'count', 'position', 'color', 'posters', 'celebrity', 'scene', 'landmark', 'artwork', 'OCR'], 12 | 'Cognition': ['commonsense_reasoning', 'numerical_calculation', 'text_translation', 'code_reasoning'] 13 | } 14 | 15 | 16 | class calculate_metrics: 17 | def divide_chunks(self, l, n=2): 18 | # looping till length l 19 | for i in range(0, len(l), n): 20 | yield l[i:i + n] 21 | 22 | return 23 | 24 | def parse_pred_ans(self, pred_ans): 25 | pred_label = None 26 | if pred_ans in ['yes', 'no']: 27 | pred_label = pred_ans 28 | else: 29 | prefix_pred_ans = pred_ans[:4] 30 | 31 | if 'yes' in prefix_pred_ans: 32 | pred_label = 'yes' 33 | elif 'no' in prefix_pred_ans: 34 | pred_label = 'no' 35 | else: 36 | pred_label = 'other' 37 | 38 | return pred_label 39 | 40 | def compute_metric(self, gts, preds): 41 | assert len(gts) == len(preds) 42 | 43 | label_map = { 44 | 'yes': 1, 45 | 'no': 0, 46 | 'other': -1, 47 | } 48 | 49 | gts = [label_map[x] for x in gts] 50 | preds = [label_map[x] for x in preds] 51 | 52 | acc = accuracy_score(gts, preds) 53 | 54 | clean_gts = [] 55 | clean_preds = [] 56 | other_num = 0 57 | for gt, pred in zip(gts, preds): 58 | if pred == -1: 59 | other_num += 1 60 | continue 61 | clean_gts.append(gt) 62 | clean_preds.append(pred) 63 | 64 | conf_mat = confusion_matrix(clean_gts, clean_preds, labels=[1,0]) 65 | precision = precision_score(clean_gts, clean_preds, average='binary') 66 | recall = recall_score(clean_gts, clean_preds, average='binary') 67 | tp, fn = conf_mat[0] 68 | fp, tn = conf_mat[1] 69 | 70 | metric_dict = dict() 71 | metric_dict = { 72 | 'TP': tp, 73 | 'FN': fn, 74 | 'TN': tn, 75 | 'FP': fp, 76 | 'precision': precision, 77 | 'recall': recall, 78 | 'other_num': other_num, 79 | 'acc': acc, 80 | } 81 | 82 | return metric_dict 83 | 84 | def process_result(self, results_dir): 85 | 86 | model_score_dict = dict() 87 | for eval_type, task_name_list in eval_type_dict.items(): 88 | print('===========', eval_type, '===========') 89 | 90 | scores = 0 91 | task_score_dict = dict() 92 | 93 | for task_name in task_name_list: 94 | 95 | task_txt = os.path.join(results_dir, task_name + '.txt') 96 | lines = open(task_txt, 'r').readlines() 97 | chunk_lines = list(self.divide_chunks(lines)) # one image corresponds to two questions 98 | 99 | img_num = len(chunk_lines) 100 | task_other_ans_num = 0 101 | task_score = 0 102 | acc_plus_correct_num = 0 103 | gts = [] 104 | preds = [] 105 | 106 | for img_items in chunk_lines: 107 | assert len(img_items) == 2 108 | img_correct_num = 0 109 | 110 | for img_item in img_items: 111 | try: 112 | img_name, question, gt_ans, pred_ans = img_item.split('\t') 113 | except: 114 | print(img_item) 115 | continue 116 | gt_ans = gt_ans.lower() 117 | pred_ans = pred_ans.lower() 118 | 119 | assert gt_ans in ['yes', 'no'] # gt can only be yes or no. 120 | 121 | pred_ans = self.parse_pred_ans(pred_ans) 122 | assert pred_ans in ['yes', 'no', 'other'] 123 | 124 | gts.append(gt_ans) 125 | preds.append(pred_ans) 126 | 127 | if gt_ans == pred_ans: 128 | img_correct_num += 1 129 | 130 | if pred_ans not in ['yes', 'no']: 131 | task_other_ans_num += 1 132 | 133 | if img_correct_num == 2: 134 | acc_plus_correct_num += 1 135 | 136 | # cal TP precision acc, etc. 137 | metric_dict = self.compute_metric(gts, preds) 138 | acc_plus = acc_plus_correct_num / img_num 139 | metric_dict['acc_plus'] = acc_plus 140 | 141 | for k, v in metric_dict.items(): 142 | if k in ['acc', 'acc_plus']: 143 | task_score += v*100 144 | 145 | task_score_dict[task_name] = task_score 146 | 147 | scores += task_score 148 | 149 | print('total score:', scores, '\n') 150 | for task_name, score in task_score_dict.items(): 151 | print('\t', task_name, ' score:', score) 152 | print('\n') 153 | 154 | return 155 | 156 | 157 | if __name__ == '__main__': 158 | cal = calculate_metrics() 159 | 160 | args = parser.parse_args() 161 | results_dir = args.results_dir 162 | cal.process_result(results_dir) 163 | -------------------------------------------------------------------------------- /eval/mme/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | import torch 6 | from utils.preprocess import build_transform, dynamic_preprocess 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModel 10 | 11 | 12 | def load_image(image_file, input_size=224): 13 | image = Image.open(image_file).convert('RGB') 14 | transform = build_transform(input_size=input_size) 15 | if args.dynamic: 16 | images = dynamic_preprocess(image, image_size=input_size, 17 | use_thumbnail=use_thumbnail, 18 | max_num=args.max_num) 19 | else: 20 | images = [image] 21 | pixel_values = [transform(image) for image in images] 22 | pixel_values = torch.stack(pixel_values) 23 | return pixel_values 24 | 25 | 26 | def post_processing(response): 27 | response = response.replace('\n', '').replace('不是', 'No').replace('是', 'Yes').replace('否', 'No') 28 | response = response.lower().replace('true', 'yes').replace('false', 'no') 29 | pattern = re.compile(r'[\u4e00-\u9fa5]') 30 | response = re.sub(pattern, '', response) 31 | return response 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--checkpoint', type=str, default='') 37 | parser.add_argument('--root', type=str, default='./Your_Results') 38 | parser.add_argument('--num-beams', type=int, default=5) 39 | parser.add_argument('--top-k', type=int, default=50) 40 | parser.add_argument('--top-p', type=float, default=0.9) 41 | parser.add_argument('--sample', type=bool, default=False) 42 | parser.add_argument('--dynamic', action='store_true') 43 | parser.add_argument('--max-num', type=int, default=6) 44 | parser.add_argument('--load-in-8bit', action='store_true') 45 | parser.add_argument('--auto', action='store_true') 46 | args = parser.parse_args() 47 | 48 | if args.auto: 49 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 50 | kwargs = {'device_map': 'auto'} if args.auto else {} 51 | prompt = 'Answer the question using a single word or phrase.' 52 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 53 | model = AutoModel.from_pretrained( 54 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 55 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 56 | if not args.load_in_8bit and not args.auto: 57 | model = model.cuda() 58 | image_size = model.config.force_image_size or model.config.vision_config.image_size 59 | use_thumbnail = model.config.use_thumbnail 60 | 61 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 62 | if total_params > 20 or args.dynamic: 63 | args.num_beams = 1 64 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 65 | else: 66 | print(f'[test] total_params: {total_params}B') 67 | print(f'[test] image_size: {image_size}') 68 | print(f'[test] template: {model.config.template}') 69 | print(f'[test] dynamic_image_size: {args.dynamic}') 70 | print(f'[test] use_thumbnail: {use_thumbnail}') 71 | print(f'[test] max_num: {args.max_num}') 72 | 73 | output = os.path.join(args.checkpoint, 'mme') 74 | os.makedirs(output, exist_ok=True) 75 | 76 | for filename in os.listdir(args.root): 77 | fin = open(os.path.join(args.root, filename), 'r', encoding='utf-8') 78 | fout = open(os.path.join(output, filename), 'w', encoding='utf-8') 79 | lines = fin.readlines() 80 | filename = filename.replace('.txt', '') 81 | for line in tqdm(lines): 82 | img, question, gt = line.strip().split('\t') 83 | question = question + ' ' + prompt 84 | img_path = os.path.join('../../data/mme/MME_Benchmark_release_version', filename, img) 85 | assert os.path.exists(img_path), img_path 86 | pixel_values = load_image(img_path, image_size).cuda().to(torch.bfloat16) 87 | generation_config = dict( 88 | do_sample=args.sample, 89 | top_k=args.top_k, 90 | top_p=args.top_p, 91 | num_beams=args.num_beams, 92 | max_new_tokens=20, 93 | eos_token_id=tokenizer.eos_token_id, 94 | ) 95 | response = model.chat( 96 | tokenizer=tokenizer, 97 | pixel_values=pixel_values, 98 | split_sizes=[pixel_values.shape[0]], 99 | question=question, 100 | generation_config=generation_config 101 | ) 102 | response = post_processing(response) 103 | print(img, question, gt, response, sep='\t', file=fout) 104 | fin.close() 105 | fout.close() 106 | -------------------------------------------------------------------------------- /eval/mmmu/data_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for data load, save, and process (e.g., prompt construction)""" 2 | 3 | import json 4 | import os 5 | import re 6 | 7 | import yaml 8 | 9 | DOMAIN_CAT2SUB_CAT = { 10 | 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'], 11 | 'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'], 12 | 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ], 13 | 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 14 | 'Pharmacy', 'Public_Health'], 15 | 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'], 16 | 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 17 | 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'], 18 | } 19 | 20 | CAT_SHORT2LONG = { 21 | 'acc': 'Accounting', 22 | 'agri': 'Agriculture', 23 | 'arch': 'Architecture_and_Engineering', 24 | 'art': 'Art', 25 | 'art_theory': 'Art_Theory', 26 | 'bas_med': 'Basic_Medical_Science', 27 | 'bio': 'Biology', 28 | 'chem': 'Chemistry', 29 | 'cli_med': 'Clinical_Medicine', 30 | 'cs': 'Computer_Science', 31 | 'design': 'Design', 32 | 'diag_med': 'Diagnostics_and_Laboratory_Medicine', 33 | 'econ': 'Economics', 34 | 'elec': 'Electronics', 35 | 'ep': 'Energy_and_Power', 36 | 'fin': 'Finance', 37 | 'geo': 'Geography', 38 | 'his': 'History', 39 | 'liter': 'Literature', 40 | 'manage': 'Manage', 41 | 'mark': 'Marketing', 42 | 'mate': 'Materials', 43 | 'math': 'Math', 44 | 'mech': 'Mechanical_Engineering', 45 | 'music': 'Music', 46 | 'phar': 'Pharmacy', 47 | 'phys': 'Physics', 48 | 'psy': 'Psychology', 49 | 'pub_health': 'Public_Health', 50 | 'socio': 'Sociology' 51 | } 52 | 53 | 54 | # DATA SAVING 55 | def save_json(filename, ds): 56 | with open(filename, 'w') as f: 57 | json.dump(ds, f, indent=4) 58 | 59 | 60 | def get_multi_choice_info(options): 61 | """ 62 | Given the list of options for multiple choice question 63 | Return the index2ans and all_choices 64 | """ 65 | 66 | start_chr = 'A' 67 | all_choices = [] 68 | index2ans = {} 69 | for i, option in enumerate(options): 70 | index2ans[chr(ord(start_chr) + i)] = option 71 | all_choices.append(chr(ord(start_chr) + i)) 72 | 73 | return index2ans, all_choices 74 | 75 | 76 | def load_yaml(file_path): 77 | with open(file_path, 'r') as stream: 78 | try: 79 | yaml_dict = yaml.safe_load(stream) 80 | except yaml.YAMLError as exc: 81 | print(exc) 82 | 83 | return yaml_dict 84 | 85 | 86 | def parse_img_path(text): 87 | matches = re.findall("", text) 88 | return matches 89 | 90 | 91 | def process_single_sample(data): 92 | question = data['question'] 93 | o_imgs_paths = [] 94 | for option in data['options']: 95 | current_o_imgs_paths = parse_img_path(option) 96 | for img_path in current_o_imgs_paths: 97 | o_imgs_paths.append(img_path) 98 | images = [data['image_1'], data['image_2'], data['image_3'], data['image_4'], 99 | data['image_5'], data['image_6'], data['image_7']] 100 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'], 101 | 'image': images, 'question_type': data['question_type']} 102 | 103 | 104 | # DATA SAVING 105 | def save_json(filename, ds): 106 | with open(filename, 'w') as f: 107 | json.dump(ds, f, indent=4) 108 | 109 | 110 | def save_jsonl(filename, data): 111 | """ 112 | Save a dictionary of data to a JSON Lines file with the filename as key and caption as value. 113 | 114 | Args: 115 | filename (str): The path to the file where the data should be saved. 116 | data (dict): The dictionary containing the data to save where key is the image path and value is the caption. 117 | """ 118 | with open(filename, 'w', encoding='utf-8') as f: 119 | for img_path, caption in data.items(): 120 | # Extract the base filename without the extension 121 | base_filename = os.path.basename(img_path) 122 | # Create a JSON object with the filename as the key and caption as the value 123 | json_record = json.dumps({base_filename: caption}, ensure_ascii=False) 124 | # Write the JSON object to the file, one per line 125 | f.write(json_record + '\n') 126 | 127 | 128 | def save_args(args, path_dir): 129 | argsDict = args.__dict__ 130 | with open(path_dir + 'setting.txt', 'w') as f: 131 | f.writelines('------------------ start ------------------' + '\n') 132 | for eachArg, value in argsDict.items(): 133 | f.writelines(eachArg + ' : ' + str(value) + '\n') 134 | f.writelines('------------------- end -------------------') 135 | 136 | 137 | # DATA PROCESSING 138 | def construct_prompt(sample, config): 139 | question = sample['question'] 140 | options = eval(sample['options']) 141 | example = '' 142 | if sample['question_type'] == 'multiple-choice': 143 | start_chr = 'A' 144 | prediction_range = [] 145 | index2ans = {} 146 | for option in options: 147 | prediction_range.append(start_chr) 148 | example += f'({start_chr}) {option}\n' 149 | index2ans[start_chr] = option 150 | start_chr = chr(ord(start_chr) + 1) 151 | empty_prompt_sample_structure = config['multi_choice_example_format'] 152 | empty_prompt = empty_prompt_sample_structure.format(question, example) 153 | res_dict = {} 154 | res_dict['index2ans'] = index2ans 155 | res_dict['correct_choice'] = sample['answer'] 156 | res_dict['all_choices'] = prediction_range 157 | res_dict['empty_prompt'] = empty_prompt 158 | if config['task_instructions']: 159 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt 160 | else: 161 | res_dict['final_input_prompt'] = empty_prompt 162 | 163 | res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] 164 | else: 165 | empty_prompt_sample_structure = config['short_ans_example_format'] 166 | empty_prompt = empty_prompt_sample_structure.format(question) 167 | res_dict = {} 168 | res_dict['empty_prompt'] = empty_prompt 169 | if config['task_instructions']: 170 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt 171 | else: 172 | res_dict['final_input_prompt'] = empty_prompt 173 | res_dict['gt_content'] = sample['answer'] 174 | 175 | res_dict.update(sample) 176 | return res_dict 177 | -------------------------------------------------------------------------------- /eval/mmmu/eval_utils.py: -------------------------------------------------------------------------------- 1 | """Response Parsing and Evaluation for various models""" 2 | import random 3 | import re 4 | from typing import Dict 5 | 6 | random.seed(42) 7 | import numpy as np 8 | 9 | 10 | # ----------- Process Multi-choice ------------- 11 | def parse_multi_choice_response(response, all_choices, index2ans): 12 | """ 13 | Parse the prediction from the generated response. 14 | Return the predicted index e.g., A, B, C, D. 15 | """ 16 | for char in [',', '.', '!', '?', ';', ':', "'"]: 17 | response = response.strip(char) 18 | response = ' ' + response + ' ' # add space to avoid partial match 19 | 20 | index_ans = True 21 | ans_with_brack = False 22 | candidates = [] 23 | for choice in all_choices: # e.g., (A) (B) (C) (D) 24 | if f'({choice})' in response: 25 | candidates.append(choice) 26 | ans_with_brack = True 27 | 28 | if len(candidates) == 0: 29 | for choice in all_choices: # e.g., A B C D 30 | if f' {choice} ' in response: 31 | candidates.append(choice) 32 | 33 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example 34 | if len(candidates) == 0 and len(response.split()) > 5: 35 | for index, ans in index2ans.items(): 36 | if ans.lower() in response.lower(): 37 | candidates.append(index) 38 | index_ans = False # it's content ans. 39 | 40 | if len(candidates) == 0: # still not get answer, randomly choose one. 41 | pred_index = random.choice(all_choices) 42 | elif len(candidates) > 1: 43 | start_indexes = [] 44 | if index_ans: 45 | if ans_with_brack: 46 | for can in candidates: 47 | index = response.rfind(f'({can})') 48 | start_indexes.append(index) # -1 will be ignored anyway 49 | # start_indexes = [generated_response.index(f'({can})') for can in candidates] 50 | else: 51 | for can in candidates: 52 | index = response.rfind(f' {can} ') 53 | start_indexes.append(index) 54 | else: 55 | for can in candidates: 56 | index = response.lower().rfind(index2ans[can].lower()) 57 | start_indexes.append(index) 58 | # get the last one 59 | pred_index = candidates[np.argmax(start_indexes)] 60 | else: # if only one candidate, use it. 61 | pred_index = candidates[0] 62 | 63 | return pred_index 64 | 65 | 66 | # ----------- Process Open ------------- 67 | def check_is_number(string): 68 | """ 69 | Check if the given string a number. 70 | """ 71 | try: 72 | float(string.replace(',', '')) 73 | return True 74 | except ValueError: 75 | # check if there's comma inside 76 | return False 77 | 78 | 79 | def normalize_str(string): 80 | """ 81 | Normalize the str to lower case and make them float numbers if possible. 82 | """ 83 | # check if characters in the string 84 | 85 | # if number, numerize it. 86 | string = string.strip() 87 | 88 | is_number = check_is_number(string) 89 | 90 | if is_number: 91 | string = string.replace(',', '') 92 | string = float(string) 93 | # leave 2 decimal 94 | string = round(string, 2) 95 | return [string] 96 | else: # it's likely to be a string 97 | # lower it 98 | string = string.lower() 99 | if len(string) == 1: 100 | return [' ' + string, string + ' '] # avoid trivial matches 101 | return [string] 102 | 103 | 104 | def extract_numbers(string): 105 | """ 106 | Exact all forms of numbers from a string with regex. 107 | """ 108 | # Pattern for numbers with commas 109 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b' 110 | # Pattern for scientific notation 111 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' 112 | # Pattern for simple numbers without commas 113 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])' 114 | 115 | # Extract numbers with commas 116 | numbers_with_commas = re.findall(pattern_commas, string) 117 | # Extract numbers in scientific notation 118 | numbers_scientific = re.findall(pattern_scientific, string) 119 | # Extract simple numbers without commas 120 | numbers_simple = re.findall(pattern_simple, string) 121 | 122 | # Combine all extracted numbers 123 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple 124 | return all_numbers 125 | 126 | 127 | def parse_open_response(response): 128 | """ 129 | Parse the prediction from the generated response. 130 | Return a list of predicted strings or numbers. 131 | """ 132 | 133 | # content = content.strip("\n").strip(".").strip(" ") 134 | def get_key_subresponses(response): 135 | key_responses = [] 136 | response = response.strip().strip('.').lower() 137 | sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response) 138 | indicators_of_keys = ['could be ', 'so ', 'is ', 139 | 'thus ', 'therefore ', 'final ', 'answer ', 'result '] 140 | key_responses = [] 141 | for index, resp in enumerate(sub_responses): 142 | # if last one, accept it's an equation (the entire response can be just one sentence with equation) 143 | if index == len(sub_responses) - 1: 144 | indicators_of_keys.extend(['=']) 145 | shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) 146 | for indicator in indicators_of_keys: 147 | if indicator in resp: 148 | if not shortest_key_response: 149 | shortest_key_response = resp.split(indicator)[-1].strip() 150 | else: 151 | if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): 152 | shortest_key_response = resp.split(indicator)[-1].strip() 153 | # key_responses.append(resp.split(indicator)[1].strip()) 154 | 155 | if shortest_key_response: 156 | # and it's not trivial 157 | if shortest_key_response.strip() not in [':', ',', '.', '!', '?', ';', ':', "'"]: 158 | key_responses.append(shortest_key_response) 159 | if len(key_responses) == 0: # did not found any 160 | return [response] 161 | return key_responses 162 | 163 | # pdb.set_trace() 164 | key_responses = get_key_subresponses(response) 165 | 166 | pred_list = key_responses.copy() # keep the original string response 167 | for resp in key_responses: 168 | pred_list.extend(extract_numbers(resp)) 169 | 170 | tmp_pred_list = [] 171 | for i in range(len(pred_list)): 172 | tmp_pred_list.extend(normalize_str(pred_list[i])) 173 | pred_list = tmp_pred_list 174 | 175 | # remove duplicates 176 | pred_list = list(set(pred_list)) 177 | 178 | return pred_list 179 | 180 | 181 | # ----------- Evaluation ------------- 182 | 183 | def eval_multi_choice(gold_i, pred_i): 184 | """ 185 | Evaluate a multiple choice instance. 186 | """ 187 | correct = False 188 | # only they are exactly the same, we consider it as correct 189 | if isinstance(gold_i, list): 190 | for answer in gold_i: 191 | if answer == pred_i: 192 | correct = True 193 | break 194 | else: # gold_i is a string 195 | if gold_i == pred_i: 196 | correct = True 197 | return correct 198 | 199 | 200 | def eval_open(gold_i, pred_i): 201 | """ 202 | Evaluate an open question instance 203 | """ 204 | correct = False 205 | if isinstance(gold_i, list): 206 | # use float to avoid trivial matches 207 | norm_answers = [] 208 | for answer in gold_i: 209 | norm_answers.extend(normalize_str(answer)) 210 | else: 211 | norm_answers = normalize_str(gold_i) 212 | for pred in pred_i: # pred is already normalized in parse response phase 213 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i 214 | for norm_ans in norm_answers: 215 | # only see if the string answer in the string pred 216 | if isinstance(norm_ans, str) and norm_ans in pred: 217 | if not correct: 218 | correct = True 219 | break 220 | else: # it's a float number 221 | if pred in norm_answers: 222 | if not correct: 223 | correct = True 224 | break 225 | return correct 226 | 227 | 228 | # ----------- Batch Evaluation ------------- 229 | def evaluate(samples): 230 | """ 231 | Batch evaluation for multiple choice and open questions. 232 | """ 233 | pred_correct = 0 234 | judge_dict = dict() 235 | for sample in samples: 236 | gold_i = sample['answer'] 237 | pred_i = sample['parsed_pred'] 238 | if sample['question_type'] == 'multiple-choice': 239 | correct = eval_multi_choice(gold_i, pred_i) 240 | else: # open question 241 | correct = eval_open(gold_i, pred_i) 242 | 243 | if correct: 244 | judge_dict[sample['id']] = 'Correct' 245 | pred_correct += 1 246 | else: 247 | judge_dict[sample['id']] = 'Wrong' 248 | 249 | if len(samples) == 0: 250 | return {'acc': 0} 251 | return judge_dict, {'acc': pred_correct / len(samples)} 252 | 253 | 254 | # ----------- Calculate Accuracy ------------- 255 | def calculate_ins_level_acc(results: Dict): 256 | """Calculate the instruction level accuracy for given Subject results""" 257 | acc = 0 258 | ins_num = 0 259 | for cat_results in results.values(): 260 | acc += cat_results['acc'] * cat_results['num_example'] 261 | ins_num += cat_results['num_example'] 262 | if ins_num == 0: 263 | return 0 264 | return acc / ins_num 265 | -------------------------------------------------------------------------------- /eval/mmmu/evaluate_mmmu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | from data_utils import CAT_SHORT2LONG, process_single_sample 11 | from datasets import concatenate_datasets, load_dataset 12 | from utils.preprocess import build_transform, dynamic_preprocess 13 | from PIL import Image 14 | from torch.utils.data import Dataset 15 | from tqdm import tqdm 16 | from transformers import AutoTokenizer, AutoModel 17 | 18 | ds_collections = { 19 | 'MMMU_validation': { 20 | 'root': 'MMMU/MMMU', 21 | 'max_new_tokens': 10, 22 | 'min_new_tokens': 1, 23 | 'split': 'validation' 24 | }, 25 | 'MMMU_test': { 26 | 'root': 'MMMU/MMMU', 27 | 'max_new_tokens': 10, 28 | 'min_new_tokens': 1, 29 | 'split': 'test' 30 | }, 31 | 'MMMU_dev': { 32 | 'root': 'MMMU/MMMU', 33 | 'max_new_tokens': 10, 34 | 'min_new_tokens': 1, 35 | 'split': 'dev' 36 | }, 37 | } 38 | 39 | 40 | def collate_fn(batches, tokenizer): 41 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 42 | questions = [_['question'] for _ in batches] 43 | answers = [_['answer'] for _ in batches] 44 | data_ids = [_['data_id'] for _ in batches] 45 | options = [_['option'] for _ in batches] 46 | return pixel_values, questions, answers, data_ids, options 47 | 48 | 49 | class MMMUDataset(torch.utils.data.Dataset): 50 | 51 | def __init__(self, root, split, prompt, input_size=224, dynamic_image_size=False, 52 | use_thumbnail=False, max_num=6): 53 | # run for each subject 54 | sub_dataset_list = [] 55 | for subject in tqdm(CAT_SHORT2LONG.values()): 56 | sub_dataset = load_dataset(root, subject, split=split, ) # cache_dir=os.path.join(os.getcwd(), 'data/MMMU/')) 57 | sub_dataset_list.append(sub_dataset) 58 | 59 | # merge all dataset 60 | self.data = concatenate_datasets(sub_dataset_list) 61 | self.prompt = prompt 62 | self.input_size = input_size 63 | self.dynamic_image_size = dynamic_image_size 64 | self.use_thumbnail = use_thumbnail 65 | self.max_num = max_num 66 | self.transform = build_transform(input_size=input_size) 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def __getitem__(self, idx): 72 | 73 | data = process_single_sample(self.data[idx]) 74 | data_id = data['id'] 75 | question = data['question'].strip() 76 | pil_images = data['image'] 77 | question_type = data['question_type'] 78 | 79 | choices = eval(data['options']) 80 | answer = data['answer'] if 'answer' in data else None 81 | 82 | choice_list = [] 83 | options = {} 84 | multiple_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M'] 85 | for i, c in enumerate(choices): 86 | choice_list.append('{}. {}'.format(multiple_choices[i], c.strip())) 87 | options[multiple_choices[i]] = c.strip() 88 | choice_txt = '\n'.join(choice_list) 89 | if self.dynamic_image_size: 90 | images = [] 91 | for idx, pil_image in enumerate(pil_images): 92 | if pil_image is not None: 93 | if idx == 0: 94 | pil_image = pil_image.resize((pil_image.width * 2, pil_image.height * 2), Image.BILINEAR) 95 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size, 96 | use_thumbnail=self.use_thumbnail, max_num=self.max_num) 97 | else: 98 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size, 99 | use_thumbnail=self.use_thumbnail, max_num=1) 100 | images += pil_image 101 | else: 102 | images = [pil_images[0]] 103 | pixel_values = [self.transform(image) for image in images] 104 | pixel_values = torch.stack(pixel_values) 105 | 106 | if len(choice_txt) > 0: 107 | question += '\n' + choice_txt 108 | question += '\n' + self.prompt[question_type] 109 | question = question.strip() 110 | 111 | return { 112 | 'question': question, 113 | 'pixel_values': pixel_values, 114 | 'answer': answer, 115 | 'option': options, 116 | 'data_id': data_id 117 | } 118 | 119 | 120 | class InferenceSampler(torch.utils.data.sampler.Sampler): 121 | 122 | def __init__(self, size): 123 | self._size = int(size) 124 | assert size > 0 125 | self._rank = torch.distributed.get_rank() 126 | self._world_size = torch.distributed.get_world_size() 127 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 128 | 129 | @staticmethod 130 | def _get_local_indices(total_size, world_size, rank): 131 | shard_size = total_size // world_size 132 | left = total_size % world_size 133 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 134 | 135 | begin = sum(shard_sizes[:rank]) 136 | end = min(sum(shard_sizes[:rank + 1]), total_size) 137 | return range(begin, end) 138 | 139 | def __iter__(self): 140 | yield from self._local_indices 141 | 142 | def __len__(self): 143 | return len(self._local_indices) 144 | 145 | 146 | def post_process(pred, option): 147 | pred = pred.strip() 148 | option_candidate = list(option.keys()) 149 | if len(pred) == 1: 150 | return pred 151 | elif len(pred) != 1 and pred[0] in option_candidate: 152 | return pred[0] 153 | elif len(pred) != 1 and pred[0] not in option_candidate: 154 | for k, v in option.items(): 155 | if v in pred: 156 | return k 157 | 158 | return pred 159 | 160 | 161 | def evaluate_chat_model(): 162 | prompt = { 163 | 'multiple-choice': "Answer with the option's letter from the given choices directly.", 164 | 'open': 'Answer the question using a single word or phrase.' 165 | } 166 | random.seed(args.seed) 167 | 168 | for ds_name in args.datasets: 169 | dataset = MMMUDataset( 170 | root=ds_collections[ds_name]['root'], 171 | split=ds_collections[ds_name]['split'], 172 | prompt=prompt, 173 | input_size=image_size, 174 | dynamic_image_size=args.dynamic, 175 | use_thumbnail=use_thumbnail, 176 | max_num=args.max_num 177 | ) 178 | dataloader = torch.utils.data.DataLoader( 179 | dataset=dataset, 180 | sampler=InferenceSampler(len(dataset)), 181 | batch_size=args.batch_size, 182 | num_workers=args.num_workers, 183 | pin_memory=True, 184 | drop_last=False, 185 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 186 | ) 187 | 188 | outputs = [] 189 | for _, (pixel_values, questions, answers, data_ids, options) in tqdm(enumerate(dataloader)): 190 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 191 | generation_config = dict( 192 | num_beams=args.num_beams, 193 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 194 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'], 195 | do_sample=True if args.temperature > 0 else False, 196 | temperature=args.temperature, 197 | ) 198 | pred = model.chat( 199 | tokenizer=tokenizer, 200 | pixel_values=pixel_values, 201 | question=questions[0], 202 | generation_config=generation_config 203 | ) 204 | if len(options[0]) == 0: 205 | preds = [pred] 206 | else: 207 | preds = [post_process(pred, options[0])] 208 | 209 | for question, pred, answer, data_id in zip(questions, preds, answers, data_ids): 210 | outputs.append({ 211 | 'question': question, 212 | 'answer': pred, 213 | 'gt_answers': answer, 214 | 'data_id': data_id 215 | }) 216 | 217 | torch.distributed.barrier() 218 | 219 | world_size = torch.distributed.get_world_size() 220 | merged_outputs = [None for _ in range(world_size)] 221 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 222 | 223 | merged_outputs = [json.loads(_) for _ in merged_outputs] 224 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 225 | 226 | if torch.distributed.get_rank() == 0: 227 | 228 | print(f'Evaluating {ds_name} ...') 229 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 230 | results_file = f'{ds_name}_{time_prefix}.json' 231 | output_path = os.path.join(args.out_dir, results_file) 232 | outputs = {} 233 | for item in merged_outputs: 234 | outputs[item['data_id']] = item['answer'] 235 | with open(output_path, 'w') as f: 236 | json.dump(outputs, f, indent=4) 237 | print('Results saved to {}'.format(output_path)) 238 | if ds_collections[ds_name]['split'] == 'validation': 239 | print('Evaluating ...') 240 | cmd = f'python eval/mmmu/main_eval_only.py ' \ 241 | f'--output_path {output_path} ' \ 242 | f'--answer_path eval/mmmu/answer_dict_val.json' 243 | print(cmd) 244 | os.system(cmd) 245 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 246 | results_file = f'{ds_name}_{time_prefix}.jsonl' 247 | output_path = os.path.join(args.out_dir, results_file) 248 | writer = open(output_path, 'w') 249 | for item in merged_outputs: 250 | writer.write(json.dumps(item) + '\n') 251 | writer.close() 252 | print('Results saved to {}'.format(output_path)) 253 | 254 | 255 | if __name__ == '__main__': 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument('--checkpoint', type=str, default='') 258 | parser.add_argument('--datasets', type=str, default='MMMU_dev') 259 | parser.add_argument('--batch-size', type=int, default=1) 260 | parser.add_argument('--num-workers', type=int, default=1) 261 | parser.add_argument('--num-beams', type=int, default=5) 262 | parser.add_argument('--temperature', type=float, default=0.0) 263 | parser.add_argument('--out-dir', type=str, default='results') 264 | parser.add_argument('--seed', type=int, default=0) 265 | parser.add_argument('--dynamic', action='store_true') 266 | parser.add_argument('--max-num', type=int, default=6) 267 | parser.add_argument('--load-in-8bit', action='store_true') 268 | parser.add_argument('--auto', action='store_true') 269 | args = parser.parse_args() 270 | 271 | if not os.path.exists(args.out_dir): 272 | os.makedirs(args.out_dir) 273 | 274 | args.datasets = args.datasets.split(',') 275 | print('datasets:', args.datasets) 276 | assert args.batch_size == 1, 'Only batch size 1 is supported' 277 | 278 | torch.distributed.init_process_group( 279 | backend='nccl', 280 | world_size=int(os.getenv('WORLD_SIZE', '1')), 281 | rank=int(os.getenv('RANK', '0')), 282 | ) 283 | 284 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 285 | 286 | if args.auto: 287 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 288 | kwargs = {'device_map': 'auto'} if args.auto else {} 289 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 290 | model = AutoModel.from_pretrained( 291 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 292 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 293 | if not args.load_in_8bit and not args.auto: 294 | model = model.cuda() 295 | image_size = model.config.force_image_size or model.config.vision_config.image_size 296 | use_thumbnail = model.config.use_thumbnail 297 | 298 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 299 | if total_params > 20 or args.dynamic: 300 | args.num_beams = 1 301 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 302 | else: 303 | print(f'[test] total_params: {total_params}B') 304 | print(f'[test] image_size: {image_size}') 305 | print(f'[test] template: {model.config.template}') 306 | print(f'[test] dynamic_image_size: {args.dynamic}') 307 | print(f'[test] use_thumbnail: {use_thumbnail}') 308 | print(f'[test] max_num: {args.max_num}') 309 | 310 | evaluate_chat_model() 311 | -------------------------------------------------------------------------------- /eval/mmmu/evaluate_mmmu_cot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | from data_utils import CAT_SHORT2LONG, process_single_sample 11 | from datasets import concatenate_datasets, load_dataset 12 | from utils.preprocess import build_transform, dynamic_preprocess 13 | from PIL import Image 14 | from torch.utils.data import Dataset 15 | from tqdm import tqdm 16 | from transformers import AutoTokenizer, AutoModel 17 | 18 | ds_collections = { 19 | 'MMMU_validation': { 20 | 'root': 'MMMU/MMMU', 21 | 'max_new_tokens': 1000, 22 | 'min_new_tokens': 1, 23 | 'split': 'validation' 24 | }, 25 | 'MMMU_test': { 26 | 'root': 'MMMU/MMMU', 27 | 'max_new_tokens': 1000, 28 | 'min_new_tokens': 1, 29 | 'split': 'test' 30 | }, 31 | 'MMMU_dev': { 32 | 'root': 'MMMU/MMMU', 33 | 'max_new_tokens': 1000, 34 | 'min_new_tokens': 1, 35 | 'split': 'dev' 36 | }, 37 | } 38 | 39 | 40 | def collate_fn(batches, tokenizer): 41 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 42 | questions = [_['question'] for _ in batches] 43 | answers = [_['answer'] for _ in batches] 44 | data_ids = [_['data_id'] for _ in batches] 45 | options = [_['option'] for _ in batches] 46 | return pixel_values, questions, answers, data_ids, options 47 | 48 | 49 | class MMMUDataset(torch.utils.data.Dataset): 50 | 51 | def __init__(self, root, split, prompt, input_size=224, dynamic_image_size=False, 52 | use_thumbnail=False, max_num=6): 53 | # run for each subject 54 | sub_dataset_list = [] 55 | for subject in tqdm(CAT_SHORT2LONG.values()): 56 | sub_dataset = load_dataset(root, subject, split=split, cache_dir=os.path.join(os.getcwd(), 'data/MMMU/')) 57 | sub_dataset_list.append(sub_dataset) 58 | 59 | # merge all dataset 60 | self.data = concatenate_datasets(sub_dataset_list) 61 | self.prompt = prompt 62 | self.input_size = input_size 63 | self.dynamic_image_size = dynamic_image_size 64 | self.use_thumbnail = use_thumbnail 65 | self.max_num = max_num 66 | self.transform = build_transform(input_size=input_size) 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def __getitem__(self, idx): 72 | 73 | data = process_single_sample(self.data[idx]) 74 | data_id = data['id'] 75 | question = data['question'].strip() 76 | pil_images = data['image'] 77 | question_type = data['question_type'] 78 | 79 | choices = eval(data['options']) 80 | answer = data['answer'] if 'answer' in data else None 81 | 82 | choice_list = [] 83 | options = {} 84 | multiple_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M'] 85 | for i, c in enumerate(choices): 86 | choice_list.append('{}. {}'.format(multiple_choices[i], c.strip())) 87 | options[multiple_choices[i]] = c.strip() 88 | choice_txt = '\n'.join(choice_list) 89 | if self.dynamic_image_size: 90 | images = [] 91 | for idx, pil_image in enumerate(pil_images): 92 | if pil_image is not None: 93 | if idx == 0: 94 | pil_image = pil_image.resize((pil_image.width * 2, pil_image.height * 2), Image.BILINEAR) 95 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size, 96 | use_thumbnail=self.use_thumbnail, max_num=self.max_num) 97 | else: 98 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size, 99 | use_thumbnail=self.use_thumbnail, max_num=1) 100 | images += pil_image 101 | else: 102 | images = [pil_images[0]] 103 | pixel_values = [self.transform(image) for image in images] 104 | pixel_values = torch.stack(pixel_values) 105 | 106 | if len(choice_txt) > 0: 107 | question += '\n' + choice_txt 108 | question += '\n' + self.prompt[question_type] 109 | question = question.strip() 110 | 111 | return { 112 | 'question': question, 113 | 'pixel_values': pixel_values, 114 | 'answer': answer, 115 | 'option': options, 116 | 'data_id': data_id 117 | } 118 | 119 | 120 | class InferenceSampler(torch.utils.data.sampler.Sampler): 121 | 122 | def __init__(self, size): 123 | self._size = int(size) 124 | assert size > 0 125 | self._rank = torch.distributed.get_rank() 126 | self._world_size = torch.distributed.get_world_size() 127 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 128 | 129 | @staticmethod 130 | def _get_local_indices(total_size, world_size, rank): 131 | shard_size = total_size // world_size 132 | left = total_size % world_size 133 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 134 | 135 | begin = sum(shard_sizes[:rank]) 136 | end = min(sum(shard_sizes[:rank + 1]), total_size) 137 | return range(begin, end) 138 | 139 | def __iter__(self): 140 | yield from self._local_indices 141 | 142 | def __len__(self): 143 | return len(self._local_indices) 144 | 145 | 146 | def post_process(pred, option): 147 | pred = pred.split('"Answer":')[-1].strip() 148 | pred = pred.strip() 149 | option_candidate = list(option.keys()) 150 | if len(pred) == 1: 151 | return pred 152 | elif len(pred) != 1 and pred[0] in option_candidate: 153 | return pred[0] 154 | elif len(pred) != 1 and pred[0] not in option_candidate: 155 | for k, v in option.items(): 156 | if v in pred: 157 | return k 158 | 159 | return pred 160 | 161 | 162 | def evaluate_chat_model(): 163 | prompt = { 164 | 'multiple-choice': 'Please provide a detailed and step-by-step explanation, structured with sections labeled Thought, Step-by-Step Solution, and Answer. If you are uncertain of the correct answer, guess the most likely one.', 165 | 'open': 'Please provide a detailed and step-by-step explanation, structured with sections labeled Thought, Step-by-Step Solution, and Answer. If you are uncertain of the correct answer, guess the most likely one.' 166 | } 167 | random.seed(args.seed) 168 | 169 | for ds_name in args.datasets: 170 | dataset = MMMUDataset( 171 | root=ds_collections[ds_name]['root'], 172 | split=ds_collections[ds_name]['split'], 173 | prompt=prompt, 174 | input_size=image_size, 175 | dynamic_image_size=args.dynamic, 176 | use_thumbnail=use_thumbnail, 177 | max_num=args.max_num 178 | ) 179 | dataloader = torch.utils.data.DataLoader( 180 | dataset=dataset, 181 | sampler=InferenceSampler(len(dataset)), 182 | batch_size=args.batch_size, 183 | num_workers=args.num_workers, 184 | pin_memory=True, 185 | drop_last=False, 186 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 187 | ) 188 | 189 | outputs = [] 190 | for _, (pixel_values, questions, answers, data_ids, options) in tqdm(enumerate(dataloader)): 191 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 192 | generation_config = dict( 193 | num_beams=args.num_beams, 194 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 195 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'], 196 | do_sample=True if args.temperature > 0 else False, 197 | temperature=args.temperature, 198 | ) 199 | pred = model.chat( 200 | tokenizer=tokenizer, 201 | pixel_values=pixel_values, 202 | question=questions[0], 203 | generation_config=generation_config 204 | ) 205 | original_pred = pred 206 | if len(options[0]) == 0: 207 | preds = [pred] 208 | else: 209 | preds = [post_process(pred, options[0])] 210 | 211 | for question, pred, answer, data_id in zip(questions, preds, answers, data_ids): 212 | outputs.append({ 213 | 'question': question, 214 | 'answer': pred, 215 | 'answer_original': original_pred, 216 | 'gt_answers': answer, 217 | 'data_id': data_id 218 | }) 219 | 220 | torch.distributed.barrier() 221 | 222 | world_size = torch.distributed.get_world_size() 223 | merged_outputs = [None for _ in range(world_size)] 224 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 225 | 226 | merged_outputs = [json.loads(_) for _ in merged_outputs] 227 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 228 | 229 | if torch.distributed.get_rank() == 0: 230 | 231 | print(f'Evaluating {ds_name} ...') 232 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 233 | results_file = f'{ds_name}_{time_prefix}.json' 234 | output_path = os.path.join(args.out_dir, results_file) 235 | outputs = {} 236 | for item in merged_outputs: 237 | outputs[item['data_id']] = item['answer'] 238 | with open(output_path, 'w') as f: 239 | json.dump(outputs, f, indent=4) 240 | print('Results saved to {}'.format(output_path)) 241 | if ds_collections[ds_name]['split'] == 'validation': 242 | print('Evaluating ...') 243 | cmd = f'python eval/mmmu/main_eval_only.py ' \ 244 | f'--output_path {output_path} ' \ 245 | f'--answer_path eval/mmmu/answer_dict_val.json' 246 | print(cmd) 247 | os.system(cmd) 248 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 249 | results_file = f'{ds_name}_{time_prefix}.jsonl' 250 | output_path = os.path.join(args.out_dir, results_file) 251 | writer = open(output_path, 'w') 252 | for item in merged_outputs: 253 | writer.write(json.dumps(item) + '\n') 254 | writer.close() 255 | print('Results saved to {}'.format(output_path)) 256 | 257 | 258 | if __name__ == '__main__': 259 | parser = argparse.ArgumentParser() 260 | parser.add_argument('--checkpoint', type=str, default='') 261 | parser.add_argument('--datasets', type=str, default='MMMU_dev') 262 | parser.add_argument('--batch-size', type=int, default=1) 263 | parser.add_argument('--num-workers', type=int, default=1) 264 | parser.add_argument('--num-beams', type=int, default=5) 265 | parser.add_argument('--temperature', type=float, default=0.0) 266 | parser.add_argument('--out-dir', type=str, default='results') 267 | parser.add_argument('--seed', type=int, default=0) 268 | parser.add_argument('--dynamic', action='store_true') 269 | parser.add_argument('--max-num', type=int, default=6) 270 | parser.add_argument('--load-in-8bit', action='store_true') 271 | parser.add_argument('--auto', action='store_true') 272 | args = parser.parse_args() 273 | 274 | if not os.path.exists(args.out_dir): 275 | os.makedirs(args.out_dir) 276 | 277 | args.datasets = args.datasets.split(',') 278 | print('datasets:', args.datasets) 279 | assert args.batch_size == 1, 'Only batch size 1 is supported' 280 | 281 | torch.distributed.init_process_group( 282 | backend='nccl', 283 | world_size=int(os.getenv('WORLD_SIZE', '1')), 284 | rank=int(os.getenv('RANK', '0')), 285 | ) 286 | 287 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 288 | 289 | if args.auto: 290 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 291 | kwargs = {'device_map': 'auto'} if args.auto else {} 292 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 293 | model = AutoModel.from_pretrained( 294 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 295 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 296 | if not args.load_in_8bit and not args.auto: 297 | model = model.cuda() 298 | image_size = model.config.force_image_size or model.config.vision_config.image_size 299 | use_thumbnail = model.config.use_thumbnail 300 | 301 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 302 | if total_params > 20 or args.dynamic: 303 | args.num_beams = 1 304 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 305 | else: 306 | print(f'[test] total_params: {total_params}B') 307 | print(f'[test] image_size: {image_size}') 308 | print(f'[test] template: {model.config.template}') 309 | print(f'[test] dynamic_image_size: {args.dynamic}') 310 | print(f'[test] use_thumbnail: {use_thumbnail}') 311 | print(f'[test] max_num: {args.max_num}') 312 | 313 | evaluate_chat_model() 314 | -------------------------------------------------------------------------------- /eval/mmmu/main_eval_only.py: -------------------------------------------------------------------------------- 1 | """Parse and Evalate""" 2 | import json 3 | from argparse import ArgumentParser 4 | 5 | from data_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, save_json 6 | from eval_utils import (calculate_ins_level_acc, evaluate, 7 | parse_multi_choice_response, parse_open_response) 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = ArgumentParser() 12 | parser.add_argument('--output_path', type=str, default='./example_outputs/qwen_vl/total_val_output.json', 13 | help='The path to model output file.') 14 | parser.add_argument('--answer_path', type=str, default='./answer_dict_val.json', help='Answer file path.') 15 | args = parser.parse_args() 16 | 17 | output_dict = json.load(open(args.output_path)) 18 | answer_dict = json.load(open(args.answer_path)) 19 | 20 | # group by category 21 | output_dict_w_cat = {} 22 | for data_id, parsed_pred in output_dict.items(): 23 | category = '_'.join(data_id.split('_')[1:-1]) 24 | if category not in output_dict_w_cat: 25 | output_dict_w_cat.update({category: {}}) 26 | output_dict_w_cat[category].update({data_id: parsed_pred}) 27 | 28 | # group by category 29 | answer_dict_w_cat = {} 30 | for data_id, parsed_pred in answer_dict.items(): 31 | category = '_'.join(data_id.split('_')[1:-1]) 32 | if category not in answer_dict_w_cat: 33 | answer_dict_w_cat.update({category: {}}) 34 | answer_dict_w_cat[category].update({data_id: parsed_pred}) 35 | 36 | evaluation_result = {} 37 | 38 | for category in CAT_SHORT2LONG.values(): 39 | print('Evaluating: {}'.format(category)) 40 | # get cat_outputs and cat_answers 41 | try: 42 | cat_outputs = output_dict_w_cat[category] 43 | cat_answers = answer_dict_w_cat[category] 44 | except KeyError: 45 | print('Skipping {} for not found'.format(category)) 46 | continue 47 | 48 | exampels_to_eval = [] 49 | for data_id, parsed_pred in cat_outputs.items(): 50 | question_type = cat_answers[data_id]['question_type'] 51 | if question_type != 'multiple-choice': 52 | parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.) 53 | else: 54 | parsed_pred = parsed_pred 55 | 56 | exampels_to_eval.append({ 57 | 'id': data_id, 58 | 'question_type': question_type, 59 | 'answer': cat_answers[data_id]['ground_truth'], 60 | 'parsed_pred': parsed_pred 61 | }) 62 | 63 | judge_dict, metric_dict = evaluate(exampels_to_eval) 64 | metric_dict.update({'num_example': len(exampels_to_eval)}) 65 | 66 | evaluation_result[category] = metric_dict 67 | 68 | printable_results = {} 69 | # pdb.set_trace() 70 | # add domain Subject 71 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): 72 | in_domain_cat_results = {} 73 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT 74 | if cat_name in evaluation_result.keys(): 75 | in_domain_cat_results[cat_name] = evaluation_result[cat_name] 76 | else: 77 | pass 78 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) 79 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()]) 80 | printable_results['Overall-' + domain] = {'num': int(in_domain_data_num), 81 | 'acc': round(in_domain_ins_acc, 3) 82 | } 83 | # add sub category 84 | for cat_name, cat_results in in_domain_cat_results.items(): 85 | printable_results[cat_name] = {'num': int(cat_results['num_example']), 86 | 'acc': round(cat_results['acc'], 3) 87 | } 88 | 89 | # table.append(["-----------------------------", "-----", "----"]) 90 | all_ins_acc = calculate_ins_level_acc(evaluation_result) 91 | printable_results['Overall'] = { 92 | 'num': sum([cat_results['num_example'] for cat_results in evaluation_result.values()]), 93 | 'acc': round(all_ins_acc, 3)} 94 | 95 | print(printable_results) 96 | -------------------------------------------------------------------------------- /eval/mvbench/evaluate_mvbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import cv2 10 | import imageio 11 | import numpy as np 12 | import torch 13 | from decord import VideoReader, cpu 14 | from utils.preprocess import build_transform, dynamic_preprocess 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | from transformers import AutoTokenizer, AutoModel 19 | 20 | data_list = { 21 | 'Action Sequence': ('action_sequence.json', './data/MVBench/video/star/Charades_v1_480/', 'video', True), 22 | # has start & end 23 | 'Action Prediction': ('action_prediction.json', './data/MVBench/video/star/Charades_v1_480/', 'video', True), 24 | # has start & end 25 | 'Action Antonym': ('action_antonym.json', './data/MVBench/video/ssv2_video/', 'video', False), 26 | 'Fine-grained Action': ( 27 | 'fine_grained_action.json', './data/MVBench/video/Moments_in_Time_Raw/videos/', 'video', False), 28 | 'Unexpected Action': ('unexpected_action.json', './data/MVBench/video/FunQA_test/test/', 'video', False), 29 | 'Object Existence': ('object_existence.json', './data/MVBench/video/clevrer/video_validation/', 'video', False), 30 | 'Object Interaction': ('object_interaction.json', './data/MVBench/video/star/Charades_v1_480/', 'video', True), 31 | # has start & end 32 | 'Object Shuffle': ('object_shuffle.json', './data/MVBench/video/perception/videos/', 'video', False), 33 | 'Moving Direction': ('moving_direction.json', './data/MVBench/video/clevrer/video_validation/', 'video', False), 34 | 'Action Localization': ('action_localization.json', './data/MVBench/video/sta/sta_video/', 'video', True), 35 | # has start & end 36 | 'Scene Transition': ('scene_transition.json', './data/MVBench/video/scene_qa/video/', 'video', False), 37 | 'Action Count': ('action_count.json', './data/MVBench/video/perception/videos/', 'video', False), 38 | 'Moving Count': ('moving_count.json', './data/MVBench/video/clevrer/video_validation/', 'video', False), 39 | 'Moving Attribute': ('moving_attribute.json', './data/MVBench/video/clevrer/video_validation/', 'video', False), 40 | 'State Change': ('state_change.json', './data/MVBench/video/perception/videos/', 'video', False), 41 | 'Fine-grained Pose': ('fine_grained_pose.json', './data/MVBench/video/nturgbd/', 'video', False), 42 | 'Character Order': ('character_order.json', './data/MVBench/video/perception/videos/', 'video', False), 43 | 'Egocentric Navigation': ('egocentric_navigation.json', './data/MVBench/video/vlnqa/', 'video', False), 44 | 'Episodic Reasoning': ('episodic_reasoning.json', './data/MVBench/video/tvqa/frames_fps3_hq/', 'frame', True), 45 | # has start & end, read frame 46 | 'Counterfactual Inference': ( 47 | 'counterfactual_inference.json', './data/MVBench/video/clevrer/video_validation/', 'video', False), 48 | } 49 | 50 | data_dir = './data/MVBench/json' 51 | 52 | 53 | def collate_fn(batches, tokenizer): 54 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 55 | split_sizes = torch.cat([_['split_sizes'] for _ in batches]) 56 | data_flag = torch.cat([_['data_flag'] for _ in batches]) 57 | questions = [_['question'] for _ in batches] 58 | answers = [_['answer'] for _ in batches] 59 | num_patches_lists = [_['num_patches_list'] for _ in batches] 60 | task_types = [_['task_type'] for _ in batches] 61 | return pixel_values, split_sizes, data_flag, questions, answers, num_patches_lists, task_types 62 | 63 | 64 | class MVBenchDataset(torch.utils.data.Dataset): 65 | 66 | def __init__(self, data_dir, data_list, prompt, question_prompt, num_segments=16, input_size=224, 67 | dynamic_image_size=False, use_thumbnail=False, max_num=6): 68 | self.data_list = [] 69 | for k, v in data_list.items(): 70 | with open(os.path.join(data_dir, v[0]), 'r') as f: 71 | json_data = json.load(f) 72 | for data in json_data: 73 | self.data_list.append({ 74 | 'task_type': k, 75 | 'prefix': v[1], 76 | 'data_type': v[2], 77 | 'bound': v[3], 78 | 'data': data 79 | }) 80 | self.decord_method = { 81 | 'video': self.read_video, 82 | 'gif': self.read_gif, 83 | 'frame': self.read_frame, 84 | } 85 | self.prompt = prompt 86 | self.question_prompt = question_prompt 87 | self.input_size = input_size 88 | self.num_segments = num_segments 89 | self.dynamic_image_size = dynamic_image_size 90 | self.use_thumbnail = use_thumbnail 91 | self.max_num = max_num 92 | self.transform = build_transform(input_size=input_size) 93 | 94 | def __len__(self): 95 | return len(self.data_list) 96 | 97 | def __str__(self): 98 | len_list = {} 99 | option_list = {} 100 | for data in self.data_list: 101 | if data['task_type'] not in len_list: 102 | len_list[data['task_type']] = 0 103 | len_list[data['task_type']] += 1 104 | if data['task_type'] not in option_list: 105 | option_list[data['task_type']] = 0 106 | option_list[data['task_type']] += len(data['data']['candidates']) 107 | 108 | correct = 0 109 | total = 0 110 | res = f'There are {len(self.data_list)} videos as follow:\n' 111 | for k, v in len_list.items(): 112 | correct += len_list[k] 113 | total += option_list[k] 114 | res += f'{v} for {k} ({option_list[k]} options => {len_list[k] / option_list[k] * 100:.2f}%)\n' 115 | correct = correct + 1 / option_list[k] 116 | res += f'Total random accuracy: {correct / total * 100:.2f}%' 117 | return res.rstrip() 118 | 119 | def get_index(self, bound, fps, max_frame, first_idx=0): 120 | if bound: 121 | start, end = bound[0], bound[1] 122 | else: 123 | start, end = -100000, 100000 124 | start_idx = max(first_idx, round(start * fps)) 125 | end_idx = min(round(end * fps), max_frame) 126 | seg_size = float(end_idx - start_idx) / self.num_segments 127 | frame_indices = np.array([ 128 | int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) 129 | for idx in range(self.num_segments) 130 | ]) 131 | return frame_indices 132 | 133 | def read_video(self, video_path, bound=None): 134 | vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) 135 | max_frame = len(vr) - 1 136 | fps = float(vr.get_avg_fps()) 137 | 138 | images_group = list() 139 | frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) 140 | for frame_index in frame_indices: 141 | img = Image.fromarray(vr[frame_index].asnumpy()) 142 | images_group.append(img) 143 | 144 | return images_group 145 | 146 | def read_gif(self, video_path, bound=None, fps=25): 147 | gif = imageio.get_reader(video_path) 148 | max_frame = len(gif) - 1 149 | 150 | images_group = list() 151 | frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) 152 | for index, frame in enumerate(gif): 153 | if index in frame_indices: 154 | img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 155 | img = Image.fromarray(img) 156 | images_group.append(img) 157 | 158 | return images_group 159 | 160 | def read_frame(self, video_path, bound=None, fps=3): 161 | max_frame = len(os.listdir(video_path)) 162 | images_group = list() 163 | frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1 164 | for frame_index in frame_indices: 165 | img = Image.open(os.path.join(video_path, f'{frame_index:05d}.jpg')) 166 | images_group.append(img) 167 | 168 | return images_group 169 | 170 | def qa_template(self, data): 171 | question = f"Question: {data['question']}\n" 172 | question += 'Options:\n' 173 | answer = data['answer'] 174 | answer_idx = -1 175 | for idx, c in enumerate(data['candidates']): 176 | question += f"({chr(ord('A') + idx)}) {c}\n" 177 | if c == answer: 178 | answer_idx = idx 179 | question = question.rstrip() 180 | answer = f"({chr(ord('A') + answer_idx)}) {answer}" 181 | return question, answer 182 | 183 | def __getitem__(self, idx): 184 | decord_method = self.decord_method[self.data_list[idx]['data_type']] 185 | bound = None 186 | if self.data_list[idx]['bound']: 187 | bound = ( 188 | self.data_list[idx]['data']['start'], 189 | self.data_list[idx]['data']['end'], 190 | ) 191 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video']) 192 | image_list = decord_method(video_path, bound) 193 | special_tokens = '\n'.join(['Frame{}:'.format(i + 1) for i in range(len(image_list))]) 194 | 195 | question, answer = self.qa_template(self.data_list[idx]['data']) 196 | question = special_tokens + '\n' + self.prompt + '\n' + question + self.question_prompt 197 | 198 | raw_images = [] 199 | num_patches_list = [] 200 | pixel_values = [] 201 | for image in image_list: 202 | raw_images.append(image) 203 | if self.dynamic_image_size: 204 | patches = dynamic_preprocess(image, image_size=self.input_size, 205 | use_thumbnail=self.use_thumbnail, 206 | max_num=self.max_num) 207 | else: 208 | patches = [image] 209 | num_patches_list.append(len(patches)) 210 | pixel_values.extend([self.transform(patch) for patch in patches]) 211 | 212 | pixel_values = torch.stack(pixel_values) 213 | 214 | return { 215 | 'question': question, 216 | 'pixel_values': pixel_values, 217 | 'split_sizes': torch.LongTensor((pixel_values.shape[0], )), 218 | 'data_flag': torch.LongTensor((3, )), 219 | 'answer': answer, 220 | 'num_patches_list': num_patches_list, 221 | 'task_type': self.data_list[idx]['task_type'] 222 | } 223 | 224 | 225 | class InferenceSampler(torch.utils.data.sampler.Sampler): 226 | 227 | def __init__(self, size): 228 | self._size = int(size) 229 | assert size > 0 230 | self._rank = torch.distributed.get_rank() 231 | self._world_size = torch.distributed.get_world_size() 232 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 233 | 234 | @staticmethod 235 | def _get_local_indices(total_size, world_size, rank): 236 | shard_size = total_size // world_size 237 | left = total_size % world_size 238 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 239 | 240 | begin = sum(shard_sizes[:rank]) 241 | end = min(sum(shard_sizes[:rank + 1]), total_size) 242 | return range(begin, end) 243 | 244 | def __iter__(self): 245 | yield from self._local_indices 246 | 247 | def __len__(self): 248 | return len(self._local_indices) 249 | 250 | 251 | def check_ans(pred, gt): 252 | flag = False 253 | pred = pred.replace('Answer: ', '') 254 | 255 | pred_list = pred.lower().split(' ') 256 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:]) 257 | gt_list = gt.lower().split(' ') 258 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:]) 259 | if gt_content[-1] == '.': 260 | gt_content = gt_content[:-1] 261 | 262 | if pred_option.replace('.', '') in gt_option: 263 | flag = True 264 | elif gt_option in pred_option: 265 | flag = True 266 | 267 | return flag 268 | 269 | 270 | def evaluate_chat_model(): 271 | random.seed(args.seed) 272 | prompt = 'Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n' 273 | question_prompt = '\nOnly give the best option.' 274 | 275 | vid_dataset = MVBenchDataset( 276 | data_dir, data_list, 277 | prompt=prompt, 278 | question_prompt=question_prompt, 279 | num_segments=args.num_segments, 280 | input_size=image_size, 281 | dynamic_image_size=args.dynamic, 282 | use_thumbnail=use_thumbnail, 283 | max_num=args.max_num) 284 | dataloader = torch.utils.data.DataLoader( 285 | dataset=vid_dataset, 286 | sampler=InferenceSampler(len(vid_dataset)), 287 | batch_size=args.batch_size, 288 | num_workers=args.num_workers, 289 | pin_memory=True, 290 | drop_last=False, 291 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 292 | ) 293 | 294 | outputs = [] 295 | for _, (pixel_values, split_sizes, data_flag, questions, answers, num_patches_lists, task_types) in tqdm(enumerate(dataloader)): 296 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 297 | generation_config = dict( 298 | num_beams=args.num_beams, 299 | max_new_tokens=1000, 300 | min_new_tokens=1, 301 | do_sample=True if args.temperature > 0 else False, 302 | temperature=args.temperature, 303 | ) 304 | pred = model.chat( 305 | tokenizer=tokenizer, 306 | pixel_values=pixel_values, 307 | split_sizes=split_sizes, 308 | data_flag=data_flag, 309 | num_patches_list=num_patches_lists[0], 310 | question=questions[0], 311 | generation_config=generation_config 312 | ) 313 | outputs.append({ 314 | 'question': questions[0], 315 | 'pred': pred, 316 | 'gt': answers[0], 317 | 'task_type': task_types[0], 318 | }) 319 | torch.distributed.barrier() 320 | 321 | world_size = torch.distributed.get_world_size() 322 | merged_outputs = [None for _ in range(world_size)] 323 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 324 | 325 | merged_outputs = [json.loads(_) for _ in merged_outputs] 326 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 327 | 328 | if torch.distributed.get_rank() == 0: 329 | 330 | print(f'Evaluating MVBench ...') 331 | correct, total, acc_dict = 0, 0, {} 332 | for item in merged_outputs: 333 | task_type = item['task_type'] 334 | pred = item['pred'] 335 | gt = item['gt'] 336 | if task_type not in acc_dict: 337 | acc_dict[task_type] = [0, 0] # correct, total 338 | acc_dict[task_type][1] += 1 339 | total += 1 340 | 341 | if check_ans(pred, gt): 342 | acc_dict[task_type][0] += 1 343 | correct += 1 344 | 345 | final_res = {} 346 | for k, v in acc_dict.items(): 347 | final_res[k] = v[0] / v[1] * 100 348 | final_res['Avg'] = correct / total * 100 349 | print(final_res) 350 | 351 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 352 | results_file = f'MVBench_{time_prefix}' 353 | output_path = os.path.join(args.out_dir, results_file) 354 | with open(f'{output_path}.json', 'w') as f: 355 | json.dump(outputs, f) 356 | with open(f'{output_path}_result_final.json', 'w') as f: 357 | json.dump(final_res, f) 358 | print('Results saved to {}'.format(output_path)) 359 | 360 | 361 | if __name__ == '__main__': 362 | parser = argparse.ArgumentParser() 363 | parser.add_argument('--checkpoint', type=str, default='') 364 | parser.add_argument('--datasets', type=str, default='mvbench') 365 | parser.add_argument('--batch-size', type=int, default=1) 366 | parser.add_argument('--num-workers', type=int, default=1) 367 | parser.add_argument('--num-beams', type=int, default=5) 368 | parser.add_argument('--temperature', type=float, default=0.0) 369 | parser.add_argument('--out-dir', type=str, default='results') 370 | parser.add_argument('--seed', type=int, default=0) 371 | parser.add_argument('--dynamic', action='store_true') 372 | parser.add_argument('--max-num', type=int, default=6) 373 | parser.add_argument('--load-in-8bit', action='store_true') 374 | parser.add_argument('--auto', action='store_true') 375 | parser.add_argument('--num_segments', type=int, default=16) 376 | args = parser.parse_args() 377 | 378 | if not os.path.exists(args.out_dir): 379 | os.makedirs(args.out_dir, exist_ok=True) 380 | 381 | args.datasets = args.datasets.split(',') 382 | print('datasets:', args.datasets) 383 | assert args.batch_size == 1, 'Only batch size 1 is supported' 384 | 385 | torch.distributed.init_process_group( 386 | backend='nccl', 387 | world_size=int(os.getenv('WORLD_SIZE', '1')), 388 | rank=int(os.getenv('RANK', '0')), 389 | ) 390 | 391 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 392 | 393 | if args.auto: 394 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 395 | kwargs = {'device_map': 'auto'} if args.auto else {} 396 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 397 | model = AutoModel.from_pretrained( 398 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 399 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 400 | if not args.load_in_8bit and not args.auto: 401 | model = model.cuda() 402 | image_size = model.config.force_image_size or model.config.vision_config.image_size 403 | use_thumbnail = model.config.use_thumbnail 404 | 405 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 406 | if total_params > 20 or args.dynamic: 407 | args.num_beams = 1 408 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 409 | else: 410 | print(f'[test] total_params: {total_params}B') 411 | print(f'[test] image_size: {image_size}') 412 | print(f'[test] template: {model.config.template}') 413 | print(f'[test] dynamic_image_size: {args.dynamic}') 414 | print(f'[test] use_thumbnail: {use_thumbnail}') 415 | print(f'[test] max_num: {args.max_num}') 416 | 417 | evaluate_chat_model() 418 | -------------------------------------------------------------------------------- /eval/scienceqa/evaluate_scienceqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | from utils.preprocess import build_transform, dynamic_preprocess 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | from transformers import AutoTokenizer, AutoModel 15 | 16 | ds_collections = { 17 | 'sqa_test': { 18 | 'root': 'data/scienceqa/scienceqa_test_img.jsonl', 19 | 'max_new_tokens': 100, 20 | 'min_new_tokens': 1, 21 | }, 22 | } 23 | 24 | 25 | def collate_fn(batches, tokenizer): 26 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 27 | questions = [_['question'] for _ in batches] 28 | answers = [_['answer'] for _ in batches] 29 | image_paths = [_['image_path'] for _ in batches] 30 | options = [_['option'] for _ in batches] 31 | return pixel_values, questions, answers, image_paths, options 32 | 33 | 34 | class ScienceQADataset(torch.utils.data.Dataset): 35 | 36 | def __init__(self, root, prompt, input_size=224, dynamic_image_size=False, 37 | use_thumbnail=False, max_num=6): 38 | f = open(root, 'r', encoding='utf-8') 39 | self.data = [json.loads(line) for line in f.readlines()] 40 | self.prompt = prompt 41 | self.input_size = input_size 42 | self.dynamic_image_size = dynamic_image_size 43 | self.use_thumbnail = use_thumbnail 44 | self.max_num = max_num 45 | self.transform = build_transform(input_size=input_size) 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | data = self.data[idx] 52 | image_path = data['image'] 53 | hint = data['hint'] if data['hint'] else None 54 | question = data['question'] 55 | 56 | choices = data['choices'] 57 | answer = data['answer'] 58 | choice_list = [] 59 | 60 | options = {} 61 | multiple_choices = ['A', 'B', 'C', 'D', 'E'] 62 | for i, c in enumerate(choices): 63 | choice_list.append('{}. {}'.format(multiple_choices[i], c)) 64 | options[multiple_choices[i]] = c 65 | choice_txt = '\n'.join(choice_list) 66 | 67 | image = Image.open(image_path).convert('RGB') 68 | if self.dynamic_image_size: 69 | images = dynamic_preprocess(image, image_size=self.input_size, 70 | use_thumbnail=self.use_thumbnail, 71 | max_num=self.max_num) 72 | else: 73 | images = [image] 74 | pixel_values = [self.transform(image) for image in images] 75 | pixel_values = torch.stack(pixel_values) 76 | 77 | if hint is not None: 78 | question = hint + '\n' + question 79 | question += '\n' + choice_txt 80 | question += '\n' + self.prompt 81 | 82 | return { 83 | 'question': question, 84 | 'pixel_values': pixel_values, 85 | 'answer': multiple_choices[answer], 86 | 'image_path': image_path, 87 | 'option': options 88 | } 89 | 90 | 91 | class InferenceSampler(torch.utils.data.sampler.Sampler): 92 | 93 | def __init__(self, size): 94 | self._size = int(size) 95 | assert size > 0 96 | self._rank = torch.distributed.get_rank() 97 | self._world_size = torch.distributed.get_world_size() 98 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 99 | 100 | @staticmethod 101 | def _get_local_indices(total_size, world_size, rank): 102 | shard_size = total_size // world_size 103 | left = total_size % world_size 104 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 105 | 106 | begin = sum(shard_sizes[:rank]) 107 | end = min(sum(shard_sizes[:rank + 1]), total_size) 108 | return range(begin, end) 109 | 110 | def __iter__(self): 111 | yield from self._local_indices 112 | 113 | def __len__(self): 114 | return len(self._local_indices) 115 | 116 | 117 | def post_process(pred, option): 118 | pred = pred.strip() 119 | option_candidate = list(option.keys()) 120 | if len(pred) == 1: 121 | return pred 122 | elif len(pred) != 1 and pred[0] in option_candidate: 123 | return pred[0] 124 | elif len(pred) != 1 and pred[0] not in option_candidate: 125 | for k, v in option.items(): 126 | if v in pred: 127 | return k 128 | 129 | return pred 130 | 131 | 132 | def evaluate_chat_model(): 133 | prompt = "Answer with the option's letter from the given choices directly." 134 | random.seed(args.seed) 135 | 136 | for ds_name in args.datasets: 137 | dataset = ScienceQADataset( 138 | root=ds_collections[ds_name]['root'], 139 | prompt=prompt, 140 | input_size=image_size, 141 | dynamic_image_size=args.dynamic, 142 | use_thumbnail=use_thumbnail, 143 | max_num=args.max_num 144 | ) 145 | dataloader = torch.utils.data.DataLoader( 146 | dataset=dataset, 147 | sampler=InferenceSampler(len(dataset)), 148 | batch_size=args.batch_size, 149 | num_workers=args.num_workers, 150 | pin_memory=True, 151 | drop_last=False, 152 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 153 | ) 154 | 155 | outputs = [] 156 | for _, (pixel_values, questions, answers, image_paths, options) in tqdm(enumerate(dataloader)): 157 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 158 | generation_config = dict( 159 | num_beams=args.num_beams, 160 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 161 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'], 162 | do_sample=True if args.temperature > 0 else False, 163 | temperature=args.temperature, 164 | ) 165 | pred = model.chat( 166 | tokenizer=tokenizer, 167 | pixel_values=pixel_values, 168 | question=questions[0], 169 | generation_config=generation_config 170 | ) 171 | preds = [post_process(pred, options[0])] 172 | 173 | for question, pred, answer, image_path in zip(questions, preds, answers, image_paths): 174 | outputs.append({ 175 | 'question': question, 176 | 'answer': pred, 177 | 'gt_answers': answer, 178 | 'image_path': image_path 179 | }) 180 | 181 | torch.distributed.barrier() 182 | 183 | world_size = torch.distributed.get_world_size() 184 | merged_outputs = [None for _ in range(world_size)] 185 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 186 | 187 | merged_outputs = [json.loads(_) for _ in merged_outputs] 188 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 189 | 190 | if torch.distributed.get_rank() == 0: 191 | 192 | print(f'Evaluating {ds_name} ...') 193 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 194 | results_file = f'{ds_name}_{time_prefix}.jsonl' 195 | output_path = os.path.join(args.out_dir, results_file) 196 | with open(output_path, 'w') as f: 197 | for output in merged_outputs: 198 | f.write(json.dumps(output) + '\n') 199 | print('Results saved to {}'.format(output_path)) 200 | cnt = 0 201 | for item in merged_outputs: 202 | if item['answer'] == item['gt_answers']: 203 | cnt += 1 204 | print(f'Acc@1: {cnt / len(merged_outputs)}') 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--checkpoint', type=str, default='') 210 | parser.add_argument('--datasets', type=str, default='sqa_test') 211 | parser.add_argument('--batch-size', type=int, default=1) 212 | parser.add_argument('--num-workers', type=int, default=1) 213 | parser.add_argument('--num-beams', type=int, default=5) 214 | parser.add_argument('--temperature', type=float, default=0.0) 215 | parser.add_argument('--out-dir', type=str, default='results') 216 | parser.add_argument('--seed', type=int, default=0) 217 | parser.add_argument('--dynamic', action='store_true') 218 | parser.add_argument('--max-num', type=int, default=6) 219 | parser.add_argument('--load-in-8bit', action='store_true') 220 | parser.add_argument('--auto', action='store_true') 221 | args = parser.parse_args() 222 | 223 | if not os.path.exists(args.out_dir): 224 | os.makedirs(args.out_dir) 225 | 226 | args.datasets = args.datasets.split(',') 227 | print('datasets:', args.datasets) 228 | assert args.batch_size == 1, 'Only batch size 1 is supported' 229 | 230 | torch.distributed.init_process_group( 231 | backend='nccl', 232 | world_size=int(os.getenv('WORLD_SIZE', '1')), 233 | rank=int(os.getenv('RANK', '0')), 234 | ) 235 | 236 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 237 | 238 | if args.auto: 239 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 240 | kwargs = {'device_map': 'auto'} if args.auto else {} 241 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 242 | model = AutoModel.from_pretrained( 243 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 244 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 245 | if not args.load_in_8bit and not args.auto: 246 | model = model.cuda() 247 | image_size = model.config.force_image_size or model.config.vision_config.image_size 248 | use_thumbnail = model.config.use_thumbnail 249 | 250 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 251 | if total_params > 20 or args.dynamic: 252 | args.num_beams = 1 253 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 254 | else: 255 | print(f'[test] total_params: {total_params}B') 256 | print(f'[test] image_size: {image_size}') 257 | print(f'[test] template: {model.config.template}') 258 | print(f'[test] dynamic_image_size: {args.dynamic}') 259 | print(f'[test] use_thumbnail: {use_thumbnail}') 260 | 261 | evaluate_chat_model() 262 | -------------------------------------------------------------------------------- /eval/seed/calculation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | argparse = argparse.ArgumentParser() 6 | argparse.add_argument('--image_result_file', type=str, default='') 7 | argparse.add_argument('--anno_path', type=str, default='data/SEED/SEED-Bench.json') 8 | 9 | args = argparse.parse_args() 10 | image_result_file = args.image_result_file 11 | anno_path = args.anno_path 12 | 13 | assert image_result_file.endswith('.jsonl') 14 | 15 | 16 | def is_integer_string(s): 17 | try: 18 | int(s) 19 | return True 20 | except ValueError: 21 | return False 22 | 23 | 24 | def filter_questions(data, task='all'): 25 | if task == 'image': 26 | return [q for q in data if 1 <= q['question_type_id'] <= 9] 27 | elif task == 'video': 28 | return [q for q in data if 10 <= q['question_type_id'] <= 12] 29 | elif task == 'all': 30 | return data 31 | elif is_integer_string(task): 32 | return [q for q in data if q['question_type_id'] == int(task)] 33 | else: 34 | raise ValueError(f'Invalid task: {task}') 35 | 36 | 37 | if __name__ == '__main__': 38 | 39 | qa_anno = json.load(open(anno_path, 'rb')) 40 | if 'questions' in qa_anno.keys(): 41 | question_type = qa_anno['question_type'] 42 | question_id_type = {v: k for k, v in question_type.items()} 43 | qa_anno = qa_anno['questions'] 44 | 45 | qa_anno = filter_questions(qa_anno, 'all') 46 | print(f'length: {len(qa_anno)}') 47 | 48 | with open(image_result_file, 'r') as f: 49 | 50 | image_result = [json.loads(line) for line in f.readlines()] 51 | 52 | results = [] 53 | 54 | results.extend(image_result) 55 | 56 | qa_id_anno = {} 57 | for item in qa_anno: 58 | question_id = str(item['question_id']) 59 | qa_id_anno[question_id] = item 60 | 61 | type_counts = {k: [] for k, v in question_id_type.items()} 62 | 63 | for item in results: 64 | pred, gt, question_id = item['prediction'], item['answer'], item['question_id'] 65 | question_id = str(question_id) 66 | question_type = qa_id_anno[question_id]['question_type_id'] 67 | data_type = qa_id_anno[question_id]['data_type'] 68 | gt = qa_id_anno[question_id]['answer'] 69 | if len(pred) != 1: 70 | pred = pred[0] 71 | if pred == gt: 72 | type_counts[question_type].append(1) 73 | else: 74 | type_counts[question_type].append(0) 75 | 76 | print('Accuracy for each data type:') 77 | total_count, image_count, video_count = 0, 0, 0 78 | total_correct, image_correct, video_correct = 0, 0, 0 79 | for data_type_id, result in type_counts.items(): 80 | accuracy = sum(result) / len(result) * 100 81 | data_type = question_id_type[data_type_id] 82 | print(f'Data type {data_type}: {accuracy:.2f}%') 83 | 84 | total_count += len(result) 85 | total_correct += sum(result) 86 | if data_type_id >= 1 and data_type_id <= 9: 87 | image_count += len(result) 88 | image_correct += sum(result) 89 | else: 90 | video_count += len(result) 91 | video_correct += sum(result) 92 | 93 | total_accuracy = total_correct / total_count * 100 94 | image_accuracy = image_correct / image_count * 100 95 | video_accuracy = video_correct / video_count * 100 96 | 97 | print(f'Total accuracy: {total_accuracy:.2f}%') 98 | print(f'Image accuracy: {image_accuracy:.2f}%') 99 | print(f'Video accuracy: {video_accuracy:.2f}%') 100 | -------------------------------------------------------------------------------- /eval/seed/evaluate_seed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | from utils.preprocess import build_transform, dynamic_preprocess 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | from transformers import AutoTokenizer, AutoModel 15 | 16 | ds_collections = { 17 | 'SEEDv1': { 18 | 'root': 'data/SEED/', 19 | 'annotation': 'data/SEED/seed.jsonl', 20 | 'max_new_tokens': 100, 21 | 'min_new_tokens': 1, 22 | }, 23 | } 24 | 25 | 26 | def collate_fn(batches, tokenizer): 27 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 28 | questions = [_['question'] for _ in batches] 29 | answers = [_['answer'] for _ in batches] 30 | indexes = [_['index'] for _ in batches] 31 | return pixel_values, questions, answers, indexes 32 | 33 | 34 | class MultipleChoiceDataset(torch.utils.data.Dataset): 35 | 36 | def __init__(self, root, annotation, input_size=224, dynamic_image_size=False, 37 | use_thumbnail=False, max_num=6): 38 | f = open(annotation, 'r', encoding='utf-8') 39 | self.data = [json.loads(line) for line in f.readlines()] 40 | self.root = root 41 | self.input_size = input_size 42 | self.dynamic_image_size = dynamic_image_size 43 | self.use_thumbnail = use_thumbnail 44 | self.max_num = max_num 45 | self.transform = build_transform(input_size=input_size) 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | data = self.data[idx] 52 | question = data['text'] 53 | image_path = os.path.join(self.root, data['image']) 54 | image = Image.open(image_path).convert('RGB') 55 | if self.dynamic_image_size: 56 | images = dynamic_preprocess(image, image_size=self.input_size, 57 | use_thumbnail=self.use_thumbnail, 58 | max_num=self.max_num) 59 | else: 60 | images = [image] 61 | pixel_values = [self.transform(image) for image in images] 62 | pixel_values = torch.stack(pixel_values) 63 | answer = data['answer'] if 'answer' in data else None 64 | return { 65 | 'question': question, 66 | 'pixel_values': pixel_values, 67 | 'answer': answer, 68 | 'index': data['question_id'], 69 | } 70 | 71 | 72 | class InferenceSampler(torch.utils.data.sampler.Sampler): 73 | 74 | def __init__(self, size): 75 | self._size = int(size) 76 | assert size > 0 77 | self._rank = torch.distributed.get_rank() 78 | self._world_size = torch.distributed.get_world_size() 79 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 80 | 81 | @staticmethod 82 | def _get_local_indices(total_size, world_size, rank): 83 | shard_size = total_size // world_size 84 | left = total_size % world_size 85 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 86 | 87 | begin = sum(shard_sizes[:rank]) 88 | end = min(sum(shard_sizes[:rank + 1]), total_size) 89 | return range(begin, end) 90 | 91 | def __iter__(self): 92 | yield from self._local_indices 93 | 94 | def __len__(self): 95 | return len(self._local_indices) 96 | 97 | 98 | def post_process(pred, option): 99 | pred = pred.strip() 100 | option_candidate = list(option.keys()) 101 | if len(pred) == 1: 102 | return pred 103 | elif len(pred) != 1 and pred[0] in option_candidate: 104 | return pred[0] 105 | elif len(pred) != 1 and pred[0] not in option_candidate: 106 | for k, v in option.items(): 107 | if v in pred: 108 | return k 109 | 110 | return pred 111 | 112 | 113 | def evaluate_chat_model(): 114 | random.seed(args.seed) 115 | 116 | for ds_name in args.datasets: 117 | dataset = MultipleChoiceDataset( 118 | root=ds_collections[ds_name]['root'], 119 | annotation=ds_collections[ds_name]['annotation'], 120 | input_size=image_size, 121 | dynamic_image_size=args.dynamic, 122 | use_thumbnail=use_thumbnail, 123 | max_num=args.max_num 124 | ) 125 | dataloader = torch.utils.data.DataLoader( 126 | dataset=dataset, 127 | sampler=InferenceSampler(len(dataset)), 128 | batch_size=args.batch_size, 129 | num_workers=args.num_workers, 130 | pin_memory=True, 131 | drop_last=False, 132 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 133 | ) 134 | 135 | outputs = [] 136 | for _, (pixel_values, questions, answers, indexes) in enumerate(tqdm(dataloader)): 137 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 138 | generation_config = dict( 139 | num_beams=args.num_beams, 140 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 141 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'], 142 | do_sample=True if args.temperature > 0 else False, 143 | temperature=args.temperature, 144 | ) 145 | pred = model.chat( 146 | tokenizer=tokenizer, 147 | pixel_values=pixel_values, 148 | question=questions[0], 149 | generation_config=generation_config 150 | ) 151 | preds = [pred] 152 | 153 | for question, pred, answer, index in zip(questions, preds, answers, indexes): 154 | outputs.append({ 155 | 'question_id': index, 156 | 'question': question, 157 | 'prediction': pred, 158 | 'answer': answer, 159 | }) 160 | 161 | torch.distributed.barrier() 162 | 163 | world_size = torch.distributed.get_world_size() 164 | merged_outputs = [None for _ in range(world_size)] 165 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 166 | 167 | merged_outputs = [json.loads(_) for _ in merged_outputs] 168 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 169 | 170 | if torch.distributed.get_rank() == 0: 171 | 172 | print(f'Evaluating {ds_name} ...') 173 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 174 | results_file = f'{ds_name}_{time_prefix}.jsonl' 175 | output_path = os.path.join(args.out_dir, results_file) 176 | writer = open(output_path, 'w') 177 | 178 | results = [] 179 | for item in merged_outputs: 180 | writer.write(json.dumps(item) + '\n') 181 | answer = item['answer'] 182 | prediction = item['prediction'] 183 | if prediction == answer: 184 | results.append(1) 185 | else: 186 | results.append(0) 187 | writer.close() 188 | print('Results saved to {}'.format(output_path)) 189 | print(f'Acc@1: {sum(results) / len(results)}') 190 | cmd = f'python eval/seed/calculation.py --image_result_file {output_path}' 191 | os.system(cmd) 192 | 193 | 194 | if __name__ == '__main__': 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument('--checkpoint', type=str, default='') 197 | parser.add_argument('--datasets', type=str, default='SEEDv1') 198 | parser.add_argument('--batch-size', type=int, default=1) 199 | parser.add_argument('--num-workers', type=int, default=1) 200 | parser.add_argument('--num-beams', type=int, default=5) 201 | parser.add_argument('--temperature', type=float, default=0.0) 202 | parser.add_argument('--out-dir', type=str, default='results') 203 | parser.add_argument('--seed', type=int, default=0) 204 | parser.add_argument('--dynamic', action='store_true') 205 | parser.add_argument('--max-num', type=int, default=6) 206 | parser.add_argument('--load-in-8bit', action='store_true') 207 | parser.add_argument('--auto', action='store_true') 208 | args = parser.parse_args() 209 | 210 | if not os.path.exists(args.out_dir): 211 | os.makedirs(args.out_dir) 212 | 213 | args.datasets = args.datasets.split(',') 214 | print('datasets:', args.datasets) 215 | assert args.batch_size == 1, 'Only batch size 1 is supported' 216 | 217 | torch.distributed.init_process_group( 218 | backend='nccl', 219 | world_size=int(os.getenv('WORLD_SIZE', '1')), 220 | rank=int(os.getenv('RANK', '0')), 221 | ) 222 | 223 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 224 | 225 | if args.auto: 226 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 227 | kwargs = {'device_map': 'auto'} if args.auto else {} 228 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 229 | model = AutoModel.from_pretrained( 230 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 231 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 232 | if not args.load_in_8bit and not args.auto: 233 | model = model.cuda() 234 | image_size = model.config.force_image_size or model.config.vision_config.image_size 235 | use_thumbnail = model.config.use_thumbnail 236 | 237 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 238 | if total_params > 20 or args.dynamic: 239 | args.num_beams = 1 240 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 241 | else: 242 | print(f'[test] total_params: {total_params}B') 243 | print(f'[test] image_size: {image_size}') 244 | print(f'[test] template: {model.config.template}') 245 | print(f'[test] dynamic_image_size: {args.dynamic}') 246 | print(f'[test] use_thumbnail: {use_thumbnail}') 247 | 248 | evaluate_chat_model() 249 | -------------------------------------------------------------------------------- /eval/vqa/convert_gqa_for_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--src', type=str) 6 | parser.add_argument('--dst', type=str) 7 | args = parser.parse_args() 8 | 9 | all_answers = [] 10 | data = json.load(open(args.src)) 11 | for res in data: 12 | question_id = res['questionId'] 13 | answer = res['answer'].rstrip('.').lower() 14 | all_answers.append({'questionId': question_id, 'prediction': answer}) 15 | 16 | with open(args.dst, 'w') as f: 17 | json.dump(all_answers, f) 18 | -------------------------------------------------------------------------------- /eval/vqa/evaluate_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import subprocess 7 | import time 8 | from functools import partial 9 | from typing import Optional 10 | 11 | import torch 12 | from utils.preprocess import build_transform, dynamic_preprocess 13 | from PIL import Image 14 | from textvqa_eval import TextVQAAccuracyEvaluator 15 | from tqdm import tqdm 16 | from transformers import AutoTokenizer, AutoModel 17 | 18 | ds_collections = { 19 | 'vqav2_val': { 20 | 'train': 'data/vqav2/vqav2_train.jsonl', 21 | 'test': 'data/vqav2/vqav2_val.jsonl', 22 | 'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json', 23 | 'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json', 24 | 'metric': 'vqa_score', 25 | 'max_new_tokens': 10, 26 | }, 27 | 'vqav2_testdev': { 28 | 'train': 'data/vqav2/vqav2_train.jsonl', 29 | 'test': 'data/vqav2/vqav2_testdev.jsonl', 30 | 'metric': None, 31 | 'max_new_tokens': 10, 32 | }, 33 | 'okvqa_val': { 34 | 'train': 'data/okvqa/okvqa_train.jsonl', 35 | 'test': 'data/okvqa/okvqa_val.jsonl', 36 | 'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json', 37 | 'annotation': 'data/okvqa/mscoco_val2014_annotations.json', 38 | 'metric': 'vqa_score', 39 | 'max_new_tokens': 10, 40 | }, 41 | 'textvqa_val': { 42 | 'train': 'data/textvqa/textvqa_train.jsonl', 43 | 'test': 'data/textvqa/textvqa_val.jsonl', 44 | 'question': 'data/textvqa/textvqa_val_questions.json', 45 | 'annotation': 'data/textvqa/textvqa_val_annotations.json', 46 | 'metric': 'vqa_score', 47 | 'max_new_tokens': 10, 48 | }, 49 | 'textvqa_val_ocr': { 50 | 'train': 'data/textvqa/textvqa_train.jsonl', 51 | 'test': 'data/textvqa/textvqa_val_llava.jsonl', 52 | 'question': 'data/textvqa/textvqa_val_questions.json', 53 | 'annotation': 'data/textvqa/textvqa_val_annotations.json', 54 | 'metric': 'vqa_score', 55 | 'max_new_tokens': 10, 56 | }, 57 | 'vizwiz_val': { 58 | 'train': 'data/vizwiz/vizwiz_train.jsonl', 59 | 'test': 'data/vizwiz/vizwiz_val.jsonl', 60 | 'question': 'data/vizwiz/vizwiz_val_questions.json', 61 | 'annotation': 'data/vizwiz/vizwiz_val_annotations.json', 62 | 'metric': 'vqa_score', 63 | 'max_new_tokens': 10, 64 | }, 65 | 'vizwiz_test': { 66 | 'train': 'data/vizwiz/vizwiz_train.jsonl', 67 | 'test': 'data/vizwiz/vizwiz_test.jsonl', 68 | 'metric': None, 69 | 'max_new_tokens': 10, 70 | }, 71 | 'docvqa_val': { 72 | 'train': 'data/docvqa/train.jsonl', 73 | 'test': 'data/docvqa/val.jsonl', 74 | 'annotation': 'data/docvqa/val/val_v1.0.json', 75 | 'metric': 'anls', 76 | 'max_new_tokens': 100, 77 | }, 78 | 'docvqa_test': { 79 | 'train': 'data/docvqa/train.jsonl', 80 | 'test': 'data/docvqa/test.jsonl', 81 | 'metric': None, 82 | 'max_new_tokens': 100, 83 | }, 84 | 'chartqa_test_human': { 85 | 'train': 'data/chartqa/train_human.jsonl', 86 | 'test': 'data/chartqa/test_human.jsonl', 87 | 'metric': 'relaxed_accuracy', 88 | 'max_new_tokens': 100, 89 | }, 90 | 'chartqa_test_augmented': { 91 | 'train': 'data/chartqa/train_augmented.jsonl', 92 | 'test': 'data/chartqa/test_augmented.jsonl', 93 | 'metric': 'relaxed_accuracy', 94 | 'max_new_tokens': 100, 95 | }, 96 | 'gqa_testdev': { 97 | 'train': 'data/gqa/train.jsonl', 98 | 'test': 'data/gqa/test_balanced.jsonl', 99 | 'metric': 'accuracy', 100 | 'max_new_tokens': 10, 101 | }, 102 | 'gqa_testdev_llava': { 103 | 'train': 'data/gqa/train.jsonl', 104 | 'test': 'data/gqa/llava_gqa_testdev_balanced_qwen_format.jsonl', 105 | 'metric': 'accuracy', 106 | 'max_new_tokens': 10, 107 | }, 108 | 'ocrvqa_val': { 109 | 'train': 'data/ocrvqa/ocrvqa_train.jsonl', 110 | 'test': 'data/ocrvqa/ocrvqa_val.jsonl', 111 | 'metric': 'accuracy', 112 | 'max_new_tokens': 100, 113 | }, 114 | 'ocrvqa_test': { 115 | 'train': 'data/ocrvqa/ocrvqa_train.jsonl', 116 | 'test': 'data/ocrvqa/ocrvqa_test.jsonl', 117 | 'metric': 'accuracy', 118 | 'max_new_tokens': 100, 119 | }, 120 | 'ai2diagram_test': { 121 | 'train': 'data/ai2diagram/train.jsonl', 122 | 'test': 'data/ai2diagram/test_vlmevalkit.jsonl', 123 | 'metric': 'accuracy', 124 | 'max_new_tokens': 10, 125 | }, 126 | 'infographicsvqa_val': { 127 | 'train': 'data/infographicsvqa/train.jsonl', 128 | 'test': 'data/infographicsvqa/val.jsonl', 129 | 'annotation': 'data/infographicsvqa/infographicsVQA_val_v1.0_withQT.json', 130 | 'metric': 'anls', 131 | 'max_new_tokens': 100, 132 | }, 133 | 'infographicsvqa_test': { 134 | 'train': 'data/infographicsvqa/train.jsonl', 135 | 'test': 'data/infographicsvqa/test.jsonl', 136 | 'annotation': 'data/infographicsvqa/infographicsVQA_test_v1.0.json', 137 | 'metric': None, 138 | 'max_new_tokens': 100, 139 | } 140 | } 141 | 142 | 143 | # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81 144 | def relaxed_correctness(target: str, 145 | prediction: str, 146 | max_relative_change: float = 0.05) -> bool: 147 | """Calculates relaxed correctness. 148 | 149 | The correctness tolerates certain error ratio defined by max_relative_change. 150 | See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: 151 | “Following Methani et al. (2020), we use a relaxed accuracy measure for the 152 | numeric answers to allow a minor inaccuracy that may result from the automatic 153 | data extraction process. We consider an answer to be correct if it is within 154 | 5% of the gold answer. For non-numeric answers, we still need an exact match 155 | to consider an answer to be correct.” 156 | 157 | Args: 158 | target: Target string. 159 | prediction: Predicted string. 160 | max_relative_change: Maximum relative change. 161 | 162 | Returns: 163 | Whether the prediction was correct given the specified tolerance. 164 | """ 165 | 166 | def _to_float(text: str) -> Optional[float]: 167 | try: 168 | if text.endswith('%'): 169 | # Convert percentages to floats. 170 | return float(text.rstrip('%')) / 100.0 171 | else: 172 | return float(text) 173 | except ValueError: 174 | return None 175 | 176 | prediction_float = _to_float(prediction) 177 | target_float = _to_float(target) 178 | if prediction_float is not None and target_float: 179 | relative_change = abs(prediction_float - 180 | target_float) / abs(target_float) 181 | return relative_change <= max_relative_change 182 | else: 183 | return prediction.lower() == target.lower() 184 | 185 | 186 | def evaluate_relaxed_accuracy(entries): 187 | scores = [] 188 | for elem in entries: 189 | if isinstance(elem['annotation'], str): 190 | elem['annotation'] = [elem['annotation']] 191 | score = max([ 192 | relaxed_correctness(elem['answer'].strip(), ann) 193 | for ann in elem['annotation'] 194 | ]) 195 | scores.append(score) 196 | return sum(scores) / len(scores) 197 | 198 | 199 | def evaluate_exact_match_accuracy(entries): 200 | scores = [] 201 | for elem in entries: 202 | if isinstance(elem['annotation'], str): 203 | elem['annotation'] = [elem['annotation']] 204 | score = max([ 205 | (1.0 if 206 | (elem['answer'].strip().lower() == ann.strip().lower()) else 0.0) 207 | for ann in elem['annotation'] 208 | ]) 209 | scores.append(score) 210 | return sum(scores) / len(scores) 211 | 212 | 213 | def collate_fn(batches, tokenizer): 214 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 215 | questions = [_['question'] for _ in batches] 216 | question_ids = [_['question_id'] for _ in batches] 217 | annotations = [_['annotation'] for _ in batches] 218 | 219 | return pixel_values, questions, question_ids, annotations 220 | 221 | 222 | class VQADataset(torch.utils.data.Dataset): 223 | 224 | def __init__(self, train, test, prompt, few_shot, input_size=224, dynamic_image_size=False, 225 | use_thumbnail=False, max_num=6): 226 | self.test = open(test).readlines() 227 | self.prompt = prompt 228 | self.input_size = input_size 229 | self.dynamic_image_size = dynamic_image_size 230 | self.use_thumbnail = use_thumbnail 231 | self.few_shot = few_shot 232 | self.max_num = max_num 233 | if few_shot > 0: 234 | self.train = open(train).readlines() 235 | self.transform = build_transform(input_size=input_size) 236 | 237 | def __len__(self): 238 | return len(self.test) 239 | 240 | def __getitem__(self, idx): 241 | data = json.loads(self.test[idx].strip()) 242 | image, question, question_id, annotation = data['image'], data[ 243 | 'question'], data['question_id'], data.get('answer', None) 244 | 245 | few_shot_prompt = '' 246 | if self.few_shot > 0: 247 | few_shot_samples = random.sample(self.train, self.few_shot) 248 | for sample in few_shot_samples: 249 | sample = json.loads(sample.strip()) 250 | few_shot_prompt += self.prompt.format( 251 | sample['image'], 252 | sample['question']) + f" {sample['answer']}" 253 | 254 | image = Image.open(image).convert('RGB') 255 | if self.dynamic_image_size: 256 | images = dynamic_preprocess(image, image_size=self.input_size, 257 | use_thumbnail=self.use_thumbnail, 258 | max_num=self.max_num) 259 | else: 260 | images = [image] 261 | pixel_values = [self.transform(image) for image in images] 262 | pixel_values = torch.stack(pixel_values) 263 | if len(self.prompt) != 0: 264 | question = question + ' ' + self.prompt 265 | return { 266 | 'question_id': question_id, 267 | 'question': question, 268 | 'pixel_values': pixel_values, 269 | 'annotation': annotation 270 | } 271 | 272 | 273 | class InferenceSampler(torch.utils.data.sampler.Sampler): 274 | 275 | def __init__(self, size): 276 | self._size = int(size) 277 | assert size > 0 278 | self._rank = torch.distributed.get_rank() 279 | self._world_size = torch.distributed.get_world_size() 280 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 281 | 282 | @staticmethod 283 | def _get_local_indices(total_size, world_size, rank): 284 | shard_size = total_size // world_size 285 | left = total_size % world_size 286 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 287 | 288 | begin = sum(shard_sizes[:rank]) 289 | end = min(sum(shard_sizes[:rank + 1]), total_size) 290 | return range(begin, end) 291 | 292 | def __iter__(self): 293 | yield from self._local_indices 294 | 295 | def __len__(self): 296 | return len(self._local_indices) 297 | 298 | 299 | def post_process(response): 300 | response = response.strip().split('.')[0].split( 301 | ',')[0].split('!')[0].lower() 302 | if 'is ' in response: 303 | response = response.split('is ')[1] 304 | if 'are ' in response: 305 | response = response.split('are ')[1] 306 | if 'a ' in response: 307 | response = response.split('a ')[1] 308 | if 'an ' in response: 309 | response = response.split('an ')[1] 310 | if 'the ' in response: 311 | response = response.split('the ')[1] 312 | if ' of' in response: 313 | response = response.split(' of')[0] 314 | response = response.strip() 315 | return response 316 | 317 | 318 | def evaluate_chat_model(): 319 | base_prompt = 'Answer the question using a single word or phrase.' 320 | vizwiz_prompt = "When the provided information is insufficient, respond with 'Unanswerable'. " 321 | # infovqa_prompt = 'Answer the question directly.' 322 | infovqa_prompt = 'Answer the question using a single word or phrase.' 323 | ai2d_prompt = '' 324 | random.seed(args.seed) 325 | summaries = [] 326 | 327 | for ds_name in args.datasets: 328 | if 'vizwiz' in ds_name: 329 | input_prompt = vizwiz_prompt + base_prompt 330 | elif 'ai2d' in ds_name: 331 | input_prompt = ai2d_prompt 332 | elif 'infographicsvqa' in ds_name: 333 | input_prompt = infovqa_prompt 334 | else: 335 | input_prompt = base_prompt 336 | 337 | dataset = VQADataset( 338 | train=ds_collections[ds_name]['train'], 339 | test=ds_collections[ds_name]['test'], 340 | prompt=input_prompt, 341 | few_shot=args.few_shot, 342 | input_size=image_size, 343 | dynamic_image_size=args.dynamic, 344 | use_thumbnail=use_thumbnail, 345 | max_num=args.max_num 346 | ) 347 | dataloader = torch.utils.data.DataLoader( 348 | dataset=dataset, 349 | sampler=InferenceSampler(len(dataset)), 350 | batch_size=args.batch_size, 351 | num_workers=args.num_workers, 352 | pin_memory=True, 353 | drop_last=False, 354 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 355 | ) 356 | 357 | outputs = [] 358 | for _, (pixel_values, questions, question_ids, annotations) in tqdm(enumerate(dataloader)): 359 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 360 | generation_config = dict( 361 | num_beams=args.num_beams, 362 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 363 | min_new_tokens=1, 364 | do_sample=True if args.temperature > 0 else False, 365 | temperature=args.temperature, 366 | ) 367 | pred = model.chat( 368 | tokenizer=tokenizer, 369 | pixel_values=pixel_values, 370 | question=questions[0], 371 | generation_config=generation_config 372 | ) 373 | answers = [pred] 374 | 375 | for question, question_id, answer, annotation in zip(questions, question_ids, answers, annotations): 376 | if ds_name in ['vqav2_val', 'vqav2_testdev', 'okvqa_val', 'textvqa_val', 377 | 'vizwiz_val', 'textvqa_val_ocr']: 378 | outputs.append({ 379 | 'question': question, 380 | 'question_id': question_id, 381 | 'answer': answer, 382 | }) 383 | elif ds_name in ['docvqa_val', 'infographicsvqa_val', 'gqa_testdev', 'ocrvqa_val', 384 | 'ocrvqa_test', 'gqa_testdev_llava', 'infographicsvqa_test',]: 385 | outputs.append({ 386 | 'question': question, 387 | 'questionId': question_id, 388 | 'answer': answer, 389 | 'annotation': annotation, 390 | }) 391 | elif ds_name in ['ai2diagram_test']: 392 | outputs.append({ 393 | 'question': question, 394 | 'image': question_id, 395 | 'answer': answer, 396 | 'annotation': annotation, 397 | }) 398 | elif ds_name in ['chartqa_test_human', 'chartqa_test_augmented']: 399 | outputs.append({ 400 | 'question': question, 401 | 'answer': answer, 402 | 'annotation': annotation, 403 | }) 404 | elif ds_name in ['docvqa_test']: 405 | outputs.append({ 406 | 'questionId': question_id, 407 | 'answer': answer, 408 | }) 409 | elif ds_name in ['vizwiz_test']: 410 | outputs.append({ 411 | 'image': question_id.replace('data/vizwiz/test/', ''), 412 | 'answer': answer, 413 | }) 414 | else: 415 | raise NotImplementedError 416 | 417 | torch.distributed.barrier() 418 | 419 | world_size = torch.distributed.get_world_size() 420 | merged_outputs = [None for _ in range(world_size)] 421 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 422 | 423 | merged_outputs = [json.loads(_) for _ in merged_outputs] 424 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 425 | 426 | if torch.distributed.get_rank() == 0: 427 | print(f'Evaluating {ds_name} ...') 428 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 429 | results_file = f'{ds_name}_{time_prefix}.json' 430 | results_file = os.path.join(args.out_dir, results_file) 431 | json.dump(merged_outputs, open(results_file, 'w')) 432 | print('Results saved to {}'.format(results_file)) 433 | 434 | if ds_collections[ds_name]['metric'] == 'vqa_score': 435 | evaluator = TextVQAAccuracyEvaluator() 436 | annotation = json.load(open(ds_collections[ds_name]['annotation'], 'r'))['annotations'] 437 | question_id2answers = {} 438 | for item in annotation: 439 | question_id = item['question_id'] 440 | answers = [answer['answer'] for answer in item['answers']] 441 | question_id2answers[question_id] = answers 442 | for item in merged_outputs: 443 | item['pred_answer'] = item['answer'] 444 | item['gt_answers'] = question_id2answers[item['question_id']] 445 | accuracy = evaluator.eval_pred_list(merged_outputs) 446 | print(ds_name, accuracy) 447 | summaries.append([args.checkpoint, ds_name, accuracy]) 448 | 449 | elif ds_collections[ds_name]['metric'] == 'anls': 450 | json.dump(merged_outputs, 451 | open(results_file, 'w'), 452 | ensure_ascii=False) 453 | print('python eval/vqa/infographicsvqa_eval.py -g ' + 454 | ds_collections[ds_name]['annotation'] + ' -s ' + 455 | results_file) 456 | os.system('python eval/vqa/infographicsvqa_eval.py -g ' + 457 | ds_collections[ds_name]['annotation'] + ' -s ' + 458 | results_file) 459 | elif ds_collections[ds_name]['metric'] == 'relaxed_accuracy': 460 | relaxed_accuracy = evaluate_relaxed_accuracy(merged_outputs) 461 | print(ds_name, {'relaxed_accuracy': relaxed_accuracy}) 462 | summaries.append([ds_name, {'relaxed_accuracy': relaxed_accuracy}]) 463 | elif ds_collections[ds_name]['metric'] == 'accuracy': 464 | if 'gqa' in ds_name: 465 | dst_file = './data/gqa/testdev_balanced_predictions.json' 466 | print('python eval/vqa/convert_gqa_for_eval.py --src ' + 467 | results_file + ' --dst ' + dst_file) 468 | python_path = 'python' 469 | os.system(python_path + ' eval/vqa/convert_gqa_for_eval.py --src ' + 470 | results_file + ' --dst ' + dst_file) 471 | command = f'cd ./data/gqa/ && {python_path} eval.py --tier testdev_balanced && cd ../../' 472 | print(command) 473 | accuracy = subprocess.check_output(command, shell=True, universal_newlines=True) 474 | else: 475 | accuracy = {'accuracy': evaluate_exact_match_accuracy(merged_outputs)} 476 | print(ds_name, accuracy) 477 | summaries.append([args.checkpoint, ds_name, accuracy]) 478 | 479 | torch.distributed.barrier() 480 | 481 | out_path = '_'.join(args.checkpoint.split('/')[-2:]) 482 | writer = open(os.path.join(args.out_dir, f'{out_path}.txt'), 'a') 483 | print(f"write results to file {os.path.join(args.out_dir, f'{out_path}.txt')}") 484 | for summary in summaries: 485 | print(summary) 486 | writer.write(f'{summary}\n') 487 | writer.close() 488 | 489 | 490 | if __name__ == '__main__': 491 | 492 | parser = argparse.ArgumentParser() 493 | parser.add_argument('--checkpoint', type=str, default='') 494 | parser.add_argument('--datasets', type=str, 495 | default='okvqa_val,textvqa_val,vizwiz_val,ai2diagram_test,gqa_testdev_llava') 496 | parser.add_argument('--batch-size', type=int, default=1) 497 | parser.add_argument('--num-workers', type=int, default=1) 498 | parser.add_argument('--num-beams', type=int, default=5) 499 | parser.add_argument('--temperature', type=float, default=0.0) 500 | parser.add_argument('--out-dir', type=str, default='results') 501 | parser.add_argument('--few-shot', type=int, default=0) 502 | parser.add_argument('--seed', type=int, default=0) 503 | parser.add_argument('--dynamic', action='store_true') 504 | parser.add_argument('--max-num', type=int, default=6) 505 | parser.add_argument('--load-in-8bit', action='store_true') 506 | parser.add_argument('--auto', action='store_true') 507 | args = parser.parse_args() 508 | 509 | if not os.path.exists(args.out_dir): 510 | os.makedirs(args.out_dir) 511 | 512 | args.datasets = args.datasets.split(',') 513 | print('datasets:', args.datasets) 514 | assert args.batch_size == 1, 'Only batch size 1 is supported' 515 | 516 | torch.distributed.init_process_group( 517 | backend='nccl', 518 | world_size=int(os.getenv('WORLD_SIZE', '1')), 519 | rank=int(os.getenv('RANK', '0')), 520 | ) 521 | 522 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 523 | 524 | if args.auto: 525 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 526 | kwargs = {'device_map': 'auto'} if args.auto else {} 527 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 528 | model = AutoModel.from_pretrained( 529 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 530 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 531 | if not args.load_in_8bit and not args.auto: 532 | model = model.cuda() 533 | image_size = model.config.force_image_size or model.config.vision_config.image_size 534 | use_thumbnail = model.config.use_thumbnail 535 | 536 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 537 | if total_params > 20 or args.dynamic: 538 | args.num_beams = 1 539 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 540 | else: 541 | print(f'[test] total_params: {total_params}B') 542 | print(f'[test] image_size: {image_size}') 543 | print(f'[test] template: {model.config.template}') 544 | print(f'[test] dynamic_image_size: {args.dynamic}') 545 | print(f'[test] use_thumbnail: {use_thumbnail}') 546 | 547 | evaluate_chat_model() 548 | -------------------------------------------------------------------------------- /eval/vqa/infographicsvqa_eval.py: -------------------------------------------------------------------------------- 1 | # This file can be downloaded from: https://www.docvqa.org/datasets/infographicvqa and https://rrc.cvc.uab.es/?ch=17&com=introduction 2 | 3 | import argparse 4 | import json 5 | import os 6 | 7 | question_ids_to_exclude = [] 8 | 9 | # answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span', 'list': 'List'} 10 | answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 11 | 'non span': 'None span'} 12 | evidence_types = {'table/list': 'Table/list', 'textual': 'Text', 'photo/pciture/visual_objects': 'Visual/Layout', 13 | 'figure': 'Figure', 'map': 'Map'} 14 | reasoning_requirements = {'comparison': 'Sorting', 'arithmetic': 'Arithmetic', 'counting': 'Counting'} 15 | 16 | 17 | def save_json(file_path, data): 18 | with open(file_path, 'w+') as json_file: 19 | json.dump(data, json_file) 20 | 21 | 22 | def levenshtein_distance(s1, s2): 23 | if len(s1) > len(s2): 24 | s1, s2 = s2, s1 25 | 26 | distances = range(len(s1) + 1) 27 | for i2, c2 in enumerate(s2): 28 | distances_ = [i2 + 1] 29 | for i1, c1 in enumerate(s1): 30 | if c1 == c2: 31 | distances_.append(distances[i1]) 32 | else: 33 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 34 | distances = distances_ 35 | return distances[-1] 36 | 37 | 38 | def validate_data(gtFilePath, submFilePath): 39 | """ 40 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 41 | Validates also that there are no missing files in the folder. 42 | If some error detected, the method raises the error 43 | """ 44 | 45 | gtJson = json.load(open(gtFilePath, 'rb')) 46 | submJson = json.load(open(submFilePath, 'rb')) 47 | 48 | if 'data' not in gtJson: 49 | raise Exception('The GT file is not valid (no data key)') 50 | 51 | if 'dataset_name' not in gtJson: 52 | raise Exception('The GT file is not valid (no dataset_name key)') 53 | 54 | if isinstance(submJson, list) is False: 55 | raise Exception('The Det file is not valid (root item must be an array)') 56 | 57 | if len(submJson) != len(gtJson['data']): 58 | raise Exception('The Det file is not valid (invalid number of answers. Expected:' + str( 59 | len(gtJson['data'])) + ' Found:' + str(len(submJson)) + ')') 60 | 61 | gtQuestions = sorted([r['questionId'] for r in gtJson['data']]) 62 | res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)} 63 | detQuestions = sorted([r['questionId'] for r in submJson]) 64 | 65 | if ((gtQuestions == detQuestions) is False): 66 | raise Exception('The Det file is not valid. Question IDs must much GT') 67 | 68 | for gtObject in gtJson['data']: 69 | 70 | try: 71 | q_id = int(gtObject['questionId']) 72 | res_ix = res_id_to_index[q_id] 73 | 74 | except: 75 | raise Exception('The Det file is not valid. Question ' + str(gtObject['questionId']) + ' not present') 76 | 77 | else: 78 | detObject = submJson[res_ix] 79 | 80 | # if detObject['questionId'] != gtObject['questionId'] : 81 | # raise Exception("Answer #" + str(i) + " not valid (invalid question ID. Expected:" + str(gtObject['questionId']) + "Found:" + detObject['questionId'] + ")") 82 | 83 | if 'answer' not in detObject: 84 | raise Exception('Question ' + str(gtObject['questionId']) + ' not valid (no answer key)') 85 | 86 | if isinstance(detObject['answer'], list) is True: 87 | raise Exception( 88 | 'Question ' + str(gtObject['questionId']) + ' not valid (answer key has to be a single string)') 89 | 90 | 91 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 92 | """ 93 | Method evaluate_method: evaluate method and returns the results 94 | Results. Dictionary with the following values: 95 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 96 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 97 | """ 98 | 99 | show_scores_per_answer_type = evaluationParams.answer_types 100 | 101 | gtJson = json.load(open(gtFilePath, 'rb')) 102 | submJson = json.load(open(submFilePath, 'rb')) 103 | 104 | res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)} 105 | 106 | perSampleMetrics = {} 107 | 108 | totalScore = 0 109 | row = 0 110 | 111 | if show_scores_per_answer_type: 112 | answerTypeTotalScore = {x: 0 for x in answer_types.keys()} 113 | answerTypeNumQuestions = {x: 0 for x in answer_types.keys()} 114 | 115 | evidenceTypeTotalScore = {x: 0 for x in evidence_types.keys()} 116 | evidenceTypeNumQuestions = {x: 0 for x in evidence_types.keys()} 117 | 118 | reasoningTypeTotalScore = {x: 0 for x in reasoning_requirements.keys()} 119 | reasoningTypeNumQuestions = {x: 0 for x in reasoning_requirements.keys()} 120 | 121 | for gtObject in gtJson['data']: 122 | 123 | q_id = int(gtObject['questionId']) 124 | res_ix = res_id_to_index[q_id] 125 | detObject = submJson[res_ix] 126 | 127 | if q_id in question_ids_to_exclude: 128 | question_result = 0 129 | info = 'Question EXCLUDED from the result' 130 | 131 | else: 132 | info = '' 133 | values = [] 134 | for answer in gtObject['answers']: 135 | # preprocess both the answers - gt and prediction 136 | gt_answer = ' '.join(answer.strip().lower().split()) 137 | det_answer = ' '.join(detObject['answer'].strip().lower().split()) 138 | 139 | # dist = levenshtein_distance(answer.lower(), detObject['answer'].lower()) 140 | dist = levenshtein_distance(gt_answer, det_answer) 141 | length = max(len(answer.upper()), len(detObject['answer'].upper())) 142 | values.append(0.0 if length == 0 else float(dist) / float(length)) 143 | 144 | question_result = 1 - min(values) 145 | 146 | if (question_result < evaluationParams.anls_threshold): 147 | question_result = 0 148 | 149 | totalScore += question_result 150 | 151 | if show_scores_per_answer_type: 152 | for q_type in gtObject['answer_type']: 153 | answerTypeTotalScore[q_type] += question_result 154 | answerTypeNumQuestions[q_type] += 1 155 | 156 | for q_type in gtObject['evidence']: 157 | evidenceTypeTotalScore[q_type] += question_result 158 | evidenceTypeNumQuestions[q_type] += 1 159 | 160 | for q_type in gtObject['operation/reasoning']: 161 | reasoningTypeTotalScore[q_type] += question_result 162 | reasoningTypeNumQuestions[q_type] += 1 163 | 164 | perSampleMetrics[str(gtObject['questionId'])] = { 165 | 'score': question_result, 166 | 'question': gtObject['question'], 167 | 'gt': gtObject['answers'], 168 | 'det': detObject['answer'], 169 | 'info': info 170 | } 171 | row = row + 1 172 | 173 | methodMetrics = { 174 | 'score': 0 if len(gtJson['data']) == 0 else totalScore / (len(gtJson['data']) - len(question_ids_to_exclude)) 175 | } 176 | 177 | answer_types_scores = {} 178 | evidence_types_scores = {} 179 | operation_types_scores = {} 180 | 181 | if show_scores_per_answer_type: 182 | for a_type, ref in answer_types.items(): 183 | answer_types_scores[ref] = 0 if len(gtJson['data']) == 0 else answerTypeTotalScore[a_type] / ( 184 | answerTypeNumQuestions[a_type]) 185 | 186 | for e_type, ref in evidence_types.items(): 187 | evidence_types_scores[ref] = 0 if len(gtJson['data']) == 0 else evidenceTypeTotalScore[e_type] / ( 188 | evidenceTypeNumQuestions[e_type]) 189 | 190 | for r_type, ref in reasoning_requirements.items(): 191 | operation_types_scores[ref] = 0 if len(gtJson['data']) == 0 else reasoningTypeTotalScore[r_type] / ( 192 | reasoningTypeNumQuestions[r_type]) 193 | 194 | resDict = { 195 | 'result': methodMetrics, 196 | 'scores_by_types': {'answer_types': answer_types_scores, 'evidence_types': evidence_types_scores, 197 | 'operation_types': operation_types_scores}, 198 | 'per_sample_result': perSampleMetrics 199 | } 200 | 201 | return resDict 202 | 203 | 204 | def display_results(results, show_answer_types): 205 | print('\nOverall ANLS: {:2.4f}'.format(results['result']['score'])) 206 | 207 | if show_answer_types: 208 | print('\nAnswer types:') 209 | for a_type in answer_types.values(): 210 | print('\t{:12s} {:2.4f}'.format(a_type, results['scores_by_types']['answer_types'][a_type])) 211 | 212 | print('\nEvidence types:') 213 | for e_type in evidence_types.values(): 214 | print('\t{:12s} {:2.4f}'.format(e_type, results['scores_by_types']['evidence_types'][e_type])) 215 | 216 | print('\nOperation required:') 217 | for r_type in reasoning_requirements.values(): 218 | print('\t{:12s} {:2.4f}'.format(r_type, results['scores_by_types']['operation_types'][r_type])) 219 | 220 | 221 | if __name__ == '__main__': 222 | parser = argparse.ArgumentParser(description='InfographVQA evaluation script.') 223 | 224 | parser.add_argument('-g', '--ground_truth', type=str, help='Path of the Ground Truth file.', required=True) 225 | parser.add_argument('-s', '--submission_file', type=str, help="Path of your method's results file.", required=True) 226 | 227 | parser.add_argument('-t', '--anls_threshold', type=float, default=0.5, 228 | help='ANLS threshold to use (See Scene-Text VQA paper for more info.).', required=False) 229 | parser.add_argument('-a', '--answer_types', type=bool, default=False, 230 | help='Score break down by answer types (special gt file required).', required=False) 231 | parser.add_argument('-o', '--output', type=str, 232 | help="Path to a directory where to copy the file 'results.json' that contains per-sample results.", 233 | required=False) 234 | 235 | args = parser.parse_args() 236 | 237 | # Validate the format of ground truth and submission files. 238 | validate_data(args.ground_truth, args.submission_file) 239 | 240 | # Evaluate method 241 | results = evaluate_method(args.ground_truth, args.submission_file, args) 242 | 243 | display_results(results, args.answer_types) 244 | 245 | if args.output: 246 | output_dir = args.output 247 | 248 | if not os.path.exists(output_dir): 249 | os.makedirs(output_dir) 250 | 251 | resultsOutputname = os.path.join(output_dir, 'results.json') 252 | save_json(resultsOutputname, results) 253 | 254 | print('All results including per-sample result has been correctly saved!') 255 | -------------------------------------------------------------------------------- /eval/vqa/textvqa_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # copied from https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/m4c_evaluator.py 3 | import re 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | class EvalAIAnswerProcessor: 9 | """ 10 | Processes an answer similar to Eval AI 11 | copied from 12 | https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 13 | """ 14 | 15 | CONTRACTIONS = { 16 | 'aint': "ain't", 17 | 'arent': "aren't", 18 | 'cant': "can't", 19 | 'couldve': "could've", 20 | 'couldnt': "couldn't", 21 | "couldn'tve": "couldn't've", 22 | "couldnt've": "couldn't've", 23 | 'didnt': "didn't", 24 | 'doesnt': "doesn't", 25 | 'dont': "don't", 26 | 'hadnt': "hadn't", 27 | "hadnt've": "hadn't've", 28 | "hadn'tve": "hadn't've", 29 | 'hasnt': "hasn't", 30 | 'havent': "haven't", 31 | 'hed': "he'd", 32 | "hed've": "he'd've", 33 | "he'dve": "he'd've", 34 | 'hes': "he's", 35 | 'howd': "how'd", 36 | 'howll': "how'll", 37 | 'hows': "how's", 38 | "Id've": "I'd've", 39 | "I'dve": "I'd've", 40 | 'Im': "I'm", 41 | 'Ive': "I've", 42 | 'isnt': "isn't", 43 | 'itd': "it'd", 44 | "itd've": "it'd've", 45 | "it'dve": "it'd've", 46 | 'itll': "it'll", 47 | "let's": "let's", 48 | 'maam': "ma'am", 49 | 'mightnt': "mightn't", 50 | "mightnt've": "mightn't've", 51 | "mightn'tve": "mightn't've", 52 | 'mightve': "might've", 53 | 'mustnt': "mustn't", 54 | 'mustve': "must've", 55 | 'neednt': "needn't", 56 | 'notve': "not've", 57 | 'oclock': "o'clock", 58 | 'oughtnt': "oughtn't", 59 | "ow's'at": "'ow's'at", 60 | "'ows'at": "'ow's'at", 61 | "'ow'sat": "'ow's'at", 62 | 'shant': "shan't", 63 | "shed've": "she'd've", 64 | "she'dve": "she'd've", 65 | "she's": "she's", 66 | 'shouldve': "should've", 67 | 'shouldnt': "shouldn't", 68 | "shouldnt've": "shouldn't've", 69 | "shouldn'tve": "shouldn't've", 70 | "somebody'd": 'somebodyd', 71 | "somebodyd've": "somebody'd've", 72 | "somebody'dve": "somebody'd've", 73 | 'somebodyll': "somebody'll", 74 | 'somebodys': "somebody's", 75 | 'someoned': "someone'd", 76 | "someoned've": "someone'd've", 77 | "someone'dve": "someone'd've", 78 | 'someonell': "someone'll", 79 | 'someones': "someone's", 80 | 'somethingd': "something'd", 81 | "somethingd've": "something'd've", 82 | "something'dve": "something'd've", 83 | 'somethingll': "something'll", 84 | 'thats': "that's", 85 | 'thered': "there'd", 86 | "thered've": "there'd've", 87 | "there'dve": "there'd've", 88 | 'therere': "there're", 89 | 'theres': "there's", 90 | 'theyd': "they'd", 91 | "theyd've": "they'd've", 92 | "they'dve": "they'd've", 93 | 'theyll': "they'll", 94 | 'theyre': "they're", 95 | 'theyve': "they've", 96 | 'twas': "'twas", 97 | 'wasnt': "wasn't", 98 | "wed've": "we'd've", 99 | "we'dve": "we'd've", 100 | 'weve': "we've", 101 | 'werent': "weren't", 102 | 'whatll': "what'll", 103 | 'whatre': "what're", 104 | 'whats': "what's", 105 | 'whatve': "what've", 106 | 'whens': "when's", 107 | 'whered': "where'd", 108 | 'wheres': "where's", 109 | 'whereve': "where've", 110 | 'whod': "who'd", 111 | "whod've": "who'd've", 112 | "who'dve": "who'd've", 113 | 'wholl': "who'll", 114 | 'whos': "who's", 115 | 'whove': "who've", 116 | 'whyll': "why'll", 117 | 'whyre': "why're", 118 | 'whys': "why's", 119 | 'wont': "won't", 120 | 'wouldve': "would've", 121 | 'wouldnt': "wouldn't", 122 | "wouldnt've": "wouldn't've", 123 | "wouldn'tve": "wouldn't've", 124 | 'yall': "y'all", 125 | "yall'll": "y'all'll", 126 | "y'allll": "y'all'll", 127 | "yall'd've": "y'all'd've", 128 | "y'alld've": "y'all'd've", 129 | "y'all'dve": "y'all'd've", 130 | 'youd': "you'd", 131 | "youd've": "you'd've", 132 | "you'dve": "you'd've", 133 | 'youll': "you'll", 134 | 'youre': "you're", 135 | 'youve': "you've", 136 | } 137 | 138 | NUMBER_MAP = { 139 | 'none': '0', 140 | 'zero': '0', 141 | 'one': '1', 142 | 'two': '2', 143 | 'three': '3', 144 | 'four': '4', 145 | 'five': '5', 146 | 'six': '6', 147 | 'seven': '7', 148 | 'eight': '8', 149 | 'nine': '9', 150 | 'ten': '10', 151 | } 152 | ARTICLES = ['a', 'an', 'the'] 153 | PERIOD_STRIP = re.compile(r'(?!<=\d)(\.)(?!\d)') 154 | COMMA_STRIP = re.compile(r'(?<=\d)(\,)+(?=\d)') 155 | PUNCTUATIONS = [ 156 | ';', 157 | r'/', 158 | '[', 159 | ']', 160 | '"', 161 | '{', 162 | '}', 163 | '(', 164 | ')', 165 | '=', 166 | '+', 167 | '\\', 168 | '_', 169 | '-', 170 | '>', 171 | '<', 172 | '@', 173 | '`', 174 | ',', 175 | '?', 176 | '!', 177 | ] 178 | 179 | def __init__(self, *args, **kwargs): 180 | pass 181 | 182 | def word_tokenize(self, word): 183 | word = word.lower() 184 | word = word.replace(',', '').replace('?', '').replace("'s", " 's") 185 | return word.strip() 186 | 187 | def process_punctuation(self, in_text): 188 | out_text = in_text 189 | for p in self.PUNCTUATIONS: 190 | if (p + ' ' in in_text or ' ' + p in in_text) or ( 191 | re.search(self.COMMA_STRIP, in_text) is not None 192 | ): 193 | out_text = out_text.replace(p, '') 194 | else: 195 | out_text = out_text.replace(p, ' ') 196 | out_text = self.PERIOD_STRIP.sub('', out_text, re.UNICODE) 197 | return out_text 198 | 199 | def process_digit_article(self, in_text): 200 | out_text = [] 201 | temp_text = in_text.lower().split() 202 | for word in temp_text: 203 | word = self.NUMBER_MAP.setdefault(word, word) 204 | if word not in self.ARTICLES: 205 | out_text.append(word) 206 | else: 207 | pass 208 | for word_id, word in enumerate(out_text): 209 | if word in self.CONTRACTIONS: 210 | out_text[word_id] = self.CONTRACTIONS[word] 211 | out_text = ' '.join(out_text) 212 | return out_text 213 | 214 | def __call__(self, item): 215 | item = self.word_tokenize(item) 216 | item = item.replace('\n', ' ').replace('\t', ' ').strip() 217 | item = self.process_punctuation(item) 218 | item = self.process_digit_article(item) 219 | return item 220 | 221 | 222 | class TextVQAAccuracyEvaluator: 223 | def __init__(self): 224 | self.answer_processor = EvalAIAnswerProcessor() 225 | 226 | def _compute_answer_scores(self, raw_answers): 227 | """ 228 | compute the accuracy (soft score) of human answers 229 | """ 230 | answers = [self.answer_processor(a) for a in raw_answers] 231 | assert len(answers) == 10 232 | gt_answers = list(enumerate(answers)) 233 | unique_answers = set(answers) 234 | unique_answer_scores = {} 235 | 236 | for unique_answer in unique_answers: 237 | accs = [] 238 | for gt_answer in gt_answers: 239 | other_answers = [item for item in gt_answers if item != gt_answer] 240 | matching_answers = [ 241 | item for item in other_answers if item[1] == unique_answer 242 | ] 243 | acc = min(1, float(len(matching_answers)) / 3) 244 | accs.append(acc) 245 | unique_answer_scores[unique_answer] = sum(accs) / len(accs) 246 | 247 | return unique_answer_scores 248 | 249 | def eval_pred_list(self, pred_list): 250 | pred_scores = [] 251 | for entry in tqdm(pred_list): 252 | pred_answer = self.answer_processor(entry['pred_answer']) 253 | unique_answer_scores = self._compute_answer_scores(entry['gt_answers']) 254 | score = unique_answer_scores.get(pred_answer, 0.0) 255 | pred_scores.append(score) 256 | 257 | accuracy = sum(pred_scores) / len(pred_scores) 258 | return accuracy 259 | 260 | 261 | class STVQAAccuracyEvaluator: 262 | def __init__(self): 263 | self.answer_processor = EvalAIAnswerProcessor() 264 | 265 | def eval_pred_list(self, pred_list): 266 | pred_scores = [] 267 | for entry in pred_list: 268 | pred_answer = self.answer_processor(entry['pred_answer']) 269 | gts = [self.answer_processor(a) for a in entry['gt_answers']] 270 | score = 1.0 if pred_answer in gts else 0.0 271 | pred_scores.append(score) 272 | 273 | accuracy = sum(pred_scores) / len(pred_scores) 274 | return accuracy 275 | 276 | 277 | class STVQAANLSEvaluator: 278 | def __init__(self): 279 | import editdistance # install with `pip install editdistance` 280 | 281 | self.get_edit_distance = editdistance.eval 282 | 283 | def get_anls(self, s1, s2): 284 | s1 = s1.lower().strip() 285 | s2 = s2.lower().strip() 286 | iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2)) 287 | anls = iou if iou >= 0.5 else 0.0 288 | return anls 289 | 290 | def eval_pred_list(self, pred_list): 291 | pred_scores = [] 292 | for entry in pred_list: 293 | anls = max( 294 | self.get_anls(entry['pred_answer'], gt) for gt in entry['gt_answers'] 295 | ) 296 | pred_scores.append(anls) 297 | 298 | accuracy = sum(pred_scores) / len(pred_scores) 299 | return accuracy 300 | 301 | 302 | class TextCapsBleu4Evaluator: 303 | def __init__(self): 304 | # The following script requires Java 1.8.0 and pycocotools installed. 305 | # The pycocoevalcap can be installed with pip as 306 | # pip install git+https://github.com/ronghanghu/coco-caption.git@python23 307 | # Original pycocoevalcap code is at https://github.com/tylin/coco-caption 308 | # but has no python3 support yet. 309 | try: 310 | from pycocoevalcap.bleu.bleu import Bleu 311 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 312 | except ModuleNotFoundError: 313 | print( 314 | 'Please install pycocoevalcap module using ' 315 | 'pip install git+https://github.com/ronghanghu/coco-caption.git@python23' # noqa 316 | ) 317 | raise 318 | 319 | self.tokenizer = PTBTokenizer() 320 | self.scorer = Bleu(4) 321 | 322 | def eval_pred_list(self, pred_list): 323 | # Create reference and hypotheses captions. 324 | gts = {} 325 | res = {} 326 | for idx, entry in enumerate(pred_list): 327 | gts[idx] = [{'caption': a} for a in entry['gt_answers']] 328 | res[idx] = [{'caption': entry['pred_answer']}] 329 | 330 | gts = self.tokenizer.tokenize(gts) 331 | res = self.tokenizer.tokenize(res) 332 | score, _ = self.scorer.compute_score(gts, res) 333 | 334 | bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4) 335 | return bleu4 336 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | CHECKPOINT=${1} 4 | DATASET=${2} 5 | # CHECKPOINT="$(pwd)/${CHECKPOINT}" 6 | CHECKPOINT=${CHECKPOINT} 7 | export PYTHONPATH="$(pwd):${PYTHONPATH}" 8 | echo "CHECKPOINT: ${CHECKPOINT}" 9 | 10 | MASTER_PORT=${MASTER_PORT:-63669} 11 | PORT=${PORT:-63665} 12 | GPUS=${GPUS:-8} 13 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 14 | NODES=$((GPUS / GPUS_PER_NODE)) 15 | export MASTER_PORT=${MASTER_PORT} 16 | export PORT=${PORT} 17 | 18 | # Save original arguments 19 | ARGS=("$@") 20 | 21 | # Parse options 22 | while [[ $# -gt 0 ]]; do 23 | case "$1" in 24 | --auto) 25 | GPUS=1 26 | shift 27 | ;; 28 | *) 29 | shift 30 | ;; 31 | esac 32 | done 33 | echo "GPUS: ${GPUS}" 34 | 35 | # 检查a是否以斜杠结尾,如果是则去掉斜杠 36 | if [[ "${CHECKPOINT}" == */ ]]; then 37 | CHECKPOINT="${CHECKPOINT%/}" 38 | fi 39 | 40 | # dir to save result files 41 | mkdir -p "$CHECKPOINT/eval" 42 | 43 | if [ ${DATASET} == "mme" ]; then 44 | cd eval/mme/ 45 | DIRNAME=`basename ${CHECKPOINT}` 46 | python eval.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}" 47 | python calculation.py --results_dir ${CHECKPOINT}/mme 2>&1 | tee -a ${CHECKPOINT}/mme/result_final.txt 48 | cd ../../ 49 | fi 50 | 51 | if [ ${DATASET} == "caption" ]; then 52 | torchrun \ 53 | --nnodes=1 \ 54 | --node_rank=0 \ 55 | --master_addr=127.0.0.1 \ 56 | --nproc_per_node=${GPUS} \ 57 | --master_port=${MASTER_PORT} \ 58 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval "${ARGS[@]:2}" 59 | fi 60 | 61 | if [ ${DATASET} == "caption-coco" ]; then 62 | torchrun \ 63 | --nnodes=1 \ 64 | --node_rank=0 \ 65 | --master_addr=127.0.0.1 \ 66 | --nproc_per_node=${GPUS} \ 67 | --master_port=${MASTER_PORT} \ 68 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets coco "${ARGS[@]:2}" 69 | fi 70 | 71 | if [ ${DATASET} == "caption-flickr30k" ]; then 72 | torchrun \ 73 | --nnodes=1 \ 74 | --node_rank=0 \ 75 | --master_addr=127.0.0.1 \ 76 | --nproc_per_node=${GPUS} \ 77 | --master_port=${MASTER_PORT} \ 78 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets flickr30k "${ARGS[@]:2}" 79 | fi 80 | 81 | if [ ${DATASET} == "caption-nocaps" ]; then 82 | torchrun \ 83 | --nnodes=1 \ 84 | --node_rank=0 \ 85 | --master_addr=127.0.0.1 \ 86 | --nproc_per_node=${GPUS} \ 87 | --master_port=${MASTER_PORT} \ 88 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets nocaps "${ARGS[@]:2}" 89 | fi 90 | 91 | if [ ${DATASET} == "vqa" ]; then 92 | torchrun \ 93 | --nnodes=1 \ 94 | --node_rank=0 \ 95 | --master_addr=127.0.0.1 \ 96 | --nproc_per_node=${GPUS} \ 97 | --master_port=${MASTER_PORT} \ 98 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval "${ARGS[@]:2}" 99 | fi 100 | 101 | if [ ${DATASET} == "vqa-okvqa-val" ]; then 102 | torchrun \ 103 | --nnodes=1 \ 104 | --node_rank=0 \ 105 | --master_addr=127.0.0.1 \ 106 | --nproc_per_node=${GPUS} \ 107 | --master_port=${MASTER_PORT} \ 108 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets okvqa_val "${ARGS[@]:2}" 109 | fi 110 | 111 | if [ ${DATASET} == "vqa-textvqa-val" ]; then 112 | torchrun \ 113 | --nnodes=1 \ 114 | --node_rank=0 \ 115 | --master_addr=127.0.0.1 \ 116 | --nproc_per_node=${GPUS} \ 117 | --master_port=${MASTER_PORT} \ 118 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets textvqa_val "${ARGS[@]:2}" 119 | fi 120 | 121 | if [ ${DATASET} == "vqa-textvqa-val-ocr" ]; then 122 | torchrun \ 123 | --nnodes=1 \ 124 | --node_rank=0 \ 125 | --master_addr=127.0.0.1 \ 126 | --nproc_per_node=${GPUS} \ 127 | --master_port=${MASTER_PORT} \ 128 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets textvqa_val_ocr "${ARGS[@]:2}" 129 | fi 130 | 131 | if [ ${DATASET} == "vqa-vizwiz-val" ]; then 132 | torchrun \ 133 | --nnodes=1 \ 134 | --node_rank=0 \ 135 | --master_addr=127.0.0.1 \ 136 | --nproc_per_node=${GPUS} \ 137 | --master_port=${MASTER_PORT} \ 138 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vizwiz_val "${ARGS[@]:2}" 139 | fi 140 | 141 | if [ ${DATASET} == "vqa-vizwiz-test" ]; then 142 | torchrun \ 143 | --nnodes=1 \ 144 | --node_rank=0 \ 145 | --master_addr=127.0.0.1 \ 146 | --nproc_per_node=${GPUS} \ 147 | --master_port=${MASTER_PORT} \ 148 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vizwiz_test "${ARGS[@]:2}" 149 | fi 150 | 151 | if [ ${DATASET} == "vqa-vqav2-testdev" ]; then 152 | torchrun \ 153 | --nnodes=1 \ 154 | --node_rank=0 \ 155 | --master_addr=127.0.0.1 \ 156 | --nproc_per_node=${GPUS} \ 157 | --master_port=${MASTER_PORT} \ 158 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vqav2_testdev "${ARGS[@]:2}" 159 | fi 160 | 161 | if [ ${DATASET} == "vqa-ai2d-test" ]; then 162 | torchrun \ 163 | --nnodes=1 \ 164 | --node_rank=0 \ 165 | --master_addr=127.0.0.1 \ 166 | --nproc_per_node=${GPUS} \ 167 | --master_port=${MASTER_PORT} \ 168 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ai2diagram_test "${ARGS[@]:2}" 169 | fi 170 | 171 | if [ ${DATASET} == "vqa-vqav2-val" ]; then 172 | torchrun \ 173 | --nnodes=1 \ 174 | --node_rank=0 \ 175 | --master_addr=127.0.0.1 \ 176 | --nproc_per_node=${GPUS} \ 177 | --master_port=${MASTER_PORT} \ 178 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vqav2_val "${ARGS[@]:2}" 179 | fi 180 | 181 | if [ ${DATASET} == "vqa-gqa-testdev" ]; then 182 | torchrun \ 183 | --nnodes=1 \ 184 | --node_rank=0 \ 185 | --master_addr=127.0.0.1 \ 186 | --nproc_per_node=${GPUS} \ 187 | --master_port=${MASTER_PORT} \ 188 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets gqa_testdev_llava "${ARGS[@]:2}" 189 | fi 190 | 191 | if [ ${DATASET} == "vqa-docvqa-val" ]; then 192 | torchrun \ 193 | --nnodes=1 \ 194 | --node_rank=0 \ 195 | --master_addr=127.0.0.1 \ 196 | --nproc_per_node=${GPUS} \ 197 | --master_port=${MASTER_PORT} \ 198 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets docvqa_val "${ARGS[@]:2}" 199 | fi 200 | 201 | if [ ${DATASET} == "vqa-docvqa-test" ]; then 202 | torchrun \ 203 | --nnodes=1 \ 204 | --node_rank=0 \ 205 | --master_addr=127.0.0.1 \ 206 | --nproc_per_node=${GPUS} \ 207 | --master_port=${MASTER_PORT} \ 208 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets docvqa_test "${ARGS[@]:2}" 209 | fi 210 | 211 | if [ ${DATASET} == "vqa-chartqa-test" ]; then 212 | torchrun \ 213 | --nnodes=1 \ 214 | --node_rank=0 \ 215 | --master_addr=127.0.0.1 \ 216 | --nproc_per_node=${GPUS} \ 217 | --master_port=${MASTER_PORT} \ 218 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets chartqa_test_human,chartqa_test_augmented "${ARGS[@]:2}" 219 | fi 220 | 221 | if [ ${DATASET} == "vqa-infovqa-val" ]; then 222 | torchrun \ 223 | --nnodes=1 \ 224 | --node_rank=0 \ 225 | --master_addr=127.0.0.1 \ 226 | --nproc_per_node=${GPUS} \ 227 | --master_port=${MASTER_PORT} \ 228 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets infographicsvqa_val "${ARGS[@]:2}" 229 | fi 230 | 231 | if [ ${DATASET} == "vqa-infovqa-test" ]; then 232 | torchrun \ 233 | --nnodes=1 \ 234 | --node_rank=0 \ 235 | --master_addr=127.0.0.1 \ 236 | --nproc_per_node=${GPUS} \ 237 | --master_port=${MASTER_PORT} \ 238 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets infographicsvqa_test "${ARGS[@]:2}" 239 | fi 240 | 241 | if [ ${DATASET} == "vqa-chartqa-test-human" ]; then 242 | torchrun \ 243 | --nnodes=1 \ 244 | --node_rank=0 \ 245 | --master_addr=127.0.0.1 \ 246 | --nproc_per_node=${GPUS} \ 247 | --master_port=${MASTER_PORT} \ 248 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets chartqa_test_human "${ARGS[@]:2}" 249 | fi 250 | 251 | if [ ${DATASET} == "vqa-chartqa-test-augmented" ]; then 252 | torchrun \ 253 | --nnodes=1 \ 254 | --node_rank=0 \ 255 | --master_addr=127.0.0.1 \ 256 | --nproc_per_node=${GPUS} \ 257 | --master_port=${MASTER_PORT} \ 258 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets chartqa_test_augmented "${ARGS[@]:2}" 259 | fi 260 | 261 | if [ ${DATASET} == "vqa-ocrvqa-val" ]; then 262 | torchrun \ 263 | --nnodes=1 \ 264 | --node_rank=0 \ 265 | --master_addr=127.0.0.1 \ 266 | --nproc_per_node=${GPUS} \ 267 | --master_port=${MASTER_PORT} \ 268 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ocrvqa_val "${ARGS[@]:2}" 269 | fi 270 | 271 | if [ ${DATASET} == "vqa-ocrvqa-test" ]; then 272 | torchrun \ 273 | --nnodes=1 \ 274 | --node_rank=0 \ 275 | --master_addr=127.0.0.1 \ 276 | --nproc_per_node=${GPUS} \ 277 | --master_port=${MASTER_PORT} \ 278 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ocrvqa_test "${ARGS[@]:2}" 279 | fi 280 | 281 | if [ ${DATASET} == "refcoco" ]; then 282 | torchrun \ 283 | --nnodes=1 \ 284 | --node_rank=0 \ 285 | --master_addr=127.0.0.1 \ 286 | --nproc_per_node=${GPUS} \ 287 | --master_port=${MASTER_PORT} \ 288 | eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval "${ARGS[@]:2}" 289 | fi 290 | 291 | if [ ${DATASET} == "refcoco-val" ]; then 292 | torchrun \ 293 | --nnodes=1 \ 294 | --node_rank=0 \ 295 | --master_addr=127.0.0.1 \ 296 | --nproc_per_node=${GPUS} \ 297 | --master_port=${MASTER_PORT} \ 298 | eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets refcoco_val "${ARGS[@]:2}" 299 | fi 300 | 301 | if [ ${DATASET} == "llava-bench" ]; then 302 | rm -rf results/llava_bench_results_review.jsonl 303 | python eval/llava_bench/evaluate_llava_bench.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}" 304 | python -u eval/llava_bench/eval_gpt_review_bench.py \ 305 | --question data/llava-bench-in-the-wild/questions.jsonl \ 306 | --context data/llava-bench-in-the-wild/context.jsonl \ 307 | --rule eval/llava_bench/rule.json \ 308 | --answer-list \ 309 | data/llava-bench-in-the-wild/answers_gpt4.jsonl \ 310 | results/llava_bench_results.jsonl \ 311 | --output \ 312 | results/llava_bench_results_review.jsonl 313 | python -u eval/llava_bench/summarize_gpt_review.py -f results/llava_bench_results_review.jsonl 314 | fi 315 | 316 | if [ ${DATASET} == "pope" ]; then 317 | torchrun \ 318 | --nnodes=1 \ 319 | --node_rank=0 \ 320 | --master_addr=127.0.0.1 \ 321 | --nproc_per_node=${GPUS} \ 322 | --master_port=${MASTER_PORT} \ 323 | eval/pope/evaluate_pope.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets pope "${ARGS[@]:2}" 324 | fi 325 | 326 | if [ ${DATASET} == "tiny_lvlm" ]; then 327 | torchrun \ 328 | --nnodes=1 \ 329 | --node_rank=0 \ 330 | --master_addr=127.0.0.1 \ 331 | --nproc_per_node=${GPUS} \ 332 | --master_port=${MASTER_PORT} \ 333 | eval/tiny_lvlm/evaluate_lvlm.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets updated_datasets "${ARGS[@]:2}" 334 | fi 335 | 336 | if [ ${DATASET} == "mmvet" ]; then 337 | python eval/mmvet/evaluate_mmvet.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmvet "${ARGS[@]:2}" 338 | fi 339 | 340 | if [ ${DATASET} == "cmmmu" ]; then 341 | CUDA_VISIBLE_DEVICES=0 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets art_and_design "${ARGS[@]:2}" & 342 | CUDA_VISIBLE_DEVICES=1 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets business "${ARGS[@]:2}" & 343 | CUDA_VISIBLE_DEVICES=2 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets health_and_medicine "${ARGS[@]:2}" & 344 | CUDA_VISIBLE_DEVICES=3 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets humanities_and_social_sciences "${ARGS[@]:2}" & 345 | CUDA_VISIBLE_DEVICES=4 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets science "${ARGS[@]:2}" & 346 | CUDA_VISIBLE_DEVICES=5 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets technology_and_engineering "${ARGS[@]:2}" & 347 | wait 348 | fi 349 | 350 | if [ ${DATASET} == "mmbench-dev-en" ]; then 351 | torchrun \ 352 | --nnodes=1 \ 353 | --node_rank=0 \ 354 | --master_addr=127.0.0.1 \ 355 | --nproc_per_node=${GPUS} \ 356 | --master_port=${MASTER_PORT} \ 357 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmbench_dev_20230712 "${ARGS[@]:2}" 358 | fi 359 | 360 | if [ ${DATASET} == "mmbench-dev-cn" ]; then 361 | torchrun \ 362 | --nnodes=1 \ 363 | --node_rank=0 \ 364 | --master_addr=127.0.0.1 \ 365 | --nproc_per_node=${GPUS} \ 366 | --master_port=${MASTER_PORT} \ 367 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmbench_dev_cn_20231003 "${ARGS[@]:2}" 368 | fi 369 | 370 | if [ ${DATASET} == "mmbench-test-en" ]; then 371 | torchrun \ 372 | --nnodes=1 \ 373 | --node_rank=0 \ 374 | --master_addr=127.0.0.1 \ 375 | --nproc_per_node=${GPUS} \ 376 | --master_port=${MASTER_PORT} \ 377 | eval/mmbench/evaluate_mmbench.py \ 378 | --checkpoint ${CHECKPOINT} \ 379 | --out-dir ${CHECKPOINT}/eval \ 380 | --datasets mmbench_test_en_20231003 "${ARGS[@]:2}" 381 | fi 382 | 383 | if [ ${DATASET} == "mmbench-test-cn" ]; then 384 | torchrun \ 385 | --nnodes=1 \ 386 | --node_rank=0 \ 387 | --master_addr=127.0.0.1 \ 388 | --nproc_per_node=${GPUS} \ 389 | --master_port=${MASTER_PORT} \ 390 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmbench_test_cn_20231003 "${ARGS[@]:2}" 391 | fi 392 | 393 | if [ ${DATASET} == "ccbench-dev" ]; then 394 | torchrun \ 395 | --nnodes=1 \ 396 | --node_rank=0 \ 397 | --master_addr=127.0.0.1 \ 398 | --nproc_per_node=${GPUS} \ 399 | --master_port=${MASTER_PORT} \ 400 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ccbench_dev_cn "${ARGS[@]:2}" 401 | fi 402 | 403 | if [ ${DATASET} == "scienceqa" ]; then 404 | torchrun \ 405 | --nnodes=1 \ 406 | --node_rank=0 \ 407 | --master_addr=127.0.0.1 \ 408 | --nproc_per_node=${GPUS} \ 409 | --master_port=${MASTER_PORT} \ 410 | eval/scienceqa/evaluate_scienceqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets sqa_test "${ARGS[@]:2}" 411 | fi 412 | 413 | 414 | if [ ${DATASET} == "mmmu-dev" ]; then 415 | torchrun \ 416 | --nnodes=1 \ 417 | --node_rank=0 \ 418 | --master_addr=127.0.0.1 \ 419 | --nproc_per_node=${GPUS} \ 420 | --master_port=${MASTER_PORT} \ 421 | eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_dev "${ARGS[@]:2}" 422 | fi 423 | 424 | if [ ${DATASET} == "mmmu-val" ]; then 425 | torchrun \ 426 | --nnodes=1 \ 427 | --node_rank=0 \ 428 | --master_addr=127.0.0.1 \ 429 | --nproc_per_node=${GPUS} \ 430 | --master_port=${MASTER_PORT} \ 431 | eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_validation "${ARGS[@]:2}" 432 | fi 433 | 434 | if [ ${DATASET} == "mmmu-test" ]; then 435 | torchrun \ 436 | --nnodes=1 \ 437 | --node_rank=0 \ 438 | --master_addr=127.0.0.1 \ 439 | --nproc_per_node=${GPUS} \ 440 | --master_port=${MASTER_PORT} \ 441 | eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_test "${ARGS[@]:2}" 442 | fi 443 | 444 | if [ ${DATASET} == "mmmu-dev-cot" ]; then 445 | torchrun \ 446 | --nnodes=1 \ 447 | --node_rank=0 \ 448 | --master_addr=127.0.0.1 \ 449 | --nproc_per_node=${GPUS} \ 450 | --master_port=${MASTER_PORT} \ 451 | eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_dev "${ARGS[@]:2}" 452 | fi 453 | 454 | if [ ${DATASET} == "mmmu-val-cot" ]; then 455 | torchrun \ 456 | --nnodes=1 \ 457 | --node_rank=0 \ 458 | --master_addr=127.0.0.1 \ 459 | --nproc_per_node=${GPUS} \ 460 | --master_port=${MASTER_PORT} \ 461 | eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_validation "${ARGS[@]:2}" 462 | fi 463 | 464 | if [ ${DATASET} == "mmmu-test-cot" ]; then 465 | torchrun \ 466 | --nnodes=1 \ 467 | --node_rank=0 \ 468 | --master_addr=127.0.0.1 \ 469 | --nproc_per_node=${GPUS} \ 470 | --master_port=${MASTER_PORT} \ 471 | eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_test "${ARGS[@]:2}" 472 | fi 473 | 474 | 475 | if [ ${DATASET} == "mmvp" ]; then 476 | torchrun \ 477 | --nnodes=1 \ 478 | --node_rank=0 \ 479 | --master_addr=127.0.0.1 \ 480 | --nproc_per_node=${GPUS} \ 481 | --master_port=${MASTER_PORT} \ 482 | eval/mmvp/evaluate_mmvp.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMVP "${ARGS[@]:2}" 483 | fi 484 | 485 | 486 | if [ ${DATASET} == "mathvista-testmini" ]; then 487 | torchrun \ 488 | --nnodes=1 \ 489 | --node_rank=0 \ 490 | --master_addr=127.0.0.1 \ 491 | --nproc_per_node=${GPUS} \ 492 | --master_port=${MASTER_PORT} \ 493 | eval/mathvista/evaluate_mathvista.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MathVista_testmini "${ARGS[@]:2}" 494 | fi 495 | 496 | 497 | if [ ${DATASET} == "mathvista-test" ]; then 498 | torchrun \ 499 | --nnodes=1 \ 500 | --node_rank=0 \ 501 | --master_addr=127.0.0.1 \ 502 | --nproc_per_node=${GPUS} \ 503 | --master_port=${MASTER_PORT} \ 504 | eval/mathvista/evaluate_mathvista.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MathVista_test "${ARGS[@]:2}" 505 | fi 506 | 507 | if [ ${DATASET} == "seed" ]; then 508 | torchrun \ 509 | --nnodes=1 \ 510 | --node_rank=0 \ 511 | --master_addr=127.0.0.1 \ 512 | --nproc_per_node=${GPUS} \ 513 | --master_port=${MASTER_PORT} \ 514 | eval/seed/evaluate_seed.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets SEEDv1 "${ARGS[@]:2}" 515 | fi 516 | 517 | if [ ${DATASET} == "mvbench" ]; then 518 | torchrun \ 519 | --nnodes=1 \ 520 | --node_rank=0 \ 521 | --master_addr=127.0.0.1 \ 522 | --nproc_per_node=${GPUS} \ 523 | --master_port=${MASTER_PORT} \ 524 | eval/mvbench/evaluate_mvbench.py \ 525 | --checkpoint ${CHECKPOINT} \ 526 | --out-dir ${CHECKPOINT}/eval \ 527 | "${ARGS[@]:2}" 528 | fi 529 | -------------------------------------------------------------------------------- /evaluate_launch.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | CHECKPOINT=${1} 4 | TASK=${2} 5 | ARGS=("$@") 6 | 7 | if [ "${TASK}" == "vqa-chartqa-test" ]; then 8 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 12 "${ARGS[@]:3}" 9 | elif [ "${TASK}" == "vqa-infovqa-val" -o "${TASK}" == "vqa-infovqa-test" ]; then 10 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 24 "${ARGS[@]:3}" 11 | elif [ "${TASK}" == "vqa-docvqa-val" -o "${TASK}" == "vqa-docvqa-test" ]; then 12 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 18 "${ARGS[@]:3}" 13 | elif [ "${TASK}" == "mvbench" ]; then 14 | sh evaluate.sh ${CHECKPOINT} ${TASK} --num_segments 96 "${ARGS[@]:3}" 15 | else 16 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 6 "${ARGS[@]:3}" 17 | fi 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes==0.41.0 3 | decord 4 | deepspeed==0.13.5 5 | einops==0.6.1 6 | einops-exts==0.0.4 7 | huggingface_hub 8 | imageio 9 | numpy 10 | opencv-python 11 | orjson 12 | peft>=0.4.0 13 | pycocoevalcap 14 | pyyaml 15 | scikit-learn>=1.2.2 16 | scipy 17 | sentencepiece==0.1.99 18 | shortuuid 19 | tensorboardX 20 | termcolor 21 | timm==0.9.12 22 | tokenizers==0.15.1 23 | torch>=2 24 | torchvision>=0.15 25 | tqdm 26 | transformers==4.37.2 27 | yacs -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as T 4 | from decord import VideoReader, cpu 5 | from PIL import Image 6 | from torchvision.transforms.functional import InterpolationMode 7 | 8 | 9 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 10 | IMAGENET_STD = (0.229, 0.224, 0.225) 11 | 12 | 13 | def build_transform(input_size): 14 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 15 | transform = T.Compose([ 16 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 17 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 18 | T.ToTensor(), 19 | T.Normalize(mean=MEAN, std=STD) 20 | ]) 21 | return transform 22 | 23 | 24 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 25 | best_ratio_diff = float('inf') 26 | best_ratio = (1, 1) 27 | area = width * height 28 | for ratio in target_ratios: 29 | target_aspect_ratio = ratio[0] / ratio[1] 30 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 31 | if ratio_diff < best_ratio_diff: 32 | best_ratio_diff = ratio_diff 33 | best_ratio = ratio 34 | elif ratio_diff == best_ratio_diff: 35 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 36 | best_ratio = ratio 37 | return best_ratio 38 | 39 | 40 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): 41 | orig_width, orig_height = image.size 42 | aspect_ratio = orig_width / orig_height 43 | 44 | # calculate the existing image aspect ratio 45 | target_ratios = set( 46 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 47 | i * j <= max_num and i * j >= min_num) 48 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 49 | 50 | # find the closest aspect ratio to the target 51 | target_aspect_ratio = find_closest_aspect_ratio( 52 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 53 | 54 | # calculate the target width and height 55 | target_width = image_size * target_aspect_ratio[0] 56 | target_height = image_size * target_aspect_ratio[1] 57 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 58 | 59 | # resize the image 60 | resized_img = image.resize((target_width, target_height)) 61 | processed_images = [] 62 | for i in range(blocks): 63 | box = ( 64 | (i % (target_width // image_size)) * image_size, 65 | (i // (target_width // image_size)) * image_size, 66 | ((i % (target_width // image_size)) + 1) * image_size, 67 | ((i // (target_width // image_size)) + 1) * image_size 68 | ) 69 | # split the image 70 | split_img = resized_img.crop(box) 71 | processed_images.append(split_img) 72 | assert len(processed_images) == blocks 73 | if use_thumbnail and len(processed_images) != 1: 74 | thumbnail_img = image.resize((image_size, image_size)) 75 | processed_images.append(thumbnail_img) 76 | return processed_images 77 | 78 | 79 | def load_image(image_file, input_size=448, max_num=12): 80 | image = Image.open(image_file).convert('RGB') 81 | transform = build_transform(input_size=input_size) 82 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 83 | pixel_values = [transform(image) for image in images] 84 | pixel_values = torch.stack(pixel_values) 85 | return pixel_values 86 | 87 | 88 | def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): 89 | if bound: 90 | start, end = bound[0], bound[1] 91 | else: 92 | start, end = -100000, 100000 93 | start_idx = max(first_idx, round(start * fps)) 94 | end_idx = min(round(end * fps), max_frame) 95 | seg_size = float(end_idx - start_idx) / num_segments 96 | frame_indices = np.array([ 97 | int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) 98 | for idx in range(num_segments) 99 | ]) 100 | return frame_indices 101 | 102 | 103 | def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): 104 | vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) 105 | max_frame = len(vr) - 1 106 | fps = float(vr.get_avg_fps()) 107 | 108 | pixel_values_list, num_patches_list = [], [] 109 | transform = build_transform(input_size=input_size) 110 | frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) 111 | for frame_index in frame_indices: 112 | img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') 113 | img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) 114 | pixel_values = [transform(tile) for tile in img] 115 | pixel_values = torch.stack(pixel_values) 116 | num_patches_list.append(pixel_values.shape[0]) 117 | pixel_values_list.append(pixel_values) 118 | pixel_values = torch.cat(pixel_values_list) 119 | return pixel_values, num_patches_list 120 | --------------------------------------------------------------------------------