├── vlmeval ├── evaluate │ ├── __init__.py │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ └── misc-checkpoint.py │ ├── __pycache__ │ │ ├── misc.cpython-310.pyc │ │ ├── OCRBench.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── coco_eval.cpython-310.pyc │ │ ├── vqa_eval.cpython-310.pyc │ │ ├── yes_or_no.cpython-310.pyc │ │ ├── llavabench.cpython-310.pyc │ │ ├── mmvet_eval.cpython-310.pyc │ │ ├── mathvista_eval.cpython-310.pyc │ │ └── multiple_choice.cpython-310.pyc │ ├── misc.py │ └── multiple_choice.py ├── __pycache__ │ ├── config.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ └── inference_multi.cpython-310.pyc ├── smp │ ├── __pycache__ │ │ ├── lb.cpython-310.pyc │ │ ├── vlm.cpython-38.pyc │ │ ├── file.cpython-310.pyc │ │ ├── file.cpython-38.pyc │ │ ├── log.cpython-310.pyc │ │ ├── misc.cpython-310.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── vlm.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── __init__.cpython-310.pyc │ ├── __init__.py │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── log-checkpoint.py │ │ ├── misc-checkpoint.py │ │ ├── vlm-checkpoint.py │ │ ├── file-checkpoint.py │ │ └── lb-checkpoint.py │ ├── log.py │ ├── misc.py │ ├── vlm.py │ ├── file.py │ └── lb.py ├── api │ ├── __pycache__ │ │ ├── base.cpython-310.pyc │ │ ├── gpt.cpython-310.pyc │ │ ├── gemini.cpython-310.pyc │ │ ├── gpt_int.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── qwen_api.cpython-310.pyc │ │ ├── qwen_vl_api.cpython-310.pyc │ │ └── hf_chat_model.cpython-310.pyc │ ├── __init__.py │ ├── base.py │ ├── qwen_api.py │ ├── qwen_vl_api.py │ ├── gemini.py │ ├── gpt_int.py │ ├── gpt.py │ └── hf_chat_model.py ├── vlm │ ├── __pycache__ │ │ ├── emu.cpython-310.pyc │ │ ├── cogvlm.cpython-310.pyc │ │ ├── idefics.cpython-310.pyc │ │ ├── llava.cpython-310.pyc │ │ ├── mmalaya.cpython-310.pyc │ │ ├── monkey.cpython-310.pyc │ │ ├── omnilmm.cpython-310.pyc │ │ ├── qwen_vl.cpython-310.pyc │ │ ├── yi_vl.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── minicpm_v.cpython-310.pyc │ │ ├── minigpt4.cpython-310.pyc │ │ ├── pandagpt.cpython-310.pyc │ │ ├── visualglm.cpython-310.pyc │ │ ├── xcomposer.cpython-310.pyc │ │ ├── instructblip.cpython-310.pyc │ │ ├── llava_xtuner.cpython-310.pyc │ │ ├── mplug_owl2.cpython-310.pyc │ │ ├── transcore_m.cpython-310.pyc │ │ ├── xcomposer2.cpython-310.pyc │ │ ├── internvl_chat.cpython-310.pyc │ │ ├── open_flamingo.cpython-310.pyc │ │ └── sharedcaptioner.cpython-310.pyc │ ├── __init__.py │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ └── instructblip-checkpoint.py │ └── instructblip.py ├── utils │ ├── __pycache__ │ │ ├── dataset.cpython-310.pyc │ │ ├── debate.cpython-310.pyc │ │ ├── mp_util.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── base_prompt.cpython-310.pyc │ │ ├── custom_prompt.cpython-310.pyc │ │ ├── dataset_config.cpython-310.pyc │ │ └── matching_util.cpython-310.pyc │ ├── __init__.py │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── custom_prompt-checkpoint.py │ │ ├── matching_util-checkpoint.py │ │ ├── debate-checkpoint.py │ │ ├── dataset_config-checkpoint.py │ │ ├── base_prompt-checkpoint.py │ │ ├── mp_util-checkpoint.py │ │ └── dataset-checkpoint.py │ ├── custom_prompt.py │ ├── matching_util.py │ ├── debate.py │ ├── dataset_config.py │ ├── base_prompt.py │ ├── mp_util.py │ └── dataset.py ├── __init__.py ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── config-checkpoint.py │ └── inference_multi-checkpoint.py ├── config.py └── inference_multi.py ├── assets ├── Model1.png ├── overview.png └── framework.png ├── scripts └── debate.sh ├── requirements.txt ├── README.md ├── run.py └── setup.py /vlmeval/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .multiple_choice import multiple_choice_eval -------------------------------------------------------------------------------- /assets/Model1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/assets/Model1.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/assets/overview.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/assets/framework.png -------------------------------------------------------------------------------- /vlmeval/evaluate/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .multiple_choice import multiple_choice_eval -------------------------------------------------------------------------------- /vlmeval/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/lb.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/lb.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/vlm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/vlm.cpython-38.pyc -------------------------------------------------------------------------------- /vlmeval/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/gpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/gpt.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__init__.py: -------------------------------------------------------------------------------- 1 | from .file import * 2 | from .vlm import * 3 | from .misc import * 4 | from .log import * 5 | from .lb import * -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/file.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/file.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/file.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/file.cpython-38.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/log.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/log.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/vlm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/vlm.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/emu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/emu.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/gemini.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/gemini.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/gpt_int.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/gpt_int.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/cogvlm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/cogvlm.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/idefics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/idefics.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/llava.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/llava.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/mmalaya.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/mmalaya.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/monkey.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/monkey.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/omnilmm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/omnilmm.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/qwen_vl.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/qwen_vl.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/yi_vl.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/yi_vl.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/qwen_api.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/qwen_api.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/smp/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/debate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/debate.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/mp_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/mp_util.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/minicpm_v.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/minicpm_v.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/minigpt4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/minigpt4.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/pandagpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/pandagpt.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/visualglm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/visualglm.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/xcomposer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/xcomposer.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/__pycache__/inference_multi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/__pycache__/inference_multi.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/qwen_vl_api.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/qwen_vl_api.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_grad_enabled(False) 4 | torch.manual_seed(1234) 5 | from .instructblip import InstructBLIP 6 | -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/instructblip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/instructblip.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/llava_xtuner.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/llava_xtuner.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/mplug_owl2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/mplug_owl2.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/transcore_m.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/transcore_m.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/xcomposer2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/xcomposer2.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/api/__pycache__/hf_chat_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/api/__pycache__/hf_chat_model.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/OCRBench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/OCRBench.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/coco_eval.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/coco_eval.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/vqa_eval.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/vqa_eval.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/yes_or_no.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/yes_or_no.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/base_prompt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/base_prompt.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/internvl_chat.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/internvl_chat.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/open_flamingo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/open_flamingo.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/llavabench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/llavabench.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/mmvet_eval.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/mmvet_eval.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/custom_prompt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/custom_prompt.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/dataset_config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/dataset_config.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/utils/__pycache__/matching_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/utils/__pycache__/matching_util.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/__pycache__/sharedcaptioner.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/vlm/__pycache__/sharedcaptioner.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/mathvista_eval.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/mathvista_eval.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/smp/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .file import * 2 | from .vlm import * 3 | from .misc import * 4 | from .log import * 5 | from .lb import * -------------------------------------------------------------------------------- /vlmeval/evaluate/__pycache__/multiple_choice.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thecharm/BDoG/HEAD/vlmeval/evaluate/__pycache__/multiple_choice.cpython-310.pyc -------------------------------------------------------------------------------- /vlmeval/vlm/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_grad_enabled(False) 4 | torch.manual_seed(1234) 5 | from .instructblip import InstructBLIP 6 | -------------------------------------------------------------------------------- /vlmeval/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | except ImportError: 4 | pass 5 | 6 | from .smp import * 7 | from .evaluate import * 8 | from .utils import * 9 | from .api import * 10 | from .vlm import * 11 | from .config import * -------------------------------------------------------------------------------- /vlmeval/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | except ImportError: 4 | pass 5 | 6 | from .smp import * 7 | from .evaluate import * 8 | from .utils import * 9 | from .api import * 10 | from .vlm import * 11 | from .config import * -------------------------------------------------------------------------------- /scripts/debate.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=2 run.py --data ScienceQA_TEST \ 2 | --model instructblip_13b \ 3 | --stage BDebate_kg_test \ 4 | --debate 2 5 | --kg_init 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | json 2 | numpy>=1.23.4 3 | requests 4 | tqdm 5 | pandas>=1.5.3 6 | gradio==4.15.0 7 | tiktoken 8 | rich 9 | portalocker 10 | torch>=2.0.1 11 | timeout-decorator 12 | transformers>=4.36.2 13 | opencv-python>=4.4.0.46 14 | typing_extensions==4.7.1 15 | pillow 16 | omegaconf 17 | matplotlib 18 | einops 19 | sentencepiece 20 | sty 21 | huggingface_hub 22 | visual_genome 23 | pycocoevalcap 24 | openpyxl 25 | seaborn 26 | tabulate 27 | xlsxwriter -------------------------------------------------------------------------------- /vlmeval/api/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import OpenAIWrapper, GPT4V 2 | from .gpt_int import OpenAIWrapperInternal, GPT4V_Internal 3 | from .hf_chat_model import HFChatModel 4 | from .gemini import GeminiWrapper, GeminiProVision 5 | from .qwen_vl_api import QwenVLWrapper, QwenVLAPI 6 | from .qwen_api import QwenAPI 7 | 8 | __all__ = [ 9 | 'OpenAIWrapper', 'HFChatModel', 'OpenAIWrapperInternal', 'GeminiWrapper', 10 | 'GPT4V', 'GPT4V_Internal', 'GeminiProVision','QwenVLWrapper', 'QwenVLAPI', 11 | 'QwenAPI' 12 | ] -------------------------------------------------------------------------------- /vlmeval/config.py: -------------------------------------------------------------------------------- 1 | from .vlm import * 2 | from .api import GPT4V, GeminiProVision 3 | from functools import partial 4 | 5 | models = { 6 | 'instructblip_13b': partial(InstructBLIP, name='/code/BaseModel/VLM/instructblip-vicuna-13b'), 7 | } 8 | 9 | api_models = { 10 | 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10), 11 | 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10), 12 | } 13 | 14 | supported_VLM = {} 15 | for model_set in [models, api_models]: 16 | supported_VLM.update(model_set) 17 | -------------------------------------------------------------------------------- /vlmeval/.ipynb_checkpoints/config-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .vlm import * 2 | from .api import GPT4V, GeminiProVision 3 | from functools import partial 4 | 5 | models = { 6 | 'instructblip_13b': partial(InstructBLIP, name='/code/BaseModel/VLM/instructblip-vicuna-13b'), 7 | } 8 | 9 | api_models = { 10 | 'GPT4V': partial(GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10), 11 | 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10), 12 | } 13 | 14 | supported_VLM = {} 15 | for model_set in [models, api_models]: 16 | supported_VLM.update(model_set) 17 | -------------------------------------------------------------------------------- /vlmeval/evaluate/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal 3 | 4 | INTERNAL = os.environ.get('INTERNAL', 0) 5 | 6 | def build_judge(version, **kwargs): 7 | model_map = { 8 | 'gpt-4-turbo': 'gpt-4-1106-preview', 9 | 'gpt-4-0613': 'gpt-4-0613', 10 | 'gpt-4-0314': 'gpt-4-0314', 11 | 'chatgpt-1106': 'gpt-3.5-turbo-1106', 12 | 'chatgpt-0613': 'gpt-3.5-turbo-0613' 13 | } 14 | model_version = model_map[version] 15 | if INTERNAL: 16 | model = OpenAIWrapperInternal(model_version, **kwargs) 17 | else: 18 | model = OpenAIWrapper(model_version, **kwargs) 19 | return model 20 | -------------------------------------------------------------------------------- /vlmeval/evaluate/.ipynb_checkpoints/misc-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vlmeval.api import OpenAIWrapper, OpenAIWrapperInternal 3 | 4 | INTERNAL = os.environ.get('INTERNAL', 0) 5 | 6 | def build_judge(version, **kwargs): 7 | model_map = { 8 | 'gpt-4-turbo': 'gpt-4-1106-preview', 9 | 'gpt-4-0613': 'gpt-4-0613', 10 | 'gpt-4-0314': 'gpt-4-0314', 11 | 'chatgpt-1106': 'gpt-3.5-turbo-1106', 12 | 'chatgpt-0613': 'gpt-3.5-turbo-0613' 13 | } 14 | model_version = model_map[version] 15 | if INTERNAL: 16 | model = OpenAIWrapperInternal(model_version, **kwargs) 17 | else: 18 | model = OpenAIWrapper(model_version, **kwargs) 19 | return model 20 | -------------------------------------------------------------------------------- /vlmeval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .matching_util import can_infer, can_infer_option, can_infer_text 2 | from .mp_util import track_progress_rich 3 | from .custom_prompt import CustomPrompt 4 | from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full 5 | from .dataset import TSVDataset, split_MMMU, init_prompt_multi 6 | from .base_prompt import create_one_example 7 | from .debate import Debate_VLM 8 | 9 | 10 | __all__ = [ 11 | 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich', 12 | 'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt', 13 | 'split_MMMU', 'abbr2full', 'init_prompt_multi', 14 | 'create_one_example', 'Debate_VLM' 15 | ] -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .matching_util import can_infer, can_infer_option, can_infer_text 2 | from .mp_util import track_progress_rich 3 | from .custom_prompt import CustomPrompt 4 | from .dataset_config import dataset_URLs, img_root_map, DATASET_TYPE, abbr2full 5 | from .dataset import TSVDataset, split_MMMU, init_prompt_multi 6 | from .base_prompt import create_one_example 7 | from .debate import Debate_VLM 8 | 9 | 10 | __all__ = [ 11 | 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich', 12 | 'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt', 13 | 'split_MMMU', 'abbr2full', 'init_prompt_multi', 14 | 'create_one_example', 'Debate_VLM' 15 | ] -------------------------------------------------------------------------------- /vlmeval/utils/custom_prompt.py: -------------------------------------------------------------------------------- 1 | from ..smp import * 2 | from .dataset_config import img_root_map 3 | from abc import abstractmethod 4 | 5 | class CustomPrompt: 6 | 7 | @abstractmethod 8 | def use_custom_prompt(self, dataset): 9 | raise NotImplementedError 10 | 11 | @abstractmethod 12 | def build_prompt(self, line, dataset): 13 | raise NotImplementedError 14 | 15 | def dump_image(self, line, dataset): 16 | ROOT = LMUDataRoot() 17 | assert isinstance(dataset, str) 18 | img_root = osp.join(ROOT, 'images', img_root_map[dataset]) 19 | os.makedirs(img_root, exist_ok=True) 20 | if isinstance(line['image'], list): 21 | tgt_path = [] 22 | assert 'image_path' in line 23 | for img, im_name in zip(line['image'], line['image_path']): 24 | path = osp.join(img_root, im_name) 25 | if not read_ok(path): 26 | decode_base64_to_image_file(img, path) 27 | tgt_path.append(path) 28 | else: 29 | tgt_path = osp.join(img_root, f"{line['index']}.jpg") 30 | if not read_ok(tgt_path): 31 | decode_base64_to_image_file(line['image'], tgt_path) 32 | return tgt_path -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/custom_prompt-checkpoint.py: -------------------------------------------------------------------------------- 1 | from ..smp import * 2 | from .dataset_config import img_root_map 3 | from abc import abstractmethod 4 | 5 | class CustomPrompt: 6 | 7 | @abstractmethod 8 | def use_custom_prompt(self, dataset): 9 | raise NotImplementedError 10 | 11 | @abstractmethod 12 | def build_prompt(self, line, dataset): 13 | raise NotImplementedError 14 | 15 | def dump_image(self, line, dataset): 16 | ROOT = LMUDataRoot() 17 | assert isinstance(dataset, str) 18 | img_root = osp.join(ROOT, 'images', img_root_map[dataset]) 19 | os.makedirs(img_root, exist_ok=True) 20 | if isinstance(line['image'], list): 21 | tgt_path = [] 22 | assert 'image_path' in line 23 | for img, im_name in zip(line['image'], line['image_path']): 24 | path = osp.join(img_root, im_name) 25 | if not read_ok(path): 26 | decode_base64_to_image_file(img, path) 27 | tgt_path.append(path) 28 | else: 29 | tgt_path = osp.join(img_root, f"{line['index']}.jpg") 30 | if not read_ok(tgt_path): 31 | decode_base64_to_image_file(line['image'], tgt_path) 32 | return tgt_path -------------------------------------------------------------------------------- /vlmeval/smp/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger_initialized = {} 4 | 5 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='a'): 6 | logger = logging.getLogger(name) 7 | if name in logger_initialized: 8 | return logger 9 | 10 | for logger_name in logger_initialized: 11 | if name.startswith(logger_name): 12 | return logger 13 | 14 | stream_handler = logging.StreamHandler() 15 | handlers = [stream_handler] 16 | 17 | try: 18 | import torch.distributed as dist 19 | if dist.is_available() and dist.is_initialized(): 20 | rank = dist.get_rank() 21 | else: 22 | rank = 0 23 | except ImportError: 24 | rank = 0 25 | 26 | if rank == 0 and log_file is not None: 27 | file_handler = logging.FileHandler(log_file, file_mode) 28 | handlers.append(file_handler) 29 | 30 | formatter = logging.Formatter( 31 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 32 | for handler in handlers: 33 | handler.setFormatter(formatter) 34 | handler.setLevel(log_level) 35 | logger.addHandler(handler) 36 | 37 | if rank == 0: 38 | logger.setLevel(log_level) 39 | else: 40 | logger.setLevel(logging.ERROR) 41 | 42 | logger_initialized[name] = True 43 | return logger -------------------------------------------------------------------------------- /vlmeval/smp/.ipynb_checkpoints/log-checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger_initialized = {} 4 | 5 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='a'): 6 | logger = logging.getLogger(name) 7 | if name in logger_initialized: 8 | return logger 9 | 10 | for logger_name in logger_initialized: 11 | if name.startswith(logger_name): 12 | return logger 13 | 14 | stream_handler = logging.StreamHandler() 15 | handlers = [stream_handler] 16 | 17 | try: 18 | import torch.distributed as dist 19 | if dist.is_available() and dist.is_initialized(): 20 | rank = dist.get_rank() 21 | else: 22 | rank = 0 23 | except ImportError: 24 | rank = 0 25 | 26 | if rank == 0 and log_file is not None: 27 | file_handler = logging.FileHandler(log_file, file_mode) 28 | handlers.append(file_handler) 29 | 30 | formatter = logging.Formatter( 31 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 32 | for handler in handlers: 33 | handler.setFormatter(formatter) 34 | handler.setLevel(log_level) 35 | logger.addHandler(handler) 36 | 37 | if rank == 0: 38 | logger.setLevel(log_level) 39 | else: 40 | logger.setLevel(logging.ERROR) 41 | 42 | logger_initialized[name] = True 43 | return logger -------------------------------------------------------------------------------- /vlmeval/vlm/instructblip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from abc import abstractproperty 4 | import os.path as osp 5 | import random 6 | import os, sys 7 | from ..smp import * 8 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration 9 | 10 | 11 | class InstructBLIP: 12 | 13 | INSTALL_REQ = True 14 | 15 | def __init__(self, name): 16 | 17 | model = InstructBlipForConditionalGeneration.from_pretrained(name, torch_dtype=torch.float16, device_map='cuda').eval() 18 | processor = InstructBlipProcessor.from_pretrained(name) 19 | self.processors = processor 20 | self.model = model 21 | 22 | def generate(self, prompt, image_path, dataset=None, max_length=100): 23 | raw_image = Image.open(image_path).convert('RGB') 24 | inputs = self.processors(images=raw_image, text=prompt, return_tensors="pt").to("cuda") 25 | try: 26 | outputs = self.model.generate( 27 | **inputs, 28 | do_sample=False, 29 | num_beams=5, 30 | max_length=max_length, 31 | min_length=1, 32 | top_p=0.9, 33 | repetition_penalty=1.5, 34 | length_penalty=1.0, 35 | temperature=1) 36 | generated_text = self.processors.batch_decode(outputs, skip_special_tokens=True)[0].strip() 37 | except Exception as e: 38 | generated_text = random.choice(['A','B','C','D','E']) 39 | print("Prompt_Error: ", e) 40 | 41 | 42 | return generated_text 43 | -------------------------------------------------------------------------------- /vlmeval/vlm/.ipynb_checkpoints/instructblip-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from abc import abstractproperty 4 | import os.path as osp 5 | import random 6 | import os, sys 7 | from ..smp import * 8 | from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration 9 | 10 | 11 | class InstructBLIP: 12 | 13 | INSTALL_REQ = True 14 | 15 | def __init__(self, name): 16 | 17 | model = InstructBlipForConditionalGeneration.from_pretrained(name, torch_dtype=torch.float16, device_map='cuda').eval() 18 | processor = InstructBlipProcessor.from_pretrained(name) 19 | self.processors = processor 20 | self.model = model 21 | 22 | def generate(self, prompt, image_path, dataset=None, max_length=100): 23 | raw_image = Image.open(image_path).convert('RGB') 24 | inputs = self.processors(images=raw_image, text=prompt, return_tensors="pt").to("cuda") 25 | try: 26 | outputs = self.model.generate( 27 | **inputs, 28 | do_sample=False, 29 | num_beams=5, 30 | max_length=max_length, 31 | min_length=1, 32 | top_p=0.9, 33 | repetition_penalty=1.5, 34 | length_penalty=1.0, 35 | temperature=1) 36 | generated_text = self.processors.batch_decode(outputs, skip_special_tokens=True)[0].strip() 37 | except Exception as e: 38 | generated_text = random.choice(['A','B','C','D','E']) 39 | print("Prompt_Error: ", e) 40 | 41 | 42 | return generated_text 43 | -------------------------------------------------------------------------------- /vlmeval/utils/matching_util.py: -------------------------------------------------------------------------------- 1 | import string 2 | import copy as cp 3 | import os 4 | from ..smp import * 5 | 6 | def can_infer_option(answer, choices): 7 | verbose = os.environ.get('VERBOSE', 0) 8 | # Choices is a dictionary 9 | if 'Failed to obtain answer via API' in answer: 10 | return False 11 | 12 | reject_to_answer = [ 13 | "Sorry, I can't help with images of people yet.", 14 | "I can't process this file.", 15 | "I'm sorry, but without the image provided", 16 | "Cannot determine the answer" 17 | ] 18 | for err in reject_to_answer: 19 | if err in answer: 20 | return 'Z' 21 | 22 | def count_choice(splits, choices, prefix='', suffix=''): 23 | cnt = 0 24 | for c in choices: 25 | if prefix + c + suffix in splits: 26 | cnt += 1 27 | return cnt 28 | 29 | answer_mod = cp.copy(answer) 30 | chars = '.()[],:;!*#{}' 31 | for c in chars: 32 | answer_mod = answer_mod.replace(c, ' ') 33 | 34 | splits = [x.strip() for x in answer_mod.split()] 35 | count = count_choice(splits, choices) 36 | 37 | if count == 1: 38 | for ch in choices: 39 | if 'A' in splits and len(splits) > 3 and verbose: 40 | logger = get_logger('Evaluation') 41 | logger.info(f'A might be a quantifier in the string: {answer}.') 42 | return False 43 | if ch in splits: 44 | return ch 45 | elif count == 0 and count_choice(splits, {'Z', ''}) == 1: 46 | return 'Z' 47 | return False 48 | 49 | def can_infer_text(answer, choices): 50 | answer = answer.lower() 51 | assert isinstance(choices, dict) 52 | for k in choices: 53 | assert k in string.ascii_uppercase 54 | choices[k] = str(choices[k]).lower() 55 | cands = [] 56 | for k in choices: 57 | if choices[k] in answer: 58 | cands.append(k) 59 | if len(cands) == 1: 60 | return cands[0] 61 | return False 62 | 63 | def can_infer(answer, choices): 64 | answer = str(answer) 65 | copt = can_infer_option(answer, choices) 66 | return copt if copt else can_infer_text(answer, choices) -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/matching_util-checkpoint.py: -------------------------------------------------------------------------------- 1 | import string 2 | import copy as cp 3 | import os 4 | from ..smp import * 5 | 6 | def can_infer_option(answer, choices): 7 | verbose = os.environ.get('VERBOSE', 0) 8 | # Choices is a dictionary 9 | if 'Failed to obtain answer via API' in answer: 10 | return False 11 | 12 | reject_to_answer = [ 13 | "Sorry, I can't help with images of people yet.", 14 | "I can't process this file.", 15 | "I'm sorry, but without the image provided", 16 | "Cannot determine the answer" 17 | ] 18 | for err in reject_to_answer: 19 | if err in answer: 20 | return 'Z' 21 | 22 | def count_choice(splits, choices, prefix='', suffix=''): 23 | cnt = 0 24 | for c in choices: 25 | if prefix + c + suffix in splits: 26 | cnt += 1 27 | return cnt 28 | 29 | answer_mod = cp.copy(answer) 30 | chars = '.()[],:;!*#{}' 31 | for c in chars: 32 | answer_mod = answer_mod.replace(c, ' ') 33 | 34 | splits = [x.strip() for x in answer_mod.split()] 35 | count = count_choice(splits, choices) 36 | 37 | if count == 1: 38 | for ch in choices: 39 | if 'A' in splits and len(splits) > 3 and verbose: 40 | logger = get_logger('Evaluation') 41 | logger.info(f'A might be a quantifier in the string: {answer}.') 42 | return False 43 | if ch in splits: 44 | return ch 45 | elif count == 0 and count_choice(splits, {'Z', ''}) == 1: 46 | return 'Z' 47 | return False 48 | 49 | def can_infer_text(answer, choices): 50 | answer = answer.lower() 51 | assert isinstance(choices, dict) 52 | for k in choices: 53 | assert k in string.ascii_uppercase 54 | choices[k] = str(choices[k]).lower() 55 | cands = [] 56 | for k in choices: 57 | if choices[k] in answer: 58 | cands.append(k) 59 | if len(cands) == 1: 60 | return cands[0] 61 | return False 62 | 63 | def can_infer(answer, choices): 64 | answer = str(answer) 65 | copt = can_infer_option(answer, choices) 66 | return copt if copt else can_infer_text(answer, choices) -------------------------------------------------------------------------------- /vlmeval/api/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random as rd 3 | from abc import abstractmethod 4 | from ..smp import get_logger 5 | 6 | class BaseAPI: 7 | 8 | def __init__(self, 9 | retry=10, 10 | wait=3, 11 | system_prompt=None, 12 | verbose=True, 13 | fail_msg='Failed to obtain answer via API.', 14 | **kwargs): 15 | self.wait = wait 16 | self.retry = retry 17 | self.system_prompt = system_prompt 18 | self.kwargs = kwargs 19 | self.verbose = verbose 20 | self.fail_msg = fail_msg 21 | self.logger = get_logger('ChatAPI') 22 | if len(kwargs): 23 | self.logger.info(f'BaseAPI received the following kwargs: {kwargs}') 24 | self.logger.info(f'Will try to use them as kwargs for `generate`. ') 25 | 26 | @abstractmethod 27 | def generate_inner(self, inputs, **kwargs): 28 | self.logger.warning(f'For APIBase, generate_inner is an abstract method. ') 29 | assert 0, 'generate_inner not defined' 30 | ret_code, answer, log = None, None, None 31 | # if ret_code is 0, means succeed 32 | return ret_code, answer, log 33 | 34 | def generate(self, inputs, **kwargs): 35 | input_type = None 36 | if isinstance(inputs, str): 37 | input_type = 'str' 38 | elif isinstance(inputs, list) and isinstance(inputs[0], str): 39 | input_type = 'strlist' 40 | elif isinstance(inputs, list) and isinstance(inputs[0], dict): 41 | input_type = 'dictlist' 42 | assert input_type is not None, input_type 43 | 44 | answer = None 45 | # a very small random delay [0s - 0.5s] 46 | T = rd.random() * 0.5 47 | time.sleep(T) 48 | 49 | for i in range(self.retry): 50 | try: 51 | ret_code, answer, log = self.generate_inner(inputs, **kwargs) 52 | if ret_code == 0 and self.fail_msg not in answer and answer != '': 53 | if self.verbose: 54 | print(answer) 55 | return answer 56 | elif self.verbose: 57 | self.logger.info(f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}") 58 | except Exception as err: 59 | if self.verbose: 60 | self.logger.error(f'An error occured during try {i}:') 61 | self.logger.error(err) 62 | # delay before each retry 63 | T = rd.random() * self.wait * 2 64 | time.sleep(T) 65 | 66 | return self.fail_msg if answer in ['', None] else answer 67 | -------------------------------------------------------------------------------- /vlmeval/api/qwen_api.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | import os 3 | from vlmeval.api.base import BaseAPI 4 | from vlmeval.smp import * 5 | 6 | class QwenAPI(BaseAPI): 7 | 8 | is_api: bool = True 9 | 10 | def __init__(self, 11 | model: str = 'qwen-max-1201', 12 | retry: int = 5, 13 | wait: int = 5, 14 | verbose: bool = True, 15 | seed: int = 2680, 16 | temperature: float = 0.0, 17 | system_prompt: str = None, 18 | key: str = None, 19 | max_tokens: int = 1024, 20 | proxy: str = None, 21 | **kwargs): 22 | 23 | assert model in ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-1201', 'qwen-max-longcontext'] 24 | self.model = model 25 | import dashscope 26 | self.fail_msg = 'Failed to obtain answer via API. ' 27 | self.max_tokens = max_tokens 28 | self.temperature = temperature 29 | self.seed = seed 30 | if key is None: 31 | key = os.environ.get('DASHSCOPE_API_KEY', None) 32 | assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)" 33 | dashscope.api_key = key 34 | if proxy is not None: 35 | proxy_set(proxy) 36 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 37 | 38 | @staticmethod 39 | def build_msgs(msgs_raw, system_prompt=None): 40 | msgs = cp.deepcopy(msgs_raw) 41 | ret = [] 42 | if system_prompt is not None: 43 | ret.append(dict(role='system', content=system_prompt)) 44 | for i, msg in enumerate(msgs): 45 | role = 'user' if i % 2 == 0 else 'assistant' 46 | ret.append(dict(role=role, content=msg)) 47 | return ret 48 | 49 | def generate_inner(self, inputs, **kwargs) -> str: 50 | from dashscope import MultiModalConversation 51 | assert isinstance(inputs, str) or isinstance(inputs, list) 52 | inputs = [inputs] if isinstance(inputs, str) else inputs 53 | messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt) 54 | 55 | import dashscope 56 | response = dashscope.Generation.call( 57 | model=self.model, 58 | messages=messages, 59 | seed=self.seed, 60 | temperature=self.temperature, 61 | max_tokens=self.max_tokens, 62 | result_format='message', # set the result to be "message" format. 63 | ) 64 | if response.status_code != HTTPStatus.OK: 65 | return -1, 'Error: Bad Response Statuse Code. ', f'The response status code is {response.status_code}. ' 66 | 67 | try: 68 | return 0, response['output']['choices'][0]['message']['content'].strip(), 'Succeeded! ' 69 | except Exception as err: 70 | return -1, f'Error: Failed to parse the response. {err}', response -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # A Picture Is Worth a Graph: A Blueprint Debate Paradigm for Multimodal Reasoning 5 | 6 | [Changmeng Zheng](https://github.com/thecharm)1, [Dayong Liang](https://github.com/YongLD)2, [Wengyu Zhang](https://github.com/zhangwengyu999)1, [Xiao-Yong Wei](https://scholar.google.com/citations?user=8kxWTokAAAAJ&hl=en)*,1, [Tat-Seng Chua](https://scholar.google.com.sg/citations?user=Z9DWCBEAAAAJ&hl=en)3, [Qing Li](https://scholar.google.com/citations?user=D1LEg-YAAAAJ&hl=en)1 7 | 8 |

1The Hong Kong Polytechnic   2South China University of Technology   3National University of Singapore 9 |
*Corresponding author    10 |

11 | 12 | [![arXiv](https://img.shields.io/badge/Arxiv-2406.07476-AD1C18.svg?logo=arXiv)](https://arxiv.org/pdf/2403.14972) 13 | 14 |
15 | 16 | ![figure1](assets/overview.png "BDoG") 17 | 18 | 19 | Blueprint Debate-on-Graph (BDoG) 20 | 21 | ## 🔥News 22 | 23 | 🔥 __[2024.10]__ Our paper has been nominated as the best paper award!\ 24 | 🔥 __[2024.07]__ The paper and Code are released! 25 | 26 | ## 🚀 Method 27 | 28 | ![method](assets/Model1.png "method") 29 |
30 | 31 | ## 🏗️ QuickStart 32 | ### 1. Installation 33 | ```bash 34 | git clone https://github.com/thecharm/BDoG.git 35 | cd BDoG 36 | pip install -e . 37 | ``` 38 | ### 2. Download model weights 39 | Download the [model weights](https://huggingface.co/Salesforce/instructblip-vicuna-13b) and set the model path in the `BDoG/vlmeval/config.py` file 40 | 41 | 42 | ### 3. Running 43 | ``` 44 | torchrun --nproc_per_node=1 run.py --data ScienceQA_TEST \ 45 | --stage BDebate \ 46 | --debate 2 47 | ``` 48 | + `--data` 49 | + Dataset supported: `ScienceQA_TEST` and `MMBench_DEV_EN`. 50 | + `--stage` 51 | + Prompt Type: `BDebate`(Blueprint Debate on Graph) or `ODebate`(Debate without Graph). 52 | + `--debate` 53 | + Number of rounds for the debate. 54 | + `--kg_init` 55 | + (optional) Use Gemini Graph as the initialization for multi-round debates. 56 | + `--nproc_per_node=2` 57 | + (optional) Speed up the inference process if you have two GPUs. 58 | + `--openai` 59 | + (optional) Use the Openai API key to perform the final result validation. 60 | 61 | The results are saved in the `BDoG/results/instructblip_13b` folder. 62 | 63 | During this process, the datasets will be automatically downloaded to the `/root/LMUData/` directory. If you need to change the data storage path, please reset `--lmudata`. 64 | 65 | ## ❤️ Acknowledgments 66 | - [VLMEvalKit](https://github.com/open-compass/VLMEvalKit): An open-source evaluation toolkit of large vision-language models (LVLMs). 67 | - [LLaVA](https://github.com/haotian-liu/LLaVA): Wounderful MLLM based on Large Language and Vision Assistant. 68 | - [LAVIS](https://github.com/salesforce/LAVIS): The amazing open-sourced multimodality learning codebase. 69 | 70 | 71 | ## 📑 Citation 72 | 73 | If this repo is useful to you, please cite using this BibTeX. 74 | ```bibtex 75 | @inproceedings{zheng2024picture, 76 | title={A Picture Is Worth a Graph: A Blueprint Debate Paradigm for Multimodal Reasoning}, 77 | author={Zheng, Changmeng and Liang, Dayong and Zhang, Wengyu and Wei, Xiao-Yong and Chua, Tat-Seng and Li, Qing}, 78 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia}, 79 | pages={419--428}, 80 | year={2024} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /vlmeval/api/qwen_vl_api.py: -------------------------------------------------------------------------------- 1 | from vlmeval.smp import * 2 | from vlmeval.api.base import BaseAPI 3 | 4 | class QwenVLWrapper(BaseAPI): 5 | 6 | is_api: bool = True 7 | 8 | def __init__(self, 9 | model: str = 'qwen-vl-plus', 10 | retry: int = 5, 11 | wait: int = 5, 12 | key: str = None, 13 | verbose: bool = True, 14 | temperature: float = 0.0, 15 | system_prompt: str = None, 16 | max_tokens: int = 1024, 17 | proxy: str = None, 18 | **kwargs): 19 | 20 | assert model in ['qwen-vl-plus', 'qwen-vl-max'] 21 | self.model = model 22 | import dashscope 23 | self.fail_msg = 'Failed to obtain answer via API. ' 24 | self.max_tokens = max_tokens 25 | self.temperature = temperature 26 | if key is None: 27 | key = os.environ.get('DASHSCOPE_API_KEY', None) 28 | assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)" 29 | dashscope.api_key = key 30 | if proxy is not None: 31 | proxy_set(proxy) 32 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 33 | 34 | @staticmethod 35 | def build_msgs(msgs_raw, system_prompt=None): 36 | msgs = cp.deepcopy(msgs_raw) 37 | ret = [] 38 | if system_prompt is not None: 39 | content = list(dict(text=system_prompt)) 40 | ret.append(dict(role='system', content=content)) 41 | content = [] 42 | for i, msg in enumerate(msgs): 43 | if osp.exists(msg): 44 | content.append(dict(image='file://' + msg)) 45 | elif msg.startswith('http'): 46 | content.append(dict(image=msg)) 47 | else: 48 | content.append(dict(text=msg)) 49 | ret.append(dict(role='user', content=content)) 50 | return ret 51 | 52 | def generate_inner(self, inputs, **kwargs) -> str: 53 | from dashscope import MultiModalConversation 54 | assert isinstance(inputs, str) or isinstance(inputs, list) 55 | pure_text = True 56 | if isinstance(inputs, list): 57 | for pth in inputs: 58 | if osp.exists(pth) or pth.startswith('http'): 59 | pure_text = False 60 | assert not pure_text 61 | messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt) 62 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature) 63 | gen_config.update(self.kwargs) 64 | try: 65 | response = MultiModalConversation.call(model=self.model, messages=messages) 66 | if self.verbose: 67 | print(response) 68 | answer = response.output.choices[0]['message']['content'][0]['text'] 69 | return 0, answer, 'Succeeded! ' 70 | except Exception as err: 71 | if self.verbose: 72 | self.logger.error(err) 73 | self.logger.error(f"The input messages are {inputs}.") 74 | 75 | return -1, '', '' 76 | 77 | class QwenVLAPI(QwenVLWrapper): 78 | 79 | def generate(self, image_path, prompt, dataset=None): 80 | return super(QwenVLAPI, self).generate([image_path, prompt]) 81 | 82 | def interleave_generate(self, ti_list, dataset=None): 83 | return super(QwenVLAPI, self).generate(ti_list) 84 | -------------------------------------------------------------------------------- /vlmeval/api/gemini.py: -------------------------------------------------------------------------------- 1 | from vlmeval.smp import * 2 | from vlmeval.api.base import BaseAPI 3 | 4 | headers = 'Content-Type: application/json' 5 | 6 | class GeminiWrapper(BaseAPI): 7 | 8 | is_api: bool = True 9 | 10 | def __init__(self, 11 | retry: int = 5, 12 | wait: int = 5, 13 | key: str = None, 14 | verbose: bool = True, 15 | temperature: float = 0.0, 16 | system_prompt: str = None, 17 | max_tokens: int = 1024, 18 | proxy: str = None, 19 | **kwargs): 20 | 21 | self.fail_msg = 'Failed to obtain answer via API. ' 22 | self.max_tokens = max_tokens 23 | self.temperature = temperature 24 | if key is None: 25 | key = os.environ.get('GOOGLE_API_KEY', None) 26 | assert key is not None 27 | self.api_key = key 28 | if proxy is not None: 29 | proxy_set(proxy) 30 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 31 | 32 | @staticmethod 33 | def build_msgs(msgs_raw, system_prompt=None): 34 | msgs = cp.deepcopy(msgs_raw) 35 | assert len(msgs) % 2 == 1 36 | 37 | if system_prompt is not None: 38 | msgs[0] = [system_prompt, msgs[0]] 39 | ret = [] 40 | for i, msg in enumerate(msgs): 41 | role = 'user' if i % 2 == 0 else 'model' 42 | parts = msg if isinstance(msg, list) else [msg] 43 | ret.append(dict(role=role, parts=parts)) 44 | return ret 45 | 46 | def generate_inner(self, inputs, **kwargs) -> str: 47 | import google.generativeai as genai 48 | assert isinstance(inputs, str) or isinstance(inputs, list) 49 | pure_text = True 50 | if isinstance(inputs, list): 51 | for pth in inputs: 52 | if osp.exists(pth) or pth.startswith('http'): 53 | pure_text = False 54 | genai.configure(api_key=self.api_key) 55 | model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel('gemini-pro-vision') 56 | if isinstance(inputs, str): 57 | messages = [inputs] if self.system_prompt is None else [self.system_prompt, inputs] 58 | elif pure_text: 59 | messages = self.build_msgs(inputs, self.system_prompt) 60 | else: 61 | messages = [] if self.system_prompt is None else [self.system_prompt] 62 | for s in inputs: 63 | if osp.exists(s): 64 | messages.append(Image.open(s)) 65 | elif s.startswith('http'): 66 | pth = download_file(s) 67 | messages.append(Image.open(pth)) 68 | shutil.remove(pth) 69 | else: 70 | messages.append(s) 71 | gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature) 72 | gen_config.update(self.kwargs) 73 | try: 74 | answer = model.generate_content(messages, generation_config=genai.types.GenerationConfig(**gen_config)).text 75 | return 0, answer, 'Succeeded! ' 76 | except Exception as err: 77 | if self.verbose: 78 | self.logger.error(err) 79 | self.logger.error(f"The input messages are {inputs}.") 80 | 81 | return -1, '', '' 82 | 83 | 84 | 85 | class GeminiProVision(GeminiWrapper): 86 | 87 | def generate(self, image_path, prompt, dataset=None): 88 | return super(GeminiProVision, self).generate([image_path, prompt]) 89 | 90 | def interleave_generate(self, ti_list, dataset=None): 91 | return super(GeminiProVision, self).generate(ti_list) -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from vlmeval.smp import * 4 | from vlmeval.evaluate import multiple_choice_eval 5 | from vlmeval.inference_multi import infer_data_job, prefetch_acc 6 | from vlmeval.config import supported_VLM 7 | from vlmeval.utils import dataset_URLs, abbr2full 8 | import json 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data', type=str, nargs='+', required=True) 13 | parser.add_argument("--model", type=str, nargs='+', default='instructblip_13b', required=False) 14 | parser.add_argument("--lmudata", type=str, default='', required=False) 15 | parser.add_argument("--openai", type=str, default='', required=False) 16 | parser.add_argument("--stage", type=str, default='BDebate', required=True) 17 | parser.add_argument("--nproc", type=int, default=4, help="Parallel API calling") 18 | parser.add_argument("--debate", type=int, default=2, required=True) 19 | parser.add_argument("--ignore", action='store_true', help="Ignore failed indices. ") 20 | parser.add_argument("--verbose", action='store_true') 21 | parser.add_argument("--prefetch", action='store_true') 22 | parser.add_argument("--kg_init", action='store_true') 23 | args = parser.parse_args() 24 | return args 25 | 26 | def main(): 27 | args = parse_args() 28 | assert len(args.data), "--data should be a list of data files" 29 | init_environ(args) 30 | 31 | rank, world_size = get_rank_and_world_size() 32 | if world_size > 1: 33 | torch.cuda.set_device(rank) 34 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400)) 35 | 36 | for _, model_name in enumerate(args.model): 37 | model = None 38 | pred_root = model_name 39 | 40 | for i, dataset_name in enumerate(args.data): 41 | os.makedirs(f'results/{model_name}/{dataset_name}', exist_ok=True) 42 | 43 | logger_file = "results/{}/{}/{}_{}_{}_log.txt".format(args.model[0], args.data[0], args.model[0],args.data[0], args.stage) 44 | logger = get_logger(name='Multi', log_file=logger_file) 45 | logger.info(f"####- Begin -####\n{args.model[0]}: {args.data[0]}-{args.stage}") 46 | 47 | if dataset_name not in dataset_URLs: 48 | dataset_name = abbr2full(dataset_name) 49 | 50 | if dataset_name not in dataset_URLs: 51 | logger.error(f'Unknown dataset: {dataset_name}. ') 52 | continue 53 | 54 | result_file = f'results/{pred_root}/{dataset_name}/{model_name}_{dataset_name}_{args.stage}_DB{args.debate}.xlsx' 55 | 56 | if model is None: 57 | model = model_name # which is only a name 58 | 59 | model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, args=args, logger=logger, ignore_failed=args.ignore) 60 | 61 | if rank == 0: 62 | time.sleep(3) 63 | res = None 64 | if listinstr(['MMBench'], dataset_name): 65 | res = prefetch_acc(result_file) 66 | else: 67 | logger.warning(f'{dataset_name} is not handled by prefetch score calculator') 68 | if res is not None: 69 | logger.info(f'{model_name} prefetching: ') 70 | logger.info(res) 71 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx')) 72 | 73 | if listinstr(['MMBench','ScienceQA'], dataset_name): 74 | multiple_choice_eval(result_file, dataset=dataset_name, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose) 75 | else: 76 | logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ') 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /vlmeval/api/gpt_int.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | import requests 4 | from ..smp import * 5 | from .gpt import GPT_context_window, OpenAIWrapper 6 | 7 | 8 | url = "http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat" 9 | headers = { 10 | "Content-Type": "application/json" 11 | } 12 | 13 | class OpenAIWrapperInternal(OpenAIWrapper): 14 | 15 | is_api: bool = True 16 | 17 | def __init__(self, 18 | model: str = 'gpt-3.5-turbo-0613', 19 | retry: int = 5, 20 | wait: int = 3, 21 | verbose: bool = True, 22 | system_prompt: str = None, 23 | temperature: float = 0, 24 | timeout: int = 60, 25 | max_tokens: int = 1024, 26 | img_size: int = 512, 27 | img_detail: str = 'low', 28 | **kwargs): 29 | 30 | self.model = model 31 | if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']): 32 | keys = load(os.environ['KEYS']) 33 | headers['alles-apin-token'] = keys.get('alles-apin-token', '') 34 | elif 'ALLES' in os.environ: 35 | headers['alles-apin-token'] = os.environ['ALLES'] 36 | self.headers = headers 37 | self.temperature = temperature 38 | self.timeout = timeout 39 | self.max_tokens = max_tokens 40 | 41 | assert img_size > 0 or img_size == -1 42 | self.img_size = img_size 43 | assert img_detail in ['high', 'low'] 44 | self.img_detail = img_detail 45 | 46 | self.vision = False 47 | if model == 'gpt-4-vision-preview': 48 | self.vision = True 49 | 50 | super(OpenAIWrapper, self).__init__( 51 | wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 52 | 53 | def generate_inner(self, inputs, **kwargs) -> str: 54 | input_msgs = self.prepare_inputs(inputs) 55 | 56 | temperature = kwargs.pop('temperature', self.temperature) 57 | max_tokens = kwargs.pop('max_tokens', self.max_tokens) 58 | 59 | # Held out 100 tokens as buffer 60 | context_window = GPT_context_window(self.model) 61 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) 62 | if 0 < max_tokens <= 100: 63 | print('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ') 64 | if max_tokens <= 0: 65 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. ' 66 | 67 | payload = dict( 68 | model=self.model, 69 | messages=input_msgs, 70 | max_tokens=max_tokens, 71 | n=1, 72 | stop=None, 73 | timeout=self.timeout, 74 | temperature=temperature, 75 | **kwargs) 76 | 77 | response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1) 78 | ret_code = response.status_code 79 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code 80 | 81 | answer = self.fail_msg 82 | try: 83 | resp_struct = json.loads(response.text) 84 | assert resp_struct['msg'] == 'ok' and resp_struct['msgCode'] == '10000', resp_struct 85 | answer = resp_struct['data']['choices'][0]['message']['content'].strip() 86 | except: 87 | pass 88 | return ret_code, answer, response 89 | 90 | 91 | class GPT4V_Internal(OpenAIWrapperInternal): 92 | 93 | def generate(self, image_path, prompt, dataset=None): 94 | assert self.model == 'gpt-4-vision-preview' 95 | return super(GPT4V_Internal, self).generate([image_path, prompt]) 96 | 97 | def interleave_generate(self, ti_list, dataset=None): 98 | assert self.model == 'gpt-4-vision-preview' 99 | return super(GPT4V_Internal, self).generate(ti_list) -------------------------------------------------------------------------------- /vlmeval/utils/debate.py: -------------------------------------------------------------------------------- 1 | from vlmeval.utils import init_prompt_multi 2 | 3 | def Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger): 4 | if stage[:8] == "baseline": 5 | prompt_format = "IQ-A" 6 | prompt_G = init_prompt_multi(struct, prompt_format) 7 | response = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name) 8 | logger.info("########--G_A--######\nPrompt: {}\nGT: {} - ANS: {}".format(prompt_G, struct['text']['answer'], response)) 9 | 10 | elif stage[:7] == "ODebate": 11 | for debate_ in range(debate): 12 | logger.info("########--DEBATE{}--######".format(debate_)) 13 | 14 | prompt_format = "ODIM-S" if debate_==0 else "ODQIM-S" 15 | prompt_A = init_prompt_multi(struct, prompt_format) 16 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name) 17 | 18 | prompt_format = "ONIM-S" if debate_==0 else "ONQIM-S" 19 | prompt_N = init_prompt_multi(struct, prompt_format) 20 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name) 21 | 22 | struct['kg'] = [kg_aff.strip(), kg_neg.strip()] 23 | prompt_format = "OAGM-A" 24 | prompt_F = init_prompt_multi(struct, prompt_format) 25 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name, max_length=20) 26 | 27 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response)) 28 | 29 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response)) 30 | 31 | elif stage[:7] == "BDebate": 32 | for debate_ in range(debate): 33 | logger.info("########--DEBATE{}--######".format(debate_)) 34 | if debate_ == 0: 35 | if kg_init: 36 | if struct['text']['kg'] != 'none': 37 | logger.info("#####---KG_IB---#####\n{}".format(struct['text']['kg'])) 38 | struct['kg'] = [struct['text']['kg'], struct['text']['kg']] 39 | else: 40 | prompt_format = "GKG-G" 41 | prompt_G = init_prompt_multi(struct, prompt_format) 42 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name) 43 | struct['kg'] = [kg_base, kg_base] 44 | logger.info("#####---KG_P---#####\n{}".format(prompt_G)) 45 | logger.info("#####---KG_B---#####\n{}".format(kg_base)) 46 | else: 47 | prompt_format = "GKG-G" 48 | prompt_G = init_prompt_multi(struct, prompt_format) 49 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name) 50 | struct['kg'] = [kg_base, kg_base] 51 | logger.info("#####---KG_P---#####\n{}".format(prompt_G)) 52 | logger.info("#####---KG_B---#####\n{}".format(kg_base)) 53 | 54 | prompt_format = "KDQIM-G" 55 | prompt_A = init_prompt_multi(struct, prompt_format) 56 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name) 57 | struct['kg'] = [kg_aff,struct['kg'][1]] 58 | 59 | prompt_format = "KNQIM-G" 60 | prompt_N = init_prompt_multi(struct, prompt_format) 61 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name) 62 | 63 | struct['kg'] = [kg_aff, kg_neg] 64 | prompt_format = "KAGM-A" 65 | prompt_F = init_prompt_multi(struct, prompt_format) 66 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name) 67 | 68 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response)) 69 | 70 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response)) 71 | 72 | else: 73 | assert stage == "BDebate", f"Please confirm if your 'stage' is set correctly.\nDebate Only: Begin with ODebate\nDebate with Blueprint:Begin with BDebate\nStage setting Now:{stage}" 74 | return response 75 | -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/debate-checkpoint.py: -------------------------------------------------------------------------------- 1 | from vlmeval.utils import init_prompt_multi 2 | 3 | def Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger): 4 | if stage[:8] == "baseline": 5 | prompt_format = "IQ-A" 6 | prompt_G = init_prompt_multi(struct, prompt_format) 7 | response = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name) 8 | logger.info("########--G_A--######\nPrompt: {}\nGT: {} - ANS: {}".format(prompt_G, struct['text']['answer'], response)) 9 | 10 | elif stage[:7] == "ODebate": 11 | for debate_ in range(debate): 12 | logger.info("########--DEBATE{}--######".format(debate_)) 13 | 14 | prompt_format = "ODIM-S" if debate_==0 else "ODQIM-S" 15 | prompt_A = init_prompt_multi(struct, prompt_format) 16 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name) 17 | 18 | prompt_format = "ONIM-S" if debate_==0 else "ONQIM-S" 19 | prompt_N = init_prompt_multi(struct, prompt_format) 20 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name) 21 | 22 | struct['kg'] = [kg_aff.strip(), kg_neg.strip()] 23 | prompt_format = "OAGM-A" 24 | prompt_F = init_prompt_multi(struct, prompt_format) 25 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name, max_length=20) 26 | 27 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response)) 28 | 29 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response)) 30 | 31 | elif stage[:7] == "BDebate": 32 | for debate_ in range(debate): 33 | logger.info("########--DEBATE{}--######".format(debate_)) 34 | if debate_ == 0: 35 | if kg_init: 36 | if struct['text']['kg'] != 'none': 37 | logger.info("#####---KG_IB---#####\n{}".format(struct['text']['kg'])) 38 | struct['kg'] = [struct['text']['kg'], struct['text']['kg']] 39 | else: 40 | prompt_format = "GKG-G" 41 | prompt_G = init_prompt_multi(struct, prompt_format) 42 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name) 43 | struct['kg'] = [kg_base, kg_base] 44 | logger.info("#####---KG_P---#####\n{}".format(prompt_G)) 45 | logger.info("#####---KG_B---#####\n{}".format(kg_base)) 46 | else: 47 | prompt_format = "GKG-G" 48 | prompt_G = init_prompt_multi(struct, prompt_format) 49 | kg_base = model.generate(prompt=prompt_G, image_path=struct['image'], dataset=dataset_name) 50 | struct['kg'] = [kg_base, kg_base] 51 | logger.info("#####---KG_P---#####\n{}".format(prompt_G)) 52 | logger.info("#####---KG_B---#####\n{}".format(kg_base)) 53 | 54 | prompt_format = "KDQIM-G" 55 | prompt_A = init_prompt_multi(struct, prompt_format) 56 | kg_aff = model.generate(prompt=prompt_A, image_path=struct['image'], dataset=dataset_name) 57 | struct['kg'] = [kg_aff,struct['kg'][1]] 58 | 59 | prompt_format = "KNQIM-G" 60 | prompt_N = init_prompt_multi(struct, prompt_format) 61 | kg_neg = model.generate(prompt=prompt_N, image_path=struct['image'], dataset=dataset_name) 62 | 63 | struct['kg'] = [kg_aff, kg_neg] 64 | prompt_format = "KAGM-A" 65 | prompt_F = init_prompt_multi(struct, prompt_format) 66 | response = model.generate(prompt=prompt_F, image_path=struct['image'], dataset=dataset_name) 67 | 68 | logger.info("########--ANSWER-{}--######\n{}".format(debate_, response)) 69 | 70 | logger.info("\nGT:{}-ANS: {} - ".format(struct['text']['answer'], response)) 71 | 72 | else: 73 | assert stage == "BDebate", f"Please confirm if your 'stage' is set correctly.\nDebate Only: Begin with ODebate\nDebate with Blueprint:Begin with BDebate\nStage setting Now:{stage}" 74 | return response 75 | -------------------------------------------------------------------------------- /vlmeval/smp/misc.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401, F403 2 | import abc 3 | import argparse 4 | import csv 5 | import multiprocessing as mp 6 | import os 7 | import os.path as osp 8 | import copy as cp 9 | import random as rd 10 | import requests 11 | import shutil 12 | import subprocess 13 | import warnings 14 | import pandas as pd 15 | from collections import OrderedDict, defaultdict 16 | from multiprocessing import Pool, current_process 17 | from tqdm import tqdm 18 | import datetime 19 | import matplotlib.pyplot as plt 20 | import seaborn as sns 21 | from tabulate import tabulate_formats, tabulate 22 | from huggingface_hub import scan_cache_dir 23 | from sty import fg, bg, ef, rs 24 | 25 | def process_punctuation(inText): 26 | import re 27 | outText = inText 28 | punct = [ 29 | ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', 30 | '>', '<', '@', '`', ',', '?', '!' 31 | ] 32 | commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 33 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 34 | for p in punct: 35 | if (p + ' ' in inText or ' ' + p in inText) or (re.search( 36 | commaStrip, inText) is not None): 37 | outText = outText.replace(p, '') 38 | else: 39 | outText = outText.replace(p, ' ') 40 | outText = periodStrip.sub('', outText, re.UNICODE) 41 | return outText 42 | 43 | def h2r(value): 44 | if value[0] == '#': 45 | value = value[1:] 46 | assert len(value) == 6 47 | return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2)) 48 | 49 | def r2h(rgb): 50 | return '#%02x%02x%02x' % rgb 51 | 52 | def colored(s, color): 53 | if isinstance(color, str): 54 | if hasattr(fg, color): 55 | return getattr(fg, color) + s + fg.rs 56 | color = h2r(color) 57 | return fg(*color) + s + fg.rs 58 | 59 | def istype(s, type): 60 | if isinstance(s, type): 61 | return True 62 | try: 63 | return isinstance(eval(s), type) 64 | except Exception as _: 65 | return False 66 | 67 | def bincount(lst): 68 | bins = defaultdict(lambda: 0) 69 | for item in lst: 70 | bins[item] += 1 71 | return bins 72 | 73 | def get_cache_path(repo_id): 74 | hf_cache_info = scan_cache_dir() 75 | repos = list(hf_cache_info.repos) 76 | repo = None 77 | for r in repos: 78 | if r.repo_id == repo_id: 79 | repo = r 80 | break 81 | if repo is None: 82 | return None 83 | revs = list(repo.revisions) 84 | rev2keep, last_modified = None, 0 85 | for rev in revs: 86 | if rev.last_modified > last_modified: 87 | rev2keep, last_modified = rev, rev.last_modified 88 | if rev2keep is None: 89 | return None 90 | return str(rev2keep.snapshot_path) 91 | 92 | def proxy_set(s): 93 | import os 94 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']: 95 | os.environ[key] = s 96 | 97 | def get_rank_and_world_size(): 98 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 99 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 100 | return local_rank, world_size 101 | 102 | def splitlen(s, sym='/'): 103 | return len(s.split(sym)) 104 | 105 | def listinstr(lst, s): 106 | assert isinstance(lst, list) 107 | for item in lst: 108 | if item in s: 109 | return True 110 | return False 111 | 112 | def d2df(D): 113 | return pd.DataFrame({x: [D[x]] for x in D}) 114 | 115 | def cn_string(s): 116 | import re 117 | if re.search(u'[\u4e00-\u9fff]', s): 118 | return True 119 | return False 120 | 121 | try: 122 | import decord 123 | except ImportError: 124 | pass 125 | 126 | def timestr(second=True, minute=False): 127 | s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:] 128 | if second: 129 | return s 130 | elif minute: 131 | return s[:-2] 132 | else: 133 | return s[:-4] 134 | 135 | def dict_merge(dct, merge_dct): 136 | for k, _ in merge_dct.items(): 137 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa 138 | dict_merge(dct[k], merge_dct[k]) 139 | else: 140 | dct[k] = merge_dct[k] 141 | 142 | def youtube_dl(idx): 143 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4' 144 | os.system(cmd) 145 | 146 | def run_command(cmd): 147 | if isinstance(cmd, str): 148 | cmd = cmd.split() 149 | return subprocess.check_output(cmd) 150 | -------------------------------------------------------------------------------- /vlmeval/smp/.ipynb_checkpoints/misc-checkpoint.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401, F403 2 | import abc 3 | import argparse 4 | import csv 5 | import multiprocessing as mp 6 | import os 7 | import os.path as osp 8 | import copy as cp 9 | import random as rd 10 | import requests 11 | import shutil 12 | import subprocess 13 | import warnings 14 | import pandas as pd 15 | from collections import OrderedDict, defaultdict 16 | from multiprocessing import Pool, current_process 17 | from tqdm import tqdm 18 | import datetime 19 | import matplotlib.pyplot as plt 20 | import seaborn as sns 21 | from tabulate import tabulate_formats, tabulate 22 | from huggingface_hub import scan_cache_dir 23 | from sty import fg, bg, ef, rs 24 | 25 | def process_punctuation(inText): 26 | import re 27 | outText = inText 28 | punct = [ 29 | ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', 30 | '>', '<', '@', '`', ',', '?', '!' 31 | ] 32 | commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 33 | periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 34 | for p in punct: 35 | if (p + ' ' in inText or ' ' + p in inText) or (re.search( 36 | commaStrip, inText) is not None): 37 | outText = outText.replace(p, '') 38 | else: 39 | outText = outText.replace(p, ' ') 40 | outText = periodStrip.sub('', outText, re.UNICODE) 41 | return outText 42 | 43 | def h2r(value): 44 | if value[0] == '#': 45 | value = value[1:] 46 | assert len(value) == 6 47 | return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2)) 48 | 49 | def r2h(rgb): 50 | return '#%02x%02x%02x' % rgb 51 | 52 | def colored(s, color): 53 | if isinstance(color, str): 54 | if hasattr(fg, color): 55 | return getattr(fg, color) + s + fg.rs 56 | color = h2r(color) 57 | return fg(*color) + s + fg.rs 58 | 59 | def istype(s, type): 60 | if isinstance(s, type): 61 | return True 62 | try: 63 | return isinstance(eval(s), type) 64 | except Exception as _: 65 | return False 66 | 67 | def bincount(lst): 68 | bins = defaultdict(lambda: 0) 69 | for item in lst: 70 | bins[item] += 1 71 | return bins 72 | 73 | def get_cache_path(repo_id): 74 | hf_cache_info = scan_cache_dir() 75 | repos = list(hf_cache_info.repos) 76 | repo = None 77 | for r in repos: 78 | if r.repo_id == repo_id: 79 | repo = r 80 | break 81 | if repo is None: 82 | return None 83 | revs = list(repo.revisions) 84 | rev2keep, last_modified = None, 0 85 | for rev in revs: 86 | if rev.last_modified > last_modified: 87 | rev2keep, last_modified = rev, rev.last_modified 88 | if rev2keep is None: 89 | return None 90 | return str(rev2keep.snapshot_path) 91 | 92 | def proxy_set(s): 93 | import os 94 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']: 95 | os.environ[key] = s 96 | 97 | def get_rank_and_world_size(): 98 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 99 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 100 | return local_rank, world_size 101 | 102 | def splitlen(s, sym='/'): 103 | return len(s.split(sym)) 104 | 105 | def listinstr(lst, s): 106 | assert isinstance(lst, list) 107 | for item in lst: 108 | if item in s: 109 | return True 110 | return False 111 | 112 | def d2df(D): 113 | return pd.DataFrame({x: [D[x]] for x in D}) 114 | 115 | def cn_string(s): 116 | import re 117 | if re.search(u'[\u4e00-\u9fff]', s): 118 | return True 119 | return False 120 | 121 | try: 122 | import decord 123 | except ImportError: 124 | pass 125 | 126 | def timestr(second=True, minute=False): 127 | s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:] 128 | if second: 129 | return s 130 | elif minute: 131 | return s[:-2] 132 | else: 133 | return s[:-4] 134 | 135 | def dict_merge(dct, merge_dct): 136 | for k, _ in merge_dct.items(): 137 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa 138 | dict_merge(dct[k], merge_dct[k]) 139 | else: 140 | dct[k] = merge_dct[k] 141 | 142 | def youtube_dl(idx): 143 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4' 144 | os.system(cmd) 145 | 146 | def run_command(cmd): 147 | if isinstance(cmd, str): 148 | cmd = cmd.split() 149 | return subprocess.check_output(cmd) 150 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from os.path import exists 4 | from setuptools import find_packages, setup 5 | 6 | def parse_requirements(fname='requirements.txt', with_version=True): 7 | """Parse the package dependencies listed in a requirements file but strips 8 | specific versioning information. 9 | 10 | Args: 11 | fname (str): path to requirements file 12 | with_version (bool, default=False): if True include version specs 13 | 14 | Returns: 15 | List[str]: list of requirements items 16 | 17 | CommandLine: 18 | python -c "import setup; print(setup.parse_requirements())" 19 | """ 20 | 21 | require_fpath = fname 22 | 23 | def parse_line(line): 24 | """Parse information from a line in a requirements text file.""" 25 | if line.startswith('-r '): 26 | # Allow specifying requirements in other files 27 | target = line.split(' ')[1] 28 | for info in parse_require_file(target): 29 | yield info 30 | else: 31 | info = {'line': line} 32 | if line.startswith('-e '): 33 | info['package'] = line.split('#egg=')[1] 34 | elif '@git+' in line: 35 | info['package'] = line 36 | else: 37 | # Remove versioning from the package 38 | pat = '(' + '|'.join(['>=', '==', '>']) + ')' 39 | parts = re.split(pat, line, maxsplit=1) 40 | parts = [p.strip() for p in parts] 41 | 42 | info['package'] = parts[0] 43 | if len(parts) > 1: 44 | op, rest = parts[1:] 45 | if ';' in rest: 46 | # Handle platform specific dependencies 47 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 48 | version, platform_deps = map(str.strip, 49 | rest.split(';')) 50 | info['platform_deps'] = platform_deps 51 | else: 52 | version = rest # NOQA 53 | info['version'] = (op, version) 54 | yield info 55 | 56 | def parse_require_file(fpath): 57 | with open(fpath, 'r') as f: 58 | for line in f.readlines(): 59 | line = line.strip() 60 | if line and not line.startswith('#'): 61 | for info in parse_line(line): 62 | yield info 63 | 64 | def gen_packages_items(): 65 | if exists(require_fpath): 66 | for info in parse_require_file(require_fpath): 67 | parts = [info['package']] 68 | if with_version and 'version' in info: 69 | parts.extend(info['version']) 70 | if not sys.version.startswith('3.4'): 71 | # apparently package_deps are broken in 3.4 72 | platform_deps = info.get('platform_deps') 73 | if platform_deps is not None: 74 | parts.append(';' + platform_deps) 75 | item = ''.join(parts) 76 | yield item 77 | 78 | packages = list(gen_packages_items()) 79 | return packages 80 | 81 | 82 | with open('README.md') as f: 83 | readme = f.read() 84 | 85 | 86 | def do_setup(): 87 | setup( 88 | name='vlmeval', 89 | version='0.1.0', 90 | description='BDoG', 91 | author="thecharm", 92 | author_email='csczheng@comp.polyu.edu.hk', 93 | long_description=readme, 94 | long_description_content_type='text/markdown', 95 | cmdclass={}, 96 | install_requires=parse_requirements('requirements.txt'), 97 | setup_requires=[], 98 | python_requires='>=3.7.0', 99 | packages=find_packages(exclude=[ 100 | 'test*', 101 | 'paper_test*', 102 | ]), 103 | keywords=['AI', 'NLP', 'in-context learning'], 104 | entry_points={ 105 | "console_scripts": [] 106 | }, 107 | classifiers=[ 108 | 'Programming Language :: Python :: 3.7', 109 | 'Programming Language :: Python :: 3.8', 110 | 'Programming Language :: Python :: 3.9', 111 | 'Programming Language :: Python :: 3.10', 112 | 'Intended Audience :: Developers', 113 | 'Intended Audience :: Education', 114 | 'Intended Audience :: Science/Research', 115 | ]) 116 | 117 | 118 | if __name__ == '__main__': 119 | do_setup() 120 | -------------------------------------------------------------------------------- /vlmeval/smp/vlm.py: -------------------------------------------------------------------------------- 1 | import os, io 2 | import pandas as pd 3 | import numpy as np 4 | import string 5 | from uuid import uuid4 6 | import os.path as osp 7 | import base64 8 | from PIL import Image 9 | 10 | def mmqa_display(question): 11 | question = {k.lower(): v for k, v in question.items()} 12 | keys = list(question.keys()) 13 | keys = [k for k in keys if k not in ['index', 'image']] 14 | 15 | images = question['image'] 16 | if isinstance(images, str): 17 | images = [images] 18 | 19 | idx = question.pop('index', 'XXX') 20 | print(f'INDEX: {idx}') 21 | 22 | for im in images: 23 | image = decode_base64_to_image(im, target_size=512) 24 | display(image) 25 | 26 | for k in keys: 27 | try: 28 | if not pd.isna(question[k]): 29 | print(f'{k.upper()}. {question[k]}') 30 | except ValueError: 31 | if False in pd.isna(question[k]): 32 | print(f'{k.upper()}. {question[k]}') 33 | 34 | def encode_image_to_base64(img, target_size=-1): 35 | # if target_size == -1, will not do resizing 36 | # else, will set the max_size ot (target_size, target_size) 37 | if img.mode in ("RGBA", "P"): 38 | img = img.convert("RGB") 39 | tmp = osp.join('/tmp', str(uuid4()) + '.jpg') 40 | if target_size > 0: 41 | img.thumbnail((target_size, target_size)) 42 | img.save(tmp) 43 | with open(tmp, 'rb') as image_file: 44 | image_data = image_file.read() 45 | ret = base64.b64encode(image_data).decode('utf-8') 46 | os.remove(tmp) 47 | return ret 48 | 49 | def encode_image_file_to_base64(image_path, target_size=-1): 50 | image = Image.open(image_path) 51 | return encode_image_to_base64(image, target_size=target_size) 52 | 53 | def decode_base64_to_image(base64_string, target_size=-1): 54 | image_data = base64.b64decode(base64_string) 55 | image = Image.open(io.BytesIO(image_data)) 56 | if image.mode in ('RGBA', 'P'): 57 | image = image.convert('RGB') 58 | if target_size > 0: 59 | image.thumbnail((target_size, target_size)) 60 | return image 61 | 62 | def decode_base64_to_image_file(base64_string, image_path, target_size=-1): 63 | image = decode_base64_to_image(base64_string, target_size=target_size) 64 | image.save(image_path) 65 | 66 | def LMUDataRoot(): 67 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']): 68 | return os.environ['LMUData'] 69 | home = osp.expanduser('~') 70 | root = osp.join(home, 'LMUData') 71 | os.makedirs(root, exist_ok=True) 72 | return root 73 | 74 | def init_environ(args): 75 | if len(args.lmudata): 76 | os.environ['LMUData']=args.lmudata 77 | if len(args.openai): 78 | os.environ['OPENAI_API_KEY']=args.openai 79 | 80 | def build_option_str(option_dict): 81 | s = 'There are several options: \n' 82 | for c, content in option_dict.items(): 83 | if not pd.isna(content): 84 | s += f'{c}. {content}\n' 85 | return s 86 | 87 | def isimg(s): 88 | return osp.exists(s) or s.startswith('http') 89 | 90 | def read_ok(img_path): 91 | if not osp.exists(img_path): 92 | return False 93 | try: 94 | im = Image.open(img_path) 95 | assert im.size[0] > 0 and im.size[1] > 0 96 | return True 97 | except: 98 | return False 99 | 100 | def gpt_key_set(): 101 | openai_key = os.environ.get('OPENAI_API_KEY', None) 102 | return isinstance(openai_key, str) and openai_key.startswith('sk-') 103 | 104 | def apiok(wrapper): 105 | s = wrapper.generate("Hello!") 106 | return wrapper.fail_msg not in s 107 | 108 | def circular_pred(df, extract_func=None): 109 | if extract_func is None: 110 | extract_func = lambda x: x 111 | df = df.sort_values('index') 112 | from vlmeval.utils import can_infer_option 113 | shift = int(1e6) 114 | 115 | choices = [extract_func(x) for x in df['prediction']] 116 | pred_map = {i: c for i, c in zip(df['index'], choices)} 117 | flag_map = {i: True for i in pred_map if i < 1e6} 118 | valid_map = {i: True for i in pred_map if i < 1e6} 119 | for i in df['index']: 120 | if i >= shift and pred_map[i] and pred_map[i - shift]: 121 | if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase): 122 | valid_map[i % shift] = False 123 | continue 124 | if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1: 125 | continue 126 | else: 127 | flag_map[i % shift] = False 128 | flag_map = {k: v for k, v in flag_map.items() if valid_map[k]} 129 | flags = list(flag_map.values()) 130 | return np.mean(flags) 131 | 132 | def MMBenchOfficialServer(): 133 | root = LMUDataRoot() 134 | for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']: 135 | if osp.exists(f'{root}/{dataset}.tsv'): 136 | return True 137 | return False -------------------------------------------------------------------------------- /vlmeval/smp/.ipynb_checkpoints/vlm-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os, io 2 | import pandas as pd 3 | import numpy as np 4 | import string 5 | from uuid import uuid4 6 | import os.path as osp 7 | import base64 8 | from PIL import Image 9 | 10 | def mmqa_display(question): 11 | question = {k.lower(): v for k, v in question.items()} 12 | keys = list(question.keys()) 13 | keys = [k for k in keys if k not in ['index', 'image']] 14 | 15 | images = question['image'] 16 | if isinstance(images, str): 17 | images = [images] 18 | 19 | idx = question.pop('index', 'XXX') 20 | print(f'INDEX: {idx}') 21 | 22 | for im in images: 23 | image = decode_base64_to_image(im, target_size=512) 24 | display(image) 25 | 26 | for k in keys: 27 | try: 28 | if not pd.isna(question[k]): 29 | print(f'{k.upper()}. {question[k]}') 30 | except ValueError: 31 | if False in pd.isna(question[k]): 32 | print(f'{k.upper()}. {question[k]}') 33 | 34 | def encode_image_to_base64(img, target_size=-1): 35 | # if target_size == -1, will not do resizing 36 | # else, will set the max_size ot (target_size, target_size) 37 | if img.mode in ("RGBA", "P"): 38 | img = img.convert("RGB") 39 | tmp = osp.join('/tmp', str(uuid4()) + '.jpg') 40 | if target_size > 0: 41 | img.thumbnail((target_size, target_size)) 42 | img.save(tmp) 43 | with open(tmp, 'rb') as image_file: 44 | image_data = image_file.read() 45 | ret = base64.b64encode(image_data).decode('utf-8') 46 | os.remove(tmp) 47 | return ret 48 | 49 | def encode_image_file_to_base64(image_path, target_size=-1): 50 | image = Image.open(image_path) 51 | return encode_image_to_base64(image, target_size=target_size) 52 | 53 | def decode_base64_to_image(base64_string, target_size=-1): 54 | image_data = base64.b64decode(base64_string) 55 | image = Image.open(io.BytesIO(image_data)) 56 | if image.mode in ('RGBA', 'P'): 57 | image = image.convert('RGB') 58 | if target_size > 0: 59 | image.thumbnail((target_size, target_size)) 60 | return image 61 | 62 | def decode_base64_to_image_file(base64_string, image_path, target_size=-1): 63 | image = decode_base64_to_image(base64_string, target_size=target_size) 64 | image.save(image_path) 65 | 66 | def LMUDataRoot(): 67 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']): 68 | return os.environ['LMUData'] 69 | home = osp.expanduser('~') 70 | root = osp.join(home, 'LMUData') 71 | os.makedirs(root, exist_ok=True) 72 | return root 73 | 74 | def init_environ(args): 75 | if len(args.lmudata): 76 | os.environ['LMUData']=args.lmudata 77 | if len(args.openai): 78 | os.environ['OPENAI_API_KEY']=args.openai 79 | 80 | def build_option_str(option_dict): 81 | s = 'There are several options: \n' 82 | for c, content in option_dict.items(): 83 | if not pd.isna(content): 84 | s += f'{c}. {content}\n' 85 | return s 86 | 87 | def isimg(s): 88 | return osp.exists(s) or s.startswith('http') 89 | 90 | def read_ok(img_path): 91 | if not osp.exists(img_path): 92 | return False 93 | try: 94 | im = Image.open(img_path) 95 | assert im.size[0] > 0 and im.size[1] > 0 96 | return True 97 | except: 98 | return False 99 | 100 | def gpt_key_set(): 101 | openai_key = os.environ.get('OPENAI_API_KEY', None) 102 | return isinstance(openai_key, str) and openai_key.startswith('sk-') 103 | 104 | def apiok(wrapper): 105 | s = wrapper.generate("Hello!") 106 | return wrapper.fail_msg not in s 107 | 108 | def circular_pred(df, extract_func=None): 109 | if extract_func is None: 110 | extract_func = lambda x: x 111 | df = df.sort_values('index') 112 | from vlmeval.utils import can_infer_option 113 | shift = int(1e6) 114 | 115 | choices = [extract_func(x) for x in df['prediction']] 116 | pred_map = {i: c for i, c in zip(df['index'], choices)} 117 | flag_map = {i: True for i in pred_map if i < 1e6} 118 | valid_map = {i: True for i in pred_map if i < 1e6} 119 | for i in df['index']: 120 | if i >= shift and pred_map[i] and pred_map[i - shift]: 121 | if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase): 122 | valid_map[i % shift] = False 123 | continue 124 | if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1: 125 | continue 126 | else: 127 | flag_map[i % shift] = False 128 | flag_map = {k: v for k, v in flag_map.items() if valid_map[k]} 129 | flags = list(flag_map.values()) 130 | return np.mean(flags) 131 | 132 | def MMBenchOfficialServer(): 133 | root = LMUDataRoot() 134 | for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']: 135 | if osp.exists(f'{root}/{dataset}.tsv'): 136 | return True 137 | return False -------------------------------------------------------------------------------- /vlmeval/smp/file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import pandas as pd 4 | import os 5 | import csv 6 | import hashlib 7 | import os.path as osp 8 | import time 9 | import numpy as np 10 | 11 | class NumpyEncoder(json.JSONEncoder): 12 | def default(self, obj): 13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, 14 | np.int16, np.int32, np.int64, np.uint8, 15 | np.uint16, np.uint32, np.uint64)): 16 | return int(obj) 17 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 18 | return float(obj) 19 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): 20 | return {'real': obj.real, 'imag': obj.imag} 21 | elif isinstance(obj, (np.ndarray,)): 22 | return obj.tolist() 23 | elif isinstance(obj, (np.bool_)): 24 | return bool(obj) 25 | elif isinstance(obj, (np.void)): 26 | return None 27 | return json.JSONEncoder.default(self, obj) 28 | 29 | # LOAD & DUMP 30 | def dump(data, f, **kwargs): 31 | def dump_pkl(data, pth, **kwargs): 32 | pickle.dump(data, open(pth, 'wb')) 33 | 34 | def dump_json(data, pth, **kwargs): 35 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder) 36 | 37 | def dump_jsonl(data, f, **kwargs): 38 | lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data] 39 | with open(f, 'w', encoding='utf8') as fout: 40 | fout.write('\n'.join(lines)) 41 | 42 | def dump_xlsx(data, f, **kwargs): 43 | data.to_excel(f, index=False, engine='xlsxwriter') 44 | 45 | def dump_csv(data, f, quoting=csv.QUOTE_ALL): 46 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting) 47 | 48 | def dump_tsv(data, f, quoting=csv.QUOTE_ALL): 49 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting) 50 | 51 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv) 52 | suffix = f.split('.')[-1] 53 | return handlers[suffix](data, f, **kwargs) 54 | 55 | def load(f): 56 | def load_pkl(pth): 57 | return pickle.load(open(pth, 'rb')) 58 | 59 | def load_json(pth): 60 | return json.load(open(pth, 'r', encoding='utf-8')) 61 | 62 | def load_jsonl(f): 63 | lines = open(f, encoding='utf-8').readlines() 64 | lines = [x.strip() for x in lines] 65 | if lines[-1] == '': 66 | lines = lines[:-1] 67 | data = [json.loads(x) for x in lines] 68 | return data 69 | 70 | def load_xlsx(f): 71 | return pd.read_excel(f) 72 | 73 | def load_csv(f): 74 | return pd.read_csv(f) 75 | 76 | def load_tsv(f): 77 | return pd.read_csv(f, sep='\t') 78 | 79 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv) 80 | suffix = f.split('.')[-1] 81 | return handlers[suffix](f) 82 | 83 | def download_file(url, filename=None): 84 | import urllib.request 85 | from tqdm import tqdm 86 | 87 | class DownloadProgressBar(tqdm): 88 | def update_to(self, b=1, bsize=1, tsize=None): 89 | if tsize is not None: 90 | self.total = tsize 91 | self.update(b * bsize - self.n) 92 | 93 | if filename is None: 94 | filename = url.split('/')[-1] 95 | 96 | with DownloadProgressBar(unit='B', unit_scale=True, 97 | miniters=1, desc=url.split('/')[-1]) as t: 98 | urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to) 99 | return filename 100 | 101 | def ls(dirname='.', match='', mode='all', level=1): 102 | if dirname == '.': 103 | ans = os.listdir(dirname) 104 | else: 105 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)] 106 | assert mode in ['all', 'dir', 'file'] 107 | assert level >= 1 and isinstance(level, int) 108 | if level == 1: 109 | ans = [x for x in ans if match in x] 110 | if mode == 'dir': 111 | ans = [x for x in ans if osp.isdir(x)] 112 | elif mode == 'file': 113 | ans = [x for x in ans if not osp.isdir(x)] 114 | else: 115 | ans = [x for x in ans if osp.isdir(x)] 116 | res = [] 117 | for d in ans: 118 | res.extend(ls(d, match=match, mode=mode, level=level-1)) 119 | ans = res 120 | return ans 121 | 122 | def mrlines(fname, sp='\n'): 123 | f = open(fname).read().split(sp) 124 | while f != [] and f[-1] == '': 125 | f = f[:-1] 126 | return f 127 | 128 | def mwlines(lines, fname): 129 | with open(fname, 'w') as fout: 130 | fout.write('\n'.join(lines)) 131 | 132 | def md5(file_pth): 133 | with open(file_pth, 'rb') as f: 134 | hash = hashlib.new('md5') 135 | for chunk in iter(lambda: f.read(2**20), b''): 136 | hash.update(chunk) 137 | return str(hash.hexdigest()) 138 | 139 | def last_modified(pth): 140 | stamp = osp.getmtime(pth) 141 | m_ti = time.ctime(stamp) 142 | t_obj = time.strptime(m_ti) 143 | t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:] 144 | return t 145 | -------------------------------------------------------------------------------- /vlmeval/smp/.ipynb_checkpoints/file-checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import pandas as pd 4 | import os 5 | import csv 6 | import hashlib 7 | import os.path as osp 8 | import time 9 | import numpy as np 10 | 11 | class NumpyEncoder(json.JSONEncoder): 12 | def default(self, obj): 13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, 14 | np.int16, np.int32, np.int64, np.uint8, 15 | np.uint16, np.uint32, np.uint64)): 16 | return int(obj) 17 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 18 | return float(obj) 19 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): 20 | return {'real': obj.real, 'imag': obj.imag} 21 | elif isinstance(obj, (np.ndarray,)): 22 | return obj.tolist() 23 | elif isinstance(obj, (np.bool_)): 24 | return bool(obj) 25 | elif isinstance(obj, (np.void)): 26 | return None 27 | return json.JSONEncoder.default(self, obj) 28 | 29 | # LOAD & DUMP 30 | def dump(data, f, **kwargs): 31 | def dump_pkl(data, pth, **kwargs): 32 | pickle.dump(data, open(pth, 'wb')) 33 | 34 | def dump_json(data, pth, **kwargs): 35 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder) 36 | 37 | def dump_jsonl(data, f, **kwargs): 38 | lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data] 39 | with open(f, 'w', encoding='utf8') as fout: 40 | fout.write('\n'.join(lines)) 41 | 42 | def dump_xlsx(data, f, **kwargs): 43 | data.to_excel(f, index=False, engine='xlsxwriter') 44 | 45 | def dump_csv(data, f, quoting=csv.QUOTE_ALL): 46 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting) 47 | 48 | def dump_tsv(data, f, quoting=csv.QUOTE_ALL): 49 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting) 50 | 51 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv) 52 | suffix = f.split('.')[-1] 53 | return handlers[suffix](data, f, **kwargs) 54 | 55 | def load(f): 56 | def load_pkl(pth): 57 | return pickle.load(open(pth, 'rb')) 58 | 59 | def load_json(pth): 60 | return json.load(open(pth, 'r', encoding='utf-8')) 61 | 62 | def load_jsonl(f): 63 | lines = open(f, encoding='utf-8').readlines() 64 | lines = [x.strip() for x in lines] 65 | if lines[-1] == '': 66 | lines = lines[:-1] 67 | data = [json.loads(x) for x in lines] 68 | return data 69 | 70 | def load_xlsx(f): 71 | return pd.read_excel(f) 72 | 73 | def load_csv(f): 74 | return pd.read_csv(f) 75 | 76 | def load_tsv(f): 77 | return pd.read_csv(f, sep='\t') 78 | 79 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv) 80 | suffix = f.split('.')[-1] 81 | return handlers[suffix](f) 82 | 83 | def download_file(url, filename=None): 84 | import urllib.request 85 | from tqdm import tqdm 86 | 87 | class DownloadProgressBar(tqdm): 88 | def update_to(self, b=1, bsize=1, tsize=None): 89 | if tsize is not None: 90 | self.total = tsize 91 | self.update(b * bsize - self.n) 92 | 93 | if filename is None: 94 | filename = url.split('/')[-1] 95 | 96 | with DownloadProgressBar(unit='B', unit_scale=True, 97 | miniters=1, desc=url.split('/')[-1]) as t: 98 | urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to) 99 | return filename 100 | 101 | def ls(dirname='.', match='', mode='all', level=1): 102 | if dirname == '.': 103 | ans = os.listdir(dirname) 104 | else: 105 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)] 106 | assert mode in ['all', 'dir', 'file'] 107 | assert level >= 1 and isinstance(level, int) 108 | if level == 1: 109 | ans = [x for x in ans if match in x] 110 | if mode == 'dir': 111 | ans = [x for x in ans if osp.isdir(x)] 112 | elif mode == 'file': 113 | ans = [x for x in ans if not osp.isdir(x)] 114 | else: 115 | ans = [x for x in ans if osp.isdir(x)] 116 | res = [] 117 | for d in ans: 118 | res.extend(ls(d, match=match, mode=mode, level=level-1)) 119 | ans = res 120 | return ans 121 | 122 | def mrlines(fname, sp='\n'): 123 | f = open(fname).read().split(sp) 124 | while f != [] and f[-1] == '': 125 | f = f[:-1] 126 | return f 127 | 128 | def mwlines(lines, fname): 129 | with open(fname, 'w') as fout: 130 | fout.write('\n'.join(lines)) 131 | 132 | def md5(file_pth): 133 | with open(file_pth, 'rb') as f: 134 | hash = hashlib.new('md5') 135 | for chunk in iter(lambda: f.read(2**20), b''): 136 | hash.update(chunk) 137 | return str(hash.hexdigest()) 138 | 139 | def last_modified(pth): 140 | stamp = osp.getmtime(pth) 141 | m_ti = time.ctime(stamp) 142 | t_obj = time.strptime(m_ti) 143 | t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:] 144 | return t 145 | -------------------------------------------------------------------------------- /vlmeval/utils/dataset_config.py: -------------------------------------------------------------------------------- 1 | from ..smp import listinstr 2 | 3 | dataset_URLs = { 4 | 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv", 5 | 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv", 6 | 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv", 7 | 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv", 8 | "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only 9 | "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only 10 | 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv", 11 | 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv", 12 | 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv", 13 | "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv", 14 | "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv", 15 | "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv", 16 | "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv", 17 | "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv", 18 | 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv", 19 | "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv", 20 | "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv", 21 | "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv", 22 | 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv", 23 | 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv", 24 | 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv", 25 | 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv", 26 | "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv", 27 | 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv", 28 | "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv", 29 | "OCRBench": 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv', 30 | } 31 | 32 | dataset_md5_dict = { 33 | 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8", 34 | 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528", 35 | 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd", 36 | 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e", 37 | "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only 38 | "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only 39 | 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07", 40 | 'MME': "b36b43c3f09801f5d368627fb92187c3", 41 | 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0", 42 | "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d", 43 | "MMVet": "f400d7f513a585a0f218cbd6882e0671", 44 | 'COCO_VAL': "72a5079dead060269ac222c5aa5128af", 45 | 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9', 46 | 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97', 47 | 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd', 48 | 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d", 49 | 'MMMU_TEST': "c19875d11a2d348d07e5eb4bdf33166d", 50 | 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464', 51 | 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc', 52 | 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3', 53 | 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f', 54 | 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c', 55 | "DocVQA_VAL": 'ee0d8ae5527439438d08e154ef65d735', 56 | "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975", 57 | "LLaVABench": "d382a093f749a697820d3dadd61c8428", 58 | "OCRBench": 'e953d98a987cc6e26ef717b61260b778', 59 | } 60 | 61 | img_root_map = {k: k for k in dataset_URLs} 62 | img_root_map.update({ 63 | 'MMBench_DEV_EN': "MMBench", 64 | 'MMBench_TEST_EN': "MMBench", 65 | 'MMBench_DEV_CN': "MMBench", 66 | 'MMBench_TEST_CN': "MMBench", 67 | "MMBench_CN": "MMBench", # Link Invalid, Internal Only 68 | 'COCO_VAL':'COCO', 69 | 'OCRVQA_TEST': 'OCRVQA', 70 | 'OCRVQA_TESTCORE': 'OCRVQA', 71 | 'TextVQA_VAL': 'TextVQA', 72 | 'MMMU_DEV_VAL': 'MMMU', 73 | "MMMU_TEST": "MMMU", 74 | 'MathVista_MINI': 'MathVista', 75 | 'ChartQA_VALTEST_HUMAN': 'ChartQA', 76 | 'HallusionBench': 'Hallusion', 77 | 'DocVQA_VAL': 'DocVQA', 78 | "OCRBench": 'OCRBench', 79 | }) 80 | 81 | assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict) 82 | 83 | def DATASET_TYPE(dataset): 84 | dataset = dataset.lower() 85 | if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset): 86 | return 'multi-choice' 87 | elif listinstr(['mme', 'hallusion'], dataset): 88 | return 'Y/N' 89 | elif 'coco' in dataset: 90 | return 'Caption' 91 | elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'llavabench', 'mmvet', 'OCRBench'], dataset): 92 | return 'VQA' 93 | else: 94 | return 'QA' 95 | 96 | def abbr2full(s): 97 | datasets = [x for x in img_root_map] 98 | ins = [s in d for d in datasets] 99 | if sum(ins) == 1: 100 | for d in datasets: 101 | if s in d: 102 | return d 103 | else: 104 | return None 105 | -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/dataset_config-checkpoint.py: -------------------------------------------------------------------------------- 1 | from ..smp import listinstr 2 | 3 | dataset_URLs = { 4 | 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv", 5 | 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv", 6 | 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv", 7 | 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv", 8 | "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only 9 | "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only 10 | 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv", 11 | 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv", 12 | 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv", 13 | "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv", 14 | "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv", 15 | "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv", 16 | "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv", 17 | "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv", 18 | 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv", 19 | "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv", 20 | "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv", 21 | "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv", 22 | 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv", 23 | 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv", 24 | 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv", 25 | 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv", 26 | "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv", 27 | 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv", 28 | "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv", 29 | "OCRBench": 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv', 30 | } 31 | 32 | dataset_md5_dict = { 33 | 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8", 34 | 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528", 35 | 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd", 36 | 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e", 37 | "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only 38 | "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only 39 | 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07", 40 | 'MME': "b36b43c3f09801f5d368627fb92187c3", 41 | 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0", 42 | "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d", 43 | "MMVet": "f400d7f513a585a0f218cbd6882e0671", 44 | 'COCO_VAL': "72a5079dead060269ac222c5aa5128af", 45 | 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9', 46 | 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97', 47 | 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd', 48 | 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d", 49 | 'MMMU_TEST': "c19875d11a2d348d07e5eb4bdf33166d", 50 | 'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464', 51 | 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc', 52 | 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3', 53 | 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f', 54 | 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c', 55 | "DocVQA_VAL": 'ee0d8ae5527439438d08e154ef65d735', 56 | "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975", 57 | "LLaVABench": "d382a093f749a697820d3dadd61c8428", 58 | "OCRBench": 'e953d98a987cc6e26ef717b61260b778', 59 | } 60 | 61 | img_root_map = {k: k for k in dataset_URLs} 62 | img_root_map.update({ 63 | 'MMBench_DEV_EN': "MMBench", 64 | 'MMBench_TEST_EN': "MMBench", 65 | 'MMBench_DEV_CN': "MMBench", 66 | 'MMBench_TEST_CN': "MMBench", 67 | "MMBench_CN": "MMBench", # Link Invalid, Internal Only 68 | 'COCO_VAL':'COCO', 69 | 'OCRVQA_TEST': 'OCRVQA', 70 | 'OCRVQA_TESTCORE': 'OCRVQA', 71 | 'TextVQA_VAL': 'TextVQA', 72 | 'MMMU_DEV_VAL': 'MMMU', 73 | "MMMU_TEST": "MMMU", 74 | 'MathVista_MINI': 'MathVista', 75 | 'ChartQA_VALTEST_HUMAN': 'ChartQA', 76 | 'HallusionBench': 'Hallusion', 77 | 'DocVQA_VAL': 'DocVQA', 78 | "OCRBench": 'OCRBench', 79 | }) 80 | 81 | assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict) 82 | 83 | def DATASET_TYPE(dataset): 84 | dataset = dataset.lower() 85 | if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset): 86 | return 'multi-choice' 87 | elif listinstr(['mme', 'hallusion'], dataset): 88 | return 'Y/N' 89 | elif 'coco' in dataset: 90 | return 'Caption' 91 | elif listinstr(['ocrvqa', 'textvqa', 'chartqa', 'mathvista', 'docvqa', 'llavabench', 'mmvet', 'OCRBench'], dataset): 92 | return 'VQA' 93 | else: 94 | return 'QA' 95 | 96 | def abbr2full(s): 97 | datasets = [x for x in img_root_map] 98 | ins = [s in d for d in datasets] 99 | if sum(ins) == 1: 100 | for d in datasets: 101 | if s in d: 102 | return d 103 | else: 104 | return None 105 | -------------------------------------------------------------------------------- /vlmeval/utils/base_prompt.py: -------------------------------------------------------------------------------- 1 | 2 | def create_one_example(format_, question, context, options, answer, knowledge, image_path): 3 | 4 | input_format, output_format = format_.split("-") 5 | 6 | aff_base = "You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems." 7 | neg_base = "You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems." 8 | 9 | kg_emo = f"Emotional Graph: {knowledge[0]}\n" if knowledge[0]!='none' else "" 10 | kg_rat = f"Rational Graph: {knowledge[1]}\n" if knowledge[1]!='none' else "" 11 | 12 | hint = f"Hint: {context}\n" if context != "none" else "" 13 | question_ = f"Question: {question}\n" 14 | option_ = f"Options:\n{options}" if options != "none" else "" 15 | answer_ = "Please select the correct answer from the options above, without any explaination." if options != "none" else "Answer directly." 16 | 17 | if input_format=="IQ": 18 | input = f"""{hint}{question_}{option_}{answer_}""" 19 | 20 | elif input_format == "QIM": 21 | input = f"""{hint}{question_}{option_}For the provided image and its associated question, generate a graph to answer the question. 22 | """ 23 | 24 | elif input_format == "GKG": 25 | input = f"""For the provided image and its associated question. generate a scene graph in JSON format that includes the following: 26 | 1. Objects, attributes, relationships that are more relevant to answering the question. 27 | 2. Objects are NO MORE THAN 3. 28 | {hint}{question_}""" 29 | 30 | #### ODebate_stage 31 | elif input_format == "ODIM": 32 | input = f"""{hint}{question_}You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems. 33 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this question. 34 | """ 35 | 36 | elif input_format == "ONIM": 37 | input = f"""{hint}{question_}You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems. 38 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this problem. 39 | """ 40 | elif input_format == "ODQIM": 41 | input = f"""{hint}{question_}Debate Solution:{knowledge[1]}\nYou are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems. 42 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem. 43 | """ 44 | 45 | elif input_format == "ONQIM": 46 | input = f"""{hint}{question_}Debate Solution:{knowledge[0]}\nYou are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems. 47 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem. 48 | """ 49 | 50 | elif input_format == "OAGM": 51 | input = f"""You're good at summarizing and answering questions. \nEmotional Solution: {knowledge[0]}\nRational Solution: {knowledge[1]}\n{hint}{question_}{option_}{answer_} 52 | """ 53 | 54 | 55 | #### Debate_KG_stage 56 | elif input_format == "KDIM": 57 | input = f"""{aff_base} 58 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 59 | 1. Objects, attributes, relationships that are more relevant to answering the question. 60 | 2. Objects are NO MORE THAN 3. 61 | {hint}{question_}""" 62 | 63 | elif input_format == "KNIM": 64 | input = f"""{neg_base} 65 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 66 | 1. Objects, attributes, relationships that are more relevant to answering the question. 67 | 2. Objects are NO MORE THAN 3. 68 | {hint}{question_}""" 69 | 70 | elif input_format == "KDQIM": 71 | input = f"""{aff_base} 72 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 73 | 1. Objects, attributes, relationships that are more relevant to answering the question. 74 | 2. Delete the irrelevant objects, attributes and relationships. 75 | {hint}{question_}{kg_rat} 76 | """ 77 | 78 | elif input_format == "KNQIM": 79 | input = f"""{neg_base} 80 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 81 | 1. Objects, attributes, relationships that are more relevant to answering the question. 82 | 2. Delete the irrelevant objects, attributes and relationships. 83 | {hint}{question_}{kg_emo} 84 | """ 85 | 86 | elif input_format == "KAGM": 87 | input = f"""You're good at summarizing and answering questions.{hint}{kg_emo}{kg_rat}Use the image and two debate Solution as context and answer the following question:\n{question_}{option_}{answer_} 88 | """ 89 | 90 | # Outputs 91 | if output_format == 'A': 92 | output = "Answer:" 93 | 94 | elif output_format == 'G': 95 | output = f"Solution: " 96 | 97 | text = input + output 98 | text = text.replace(" ", " ") 99 | if text.endswith("BECAUSE:"): 100 | text = text.replace("BECAUSE:", "").strip() 101 | return text 102 | -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/base_prompt-checkpoint.py: -------------------------------------------------------------------------------- 1 | 2 | def create_one_example(format_, question, context, options, answer, knowledge, image_path): 3 | 4 | input_format, output_format = format_.split("-") 5 | 6 | aff_base = "You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems." 7 | neg_base = "You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems." 8 | 9 | kg_emo = f"Graph: {knowledge[0]}\n" if knowledge[0]!='none' else "" 10 | kg_rat = f"Graph: {knowledge[1]}\n" if knowledge[1]!='none' else "" 11 | 12 | hint = f"Hint: {context}\n" if context != "none" else "" 13 | question_ = f"Question: {question}\n" 14 | option_ = f"Options:\n{options}" if options != "none" else "" 15 | answer_ = "Answer with the option's letter from the given choices directly." if options != "none" else "Answer directly." 16 | 17 | if input_format=="IQ": 18 | input = f"""{hint}{question_}{option_}{answer_}""" 19 | 20 | elif input_format == "QIM": 21 | input = f"""{hint}{question_}{option_}For the provided image and its associated question, generate a graph to answer the question. 22 | """ 23 | 24 | elif input_format == "GKG": 25 | input = f"""{hint}{question_}{option_}For the provided image and its associated question. generate a scene graph in JSON format that includes the following: 26 | 1. Obiects that are relevant to answering the question. 27 | 2. Obiect attributes that are relevant to answering the question. 28 | 3. Obiect relationships that are releyant to answering the question. 29 | """ 30 | 31 | #### ODebate_stage 32 | elif input_format == "ODIM": 33 | input = f"""{hint}{question_}You are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems. 34 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this question. 35 | """ 36 | 37 | elif input_format == "ONIM": 38 | input = f"""{hint}{question_}You are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems. 39 | For the provided image and its associated question, Do not give the answer, But your solution and ideas to solve this problem. 40 | """ 41 | elif input_format == "ODQIM": 42 | input = f"""{hint}{question_}Debate Solution:{knowledge[1]}\nYou are a fellow debater from the AFFIRMATIVE side, You are more Emotional to think about problems. 43 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem. 44 | """ 45 | 46 | elif input_format == "ONQIM": 47 | input = f"""{hint}{question_}Debate Solution:{knowledge[0]}\nYou are a fellow debater from the NEGATIVE side, You are more Rational in thinking about problems. 48 | Based on the debate Solution of the question, Do not give the answer, But your Better solution and ideas to solve this problem. 49 | """ 50 | 51 | elif input_format == "OAGM": 52 | input = f"""You're good at summarizing and answering questions. \nEmotional Solution: {knowledge[0]}\nRational Solution: {knowledge[1]}\n{hint}{question_}{option_}{answer_} 53 | """ 54 | 55 | 56 | #### Debate_KG_stage 57 | elif input_format == "KDIM": 58 | input = f"""{aff_base} 59 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 60 | 1. Objects, attributes, relationships that are more relevant to answering the question. 61 | 2. Delete the irrelevant objects, attributes and relationships. 62 | {hint}{question_}""" 63 | 64 | elif input_format == "KNIM": 65 | input = f"""{neg_base} 66 | For the provided image and its associated question, Please give your solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 67 | 1. Objects, attributes, relationships that are more relevant to answering the question. 68 | 2. Delete the irrelevant objects, attributes and relationships. 69 | {hint}{question_}""" 70 | 71 | elif input_format == "KDQIM": 72 | input = f"""{aff_base} 73 | Based on the debate Solution of the question, Please give your Better solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 74 | 1. Objects, attributes, relationships that are more relevant to answering the question. 75 | 2. Delete the irrelevant objects, attributes and relationships. 76 | {hint}{question_}{kg_rat} 77 | """ 78 | 79 | elif input_format == "KNQIM": 80 | input = f"""{neg_base} 81 | Based on the debate Solution of the question, Please give your Better solution and ideas to solve this problem, but do not give a final answer. Generate an updated graph from a different view based on the Debate Graph in JSON format that includes the following: 82 | 1. Objects, attributes, relationships that are more relevant to answering the question. 83 | 2. Delete the irrelevant objects, attributes and relationships. 84 | {hint}{question_}{kg_emo} 85 | """ 86 | 87 | elif input_format == "KAGM": 88 | input = f"""{hint}{kg_emo}{kg_rat}{question_}{option_}{answer_} 89 | """ 90 | 91 | # Outputs 92 | if output_format == 'A': 93 | output = "Answer:" 94 | 95 | elif output_format == 'G': 96 | output = f"Graph: " 97 | 98 | text = input + output 99 | text = text.replace(" ", " ") 100 | if text.endswith("BECAUSE:"): 101 | text = text.replace("BECAUSE:", "").strip() 102 | return text 103 | -------------------------------------------------------------------------------- /vlmeval/smp/lb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from collections import defaultdict 4 | import gradio as gr 5 | import copy as cp 6 | import numpy as np 7 | from .misc import listinstr 8 | 9 | # CONSTANTS-URL 10 | URL = "https://github.com/thecharm/BDoG.git" 11 | VLMEVALKIT_README = 'https://raw.githubusercontent.com/thecharm/BDoG/main/README.md' 12 | # CONSTANTS-CITATION 13 | CITATION_BUTTON_TEXT = r"""@misc{zheng2024pictureworthgraphblueprint, 14 | title={A Picture Is Worth a Graph: Blueprint Debate on Graph for Multimodal Reasoning}, 15 | author={Changmeng Zheng and Dayong Liang and Wengyu Zhang and Xiao-Yong Wei and Tat-Seng Chua and Qing Li}, 16 | year={2024}, 17 | eprint={2403.14972}, 18 | archivePrefix={arXiv}, 19 | primaryClass={cs.AI}, 20 | url={https://arxiv.org/abs/2403.14972}, 21 | }""" 22 | CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results" 23 | # CONSTANTS-TEXT 24 | LEADERBORAD_INTRODUCTION = "" 25 | # CONSTANTS-FIELDS 26 | META_FIELDS = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model', 'OpenSource', 'Verified'] 27 | MAIN_FIELDS = ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'CCBench', 'MME', 'SEEDBench_IMG', 'MMVet', 'MMMU_VAL', 'MathVista', 'HallusionBench', 'LLaVABench'] 28 | MMBENCH_FIELDS = ['MMBench_TEST_EN', 'MMBench_DEV_EN', 'MMBench_TEST_CN', 'MMBench_DEV_CN', 'CCBench'] 29 | MODEL_SIZE = ['<10B', '10B-20B', '20B-40B', '>40B', 'Unknown'] 30 | MODEL_TYPE = ['API', 'OpenSource', 'Proprietary'] 31 | 32 | LEADERBOARD_MD = { 33 | } 34 | 35 | from urllib.request import urlopen 36 | 37 | def load_results(): 38 | data = json.loads(urlopen(URL).read()) 39 | return data 40 | 41 | def nth_large(val, vals): 42 | return sum([1 for v in vals if v > val]) + 1 43 | 44 | def format_timestamp(timestamp): 45 | return timestamp[:2] + '.' + timestamp[2:4] + '.' + timestamp[4:6] + ' ' + timestamp[6:8] + ':' + timestamp[8:10] + ':' + timestamp[10:12] 46 | 47 | def model_size_flag(sz, FIELDS): 48 | if pd.isna(sz) and 'Unknown' in FIELDS: 49 | return True 50 | if pd.isna(sz): 51 | return False 52 | if '<10B' in FIELDS and sz < 10: 53 | return True 54 | if '10B-20B' in FIELDS and sz >= 10 and sz < 20: 55 | return True 56 | if '20B-40B' in FIELDS and sz >= 20 and sz < 40: 57 | return True 58 | if '>40B' in FIELDS and sz >= 40: 59 | return True 60 | return False 61 | 62 | def model_type_flag(line, FIELDS): 63 | if 'OpenSource' in FIELDS and line['OpenSource'] == 'Yes': 64 | return True 65 | if 'API' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'Yes': 66 | return True 67 | if 'Proprietary' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'No': 68 | return True 69 | return False 70 | 71 | def BUILD_L1_DF(results, fields): 72 | res = defaultdict(list) 73 | for i, m in enumerate(results): 74 | item = results[m] 75 | meta = item['META'] 76 | for k in META_FIELDS: 77 | if k == 'Parameters (B)': 78 | param = meta['Parameters'] 79 | res[k].append(float(param.replace('B', '')) if param != '' else None) 80 | elif k == 'Method': 81 | name, url = meta['Method'] 82 | res[k].append(f'{name}') 83 | else: 84 | res[k].append(meta[k]) 85 | scores, ranks = [], [] 86 | for d in fields: 87 | res[d].append(item[d]['Overall']) 88 | if d == 'MME': 89 | scores.append(item[d]['Overall'] / 28) 90 | else: 91 | scores.append(item[d]['Overall']) 92 | ranks.append(nth_large(item[d]['Overall'], [x[d]['Overall'] for x in results.values()])) 93 | res['Avg Score'].append(round(np.mean(scores), 1)) 94 | res['Avg Rank'].append(round(np.mean(ranks), 2)) 95 | 96 | df = pd.DataFrame(res) 97 | df = df.sort_values('Avg Rank') 98 | 99 | check_box = {} 100 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model'] 101 | check_box['required'] = ['Avg Score', 'Avg Rank'] 102 | check_box['all'] = check_box['required'] + ['OpenSource', 'Verified'] + fields 103 | type_map = defaultdict(lambda: 'number') 104 | type_map['Method'] = 'html' 105 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str' 106 | check_box['type_map'] = type_map 107 | return df, check_box 108 | 109 | def BUILD_L2_DF(results, dataset): 110 | res = defaultdict(list) 111 | fields = list(list(results.values())[0][dataset].keys()) 112 | non_overall_fields = [x for x in fields if 'Overall' not in x] 113 | overall_fields = [x for x in fields if 'Overall' in x] 114 | if dataset == 'MME': 115 | non_overall_fields = [x for x in non_overall_fields if not listinstr(['Perception', 'Cognition'], x)] 116 | overall_fields = overall_fields + ['Perception', 'Cognition'] 117 | 118 | for m in results: 119 | item = results[m] 120 | meta = item['META'] 121 | for k in META_FIELDS: 122 | if k == 'Parameters (B)': 123 | param = meta['Parameters'] 124 | res[k].append(float(param.replace('B', '')) if param != '' else None) 125 | elif k == 'Method': 126 | name, url = meta['Method'] 127 | res[k].append(f'{name}') 128 | else: 129 | res[k].append(meta[k]) 130 | fields = [x for x in fields] 131 | 132 | for d in non_overall_fields: 133 | res[d].append(item[dataset][d]) 134 | for d in overall_fields: 135 | res[d].append(item[dataset][d]) 136 | 137 | df = pd.DataFrame(res) 138 | df = df.sort_values('Overall') 139 | df = df.iloc[::-1] 140 | 141 | check_box = {} 142 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model'] 143 | check_box['required'] = overall_fields 144 | check_box['all'] = non_overall_fields + overall_fields 145 | type_map = defaultdict(lambda: 'number') 146 | type_map['Method'] = 'html' 147 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str' 148 | check_box['type_map'] = type_map 149 | return df, check_box -------------------------------------------------------------------------------- /vlmeval/smp/.ipynb_checkpoints/lb-checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from collections import defaultdict 4 | import gradio as gr 5 | import copy as cp 6 | import numpy as np 7 | from .misc import listinstr 8 | 9 | # CONSTANTS-URL 10 | URL = "https://github.com/thecharm/BDoG.git" 11 | VLMEVALKIT_README = 'https://raw.githubusercontent.com/thecharm/BDoG/main/README.md' 12 | # CONSTANTS-CITATION 13 | CITATION_BUTTON_TEXT = r"""@misc{zheng2024pictureworthgraphblueprint, 14 | title={A Picture Is Worth a Graph: Blueprint Debate on Graph for Multimodal Reasoning}, 15 | author={Changmeng Zheng and Dayong Liang and Wengyu Zhang and Xiao-Yong Wei and Tat-Seng Chua and Qing Li}, 16 | year={2024}, 17 | eprint={2403.14972}, 18 | archivePrefix={arXiv}, 19 | primaryClass={cs.AI}, 20 | url={https://arxiv.org/abs/2403.14972}, 21 | }""" 22 | CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results" 23 | # CONSTANTS-TEXT 24 | LEADERBORAD_INTRODUCTION = "" 25 | # CONSTANTS-FIELDS 26 | META_FIELDS = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model', 'OpenSource', 'Verified'] 27 | MAIN_FIELDS = ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'CCBench', 'MME', 'SEEDBench_IMG', 'MMVet', 'MMMU_VAL', 'MathVista', 'HallusionBench', 'LLaVABench'] 28 | MMBENCH_FIELDS = ['MMBench_TEST_EN', 'MMBench_DEV_EN', 'MMBench_TEST_CN', 'MMBench_DEV_CN', 'CCBench'] 29 | MODEL_SIZE = ['<10B', '10B-20B', '20B-40B', '>40B', 'Unknown'] 30 | MODEL_TYPE = ['API', 'OpenSource', 'Proprietary'] 31 | 32 | LEADERBOARD_MD = { 33 | } 34 | 35 | from urllib.request import urlopen 36 | 37 | def load_results(): 38 | data = json.loads(urlopen(URL).read()) 39 | return data 40 | 41 | def nth_large(val, vals): 42 | return sum([1 for v in vals if v > val]) + 1 43 | 44 | def format_timestamp(timestamp): 45 | return timestamp[:2] + '.' + timestamp[2:4] + '.' + timestamp[4:6] + ' ' + timestamp[6:8] + ':' + timestamp[8:10] + ':' + timestamp[10:12] 46 | 47 | def model_size_flag(sz, FIELDS): 48 | if pd.isna(sz) and 'Unknown' in FIELDS: 49 | return True 50 | if pd.isna(sz): 51 | return False 52 | if '<10B' in FIELDS and sz < 10: 53 | return True 54 | if '10B-20B' in FIELDS and sz >= 10 and sz < 20: 55 | return True 56 | if '20B-40B' in FIELDS and sz >= 20 and sz < 40: 57 | return True 58 | if '>40B' in FIELDS and sz >= 40: 59 | return True 60 | return False 61 | 62 | def model_type_flag(line, FIELDS): 63 | if 'OpenSource' in FIELDS and line['OpenSource'] == 'Yes': 64 | return True 65 | if 'API' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'Yes': 66 | return True 67 | if 'Proprietary' in FIELDS and line['OpenSource'] == 'No' and line['Verified'] == 'No': 68 | return True 69 | return False 70 | 71 | def BUILD_L1_DF(results, fields): 72 | res = defaultdict(list) 73 | for i, m in enumerate(results): 74 | item = results[m] 75 | meta = item['META'] 76 | for k in META_FIELDS: 77 | if k == 'Parameters (B)': 78 | param = meta['Parameters'] 79 | res[k].append(float(param.replace('B', '')) if param != '' else None) 80 | elif k == 'Method': 81 | name, url = meta['Method'] 82 | res[k].append(f'{name}') 83 | else: 84 | res[k].append(meta[k]) 85 | scores, ranks = [], [] 86 | for d in fields: 87 | res[d].append(item[d]['Overall']) 88 | if d == 'MME': 89 | scores.append(item[d]['Overall'] / 28) 90 | else: 91 | scores.append(item[d]['Overall']) 92 | ranks.append(nth_large(item[d]['Overall'], [x[d]['Overall'] for x in results.values()])) 93 | res['Avg Score'].append(round(np.mean(scores), 1)) 94 | res['Avg Rank'].append(round(np.mean(ranks), 2)) 95 | 96 | df = pd.DataFrame(res) 97 | df = df.sort_values('Avg Rank') 98 | 99 | check_box = {} 100 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model'] 101 | check_box['required'] = ['Avg Score', 'Avg Rank'] 102 | check_box['all'] = check_box['required'] + ['OpenSource', 'Verified'] + fields 103 | type_map = defaultdict(lambda: 'number') 104 | type_map['Method'] = 'html' 105 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str' 106 | check_box['type_map'] = type_map 107 | return df, check_box 108 | 109 | def BUILD_L2_DF(results, dataset): 110 | res = defaultdict(list) 111 | fields = list(list(results.values())[0][dataset].keys()) 112 | non_overall_fields = [x for x in fields if 'Overall' not in x] 113 | overall_fields = [x for x in fields if 'Overall' in x] 114 | if dataset == 'MME': 115 | non_overall_fields = [x for x in non_overall_fields if not listinstr(['Perception', 'Cognition'], x)] 116 | overall_fields = overall_fields + ['Perception', 'Cognition'] 117 | 118 | for m in results: 119 | item = results[m] 120 | meta = item['META'] 121 | for k in META_FIELDS: 122 | if k == 'Parameters (B)': 123 | param = meta['Parameters'] 124 | res[k].append(float(param.replace('B', '')) if param != '' else None) 125 | elif k == 'Method': 126 | name, url = meta['Method'] 127 | res[k].append(f'{name}') 128 | else: 129 | res[k].append(meta[k]) 130 | fields = [x for x in fields] 131 | 132 | for d in non_overall_fields: 133 | res[d].append(item[dataset][d]) 134 | for d in overall_fields: 135 | res[d].append(item[dataset][d]) 136 | 137 | df = pd.DataFrame(res) 138 | df = df.sort_values('Overall') 139 | df = df.iloc[::-1] 140 | 141 | check_box = {} 142 | check_box['essential'] = ['Method', 'Parameters (B)', 'Language Model', 'Vision Model'] 143 | check_box['required'] = overall_fields 144 | check_box['all'] = non_overall_fields + overall_fields 145 | type_map = defaultdict(lambda: 'number') 146 | type_map['Method'] = 'html' 147 | type_map['Language Model'] = type_map['Vision Model'] = type_map['OpenSource'] = type_map['Verified'] = 'str' 148 | check_box['type_map'] = type_map 149 | return df, check_box -------------------------------------------------------------------------------- /vlmeval/inference_multi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.distributed as dist 4 | import random 5 | import datetime 6 | from vlmeval.config import supported_VLM 7 | from vlmeval.utils import TSVDataset, track_progress_rich, split_MMMU, Debate_VLM 8 | from vlmeval.smp import * 9 | import logging 10 | import numpy as np 11 | 12 | FAIL_MSG = 'Failed to obtain answer via API.' 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data', type=str, nargs='+', required=True) 17 | parser.add_argument("--model", type=str, nargs='+', required=True) 18 | parser.add_argument("--nproc", type=int, default=4, required=True) 19 | parser.add_argument("--debate", type=int, default=2, required=True) 20 | parser.add_argument("--verbose", action='store_true') 21 | args = parser.parse_args() 22 | return args 23 | 24 | def check_identical_elements_2(lst): 25 | return len(set(lst)) == 1 26 | 27 | def find_duplicates(lst): 28 | return list(set([x for x in lst if lst.count(x) > 1])) 29 | 30 | def infer_data(model_name, dataset_name, out_file, logger, kg_init, stage="base", debate=2, verbose=False, api_nproc=4): 31 | 32 | res = {} 33 | if osp.exists(out_file): 34 | res = load(out_file) 35 | 36 | # Dataset init 37 | rank, world_size = get_rank_and_world_size() 38 | if rank == 0: 39 | dataset = TSVDataset(dataset_name) 40 | if world_size > 1: 41 | dist.barrier() 42 | dataset = TSVDataset(dataset_name) 43 | indices = list(range(rank, len(dataset), world_size)) 44 | lt = len(indices) 45 | data = dataset.data.iloc[indices] 46 | 47 | # If finished, will exit without building the model 48 | all_finished = True 49 | for i in range(lt): 50 | idx = data.iloc[i]['index'] 51 | if idx not in res: 52 | all_finished = False 53 | if all_finished: 54 | return 55 | data = data[~data['index'].isin(res)] 56 | lt = len(data) 57 | 58 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name 59 | 60 | for i in tqdm(range(lt)): 61 | idx = data.iloc[i]['index'] 62 | if idx in res: 63 | continue 64 | 65 | if stage != "base_line": 66 | struct = dataset.build_prompt_multi(data.iloc[i]) 67 | else: 68 | struct = dataset.build_prompt(data.iloc[i]) 69 | 70 | response = Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger) 71 | torch.cuda.empty_cache() 72 | 73 | if verbose: 74 | print(response, flush=True) 75 | 76 | res[idx] = response 77 | if (i + 1) % 20 == 0: 78 | dump(res, out_file) 79 | 80 | dump(res, out_file) 81 | return model 82 | 83 | def prefetch_acc(result_file): 84 | data = load(result_file) 85 | from vlmeval.evaluate.multiple_choice import build_choices, can_infer 86 | tot = defaultdict(lambda: 0) 87 | match = defaultdict(lambda: 0) 88 | hit = defaultdict(lambda: 0) 89 | lt = len(data) 90 | for i in range(lt): 91 | item = data.iloc[i] 92 | cate = item['category'] 93 | tot['Overall'] += 1 94 | tot[cate] += 1 95 | choices = build_choices(item) 96 | matched = can_infer(item['prediction'], choices) 97 | if matched: 98 | match['Overall'] += 1 99 | match[cate] += 1 100 | if matched == item['answer']: 101 | hit['Overall'] += 1 102 | hit[cate] += 1 103 | res = defaultdict(list) 104 | for k in tot.keys(): 105 | res['Category'].append(k) 106 | res['tot'].append(tot[k]) 107 | res['match'].append(match[k]) 108 | res['hit'].append(hit[k]) 109 | res['match_rate'].append(match[k] / tot[k] * 100) 110 | if match[k] == 0: 111 | res['acc'].append(0) 112 | else: 113 | res['acc'].append(hit[k] / match[k] * 100) 114 | res = pd.DataFrame(res) 115 | return res 116 | 117 | def infer_data_job(model, model_name, dataset_name, args, logger, ignore_failed=False): 118 | 119 | result_ = f'results/{model_name}/{dataset_name}/' 120 | result_file = result_ + f'{model_name}_{dataset_name}_{args.stage}_DB{args.debate}.xlsx' 121 | rank, world_size = get_rank_and_world_size() 122 | tmpl = result_ + '{}' + f'{world_size}_{dataset_name}_{args.stage}_DB{args.debate}.pkl' 123 | out_file = tmpl.format(rank) 124 | 125 | if not osp.exists(result_file): 126 | model = infer_data(model, dataset_name=dataset_name, out_file=out_file, logger=logger, kg_init=args.kg_init, stage=args.stage, debate=args.debate, verbose=args.verbose) 127 | if world_size > 1: 128 | dist.barrier() 129 | 130 | if rank == 0: 131 | data_all = {} 132 | for i in range(world_size): 133 | data_all.update(load(tmpl.format(i))) 134 | 135 | data = TSVDataset(dataset_name).data 136 | print(len(data_all)) 137 | print(len(data)) 138 | assert len(data_all) == len(data) 139 | data['prediction'] = [str(data_all[x]) for x in data['index']] 140 | data.pop('image') 141 | 142 | dump(data, result_file) 143 | for i in range(world_size): 144 | os.remove(tmpl.format(i)) 145 | return model 146 | else: 147 | data = load(result_file) 148 | failed_set = [] 149 | data['prediction'] = [str(x) for x in data['prediction']] 150 | for idx, pred in zip(data['index'], data['prediction']): 151 | if FAIL_MSG in str(pred): 152 | failed_set.append(idx) 153 | if len(failed_set) and (not ignore_failed): 154 | print(f'{len(failed_set)} records failed in the original result file {result_file}. ') 155 | assert rank == 0 and world_size == 1 156 | failed_set = set(failed_set) 157 | answer_map = {x: y for x, y in zip(data['index'], data['prediction'])} 158 | res = infer_data_api(model_name, dataset_name, failed_set, api_nproc=args.api_nproc) 159 | answer_map.update(res) 160 | data['prediction'] = [str(answer_map[x]) for x in data['index']] 161 | dump(data, result_file) 162 | return model_name 163 | 164 | def main(): 165 | logger = get_logger('Inference') 166 | 167 | args = parse_args() 168 | assert len(args.data), "--data should be a list of data files" 169 | 170 | rank, world_size = get_rank_and_world_size() 171 | if world_size > 1: 172 | torch.cuda.set_device(rank) 173 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400)) 174 | 175 | for _, model_name in enumerate(args.model): 176 | model = None 177 | os.makedirs(model_name, exist_ok=True) 178 | pred_root = model_name 179 | 180 | for i, dataset_name in enumerate(args.data): 181 | 182 | result_file = f'{pred_root}/{dataset_name}/{model_name}_{dataset_name}.xlsx' 183 | if model is None: 184 | model = model_name # which is only a name 185 | model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc) 186 | 187 | if rank == 0 and listinstr(['MMBench','ScienceQA'], dataset_name): 188 | time.sleep(3) 189 | res = prefetch_acc(result_file) 190 | print(model_name, res) 191 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx')) 192 | 193 | if __name__ == '__main__': 194 | main() 195 | -------------------------------------------------------------------------------- /vlmeval/.ipynb_checkpoints/inference_multi-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.distributed as dist 4 | import random 5 | import datetime 6 | from vlmeval.config import supported_VLM 7 | from vlmeval.utils import TSVDataset, track_progress_rich, split_MMMU, Debate_VLM 8 | from vlmeval.smp import * 9 | import logging 10 | import numpy as np 11 | 12 | FAIL_MSG = 'Failed to obtain answer via API.' 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data', type=str, nargs='+', required=True) 17 | parser.add_argument("--model", type=str, nargs='+', required=True) 18 | parser.add_argument("--nproc", type=int, default=4, required=True) 19 | parser.add_argument("--debate", type=int, default=2, required=True) 20 | parser.add_argument("--verbose", action='store_true') 21 | args = parser.parse_args() 22 | return args 23 | 24 | def check_identical_elements_2(lst): 25 | return len(set(lst)) == 1 26 | 27 | def find_duplicates(lst): 28 | return list(set([x for x in lst if lst.count(x) > 1])) 29 | 30 | def infer_data(model_name, dataset_name, out_file, logger, kg_init, stage="base", debate=2, verbose=False, api_nproc=4): 31 | 32 | res = {} 33 | if osp.exists(out_file): 34 | res = load(out_file) 35 | 36 | # Dataset init 37 | rank, world_size = get_rank_and_world_size() 38 | if rank == 0: 39 | dataset = TSVDataset(dataset_name) 40 | if world_size > 1: 41 | dist.barrier() 42 | dataset = TSVDataset(dataset_name) 43 | indices = list(range(rank, len(dataset), world_size)) 44 | lt = len(indices) 45 | data = dataset.data.iloc[indices] 46 | 47 | # If finished, will exit without building the model 48 | all_finished = True 49 | for i in range(lt): 50 | idx = data.iloc[i]['index'] 51 | if idx not in res: 52 | all_finished = False 53 | if all_finished: 54 | return 55 | data = data[~data['index'].isin(res)] 56 | lt = len(data) 57 | 58 | model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name 59 | 60 | for i in tqdm(range(lt)): 61 | idx = data.iloc[i]['index'] 62 | if idx in res: 63 | continue 64 | 65 | if stage != "base_line": 66 | struct = dataset.build_prompt_multi(data.iloc[i]) 67 | else: 68 | struct = dataset.build_prompt(data.iloc[i]) 69 | 70 | response = Debate_VLM(stage, model, struct, dataset_name, debate, kg_init, logger) 71 | torch.cuda.empty_cache() 72 | 73 | if verbose: 74 | print(response, flush=True) 75 | 76 | res[idx] = response 77 | if (i + 1) % 20 == 0: 78 | dump(res, out_file) 79 | 80 | dump(res, out_file) 81 | return model 82 | 83 | def prefetch_acc(result_file): 84 | data = load(result_file) 85 | from vlmeval.evaluate.multiple_choice import build_choices, can_infer 86 | tot = defaultdict(lambda: 0) 87 | match = defaultdict(lambda: 0) 88 | hit = defaultdict(lambda: 0) 89 | lt = len(data) 90 | for i in range(lt): 91 | item = data.iloc[i] 92 | cate = item['category'] 93 | tot['Overall'] += 1 94 | tot[cate] += 1 95 | choices = build_choices(item) 96 | matched = can_infer(item['prediction'], choices) 97 | if matched: 98 | match['Overall'] += 1 99 | match[cate] += 1 100 | if matched == item['answer']: 101 | hit['Overall'] += 1 102 | hit[cate] += 1 103 | res = defaultdict(list) 104 | for k in tot.keys(): 105 | res['Category'].append(k) 106 | res['tot'].append(tot[k]) 107 | res['match'].append(match[k]) 108 | res['hit'].append(hit[k]) 109 | res['match_rate'].append(match[k] / tot[k] * 100) 110 | if match[k] == 0: 111 | res['acc'].append(0) 112 | else: 113 | res['acc'].append(hit[k] / match[k] * 100) 114 | res = pd.DataFrame(res) 115 | return res 116 | 117 | def infer_data_job(model, model_name, dataset_name, args, logger, ignore_failed=False): 118 | 119 | result_ = f'results/{model_name}/{dataset_name}/' 120 | result_file = result_ + f'{model_name}_{dataset_name}_{args.stage}_DB{args.debate}.xlsx' 121 | rank, world_size = get_rank_and_world_size() 122 | tmpl = result_ + '{}' + f'{world_size}_{dataset_name}_{args.stage}_DB{args.debate}.pkl' 123 | out_file = tmpl.format(rank) 124 | 125 | if not osp.exists(result_file): 126 | model = infer_data(model, dataset_name=dataset_name, out_file=out_file, logger=logger, kg_init=args.kg_init, stage=args.stage, debate=args.debate, verbose=args.verbose) 127 | if world_size > 1: 128 | dist.barrier() 129 | 130 | if rank == 0: 131 | data_all = {} 132 | for i in range(world_size): 133 | data_all.update(load(tmpl.format(i))) 134 | 135 | data = TSVDataset(dataset_name).data 136 | print(len(data_all)) 137 | print(len(data)) 138 | assert len(data_all) == len(data) 139 | data['prediction'] = [str(data_all[x]) for x in data['index']] 140 | data.pop('image') 141 | 142 | dump(data, result_file) 143 | for i in range(world_size): 144 | os.remove(tmpl.format(i)) 145 | return model 146 | else: 147 | data = load(result_file) 148 | failed_set = [] 149 | data['prediction'] = [str(x) for x in data['prediction']] 150 | for idx, pred in zip(data['index'], data['prediction']): 151 | if FAIL_MSG in str(pred): 152 | failed_set.append(idx) 153 | if len(failed_set) and (not ignore_failed): 154 | print(f'{len(failed_set)} records failed in the original result file {result_file}. ') 155 | assert rank == 0 and world_size == 1 156 | failed_set = set(failed_set) 157 | answer_map = {x: y for x, y in zip(data['index'], data['prediction'])} 158 | res = infer_data_api(model_name, dataset_name, failed_set, api_nproc=args.api_nproc) 159 | answer_map.update(res) 160 | data['prediction'] = [str(answer_map[x]) for x in data['index']] 161 | dump(data, result_file) 162 | return model_name 163 | 164 | def main(): 165 | logger = get_logger('Inference') 166 | 167 | args = parse_args() 168 | assert len(args.data), "--data should be a list of data files" 169 | 170 | rank, world_size = get_rank_and_world_size() 171 | if world_size > 1: 172 | torch.cuda.set_device(rank) 173 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=5400)) 174 | 175 | for _, model_name in enumerate(args.model): 176 | model = None 177 | os.makedirs(model_name, exist_ok=True) 178 | pred_root = model_name 179 | 180 | for i, dataset_name in enumerate(args.data): 181 | 182 | result_file = f'{pred_root}/{dataset_name}/{model_name}_{dataset_name}.xlsx' 183 | if model is None: 184 | model = model_name # which is only a name 185 | model = infer_data_job(model, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc) 186 | 187 | if rank == 0 and listinstr(['MMBench','ScienceQA'], dataset_name): 188 | time.sleep(3) 189 | res = prefetch_acc(result_file) 190 | print(model_name, res) 191 | dump(res, result_file.replace('.xlsx', '_prefetch.xlsx')) 192 | 193 | if __name__ == '__main__': 194 | main() 195 | -------------------------------------------------------------------------------- /vlmeval/utils/mp_util.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | from typing import Callable, Iterable, Sized 4 | 5 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, 6 | TaskProgressColumn, TextColumn, TimeRemainingColumn) 7 | from rich.text import Text 8 | import os.path as osp 9 | import portalocker 10 | from ..smp import load, dump 11 | 12 | 13 | class _Worker: 14 | """Function wrapper for ``track_progress_rich``""" 15 | 16 | def __init__(self, func) -> None: 17 | self.func = func 18 | 19 | def __call__(self, inputs): 20 | inputs, idx = inputs 21 | if not isinstance(inputs, (tuple, list, dict)): 22 | inputs = (inputs, ) 23 | 24 | if isinstance(inputs, dict): 25 | return self.func(**inputs), idx 26 | else: 27 | return self.func(*inputs), idx 28 | 29 | 30 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): 31 | """Skip calculating remaining time for the first few times. 32 | 33 | Args: 34 | skip_times (int): The number of times to skip. Defaults to 0. 35 | """ 36 | 37 | def __init__(self, *args, skip_times=0, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self.skip_times = skip_times 40 | 41 | def render(self, task: Task) -> Text: 42 | """Show time remaining.""" 43 | if task.completed <= self.skip_times: 44 | return Text('-:--:--', style='progress.remaining') 45 | return super().render(task) 46 | 47 | 48 | def _tasks_with_index(tasks): 49 | """Add index to tasks.""" 50 | for idx, task in enumerate(tasks): 51 | yield task, idx 52 | 53 | def track_progress_rich(func: Callable, 54 | tasks: Iterable = tuple(), 55 | task_num: int = None, 56 | nproc: int = 1, 57 | chunksize: int = 1, 58 | description: str = 'Processing', 59 | save=None, keys=None, 60 | color: str = 'blue') -> list: 61 | """Track the progress of parallel task execution with a progress bar. The 62 | built-in :mod:`multiprocessing` module is used for process pools and tasks 63 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. 64 | 65 | Args: 66 | func (callable): The function to be applied to each task. 67 | tasks (Iterable or Sized): A tuple of tasks. There are several cases 68 | for different format tasks: 69 | - When ``func`` accepts no arguments: tasks should be an empty 70 | tuple, and ``task_num`` must be specified. 71 | - When ``func`` accepts only one argument: tasks should be a tuple 72 | containing the argument. 73 | - When ``func`` accepts multiple arguments: tasks should be a 74 | tuple, with each element representing a set of arguments. 75 | If an element is a ``dict``, it will be parsed as a set of 76 | keyword-only arguments. 77 | Defaults to an empty tuple. 78 | task_num (int, optional): If ``tasks`` is an iterator which does not 79 | have length, the number of tasks can be provided by ``task_num``. 80 | Defaults to None. 81 | nproc (int): Process (worker) number, if nuproc is 1, 82 | use single process. Defaults to 1. 83 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details. 84 | Defaults to 1. 85 | description (str): The description of progress bar. 86 | Defaults to "Process". 87 | color (str): The color of progress bar. Defaults to "blue". 88 | 89 | Examples: 90 | >>> import time 91 | 92 | >>> def func(x): 93 | ... time.sleep(1) 94 | ... return x**2 95 | >>> track_progress_rich(func, range(10), nproc=2) 96 | 97 | Returns: 98 | list: The task results. 99 | """ 100 | if save is not None: 101 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == '' 102 | if not osp.exists(save): 103 | dump({}, save) 104 | if keys is not None: 105 | assert len(keys) == len(tasks) 106 | 107 | if not callable(func): 108 | raise TypeError('func must be a callable object') 109 | if not isinstance(tasks, Iterable): 110 | raise TypeError( 111 | f'tasks must be an iterable object, but got {type(tasks)}') 112 | if isinstance(tasks, Sized): 113 | if len(tasks) == 0: 114 | if task_num is None: 115 | raise ValueError('If tasks is an empty iterable, ' 116 | 'task_num must be set') 117 | else: 118 | tasks = tuple(tuple() for _ in range(task_num)) 119 | else: 120 | if task_num is not None and task_num != len(tasks): 121 | raise ValueError('task_num does not match the length of tasks') 122 | task_num = len(tasks) 123 | 124 | if nproc <= 0: 125 | raise ValueError('nproc must be a positive number') 126 | 127 | skip_times = nproc * chunksize if nproc > 1 else 0 128 | prog_bar = Progress( 129 | TextColumn('{task.description}'), 130 | BarColumn(), 131 | _SkipFirstTimeRemainingColumn(skip_times=skip_times), 132 | MofNCompleteColumn(), 133 | TaskProgressColumn(show_speed=True), 134 | ) 135 | 136 | worker = _Worker(func) 137 | task_id = prog_bar.add_task( 138 | total=task_num, color=color, description=description) 139 | tasks = _tasks_with_index(tasks) 140 | 141 | # Use single process when nproc is 1, else use multiprocess. 142 | with prog_bar: 143 | if nproc == 1: 144 | results = [] 145 | for task in tasks: 146 | result, idx = worker(task) 147 | results.append(worker(task)[0]) 148 | if save is not None: 149 | with portalocker.Lock(save, timeout=5) as fh: 150 | ans = load(save) 151 | ans[keys[idx]] = result 152 | 153 | if os.environ.get('VERBOSE', True): 154 | print(keys[idx], result, flush=True) 155 | 156 | dump(ans, save) 157 | fh.flush() 158 | os.fsync(fh.fileno()) 159 | 160 | prog_bar.update(task_id, advance=1, refresh=True) 161 | else: 162 | with Pool(nproc) as pool: 163 | results = [] 164 | unordered_results = [] 165 | gen = pool.imap_unordered(worker, tasks, chunksize) 166 | try: 167 | for result in gen: 168 | result, idx = result 169 | unordered_results.append((result, idx)) 170 | 171 | if save is not None: 172 | with portalocker.Lock(save, timeout=5) as fh: 173 | ans = load(save) 174 | ans[keys[idx]] = result 175 | 176 | if os.environ.get('VERBOSE', False): 177 | print(keys[idx], result, flush=True) 178 | 179 | dump(ans, save) 180 | fh.flush() 181 | os.fsync(fh.fileno()) 182 | 183 | results.append(None) 184 | prog_bar.update(task_id, advance=1, refresh=True) 185 | except Exception as e: 186 | prog_bar.stop() 187 | raise e 188 | for result, idx in unordered_results: 189 | results[idx] = result 190 | return results -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/mp_util-checkpoint.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | from typing import Callable, Iterable, Sized 4 | 5 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, 6 | TaskProgressColumn, TextColumn, TimeRemainingColumn) 7 | from rich.text import Text 8 | import os.path as osp 9 | import portalocker 10 | from ..smp import load, dump 11 | 12 | 13 | class _Worker: 14 | """Function wrapper for ``track_progress_rich``""" 15 | 16 | def __init__(self, func) -> None: 17 | self.func = func 18 | 19 | def __call__(self, inputs): 20 | inputs, idx = inputs 21 | if not isinstance(inputs, (tuple, list, dict)): 22 | inputs = (inputs, ) 23 | 24 | if isinstance(inputs, dict): 25 | return self.func(**inputs), idx 26 | else: 27 | return self.func(*inputs), idx 28 | 29 | 30 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): 31 | """Skip calculating remaining time for the first few times. 32 | 33 | Args: 34 | skip_times (int): The number of times to skip. Defaults to 0. 35 | """ 36 | 37 | def __init__(self, *args, skip_times=0, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self.skip_times = skip_times 40 | 41 | def render(self, task: Task) -> Text: 42 | """Show time remaining.""" 43 | if task.completed <= self.skip_times: 44 | return Text('-:--:--', style='progress.remaining') 45 | return super().render(task) 46 | 47 | 48 | def _tasks_with_index(tasks): 49 | """Add index to tasks.""" 50 | for idx, task in enumerate(tasks): 51 | yield task, idx 52 | 53 | def track_progress_rich(func: Callable, 54 | tasks: Iterable = tuple(), 55 | task_num: int = None, 56 | nproc: int = 1, 57 | chunksize: int = 1, 58 | description: str = 'Processing', 59 | save=None, keys=None, 60 | color: str = 'blue') -> list: 61 | """Track the progress of parallel task execution with a progress bar. The 62 | built-in :mod:`multiprocessing` module is used for process pools and tasks 63 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. 64 | 65 | Args: 66 | func (callable): The function to be applied to each task. 67 | tasks (Iterable or Sized): A tuple of tasks. There are several cases 68 | for different format tasks: 69 | - When ``func`` accepts no arguments: tasks should be an empty 70 | tuple, and ``task_num`` must be specified. 71 | - When ``func`` accepts only one argument: tasks should be a tuple 72 | containing the argument. 73 | - When ``func`` accepts multiple arguments: tasks should be a 74 | tuple, with each element representing a set of arguments. 75 | If an element is a ``dict``, it will be parsed as a set of 76 | keyword-only arguments. 77 | Defaults to an empty tuple. 78 | task_num (int, optional): If ``tasks`` is an iterator which does not 79 | have length, the number of tasks can be provided by ``task_num``. 80 | Defaults to None. 81 | nproc (int): Process (worker) number, if nuproc is 1, 82 | use single process. Defaults to 1. 83 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details. 84 | Defaults to 1. 85 | description (str): The description of progress bar. 86 | Defaults to "Process". 87 | color (str): The color of progress bar. Defaults to "blue". 88 | 89 | Examples: 90 | >>> import time 91 | 92 | >>> def func(x): 93 | ... time.sleep(1) 94 | ... return x**2 95 | >>> track_progress_rich(func, range(10), nproc=2) 96 | 97 | Returns: 98 | list: The task results. 99 | """ 100 | if save is not None: 101 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == '' 102 | if not osp.exists(save): 103 | dump({}, save) 104 | if keys is not None: 105 | assert len(keys) == len(tasks) 106 | 107 | if not callable(func): 108 | raise TypeError('func must be a callable object') 109 | if not isinstance(tasks, Iterable): 110 | raise TypeError( 111 | f'tasks must be an iterable object, but got {type(tasks)}') 112 | if isinstance(tasks, Sized): 113 | if len(tasks) == 0: 114 | if task_num is None: 115 | raise ValueError('If tasks is an empty iterable, ' 116 | 'task_num must be set') 117 | else: 118 | tasks = tuple(tuple() for _ in range(task_num)) 119 | else: 120 | if task_num is not None and task_num != len(tasks): 121 | raise ValueError('task_num does not match the length of tasks') 122 | task_num = len(tasks) 123 | 124 | if nproc <= 0: 125 | raise ValueError('nproc must be a positive number') 126 | 127 | skip_times = nproc * chunksize if nproc > 1 else 0 128 | prog_bar = Progress( 129 | TextColumn('{task.description}'), 130 | BarColumn(), 131 | _SkipFirstTimeRemainingColumn(skip_times=skip_times), 132 | MofNCompleteColumn(), 133 | TaskProgressColumn(show_speed=True), 134 | ) 135 | 136 | worker = _Worker(func) 137 | task_id = prog_bar.add_task( 138 | total=task_num, color=color, description=description) 139 | tasks = _tasks_with_index(tasks) 140 | 141 | # Use single process when nproc is 1, else use multiprocess. 142 | with prog_bar: 143 | if nproc == 1: 144 | results = [] 145 | for task in tasks: 146 | result, idx = worker(task) 147 | results.append(worker(task)[0]) 148 | if save is not None: 149 | with portalocker.Lock(save, timeout=5) as fh: 150 | ans = load(save) 151 | ans[keys[idx]] = result 152 | 153 | if os.environ.get('VERBOSE', True): 154 | print(keys[idx], result, flush=True) 155 | 156 | dump(ans, save) 157 | fh.flush() 158 | os.fsync(fh.fileno()) 159 | 160 | prog_bar.update(task_id, advance=1, refresh=True) 161 | else: 162 | with Pool(nproc) as pool: 163 | results = [] 164 | unordered_results = [] 165 | gen = pool.imap_unordered(worker, tasks, chunksize) 166 | try: 167 | for result in gen: 168 | result, idx = result 169 | unordered_results.append((result, idx)) 170 | 171 | if save is not None: 172 | with portalocker.Lock(save, timeout=5) as fh: 173 | ans = load(save) 174 | ans[keys[idx]] = result 175 | 176 | if os.environ.get('VERBOSE', False): 177 | print(keys[idx], result, flush=True) 178 | 179 | dump(ans, save) 180 | fh.flush() 181 | os.fsync(fh.fileno()) 182 | 183 | results.append(None) 184 | prog_bar.update(task_id, advance=1, refresh=True) 185 | except Exception as e: 186 | prog_bar.stop() 187 | raise e 188 | for result, idx in unordered_results: 189 | results[idx] = result 190 | return results -------------------------------------------------------------------------------- /vlmeval/api/gpt.py: -------------------------------------------------------------------------------- 1 | from ..smp import * 2 | import os, sys 3 | from .base import BaseAPI 4 | 5 | APIBASES = { 6 | 'OFFICIAL': "https://api.openai.com/v1/chat/completions", 7 | } 8 | 9 | 10 | def GPT_context_window(model): 11 | length_map = { 12 | 'gpt-4-1106-preview': 128000, 13 | 'gpt-4-vision-preview': 128000, 14 | 'gpt-4': 8192, 15 | 'gpt-4-32k': 32768, 16 | 'gpt-4-0613': 8192, 17 | 'gpt-4-32k-0613': 32768, 18 | 'gpt-3.5-turbo-1106': 16385, 19 | 'gpt-3.5-turbo': 4096, 20 | 'gpt-3.5-turbo-16k': 16385, 21 | 'gpt-3.5-turbo-instruct': 4096, 22 | 'gpt-3.5-turbo-0613': 4096, 23 | 'gpt-3.5-turbo-16k-0613': 16385, 24 | } 25 | if model in length_map: 26 | return length_map[model] 27 | else: 28 | return 4096 29 | 30 | class OpenAIWrapper(BaseAPI): 31 | 32 | is_api: bool = True 33 | 34 | def __init__(self, 35 | model: str = 'gpt-3.5-turbo-0613', 36 | retry: int = 5, 37 | wait: int = 5, 38 | key: str = None, 39 | verbose: bool = True, 40 | system_prompt: str = None, 41 | temperature: float = 0, 42 | timeout: int = 60, 43 | api_base: str = 'OFFICIAL', 44 | max_tokens: int = 1024, 45 | img_size: int = 512, 46 | img_detail: str = 'low', 47 | **kwargs): 48 | 49 | self.model = model 50 | self.cur_idx = 0 51 | self.fail_msg = 'Failed to obtain answer via API. ' 52 | self.max_tokens = max_tokens 53 | self.temperature = temperature 54 | 55 | openai_key = os.environ.get('OPENAI_API_KEY', None) if key is None else key 56 | self.openai_key = openai_key 57 | assert img_size > 0 or img_size == -1 58 | self.img_size = img_size 59 | assert img_detail in ['high', 'low'] 60 | self.img_detail = img_detail 61 | 62 | self.vision = False 63 | if model == 'gpt-4-vision-preview': 64 | self.vision = True 65 | self.timeout = timeout 66 | 67 | assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. ' 68 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 69 | 70 | if api_base in APIBASES: 71 | self.api_base = APIBASES[api_base] 72 | elif api_base.startswith('http'): 73 | self.api_base = api_base 74 | else: 75 | self.logger.error("Unknown API Base. ") 76 | sys.exit(-1) 77 | 78 | if 'OPENAI_API_BASE' in os.environ: 79 | self.logger.error("Environment variable OPENAI_API_BASE is set. Will override the api_base arg. ") 80 | self.api_base = os.environ['OPENAI_API_BASE'] 81 | 82 | # inputs can be a lvl-2 nested list: [content1, content2, content3, ...] 83 | # content can be a string or a list of image & text 84 | def prepare_inputs(self, inputs): 85 | input_msgs = [] 86 | if self.system_prompt is not None: 87 | input_msgs.append(dict(role='system', content=self.system_prompt)) 88 | if isinstance(inputs, str): 89 | input_msgs.append(dict(role='user', content=inputs)) 90 | return input_msgs 91 | assert isinstance(inputs, list) 92 | dict_flag = [isinstance(x, dict) for x in inputs] 93 | if np.all(dict_flag): 94 | input_msgs.extend(inputs) 95 | return input_msgs 96 | str_flag = [isinstance(x, str) for x in inputs] 97 | if np.all(str_flag): 98 | img_flag = [x.startswith('http') or osp.exists(x) for x in inputs] 99 | if np.any(img_flag): 100 | content_list = [] 101 | for fl, msg in zip(img_flag, inputs): 102 | if not fl: 103 | content_list.append(dict(type='text', text=msg)) 104 | elif msg.startswith('http'): 105 | content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail})) 106 | elif osp.exists(msg): 107 | from PIL import Image 108 | img = Image.open(msg) 109 | b64 = encode_image_to_base64(img, target_size=self.img_size) 110 | img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail) 111 | content_list.append(dict(type='image_url', image_url=img_struct)) 112 | input_msgs.append(dict(role='user', content=content_list)) 113 | return input_msgs 114 | else: 115 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] 116 | roles = roles * len(inputs) 117 | for role, msg in zip(roles, inputs): 118 | input_msgs.append(dict(role=role, content=msg)) 119 | return input_msgs 120 | raise NotImplemented("list of list prompt not implemented now. ") 121 | 122 | def generate_inner(self, inputs, **kwargs) -> str: 123 | input_msgs = self.prepare_inputs(inputs) 124 | temperature = kwargs.pop('temperature', self.temperature) 125 | max_tokens = kwargs.pop('max_tokens', self.max_tokens) 126 | 127 | context_window = GPT_context_window(self.model) 128 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) 129 | if 0 < max_tokens <= 100: 130 | self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ') 131 | if max_tokens <= 0: 132 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. ' 133 | 134 | headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'} 135 | payload = dict( 136 | model=self.model, 137 | messages=input_msgs, 138 | max_tokens=max_tokens, 139 | n=1, 140 | temperature=temperature, 141 | **kwargs) 142 | response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1) 143 | ret_code = response.status_code 144 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code 145 | answer = self.fail_msg 146 | try: 147 | resp_struct = json.loads(response.text) 148 | answer = resp_struct['choices'][0]['message']['content'].strip() 149 | except: 150 | pass 151 | return ret_code, answer, response 152 | 153 | def get_token_len(self, inputs) -> int: 154 | import tiktoken 155 | enc = tiktoken.encoding_for_model(self.model) 156 | if isinstance(inputs, str): 157 | if inputs.startswith('http') or osp.exists(inputs): 158 | return 65 if self.img_detail == 'low' else 130 159 | else: 160 | return len(enc.encode(inputs)) 161 | elif isinstance(inputs, dict): 162 | assert 'content' in inputs 163 | return self.get_token_len(inputs['content']) 164 | assert isinstance(inputs, list) 165 | res = 0 166 | for item in inputs: 167 | res += self.get_token_len(item) 168 | return res 169 | 170 | class GPT4V(OpenAIWrapper): 171 | 172 | def generate(self, image_path, prompt, dataset=None): 173 | assert self.model == 'gpt-4-vision-preview' 174 | return super(GPT4V, self).generate([image_path, prompt]) 175 | 176 | def interleave_generate(self, ti_list, dataset=None): 177 | assert self.model == 'gpt-4-vision-preview' 178 | return super(GPT4V, self).generate(ti_list) 179 | -------------------------------------------------------------------------------- /vlmeval/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import hashlib 3 | from ..smp import * 4 | from .dataset_config import dataset_URLs, dataset_md5_dict, img_root_map, DATASET_TYPE 5 | from .custom_prompt import CustomPrompt 6 | from .base_prompt import create_one_example 7 | import csv 8 | 9 | def isliststr(s): 10 | return (s[0] == '[') and (s[-1] == ']') 11 | 12 | def check_md5(data_path, dataset): 13 | try: 14 | with open(data_path, 'rb') as f: 15 | hash = hashlib.new('md5') 16 | for chunk in iter(lambda: f.read(2**20), b''): 17 | hash.update(chunk) 18 | if str(hash.hexdigest()) == dataset_md5_dict[dataset]: 19 | return True 20 | else: 21 | warnings.warn('this data file is incomplete, so it needs to be downloaded again.') 22 | return False 23 | except: 24 | return False 25 | 26 | 27 | def split_MMMU(struct): 28 | assert 'image' in struct and 'text' in struct 29 | text, images = struct['text'], struct['image'] 30 | text_segs = text.split('' 36 | image_idx = int(seg[0]) - 1 37 | segs.append(images[image_idx]) 38 | segs.append(seg[2:]) 39 | return segs 40 | 41 | def init_prompt_multi(struct, format_): 42 | question = struct['text']['question'] if 'question' in struct['text'] else 'none' 43 | context = struct['text']['hint'] if 'hint' in struct['text'] else 'none' 44 | options = struct['text']['options'] if 'options' in struct['text'] else 'none' 45 | answer = struct['debate_ans'] if 'debate_ans' in struct else 'none' 46 | knowledge = struct['kg'] if 'kg' in struct else 'none' 47 | image_path = struct['image'] if 'image' in struct else 'none' 48 | 49 | prompt = create_one_example(format_, question, context.replace("\n"," "), options, answer, knowledge, image_path) 50 | return prompt 51 | 52 | class TSVDataset(CustomPrompt): 53 | 54 | def __init__(self, dataset='MMBench_DEV_EN', img_root=None, skip_noimg=True): 55 | 56 | self.data_root = LMUDataRoot() 57 | assert osp.exists(self.data_root) 58 | 59 | self.dataset = dataset 60 | self.dataset_type = DATASET_TYPE(dataset) 61 | 62 | url = dataset_URLs[dataset] 63 | file_name = url.split('/')[-1] 64 | data_path = osp.join(self.data_root, file_name) 65 | print(data_path) 66 | 67 | if osp.exists(data_path) and md5(data_path) == dataset_md5_dict[dataset]: 68 | print("Dateset is Download: ",data_path) 69 | pass 70 | else: 71 | warnings.warn("The dataset tsv is not downloaded") 72 | download_file(url, data_path) 73 | 74 | data = load(data_path) 75 | if dataset=="ScienceQA_TEST": 76 | kg_file = "/code/BDoG/data/kg_init/scienceqa_test_kg_gpt4.json" 77 | kg_base = json.load(open(kg_file)) 78 | kg_base = list(kg_base.values()) 79 | elif dataset=="MMBench_DEV_EN": 80 | kg_file = "/code/BDoG/data/kg_init/MMBench_DEV_EN_s.json" 81 | kg_base = json.load(open(kg_file)) 82 | kg_base = list(kg_base.values()) 83 | else: 84 | kg_base = ['none' for i in range(len(data))] 85 | 86 | self.skip_noimg = skip_noimg 87 | if skip_noimg: 88 | data = data[~pd.isna(data['image'])] 89 | 90 | # Prompt for Captioning 91 | if listinstr(['COCO'], dataset): 92 | data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data) 93 | 94 | data['index'] = [str(x) for x in data['index']] 95 | data['image'] = [str(x) for x in data['image']] 96 | ## Add kg_init 97 | data['kg'] = kg_base 98 | 99 | image_map = {x: y for x, y in zip(data['index'], data['image'])} 100 | for k in image_map: 101 | if len(image_map[k]) <= 64: 102 | idx = image_map[k] 103 | assert idx in image_map and len(image_map[idx]) > 64 104 | image_map[k] = image_map[idx] 105 | 106 | data['image'] = [ 107 | eval(image_map[k]) if isliststr(image_map[k]) else image_map[k] 108 | for k in data['index'] 109 | ] 110 | if 'image_path' in data: 111 | data['image_path'] = [ 112 | eval(pths) if isliststr(pths) else pths for pths in data['image_path'] 113 | ] 114 | if np.all([istype(x, int) for x in data['index']]): 115 | data['index'] = [int(x) for x in data['index']] 116 | 117 | self.data = data 118 | 119 | img_root = img_root if img_root is not None else osp.join('images', img_root_map[dataset]) 120 | os.makedirs(img_root, exist_ok=True) 121 | self.img_root = img_root 122 | 123 | # self.set_file() 124 | # print("#### Save: /code/VLMEvalKit-main/data/CCBench_Fin.tsv") 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def set_file(self,data_in,data_out): 130 | dataset = self.dataset 131 | print("##START Processing: ", dataset) 132 | with open(data_in, 'r') as in_file: 133 | with open(data_out, 'w', newline='') as out_file: 134 | reader = csv.DictReader(in_file, delimiter='\t') 135 | fieldnames = reader.fieldnames 136 | writer = csv.DictWriter(out_file, fieldnames=fieldnames, delimiter='\t') 137 | 138 | writer.writeheader() 139 | for i, row in enumerate(tqdm(reader)): 140 | # 给每一行增加一个键值对 141 | line = self.data.iloc[i] 142 | tgt_path = self.dump_image(line, dataset) 143 | row['image'] = tgt_path 144 | writer.writerow(row) 145 | print(writer.fieldnames) 146 | 147 | def build_prompt(self, line, dataset=None): 148 | if dataset is None: 149 | dataset = self.dataset 150 | 151 | if isinstance(line, int): 152 | line = self.data.iloc[line] 153 | 154 | tgt_path = self.dump_image(line, dataset) 155 | 156 | prompt = line['question'] 157 | if DATASET_TYPE(dataset) == 'multi-choice': 158 | question = line['question'] 159 | options = { 160 | cand: line[cand] 161 | for cand in string.ascii_uppercase 162 | if cand in line and not pd.isna(line[cand]) 163 | } 164 | options_prompt = 'Options:\n' 165 | for key, item in options.items(): 166 | options_prompt += f'{key}. {item}\n' 167 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None 168 | prompt = '' 169 | if hint is not None: 170 | prompt += f'Hint: {hint}\n' 171 | prompt += f'Question: {question}\n' 172 | if len(options): 173 | prompt += options_prompt 174 | prompt += "Answer with the option's letter from the given choices directly.\n" 175 | 176 | return dict(image=tgt_path, text=prompt) 177 | 178 | 179 | def build_prompt_multi(self, line, dataset=None): 180 | if dataset is None: 181 | dataset = self.dataset 182 | 183 | if isinstance(line, int): 184 | line = self.data.iloc[line] 185 | 186 | tgt_path = self.dump_image(line, dataset) 187 | prompt = {} 188 | prompt['index'] = line['index'] 189 | prompt['answer'] = line['gpt4_ans'] if dataset == "LLaVABench" else line['answer'] 190 | prompt['question'] = line['question'] 191 | if 'kg' in line: 192 | prompt['kg'] = str(line['kg'])[:350] 193 | else: 194 | prompt['kg'] = 'none' 195 | if DATASET_TYPE(dataset) == 'multi-choice': 196 | options = { 197 | cand: line[cand] 198 | for cand in string.ascii_uppercase 199 | if cand in line and not pd.isna(line[cand]) 200 | } 201 | options_prompt = '' 202 | choise = [] 203 | for key, item in options.items(): 204 | options_prompt += f'{key}. {item}\n' 205 | choise.append(item) 206 | 207 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None 208 | if dataset == "LLaVABench": 209 | hint = line['caption'] if ('caption' in line and not pd.isna(line['caption'])) else None 210 | 211 | if hint is not None: 212 | prompt['hint'] = f'{hint}' 213 | if len(options): 214 | prompt['options'] = options_prompt 215 | # prompt['options'] += 'Please select the correct answer from the options above.' 216 | # prompt['options'] += "Answer with the option's letter from the given choices directly" 217 | prompt['choise'] = choise 218 | # print(tgt_path) 219 | return dict(image=tgt_path, text=prompt) 220 | 221 | def display(self, line): 222 | if isinstance(line, int): 223 | line = self.data.iloc[line] 224 | mmqa_display(line) 225 | 226 | -------------------------------------------------------------------------------- /vlmeval/utils/.ipynb_checkpoints/dataset-checkpoint.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import hashlib 3 | from ..smp import * 4 | from .dataset_config import dataset_URLs, dataset_md5_dict, img_root_map, DATASET_TYPE 5 | from .custom_prompt import CustomPrompt 6 | from .base_prompt import create_one_example 7 | import csv 8 | 9 | def isliststr(s): 10 | return (s[0] == '[') and (s[-1] == ']') 11 | 12 | def check_md5(data_path, dataset): 13 | try: 14 | with open(data_path, 'rb') as f: 15 | hash = hashlib.new('md5') 16 | for chunk in iter(lambda: f.read(2**20), b''): 17 | hash.update(chunk) 18 | if str(hash.hexdigest()) == dataset_md5_dict[dataset]: 19 | return True 20 | else: 21 | warnings.warn('this data file is incomplete, so it needs to be downloaded again.') 22 | return False 23 | except: 24 | return False 25 | 26 | 27 | def split_MMMU(struct): 28 | assert 'image' in struct and 'text' in struct 29 | text, images = struct['text'], struct['image'] 30 | text_segs = text.split('' 36 | image_idx = int(seg[0]) - 1 37 | segs.append(images[image_idx]) 38 | segs.append(seg[2:]) 39 | return segs 40 | 41 | def init_prompt_multi(struct, format_): 42 | question = struct['text']['question'] if 'question' in struct['text'] else 'none' 43 | context = struct['text']['hint'] if 'hint' in struct['text'] else 'none' 44 | options = struct['text']['options'] if 'options' in struct['text'] else 'none' 45 | answer = struct['debate_ans'] if 'debate_ans' in struct else 'none' 46 | knowledge = struct['kg'] if 'kg' in struct else 'none' 47 | image_path = struct['image'] if 'image' in struct else 'none' 48 | 49 | prompt = create_one_example(format_, question, context.replace("\n"," "), options, answer, knowledge, image_path) 50 | return prompt 51 | 52 | class TSVDataset(CustomPrompt): 53 | 54 | def __init__(self, dataset='MMBench_DEV_EN', img_root=None, skip_noimg=True): 55 | 56 | self.data_root = LMUDataRoot() 57 | assert osp.exists(self.data_root) 58 | 59 | self.dataset = dataset 60 | self.dataset_type = DATASET_TYPE(dataset) 61 | 62 | url = dataset_URLs[dataset] 63 | file_name = url.split('/')[-1] 64 | data_path = osp.join(self.data_root, file_name) 65 | print(data_path) 66 | 67 | if osp.exists(data_path) and md5(data_path) == dataset_md5_dict[dataset]: 68 | print("Dateset is Download: ",data_path) 69 | pass 70 | else: 71 | warnings.warn("The dataset tsv is not downloaded") 72 | download_file(url, data_path) 73 | 74 | data = load(data_path) 75 | if dataset=="ScienceQA_TEST": 76 | kg_file = "/code/BDoG/data/kg_init/scienceqa_test_kg_gpt4.json" 77 | kg_base = json.load(open(kg_file)) 78 | kg_base = list(kg_base.values()) 79 | elif dataset=="MMBench_DEV_EN": 80 | kg_file = "/code/BDoG/data/kg_init/MMBench_DEV_EN_s.json" 81 | kg_base = json.load(open(kg_file)) 82 | kg_base = list(kg_base.values()) 83 | else: 84 | kg_base = ['none' for i in range(len(data))] 85 | 86 | self.skip_noimg = skip_noimg 87 | if skip_noimg: 88 | data = data[~pd.isna(data['image'])] 89 | 90 | # Prompt for Captioning 91 | if listinstr(['COCO'], dataset): 92 | data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data) 93 | 94 | data['index'] = [str(x) for x in data['index']] 95 | data['image'] = [str(x) for x in data['image']] 96 | ## Add kg_init 97 | data['kg'] = kg_base 98 | 99 | image_map = {x: y for x, y in zip(data['index'], data['image'])} 100 | for k in image_map: 101 | if len(image_map[k]) <= 64: 102 | idx = image_map[k] 103 | assert idx in image_map and len(image_map[idx]) > 64 104 | image_map[k] = image_map[idx] 105 | 106 | data['image'] = [ 107 | eval(image_map[k]) if isliststr(image_map[k]) else image_map[k] 108 | for k in data['index'] 109 | ] 110 | if 'image_path' in data: 111 | data['image_path'] = [ 112 | eval(pths) if isliststr(pths) else pths for pths in data['image_path'] 113 | ] 114 | if np.all([istype(x, int) for x in data['index']]): 115 | data['index'] = [int(x) for x in data['index']] 116 | 117 | self.data = data 118 | 119 | img_root = img_root if img_root is not None else osp.join('images', img_root_map[dataset]) 120 | os.makedirs(img_root, exist_ok=True) 121 | self.img_root = img_root 122 | 123 | # self.set_file() 124 | # print("#### Save: /code/VLMEvalKit-main/data/CCBench_Fin.tsv") 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def set_file(self,data_in,data_out): 130 | dataset = self.dataset 131 | print("##START Processing: ", dataset) 132 | with open(data_in, 'r') as in_file: 133 | with open(data_out, 'w', newline='') as out_file: 134 | reader = csv.DictReader(in_file, delimiter='\t') 135 | fieldnames = reader.fieldnames 136 | writer = csv.DictWriter(out_file, fieldnames=fieldnames, delimiter='\t') 137 | 138 | writer.writeheader() 139 | for i, row in enumerate(tqdm(reader)): 140 | # 给每一行增加一个键值对 141 | line = self.data.iloc[i] 142 | tgt_path = self.dump_image(line, dataset) 143 | row['image'] = tgt_path 144 | writer.writerow(row) 145 | print(writer.fieldnames) 146 | 147 | def build_prompt(self, line, dataset=None): 148 | if dataset is None: 149 | dataset = self.dataset 150 | 151 | if isinstance(line, int): 152 | line = self.data.iloc[line] 153 | 154 | tgt_path = self.dump_image(line, dataset) 155 | 156 | prompt = line['question'] 157 | if DATASET_TYPE(dataset) == 'multi-choice': 158 | question = line['question'] 159 | options = { 160 | cand: line[cand] 161 | for cand in string.ascii_uppercase 162 | if cand in line and not pd.isna(line[cand]) 163 | } 164 | options_prompt = 'Options:\n' 165 | for key, item in options.items(): 166 | options_prompt += f'{key}. {item}\n' 167 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None 168 | prompt = '' 169 | if hint is not None: 170 | prompt += f'Hint: {hint}\n' 171 | prompt += f'Question: {question}\n' 172 | if len(options): 173 | prompt += options_prompt 174 | prompt += "Answer with the option's letter from the given choices directly.\n" 175 | 176 | return dict(image=tgt_path, text=prompt) 177 | 178 | 179 | def build_prompt_multi(self, line, dataset=None): 180 | if dataset is None: 181 | dataset = self.dataset 182 | 183 | if isinstance(line, int): 184 | line = self.data.iloc[line] 185 | 186 | tgt_path = self.dump_image(line, dataset) 187 | prompt = {} 188 | prompt['index'] = line['index'] 189 | prompt['answer'] = line['gpt4_ans'] if dataset == "LLaVABench" else line['answer'] 190 | prompt['question'] = line['question'] 191 | if 'kg' in line: 192 | prompt['kg'] = str(line['kg'])[:350] 193 | else: 194 | prompt['kg'] = 'none' 195 | if DATASET_TYPE(dataset) == 'multi-choice': 196 | options = { 197 | cand: line[cand] 198 | for cand in string.ascii_uppercase 199 | if cand in line and not pd.isna(line[cand]) 200 | } 201 | options_prompt = '' 202 | choise = [] 203 | for key, item in options.items(): 204 | options_prompt += f'{key}. {item}\n' 205 | choise.append(item) 206 | 207 | hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None 208 | if dataset == "LLaVABench": 209 | hint = line['caption'] if ('caption' in line and not pd.isna(line['caption'])) else None 210 | 211 | if hint is not None: 212 | prompt['hint'] = f'{hint}' 213 | if len(options): 214 | prompt['options'] = options_prompt 215 | # prompt['options'] += 'Please select the correct answer from the options above.' 216 | # prompt['options'] += "Answer with the option's letter from the given choices directly" 217 | prompt['choise'] = choise 218 | # print(tgt_path) 219 | return dict(image=tgt_path, text=prompt) 220 | 221 | def display(self, line): 222 | if isinstance(line, int): 223 | line = self.data.iloc[line] 224 | mmqa_display(line) 225 | 226 | -------------------------------------------------------------------------------- /vlmeval/api/hf_chat_model.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | import torch 4 | from ..smp import * 5 | 6 | def get_gpu_num(model_name): 7 | model_name = model_name.lower() 8 | kws = { 9 | 8: ['65b', '70b'], 10 | 4: ['30b', '33b', '35b', '40b'], 11 | 2: ['13b', '14b', '20b'], 12 | 1: ['6b', '7b', 'moss'], 13 | } 14 | for k in [8, 4, 2, 1]: 15 | for keyword in kws[k]: 16 | if keyword in model_name: 17 | return k 18 | return 8 19 | 20 | validated_llms = [ 21 | 'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b', 22 | 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', 23 | 'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k', 24 | 'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat', 25 | 'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5', 26 | 'meta-llama/Llama-2-7b-chat-hf' 27 | ] 28 | Auto_model = ['chatglm'] 29 | 30 | class HFChatModel: 31 | 32 | def _get_context_length(self, model, model_path): 33 | # By default, we use model.config.seq_length 34 | model_path = model_path.lower() 35 | if 'baichuan' in model_path: 36 | context_window = model.config.model_max_length 37 | elif 'internlm' in model_path or 'llama' in model_path: 38 | context_window = model.config.max_position_embeddings 39 | elif 'vicuna' in model_path: 40 | context_window = model.generation_config.max_length 41 | else: 42 | # chatglm & qwen 43 | context_window = model.config.seq_length 44 | return context_window 45 | 46 | def _get_context_length_robust(self, model, model_path): 47 | try: 48 | context_window = self._get_context_length(model, model_path) 49 | return context_window 50 | except: 51 | self.logger.critical( 52 | "Failed to extract context_window information from config / generation_config. " 53 | "Please read the above code and check if the logic works for you model path" 54 | ) 55 | raise NotImplementedError 56 | 57 | def __init__(self, 58 | model_path, 59 | system_prompt: str=None, 60 | **kwargs): 61 | 62 | self.logger = get_logger('HFChatModel') 63 | if 'vicuna' in model_path.lower(): 64 | try: 65 | from fastchat.model import get_conversation_template 66 | except: 67 | self.logger.critical("Please install fastchat first to use vicuna. ") 68 | sys.exit(-1) 69 | 70 | self.explicit_device = kwargs.pop('device', None) 71 | 72 | if self.explicit_device is None: 73 | # If CUDA_VISIBLE_DEVICES is not properly set 74 | if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] in ['', '0,1,2,3,4,5,6,7']: 75 | num_gpu = get_gpu_num(model_path) 76 | gpu_offset = kwargs.pop('gpu_offset', 0) 77 | cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset+num_gpu)]) 78 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices 79 | 80 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel 81 | from transformers.generation import GenerationConfig 82 | 83 | if model_path not in validated_llms: 84 | self.logger.warning(f"{model_path} not in validated LLMs, may have inference troubles. ") 85 | 86 | self.model_path = model_path 87 | if listinstr(Auto_model, model_path): 88 | LoadModel = AutoModel 89 | else: 90 | LoadModel = AutoModelForCausalLM 91 | 92 | assert osp.exists(model_path) or len(model_path.split('/')) == 2 93 | 94 | device = self.explicit_device if self.explicit_device else "auto" 95 | 96 | precision = {} 97 | if 'internlm-chat-7b' in model_path: 98 | precision = {'torch_dtype': torch.float16} 99 | elif 'internlm-chat-20b' in model_path: 100 | precision = {'torch_dtype': torch.bfloat16} 101 | 102 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 103 | model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision) 104 | model = model.eval() 105 | 106 | if device != 'cpu': 107 | model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda') 108 | try: 109 | model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True, device_map=device) 110 | except: 111 | pass 112 | 113 | torch.cuda.empty_cache() 114 | self.model = model 115 | self.context_length = self._get_context_length_robust(model=model, model_path=model_path) 116 | self.answer_buffer = 192 117 | self.system_prompt = system_prompt 118 | for k, v in kwargs.items(): 119 | self.logger.info(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ') 120 | self.kwargs = kwargs 121 | 122 | def generate_str(self, input, **kwargs): 123 | if 'baichuan' in self.model_path.lower(): 124 | messages=[] 125 | messages.append({"role": "user", "content": input}) 126 | resp= self.model.chat(self.tokenizer, messages, **kwargs) 127 | elif 'vicuna' in self.model_path.lower(): 128 | from fastchat.model import get_conversation_template 129 | conv = get_conversation_template('vicuna') 130 | conv.append_message(conv.roles[0], input) 131 | conv.append_message(conv.roles[1], None) 132 | prompt = conv.get_prompt() 133 | inputs = self.tokenizer([prompt], return_tensors="pt") 134 | if torch.cuda.is_available(): 135 | for k in inputs: 136 | inputs[k] = inputs[k].cuda() 137 | 138 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512) 139 | params.update(self.kwargs) 140 | params.update(kwargs) 141 | outputs = self.model.generate(**inputs, **params) 142 | resp = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False) 143 | 144 | else: 145 | params = self.kwargs 146 | params.update(kwargs) 147 | resp, _ = self.model.chat(self.tokenizer, input, history=[], **params) 148 | 149 | return resp 150 | 151 | def length_ok(self, inputs): 152 | tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0 153 | for s in inputs: 154 | tot += len(self.tokenizer.encode(s)) 155 | return tot + self.answer_buffer < self.context_length 156 | 157 | def generate_list(self, full_inputs, offset=0, **kwargs): 158 | assert isinstance(full_inputs, list) 159 | 160 | inputs = full_inputs[offset:] 161 | if not self.length_ok(inputs): 162 | return self.chat(full_inputs, offset + 1) 163 | 164 | model_path = self.model_path.lower() 165 | 166 | if sum([x in model_path for x in ['baichuan']]): 167 | input_msgs = [] 168 | if self.system_prompt is not None: 169 | input_msgs.append(dict(role='user', content=self.system_prompt)) 170 | if len(inputs): 171 | assert isinstance(inputs, list) and isinstance(inputs[0], str) 172 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] 173 | roles = roles * len(inputs) 174 | for role, msg in zip(roles, inputs): 175 | input_msgs.append(dict(role=role, content=msg)) 176 | response = self.model.chat(self.tokenizer, input_msgs) 177 | elif sum([x in model_path for x in ['vicuna']]): 178 | from fastchat.model import get_conversation_template 179 | conv = get_conversation_template('vicuna') 180 | assert isinstance(inputs, list) and isinstance(inputs[0], str) 181 | if len(inputs) % 2 == 1: 182 | if self.system_prompt is not None: 183 | conv.append_message(conv.roles[0], self.system_prompt) 184 | for i in range(len(inputs)//2): 185 | conv.append_message(conv.roles[0], inputs[2 * i]) 186 | conv.append_message(conv.roles[1], inputs[2 * i + 1]) 187 | else: 188 | assert self.system_prompt is not None 189 | conv.append_message(conv.roles[0], self.system_prompt) 190 | conv.append_message(conv.roles[1], inputs[0]) 191 | for i in range(len(inputs) // 2 - 1): 192 | conv.append_message(conv.roles[0], inputs[2 * i + 1]) 193 | conv.append_message(conv.roles[1], inputs[2 * i + 2]) 194 | conv.append_message(conv.roles[0], inputs[-1]) 195 | conv.append_message(conv.roles[1], None) 196 | prompt = conv.get_prompt() 197 | inputs = self.tokenizer([prompt], return_tensors="pt") 198 | if torch.cuda.is_available(): 199 | for k in inputs: 200 | inputs[k] = inputs[k].cuda() 201 | 202 | params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512) 203 | params.update(self.kwargs) 204 | params.update(kwargs) 205 | 206 | outputs = self.model.generate(**inputs, **params) 207 | response = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False) 208 | response = response.lstrip('\n') 209 | else: 210 | # The default option, support internlm, chatglm, qwen 211 | history, msg = [], None 212 | if len(inputs) % 2 == 1: 213 | if self.system_prompt is not None: 214 | history = [(self.system_prompt, '')] 215 | for i in range(len(inputs)//2): 216 | history.append((inputs[2 * i], inputs[2 * i + 1])) 217 | else: 218 | assert self.system_prompt is not None 219 | history = [(self.system_prompt, inputs[0])] 220 | for i in range(len(inputs) // 2 - 1): 221 | history.append((inputs[2 * i + 1], inputs[2 * i + 2])) 222 | msg = inputs[-1] 223 | 224 | params = self.kwargs 225 | params.update(kwargs) 226 | response, _ = self.model.chat(self.tokenizer, msg, history=history, **params) 227 | 228 | return response, offset 229 | 230 | def generate(self, inputs, **kwargs): 231 | if isinstance(inputs, str): 232 | return self.generate_str(inputs, **kwargs) 233 | elif isinstance(inputs, list): 234 | return self.generate_list(inputs, **kwargs) -------------------------------------------------------------------------------- /vlmeval/evaluate/multiple_choice.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from vlmeval.evaluate.misc import build_judge 5 | from vlmeval.utils import can_infer, track_progress_rich, TSVDataset 6 | from vlmeval.smp import * 7 | import numpy as np 8 | 9 | INTERNAL = os.environ.get('INTERNAL', 0) 10 | 11 | abbrs = { 12 | 'coarse_perception': 'CP', 13 | 'finegrained_perception (instance-level)': 'FP-S', 14 | 'finegrained_perception (cross-instance)': 'FP-C', 15 | 'logic_reasoning': 'LR', 16 | 'relation_reasoning': 'RR', 17 | 'attribute_reasoning': 'AR' 18 | } 19 | 20 | 21 | 22 | 23 | def MMMU_preproc(data): 24 | logger = get_logger('Evaluation') 25 | cnt = 0 26 | As, Bs, Ans = list(data['A']), list(data['B']), list(data['answer']) 27 | lt = len(data) 28 | for i in range(lt): 29 | if pd.isna(As[i]): 30 | As[i] = Ans[i] 31 | Bs[i] = 'Other Answers' 32 | cnt += 1 33 | logger.info(f'During MMMU_preproc in Evaluation, {cnt} open questions are re-formulated to multi-choice ones. ') 34 | data['A'] = As 35 | data['B'] = Bs 36 | return data 37 | 38 | 39 | def report_acc(df): 40 | # assert group in [None, 'category', 'l2-category'] 41 | res = defaultdict(list) 42 | 43 | if 'split' in df: 44 | splits = list(set(df['split'])) 45 | res['split'] = splits 46 | else: 47 | df['split'] = ['none'] * len(df) 48 | res['split'] = ['none'] 49 | 50 | for group in [None, 'l2-category', 'category']: 51 | if group is None: 52 | res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']] 53 | elif group not in df: 54 | continue 55 | else: 56 | abilities = list(set(df[group])) 57 | abilities.sort() 58 | for ab in abilities: 59 | ab_name = abbrs[ab] if ab in abbrs else ab 60 | sub_df = df[df[group] == ab] 61 | res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']] 62 | return pd.DataFrame(res) 63 | 64 | 65 | def build_prompt(question, options, prediction): 66 | tmpl = ( 67 | 'You are an AI assistant who will help me to match ' 68 | 'an answer with several options of a single-choice question. ' 69 | 'You are provided with a question, several options, and an answer, ' 70 | 'and you need to find which option is most similar to the answer. ' 71 | 'If the meaning of all options are significantly different from the answer, output Z. ' 72 | 'Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n' 73 | 'Example 1: \n' 74 | 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n' 75 | 'Answer: a cute teddy bear\nYour output: A\n' 76 | 'Example 2: \n' 77 | 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n' 78 | 'Answer: Spider\nYour output: Z\n' 79 | 'Example 3: \n' 80 | 'Question: {}?\nOptions: {}\nAnswer: {}\nYour output: ' 81 | ) 82 | return tmpl.format(question, options, prediction) 83 | 84 | 85 | def build_prompt_cn(question, options, prediction): 86 | tmpl = ( 87 | '你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。' 88 | '你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。' 89 | '如果所有选项的意义都与答案显著不同,则输出 Z。' 90 | '你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。' 91 | '例 1:' 92 | '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n' 93 | '例 2: \n' 94 | '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n' 95 | '例 3: \n' 96 | '问题: {}?\n选项: {}\n答案: {}\n输出: ' 97 | ) 98 | return tmpl.format(question, options, prediction) 99 | 100 | 101 | def build_choices(item): 102 | ret = {} 103 | for ch in string.ascii_uppercase: 104 | if ch in item and (not pd.isna(item[ch])): 105 | ret[ch] = item[ch] 106 | return ret 107 | 108 | 109 | def prefetch_answer(item): 110 | choices = build_choices(item) 111 | return can_infer(item['prediction'], choices) 112 | 113 | 114 | def extract_answer_from_item(model, item): 115 | logger = get_logger('Evaluation') 116 | # It will return: (pred, raw, llm_time) 117 | choices = build_choices(item) 118 | option_str = build_option_str(choices) 119 | 120 | if cn_string(item['question']): 121 | prompt = build_prompt_cn(item['question'], option_str, item['prediction']) 122 | else: 123 | prompt = build_prompt(item['question'], option_str, item['prediction']) 124 | retry = 3 125 | 126 | ret = can_infer(item['prediction'], choices) 127 | if ret: 128 | return dict(opt=ret, log=item['prediction']) 129 | 130 | while retry: 131 | ans = model.generate(prompt) 132 | if 'Failed to obtain answer via API' in ans: 133 | logger.warning('GPT API failed to answer. ') 134 | else: 135 | ret = can_infer(ans, choices) 136 | if ret: 137 | return dict(opt=ret, log=ans) 138 | else: 139 | logger.warning(f'Output includes 0 / > 1 letter among candidates {set(choices)} and Z: {ans}') 140 | retry -= 1 141 | 142 | if retry == 0: 143 | options = list(choices) + ['Z'] if 'Z' not in choices else [] 144 | return dict(opt=rd.choice(options), log='Failed to predict, thus randomly generate one. ') 145 | 146 | 147 | def prefetch_sub_data(sub_data, answer_map, verbose=False): 148 | lt = len(sub_data) 149 | GT, PRED = [], [] 150 | for i in range(lt): 151 | item = sub_data.iloc[i] 152 | idx = item['index'] 153 | GT.append(answer_map[idx]) 154 | PRED.append(prefetch_answer(item)) 155 | if PRED[-1] and (GT[-1] != PRED[-1]): 156 | log = ( 157 | f'Failed in Prefetching Rolling {i}: Answer is {GT[-1]}, ' 158 | f"Prediction is {item['prediction']}, Pre-fetched is {PRED[-1]}. " 159 | ) 160 | return dict(hit=0, log=log) 161 | flag = True 162 | for g, p in zip(GT, PRED): 163 | if g != p: 164 | flag = False 165 | ret = (dict(hit=1, log='Succeed During Pre-fetching'), ) if flag else (None, ) 166 | ret = ret + (GT, PRED) if verbose else ret 167 | return ret if len(ret) > 1 else ret[0] 168 | 169 | 170 | def eval_sub_data(model, sub_data, answer_map): 171 | res, GT, PRED = prefetch_sub_data(sub_data, answer_map, verbose=True) 172 | if res is not None: 173 | return res 174 | 175 | lt = len(sub_data) 176 | log = '' 177 | for i in range(lt): 178 | if PRED[i]: 179 | log += f'Rolling {i} Matched.\n' 180 | else: 181 | res = extract_answer_from_item(model, sub_data.iloc[i]) 182 | opt, match_log = res['opt'], res['log'] 183 | PRED[i] = opt 184 | if PRED[i] != GT[i]: 185 | log += ( 186 | f"Failed in Rolling {i}: Answer is {GT[i]}; Prediction is {sub_data.iloc[i]['prediction']}; " 187 | f'Pre-fetched is {PRED[i]}; Match Log is {match_log}.\n' 188 | ) 189 | return dict(hit=0, log=log) 190 | else: 191 | log += ( 192 | f"Rolling {i}: Answer is {GT[i]}, Prediction is {sub_data.iloc[i]['prediction']}, " 193 | f'Pre-fetched is {PRED[i]}.\n' 194 | ) 195 | 196 | return dict(hit=1, log=log) 197 | 198 | 199 | def eval_data_groups(model, data_groups, answer_map, result, result_file, nproc=16): 200 | prefetched = [prefetch_sub_data(g, answer_map, verbose=False) for g in data_groups] 201 | remain = [] 202 | for dg, pf in zip(data_groups, prefetched): 203 | if pf: 204 | result[dg.iloc[0]['index'] % 1e6] = pf 205 | else: 206 | remain.append(dg) 207 | dump(result, result_file) 208 | tups = [(model, x, answer_map) for x in remain] 209 | keys = [x.iloc[0]['index'] % 1e6 for x in remain] 210 | if len(tups) == 0: 211 | return 212 | 213 | if model is None: 214 | logger = get_logger('Evaluation') 215 | logger.warning('Exact Matching mode, will not do GPT-based answer matching. ') 216 | for k in keys: 217 | result[k] = dict( 218 | hit=0, log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.') 219 | dump(result, result_file) 220 | return 221 | 222 | res = track_progress_rich( 223 | eval_sub_data, 224 | tups, 225 | nproc=nproc, 226 | chunksize=nproc, 227 | save=result_file, 228 | keys=keys) 229 | result = load(result_file) 230 | for k, v in zip(keys, res): 231 | if k in result: 232 | assert result[k]['hit'] == v['hit'] and result[k]['log'] == v['log'] 233 | else: 234 | result[k] = v 235 | dump(result, result_file) 236 | 237 | 238 | def multiple_choice_eval(eval_file, dataset='default', model='chatgpt-0613', nproc=4, verbose=False): 239 | logger = get_logger('Evaluation') 240 | 241 | # assert dataset is not None 242 | if dataset == 'MMBench_TEST_CN': 243 | dataset = 'MMBench_CN' 244 | elif dataset == 'MMBench_TEST_EN': 245 | dataset = 'MMBench' 246 | 247 | # if listinstr(['mmbench', 'ccbench'], dataset.lower()): 248 | # data = load(eval_file) 249 | # data['index'] = [int(x) for x in data['index']] 250 | # dump(data, eval_file) 251 | 252 | rd.seed(2680) 253 | suffix = eval_file.split('.')[-1] 254 | assert model in ['chatgpt-0613', 'exact_matching', 'gpt-4-0125'] 255 | name_str_map = { 256 | 'chatgpt-0613': 'openai', 257 | 'gpt-4-0125': 'gpt4' 258 | } 259 | name_str = name_str_map[model] if model in name_str_map else model 260 | 261 | if model == 'exact_matching': 262 | model = None 263 | else: 264 | if INTERNAL or gpt_key_set(): 265 | model = build_judge(model, verbose=verbose, retry=10) 266 | else: 267 | logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation') 268 | model = None 269 | 270 | logger.info(f'Evaluating {eval_file}') 271 | result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl') 272 | result = {} 273 | if osp.exists(result_file): 274 | result = load(result_file) 275 | 276 | data = load(eval_file) 277 | data = data.sort_values(by='index') 278 | data['prediction'] = [str(x) for x in data['prediction']] 279 | for k in data.keys(): 280 | data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k) 281 | 282 | if dataset != 'default': 283 | meta = TSVDataset(dataset).data 284 | else: 285 | logger.warning('Dataset is not provided, try to use the original `eval_file` as meta data. ') 286 | meta = load(eval_file) 287 | assert 'category' in meta and 'index' in meta and 'answer' in meta, ( 288 | 'Essentail columns missing in the eval_file.') 289 | 290 | cate_map = {i: c for i, c in zip(meta['index'], meta['category'])} 291 | answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])} 292 | l2_cate_map = {i: c for i, c in zip(meta['index'], meta['l2-category'])} if 'l2-category' in meta else None 293 | split_map = {i: c for i, c in zip(meta['index'], meta['split'])} if 'split' in meta else None 294 | 295 | if l2_cate_map is not None and np.all([pd.isna(x) for x in l2_cate_map.values()]): 296 | l2_cate_map = None 297 | if split_map is not None and np.all([pd.isna(x) for x in split_map.values()]): 298 | split_map = None 299 | 300 | if listinstr(['MMMU'], dataset): 301 | data = MMMU_preproc(data) 302 | answer_map = {k: (v if v in list(string.ascii_uppercase) else 'A') for k, v in answer_map.items()} 303 | 304 | data = data[data['index'].isin(answer_map)] 305 | data_main = data[data['index'] < int(1e6)] 306 | meta_idx_set = set(meta['index']) 307 | data_main = data_main[data_main['index'].isin(meta_idx_set)] 308 | 309 | lt = len(data_main) 310 | hit, tot = 0, 0 311 | 312 | data_groups = [] 313 | for i in tqdm(range(lt)): 314 | # Dealing with the normal part 315 | item_main = data_main.iloc[i] 316 | idx = item_main['index'] 317 | 318 | if idx in result: 319 | correct = result[idx]['hit'] 320 | assert correct in [0, 1] 321 | hit += correct 322 | tot += 1 323 | continue 324 | 325 | sub_data = data[data['index'] % int(1e6) == idx] 326 | data_groups.append(sub_data) 327 | 328 | if len(data_groups): 329 | eval_data_groups( 330 | model=model, 331 | data_groups=data_groups, 332 | answer_map=answer_map, 333 | nproc=nproc, 334 | result=result, 335 | result_file=result_file) 336 | 337 | tmp_pth = f'/tmp/{timestr()}.xlsx' 338 | dump(data_main, tmp_pth) 339 | data_main = load(tmp_pth) 340 | 341 | res = load(result_file) 342 | indices = data_main['index'] 343 | 344 | data_main['hit'] = [res[i]['hit'] for i in indices] 345 | data_main['log'] = [res[i]['log'] for i in indices] 346 | 347 | main_idx = data_main['index'] 348 | data_main['category'] = [cate_map[i] for i in main_idx] 349 | if l2_cate_map is not None: 350 | data_main['l2-category'] = [l2_cate_map[i] for i in main_idx] 351 | if split_map is not None: 352 | data_main['split'] = [split_map[i] for i in indices] 353 | 354 | # load split 355 | dump(data_main, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')) 356 | data_main = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')) 357 | 358 | acc = report_acc(data_main) 359 | score_file = eval_file.replace(f'.{suffix}', '_acc.csv') 360 | dump(acc, score_file) 361 | logger.info(f'multiple_choice_eval successfully finished evaluating {eval_file}, results saved in {score_file}') 362 | logger.info('Score: ') 363 | logger.info(acc) 364 | return acc 365 | 366 | 367 | def parse_args(): 368 | parser = argparse.ArgumentParser(description='Inference LLM Answers. ') 369 | parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ') 370 | parser.add_argument( 371 | '--model', 372 | type=str, 373 | help='The LLM (GPT) used for inference. ', 374 | default='chatgpt-0613', 375 | choices=['chatgpt-0613', 'exact_matching', 'gpt-4-0125']) 376 | parser.add_argument( 377 | '--dataset', 378 | type=str, 379 | default='default', 380 | help='The dataset to evaluate') 381 | parser.add_argument('--nproc', type=int, default=6) 382 | parser.add_argument('--verbose', action='store_true') 383 | args = parser.parse_args() 384 | return args 385 | 386 | 387 | if __name__ == '__main__': 388 | args = parse_args() 389 | acc = multiple_choice_eval( 390 | eval_file=args.data, model=args.model, dataset=args.dataset, nproc=args.nproc, verbose=args.verbose) --------------------------------------------------------------------------------