├── vlmeval
├── evaluate
│ ├── __init__.py
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ └── misc-checkpoint.py
│ ├── __pycache__
│ │ ├── misc.cpython-310.pyc
│ │ ├── OCRBench.cpython-310.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── coco_eval.cpython-310.pyc
│ │ ├── vqa_eval.cpython-310.pyc
│ │ ├── yes_or_no.cpython-310.pyc
│ │ ├── llavabench.cpython-310.pyc
│ │ ├── mmvet_eval.cpython-310.pyc
│ │ ├── mathvista_eval.cpython-310.pyc
│ │ └── multiple_choice.cpython-310.pyc
│ ├── misc.py
│ └── multiple_choice.py
├── __pycache__
│ ├── config.cpython-310.pyc
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-38.pyc
│ └── inference_multi.cpython-310.pyc
├── smp
│ ├── __pycache__
│ │ ├── lb.cpython-310.pyc
│ │ ├── vlm.cpython-38.pyc
│ │ ├── file.cpython-310.pyc
│ │ ├── file.cpython-38.pyc
│ │ ├── log.cpython-310.pyc
│ │ ├── misc.cpython-310.pyc
│ │ ├── misc.cpython-38.pyc
│ │ ├── vlm.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ └── __init__.cpython-310.pyc
│ ├── __init__.py
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ ├── log-checkpoint.py
│ │ ├── misc-checkpoint.py
│ │ ├── vlm-checkpoint.py
│ │ ├── file-checkpoint.py
│ │ └── lb-checkpoint.py
│ ├── log.py
│ ├── misc.py
│ ├── vlm.py
│ ├── file.py
│ └── lb.py
├── api
│ ├── __pycache__
│ │ ├── base.cpython-310.pyc
│ │ ├── gpt.cpython-310.pyc
│ │ ├── gemini.cpython-310.pyc
│ │ ├── gpt_int.cpython-310.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── qwen_api.cpython-310.pyc
│ │ ├── qwen_vl_api.cpython-310.pyc
│ │ └── hf_chat_model.cpython-310.pyc
│ ├── __init__.py
│ ├── base.py
│ ├── qwen_api.py
│ ├── qwen_vl_api.py
│ ├── gemini.py
│ ├── gpt_int.py
│ ├── gpt.py
│ └── hf_chat_model.py
├── vlm
│ ├── __pycache__
│ │ ├── emu.cpython-310.pyc
│ │ ├── cogvlm.cpython-310.pyc
│ │ ├── idefics.cpython-310.pyc
│ │ ├── llava.cpython-310.pyc
│ │ ├── mmalaya.cpython-310.pyc
│ │ ├── monkey.cpython-310.pyc
│ │ ├── omnilmm.cpython-310.pyc
│ │ ├── qwen_vl.cpython-310.pyc
│ │ ├── yi_vl.cpython-310.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── minicpm_v.cpython-310.pyc
│ │ ├── minigpt4.cpython-310.pyc
│ │ ├── pandagpt.cpython-310.pyc
│ │ ├── visualglm.cpython-310.pyc
│ │ ├── xcomposer.cpython-310.pyc
│ │ ├── instructblip.cpython-310.pyc
│ │ ├── llava_xtuner.cpython-310.pyc
│ │ ├── mplug_owl2.cpython-310.pyc
│ │ ├── transcore_m.cpython-310.pyc
│ │ ├── xcomposer2.cpython-310.pyc
│ │ ├── internvl_chat.cpython-310.pyc
│ │ ├── open_flamingo.cpython-310.pyc
│ │ └── sharedcaptioner.cpython-310.pyc
│ ├── __init__.py
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ └── instructblip-checkpoint.py
│ └── instructblip.py
├── utils
│ ├── __pycache__
│ │ ├── dataset.cpython-310.pyc
│ │ ├── debate.cpython-310.pyc
│ │ ├── mp_util.cpython-310.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── base_prompt.cpython-310.pyc
│ │ ├── custom_prompt.cpython-310.pyc
│ │ ├── dataset_config.cpython-310.pyc
│ │ └── matching_util.cpython-310.pyc
│ ├── __init__.py
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ ├── custom_prompt-checkpoint.py
│ │ ├── matching_util-checkpoint.py
│ │ ├── debate-checkpoint.py
│ │ ├── dataset_config-checkpoint.py
│ │ ├── base_prompt-checkpoint.py
│ │ ├── mp_util-checkpoint.py
│ │ └── dataset-checkpoint.py
│ ├── custom_prompt.py
│ ├── matching_util.py
│ ├── debate.py
│ ├── dataset_config.py
│ ├── base_prompt.py
│ ├── mp_util.py
│ └── dataset.py
├── __init__.py
├── .ipynb_checkpoints
│ ├── __init__-checkpoint.py
│ ├── config-checkpoint.py
│ └── inference_multi-checkpoint.py
├── config.py
└── inference_multi.py
├── assets
├── Model1.png
├── overview.png
└── framework.png
├── scripts
└── debate.sh
├── requirements.txt
├── README.md
├── run.py
└── setup.py
/vlmeval/evaluate/__init__.py:
--------------------------------------------------------------------------------
1 | from .multiple_choice import multiple_choice_eval
--------------------------------------------------------------------------------
/assets/Model1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/assets/Model1.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/assets/overview.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/assets/framework.png
--------------------------------------------------------------------------------
/vlmeval/evaluate/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .multiple_choice import multiple_choice_eval
--------------------------------------------------------------------------------
/vlmeval/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/lb.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/lb.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/vlm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/vlm.cpython-38.pyc
--------------------------------------------------------------------------------
/vlmeval/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/gpt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/gpt.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__init__.py:
--------------------------------------------------------------------------------
1 | from .file import *
2 | from .vlm import *
3 | from .misc import *
4 | from .log import *
5 | from .lb import *
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/file.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/file.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/file.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/file.cpython-38.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/log.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/log.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/misc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/misc.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/vlm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/vlm.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/emu.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/emu.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/gemini.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/gemini.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/gpt_int.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/gpt_int.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/cogvlm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/cogvlm.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/idefics.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/idefics.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/llava.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/llava.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/mmalaya.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/mmalaya.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/monkey.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/monkey.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/omnilmm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/omnilmm.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/qwen_vl.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/qwen_vl.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/yi_vl.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/yi_vl.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/qwen_api.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/qwen_api.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/misc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/misc.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/debate.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/debate.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/mp_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/mp_util.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/minicpm_v.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/minicpm_v.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/minigpt4.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/minigpt4.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/pandagpt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/pandagpt.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/visualglm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/visualglm.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/xcomposer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/xcomposer.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/__pycache__/inference_multi.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/inference_multi.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/qwen_vl_api.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/qwen_vl_api.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_grad_enabled(False)
4 | torch.manual_seed(1234)
5 | from .instructblip import InstructBLIP
6 |
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/instructblip.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/instructblip.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/llava_xtuner.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/llava_xtuner.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/mplug_owl2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/mplug_owl2.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/transcore_m.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/transcore_m.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/xcomposer2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/xcomposer2.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/api/__pycache__/hf_chat_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/hf_chat_model.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/OCRBench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/OCRBench.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/coco_eval.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/coco_eval.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/vqa_eval.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/vqa_eval.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/yes_or_no.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/yes_or_no.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/base_prompt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/base_prompt.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/internvl_chat.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/internvl_chat.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/open_flamingo.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/open_flamingo.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/llavabench.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/llavabench.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/mmvet_eval.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/mmvet_eval.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/custom_prompt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/custom_prompt.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/dataset_config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/dataset_config.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/utils/__pycache__/matching_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/matching_util.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/__pycache__/sharedcaptioner.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/sharedcaptioner.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/mathvista_eval.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/mathvista_eval.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/smp/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .file import *
2 | from .vlm import *
3 | from .misc import *
4 | from .log import *
5 | from .lb import *
--------------------------------------------------------------------------------
/vlmeval/evaluate/__pycache__/multiple_choice.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/multiple_choice.cpython-310.pyc
--------------------------------------------------------------------------------
/vlmeval/vlm/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_grad_enabled(False)
4 | torch.manual_seed(1234)
5 | from .instructblip import InstructBLIP
6 |
--------------------------------------------------------------------------------
/vlmeval/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | import torch
3 | except ImportError:
4 | pass
5 |
6 | from .smp import *
7 | from .evaluate import *
8 | from .utils import *
9 | from .api import *
10 | from .vlm import *
11 | from .config import *
--------------------------------------------------------------------------------
/vlmeval/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | try:
2 | import torch
3 | except ImportError:
4 | pass
5 |
6 | from .smp import *
7 | from .evaluate import *
8 | from .utils import *
9 | from .api import *
10 | from .vlm import *
11 | from .config import *
--------------------------------------------------------------------------------
/scripts/debate.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node=2 run.py --data ScienceQA_TEST \
2 | --model instructblip_13b \
3 | --stage BDebate_kg_test \
4 | --debate 2
5 | --kg_init
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | json
2 | numpy>=1.23.4
3 | requests
4 | tqdm
5 | pandas>=1.5.3
6 | gradio==4.15.0
7 | tiktoken
8 | rich
9 | portalocker
10 | torch>=2.0.1
11 | timeout-decorator
12 | transformers>=4.36.2
13 | opencv-python>=4.4.0.46
14 | typing_extensions==4.7.1
15 | pillow
16 | omegaconf
17 | matplotlib
18 | einops
19 | sentencepiece
20 | sty
21 | huggingface_hub
22 | visual_genome
23 | pycocoevalcap
24 | openpyxl
25 | seaborn
26 | tabulate
27 | xlsxwriter
--------------------------------------------------------------------------------
/vlmeval/api/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt import OpenAIWrapper, GPT4V
2 | from .gpt_int import OpenAIWrapperInternal, GPT4V_Internal
3 | from .hf_chat_model import HFChatModel
4 | from .gemini import GeminiWrapper, GeminiProVision
5 | from .qwen_vl_api import QwenVLWrapper, QwenVLAPI
6 | from .qwen_api import QwenAPI
7 |
8 | __all__ = [
9 | 'OpenAIWrapper', 'HFChatModel', 'OpenAIWrapperInternal', 'GeminiWrapper',
10 | 'GPT4V', 'GPT4V_Internal', 'GeminiProVision','QwenVLWrapper', 'QwenVLAPI',
11 | 'QwenAPI'
12 | ]
--------------------------------------------------------------------------------
/vlmeval/config.py:
--------------------------------------------------------------------------------
1 | from .vlm import *
2 | from .api import GPT4V, GeminiProVision
3 | from functools import partial
4 |
5 | models = {
6 | 'instructblip_13b': partial(InstructBLIP, name='/code/BaseModel/VLM/instructblip-vicuna-13b'),
7 | }
8 |
9 | api_models = {
10 | 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10),
11 | 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10),
12 | }
13 |
14 | supported_VLM = {}
15 | for model_set in [models, api_models]:
16 | supported_VLM.update(model_set)
17 |
--------------------------------------------------------------------------------
/vlmeval/.ipynb_checkpoints/config-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .vlm import *
2 | from .api import GPT4V, GeminiProVision
3 | from functools import partial
4 |
5 | models = {
6 | 'instructblip_13b': partial(InstructBLIP, name='/code/BaseModel/VLM/instructblip-vicuna-13b'),
7 | }
8 |
9 | api_models = {
10 | 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10),
11 | 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10),
12 | }
13 |
14 | supported_VLM = {}
15 | for model_set in [models, api_models]:
16 | supported_VLM.update(model_set)
17 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal
3 |
4 | INTERNAL = os.environ.get('INTERNAL', 0)
5 |
6 | def build_judge(version, **kwargs):
7 | model_map = {
8 | 'gpt-4-turbo': 'gpt-4-1106-preview',
9 | 'gpt-4-0613': 'gpt-4-0613',
10 | 'gpt-4-0314': 'gpt-4-0314',
11 | 'chatgpt-1106': 'gpt-3.5-turbo-1106',
12 | 'chatgpt-0613': 'gpt-3.5-turbo-0613'
13 | }
14 | model_version = model_map[version]
15 | if INTERNAL:
16 | model = OpenAIWrapperInternal(model_version, **kwargs)
17 | else:
18 | model = OpenAIWrapper(model_version, **kwargs)
19 | return model
20 |
--------------------------------------------------------------------------------
/vlmeval/evaluate/.ipynb_checkpoints/misc-checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal
3 |
4 | INTERNAL = os.environ.get('INTERNAL', 0)
5 |
6 | def build_judge(version, **kwargs):
7 | model_map = {
8 | 'gpt-4-turbo': 'gpt-4-1106-preview',
9 | 'gpt-4-0613': 'gpt-4-0613',
10 | 'gpt-4-0314': 'gpt-4-0314',
11 | 'chatgpt-1106': 'gpt-3.5-turbo-1106',
12 | 'chatgpt-0613': 'gpt-3.5-turbo-0613'
13 | }
14 | model_version = model_map[version]
15 | if INTERNAL:
16 | model = OpenAIWrapperInternal(model_version, **kwargs)
17 | else:
18 | model = OpenAIWrapper(model_version, **kwargs)
19 | return model
20 |
--------------------------------------------------------------------------------
/vlmeval/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .matching_util import can_infer, can_infer_option, can_infer_text
2 | from .mp_util import track_progress_rich
3 | from .custom_prompt import CustomPrompt
4 | from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
5 | from .dataset import TSVDataset, split_MMMU, init_prompt_multi
6 | from .base_prompt import create_one_example
7 | from .debate import Debate_VLM
8 |
9 |
10 | __all__ = [
11 | 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
12 | 'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
13 | 'split_MMMU', 'abbr2full', 'init_prompt_multi',
14 | 'create_one_example', 'Debate_VLM'
15 | ]
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .matching_util import can_infer, can_infer_option, can_infer_text
2 | from .mp_util import track_progress_rich
3 | from .custom_prompt import CustomPrompt
4 | from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full
5 | from .dataset import TSVDataset, split_MMMU, init_prompt_multi
6 | from .base_prompt import create_one_example
7 | from .debate import Debate_VLM
8 |
9 |
10 | __all__ = [
11 | 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
12 | 'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
13 | 'split_MMMU', 'abbr2full', 'init_prompt_multi',
14 | 'create_one_example', 'Debate_VLM'
15 | ]
--------------------------------------------------------------------------------
/vlmeval/utils/custom_prompt.py:
--------------------------------------------------------------------------------
1 | from ..smp import *
2 | from .dataset_config import img_root_map
3 | from abc import abstractmethod
4 |
5 | class CustomPrompt:
6 |
7 | @abstractmethod
8 | def use_custom_prompt(self, dataset):
9 | raise NotImplementedError
10 |
11 | @abstractmethod
12 | def build_prompt(self, line, dataset):
13 | raise NotImplementedError
14 |
15 | def dump_image(self, line, dataset):
16 | ROOT = LMUDataRoot()
17 | assert isinstance(dataset, str)
18 | img_root = osp.join(ROOT, 'images', img_root_map[dataset])
19 | os.makedirs(img_root, exist_ok=True)
20 | if isinstance(line['image'], list):
21 | tgt_path = []
22 | assert 'image_path' in line
23 | for img, im_name in zip(line['image'], line['image_path']):
24 | path = osp.join(img_root, im_name)
25 | if not read_ok(path):
26 | decode_base64_to_image_file(img, path)
27 | tgt_path.append(path)
28 | else:
29 | tgt_path = osp.join(img_root, f"{line['index']}.jpg")
30 | if not read_ok(tgt_path):
31 | decode_base64_to_image_file(line['image'], tgt_path)
32 | return tgt_path
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/custom_prompt-checkpoint.py:
--------------------------------------------------------------------------------
1 | from ..smp import *
2 | from .dataset_config import img_root_map
3 | from abc import abstractmethod
4 |
5 | class CustomPrompt:
6 |
7 | @abstractmethod
8 | def use_custom_prompt(self, dataset):
9 | raise NotImplementedError
10 |
11 | @abstractmethod
12 | def build_prompt(self, line, dataset):
13 | raise NotImplementedError
14 |
15 | def dump_image(self, line, dataset):
16 | ROOT = LMUDataRoot()
17 | assert isinstance(dataset, str)
18 | img_root = osp.join(ROOT, 'images', img_root_map[dataset])
19 | os.makedirs(img_root, exist_ok=True)
20 | if isinstance(line['image'], list):
21 | tgt_path = []
22 | assert 'image_path' in line
23 | for img, im_name in zip(line['image'], line['image_path']):
24 | path = osp.join(img_root, im_name)
25 | if not read_ok(path):
26 | decode_base64_to_image_file(img, path)
27 | tgt_path.append(path)
28 | else:
29 | tgt_path = osp.join(img_root, f"{line['index']}.jpg")
30 | if not read_ok(tgt_path):
31 | decode_base64_to_image_file(line['image'], tgt_path)
32 | return tgt_path
--------------------------------------------------------------------------------
/vlmeval/smp/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger_initialized = {}
4 |
5 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='a'):
6 | logger = logging.getLogger(name)
7 | if name in logger_initialized:
8 | return logger
9 |
10 | for logger_name in logger_initialized:
11 | if name.startswith(logger_name):
12 | return logger
13 |
14 | stream_handler = logging.StreamHandler()
15 | handlers = [stream_handler]
16 |
17 | try:
18 | import torch.distributed as dist
19 | if dist.is_available() and dist.is_initialized():
20 | rank = dist.get_rank()
21 | else:
22 | rank = 0
23 | except ImportError:
24 | rank = 0
25 |
26 | if rank == 0 and log_file is not None:
27 | file_handler = logging.FileHandler(log_file, file_mode)
28 | handlers.append(file_handler)
29 |
30 | formatter = logging.Formatter(
31 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32 | for handler in handlers:
33 | handler.setFormatter(formatter)
34 | handler.setLevel(log_level)
35 | logger.addHandler(handler)
36 |
37 | if rank == 0:
38 | logger.setLevel(log_level)
39 | else:
40 | logger.setLevel(logging.ERROR)
41 |
42 | logger_initialized[name] = True
43 | return logger
--------------------------------------------------------------------------------
/vlmeval/smp/.ipynb_checkpoints/log-checkpoint.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger_initialized = {}
4 |
5 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='a'):
6 | logger = logging.getLogger(name)
7 | if name in logger_initialized:
8 | return logger
9 |
10 | for logger_name in logger_initialized:
11 | if name.startswith(logger_name):
12 | return logger
13 |
14 | stream_handler = logging.StreamHandler()
15 | handlers = [stream_handler]
16 |
17 | try:
18 | import torch.distributed as dist
19 | if dist.is_available() and dist.is_initialized():
20 | rank = dist.get_rank()
21 | else:
22 | rank = 0
23 | except ImportError:
24 | rank = 0
25 |
26 | if rank == 0 and log_file is not None:
27 | file_handler = logging.FileHandler(log_file, file_mode)
28 | handlers.append(file_handler)
29 |
30 | formatter = logging.Formatter(
31 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32 | for handler in handlers:
33 | handler.setFormatter(formatter)
34 | handler.setLevel(log_level)
35 | logger.addHandler(handler)
36 |
37 | if rank == 0:
38 | logger.setLevel(log_level)
39 | else:
40 | logger.setLevel(logging.ERROR)
41 |
42 | logger_initialized[name] = True
43 | return logger
--------------------------------------------------------------------------------
/vlmeval/vlm/instructblip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from abc import abstractproperty
4 | import os.path as osp
5 | import random
6 | import os, sys
7 | from ..smp import *
8 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
9 |
10 |
11 | class InstructBLIP:
12 |
13 | INSTALL_REQ = True
14 |
15 | def __init__(self, name):
16 |
17 | model = InstructBlipForConditionalGeneration.from_pretrained(name, torch_dtype=torch.float16, device_map='cuda').eval()
18 | processor = InstructBlipProcessor.from_pretrained(name)
19 | self.processors = processor
20 | self.model = model
21 |
22 | def generate(self, prompt, image_path, dataset=None, max_length=100):
23 | raw_image = Image.open(image_path).convert('RGB')
24 | inputs = self.processors(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
25 | try:
26 | outputs = self.model.generate(
27 | **inputs,
28 | do_sample=False,
29 | num_beams=5,
30 | max_length=max_length,
31 | min_length=1,
32 | top_p=0.9,
33 | repetition_penalty=1.5,
34 | length_penalty=1.0,
35 | temperature=1)
36 | generated_text = self.processors.batch_decode(outputs, skip_special_tokens=True)[0].strip()
37 | except Exception as e:
38 | generated_text = random.choice(['A','B','C','D','E'])
39 | print("Prompt_Error: ", e)
40 |
41 |
42 | return generated_text
43 |
--------------------------------------------------------------------------------
/vlmeval/vlm/.ipynb_checkpoints/instructblip-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from abc import abstractproperty
4 | import os.path as osp
5 | import random
6 | import os, sys
7 | from ..smp import *
8 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
9 |
10 |
11 | class InstructBLIP:
12 |
13 | INSTALL_REQ = True
14 |
15 | def __init__(self, name):
16 |
17 | model = InstructBlipForConditionalGeneration.from_pretrained(name, torch_dtype=torch.float16, device_map='cuda').eval()
18 | processor = InstructBlipProcessor.from_pretrained(name)
19 | self.processors = processor
20 | self.model = model
21 |
22 | def generate(self, prompt, image_path, dataset=None, max_length=100):
23 | raw_image = Image.open(image_path).convert('RGB')
24 | inputs = self.processors(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
25 | try:
26 | outputs = self.model.generate(
27 | **inputs,
28 | do_sample=False,
29 | num_beams=5,
30 | max_length=max_length,
31 | min_length=1,
32 | top_p=0.9,
33 | repetition_penalty=1.5,
34 | length_penalty=1.0,
35 | temperature=1)
36 | generated_text = self.processors.batch_decode(outputs, skip_special_tokens=True)[0].strip()
37 | except Exception as e:
38 | generated_text = random.choice(['A','B','C','D','E'])
39 | print("Prompt_Error: ", e)
40 |
41 |
42 | return generated_text
43 |
--------------------------------------------------------------------------------
/vlmeval/utils/matching_util.py:
--------------------------------------------------------------------------------
1 | import string
2 | import copy as cp
3 | import os
4 | from ..smp import *
5 |
6 | def can_infer_option(answer, choices):
7 | verbose = os.environ.get('VERBOSE', 0)
8 | # Choices is a dictionary
9 | if 'Failed to obtain answer via API' in answer:
10 | return False
11 |
12 | reject_to_answer = [
13 | "Sorry, I can't help with images of people yet.",
14 | "I can't process this file.",
15 | "I'm sorry, but without the image provided",
16 | "Cannot determine the answer"
17 | ]
18 | for err in reject_to_answer:
19 | if err in answer:
20 | return 'Z'
21 |
22 | def count_choice(splits, choices, prefix='', suffix=''):
23 | cnt = 0
24 | for c in choices:
25 | if prefix + c + suffix in splits:
26 | cnt += 1
27 | return cnt
28 |
29 | answer_mod = cp.copy(answer)
30 | chars = '.()[],:;!*#{}'
31 | for c in chars:
32 | answer_mod = answer_mod.replace(c, ' ')
33 |
34 | splits = [x.strip() for x in answer_mod.split()]
35 | count = count_choice(splits, choices)
36 |
37 | if count == 1:
38 | for ch in choices:
39 | if 'A' in splits and len(splits) > 3 and verbose:
40 | logger = get_logger('Evaluation')
41 | logger.info(f'A might be a quantifier in the string: {answer}.')
42 | return False
43 | if ch in splits:
44 | return ch
45 | elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
46 | return 'Z'
47 | return False
48 |
49 | def can_infer_text(answer, choices):
50 | answer = answer.lower()
51 | assert isinstance(choices, dict)
52 | for k in choices:
53 | assert k in string.ascii_uppercase
54 | choices[k] = str(choices[k]).lower()
55 | cands = []
56 | for k in choices:
57 | if choices[k] in answer:
58 | cands.append(k)
59 | if len(cands) == 1:
60 | return cands[0]
61 | return False
62 |
63 | def can_infer(answer, choices):
64 | answer = str(answer)
65 | copt = can_infer_option(answer, choices)
66 | return copt if copt else can_infer_text(answer, choices)
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/matching_util-checkpoint.py:
--------------------------------------------------------------------------------
1 | import string
2 | import copy as cp
3 | import os
4 | from ..smp import *
5 |
6 | def can_infer_option(answer, choices):
7 | verbose = os.environ.get('VERBOSE', 0)
8 | # Choices is a dictionary
9 | if 'Failed to obtain answer via API' in answer:
10 | return False
11 |
12 | reject_to_answer = [
13 | "Sorry, I can't help with images of people yet.",
14 | "I can't process this file.",
15 | "I'm sorry, but without the image provided",
16 | "Cannot determine the answer"
17 | ]
18 | for err in reject_to_answer:
19 | if err in answer:
20 | return 'Z'
21 |
22 | def count_choice(splits, choices, prefix='', suffix=''):
23 | cnt = 0
24 | for c in choices:
25 | if prefix + c + suffix in splits:
26 | cnt += 1
27 | return cnt
28 |
29 | answer_mod = cp.copy(answer)
30 | chars = '.()[],:;!*#{}'
31 | for c in chars:
32 | answer_mod = answer_mod.replace(c, ' ')
33 |
34 | splits = [x.strip() for x in answer_mod.split()]
35 | count = count_choice(splits, choices)
36 |
37 | if count == 1:
38 | for ch in choices:
39 | if 'A' in splits and len(splits) > 3 and verbose:
40 | logger = get_logger('Evaluation')
41 | logger.info(f'A might be a quantifier in the string: {answer}.')
42 | return False
43 | if ch in splits:
44 | return ch
45 | elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
46 | return 'Z'
47 | return False
48 |
49 | def can_infer_text(answer, choices):
50 | answer = answer.lower()
51 | assert isinstance(choices, dict)
52 | for k in choices:
53 | assert k in string.ascii_uppercase
54 | choices[k] = str(choices[k]).lower()
55 | cands = []
56 | for k in choices:
57 | if choices[k] in answer:
58 | cands.append(k)
59 | if len(cands) == 1:
60 | return cands[0]
61 | return False
62 |
63 | def can_infer(answer, choices):
64 | answer = str(answer)
65 | copt = can_infer_option(answer, choices)
66 | return copt if copt else can_infer_text(answer, choices)
--------------------------------------------------------------------------------
/vlmeval/api/base.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random as rd
3 | from abc import abstractmethod
4 | from ..smp import get_logger
5 |
6 | class BaseAPI:
7 |
8 | def __init__(self,
9 | retry=10,
10 | wait=3,
11 | system_prompt=None,
12 | verbose=True,
13 | fail_msg='Failed to obtain answer via API.',
14 | **kwargs):
15 | self.wait = wait
16 | self.retry = retry
17 | self.system_prompt = system_prompt
18 | self.kwargs = kwargs
19 | self.verbose = verbose
20 | self.fail_msg = fail_msg
21 | self.logger = get_logger('ChatAPI')
22 | if len(kwargs):
23 | self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
24 | self.logger.info(f'Will try to use them as kwargs for `generate`. ')
25 |
26 | @abstractmethod
27 | def generate_inner(self, inputs, **kwargs):
28 | self.logger.warning(f'For APIBase, generate_inner is an abstract method. ')
29 | assert 0, 'generate_inner not defined'
30 | ret_code, answer, log = None, None, None
31 | # if ret_code is 0, means succeed
32 | return ret_code, answer, log
33 |
34 | def generate(self, inputs, **kwargs):
35 | input_type = None
36 | if isinstance(inputs, str):
37 | input_type = 'str'
38 | elif isinstance(inputs, list) and isinstance(inputs[0], str):
39 | input_type = 'strlist'
40 | elif isinstance(inputs, list) and isinstance(inputs[0], dict):
41 | input_type = 'dictlist'
42 | assert input_type is not None, input_type
43 |
44 | answer = None
45 | # a very small random delay [0s - 0.5s]
46 | T = rd.random() * 0.5
47 | time.sleep(T)
48 |
49 | for i in range(self.retry):
50 | try:
51 | ret_code, answer, log = self.generate_inner(inputs, **kwargs)
52 | if ret_code == 0 and self.fail_msg not in answer and answer != '':
53 | if self.verbose:
54 | print(answer)
55 | return answer
56 | elif self.verbose:
57 | self.logger.info(f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}")
58 | except Exception as err:
59 | if self.verbose:
60 | self.logger.error(f'An error occured during try {i}:')
61 | self.logger.error(err)
62 | # delay before each retry
63 | T = rd.random() * self.wait * 2
64 | time.sleep(T)
65 |
66 | return self.fail_msg if answer in ['', None] else answer
67 |
--------------------------------------------------------------------------------
/vlmeval/api/qwen_api.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | import os
3 | from vlmeval.api.base import BaseAPI
4 | from vlmeval.smp import *
5 |
6 | class QwenAPI(BaseAPI):
7 |
8 | is_api: bool = True
9 |
10 | def __init__(self,
11 | model: str = 'qwen-max-1201',
12 | retry: int = 5,
13 | wait: int = 5,
14 | verbose: bool = True,
15 | seed: int = 2680,
16 | temperature: float = 0.0,
17 | system_prompt: str = None,
18 | key: str = None,
19 | max_tokens: int = 1024,
20 | proxy: str = None,
21 | **kwargs):
22 |
23 | assert model in ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-1201', 'qwen-max-longcontext']
24 | self.model = model
25 | import dashscope
26 | self.fail_msg = 'Failed to obtain answer via API. '
27 | self.max_tokens = max_tokens
28 | self.temperature = temperature
29 | self.seed = seed
30 | if key is None:
31 | key = os.environ.get('DASHSCOPE_API_KEY', None)
32 | assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)"
33 | dashscope.api_key = key
34 | if proxy is not None:
35 | proxy_set(proxy)
36 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
37 |
38 | @staticmethod
39 | def build_msgs(msgs_raw, system_prompt=None):
40 | msgs = cp.deepcopy(msgs_raw)
41 | ret = []
42 | if system_prompt is not None:
43 | ret.append(dict(role='system', content=system_prompt))
44 | for i, msg in enumerate(msgs):
45 | role = 'user' if i % 2 == 0 else 'assistant'
46 | ret.append(dict(role=role, content=msg))
47 | return ret
48 |
49 | def generate_inner(self, inputs, **kwargs) -> str:
50 | from dashscope import MultiModalConversation
51 | assert isinstance(inputs, str) or isinstance(inputs, list)
52 | inputs = [inputs] if isinstance(inputs, str) else inputs
53 | messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
54 |
55 | import dashscope
56 | response = dashscope.Generation.call(
57 | model=self.model,
58 | messages=messages,
59 | seed=self.seed,
60 | temperature=self.temperature,
61 | max_tokens=self.max_tokens,
62 | result_format='message', # set the result to be "message" format.
63 | )
64 | if response.status_code != HTTPStatus.OK:
65 | return -1, 'Error: Bad Response Statuse Code. ', f'The response status code is {response.status_code}. '
66 |
67 | try:
68 | return 0, response['output']['choices'][0]['message']['content'].strip(), 'Succeeded! '
69 | except Exception as err:
70 | return -1, f'Error: Failed to parse the response. {err}', response
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | #
A Picture Is Worth a Graph: A Blueprint Debate Paradigm for Multimodal Reasoning
5 |
6 | [Changmeng Zheng](https://github.com/thecharm)
1, [Dayong Liang](https://github.com/YongLD)
2, [Wengyu Zhang](https://github.com/zhangwengyu999)
1, [Xiao-Yong Wei](https://scholar.google.com/citations?user=8kxWTokAAAAJ&hl=en)
*,1, [Tat-Seng Chua](https://scholar.google.com.sg/citations?user=Z9DWCBEAAAAJ&hl=en)
3, [Qing Li](https://scholar.google.com/citations?user=D1LEg-YAAAAJ&hl=en)
1
7 |
8 |
1The Hong Kong Polytechnic 2South China University of Technology 3National University of Singapore
9 |
*Corresponding author
10 |
11 |
12 | [](https://arxiv.org/pdf/2403.14972)
13 |
14 |
15 |
16 | 
17 |
18 |
19 | Blueprint Debate-on-Graph (BDoG)
20 |
21 | ## 🔥News
22 |
23 | 🔥 __[2024.10]__ Our paper has been nominated as the best paper award!\
24 | 🔥 __[2024.07]__ The paper and Code are released!
25 |
26 | ## 🚀 Method
27 |
28 | 
29 |
30 |
31 | ## 🏗️ QuickStart
32 | ### 1. Installation
33 | ```bash
34 | git clone https://github.com/thecharm/BDoG.git
35 | cd BDoG
36 | pip install -e .
37 | ```
38 | ### 2. Download model weights
39 | Download the [model weights](https://huggingface.co/Salesforce/instructblip-vicuna-13b) and set the model path in the `BDoG/vlmeval/config.py` file
40 |
41 |
42 | ### 3. Running
43 | ```
44 | torchrun --nproc_per_node=1 run.py --data ScienceQA_TEST \
45 | --stage BDebate \
46 | --debate 2
47 | ```
48 | + `--data`
49 | + Dataset supported: `ScienceQA_TEST` and `MMBench_DEV_EN`.
50 | + `--stage`
51 | + Prompt Type: `BDebate`(Blueprint Debate on Graph) or `ODebate`(Debate without Graph).
52 | + `--debate`
53 | + Number of rounds for the debate.
54 | + `--kg_init`
55 | + (optional) Use Gemini Graph as the initialization for multi-round debates.
56 | + `--nproc_per_node=2`
57 | + (optional) Speed up the inference process if you have two GPUs.
58 | + `--openai`
59 | + (optional) Use the Openai API key to perform the final result validation.
60 |
61 | The results are saved in the `BDoG/results/instructblip_13b` folder.
62 |
63 | During this process, the datasets will be automatically downloaded to the `/root/LMUData/` directory. If you need to change the data storage path, please reset `--lmudata`.
64 |
65 | ## ❤️ Acknowledgments
66 | - [VLMEvalKit](https://github.com/open-compass/VLMEvalKit): An open-source evaluation toolkit of large vision-language models (LVLMs).
67 | - [LLaVA](https://github.com/haotian-liu/LLaVA): Wounderful MLLM based on Large Language and Vision Assistant.
68 | - [LAVIS](https://github.com/salesforce/LAVIS): The amazing open-sourced multimodality learning codebase.
69 |
70 |
71 | ## 📑 Citation
72 |
73 | If this repo is useful to you, please cite using this BibTeX.
74 | ```bibtex
75 | @inproceedings{zheng2024picture,
76 | title={A Picture Is Worth a Graph: A Blueprint Debate Paradigm for Multimodal Reasoning},
77 | author={Zheng, Changmeng and Liang, Dayong and Zhang, Wengyu and Wei, Xiao-Yong and Chua, Tat-Seng and Li, Qing},
78 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia},
79 | pages={419--428},
80 | year={2024}
81 | }
82 | ```
83 |
--------------------------------------------------------------------------------
/vlmeval/api/qwen_vl_api.py:
--------------------------------------------------------------------------------
1 | from vlmeval.smp import *
2 | from vlmeval.api.base import BaseAPI
3 |
4 | class QwenVLWrapper(BaseAPI):
5 |
6 | is_api: bool = True
7 |
8 | def __init__(self,
9 | model: str = 'qwen-vl-plus',
10 | retry: int = 5,
11 | wait: int = 5,
12 | key: str = None,
13 | verbose: bool = True,
14 | temperature: float = 0.0,
15 | system_prompt: str = None,
16 | max_tokens: int = 1024,
17 | proxy: str = None,
18 | **kwargs):
19 |
20 | assert model in ['qwen-vl-plus', 'qwen-vl-max']
21 | self.model = model
22 | import dashscope
23 | self.fail_msg = 'Failed to obtain answer via API. '
24 | self.max_tokens = max_tokens
25 | self.temperature = temperature
26 | if key is None:
27 | key = os.environ.get('DASHSCOPE_API_KEY', None)
28 | assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)"
29 | dashscope.api_key = key
30 | if proxy is not None:
31 | proxy_set(proxy)
32 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
33 |
34 | @staticmethod
35 | def build_msgs(msgs_raw, system_prompt=None):
36 | msgs = cp.deepcopy(msgs_raw)
37 | ret = []
38 | if system_prompt is not None:
39 | content = list(dict(text=system_prompt))
40 | ret.append(dict(role='system', content=content))
41 | content = []
42 | for i, msg in enumerate(msgs):
43 | if osp.exists(msg):
44 | content.append(dict(image='file://' + msg))
45 | elif msg.startswith('http'):
46 | content.append(dict(image=msg))
47 | else:
48 | content.append(dict(text=msg))
49 | ret.append(dict(role='user', content=content))
50 | return ret
51 |
52 | def generate_inner(self, inputs, **kwargs) -> str:
53 | from dashscope import MultiModalConversation
54 | assert isinstance(inputs, str) or isinstance(inputs, list)
55 | pure_text = True
56 | if isinstance(inputs, list):
57 | for pth in inputs:
58 | if osp.exists(pth) or pth.startswith('http'):
59 | pure_text = False
60 | assert not pure_text
61 | messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
62 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
63 | gen_config.update(self.kwargs)
64 | try:
65 | response = MultiModalConversation.call(model=self.model, messages=messages)
66 | if self.verbose:
67 | print(response)
68 | answer = response.output.choices[0]['message']['content'][0]['text']
69 | return 0, answer, 'Succeeded! '
70 | except Exception as err:
71 | if self.verbose:
72 | self.logger.error(err)
73 | self.logger.error(f"The input messages are {inputs}.")
74 |
75 | return -1, '', ''
76 |
77 | class QwenVLAPI(QwenVLWrapper):
78 |
79 | def generate(self, image_path, prompt, dataset=None):
80 | return super(QwenVLAPI, self).generate([image_path, prompt])
81 |
82 | def interleave_generate(self, ti_list, dataset=None):
83 | return super(QwenVLAPI, self).generate(ti_list)
84 |
--------------------------------------------------------------------------------
/vlmeval/api/gemini.py:
--------------------------------------------------------------------------------
1 | from vlmeval.smp import *
2 | from vlmeval.api.base import BaseAPI
3 |
4 | headers = 'Content-Type: application/json'
5 |
6 | class GeminiWrapper(BaseAPI):
7 |
8 | is_api: bool = True
9 |
10 | def __init__(self,
11 | retry: int = 5,
12 | wait: int = 5,
13 | key: str = None,
14 | verbose: bool = True,
15 | temperature: float = 0.0,
16 | system_prompt: str = None,
17 | max_tokens: int = 1024,
18 | proxy: str = None,
19 | **kwargs):
20 |
21 | self.fail_msg = 'Failed to obtain answer via API. '
22 | self.max_tokens = max_tokens
23 | self.temperature = temperature
24 | if key is None:
25 | key = os.environ.get('GOOGLE_API_KEY', None)
26 | assert key is not None
27 | self.api_key = key
28 | if proxy is not None:
29 | proxy_set(proxy)
30 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
31 |
32 | @staticmethod
33 | def build_msgs(msgs_raw, system_prompt=None):
34 | msgs = cp.deepcopy(msgs_raw)
35 | assert len(msgs) % 2 == 1
36 |
37 | if system_prompt is not None:
38 | msgs[0] = [system_prompt, msgs[0]]
39 | ret = []
40 | for i, msg in enumerate(msgs):
41 | role = 'user' if i % 2 == 0 else 'model'
42 | parts = msg if isinstance(msg, list) else [msg]
43 | ret.append(dict(role=role, parts=parts))
44 | return ret
45 |
46 | def generate_inner(self, inputs, **kwargs) -> str:
47 | import google.generativeai as genai
48 | assert isinstance(inputs, str) or isinstance(inputs, list)
49 | pure_text = True
50 | if isinstance(inputs, list):
51 | for pth in inputs:
52 | if osp.exists(pth) or pth.startswith('http'):
53 | pure_text = False
54 | genai.configure(api_key=self.api_key)
55 | model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel('gemini-pro-vision')
56 | if isinstance(inputs, str):
57 | messages = [inputs] if self.system_prompt is None else [self.system_prompt, inputs]
58 | elif pure_text:
59 | messages = self.build_msgs(inputs, self.system_prompt)
60 | else:
61 | messages = [] if self.system_prompt is None else [self.system_prompt]
62 | for s in inputs:
63 | if osp.exists(s):
64 | messages.append(Image.open(s))
65 | elif s.startswith('http'):
66 | pth = download_file(s)
67 | messages.append(Image.open(pth))
68 | shutil.remove(pth)
69 | else:
70 | messages.append(s)
71 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
72 | gen_config.update(self.kwargs)
73 | try:
74 | answer = model.generate_content(messages, generation_config=genai.types.GenerationConfig(**gen_config)).text
75 | return 0, answer, 'Succeeded! '
76 | except Exception as err:
77 | if self.verbose:
78 | self.logger.error(err)
79 | self.logger.error(f"The input messages are {inputs}.")
80 |
81 | return -1, '', ''
82 |
83 |
84 |
85 | class GeminiProVision(GeminiWrapper):
86 |
87 | def generate(self, image_path, prompt, dataset=None):
88 | return super(GeminiProVision, self).generate([image_path, prompt])
89 |
90 | def interleave_generate(self, ti_list, dataset=None):
91 | return super(GeminiProVision, self).generate(ti_list)
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from vlmeval.smp import *
4 | from vlmeval.evaluate import multiple_choice_eval
5 | from vlmeval.inference_multi import infer_data_job, prefetch_acc
6 | from vlmeval.config import supported_VLM
7 | from vlmeval.utils import dataset_URLs, abbr2full
8 | import json
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--data', type=str, nargs='+', required=True)
13 | parser.add_argument("--model", type=str, nargs='+', default='instructblip_13b', required=False)
14 | parser.add_argument("--lmudata", type=str, default='', required=False)
15 | parser.add_argument("--openai", type=str, default='', required=False)
16 | parser.add_argument("--stage", type=str, default='BDebate', required=True)
17 | parser.add_argument("--nproc", type=int, default=4, help="Parallel API calling")
18 | parser.add_argument("--debate", type=int, default=2, required=True)
19 | parser.add_argument("--ignore", action='store_true', help="Ignore failed indices. ")
20 | parser.add_argument("--verbose", action='store_true')
21 | parser.add_argument("--prefetch", action='store_true')
22 | parser.add_argument("--kg_init", action='store_true')
23 | args = parser.parse_args()
24 | return args
25 |
26 | def main():
27 | args = parse_args()
28 | assert len(args.data), "--data should be a list of data files"
29 | init_environ(args)
30 |
31 | rank, world_size = get_rank_and_world_size()
32 | if world_size > 1:
33 | torch.cuda.set_device(rank)
34 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400))
35 |
36 | for _, model_name in enumerate(args.model):
37 | model = None
38 | pred_root = model_name
39 |
40 | for i, dataset_name in enumerate(args.data):
41 | os.makedirs(f'results/{model_name}/{dataset_name}', exist_ok=True)
42 |
43 | logger_file = "results/{}/{}/{}_{}_{}_log.txt".format(args.model[0], args.data[0], args.model[0],args.data[0], args.stage)
44 | logger = get_logger(name='Multi', log_file=logger_file)
45 | logger.info(f"####- Begin -####\n{args.model[0]}: {args.data[0]}-{args.stage}")
46 |
47 | if dataset_name not in dataset_URLs:
48 | dataset_name = abbr2full(dataset_name)
49 |
50 | if dataset_name not in dataset_URLs:
51 | logger.error(f'Unknown dataset: {dataset_name}. ')
52 | continue
53 |
54 | result_file = f'results/{pred_root}/{dataset_name}/{model_name}_{dataset_name}_{args.stage}_DB{args.debate}.xlsx'
55 |
56 | if model is None:
57 | model = model_name # which is only a name
58 |
59 | model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, args=args, logger=logger, ignore_failed=args.ignore)
60 |
61 | if rank == 0:
62 | time.sleep(3)
63 | res = None
64 | if listinstr(['MMBench'], dataset_name):
65 | res = prefetch_acc(result_file)
66 | else:
67 | logger.warning(f'{dataset_name} is not handled by prefetch score calculator')
68 | if res is not None:
69 | logger.info(f'{model_name} prefetching: ')
70 | logger.info(res)
71 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx'))
72 |
73 | if listinstr(['MMBench','ScienceQA'], dataset_name):
74 | multiple_choice_eval(result_file, dataset=dataset_name, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose)
75 | else:
76 | logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ')
77 |
78 | if __name__ == '__main__':
79 | main()
80 |
--------------------------------------------------------------------------------
/vlmeval/api/gpt_int.py:
--------------------------------------------------------------------------------
1 | import json
2 | import warnings
3 | import requests
4 | from ..smp import *
5 | from .gpt import GPT_context_window, OpenAIWrapper
6 |
7 |
8 | url = "http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat"
9 | headers = {
10 | "Content-Type": "application/json"
11 | }
12 |
13 | class OpenAIWrapperInternal(OpenAIWrapper):
14 |
15 | is_api: bool = True
16 |
17 | def __init__(self,
18 | model: str = 'gpt-3.5-turbo-0613',
19 | retry: int = 5,
20 | wait: int = 3,
21 | verbose: bool = True,
22 | system_prompt: str = None,
23 | temperature: float = 0,
24 | timeout: int = 60,
25 | max_tokens: int = 1024,
26 | img_size: int = 512,
27 | img_detail: str = 'low',
28 | **kwargs):
29 |
30 | self.model = model
31 | if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']):
32 | keys = load(os.environ['KEYS'])
33 | headers['alles-apin-token'] = keys.get('alles-apin-token', '')
34 | elif 'ALLES' in os.environ:
35 | headers['alles-apin-token'] = os.environ['ALLES']
36 | self.headers = headers
37 | self.temperature = temperature
38 | self.timeout = timeout
39 | self.max_tokens = max_tokens
40 |
41 | assert img_size > 0 or img_size == -1
42 | self.img_size = img_size
43 | assert img_detail in ['high', 'low']
44 | self.img_detail = img_detail
45 |
46 | self.vision = False
47 | if model == 'gpt-4-vision-preview':
48 | self.vision = True
49 |
50 | super(OpenAIWrapper, self).__init__(
51 | wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
52 |
53 | def generate_inner(self, inputs, **kwargs) -> str:
54 | input_msgs = self.prepare_inputs(inputs)
55 |
56 | temperature = kwargs.pop('temperature', self.temperature)
57 | max_tokens = kwargs.pop('max_tokens', self.max_tokens)
58 |
59 | # Held out 100 tokens as buffer
60 | context_window = GPT_context_window(self.model)
61 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
62 | if 0 < max_tokens <= 100:
63 | print('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
64 | if max_tokens <= 0:
65 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
66 |
67 | payload = dict(
68 | model=self.model,
69 | messages=input_msgs,
70 | max_tokens=max_tokens,
71 | n=1,
72 | stop=None,
73 | timeout=self.timeout,
74 | temperature=temperature,
75 | **kwargs)
76 |
77 | response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
78 | ret_code = response.status_code
79 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
80 |
81 | answer = self.fail_msg
82 | try:
83 | resp_struct = json.loads(response.text)
84 | assert resp_struct['msg'] == 'ok' and resp_struct['msgCode'] == '10000', resp_struct
85 | answer = resp_struct['data']['choices'][0]['message']['content'].strip()
86 | except:
87 | pass
88 | return ret_code, answer, response
89 |
90 |
91 | class GPT4V_Internal(OpenAIWrapperInternal):
92 |
93 | def generate(self, image_path, prompt, dataset=None):
94 | assert self.model == 'gpt-4-vision-preview'
95 | return super(GPT4V_Internal, self).generate([image_path, prompt])
96 |
97 | def interleave_generate(self, ti_list, dataset=None):
98 | assert self.model == 'gpt-4-vision-preview'
99 | return super(GPT4V_Internal, self).generate(ti_list)
--------------------------------------------------------------------------------
/vlmeval/utils/debate.py:
--------------------------------------------------------------------------------
1 | from vlmeval.utils import init_prompt_multi
2 |
3 | def Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger):
4 | if stage[:8] == "baseline":
5 | prompt_format = "IQ-A"
6 | prompt_G = init_prompt_multi(struct, prompt_format)
7 | response = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name)
8 | logger.info("########--G_A--######\nPrompt: {}\nGT: {} - ANS: {}".format(prompt_G, struct['text']['answer'], response))
9 |
10 | elif stage[:7] == "ODebate":
11 | for debate_ in range(debate):
12 | logger.info("########--DEBATE{}--######".format(debate_))
13 |
14 | prompt_format = "ODIM-S" if debate_==0 else "ODQIM-S"
15 | prompt_A = init_prompt_multi(struct, prompt_format)
16 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name)
17 |
18 | prompt_format = "ONIM-S" if debate_==0 else "ONQIM-S"
19 | prompt_N = init_prompt_multi(struct, prompt_format)
20 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name)
21 |
22 | struct['kg'] = [kg_aff.strip(), kg_neg.strip()]
23 | prompt_format = "OAGM-A"
24 | prompt_F = init_prompt_multi(struct, prompt_format)
25 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name, max_length=20)
26 |
27 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response))
28 |
29 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response))
30 |
31 | elif stage[:7] == "BDebate":
32 | for debate_ in range(debate):
33 | logger.info("########--DEBATE{}--######".format(debate_))
34 | if debate_ == 0:
35 | if kg_init:
36 | if struct['text']['kg'] != 'none':
37 | logger.info("#####---KG_IB---#####\n{}".format(struct['text']['kg']))
38 | struct['kg'] = [struct['text']['kg'], struct['text']['kg']]
39 | else:
40 | prompt_format = "GKG-G"
41 | prompt_G = init_prompt_multi(struct, prompt_format)
42 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name)
43 | struct['kg'] = [kg_base, kg_base]
44 | logger.info("#####---KG_P---#####\n{}".format(prompt_G))
45 | logger.info("#####---KG_B---#####\n{}".format(kg_base))
46 | else:
47 | prompt_format = "GKG-G"
48 | prompt_G = init_prompt_multi(struct, prompt_format)
49 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name)
50 | struct['kg'] = [kg_base, kg_base]
51 | logger.info("#####---KG_P---#####\n{}".format(prompt_G))
52 | logger.info("#####---KG_B---#####\n{}".format(kg_base))
53 |
54 | prompt_format = "KDQIM-G"
55 | prompt_A = init_prompt_multi(struct, prompt_format)
56 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name)
57 | struct['kg'] = [kg_aff,struct['kg'][1]]
58 |
59 | prompt_format = "KNQIM-G"
60 | prompt_N = init_prompt_multi(struct, prompt_format)
61 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name)
62 |
63 | struct['kg'] = [kg_aff, kg_neg]
64 | prompt_format = "KAGM-A"
65 | prompt_F = init_prompt_multi(struct, prompt_format)
66 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name)
67 |
68 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response))
69 |
70 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response))
71 |
72 | else:
73 | assert stage == "BDebate", f"Please confirm if your 'stage' is set correctly.\nDebate Only: Begin with ODebate\nDebate with Blueprint:Begin with BDebate\nStage setting Now:{stage}"
74 | return response
75 |
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/debate-checkpoint.py:
--------------------------------------------------------------------------------
1 | from vlmeval.utils import init_prompt_multi
2 |
3 | def Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger):
4 | if stage[:8] == "baseline":
5 | prompt_format = "IQ-A"
6 | prompt_G = init_prompt_multi(struct, prompt_format)
7 | response = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name)
8 | logger.info("########--G_A--######\nPrompt: {}\nGT: {} - ANS: {}".format(prompt_G, struct['text']['answer'], response))
9 |
10 | elif stage[:7] == "ODebate":
11 | for debate_ in range(debate):
12 | logger.info("########--DEBATE{}--######".format(debate_))
13 |
14 | prompt_format = "ODIM-S" if debate_==0 else "ODQIM-S"
15 | prompt_A = init_prompt_multi(struct, prompt_format)
16 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name)
17 |
18 | prompt_format = "ONIM-S" if debate_==0 else "ONQIM-S"
19 | prompt_N = init_prompt_multi(struct, prompt_format)
20 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name)
21 |
22 | struct['kg'] = [kg_aff.strip(), kg_neg.strip()]
23 | prompt_format = "OAGM-A"
24 | prompt_F = init_prompt_multi(struct, prompt_format)
25 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name, max_length=20)
26 |
27 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response))
28 |
29 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response))
30 |
31 | elif stage[:7] == "BDebate":
32 | for debate_ in range(debate):
33 | logger.info("########--DEBATE{}--######".format(debate_))
34 | if debate_ == 0:
35 | if kg_init:
36 | if struct['text']['kg'] != 'none':
37 | logger.info("#####---KG_IB---#####\n{}".format(struct['text']['kg']))
38 | struct['kg'] = [struct['text']['kg'], struct['text']['kg']]
39 | else:
40 | prompt_format = "GKG-G"
41 | prompt_G = init_prompt_multi(struct, prompt_format)
42 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name)
43 | struct['kg'] = [kg_base, kg_base]
44 | logger.info("#####---KG_P---#####\n{}".format(prompt_G))
45 | logger.info("#####---KG_B---#####\n{}".format(kg_base))
46 | else:
47 | prompt_format = "GKG-G"
48 | prompt_G = init_prompt_multi(struct, prompt_format)
49 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name)
50 | struct['kg'] = [kg_base, kg_base]
51 | logger.info("#####---KG_P---#####\n{}".format(prompt_G))
52 | logger.info("#####---KG_B---#####\n{}".format(kg_base))
53 |
54 | prompt_format = "KDQIM-G"
55 | prompt_A = init_prompt_multi(struct, prompt_format)
56 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name)
57 | struct['kg'] = [kg_aff,struct['kg'][1]]
58 |
59 | prompt_format = "KNQIM-G"
60 | prompt_N = init_prompt_multi(struct, prompt_format)
61 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name)
62 |
63 | struct['kg'] = [kg_aff, kg_neg]
64 | prompt_format = "KAGM-A"
65 | prompt_F = init_prompt_multi(struct, prompt_format)
66 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name)
67 |
68 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response))
69 |
70 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response))
71 |
72 | else:
73 | assert stage == "BDebate", f"Please confirm if your 'stage' is set correctly.\nDebate Only: Begin with ODebate\nDebate with Blueprint:Begin with BDebate\nStage setting Now:{stage}"
74 | return response
75 |
--------------------------------------------------------------------------------
/vlmeval/smp/misc.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: F401, F403
2 | import abc
3 | import argparse
4 | import csv
5 | import multiprocessing as mp
6 | import os
7 | import os.path as osp
8 | import copy as cp
9 | import random as rd
10 | import requests
11 | import shutil
12 | import subprocess
13 | import warnings
14 | import pandas as pd
15 | from collections import OrderedDict, defaultdict
16 | from multiprocessing import Pool, current_process
17 | from tqdm import tqdm
18 | import datetime
19 | import matplotlib.pyplot as plt
20 | import seaborn as sns
21 | from tabulate import tabulate_formats, tabulate
22 | from huggingface_hub import scan_cache_dir
23 | from sty import fg, bg, ef, rs
24 |
25 | def process_punctuation(inText):
26 | import re
27 | outText = inText
28 | punct = [
29 | ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
30 | '>', '<', '@', '`', ',', '?', '!'
31 | ]
32 | commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
33 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
34 | for p in punct:
35 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(
36 | commaStrip, inText) is not None):
37 | outText = outText.replace(p, '')
38 | else:
39 | outText = outText.replace(p, ' ')
40 | outText = periodStrip.sub('', outText, re.UNICODE)
41 | return outText
42 |
43 | def h2r(value):
44 | if value[0] == '#':
45 | value = value[1:]
46 | assert len(value) == 6
47 | return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
48 |
49 | def r2h(rgb):
50 | return '#%02x%02x%02x' % rgb
51 |
52 | def colored(s, color):
53 | if isinstance(color, str):
54 | if hasattr(fg, color):
55 | return getattr(fg, color) + s + fg.rs
56 | color = h2r(color)
57 | return fg(*color) + s + fg.rs
58 |
59 | def istype(s, type):
60 | if isinstance(s, type):
61 | return True
62 | try:
63 | return isinstance(eval(s), type)
64 | except Exception as _:
65 | return False
66 |
67 | def bincount(lst):
68 | bins = defaultdict(lambda: 0)
69 | for item in lst:
70 | bins[item] += 1
71 | return bins
72 |
73 | def get_cache_path(repo_id):
74 | hf_cache_info = scan_cache_dir()
75 | repos = list(hf_cache_info.repos)
76 | repo = None
77 | for r in repos:
78 | if r.repo_id == repo_id:
79 | repo = r
80 | break
81 | if repo is None:
82 | return None
83 | revs = list(repo.revisions)
84 | rev2keep, last_modified = None, 0
85 | for rev in revs:
86 | if rev.last_modified > last_modified:
87 | rev2keep, last_modified = rev, rev.last_modified
88 | if rev2keep is None:
89 | return None
90 | return str(rev2keep.snapshot_path)
91 |
92 | def proxy_set(s):
93 | import os
94 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
95 | os.environ[key] = s
96 |
97 | def get_rank_and_world_size():
98 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
99 | world_size = int(os.environ.get("WORLD_SIZE", 1))
100 | return local_rank, world_size
101 |
102 | def splitlen(s, sym='/'):
103 | return len(s.split(sym))
104 |
105 | def listinstr(lst, s):
106 | assert isinstance(lst, list)
107 | for item in lst:
108 | if item in s:
109 | return True
110 | return False
111 |
112 | def d2df(D):
113 | return pd.DataFrame({x: [D[x]] for x in D})
114 |
115 | def cn_string(s):
116 | import re
117 | if re.search(u'[\u4e00-\u9fff]', s):
118 | return True
119 | return False
120 |
121 | try:
122 | import decord
123 | except ImportError:
124 | pass
125 |
126 | def timestr(second=True, minute=False):
127 | s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
128 | if second:
129 | return s
130 | elif minute:
131 | return s[:-2]
132 | else:
133 | return s[:-4]
134 |
135 | def dict_merge(dct, merge_dct):
136 | for k, _ in merge_dct.items():
137 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
138 | dict_merge(dct[k], merge_dct[k])
139 | else:
140 | dct[k] = merge_dct[k]
141 |
142 | def youtube_dl(idx):
143 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
144 | os.system(cmd)
145 |
146 | def run_command(cmd):
147 | if isinstance(cmd, str):
148 | cmd = cmd.split()
149 | return subprocess.check_output(cmd)
150 |
--------------------------------------------------------------------------------
/vlmeval/smp/.ipynb_checkpoints/misc-checkpoint.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: F401, F403
2 | import abc
3 | import argparse
4 | import csv
5 | import multiprocessing as mp
6 | import os
7 | import os.path as osp
8 | import copy as cp
9 | import random as rd
10 | import requests
11 | import shutil
12 | import subprocess
13 | import warnings
14 | import pandas as pd
15 | from collections import OrderedDict, defaultdict
16 | from multiprocessing import Pool, current_process
17 | from tqdm import tqdm
18 | import datetime
19 | import matplotlib.pyplot as plt
20 | import seaborn as sns
21 | from tabulate import tabulate_formats, tabulate
22 | from huggingface_hub import scan_cache_dir
23 | from sty import fg, bg, ef, rs
24 |
25 | def process_punctuation(inText):
26 | import re
27 | outText = inText
28 | punct = [
29 | ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
30 | '>', '<', '@', '`', ',', '?', '!'
31 | ]
32 | commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
33 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
34 | for p in punct:
35 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(
36 | commaStrip, inText) is not None):
37 | outText = outText.replace(p, '')
38 | else:
39 | outText = outText.replace(p, ' ')
40 | outText = periodStrip.sub('', outText, re.UNICODE)
41 | return outText
42 |
43 | def h2r(value):
44 | if value[0] == '#':
45 | value = value[1:]
46 | assert len(value) == 6
47 | return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
48 |
49 | def r2h(rgb):
50 | return '#%02x%02x%02x' % rgb
51 |
52 | def colored(s, color):
53 | if isinstance(color, str):
54 | if hasattr(fg, color):
55 | return getattr(fg, color) + s + fg.rs
56 | color = h2r(color)
57 | return fg(*color) + s + fg.rs
58 |
59 | def istype(s, type):
60 | if isinstance(s, type):
61 | return True
62 | try:
63 | return isinstance(eval(s), type)
64 | except Exception as _:
65 | return False
66 |
67 | def bincount(lst):
68 | bins = defaultdict(lambda: 0)
69 | for item in lst:
70 | bins[item] += 1
71 | return bins
72 |
73 | def get_cache_path(repo_id):
74 | hf_cache_info = scan_cache_dir()
75 | repos = list(hf_cache_info.repos)
76 | repo = None
77 | for r in repos:
78 | if r.repo_id == repo_id:
79 | repo = r
80 | break
81 | if repo is None:
82 | return None
83 | revs = list(repo.revisions)
84 | rev2keep, last_modified = None, 0
85 | for rev in revs:
86 | if rev.last_modified > last_modified:
87 | rev2keep, last_modified = rev, rev.last_modified
88 | if rev2keep is None:
89 | return None
90 | return str(rev2keep.snapshot_path)
91 |
92 | def proxy_set(s):
93 | import os
94 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
95 | os.environ[key] = s
96 |
97 | def get_rank_and_world_size():
98 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
99 | world_size = int(os.environ.get("WORLD_SIZE", 1))
100 | return local_rank, world_size
101 |
102 | def splitlen(s, sym='/'):
103 | return len(s.split(sym))
104 |
105 | def listinstr(lst, s):
106 | assert isinstance(lst, list)
107 | for item in lst:
108 | if item in s:
109 | return True
110 | return False
111 |
112 | def d2df(D):
113 | return pd.DataFrame({x: [D[x]] for x in D})
114 |
115 | def cn_string(s):
116 | import re
117 | if re.search(u'[\u4e00-\u9fff]', s):
118 | return True
119 | return False
120 |
121 | try:
122 | import decord
123 | except ImportError:
124 | pass
125 |
126 | def timestr(second=True, minute=False):
127 | s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
128 | if second:
129 | return s
130 | elif minute:
131 | return s[:-2]
132 | else:
133 | return s[:-4]
134 |
135 | def dict_merge(dct, merge_dct):
136 | for k, _ in merge_dct.items():
137 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
138 | dict_merge(dct[k], merge_dct[k])
139 | else:
140 | dct[k] = merge_dct[k]
141 |
142 | def youtube_dl(idx):
143 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
144 | os.system(cmd)
145 |
146 | def run_command(cmd):
147 | if isinstance(cmd, str):
148 | cmd = cmd.split()
149 | return subprocess.check_output(cmd)
150 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | from os.path import exists
4 | from setuptools import find_packages, setup
5 |
6 | def parse_requirements(fname='requirements.txt', with_version=True):
7 | """Parse the package dependencies listed in a requirements file but strips
8 | specific versioning information.
9 |
10 | Args:
11 | fname (str): path to requirements file
12 | with_version (bool, default=False): if True include version specs
13 |
14 | Returns:
15 | List[str]: list of requirements items
16 |
17 | CommandLine:
18 | python -c "import setup; print(setup.parse_requirements())"
19 | """
20 |
21 | require_fpath = fname
22 |
23 | def parse_line(line):
24 | """Parse information from a line in a requirements text file."""
25 | if line.startswith('-r '):
26 | # Allow specifying requirements in other files
27 | target = line.split(' ')[1]
28 | for info in parse_require_file(target):
29 | yield info
30 | else:
31 | info = {'line': line}
32 | if line.startswith('-e '):
33 | info['package'] = line.split('#egg=')[1]
34 | elif '@git+' in line:
35 | info['package'] = line
36 | else:
37 | # Remove versioning from the package
38 | pat = '(' + '|'.join(['>=', '==', '>']) + ')'
39 | parts = re.split(pat, line, maxsplit=1)
40 | parts = [p.strip() for p in parts]
41 |
42 | info['package'] = parts[0]
43 | if len(parts) > 1:
44 | op, rest = parts[1:]
45 | if ';' in rest:
46 | # Handle platform specific dependencies
47 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
48 | version, platform_deps = map(str.strip,
49 | rest.split(';'))
50 | info['platform_deps'] = platform_deps
51 | else:
52 | version = rest # NOQA
53 | info['version'] = (op, version)
54 | yield info
55 |
56 | def parse_require_file(fpath):
57 | with open(fpath, 'r') as f:
58 | for line in f.readlines():
59 | line = line.strip()
60 | if line and not line.startswith('#'):
61 | for info in parse_line(line):
62 | yield info
63 |
64 | def gen_packages_items():
65 | if exists(require_fpath):
66 | for info in parse_require_file(require_fpath):
67 | parts = [info['package']]
68 | if with_version and 'version' in info:
69 | parts.extend(info['version'])
70 | if not sys.version.startswith('3.4'):
71 | # apparently package_deps are broken in 3.4
72 | platform_deps = info.get('platform_deps')
73 | if platform_deps is not None:
74 | parts.append(';' + platform_deps)
75 | item = ''.join(parts)
76 | yield item
77 |
78 | packages = list(gen_packages_items())
79 | return packages
80 |
81 |
82 | with open('README.md') as f:
83 | readme = f.read()
84 |
85 |
86 | def do_setup():
87 | setup(
88 | name='vlmeval',
89 | version='0.1.0',
90 | description='BDoG',
91 | author="thecharm",
92 | author_email='csczheng@comp.polyu.edu.hk',
93 | long_description=readme,
94 | long_description_content_type='text/markdown',
95 | cmdclass={},
96 | install_requires=parse_requirements('requirements.txt'),
97 | setup_requires=[],
98 | python_requires='>=3.7.0',
99 | packages=find_packages(exclude=[
100 | 'test*',
101 | 'paper_test*',
102 | ]),
103 | keywords=['AI', 'NLP', 'in-context learning'],
104 | entry_points={
105 | "console_scripts": []
106 | },
107 | classifiers=[
108 | 'Programming Language :: Python :: 3.7',
109 | 'Programming Language :: Python :: 3.8',
110 | 'Programming Language :: Python :: 3.9',
111 | 'Programming Language :: Python :: 3.10',
112 | 'Intended Audience :: Developers',
113 | 'Intended Audience :: Education',
114 | 'Intended Audience :: Science/Research',
115 | ])
116 |
117 |
118 | if __name__ == '__main__':
119 | do_setup()
120 |
--------------------------------------------------------------------------------
/vlmeval/smp/vlm.py:
--------------------------------------------------------------------------------
1 | import os, io
2 | import pandas as pd
3 | import numpy as np
4 | import string
5 | from uuid import uuid4
6 | import os.path as osp
7 | import base64
8 | from PIL import Image
9 |
10 | def mmqa_display(question):
11 | question = {k.lower(): v for k, v in question.items()}
12 | keys = list(question.keys())
13 | keys = [k for k in keys if k not in ['index', 'image']]
14 |
15 | images = question['image']
16 | if isinstance(images, str):
17 | images = [images]
18 |
19 | idx = question.pop('index', 'XXX')
20 | print(f'INDEX: {idx}')
21 |
22 | for im in images:
23 | image = decode_base64_to_image(im, target_size=512)
24 | display(image)
25 |
26 | for k in keys:
27 | try:
28 | if not pd.isna(question[k]):
29 | print(f'{k.upper()}. {question[k]}')
30 | except ValueError:
31 | if False in pd.isna(question[k]):
32 | print(f'{k.upper()}. {question[k]}')
33 |
34 | def encode_image_to_base64(img, target_size=-1):
35 | # if target_size == -1, will not do resizing
36 | # else, will set the max_size ot (target_size, target_size)
37 | if img.mode in ("RGBA", "P"):
38 | img = img.convert("RGB")
39 | tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
40 | if target_size > 0:
41 | img.thumbnail((target_size, target_size))
42 | img.save(tmp)
43 | with open(tmp, 'rb') as image_file:
44 | image_data = image_file.read()
45 | ret = base64.b64encode(image_data).decode('utf-8')
46 | os.remove(tmp)
47 | return ret
48 |
49 | def encode_image_file_to_base64(image_path, target_size=-1):
50 | image = Image.open(image_path)
51 | return encode_image_to_base64(image, target_size=target_size)
52 |
53 | def decode_base64_to_image(base64_string, target_size=-1):
54 | image_data = base64.b64decode(base64_string)
55 | image = Image.open(io.BytesIO(image_data))
56 | if image.mode in ('RGBA', 'P'):
57 | image = image.convert('RGB')
58 | if target_size > 0:
59 | image.thumbnail((target_size, target_size))
60 | return image
61 |
62 | def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
63 | image = decode_base64_to_image(base64_string, target_size=target_size)
64 | image.save(image_path)
65 |
66 | def LMUDataRoot():
67 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
68 | return os.environ['LMUData']
69 | home = osp.expanduser('~')
70 | root = osp.join(home, 'LMUData')
71 | os.makedirs(root, exist_ok=True)
72 | return root
73 |
74 | def init_environ(args):
75 | if len(args.lmudata):
76 | os.environ['LMUData']=args.lmudata
77 | if len(args.openai):
78 | os.environ['OPENAI_API_KEY']=args.openai
79 |
80 | def build_option_str(option_dict):
81 | s = 'There are several options: \n'
82 | for c, content in option_dict.items():
83 | if not pd.isna(content):
84 | s += f'{c}. {content}\n'
85 | return s
86 |
87 | def isimg(s):
88 | return osp.exists(s) or s.startswith('http')
89 |
90 | def read_ok(img_path):
91 | if not osp.exists(img_path):
92 | return False
93 | try:
94 | im = Image.open(img_path)
95 | assert im.size[0] > 0 and im.size[1] > 0
96 | return True
97 | except:
98 | return False
99 |
100 | def gpt_key_set():
101 | openai_key = os.environ.get('OPENAI_API_KEY', None)
102 | return isinstance(openai_key, str) and openai_key.startswith('sk-')
103 |
104 | def apiok(wrapper):
105 | s = wrapper.generate("Hello!")
106 | return wrapper.fail_msg not in s
107 |
108 | def circular_pred(df, extract_func=None):
109 | if extract_func is None:
110 | extract_func = lambda x: x
111 | df = df.sort_values('index')
112 | from vlmeval.utils import can_infer_option
113 | shift = int(1e6)
114 |
115 | choices = [extract_func(x) for x in df['prediction']]
116 | pred_map = {i: c for i, c in zip(df['index'], choices)}
117 | flag_map = {i: True for i in pred_map if i < 1e6}
118 | valid_map = {i: True for i in pred_map if i < 1e6}
119 | for i in df['index']:
120 | if i >= shift and pred_map[i] and pred_map[i - shift]:
121 | if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase):
122 | valid_map[i % shift] = False
123 | continue
124 | if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
125 | continue
126 | else:
127 | flag_map[i % shift] = False
128 | flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
129 | flags = list(flag_map.values())
130 | return np.mean(flags)
131 |
132 | def MMBenchOfficialServer():
133 | root = LMUDataRoot()
134 | for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']:
135 | if osp.exists(f'{root}/{dataset}.tsv'):
136 | return True
137 | return False
--------------------------------------------------------------------------------
/vlmeval/smp/.ipynb_checkpoints/vlm-checkpoint.py:
--------------------------------------------------------------------------------
1 | import os, io
2 | import pandas as pd
3 | import numpy as np
4 | import string
5 | from uuid import uuid4
6 | import os.path as osp
7 | import base64
8 | from PIL import Image
9 |
10 | def mmqa_display(question):
11 | question = {k.lower(): v for k, v in question.items()}
12 | keys = list(question.keys())
13 | keys = [k for k in keys if k not in ['index', 'image']]
14 |
15 | images = question['image']
16 | if isinstance(images, str):
17 | images = [images]
18 |
19 | idx = question.pop('index', 'XXX')
20 | print(f'INDEX: {idx}')
21 |
22 | for im in images:
23 | image = decode_base64_to_image(im, target_size=512)
24 | display(image)
25 |
26 | for k in keys:
27 | try:
28 | if not pd.isna(question[k]):
29 | print(f'{k.upper()}. {question[k]}')
30 | except ValueError:
31 | if False in pd.isna(question[k]):
32 | print(f'{k.upper()}. {question[k]}')
33 |
34 | def encode_image_to_base64(img, target_size=-1):
35 | # if target_size == -1, will not do resizing
36 | # else, will set the max_size ot (target_size, target_size)
37 | if img.mode in ("RGBA", "P"):
38 | img = img.convert("RGB")
39 | tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
40 | if target_size > 0:
41 | img.thumbnail((target_size, target_size))
42 | img.save(tmp)
43 | with open(tmp, 'rb') as image_file:
44 | image_data = image_file.read()
45 | ret = base64.b64encode(image_data).decode('utf-8')
46 | os.remove(tmp)
47 | return ret
48 |
49 | def encode_image_file_to_base64(image_path, target_size=-1):
50 | image = Image.open(image_path)
51 | return encode_image_to_base64(image, target_size=target_size)
52 |
53 | def decode_base64_to_image(base64_string, target_size=-1):
54 | image_data = base64.b64decode(base64_string)
55 | image = Image.open(io.BytesIO(image_data))
56 | if image.mode in ('RGBA', 'P'):
57 | image = image.convert('RGB')
58 | if target_size > 0:
59 | image.thumbnail((target_size, target_size))
60 | return image
61 |
62 | def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
63 | image = decode_base64_to_image(base64_string, target_size=target_size)
64 | image.save(image_path)
65 |
66 | def LMUDataRoot():
67 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
68 | return os.environ['LMUData']
69 | home = osp.expanduser('~')
70 | root = osp.join(home, 'LMUData')
71 | os.makedirs(root, exist_ok=True)
72 | return root
73 |
74 | def init_environ(args):
75 | if len(args.lmudata):
76 | os.environ['LMUData']=args.lmudata
77 | if len(args.openai):
78 | os.environ['OPENAI_API_KEY']=args.openai
79 |
80 | def build_option_str(option_dict):
81 | s = 'There are several options: \n'
82 | for c, content in option_dict.items():
83 | if not pd.isna(content):
84 | s += f'{c}. {content}\n'
85 | return s
86 |
87 | def isimg(s):
88 | return osp.exists(s) or s.startswith('http')
89 |
90 | def read_ok(img_path):
91 | if not osp.exists(img_path):
92 | return False
93 | try:
94 | im = Image.open(img_path)
95 | assert im.size[0] > 0 and im.size[1] > 0
96 | return True
97 | except:
98 | return False
99 |
100 | def gpt_key_set():
101 | openai_key = os.environ.get('OPENAI_API_KEY', None)
102 | return isinstance(openai_key, str) and openai_key.startswith('sk-')
103 |
104 | def apiok(wrapper):
105 | s = wrapper.generate("Hello!")
106 | return wrapper.fail_msg not in s
107 |
108 | def circular_pred(df, extract_func=None):
109 | if extract_func is None:
110 | extract_func = lambda x: x
111 | df = df.sort_values('index')
112 | from vlmeval.utils import can_infer_option
113 | shift = int(1e6)
114 |
115 | choices = [extract_func(x) for x in df['prediction']]
116 | pred_map = {i: c for i, c in zip(df['index'], choices)}
117 | flag_map = {i: True for i in pred_map if i < 1e6}
118 | valid_map = {i: True for i in pred_map if i < 1e6}
119 | for i in df['index']:
120 | if i >= shift and pred_map[i] and pred_map[i - shift]:
121 | if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase):
122 | valid_map[i % shift] = False
123 | continue
124 | if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
125 | continue
126 | else:
127 | flag_map[i % shift] = False
128 | flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
129 | flags = list(flag_map.values())
130 | return np.mean(flags)
131 |
132 | def MMBenchOfficialServer():
133 | root = LMUDataRoot()
134 | for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']:
135 | if osp.exists(f'{root}/{dataset}.tsv'):
136 | return True
137 | return False
--------------------------------------------------------------------------------
/vlmeval/smp/file.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import pandas as pd
4 | import os
5 | import csv
6 | import hashlib
7 | import os.path as osp
8 | import time
9 | import numpy as np
10 |
11 | class NumpyEncoder(json.JSONEncoder):
12 | def default(self, obj):
13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
14 | np.int16, np.int32, np.int64, np.uint8,
15 | np.uint16, np.uint32, np.uint64)):
16 | return int(obj)
17 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
18 | return float(obj)
19 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
20 | return {'real': obj.real, 'imag': obj.imag}
21 | elif isinstance(obj, (np.ndarray,)):
22 | return obj.tolist()
23 | elif isinstance(obj, (np.bool_)):
24 | return bool(obj)
25 | elif isinstance(obj, (np.void)):
26 | return None
27 | return json.JSONEncoder.default(self, obj)
28 |
29 | # LOAD & DUMP
30 | def dump(data, f, **kwargs):
31 | def dump_pkl(data, pth, **kwargs):
32 | pickle.dump(data, open(pth, 'wb'))
33 |
34 | def dump_json(data, pth, **kwargs):
35 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
36 |
37 | def dump_jsonl(data, f, **kwargs):
38 | lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
39 | with open(f, 'w', encoding='utf8') as fout:
40 | fout.write('\n'.join(lines))
41 |
42 | def dump_xlsx(data, f, **kwargs):
43 | data.to_excel(f, index=False, engine='xlsxwriter')
44 |
45 | def dump_csv(data, f, quoting=csv.QUOTE_ALL):
46 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
47 |
48 | def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
49 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
50 |
51 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
52 | suffix = f.split('.')[-1]
53 | return handlers[suffix](data, f, **kwargs)
54 |
55 | def load(f):
56 | def load_pkl(pth):
57 | return pickle.load(open(pth, 'rb'))
58 |
59 | def load_json(pth):
60 | return json.load(open(pth, 'r', encoding='utf-8'))
61 |
62 | def load_jsonl(f):
63 | lines = open(f, encoding='utf-8').readlines()
64 | lines = [x.strip() for x in lines]
65 | if lines[-1] == '':
66 | lines = lines[:-1]
67 | data = [json.loads(x) for x in lines]
68 | return data
69 |
70 | def load_xlsx(f):
71 | return pd.read_excel(f)
72 |
73 | def load_csv(f):
74 | return pd.read_csv(f)
75 |
76 | def load_tsv(f):
77 | return pd.read_csv(f, sep='\t')
78 |
79 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
80 | suffix = f.split('.')[-1]
81 | return handlers[suffix](f)
82 |
83 | def download_file(url, filename=None):
84 | import urllib.request
85 | from tqdm import tqdm
86 |
87 | class DownloadProgressBar(tqdm):
88 | def update_to(self, b=1, bsize=1, tsize=None):
89 | if tsize is not None:
90 | self.total = tsize
91 | self.update(b * bsize - self.n)
92 |
93 | if filename is None:
94 | filename = url.split('/')[-1]
95 |
96 | with DownloadProgressBar(unit='B', unit_scale=True,
97 | miniters=1, desc=url.split('/')[-1]) as t:
98 | urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
99 | return filename
100 |
101 | def ls(dirname='.', match='', mode='all', level=1):
102 | if dirname == '.':
103 | ans = os.listdir(dirname)
104 | else:
105 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
106 | assert mode in ['all', 'dir', 'file']
107 | assert level >= 1 and isinstance(level, int)
108 | if level == 1:
109 | ans = [x for x in ans if match in x]
110 | if mode == 'dir':
111 | ans = [x for x in ans if osp.isdir(x)]
112 | elif mode == 'file':
113 | ans = [x for x in ans if not osp.isdir(x)]
114 | else:
115 | ans = [x for x in ans if osp.isdir(x)]
116 | res = []
117 | for d in ans:
118 | res.extend(ls(d, match=match, mode=mode, level=level-1))
119 | ans = res
120 | return ans
121 |
122 | def mrlines(fname, sp='\n'):
123 | f = open(fname).read().split(sp)
124 | while f != [] and f[-1] == '':
125 | f = f[:-1]
126 | return f
127 |
128 | def mwlines(lines, fname):
129 | with open(fname, 'w') as fout:
130 | fout.write('\n'.join(lines))
131 |
132 | def md5(file_pth):
133 | with open(file_pth, 'rb') as f:
134 | hash = hashlib.new('md5')
135 | for chunk in iter(lambda: f.read(2**20), b''):
136 | hash.update(chunk)
137 | return str(hash.hexdigest())
138 |
139 | def last_modified(pth):
140 | stamp = osp.getmtime(pth)
141 | m_ti = time.ctime(stamp)
142 | t_obj = time.strptime(m_ti)
143 | t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
144 | return t
145 |
--------------------------------------------------------------------------------
/vlmeval/smp/.ipynb_checkpoints/file-checkpoint.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import pandas as pd
4 | import os
5 | import csv
6 | import hashlib
7 | import os.path as osp
8 | import time
9 | import numpy as np
10 |
11 | class NumpyEncoder(json.JSONEncoder):
12 | def default(self, obj):
13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
14 | np.int16, np.int32, np.int64, np.uint8,
15 | np.uint16, np.uint32, np.uint64)):
16 | return int(obj)
17 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
18 | return float(obj)
19 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
20 | return {'real': obj.real, 'imag': obj.imag}
21 | elif isinstance(obj, (np.ndarray,)):
22 | return obj.tolist()
23 | elif isinstance(obj, (np.bool_)):
24 | return bool(obj)
25 | elif isinstance(obj, (np.void)):
26 | return None
27 | return json.JSONEncoder.default(self, obj)
28 |
29 | # LOAD & DUMP
30 | def dump(data, f, **kwargs):
31 | def dump_pkl(data, pth, **kwargs):
32 | pickle.dump(data, open(pth, 'wb'))
33 |
34 | def dump_json(data, pth, **kwargs):
35 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
36 |
37 | def dump_jsonl(data, f, **kwargs):
38 | lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
39 | with open(f, 'w', encoding='utf8') as fout:
40 | fout.write('\n'.join(lines))
41 |
42 | def dump_xlsx(data, f, **kwargs):
43 | data.to_excel(f, index=False, engine='xlsxwriter')
44 |
45 | def dump_csv(data, f, quoting=csv.QUOTE_ALL):
46 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
47 |
48 | def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
49 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
50 |
51 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
52 | suffix = f.split('.')[-1]
53 | return handlers[suffix](data, f, **kwargs)
54 |
55 | def load(f):
56 | def load_pkl(pth):
57 | return pickle.load(open(pth, 'rb'))
58 |
59 | def load_json(pth):
60 | return json.load(open(pth, 'r', encoding='utf-8'))
61 |
62 | def load_jsonl(f):
63 | lines = open(f, encoding='utf-8').readlines()
64 | lines = [x.strip() for x in lines]
65 | if lines[-1] == '':
66 | lines = lines[:-1]
67 | data = [json.loads(x) for x in lines]
68 | return data
69 |
70 | def load_xlsx(f):
71 | return pd.read_excel(f)
72 |
73 | def load_csv(f):
74 | return pd.read_csv(f)
75 |
76 | def load_tsv(f):
77 | return pd.read_csv(f, sep='\t')
78 |
79 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
80 | suffix = f.split('.')[-1]
81 | return handlers[suffix](f)
82 |
83 | def download_file(url, filename=None):
84 | import urllib.request
85 | from tqdm import tqdm
86 |
87 | class DownloadProgressBar(tqdm):
88 | def update_to(self, b=1, bsize=1, tsize=None):
89 | if tsize is not None:
90 | self.total = tsize
91 | self.update(b * bsize - self.n)
92 |
93 | if filename is None:
94 | filename = url.split('/')[-1]
95 |
96 | with DownloadProgressBar(unit='B', unit_scale=True,
97 | miniters=1, desc=url.split('/')[-1]) as t:
98 | urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
99 | return filename
100 |
101 | def ls(dirname='.', match='', mode='all', level=1):
102 | if dirname == '.':
103 | ans = os.listdir(dirname)
104 | else:
105 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
106 | assert mode in ['all', 'dir', 'file']
107 | assert level >= 1 and isinstance(level, int)
108 | if level == 1:
109 | ans = [x for x in ans if match in x]
110 | if mode == 'dir':
111 | ans = [x for x in ans if osp.isdir(x)]
112 | elif mode == 'file':
113 | ans = [x for x in ans if not osp.isdir(x)]
114 | else:
115 | ans = [x for x in ans if osp.isdir(x)]
116 | res = []
117 | for d in ans:
118 | res.extend(ls(d, match=match, mode=mode, level=level-1))
119 | ans = res
120 | return ans
121 |
122 | def mrlines(fname, sp='\n'):
123 | f = open(fname).read().split(sp)
124 | while f != [] and f[-1] == '':
125 | f = f[:-1]
126 | return f
127 |
128 | def mwlines(lines, fname):
129 | with open(fname, 'w') as fout:
130 | fout.write('\n'.join(lines))
131 |
132 | def md5(file_pth):
133 | with open(file_pth, 'rb') as f:
134 | hash = hashlib.new('md5')
135 | for chunk in iter(lambda: f.read(2**20), b''):
136 | hash.update(chunk)
137 | return str(hash.hexdigest())
138 |
139 | def last_modified(pth):
140 | stamp = osp.getmtime(pth)
141 | m_ti = time.ctime(stamp)
142 | t_obj = time.strptime(m_ti)
143 | t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
144 | return t
145 |
--------------------------------------------------------------------------------
/vlmeval/utils/dataset_config.py:
--------------------------------------------------------------------------------
1 | from ..smp import listinstr
2 |
3 | dataset_URLs = {
4 | 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv",
5 | 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv",
6 | 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv",
7 | 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv",
8 | "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only
9 | "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only
10 | 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv",
11 | 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv",
12 | 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv",
13 | "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv",
14 | "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv",
15 | "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv",
16 | "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv",
17 | "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv",
18 | 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv",
19 | "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv",
20 | "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv",
21 | "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv",
22 | 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv",
23 | 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv",
24 | 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv",
25 | 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv",
26 | "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv",
27 | 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv",
28 | "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv",
29 | "OCRBench": 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
30 | }
31 |
32 | dataset_md5_dict = {
33 | 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8",
34 | 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528",
35 | 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd",
36 | 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e",
37 | "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only
38 | "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only
39 | 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07",
40 | 'MME': "b36b43c3f09801f5d368627fb92187c3",
41 | 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0",
42 | "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d",
43 | "MMVet": "f400d7f513a585a0f218cbd6882e0671",
44 | 'COCO_VAL': "72a5079dead060269ac222c5aa5128af",
45 | 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
46 | 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
47 | 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
48 | 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d",
49 | 'MMMU_TEST': "c19875d11a2d348d07e5eb4bdf33166d",
50 | 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
51 | 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc',
52 | 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
53 | 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
54 | 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
55 | "DocVQA_VAL": 'ee0d8ae5527439438d08e154ef65d735',
56 | "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975",
57 | "LLaVABench": "d382a093f749a697820d3dadd61c8428",
58 | "OCRBench": 'e953d98a987cc6e26ef717b61260b778',
59 | }
60 |
61 | img_root_map = {k: k for k in dataset_URLs}
62 | img_root_map.update({
63 | 'MMBench_DEV_EN': "MMBench",
64 | 'MMBench_TEST_EN': "MMBench",
65 | 'MMBench_DEV_CN': "MMBench",
66 | 'MMBench_TEST_CN': "MMBench",
67 | "MMBench_CN": "MMBench", # Link Invalid, Internal Only
68 | 'COCO_VAL':'COCO',
69 | 'OCRVQA_TEST': 'OCRVQA',
70 | 'OCRVQA_TESTCORE': 'OCRVQA',
71 | 'TextVQA_VAL': 'TextVQA',
72 | 'MMMU_DEV_VAL': 'MMMU',
73 | "MMMU_TEST": "MMMU",
74 | 'MathVista_MINI': 'MathVista',
75 | 'ChartQA_VALTEST_HUMAN': 'ChartQA',
76 | 'HallusionBench': 'Hallusion',
77 | 'DocVQA_VAL': 'DocVQA',
78 | "OCRBench": 'OCRBench',
79 | })
80 |
81 | assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict)
82 |
83 | def DATASET_TYPE(dataset):
84 | dataset = dataset.lower()
85 | if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset):
86 | return 'multi-choice'
87 | elif listinstr(['mme', 'hallusion'], dataset):
88 | return 'Y/N'
89 | elif 'coco' in dataset:
90 | return 'Caption'
91 | elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'llavabench', 'mmvet', 'OCRBench'], dataset):
92 | return 'VQA'
93 | else:
94 | return 'QA'
95 |
96 | def abbr2full(s):
97 | datasets = [x for x in img_root_map]
98 | ins = [s in d for d in datasets]
99 | if sum(ins) == 1:
100 | for d in datasets:
101 | if s in d:
102 | return d
103 | else:
104 | return None
105 |
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/dataset_config-checkpoint.py:
--------------------------------------------------------------------------------
1 | from ..smp import listinstr
2 |
3 | dataset_URLs = {
4 | 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv",
5 | 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv",
6 | 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv",
7 | 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv",
8 | "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only
9 | "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only
10 | 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv",
11 | 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv",
12 | 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv",
13 | "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv",
14 | "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv",
15 | "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv",
16 | "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv",
17 | "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv",
18 | 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv",
19 | "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv",
20 | "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv",
21 | "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv",
22 | 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv",
23 | 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv",
24 | 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv",
25 | 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv",
26 | "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv",
27 | 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv",
28 | "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv",
29 | "OCRBench": 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
30 | }
31 |
32 | dataset_md5_dict = {
33 | 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8",
34 | 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528",
35 | 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd",
36 | 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e",
37 | "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only
38 | "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only
39 | 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07",
40 | 'MME': "b36b43c3f09801f5d368627fb92187c3",
41 | 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0",
42 | "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d",
43 | "MMVet": "f400d7f513a585a0f218cbd6882e0671",
44 | 'COCO_VAL': "72a5079dead060269ac222c5aa5128af",
45 | 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
46 | 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
47 | 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
48 | 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d",
49 | 'MMMU_TEST': "c19875d11a2d348d07e5eb4bdf33166d",
50 | 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
51 | 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc',
52 | 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
53 | 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
54 | 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
55 | "DocVQA_VAL": 'ee0d8ae5527439438d08e154ef65d735',
56 | "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975",
57 | "LLaVABench": "d382a093f749a697820d3dadd61c8428",
58 | "OCRBench": 'e953d98a987cc6e26ef717b61260b778',
59 | }
60 |
61 | img_root_map = {k: k for k in dataset_URLs}
62 | img_root_map.update({
63 | 'MMBench_DEV_EN': "MMBench",
64 | 'MMBench_TEST_EN': "MMBench",
65 | 'MMBench_DEV_CN': "MMBench",
66 | 'MMBench_TEST_CN': "MMBench",
67 | "MMBench_CN": "MMBench", # Link Invalid, Internal Only
68 | 'COCO_VAL':'COCO',
69 | 'OCRVQA_TEST': 'OCRVQA',
70 | 'OCRVQA_TESTCORE': 'OCRVQA',
71 | 'TextVQA_VAL': 'TextVQA',
72 | 'MMMU_DEV_VAL': 'MMMU',
73 | "MMMU_TEST": "MMMU",
74 | 'MathVista_MINI': 'MathVista',
75 | 'ChartQA_VALTEST_HUMAN': 'ChartQA',
76 | 'HallusionBench': 'Hallusion',
77 | 'DocVQA_VAL': 'DocVQA',
78 | "OCRBench": 'OCRBench',
79 | })
80 |
81 | assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict)
82 |
83 | def DATASET_TYPE(dataset):
84 | dataset = dataset.lower()
85 | if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset):
86 | return 'multi-choice'
87 | elif listinstr(['mme', 'hallusion'], dataset):
88 | return 'Y/N'
89 | elif 'coco' in dataset:
90 | return 'Caption'
91 | elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'llavabench', 'mmvet', 'OCRBench'], dataset):
92 | return 'VQA'
93 | else:
94 | return 'QA'
95 |
96 | def abbr2full(s):
97 | datasets = [x for x in img_root_map]
98 | ins = [s in d for d in datasets]
99 | if sum(ins) == 1:
100 | for d in datasets:
101 | if s in d:
102 | return d
103 | else:
104 | return None
105 |
--------------------------------------------------------------------------------
/vlmeval/utils/base_prompt.py:
--------------------------------------------------------------------------------
1 |
2 | def create_one_example(format_, question, context, options, answer, knowledge, image_path):
3 |
4 | input_format, output_format = format_.split("-")
5 |
6 | aff_base = "You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems."
7 | neg_base = "You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems."
8 |
9 | kg_emo = f"Emotional Graph: {knowledge[0]}\n" if knowledge[0]!='none' else ""
10 | kg_rat = f"Rational Graph: {knowledge[1]}\n" if knowledge[1]!='none' else ""
11 |
12 | hint = f"Hint: {context}\n" if context != "none" else ""
13 | question_ = f"Question: {question}\n"
14 | option_ = f"Options:\n{options}" if options != "none" else ""
15 | answer_ = "Please select the correct answer from the options above, without any explaination." if options != "none" else "Answer directly."
16 |
17 | if input_format=="IQ":
18 | input = f"""{hint}{question_}{option_}{answer_}"""
19 |
20 | elif input_format == "QIM":
21 | input = f"""{hint}{question_}{option_}For the provided image and its associated question, generate a graph to answer the question.
22 | """
23 |
24 | elif input_format == "GKG":
25 | input = f"""For the provided image and its associated question. generate a scene graph in JSON format that includes the following:
26 | 1. Objects, attributes, relationships that are more relevant to answering the question.
27 | 2. Objects are NO MORE THAN 3.
28 | {hint}{question_}"""
29 |
30 | #### ODebate_stage
31 | elif input_format == "ODIM":
32 | input = f"""{hint}{question_}You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems.
33 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this question.
34 | """
35 |
36 | elif input_format == "ONIM":
37 | input = f"""{hint}{question_}You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems.
38 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this problem.
39 | """
40 | elif input_format == "ODQIM":
41 | input = f"""{hint}{question_}Debate Solution:{knowledge[1]}\nYou are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems.
42 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem.
43 | """
44 |
45 | elif input_format == "ONQIM":
46 | input = f"""{hint}{question_}Debate Solution:{knowledge[0]}\nYou are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems.
47 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem.
48 | """
49 |
50 | elif input_format == "OAGM":
51 | input = f"""You're good at summarizing and answering questions. \nEmotional Solution: {knowledge[0]}\nRational Solution: {knowledge[1]}\n{hint}{question_}{option_}{answer_}
52 | """
53 |
54 |
55 | #### Debate_KG_stage
56 | elif input_format == "KDIM":
57 | input = f"""{aff_base}
58 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
59 | 1. Objects, attributes, relationships that are more relevant to answering the question.
60 | 2. Objects are NO MORE THAN 3.
61 | {hint}{question_}"""
62 |
63 | elif input_format == "KNIM":
64 | input = f"""{neg_base}
65 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
66 | 1. Objects, attributes, relationships that are more relevant to answering the question.
67 | 2. Objects are NO MORE THAN 3.
68 | {hint}{question_}"""
69 |
70 | elif input_format == "KDQIM":
71 | input = f"""{aff_base}
72 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
73 | 1. Objects, attributes, relationships that are more relevant to answering the question.
74 | 2. Delete the irrelevant objects, attributes and relationships.
75 | {hint}{question_}{kg_rat}
76 | """
77 |
78 | elif input_format == "KNQIM":
79 | input = f"""{neg_base}
80 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
81 | 1. Objects, attributes, relationships that are more relevant to answering the question.
82 | 2. Delete the irrelevant objects, attributes and relationships.
83 | {hint}{question_}{kg_emo}
84 | """
85 |
86 | elif input_format == "KAGM":
87 | input = f"""You're good at summarizing and answering questions.{hint}{kg_emo}{kg_rat}Use the image and two debate Solution as context and answer the following question:\n{question_}{option_}{answer_}
88 | """
89 |
90 | # Outputs
91 | if output_format == 'A':
92 | output = "Answer:"
93 |
94 | elif output_format == 'G':
95 | output = f"Solution: "
96 |
97 | text = input + output
98 | text = text.replace(" ", " ")
99 | if text.endswith("BECAUSE:"):
100 | text = text.replace("BECAUSE:", "").strip()
101 | return text
102 |
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/base_prompt-checkpoint.py:
--------------------------------------------------------------------------------
1 |
2 | def create_one_example(format_, question, context, options, answer, knowledge, image_path):
3 |
4 | input_format, output_format = format_.split("-")
5 |
6 | aff_base = "You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems."
7 | neg_base = "You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems."
8 |
9 | kg_emo = f"Graph: {knowledge[0]}\n" if knowledge[0]!='none' else ""
10 | kg_rat = f"Graph: {knowledge[1]}\n" if knowledge[1]!='none' else ""
11 |
12 | hint = f"Hint: {context}\n" if context != "none" else ""
13 | question_ = f"Question: {question}\n"
14 | option_ = f"Options:\n{options}" if options != "none" else ""
15 | answer_ = "Answer with the option's letter from the given choices directly." if options != "none" else "Answer directly."
16 |
17 | if input_format=="IQ":
18 | input = f"""{hint}{question_}{option_}{answer_}"""
19 |
20 | elif input_format == "QIM":
21 | input = f"""{hint}{question_}{option_}For the provided image and its associated question, generate a graph to answer the question.
22 | """
23 |
24 | elif input_format == "GKG":
25 | input = f"""{hint}{question_}{option_}For the provided image and its associated question. generate a scene graph in JSON format that includes the following:
26 | 1. Obiects that are relevant to answering the question.
27 | 2. Obiect attributes that are relevant to answering the question.
28 | 3. Obiect relationships that are releyant to answering the question.
29 | """
30 |
31 | #### ODebate_stage
32 | elif input_format == "ODIM":
33 | input = f"""{hint}{question_}You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems.
34 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this question.
35 | """
36 |
37 | elif input_format == "ONIM":
38 | input = f"""{hint}{question_}You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems.
39 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this problem.
40 | """
41 | elif input_format == "ODQIM":
42 | input = f"""{hint}{question_}Debate Solution:{knowledge[1]}\nYou are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems.
43 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem.
44 | """
45 |
46 | elif input_format == "ONQIM":
47 | input = f"""{hint}{question_}Debate Solution:{knowledge[0]}\nYou are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems.
48 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem.
49 | """
50 |
51 | elif input_format == "OAGM":
52 | input = f"""You're good at summarizing and answering questions. \nEmotional Solution: {knowledge[0]}\nRational Solution: {knowledge[1]}\n{hint}{question_}{option_}{answer_}
53 | """
54 |
55 |
56 | #### Debate_KG_stage
57 | elif input_format == "KDIM":
58 | input = f"""{aff_base}
59 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
60 | 1. Objects, attributes, relationships that are more relevant to answering the question.
61 | 2. Delete the irrelevant objects, attributes and relationships.
62 | {hint}{question_}"""
63 |
64 | elif input_format == "KNIM":
65 | input = f"""{neg_base}
66 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
67 | 1. Objects, attributes, relationships that are more relevant to answering the question.
68 | 2. Delete the irrelevant objects, attributes and relationships.
69 | {hint}{question_}"""
70 |
71 | elif input_format == "KDQIM":
72 | input = f"""{aff_base}
73 | Based on the debate Solution of the question, Please give your Better solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
74 | 1. Objects, attributes, relationships that are more relevant to answering the question.
75 | 2. Delete the irrelevant objects, attributes and relationships.
76 | {hint}{question_}{kg_rat}
77 | """
78 |
79 | elif input_format == "KNQIM":
80 | input = f"""{neg_base}
81 | Based on the debate Solution of the question, Please give your Better solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following:
82 | 1. Objects, attributes, relationships that are more relevant to answering the question.
83 | 2. Delete the irrelevant objects, attributes and relationships.
84 | {hint}{question_}{kg_emo}
85 | """
86 |
87 | elif input_format == "KAGM":
88 | input = f"""{hint}{kg_emo}{kg_rat}{question_}{option_}{answer_}
89 | """
90 |
91 | # Outputs
92 | if output_format == 'A':
93 | output = "Answer:"
94 |
95 | elif output_format == 'G':
96 | output = f"Graph: "
97 |
98 | text = input + output
99 | text = text.replace(" ", " ")
100 | if text.endswith("BECAUSE:"):
101 | text = text.replace("BECAUSE:", "").strip()
102 | return text
103 |
--------------------------------------------------------------------------------
/vlmeval/smp/lb.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | from collections import defaultdict
4 | import gradio as gr
5 | import copy as cp
6 | import numpy as np
7 | from .misc import listinstr
8 |
9 | # CONSTANTS-URL
10 | URL = "https://github.com/thecharm/BDoG.git"
11 | VLMEVALKIT_README = 'https://raw.githubusercontent.com/thecharm/BDoG/main/README.md'
12 | # CONSTANTS-CITATION
13 | CITATION_BUTTON_TEXT = r"""@misc{zheng2024pictureworthgraphblueprint,
14 | title={A Picture Is Worth a Graph: Blueprint Debate on Graph for Multimodal Reasoning},
15 | author={Changmeng Zheng and Dayong Liang and Wengyu Zhang and Xiao-Yong Wei and Tat-Seng Chua and Qing Li},
16 | year={2024},
17 | eprint={2403.14972},
18 | archivePrefix={arXiv},
19 | primaryClass={cs.AI},
20 | url={https://arxiv.org/abs/2403.14972},
21 | }"""
22 | CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
23 | # CONSTANTS-TEXT
24 | LEADERBORAD_INTRODUCTION = ""
25 | # CONSTANTS-FIELDS
26 | META_FIELDS = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model', 'OpenSource', 'Verified']
27 | MAIN_FIELDS = ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'CCBench', 'MME', 'SEEDBench_IMG', 'MMVet', 'MMMU_VAL', 'MathVista', 'HallusionBench', 'LLaVABench']
28 | MMBENCH_FIELDS = ['MMBench_TEST_EN', 'MMBench_DEV_EN', 'MMBench_TEST_CN', 'MMBench_DEV_CN', 'CCBench']
29 | MODEL_SIZE = ['<10B', '10B-20B', '20B-40B', '>40B', 'Unknown']
30 | MODEL_TYPE = ['API', 'OpenSource', 'Proprietary']
31 |
32 | LEADERBOARD_MD = {
33 | }
34 |
35 | from urllib.request import urlopen
36 |
37 | def load_results():
38 | data = json.loads(urlopen(URL).read())
39 | return data
40 |
41 | def nth_large(val, vals):
42 | return sum([1 for v in vals if v > val]) + 1
43 |
44 | def format_timestamp(timestamp):
45 | return timestamp[:2] + '.' + timestamp[2:4] + '.' + timestamp[4:6] + ' ' + timestamp[6:8] + ':' + timestamp[8:10] + ':' + timestamp[10:12]
46 |
47 | def model_size_flag(sz, FIELDS):
48 | if pd.isna(sz) and 'Unknown' in FIELDS:
49 | return True
50 | if pd.isna(sz):
51 | return False
52 | if '<10B' in FIELDS and sz < 10:
53 | return True
54 | if '10B-20B' in FIELDS and sz >= 10 and sz < 20:
55 | return True
56 | if '20B-40B' in FIELDS and sz >= 20 and sz < 40:
57 | return True
58 | if '>40B' in FIELDS and sz >= 40:
59 | return True
60 | return False
61 |
62 | def model_type_flag(line, FIELDS):
63 | if 'OpenSource' in FIELDS and line['OpenSource'] == 'Yes':
64 | return True
65 | if 'API' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'Yes':
66 | return True
67 | if 'Proprietary' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'No':
68 | return True
69 | return False
70 |
71 | def BUILD_L1_DF(results, fields):
72 | res = defaultdict(list)
73 | for i, m in enumerate(results):
74 | item = results[m]
75 | meta = item['META']
76 | for k in META_FIELDS:
77 | if k == 'Parameters (B)':
78 | param = meta['Parameters']
79 | res[k].append(float(param.replace('B', '')) if param != '' else None)
80 | elif k == 'Method':
81 | name, url = meta['Method']
82 | res[k].append(f'{name}')
83 | else:
84 | res[k].append(meta[k])
85 | scores, ranks = [], []
86 | for d in fields:
87 | res[d].append(item[d]['Overall'])
88 | if d == 'MME':
89 | scores.append(item[d]['Overall'] / 28)
90 | else:
91 | scores.append(item[d]['Overall'])
92 | ranks.append(nth_large(item[d]['Overall'], [x[d]['Overall'] for x in results.values()]))
93 | res['Avg Score'].append(round(np.mean(scores), 1))
94 | res['Avg Rank'].append(round(np.mean(ranks), 2))
95 |
96 | df = pd.DataFrame(res)
97 | df = df.sort_values('Avg Rank')
98 |
99 | check_box = {}
100 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model']
101 | check_box['required'] = ['Avg Score', 'Avg Rank']
102 | check_box['all'] = check_box['required'] + ['OpenSource', 'Verified'] + fields
103 | type_map = defaultdict(lambda: 'number')
104 | type_map['Method'] = 'html'
105 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str'
106 | check_box['type_map'] = type_map
107 | return df, check_box
108 |
109 | def BUILD_L2_DF(results, dataset):
110 | res = defaultdict(list)
111 | fields = list(list(results.values())[0][dataset].keys())
112 | non_overall_fields = [x for x in fields if 'Overall' not in x]
113 | overall_fields = [x for x in fields if 'Overall' in x]
114 | if dataset == 'MME':
115 | non_overall_fields = [x for x in non_overall_fields if not listinstr(['Perception', 'Cognition'], x)]
116 | overall_fields = overall_fields + ['Perception', 'Cognition']
117 |
118 | for m in results:
119 | item = results[m]
120 | meta = item['META']
121 | for k in META_FIELDS:
122 | if k == 'Parameters (B)':
123 | param = meta['Parameters']
124 | res[k].append(float(param.replace('B', '')) if param != '' else None)
125 | elif k == 'Method':
126 | name, url = meta['Method']
127 | res[k].append(f'{name}')
128 | else:
129 | res[k].append(meta[k])
130 | fields = [x for x in fields]
131 |
132 | for d in non_overall_fields:
133 | res[d].append(item[dataset][d])
134 | for d in overall_fields:
135 | res[d].append(item[dataset][d])
136 |
137 | df = pd.DataFrame(res)
138 | df = df.sort_values('Overall')
139 | df = df.iloc[::-1]
140 |
141 | check_box = {}
142 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model']
143 | check_box['required'] = overall_fields
144 | check_box['all'] = non_overall_fields + overall_fields
145 | type_map = defaultdict(lambda: 'number')
146 | type_map['Method'] = 'html'
147 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str'
148 | check_box['type_map'] = type_map
149 | return df, check_box
--------------------------------------------------------------------------------
/vlmeval/smp/.ipynb_checkpoints/lb-checkpoint.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | from collections import defaultdict
4 | import gradio as gr
5 | import copy as cp
6 | import numpy as np
7 | from .misc import listinstr
8 |
9 | # CONSTANTS-URL
10 | URL = "https://github.com/thecharm/BDoG.git"
11 | VLMEVALKIT_README = 'https://raw.githubusercontent.com/thecharm/BDoG/main/README.md'
12 | # CONSTANTS-CITATION
13 | CITATION_BUTTON_TEXT = r"""@misc{zheng2024pictureworthgraphblueprint,
14 | title={A Picture Is Worth a Graph: Blueprint Debate on Graph for Multimodal Reasoning},
15 | author={Changmeng Zheng and Dayong Liang and Wengyu Zhang and Xiao-Yong Wei and Tat-Seng Chua and Qing Li},
16 | year={2024},
17 | eprint={2403.14972},
18 | archivePrefix={arXiv},
19 | primaryClass={cs.AI},
20 | url={https://arxiv.org/abs/2403.14972},
21 | }"""
22 | CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
23 | # CONSTANTS-TEXT
24 | LEADERBORAD_INTRODUCTION = ""
25 | # CONSTANTS-FIELDS
26 | META_FIELDS = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model', 'OpenSource', 'Verified']
27 | MAIN_FIELDS = ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'CCBench', 'MME', 'SEEDBench_IMG', 'MMVet', 'MMMU_VAL', 'MathVista', 'HallusionBench', 'LLaVABench']
28 | MMBENCH_FIELDS = ['MMBench_TEST_EN', 'MMBench_DEV_EN', 'MMBench_TEST_CN', 'MMBench_DEV_CN', 'CCBench']
29 | MODEL_SIZE = ['<10B', '10B-20B', '20B-40B', '>40B', 'Unknown']
30 | MODEL_TYPE = ['API', 'OpenSource', 'Proprietary']
31 |
32 | LEADERBOARD_MD = {
33 | }
34 |
35 | from urllib.request import urlopen
36 |
37 | def load_results():
38 | data = json.loads(urlopen(URL).read())
39 | return data
40 |
41 | def nth_large(val, vals):
42 | return sum([1 for v in vals if v > val]) + 1
43 |
44 | def format_timestamp(timestamp):
45 | return timestamp[:2] + '.' + timestamp[2:4] + '.' + timestamp[4:6] + ' ' + timestamp[6:8] + ':' + timestamp[8:10] + ':' + timestamp[10:12]
46 |
47 | def model_size_flag(sz, FIELDS):
48 | if pd.isna(sz) and 'Unknown' in FIELDS:
49 | return True
50 | if pd.isna(sz):
51 | return False
52 | if '<10B' in FIELDS and sz < 10:
53 | return True
54 | if '10B-20B' in FIELDS and sz >= 10 and sz < 20:
55 | return True
56 | if '20B-40B' in FIELDS and sz >= 20 and sz < 40:
57 | return True
58 | if '>40B' in FIELDS and sz >= 40:
59 | return True
60 | return False
61 |
62 | def model_type_flag(line, FIELDS):
63 | if 'OpenSource' in FIELDS and line['OpenSource'] == 'Yes':
64 | return True
65 | if 'API' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'Yes':
66 | return True
67 | if 'Proprietary' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'No':
68 | return True
69 | return False
70 |
71 | def BUILD_L1_DF(results, fields):
72 | res = defaultdict(list)
73 | for i, m in enumerate(results):
74 | item = results[m]
75 | meta = item['META']
76 | for k in META_FIELDS:
77 | if k == 'Parameters (B)':
78 | param = meta['Parameters']
79 | res[k].append(float(param.replace('B', '')) if param != '' else None)
80 | elif k == 'Method':
81 | name, url = meta['Method']
82 | res[k].append(f'{name}')
83 | else:
84 | res[k].append(meta[k])
85 | scores, ranks = [], []
86 | for d in fields:
87 | res[d].append(item[d]['Overall'])
88 | if d == 'MME':
89 | scores.append(item[d]['Overall'] / 28)
90 | else:
91 | scores.append(item[d]['Overall'])
92 | ranks.append(nth_large(item[d]['Overall'], [x[d]['Overall'] for x in results.values()]))
93 | res['Avg Score'].append(round(np.mean(scores), 1))
94 | res['Avg Rank'].append(round(np.mean(ranks), 2))
95 |
96 | df = pd.DataFrame(res)
97 | df = df.sort_values('Avg Rank')
98 |
99 | check_box = {}
100 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model']
101 | check_box['required'] = ['Avg Score', 'Avg Rank']
102 | check_box['all'] = check_box['required'] + ['OpenSource', 'Verified'] + fields
103 | type_map = defaultdict(lambda: 'number')
104 | type_map['Method'] = 'html'
105 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str'
106 | check_box['type_map'] = type_map
107 | return df, check_box
108 |
109 | def BUILD_L2_DF(results, dataset):
110 | res = defaultdict(list)
111 | fields = list(list(results.values())[0][dataset].keys())
112 | non_overall_fields = [x for x in fields if 'Overall' not in x]
113 | overall_fields = [x for x in fields if 'Overall' in x]
114 | if dataset == 'MME':
115 | non_overall_fields = [x for x in non_overall_fields if not listinstr(['Perception', 'Cognition'], x)]
116 | overall_fields = overall_fields + ['Perception', 'Cognition']
117 |
118 | for m in results:
119 | item = results[m]
120 | meta = item['META']
121 | for k in META_FIELDS:
122 | if k == 'Parameters (B)':
123 | param = meta['Parameters']
124 | res[k].append(float(param.replace('B', '')) if param != '' else None)
125 | elif k == 'Method':
126 | name, url = meta['Method']
127 | res[k].append(f'{name}')
128 | else:
129 | res[k].append(meta[k])
130 | fields = [x for x in fields]
131 |
132 | for d in non_overall_fields:
133 | res[d].append(item[dataset][d])
134 | for d in overall_fields:
135 | res[d].append(item[dataset][d])
136 |
137 | df = pd.DataFrame(res)
138 | df = df.sort_values('Overall')
139 | df = df.iloc[::-1]
140 |
141 | check_box = {}
142 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model']
143 | check_box['required'] = overall_fields
144 | check_box['all'] = non_overall_fields + overall_fields
145 | type_map = defaultdict(lambda: 'number')
146 | type_map['Method'] = 'html'
147 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str'
148 | check_box['type_map'] = type_map
149 | return df, check_box
--------------------------------------------------------------------------------
/vlmeval/inference_multi.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.distributed as dist
4 | import random
5 | import datetime
6 | from vlmeval.config import supported_VLM
7 | from vlmeval.utils import TSVDataset, track_progress_rich, split_MMMU, Debate_VLM
8 | from vlmeval.smp import *
9 | import logging
10 | import numpy as np
11 |
12 | FAIL_MSG = 'Failed to obtain answer via API.'
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data', type=str, nargs='+', required=True)
17 | parser.add_argument("--model", type=str, nargs='+', required=True)
18 | parser.add_argument("--nproc", type=int, default=4, required=True)
19 | parser.add_argument("--debate", type=int, default=2, required=True)
20 | parser.add_argument("--verbose", action='store_true')
21 | args = parser.parse_args()
22 | return args
23 |
24 | def check_identical_elements_2(lst):
25 | return len(set(lst)) == 1
26 |
27 | def find_duplicates(lst):
28 | return list(set([x for x in lst if lst.count(x) > 1]))
29 |
30 | def infer_data(model_name, dataset_name, out_file, logger, kg_init, stage="base", debate=2, verbose=False, api_nproc=4):
31 |
32 | res = {}
33 | if osp.exists(out_file):
34 | res = load(out_file)
35 |
36 | # Dataset init
37 | rank, world_size = get_rank_and_world_size()
38 | if rank == 0:
39 | dataset = TSVDataset(dataset_name)
40 | if world_size > 1:
41 | dist.barrier()
42 | dataset = TSVDataset(dataset_name)
43 | indices = list(range(rank, len(dataset), world_size))
44 | lt = len(indices)
45 | data = dataset.data.iloc[indices]
46 |
47 | # If finished, will exit without building the model
48 | all_finished = True
49 | for i in range(lt):
50 | idx = data.iloc[i]['index']
51 | if idx not in res:
52 | all_finished = False
53 | if all_finished:
54 | return
55 | data = data[~data['index'].isin(res)]
56 | lt = len(data)
57 |
58 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
59 |
60 | for i in tqdm(range(lt)):
61 | idx = data.iloc[i]['index']
62 | if idx in res:
63 | continue
64 |
65 | if stage != "base_line":
66 | struct = dataset.build_prompt_multi(data.iloc[i])
67 | else:
68 | struct = dataset.build_prompt(data.iloc[i])
69 |
70 | response = Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger)
71 | torch.cuda.empty_cache()
72 |
73 | if verbose:
74 | print(response, flush=True)
75 |
76 | res[idx] = response
77 | if (i + 1) % 20 == 0:
78 | dump(res, out_file)
79 |
80 | dump(res, out_file)
81 | return model
82 |
83 | def prefetch_acc(result_file):
84 | data = load(result_file)
85 | from vlmeval.evaluate.multiple_choice import build_choices, can_infer
86 | tot = defaultdict(lambda: 0)
87 | match = defaultdict(lambda: 0)
88 | hit = defaultdict(lambda: 0)
89 | lt = len(data)
90 | for i in range(lt):
91 | item = data.iloc[i]
92 | cate = item['category']
93 | tot['Overall'] += 1
94 | tot[cate] += 1
95 | choices = build_choices(item)
96 | matched = can_infer(item['prediction'], choices)
97 | if matched:
98 | match['Overall'] += 1
99 | match[cate] += 1
100 | if matched == item['answer']:
101 | hit['Overall'] += 1
102 | hit[cate] += 1
103 | res = defaultdict(list)
104 | for k in tot.keys():
105 | res['Category'].append(k)
106 | res['tot'].append(tot[k])
107 | res['match'].append(match[k])
108 | res['hit'].append(hit[k])
109 | res['match_rate'].append(match[k] / tot[k] * 100)
110 | if match[k] == 0:
111 | res['acc'].append(0)
112 | else:
113 | res['acc'].append(hit[k] / match[k] * 100)
114 | res = pd.DataFrame(res)
115 | return res
116 |
117 | def infer_data_job(model, model_name, dataset_name, args, logger, ignore_failed=False):
118 |
119 | result_ = f'results/{model_name}/{dataset_name}/'
120 | result_file = result_ + f'{model_name}_{dataset_name}_{args.stage}_DB{args.debate}.xlsx'
121 | rank, world_size = get_rank_and_world_size()
122 | tmpl = result_ + '{}' + f'{world_size}_{dataset_name}_{args.stage}_DB{args.debate}.pkl'
123 | out_file = tmpl.format(rank)
124 |
125 | if not osp.exists(result_file):
126 | model = infer_data(model, dataset_name=dataset_name, out_file=out_file, logger=logger, kg_init=args.kg_init, stage=args.stage, debate=args.debate, verbose=args.verbose)
127 | if world_size > 1:
128 | dist.barrier()
129 |
130 | if rank == 0:
131 | data_all = {}
132 | for i in range(world_size):
133 | data_all.update(load(tmpl.format(i)))
134 |
135 | data = TSVDataset(dataset_name).data
136 | print(len(data_all))
137 | print(len(data))
138 | assert len(data_all) == len(data)
139 | data['prediction'] = [str(data_all[x]) for x in data['index']]
140 | data.pop('image')
141 |
142 | dump(data, result_file)
143 | for i in range(world_size):
144 | os.remove(tmpl.format(i))
145 | return model
146 | else:
147 | data = load(result_file)
148 | failed_set = []
149 | data['prediction'] = [str(x) for x in data['prediction']]
150 | for idx, pred in zip(data['index'], data['prediction']):
151 | if FAIL_MSG in str(pred):
152 | failed_set.append(idx)
153 | if len(failed_set) and (not ignore_failed):
154 | print(f'{len(failed_set)} records failed in the original result file {result_file}. ')
155 | assert rank == 0 and world_size == 1
156 | failed_set = set(failed_set)
157 | answer_map = {x: y for x, y in zip(data['index'], data['prediction'])}
158 | res = infer_data_api(model_name, dataset_name, failed_set, api_nproc=args.api_nproc)
159 | answer_map.update(res)
160 | data['prediction'] = [str(answer_map[x]) for x in data['index']]
161 | dump(data, result_file)
162 | return model_name
163 |
164 | def main():
165 | logger = get_logger('Inference')
166 |
167 | args = parse_args()
168 | assert len(args.data), "--data should be a list of data files"
169 |
170 | rank, world_size = get_rank_and_world_size()
171 | if world_size > 1:
172 | torch.cuda.set_device(rank)
173 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400))
174 |
175 | for _, model_name in enumerate(args.model):
176 | model = None
177 | os.makedirs(model_name, exist_ok=True)
178 | pred_root = model_name
179 |
180 | for i, dataset_name in enumerate(args.data):
181 |
182 | result_file = f'{pred_root}/{dataset_name}/{model_name}_{dataset_name}.xlsx'
183 | if model is None:
184 | model = model_name # which is only a name
185 | model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc)
186 |
187 | if rank == 0 and listinstr(['MMBench','ScienceQA'], dataset_name):
188 | time.sleep(3)
189 | res = prefetch_acc(result_file)
190 | print(model_name, res)
191 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx'))
192 |
193 | if __name__ == '__main__':
194 | main()
195 |
--------------------------------------------------------------------------------
/vlmeval/.ipynb_checkpoints/inference_multi-checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.distributed as dist
4 | import random
5 | import datetime
6 | from vlmeval.config import supported_VLM
7 | from vlmeval.utils import TSVDataset, track_progress_rich, split_MMMU, Debate_VLM
8 | from vlmeval.smp import *
9 | import logging
10 | import numpy as np
11 |
12 | FAIL_MSG = 'Failed to obtain answer via API.'
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data', type=str, nargs='+', required=True)
17 | parser.add_argument("--model", type=str, nargs='+', required=True)
18 | parser.add_argument("--nproc", type=int, default=4, required=True)
19 | parser.add_argument("--debate", type=int, default=2, required=True)
20 | parser.add_argument("--verbose", action='store_true')
21 | args = parser.parse_args()
22 | return args
23 |
24 | def check_identical_elements_2(lst):
25 | return len(set(lst)) == 1
26 |
27 | def find_duplicates(lst):
28 | return list(set([x for x in lst if lst.count(x) > 1]))
29 |
30 | def infer_data(model_name, dataset_name, out_file, logger, kg_init, stage="base", debate=2, verbose=False, api_nproc=4):
31 |
32 | res = {}
33 | if osp.exists(out_file):
34 | res = load(out_file)
35 |
36 | # Dataset init
37 | rank, world_size = get_rank_and_world_size()
38 | if rank == 0:
39 | dataset = TSVDataset(dataset_name)
40 | if world_size > 1:
41 | dist.barrier()
42 | dataset = TSVDataset(dataset_name)
43 | indices = list(range(rank, len(dataset), world_size))
44 | lt = len(indices)
45 | data = dataset.data.iloc[indices]
46 |
47 | # If finished, will exit without building the model
48 | all_finished = True
49 | for i in range(lt):
50 | idx = data.iloc[i]['index']
51 | if idx not in res:
52 | all_finished = False
53 | if all_finished:
54 | return
55 | data = data[~data['index'].isin(res)]
56 | lt = len(data)
57 |
58 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
59 |
60 | for i in tqdm(range(lt)):
61 | idx = data.iloc[i]['index']
62 | if idx in res:
63 | continue
64 |
65 | if stage != "base_line":
66 | struct = dataset.build_prompt_multi(data.iloc[i])
67 | else:
68 | struct = dataset.build_prompt(data.iloc[i])
69 |
70 | response = Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger)
71 | torch.cuda.empty_cache()
72 |
73 | if verbose:
74 | print(response, flush=True)
75 |
76 | res[idx] = response
77 | if (i + 1) % 20 == 0:
78 | dump(res, out_file)
79 |
80 | dump(res, out_file)
81 | return model
82 |
83 | def prefetch_acc(result_file):
84 | data = load(result_file)
85 | from vlmeval.evaluate.multiple_choice import build_choices, can_infer
86 | tot = defaultdict(lambda: 0)
87 | match = defaultdict(lambda: 0)
88 | hit = defaultdict(lambda: 0)
89 | lt = len(data)
90 | for i in range(lt):
91 | item = data.iloc[i]
92 | cate = item['category']
93 | tot['Overall'] += 1
94 | tot[cate] += 1
95 | choices = build_choices(item)
96 | matched = can_infer(item['prediction'], choices)
97 | if matched:
98 | match['Overall'] += 1
99 | match[cate] += 1
100 | if matched == item['answer']:
101 | hit['Overall'] += 1
102 | hit[cate] += 1
103 | res = defaultdict(list)
104 | for k in tot.keys():
105 | res['Category'].append(k)
106 | res['tot'].append(tot[k])
107 | res['match'].append(match[k])
108 | res['hit'].append(hit[k])
109 | res['match_rate'].append(match[k] / tot[k] * 100)
110 | if match[k] == 0:
111 | res['acc'].append(0)
112 | else:
113 | res['acc'].append(hit[k] / match[k] * 100)
114 | res = pd.DataFrame(res)
115 | return res
116 |
117 | def infer_data_job(model, model_name, dataset_name, args, logger, ignore_failed=False):
118 |
119 | result_ = f'results/{model_name}/{dataset_name}/'
120 | result_file = result_ + f'{model_name}_{dataset_name}_{args.stage}_DB{args.debate}.xlsx'
121 | rank, world_size = get_rank_and_world_size()
122 | tmpl = result_ + '{}' + f'{world_size}_{dataset_name}_{args.stage}_DB{args.debate}.pkl'
123 | out_file = tmpl.format(rank)
124 |
125 | if not osp.exists(result_file):
126 | model = infer_data(model, dataset_name=dataset_name, out_file=out_file, logger=logger, kg_init=args.kg_init, stage=args.stage, debate=args.debate, verbose=args.verbose)
127 | if world_size > 1:
128 | dist.barrier()
129 |
130 | if rank == 0:
131 | data_all = {}
132 | for i in range(world_size):
133 | data_all.update(load(tmpl.format(i)))
134 |
135 | data = TSVDataset(dataset_name).data
136 | print(len(data_all))
137 | print(len(data))
138 | assert len(data_all) == len(data)
139 | data['prediction'] = [str(data_all[x]) for x in data['index']]
140 | data.pop('image')
141 |
142 | dump(data, result_file)
143 | for i in range(world_size):
144 | os.remove(tmpl.format(i))
145 | return model
146 | else:
147 | data = load(result_file)
148 | failed_set = []
149 | data['prediction'] = [str(x) for x in data['prediction']]
150 | for idx, pred in zip(data['index'], data['prediction']):
151 | if FAIL_MSG in str(pred):
152 | failed_set.append(idx)
153 | if len(failed_set) and (not ignore_failed):
154 | print(f'{len(failed_set)} records failed in the original result file {result_file}. ')
155 | assert rank == 0 and world_size == 1
156 | failed_set = set(failed_set)
157 | answer_map = {x: y for x, y in zip(data['index'], data['prediction'])}
158 | res = infer_data_api(model_name, dataset_name, failed_set, api_nproc=args.api_nproc)
159 | answer_map.update(res)
160 | data['prediction'] = [str(answer_map[x]) for x in data['index']]
161 | dump(data, result_file)
162 | return model_name
163 |
164 | def main():
165 | logger = get_logger('Inference')
166 |
167 | args = parse_args()
168 | assert len(args.data), "--data should be a list of data files"
169 |
170 | rank, world_size = get_rank_and_world_size()
171 | if world_size > 1:
172 | torch.cuda.set_device(rank)
173 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400))
174 |
175 | for _, model_name in enumerate(args.model):
176 | model = None
177 | os.makedirs(model_name, exist_ok=True)
178 | pred_root = model_name
179 |
180 | for i, dataset_name in enumerate(args.data):
181 |
182 | result_file = f'{pred_root}/{dataset_name}/{model_name}_{dataset_name}.xlsx'
183 | if model is None:
184 | model = model_name # which is only a name
185 | model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc)
186 |
187 | if rank == 0 and listinstr(['MMBench','ScienceQA'], dataset_name):
188 | time.sleep(3)
189 | res = prefetch_acc(result_file)
190 | print(model_name, res)
191 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx'))
192 |
193 | if __name__ == '__main__':
194 | main()
195 |
--------------------------------------------------------------------------------
/vlmeval/utils/mp_util.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import os
3 | from typing import Callable, Iterable, Sized
4 |
5 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
6 | TaskProgressColumn, TextColumn, TimeRemainingColumn)
7 | from rich.text import Text
8 | import os.path as osp
9 | import portalocker
10 | from ..smp import load, dump
11 |
12 |
13 | class _Worker:
14 | """Function wrapper for ``track_progress_rich``"""
15 |
16 | def __init__(self, func) -> None:
17 | self.func = func
18 |
19 | def __call__(self, inputs):
20 | inputs, idx = inputs
21 | if not isinstance(inputs, (tuple, list, dict)):
22 | inputs = (inputs, )
23 |
24 | if isinstance(inputs, dict):
25 | return self.func(**inputs), idx
26 | else:
27 | return self.func(*inputs), idx
28 |
29 |
30 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
31 | """Skip calculating remaining time for the first few times.
32 |
33 | Args:
34 | skip_times (int): The number of times to skip. Defaults to 0.
35 | """
36 |
37 | def __init__(self, *args, skip_times=0, **kwargs):
38 | super().__init__(*args, **kwargs)
39 | self.skip_times = skip_times
40 |
41 | def render(self, task: Task) -> Text:
42 | """Show time remaining."""
43 | if task.completed <= self.skip_times:
44 | return Text('-:--:--', style='progress.remaining')
45 | return super().render(task)
46 |
47 |
48 | def _tasks_with_index(tasks):
49 | """Add index to tasks."""
50 | for idx, task in enumerate(tasks):
51 | yield task, idx
52 |
53 | def track_progress_rich(func: Callable,
54 | tasks: Iterable = tuple(),
55 | task_num: int = None,
56 | nproc: int = 1,
57 | chunksize: int = 1,
58 | description: str = 'Processing',
59 | save=None, keys=None,
60 | color: str = 'blue') -> list:
61 | """Track the progress of parallel task execution with a progress bar. The
62 | built-in :mod:`multiprocessing` module is used for process pools and tasks
63 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
64 |
65 | Args:
66 | func (callable): The function to be applied to each task.
67 | tasks (Iterable or Sized): A tuple of tasks. There are several cases
68 | for different format tasks:
69 | - When ``func`` accepts no arguments: tasks should be an empty
70 | tuple, and ``task_num`` must be specified.
71 | - When ``func`` accepts only one argument: tasks should be a tuple
72 | containing the argument.
73 | - When ``func`` accepts multiple arguments: tasks should be a
74 | tuple, with each element representing a set of arguments.
75 | If an element is a ``dict``, it will be parsed as a set of
76 | keyword-only arguments.
77 | Defaults to an empty tuple.
78 | task_num (int, optional): If ``tasks`` is an iterator which does not
79 | have length, the number of tasks can be provided by ``task_num``.
80 | Defaults to None.
81 | nproc (int): Process (worker) number, if nuproc is 1,
82 | use single process. Defaults to 1.
83 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
84 | Defaults to 1.
85 | description (str): The description of progress bar.
86 | Defaults to "Process".
87 | color (str): The color of progress bar. Defaults to "blue".
88 |
89 | Examples:
90 | >>> import time
91 |
92 | >>> def func(x):
93 | ... time.sleep(1)
94 | ... return x**2
95 | >>> track_progress_rich(func, range(10), nproc=2)
96 |
97 | Returns:
98 | list: The task results.
99 | """
100 | if save is not None:
101 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
102 | if not osp.exists(save):
103 | dump({}, save)
104 | if keys is not None:
105 | assert len(keys) == len(tasks)
106 |
107 | if not callable(func):
108 | raise TypeError('func must be a callable object')
109 | if not isinstance(tasks, Iterable):
110 | raise TypeError(
111 | f'tasks must be an iterable object, but got {type(tasks)}')
112 | if isinstance(tasks, Sized):
113 | if len(tasks) == 0:
114 | if task_num is None:
115 | raise ValueError('If tasks is an empty iterable, '
116 | 'task_num must be set')
117 | else:
118 | tasks = tuple(tuple() for _ in range(task_num))
119 | else:
120 | if task_num is not None and task_num != len(tasks):
121 | raise ValueError('task_num does not match the length of tasks')
122 | task_num = len(tasks)
123 |
124 | if nproc <= 0:
125 | raise ValueError('nproc must be a positive number')
126 |
127 | skip_times = nproc * chunksize if nproc > 1 else 0
128 | prog_bar = Progress(
129 | TextColumn('{task.description}'),
130 | BarColumn(),
131 | _SkipFirstTimeRemainingColumn(skip_times=skip_times),
132 | MofNCompleteColumn(),
133 | TaskProgressColumn(show_speed=True),
134 | )
135 |
136 | worker = _Worker(func)
137 | task_id = prog_bar.add_task(
138 | total=task_num, color=color, description=description)
139 | tasks = _tasks_with_index(tasks)
140 |
141 | # Use single process when nproc is 1, else use multiprocess.
142 | with prog_bar:
143 | if nproc == 1:
144 | results = []
145 | for task in tasks:
146 | result, idx = worker(task)
147 | results.append(worker(task)[0])
148 | if save is not None:
149 | with portalocker.Lock(save, timeout=5) as fh:
150 | ans = load(save)
151 | ans[keys[idx]] = result
152 |
153 | if os.environ.get('VERBOSE', True):
154 | print(keys[idx], result, flush=True)
155 |
156 | dump(ans, save)
157 | fh.flush()
158 | os.fsync(fh.fileno())
159 |
160 | prog_bar.update(task_id, advance=1, refresh=True)
161 | else:
162 | with Pool(nproc) as pool:
163 | results = []
164 | unordered_results = []
165 | gen = pool.imap_unordered(worker, tasks, chunksize)
166 | try:
167 | for result in gen:
168 | result, idx = result
169 | unordered_results.append((result, idx))
170 |
171 | if save is not None:
172 | with portalocker.Lock(save, timeout=5) as fh:
173 | ans = load(save)
174 | ans[keys[idx]] = result
175 |
176 | if os.environ.get('VERBOSE', False):
177 | print(keys[idx], result, flush=True)
178 |
179 | dump(ans, save)
180 | fh.flush()
181 | os.fsync(fh.fileno())
182 |
183 | results.append(None)
184 | prog_bar.update(task_id, advance=1, refresh=True)
185 | except Exception as e:
186 | prog_bar.stop()
187 | raise e
188 | for result, idx in unordered_results:
189 | results[idx] = result
190 | return results
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/mp_util-checkpoint.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import os
3 | from typing import Callable, Iterable, Sized
4 |
5 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
6 | TaskProgressColumn, TextColumn, TimeRemainingColumn)
7 | from rich.text import Text
8 | import os.path as osp
9 | import portalocker
10 | from ..smp import load, dump
11 |
12 |
13 | class _Worker:
14 | """Function wrapper for ``track_progress_rich``"""
15 |
16 | def __init__(self, func) -> None:
17 | self.func = func
18 |
19 | def __call__(self, inputs):
20 | inputs, idx = inputs
21 | if not isinstance(inputs, (tuple, list, dict)):
22 | inputs = (inputs, )
23 |
24 | if isinstance(inputs, dict):
25 | return self.func(**inputs), idx
26 | else:
27 | return self.func(*inputs), idx
28 |
29 |
30 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
31 | """Skip calculating remaining time for the first few times.
32 |
33 | Args:
34 | skip_times (int): The number of times to skip. Defaults to 0.
35 | """
36 |
37 | def __init__(self, *args, skip_times=0, **kwargs):
38 | super().__init__(*args, **kwargs)
39 | self.skip_times = skip_times
40 |
41 | def render(self, task: Task) -> Text:
42 | """Show time remaining."""
43 | if task.completed <= self.skip_times:
44 | return Text('-:--:--', style='progress.remaining')
45 | return super().render(task)
46 |
47 |
48 | def _tasks_with_index(tasks):
49 | """Add index to tasks."""
50 | for idx, task in enumerate(tasks):
51 | yield task, idx
52 |
53 | def track_progress_rich(func: Callable,
54 | tasks: Iterable = tuple(),
55 | task_num: int = None,
56 | nproc: int = 1,
57 | chunksize: int = 1,
58 | description: str = 'Processing',
59 | save=None, keys=None,
60 | color: str = 'blue') -> list:
61 | """Track the progress of parallel task execution with a progress bar. The
62 | built-in :mod:`multiprocessing` module is used for process pools and tasks
63 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
64 |
65 | Args:
66 | func (callable): The function to be applied to each task.
67 | tasks (Iterable or Sized): A tuple of tasks. There are several cases
68 | for different format tasks:
69 | - When ``func`` accepts no arguments: tasks should be an empty
70 | tuple, and ``task_num`` must be specified.
71 | - When ``func`` accepts only one argument: tasks should be a tuple
72 | containing the argument.
73 | - When ``func`` accepts multiple arguments: tasks should be a
74 | tuple, with each element representing a set of arguments.
75 | If an element is a ``dict``, it will be parsed as a set of
76 | keyword-only arguments.
77 | Defaults to an empty tuple.
78 | task_num (int, optional): If ``tasks`` is an iterator which does not
79 | have length, the number of tasks can be provided by ``task_num``.
80 | Defaults to None.
81 | nproc (int): Process (worker) number, if nuproc is 1,
82 | use single process. Defaults to 1.
83 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
84 | Defaults to 1.
85 | description (str): The description of progress bar.
86 | Defaults to "Process".
87 | color (str): The color of progress bar. Defaults to "blue".
88 |
89 | Examples:
90 | >>> import time
91 |
92 | >>> def func(x):
93 | ... time.sleep(1)
94 | ... return x**2
95 | >>> track_progress_rich(func, range(10), nproc=2)
96 |
97 | Returns:
98 | list: The task results.
99 | """
100 | if save is not None:
101 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
102 | if not osp.exists(save):
103 | dump({}, save)
104 | if keys is not None:
105 | assert len(keys) == len(tasks)
106 |
107 | if not callable(func):
108 | raise TypeError('func must be a callable object')
109 | if not isinstance(tasks, Iterable):
110 | raise TypeError(
111 | f'tasks must be an iterable object, but got {type(tasks)}')
112 | if isinstance(tasks, Sized):
113 | if len(tasks) == 0:
114 | if task_num is None:
115 | raise ValueError('If tasks is an empty iterable, '
116 | 'task_num must be set')
117 | else:
118 | tasks = tuple(tuple() for _ in range(task_num))
119 | else:
120 | if task_num is not None and task_num != len(tasks):
121 | raise ValueError('task_num does not match the length of tasks')
122 | task_num = len(tasks)
123 |
124 | if nproc <= 0:
125 | raise ValueError('nproc must be a positive number')
126 |
127 | skip_times = nproc * chunksize if nproc > 1 else 0
128 | prog_bar = Progress(
129 | TextColumn('{task.description}'),
130 | BarColumn(),
131 | _SkipFirstTimeRemainingColumn(skip_times=skip_times),
132 | MofNCompleteColumn(),
133 | TaskProgressColumn(show_speed=True),
134 | )
135 |
136 | worker = _Worker(func)
137 | task_id = prog_bar.add_task(
138 | total=task_num, color=color, description=description)
139 | tasks = _tasks_with_index(tasks)
140 |
141 | # Use single process when nproc is 1, else use multiprocess.
142 | with prog_bar:
143 | if nproc == 1:
144 | results = []
145 | for task in tasks:
146 | result, idx = worker(task)
147 | results.append(worker(task)[0])
148 | if save is not None:
149 | with portalocker.Lock(save, timeout=5) as fh:
150 | ans = load(save)
151 | ans[keys[idx]] = result
152 |
153 | if os.environ.get('VERBOSE', True):
154 | print(keys[idx], result, flush=True)
155 |
156 | dump(ans, save)
157 | fh.flush()
158 | os.fsync(fh.fileno())
159 |
160 | prog_bar.update(task_id, advance=1, refresh=True)
161 | else:
162 | with Pool(nproc) as pool:
163 | results = []
164 | unordered_results = []
165 | gen = pool.imap_unordered(worker, tasks, chunksize)
166 | try:
167 | for result in gen:
168 | result, idx = result
169 | unordered_results.append((result, idx))
170 |
171 | if save is not None:
172 | with portalocker.Lock(save, timeout=5) as fh:
173 | ans = load(save)
174 | ans[keys[idx]] = result
175 |
176 | if os.environ.get('VERBOSE', False):
177 | print(keys[idx], result, flush=True)
178 |
179 | dump(ans, save)
180 | fh.flush()
181 | os.fsync(fh.fileno())
182 |
183 | results.append(None)
184 | prog_bar.update(task_id, advance=1, refresh=True)
185 | except Exception as e:
186 | prog_bar.stop()
187 | raise e
188 | for result, idx in unordered_results:
189 | results[idx] = result
190 | return results
--------------------------------------------------------------------------------
/vlmeval/api/gpt.py:
--------------------------------------------------------------------------------
1 | from ..smp import *
2 | import os, sys
3 | from .base import BaseAPI
4 |
5 | APIBASES = {
6 | 'OFFICIAL': "https://api.openai.com/v1/chat/completions",
7 | }
8 |
9 |
10 | def GPT_context_window(model):
11 | length_map = {
12 | 'gpt-4-1106-preview': 128000,
13 | 'gpt-4-vision-preview': 128000,
14 | 'gpt-4': 8192,
15 | 'gpt-4-32k': 32768,
16 | 'gpt-4-0613': 8192,
17 | 'gpt-4-32k-0613': 32768,
18 | 'gpt-3.5-turbo-1106': 16385,
19 | 'gpt-3.5-turbo': 4096,
20 | 'gpt-3.5-turbo-16k': 16385,
21 | 'gpt-3.5-turbo-instruct': 4096,
22 | 'gpt-3.5-turbo-0613': 4096,
23 | 'gpt-3.5-turbo-16k-0613': 16385,
24 | }
25 | if model in length_map:
26 | return length_map[model]
27 | else:
28 | return 4096
29 |
30 | class OpenAIWrapper(BaseAPI):
31 |
32 | is_api: bool = True
33 |
34 | def __init__(self,
35 | model: str = 'gpt-3.5-turbo-0613',
36 | retry: int = 5,
37 | wait: int = 5,
38 | key: str = None,
39 | verbose: bool = True,
40 | system_prompt: str = None,
41 | temperature: float = 0,
42 | timeout: int = 60,
43 | api_base: str = 'OFFICIAL',
44 | max_tokens: int = 1024,
45 | img_size: int = 512,
46 | img_detail: str = 'low',
47 | **kwargs):
48 |
49 | self.model = model
50 | self.cur_idx = 0
51 | self.fail_msg = 'Failed to obtain answer via API. '
52 | self.max_tokens = max_tokens
53 | self.temperature = temperature
54 |
55 | openai_key = os.environ.get('OPENAI_API_KEY', None) if key is None else key
56 | self.openai_key = openai_key
57 | assert img_size > 0 or img_size == -1
58 | self.img_size = img_size
59 | assert img_detail in ['high', 'low']
60 | self.img_detail = img_detail
61 |
62 | self.vision = False
63 | if model == 'gpt-4-vision-preview':
64 | self.vision = True
65 | self.timeout = timeout
66 |
67 | assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. '
68 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
69 |
70 | if api_base in APIBASES:
71 | self.api_base = APIBASES[api_base]
72 | elif api_base.startswith('http'):
73 | self.api_base = api_base
74 | else:
75 | self.logger.error("Unknown API Base. ")
76 | sys.exit(-1)
77 |
78 | if 'OPENAI_API_BASE' in os.environ:
79 | self.logger.error("Environment variable OPENAI_API_BASE is set. Will override the api_base arg. ")
80 | self.api_base = os.environ['OPENAI_API_BASE']
81 |
82 | # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
83 | # content can be a string or a list of image & text
84 | def prepare_inputs(self, inputs):
85 | input_msgs = []
86 | if self.system_prompt is not None:
87 | input_msgs.append(dict(role='system', content=self.system_prompt))
88 | if isinstance(inputs, str):
89 | input_msgs.append(dict(role='user', content=inputs))
90 | return input_msgs
91 | assert isinstance(inputs, list)
92 | dict_flag = [isinstance(x, dict) for x in inputs]
93 | if np.all(dict_flag):
94 | input_msgs.extend(inputs)
95 | return input_msgs
96 | str_flag = [isinstance(x, str) for x in inputs]
97 | if np.all(str_flag):
98 | img_flag = [x.startswith('http') or osp.exists(x) for x in inputs]
99 | if np.any(img_flag):
100 | content_list = []
101 | for fl, msg in zip(img_flag, inputs):
102 | if not fl:
103 | content_list.append(dict(type='text', text=msg))
104 | elif msg.startswith('http'):
105 | content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail}))
106 | elif osp.exists(msg):
107 | from PIL import Image
108 | img = Image.open(msg)
109 | b64 = encode_image_to_base64(img, target_size=self.img_size)
110 | img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
111 | content_list.append(dict(type='image_url', image_url=img_struct))
112 | input_msgs.append(dict(role='user', content=content_list))
113 | return input_msgs
114 | else:
115 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
116 | roles = roles * len(inputs)
117 | for role, msg in zip(roles, inputs):
118 | input_msgs.append(dict(role=role, content=msg))
119 | return input_msgs
120 | raise NotImplemented("list of list prompt not implemented now. ")
121 |
122 | def generate_inner(self, inputs, **kwargs) -> str:
123 | input_msgs = self.prepare_inputs(inputs)
124 | temperature = kwargs.pop('temperature', self.temperature)
125 | max_tokens = kwargs.pop('max_tokens', self.max_tokens)
126 |
127 | context_window = GPT_context_window(self.model)
128 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
129 | if 0 < max_tokens <= 100:
130 | self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
131 | if max_tokens <= 0:
132 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
133 |
134 | headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'}
135 | payload = dict(
136 | model=self.model,
137 | messages=input_msgs,
138 | max_tokens=max_tokens,
139 | n=1,
140 | temperature=temperature,
141 | **kwargs)
142 | response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
143 | ret_code = response.status_code
144 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
145 | answer = self.fail_msg
146 | try:
147 | resp_struct = json.loads(response.text)
148 | answer = resp_struct['choices'][0]['message']['content'].strip()
149 | except:
150 | pass
151 | return ret_code, answer, response
152 |
153 | def get_token_len(self, inputs) -> int:
154 | import tiktoken
155 | enc = tiktoken.encoding_for_model(self.model)
156 | if isinstance(inputs, str):
157 | if inputs.startswith('http') or osp.exists(inputs):
158 | return 65 if self.img_detail == 'low' else 130
159 | else:
160 | return len(enc.encode(inputs))
161 | elif isinstance(inputs, dict):
162 | assert 'content' in inputs
163 | return self.get_token_len(inputs['content'])
164 | assert isinstance(inputs, list)
165 | res = 0
166 | for item in inputs:
167 | res += self.get_token_len(item)
168 | return res
169 |
170 | class GPT4V(OpenAIWrapper):
171 |
172 | def generate(self, image_path, prompt, dataset=None):
173 | assert self.model == 'gpt-4-vision-preview'
174 | return super(GPT4V, self).generate([image_path, prompt])
175 |
176 | def interleave_generate(self, ti_list, dataset=None):
177 | assert self.model == 'gpt-4-vision-preview'
178 | return super(GPT4V, self).generate(ti_list)
179 |
--------------------------------------------------------------------------------
/vlmeval/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import hashlib
3 | from ..smp import *
4 | from .dataset_config import dataset_URLs, dataset_md5_dict, img_root_map, DATASET_TYPE
5 | from .custom_prompt import CustomPrompt
6 | from .base_prompt import create_one_example
7 | import csv
8 |
9 | def isliststr(s):
10 | return (s[0] == '[') and (s[-1] == ']')
11 |
12 | def check_md5(data_path, dataset):
13 | try:
14 | with open(data_path, 'rb') as f:
15 | hash = hashlib.new('md5')
16 | for chunk in iter(lambda: f.read(2**20), b''):
17 | hash.update(chunk)
18 | if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
19 | return True
20 | else:
21 | warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
22 | return False
23 | except:
24 | return False
25 |
26 |
27 | def split_MMMU(struct):
28 | assert 'image' in struct and 'text' in struct
29 | text, images = struct['text'], struct['image']
30 | text_segs = text.split(''
36 | image_idx = int(seg[0]) - 1
37 | segs.append(images[image_idx])
38 | segs.append(seg[2:])
39 | return segs
40 |
41 | def init_prompt_multi(struct, format_):
42 | question = struct['text']['question'] if 'question' in struct['text'] else 'none'
43 | context = struct['text']['hint'] if 'hint' in struct['text'] else 'none'
44 | options = struct['text']['options'] if 'options' in struct['text'] else 'none'
45 | answer = struct['debate_ans'] if 'debate_ans' in struct else 'none'
46 | knowledge = struct['kg'] if 'kg' in struct else 'none'
47 | image_path = struct['image'] if 'image' in struct else 'none'
48 |
49 | prompt = create_one_example(format_, question, context.replace("\n"," "), options, answer, knowledge, image_path)
50 | return prompt
51 |
52 | class TSVDataset(CustomPrompt):
53 |
54 | def __init__(self, dataset='MMBench_DEV_EN', img_root=None, skip_noimg=True):
55 |
56 | self.data_root = LMUDataRoot()
57 | assert osp.exists(self.data_root)
58 |
59 | self.dataset = dataset
60 | self.dataset_type = DATASET_TYPE(dataset)
61 |
62 | url = dataset_URLs[dataset]
63 | file_name = url.split('/')[-1]
64 | data_path = osp.join(self.data_root, file_name)
65 | print(data_path)
66 |
67 | if osp.exists(data_path) and md5(data_path) == dataset_md5_dict[dataset]:
68 | print("Dateset is Download: ",data_path)
69 | pass
70 | else:
71 | warnings.warn("The dataset tsv is not downloaded")
72 | download_file(url, data_path)
73 |
74 | data = load(data_path)
75 | if dataset=="ScienceQA_TEST":
76 | kg_file = "/code/BDoG/data/kg_init/scienceqa_test_kg_gpt4.json"
77 | kg_base = json.load(open(kg_file))
78 | kg_base = list(kg_base.values())
79 | elif dataset=="MMBench_DEV_EN":
80 | kg_file = "/code/BDoG/data/kg_init/MMBench_DEV_EN_s.json"
81 | kg_base = json.load(open(kg_file))
82 | kg_base = list(kg_base.values())
83 | else:
84 | kg_base = ['none' for i in range(len(data))]
85 |
86 | self.skip_noimg = skip_noimg
87 | if skip_noimg:
88 | data = data[~pd.isna(data['image'])]
89 |
90 | # Prompt for Captioning
91 | if listinstr(['COCO'], dataset):
92 | data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data)
93 |
94 | data['index'] = [str(x) for x in data['index']]
95 | data['image'] = [str(x) for x in data['image']]
96 | ## Add kg_init
97 | data['kg'] = kg_base
98 |
99 | image_map = {x: y for x, y in zip(data['index'], data['image'])}
100 | for k in image_map:
101 | if len(image_map[k]) <= 64:
102 | idx = image_map[k]
103 | assert idx in image_map and len(image_map[idx]) > 64
104 | image_map[k] = image_map[idx]
105 |
106 | data['image'] = [
107 | eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
108 | for k in data['index']
109 | ]
110 | if 'image_path' in data:
111 | data['image_path'] = [
112 | eval(pths) if isliststr(pths) else pths for pths in data['image_path']
113 | ]
114 | if np.all([istype(x, int) for x in data['index']]):
115 | data['index'] = [int(x) for x in data['index']]
116 |
117 | self.data = data
118 |
119 | img_root = img_root if img_root is not None else osp.join('images', img_root_map[dataset])
120 | os.makedirs(img_root, exist_ok=True)
121 | self.img_root = img_root
122 |
123 | # self.set_file()
124 | # print("#### Save: /code/VLMEvalKit-main/data/CCBench_Fin.tsv")
125 |
126 | def __len__(self):
127 | return len(self.data)
128 |
129 | def set_file(self,data_in,data_out):
130 | dataset = self.dataset
131 | print("##START Processing: ", dataset)
132 | with open(data_in, 'r') as in_file:
133 | with open(data_out, 'w', newline='') as out_file:
134 | reader = csv.DictReader(in_file, delimiter='\t')
135 | fieldnames = reader.fieldnames
136 | writer = csv.DictWriter(out_file, fieldnames=fieldnames, delimiter='\t')
137 |
138 | writer.writeheader()
139 | for i, row in enumerate(tqdm(reader)):
140 | # 给每一行增加一个键值对
141 | line = self.data.iloc[i]
142 | tgt_path = self.dump_image(line, dataset)
143 | row['image'] = tgt_path
144 | writer.writerow(row)
145 | print(writer.fieldnames)
146 |
147 | def build_prompt(self, line, dataset=None):
148 | if dataset is None:
149 | dataset = self.dataset
150 |
151 | if isinstance(line, int):
152 | line = self.data.iloc[line]
153 |
154 | tgt_path = self.dump_image(line, dataset)
155 |
156 | prompt = line['question']
157 | if DATASET_TYPE(dataset) == 'multi-choice':
158 | question = line['question']
159 | options = {
160 | cand: line[cand]
161 | for cand in string.ascii_uppercase
162 | if cand in line and not pd.isna(line[cand])
163 | }
164 | options_prompt = 'Options:\n'
165 | for key, item in options.items():
166 | options_prompt += f'{key}. {item}\n'
167 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
168 | prompt = ''
169 | if hint is not None:
170 | prompt += f'Hint: {hint}\n'
171 | prompt += f'Question: {question}\n'
172 | if len(options):
173 | prompt += options_prompt
174 | prompt += "Answer with the option's letter from the given choices directly.\n"
175 |
176 | return dict(image=tgt_path, text=prompt)
177 |
178 |
179 | def build_prompt_multi(self, line, dataset=None):
180 | if dataset is None:
181 | dataset = self.dataset
182 |
183 | if isinstance(line, int):
184 | line = self.data.iloc[line]
185 |
186 | tgt_path = self.dump_image(line, dataset)
187 | prompt = {}
188 | prompt['index'] = line['index']
189 | prompt['answer'] = line['gpt4_ans'] if dataset == "LLaVABench" else line['answer']
190 | prompt['question'] = line['question']
191 | if 'kg' in line:
192 | prompt['kg'] = str(line['kg'])[:350]
193 | else:
194 | prompt['kg'] = 'none'
195 | if DATASET_TYPE(dataset) == 'multi-choice':
196 | options = {
197 | cand: line[cand]
198 | for cand in string.ascii_uppercase
199 | if cand in line and not pd.isna(line[cand])
200 | }
201 | options_prompt = ''
202 | choise = []
203 | for key, item in options.items():
204 | options_prompt += f'{key}. {item}\n'
205 | choise.append(item)
206 |
207 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
208 | if dataset == "LLaVABench":
209 | hint = line['caption'] if ('caption' in line and not pd.isna(line['caption'])) else None
210 |
211 | if hint is not None:
212 | prompt['hint'] = f'{hint}'
213 | if len(options):
214 | prompt['options'] = options_prompt
215 | # prompt['options'] += 'Please select the correct answer from the options above.'
216 | # prompt['options'] += "Answer with the option's letter from the given choices directly"
217 | prompt['choise'] = choise
218 | # print(tgt_path)
219 | return dict(image=tgt_path, text=prompt)
220 |
221 | def display(self, line):
222 | if isinstance(line, int):
223 | line = self.data.iloc[line]
224 | mmqa_display(line)
225 |
226 |
--------------------------------------------------------------------------------
/vlmeval/utils/.ipynb_checkpoints/dataset-checkpoint.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import hashlib
3 | from ..smp import *
4 | from .dataset_config import dataset_URLs, dataset_md5_dict, img_root_map, DATASET_TYPE
5 | from .custom_prompt import CustomPrompt
6 | from .base_prompt import create_one_example
7 | import csv
8 |
9 | def isliststr(s):
10 | return (s[0] == '[') and (s[-1] == ']')
11 |
12 | def check_md5(data_path, dataset):
13 | try:
14 | with open(data_path, 'rb') as f:
15 | hash = hashlib.new('md5')
16 | for chunk in iter(lambda: f.read(2**20), b''):
17 | hash.update(chunk)
18 | if str(hash.hexdigest()) == dataset_md5_dict[dataset]:
19 | return True
20 | else:
21 | warnings.warn('this data file is incomplete, so it needs to be downloaded again.')
22 | return False
23 | except:
24 | return False
25 |
26 |
27 | def split_MMMU(struct):
28 | assert 'image' in struct and 'text' in struct
29 | text, images = struct['text'], struct['image']
30 | text_segs = text.split(''
36 | image_idx = int(seg[0]) - 1
37 | segs.append(images[image_idx])
38 | segs.append(seg[2:])
39 | return segs
40 |
41 | def init_prompt_multi(struct, format_):
42 | question = struct['text']['question'] if 'question' in struct['text'] else 'none'
43 | context = struct['text']['hint'] if 'hint' in struct['text'] else 'none'
44 | options = struct['text']['options'] if 'options' in struct['text'] else 'none'
45 | answer = struct['debate_ans'] if 'debate_ans' in struct else 'none'
46 | knowledge = struct['kg'] if 'kg' in struct else 'none'
47 | image_path = struct['image'] if 'image' in struct else 'none'
48 |
49 | prompt = create_one_example(format_, question, context.replace("\n"," "), options, answer, knowledge, image_path)
50 | return prompt
51 |
52 | class TSVDataset(CustomPrompt):
53 |
54 | def __init__(self, dataset='MMBench_DEV_EN', img_root=None, skip_noimg=True):
55 |
56 | self.data_root = LMUDataRoot()
57 | assert osp.exists(self.data_root)
58 |
59 | self.dataset = dataset
60 | self.dataset_type = DATASET_TYPE(dataset)
61 |
62 | url = dataset_URLs[dataset]
63 | file_name = url.split('/')[-1]
64 | data_path = osp.join(self.data_root, file_name)
65 | print(data_path)
66 |
67 | if osp.exists(data_path) and md5(data_path) == dataset_md5_dict[dataset]:
68 | print("Dateset is Download: ",data_path)
69 | pass
70 | else:
71 | warnings.warn("The dataset tsv is not downloaded")
72 | download_file(url, data_path)
73 |
74 | data = load(data_path)
75 | if dataset=="ScienceQA_TEST":
76 | kg_file = "/code/BDoG/data/kg_init/scienceqa_test_kg_gpt4.json"
77 | kg_base = json.load(open(kg_file))
78 | kg_base = list(kg_base.values())
79 | elif dataset=="MMBench_DEV_EN":
80 | kg_file = "/code/BDoG/data/kg_init/MMBench_DEV_EN_s.json"
81 | kg_base = json.load(open(kg_file))
82 | kg_base = list(kg_base.values())
83 | else:
84 | kg_base = ['none' for i in range(len(data))]
85 |
86 | self.skip_noimg = skip_noimg
87 | if skip_noimg:
88 | data = data[~pd.isna(data['image'])]
89 |
90 | # Prompt for Captioning
91 | if listinstr(['COCO'], dataset):
92 | data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data)
93 |
94 | data['index'] = [str(x) for x in data['index']]
95 | data['image'] = [str(x) for x in data['image']]
96 | ## Add kg_init
97 | data['kg'] = kg_base
98 |
99 | image_map = {x: y for x, y in zip(data['index'], data['image'])}
100 | for k in image_map:
101 | if len(image_map[k]) <= 64:
102 | idx = image_map[k]
103 | assert idx in image_map and len(image_map[idx]) > 64
104 | image_map[k] = image_map[idx]
105 |
106 | data['image'] = [
107 | eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
108 | for k in data['index']
109 | ]
110 | if 'image_path' in data:
111 | data['image_path'] = [
112 | eval(pths) if isliststr(pths) else pths for pths in data['image_path']
113 | ]
114 | if np.all([istype(x, int) for x in data['index']]):
115 | data['index'] = [int(x) for x in data['index']]
116 |
117 | self.data = data
118 |
119 | img_root = img_root if img_root is not None else osp.join('images', img_root_map[dataset])
120 | os.makedirs(img_root, exist_ok=True)
121 | self.img_root = img_root
122 |
123 | # self.set_file()
124 | # print("#### Save: /code/VLMEvalKit-main/data/CCBench_Fin.tsv")
125 |
126 | def __len__(self):
127 | return len(self.data)
128 |
129 | def set_file(self,data_in,data_out):
130 | dataset = self.dataset
131 | print("##START Processing: ", dataset)
132 | with open(data_in, 'r') as in_file:
133 | with open(data_out, 'w', newline='') as out_file:
134 | reader = csv.DictReader(in_file, delimiter='\t')
135 | fieldnames = reader.fieldnames
136 | writer = csv.DictWriter(out_file, fieldnames=fieldnames, delimiter='\t')
137 |
138 | writer.writeheader()
139 | for i, row in enumerate(tqdm(reader)):
140 | # 给每一行增加一个键值对
141 | line = self.data.iloc[i]
142 | tgt_path = self.dump_image(line, dataset)
143 | row['image'] = tgt_path
144 | writer.writerow(row)
145 | print(writer.fieldnames)
146 |
147 | def build_prompt(self, line, dataset=None):
148 | if dataset is None:
149 | dataset = self.dataset
150 |
151 | if isinstance(line, int):
152 | line = self.data.iloc[line]
153 |
154 | tgt_path = self.dump_image(line, dataset)
155 |
156 | prompt = line['question']
157 | if DATASET_TYPE(dataset) == 'multi-choice':
158 | question = line['question']
159 | options = {
160 | cand: line[cand]
161 | for cand in string.ascii_uppercase
162 | if cand in line and not pd.isna(line[cand])
163 | }
164 | options_prompt = 'Options:\n'
165 | for key, item in options.items():
166 | options_prompt += f'{key}. {item}\n'
167 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
168 | prompt = ''
169 | if hint is not None:
170 | prompt += f'Hint: {hint}\n'
171 | prompt += f'Question: {question}\n'
172 | if len(options):
173 | prompt += options_prompt
174 | prompt += "Answer with the option's letter from the given choices directly.\n"
175 |
176 | return dict(image=tgt_path, text=prompt)
177 |
178 |
179 | def build_prompt_multi(self, line, dataset=None):
180 | if dataset is None:
181 | dataset = self.dataset
182 |
183 | if isinstance(line, int):
184 | line = self.data.iloc[line]
185 |
186 | tgt_path = self.dump_image(line, dataset)
187 | prompt = {}
188 | prompt['index'] = line['index']
189 | prompt['answer'] = line['gpt4_ans'] if dataset == "LLaVABench" else line['answer']
190 | prompt['question'] = line['question']
191 | if 'kg' in line:
192 | prompt['kg'] = str(line['kg'])[:350]
193 | else:
194 | prompt['kg'] = 'none'
195 | if DATASET_TYPE(dataset) == 'multi-choice':
196 | options = {
197 | cand: line[cand]
198 | for cand in string.ascii_uppercase
199 | if cand in line and not pd.isna(line[cand])
200 | }
201 | options_prompt = ''
202 | choise = []
203 | for key, item in options.items():
204 | options_prompt += f'{key}. {item}\n'
205 | choise.append(item)
206 |
207 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
208 | if dataset == "LLaVABench":
209 | hint = line['caption'] if ('caption' in line and not pd.isna(line['caption'])) else None
210 |
211 | if hint is not None:
212 | prompt['hint'] = f'{hint}'
213 | if len(options):
214 | prompt['options'] = options_prompt
215 | # prompt['options'] += 'Please select the correct answer from the options above.'
216 | # prompt['options'] += "Answer with the option's letter from the given choices directly"
217 | prompt['choise'] = choise
218 | # print(tgt_path)
219 | return dict(image=tgt_path, text=prompt)
220 |
221 | def display(self, line):
222 | if isinstance(line, int):
223 | line = self.data.iloc[line]
224 | mmqa_display(line)
225 |
226 |
--------------------------------------------------------------------------------
/vlmeval/api/hf_chat_model.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import os.path as osp
3 | import torch
4 | from ..smp import *
5 |
6 | def get_gpu_num(model_name):
7 | model_name = model_name.lower()
8 | kws = {
9 | 8: ['65b', '70b'],
10 | 4: ['30b', '33b', '35b', '40b'],
11 | 2: ['13b', '14b', '20b'],
12 | 1: ['6b', '7b', 'moss'],
13 | }
14 | for k in [8, 4, 2, 1]:
15 | for keyword in kws[k]:
16 | if keyword in model_name:
17 | return k
18 | return 8
19 |
20 | validated_llms = [
21 | 'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b',
22 | 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat',
23 | 'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k',
24 | 'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat',
25 | 'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5',
26 | 'meta-llama/Llama-2-7b-chat-hf'
27 | ]
28 | Auto_model = ['chatglm']
29 |
30 | class HFChatModel:
31 |
32 | def _get_context_length(self, model, model_path):
33 | # By default, we use model.config.seq_length
34 | model_path = model_path.lower()
35 | if 'baichuan' in model_path:
36 | context_window = model.config.model_max_length
37 | elif 'internlm' in model_path or 'llama' in model_path:
38 | context_window = model.config.max_position_embeddings
39 | elif 'vicuna' in model_path:
40 | context_window = model.generation_config.max_length
41 | else:
42 | # chatglm & qwen
43 | context_window = model.config.seq_length
44 | return context_window
45 |
46 | def _get_context_length_robust(self, model, model_path):
47 | try:
48 | context_window = self._get_context_length(model, model_path)
49 | return context_window
50 | except:
51 | self.logger.critical(
52 | "Failed to extract context_window information from config / generation_config. "
53 | "Please read the above code and check if the logic works for you model path"
54 | )
55 | raise NotImplementedError
56 |
57 | def __init__(self,
58 | model_path,
59 | system_prompt: str=None,
60 | **kwargs):
61 |
62 | self.logger = get_logger('HFChatModel')
63 | if 'vicuna' in model_path.lower():
64 | try:
65 | from fastchat.model import get_conversation_template
66 | except:
67 | self.logger.critical("Please install fastchat first to use vicuna. ")
68 | sys.exit(-1)
69 |
70 | self.explicit_device = kwargs.pop('device', None)
71 |
72 | if self.explicit_device is None:
73 | # If CUDA_VISIBLE_DEVICES is not properly set
74 | if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] in ['', '0,1,2,3,4,5,6,7']:
75 | num_gpu = get_gpu_num(model_path)
76 | gpu_offset = kwargs.pop('gpu_offset', 0)
77 | cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset+num_gpu)])
78 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
79 |
80 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
81 | from transformers.generation import GenerationConfig
82 |
83 | if model_path not in validated_llms:
84 | self.logger.warning(f"{model_path} not in validated LLMs, may have inference troubles. ")
85 |
86 | self.model_path = model_path
87 | if listinstr(Auto_model, model_path):
88 | LoadModel = AutoModel
89 | else:
90 | LoadModel = AutoModelForCausalLM
91 |
92 | assert osp.exists(model_path) or len(model_path.split('/')) == 2
93 |
94 | device = self.explicit_device if self.explicit_device else "auto"
95 |
96 | precision = {}
97 | if 'internlm-chat-7b' in model_path:
98 | precision = {'torch_dtype': torch.float16}
99 | elif 'internlm-chat-20b' in model_path:
100 | precision = {'torch_dtype': torch.bfloat16}
101 |
102 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
103 | model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision)
104 | model = model.eval()
105 |
106 | if device != 'cpu':
107 | model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda')
108 | try:
109 | model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True, device_map=device)
110 | except:
111 | pass
112 |
113 | torch.cuda.empty_cache()
114 | self.model = model
115 | self.context_length = self._get_context_length_robust(model=model, model_path=model_path)
116 | self.answer_buffer = 192
117 | self.system_prompt = system_prompt
118 | for k, v in kwargs.items():
119 | self.logger.info(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ')
120 | self.kwargs = kwargs
121 |
122 | def generate_str(self, input, **kwargs):
123 | if 'baichuan' in self.model_path.lower():
124 | messages=[]
125 | messages.append({"role": "user", "content": input})
126 | resp= self.model.chat(self.tokenizer, messages, **kwargs)
127 | elif 'vicuna' in self.model_path.lower():
128 | from fastchat.model import get_conversation_template
129 | conv = get_conversation_template('vicuna')
130 | conv.append_message(conv.roles[0], input)
131 | conv.append_message(conv.roles[1], None)
132 | prompt = conv.get_prompt()
133 | inputs = self.tokenizer([prompt], return_tensors="pt")
134 | if torch.cuda.is_available():
135 | for k in inputs:
136 | inputs[k] = inputs[k].cuda()
137 |
138 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
139 | params.update(self.kwargs)
140 | params.update(kwargs)
141 | outputs = self.model.generate(**inputs, **params)
142 | resp = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False)
143 |
144 | else:
145 | params = self.kwargs
146 | params.update(kwargs)
147 | resp, _ = self.model.chat(self.tokenizer, input, history=[], **params)
148 |
149 | return resp
150 |
151 | def length_ok(self, inputs):
152 | tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0
153 | for s in inputs:
154 | tot += len(self.tokenizer.encode(s))
155 | return tot + self.answer_buffer < self.context_length
156 |
157 | def generate_list(self, full_inputs, offset=0, **kwargs):
158 | assert isinstance(full_inputs, list)
159 |
160 | inputs = full_inputs[offset:]
161 | if not self.length_ok(inputs):
162 | return self.chat(full_inputs, offset + 1)
163 |
164 | model_path = self.model_path.lower()
165 |
166 | if sum([x in model_path for x in ['baichuan']]):
167 | input_msgs = []
168 | if self.system_prompt is not None:
169 | input_msgs.append(dict(role='user', content=self.system_prompt))
170 | if len(inputs):
171 | assert isinstance(inputs, list) and isinstance(inputs[0], str)
172 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
173 | roles = roles * len(inputs)
174 | for role, msg in zip(roles, inputs):
175 | input_msgs.append(dict(role=role, content=msg))
176 | response = self.model.chat(self.tokenizer, input_msgs)
177 | elif sum([x in model_path for x in ['vicuna']]):
178 | from fastchat.model import get_conversation_template
179 | conv = get_conversation_template('vicuna')
180 | assert isinstance(inputs, list) and isinstance(inputs[0], str)
181 | if len(inputs) % 2 == 1:
182 | if self.system_prompt is not None:
183 | conv.append_message(conv.roles[0], self.system_prompt)
184 | for i in range(len(inputs)//2):
185 | conv.append_message(conv.roles[0], inputs[2 * i])
186 | conv.append_message(conv.roles[1], inputs[2 * i + 1])
187 | else:
188 | assert self.system_prompt is not None
189 | conv.append_message(conv.roles[0], self.system_prompt)
190 | conv.append_message(conv.roles[1], inputs[0])
191 | for i in range(len(inputs) // 2 - 1):
192 | conv.append_message(conv.roles[0], inputs[2 * i + 1])
193 | conv.append_message(conv.roles[1], inputs[2 * i + 2])
194 | conv.append_message(conv.roles[0], inputs[-1])
195 | conv.append_message(conv.roles[1], None)
196 | prompt = conv.get_prompt()
197 | inputs = self.tokenizer([prompt], return_tensors="pt")
198 | if torch.cuda.is_available():
199 | for k in inputs:
200 | inputs[k] = inputs[k].cuda()
201 |
202 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
203 | params.update(self.kwargs)
204 | params.update(kwargs)
205 |
206 | outputs = self.model.generate(**inputs, **params)
207 | response = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False)
208 | response = response.lstrip('\n')
209 | else:
210 | # The default option, support internlm, chatglm, qwen
211 | history, msg = [], None
212 | if len(inputs) % 2 == 1:
213 | if self.system_prompt is not None:
214 | history = [(self.system_prompt, '')]
215 | for i in range(len(inputs)//2):
216 | history.append((inputs[2 * i], inputs[2 * i + 1]))
217 | else:
218 | assert self.system_prompt is not None
219 | history = [(self.system_prompt, inputs[0])]
220 | for i in range(len(inputs) // 2 - 1):
221 | history.append((inputs[2 * i + 1], inputs[2 * i + 2]))
222 | msg = inputs[-1]
223 |
224 | params = self.kwargs
225 | params.update(kwargs)
226 | response, _ = self.model.chat(self.tokenizer, msg, history=history, **params)
227 |
228 | return response, offset
229 |
230 | def generate(self, inputs, **kwargs):
231 | if isinstance(inputs, str):
232 | return self.generate_str(inputs, **kwargs)
233 | elif isinstance(inputs, list):
234 | return self.generate_list(inputs, **kwargs)
--------------------------------------------------------------------------------
/vlmeval/evaluate/multiple_choice.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pandas as pd
3 | from tqdm import tqdm
4 | from vlmeval.evaluate.misc import build_judge
5 | from vlmeval.utils import can_infer, track_progress_rich, TSVDataset
6 | from vlmeval.smp import *
7 | import numpy as np
8 |
9 | INTERNAL = os.environ.get('INTERNAL', 0)
10 |
11 | abbrs = {
12 | 'coarse_perception': 'CP',
13 | 'finegrained_perception (instance-level)': 'FP-S',
14 | 'finegrained_perception (cross-instance)': 'FP-C',
15 | 'logic_reasoning': 'LR',
16 | 'relation_reasoning': 'RR',
17 | 'attribute_reasoning': 'AR'
18 | }
19 |
20 |
21 |
22 |
23 | def MMMU_preproc(data):
24 | logger = get_logger('Evaluation')
25 | cnt = 0
26 | As, Bs, Ans = list(data['A']), list(data['B']), list(data['answer'])
27 | lt = len(data)
28 | for i in range(lt):
29 | if pd.isna(As[i]):
30 | As[i] = Ans[i]
31 | Bs[i] = 'Other Answers'
32 | cnt += 1
33 | logger.info(f'During MMMU_preproc in Evaluation, {cnt} open questions are re-formulated to multi-choice ones. ')
34 | data['A'] = As
35 | data['B'] = Bs
36 | return data
37 |
38 |
39 | def report_acc(df):
40 | # assert group in [None, 'category', 'l2-category']
41 | res = defaultdict(list)
42 |
43 | if 'split' in df:
44 | splits = list(set(df['split']))
45 | res['split'] = splits
46 | else:
47 | df['split'] = ['none'] * len(df)
48 | res['split'] = ['none']
49 |
50 | for group in [None, 'l2-category', 'category']:
51 | if group is None:
52 | res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
53 | elif group not in df:
54 | continue
55 | else:
56 | abilities = list(set(df[group]))
57 | abilities.sort()
58 | for ab in abilities:
59 | ab_name = abbrs[ab] if ab in abbrs else ab
60 | sub_df = df[df[group] == ab]
61 | res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
62 | return pd.DataFrame(res)
63 |
64 |
65 | def build_prompt(question, options, prediction):
66 | tmpl = (
67 | 'You are an AI assistant who will help me to match '
68 | 'an answer with several options of a single-choice question. '
69 | 'You are provided with a question, several options, and an answer, '
70 | 'and you need to find which option is most similar to the answer. '
71 | 'If the meaning of all options are significantly different from the answer, output Z. '
72 | 'Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n'
73 | 'Example 1: \n'
74 | 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
75 | 'Answer: a cute teddy bear\nYour output: A\n'
76 | 'Example 2: \n'
77 | 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
78 | 'Answer: Spider\nYour output: Z\n'
79 | 'Example 3: \n'
80 | 'Question: {}?\nOptions: {}\nAnswer: {}\nYour output: '
81 | )
82 | return tmpl.format(question, options, prediction)
83 |
84 |
85 | def build_prompt_cn(question, options, prediction):
86 | tmpl = (
87 | '你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。'
88 | '你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。'
89 | '如果所有选项的意义都与答案显著不同,则输出 Z。'
90 | '你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。'
91 | '例 1:'
92 | '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n'
93 | '例 2: \n'
94 | '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n'
95 | '例 3: \n'
96 | '问题: {}?\n选项: {}\n答案: {}\n输出: '
97 | )
98 | return tmpl.format(question, options, prediction)
99 |
100 |
101 | def build_choices(item):
102 | ret = {}
103 | for ch in string.ascii_uppercase:
104 | if ch in item and (not pd.isna(item[ch])):
105 | ret[ch] = item[ch]
106 | return ret
107 |
108 |
109 | def prefetch_answer(item):
110 | choices = build_choices(item)
111 | return can_infer(item['prediction'], choices)
112 |
113 |
114 | def extract_answer_from_item(model, item):
115 | logger = get_logger('Evaluation')
116 | # It will return: (pred, raw, llm_time)
117 | choices = build_choices(item)
118 | option_str = build_option_str(choices)
119 |
120 | if cn_string(item['question']):
121 | prompt = build_prompt_cn(item['question'], option_str, item['prediction'])
122 | else:
123 | prompt = build_prompt(item['question'], option_str, item['prediction'])
124 | retry = 3
125 |
126 | ret = can_infer(item['prediction'], choices)
127 | if ret:
128 | return dict(opt=ret, log=item['prediction'])
129 |
130 | while retry:
131 | ans = model.generate(prompt)
132 | if 'Failed to obtain answer via API' in ans:
133 | logger.warning('GPT API failed to answer. ')
134 | else:
135 | ret = can_infer(ans, choices)
136 | if ret:
137 | return dict(opt=ret, log=ans)
138 | else:
139 | logger.warning(f'Output includes 0 / > 1 letter among candidates {set(choices)} and Z: {ans}')
140 | retry -= 1
141 |
142 | if retry == 0:
143 | options = list(choices) + ['Z'] if 'Z' not in choices else []
144 | return dict(opt=rd.choice(options), log='Failed to predict, thus randomly generate one. ')
145 |
146 |
147 | def prefetch_sub_data(sub_data, answer_map, verbose=False):
148 | lt = len(sub_data)
149 | GT, PRED = [], []
150 | for i in range(lt):
151 | item = sub_data.iloc[i]
152 | idx = item['index']
153 | GT.append(answer_map[idx])
154 | PRED.append(prefetch_answer(item))
155 | if PRED[-1] and (GT[-1] != PRED[-1]):
156 | log = (
157 | f'Failed in Prefetching Rolling {i}: Answer is {GT[-1]}, '
158 | f"Prediction is {item['prediction']}, Pre-fetched is {PRED[-1]}. "
159 | )
160 | return dict(hit=0, log=log)
161 | flag = True
162 | for g, p in zip(GT, PRED):
163 | if g != p:
164 | flag = False
165 | ret = (dict(hit=1, log='Succeed During Pre-fetching'), ) if flag else (None, )
166 | ret = ret + (GT, PRED) if verbose else ret
167 | return ret if len(ret) > 1 else ret[0]
168 |
169 |
170 | def eval_sub_data(model, sub_data, answer_map):
171 | res, GT, PRED = prefetch_sub_data(sub_data, answer_map, verbose=True)
172 | if res is not None:
173 | return res
174 |
175 | lt = len(sub_data)
176 | log = ''
177 | for i in range(lt):
178 | if PRED[i]:
179 | log += f'Rolling {i} Matched.\n'
180 | else:
181 | res = extract_answer_from_item(model, sub_data.iloc[i])
182 | opt, match_log = res['opt'], res['log']
183 | PRED[i] = opt
184 | if PRED[i] != GT[i]:
185 | log += (
186 | f"Failed in Rolling {i}: Answer is {GT[i]}; Prediction is {sub_data.iloc[i]['prediction']}; "
187 | f'Pre-fetched is {PRED[i]}; Match Log is {match_log}.\n'
188 | )
189 | return dict(hit=0, log=log)
190 | else:
191 | log += (
192 | f"Rolling {i}: Answer is {GT[i]}, Prediction is {sub_data.iloc[i]['prediction']}, "
193 | f'Pre-fetched is {PRED[i]}.\n'
194 | )
195 |
196 | return dict(hit=1, log=log)
197 |
198 |
199 | def eval_data_groups(model, data_groups, answer_map, result, result_file, nproc=16):
200 | prefetched = [prefetch_sub_data(g, answer_map, verbose=False) for g in data_groups]
201 | remain = []
202 | for dg, pf in zip(data_groups, prefetched):
203 | if pf:
204 | result[dg.iloc[0]['index'] % 1e6] = pf
205 | else:
206 | remain.append(dg)
207 | dump(result, result_file)
208 | tups = [(model, x, answer_map) for x in remain]
209 | keys = [x.iloc[0]['index'] % 1e6 for x in remain]
210 | if len(tups) == 0:
211 | return
212 |
213 | if model is None:
214 | logger = get_logger('Evaluation')
215 | logger.warning('Exact Matching mode, will not do GPT-based answer matching. ')
216 | for k in keys:
217 | result[k] = dict(
218 | hit=0, log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.')
219 | dump(result, result_file)
220 | return
221 |
222 | res = track_progress_rich(
223 | eval_sub_data,
224 | tups,
225 | nproc=nproc,
226 | chunksize=nproc,
227 | save=result_file,
228 | keys=keys)
229 | result = load(result_file)
230 | for k, v in zip(keys, res):
231 | if k in result:
232 | assert result[k]['hit'] == v['hit'] and result[k]['log'] == v['log']
233 | else:
234 | result[k] = v
235 | dump(result, result_file)
236 |
237 |
238 | def multiple_choice_eval(eval_file, dataset='default', model='chatgpt-0613', nproc=4, verbose=False):
239 | logger = get_logger('Evaluation')
240 |
241 | # assert dataset is not None
242 | if dataset == 'MMBench_TEST_CN':
243 | dataset = 'MMBench_CN'
244 | elif dataset == 'MMBench_TEST_EN':
245 | dataset = 'MMBench'
246 |
247 | # if listinstr(['mmbench', 'ccbench'], dataset.lower()):
248 | # data = load(eval_file)
249 | # data['index'] = [int(x) for x in data['index']]
250 | # dump(data, eval_file)
251 |
252 | rd.seed(2680)
253 | suffix = eval_file.split('.')[-1]
254 | assert model in ['chatgpt-0613', 'exact_matching', 'gpt-4-0125']
255 | name_str_map = {
256 | 'chatgpt-0613': 'openai',
257 | 'gpt-4-0125': 'gpt4'
258 | }
259 | name_str = name_str_map[model] if model in name_str_map else model
260 |
261 | if model == 'exact_matching':
262 | model = None
263 | else:
264 | if INTERNAL or gpt_key_set():
265 | model = build_judge(model, verbose=verbose, retry=10)
266 | else:
267 | logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
268 | model = None
269 |
270 | logger.info(f'Evaluating {eval_file}')
271 | result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
272 | result = {}
273 | if osp.exists(result_file):
274 | result = load(result_file)
275 |
276 | data = load(eval_file)
277 | data = data.sort_values(by='index')
278 | data['prediction'] = [str(x) for x in data['prediction']]
279 | for k in data.keys():
280 | data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
281 |
282 | if dataset != 'default':
283 | meta = TSVDataset(dataset).data
284 | else:
285 | logger.warning('Dataset is not provided, try to use the original `eval_file` as meta data. ')
286 | meta = load(eval_file)
287 | assert 'category' in meta and 'index' in meta and 'answer' in meta, (
288 | 'Essentail columns missing in the eval_file.')
289 |
290 | cate_map = {i: c for i, c in zip(meta['index'], meta['category'])}
291 | answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}
292 | l2_cate_map = {i: c for i, c in zip(meta['index'], meta['l2-category'])} if 'l2-category' in meta else None
293 | split_map = {i: c for i, c in zip(meta['index'], meta['split'])} if 'split' in meta else None
294 |
295 | if l2_cate_map is not None and np.all([pd.isna(x) for x in l2_cate_map.values()]):
296 | l2_cate_map = None
297 | if split_map is not None and np.all([pd.isna(x) for x in split_map.values()]):
298 | split_map = None
299 |
300 | if listinstr(['MMMU'], dataset):
301 | data = MMMU_preproc(data)
302 | answer_map = {k: (v if v in list(string.ascii_uppercase) else 'A') for k, v in answer_map.items()}
303 |
304 | data = data[data['index'].isin(answer_map)]
305 | data_main = data[data['index'] < int(1e6)]
306 | meta_idx_set = set(meta['index'])
307 | data_main = data_main[data_main['index'].isin(meta_idx_set)]
308 |
309 | lt = len(data_main)
310 | hit, tot = 0, 0
311 |
312 | data_groups = []
313 | for i in tqdm(range(lt)):
314 | # Dealing with the normal part
315 | item_main = data_main.iloc[i]
316 | idx = item_main['index']
317 |
318 | if idx in result:
319 | correct = result[idx]['hit']
320 | assert correct in [0, 1]
321 | hit += correct
322 | tot += 1
323 | continue
324 |
325 | sub_data = data[data['index'] % int(1e6) == idx]
326 | data_groups.append(sub_data)
327 |
328 | if len(data_groups):
329 | eval_data_groups(
330 | model=model,
331 | data_groups=data_groups,
332 | answer_map=answer_map,
333 | nproc=nproc,
334 | result=result,
335 | result_file=result_file)
336 |
337 | tmp_pth = f'/tmp/{timestr()}.xlsx'
338 | dump(data_main, tmp_pth)
339 | data_main = load(tmp_pth)
340 |
341 | res = load(result_file)
342 | indices = data_main['index']
343 |
344 | data_main['hit'] = [res[i]['hit'] for i in indices]
345 | data_main['log'] = [res[i]['log'] for i in indices]
346 |
347 | main_idx = data_main['index']
348 | data_main['category'] = [cate_map[i] for i in main_idx]
349 | if l2_cate_map is not None:
350 | data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
351 | if split_map is not None:
352 | data_main['split'] = [split_map[i] for i in indices]
353 |
354 | # load split
355 | dump(data_main, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
356 | data_main = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
357 |
358 | acc = report_acc(data_main)
359 | score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
360 | dump(acc, score_file)
361 | logger.info(f'multiple_choice_eval successfully finished evaluating {eval_file}, results saved in {score_file}')
362 | logger.info('Score: ')
363 | logger.info(acc)
364 | return acc
365 |
366 |
367 | def parse_args():
368 | parser = argparse.ArgumentParser(description='Inference LLM Answers. ')
369 | parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
370 | parser.add_argument(
371 | '--model',
372 | type=str,
373 | help='The LLM (GPT) used for inference. ',
374 | default='chatgpt-0613',
375 | choices=['chatgpt-0613', 'exact_matching', 'gpt-4-0125'])
376 | parser.add_argument(
377 | '--dataset',
378 | type=str,
379 | default='default',
380 | help='The dataset to evaluate')
381 | parser.add_argument('--nproc', type=int, default=6)
382 | parser.add_argument('--verbose', action='store_true')
383 | args = parser.parse_args()
384 | return args
385 |
386 |
387 | if __name__ == '__main__':
388 | args = parse_args()
389 | acc = multiple_choice_eval(
390 | eval_file=args.data, model=args.model, dataset=args.dataset, nproc=args.nproc, verbose=args.verbose)
--------------------------------------------------------------------------------