├── .gitignore
├── LICENSE
├── README.md
├── assets
├── example_image1.jpg
├── example_image2.jpg
├── example_video.mp4
└── overview.png
├── eval
├── mmbench
│ └── evaluate_mmbench.py
├── mme
│ ├── README.md
│ ├── calculation.py
│ └── eval.py
├── mmmu
│ ├── answer_dict_val.json
│ ├── data_utils.py
│ ├── eval_utils.py
│ ├── evaluate_mmmu.py
│ ├── evaluate_mmmu_cot.py
│ └── main_eval_only.py
├── mvbench
│ └── evaluate_mvbench.py
├── scienceqa
│ └── evaluate_scienceqa.py
├── seed
│ ├── calculation.py
│ └── evaluate_seed.py
└── vqa
│ ├── convert_gqa_for_eval.py
│ ├── evaluate_vqa.py
│ ├── infographicsvqa_eval.py
│ └── textvqa_eval.py
├── evaluate.sh
├── evaluate_launch.sh
├── requirements.txt
└── utils
└── preprocess.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | .idea/
163 | .vscode/
164 |
165 | data
166 | results/
167 | eval/
168 | ckpts/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 OpenGVLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Progressive Visual Token Compression (PVC)
2 |
3 | 
4 | [](https://arxiv.org/abs/2412.09613)
5 | [](https://huggingface.co/OpenGVLab/PVC-InternVL2-8B)
6 |
7 | **[CVPR 2025]** [**PVC: Progressive Visual Token Compression for Unified Image and Video Processing in Large Vision-Language Models**](https://arxiv.org/abs/2412.09613)
8 |
9 | We introduce the **Progressive Visual Token Compression (PVC)** in large vision-language models (VLMs), which unifies the visual inputs as videos and progressively compresses vision tokens across video frames. Our PVC achieves:
10 |
11 | * Preserve spatial details and temporal dynamics for both images and videos.
12 | * Effectively reduce the tokens used for each video frame and image tile.
13 | * SoTA performance on various video benchmarks, including long and fine-grained short video tasks.
14 | * No performance loss on image benchmarks, especially on detail-sensitive tasks.
15 |
16 |
17 |

18 |
19 |
20 | ## 📈 Results
21 |
22 | Our implementation is based on the [InternVL2](https://github.com/OpenGVLab/InternVL) model, referred to as **PVCInternVL2**
23 |
24 | ### Video Understanding Benckmarks
25 |
26 | | Model | LLaVA-OneVision-7B | Qwen2-VL-7B | InternVL2-8B | PVCInternVL2-8B
🤗 [link](https://huggingface.co/OpenGVLab/PVC-InternVL2-8B) |
27 | | :--------------: | :--: | :--: | :--: | :--: |
28 | | \# token/frame | 196 | - | 256 | 64 |
29 | | | | | | |
30 | | MVbench | 56.7 | 67.0 | 66.4 | 73.8 |
31 | | VideoMME w/o-sub | 58.2 | 63.3 | 54.0 | 64.1 |
32 | | VideoMME w-sub | 61.5 | 69.0 | 56.9 | 69.7 |
33 | | MLVU | 64.7 | - | 52.0 | 72.4 |
34 | | LongVideoBench | 56.5 | - | - | 59.2 |
35 | | NextQA | 79.4 | - | - | 82.0 |
36 | | Egoschema | 60.1 | 66.7 | 55.0 | 59.6 |
37 | | PercepTest | 57.1 | 62.3 | 52.0 | 68.4 |
38 | | AcNet-QA | 56.6 | - | - | 57.1 |
39 |
40 | ### Image Understanding Benckmarks
41 |
42 | | Model | LLaVA-OneVision-7B | Qwen2-VL-7B | InternVL2-8B | PVCInternVL2-8B
🤗 [link](https://huggingface.co/OpenGVLab/PVC-InternVL2-8B) |
43 | | :--------------------: | :--: | :--: | :--: | :--: |
44 | | \# token/image tile | 729 | - | 256 | 64 |
45 | | | | | | |
46 | | AI2Dtest | 81.4 | 83.0 | 83.8 | 83.8 |
47 | | ChartQAtest | 80.0 | 83.0 | 83.3 | 84.1 |
48 | | DocVQAtest | 87.5 | 94.5 | 91.6 | 92.5 |
49 | | InfoVQAtest | 68.8 | 76.5 | 74.8 | 75.0 |
50 | | SQAtest | 96.0 | - | 97.1 | 97.7 |
51 | | TextVQAval | - | 84.3 | 77.4 | 80.0 |
52 | | MMBen-test | - | 83.0 | 81.7 | 83.9 |
53 | | MMEsum | 1998 | 2327 | 2210 | 2282 |
54 | | MMMUval | 48.8 | 54.1 | 49.3 | 50.9 |
55 | | SEEDI | 75.4 | - | 76.2 | 77.2 |
56 | | OCRBench | - | 866 | 794 | 807 |
57 |
58 | ## 🛠️ Usage
59 |
60 | You can use `pip install -r requirements.txt` to set up the environment. Please use `transformers>=4.37.2` to ensure the model works normally.
61 |
62 | ```python
63 | import torch
64 | from transformers import AutoTokenizer, AutoModel
65 | from utils.preprocess import load_image, load_video
66 |
67 | path = 'OpenGVLab/PVC-InternVL2-8B'
68 | model = AutoModel.from_pretrained(
69 | path,
70 | torch_dtype=torch.bfloat16,
71 | low_cpu_mem_usage=True,
72 | trust_remote_code=True).eval().cuda()
73 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
74 | generation_config = dict(max_new_tokens=1024, do_sample=True)
75 |
76 | # single-image conversation
77 | pixel_values = load_image('./assets/example_image1.jpg', max_num=12).to(torch.bfloat16).cuda()
78 | data_flag = torch.tensor([1], dtype=torch.long).cuda()
79 |
80 | question = '\nWhat is in the image?'
81 | response = model.chat(tokenizer, pixel_values, question, generation_config, data_flag=data_flag)
82 | print(f'User: {question}\nAssistant: {response}')
83 |
84 | # multi-image conversation
85 | pixel_values1 = load_image('./assets/example_image1.jpg', max_num=12).to(torch.bfloat16).cuda()
86 | pixel_values2 = load_image('./assets/example_image2.jpg', max_num=12).to(torch.bfloat16).cuda()
87 | pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
88 | data_flag = torch.tensor([2], dtype=torch.long).cuda()
89 | num_patches_list = [pixel_values1.shape[0], pixel_values2.shape[0]]
90 |
91 | question = 'Image-1: \nImage-2: \nWhat are the similarities and differences between these two images.'
92 | response = model.chat(tokenizer, pixel_values, question, generation_config, data_flag=data_flag, num_patches_list=num_patches_list)
93 | print(f'User: {question}\nAssistant: {response}')
94 |
95 | # video conversation
96 | pixel_values, num_patches_list = load_video('./assets/example_video.mp4', num_segments=64, max_num=1)
97 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
98 | video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))])
99 | # Frame1: \nFrame2: \n...\nFrameN: \n{question}
100 | data_flag = torch.tensor([3], dtype=torch.long).cuda()
101 |
102 | question = video_prefix + 'Describe this video in detail.'
103 | response = model.chat(tokenizer, pixel_values, question, generation_config, data_flag=data_flag, num_patches_list=num_patches_list)
104 | print(f'User: {question}\nAssistant: {response}')
105 | ```
106 |
107 | ## 📊 Evaluation
108 |
109 | ### Image Benchmarks & MVBench
110 |
111 | **Prepare data:** please follow [here](https://internvl.readthedocs.io/en/latest/get_started/eval_data_preparation.html) to prepare the data for evaluation.
112 |
113 | **Run evaluation:** use the following command to start the evaluation:
114 |
115 | ```bash
116 | bash evaluate_launch.sh
117 | ```
118 |
119 | Currently supported tasks: `vqa-ai2d-test`, `vqa-chartqa-test`, `vqa-docvqa-val`, `vqa-docvqa-test`, `vqa-infovqa-val`, `vqa-infovqa-test`, `scienceqa`, `mme`, `mmbench-dev-en`, `mmbench-test-en`, `mmmu-val`, `seed`, `mvbench`.
120 |
121 | For image benchmarks and MVBench, we use the evaluation codebase of InternVL2. Refer to [here](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html#) for more details.
122 |
123 | ## 📅 TODO List
124 |
125 | * [X] release model and checkpoint
126 | * [ ] release evaluation code
127 | * [ ] release training code
128 |
129 | ## 🖊️ Citation
130 |
131 | If you find this work helpful in your research, please consider citing:
132 |
133 | ```bibtex
134 | @article{yang2024pvc,
135 | title={PVC: Progressive Visual Token Compression for Unified Image and Video Processing in Large Vision-Language Models},
136 | author={Yang, Chenyu and Dong, Xuan and Zhu, Xizhou and Su, Weijie and Wang, Jiahao and Tian, Hao and Chen, Zhe and Wang, Wenhai and Lu, Lewei and and Dai, Jifeng},
137 | journal={arXiv preprint arXiv:2412.09613},
138 | year={2024}
139 | }
140 | ```
141 |
142 | ## 📃 License
143 |
144 | This project is released under the [MIT license](LICENSE). Parts of this project contain code and models from other sources, which are subject to their respective licenses.
145 |
--------------------------------------------------------------------------------
/assets/example_image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/example_image1.jpg
--------------------------------------------------------------------------------
/assets/example_image2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/example_image2.jpg
--------------------------------------------------------------------------------
/assets/example_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/example_video.mp4
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/PVC/b400817bb19f171175408fe458c951968001d491/assets/overview.png
--------------------------------------------------------------------------------
/eval/mmbench/evaluate_mmbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import itertools
4 | import json
5 | import os
6 | import random
7 | import time
8 | from functools import partial
9 | from io import BytesIO
10 |
11 | import pandas as pd
12 | import torch
13 | from utils.preprocess import build_transform, dynamic_preprocess
14 | from PIL import Image
15 | from torch.utils.data import Dataset
16 | from tqdm import tqdm
17 | from transformers import AutoTokenizer, AutoModel
18 |
19 | ds_collections = {
20 | 'mmbench_dev_20230712': {
21 | 'root': 'data/mmbench/mmbench_dev_20230712.tsv',
22 | 'max_new_tokens': 100,
23 | 'min_new_tokens': 1,
24 | 'type': 'dev',
25 | 'language': 'en'
26 | },
27 | 'mmbench_dev_cn_20231003': {
28 | 'root': 'data/mmbench/mmbench_dev_cn_20231003.tsv',
29 | 'max_new_tokens': 100,
30 | 'min_new_tokens': 1,
31 | 'type': 'dev',
32 | 'language': 'cn'
33 | },
34 | 'mmbench_dev_en_20231003': {
35 | 'root': 'data/mmbench/mmbench_dev_en_20231003.tsv',
36 | 'max_new_tokens': 100,
37 | 'min_new_tokens': 1,
38 | 'type': 'dev',
39 | 'language': 'en'
40 | },
41 | 'mmbench_test_cn_20231003': {
42 | 'root': 'data/mmbench/mmbench_test_cn_20231003.tsv',
43 | 'max_new_tokens': 100,
44 | 'min_new_tokens': 1,
45 | 'type': 'test',
46 | 'language': 'cn'
47 | },
48 | 'mmbench_test_en_20231003': {
49 | 'root': 'data/mmbench/mmbench_test_en_20231003.tsv',
50 | 'max_new_tokens': 100,
51 | 'min_new_tokens': 1,
52 | 'type': 'test',
53 | 'language': 'en'
54 | },
55 | 'ccbench_dev_cn': {
56 | 'root': 'data/mmbench/CCBench_legacy.tsv',
57 | 'max_new_tokens': 100,
58 | 'min_new_tokens': 1,
59 | 'type': 'dev',
60 | 'language': 'cn'
61 | }
62 | }
63 |
64 |
65 | def collate_fn(batches, tokenizer):
66 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
67 | split_sizes = torch.cat([_['split_sizes'] for _ in batches])
68 | questions = [_['question'] for _ in batches]
69 | answers = [_['answer'] for _ in batches]
70 | indexes = [_['index'] for _ in batches]
71 | options = [_['option'] for _ in batches]
72 | return pixel_values, split_sizes, questions, answers, indexes, options
73 |
74 |
75 | class MMBenchDataset(torch.utils.data.Dataset):
76 |
77 | def __init__(self, root, prompt, language, input_size=224, dynamic_image_size=False,
78 | use_thumbnail=False, max_num=6):
79 | self.df = pd.read_csv(root, sep='\t')
80 | self.prompt = prompt
81 | self.language = language
82 | self.input_size = input_size
83 | self.dynamic_image_size = dynamic_image_size
84 | self.use_thumbnail = use_thumbnail
85 | self.max_num = max_num
86 | self.transform = build_transform(input_size=input_size)
87 |
88 | def __len__(self):
89 | return len(self.df)
90 |
91 | def __getitem__(self, idx):
92 | index = self.df.iloc[idx]['index']
93 | image = self.df.iloc[idx]['image']
94 | question = self.df.iloc[idx]['question']
95 | answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None
96 | # catetory = self.df.iloc[idx]['category']
97 | # l2_catetory = self.df.iloc[idx]['l2-category']
98 |
99 | image = Image.open(BytesIO(base64.b64decode(image))).convert('RGB')
100 | if self.dynamic_image_size:
101 | images = dynamic_preprocess(image, image_size=self.input_size,
102 | use_thumbnail=self.use_thumbnail,
103 | max_num=self.max_num)
104 | else:
105 | images = [image]
106 | pixel_values = [self.transform(image) for image in images]
107 | pixel_values = torch.stack(pixel_values)
108 |
109 | option_candidate = ['A', 'B', 'C', 'D', 'E']
110 | options = {
111 | cand: self.load_from_df(idx, cand)
112 | for cand in option_candidate
113 | if self.load_from_df(idx, cand) is not None
114 | }
115 |
116 | hint = self.load_from_df(idx, 'hint')
117 | if hint is not None:
118 | question = hint + '\n' + question
119 | for key, item in options.items():
120 | question += f'\n{key}. {item}'
121 | if self.language == 'cn':
122 | question = question + '\n' + self.prompt['cn']
123 | else:
124 | question = question + '\n' + self.prompt['en']
125 |
126 | return {
127 | 'question': question,
128 | 'pixel_values': pixel_values,
129 | 'split_sizes': torch.LongTensor((pixel_values.shape[0], )),
130 | 'answer': answer,
131 | 'index': index,
132 | 'option': options
133 | }
134 |
135 | def load_from_df(self, idx, key):
136 | if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
137 | return self.df.iloc[idx][key]
138 | else:
139 | return None
140 |
141 |
142 | class InferenceSampler(torch.utils.data.sampler.Sampler):
143 |
144 | def __init__(self, size):
145 | self._size = int(size)
146 | assert size > 0
147 | self._rank = torch.distributed.get_rank()
148 | self._world_size = torch.distributed.get_world_size()
149 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
150 |
151 | @staticmethod
152 | def _get_local_indices(total_size, world_size, rank):
153 | shard_size = total_size // world_size
154 | left = total_size % world_size
155 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
156 |
157 | begin = sum(shard_sizes[:rank])
158 | end = min(sum(shard_sizes[:rank + 1]), total_size)
159 | return range(begin, end)
160 |
161 | def __iter__(self):
162 | yield from self._local_indices
163 |
164 | def __len__(self):
165 | return len(self._local_indices)
166 |
167 |
168 | def post_process(pred, option):
169 | pred = pred.strip()
170 | option_candidate = list(option.keys())
171 | if len(pred) == 1:
172 | return pred
173 | elif len(pred) != 1 and pred[0] in option_candidate:
174 | return pred[0]
175 | elif len(pred) != 1 and pred[0] not in option_candidate:
176 | for k, v in option.items():
177 | if v in pred:
178 | return k
179 |
180 | return pred
181 |
182 |
183 | def evaluate_chat_model():
184 | random.seed(args.seed)
185 |
186 | for ds_name in args.datasets:
187 | dataset = MMBenchDataset(
188 | root=ds_collections[ds_name]['root'],
189 | prompt=prompt,
190 | language=ds_collections[ds_name]['language'],
191 | input_size=image_size,
192 | dynamic_image_size=args.dynamic,
193 | use_thumbnail=use_thumbnail,
194 | max_num=args.max_num
195 | )
196 | dataloader = torch.utils.data.DataLoader(
197 | dataset=dataset,
198 | sampler=InferenceSampler(len(dataset)),
199 | batch_size=args.batch_size,
200 | num_workers=args.num_workers,
201 | pin_memory=True,
202 | drop_last=False,
203 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
204 | )
205 |
206 | outputs = []
207 | for _, (pixel_values, split_sizes, questions, answers, indexes, options) in tqdm(enumerate(dataloader)):
208 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
209 | generation_config = dict(
210 | num_beams=args.num_beams,
211 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
212 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
213 | do_sample=True if args.temperature > 0 else False,
214 | temperature=args.temperature,
215 | )
216 | pred = model.chat(
217 | tokenizer=tokenizer,
218 | pixel_values=pixel_values,
219 | split_sizes=split_sizes,
220 | question=questions[0],
221 | generation_config=generation_config
222 | )
223 | preds = [post_process(pred, options[0])]
224 |
225 | for question, pred, answer, index in zip(questions, preds, answers, indexes):
226 | outputs.append({
227 | 'question': question,
228 | 'answer': pred,
229 | 'gt_answers': answer,
230 | 'index': int(index)
231 | })
232 |
233 | torch.distributed.barrier()
234 |
235 | world_size = torch.distributed.get_world_size()
236 | merged_outputs = [None for _ in range(world_size)]
237 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
238 |
239 | merged_outputs = [json.loads(_) for _ in merged_outputs]
240 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
241 |
242 | if torch.distributed.get_rank() == 0:
243 |
244 | print(f'Evaluating {ds_name} ...')
245 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
246 | results_file = f'{ds_name}_{time_prefix}.xlsx'
247 | output_path = os.path.join(args.out_dir, results_file)
248 | df = pd.read_table(ds_collections[ds_name]['root'])
249 | cur_df = df.copy()
250 | if 'mmbench' in ds_name:
251 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category'])
252 | cur_df.insert(6, 'prediction', None)
253 | else:
254 | cur_df = cur_df.drop(columns=['category', 'image'])
255 | cur_df.insert(8, 'prediction', None)
256 | for item in merged_outputs:
257 | cur_df.loc[df['index'] == item['index'], 'prediction'] = item['answer']
258 |
259 | cur_df.to_excel(output_path, index=False, engine='openpyxl')
260 | print('Results saved to {}'.format(output_path))
261 |
262 |
263 | if __name__ == '__main__':
264 | parser = argparse.ArgumentParser()
265 | parser.add_argument('--checkpoint', type=str, default='')
266 | parser.add_argument('--datasets', type=str, default='mmbench_dev_20230712')
267 | parser.add_argument('--batch-size', type=int, default=1)
268 | parser.add_argument('--num-workers', type=int, default=1)
269 | parser.add_argument('--num-beams', type=int, default=5)
270 | parser.add_argument('--temperature', type=float, default=0.0)
271 | parser.add_argument('--out-dir', type=str, default='results')
272 | parser.add_argument('--seed', type=int, default=0)
273 | parser.add_argument('--dynamic', action='store_true')
274 | parser.add_argument('--max-num', type=int, default=6)
275 | parser.add_argument('--load-in-8bit', action='store_true')
276 | parser.add_argument('--auto', action='store_true')
277 | args = parser.parse_args()
278 |
279 | if not os.path.exists(args.out_dir):
280 | os.makedirs(args.out_dir, exist_ok=True)
281 |
282 | args.datasets = args.datasets.split(',')
283 | print('datasets:', args.datasets)
284 | assert args.batch_size == 1, 'Only batch size 1 is supported'
285 |
286 | torch.distributed.init_process_group(
287 | backend='nccl',
288 | world_size=int(os.getenv('WORLD_SIZE', '1')),
289 | rank=int(os.getenv('RANK', '0')),
290 | )
291 |
292 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
293 |
294 | if args.auto:
295 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
296 | kwargs = {'device_map': 'auto'} if args.auto else {}
297 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
298 | model = AutoModel.from_pretrained(
299 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
300 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
301 | if not args.load_in_8bit and not args.auto:
302 | model = model.cuda()
303 | image_size = model.config.force_image_size or model.config.vision_config.image_size
304 | use_thumbnail = model.config.use_thumbnail
305 |
306 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
307 | if total_params > 20 or args.dynamic:
308 | args.num_beams = 1
309 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
310 | else:
311 | print(f'[test] total_params: {total_params}B')
312 | print(f'[test] image_size: {image_size}')
313 | print(f'[test] template: {model.config.template}')
314 | print(f'[test] dynamic_image_size: {args.dynamic}')
315 | print(f'[test] use_thumbnail: {use_thumbnail}')
316 | print(f'[test] max_num: {args.max_num}')
317 |
318 | prompt = {
319 | 'en': "Answer with the option's letter from the given choices directly.",
320 | 'cn': '请直接回答选项字母。'
321 | }
322 | evaluate_chat_model()
323 |
--------------------------------------------------------------------------------
/eval/mme/README.md:
--------------------------------------------------------------------------------
1 | # This is an automated calculation script for the acc, acc+, and score.
2 |
3 | # You can directly run "python3 calculation.py" to get the evaluation results of LaVIN.
4 |
5 | # In order to get the statistical results of your model:
6 |
7 | (1) Fill all the files in "Your_Results", adding your model's responses:
8 | Each file in "Your_Results" consists of:
9 | Image_Name + "\\t" + Question + "\\t" + Ground_Truth_Answer + "\\n"
10 |
11 | You need to add the responses of your model as:
12 | Image_Name + "\\t" + Question + "\\t" + Ground_Truth_Answer + "\\t" + Your_Response + "\\n"
13 |
14 | Note: if your responses contain "\\n", please delet it. For each question, your response can only be in one line, not across lines!
15 |
16 | (2) run "python3 calculation.py --results_dir ./Your_Results"
17 |
--------------------------------------------------------------------------------
/eval/mme/calculation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from sklearn.metrics import (accuracy_score, confusion_matrix, precision_score,
5 | recall_score)
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--results_dir', default='./LaVIN', type=str)
9 |
10 | eval_type_dict = {
11 | 'Perception': ['existence', 'count', 'position', 'color', 'posters', 'celebrity', 'scene', 'landmark', 'artwork', 'OCR'],
12 | 'Cognition': ['commonsense_reasoning', 'numerical_calculation', 'text_translation', 'code_reasoning']
13 | }
14 |
15 |
16 | class calculate_metrics:
17 | def divide_chunks(self, l, n=2):
18 | # looping till length l
19 | for i in range(0, len(l), n):
20 | yield l[i:i + n]
21 |
22 | return
23 |
24 | def parse_pred_ans(self, pred_ans):
25 | pred_label = None
26 | if pred_ans in ['yes', 'no']:
27 | pred_label = pred_ans
28 | else:
29 | prefix_pred_ans = pred_ans[:4]
30 |
31 | if 'yes' in prefix_pred_ans:
32 | pred_label = 'yes'
33 | elif 'no' in prefix_pred_ans:
34 | pred_label = 'no'
35 | else:
36 | pred_label = 'other'
37 |
38 | return pred_label
39 |
40 | def compute_metric(self, gts, preds):
41 | assert len(gts) == len(preds)
42 |
43 | label_map = {
44 | 'yes': 1,
45 | 'no': 0,
46 | 'other': -1,
47 | }
48 |
49 | gts = [label_map[x] for x in gts]
50 | preds = [label_map[x] for x in preds]
51 |
52 | acc = accuracy_score(gts, preds)
53 |
54 | clean_gts = []
55 | clean_preds = []
56 | other_num = 0
57 | for gt, pred in zip(gts, preds):
58 | if pred == -1:
59 | other_num += 1
60 | continue
61 | clean_gts.append(gt)
62 | clean_preds.append(pred)
63 |
64 | conf_mat = confusion_matrix(clean_gts, clean_preds, labels=[1,0])
65 | precision = precision_score(clean_gts, clean_preds, average='binary')
66 | recall = recall_score(clean_gts, clean_preds, average='binary')
67 | tp, fn = conf_mat[0]
68 | fp, tn = conf_mat[1]
69 |
70 | metric_dict = dict()
71 | metric_dict = {
72 | 'TP': tp,
73 | 'FN': fn,
74 | 'TN': tn,
75 | 'FP': fp,
76 | 'precision': precision,
77 | 'recall': recall,
78 | 'other_num': other_num,
79 | 'acc': acc,
80 | }
81 |
82 | return metric_dict
83 |
84 | def process_result(self, results_dir):
85 |
86 | model_score_dict = dict()
87 | for eval_type, task_name_list in eval_type_dict.items():
88 | print('===========', eval_type, '===========')
89 |
90 | scores = 0
91 | task_score_dict = dict()
92 |
93 | for task_name in task_name_list:
94 |
95 | task_txt = os.path.join(results_dir, task_name + '.txt')
96 | lines = open(task_txt, 'r').readlines()
97 | chunk_lines = list(self.divide_chunks(lines)) # one image corresponds to two questions
98 |
99 | img_num = len(chunk_lines)
100 | task_other_ans_num = 0
101 | task_score = 0
102 | acc_plus_correct_num = 0
103 | gts = []
104 | preds = []
105 |
106 | for img_items in chunk_lines:
107 | assert len(img_items) == 2
108 | img_correct_num = 0
109 |
110 | for img_item in img_items:
111 | try:
112 | img_name, question, gt_ans, pred_ans = img_item.split('\t')
113 | except:
114 | print(img_item)
115 | continue
116 | gt_ans = gt_ans.lower()
117 | pred_ans = pred_ans.lower()
118 |
119 | assert gt_ans in ['yes', 'no'] # gt can only be yes or no.
120 |
121 | pred_ans = self.parse_pred_ans(pred_ans)
122 | assert pred_ans in ['yes', 'no', 'other']
123 |
124 | gts.append(gt_ans)
125 | preds.append(pred_ans)
126 |
127 | if gt_ans == pred_ans:
128 | img_correct_num += 1
129 |
130 | if pred_ans not in ['yes', 'no']:
131 | task_other_ans_num += 1
132 |
133 | if img_correct_num == 2:
134 | acc_plus_correct_num += 1
135 |
136 | # cal TP precision acc, etc.
137 | metric_dict = self.compute_metric(gts, preds)
138 | acc_plus = acc_plus_correct_num / img_num
139 | metric_dict['acc_plus'] = acc_plus
140 |
141 | for k, v in metric_dict.items():
142 | if k in ['acc', 'acc_plus']:
143 | task_score += v*100
144 |
145 | task_score_dict[task_name] = task_score
146 |
147 | scores += task_score
148 |
149 | print('total score:', scores, '\n')
150 | for task_name, score in task_score_dict.items():
151 | print('\t', task_name, ' score:', score)
152 | print('\n')
153 |
154 | return
155 |
156 |
157 | if __name__ == '__main__':
158 | cal = calculate_metrics()
159 |
160 | args = parser.parse_args()
161 | results_dir = args.results_dir
162 | cal.process_result(results_dir)
163 |
--------------------------------------------------------------------------------
/eval/mme/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 |
5 | import torch
6 | from utils.preprocess import build_transform, dynamic_preprocess
7 | from PIL import Image
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModel
10 |
11 |
12 | def load_image(image_file, input_size=224):
13 | image = Image.open(image_file).convert('RGB')
14 | transform = build_transform(input_size=input_size)
15 | if args.dynamic:
16 | images = dynamic_preprocess(image, image_size=input_size,
17 | use_thumbnail=use_thumbnail,
18 | max_num=args.max_num)
19 | else:
20 | images = [image]
21 | pixel_values = [transform(image) for image in images]
22 | pixel_values = torch.stack(pixel_values)
23 | return pixel_values
24 |
25 |
26 | def post_processing(response):
27 | response = response.replace('\n', '').replace('不是', 'No').replace('是', 'Yes').replace('否', 'No')
28 | response = response.lower().replace('true', 'yes').replace('false', 'no')
29 | pattern = re.compile(r'[\u4e00-\u9fa5]')
30 | response = re.sub(pattern, '', response)
31 | return response
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--checkpoint', type=str, default='')
37 | parser.add_argument('--root', type=str, default='./Your_Results')
38 | parser.add_argument('--num-beams', type=int, default=5)
39 | parser.add_argument('--top-k', type=int, default=50)
40 | parser.add_argument('--top-p', type=float, default=0.9)
41 | parser.add_argument('--sample', type=bool, default=False)
42 | parser.add_argument('--dynamic', action='store_true')
43 | parser.add_argument('--max-num', type=int, default=6)
44 | parser.add_argument('--load-in-8bit', action='store_true')
45 | parser.add_argument('--auto', action='store_true')
46 | args = parser.parse_args()
47 |
48 | if args.auto:
49 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
50 | kwargs = {'device_map': 'auto'} if args.auto else {}
51 | prompt = 'Answer the question using a single word or phrase.'
52 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
53 | model = AutoModel.from_pretrained(
54 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
55 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
56 | if not args.load_in_8bit and not args.auto:
57 | model = model.cuda()
58 | image_size = model.config.force_image_size or model.config.vision_config.image_size
59 | use_thumbnail = model.config.use_thumbnail
60 |
61 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
62 | if total_params > 20 or args.dynamic:
63 | args.num_beams = 1
64 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
65 | else:
66 | print(f'[test] total_params: {total_params}B')
67 | print(f'[test] image_size: {image_size}')
68 | print(f'[test] template: {model.config.template}')
69 | print(f'[test] dynamic_image_size: {args.dynamic}')
70 | print(f'[test] use_thumbnail: {use_thumbnail}')
71 | print(f'[test] max_num: {args.max_num}')
72 |
73 | output = os.path.join(args.checkpoint, 'mme')
74 | os.makedirs(output, exist_ok=True)
75 |
76 | for filename in os.listdir(args.root):
77 | fin = open(os.path.join(args.root, filename), 'r', encoding='utf-8')
78 | fout = open(os.path.join(output, filename), 'w', encoding='utf-8')
79 | lines = fin.readlines()
80 | filename = filename.replace('.txt', '')
81 | for line in tqdm(lines):
82 | img, question, gt = line.strip().split('\t')
83 | question = question + ' ' + prompt
84 | img_path = os.path.join('../../data/mme/MME_Benchmark_release_version', filename, img)
85 | assert os.path.exists(img_path), img_path
86 | pixel_values = load_image(img_path, image_size).cuda().to(torch.bfloat16)
87 | generation_config = dict(
88 | do_sample=args.sample,
89 | top_k=args.top_k,
90 | top_p=args.top_p,
91 | num_beams=args.num_beams,
92 | max_new_tokens=20,
93 | eos_token_id=tokenizer.eos_token_id,
94 | )
95 | response = model.chat(
96 | tokenizer=tokenizer,
97 | pixel_values=pixel_values,
98 | split_sizes=[pixel_values.shape[0]],
99 | question=question,
100 | generation_config=generation_config
101 | )
102 | response = post_processing(response)
103 | print(img, question, gt, response, sep='\t', file=fout)
104 | fin.close()
105 | fout.close()
106 |
--------------------------------------------------------------------------------
/eval/mmmu/data_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for data load, save, and process (e.g., prompt construction)"""
2 |
3 | import json
4 | import os
5 | import re
6 |
7 | import yaml
8 |
9 | DOMAIN_CAT2SUB_CAT = {
10 | 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
11 | 'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'],
12 | 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ],
13 | 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine',
14 | 'Pharmacy', 'Public_Health'],
15 | 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
16 | 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics',
17 | 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
18 | }
19 |
20 | CAT_SHORT2LONG = {
21 | 'acc': 'Accounting',
22 | 'agri': 'Agriculture',
23 | 'arch': 'Architecture_and_Engineering',
24 | 'art': 'Art',
25 | 'art_theory': 'Art_Theory',
26 | 'bas_med': 'Basic_Medical_Science',
27 | 'bio': 'Biology',
28 | 'chem': 'Chemistry',
29 | 'cli_med': 'Clinical_Medicine',
30 | 'cs': 'Computer_Science',
31 | 'design': 'Design',
32 | 'diag_med': 'Diagnostics_and_Laboratory_Medicine',
33 | 'econ': 'Economics',
34 | 'elec': 'Electronics',
35 | 'ep': 'Energy_and_Power',
36 | 'fin': 'Finance',
37 | 'geo': 'Geography',
38 | 'his': 'History',
39 | 'liter': 'Literature',
40 | 'manage': 'Manage',
41 | 'mark': 'Marketing',
42 | 'mate': 'Materials',
43 | 'math': 'Math',
44 | 'mech': 'Mechanical_Engineering',
45 | 'music': 'Music',
46 | 'phar': 'Pharmacy',
47 | 'phys': 'Physics',
48 | 'psy': 'Psychology',
49 | 'pub_health': 'Public_Health',
50 | 'socio': 'Sociology'
51 | }
52 |
53 |
54 | # DATA SAVING
55 | def save_json(filename, ds):
56 | with open(filename, 'w') as f:
57 | json.dump(ds, f, indent=4)
58 |
59 |
60 | def get_multi_choice_info(options):
61 | """
62 | Given the list of options for multiple choice question
63 | Return the index2ans and all_choices
64 | """
65 |
66 | start_chr = 'A'
67 | all_choices = []
68 | index2ans = {}
69 | for i, option in enumerate(options):
70 | index2ans[chr(ord(start_chr) + i)] = option
71 | all_choices.append(chr(ord(start_chr) + i))
72 |
73 | return index2ans, all_choices
74 |
75 |
76 | def load_yaml(file_path):
77 | with open(file_path, 'r') as stream:
78 | try:
79 | yaml_dict = yaml.safe_load(stream)
80 | except yaml.YAMLError as exc:
81 | print(exc)
82 |
83 | return yaml_dict
84 |
85 |
86 | def parse_img_path(text):
87 | matches = re.findall("
", text)
88 | return matches
89 |
90 |
91 | def process_single_sample(data):
92 | question = data['question']
93 | o_imgs_paths = []
94 | for option in data['options']:
95 | current_o_imgs_paths = parse_img_path(option)
96 | for img_path in current_o_imgs_paths:
97 | o_imgs_paths.append(img_path)
98 | images = [data['image_1'], data['image_2'], data['image_3'], data['image_4'],
99 | data['image_5'], data['image_6'], data['image_7']]
100 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
101 | 'image': images, 'question_type': data['question_type']}
102 |
103 |
104 | # DATA SAVING
105 | def save_json(filename, ds):
106 | with open(filename, 'w') as f:
107 | json.dump(ds, f, indent=4)
108 |
109 |
110 | def save_jsonl(filename, data):
111 | """
112 | Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
113 |
114 | Args:
115 | filename (str): The path to the file where the data should be saved.
116 | data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
117 | """
118 | with open(filename, 'w', encoding='utf-8') as f:
119 | for img_path, caption in data.items():
120 | # Extract the base filename without the extension
121 | base_filename = os.path.basename(img_path)
122 | # Create a JSON object with the filename as the key and caption as the value
123 | json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
124 | # Write the JSON object to the file, one per line
125 | f.write(json_record + '\n')
126 |
127 |
128 | def save_args(args, path_dir):
129 | argsDict = args.__dict__
130 | with open(path_dir + 'setting.txt', 'w') as f:
131 | f.writelines('------------------ start ------------------' + '\n')
132 | for eachArg, value in argsDict.items():
133 | f.writelines(eachArg + ' : ' + str(value) + '\n')
134 | f.writelines('------------------- end -------------------')
135 |
136 |
137 | # DATA PROCESSING
138 | def construct_prompt(sample, config):
139 | question = sample['question']
140 | options = eval(sample['options'])
141 | example = ''
142 | if sample['question_type'] == 'multiple-choice':
143 | start_chr = 'A'
144 | prediction_range = []
145 | index2ans = {}
146 | for option in options:
147 | prediction_range.append(start_chr)
148 | example += f'({start_chr}) {option}\n'
149 | index2ans[start_chr] = option
150 | start_chr = chr(ord(start_chr) + 1)
151 | empty_prompt_sample_structure = config['multi_choice_example_format']
152 | empty_prompt = empty_prompt_sample_structure.format(question, example)
153 | res_dict = {}
154 | res_dict['index2ans'] = index2ans
155 | res_dict['correct_choice'] = sample['answer']
156 | res_dict['all_choices'] = prediction_range
157 | res_dict['empty_prompt'] = empty_prompt
158 | if config['task_instructions']:
159 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
160 | else:
161 | res_dict['final_input_prompt'] = empty_prompt
162 |
163 | res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
164 | else:
165 | empty_prompt_sample_structure = config['short_ans_example_format']
166 | empty_prompt = empty_prompt_sample_structure.format(question)
167 | res_dict = {}
168 | res_dict['empty_prompt'] = empty_prompt
169 | if config['task_instructions']:
170 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
171 | else:
172 | res_dict['final_input_prompt'] = empty_prompt
173 | res_dict['gt_content'] = sample['answer']
174 |
175 | res_dict.update(sample)
176 | return res_dict
177 |
--------------------------------------------------------------------------------
/eval/mmmu/eval_utils.py:
--------------------------------------------------------------------------------
1 | """Response Parsing and Evaluation for various models"""
2 | import random
3 | import re
4 | from typing import Dict
5 |
6 | random.seed(42)
7 | import numpy as np
8 |
9 |
10 | # ----------- Process Multi-choice -------------
11 | def parse_multi_choice_response(response, all_choices, index2ans):
12 | """
13 | Parse the prediction from the generated response.
14 | Return the predicted index e.g., A, B, C, D.
15 | """
16 | for char in [',', '.', '!', '?', ';', ':', "'"]:
17 | response = response.strip(char)
18 | response = ' ' + response + ' ' # add space to avoid partial match
19 |
20 | index_ans = True
21 | ans_with_brack = False
22 | candidates = []
23 | for choice in all_choices: # e.g., (A) (B) (C) (D)
24 | if f'({choice})' in response:
25 | candidates.append(choice)
26 | ans_with_brack = True
27 |
28 | if len(candidates) == 0:
29 | for choice in all_choices: # e.g., A B C D
30 | if f' {choice} ' in response:
31 | candidates.append(choice)
32 |
33 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
34 | if len(candidates) == 0 and len(response.split()) > 5:
35 | for index, ans in index2ans.items():
36 | if ans.lower() in response.lower():
37 | candidates.append(index)
38 | index_ans = False # it's content ans.
39 |
40 | if len(candidates) == 0: # still not get answer, randomly choose one.
41 | pred_index = random.choice(all_choices)
42 | elif len(candidates) > 1:
43 | start_indexes = []
44 | if index_ans:
45 | if ans_with_brack:
46 | for can in candidates:
47 | index = response.rfind(f'({can})')
48 | start_indexes.append(index) # -1 will be ignored anyway
49 | # start_indexes = [generated_response.index(f'({can})') for can in candidates]
50 | else:
51 | for can in candidates:
52 | index = response.rfind(f' {can} ')
53 | start_indexes.append(index)
54 | else:
55 | for can in candidates:
56 | index = response.lower().rfind(index2ans[can].lower())
57 | start_indexes.append(index)
58 | # get the last one
59 | pred_index = candidates[np.argmax(start_indexes)]
60 | else: # if only one candidate, use it.
61 | pred_index = candidates[0]
62 |
63 | return pred_index
64 |
65 |
66 | # ----------- Process Open -------------
67 | def check_is_number(string):
68 | """
69 | Check if the given string a number.
70 | """
71 | try:
72 | float(string.replace(',', ''))
73 | return True
74 | except ValueError:
75 | # check if there's comma inside
76 | return False
77 |
78 |
79 | def normalize_str(string):
80 | """
81 | Normalize the str to lower case and make them float numbers if possible.
82 | """
83 | # check if characters in the string
84 |
85 | # if number, numerize it.
86 | string = string.strip()
87 |
88 | is_number = check_is_number(string)
89 |
90 | if is_number:
91 | string = string.replace(',', '')
92 | string = float(string)
93 | # leave 2 decimal
94 | string = round(string, 2)
95 | return [string]
96 | else: # it's likely to be a string
97 | # lower it
98 | string = string.lower()
99 | if len(string) == 1:
100 | return [' ' + string, string + ' '] # avoid trivial matches
101 | return [string]
102 |
103 |
104 | def extract_numbers(string):
105 | """
106 | Exact all forms of numbers from a string with regex.
107 | """
108 | # Pattern for numbers with commas
109 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
110 | # Pattern for scientific notation
111 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
112 | # Pattern for simple numbers without commas
113 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
114 |
115 | # Extract numbers with commas
116 | numbers_with_commas = re.findall(pattern_commas, string)
117 | # Extract numbers in scientific notation
118 | numbers_scientific = re.findall(pattern_scientific, string)
119 | # Extract simple numbers without commas
120 | numbers_simple = re.findall(pattern_simple, string)
121 |
122 | # Combine all extracted numbers
123 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
124 | return all_numbers
125 |
126 |
127 | def parse_open_response(response):
128 | """
129 | Parse the prediction from the generated response.
130 | Return a list of predicted strings or numbers.
131 | """
132 |
133 | # content = content.strip("\n").strip(".").strip(" ")
134 | def get_key_subresponses(response):
135 | key_responses = []
136 | response = response.strip().strip('.').lower()
137 | sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
138 | indicators_of_keys = ['could be ', 'so ', 'is ',
139 | 'thus ', 'therefore ', 'final ', 'answer ', 'result ']
140 | key_responses = []
141 | for index, resp in enumerate(sub_responses):
142 | # if last one, accept it's an equation (the entire response can be just one sentence with equation)
143 | if index == len(sub_responses) - 1:
144 | indicators_of_keys.extend(['='])
145 | shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
146 | for indicator in indicators_of_keys:
147 | if indicator in resp:
148 | if not shortest_key_response:
149 | shortest_key_response = resp.split(indicator)[-1].strip()
150 | else:
151 | if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
152 | shortest_key_response = resp.split(indicator)[-1].strip()
153 | # key_responses.append(resp.split(indicator)[1].strip())
154 |
155 | if shortest_key_response:
156 | # and it's not trivial
157 | if shortest_key_response.strip() not in [':', ',', '.', '!', '?', ';', ':', "'"]:
158 | key_responses.append(shortest_key_response)
159 | if len(key_responses) == 0: # did not found any
160 | return [response]
161 | return key_responses
162 |
163 | # pdb.set_trace()
164 | key_responses = get_key_subresponses(response)
165 |
166 | pred_list = key_responses.copy() # keep the original string response
167 | for resp in key_responses:
168 | pred_list.extend(extract_numbers(resp))
169 |
170 | tmp_pred_list = []
171 | for i in range(len(pred_list)):
172 | tmp_pred_list.extend(normalize_str(pred_list[i]))
173 | pred_list = tmp_pred_list
174 |
175 | # remove duplicates
176 | pred_list = list(set(pred_list))
177 |
178 | return pred_list
179 |
180 |
181 | # ----------- Evaluation -------------
182 |
183 | def eval_multi_choice(gold_i, pred_i):
184 | """
185 | Evaluate a multiple choice instance.
186 | """
187 | correct = False
188 | # only they are exactly the same, we consider it as correct
189 | if isinstance(gold_i, list):
190 | for answer in gold_i:
191 | if answer == pred_i:
192 | correct = True
193 | break
194 | else: # gold_i is a string
195 | if gold_i == pred_i:
196 | correct = True
197 | return correct
198 |
199 |
200 | def eval_open(gold_i, pred_i):
201 | """
202 | Evaluate an open question instance
203 | """
204 | correct = False
205 | if isinstance(gold_i, list):
206 | # use float to avoid trivial matches
207 | norm_answers = []
208 | for answer in gold_i:
209 | norm_answers.extend(normalize_str(answer))
210 | else:
211 | norm_answers = normalize_str(gold_i)
212 | for pred in pred_i: # pred is already normalized in parse response phase
213 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
214 | for norm_ans in norm_answers:
215 | # only see if the string answer in the string pred
216 | if isinstance(norm_ans, str) and norm_ans in pred:
217 | if not correct:
218 | correct = True
219 | break
220 | else: # it's a float number
221 | if pred in norm_answers:
222 | if not correct:
223 | correct = True
224 | break
225 | return correct
226 |
227 |
228 | # ----------- Batch Evaluation -------------
229 | def evaluate(samples):
230 | """
231 | Batch evaluation for multiple choice and open questions.
232 | """
233 | pred_correct = 0
234 | judge_dict = dict()
235 | for sample in samples:
236 | gold_i = sample['answer']
237 | pred_i = sample['parsed_pred']
238 | if sample['question_type'] == 'multiple-choice':
239 | correct = eval_multi_choice(gold_i, pred_i)
240 | else: # open question
241 | correct = eval_open(gold_i, pred_i)
242 |
243 | if correct:
244 | judge_dict[sample['id']] = 'Correct'
245 | pred_correct += 1
246 | else:
247 | judge_dict[sample['id']] = 'Wrong'
248 |
249 | if len(samples) == 0:
250 | return {'acc': 0}
251 | return judge_dict, {'acc': pred_correct / len(samples)}
252 |
253 |
254 | # ----------- Calculate Accuracy -------------
255 | def calculate_ins_level_acc(results: Dict):
256 | """Calculate the instruction level accuracy for given Subject results"""
257 | acc = 0
258 | ins_num = 0
259 | for cat_results in results.values():
260 | acc += cat_results['acc'] * cat_results['num_example']
261 | ins_num += cat_results['num_example']
262 | if ins_num == 0:
263 | return 0
264 | return acc / ins_num
265 |
--------------------------------------------------------------------------------
/eval/mmmu/evaluate_mmmu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import torch
10 | from data_utils import CAT_SHORT2LONG, process_single_sample
11 | from datasets import concatenate_datasets, load_dataset
12 | from utils.preprocess import build_transform, dynamic_preprocess
13 | from PIL import Image
14 | from torch.utils.data import Dataset
15 | from tqdm import tqdm
16 | from transformers import AutoTokenizer, AutoModel
17 |
18 | ds_collections = {
19 | 'MMMU_validation': {
20 | 'root': 'MMMU/MMMU',
21 | 'max_new_tokens': 10,
22 | 'min_new_tokens': 1,
23 | 'split': 'validation'
24 | },
25 | 'MMMU_test': {
26 | 'root': 'MMMU/MMMU',
27 | 'max_new_tokens': 10,
28 | 'min_new_tokens': 1,
29 | 'split': 'test'
30 | },
31 | 'MMMU_dev': {
32 | 'root': 'MMMU/MMMU',
33 | 'max_new_tokens': 10,
34 | 'min_new_tokens': 1,
35 | 'split': 'dev'
36 | },
37 | }
38 |
39 |
40 | def collate_fn(batches, tokenizer):
41 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
42 | questions = [_['question'] for _ in batches]
43 | answers = [_['answer'] for _ in batches]
44 | data_ids = [_['data_id'] for _ in batches]
45 | options = [_['option'] for _ in batches]
46 | return pixel_values, questions, answers, data_ids, options
47 |
48 |
49 | class MMMUDataset(torch.utils.data.Dataset):
50 |
51 | def __init__(self, root, split, prompt, input_size=224, dynamic_image_size=False,
52 | use_thumbnail=False, max_num=6):
53 | # run for each subject
54 | sub_dataset_list = []
55 | for subject in tqdm(CAT_SHORT2LONG.values()):
56 | sub_dataset = load_dataset(root, subject, split=split, ) # cache_dir=os.path.join(os.getcwd(), 'data/MMMU/'))
57 | sub_dataset_list.append(sub_dataset)
58 |
59 | # merge all dataset
60 | self.data = concatenate_datasets(sub_dataset_list)
61 | self.prompt = prompt
62 | self.input_size = input_size
63 | self.dynamic_image_size = dynamic_image_size
64 | self.use_thumbnail = use_thumbnail
65 | self.max_num = max_num
66 | self.transform = build_transform(input_size=input_size)
67 |
68 | def __len__(self):
69 | return len(self.data)
70 |
71 | def __getitem__(self, idx):
72 |
73 | data = process_single_sample(self.data[idx])
74 | data_id = data['id']
75 | question = data['question'].strip()
76 | pil_images = data['image']
77 | question_type = data['question_type']
78 |
79 | choices = eval(data['options'])
80 | answer = data['answer'] if 'answer' in data else None
81 |
82 | choice_list = []
83 | options = {}
84 | multiple_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M']
85 | for i, c in enumerate(choices):
86 | choice_list.append('{}. {}'.format(multiple_choices[i], c.strip()))
87 | options[multiple_choices[i]] = c.strip()
88 | choice_txt = '\n'.join(choice_list)
89 | if self.dynamic_image_size:
90 | images = []
91 | for idx, pil_image in enumerate(pil_images):
92 | if pil_image is not None:
93 | if idx == 0:
94 | pil_image = pil_image.resize((pil_image.width * 2, pil_image.height * 2), Image.BILINEAR)
95 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size,
96 | use_thumbnail=self.use_thumbnail, max_num=self.max_num)
97 | else:
98 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size,
99 | use_thumbnail=self.use_thumbnail, max_num=1)
100 | images += pil_image
101 | else:
102 | images = [pil_images[0]]
103 | pixel_values = [self.transform(image) for image in images]
104 | pixel_values = torch.stack(pixel_values)
105 |
106 | if len(choice_txt) > 0:
107 | question += '\n' + choice_txt
108 | question += '\n' + self.prompt[question_type]
109 | question = question.strip()
110 |
111 | return {
112 | 'question': question,
113 | 'pixel_values': pixel_values,
114 | 'answer': answer,
115 | 'option': options,
116 | 'data_id': data_id
117 | }
118 |
119 |
120 | class InferenceSampler(torch.utils.data.sampler.Sampler):
121 |
122 | def __init__(self, size):
123 | self._size = int(size)
124 | assert size > 0
125 | self._rank = torch.distributed.get_rank()
126 | self._world_size = torch.distributed.get_world_size()
127 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
128 |
129 | @staticmethod
130 | def _get_local_indices(total_size, world_size, rank):
131 | shard_size = total_size // world_size
132 | left = total_size % world_size
133 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
134 |
135 | begin = sum(shard_sizes[:rank])
136 | end = min(sum(shard_sizes[:rank + 1]), total_size)
137 | return range(begin, end)
138 |
139 | def __iter__(self):
140 | yield from self._local_indices
141 |
142 | def __len__(self):
143 | return len(self._local_indices)
144 |
145 |
146 | def post_process(pred, option):
147 | pred = pred.strip()
148 | option_candidate = list(option.keys())
149 | if len(pred) == 1:
150 | return pred
151 | elif len(pred) != 1 and pred[0] in option_candidate:
152 | return pred[0]
153 | elif len(pred) != 1 and pred[0] not in option_candidate:
154 | for k, v in option.items():
155 | if v in pred:
156 | return k
157 |
158 | return pred
159 |
160 |
161 | def evaluate_chat_model():
162 | prompt = {
163 | 'multiple-choice': "Answer with the option's letter from the given choices directly.",
164 | 'open': 'Answer the question using a single word or phrase.'
165 | }
166 | random.seed(args.seed)
167 |
168 | for ds_name in args.datasets:
169 | dataset = MMMUDataset(
170 | root=ds_collections[ds_name]['root'],
171 | split=ds_collections[ds_name]['split'],
172 | prompt=prompt,
173 | input_size=image_size,
174 | dynamic_image_size=args.dynamic,
175 | use_thumbnail=use_thumbnail,
176 | max_num=args.max_num
177 | )
178 | dataloader = torch.utils.data.DataLoader(
179 | dataset=dataset,
180 | sampler=InferenceSampler(len(dataset)),
181 | batch_size=args.batch_size,
182 | num_workers=args.num_workers,
183 | pin_memory=True,
184 | drop_last=False,
185 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
186 | )
187 |
188 | outputs = []
189 | for _, (pixel_values, questions, answers, data_ids, options) in tqdm(enumerate(dataloader)):
190 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
191 | generation_config = dict(
192 | num_beams=args.num_beams,
193 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
194 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
195 | do_sample=True if args.temperature > 0 else False,
196 | temperature=args.temperature,
197 | )
198 | pred = model.chat(
199 | tokenizer=tokenizer,
200 | pixel_values=pixel_values,
201 | question=questions[0],
202 | generation_config=generation_config
203 | )
204 | if len(options[0]) == 0:
205 | preds = [pred]
206 | else:
207 | preds = [post_process(pred, options[0])]
208 |
209 | for question, pred, answer, data_id in zip(questions, preds, answers, data_ids):
210 | outputs.append({
211 | 'question': question,
212 | 'answer': pred,
213 | 'gt_answers': answer,
214 | 'data_id': data_id
215 | })
216 |
217 | torch.distributed.barrier()
218 |
219 | world_size = torch.distributed.get_world_size()
220 | merged_outputs = [None for _ in range(world_size)]
221 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
222 |
223 | merged_outputs = [json.loads(_) for _ in merged_outputs]
224 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
225 |
226 | if torch.distributed.get_rank() == 0:
227 |
228 | print(f'Evaluating {ds_name} ...')
229 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
230 | results_file = f'{ds_name}_{time_prefix}.json'
231 | output_path = os.path.join(args.out_dir, results_file)
232 | outputs = {}
233 | for item in merged_outputs:
234 | outputs[item['data_id']] = item['answer']
235 | with open(output_path, 'w') as f:
236 | json.dump(outputs, f, indent=4)
237 | print('Results saved to {}'.format(output_path))
238 | if ds_collections[ds_name]['split'] == 'validation':
239 | print('Evaluating ...')
240 | cmd = f'python eval/mmmu/main_eval_only.py ' \
241 | f'--output_path {output_path} ' \
242 | f'--answer_path eval/mmmu/answer_dict_val.json'
243 | print(cmd)
244 | os.system(cmd)
245 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
246 | results_file = f'{ds_name}_{time_prefix}.jsonl'
247 | output_path = os.path.join(args.out_dir, results_file)
248 | writer = open(output_path, 'w')
249 | for item in merged_outputs:
250 | writer.write(json.dumps(item) + '\n')
251 | writer.close()
252 | print('Results saved to {}'.format(output_path))
253 |
254 |
255 | if __name__ == '__main__':
256 | parser = argparse.ArgumentParser()
257 | parser.add_argument('--checkpoint', type=str, default='')
258 | parser.add_argument('--datasets', type=str, default='MMMU_dev')
259 | parser.add_argument('--batch-size', type=int, default=1)
260 | parser.add_argument('--num-workers', type=int, default=1)
261 | parser.add_argument('--num-beams', type=int, default=5)
262 | parser.add_argument('--temperature', type=float, default=0.0)
263 | parser.add_argument('--out-dir', type=str, default='results')
264 | parser.add_argument('--seed', type=int, default=0)
265 | parser.add_argument('--dynamic', action='store_true')
266 | parser.add_argument('--max-num', type=int, default=6)
267 | parser.add_argument('--load-in-8bit', action='store_true')
268 | parser.add_argument('--auto', action='store_true')
269 | args = parser.parse_args()
270 |
271 | if not os.path.exists(args.out_dir):
272 | os.makedirs(args.out_dir)
273 |
274 | args.datasets = args.datasets.split(',')
275 | print('datasets:', args.datasets)
276 | assert args.batch_size == 1, 'Only batch size 1 is supported'
277 |
278 | torch.distributed.init_process_group(
279 | backend='nccl',
280 | world_size=int(os.getenv('WORLD_SIZE', '1')),
281 | rank=int(os.getenv('RANK', '0')),
282 | )
283 |
284 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
285 |
286 | if args.auto:
287 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
288 | kwargs = {'device_map': 'auto'} if args.auto else {}
289 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
290 | model = AutoModel.from_pretrained(
291 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
292 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
293 | if not args.load_in_8bit and not args.auto:
294 | model = model.cuda()
295 | image_size = model.config.force_image_size or model.config.vision_config.image_size
296 | use_thumbnail = model.config.use_thumbnail
297 |
298 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
299 | if total_params > 20 or args.dynamic:
300 | args.num_beams = 1
301 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
302 | else:
303 | print(f'[test] total_params: {total_params}B')
304 | print(f'[test] image_size: {image_size}')
305 | print(f'[test] template: {model.config.template}')
306 | print(f'[test] dynamic_image_size: {args.dynamic}')
307 | print(f'[test] use_thumbnail: {use_thumbnail}')
308 | print(f'[test] max_num: {args.max_num}')
309 |
310 | evaluate_chat_model()
311 |
--------------------------------------------------------------------------------
/eval/mmmu/evaluate_mmmu_cot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import torch
10 | from data_utils import CAT_SHORT2LONG, process_single_sample
11 | from datasets import concatenate_datasets, load_dataset
12 | from utils.preprocess import build_transform, dynamic_preprocess
13 | from PIL import Image
14 | from torch.utils.data import Dataset
15 | from tqdm import tqdm
16 | from transformers import AutoTokenizer, AutoModel
17 |
18 | ds_collections = {
19 | 'MMMU_validation': {
20 | 'root': 'MMMU/MMMU',
21 | 'max_new_tokens': 1000,
22 | 'min_new_tokens': 1,
23 | 'split': 'validation'
24 | },
25 | 'MMMU_test': {
26 | 'root': 'MMMU/MMMU',
27 | 'max_new_tokens': 1000,
28 | 'min_new_tokens': 1,
29 | 'split': 'test'
30 | },
31 | 'MMMU_dev': {
32 | 'root': 'MMMU/MMMU',
33 | 'max_new_tokens': 1000,
34 | 'min_new_tokens': 1,
35 | 'split': 'dev'
36 | },
37 | }
38 |
39 |
40 | def collate_fn(batches, tokenizer):
41 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
42 | questions = [_['question'] for _ in batches]
43 | answers = [_['answer'] for _ in batches]
44 | data_ids = [_['data_id'] for _ in batches]
45 | options = [_['option'] for _ in batches]
46 | return pixel_values, questions, answers, data_ids, options
47 |
48 |
49 | class MMMUDataset(torch.utils.data.Dataset):
50 |
51 | def __init__(self, root, split, prompt, input_size=224, dynamic_image_size=False,
52 | use_thumbnail=False, max_num=6):
53 | # run for each subject
54 | sub_dataset_list = []
55 | for subject in tqdm(CAT_SHORT2LONG.values()):
56 | sub_dataset = load_dataset(root, subject, split=split, cache_dir=os.path.join(os.getcwd(), 'data/MMMU/'))
57 | sub_dataset_list.append(sub_dataset)
58 |
59 | # merge all dataset
60 | self.data = concatenate_datasets(sub_dataset_list)
61 | self.prompt = prompt
62 | self.input_size = input_size
63 | self.dynamic_image_size = dynamic_image_size
64 | self.use_thumbnail = use_thumbnail
65 | self.max_num = max_num
66 | self.transform = build_transform(input_size=input_size)
67 |
68 | def __len__(self):
69 | return len(self.data)
70 |
71 | def __getitem__(self, idx):
72 |
73 | data = process_single_sample(self.data[idx])
74 | data_id = data['id']
75 | question = data['question'].strip()
76 | pil_images = data['image']
77 | question_type = data['question_type']
78 |
79 | choices = eval(data['options'])
80 | answer = data['answer'] if 'answer' in data else None
81 |
82 | choice_list = []
83 | options = {}
84 | multiple_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M']
85 | for i, c in enumerate(choices):
86 | choice_list.append('{}. {}'.format(multiple_choices[i], c.strip()))
87 | options[multiple_choices[i]] = c.strip()
88 | choice_txt = '\n'.join(choice_list)
89 | if self.dynamic_image_size:
90 | images = []
91 | for idx, pil_image in enumerate(pil_images):
92 | if pil_image is not None:
93 | if idx == 0:
94 | pil_image = pil_image.resize((pil_image.width * 2, pil_image.height * 2), Image.BILINEAR)
95 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size,
96 | use_thumbnail=self.use_thumbnail, max_num=self.max_num)
97 | else:
98 | pil_image = dynamic_preprocess(pil_image, image_size=self.input_size,
99 | use_thumbnail=self.use_thumbnail, max_num=1)
100 | images += pil_image
101 | else:
102 | images = [pil_images[0]]
103 | pixel_values = [self.transform(image) for image in images]
104 | pixel_values = torch.stack(pixel_values)
105 |
106 | if len(choice_txt) > 0:
107 | question += '\n' + choice_txt
108 | question += '\n' + self.prompt[question_type]
109 | question = question.strip()
110 |
111 | return {
112 | 'question': question,
113 | 'pixel_values': pixel_values,
114 | 'answer': answer,
115 | 'option': options,
116 | 'data_id': data_id
117 | }
118 |
119 |
120 | class InferenceSampler(torch.utils.data.sampler.Sampler):
121 |
122 | def __init__(self, size):
123 | self._size = int(size)
124 | assert size > 0
125 | self._rank = torch.distributed.get_rank()
126 | self._world_size = torch.distributed.get_world_size()
127 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
128 |
129 | @staticmethod
130 | def _get_local_indices(total_size, world_size, rank):
131 | shard_size = total_size // world_size
132 | left = total_size % world_size
133 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
134 |
135 | begin = sum(shard_sizes[:rank])
136 | end = min(sum(shard_sizes[:rank + 1]), total_size)
137 | return range(begin, end)
138 |
139 | def __iter__(self):
140 | yield from self._local_indices
141 |
142 | def __len__(self):
143 | return len(self._local_indices)
144 |
145 |
146 | def post_process(pred, option):
147 | pred = pred.split('"Answer":')[-1].strip()
148 | pred = pred.strip()
149 | option_candidate = list(option.keys())
150 | if len(pred) == 1:
151 | return pred
152 | elif len(pred) != 1 and pred[0] in option_candidate:
153 | return pred[0]
154 | elif len(pred) != 1 and pred[0] not in option_candidate:
155 | for k, v in option.items():
156 | if v in pred:
157 | return k
158 |
159 | return pred
160 |
161 |
162 | def evaluate_chat_model():
163 | prompt = {
164 | 'multiple-choice': 'Please provide a detailed and step-by-step explanation, structured with sections labeled Thought, Step-by-Step Solution, and Answer. If you are uncertain of the correct answer, guess the most likely one.',
165 | 'open': 'Please provide a detailed and step-by-step explanation, structured with sections labeled Thought, Step-by-Step Solution, and Answer. If you are uncertain of the correct answer, guess the most likely one.'
166 | }
167 | random.seed(args.seed)
168 |
169 | for ds_name in args.datasets:
170 | dataset = MMMUDataset(
171 | root=ds_collections[ds_name]['root'],
172 | split=ds_collections[ds_name]['split'],
173 | prompt=prompt,
174 | input_size=image_size,
175 | dynamic_image_size=args.dynamic,
176 | use_thumbnail=use_thumbnail,
177 | max_num=args.max_num
178 | )
179 | dataloader = torch.utils.data.DataLoader(
180 | dataset=dataset,
181 | sampler=InferenceSampler(len(dataset)),
182 | batch_size=args.batch_size,
183 | num_workers=args.num_workers,
184 | pin_memory=True,
185 | drop_last=False,
186 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
187 | )
188 |
189 | outputs = []
190 | for _, (pixel_values, questions, answers, data_ids, options) in tqdm(enumerate(dataloader)):
191 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
192 | generation_config = dict(
193 | num_beams=args.num_beams,
194 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
195 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
196 | do_sample=True if args.temperature > 0 else False,
197 | temperature=args.temperature,
198 | )
199 | pred = model.chat(
200 | tokenizer=tokenizer,
201 | pixel_values=pixel_values,
202 | question=questions[0],
203 | generation_config=generation_config
204 | )
205 | original_pred = pred
206 | if len(options[0]) == 0:
207 | preds = [pred]
208 | else:
209 | preds = [post_process(pred, options[0])]
210 |
211 | for question, pred, answer, data_id in zip(questions, preds, answers, data_ids):
212 | outputs.append({
213 | 'question': question,
214 | 'answer': pred,
215 | 'answer_original': original_pred,
216 | 'gt_answers': answer,
217 | 'data_id': data_id
218 | })
219 |
220 | torch.distributed.barrier()
221 |
222 | world_size = torch.distributed.get_world_size()
223 | merged_outputs = [None for _ in range(world_size)]
224 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
225 |
226 | merged_outputs = [json.loads(_) for _ in merged_outputs]
227 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
228 |
229 | if torch.distributed.get_rank() == 0:
230 |
231 | print(f'Evaluating {ds_name} ...')
232 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
233 | results_file = f'{ds_name}_{time_prefix}.json'
234 | output_path = os.path.join(args.out_dir, results_file)
235 | outputs = {}
236 | for item in merged_outputs:
237 | outputs[item['data_id']] = item['answer']
238 | with open(output_path, 'w') as f:
239 | json.dump(outputs, f, indent=4)
240 | print('Results saved to {}'.format(output_path))
241 | if ds_collections[ds_name]['split'] == 'validation':
242 | print('Evaluating ...')
243 | cmd = f'python eval/mmmu/main_eval_only.py ' \
244 | f'--output_path {output_path} ' \
245 | f'--answer_path eval/mmmu/answer_dict_val.json'
246 | print(cmd)
247 | os.system(cmd)
248 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
249 | results_file = f'{ds_name}_{time_prefix}.jsonl'
250 | output_path = os.path.join(args.out_dir, results_file)
251 | writer = open(output_path, 'w')
252 | for item in merged_outputs:
253 | writer.write(json.dumps(item) + '\n')
254 | writer.close()
255 | print('Results saved to {}'.format(output_path))
256 |
257 |
258 | if __name__ == '__main__':
259 | parser = argparse.ArgumentParser()
260 | parser.add_argument('--checkpoint', type=str, default='')
261 | parser.add_argument('--datasets', type=str, default='MMMU_dev')
262 | parser.add_argument('--batch-size', type=int, default=1)
263 | parser.add_argument('--num-workers', type=int, default=1)
264 | parser.add_argument('--num-beams', type=int, default=5)
265 | parser.add_argument('--temperature', type=float, default=0.0)
266 | parser.add_argument('--out-dir', type=str, default='results')
267 | parser.add_argument('--seed', type=int, default=0)
268 | parser.add_argument('--dynamic', action='store_true')
269 | parser.add_argument('--max-num', type=int, default=6)
270 | parser.add_argument('--load-in-8bit', action='store_true')
271 | parser.add_argument('--auto', action='store_true')
272 | args = parser.parse_args()
273 |
274 | if not os.path.exists(args.out_dir):
275 | os.makedirs(args.out_dir)
276 |
277 | args.datasets = args.datasets.split(',')
278 | print('datasets:', args.datasets)
279 | assert args.batch_size == 1, 'Only batch size 1 is supported'
280 |
281 | torch.distributed.init_process_group(
282 | backend='nccl',
283 | world_size=int(os.getenv('WORLD_SIZE', '1')),
284 | rank=int(os.getenv('RANK', '0')),
285 | )
286 |
287 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
288 |
289 | if args.auto:
290 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
291 | kwargs = {'device_map': 'auto'} if args.auto else {}
292 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
293 | model = AutoModel.from_pretrained(
294 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
295 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
296 | if not args.load_in_8bit and not args.auto:
297 | model = model.cuda()
298 | image_size = model.config.force_image_size or model.config.vision_config.image_size
299 | use_thumbnail = model.config.use_thumbnail
300 |
301 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
302 | if total_params > 20 or args.dynamic:
303 | args.num_beams = 1
304 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
305 | else:
306 | print(f'[test] total_params: {total_params}B')
307 | print(f'[test] image_size: {image_size}')
308 | print(f'[test] template: {model.config.template}')
309 | print(f'[test] dynamic_image_size: {args.dynamic}')
310 | print(f'[test] use_thumbnail: {use_thumbnail}')
311 | print(f'[test] max_num: {args.max_num}')
312 |
313 | evaluate_chat_model()
314 |
--------------------------------------------------------------------------------
/eval/mmmu/main_eval_only.py:
--------------------------------------------------------------------------------
1 | """Parse and Evalate"""
2 | import json
3 | from argparse import ArgumentParser
4 |
5 | from data_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, save_json
6 | from eval_utils import (calculate_ins_level_acc, evaluate,
7 | parse_multi_choice_response, parse_open_response)
8 |
9 | if __name__ == '__main__':
10 |
11 | parser = ArgumentParser()
12 | parser.add_argument('--output_path', type=str, default='./example_outputs/qwen_vl/total_val_output.json',
13 | help='The path to model output file.')
14 | parser.add_argument('--answer_path', type=str, default='./answer_dict_val.json', help='Answer file path.')
15 | args = parser.parse_args()
16 |
17 | output_dict = json.load(open(args.output_path))
18 | answer_dict = json.load(open(args.answer_path))
19 |
20 | # group by category
21 | output_dict_w_cat = {}
22 | for data_id, parsed_pred in output_dict.items():
23 | category = '_'.join(data_id.split('_')[1:-1])
24 | if category not in output_dict_w_cat:
25 | output_dict_w_cat.update({category: {}})
26 | output_dict_w_cat[category].update({data_id: parsed_pred})
27 |
28 | # group by category
29 | answer_dict_w_cat = {}
30 | for data_id, parsed_pred in answer_dict.items():
31 | category = '_'.join(data_id.split('_')[1:-1])
32 | if category not in answer_dict_w_cat:
33 | answer_dict_w_cat.update({category: {}})
34 | answer_dict_w_cat[category].update({data_id: parsed_pred})
35 |
36 | evaluation_result = {}
37 |
38 | for category in CAT_SHORT2LONG.values():
39 | print('Evaluating: {}'.format(category))
40 | # get cat_outputs and cat_answers
41 | try:
42 | cat_outputs = output_dict_w_cat[category]
43 | cat_answers = answer_dict_w_cat[category]
44 | except KeyError:
45 | print('Skipping {} for not found'.format(category))
46 | continue
47 |
48 | exampels_to_eval = []
49 | for data_id, parsed_pred in cat_outputs.items():
50 | question_type = cat_answers[data_id]['question_type']
51 | if question_type != 'multiple-choice':
52 | parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
53 | else:
54 | parsed_pred = parsed_pred
55 |
56 | exampels_to_eval.append({
57 | 'id': data_id,
58 | 'question_type': question_type,
59 | 'answer': cat_answers[data_id]['ground_truth'],
60 | 'parsed_pred': parsed_pred
61 | })
62 |
63 | judge_dict, metric_dict = evaluate(exampels_to_eval)
64 | metric_dict.update({'num_example': len(exampels_to_eval)})
65 |
66 | evaluation_result[category] = metric_dict
67 |
68 | printable_results = {}
69 | # pdb.set_trace()
70 | # add domain Subject
71 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
72 | in_domain_cat_results = {}
73 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
74 | if cat_name in evaluation_result.keys():
75 | in_domain_cat_results[cat_name] = evaluation_result[cat_name]
76 | else:
77 | pass
78 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
79 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
80 | printable_results['Overall-' + domain] = {'num': int(in_domain_data_num),
81 | 'acc': round(in_domain_ins_acc, 3)
82 | }
83 | # add sub category
84 | for cat_name, cat_results in in_domain_cat_results.items():
85 | printable_results[cat_name] = {'num': int(cat_results['num_example']),
86 | 'acc': round(cat_results['acc'], 3)
87 | }
88 |
89 | # table.append(["-----------------------------", "-----", "----"])
90 | all_ins_acc = calculate_ins_level_acc(evaluation_result)
91 | printable_results['Overall'] = {
92 | 'num': sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
93 | 'acc': round(all_ins_acc, 3)}
94 |
95 | print(printable_results)
96 |
--------------------------------------------------------------------------------
/eval/mvbench/evaluate_mvbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import cv2
10 | import imageio
11 | import numpy as np
12 | import torch
13 | from decord import VideoReader, cpu
14 | from utils.preprocess import build_transform, dynamic_preprocess
15 | from PIL import Image
16 | from torch.utils.data import Dataset
17 | from tqdm import tqdm
18 | from transformers import AutoTokenizer, AutoModel
19 |
20 | data_list = {
21 | 'Action Sequence': ('action_sequence.json', './data/MVBench/video/star/Charades_v1_480/', 'video', True),
22 | # has start & end
23 | 'Action Prediction': ('action_prediction.json', './data/MVBench/video/star/Charades_v1_480/', 'video', True),
24 | # has start & end
25 | 'Action Antonym': ('action_antonym.json', './data/MVBench/video/ssv2_video/', 'video', False),
26 | 'Fine-grained Action': (
27 | 'fine_grained_action.json', './data/MVBench/video/Moments_in_Time_Raw/videos/', 'video', False),
28 | 'Unexpected Action': ('unexpected_action.json', './data/MVBench/video/FunQA_test/test/', 'video', False),
29 | 'Object Existence': ('object_existence.json', './data/MVBench/video/clevrer/video_validation/', 'video', False),
30 | 'Object Interaction': ('object_interaction.json', './data/MVBench/video/star/Charades_v1_480/', 'video', True),
31 | # has start & end
32 | 'Object Shuffle': ('object_shuffle.json', './data/MVBench/video/perception/videos/', 'video', False),
33 | 'Moving Direction': ('moving_direction.json', './data/MVBench/video/clevrer/video_validation/', 'video', False),
34 | 'Action Localization': ('action_localization.json', './data/MVBench/video/sta/sta_video/', 'video', True),
35 | # has start & end
36 | 'Scene Transition': ('scene_transition.json', './data/MVBench/video/scene_qa/video/', 'video', False),
37 | 'Action Count': ('action_count.json', './data/MVBench/video/perception/videos/', 'video', False),
38 | 'Moving Count': ('moving_count.json', './data/MVBench/video/clevrer/video_validation/', 'video', False),
39 | 'Moving Attribute': ('moving_attribute.json', './data/MVBench/video/clevrer/video_validation/', 'video', False),
40 | 'State Change': ('state_change.json', './data/MVBench/video/perception/videos/', 'video', False),
41 | 'Fine-grained Pose': ('fine_grained_pose.json', './data/MVBench/video/nturgbd/', 'video', False),
42 | 'Character Order': ('character_order.json', './data/MVBench/video/perception/videos/', 'video', False),
43 | 'Egocentric Navigation': ('egocentric_navigation.json', './data/MVBench/video/vlnqa/', 'video', False),
44 | 'Episodic Reasoning': ('episodic_reasoning.json', './data/MVBench/video/tvqa/frames_fps3_hq/', 'frame', True),
45 | # has start & end, read frame
46 | 'Counterfactual Inference': (
47 | 'counterfactual_inference.json', './data/MVBench/video/clevrer/video_validation/', 'video', False),
48 | }
49 |
50 | data_dir = './data/MVBench/json'
51 |
52 |
53 | def collate_fn(batches, tokenizer):
54 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
55 | split_sizes = torch.cat([_['split_sizes'] for _ in batches])
56 | data_flag = torch.cat([_['data_flag'] for _ in batches])
57 | questions = [_['question'] for _ in batches]
58 | answers = [_['answer'] for _ in batches]
59 | num_patches_lists = [_['num_patches_list'] for _ in batches]
60 | task_types = [_['task_type'] for _ in batches]
61 | return pixel_values, split_sizes, data_flag, questions, answers, num_patches_lists, task_types
62 |
63 |
64 | class MVBenchDataset(torch.utils.data.Dataset):
65 |
66 | def __init__(self, data_dir, data_list, prompt, question_prompt, num_segments=16, input_size=224,
67 | dynamic_image_size=False, use_thumbnail=False, max_num=6):
68 | self.data_list = []
69 | for k, v in data_list.items():
70 | with open(os.path.join(data_dir, v[0]), 'r') as f:
71 | json_data = json.load(f)
72 | for data in json_data:
73 | self.data_list.append({
74 | 'task_type': k,
75 | 'prefix': v[1],
76 | 'data_type': v[2],
77 | 'bound': v[3],
78 | 'data': data
79 | })
80 | self.decord_method = {
81 | 'video': self.read_video,
82 | 'gif': self.read_gif,
83 | 'frame': self.read_frame,
84 | }
85 | self.prompt = prompt
86 | self.question_prompt = question_prompt
87 | self.input_size = input_size
88 | self.num_segments = num_segments
89 | self.dynamic_image_size = dynamic_image_size
90 | self.use_thumbnail = use_thumbnail
91 | self.max_num = max_num
92 | self.transform = build_transform(input_size=input_size)
93 |
94 | def __len__(self):
95 | return len(self.data_list)
96 |
97 | def __str__(self):
98 | len_list = {}
99 | option_list = {}
100 | for data in self.data_list:
101 | if data['task_type'] not in len_list:
102 | len_list[data['task_type']] = 0
103 | len_list[data['task_type']] += 1
104 | if data['task_type'] not in option_list:
105 | option_list[data['task_type']] = 0
106 | option_list[data['task_type']] += len(data['data']['candidates'])
107 |
108 | correct = 0
109 | total = 0
110 | res = f'There are {len(self.data_list)} videos as follow:\n'
111 | for k, v in len_list.items():
112 | correct += len_list[k]
113 | total += option_list[k]
114 | res += f'{v} for {k} ({option_list[k]} options => {len_list[k] / option_list[k] * 100:.2f}%)\n'
115 | correct = correct + 1 / option_list[k]
116 | res += f'Total random accuracy: {correct / total * 100:.2f}%'
117 | return res.rstrip()
118 |
119 | def get_index(self, bound, fps, max_frame, first_idx=0):
120 | if bound:
121 | start, end = bound[0], bound[1]
122 | else:
123 | start, end = -100000, 100000
124 | start_idx = max(first_idx, round(start * fps))
125 | end_idx = min(round(end * fps), max_frame)
126 | seg_size = float(end_idx - start_idx) / self.num_segments
127 | frame_indices = np.array([
128 | int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
129 | for idx in range(self.num_segments)
130 | ])
131 | return frame_indices
132 |
133 | def read_video(self, video_path, bound=None):
134 | vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
135 | max_frame = len(vr) - 1
136 | fps = float(vr.get_avg_fps())
137 |
138 | images_group = list()
139 | frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
140 | for frame_index in frame_indices:
141 | img = Image.fromarray(vr[frame_index].asnumpy())
142 | images_group.append(img)
143 |
144 | return images_group
145 |
146 | def read_gif(self, video_path, bound=None, fps=25):
147 | gif = imageio.get_reader(video_path)
148 | max_frame = len(gif) - 1
149 |
150 | images_group = list()
151 | frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
152 | for index, frame in enumerate(gif):
153 | if index in frame_indices:
154 | img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
155 | img = Image.fromarray(img)
156 | images_group.append(img)
157 |
158 | return images_group
159 |
160 | def read_frame(self, video_path, bound=None, fps=3):
161 | max_frame = len(os.listdir(video_path))
162 | images_group = list()
163 | frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
164 | for frame_index in frame_indices:
165 | img = Image.open(os.path.join(video_path, f'{frame_index:05d}.jpg'))
166 | images_group.append(img)
167 |
168 | return images_group
169 |
170 | def qa_template(self, data):
171 | question = f"Question: {data['question']}\n"
172 | question += 'Options:\n'
173 | answer = data['answer']
174 | answer_idx = -1
175 | for idx, c in enumerate(data['candidates']):
176 | question += f"({chr(ord('A') + idx)}) {c}\n"
177 | if c == answer:
178 | answer_idx = idx
179 | question = question.rstrip()
180 | answer = f"({chr(ord('A') + answer_idx)}) {answer}"
181 | return question, answer
182 |
183 | def __getitem__(self, idx):
184 | decord_method = self.decord_method[self.data_list[idx]['data_type']]
185 | bound = None
186 | if self.data_list[idx]['bound']:
187 | bound = (
188 | self.data_list[idx]['data']['start'],
189 | self.data_list[idx]['data']['end'],
190 | )
191 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
192 | image_list = decord_method(video_path, bound)
193 | special_tokens = '\n'.join(['Frame{}:'.format(i + 1) for i in range(len(image_list))])
194 |
195 | question, answer = self.qa_template(self.data_list[idx]['data'])
196 | question = special_tokens + '\n' + self.prompt + '\n' + question + self.question_prompt
197 |
198 | raw_images = []
199 | num_patches_list = []
200 | pixel_values = []
201 | for image in image_list:
202 | raw_images.append(image)
203 | if self.dynamic_image_size:
204 | patches = dynamic_preprocess(image, image_size=self.input_size,
205 | use_thumbnail=self.use_thumbnail,
206 | max_num=self.max_num)
207 | else:
208 | patches = [image]
209 | num_patches_list.append(len(patches))
210 | pixel_values.extend([self.transform(patch) for patch in patches])
211 |
212 | pixel_values = torch.stack(pixel_values)
213 |
214 | return {
215 | 'question': question,
216 | 'pixel_values': pixel_values,
217 | 'split_sizes': torch.LongTensor((pixel_values.shape[0], )),
218 | 'data_flag': torch.LongTensor((3, )),
219 | 'answer': answer,
220 | 'num_patches_list': num_patches_list,
221 | 'task_type': self.data_list[idx]['task_type']
222 | }
223 |
224 |
225 | class InferenceSampler(torch.utils.data.sampler.Sampler):
226 |
227 | def __init__(self, size):
228 | self._size = int(size)
229 | assert size > 0
230 | self._rank = torch.distributed.get_rank()
231 | self._world_size = torch.distributed.get_world_size()
232 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
233 |
234 | @staticmethod
235 | def _get_local_indices(total_size, world_size, rank):
236 | shard_size = total_size // world_size
237 | left = total_size % world_size
238 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
239 |
240 | begin = sum(shard_sizes[:rank])
241 | end = min(sum(shard_sizes[:rank + 1]), total_size)
242 | return range(begin, end)
243 |
244 | def __iter__(self):
245 | yield from self._local_indices
246 |
247 | def __len__(self):
248 | return len(self._local_indices)
249 |
250 |
251 | def check_ans(pred, gt):
252 | flag = False
253 | pred = pred.replace('Answer: ', '')
254 |
255 | pred_list = pred.lower().split(' ')
256 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
257 | gt_list = gt.lower().split(' ')
258 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
259 | if gt_content[-1] == '.':
260 | gt_content = gt_content[:-1]
261 |
262 | if pred_option.replace('.', '') in gt_option:
263 | flag = True
264 | elif gt_option in pred_option:
265 | flag = True
266 |
267 | return flag
268 |
269 |
270 | def evaluate_chat_model():
271 | random.seed(args.seed)
272 | prompt = 'Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n'
273 | question_prompt = '\nOnly give the best option.'
274 |
275 | vid_dataset = MVBenchDataset(
276 | data_dir, data_list,
277 | prompt=prompt,
278 | question_prompt=question_prompt,
279 | num_segments=args.num_segments,
280 | input_size=image_size,
281 | dynamic_image_size=args.dynamic,
282 | use_thumbnail=use_thumbnail,
283 | max_num=args.max_num)
284 | dataloader = torch.utils.data.DataLoader(
285 | dataset=vid_dataset,
286 | sampler=InferenceSampler(len(vid_dataset)),
287 | batch_size=args.batch_size,
288 | num_workers=args.num_workers,
289 | pin_memory=True,
290 | drop_last=False,
291 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
292 | )
293 |
294 | outputs = []
295 | for _, (pixel_values, split_sizes, data_flag, questions, answers, num_patches_lists, task_types) in tqdm(enumerate(dataloader)):
296 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
297 | generation_config = dict(
298 | num_beams=args.num_beams,
299 | max_new_tokens=1000,
300 | min_new_tokens=1,
301 | do_sample=True if args.temperature > 0 else False,
302 | temperature=args.temperature,
303 | )
304 | pred = model.chat(
305 | tokenizer=tokenizer,
306 | pixel_values=pixel_values,
307 | split_sizes=split_sizes,
308 | data_flag=data_flag,
309 | num_patches_list=num_patches_lists[0],
310 | question=questions[0],
311 | generation_config=generation_config
312 | )
313 | outputs.append({
314 | 'question': questions[0],
315 | 'pred': pred,
316 | 'gt': answers[0],
317 | 'task_type': task_types[0],
318 | })
319 | torch.distributed.barrier()
320 |
321 | world_size = torch.distributed.get_world_size()
322 | merged_outputs = [None for _ in range(world_size)]
323 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
324 |
325 | merged_outputs = [json.loads(_) for _ in merged_outputs]
326 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
327 |
328 | if torch.distributed.get_rank() == 0:
329 |
330 | print(f'Evaluating MVBench ...')
331 | correct, total, acc_dict = 0, 0, {}
332 | for item in merged_outputs:
333 | task_type = item['task_type']
334 | pred = item['pred']
335 | gt = item['gt']
336 | if task_type not in acc_dict:
337 | acc_dict[task_type] = [0, 0] # correct, total
338 | acc_dict[task_type][1] += 1
339 | total += 1
340 |
341 | if check_ans(pred, gt):
342 | acc_dict[task_type][0] += 1
343 | correct += 1
344 |
345 | final_res = {}
346 | for k, v in acc_dict.items():
347 | final_res[k] = v[0] / v[1] * 100
348 | final_res['Avg'] = correct / total * 100
349 | print(final_res)
350 |
351 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
352 | results_file = f'MVBench_{time_prefix}'
353 | output_path = os.path.join(args.out_dir, results_file)
354 | with open(f'{output_path}.json', 'w') as f:
355 | json.dump(outputs, f)
356 | with open(f'{output_path}_result_final.json', 'w') as f:
357 | json.dump(final_res, f)
358 | print('Results saved to {}'.format(output_path))
359 |
360 |
361 | if __name__ == '__main__':
362 | parser = argparse.ArgumentParser()
363 | parser.add_argument('--checkpoint', type=str, default='')
364 | parser.add_argument('--datasets', type=str, default='mvbench')
365 | parser.add_argument('--batch-size', type=int, default=1)
366 | parser.add_argument('--num-workers', type=int, default=1)
367 | parser.add_argument('--num-beams', type=int, default=5)
368 | parser.add_argument('--temperature', type=float, default=0.0)
369 | parser.add_argument('--out-dir', type=str, default='results')
370 | parser.add_argument('--seed', type=int, default=0)
371 | parser.add_argument('--dynamic', action='store_true')
372 | parser.add_argument('--max-num', type=int, default=6)
373 | parser.add_argument('--load-in-8bit', action='store_true')
374 | parser.add_argument('--auto', action='store_true')
375 | parser.add_argument('--num_segments', type=int, default=16)
376 | args = parser.parse_args()
377 |
378 | if not os.path.exists(args.out_dir):
379 | os.makedirs(args.out_dir, exist_ok=True)
380 |
381 | args.datasets = args.datasets.split(',')
382 | print('datasets:', args.datasets)
383 | assert args.batch_size == 1, 'Only batch size 1 is supported'
384 |
385 | torch.distributed.init_process_group(
386 | backend='nccl',
387 | world_size=int(os.getenv('WORLD_SIZE', '1')),
388 | rank=int(os.getenv('RANK', '0')),
389 | )
390 |
391 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
392 |
393 | if args.auto:
394 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
395 | kwargs = {'device_map': 'auto'} if args.auto else {}
396 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
397 | model = AutoModel.from_pretrained(
398 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
399 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
400 | if not args.load_in_8bit and not args.auto:
401 | model = model.cuda()
402 | image_size = model.config.force_image_size or model.config.vision_config.image_size
403 | use_thumbnail = model.config.use_thumbnail
404 |
405 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
406 | if total_params > 20 or args.dynamic:
407 | args.num_beams = 1
408 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
409 | else:
410 | print(f'[test] total_params: {total_params}B')
411 | print(f'[test] image_size: {image_size}')
412 | print(f'[test] template: {model.config.template}')
413 | print(f'[test] dynamic_image_size: {args.dynamic}')
414 | print(f'[test] use_thumbnail: {use_thumbnail}')
415 | print(f'[test] max_num: {args.max_num}')
416 |
417 | evaluate_chat_model()
418 |
--------------------------------------------------------------------------------
/eval/scienceqa/evaluate_scienceqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import torch
10 | from utils.preprocess import build_transform, dynamic_preprocess
11 | from PIL import Image
12 | from torch.utils.data import Dataset
13 | from tqdm import tqdm
14 | from transformers import AutoTokenizer, AutoModel
15 |
16 | ds_collections = {
17 | 'sqa_test': {
18 | 'root': 'data/scienceqa/scienceqa_test_img.jsonl',
19 | 'max_new_tokens': 100,
20 | 'min_new_tokens': 1,
21 | },
22 | }
23 |
24 |
25 | def collate_fn(batches, tokenizer):
26 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
27 | questions = [_['question'] for _ in batches]
28 | answers = [_['answer'] for _ in batches]
29 | image_paths = [_['image_path'] for _ in batches]
30 | options = [_['option'] for _ in batches]
31 | return pixel_values, questions, answers, image_paths, options
32 |
33 |
34 | class ScienceQADataset(torch.utils.data.Dataset):
35 |
36 | def __init__(self, root, prompt, input_size=224, dynamic_image_size=False,
37 | use_thumbnail=False, max_num=6):
38 | f = open(root, 'r', encoding='utf-8')
39 | self.data = [json.loads(line) for line in f.readlines()]
40 | self.prompt = prompt
41 | self.input_size = input_size
42 | self.dynamic_image_size = dynamic_image_size
43 | self.use_thumbnail = use_thumbnail
44 | self.max_num = max_num
45 | self.transform = build_transform(input_size=input_size)
46 |
47 | def __len__(self):
48 | return len(self.data)
49 |
50 | def __getitem__(self, idx):
51 | data = self.data[idx]
52 | image_path = data['image']
53 | hint = data['hint'] if data['hint'] else None
54 | question = data['question']
55 |
56 | choices = data['choices']
57 | answer = data['answer']
58 | choice_list = []
59 |
60 | options = {}
61 | multiple_choices = ['A', 'B', 'C', 'D', 'E']
62 | for i, c in enumerate(choices):
63 | choice_list.append('{}. {}'.format(multiple_choices[i], c))
64 | options[multiple_choices[i]] = c
65 | choice_txt = '\n'.join(choice_list)
66 |
67 | image = Image.open(image_path).convert('RGB')
68 | if self.dynamic_image_size:
69 | images = dynamic_preprocess(image, image_size=self.input_size,
70 | use_thumbnail=self.use_thumbnail,
71 | max_num=self.max_num)
72 | else:
73 | images = [image]
74 | pixel_values = [self.transform(image) for image in images]
75 | pixel_values = torch.stack(pixel_values)
76 |
77 | if hint is not None:
78 | question = hint + '\n' + question
79 | question += '\n' + choice_txt
80 | question += '\n' + self.prompt
81 |
82 | return {
83 | 'question': question,
84 | 'pixel_values': pixel_values,
85 | 'answer': multiple_choices[answer],
86 | 'image_path': image_path,
87 | 'option': options
88 | }
89 |
90 |
91 | class InferenceSampler(torch.utils.data.sampler.Sampler):
92 |
93 | def __init__(self, size):
94 | self._size = int(size)
95 | assert size > 0
96 | self._rank = torch.distributed.get_rank()
97 | self._world_size = torch.distributed.get_world_size()
98 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
99 |
100 | @staticmethod
101 | def _get_local_indices(total_size, world_size, rank):
102 | shard_size = total_size // world_size
103 | left = total_size % world_size
104 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
105 |
106 | begin = sum(shard_sizes[:rank])
107 | end = min(sum(shard_sizes[:rank + 1]), total_size)
108 | return range(begin, end)
109 |
110 | def __iter__(self):
111 | yield from self._local_indices
112 |
113 | def __len__(self):
114 | return len(self._local_indices)
115 |
116 |
117 | def post_process(pred, option):
118 | pred = pred.strip()
119 | option_candidate = list(option.keys())
120 | if len(pred) == 1:
121 | return pred
122 | elif len(pred) != 1 and pred[0] in option_candidate:
123 | return pred[0]
124 | elif len(pred) != 1 and pred[0] not in option_candidate:
125 | for k, v in option.items():
126 | if v in pred:
127 | return k
128 |
129 | return pred
130 |
131 |
132 | def evaluate_chat_model():
133 | prompt = "Answer with the option's letter from the given choices directly."
134 | random.seed(args.seed)
135 |
136 | for ds_name in args.datasets:
137 | dataset = ScienceQADataset(
138 | root=ds_collections[ds_name]['root'],
139 | prompt=prompt,
140 | input_size=image_size,
141 | dynamic_image_size=args.dynamic,
142 | use_thumbnail=use_thumbnail,
143 | max_num=args.max_num
144 | )
145 | dataloader = torch.utils.data.DataLoader(
146 | dataset=dataset,
147 | sampler=InferenceSampler(len(dataset)),
148 | batch_size=args.batch_size,
149 | num_workers=args.num_workers,
150 | pin_memory=True,
151 | drop_last=False,
152 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
153 | )
154 |
155 | outputs = []
156 | for _, (pixel_values, questions, answers, image_paths, options) in tqdm(enumerate(dataloader)):
157 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
158 | generation_config = dict(
159 | num_beams=args.num_beams,
160 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
161 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
162 | do_sample=True if args.temperature > 0 else False,
163 | temperature=args.temperature,
164 | )
165 | pred = model.chat(
166 | tokenizer=tokenizer,
167 | pixel_values=pixel_values,
168 | question=questions[0],
169 | generation_config=generation_config
170 | )
171 | preds = [post_process(pred, options[0])]
172 |
173 | for question, pred, answer, image_path in zip(questions, preds, answers, image_paths):
174 | outputs.append({
175 | 'question': question,
176 | 'answer': pred,
177 | 'gt_answers': answer,
178 | 'image_path': image_path
179 | })
180 |
181 | torch.distributed.barrier()
182 |
183 | world_size = torch.distributed.get_world_size()
184 | merged_outputs = [None for _ in range(world_size)]
185 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
186 |
187 | merged_outputs = [json.loads(_) for _ in merged_outputs]
188 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
189 |
190 | if torch.distributed.get_rank() == 0:
191 |
192 | print(f'Evaluating {ds_name} ...')
193 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
194 | results_file = f'{ds_name}_{time_prefix}.jsonl'
195 | output_path = os.path.join(args.out_dir, results_file)
196 | with open(output_path, 'w') as f:
197 | for output in merged_outputs:
198 | f.write(json.dumps(output) + '\n')
199 | print('Results saved to {}'.format(output_path))
200 | cnt = 0
201 | for item in merged_outputs:
202 | if item['answer'] == item['gt_answers']:
203 | cnt += 1
204 | print(f'Acc@1: {cnt / len(merged_outputs)}')
205 |
206 |
207 | if __name__ == '__main__':
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument('--checkpoint', type=str, default='')
210 | parser.add_argument('--datasets', type=str, default='sqa_test')
211 | parser.add_argument('--batch-size', type=int, default=1)
212 | parser.add_argument('--num-workers', type=int, default=1)
213 | parser.add_argument('--num-beams', type=int, default=5)
214 | parser.add_argument('--temperature', type=float, default=0.0)
215 | parser.add_argument('--out-dir', type=str, default='results')
216 | parser.add_argument('--seed', type=int, default=0)
217 | parser.add_argument('--dynamic', action='store_true')
218 | parser.add_argument('--max-num', type=int, default=6)
219 | parser.add_argument('--load-in-8bit', action='store_true')
220 | parser.add_argument('--auto', action='store_true')
221 | args = parser.parse_args()
222 |
223 | if not os.path.exists(args.out_dir):
224 | os.makedirs(args.out_dir)
225 |
226 | args.datasets = args.datasets.split(',')
227 | print('datasets:', args.datasets)
228 | assert args.batch_size == 1, 'Only batch size 1 is supported'
229 |
230 | torch.distributed.init_process_group(
231 | backend='nccl',
232 | world_size=int(os.getenv('WORLD_SIZE', '1')),
233 | rank=int(os.getenv('RANK', '0')),
234 | )
235 |
236 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
237 |
238 | if args.auto:
239 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
240 | kwargs = {'device_map': 'auto'} if args.auto else {}
241 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
242 | model = AutoModel.from_pretrained(
243 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
244 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
245 | if not args.load_in_8bit and not args.auto:
246 | model = model.cuda()
247 | image_size = model.config.force_image_size or model.config.vision_config.image_size
248 | use_thumbnail = model.config.use_thumbnail
249 |
250 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
251 | if total_params > 20 or args.dynamic:
252 | args.num_beams = 1
253 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
254 | else:
255 | print(f'[test] total_params: {total_params}B')
256 | print(f'[test] image_size: {image_size}')
257 | print(f'[test] template: {model.config.template}')
258 | print(f'[test] dynamic_image_size: {args.dynamic}')
259 | print(f'[test] use_thumbnail: {use_thumbnail}')
260 |
261 | evaluate_chat_model()
262 |
--------------------------------------------------------------------------------
/eval/seed/calculation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | argparse = argparse.ArgumentParser()
6 | argparse.add_argument('--image_result_file', type=str, default='')
7 | argparse.add_argument('--anno_path', type=str, default='data/SEED/SEED-Bench.json')
8 |
9 | args = argparse.parse_args()
10 | image_result_file = args.image_result_file
11 | anno_path = args.anno_path
12 |
13 | assert image_result_file.endswith('.jsonl')
14 |
15 |
16 | def is_integer_string(s):
17 | try:
18 | int(s)
19 | return True
20 | except ValueError:
21 | return False
22 |
23 |
24 | def filter_questions(data, task='all'):
25 | if task == 'image':
26 | return [q for q in data if 1 <= q['question_type_id'] <= 9]
27 | elif task == 'video':
28 | return [q for q in data if 10 <= q['question_type_id'] <= 12]
29 | elif task == 'all':
30 | return data
31 | elif is_integer_string(task):
32 | return [q for q in data if q['question_type_id'] == int(task)]
33 | else:
34 | raise ValueError(f'Invalid task: {task}')
35 |
36 |
37 | if __name__ == '__main__':
38 |
39 | qa_anno = json.load(open(anno_path, 'rb'))
40 | if 'questions' in qa_anno.keys():
41 | question_type = qa_anno['question_type']
42 | question_id_type = {v: k for k, v in question_type.items()}
43 | qa_anno = qa_anno['questions']
44 |
45 | qa_anno = filter_questions(qa_anno, 'all')
46 | print(f'length: {len(qa_anno)}')
47 |
48 | with open(image_result_file, 'r') as f:
49 |
50 | image_result = [json.loads(line) for line in f.readlines()]
51 |
52 | results = []
53 |
54 | results.extend(image_result)
55 |
56 | qa_id_anno = {}
57 | for item in qa_anno:
58 | question_id = str(item['question_id'])
59 | qa_id_anno[question_id] = item
60 |
61 | type_counts = {k: [] for k, v in question_id_type.items()}
62 |
63 | for item in results:
64 | pred, gt, question_id = item['prediction'], item['answer'], item['question_id']
65 | question_id = str(question_id)
66 | question_type = qa_id_anno[question_id]['question_type_id']
67 | data_type = qa_id_anno[question_id]['data_type']
68 | gt = qa_id_anno[question_id]['answer']
69 | if len(pred) != 1:
70 | pred = pred[0]
71 | if pred == gt:
72 | type_counts[question_type].append(1)
73 | else:
74 | type_counts[question_type].append(0)
75 |
76 | print('Accuracy for each data type:')
77 | total_count, image_count, video_count = 0, 0, 0
78 | total_correct, image_correct, video_correct = 0, 0, 0
79 | for data_type_id, result in type_counts.items():
80 | accuracy = sum(result) / len(result) * 100
81 | data_type = question_id_type[data_type_id]
82 | print(f'Data type {data_type}: {accuracy:.2f}%')
83 |
84 | total_count += len(result)
85 | total_correct += sum(result)
86 | if data_type_id >= 1 and data_type_id <= 9:
87 | image_count += len(result)
88 | image_correct += sum(result)
89 | else:
90 | video_count += len(result)
91 | video_correct += sum(result)
92 |
93 | total_accuracy = total_correct / total_count * 100
94 | image_accuracy = image_correct / image_count * 100
95 | video_accuracy = video_correct / video_count * 100
96 |
97 | print(f'Total accuracy: {total_accuracy:.2f}%')
98 | print(f'Image accuracy: {image_accuracy:.2f}%')
99 | print(f'Video accuracy: {video_accuracy:.2f}%')
100 |
--------------------------------------------------------------------------------
/eval/seed/evaluate_seed.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import torch
10 | from utils.preprocess import build_transform, dynamic_preprocess
11 | from PIL import Image
12 | from torch.utils.data import Dataset
13 | from tqdm import tqdm
14 | from transformers import AutoTokenizer, AutoModel
15 |
16 | ds_collections = {
17 | 'SEEDv1': {
18 | 'root': 'data/SEED/',
19 | 'annotation': 'data/SEED/seed.jsonl',
20 | 'max_new_tokens': 100,
21 | 'min_new_tokens': 1,
22 | },
23 | }
24 |
25 |
26 | def collate_fn(batches, tokenizer):
27 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
28 | questions = [_['question'] for _ in batches]
29 | answers = [_['answer'] for _ in batches]
30 | indexes = [_['index'] for _ in batches]
31 | return pixel_values, questions, answers, indexes
32 |
33 |
34 | class MultipleChoiceDataset(torch.utils.data.Dataset):
35 |
36 | def __init__(self, root, annotation, input_size=224, dynamic_image_size=False,
37 | use_thumbnail=False, max_num=6):
38 | f = open(annotation, 'r', encoding='utf-8')
39 | self.data = [json.loads(line) for line in f.readlines()]
40 | self.root = root
41 | self.input_size = input_size
42 | self.dynamic_image_size = dynamic_image_size
43 | self.use_thumbnail = use_thumbnail
44 | self.max_num = max_num
45 | self.transform = build_transform(input_size=input_size)
46 |
47 | def __len__(self):
48 | return len(self.data)
49 |
50 | def __getitem__(self, idx):
51 | data = self.data[idx]
52 | question = data['text']
53 | image_path = os.path.join(self.root, data['image'])
54 | image = Image.open(image_path).convert('RGB')
55 | if self.dynamic_image_size:
56 | images = dynamic_preprocess(image, image_size=self.input_size,
57 | use_thumbnail=self.use_thumbnail,
58 | max_num=self.max_num)
59 | else:
60 | images = [image]
61 | pixel_values = [self.transform(image) for image in images]
62 | pixel_values = torch.stack(pixel_values)
63 | answer = data['answer'] if 'answer' in data else None
64 | return {
65 | 'question': question,
66 | 'pixel_values': pixel_values,
67 | 'answer': answer,
68 | 'index': data['question_id'],
69 | }
70 |
71 |
72 | class InferenceSampler(torch.utils.data.sampler.Sampler):
73 |
74 | def __init__(self, size):
75 | self._size = int(size)
76 | assert size > 0
77 | self._rank = torch.distributed.get_rank()
78 | self._world_size = torch.distributed.get_world_size()
79 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
80 |
81 | @staticmethod
82 | def _get_local_indices(total_size, world_size, rank):
83 | shard_size = total_size // world_size
84 | left = total_size % world_size
85 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
86 |
87 | begin = sum(shard_sizes[:rank])
88 | end = min(sum(shard_sizes[:rank + 1]), total_size)
89 | return range(begin, end)
90 |
91 | def __iter__(self):
92 | yield from self._local_indices
93 |
94 | def __len__(self):
95 | return len(self._local_indices)
96 |
97 |
98 | def post_process(pred, option):
99 | pred = pred.strip()
100 | option_candidate = list(option.keys())
101 | if len(pred) == 1:
102 | return pred
103 | elif len(pred) != 1 and pred[0] in option_candidate:
104 | return pred[0]
105 | elif len(pred) != 1 and pred[0] not in option_candidate:
106 | for k, v in option.items():
107 | if v in pred:
108 | return k
109 |
110 | return pred
111 |
112 |
113 | def evaluate_chat_model():
114 | random.seed(args.seed)
115 |
116 | for ds_name in args.datasets:
117 | dataset = MultipleChoiceDataset(
118 | root=ds_collections[ds_name]['root'],
119 | annotation=ds_collections[ds_name]['annotation'],
120 | input_size=image_size,
121 | dynamic_image_size=args.dynamic,
122 | use_thumbnail=use_thumbnail,
123 | max_num=args.max_num
124 | )
125 | dataloader = torch.utils.data.DataLoader(
126 | dataset=dataset,
127 | sampler=InferenceSampler(len(dataset)),
128 | batch_size=args.batch_size,
129 | num_workers=args.num_workers,
130 | pin_memory=True,
131 | drop_last=False,
132 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
133 | )
134 |
135 | outputs = []
136 | for _, (pixel_values, questions, answers, indexes) in enumerate(tqdm(dataloader)):
137 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
138 | generation_config = dict(
139 | num_beams=args.num_beams,
140 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
141 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
142 | do_sample=True if args.temperature > 0 else False,
143 | temperature=args.temperature,
144 | )
145 | pred = model.chat(
146 | tokenizer=tokenizer,
147 | pixel_values=pixel_values,
148 | question=questions[0],
149 | generation_config=generation_config
150 | )
151 | preds = [pred]
152 |
153 | for question, pred, answer, index in zip(questions, preds, answers, indexes):
154 | outputs.append({
155 | 'question_id': index,
156 | 'question': question,
157 | 'prediction': pred,
158 | 'answer': answer,
159 | })
160 |
161 | torch.distributed.barrier()
162 |
163 | world_size = torch.distributed.get_world_size()
164 | merged_outputs = [None for _ in range(world_size)]
165 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
166 |
167 | merged_outputs = [json.loads(_) for _ in merged_outputs]
168 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
169 |
170 | if torch.distributed.get_rank() == 0:
171 |
172 | print(f'Evaluating {ds_name} ...')
173 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
174 | results_file = f'{ds_name}_{time_prefix}.jsonl'
175 | output_path = os.path.join(args.out_dir, results_file)
176 | writer = open(output_path, 'w')
177 |
178 | results = []
179 | for item in merged_outputs:
180 | writer.write(json.dumps(item) + '\n')
181 | answer = item['answer']
182 | prediction = item['prediction']
183 | if prediction == answer:
184 | results.append(1)
185 | else:
186 | results.append(0)
187 | writer.close()
188 | print('Results saved to {}'.format(output_path))
189 | print(f'Acc@1: {sum(results) / len(results)}')
190 | cmd = f'python eval/seed/calculation.py --image_result_file {output_path}'
191 | os.system(cmd)
192 |
193 |
194 | if __name__ == '__main__':
195 | parser = argparse.ArgumentParser()
196 | parser.add_argument('--checkpoint', type=str, default='')
197 | parser.add_argument('--datasets', type=str, default='SEEDv1')
198 | parser.add_argument('--batch-size', type=int, default=1)
199 | parser.add_argument('--num-workers', type=int, default=1)
200 | parser.add_argument('--num-beams', type=int, default=5)
201 | parser.add_argument('--temperature', type=float, default=0.0)
202 | parser.add_argument('--out-dir', type=str, default='results')
203 | parser.add_argument('--seed', type=int, default=0)
204 | parser.add_argument('--dynamic', action='store_true')
205 | parser.add_argument('--max-num', type=int, default=6)
206 | parser.add_argument('--load-in-8bit', action='store_true')
207 | parser.add_argument('--auto', action='store_true')
208 | args = parser.parse_args()
209 |
210 | if not os.path.exists(args.out_dir):
211 | os.makedirs(args.out_dir)
212 |
213 | args.datasets = args.datasets.split(',')
214 | print('datasets:', args.datasets)
215 | assert args.batch_size == 1, 'Only batch size 1 is supported'
216 |
217 | torch.distributed.init_process_group(
218 | backend='nccl',
219 | world_size=int(os.getenv('WORLD_SIZE', '1')),
220 | rank=int(os.getenv('RANK', '0')),
221 | )
222 |
223 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
224 |
225 | if args.auto:
226 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
227 | kwargs = {'device_map': 'auto'} if args.auto else {}
228 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
229 | model = AutoModel.from_pretrained(
230 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
231 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
232 | if not args.load_in_8bit and not args.auto:
233 | model = model.cuda()
234 | image_size = model.config.force_image_size or model.config.vision_config.image_size
235 | use_thumbnail = model.config.use_thumbnail
236 |
237 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
238 | if total_params > 20 or args.dynamic:
239 | args.num_beams = 1
240 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
241 | else:
242 | print(f'[test] total_params: {total_params}B')
243 | print(f'[test] image_size: {image_size}')
244 | print(f'[test] template: {model.config.template}')
245 | print(f'[test] dynamic_image_size: {args.dynamic}')
246 | print(f'[test] use_thumbnail: {use_thumbnail}')
247 |
248 | evaluate_chat_model()
249 |
--------------------------------------------------------------------------------
/eval/vqa/convert_gqa_for_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--src', type=str)
6 | parser.add_argument('--dst', type=str)
7 | args = parser.parse_args()
8 |
9 | all_answers = []
10 | data = json.load(open(args.src))
11 | for res in data:
12 | question_id = res['questionId']
13 | answer = res['answer'].rstrip('.').lower()
14 | all_answers.append({'questionId': question_id, 'prediction': answer})
15 |
16 | with open(args.dst, 'w') as f:
17 | json.dump(all_answers, f)
18 |
--------------------------------------------------------------------------------
/eval/vqa/evaluate_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import subprocess
7 | import time
8 | from functools import partial
9 | from typing import Optional
10 |
11 | import torch
12 | from utils.preprocess import build_transform, dynamic_preprocess
13 | from PIL import Image
14 | from textvqa_eval import TextVQAAccuracyEvaluator
15 | from tqdm import tqdm
16 | from transformers import AutoTokenizer, AutoModel
17 |
18 | ds_collections = {
19 | 'vqav2_val': {
20 | 'train': 'data/vqav2/vqav2_train.jsonl',
21 | 'test': 'data/vqav2/vqav2_val.jsonl',
22 | 'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json',
23 | 'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json',
24 | 'metric': 'vqa_score',
25 | 'max_new_tokens': 10,
26 | },
27 | 'vqav2_testdev': {
28 | 'train': 'data/vqav2/vqav2_train.jsonl',
29 | 'test': 'data/vqav2/vqav2_testdev.jsonl',
30 | 'metric': None,
31 | 'max_new_tokens': 10,
32 | },
33 | 'okvqa_val': {
34 | 'train': 'data/okvqa/okvqa_train.jsonl',
35 | 'test': 'data/okvqa/okvqa_val.jsonl',
36 | 'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json',
37 | 'annotation': 'data/okvqa/mscoco_val2014_annotations.json',
38 | 'metric': 'vqa_score',
39 | 'max_new_tokens': 10,
40 | },
41 | 'textvqa_val': {
42 | 'train': 'data/textvqa/textvqa_train.jsonl',
43 | 'test': 'data/textvqa/textvqa_val.jsonl',
44 | 'question': 'data/textvqa/textvqa_val_questions.json',
45 | 'annotation': 'data/textvqa/textvqa_val_annotations.json',
46 | 'metric': 'vqa_score',
47 | 'max_new_tokens': 10,
48 | },
49 | 'textvqa_val_ocr': {
50 | 'train': 'data/textvqa/textvqa_train.jsonl',
51 | 'test': 'data/textvqa/textvqa_val_llava.jsonl',
52 | 'question': 'data/textvqa/textvqa_val_questions.json',
53 | 'annotation': 'data/textvqa/textvqa_val_annotations.json',
54 | 'metric': 'vqa_score',
55 | 'max_new_tokens': 10,
56 | },
57 | 'vizwiz_val': {
58 | 'train': 'data/vizwiz/vizwiz_train.jsonl',
59 | 'test': 'data/vizwiz/vizwiz_val.jsonl',
60 | 'question': 'data/vizwiz/vizwiz_val_questions.json',
61 | 'annotation': 'data/vizwiz/vizwiz_val_annotations.json',
62 | 'metric': 'vqa_score',
63 | 'max_new_tokens': 10,
64 | },
65 | 'vizwiz_test': {
66 | 'train': 'data/vizwiz/vizwiz_train.jsonl',
67 | 'test': 'data/vizwiz/vizwiz_test.jsonl',
68 | 'metric': None,
69 | 'max_new_tokens': 10,
70 | },
71 | 'docvqa_val': {
72 | 'train': 'data/docvqa/train.jsonl',
73 | 'test': 'data/docvqa/val.jsonl',
74 | 'annotation': 'data/docvqa/val/val_v1.0.json',
75 | 'metric': 'anls',
76 | 'max_new_tokens': 100,
77 | },
78 | 'docvqa_test': {
79 | 'train': 'data/docvqa/train.jsonl',
80 | 'test': 'data/docvqa/test.jsonl',
81 | 'metric': None,
82 | 'max_new_tokens': 100,
83 | },
84 | 'chartqa_test_human': {
85 | 'train': 'data/chartqa/train_human.jsonl',
86 | 'test': 'data/chartqa/test_human.jsonl',
87 | 'metric': 'relaxed_accuracy',
88 | 'max_new_tokens': 100,
89 | },
90 | 'chartqa_test_augmented': {
91 | 'train': 'data/chartqa/train_augmented.jsonl',
92 | 'test': 'data/chartqa/test_augmented.jsonl',
93 | 'metric': 'relaxed_accuracy',
94 | 'max_new_tokens': 100,
95 | },
96 | 'gqa_testdev': {
97 | 'train': 'data/gqa/train.jsonl',
98 | 'test': 'data/gqa/test_balanced.jsonl',
99 | 'metric': 'accuracy',
100 | 'max_new_tokens': 10,
101 | },
102 | 'gqa_testdev_llava': {
103 | 'train': 'data/gqa/train.jsonl',
104 | 'test': 'data/gqa/llava_gqa_testdev_balanced_qwen_format.jsonl',
105 | 'metric': 'accuracy',
106 | 'max_new_tokens': 10,
107 | },
108 | 'ocrvqa_val': {
109 | 'train': 'data/ocrvqa/ocrvqa_train.jsonl',
110 | 'test': 'data/ocrvqa/ocrvqa_val.jsonl',
111 | 'metric': 'accuracy',
112 | 'max_new_tokens': 100,
113 | },
114 | 'ocrvqa_test': {
115 | 'train': 'data/ocrvqa/ocrvqa_train.jsonl',
116 | 'test': 'data/ocrvqa/ocrvqa_test.jsonl',
117 | 'metric': 'accuracy',
118 | 'max_new_tokens': 100,
119 | },
120 | 'ai2diagram_test': {
121 | 'train': 'data/ai2diagram/train.jsonl',
122 | 'test': 'data/ai2diagram/test_vlmevalkit.jsonl',
123 | 'metric': 'accuracy',
124 | 'max_new_tokens': 10,
125 | },
126 | 'infographicsvqa_val': {
127 | 'train': 'data/infographicsvqa/train.jsonl',
128 | 'test': 'data/infographicsvqa/val.jsonl',
129 | 'annotation': 'data/infographicsvqa/infographicsVQA_val_v1.0_withQT.json',
130 | 'metric': 'anls',
131 | 'max_new_tokens': 100,
132 | },
133 | 'infographicsvqa_test': {
134 | 'train': 'data/infographicsvqa/train.jsonl',
135 | 'test': 'data/infographicsvqa/test.jsonl',
136 | 'annotation': 'data/infographicsvqa/infographicsVQA_test_v1.0.json',
137 | 'metric': None,
138 | 'max_new_tokens': 100,
139 | }
140 | }
141 |
142 |
143 | # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
144 | def relaxed_correctness(target: str,
145 | prediction: str,
146 | max_relative_change: float = 0.05) -> bool:
147 | """Calculates relaxed correctness.
148 |
149 | The correctness tolerates certain error ratio defined by max_relative_change.
150 | See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
151 | “Following Methani et al. (2020), we use a relaxed accuracy measure for the
152 | numeric answers to allow a minor inaccuracy that may result from the automatic
153 | data extraction process. We consider an answer to be correct if it is within
154 | 5% of the gold answer. For non-numeric answers, we still need an exact match
155 | to consider an answer to be correct.”
156 |
157 | Args:
158 | target: Target string.
159 | prediction: Predicted string.
160 | max_relative_change: Maximum relative change.
161 |
162 | Returns:
163 | Whether the prediction was correct given the specified tolerance.
164 | """
165 |
166 | def _to_float(text: str) -> Optional[float]:
167 | try:
168 | if text.endswith('%'):
169 | # Convert percentages to floats.
170 | return float(text.rstrip('%')) / 100.0
171 | else:
172 | return float(text)
173 | except ValueError:
174 | return None
175 |
176 | prediction_float = _to_float(prediction)
177 | target_float = _to_float(target)
178 | if prediction_float is not None and target_float:
179 | relative_change = abs(prediction_float -
180 | target_float) / abs(target_float)
181 | return relative_change <= max_relative_change
182 | else:
183 | return prediction.lower() == target.lower()
184 |
185 |
186 | def evaluate_relaxed_accuracy(entries):
187 | scores = []
188 | for elem in entries:
189 | if isinstance(elem['annotation'], str):
190 | elem['annotation'] = [elem['annotation']]
191 | score = max([
192 | relaxed_correctness(elem['answer'].strip(), ann)
193 | for ann in elem['annotation']
194 | ])
195 | scores.append(score)
196 | return sum(scores) / len(scores)
197 |
198 |
199 | def evaluate_exact_match_accuracy(entries):
200 | scores = []
201 | for elem in entries:
202 | if isinstance(elem['annotation'], str):
203 | elem['annotation'] = [elem['annotation']]
204 | score = max([
205 | (1.0 if
206 | (elem['answer'].strip().lower() == ann.strip().lower()) else 0.0)
207 | for ann in elem['annotation']
208 | ])
209 | scores.append(score)
210 | return sum(scores) / len(scores)
211 |
212 |
213 | def collate_fn(batches, tokenizer):
214 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
215 | questions = [_['question'] for _ in batches]
216 | question_ids = [_['question_id'] for _ in batches]
217 | annotations = [_['annotation'] for _ in batches]
218 |
219 | return pixel_values, questions, question_ids, annotations
220 |
221 |
222 | class VQADataset(torch.utils.data.Dataset):
223 |
224 | def __init__(self, train, test, prompt, few_shot, input_size=224, dynamic_image_size=False,
225 | use_thumbnail=False, max_num=6):
226 | self.test = open(test).readlines()
227 | self.prompt = prompt
228 | self.input_size = input_size
229 | self.dynamic_image_size = dynamic_image_size
230 | self.use_thumbnail = use_thumbnail
231 | self.few_shot = few_shot
232 | self.max_num = max_num
233 | if few_shot > 0:
234 | self.train = open(train).readlines()
235 | self.transform = build_transform(input_size=input_size)
236 |
237 | def __len__(self):
238 | return len(self.test)
239 |
240 | def __getitem__(self, idx):
241 | data = json.loads(self.test[idx].strip())
242 | image, question, question_id, annotation = data['image'], data[
243 | 'question'], data['question_id'], data.get('answer', None)
244 |
245 | few_shot_prompt = ''
246 | if self.few_shot > 0:
247 | few_shot_samples = random.sample(self.train, self.few_shot)
248 | for sample in few_shot_samples:
249 | sample = json.loads(sample.strip())
250 | few_shot_prompt += self.prompt.format(
251 | sample['image'],
252 | sample['question']) + f" {sample['answer']}"
253 |
254 | image = Image.open(image).convert('RGB')
255 | if self.dynamic_image_size:
256 | images = dynamic_preprocess(image, image_size=self.input_size,
257 | use_thumbnail=self.use_thumbnail,
258 | max_num=self.max_num)
259 | else:
260 | images = [image]
261 | pixel_values = [self.transform(image) for image in images]
262 | pixel_values = torch.stack(pixel_values)
263 | if len(self.prompt) != 0:
264 | question = question + ' ' + self.prompt
265 | return {
266 | 'question_id': question_id,
267 | 'question': question,
268 | 'pixel_values': pixel_values,
269 | 'annotation': annotation
270 | }
271 |
272 |
273 | class InferenceSampler(torch.utils.data.sampler.Sampler):
274 |
275 | def __init__(self, size):
276 | self._size = int(size)
277 | assert size > 0
278 | self._rank = torch.distributed.get_rank()
279 | self._world_size = torch.distributed.get_world_size()
280 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
281 |
282 | @staticmethod
283 | def _get_local_indices(total_size, world_size, rank):
284 | shard_size = total_size // world_size
285 | left = total_size % world_size
286 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
287 |
288 | begin = sum(shard_sizes[:rank])
289 | end = min(sum(shard_sizes[:rank + 1]), total_size)
290 | return range(begin, end)
291 |
292 | def __iter__(self):
293 | yield from self._local_indices
294 |
295 | def __len__(self):
296 | return len(self._local_indices)
297 |
298 |
299 | def post_process(response):
300 | response = response.strip().split('.')[0].split(
301 | ',')[0].split('!')[0].lower()
302 | if 'is ' in response:
303 | response = response.split('is ')[1]
304 | if 'are ' in response:
305 | response = response.split('are ')[1]
306 | if 'a ' in response:
307 | response = response.split('a ')[1]
308 | if 'an ' in response:
309 | response = response.split('an ')[1]
310 | if 'the ' in response:
311 | response = response.split('the ')[1]
312 | if ' of' in response:
313 | response = response.split(' of')[0]
314 | response = response.strip()
315 | return response
316 |
317 |
318 | def evaluate_chat_model():
319 | base_prompt = 'Answer the question using a single word or phrase.'
320 | vizwiz_prompt = "When the provided information is insufficient, respond with 'Unanswerable'. "
321 | # infovqa_prompt = 'Answer the question directly.'
322 | infovqa_prompt = 'Answer the question using a single word or phrase.'
323 | ai2d_prompt = ''
324 | random.seed(args.seed)
325 | summaries = []
326 |
327 | for ds_name in args.datasets:
328 | if 'vizwiz' in ds_name:
329 | input_prompt = vizwiz_prompt + base_prompt
330 | elif 'ai2d' in ds_name:
331 | input_prompt = ai2d_prompt
332 | elif 'infographicsvqa' in ds_name:
333 | input_prompt = infovqa_prompt
334 | else:
335 | input_prompt = base_prompt
336 |
337 | dataset = VQADataset(
338 | train=ds_collections[ds_name]['train'],
339 | test=ds_collections[ds_name]['test'],
340 | prompt=input_prompt,
341 | few_shot=args.few_shot,
342 | input_size=image_size,
343 | dynamic_image_size=args.dynamic,
344 | use_thumbnail=use_thumbnail,
345 | max_num=args.max_num
346 | )
347 | dataloader = torch.utils.data.DataLoader(
348 | dataset=dataset,
349 | sampler=InferenceSampler(len(dataset)),
350 | batch_size=args.batch_size,
351 | num_workers=args.num_workers,
352 | pin_memory=True,
353 | drop_last=False,
354 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
355 | )
356 |
357 | outputs = []
358 | for _, (pixel_values, questions, question_ids, annotations) in tqdm(enumerate(dataloader)):
359 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
360 | generation_config = dict(
361 | num_beams=args.num_beams,
362 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
363 | min_new_tokens=1,
364 | do_sample=True if args.temperature > 0 else False,
365 | temperature=args.temperature,
366 | )
367 | pred = model.chat(
368 | tokenizer=tokenizer,
369 | pixel_values=pixel_values,
370 | question=questions[0],
371 | generation_config=generation_config
372 | )
373 | answers = [pred]
374 |
375 | for question, question_id, answer, annotation in zip(questions, question_ids, answers, annotations):
376 | if ds_name in ['vqav2_val', 'vqav2_testdev', 'okvqa_val', 'textvqa_val',
377 | 'vizwiz_val', 'textvqa_val_ocr']:
378 | outputs.append({
379 | 'question': question,
380 | 'question_id': question_id,
381 | 'answer': answer,
382 | })
383 | elif ds_name in ['docvqa_val', 'infographicsvqa_val', 'gqa_testdev', 'ocrvqa_val',
384 | 'ocrvqa_test', 'gqa_testdev_llava', 'infographicsvqa_test',]:
385 | outputs.append({
386 | 'question': question,
387 | 'questionId': question_id,
388 | 'answer': answer,
389 | 'annotation': annotation,
390 | })
391 | elif ds_name in ['ai2diagram_test']:
392 | outputs.append({
393 | 'question': question,
394 | 'image': question_id,
395 | 'answer': answer,
396 | 'annotation': annotation,
397 | })
398 | elif ds_name in ['chartqa_test_human', 'chartqa_test_augmented']:
399 | outputs.append({
400 | 'question': question,
401 | 'answer': answer,
402 | 'annotation': annotation,
403 | })
404 | elif ds_name in ['docvqa_test']:
405 | outputs.append({
406 | 'questionId': question_id,
407 | 'answer': answer,
408 | })
409 | elif ds_name in ['vizwiz_test']:
410 | outputs.append({
411 | 'image': question_id.replace('data/vizwiz/test/', ''),
412 | 'answer': answer,
413 | })
414 | else:
415 | raise NotImplementedError
416 |
417 | torch.distributed.barrier()
418 |
419 | world_size = torch.distributed.get_world_size()
420 | merged_outputs = [None for _ in range(world_size)]
421 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
422 |
423 | merged_outputs = [json.loads(_) for _ in merged_outputs]
424 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
425 |
426 | if torch.distributed.get_rank() == 0:
427 | print(f'Evaluating {ds_name} ...')
428 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
429 | results_file = f'{ds_name}_{time_prefix}.json'
430 | results_file = os.path.join(args.out_dir, results_file)
431 | json.dump(merged_outputs, open(results_file, 'w'))
432 | print('Results saved to {}'.format(results_file))
433 |
434 | if ds_collections[ds_name]['metric'] == 'vqa_score':
435 | evaluator = TextVQAAccuracyEvaluator()
436 | annotation = json.load(open(ds_collections[ds_name]['annotation'], 'r'))['annotations']
437 | question_id2answers = {}
438 | for item in annotation:
439 | question_id = item['question_id']
440 | answers = [answer['answer'] for answer in item['answers']]
441 | question_id2answers[question_id] = answers
442 | for item in merged_outputs:
443 | item['pred_answer'] = item['answer']
444 | item['gt_answers'] = question_id2answers[item['question_id']]
445 | accuracy = evaluator.eval_pred_list(merged_outputs)
446 | print(ds_name, accuracy)
447 | summaries.append([args.checkpoint, ds_name, accuracy])
448 |
449 | elif ds_collections[ds_name]['metric'] == 'anls':
450 | json.dump(merged_outputs,
451 | open(results_file, 'w'),
452 | ensure_ascii=False)
453 | print('python eval/vqa/infographicsvqa_eval.py -g ' +
454 | ds_collections[ds_name]['annotation'] + ' -s ' +
455 | results_file)
456 | os.system('python eval/vqa/infographicsvqa_eval.py -g ' +
457 | ds_collections[ds_name]['annotation'] + ' -s ' +
458 | results_file)
459 | elif ds_collections[ds_name]['metric'] == 'relaxed_accuracy':
460 | relaxed_accuracy = evaluate_relaxed_accuracy(merged_outputs)
461 | print(ds_name, {'relaxed_accuracy': relaxed_accuracy})
462 | summaries.append([ds_name, {'relaxed_accuracy': relaxed_accuracy}])
463 | elif ds_collections[ds_name]['metric'] == 'accuracy':
464 | if 'gqa' in ds_name:
465 | dst_file = './data/gqa/testdev_balanced_predictions.json'
466 | print('python eval/vqa/convert_gqa_for_eval.py --src ' +
467 | results_file + ' --dst ' + dst_file)
468 | python_path = 'python'
469 | os.system(python_path + ' eval/vqa/convert_gqa_for_eval.py --src ' +
470 | results_file + ' --dst ' + dst_file)
471 | command = f'cd ./data/gqa/ && {python_path} eval.py --tier testdev_balanced && cd ../../'
472 | print(command)
473 | accuracy = subprocess.check_output(command, shell=True, universal_newlines=True)
474 | else:
475 | accuracy = {'accuracy': evaluate_exact_match_accuracy(merged_outputs)}
476 | print(ds_name, accuracy)
477 | summaries.append([args.checkpoint, ds_name, accuracy])
478 |
479 | torch.distributed.barrier()
480 |
481 | out_path = '_'.join(args.checkpoint.split('/')[-2:])
482 | writer = open(os.path.join(args.out_dir, f'{out_path}.txt'), 'a')
483 | print(f"write results to file {os.path.join(args.out_dir, f'{out_path}.txt')}")
484 | for summary in summaries:
485 | print(summary)
486 | writer.write(f'{summary}\n')
487 | writer.close()
488 |
489 |
490 | if __name__ == '__main__':
491 |
492 | parser = argparse.ArgumentParser()
493 | parser.add_argument('--checkpoint', type=str, default='')
494 | parser.add_argument('--datasets', type=str,
495 | default='okvqa_val,textvqa_val,vizwiz_val,ai2diagram_test,gqa_testdev_llava')
496 | parser.add_argument('--batch-size', type=int, default=1)
497 | parser.add_argument('--num-workers', type=int, default=1)
498 | parser.add_argument('--num-beams', type=int, default=5)
499 | parser.add_argument('--temperature', type=float, default=0.0)
500 | parser.add_argument('--out-dir', type=str, default='results')
501 | parser.add_argument('--few-shot', type=int, default=0)
502 | parser.add_argument('--seed', type=int, default=0)
503 | parser.add_argument('--dynamic', action='store_true')
504 | parser.add_argument('--max-num', type=int, default=6)
505 | parser.add_argument('--load-in-8bit', action='store_true')
506 | parser.add_argument('--auto', action='store_true')
507 | args = parser.parse_args()
508 |
509 | if not os.path.exists(args.out_dir):
510 | os.makedirs(args.out_dir)
511 |
512 | args.datasets = args.datasets.split(',')
513 | print('datasets:', args.datasets)
514 | assert args.batch_size == 1, 'Only batch size 1 is supported'
515 |
516 | torch.distributed.init_process_group(
517 | backend='nccl',
518 | world_size=int(os.getenv('WORLD_SIZE', '1')),
519 | rank=int(os.getenv('RANK', '0')),
520 | )
521 |
522 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
523 |
524 | if args.auto:
525 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
526 | kwargs = {'device_map': 'auto'} if args.auto else {}
527 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
528 | model = AutoModel.from_pretrained(
529 | args.checkpoint, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
530 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
531 | if not args.load_in_8bit and not args.auto:
532 | model = model.cuda()
533 | image_size = model.config.force_image_size or model.config.vision_config.image_size
534 | use_thumbnail = model.config.use_thumbnail
535 |
536 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
537 | if total_params > 20 or args.dynamic:
538 | args.num_beams = 1
539 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
540 | else:
541 | print(f'[test] total_params: {total_params}B')
542 | print(f'[test] image_size: {image_size}')
543 | print(f'[test] template: {model.config.template}')
544 | print(f'[test] dynamic_image_size: {args.dynamic}')
545 | print(f'[test] use_thumbnail: {use_thumbnail}')
546 |
547 | evaluate_chat_model()
548 |
--------------------------------------------------------------------------------
/eval/vqa/infographicsvqa_eval.py:
--------------------------------------------------------------------------------
1 | # This file can be downloaded from: https://www.docvqa.org/datasets/infographicvqa and https://rrc.cvc.uab.es/?ch=17&com=introduction
2 |
3 | import argparse
4 | import json
5 | import os
6 |
7 | question_ids_to_exclude = []
8 |
9 | # answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span', 'list': 'List'}
10 | answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span',
11 | 'non span': 'None span'}
12 | evidence_types = {'table/list': 'Table/list', 'textual': 'Text', 'photo/pciture/visual_objects': 'Visual/Layout',
13 | 'figure': 'Figure', 'map': 'Map'}
14 | reasoning_requirements = {'comparison': 'Sorting', 'arithmetic': 'Arithmetic', 'counting': 'Counting'}
15 |
16 |
17 | def save_json(file_path, data):
18 | with open(file_path, 'w+') as json_file:
19 | json.dump(data, json_file)
20 |
21 |
22 | def levenshtein_distance(s1, s2):
23 | if len(s1) > len(s2):
24 | s1, s2 = s2, s1
25 |
26 | distances = range(len(s1) + 1)
27 | for i2, c2 in enumerate(s2):
28 | distances_ = [i2 + 1]
29 | for i1, c1 in enumerate(s1):
30 | if c1 == c2:
31 | distances_.append(distances[i1])
32 | else:
33 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
34 | distances = distances_
35 | return distances[-1]
36 |
37 |
38 | def validate_data(gtFilePath, submFilePath):
39 | """
40 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
41 | Validates also that there are no missing files in the folder.
42 | If some error detected, the method raises the error
43 | """
44 |
45 | gtJson = json.load(open(gtFilePath, 'rb'))
46 | submJson = json.load(open(submFilePath, 'rb'))
47 |
48 | if 'data' not in gtJson:
49 | raise Exception('The GT file is not valid (no data key)')
50 |
51 | if 'dataset_name' not in gtJson:
52 | raise Exception('The GT file is not valid (no dataset_name key)')
53 |
54 | if isinstance(submJson, list) is False:
55 | raise Exception('The Det file is not valid (root item must be an array)')
56 |
57 | if len(submJson) != len(gtJson['data']):
58 | raise Exception('The Det file is not valid (invalid number of answers. Expected:' + str(
59 | len(gtJson['data'])) + ' Found:' + str(len(submJson)) + ')')
60 |
61 | gtQuestions = sorted([r['questionId'] for r in gtJson['data']])
62 | res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
63 | detQuestions = sorted([r['questionId'] for r in submJson])
64 |
65 | if ((gtQuestions == detQuestions) is False):
66 | raise Exception('The Det file is not valid. Question IDs must much GT')
67 |
68 | for gtObject in gtJson['data']:
69 |
70 | try:
71 | q_id = int(gtObject['questionId'])
72 | res_ix = res_id_to_index[q_id]
73 |
74 | except:
75 | raise Exception('The Det file is not valid. Question ' + str(gtObject['questionId']) + ' not present')
76 |
77 | else:
78 | detObject = submJson[res_ix]
79 |
80 | # if detObject['questionId'] != gtObject['questionId'] :
81 | # raise Exception("Answer #" + str(i) + " not valid (invalid question ID. Expected:" + str(gtObject['questionId']) + "Found:" + detObject['questionId'] + ")")
82 |
83 | if 'answer' not in detObject:
84 | raise Exception('Question ' + str(gtObject['questionId']) + ' not valid (no answer key)')
85 |
86 | if isinstance(detObject['answer'], list) is True:
87 | raise Exception(
88 | 'Question ' + str(gtObject['questionId']) + ' not valid (answer key has to be a single string)')
89 |
90 |
91 | def evaluate_method(gtFilePath, submFilePath, evaluationParams):
92 | """
93 | Method evaluate_method: evaluate method and returns the results
94 | Results. Dictionary with the following values:
95 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
96 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
97 | """
98 |
99 | show_scores_per_answer_type = evaluationParams.answer_types
100 |
101 | gtJson = json.load(open(gtFilePath, 'rb'))
102 | submJson = json.load(open(submFilePath, 'rb'))
103 |
104 | res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
105 |
106 | perSampleMetrics = {}
107 |
108 | totalScore = 0
109 | row = 0
110 |
111 | if show_scores_per_answer_type:
112 | answerTypeTotalScore = {x: 0 for x in answer_types.keys()}
113 | answerTypeNumQuestions = {x: 0 for x in answer_types.keys()}
114 |
115 | evidenceTypeTotalScore = {x: 0 for x in evidence_types.keys()}
116 | evidenceTypeNumQuestions = {x: 0 for x in evidence_types.keys()}
117 |
118 | reasoningTypeTotalScore = {x: 0 for x in reasoning_requirements.keys()}
119 | reasoningTypeNumQuestions = {x: 0 for x in reasoning_requirements.keys()}
120 |
121 | for gtObject in gtJson['data']:
122 |
123 | q_id = int(gtObject['questionId'])
124 | res_ix = res_id_to_index[q_id]
125 | detObject = submJson[res_ix]
126 |
127 | if q_id in question_ids_to_exclude:
128 | question_result = 0
129 | info = 'Question EXCLUDED from the result'
130 |
131 | else:
132 | info = ''
133 | values = []
134 | for answer in gtObject['answers']:
135 | # preprocess both the answers - gt and prediction
136 | gt_answer = ' '.join(answer.strip().lower().split())
137 | det_answer = ' '.join(detObject['answer'].strip().lower().split())
138 |
139 | # dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
140 | dist = levenshtein_distance(gt_answer, det_answer)
141 | length = max(len(answer.upper()), len(detObject['answer'].upper()))
142 | values.append(0.0 if length == 0 else float(dist) / float(length))
143 |
144 | question_result = 1 - min(values)
145 |
146 | if (question_result < evaluationParams.anls_threshold):
147 | question_result = 0
148 |
149 | totalScore += question_result
150 |
151 | if show_scores_per_answer_type:
152 | for q_type in gtObject['answer_type']:
153 | answerTypeTotalScore[q_type] += question_result
154 | answerTypeNumQuestions[q_type] += 1
155 |
156 | for q_type in gtObject['evidence']:
157 | evidenceTypeTotalScore[q_type] += question_result
158 | evidenceTypeNumQuestions[q_type] += 1
159 |
160 | for q_type in gtObject['operation/reasoning']:
161 | reasoningTypeTotalScore[q_type] += question_result
162 | reasoningTypeNumQuestions[q_type] += 1
163 |
164 | perSampleMetrics[str(gtObject['questionId'])] = {
165 | 'score': question_result,
166 | 'question': gtObject['question'],
167 | 'gt': gtObject['answers'],
168 | 'det': detObject['answer'],
169 | 'info': info
170 | }
171 | row = row + 1
172 |
173 | methodMetrics = {
174 | 'score': 0 if len(gtJson['data']) == 0 else totalScore / (len(gtJson['data']) - len(question_ids_to_exclude))
175 | }
176 |
177 | answer_types_scores = {}
178 | evidence_types_scores = {}
179 | operation_types_scores = {}
180 |
181 | if show_scores_per_answer_type:
182 | for a_type, ref in answer_types.items():
183 | answer_types_scores[ref] = 0 if len(gtJson['data']) == 0 else answerTypeTotalScore[a_type] / (
184 | answerTypeNumQuestions[a_type])
185 |
186 | for e_type, ref in evidence_types.items():
187 | evidence_types_scores[ref] = 0 if len(gtJson['data']) == 0 else evidenceTypeTotalScore[e_type] / (
188 | evidenceTypeNumQuestions[e_type])
189 |
190 | for r_type, ref in reasoning_requirements.items():
191 | operation_types_scores[ref] = 0 if len(gtJson['data']) == 0 else reasoningTypeTotalScore[r_type] / (
192 | reasoningTypeNumQuestions[r_type])
193 |
194 | resDict = {
195 | 'result': methodMetrics,
196 | 'scores_by_types': {'answer_types': answer_types_scores, 'evidence_types': evidence_types_scores,
197 | 'operation_types': operation_types_scores},
198 | 'per_sample_result': perSampleMetrics
199 | }
200 |
201 | return resDict
202 |
203 |
204 | def display_results(results, show_answer_types):
205 | print('\nOverall ANLS: {:2.4f}'.format(results['result']['score']))
206 |
207 | if show_answer_types:
208 | print('\nAnswer types:')
209 | for a_type in answer_types.values():
210 | print('\t{:12s} {:2.4f}'.format(a_type, results['scores_by_types']['answer_types'][a_type]))
211 |
212 | print('\nEvidence types:')
213 | for e_type in evidence_types.values():
214 | print('\t{:12s} {:2.4f}'.format(e_type, results['scores_by_types']['evidence_types'][e_type]))
215 |
216 | print('\nOperation required:')
217 | for r_type in reasoning_requirements.values():
218 | print('\t{:12s} {:2.4f}'.format(r_type, results['scores_by_types']['operation_types'][r_type]))
219 |
220 |
221 | if __name__ == '__main__':
222 | parser = argparse.ArgumentParser(description='InfographVQA evaluation script.')
223 |
224 | parser.add_argument('-g', '--ground_truth', type=str, help='Path of the Ground Truth file.', required=True)
225 | parser.add_argument('-s', '--submission_file', type=str, help="Path of your method's results file.", required=True)
226 |
227 | parser.add_argument('-t', '--anls_threshold', type=float, default=0.5,
228 | help='ANLS threshold to use (See Scene-Text VQA paper for more info.).', required=False)
229 | parser.add_argument('-a', '--answer_types', type=bool, default=False,
230 | help='Score break down by answer types (special gt file required).', required=False)
231 | parser.add_argument('-o', '--output', type=str,
232 | help="Path to a directory where to copy the file 'results.json' that contains per-sample results.",
233 | required=False)
234 |
235 | args = parser.parse_args()
236 |
237 | # Validate the format of ground truth and submission files.
238 | validate_data(args.ground_truth, args.submission_file)
239 |
240 | # Evaluate method
241 | results = evaluate_method(args.ground_truth, args.submission_file, args)
242 |
243 | display_results(results, args.answer_types)
244 |
245 | if args.output:
246 | output_dir = args.output
247 |
248 | if not os.path.exists(output_dir):
249 | os.makedirs(output_dir)
250 |
251 | resultsOutputname = os.path.join(output_dir, 'results.json')
252 | save_json(resultsOutputname, results)
253 |
254 | print('All results including per-sample result has been correctly saved!')
255 |
--------------------------------------------------------------------------------
/eval/vqa/textvqa_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # copied from https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/m4c_evaluator.py
3 | import re
4 |
5 | from tqdm import tqdm
6 |
7 |
8 | class EvalAIAnswerProcessor:
9 | """
10 | Processes an answer similar to Eval AI
11 | copied from
12 | https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
13 | """
14 |
15 | CONTRACTIONS = {
16 | 'aint': "ain't",
17 | 'arent': "aren't",
18 | 'cant': "can't",
19 | 'couldve': "could've",
20 | 'couldnt': "couldn't",
21 | "couldn'tve": "couldn't've",
22 | "couldnt've": "couldn't've",
23 | 'didnt': "didn't",
24 | 'doesnt': "doesn't",
25 | 'dont': "don't",
26 | 'hadnt': "hadn't",
27 | "hadnt've": "hadn't've",
28 | "hadn'tve": "hadn't've",
29 | 'hasnt': "hasn't",
30 | 'havent': "haven't",
31 | 'hed': "he'd",
32 | "hed've": "he'd've",
33 | "he'dve": "he'd've",
34 | 'hes': "he's",
35 | 'howd': "how'd",
36 | 'howll': "how'll",
37 | 'hows': "how's",
38 | "Id've": "I'd've",
39 | "I'dve": "I'd've",
40 | 'Im': "I'm",
41 | 'Ive': "I've",
42 | 'isnt': "isn't",
43 | 'itd': "it'd",
44 | "itd've": "it'd've",
45 | "it'dve": "it'd've",
46 | 'itll': "it'll",
47 | "let's": "let's",
48 | 'maam': "ma'am",
49 | 'mightnt': "mightn't",
50 | "mightnt've": "mightn't've",
51 | "mightn'tve": "mightn't've",
52 | 'mightve': "might've",
53 | 'mustnt': "mustn't",
54 | 'mustve': "must've",
55 | 'neednt': "needn't",
56 | 'notve': "not've",
57 | 'oclock': "o'clock",
58 | 'oughtnt': "oughtn't",
59 | "ow's'at": "'ow's'at",
60 | "'ows'at": "'ow's'at",
61 | "'ow'sat": "'ow's'at",
62 | 'shant': "shan't",
63 | "shed've": "she'd've",
64 | "she'dve": "she'd've",
65 | "she's": "she's",
66 | 'shouldve': "should've",
67 | 'shouldnt': "shouldn't",
68 | "shouldnt've": "shouldn't've",
69 | "shouldn'tve": "shouldn't've",
70 | "somebody'd": 'somebodyd',
71 | "somebodyd've": "somebody'd've",
72 | "somebody'dve": "somebody'd've",
73 | 'somebodyll': "somebody'll",
74 | 'somebodys': "somebody's",
75 | 'someoned': "someone'd",
76 | "someoned've": "someone'd've",
77 | "someone'dve": "someone'd've",
78 | 'someonell': "someone'll",
79 | 'someones': "someone's",
80 | 'somethingd': "something'd",
81 | "somethingd've": "something'd've",
82 | "something'dve": "something'd've",
83 | 'somethingll': "something'll",
84 | 'thats': "that's",
85 | 'thered': "there'd",
86 | "thered've": "there'd've",
87 | "there'dve": "there'd've",
88 | 'therere': "there're",
89 | 'theres': "there's",
90 | 'theyd': "they'd",
91 | "theyd've": "they'd've",
92 | "they'dve": "they'd've",
93 | 'theyll': "they'll",
94 | 'theyre': "they're",
95 | 'theyve': "they've",
96 | 'twas': "'twas",
97 | 'wasnt': "wasn't",
98 | "wed've": "we'd've",
99 | "we'dve": "we'd've",
100 | 'weve': "we've",
101 | 'werent': "weren't",
102 | 'whatll': "what'll",
103 | 'whatre': "what're",
104 | 'whats': "what's",
105 | 'whatve': "what've",
106 | 'whens': "when's",
107 | 'whered': "where'd",
108 | 'wheres': "where's",
109 | 'whereve': "where've",
110 | 'whod': "who'd",
111 | "whod've": "who'd've",
112 | "who'dve": "who'd've",
113 | 'wholl': "who'll",
114 | 'whos': "who's",
115 | 'whove': "who've",
116 | 'whyll': "why'll",
117 | 'whyre': "why're",
118 | 'whys': "why's",
119 | 'wont': "won't",
120 | 'wouldve': "would've",
121 | 'wouldnt': "wouldn't",
122 | "wouldnt've": "wouldn't've",
123 | "wouldn'tve": "wouldn't've",
124 | 'yall': "y'all",
125 | "yall'll": "y'all'll",
126 | "y'allll": "y'all'll",
127 | "yall'd've": "y'all'd've",
128 | "y'alld've": "y'all'd've",
129 | "y'all'dve": "y'all'd've",
130 | 'youd': "you'd",
131 | "youd've": "you'd've",
132 | "you'dve": "you'd've",
133 | 'youll': "you'll",
134 | 'youre': "you're",
135 | 'youve': "you've",
136 | }
137 |
138 | NUMBER_MAP = {
139 | 'none': '0',
140 | 'zero': '0',
141 | 'one': '1',
142 | 'two': '2',
143 | 'three': '3',
144 | 'four': '4',
145 | 'five': '5',
146 | 'six': '6',
147 | 'seven': '7',
148 | 'eight': '8',
149 | 'nine': '9',
150 | 'ten': '10',
151 | }
152 | ARTICLES = ['a', 'an', 'the']
153 | PERIOD_STRIP = re.compile(r'(?!<=\d)(\.)(?!\d)')
154 | COMMA_STRIP = re.compile(r'(?<=\d)(\,)+(?=\d)')
155 | PUNCTUATIONS = [
156 | ';',
157 | r'/',
158 | '[',
159 | ']',
160 | '"',
161 | '{',
162 | '}',
163 | '(',
164 | ')',
165 | '=',
166 | '+',
167 | '\\',
168 | '_',
169 | '-',
170 | '>',
171 | '<',
172 | '@',
173 | '`',
174 | ',',
175 | '?',
176 | '!',
177 | ]
178 |
179 | def __init__(self, *args, **kwargs):
180 | pass
181 |
182 | def word_tokenize(self, word):
183 | word = word.lower()
184 | word = word.replace(',', '').replace('?', '').replace("'s", " 's")
185 | return word.strip()
186 |
187 | def process_punctuation(self, in_text):
188 | out_text = in_text
189 | for p in self.PUNCTUATIONS:
190 | if (p + ' ' in in_text or ' ' + p in in_text) or (
191 | re.search(self.COMMA_STRIP, in_text) is not None
192 | ):
193 | out_text = out_text.replace(p, '')
194 | else:
195 | out_text = out_text.replace(p, ' ')
196 | out_text = self.PERIOD_STRIP.sub('', out_text, re.UNICODE)
197 | return out_text
198 |
199 | def process_digit_article(self, in_text):
200 | out_text = []
201 | temp_text = in_text.lower().split()
202 | for word in temp_text:
203 | word = self.NUMBER_MAP.setdefault(word, word)
204 | if word not in self.ARTICLES:
205 | out_text.append(word)
206 | else:
207 | pass
208 | for word_id, word in enumerate(out_text):
209 | if word in self.CONTRACTIONS:
210 | out_text[word_id] = self.CONTRACTIONS[word]
211 | out_text = ' '.join(out_text)
212 | return out_text
213 |
214 | def __call__(self, item):
215 | item = self.word_tokenize(item)
216 | item = item.replace('\n', ' ').replace('\t', ' ').strip()
217 | item = self.process_punctuation(item)
218 | item = self.process_digit_article(item)
219 | return item
220 |
221 |
222 | class TextVQAAccuracyEvaluator:
223 | def __init__(self):
224 | self.answer_processor = EvalAIAnswerProcessor()
225 |
226 | def _compute_answer_scores(self, raw_answers):
227 | """
228 | compute the accuracy (soft score) of human answers
229 | """
230 | answers = [self.answer_processor(a) for a in raw_answers]
231 | assert len(answers) == 10
232 | gt_answers = list(enumerate(answers))
233 | unique_answers = set(answers)
234 | unique_answer_scores = {}
235 |
236 | for unique_answer in unique_answers:
237 | accs = []
238 | for gt_answer in gt_answers:
239 | other_answers = [item for item in gt_answers if item != gt_answer]
240 | matching_answers = [
241 | item for item in other_answers if item[1] == unique_answer
242 | ]
243 | acc = min(1, float(len(matching_answers)) / 3)
244 | accs.append(acc)
245 | unique_answer_scores[unique_answer] = sum(accs) / len(accs)
246 |
247 | return unique_answer_scores
248 |
249 | def eval_pred_list(self, pred_list):
250 | pred_scores = []
251 | for entry in tqdm(pred_list):
252 | pred_answer = self.answer_processor(entry['pred_answer'])
253 | unique_answer_scores = self._compute_answer_scores(entry['gt_answers'])
254 | score = unique_answer_scores.get(pred_answer, 0.0)
255 | pred_scores.append(score)
256 |
257 | accuracy = sum(pred_scores) / len(pred_scores)
258 | return accuracy
259 |
260 |
261 | class STVQAAccuracyEvaluator:
262 | def __init__(self):
263 | self.answer_processor = EvalAIAnswerProcessor()
264 |
265 | def eval_pred_list(self, pred_list):
266 | pred_scores = []
267 | for entry in pred_list:
268 | pred_answer = self.answer_processor(entry['pred_answer'])
269 | gts = [self.answer_processor(a) for a in entry['gt_answers']]
270 | score = 1.0 if pred_answer in gts else 0.0
271 | pred_scores.append(score)
272 |
273 | accuracy = sum(pred_scores) / len(pred_scores)
274 | return accuracy
275 |
276 |
277 | class STVQAANLSEvaluator:
278 | def __init__(self):
279 | import editdistance # install with `pip install editdistance`
280 |
281 | self.get_edit_distance = editdistance.eval
282 |
283 | def get_anls(self, s1, s2):
284 | s1 = s1.lower().strip()
285 | s2 = s2.lower().strip()
286 | iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
287 | anls = iou if iou >= 0.5 else 0.0
288 | return anls
289 |
290 | def eval_pred_list(self, pred_list):
291 | pred_scores = []
292 | for entry in pred_list:
293 | anls = max(
294 | self.get_anls(entry['pred_answer'], gt) for gt in entry['gt_answers']
295 | )
296 | pred_scores.append(anls)
297 |
298 | accuracy = sum(pred_scores) / len(pred_scores)
299 | return accuracy
300 |
301 |
302 | class TextCapsBleu4Evaluator:
303 | def __init__(self):
304 | # The following script requires Java 1.8.0 and pycocotools installed.
305 | # The pycocoevalcap can be installed with pip as
306 | # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
307 | # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
308 | # but has no python3 support yet.
309 | try:
310 | from pycocoevalcap.bleu.bleu import Bleu
311 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
312 | except ModuleNotFoundError:
313 | print(
314 | 'Please install pycocoevalcap module using '
315 | 'pip install git+https://github.com/ronghanghu/coco-caption.git@python23' # noqa
316 | )
317 | raise
318 |
319 | self.tokenizer = PTBTokenizer()
320 | self.scorer = Bleu(4)
321 |
322 | def eval_pred_list(self, pred_list):
323 | # Create reference and hypotheses captions.
324 | gts = {}
325 | res = {}
326 | for idx, entry in enumerate(pred_list):
327 | gts[idx] = [{'caption': a} for a in entry['gt_answers']]
328 | res[idx] = [{'caption': entry['pred_answer']}]
329 |
330 | gts = self.tokenizer.tokenize(gts)
331 | res = self.tokenizer.tokenize(res)
332 | score, _ = self.scorer.compute_score(gts, res)
333 |
334 | bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
335 | return bleu4
336 |
--------------------------------------------------------------------------------
/evaluate.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | CHECKPOINT=${1}
4 | DATASET=${2}
5 | # CHECKPOINT="$(pwd)/${CHECKPOINT}"
6 | CHECKPOINT=${CHECKPOINT}
7 | export PYTHONPATH="$(pwd):${PYTHONPATH}"
8 | echo "CHECKPOINT: ${CHECKPOINT}"
9 |
10 | MASTER_PORT=${MASTER_PORT:-63669}
11 | PORT=${PORT:-63665}
12 | GPUS=${GPUS:-8}
13 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
14 | NODES=$((GPUS / GPUS_PER_NODE))
15 | export MASTER_PORT=${MASTER_PORT}
16 | export PORT=${PORT}
17 |
18 | # Save original arguments
19 | ARGS=("$@")
20 |
21 | # Parse options
22 | while [[ $# -gt 0 ]]; do
23 | case "$1" in
24 | --auto)
25 | GPUS=1
26 | shift
27 | ;;
28 | *)
29 | shift
30 | ;;
31 | esac
32 | done
33 | echo "GPUS: ${GPUS}"
34 |
35 | # 检查a是否以斜杠结尾,如果是则去掉斜杠
36 | if [[ "${CHECKPOINT}" == */ ]]; then
37 | CHECKPOINT="${CHECKPOINT%/}"
38 | fi
39 |
40 | # dir to save result files
41 | mkdir -p "$CHECKPOINT/eval"
42 |
43 | if [ ${DATASET} == "mme" ]; then
44 | cd eval/mme/
45 | DIRNAME=`basename ${CHECKPOINT}`
46 | python eval.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
47 | python calculation.py --results_dir ${CHECKPOINT}/mme 2>&1 | tee -a ${CHECKPOINT}/mme/result_final.txt
48 | cd ../../
49 | fi
50 |
51 | if [ ${DATASET} == "caption" ]; then
52 | torchrun \
53 | --nnodes=1 \
54 | --node_rank=0 \
55 | --master_addr=127.0.0.1 \
56 | --nproc_per_node=${GPUS} \
57 | --master_port=${MASTER_PORT} \
58 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval "${ARGS[@]:2}"
59 | fi
60 |
61 | if [ ${DATASET} == "caption-coco" ]; then
62 | torchrun \
63 | --nnodes=1 \
64 | --node_rank=0 \
65 | --master_addr=127.0.0.1 \
66 | --nproc_per_node=${GPUS} \
67 | --master_port=${MASTER_PORT} \
68 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets coco "${ARGS[@]:2}"
69 | fi
70 |
71 | if [ ${DATASET} == "caption-flickr30k" ]; then
72 | torchrun \
73 | --nnodes=1 \
74 | --node_rank=0 \
75 | --master_addr=127.0.0.1 \
76 | --nproc_per_node=${GPUS} \
77 | --master_port=${MASTER_PORT} \
78 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets flickr30k "${ARGS[@]:2}"
79 | fi
80 |
81 | if [ ${DATASET} == "caption-nocaps" ]; then
82 | torchrun \
83 | --nnodes=1 \
84 | --node_rank=0 \
85 | --master_addr=127.0.0.1 \
86 | --nproc_per_node=${GPUS} \
87 | --master_port=${MASTER_PORT} \
88 | eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets nocaps "${ARGS[@]:2}"
89 | fi
90 |
91 | if [ ${DATASET} == "vqa" ]; then
92 | torchrun \
93 | --nnodes=1 \
94 | --node_rank=0 \
95 | --master_addr=127.0.0.1 \
96 | --nproc_per_node=${GPUS} \
97 | --master_port=${MASTER_PORT} \
98 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval "${ARGS[@]:2}"
99 | fi
100 |
101 | if [ ${DATASET} == "vqa-okvqa-val" ]; then
102 | torchrun \
103 | --nnodes=1 \
104 | --node_rank=0 \
105 | --master_addr=127.0.0.1 \
106 | --nproc_per_node=${GPUS} \
107 | --master_port=${MASTER_PORT} \
108 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets okvqa_val "${ARGS[@]:2}"
109 | fi
110 |
111 | if [ ${DATASET} == "vqa-textvqa-val" ]; then
112 | torchrun \
113 | --nnodes=1 \
114 | --node_rank=0 \
115 | --master_addr=127.0.0.1 \
116 | --nproc_per_node=${GPUS} \
117 | --master_port=${MASTER_PORT} \
118 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets textvqa_val "${ARGS[@]:2}"
119 | fi
120 |
121 | if [ ${DATASET} == "vqa-textvqa-val-ocr" ]; then
122 | torchrun \
123 | --nnodes=1 \
124 | --node_rank=0 \
125 | --master_addr=127.0.0.1 \
126 | --nproc_per_node=${GPUS} \
127 | --master_port=${MASTER_PORT} \
128 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets textvqa_val_ocr "${ARGS[@]:2}"
129 | fi
130 |
131 | if [ ${DATASET} == "vqa-vizwiz-val" ]; then
132 | torchrun \
133 | --nnodes=1 \
134 | --node_rank=0 \
135 | --master_addr=127.0.0.1 \
136 | --nproc_per_node=${GPUS} \
137 | --master_port=${MASTER_PORT} \
138 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vizwiz_val "${ARGS[@]:2}"
139 | fi
140 |
141 | if [ ${DATASET} == "vqa-vizwiz-test" ]; then
142 | torchrun \
143 | --nnodes=1 \
144 | --node_rank=0 \
145 | --master_addr=127.0.0.1 \
146 | --nproc_per_node=${GPUS} \
147 | --master_port=${MASTER_PORT} \
148 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vizwiz_test "${ARGS[@]:2}"
149 | fi
150 |
151 | if [ ${DATASET} == "vqa-vqav2-testdev" ]; then
152 | torchrun \
153 | --nnodes=1 \
154 | --node_rank=0 \
155 | --master_addr=127.0.0.1 \
156 | --nproc_per_node=${GPUS} \
157 | --master_port=${MASTER_PORT} \
158 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vqav2_testdev "${ARGS[@]:2}"
159 | fi
160 |
161 | if [ ${DATASET} == "vqa-ai2d-test" ]; then
162 | torchrun \
163 | --nnodes=1 \
164 | --node_rank=0 \
165 | --master_addr=127.0.0.1 \
166 | --nproc_per_node=${GPUS} \
167 | --master_port=${MASTER_PORT} \
168 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ai2diagram_test "${ARGS[@]:2}"
169 | fi
170 |
171 | if [ ${DATASET} == "vqa-vqav2-val" ]; then
172 | torchrun \
173 | --nnodes=1 \
174 | --node_rank=0 \
175 | --master_addr=127.0.0.1 \
176 | --nproc_per_node=${GPUS} \
177 | --master_port=${MASTER_PORT} \
178 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets vqav2_val "${ARGS[@]:2}"
179 | fi
180 |
181 | if [ ${DATASET} == "vqa-gqa-testdev" ]; then
182 | torchrun \
183 | --nnodes=1 \
184 | --node_rank=0 \
185 | --master_addr=127.0.0.1 \
186 | --nproc_per_node=${GPUS} \
187 | --master_port=${MASTER_PORT} \
188 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets gqa_testdev_llava "${ARGS[@]:2}"
189 | fi
190 |
191 | if [ ${DATASET} == "vqa-docvqa-val" ]; then
192 | torchrun \
193 | --nnodes=1 \
194 | --node_rank=0 \
195 | --master_addr=127.0.0.1 \
196 | --nproc_per_node=${GPUS} \
197 | --master_port=${MASTER_PORT} \
198 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets docvqa_val "${ARGS[@]:2}"
199 | fi
200 |
201 | if [ ${DATASET} == "vqa-docvqa-test" ]; then
202 | torchrun \
203 | --nnodes=1 \
204 | --node_rank=0 \
205 | --master_addr=127.0.0.1 \
206 | --nproc_per_node=${GPUS} \
207 | --master_port=${MASTER_PORT} \
208 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets docvqa_test "${ARGS[@]:2}"
209 | fi
210 |
211 | if [ ${DATASET} == "vqa-chartqa-test" ]; then
212 | torchrun \
213 | --nnodes=1 \
214 | --node_rank=0 \
215 | --master_addr=127.0.0.1 \
216 | --nproc_per_node=${GPUS} \
217 | --master_port=${MASTER_PORT} \
218 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets chartqa_test_human,chartqa_test_augmented "${ARGS[@]:2}"
219 | fi
220 |
221 | if [ ${DATASET} == "vqa-infovqa-val" ]; then
222 | torchrun \
223 | --nnodes=1 \
224 | --node_rank=0 \
225 | --master_addr=127.0.0.1 \
226 | --nproc_per_node=${GPUS} \
227 | --master_port=${MASTER_PORT} \
228 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets infographicsvqa_val "${ARGS[@]:2}"
229 | fi
230 |
231 | if [ ${DATASET} == "vqa-infovqa-test" ]; then
232 | torchrun \
233 | --nnodes=1 \
234 | --node_rank=0 \
235 | --master_addr=127.0.0.1 \
236 | --nproc_per_node=${GPUS} \
237 | --master_port=${MASTER_PORT} \
238 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets infographicsvqa_test "${ARGS[@]:2}"
239 | fi
240 |
241 | if [ ${DATASET} == "vqa-chartqa-test-human" ]; then
242 | torchrun \
243 | --nnodes=1 \
244 | --node_rank=0 \
245 | --master_addr=127.0.0.1 \
246 | --nproc_per_node=${GPUS} \
247 | --master_port=${MASTER_PORT} \
248 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets chartqa_test_human "${ARGS[@]:2}"
249 | fi
250 |
251 | if [ ${DATASET} == "vqa-chartqa-test-augmented" ]; then
252 | torchrun \
253 | --nnodes=1 \
254 | --node_rank=0 \
255 | --master_addr=127.0.0.1 \
256 | --nproc_per_node=${GPUS} \
257 | --master_port=${MASTER_PORT} \
258 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets chartqa_test_augmented "${ARGS[@]:2}"
259 | fi
260 |
261 | if [ ${DATASET} == "vqa-ocrvqa-val" ]; then
262 | torchrun \
263 | --nnodes=1 \
264 | --node_rank=0 \
265 | --master_addr=127.0.0.1 \
266 | --nproc_per_node=${GPUS} \
267 | --master_port=${MASTER_PORT} \
268 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ocrvqa_val "${ARGS[@]:2}"
269 | fi
270 |
271 | if [ ${DATASET} == "vqa-ocrvqa-test" ]; then
272 | torchrun \
273 | --nnodes=1 \
274 | --node_rank=0 \
275 | --master_addr=127.0.0.1 \
276 | --nproc_per_node=${GPUS} \
277 | --master_port=${MASTER_PORT} \
278 | eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ocrvqa_test "${ARGS[@]:2}"
279 | fi
280 |
281 | if [ ${DATASET} == "refcoco" ]; then
282 | torchrun \
283 | --nnodes=1 \
284 | --node_rank=0 \
285 | --master_addr=127.0.0.1 \
286 | --nproc_per_node=${GPUS} \
287 | --master_port=${MASTER_PORT} \
288 | eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval "${ARGS[@]:2}"
289 | fi
290 |
291 | if [ ${DATASET} == "refcoco-val" ]; then
292 | torchrun \
293 | --nnodes=1 \
294 | --node_rank=0 \
295 | --master_addr=127.0.0.1 \
296 | --nproc_per_node=${GPUS} \
297 | --master_port=${MASTER_PORT} \
298 | eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets refcoco_val "${ARGS[@]:2}"
299 | fi
300 |
301 | if [ ${DATASET} == "llava-bench" ]; then
302 | rm -rf results/llava_bench_results_review.jsonl
303 | python eval/llava_bench/evaluate_llava_bench.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
304 | python -u eval/llava_bench/eval_gpt_review_bench.py \
305 | --question data/llava-bench-in-the-wild/questions.jsonl \
306 | --context data/llava-bench-in-the-wild/context.jsonl \
307 | --rule eval/llava_bench/rule.json \
308 | --answer-list \
309 | data/llava-bench-in-the-wild/answers_gpt4.jsonl \
310 | results/llava_bench_results.jsonl \
311 | --output \
312 | results/llava_bench_results_review.jsonl
313 | python -u eval/llava_bench/summarize_gpt_review.py -f results/llava_bench_results_review.jsonl
314 | fi
315 |
316 | if [ ${DATASET} == "pope" ]; then
317 | torchrun \
318 | --nnodes=1 \
319 | --node_rank=0 \
320 | --master_addr=127.0.0.1 \
321 | --nproc_per_node=${GPUS} \
322 | --master_port=${MASTER_PORT} \
323 | eval/pope/evaluate_pope.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets pope "${ARGS[@]:2}"
324 | fi
325 |
326 | if [ ${DATASET} == "tiny_lvlm" ]; then
327 | torchrun \
328 | --nnodes=1 \
329 | --node_rank=0 \
330 | --master_addr=127.0.0.1 \
331 | --nproc_per_node=${GPUS} \
332 | --master_port=${MASTER_PORT} \
333 | eval/tiny_lvlm/evaluate_lvlm.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets updated_datasets "${ARGS[@]:2}"
334 | fi
335 |
336 | if [ ${DATASET} == "mmvet" ]; then
337 | python eval/mmvet/evaluate_mmvet.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmvet "${ARGS[@]:2}"
338 | fi
339 |
340 | if [ ${DATASET} == "cmmmu" ]; then
341 | CUDA_VISIBLE_DEVICES=0 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets art_and_design "${ARGS[@]:2}" &
342 | CUDA_VISIBLE_DEVICES=1 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets business "${ARGS[@]:2}" &
343 | CUDA_VISIBLE_DEVICES=2 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets health_and_medicine "${ARGS[@]:2}" &
344 | CUDA_VISIBLE_DEVICES=3 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets humanities_and_social_sciences "${ARGS[@]:2}" &
345 | CUDA_VISIBLE_DEVICES=4 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets science "${ARGS[@]:2}" &
346 | CUDA_VISIBLE_DEVICES=5 python eval/cmmmu/evaluate_cmmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets technology_and_engineering "${ARGS[@]:2}" &
347 | wait
348 | fi
349 |
350 | if [ ${DATASET} == "mmbench-dev-en" ]; then
351 | torchrun \
352 | --nnodes=1 \
353 | --node_rank=0 \
354 | --master_addr=127.0.0.1 \
355 | --nproc_per_node=${GPUS} \
356 | --master_port=${MASTER_PORT} \
357 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmbench_dev_20230712 "${ARGS[@]:2}"
358 | fi
359 |
360 | if [ ${DATASET} == "mmbench-dev-cn" ]; then
361 | torchrun \
362 | --nnodes=1 \
363 | --node_rank=0 \
364 | --master_addr=127.0.0.1 \
365 | --nproc_per_node=${GPUS} \
366 | --master_port=${MASTER_PORT} \
367 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmbench_dev_cn_20231003 "${ARGS[@]:2}"
368 | fi
369 |
370 | if [ ${DATASET} == "mmbench-test-en" ]; then
371 | torchrun \
372 | --nnodes=1 \
373 | --node_rank=0 \
374 | --master_addr=127.0.0.1 \
375 | --nproc_per_node=${GPUS} \
376 | --master_port=${MASTER_PORT} \
377 | eval/mmbench/evaluate_mmbench.py \
378 | --checkpoint ${CHECKPOINT} \
379 | --out-dir ${CHECKPOINT}/eval \
380 | --datasets mmbench_test_en_20231003 "${ARGS[@]:2}"
381 | fi
382 |
383 | if [ ${DATASET} == "mmbench-test-cn" ]; then
384 | torchrun \
385 | --nnodes=1 \
386 | --node_rank=0 \
387 | --master_addr=127.0.0.1 \
388 | --nproc_per_node=${GPUS} \
389 | --master_port=${MASTER_PORT} \
390 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets mmbench_test_cn_20231003 "${ARGS[@]:2}"
391 | fi
392 |
393 | if [ ${DATASET} == "ccbench-dev" ]; then
394 | torchrun \
395 | --nnodes=1 \
396 | --node_rank=0 \
397 | --master_addr=127.0.0.1 \
398 | --nproc_per_node=${GPUS} \
399 | --master_port=${MASTER_PORT} \
400 | eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets ccbench_dev_cn "${ARGS[@]:2}"
401 | fi
402 |
403 | if [ ${DATASET} == "scienceqa" ]; then
404 | torchrun \
405 | --nnodes=1 \
406 | --node_rank=0 \
407 | --master_addr=127.0.0.1 \
408 | --nproc_per_node=${GPUS} \
409 | --master_port=${MASTER_PORT} \
410 | eval/scienceqa/evaluate_scienceqa.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets sqa_test "${ARGS[@]:2}"
411 | fi
412 |
413 |
414 | if [ ${DATASET} == "mmmu-dev" ]; then
415 | torchrun \
416 | --nnodes=1 \
417 | --node_rank=0 \
418 | --master_addr=127.0.0.1 \
419 | --nproc_per_node=${GPUS} \
420 | --master_port=${MASTER_PORT} \
421 | eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_dev "${ARGS[@]:2}"
422 | fi
423 |
424 | if [ ${DATASET} == "mmmu-val" ]; then
425 | torchrun \
426 | --nnodes=1 \
427 | --node_rank=0 \
428 | --master_addr=127.0.0.1 \
429 | --nproc_per_node=${GPUS} \
430 | --master_port=${MASTER_PORT} \
431 | eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_validation "${ARGS[@]:2}"
432 | fi
433 |
434 | if [ ${DATASET} == "mmmu-test" ]; then
435 | torchrun \
436 | --nnodes=1 \
437 | --node_rank=0 \
438 | --master_addr=127.0.0.1 \
439 | --nproc_per_node=${GPUS} \
440 | --master_port=${MASTER_PORT} \
441 | eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_test "${ARGS[@]:2}"
442 | fi
443 |
444 | if [ ${DATASET} == "mmmu-dev-cot" ]; then
445 | torchrun \
446 | --nnodes=1 \
447 | --node_rank=0 \
448 | --master_addr=127.0.0.1 \
449 | --nproc_per_node=${GPUS} \
450 | --master_port=${MASTER_PORT} \
451 | eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_dev "${ARGS[@]:2}"
452 | fi
453 |
454 | if [ ${DATASET} == "mmmu-val-cot" ]; then
455 | torchrun \
456 | --nnodes=1 \
457 | --node_rank=0 \
458 | --master_addr=127.0.0.1 \
459 | --nproc_per_node=${GPUS} \
460 | --master_port=${MASTER_PORT} \
461 | eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_validation "${ARGS[@]:2}"
462 | fi
463 |
464 | if [ ${DATASET} == "mmmu-test-cot" ]; then
465 | torchrun \
466 | --nnodes=1 \
467 | --node_rank=0 \
468 | --master_addr=127.0.0.1 \
469 | --nproc_per_node=${GPUS} \
470 | --master_port=${MASTER_PORT} \
471 | eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMMU_test "${ARGS[@]:2}"
472 | fi
473 |
474 |
475 | if [ ${DATASET} == "mmvp" ]; then
476 | torchrun \
477 | --nnodes=1 \
478 | --node_rank=0 \
479 | --master_addr=127.0.0.1 \
480 | --nproc_per_node=${GPUS} \
481 | --master_port=${MASTER_PORT} \
482 | eval/mmvp/evaluate_mmvp.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MMVP "${ARGS[@]:2}"
483 | fi
484 |
485 |
486 | if [ ${DATASET} == "mathvista-testmini" ]; then
487 | torchrun \
488 | --nnodes=1 \
489 | --node_rank=0 \
490 | --master_addr=127.0.0.1 \
491 | --nproc_per_node=${GPUS} \
492 | --master_port=${MASTER_PORT} \
493 | eval/mathvista/evaluate_mathvista.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MathVista_testmini "${ARGS[@]:2}"
494 | fi
495 |
496 |
497 | if [ ${DATASET} == "mathvista-test" ]; then
498 | torchrun \
499 | --nnodes=1 \
500 | --node_rank=0 \
501 | --master_addr=127.0.0.1 \
502 | --nproc_per_node=${GPUS} \
503 | --master_port=${MASTER_PORT} \
504 | eval/mathvista/evaluate_mathvista.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets MathVista_test "${ARGS[@]:2}"
505 | fi
506 |
507 | if [ ${DATASET} == "seed" ]; then
508 | torchrun \
509 | --nnodes=1 \
510 | --node_rank=0 \
511 | --master_addr=127.0.0.1 \
512 | --nproc_per_node=${GPUS} \
513 | --master_port=${MASTER_PORT} \
514 | eval/seed/evaluate_seed.py --checkpoint ${CHECKPOINT} --out-dir ${CHECKPOINT}/eval --datasets SEEDv1 "${ARGS[@]:2}"
515 | fi
516 |
517 | if [ ${DATASET} == "mvbench" ]; then
518 | torchrun \
519 | --nnodes=1 \
520 | --node_rank=0 \
521 | --master_addr=127.0.0.1 \
522 | --nproc_per_node=${GPUS} \
523 | --master_port=${MASTER_PORT} \
524 | eval/mvbench/evaluate_mvbench.py \
525 | --checkpoint ${CHECKPOINT} \
526 | --out-dir ${CHECKPOINT}/eval \
527 | "${ARGS[@]:2}"
528 | fi
529 |
--------------------------------------------------------------------------------
/evaluate_launch.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | CHECKPOINT=${1}
4 | TASK=${2}
5 | ARGS=("$@")
6 |
7 | if [ "${TASK}" == "vqa-chartqa-test" ]; then
8 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 12 "${ARGS[@]:3}"
9 | elif [ "${TASK}" == "vqa-infovqa-val" -o "${TASK}" == "vqa-infovqa-test" ]; then
10 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 24 "${ARGS[@]:3}"
11 | elif [ "${TASK}" == "vqa-docvqa-val" -o "${TASK}" == "vqa-docvqa-test" ]; then
12 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 18 "${ARGS[@]:3}"
13 | elif [ "${TASK}" == "mvbench" ]; then
14 | sh evaluate.sh ${CHECKPOINT} ${TASK} --num_segments 96 "${ARGS[@]:3}"
15 | else
16 | sh evaluate.sh ${CHECKPOINT} ${TASK} --dynamic --max-num 6 "${ARGS[@]:3}"
17 | fi
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | bitsandbytes==0.41.0
3 | decord
4 | deepspeed==0.13.5
5 | einops==0.6.1
6 | einops-exts==0.0.4
7 | huggingface_hub
8 | imageio
9 | numpy
10 | opencv-python
11 | orjson
12 | peft>=0.4.0
13 | pycocoevalcap
14 | pyyaml
15 | scikit-learn>=1.2.2
16 | scipy
17 | sentencepiece==0.1.99
18 | shortuuid
19 | tensorboardX
20 | termcolor
21 | timm==0.9.12
22 | tokenizers==0.15.1
23 | torch>=2
24 | torchvision>=0.15
25 | tqdm
26 | transformers==4.37.2
27 | yacs
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision.transforms as T
4 | from decord import VideoReader, cpu
5 | from PIL import Image
6 | from torchvision.transforms.functional import InterpolationMode
7 |
8 |
9 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
10 | IMAGENET_STD = (0.229, 0.224, 0.225)
11 |
12 |
13 | def build_transform(input_size):
14 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
15 | transform = T.Compose([
16 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
17 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
18 | T.ToTensor(),
19 | T.Normalize(mean=MEAN, std=STD)
20 | ])
21 | return transform
22 |
23 |
24 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
25 | best_ratio_diff = float('inf')
26 | best_ratio = (1, 1)
27 | area = width * height
28 | for ratio in target_ratios:
29 | target_aspect_ratio = ratio[0] / ratio[1]
30 | ratio_diff = abs(aspect_ratio - target_aspect_ratio)
31 | if ratio_diff < best_ratio_diff:
32 | best_ratio_diff = ratio_diff
33 | best_ratio = ratio
34 | elif ratio_diff == best_ratio_diff:
35 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
36 | best_ratio = ratio
37 | return best_ratio
38 |
39 |
40 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
41 | orig_width, orig_height = image.size
42 | aspect_ratio = orig_width / orig_height
43 |
44 | # calculate the existing image aspect ratio
45 | target_ratios = set(
46 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
47 | i * j <= max_num and i * j >= min_num)
48 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
49 |
50 | # find the closest aspect ratio to the target
51 | target_aspect_ratio = find_closest_aspect_ratio(
52 | aspect_ratio, target_ratios, orig_width, orig_height, image_size)
53 |
54 | # calculate the target width and height
55 | target_width = image_size * target_aspect_ratio[0]
56 | target_height = image_size * target_aspect_ratio[1]
57 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
58 |
59 | # resize the image
60 | resized_img = image.resize((target_width, target_height))
61 | processed_images = []
62 | for i in range(blocks):
63 | box = (
64 | (i % (target_width // image_size)) * image_size,
65 | (i // (target_width // image_size)) * image_size,
66 | ((i % (target_width // image_size)) + 1) * image_size,
67 | ((i // (target_width // image_size)) + 1) * image_size
68 | )
69 | # split the image
70 | split_img = resized_img.crop(box)
71 | processed_images.append(split_img)
72 | assert len(processed_images) == blocks
73 | if use_thumbnail and len(processed_images) != 1:
74 | thumbnail_img = image.resize((image_size, image_size))
75 | processed_images.append(thumbnail_img)
76 | return processed_images
77 |
78 |
79 | def load_image(image_file, input_size=448, max_num=12):
80 | image = Image.open(image_file).convert('RGB')
81 | transform = build_transform(input_size=input_size)
82 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
83 | pixel_values = [transform(image) for image in images]
84 | pixel_values = torch.stack(pixel_values)
85 | return pixel_values
86 |
87 |
88 | def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
89 | if bound:
90 | start, end = bound[0], bound[1]
91 | else:
92 | start, end = -100000, 100000
93 | start_idx = max(first_idx, round(start * fps))
94 | end_idx = min(round(end * fps), max_frame)
95 | seg_size = float(end_idx - start_idx) / num_segments
96 | frame_indices = np.array([
97 | int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
98 | for idx in range(num_segments)
99 | ])
100 | return frame_indices
101 |
102 |
103 | def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
104 | vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
105 | max_frame = len(vr) - 1
106 | fps = float(vr.get_avg_fps())
107 |
108 | pixel_values_list, num_patches_list = [], []
109 | transform = build_transform(input_size=input_size)
110 | frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
111 | for frame_index in frame_indices:
112 | img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
113 | img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
114 | pixel_values = [transform(tile) for tile in img]
115 | pixel_values = torch.stack(pixel_values)
116 | num_patches_list.append(pixel_values.shape[0])
117 | pixel_values_list.append(pixel_values)
118 | pixel_values = torch.cat(pixel_values_list)
119 | return pixel_values, num_patches_list
120 |
--------------------------------------------------------------------------------