├── vlm_model ├── internvl_v2pe │ ├── __init__.py │ └── internvl │ │ ├── __init__.py │ │ ├── model │ │ ├── __init__.py │ │ └── internvl_chat │ │ │ ├── __init__.py │ │ │ ├── flash_attention.py │ │ │ ├── configuration_intern_vit.py │ │ │ └── configuration_internvl_chat.py │ │ ├── serve │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── mm_utils.py │ │ └── utils.py │ │ ├── train │ │ ├── __init__.py │ │ └── constants.py │ │ ├── patch │ │ ├── llama_rmsnorm_monkey_patch.py │ │ ├── __init__.py │ │ ├── train_dataloader_patch.py │ │ ├── llama_packed_training_patch.py │ │ ├── qwen2_packed_training_patch.py │ │ ├── train_sampler_patch.py │ │ ├── pad_data_collator.py │ │ └── internlm2_packed_training_patch.py │ │ └── dist_utils.py ├── nvila_2b_ef8fa9c8 │ ├── __init__.py │ ├── model_utils_packing.py │ ├── loss.py │ ├── distributed.py │ ├── constants.py │ ├── configuration_vila.py │ ├── media_encoder.py │ ├── media.py │ └── conversation.py ├── nvila_8b_e2481b0c │ ├── __init__.py │ ├── model_utils_packing.py │ ├── loss.py │ ├── distributed.py │ ├── constants.py │ ├── configuration_vila.py │ ├── media_encoder.py │ └── media.py ├── qwen2_with_prefill │ └── __init__.py ├── mplug_owl3.py ├── minicpm.py ├── qwen_vl.py ├── __init__.py └── model_utils.py ├── .gitignore ├── assets ├── overview_page.jpg └── comparison_page.jpg ├── scripts ├── eval_gpt4_summ.sh ├── download_text_data.sh └── download_image_data.sh ├── configs ├── text_rag_all.yaml ├── text_docqa_all.yaml ├── summ_all.yaml ├── vrag_all.yaml ├── vh_all.yaml ├── docqa_all.yaml ├── icl_all.yaml ├── mm_niah_text_all.yaml └── mm_niah_image_all.yaml ├── requirements.txt ├── LICENSE.txt ├── figure_scripts ├── 13_vh_multi_difficulty.py ├── 13_vh_difficulty.py ├── 1_main_most_models.py ├── 14_needle_modal_difficulty.py ├── 2_main_full_models.py ├── 6_correlation_NIAH_all.py ├── 5_correlation_NIAH_most.py ├── 8_correlation_all_datasets.py ├── 4_main_full_models_split.py ├── 11_task_diffculty.py ├── 7_correlation_all_categories.py ├── 10_plot_NIAH_distribution.py ├── 16_docqa_pie_figure.py ├── 9_heatmap_by_depth.py ├── 3_main_full_models_by_category.py ├── 18_internvl2_V2PE.py ├── 18_qwen2_5_yarn.py ├── 12_metrics_for_summ.py ├── 15_rag_modal_difficulty.py └── 15_docqa_modal_difficulty.py └── arguments.py /vlm_model/internvl_v2pe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/qwen2_with_prefill/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/serve/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .idea 3 | .DS_Store 4 | __pycache__ -------------------------------------------------------------------------------- /assets/overview_page.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EdinburghNLP/MMLongBench/HEAD/assets/overview_page.jpg -------------------------------------------------------------------------------- /assets/comparison_page.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EdinburghNLP/MMLongBench/HEAD/assets/comparison_page.jpg -------------------------------------------------------------------------------- /scripts/eval_gpt4_summ.sh: -------------------------------------------------------------------------------- 1 | for i in {0..3}; do python scripts/eval_gpt4_summ.py --num_shards 4 --shard_idx $i & done 2 | -------------------------------------------------------------------------------- /scripts/download_text_data.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/datasets/ZhaoweiWang/MMLongBench/resolve/main/0_mmlb_data.tar.gz 2 | # or 3 | # huggingface-cli download ZhaoweiWang/MMLongBench 0_mmlb_data.tar.gz --local-dir ./ --repo-type dataset 4 | 5 | tar -xzvf 0_mmlb_data.tar.gz -------------------------------------------------------------------------------- /configs/text_rag_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072 2 | datasets: triviaqa,triviaqa,triviaqa,triviaqa,triviaqa 3 | generation_max_length: 128,128,128,128,128 4 | test_files: vrag/viquae_K8_dep6.jsonl,vrag/viquae_K16_dep6.jsonl,vrag/viquae_K32_dep6.jsonl,vrag/viquae_K64_dep6.jsonl,vrag/viquae_K128_dep6.jsonl 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /configs/text_docqa_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072 2 | datasets: text_doc,text_doc,text_doc,text_doc,text_doc 3 | generation_max_length: 128,128,128,128,128 4 | test_files: documentQA/text_mmlongdoc_K8.jsonl,documentQA/text_mmlongdoc_K16.jsonl,documentQA/text_mmlongdoc_K32.jsonl,documentQA/text_mmlongdoc_K64.jsonl,documentQA/text_mmlongdoc_K128.jsonl 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/serve/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = '.' 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = '' 10 | DEFAULT_IMAGE_PATCH_TOKEN = '' 11 | DEFAULT_IM_START_TOKEN = '' 12 | DEFAULT_IM_END_TOKEN = '' 13 | IMAGE_PLACEHOLDER = '' 14 | -------------------------------------------------------------------------------- /configs/summ_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: gov-report,gov-report,gov-report,gov-report,gov-report,multi-lexsum,multi-lexsum,multi-lexsum,multi-lexsum,multi-lexsum 3 | generation_max_length: 384,384,384,384,384,384,384,384,384,384 4 | test_files: summ/gov_K8.jsonl,summ/gov_K16.jsonl,summ/gov_K32.jsonl,summ/gov_K64.jsonl,summ/gov_K128.jsonl,summ/lexsum_K8.jsonl,summ/lexsum_K16.jsonl,summ/lexsum_K32.jsonl,summ/lexsum_K64.jsonl,summ/lexsum_K128.jsonl 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /configs/vrag_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: infoseek,infoseek,infoseek,infoseek,infoseek,viquae,viquae,viquae,viquae,viquae 3 | generation_max_length: 128,128,128,128,128,128,128,128,128,128 4 | test_files: vrag/infoseek_K8_dep3.jsonl,vrag/infoseek_K16_dep3.jsonl,vrag/infoseek_K32_dep3.jsonl,vrag/infoseek_K64_dep3.jsonl,vrag/infoseek_K128_dep3.jsonl,vrag/viquae_K8_dep6.jsonl,vrag/viquae_K16_dep6.jsonl,vrag/viquae_K32_dep6.jsonl,vrag/viquae_K64_dep6.jsonl,vrag/viquae_K128_dep6.jsonl 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/train/constants.py: -------------------------------------------------------------------------------- 1 | IMG_CONTEXT_TOKEN = '' 2 | IMG_START_TOKEN = '' 3 | IMG_END_TOKEN = '' 4 | 5 | QUAD_START_TOKEN = '' 6 | QUAD_END_TOKEN = '' 7 | 8 | REF_START_TOKEN = '' 9 | REF_END_TOKEN = '' 10 | 11 | BOX_START_TOKEN = '' 12 | BOX_END_TOKEN = '' 13 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 14 | IMAGENET_STD = (0.229, 0.224, 0.225) 15 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) 16 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711) 17 | SIGLIP_MEAN = (0.5, 0.5, 0.5) 18 | SIGLIP_STD = (0.5, 0.5, 0.5) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rouge_score==0.1.2 2 | transformers==4.49.0 # Gemma3 needs 4.50 3 | datasets==3.1.0 4 | accelerate==1.2.0 5 | peft==0.15.0 6 | flash_attn==2.7.4.post1 7 | torch==2.6.0 8 | torchvision==0.21.0 9 | qwen_vl_utils==0.0.10 # for qwen2 10 | transformers-stream-generator==0.0.4 11 | autoawq==0.2.7.post3 # qwen2-awq 12 | backoff # internvl, phi 4 13 | timm # internvl 14 | sentencepiece # internvl 15 | einops # phi, NVila 16 | openai==1.73.0 # openai and anthropic 17 | google-genai # google gemini 18 | git+https://github.com/bfshi/scaling_on_scales.git # NVila 19 | imageio # for v2pe 20 | decord # for v2pe 21 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/model/internvl_chat/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2023 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .configuration_intern_vit import InternVisionConfig 8 | from .configuration_internvl_chat import InternVLChatConfig 9 | from .modeling_intern_vit import InternVisionModel 10 | from .modeling_internvl_chat import InternVLChatModel 11 | 12 | __all__ = ['InternVisionConfig', 'InternVisionModel', 13 | 'InternVLChatConfig', 'InternVLChatModel'] 14 | -------------------------------------------------------------------------------- /configs/vh_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: vh_single,vh_single,vh_single,vh_single,vh_single,vh_multi,vh_multi,vh_multi,vh_multi,vh_multi 3 | generation_max_length: 128,128,128,128,128,128,128,128,128,128 4 | test_files: NIAH/vh_single_test_1000_K8_dep6.jsonl,NIAH/vh_single_test_1000_K16_dep6.jsonl,NIAH/vh_single_test_1000_K32_dep6.jsonl,NIAH/vh_single_test_1000_K64_dep6.jsonl,NIAH/vh_single_test_1000_K128_dep6.jsonl,NIAH/vh_multi_test_1000_K8_dep3.jsonl,NIAH/vh_multi_test_1000_K16_dep3.jsonl,NIAH/vh_multi_test_1000_K32_dep3.jsonl,NIAH/vh_multi_test_1000_K64_dep3.jsonl,NIAH/vh_multi_test_1000_K128_dep3.jsonl 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /scripts/download_image_data.sh: -------------------------------------------------------------------------------- 1 | # download images 2 | for file in 1_vrag_image.tar.gz 2_vh_image.tar.gz 2_mm-niah_image.tar.gz 3_icl_image.tar.gz 4_summ_image.tar.gz 5_docqa_image.tar.gz; do 3 | wget -c https://huggingface.co/datasets/ZhaoweiWang/MMLongBench/resolve/main/$file 4 | done 5 | # or 6 | #for file in 1_vrag_image.tar.gz 2_vh_image.tar.gz 2_mm-niah_image.tar.gz 3_icl_image.tar.gz 4_summ_image.tar.gz 5_docqa_image.tar.gz; do 7 | # huggingface-cli download ZhaoweiWang/MMLongBench $file --local-dir ./ --repo-type dataset 8 | #done 9 | 10 | # decompress images 11 | for file in 1_vrag_image.tar.gz 2_vh_image.tar.gz 2_mm-niah_image.tar.gz 3_icl_image.tar.gz 4_summ_image.tar.gz 5_docqa_image.tar.gz; do 12 | tar -xzvf "$file" 13 | done -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/llama_rmsnorm_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | 4 | def replace_llama_rmsnorm_with_fused_rmsnorm(): 5 | try: 6 | from functools import partial 7 | 8 | from apex.normalization import FusedRMSNorm 9 | LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa 10 | transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm 11 | print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm') 12 | except ImportError: 13 | # using the normal LlamaRMSNorm 14 | pass 15 | except Exception: 16 | print('discovered apex but it failed to load, falling back to LlamaRMSNorm') 17 | pass 18 | -------------------------------------------------------------------------------- /configs/docqa_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: longdocurl,longdocurl,longdocurl,longdocurl,longdocurl,mmlongdoc,mmlongdoc,mmlongdoc,mmlongdoc,mmlongdoc,slidevqa,slidevqa,slidevqa,slidevqa,slidevqa 3 | generation_max_length: 128,128,128,128,128,128,128,128,128,128,128,128,128,128,128 4 | test_files: documentQA/longdocurl_K8.jsonl,documentQA/longdocurl_K16.jsonl,documentQA/longdocurl_K32.jsonl,documentQA/longdocurl_K64.jsonl,documentQA/longdocurl_K128.jsonl,documentQA/mmlongdoc_K8.jsonl,documentQA/mmlongdoc_K16.jsonl,documentQA/mmlongdoc_K32.jsonl,documentQA/mmlongdoc_K64.jsonl,documentQA/mmlongdoc_K128.jsonl,documentQA/slidevqa_K8.jsonl,documentQA/slidevqa_K16.jsonl,documentQA/slidevqa_K32.jsonl,documentQA/slidevqa_K64.jsonl,documentQA/slidevqa_K128.jsonl 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /configs/icl_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072,8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: cars196,cars196,cars196,cars196,cars196,food101,food101,food101,food101,food101,inat2021,inat2021,inat2021,inat2021,inat2021,sun397,sun397,sun397,sun397,sun397 3 | generation_max_length: 128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128 4 | test_files: ICL/cars196_K8.json,ICL/cars196_K16.json,ICL/cars196_K32.json,ICL/cars196_K64.json,ICL/cars196_K128.json,ICL/food101_K8.json,ICL/food101_K16.json,ICL/food101_K32.json,ICL/food101_K64.json,ICL/food101_K128.json,ICL/inat2021_K8.json,ICL/inat2021_K16.json,ICL/inat2021_K32.json,ICL/inat2021_K64.json,ICL/inat2021_K128.json,ICL/sun397_K8.json,ICL/sun397_K16.json,ICL/sun397_K32.json,ICL/sun397_K64.json,ICL/sun397_K128.json 5 | use_chat_template: true 6 | max_test_samples: 100 7 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 HKUST-KnowComp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /configs/mm_niah_text_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: mm_niah_retrieval-text,mm_niah_retrieval-text,mm_niah_retrieval-text,mm_niah_retrieval-text,mm_niah_retrieval-text,mm_niah_counting-text,mm_niah_counting-text,mm_niah_counting-text,mm_niah_counting-text,mm_niah_counting-text,mm_niah_reasoning-text,mm_niah_reasoning-text,mm_niah_reasoning-text,mm_niah_reasoning-text,mm_niah_reasoning-text 3 | generation_max_length: 128,128,128,128,128,128,128,128,128,128,128,128,128,128,128 4 | test_files: NIAH/retrieval-text_test_K8_dep6.jsonl,NIAH/retrieval-text_test_K16_dep6.jsonl,NIAH/retrieval-text_test_K32_dep6.jsonl,NIAH/retrieval-text_test_K64_dep6.jsonl,NIAH/retrieval-text_test_K128_dep6.jsonl,NIAH/counting-text_test_K8_dep3.jsonl,NIAH/counting-text_test_K16_dep3.jsonl,NIAH/counting-text_test_K32_dep3.jsonl,NIAH/counting-text_test_K64_dep3.jsonl,NIAH/counting-text_test_K128_dep3.jsonl,NIAH/reasoning-text_test_K8_dep3.jsonl,NIAH/reasoning-text_test_K16_dep3.jsonl,NIAH/reasoning-text_test_K32_dep3.jsonl,NIAH/reasoning-text_test_K64_dep3.jsonl,NIAH/reasoning-text_test_K128_dep3.jsonl 5 | use_chat_template: true 6 | max_test_samples: 50 7 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .internlm2_packed_training_patch import replace_internlm2_attention_class 2 | from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn 3 | from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 4 | from .llama_packed_training_patch import replace_llama_attention_class 5 | from .llama_rmsnorm_monkey_patch import \ 6 | replace_llama_rmsnorm_with_fused_rmsnorm 7 | from .pad_data_collator import concat_pad_data_collator, pad_data_collator 8 | from .qwen2_packed_training_patch import replace_qwen2_attention_class 9 | from .train_dataloader_patch import replace_train_dataloader 10 | from .train_sampler_patch import replace_train_sampler 11 | 12 | __all__ = ['replace_llama_attn_with_flash_attn', 13 | 'replace_llama_rmsnorm_with_fused_rmsnorm', 14 | 'replace_llama2_attn_with_flash_attn', 15 | 'replace_train_sampler', 16 | 'replace_train_dataloader', 17 | 'replace_internlm2_attention_class', 18 | 'replace_qwen2_attention_class', 19 | 'replace_llama_attention_class', 20 | 'pad_data_collator', 21 | 'concat_pad_data_collator'] 22 | -------------------------------------------------------------------------------- /configs/mm_niah_image_all.yaml: -------------------------------------------------------------------------------- 1 | input_max_length: 8192,16384,32768,65536,131072,8192,16384,32768,65536,131072,8192,16384,32768,65536,131072 2 | datasets: mm_niah_retrieval-image,mm_niah_retrieval-image,mm_niah_retrieval-image,mm_niah_retrieval-image,mm_niah_retrieval-image,mm_niah_counting-image,mm_niah_counting-image,mm_niah_counting-image,mm_niah_counting-image,mm_niah_counting-image,mm_niah_reasoning-image,mm_niah_reasoning-image,mm_niah_reasoning-image,mm_niah_reasoning-image,mm_niah_reasoning-image 3 | generation_max_length: 128,128,128,128,128,128,128,128,128,128,128,128,128,128,128 4 | test_files: NIAH/retrieval-image_test_K8_dep6.jsonl,NIAH/retrieval-image_test_K16_dep6.jsonl,NIAH/retrieval-image_test_K32_dep6.jsonl,NIAH/retrieval-image_test_K64_dep6.jsonl,NIAH/retrieval-image_test_K128_dep6.jsonl,NIAH/counting-image_test_K8_dep3.jsonl,NIAH/counting-image_test_K16_dep3.jsonl,NIAH/counting-image_test_K32_dep3.jsonl,NIAH/counting-image_test_K64_dep3.jsonl,NIAH/counting-image_test_K128_dep3.jsonl,NIAH/reasoning-image_test_K8_dep6.jsonl,NIAH/reasoning-image_test_K16_dep6.jsonl,NIAH/reasoning-image_test_K32_dep6.jsonl,NIAH/reasoning-image_test_K64_dep6.jsonl,NIAH/reasoning-image_test_K128_dep6.jsonl 5 | use_chat_template: true 6 | max_test_samples: 50 7 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/model_utils_packing.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from typing import Tuple 3 | 4 | import torch 5 | import transformers 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | __all__ = ["patch"] 10 | 11 | 12 | def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]: 13 | if hasattr(_get_unpad_data, "seqlens_in_batch"): 14 | seqlens_in_batch = _get_unpad_data.seqlens_in_batch 15 | else: 16 | seqlens_in_batch = torch.sum(attention_mask, dim=1) 17 | 18 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 19 | max_seqlen_in_batch = seqlens_in_batch.max().item() 20 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 21 | return indices, cu_seqlens, max_seqlen_in_batch 22 | 23 | 24 | def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None: 25 | _get_unpad_data.seqlens_in_batch = seqlens_in_batch 26 | 27 | 28 | def patch(model: nn.Module) -> None: 29 | if transformers.__version__ < "4.43.0": 30 | m = import_module(model.__module__) 31 | if not hasattr(m, "_get_unpad_data"): 32 | raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing") 33 | m._get_unpad_data = _get_unpad_data 34 | else: 35 | transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data 36 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/model_utils_packing.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from typing import Tuple 3 | 4 | import torch 5 | import transformers 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | __all__ = ["patch"] 10 | 11 | 12 | def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]: 13 | if hasattr(_get_unpad_data, "seqlens_in_batch"): 14 | seqlens_in_batch = _get_unpad_data.seqlens_in_batch 15 | else: 16 | seqlens_in_batch = torch.sum(attention_mask, dim=1) 17 | 18 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 19 | max_seqlen_in_batch = seqlens_in_batch.max().item() 20 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 21 | return indices, cu_seqlens, max_seqlen_in_batch 22 | 23 | 24 | def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None: 25 | _get_unpad_data.seqlens_in_batch = seqlens_in_batch 26 | 27 | 28 | def patch(model: nn.Module) -> None: 29 | if transformers.__version__ < "4.43.0": 30 | m = import_module(model.__module__) 31 | if not hasattr(m, "_get_unpad_data"): 32 | raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing") 33 | m._get_unpad_data = _get_unpad_data 34 | else: 35 | transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data 36 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | from torch.nn.functional import cross_entropy 5 | 6 | from .constants import IGNORE_INDEX 7 | 8 | __all__ = ["soft_cross_entropy"] 9 | 10 | 11 | def soft_cross_entropy( 12 | outputs: torch.Tensor, 13 | targets: torch.Tensor, 14 | soft_tokens: Union[torch.Tensor, List[int]], 15 | std: float = 1, 16 | ignore_index: int = IGNORE_INDEX, 17 | ) -> torch.Tensor: 18 | # Remove last token from outputs and first token from targets 19 | outputs = outputs[..., :-1, :].contiguous() 20 | targets = targets[..., 1:].contiguous() 21 | 22 | # Flatten outputs and targets 23 | targets = targets.view(-1) 24 | outputs = outputs.view(targets.size(0), -1) 25 | 26 | # Remove outputs and targets with ignore_index 27 | indices = targets != ignore_index 28 | outputs = outputs[indices] 29 | targets = targets[indices] 30 | 31 | # Convert soft token IDs to tensor 32 | if isinstance(soft_tokens, list): 33 | soft_tokens = torch.tensor(soft_tokens).to(targets) 34 | 35 | # Calculate loss for non-soft tokens 36 | indices = torch.isin(targets, soft_tokens, invert=True) 37 | loss = cross_entropy(outputs[indices], targets[indices], reduction="sum") 38 | 39 | # Calculate loss for soft tokens 40 | indices = torch.isin(targets, soft_tokens) 41 | targets_indices = torch.zeros_like(outputs[indices]) 42 | for k, target in enumerate(targets[indices]): 43 | dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2)) 44 | targets_indices[k][soft_tokens] = dist / dist.sum() 45 | loss += cross_entropy(outputs[indices], targets_indices, reduction="sum") 46 | 47 | # Return average loss 48 | return loss / targets.size(0) 49 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | from torch.nn.functional import cross_entropy 5 | 6 | from .constants import IGNORE_INDEX 7 | 8 | __all__ = ["soft_cross_entropy"] 9 | 10 | 11 | def soft_cross_entropy( 12 | outputs: torch.Tensor, 13 | targets: torch.Tensor, 14 | soft_tokens: Union[torch.Tensor, List[int]], 15 | std: float = 1, 16 | ignore_index: int = IGNORE_INDEX, 17 | ) -> torch.Tensor: 18 | # Remove last token from outputs and first token from targets 19 | outputs = outputs[..., :-1, :].contiguous() 20 | targets = targets[..., 1:].contiguous() 21 | 22 | # Flatten outputs and targets 23 | targets = targets.view(-1) 24 | outputs = outputs.view(targets.size(0), -1) 25 | 26 | # Remove outputs and targets with ignore_index 27 | indices = targets != ignore_index 28 | outputs = outputs[indices] 29 | targets = targets[indices] 30 | 31 | # Convert soft token IDs to tensor 32 | if isinstance(soft_tokens, list): 33 | soft_tokens = torch.tensor(soft_tokens).to(targets) 34 | 35 | # Calculate loss for non-soft tokens 36 | indices = torch.isin(targets, soft_tokens, invert=True) 37 | loss = cross_entropy(outputs[indices], targets[indices], reduction="sum") 38 | 39 | # Calculate loss for soft tokens 40 | indices = torch.isin(targets, soft_tokens) 41 | targets_indices = torch.zeros_like(outputs[indices]) 42 | for k, target in enumerate(targets[indices]): 43 | dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2)) 44 | targets_indices[k][soft_tokens] = dist / dist.sum() 45 | loss += cross_entropy(outputs[indices], targets_indices, reduction="sum") 46 | 47 | # Return average loss 48 | return loss / targets.size(0) 49 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Any, List, Optional 4 | 5 | from torch import distributed as dist 6 | 7 | __all__ = [ 8 | "init", 9 | "is_initialized", 10 | "size", 11 | "rank", 12 | "local_size", 13 | "local_rank", 14 | "is_main", 15 | "barrier", 16 | "gather", 17 | "all_gather", 18 | ] 19 | 20 | 21 | def init() -> None: 22 | if "RANK" not in os.environ: 23 | warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") 24 | return 25 | dist.init_process_group(backend="nccl", init_method="env://") 26 | 27 | 28 | def is_initialized() -> bool: 29 | return dist.is_initialized() 30 | 31 | 32 | def size() -> int: 33 | return int(os.environ.get("WORLD_SIZE", 1)) 34 | 35 | 36 | def rank() -> int: 37 | return int(os.environ.get("RANK", 0)) 38 | 39 | 40 | def local_size() -> int: 41 | return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) 42 | 43 | 44 | def local_rank() -> int: 45 | return int(os.environ.get("LOCAL_RANK", 0)) 46 | 47 | 48 | def is_main() -> bool: 49 | return rank() == 0 50 | 51 | 52 | def barrier() -> None: 53 | dist.barrier() 54 | 55 | 56 | def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: 57 | if not is_initialized(): 58 | return [obj] 59 | if is_main(): 60 | objs = [None for _ in range(size())] 61 | dist.gather_object(obj, objs, dst=dst) 62 | return objs 63 | else: 64 | dist.gather_object(obj, dst=dst) 65 | return None 66 | 67 | 68 | def all_gather(obj: Any) -> List[Any]: 69 | if not is_initialized(): 70 | return [obj] 71 | objs = [None for _ in range(size())] 72 | dist.all_gather_object(objs, obj) 73 | return objs 74 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Any, List, Optional 4 | 5 | from torch import distributed as dist 6 | 7 | __all__ = [ 8 | "init", 9 | "is_initialized", 10 | "size", 11 | "rank", 12 | "local_size", 13 | "local_rank", 14 | "is_main", 15 | "barrier", 16 | "gather", 17 | "all_gather", 18 | ] 19 | 20 | 21 | def init() -> None: 22 | if "RANK" not in os.environ: 23 | warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") 24 | return 25 | dist.init_process_group(backend="nccl", init_method="env://") 26 | 27 | 28 | def is_initialized() -> bool: 29 | return dist.is_initialized() 30 | 31 | 32 | def size() -> int: 33 | return int(os.environ.get("WORLD_SIZE", 1)) 34 | 35 | 36 | def rank() -> int: 37 | return int(os.environ.get("RANK", 0)) 38 | 39 | 40 | def local_size() -> int: 41 | return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) 42 | 43 | 44 | def local_rank() -> int: 45 | return int(os.environ.get("LOCAL_RANK", 0)) 46 | 47 | 48 | def is_main() -> bool: 49 | return rank() == 0 50 | 51 | 52 | def barrier() -> None: 53 | dist.barrier() 54 | 55 | 56 | def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: 57 | if not is_initialized(): 58 | return [obj] 59 | if is_main(): 60 | objs = [None for _ in range(size())] 61 | dist.gather_object(obj, objs, dst=dst) 62 | return objs 63 | else: 64 | dist.gather_object(obj, dst=dst) 65 | return None 66 | 67 | 68 | def all_gather(obj: Any) -> List[Any]: 69 | if not is_initialized(): 70 | return [obj] 71 | objs = [None for _ in range(size())] 72 | dist.all_gather_object(objs, obj) 73 | return objs 74 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/train_dataloader_patch.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | import transformers 4 | from torch.utils.data import DataLoader 5 | from transformers.trainer import is_datasets_available, seed_worker 6 | 7 | 8 | def get_train_dataloader(self) -> DataLoader: 9 | """ 10 | Returns the training [`~torch.utils.data.DataLoader`]. 11 | 12 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 13 | training if necessary) otherwise. 14 | 15 | Subclass and override this method if you want to inject some custom behavior. 16 | """ 17 | if self.train_dataset is None: 18 | raise ValueError('Trainer: training requires a train_dataset.') 19 | 20 | train_dataset = self.train_dataset 21 | data_collator = self.data_collator 22 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 23 | train_dataset = self._remove_unused_columns(train_dataset, description='training') 24 | else: 25 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training') 26 | 27 | dataloader_params = { 28 | 'batch_size': self._train_batch_size, 29 | 'collate_fn': data_collator, 30 | 'num_workers': self.args.dataloader_num_workers, 31 | 'pin_memory': self.args.dataloader_pin_memory, 32 | 'persistent_workers': self.args.dataloader_persistent_workers, 33 | } 34 | 35 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 36 | dataloader_params['sampler'] = self._get_train_sampler() 37 | dataloader_params['drop_last'] = self.args.dataloader_drop_last 38 | dataloader_params['worker_init_fn'] = seed_worker 39 | 40 | if self.args.use_packed_ds: 41 | return DataLoader(train_dataset, **dataloader_params) 42 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 43 | 44 | 45 | def replace_train_dataloader(): 46 | transformers.Trainer.get_train_dataloader = get_train_dataloader 47 | print('Replace train dataloader!!') 48 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 18 | WORKER_HEART_BEAT_INTERVAL = 15 19 | 20 | LOGDIR = "." 21 | 22 | # Model Constants 23 | IGNORE_INDEX = -100 24 | DEFAULT_IMAGE_TOKEN = "" 25 | 26 | SENTINEL_TOKEN = "" 27 | MEDIA_TOKENS = { 28 | "image": "", 29 | "video": "", 30 | } 31 | # 32 | # TODO(ligeng): need to discuss with Zhijian for the following tokens for different models. 33 | """ 34 | 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 35 | 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 36 | 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 37 | 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 38 | 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 39 | 151648: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 40 | 151649: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 41 | 151650: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 42 | """ 43 | NUM_EXTRA_TOKENS = 8 44 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 18 | WORKER_HEART_BEAT_INTERVAL = 15 19 | 20 | LOGDIR = "." 21 | 22 | # Model Constants 23 | IGNORE_INDEX = -100 24 | DEFAULT_IMAGE_TOKEN = "" 25 | 26 | SENTINEL_TOKEN = "" 27 | MEDIA_TOKENS = { 28 | "image": "", 29 | "video": "", 30 | } 31 | # 32 | # TODO(ligeng): need to discuss with Zhijian for the following tokens for different models. 33 | """ 34 | 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 35 | 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 36 | 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 37 | 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 38 | 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 39 | 151648: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 40 | 151649: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 41 | 151650: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 42 | """ 43 | NUM_EXTRA_TOKENS = 8 44 | -------------------------------------------------------------------------------- /figure_scripts/13_vh_multi_difficulty.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib 3 | 4 | # plot specific ones in a row, for formatting in the paper 5 | lf_df = process_df(all_df) 6 | length_datasets = ["VH-Multi"] 7 | 8 | ncols = 2 9 | nrows = (len(length_datasets) - 1) // ncols + 1 10 | 11 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 12 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 13 | fig.set_size_inches((ncols * 6, nrows * 5)) 14 | plt.rc('axes', unicode_minus=False) 15 | plt.rcParams.update({'axes.unicode_minus': False}) 16 | 17 | base_index_order = [ 18 | "GPT-4o", "Gemini-2.5-Pro", 19 | 'Qwen2.5-VL-7B-Inst', 'Qwen2.5-VL-32B-Inst', 'Qwen2.5-VL-72B-Inst', 20 | "Gemma3-4B", "Gemma3-12B", "Gemma3-27B", 21 | ] 22 | 23 | for i, dataset in enumerate(length_datasets): 24 | if nrows > 1: 25 | a = ax[i // ncols][i % ncols] 26 | else: 27 | a = ax[i] 28 | 29 | tdf = lf_df[lf_df.input_max_length > 4096] 30 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 31 | tdf = tdf.reindex(base_index_order) 32 | 33 | final_index_order = list(base_index_order) 34 | 35 | sns_g = sns.heatmap( 36 | tdf, annot=True, cmap=custom_cmap, fmt=".1f", yticklabels=True, 37 | ax=a, annot_kws={"fontsize": 22}, 38 | cbar=False 39 | ) 40 | sns_g.set_title(dataset, fontsize=22) 41 | 42 | sns_g.set_ylabel("") 43 | sns_g.set_xlabel("") 44 | 45 | written_index = [x.replace("-Inst", '') for x in final_index_order] 46 | sns_g.set_yticklabels(written_index, size=18, fontweight='bold') 47 | xticks_map = {"8192": '8k', "16384": '16k', "32768": '32k', "65536":'64k', "131072":'128k'} 48 | sns_g.set_xticklabels([xticks_map[st.get_text()] for st in sns_g.get_xticklabels()], size=22) 49 | 50 | # idx, start, end 51 | a.hlines([2, 5], 0, 6, color="0.95", linestyle="-", linewidth=3) 52 | 53 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 54 | 55 | plt.tight_layout() 56 | file_path = os.path.join(project_root, f"figures/13_vh_multi_difficulty.pdf") 57 | plt.savefig(file_path, dpi=500, format="pdf") 58 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/13_vh_difficulty.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib 3 | 4 | # plot specific ones in a row, for formatting in the paper 5 | lf_df = process_df(all_df) 6 | length_datasets = ["VH-Single", "VH-Multi"] 7 | 8 | ncols = 2 9 | nrows = (len(length_datasets) - 1) // ncols + 1 10 | 11 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 12 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 13 | fig.set_size_inches((ncols * 6 + 2, nrows * 5)) 14 | plt.rc('axes', unicode_minus=False) 15 | plt.rcParams.update({'axes.unicode_minus': False}) 16 | 17 | base_index_order = [ 18 | "GPT-4o", "Gemini-2.5-Pro", 19 | 'Qwen2.5-VL-7B-Inst', 'Qwen2.5-VL-32B-Inst', 'Qwen2.5-VL-72B-Inst', 20 | "Gemma3-4B", "Gemma3-12B", "Gemma3-27B", 21 | ] 22 | 23 | for i, dataset in enumerate(length_datasets): 24 | if nrows > 1: 25 | a = ax[i // ncols][i % ncols] 26 | else: 27 | a = ax[i] 28 | 29 | tdf = lf_df[lf_df.input_max_length > 4096] 30 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 31 | tdf = tdf.reindex(base_index_order) 32 | 33 | final_index_order = list(base_index_order) 34 | 35 | sns_g = sns.heatmap( 36 | tdf, annot=True, cmap=custom_cmap, fmt=".1f", yticklabels=True, 37 | ax=a, annot_kws={"fontsize": 22}, 38 | cbar=False 39 | ) 40 | sns_g.set_title(dataset, fontsize=26) 41 | 42 | sns_g.set_ylabel("") 43 | sns_g.set_xlabel("") 44 | 45 | written_index = [x.replace("-Inst", '') for x in final_index_order] 46 | sns_g.set_yticklabels(written_index, size=18, fontweight='bold') 47 | xticks_map = {"8192": '8k', "16384": '16k', "32768": '32k', "65536":'64k', "131072":'128k'} 48 | sns_g.set_xticklabels([xticks_map[st.get_text()] for st in sns_g.get_xticklabels()], size=22) 49 | 50 | # idx, start, end 51 | a.hlines([2, 5], 0, 6, color="0.95", linestyle="-", linewidth=3) 52 | 53 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 54 | 55 | plt.tight_layout() 56 | plt.subplots_adjust(left=0.18, wspace=0.28) 57 | file_path = os.path.join(project_root, f"figures/13_vh_difficulty.pdf") 58 | plt.savefig(file_path, dpi=500, format="pdf") 59 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/1_main_most_models.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | main_table_datasets = [ 4 | "VRAG", 5 | "NIAH", 6 | 'ICL', 7 | "Summ", 8 | "DocVQA", 9 | "Ours" 10 | ] 11 | 12 | lf_df = process_df(all_df) 13 | length_datasets = main_table_datasets 14 | 15 | ncols = 3 16 | nrows = (len(length_datasets) - 1) // ncols + 1 17 | 18 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 19 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 20 | fig.set_size_inches((ncols * 8, 26)) # helmet has 22 models for 26 height, we have 23 models 21 | 22 | plt.rc('axes', unicode_minus=False) 23 | plt.rcParams.update({'axes.unicode_minus': False}) 24 | 25 | for i, dataset in enumerate(length_datasets): 26 | if nrows > 1: 27 | a = ax[i // ncols][i % ncols] 28 | else: 29 | a = ax[i] 30 | 31 | new_index = main_table_models 32 | 33 | tdf = lf_df[lf_df.input_max_length > 4096] 34 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 35 | tdf = tdf.reindex(new_index) 36 | 37 | # process the scores 38 | annot_matrix = tdf.copy() 39 | tdf = tdf.applymap(lambda x: x if not pd.isna(x) else 0) 40 | annot_matrix = annot_matrix.applymap(lambda x: "N/A" if pd.isna(x) else f"{x:.1f}") 41 | 42 | 43 | sns_g = sns.heatmap( 44 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 45 | ax=a, annot_kws={"fontsize": 23.5}, 46 | cbar=False 47 | ) 48 | sns_g.set_title(dataset if dataset != "Ours" else "Avg.", fontsize=34) 49 | 50 | sns_g.set_ylabel("") 51 | sns_g.set_xlabel("") 52 | 53 | new_index = [x.replace("-Inst", '') for x in new_index] 54 | sns_g.set_yticklabels(new_index, size=28) 55 | 56 | xticks = ['8k', '16k', '32k', '64k', '128k'] 57 | sns_g.set_xticklabels(xticks, size=28) 58 | 59 | for idx in [6, 10, 14, 17, 19]: 60 | a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 61 | 62 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 63 | 64 | 65 | plt.tight_layout() 66 | plt.subplots_adjust(left=0.17, wspace=0.15) 67 | plt.savefig(os.path.join(project_root, f"figures/1_results_length_main.pdf"), dpi=500, format="pdf") 68 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/14_needle_modal_difficulty.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib 3 | 4 | # plot specific ones in a row, for formatting in the paper 5 | lf_df = process_df(all_df) 6 | length_datasets = ['MM-NIAH-Ret (T)', 'MM-NIAH-Ret (I)', 'MM-NIAH-Count (T)', 'MM-NIAH-Count (I)', 'MM-NIAH-Reason (T)', 'MM-NIAH-Reason (I)'] 7 | 8 | ncols = 3 9 | nrows = (len(length_datasets) - 1) // ncols + 1 10 | 11 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 12 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 13 | fig.set_size_inches((ncols * 8, nrows * 5)) 14 | plt.rc('axes', unicode_minus=False) 15 | plt.rcParams.update({'axes.unicode_minus': False}) 16 | 17 | base_index_order = [ 18 | "GPT-4o", "Gemini-2.5-Pro", 19 | 'Qwen2.5-VL-32B-Inst', 'Qwen2.5-VL-72B-Inst', 20 | "Ovis2-16B", "Ovis2-34B", 21 | "Gemma3-12B", "Gemma3-27B", 22 | ] 23 | 24 | for i, dataset in enumerate(length_datasets): 25 | if nrows > 1: 26 | a = ax[i // ncols][i % ncols] 27 | else: 28 | a = ax[i] 29 | 30 | tdf = lf_df[lf_df.input_max_length > 4096] 31 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 32 | tdf = tdf.reindex(base_index_order) 33 | 34 | final_index_order = list(base_index_order) 35 | 36 | sns_g = sns.heatmap( 37 | tdf, annot=True, cmap=custom_cmap, fmt=".1f", yticklabels=True, 38 | ax=a, annot_kws={"fontsize": 22}, 39 | cbar=False 40 | ) 41 | sns_g.set_title(dataset, fontsize=30) 42 | 43 | sns_g.set_ylabel("") 44 | sns_g.set_xlabel("") 45 | 46 | written_index = [x.replace("-Inst", '') for x in final_index_order] 47 | sns_g.set_yticklabels(written_index, size=26) 48 | xticks_map = {"8192": '8k', "16384": '16k', "32768": '32k', "65536":'64k', "131072":'128k'} 49 | sns_g.set_xticklabels([xticks_map[st.get_text()] for st in sns_g.get_xticklabels()], size=26) 50 | 51 | # idx, start, end 52 | a.hlines([2], 0, 6, color="0.95", linestyle="-", linewidth=3) 53 | 54 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 55 | 56 | plt.tight_layout() 57 | plt.subplots_adjust(left=0.15, wspace=0.3) 58 | file_path = os.path.join(project_root, f"figures/14_needle_modal_difficulty.pdf") 59 | plt.savefig(file_path, dpi=500, format="pdf") 60 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/2_main_full_models.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | # Figure: Full models 5 | 6 | 7 | main_table_datasets = [ 8 | "VRAG", 9 | "NIAH", 10 | 'ICL', 11 | "Summ", 12 | "DocVQA", 13 | "Ours" 14 | ] 15 | 16 | # plot specific ones in a row, for formatting in the paper 17 | lf_df = process_df(all_df) 18 | length_datasets = main_table_datasets 19 | 20 | ncols = 3 21 | nrows = (len(length_datasets) - 1) // ncols + 1 22 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 23 | 24 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 25 | fig.set_size_inches((ncols * 8, 40)) # helmet has 45 models for 40 height, we have 46 models 26 | 27 | plt.rc('axes', unicode_minus=False) 28 | plt.rcParams.update({'axes.unicode_minus': False}) 29 | 30 | for i, dataset in enumerate(length_datasets): 31 | if nrows > 1: 32 | a = ax[i // ncols][i % ncols] 33 | else: 34 | a = ax[i] 35 | 36 | new_index = full_table_models 37 | 38 | tdf = lf_df[lf_df.input_max_length > 4096] 39 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 40 | tdf = tdf.reindex(new_index) 41 | 42 | # process the scores 43 | annot_matrix = tdf.copy() 44 | tdf = tdf.applymap(lambda x: x if not pd.isna(x) else 0) 45 | annot_matrix = annot_matrix.applymap(lambda x: "N/A" if pd.isna(x) else f"{x:.1f}") 46 | 47 | sns_g = sns.heatmap( 48 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 49 | ax=a, annot_kws={"fontsize": 23.5}, 50 | cbar=False 51 | ) 52 | sns_g.set_title(dataset if dataset != "Ours" else "Avg.", fontsize=34) 53 | 54 | sns_g.set_ylabel("") 55 | sns_g.set_xlabel("") 56 | 57 | new_index = [x.replace("-Inst", '') for x in new_index] 58 | sns_g.set_yticklabels(new_index, size=28) 59 | 60 | xticks = ['8k', '16k', '32k', '64k', '128k'] 61 | sns_g.set_xticklabels(xticks, size=28) 62 | 63 | for idx in [6, 13, 27, 33, 36, 40, 43, 45]: 64 | a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 65 | 66 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 67 | 68 | plt.tight_layout() 69 | plt.subplots_adjust(left=0.17, wspace=0.30) 70 | figure_path = os.path.join(project_root, f"figures/2_results_length_full.pdf") 71 | plt.savefig(figure_path, dpi=500, format="pdf") 72 | plt.show() 73 | -------------------------------------------------------------------------------- /figure_scripts/6_correlation_NIAH_all.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # break into different lengths and then plot the pairwise heatmap 4 | lf_df = process_df(all_df) 5 | 6 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#ed4d6e", '#DD9380', '#DEA683', '#CFCC86', "#0CD79F"]) 7 | datasets2 = ['VRAG', 'ICL', 'Summ', 'DocVQA'] 8 | datasets1 = ['VH-Single', 'VH-Multi', 'MM-NIAH-Ret (T)', 'MM-NIAH-Ret (I)', 'MM-NIAH-Ret', 'MM-NIAH-Count (T)', 'MM-NIAH-Count (I)', 'MM-NIAH-Count', 9 | 'MM-NIAH-Reason (T)', 'MM-NIAH-Reason (I)', 'MM-NIAH-Reason', 'NIAH'] 10 | 11 | all_corr = {} 12 | lengths = [131072] 13 | fig, ax = plt.subplots(figsize=(2.5 + 0.7 * (len(datasets2) + 1), 0.7 * len(datasets1)), nrows=1, ncols=len(lengths)) 14 | 15 | for i, (l, d) in enumerate(lf_df.groupby("input_max_length", sort=False)): 16 | if l not in lengths: 17 | continue 18 | i = lengths.index(l) 19 | spearmans = [] 20 | for d1 in datasets1: 21 | for d2 in datasets2: 22 | x = d[d[d1].notnull() & d[d2].notnull()] 23 | m1 = x[d1] 24 | m2 = x[d2] 25 | 26 | if len(m1) < 2 and len(m2) < 2: 27 | continue 28 | 29 | rho, p = stats.spearmanr(m1, m2) 30 | spearmans.append({"dataset 1": d1, "dataset 2": d2, "correlation": rho}) 31 | 32 | all_corr[l] = {"spearman": pd.DataFrame(spearmans)} 33 | for j, (name, table) in enumerate(all_corr[l].items()): 34 | hm = table.pivot_table(index="dataset 1", columns="dataset 2", values="correlation", sort=False) 35 | a = ax[i] if len(lengths) > 1 else ax 36 | hm["Avg"] = hm.mean(axis=1) 37 | def fmt(x): 38 | return "0.00" if abs(x) < 1e-2 else f"{x:.2f}" 39 | annots = np.vectorize(fmt)(hm) 40 | 41 | sns_g = sns.heatmap(hm, annot=annots, ax=a, cbar=False, cmap=cmap, annot_kws={"fontsize": 13.5}, fmt="") 42 | 43 | sns_g.set_ylabel("") 44 | sns_g.set_xlabel("") 45 | 46 | sns_g.set_yticklabels(datasets1, size=18) 47 | sns_g.set_xticklabels(sns_g.get_xticklabels(), size=18) 48 | 49 | a.tick_params(axis='x', rotation=45) 50 | a.tick_params(axis='y', rotation=0) 51 | 52 | plt.setp(a.get_xticklabels(), ha="right", rotation_mode="anchor") 53 | plt.tight_layout() 54 | figure_path = os.path.join(project_root, "figures/6_correlation_NIAH_all.pdf") 55 | plt.savefig(figure_path, dpi=300, format="pdf") 56 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/5_correlation_NIAH_most.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # break into different lengths and then plot the pairwise heatmap 4 | lf_df = process_df(all_df) #, chosen_models=main_table_models) 5 | # synthetic analysis 6 | datasets2 = [ 7 | 'VRAG', 'ICL', 'Summ', 'DocVQA' 8 | ] 9 | datasets1 = ["VH-Single", "VH-Multi", 'MM-NIAH-Ret', 'MM-NIAH-Count', 'MM-NIAH-Reason'] 10 | 11 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#ed4d6e", '#DD9380', '#DEA683', '#CFCC86', "#0CD79F"]) 12 | 13 | all_corr = {} 14 | lengths = [131072] 15 | fig, ax = plt.subplots(figsize=(4.7, 4.8), nrows=1, ncols=len(lengths)) 16 | 17 | for i, (l, d) in enumerate(lf_df.groupby("input_max_length", sort=False)): 18 | if l not in lengths: 19 | continue 20 | i = lengths.index(l) 21 | 22 | spearmans = [] 23 | for d1 in datasets1: 24 | for d2 in datasets2: 25 | x = d[d[d1].notnull() & d[d2].notnull()] 26 | m1 = x[d1] 27 | m2 = x[d2] 28 | 29 | if len(m1) < 2 and len(m2) < 2: 30 | continue 31 | 32 | rho, p = stats.spearmanr(m1, m2) 33 | spearmans.append({"dataset 1": d1, "dataset 2": d2, "correlation": rho}) 34 | 35 | all_corr[l] = {"spearman": pd.DataFrame(spearmans)} 36 | for j, (name, table) in enumerate(all_corr[l].items()): 37 | hm = table.pivot_table(index="dataset 1", columns="dataset 2", values="correlation", sort=False) 38 | a = ax[i] if len(lengths) > 1 else ax 39 | # hm["Avg"] = hm.mean(axis=1) 40 | def fmt(x): 41 | return "0.00" if abs(x) < 1e-2 else f"{x:.2f}" 42 | 43 | annots = np.vectorize(fmt)(hm) 44 | sns_g = sns.heatmap(hm, annot=annots, ax=a, cbar=False, cmap=cmap, annot_kws={"fontsize": 13.5}, fmt="") 45 | 46 | sns_g.set_ylabel("") 47 | sns_g.set_xlabel("") 48 | 49 | ylabels = [n.replace("MM-", "") for n in datasets1] 50 | sns_g.set_yticklabels(ylabels, size=14) 51 | sns_g.set_xticklabels(sns_g.get_xticklabels(), size=14) 52 | 53 | a.tick_params(axis='x', rotation=45) 54 | a.tick_params(axis='y', rotation=0) 55 | 56 | # for idx in [1, 2, 3, 4, 5]: 57 | # a.axhline(idx, color="white", linestyle="-", linewidth=3) 58 | 59 | plt.setp(a.get_xticklabels(), ha="right", rotation_mode="anchor") 60 | plt.tight_layout() 61 | figure_path = os.path.join(project_root, "figures/5_correlation_NIAH_most.pdf") 62 | plt.savefig(figure_path, dpi=300, format="pdf") 63 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/8_correlation_all_datasets.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # break into different lengths and then plot the pairwise heatmap 4 | lf_df = process_df(all_df) 5 | datasets = ["InfoSeek", "ViQuAE", "VRAG", 6 | "VH-Single", "VH-Multi", 'MM-NIAH-Ret', "MM-NIAH-Count", "MM-NIAH-Reason", 'NIAH', 7 | "Stanford Cars", "Food101", "SUN397", "Inat2021", 'ICL', 8 | "GovReport", "Multi-LexSum", 'Summ', 9 | "MMLongBench-Doc", "LongDocURL", "SlideVQA", 'DocVQA', 10 | 'Ours'] 11 | 12 | datasets1, datasets2 = datasets, datasets 13 | 14 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#ed4d6e", '#DD9380', '#DEA683', '#CFCC86', "#0CD79F"]) 15 | 16 | all_corr = {} 17 | lengths = [131072] 18 | fig, ax = plt.subplots(figsize=(len(datasets), len(datasets)), nrows=1, ncols=len(lengths)) 19 | 20 | for i, (l, d) in enumerate(lf_df.groupby("input_max_length", sort=False)): 21 | if l not in lengths: 22 | continue 23 | i = lengths.index(l) 24 | spearmans = [] 25 | for d1 in datasets1: 26 | for d2 in datasets2: 27 | x = d[d[d1].notnull() & d[d2].notnull()] 28 | m1 = x[d1] 29 | m2 = x[d2] 30 | 31 | if len(m1) < 2 and len(m2) < 2: 32 | continue 33 | 34 | rho, p = stats.spearmanr(m1, m2) 35 | spearmans.append({"dataset 1": d1, "dataset 2": d2, "correlation": rho}) 36 | 37 | all_corr[l] = {"spearman": pd.DataFrame(spearmans)} 38 | for j, (name, table) in enumerate(all_corr[l].items()): 39 | 40 | hm = table.pivot_table(index="dataset 1", columns="dataset 2", values="correlation", sort=False) 41 | a = ax[i] if len(lengths) > 1 else ax 42 | def fmt(x): 43 | return "0.00" if abs(x) < 1e-2 else f"{x:.2f}" 44 | annots = np.vectorize(fmt)(hm) 45 | sns_g = sns.heatmap(hm, annot=annots, ax=a, cbar=False, cmap=cmap, annot_kws={"fontsize": 19}, fmt="") 46 | 47 | sns_g.set_ylabel("") 48 | sns_g.set_xlabel("") 49 | 50 | t = datasets 51 | sns_g.set_yticklabels(t, size=24) 52 | sns_g.set_xticklabels(t, size=24) 53 | 54 | a.tick_params(axis='x', rotation=45) 55 | a.tick_params(axis='y', rotation=0) 56 | 57 | for idx in [3, 9, 14, 17, 21]: 58 | a.axvline(x=idx, color="white", linestyle="-", linewidth=3) 59 | a.axhline(y=idx, color="white", linestyle="-", linewidth=3) 60 | 61 | plt.setp(a.get_xticklabels(), ha="right", rotation_mode="anchor") 62 | plt.tight_layout() 63 | figure_path = os.path.join(project_root, "figures/8_correlation_all_dataset.pdf") 64 | plt.savefig(figure_path, dpi=300, format="pdf") 65 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/4_main_full_models_split.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | # Figure: Full models 5 | 6 | 7 | main_table_datasets = [ 8 | "VRAG", 9 | "NIAH", 10 | 'ICL', 11 | "Summ", 12 | "DocVQA", 13 | "Ours" 14 | ] 15 | 16 | # plot specific ones in a row, for formatting in the paper 17 | lf_df = process_df(all_df) 18 | length_datasets = main_table_datasets 19 | 20 | tasks_per_row = 3 21 | total_tasks = len(length_datasets) 22 | rows_needed = (total_tasks + tasks_per_row - 1) // tasks_per_row 23 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 24 | 25 | for row in range(rows_needed): 26 | start_idx = row * tasks_per_row 27 | end_idx = min((row + 1) * tasks_per_row, total_tasks) 28 | current_datasets = length_datasets[start_idx:end_idx] 29 | 30 | fig, ax = plt.subplots(ncols=tasks_per_row, nrows=1, sharey=True, sharex=False) 31 | fig.set_size_inches((tasks_per_row * 7.5, 20.5)) # helmet has 45 models for height 20, we have 47 models 32 | 33 | plt.rc('axes', unicode_minus=False) 34 | plt.rcParams.update({'axes.unicode_minus': False}) 35 | 36 | for i, dataset in enumerate(current_datasets): 37 | a = ax[i] 38 | 39 | new_index = full_table_models 40 | 41 | tdf = lf_df[lf_df.input_max_length > 4096] 42 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 43 | tdf = tdf.reindex(new_index) 44 | 45 | # process the scores 46 | annot_matrix = tdf.copy() 47 | tdf = tdf.applymap(lambda x: x if not pd.isna(x) else 0) 48 | annot_matrix = annot_matrix.applymap(lambda x: "N/A" if pd.isna(x) else f"{x:.1f}") 49 | 50 | sns_g = sns.heatmap( 51 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 52 | ax=a, annot_kws={"fontsize": 23.5}, 53 | cbar=False 54 | ) 55 | sns_g.set_title(dataset if dataset != "Ours" else "Avg.", fontsize=34) 56 | 57 | sns_g.set_ylabel("") 58 | sns_g.set_xlabel("") 59 | 60 | new_index = [x.replace("-Inst", '') for x in new_index] 61 | sns_g.set_yticklabels(new_index, size=28) 62 | 63 | xticks = ['8k', '16k', '32k', '64k', '128k'] 64 | sns_g.set_xticklabels(xticks, size=28) 65 | 66 | for idx in [6, 13, 27, 33, 36, 40, 43, 45]: 67 | a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 68 | 69 | [fig.delaxes(a) for a in np.atleast_1d(ax).flatten() if not a.has_data()] 70 | 71 | plt.tight_layout() 72 | 73 | figure_path = os.path.join(project_root, f"figures/4_results_length_full_row{row+1}.pdf") 74 | plt.savefig(figure_path, dpi=500, format="pdf") 75 | plt.close(fig) 76 | 77 | 78 | -------------------------------------------------------------------------------- /vlm_model/mplug_owl3.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoTokenizer, AutoModel 2 | from .model_utils import LLM 3 | 4 | 5 | from transformers.cache_utils import DynamicCache 6 | if not hasattr(DynamicCache, 'get_max_length'): 7 | # set get_max_length as the alias of get_max_cache_shape 8 | DynamicCache.get_max_length = DynamicCache.get_max_cache_shape 9 | print("Added get_max_length method to DynamicCache for backward compatibility") 10 | 11 | from PIL import Image 12 | 13 | 14 | class mPLUGOwl3Model(LLM): 15 | def __init__( 16 | self, 17 | model_name, 18 | temperature=0.9, 19 | top_p=0.9, 20 | max_length=32768, 21 | generation_max_length=2048, 22 | generation_min_length=0, 23 | do_sample=True, 24 | stop_newline=False, 25 | use_chat_template=False, 26 | **kwargs, 27 | ): 28 | super().__init__( 29 | model_name, 30 | temperature=temperature, 31 | top_p=top_p, 32 | max_length=max_length, 33 | generation_max_length=generation_max_length, 34 | generation_min_length=generation_min_length, 35 | do_sample=do_sample, 36 | stop_newline=stop_newline, 37 | use_chat_template=use_chat_template, 38 | ) 39 | 40 | 41 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 42 | self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) 43 | self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) 44 | # print(self.config) 45 | 46 | 47 | if self.tokenizer.pad_token is None: 48 | self.tokenizer.pad_token = self.tokenizer.eos_token 49 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 50 | self.tokenizer.truncation_side = "left" 51 | self.tokenizer.padding_side = "left" 52 | 53 | # self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) 54 | self.processor = self.model.init_processor(self.tokenizer) 55 | 56 | 57 | def format_chat(self, text, image_list, system_prompt): 58 | formatted_text = text.replace("", "<|image|>") 59 | messages = [{"role": "user", "content": formatted_text},{"role": "assistant", "content": ""}] 60 | return messages 61 | 62 | 63 | def prepare_inputs(self, test_item, data): 64 | text = data["user_template"].format(**test_item) 65 | image_list = test_item["image_list"] 66 | # import pdb; pdb.set_trace() 67 | 68 | # Convert file paths to PIL Images 69 | pil_images = [Image.open(img_path).convert("RGB") for img_path in image_list] 70 | 71 | messages = self.format_chat(text, image_list, data["system_template"]) 72 | inputs = self.processor( 73 | messages, 74 | images=pil_images, 75 | return_tensors="pt" 76 | ) 77 | return inputs -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/configuration_vila.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import os.path as osp 5 | from copy import deepcopy 6 | from threading import Thread 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torchvision 11 | from PIL import Image 12 | from transformers import ( 13 | AutoProcessor, 14 | PretrainedConfig, 15 | PreTrainedModel, 16 | Qwen2Config, 17 | Qwen2ForCausalLM, 18 | Qwen2PreTrainedModel, 19 | TextIteratorStreamer, 20 | ) 21 | 22 | 23 | class VILAConfig(PretrainedConfig): 24 | model_type = "vila" 25 | keys_to_ignore_at_inference = ["past_key_values"] 26 | 27 | def __init__( 28 | self, 29 | llm_cfg=None, 30 | vision_tower_cfg=None, 31 | mm_projector_cfg=None, 32 | architectures=None, 33 | resume_path=None, 34 | hidden_size=None, 35 | mm_hidden_size=None, 36 | image_aspect_ratio=None, 37 | num_video_frames=None, 38 | fps=None, 39 | mm_vision_select_layer=None, 40 | mm_vision_select_feature=None, 41 | mm_use_im_start_end=False, 42 | mm_use_im_patch_token=False, 43 | mm_projector_lr=None, 44 | vision_tower_lr=None, 45 | vision_resolution=None, 46 | interpolate_mode=None, 47 | s2=None, 48 | dynamic_s2=None, 49 | s2_scales=None, 50 | s2_max_split_size=None, 51 | s2_resize_output_to_scale_idx=0, 52 | min_tiles: Optional[int] = 1, 53 | max_tiles: Optional[int] = 12, 54 | num_time_tokens=None, 55 | time_token_format=None, 56 | image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}', 57 | video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}', 58 | **kwargs, 59 | ): 60 | super().__init__(**kwargs) 61 | 62 | self.architectures = architectures 63 | self.llm_cfg = llm_cfg 64 | self.vision_tower_cfg = vision_tower_cfg 65 | self.mm_projector_cfg = mm_projector_cfg 66 | self.resume_path = resume_path 67 | 68 | self.hidden_size = hidden_size 69 | self.mm_hidden_size = mm_hidden_size 70 | self.image_aspect_ratio = image_aspect_ratio 71 | self.num_video_frames = num_video_frames 72 | self.fps = fps 73 | self.mm_vision_select_layer = mm_vision_select_layer 74 | self.mm_vision_select_feature = mm_vision_select_feature 75 | self.mm_use_im_start_end = mm_use_im_start_end 76 | self.mm_use_im_patch_token = mm_use_im_patch_token 77 | self.mm_projector_lr = mm_projector_lr 78 | self.vision_tower_lr = vision_tower_lr 79 | self.vision_resolution = vision_resolution 80 | self.interpolate_mode = interpolate_mode 81 | self.s2 = s2 82 | self.dynamic_s2 = dynamic_s2 83 | self.s2_scales = s2_scales 84 | self.s2_max_split_size = s2_max_split_size 85 | self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx 86 | self.min_tiles = min_tiles 87 | self.max_tiles = max_tiles 88 | self.num_time_tokens = num_time_tokens 89 | self.time_token_format = time_token_format 90 | 91 | self.image_encoder = image_encoder 92 | self.video_encoder = video_encoder 93 | 94 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/configuration_vila.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import os.path as osp 5 | from copy import deepcopy 6 | from threading import Thread 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torchvision 11 | from PIL import Image 12 | from transformers import ( 13 | AutoProcessor, 14 | PretrainedConfig, 15 | PreTrainedModel, 16 | Qwen2Config, 17 | Qwen2ForCausalLM, 18 | Qwen2PreTrainedModel, 19 | TextIteratorStreamer, 20 | ) 21 | 22 | 23 | class VILAConfig(PretrainedConfig): 24 | model_type = "vila" 25 | keys_to_ignore_at_inference = ["past_key_values"] 26 | 27 | def __init__( 28 | self, 29 | llm_cfg=None, 30 | vision_tower_cfg=None, 31 | mm_projector_cfg=None, 32 | architectures=None, 33 | resume_path=None, 34 | hidden_size=None, 35 | mm_hidden_size=None, 36 | image_aspect_ratio=None, 37 | num_video_frames=None, 38 | fps=None, 39 | mm_vision_select_layer=None, 40 | mm_vision_select_feature=None, 41 | mm_use_im_start_end=False, 42 | mm_use_im_patch_token=False, 43 | mm_projector_lr=None, 44 | vision_tower_lr=None, 45 | vision_resolution=None, 46 | interpolate_mode=None, 47 | s2=None, 48 | dynamic_s2=None, 49 | s2_scales=None, 50 | s2_max_split_size=None, 51 | s2_resize_output_to_scale_idx=0, 52 | min_tiles: Optional[int] = 1, 53 | max_tiles: Optional[int] = 12, 54 | num_time_tokens=None, 55 | time_token_format=None, 56 | image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}', 57 | video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}', 58 | **kwargs, 59 | ): 60 | super().__init__(**kwargs) 61 | 62 | self.architectures = architectures 63 | self.llm_cfg = llm_cfg 64 | self.vision_tower_cfg = vision_tower_cfg 65 | self.mm_projector_cfg = mm_projector_cfg 66 | self.resume_path = resume_path 67 | 68 | self.hidden_size = hidden_size 69 | self.mm_hidden_size = mm_hidden_size 70 | self.image_aspect_ratio = image_aspect_ratio 71 | self.num_video_frames = num_video_frames 72 | self.fps = fps 73 | self.mm_vision_select_layer = mm_vision_select_layer 74 | self.mm_vision_select_feature = mm_vision_select_feature 75 | self.mm_use_im_start_end = mm_use_im_start_end 76 | self.mm_use_im_patch_token = mm_use_im_patch_token 77 | self.mm_projector_lr = mm_projector_lr 78 | self.vision_tower_lr = vision_tower_lr 79 | self.vision_resolution = vision_resolution 80 | self.interpolate_mode = interpolate_mode 81 | self.s2 = s2 82 | self.dynamic_s2 = dynamic_s2 83 | self.s2_scales = s2_scales 84 | self.s2_max_split_size = s2_max_split_size 85 | self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx 86 | self.min_tiles = min_tiles 87 | self.max_tiles = max_tiles 88 | self.num_time_tokens = num_time_tokens 89 | self.time_token_format = time_token_format 90 | 91 | self.image_encoder = image_encoder 92 | self.video_encoder = video_encoder 93 | 94 | -------------------------------------------------------------------------------- /figure_scripts/11_task_diffculty.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib 3 | 4 | # plot specific ones in a row, for formatting in the paper 5 | lf_df = process_df(all_df) 6 | length_datasets = ["ICL", "VRAG", "NIAH", "DocVQA"] 7 | 8 | ncols = 4 9 | nrows = (len(length_datasets) - 1) // ncols + 1 10 | 11 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 12 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 13 | fig.set_size_inches((ncols * 6, nrows * 5)) 14 | plt.rc('axes', unicode_minus=False) 15 | plt.rcParams.update({'axes.unicode_minus': False}) 16 | 17 | diff_pairs = [ 18 | ('InternVL3-38B', 'InternVL3-8B', 'Diff (8B$\\rightarrow$38B)'), 19 | ('Gemma3-27B', 'Gemma3-12B', 'Diff (12B$\\rightarrow$27B)') 20 | ] 21 | 22 | base_index_order = [ 23 | "GPT-4o", "Gemini-2.5-Pro", 24 | 'InternVL3-8B', 'InternVL3-38B', 25 | "Gemma3-12B", "Gemma3-27B", 26 | ] 27 | 28 | for i, dataset in enumerate(length_datasets): 29 | if nrows > 1: 30 | a = ax[i // ncols][i % ncols] 31 | else: 32 | a = ax[i] 33 | 34 | tdf = lf_df[lf_df.input_max_length > 4096] 35 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 36 | tdf = tdf.reindex(base_index_order) 37 | 38 | final_index_order = list(base_index_order) 39 | diff_data_to_add = {} 40 | 41 | for model_large, model_small, diff_name in diff_pairs: 42 | if model_large in tdf.index and model_small in tdf.index: 43 | diff_series = tdf.loc[model_large] - tdf.loc[model_small] 44 | diff_data_to_add[diff_name] = diff_series 45 | 46 | # insert the diff 47 | insert_pos = final_index_order.index(model_large) + 1 48 | final_index_order.insert(insert_pos, diff_name) 49 | 50 | tdf = tdf.reindex(final_index_order) 51 | 52 | for diff_name, diff_series in diff_data_to_add.items(): 53 | tdf.loc[diff_name] = diff_series 54 | 55 | annot_matrix = tdf.copy() 56 | for diff_name in diff_data_to_add.keys(): 57 | annot_matrix.loc[diff_name] = annot_matrix.loc[diff_name].apply(lambda x: f"{x:+.1f}") 58 | for idx in tdf.index: 59 | if idx not in diff_data_to_add: 60 | annot_matrix.loc[idx] = annot_matrix.loc[idx].apply(lambda x: f"{x:.1f}") 61 | 62 | sns_g = sns.heatmap( 63 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 64 | ax=a, annot_kws={"fontsize": 22}, 65 | cbar=False 66 | ) 67 | sns_g.set_title(dataset, fontsize=34) 68 | 69 | sns_g.set_ylabel("") 70 | sns_g.set_xlabel("") 71 | 72 | # a.set_yticklabels(ax[0].get_yticklabels() if (i%ncols) == 0 else [], size = 16) 73 | # sns_g.set_yticklabels(sns_g.get_yticklabels(), size = 16) 74 | # sns_g.set_xticklabels(sns_g.get_xticklabels(), size = 24) 75 | 76 | sns_g.set_yticklabels(final_index_order, size=26) 77 | xticks_map = {"8192": '8k', "16384": '16k', "32768": '32k', "65536":'64k', "131072":'128k'} 78 | sns_g.set_xticklabels([xticks_map[st.get_text()] for st in sns_g.get_xticklabels()], size=28) 79 | 80 | # idx, start, end 81 | a.hlines([2, 5], 0, 6, color="0.95", linestyle="-", linewidth=3) 82 | # a.vlines([5, 1, 3], [0, 5, 7], [5, 7, 11], color="bisque", linestyle="--", linewidth=3) 83 | 84 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 85 | 86 | plt.tight_layout() 87 | file_path = os.path.join(project_root, f"figures/11_results_length_select.pdf") 88 | plt.savefig(file_path, dpi=500, format="pdf") 89 | plt.show() -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/model/internvl_chat/flash_attention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | try: # v1 7 | from flash_attn.flash_attn_interface import \ 8 | flash_attn_unpadded_qkvpacked_func 9 | except: # v2 10 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 11 | 12 | from flash_attn.bert_padding import pad_input, unpad_input 13 | 14 | 15 | class FlashAttention(nn.Module): 16 | """Implement the scaled dot product attention with softmax. 17 | Arguments 18 | --------- 19 | softmax_scale: The temperature to use for the softmax attention. 20 | (default: 1/sqrt(d_keys) where d_keys is computed at 21 | runtime) 22 | attention_dropout: The dropout rate to apply to the attention 23 | (default: 0.0) 24 | """ 25 | 26 | def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): 27 | super().__init__() 28 | self.softmax_scale = softmax_scale 29 | self.dropout_p = attention_dropout 30 | 31 | def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, 32 | max_s=None, need_weights=False): 33 | """Implements the multihead softmax attention. 34 | Arguments 35 | --------- 36 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None 37 | if unpadded: (nnz, 3, h, d) 38 | key_padding_mask: a bool tensor of shape (B, S) 39 | """ 40 | assert not need_weights 41 | assert qkv.dtype in [torch.float16, torch.bfloat16] 42 | assert qkv.is_cuda 43 | 44 | if cu_seqlens is None: 45 | batch_size = qkv.shape[0] 46 | seqlen = qkv.shape[1] 47 | if key_padding_mask is None: 48 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 49 | max_s = seqlen 50 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, 51 | device=qkv.device) 52 | output = flash_attn_unpadded_qkvpacked_func( 53 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 54 | softmax_scale=self.softmax_scale, causal=causal 55 | ) 56 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) 57 | else: 58 | nheads = qkv.shape[-2] 59 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 60 | x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) 61 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 62 | output_unpad = flash_attn_unpadded_qkvpacked_func( 63 | x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 64 | softmax_scale=self.softmax_scale, causal=causal 65 | ) 66 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 67 | indices, batch_size, seqlen), 68 | 'b s (h d) -> b s h d', h=nheads) 69 | else: 70 | assert max_s is not None 71 | output = flash_attn_unpadded_qkvpacked_func( 72 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 73 | softmax_scale=self.softmax_scale, causal=causal 74 | ) 75 | 76 | return output, None 77 | -------------------------------------------------------------------------------- /figure_scripts/7_correlation_all_categories.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # break into different lengths and then plot the pairwise heatmap 4 | datasets = ['VRAG', 'NIAH', 'ICL', 'Summ', 'DocVQA'] 5 | lf_df = process_df(all_df) 6 | 7 | datasets1, datasets2 = datasets, datasets 8 | 9 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#ed4d6e", '#DD9380', '#DEA683', '#CFCC86', "#0CD79F"]) 10 | 11 | all_corr = {} 12 | lengths = [131072] 13 | total_columns = len(datasets) + 1 14 | fig, ax = plt.subplots(figsize=(total_columns, len(datasets) + 0.1), nrows=1, ncols=len(lengths)) 15 | 16 | for i, (l, d) in enumerate(lf_df.groupby("input_max_length", sort=False)): 17 | if l not in lengths: 18 | continue 19 | i = lengths.index(l) 20 | spearmans = [] 21 | for d1 in datasets1: 22 | for d2 in datasets2: 23 | x = d[d[d1].notnull() & d[d2].notnull()] 24 | m1 = x[d1] 25 | m2 = x[d2] 26 | 27 | if len(m1) < 2 and len(m2) < 2: 28 | continue 29 | 30 | rho, p = stats.spearmanr(m1, m2) 31 | spearmans.append({"dataset 1": d1, "dataset 2": d2, "correlation": rho}) 32 | 33 | all_corr[l] = {"spearman": pd.DataFrame(spearmans)} 34 | for j, (name, table) in enumerate(all_corr[l].items()): 35 | hm = table.pivot_table(index="dataset 1", columns="dataset 2", values="correlation", sort=False) 36 | 37 | avg_corr = {} 38 | std_corr = {} 39 | for dataset in hm.index: 40 | corrs = [hm.loc[dataset, col] for col in hm.columns if col != dataset] 41 | avg_corr[dataset] = np.mean(corrs) 42 | std_corr[dataset] = np.std(corrs) 43 | hm['Avg'] = pd.Series(avg_corr) 44 | 45 | annot_matrix = hm.copy() 46 | for dataset in avg_corr: 47 | avg = avg_corr[dataset] 48 | std = std_corr[dataset] 49 | avg_str = f"{avg:.3f}"[1:] 50 | std_str = f"{std:.2f}" 51 | annot_matrix.loc[dataset, "Avg"] = f"${avg_str}_{{{std_str}}}$" 52 | 53 | annot_matrix[datasets] = annot_matrix[datasets].applymap(lambda x: f"{x:.2f}") 54 | 55 | # compress the outlier values for better visualization. 56 | # otherwise we cannot tell the difference 57 | max_value = hm[abs(hm - 1) > 1e-6].max().max() 58 | hm[abs(hm - 1.0) < 1e-6] = min(max_value + 0.03, 1) 59 | 60 | tmp_min_value = hm.min().min() 61 | min_value = hm[abs(hm - tmp_min_value) > 1e-6].min().min() 62 | hm[abs(hm - tmp_min_value) < 1e-6] = min(min_value - 0.03, 1) 63 | 64 | 65 | a = ax[i] if len(lengths) > 1 else ax 66 | import matplotlib.colors as mcolors 67 | sns_g = sns.heatmap(hm, annot=False, ax=a, cbar=False, cmap=cmap, norm=mcolors.PowerNorm(gamma=0.5, vmin=hm.min().min(), vmax=hm.max().max())) 68 | for i in range(len(annot_matrix.index)): 69 | for j in range(len(annot_matrix.columns)): 70 | text_value = annot_matrix.iloc[i, j] 71 | if annot_matrix.columns[j] == "Avg": 72 | a.text(j + 0.5, i + 0.5, text_value, 73 | ha="center", va="center", color="black", fontsize=13) 74 | else: 75 | a.text(j + 0.5, i + 0.5, text_value, 76 | ha="center", va="center", color="black", fontsize=13.5) 77 | 78 | sns_g.set_ylabel("") 79 | sns_g.set_xlabel("") 80 | 81 | t = datasets 82 | sns_g.set_yticklabels(t, size=16) 83 | sns_g.set_xticklabels(t + ["Avg$_{std}$"], size=16) 84 | 85 | a.tick_params(axis='x', rotation=45) 86 | a.tick_params(axis='y', rotation=0) 87 | 88 | # a.axvline(x=5, color="white", linestyle="-", linewidth=1.5) 89 | 90 | plt.setp(a.get_xticklabels(), ha="right", rotation_mode="anchor") 91 | plt.tight_layout() 92 | figure_path = os.path.join(project_root, "figures/7_correlation_all_category.pdf") 93 | plt.savefig(figure_path, dpi=300, format="pdf") 94 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/10_plot_NIAH_distribution.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib.patches as patches 3 | 4 | lf_df = process_df(all_df) 5 | dname = "DocVQA" 6 | 7 | more_avgs = { 8 | "Others": ["DocVQA"], # 'VRAG', 'ICL', 'Summ', "DocQA" 9 | } 10 | 11 | for k, v in more_avgs.items(): 12 | lf_df[k] = lf_df[v].mean(axis=1) 13 | 14 | melted_df = lf_df.melt(id_vars=['input_max_length', "Model", dname]) 15 | 16 | # plot correlation across datasets and lengths with one specific dataset 17 | tdf = melted_df[melted_df.input_max_length==131072] 18 | # tdf = tdf[tdf.dataset_simple.isin(['NIAH S Essay', 'NIAH MK Needle', 'HotpotQA'])] 19 | g = sns.relplot( 20 | tdf.rename({"input_max_length": "length", "dataset_simple": "dataset"}, axis=1), 21 | x="value", y=dname, col="dataset", row="length", 22 | facet_kws={'sharey': False, 'sharex': False}, 23 | hue="Model", markers=True, legend=False, col_order=['MM-NIAH-Ret', 'MM-NIAH-Count', 'MM-NIAH-Reason'], 24 | s=100, 25 | ) 26 | 27 | 28 | def annotate(data, **kws): 29 | data = data[data["value"].notna() & data[dname].notna()] 30 | ax = plt.gca() 31 | if len(data) > 1: 32 | ax.text(.05, .95, 'n={}'.format(len(data)), transform=ax.transAxes, fontsize=15) 33 | rho, p = stats.spearmanr(data['value'], data[dname]) 34 | ax.text(.05, .88, 'Spearman $\\rho$={:.2f}'.format(rho), transform=ax.transAxes, fontsize=15) 35 | # ax.text(.05, .81, 'p={:.2g}'.format(p), transform=ax.transAxes, fontsize=15) 36 | title_info = ax.get_title().split("dataset =")[1].strip() 37 | if title_info == 'MM-NIAH-Reason': 38 | x_start, x_end = 0, 40 39 | y_start, y_end = 0, 60 40 | x_offset = 18 41 | y_offset = 2 42 | elif title_info == 'MM-NIAH-Count': 43 | x_start, x_end = 0, 30 44 | y_start, y_end = 0, 60 45 | x_offset = 18 46 | y_offset = -1 47 | elif title_info == 'MM-NIAH-Ret': 48 | x_start, x_end = 0, 50 49 | y_start, y_end = 0, 60 50 | x_offset = 20 51 | y_offset = -0.5 52 | 53 | width, height = x_end - x_start, y_end - y_start 54 | 55 | rect = patches.Rectangle( 56 | (x_start, y_start), width, height, 57 | linewidth=0, edgecolor='none', facecolor='lightcoral', 58 | alpha=0.2, zorder=0 # Put it behind data points 59 | ) 60 | ax.add_patch(rect) 61 | 62 | # 2. Find points within the X range [0, 40] 63 | points_in_range = data[(data['value'] >= 0) & (data['value'] <= x_end)].copy() 64 | max_point = points_in_range.loc[points_in_range[dname].idxmax()] 65 | x_coord = max_point['value'] 66 | y_coord = max_point[dname] 67 | 68 | ax.axhline(y=y_coord, xmin=0, xmax=x_coord / ax.get_xlim()[1], 69 | linestyle='--', color='dimgrey', alpha=0.6, linewidth=2) 70 | 71 | ax.annotate(text=f"{y_coord:.1f}", xy=(x_coord, y_coord), xytext=(x_coord + x_offset, y_coord + y_offset), 72 | fontsize=14, color='dimgrey', fontweight='normal', 73 | arrowprops=dict( 74 | arrowstyle="->", # Arrow style 75 | color='dimgrey', # Arrow color 76 | lw=1.5, # Slightly thicker arrow 77 | connectionstyle = "arc3,rad=0.2" 78 | ), 79 | bbox=dict(boxstyle="round,pad=0.3", fc='bisque', ec="black", lw=0.5, alpha=0.6) 80 | ) 81 | 82 | print(ax.get_title()) 83 | dataset = ax.get_title().split("= ")[-1] 84 | ax.set_title("") 85 | ax.set_xlabel(dataset, fontsize=22) 86 | ax.set_ylabel(dname, fontsize=22) 87 | ax.set_yticklabels(ax.get_yticklabels(), size = 20) 88 | ax.set_xticklabels(ax.get_xticklabels(), size = 20) 89 | 90 | g.map_dataframe(annotate) 91 | # g.fig.suptitle(f"Correlation with {dname}", fontsize=24, y=1.05) 92 | plt.tight_layout() 93 | fname = "figures/10_correlation_recall_others_dist.pdf" 94 | fname = os.path.join(project_root, fname) 95 | print(fname) 96 | plt.savefig(fname, dpi=300, format="pdf") 97 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/16_docqa_pie_figure.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def create_adaptive_pie(ax, data, colors, threshold=0.05, title=None, font_size=12): 6 | large_data = {} 7 | small_data = {} 8 | 9 | for key, value in data.items(): 10 | if value >= threshold: 11 | large_data[key] = value 12 | else: 13 | small_data[key] = value 14 | 15 | all_keys = list(data.keys()) 16 | all_values = list(data.values()) 17 | all_colors = [colors[i % len(colors)] for i in range(len(data))] 18 | 19 | wedges, _ = ax.pie( 20 | all_values, 21 | labels=None, 22 | colors=all_colors, 23 | startangle=90, 24 | wedgeprops={'edgecolor': 'w', 'linewidth': 1} 25 | ) 26 | 27 | for i, key in enumerate(all_keys): 28 | if all_values[i] >= threshold: 29 | angle = wedges[i].theta1 + (wedges[i].theta2 - wedges[i].theta1) / 2 30 | angle_rad = np.deg2rad(angle) 31 | 32 | r = 0.6 33 | x = r * np.cos(angle_rad) 34 | y = r * np.sin(angle_rad) 35 | 36 | percentage = f"{all_values[i]:.1%}" 37 | ax.text(x, y, percentage, ha='center', va='center', 38 | fontweight='bold', fontsize=font_size) 39 | 40 | for i, key in enumerate(all_keys): 41 | angle = wedges[i].theta1 + (wedges[i].theta2 - wedges[i].theta1) / 2 42 | angle_rad = np.deg2rad(angle) 43 | 44 | if all_values[i] < threshold: 45 | 46 | label_text = f"{key}\n({all_values[i]:.1%})" 47 | else: 48 | label_text = f"{key}" 49 | 50 | r = 1.1 51 | x = r * np.cos(angle_rad) 52 | y = r * np.sin(angle_rad) 53 | 54 | if key == "Pure-text" and "LongDocURL" in title: 55 | x += 0.2 56 | y -= 0.4 57 | 58 | ha = "center" 59 | if angle > 90 and angle < 270: 60 | ha = "right" 61 | elif angle < 90 or angle > 270: 62 | ha = "left" 63 | 64 | ax.text(x, y, label_text, ha=ha, va='center', fontweight='bold', fontsize=font_size) 65 | 66 | if title: 67 | ax.set_title(title, fontsize=font_size + 2) 68 | 69 | ax.set_aspect('equal') 70 | 71 | 72 | mmlongdoc_sources = { 73 | 'Pure-text': 0.256, 74 | 'Layout': 0.106, 75 | 'Table': 0.205, 76 | 'Figure': 0.285, 77 | 'Chart': 0.161, 78 | } 79 | 80 | mmlongdoc_formats = { 81 | 'String': 0.172, 82 | 'Integer': 0.341, 83 | 'Float': 0.135, 84 | 'List': 0.120, 85 | 'None': 0.232, 86 | } 87 | 88 | longdocurl_sources = { 89 | 'Pure-Text': 0.450, 90 | 'Layout': 0.272, 91 | 'Table': 0.372, 92 | 'Figure': 0.208, 93 | 'Others': 0.002 94 | } 95 | 96 | longdocurl_formats = { 97 | 'String': 0.261, 98 | 'Integer': 0.341, 99 | 'Float': 0.152, 100 | 'List': 0.239, 101 | 'None': 0.008 102 | } 103 | 104 | colors1 = plt.cm.Pastel1(np.linspace(0, 1, 5)) 105 | colors2 = plt.cm.Pastel2(np.linspace(0, 1, 5)) 106 | 107 | fig, axes = plt.subplots(1, 4, figsize=(20, 5)) 108 | pie_font = 12 109 | 110 | create_adaptive_pie(axes[0], mmlongdoc_sources, colors1, threshold=0.10, 111 | title='MMLB-Doc Answer Sources', font_size=14) 112 | 113 | create_adaptive_pie(axes[1], mmlongdoc_formats, colors2, threshold=0.10, 114 | title='MMLB-Doc Answer Format', font_size=14) 115 | 116 | create_adaptive_pie(axes[2], longdocurl_sources, colors1, threshold=0.10, 117 | title='LongDocURL Answer Sources', font_size=14) 118 | 119 | create_adaptive_pie(axes[3], longdocurl_formats, colors2, threshold=0.10, 120 | title='LongDocURL Answer Format', font_size=14) 121 | for ax in axes: 122 | ax.set_aspect('equal') 123 | 124 | plt.tight_layout() 125 | project_root = "/home/zhaowei.wang/vl-longbench" 126 | import os 127 | figure_path = os.path.join(project_root, 'figures/16_docqa_dist.pdf') 128 | plt.savefig(figure_path, format='pdf', dpi=300, bbox_inches='tight') 129 | plt.show() -------------------------------------------------------------------------------- /figure_scripts/9_heatmap_by_depth.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # plot the heatmap by depth for each model, note that this takes quite some time to generate... 4 | # use your dataset name to filter datasets_configs 5 | # the tasks that need depth: viquae, vh_single, mm_niah_retrieval-text, mm_niah_retrieval-image, mm_niah_reasoning-image 6 | dataset_name_list = ["viquae", "vh_single", "mm_niah_retrieval-text", "mm_niah_retrieval-image", "mm_niah_reasoning-image"] # 7 | 8 | for dataset_name in dataset_name_list: 9 | print(f"Processing {dataset_name} now") 10 | datasets_configs = [config for config in dataset_configs if config["dataset"] == dataset_name] 11 | 12 | depth_dfs = [] 13 | ncols = 6 14 | 15 | for i, model in enumerate(models_configs): 16 | args = arguments() 17 | depths = [] 18 | for dataset in datasets_configs: 19 | args.update(dataset) 20 | args.update(model) 21 | 22 | depth = args.get_metric_by_depth() 23 | if depth is None: 24 | continue 25 | for d in depth: 26 | d["input_length"] = args.input_max_length 27 | d["depth"] = math.ceil(d["depth"] * 10) / 10 28 | depths += depth 29 | print('good') 30 | if len(depths) == 0: 31 | continue 32 | depths = pd.DataFrame(depths) 33 | depth_dfs.append((model, depths)) 34 | 35 | fig, ax = plt.subplots(nrows=(len(depth_dfs) - 1) // ncols + 1, ncols=ncols, sharey=False, sharex=False) 36 | fig.set_size_inches((ncols * 5, len(ax) * 4.25)) 37 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#F0496E", "#EBB839", "#0CD79F"]) 38 | 39 | vmin = min([d[1]['metric'].min() for d in depth_dfs]) 40 | vmax = max([d[1]['metric'].max() for d in depth_dfs]) 41 | for i, (config, depths) in enumerate(tqdm(depth_dfs, "drawing models")): 42 | a = ax[i // ncols][i % ncols] 43 | pivot_table = depths.pivot(index="depth", columns="input_length", values="metric") 44 | 45 | need_annot_matrix = [] 46 | for col in [8192, 16384, 32768, 65536, 131072]: 47 | if col not in pivot_table.columns: 48 | pivot_table[col] = vmin # use vmin as default 49 | need_annot_matrix.append(col) 50 | 51 | if need_annot_matrix: 52 | annot_matrix = pivot_table.copy() 53 | annot_matrix = annot_matrix.applymap(lambda x: f"{x:.2f}") 54 | for col in need_annot_matrix: 55 | annot_matrix[col] = "N/A" 56 | sns_g = sns.heatmap( 57 | pivot_table, cmap=cmap, vmin=vmin, vmax=vmax, cbar_kws={"label": "Score"}, ax=a, annot=annot_matrix, 58 | cbar=False, annot_kws={"fontsize": 16}, fmt="") 59 | else: 60 | sns_g = sns.heatmap( 61 | pivot_table, cmap=cmap, vmin=vmin, vmax=vmax, cbar_kws={"label": "Score"}, ax=a, annot=True, 62 | cbar=False, annot_kws={"fontsize": 16} 63 | ) 64 | m = config['model'] 65 | 66 | idx = {'8k': 1, '10k': 1, '16k': 2, '32k': 3, '64k': 4, '80k': 4, '128k': 5, '200k': 5, '1m': 5}[model_lengths[model_name_replace[config['model']]]] 67 | a.axvline(idx, color="white", linestyle="--", linewidth=4) 68 | 69 | num_columns = pivot_table.shape[1] 70 | a.set_xlim(0, num_columns + 0.5) 71 | 72 | sns_g.set_title(model_name_replace.get(m, m), fontsize=28) 73 | 74 | xticks = {'4096': '4k', '8192': '8k', '16384': '16k', '32768': '32k', '65536': '64k', '131072': '128k'} 75 | xticks = [xticks[x.get_text()] for x in a.get_xticklabels()] 76 | a.set_xticklabels(xticks, size=24) 77 | 78 | ytick_labels = pivot_table.index.astype(str).tolist() 79 | a.set_yticklabels(ytick_labels, size=24, rotation=0) 80 | 81 | if i % ncols == 0: 82 | a.set_ylabel("Depth", size=28) 83 | else: 84 | a.set_ylabel("") 85 | a.set_xlabel("") 86 | 87 | [fig.delaxes(a) for i, a in enumerate(ax.flatten()) if not a.has_data()] 88 | plt.tight_layout() 89 | figure_path = os.path.join(project_root, f"figures/9_depths_{dataset_name}.pdf") 90 | plt.savefig(figure_path, dpi=300, format="pdf") 91 | plt.show() -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/media_encoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Dict, List, Optional 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class BaseEncoder(nn.Module): 9 | def __init__(self, parent: nn.Module) -> None: 10 | super().__init__() 11 | self._parent = [parent] 12 | 13 | @property 14 | def parent(self) -> nn.Module: 15 | return self._parent[0] 16 | 17 | 18 | class BasicImageEncoder(BaseEncoder): 19 | def __init__( 20 | self, 21 | parent: torch.nn.Module, 22 | start_tokens: Optional[str] = None, 23 | end_tokens: Optional[str] = "\n", 24 | ) -> None: 25 | super().__init__(parent) 26 | self.start_tokens = start_tokens 27 | self.end_tokens = end_tokens 28 | 29 | def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: 30 | if tokens is None: 31 | return None 32 | token_ids = self.parent.tokenizer(tokens).input_ids 33 | token_ids = torch.tensor(token_ids, device=self.parent.device) 34 | return self.parent.llm.model.embed_tokens(token_ids) 35 | 36 | def _process_features( 37 | self, 38 | features: torch.Tensor, 39 | start_token_embeds: Optional[torch.Tensor], 40 | end_token_embeds: Optional[torch.Tensor], 41 | ) -> torch.Tensor: 42 | if start_token_embeds is not None: 43 | features = torch.cat([start_token_embeds, features], dim=0) 44 | if end_token_embeds is not None: 45 | features = torch.cat([features, end_token_embeds], dim=0) 46 | return features 47 | 48 | def forward(self, images: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: 49 | images = torch.stack(images, dim=0) 50 | features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) 51 | process_features = partial( 52 | self._process_features, 53 | start_token_embeds=self.embed_tokens(self.start_tokens), 54 | end_token_embeds=self.embed_tokens(self.end_tokens), 55 | ) 56 | return [process_features(f) for f in features] 57 | 58 | 59 | class BasicVideoEncoder(BaseEncoder): 60 | def __init__( 61 | self, 62 | parent: torch.nn.Module, 63 | start_tokens: Optional[str] = None, 64 | end_tokens: Optional[str] = "\n", 65 | ) -> None: 66 | super().__init__(parent) 67 | self.start_tokens = start_tokens 68 | self.end_tokens = end_tokens 69 | 70 | def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: 71 | if tokens is None: 72 | return None 73 | token_ids = self.parent.tokenizer(tokens).input_ids 74 | token_ids = torch.tensor(token_ids, device=self.parent.device) 75 | return self.parent.llm.model.embed_tokens(token_ids) 76 | 77 | def _process_features( 78 | self, 79 | features: torch.Tensor, 80 | start_token_embeds: Optional[torch.Tensor], 81 | end_token_embeds: Optional[torch.Tensor], 82 | ) -> torch.Tensor: 83 | if start_token_embeds is not None: 84 | start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) 85 | features = torch.cat([start_embeds, features], dim=1) 86 | if end_token_embeds is not None: 87 | end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) 88 | features = torch.cat([features, end_embeds], dim=1) 89 | return features.flatten(0, 1) 90 | 91 | def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: 92 | num_frames = [video.shape[0] for video in videos] 93 | images = torch.cat(videos, dim=0) 94 | features = self.parent.encode_images(images) 95 | features = torch.split(features, num_frames) 96 | process_features = partial( 97 | self._process_features, 98 | start_token_embeds=self.embed_tokens(self.start_tokens), 99 | end_token_embeds=self.embed_tokens(self.end_tokens), 100 | ) 101 | return [process_features(f) for f in features] 102 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/media_encoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Dict, List, Optional 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class BaseEncoder(nn.Module): 9 | def __init__(self, parent: nn.Module) -> None: 10 | super().__init__() 11 | self._parent = [parent] 12 | 13 | @property 14 | def parent(self) -> nn.Module: 15 | return self._parent[0] 16 | 17 | 18 | class BasicImageEncoder(BaseEncoder): 19 | def __init__( 20 | self, 21 | parent: torch.nn.Module, 22 | start_tokens: Optional[str] = None, 23 | end_tokens: Optional[str] = "\n", 24 | ) -> None: 25 | super().__init__(parent) 26 | self.start_tokens = start_tokens 27 | self.end_tokens = end_tokens 28 | 29 | def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: 30 | if tokens is None: 31 | return None 32 | token_ids = self.parent.tokenizer(tokens).input_ids 33 | token_ids = torch.tensor(token_ids, device=self.parent.device) 34 | return self.parent.llm.model.embed_tokens(token_ids) 35 | 36 | def _process_features( 37 | self, 38 | features: torch.Tensor, 39 | start_token_embeds: Optional[torch.Tensor], 40 | end_token_embeds: Optional[torch.Tensor], 41 | ) -> torch.Tensor: 42 | if start_token_embeds is not None: 43 | features = torch.cat([start_token_embeds, features], dim=0) 44 | if end_token_embeds is not None: 45 | features = torch.cat([features, end_token_embeds], dim=0) 46 | return features 47 | 48 | def forward(self, images: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: 49 | images = torch.stack(images, dim=0) 50 | features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) 51 | process_features = partial( 52 | self._process_features, 53 | start_token_embeds=self.embed_tokens(self.start_tokens), 54 | end_token_embeds=self.embed_tokens(self.end_tokens), 55 | ) 56 | return [process_features(f) for f in features] 57 | 58 | 59 | class BasicVideoEncoder(BaseEncoder): 60 | def __init__( 61 | self, 62 | parent: torch.nn.Module, 63 | start_tokens: Optional[str] = None, 64 | end_tokens: Optional[str] = "\n", 65 | ) -> None: 66 | super().__init__(parent) 67 | self.start_tokens = start_tokens 68 | self.end_tokens = end_tokens 69 | 70 | def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: 71 | if tokens is None: 72 | return None 73 | token_ids = self.parent.tokenizer(tokens).input_ids 74 | token_ids = torch.tensor(token_ids, device=self.parent.device) 75 | return self.parent.llm.model.embed_tokens(token_ids) 76 | 77 | def _process_features( 78 | self, 79 | features: torch.Tensor, 80 | start_token_embeds: Optional[torch.Tensor], 81 | end_token_embeds: Optional[torch.Tensor], 82 | ) -> torch.Tensor: 83 | if start_token_embeds is not None: 84 | start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) 85 | features = torch.cat([start_embeds, features], dim=1) 86 | if end_token_embeds is not None: 87 | end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) 88 | features = torch.cat([features, end_embeds], dim=1) 89 | return features.flatten(0, 1) 90 | 91 | def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: 92 | num_frames = [video.shape[0] for video in videos] 93 | images = torch.cat(videos, dim=0) 94 | features = self.parent.encode_images(images) 95 | features = torch.split(features, num_frames) 96 | process_features = partial( 97 | self._process_features, 98 | start_token_embeds=self.embed_tokens(self.start_tokens), 99 | end_token_embeds=self.embed_tokens(self.end_tokens), 100 | ) 101 | return [process_features(f) for f in features] 102 | -------------------------------------------------------------------------------- /figure_scripts/3_main_full_models_by_category.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | # plot figures on each dataset by category 5 | 6 | category_to_datasets = { 7 | "VRAG": ["InfoSeek", "ViQuAE"], 8 | "NIAH": ["VH-Single", "VH-Multi", 'MM-NIAH-Ret (T)', "MM-NIAH-Ret (I)", 9 | "MM-NIAH-Count (T)", "MM-NIAH-Count (I)", "MM-NIAH-Reason (T)", "MM-NIAH-Reason (I)"], 10 | "ICL": ["Stanford Cars", "Food101", "SUN397", "Inat2021"], 11 | "Summ": ["GovReport", "Multi-LexSum"], 12 | "DocVQA": ["MMLongBench-Doc", "LongDocURL", "SlideVQA"] 13 | } 14 | 15 | category_group = [["VRAG", "Summ"], "NIAH", "ICL", "DocVQA"] # ["VRAG", "Summ"], "Recall", "ICL", "DocQA" 16 | 17 | # plot specific ones in a row, for formatting in the paper 18 | lf_df = process_df(all_df) 19 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 20 | 21 | for cur_cate_list in category_group: 22 | if isinstance(cur_cate_list, str): 23 | cur_cate_list = [cur_cate_list] 24 | cur_dataset_list = [dataset_name for cate_name in cur_cate_list for dataset_name in category_to_datasets[cate_name]] 25 | cur_title = " and ".join(cur_cate_list) 26 | length_datasets = cur_dataset_list 27 | 28 | ncols = min(4, len(length_datasets)) 29 | nrows = (len(length_datasets) - 1) // ncols + 1 30 | if ncols == 3: 31 | col_width = 8 32 | else: 33 | col_width = 6 34 | 35 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 36 | fig.set_size_inches((ncols * col_width, 20 * nrows)) # helmet has 45 models for 20 height, we have 47 models 37 | 38 | plt.rc('axes', unicode_minus=False) 39 | plt.rcParams.update({'axes.unicode_minus': False}) 40 | 41 | for i, dataset in enumerate(length_datasets): 42 | if nrows > 1: 43 | a = ax[i // ncols][i % ncols] 44 | else: 45 | a = ax[i] 46 | 47 | new_index = full_table_models 48 | 49 | tdf = lf_df[lf_df.input_max_length > 4096] 50 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 51 | tdf = tdf.reindex(new_index) 52 | 53 | # process the scores 54 | annot_matrix = tdf.copy() 55 | tdf = tdf.applymap(lambda x: x if not pd.isna(x) else 0) 56 | annot_matrix = annot_matrix.applymap(lambda x: "N/A" if pd.isna(x) else f"{x:.1f}") 57 | 58 | sns_g = sns.heatmap( 59 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 60 | ax=a, annot_kws={"fontsize": 23.5}, 61 | cbar=False 62 | ) 63 | sns_g.set_title(dataset if dataset != "Ours" else "Avg.", fontsize=34) 64 | 65 | sns_g.set_ylabel("") 66 | sns_g.set_xlabel("") 67 | 68 | new_index = [x.replace("-Inst", '') for x in new_index] 69 | sns_g.set_yticklabels(new_index, size=28) 70 | 71 | xticks = ['8k', '16k', '32k', '64k', '128k'] 72 | sns_g.set_xticklabels(xticks, size=28) 73 | 74 | for idx in [6, 13, 27, 33, 36, 40, 43, 45]: 75 | a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 76 | 77 | if len(cur_cate_list) > 1: 78 | fig.tight_layout() 79 | col_divider = 2 80 | assert nrows == 1 81 | left_ax = ax[col_divider - 1] 82 | right_ax = ax[col_divider] 83 | left_pos = left_ax.get_position() 84 | right_pos = right_ax.get_position() 85 | 86 | line_pos = (left_pos.x1 + right_pos.x0) / 2 + 0.001 87 | 88 | fig.add_artist(plt.Line2D([line_pos, line_pos], [0, 1], 89 | transform=fig.transFigure, 90 | color='black', 91 | linestyle='--', 92 | linewidth=5)) 93 | 94 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 95 | 96 | plt.tight_layout() 97 | if ncols == 3: 98 | plt.subplots_adjust(left=0.17, wspace=0.30) 99 | cur_file_title = cur_title.replace(" ", "-") 100 | figure_path = os.path.join(project_root, f"figures/3_category_{cur_file_title}_length_full.pdf") 101 | plt.savefig(figure_path, dpi=500, format="pdf") 102 | plt.show() 103 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import subprocess 4 | from datetime import timedelta 5 | 6 | import deepspeed 7 | import torch 8 | import torch.multiprocessing as mp 9 | from torch import distributed as dist 10 | 11 | timeout = timedelta(minutes=60) 12 | 13 | 14 | def _find_free_port(): 15 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 17 | # Binding to port 0 will cause the OS to find an available port for us 18 | sock.bind(('', 0)) 19 | port = sock.getsockname()[1] 20 | sock.close() 21 | # NOTE: there is still a chance the port could be taken by other processes. 22 | return port 23 | 24 | 25 | def _is_free_port(port): 26 | ips = socket.gethostbyname_ex(socket.gethostname())[-1] 27 | ips.append('localhost') 28 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 29 | return all(s.connect_ex((ip, port)) != 0 for ip in ips) 30 | 31 | 32 | def init_dist(launcher, backend='nccl', **kwargs): 33 | if mp.get_start_method(allow_none=True) is None: 34 | mp.set_start_method('spawn') 35 | if launcher == 'pytorch': 36 | _init_dist_pytorch(backend, **kwargs) 37 | elif launcher == 'mpi': 38 | _init_dist_mpi(backend, **kwargs) 39 | elif launcher == 'slurm': 40 | _init_dist_slurm(backend, **kwargs) 41 | else: 42 | raise ValueError(f'Invalid launcher type: {launcher}') 43 | 44 | 45 | def _init_dist_pytorch(backend, **kwargs): 46 | # TODO: use local_rank instead of rank % num_gpus 47 | rank = int(os.environ['RANK']) 48 | num_gpus = torch.cuda.device_count() 49 | torch.cuda.set_device(rank % num_gpus) 50 | # dist.init_process_group(backend=backend, **kwargs) 51 | deepspeed.init_distributed(dist_backend=backend) 52 | 53 | 54 | def _init_dist_mpi(backend, **kwargs): 55 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 56 | torch.cuda.set_device(local_rank) 57 | if 'MASTER_PORT' not in os.environ: 58 | # 29500 is torch.distributed default port 59 | os.environ['MASTER_PORT'] = '29500' 60 | if 'MASTER_ADDR' not in os.environ: 61 | raise KeyError('The environment variable MASTER_ADDR is not set') 62 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] 63 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] 64 | dist.init_process_group(backend=backend, **kwargs) 65 | 66 | 67 | def _init_dist_slurm(backend, port=None): 68 | """Initialize slurm distributed training environment. 69 | 70 | If argument ``port`` is not specified, then the master port will be system 71 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 72 | environment variable, then a default port ``29500`` will be used. 73 | 74 | Args: 75 | backend (str): Backend of torch.distributed. 76 | port (int, optional): Master port. Defaults to None. 77 | """ 78 | proc_id = int(os.environ['SLURM_PROCID']) 79 | ntasks = int(os.environ['SLURM_NTASKS']) 80 | node_list = os.environ['SLURM_NODELIST'] 81 | num_gpus = torch.cuda.device_count() 82 | torch.cuda.set_device(proc_id % num_gpus) 83 | addr = subprocess.getoutput( 84 | f'scontrol show hostname {node_list} | head -n1') 85 | # specify master port 86 | if port is not None: 87 | os.environ['MASTER_PORT'] = str(port) 88 | elif 'MASTER_PORT' in os.environ: 89 | pass # use MASTER_PORT in the environment variable 90 | else: 91 | # if torch.distributed default port(29500) is available 92 | # then use it, else find a free port 93 | if _is_free_port(29500): 94 | os.environ['MASTER_PORT'] = '29500' 95 | else: 96 | os.environ['MASTER_PORT'] = str(_find_free_port()) 97 | # use MASTER_ADDR in the environment variable if it already exists 98 | if 'MASTER_ADDR' not in os.environ: 99 | os.environ['MASTER_ADDR'] = addr 100 | os.environ['WORLD_SIZE'] = str(ntasks) 101 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 102 | os.environ['RANK'] = str(proc_id) 103 | # dist.init_process_group(backend=backend, timeout=timeout) 104 | deepspeed.init_distributed(dist_backend=backend) 105 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/serve/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import torch 5 | from PIL import Image 6 | from transformers import StoppingCriteria 7 | 8 | from .constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def expand2square(pil_img, background_color): 16 | width, height = pil_img.size 17 | if width == height: 18 | return pil_img 19 | elif width > height: 20 | result = Image.new(pil_img.mode, (width, width), background_color) 21 | result.paste(pil_img, (0, (width - height) // 2)) 22 | return result 23 | else: 24 | result = Image.new(pil_img.mode, (height, height), background_color) 25 | result.paste(pil_img, ((height - width) // 2, 0)) 26 | return result 27 | 28 | 29 | def process_images(images, image_processor, model_cfg): 30 | image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None) 31 | new_images = [] 32 | if image_aspect_ratio == 'pad': 33 | for image in images: 34 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 35 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 36 | new_images.append(image) 37 | else: 38 | return image_processor(images, return_tensors='pt')['pixel_values'] 39 | if all(x.shape == new_images[0].shape for x in new_images): 40 | new_images = torch.stack(new_images, dim=0) 41 | return new_images 42 | 43 | 44 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, 45 | num_image_tokens=None, return_tensors=None): 46 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 47 | 48 | def insert_separator(X, sep): 49 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 50 | 51 | input_ids = [] 52 | offset = 0 53 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 54 | offset = 1 55 | input_ids.append(prompt_chunks[0][0]) 56 | 57 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + num_image_tokens)): 58 | input_ids.extend(x[offset:]) 59 | 60 | if return_tensors is not None: 61 | if return_tensors == 'pt': 62 | return torch.tensor(input_ids, dtype=torch.long) 63 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 64 | return input_ids 65 | 66 | 67 | def get_model_name_from_path(model_path): 68 | model_path = model_path.strip('/') 69 | model_paths = model_path.split('/') 70 | if model_paths[-1].startswith('checkpoint-'): 71 | return model_paths[-2] + '_' + model_paths[-1] 72 | else: 73 | return model_paths[-1] 74 | 75 | 76 | class KeywordsStoppingCriteria(StoppingCriteria): 77 | def __init__(self, keywords, tokenizer, input_ids): 78 | self.keywords = keywords 79 | self.keyword_ids = [] 80 | self.max_keyword_len = 0 81 | for keyword in keywords: 82 | cur_keyword_ids = tokenizer(keyword).input_ids 83 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 84 | cur_keyword_ids = cur_keyword_ids[1:] 85 | if len(cur_keyword_ids) > self.max_keyword_len: 86 | self.max_keyword_len = len(cur_keyword_ids) 87 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 88 | self.tokenizer = tokenizer 89 | self.start_len = input_ids.shape[1] 90 | 91 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 92 | assert output_ids.shape[0] == 1, 'Only support batch size 1 (yet)' # TODO 93 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 94 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 95 | for keyword_id in self.keyword_ids: 96 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 97 | return True 98 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 99 | for keyword in self.keywords: 100 | if keyword in outputs: 101 | return True 102 | return False 103 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/serve/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from .constants import LOGDIR 10 | 11 | server_error_msg = '**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**' 12 | moderation_msg = 'YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN.' 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 22 | datefmt='%Y-%m-%d %H:%M:%S', 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger('stdout') 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger('stderr') 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) 99 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = 'https://api.openai.com/v1/moderations' 107 | headers = {'Content-Type': 'application/json', 108 | 'Authorization': 'Bearer ' + os.environ['OPENAI_API_KEY']} 109 | text = text.replace('\n', '') 110 | data = '{' + '"input": ' + f'"{text}"' + '}' 111 | data = data.encode('utf-8') 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()['results'][0]['flagged'] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return 'None' 126 | return f'Semaphore(value={semaphore._value}, locked={semaphore.locked()})' 127 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/llama_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 3 | from transformers.models.llama.modeling_llama import (LLAMA_ATTENTION_CLASSES, 4 | LlamaFlashAttention2) 5 | 6 | 7 | # Modified from transformers.models.llama.modeling_llama.LlamaFlashAttention2 8 | class LlamaFlashAttention2ForPackedTraining(LlamaFlashAttention2): 9 | 10 | def _flash_attention_forward( 11 | self, 12 | query_states, 13 | key_states, 14 | value_states, 15 | attention_mask, 16 | query_length, 17 | dropout=0.0, 18 | softmax_scale=None, 19 | use_sliding_windows=False, 20 | ): 21 | """ 22 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 23 | first unpad the input, then computes the attention scores and pad the final attention scores. 24 | 25 | Args: 26 | query_states (`torch.Tensor`): 27 | Input query states to be passed to Flash Attention API 28 | key_states (`torch.Tensor`): 29 | Input key states to be passed to Flash Attention API 30 | value_states (`torch.Tensor`): 31 | Input value states to be passed to Flash Attention API 32 | attention_mask (`torch.Tensor`): 33 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 34 | position of padding tokens and 1 for the position of non-padding tokens. 35 | dropout (`int`, *optional*): 36 | Attention dropout 37 | softmax_scale (`float`, *optional*): 38 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 39 | use_sliding_windows (`bool`, *optional*): 40 | Whether to activate sliding window attention. 41 | """ 42 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 43 | query_states = query_states.squeeze(0) 44 | key_states = key_states.squeeze(0) 45 | value_states = value_states.squeeze(0) 46 | cu_seqlens = attention_mask.squeeze(0) 47 | 48 | with torch.no_grad(): 49 | max_seqlen = max([ 50 | cu_seqlens[idx+1] - cu_seqlens[idx] 51 | for idx in range(cu_seqlens.size(0) - 1) 52 | ]).item() 53 | 54 | if not self._flash_attn_uses_top_left_mask: 55 | causal = self.is_causal 56 | else: 57 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 58 | causal = self.is_causal and query_length != 1 59 | 60 | # Decide whether to use SWA or not by layer index. 61 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: 62 | use_sliding_windows = False 63 | 64 | if not use_sliding_windows: 65 | attn_output = flash_attn_varlen_func( 66 | q=query_states, 67 | k=key_states, 68 | v=value_states, 69 | cu_seqlens_q=cu_seqlens, 70 | cu_seqlens_k=cu_seqlens, 71 | max_seqlen_q=max_seqlen, 72 | max_seqlen_k=max_seqlen, 73 | dropout_p=dropout, 74 | softmax_scale=softmax_scale, 75 | causal=causal, 76 | ) 77 | else: 78 | attn_output = flash_attn_varlen_func( 79 | q=query_states, 80 | k=key_states, 81 | v=value_states, 82 | cu_seqlens_q=cu_seqlens, 83 | cu_seqlens_k=cu_seqlens, 84 | max_seqlen_q=max_seqlen, 85 | max_seqlen_k=max_seqlen, 86 | dropout_p=dropout, 87 | softmax_scale=softmax_scale, 88 | causal=causal, 89 | window_size=(self.config.sliding_window, self.config.sliding_window), 90 | ) 91 | 92 | query_states = query_states.unsqueeze(0) 93 | key_states = key_states.unsqueeze(0) 94 | value_states = value_states.unsqueeze(0) 95 | return attn_output 96 | 97 | 98 | def replace_llama_attention_class(): 99 | LLAMA_ATTENTION_CLASSES['flash_attention_2'] = LlamaFlashAttention2ForPackedTraining 100 | print('Replace LLAMA_ATTENTION_CLASSES to support packed training!!') 101 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/qwen2_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 3 | from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES, 4 | Qwen2FlashAttention2) 5 | 6 | 7 | # Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 8 | class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2): 9 | 10 | def _flash_attention_forward( 11 | self, 12 | query_states, 13 | key_states, 14 | value_states, 15 | attention_mask, 16 | query_length, 17 | dropout=0.0, 18 | softmax_scale=None, 19 | use_sliding_windows=False, 20 | ): 21 | """ 22 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 23 | first unpad the input, then computes the attention scores and pad the final attention scores. 24 | 25 | Args: 26 | query_states (`torch.Tensor`): 27 | Input query states to be passed to Flash Attention API 28 | key_states (`torch.Tensor`): 29 | Input key states to be passed to Flash Attention API 30 | value_states (`torch.Tensor`): 31 | Input value states to be passed to Flash Attention API 32 | attention_mask (`torch.Tensor`): 33 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 34 | position of padding tokens and 1 for the position of non-padding tokens. 35 | dropout (`int`, *optional*): 36 | Attention dropout 37 | softmax_scale (`float`, *optional*): 38 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 39 | use_sliding_windows (`bool`, *optional*): 40 | Whether to activate sliding window attention. 41 | """ 42 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 43 | query_states = query_states.squeeze(0) 44 | key_states = key_states.squeeze(0) 45 | value_states = value_states.squeeze(0) 46 | cu_seqlens = attention_mask.squeeze(0) 47 | 48 | with torch.no_grad(): 49 | max_seqlen = max([ 50 | cu_seqlens[idx+1] - cu_seqlens[idx] 51 | for idx in range(cu_seqlens.size(0) - 1) 52 | ]).item() 53 | 54 | if not self._flash_attn_uses_top_left_mask: 55 | causal = self.is_causal 56 | else: 57 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 58 | causal = self.is_causal and query_length != 1 59 | 60 | # Decide whether to use SWA or not by layer index. 61 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: 62 | use_sliding_windows = False 63 | 64 | if not use_sliding_windows: 65 | attn_output = flash_attn_varlen_func( 66 | q=query_states, 67 | k=key_states, 68 | v=value_states, 69 | cu_seqlens_q=cu_seqlens, 70 | cu_seqlens_k=cu_seqlens, 71 | max_seqlen_q=max_seqlen, 72 | max_seqlen_k=max_seqlen, 73 | dropout_p=dropout, 74 | softmax_scale=softmax_scale, 75 | causal=causal, 76 | ) 77 | else: 78 | attn_output = flash_attn_varlen_func( 79 | q=query_states, 80 | k=key_states, 81 | v=value_states, 82 | cu_seqlens_q=cu_seqlens, 83 | cu_seqlens_k=cu_seqlens, 84 | max_seqlen_q=max_seqlen, 85 | max_seqlen_k=max_seqlen, 86 | dropout_p=dropout, 87 | softmax_scale=softmax_scale, 88 | causal=causal, 89 | window_size=(self.config.sliding_window, self.config.sliding_window), 90 | ) 91 | 92 | query_states = query_states.unsqueeze(0) 93 | key_states = key_states.unsqueeze(0) 94 | value_states = value_states.unsqueeze(0) 95 | return attn_output 96 | 97 | 98 | def replace_qwen2_attention_class(): 99 | QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining 100 | print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!') 101 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/media.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | import cv2 7 | import numpy as np 8 | import PIL 9 | import PIL.Image 10 | import requests 11 | from transformers import PretrainedConfig 12 | 13 | # from llava.constants import MEDIA_TOKENS 14 | # from llava.media import Image, Video 15 | # from llava.utils import make_list 16 | # from llava.utils.logging import logger 17 | 18 | MEDIA_TOKENS = { 19 | "image": "", 20 | "video": "", 21 | } 22 | 23 | 24 | class Media: 25 | pass 26 | 27 | 28 | class File(Media): 29 | def __init__(self, path: str) -> None: 30 | self.path = path 31 | 32 | 33 | class Image(File): 34 | pass 35 | 36 | 37 | class Video(File): 38 | pass 39 | 40 | 41 | def make_list(obj: Any) -> List: 42 | return obj if isinstance(obj, list) else [obj] 43 | 44 | 45 | def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: 46 | if isinstance(image, Image): 47 | if image.path.startswith("http://") or image.path.startswith("https://"): 48 | image = PIL.Image.open(requests.get(image.path, stream=True).raw) 49 | else: 50 | image = PIL.Image.open(image.path) 51 | return image 52 | 53 | 54 | def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: 55 | # Load video frames from a directory 56 | if os.path.isdir(video_path): 57 | frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) 58 | indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) 59 | return [PIL.Image.open(frame_paths[index]) for index in indices] 60 | 61 | # Load video frames from a video file 62 | vidcap = cv2.VideoCapture(video_path) 63 | 64 | # Find the last frame as frame count might not be accurate 65 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 66 | while frame_count > 0: 67 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) 68 | if vidcap.grab(): 69 | break 70 | frame_count -= 1 71 | else: 72 | raise ValueError(f"Video '{video_path}' has no frames.") 73 | 74 | # Extract frames uniformly 75 | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) 76 | frames = {} 77 | for index in indices: 78 | if index in frames: 79 | continue 80 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) 81 | success, frame = vidcap.read() 82 | if not success: 83 | print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") 84 | continue 85 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 86 | frames[index] = PIL.Image.fromarray(frame) 87 | return [frames[index] for index in indices if index in frames] 88 | 89 | 90 | def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: 91 | num_frames = config.num_video_frames 92 | if getattr(config, "fps") != 0: 93 | print("Extracting frames from video with specified FPS is not supported yet. Ignored.") 94 | 95 | frames = _load_video(video.path, num_frames=num_frames) 96 | return frames 97 | 98 | 99 | def extract_media( 100 | messages: List[Dict[str, Any]], 101 | config: Optional[PretrainedConfig] = None, 102 | draft: bool = False, 103 | ) -> Dict[str, List[Any]]: 104 | media = defaultdict(list) 105 | for message in messages: 106 | text = "" 107 | for part in make_list(message["value"]): 108 | if isinstance(part, str): 109 | for token in MEDIA_TOKENS.values(): 110 | if token in part: 111 | print(f"Media token '{token}' found in text: '{part}'. Removed.") 112 | part = part.replace(token, "").strip() 113 | text += part 114 | elif isinstance(part, (Image, PIL.Image.Image)): 115 | if draft: 116 | media["image"].append(part) 117 | else: 118 | media["image"].append(_extract_image(part)) 119 | text += MEDIA_TOKENS["image"] 120 | elif isinstance(part, Video): 121 | if draft: 122 | media["video"].append(part) 123 | else: 124 | media["video"].append(_extract_video(part, config)) 125 | text += MEDIA_TOKENS["video"] 126 | else: 127 | raise ValueError(f"Unsupported prompt part type: {type(part)}") 128 | message["value"] = text 129 | return media 130 | -------------------------------------------------------------------------------- /vlm_model/nvila_8b_e2481b0c/media.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | import cv2 7 | import numpy as np 8 | import PIL 9 | import PIL.Image 10 | import requests 11 | from transformers import PretrainedConfig 12 | 13 | # from llava.constants import MEDIA_TOKENS 14 | # from llava.media import Image, Video 15 | # from llava.utils import make_list 16 | # from llava.utils.logging import logger 17 | 18 | MEDIA_TOKENS = { 19 | "image": "", 20 | "video": "", 21 | } 22 | 23 | 24 | class Media: 25 | pass 26 | 27 | 28 | class File(Media): 29 | def __init__(self, path: str) -> None: 30 | self.path = path 31 | 32 | 33 | class Image(File): 34 | pass 35 | 36 | 37 | class Video(File): 38 | pass 39 | 40 | 41 | def make_list(obj: Any) -> List: 42 | return obj if isinstance(obj, list) else [obj] 43 | 44 | 45 | def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: 46 | if isinstance(image, Image): 47 | if image.path.startswith("http://") or image.path.startswith("https://"): 48 | image = PIL.Image.open(requests.get(image.path, stream=True).raw) 49 | else: 50 | image = PIL.Image.open(image.path) 51 | return image 52 | 53 | 54 | def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: 55 | # Load video frames from a directory 56 | if os.path.isdir(video_path): 57 | frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) 58 | indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) 59 | return [PIL.Image.open(frame_paths[index]) for index in indices] 60 | 61 | # Load video frames from a video file 62 | vidcap = cv2.VideoCapture(video_path) 63 | 64 | # Find the last frame as frame count might not be accurate 65 | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 66 | while frame_count > 0: 67 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) 68 | if vidcap.grab(): 69 | break 70 | frame_count -= 1 71 | else: 72 | raise ValueError(f"Video '{video_path}' has no frames.") 73 | 74 | # Extract frames uniformly 75 | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) 76 | frames = {} 77 | for index in indices: 78 | if index in frames: 79 | continue 80 | vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) 81 | success, frame = vidcap.read() 82 | if not success: 83 | print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") 84 | continue 85 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 86 | frames[index] = PIL.Image.fromarray(frame) 87 | return [frames[index] for index in indices if index in frames] 88 | 89 | 90 | def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: 91 | num_frames = config.num_video_frames 92 | if getattr(config, "fps") != 0: 93 | print("Extracting frames from video with specified FPS is not supported yet. Ignored.") 94 | 95 | frames = _load_video(video.path, num_frames=num_frames) 96 | return frames 97 | 98 | 99 | def extract_media( 100 | messages: List[Dict[str, Any]], 101 | config: Optional[PretrainedConfig] = None, 102 | draft: bool = False, 103 | ) -> Dict[str, List[Any]]: 104 | media = defaultdict(list) 105 | for message in messages: 106 | text = "" 107 | for part in make_list(message["value"]): 108 | if isinstance(part, str): 109 | for token in MEDIA_TOKENS.values(): 110 | if token in part: 111 | print(f"Media token '{token}' found in text: '{part}'. Removed.") 112 | part = part.replace(token, "").strip() 113 | text += part 114 | elif isinstance(part, (Image, PIL.Image.Image)): 115 | if draft: 116 | media["image"].append(part) 117 | else: 118 | media["image"].append(_extract_image(part)) 119 | text += MEDIA_TOKENS["image"] 120 | elif isinstance(part, Video): 121 | if draft: 122 | media["video"].append(part) 123 | else: 124 | media["video"].append(_extract_video(part, config)) 125 | text += MEDIA_TOKENS["video"] 126 | else: 127 | raise ValueError(f"Unsupported prompt part type: {type(part)}") 128 | message["value"] = text 129 | return media 130 | -------------------------------------------------------------------------------- /vlm_model/minicpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .model_utils import LLM 3 | 4 | from PIL import Image 5 | from transformers import AutoModel, AutoTokenizer 6 | 7 | 8 | class MiniCPMModel(LLM): 9 | def __init__( 10 | self, 11 | model_name, 12 | temperature=0.9, 13 | top_p=0.9, 14 | max_length=32768, 15 | generation_max_length=2048, 16 | generation_min_length=0, 17 | do_sample=True, 18 | stop_newline=False, 19 | use_chat_template=False, 20 | **kwargs, 21 | ): 22 | super().__init__( 23 | model_name, 24 | temperature=temperature, 25 | top_p=top_p, 26 | max_length=max_length, 27 | generation_max_length=generation_max_length, 28 | generation_min_length=generation_min_length, 29 | do_sample=do_sample, 30 | stop_newline=stop_newline, 31 | use_chat_template=use_chat_template, 32 | ) 33 | 34 | model_kwargs = {} 35 | model_kwargs["offload_state_dict"] = kwargs.get("offload_state_dict", False) 36 | model_kwargs["attn_implementation"] = kwargs.get("attn_implementation", "flash_attention_2") 37 | self.max_length = max_length 38 | self.processor = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True) 39 | tokenizer = self.processor 40 | if tokenizer.pad_token is None: 41 | tokenizer.pad_token = tokenizer.eos_token 42 | tokenizer.pad_token_id = tokenizer.eos_token_id 43 | tokenizer.truncation_side = "left" # we truncate elder history than recent one 44 | tokenizer.padding_side = "left" # batch generation needs left padding 45 | 46 | self.model = AutoModel.from_pretrained( 47 | model_name, 48 | torch_dtype=kwargs.get("torch_dtype", torch.bfloat16), 49 | device_map="auto", 50 | trust_remote_code=True, 51 | **model_kwargs 52 | ) 53 | 54 | if kwargs.get("torch_compile", True): 55 | self.model = torch.compile(self.model) 56 | 57 | # use the default if possible, append if necessary 58 | stop_token_ids = self.model.generation_config.eos_token_id 59 | stop_token_ids = [stop_token_ids] if not isinstance(stop_token_ids, list) else stop_token_ids 60 | if stop_newline: 61 | stop = list(set(["\n", "Ċ", "ĊĊ", "<0x0A>"])) 62 | stop_token_ids = list( 63 | set([tokenizer.convert_tokens_to_ids(stop_token) for stop_token in stop] + stop_token_ids)) 64 | if tokenizer.unk_token_id is not None and tokenizer.unk_token_id in stop_token_ids: 65 | stop_token_ids.remove(tokenizer.unk_token_id) 66 | stop_token_ids = [x for x in stop_token_ids if x is not None] 67 | self.stop_token_ids = stop_token_ids 68 | self.device = self.model.device 69 | 70 | def format_chat(self, text, image_list, system_prompt): 71 | new_content = [Image.open(image).convert('RGB') for image in image_list] 72 | text_content = text.replace("", "") 73 | new_content.append(text_content) 74 | messages = [{"role": "user", "content": new_content}] 75 | return messages 76 | 77 | 78 | def prepare_inputs(self, test_item, data): 79 | text = data["user_template"].format(**test_item) 80 | image_list = test_item["image_list"] 81 | inputs = self.format_chat(text, image_list, data["system_template"]) 82 | 83 | return inputs 84 | 85 | 86 | @torch.no_grad() 87 | def generate(self, inputs=None, prompt=None, **kwargs): 88 | text = self.model.chat( 89 | image=None, 90 | msgs=inputs, 91 | tokenizer=self.processor, 92 | max_new_tokens=self.generation_max_length, 93 | min_new_tokens=self.generation_min_length, 94 | do_sample=self.do_sample, 95 | temperature=self.temperature if self.do_sample else None, 96 | top_p=self.top_p if self.do_sample else None, 97 | top_k=None, 98 | eos_token_id=self.stop_token_ids, 99 | pad_token_id=self.processor.pad_token_id, 100 | ) 101 | 102 | save_prompt = [i if isinstance(i, str) else "" for i in inputs[0]["content"]] 103 | save_prompt = " ".join(save_prompt) 104 | if len(save_prompt) > 6000: 105 | save_prompt = save_prompt[:2000] + " " + save_prompt[-2000:] 106 | return { 107 | "output": text, 108 | "input_len": -1, 109 | "output_len": -1, 110 | "input_text": save_prompt, 111 | } -------------------------------------------------------------------------------- /figure_scripts/18_internvl2_V2PE.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | dataset_to_metrics["gov-report"] = ["rougeLsum_f1"] 4 | dataset_to_metrics["multi-lexsum"] = ["rougeLsum_f1"] 5 | custom_avgs["Summ"] = ['gov-report rougeLsum_f1', 'multi-lexsum rougeLsum_f1'] 6 | 7 | models_configs = [ 8 | {"model": "InternVL2-2B", "use_chat_template": True, "training_length": 8192}, 9 | {"model": "V2PE-256K_16", "use_chat_template": True, "training_length": 8192}, 10 | {"model": "V2PE-256K", "use_chat_template": True, "training_length": 8192}, 11 | {"model": "V2PE-256K_256", "use_chat_template": True, "training_length": 8192}, 12 | ] 13 | 14 | for model in models_configs: 15 | model["output_dir"] = f"/home/zhaowei.wang/data_dir/mmlb_result_backup/analysis_models/{model['model']}" 16 | 17 | models_configs[0]["output_dir"] = f"/home/zhaowei.wang/data_dir/mmlb_result/InternVL2-2B" 18 | 19 | model_name_replace.update({ 20 | "V2PE-256K_16": "V2PE (16)", 21 | "V2PE-256K": "V2PE (64)", 22 | "V2PE-256K_256": "V2PE (256)" 23 | }) 24 | 25 | curr_table_models = [ 26 | "InternVL2-2B", 27 | "V2PE (16)", 28 | "V2PE (64)", 29 | "V2PE (256)", 30 | ] 31 | 32 | new_dfs = [] 33 | 34 | for m_idx, model in enumerate(models_configs): 35 | args = arguments() 36 | for dataset in dataset_configs: 37 | args.update(dataset) 38 | args.update(model) 39 | 40 | # parse the metrics 41 | if dataset["dataset"] in {"multi-lexsum", "gov-report"}: 42 | path = args.get_path() 43 | path = path.replace("-gpt4eval_o.json", ".json") + ".score" 44 | with open(path) as f: 45 | results = json.load(f) 46 | metric = {"rougeLsum_f1": results["rougeLsum_f1"]} 47 | else: 48 | metric = args.get_averaged_metric() 49 | dsimple, mnames = args.get_metric_name() 50 | 51 | if metric is None: 52 | print("failed:", args.get_path()) # will be np.nan when using DataFrame 53 | continue 54 | for k, m in metric.items(): 55 | new_dfs.append({**asdict(args), **model, 56 | "metric name": k, "metric": m, 57 | "dataset_simple": dsimple + " " + k, 58 | "test_data": f"{args.dataset}-{args.test_name}-{args.input_max_length}" 59 | }) 60 | 61 | all_dfs = new_dfs 62 | all_df = pd.DataFrame(all_dfs) 63 | 64 | # Figure: positional extrploation methods 65 | main_table_datasets = [ 66 | "VRAG", 67 | "NIAH", 68 | 'ICL', 69 | "Summ", 70 | "DocVQA", 71 | "Ours" 72 | ] 73 | 74 | # plot specific ones in a row, for formatting in the paper 75 | lf_df = process_df(all_df) 76 | length_datasets = main_table_datasets 77 | 78 | ncols = 3 79 | nrows = (len(length_datasets) - 1) // ncols + 1 80 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 81 | 82 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 83 | fig.set_size_inches((ncols * 8, 7)) # helmet has 45 models for 40 height, we have 46 models 84 | 85 | plt.rc('axes', unicode_minus=False) 86 | plt.rcParams.update({'axes.unicode_minus': False}) 87 | 88 | for i, dataset in enumerate(length_datasets): 89 | if nrows > 1: 90 | a = ax[i // ncols][i % ncols] 91 | else: 92 | a = ax[i] 93 | 94 | new_index = curr_table_models 95 | 96 | tdf = lf_df[lf_df.input_max_length > 4096] 97 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 98 | tdf = tdf.reindex(new_index) 99 | 100 | # process the scores 101 | annot_matrix = tdf.copy() 102 | tdf = tdf.applymap(lambda x: x if not pd.isna(x) else 0) 103 | annot_matrix = annot_matrix.applymap(lambda x: "N/A" if pd.isna(x) else f"{x:.1f}") 104 | 105 | sns_g = sns.heatmap( 106 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 107 | ax=a, annot_kws={"fontsize": 23.5}, 108 | cbar=False 109 | ) 110 | sns_g.set_title(dataset if dataset != "Ours" else "Avg.", fontsize=34) 111 | 112 | sns_g.set_ylabel("") 113 | sns_g.set_xlabel("") 114 | 115 | new_index = [x.replace("-Inst", '') for x in new_index] 116 | new_index = ["$\\diamond$ w/ Yarn" if "(Y)" in x else x for x in new_index ] 117 | 118 | sns_g.set_yticklabels(new_index, size=28) 119 | 120 | xticks = ['8k', '16k', '32k', '64k', '128k'] 121 | sns_g.set_xticklabels(xticks, size=28) 122 | 123 | # for idx in [6, 13, 27, 33, 36, 40, 43, 45]: 124 | # a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 125 | 126 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 127 | 128 | plt.tight_layout() 129 | plt.subplots_adjust(left=0.17, wspace=0.15) 130 | figure_path = os.path.join(project_root, f"figures/18_internvl2_v2pe.pdf") 131 | plt.savefig(figure_path, dpi=500, format="pdf") 132 | plt.show() 133 | -------------------------------------------------------------------------------- /figure_scripts/18_qwen2_5_yarn.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | 4 | models_configs = [ 5 | {"model": "Qwen2.5-VL-3B-Instruct_yarn", "use_chat_template": True, "training_length": 32768}, 6 | {"model": "Qwen2.5-VL-7B-Instruct_yarn", "use_chat_template": True, "training_length": 32768}, 7 | {"model": 'Qwen2.5-VL-32B-Instruct_yarn', "use_chat_template": True, "training_length": 32768}, 8 | {"model": 'Qwen2.5-VL-72B-Instruct-AWQ_yarn', "use_chat_template": True, "training_length": 32768}, 9 | 10 | {"model": "V2PE-256K_16", "use_chat_template": True, "training_length": 8192}, 11 | {"model": "V2PE-256K", "use_chat_template": True, "training_length": 8192}, 12 | {"model": "V2PE-256K_256", "use_chat_template": True, "training_length": 8192}, 13 | ] 14 | 15 | for model in models_configs: 16 | model["output_dir"] = f"/home/zhaowei.wang/data_dir/mmlb_result_backup/analysis_models/{model['model']}" 17 | 18 | model_name_replace.update({ 19 | "Qwen2.5-VL-3B-Instruct_yarn": 'Qwen2.5-VL-3B-Inst (Y)', 20 | "Qwen2.5-VL-7B-Instruct_yarn": 'Qwen2.5-VL-7B-Inst (Y)', 21 | 'Qwen2.5-VL-32B-Instruct_yarn': 'Qwen2.5-VL-32B-Inst (Y)', 22 | 'Qwen2.5-VL-72B-Instruct-AWQ_yarn': 'Qwen2.5-VL-72B-Inst (Y)', 23 | "V2PE-256K_16": "V2PE (16)", 24 | "V2PE-256K": "V2PE (64)", 25 | "V2PE-256K_256": "V2PE (256)" 26 | }) 27 | 28 | curr_table_models = [ 29 | "Qwen2.5-VL-3B-Inst", 30 | "Qwen2.5-VL-3B-Inst (Y)", 31 | "Qwen2.5-VL-7B-Inst", 32 | "Qwen2.5-VL-7B-Inst (Y)", 33 | 'Qwen2.5-VL-32B-Inst', 34 | "Qwen2.5-VL-32B-Inst (Y)", 35 | 'Qwen2.5-VL-72B-Inst', 36 | "Qwen2.5-VL-72B-Inst (Y)", 37 | # "V2PE (16)", 38 | # "V2PE (64)", 39 | # "V2PE (256)", 40 | ] 41 | 42 | new_dfs = [] 43 | 44 | for m_idx, model in enumerate(models_configs): 45 | args = arguments() 46 | for dataset in dataset_configs: 47 | args.update(dataset) 48 | args.update(model) 49 | 50 | # parse the metrics 51 | metric = args.get_averaged_metric() 52 | dsimple, mnames = args.get_metric_name() 53 | 54 | if metric is None: 55 | print("failed:", args.get_path()) # will be np.nan when using DataFrame 56 | continue 57 | for k, m in metric.items(): 58 | new_dfs.append({**asdict(args), **model, 59 | "metric name": k, "metric": m, 60 | "dataset_simple": dsimple + " " + k, 61 | "test_data": f"{args.dataset}-{args.test_name}-{args.input_max_length}" 62 | }) 63 | 64 | all_dfs = dfs + new_dfs 65 | all_df = pd.DataFrame(all_dfs) 66 | 67 | # Figure: positional extrploation methods 68 | main_table_datasets = [ 69 | "VRAG", 70 | "NIAH", 71 | 'ICL', 72 | "Summ", 73 | "DocVQA", 74 | "Ours" 75 | ] 76 | 77 | # plot specific ones in a row, for formatting in the paper 78 | lf_df = process_df(all_df) 79 | length_datasets = main_table_datasets 80 | 81 | ncols = 3 82 | nrows = (len(length_datasets) - 1) // ncols + 1 83 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 84 | 85 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 86 | fig.set_size_inches((ncols * 8, 11)) # helmet has 45 models for 40 height, we have 46 models 87 | 88 | plt.rc('axes', unicode_minus=False) 89 | plt.rcParams.update({'axes.unicode_minus': False}) 90 | 91 | for i, dataset in enumerate(length_datasets): 92 | if nrows > 1: 93 | a = ax[i // ncols][i % ncols] 94 | else: 95 | a = ax[i] 96 | 97 | new_index = curr_table_models 98 | 99 | tdf = lf_df[lf_df.input_max_length > 4096] 100 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 101 | tdf = tdf.reindex(new_index) 102 | 103 | # process the scores 104 | annot_matrix = tdf.copy() 105 | tdf = tdf.applymap(lambda x: x if not pd.isna(x) else 0) 106 | annot_matrix = annot_matrix.applymap(lambda x: "N/A" if pd.isna(x) else f"{x:.1f}") 107 | 108 | sns_g = sns.heatmap( 109 | tdf, annot=annot_matrix, cmap=custom_cmap, fmt="", yticklabels=True, 110 | ax=a, annot_kws={"fontsize": 23.5}, 111 | cbar=False 112 | ) 113 | sns_g.set_title(dataset if dataset != "Ours" else "Avg.", fontsize=34) 114 | 115 | sns_g.set_ylabel("") 116 | sns_g.set_xlabel("") 117 | 118 | new_index = [x.replace("-Inst", '') for x in new_index] 119 | new_index = ["$\\diamond$ w/ Yarn" if "(Y)" in x else x for x in new_index ] 120 | 121 | sns_g.set_yticklabels(new_index, size=28) 122 | 123 | xticks = ['8k', '16k', '32k', '64k', '128k'] 124 | sns_g.set_xticklabels(xticks, size=28) 125 | 126 | # for idx in [6, 13, 27, 33, 36, 40, 43, 45]: 127 | # a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 128 | 129 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 130 | 131 | plt.tight_layout() 132 | plt.subplots_adjust(left=0.17, wspace=0.15) 133 | figure_path = os.path.join(project_root, f"figures/18_qwen2_5_yarn.pdf") 134 | plt.savefig(figure_path, dpi=500, format="pdf") 135 | plt.show() 136 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/train_sampler_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import transformers 5 | from torch.utils.data import Dataset, Sampler 6 | from transformers.tokenization_utils_base import BatchEncoding 7 | from transformers.trainer import (LengthGroupedSampler, RandomSampler, 8 | has_length) 9 | from transformers.trainer_pt_utils import logger 10 | 11 | 12 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38 13 | def split_to_even_chunks(indices, lengths, num_chunks): 14 | """ 15 | Split a list of indices into `chunks` chunks of roughly equal lengths. 16 | """ 17 | 18 | if len(indices) % num_chunks != 0: 19 | return [indices[i::num_chunks] for i in range(num_chunks)] 20 | 21 | num_indices_per_chunk = len(indices) // num_chunks 22 | 23 | chunks = [[] for _ in range(num_chunks)] 24 | chunks_lengths = [0 for _ in range(num_chunks)] 25 | for index in indices: 26 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 27 | chunks[shortest_chunk].append(index) 28 | chunks_lengths[shortest_chunk] += lengths[index] 29 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 30 | chunks_lengths[shortest_chunk] = float('inf') 31 | 32 | return chunks 33 | 34 | 35 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88 36 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 37 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 38 | indices = torch.randperm(len(lengths), generator=generator) 39 | megabatch_size = world_size * batch_size 40 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 41 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 42 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 43 | 44 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 45 | 46 | 47 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99 48 | class LengthGroupedSampler(Sampler): 49 | r""" 50 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 51 | keeping a bit of randomness. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | batch_size: int, 57 | world_size: int, 58 | dataset: Optional[Dataset] = None, 59 | lengths: Optional[List[int]] = None, 60 | model_input_name: Optional[str] = None, 61 | generator=None, 62 | ): 63 | if dataset is None and lengths is None: 64 | raise ValueError('One of dataset and lengths must be provided.') 65 | 66 | self.batch_size = batch_size 67 | if lengths is None: 68 | model_input_name = model_input_name if model_input_name is not None else 'input_ids' 69 | if ( 70 | not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) 71 | or model_input_name not in dataset[0] 72 | ): 73 | raise ValueError( 74 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an ' 75 | f"'{model_input_name}' key." 76 | ) 77 | lengths = [len(feature[model_input_name]) for feature in dataset] 78 | elif isinstance(lengths, torch.Tensor): 79 | logger.info( 80 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...' 81 | ) 82 | lengths = lengths.tolist() 83 | self.world_size = world_size 84 | self.lengths = lengths 85 | self.generator = generator 86 | 87 | def __len__(self): 88 | return len(self.lengths) 89 | 90 | def __iter__(self): 91 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 92 | return iter(indices) 93 | 94 | 95 | # patch trainer 96 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 97 | if self.train_dataset is None or not has_length(self.train_dataset): 98 | return None 99 | # Build the sampler. 100 | if self.args.group_by_length: 101 | lengths = [] 102 | for dataset in self.train_dataset.datasets: 103 | lengths = lengths + dataset.length 104 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None 105 | return LengthGroupedSampler( 106 | self.args.train_batch_size, 107 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 108 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, 109 | dataset=self.train_dataset, 110 | lengths=lengths, 111 | model_input_name=model_input_name, 112 | ) 113 | else: 114 | return RandomSampler(self.train_dataset) 115 | 116 | 117 | def replace_train_sampler(): 118 | transformers.Trainer._get_train_sampler = _get_train_sampler 119 | print('Replace train sampler!!') 120 | -------------------------------------------------------------------------------- /vlm_model/qwen_vl.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from copy import deepcopy 4 | from .model_utils import LLM 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | 8 | class QwenVLModel(LLM): 9 | def __init__( 10 | self, 11 | model_name, 12 | temperature=0.9, 13 | top_p=0.9, 14 | max_length=32768, 15 | generation_max_length=2048, 16 | generation_min_length=0, 17 | do_sample=True, 18 | stop_newline=False, 19 | use_chat_template=False, 20 | **kwargs, 21 | ): 22 | super().__init__( 23 | model_name, 24 | temperature=temperature, 25 | top_p=top_p, 26 | max_length=max_length, 27 | generation_max_length=generation_max_length, 28 | generation_min_length=generation_min_length, 29 | do_sample=do_sample, 30 | stop_newline=stop_newline, 31 | use_chat_template=use_chat_template, 32 | ) 33 | 34 | model_kwargs = {} 35 | model_kwargs["offload_state_dict"] = kwargs.get("offload_state_dict", False) 36 | model_kwargs["use_flash_attn"] = True 37 | # Qwen-VL usese flash-attention by default 38 | self.max_length = max_length 39 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 40 | 41 | if self.tokenizer.pad_token is None: 42 | self.tokenizer.pad_token = self.tokenizer.eos_token 43 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 44 | self.tokenizer.truncation_side = "left" # we truncate elder history than recent one 45 | self.tokenizer.padding_side = "left" # batch generation needs left padding 46 | 47 | torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) 48 | if torch_dtype == "int8": 49 | raise ValueError("dtype doesn't support int8") 50 | elif torch_dtype == "int4": 51 | # model_kwargs["load_in_4bit"] = True 52 | # Please use Qwen-VL-Chat-Int4 53 | assert model_name == "Qwen/Qwen-VL-Chat-Int4" 54 | torch_dtype = None 55 | 56 | self.model = AutoModelForCausalLM.from_pretrained( 57 | model_name, 58 | torch_dtype=torch_dtype, 59 | device_map="auto", 60 | trust_remote_code=True, 61 | **model_kwargs 62 | ) 63 | 64 | # if kwargs.get("torch_compile", True): 65 | # self.model = torch.compile(self.model) 66 | 67 | # use the default if possible, append if necessary 68 | stop_token_ids = self.model.generation_config.eos_token_id 69 | stop_token_ids = [stop_token_ids] if not isinstance(stop_token_ids, list) else stop_token_ids 70 | if stop_newline: 71 | stop = list(set(["\n", "Ċ", "ĊĊ", "<0x0A>"])) 72 | stop_token_ids = list( 73 | set([self.tokenizer.convert_tokens_to_ids(stop_token) for stop_token in stop] + stop_token_ids)) 74 | if self.tokenizer.unk_token_id is not None and self.tokenizer.unk_token_id in stop_token_ids: 75 | stop_token_ids.remove(self.tokenizer.unk_token_id) 76 | stop_token_ids = [x for x in stop_token_ids if x is not None] 77 | self.stop_token_ids = stop_token_ids 78 | self.device = self.model.device 79 | self.processor = self.tokenizer 80 | 81 | def format_chat(self, text, image_list, system_prompt): 82 | content = re.split(r'()', text) 83 | image_idx, new_content = 0, [] 84 | for c in content: 85 | if c == "": 86 | new_content.append({ 87 | "image": image_list[image_idx] 88 | }) 89 | image_idx += 1 90 | else: 91 | new_content.append({ 92 | "text": c 93 | }) 94 | assert image_idx == len(image_list) 95 | return new_content 96 | 97 | 98 | def prepare_inputs(self, test_item, data): 99 | text = data["user_template"].format(**test_item) 100 | image_list = test_item["image_list"] 101 | messages = self.format_chat(text, image_list, data["system_template"]) 102 | 103 | inputs = self.tokenizer.from_list_format(messages) 104 | 105 | return inputs 106 | 107 | 108 | @torch.no_grad() 109 | def generate(self, inputs=None, prompt=None, **kwargs): 110 | input_len = -1 111 | generation_config = deepcopy(self.model.generation_config) 112 | generation_config.max_new_tokens = self.generation_max_length 113 | generation_config.do_sample = self.do_sample 114 | generation_config.temperature = self.temperature if self.do_sample else None 115 | generation_config.top_p = self.top_p if self.do_sample else None 116 | generation_config.top_k = None 117 | generation_config.eos_token_id = self.stop_token_ids 118 | 119 | text, history = self.model.chat(self.tokenizer, query=inputs, history=None, 120 | generation_config=generation_config) 121 | 122 | save_prompt = inputs 123 | if len(save_prompt) > 6000: 124 | save_prompt = save_prompt[:2000] + " " + save_prompt[-2000:] 125 | return { 126 | "output": text, 127 | "input_len": input_len, 128 | "output_len": -1, 129 | "input_text": save_prompt, 130 | } -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import ast 4 | import os 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser(description="evaluation on downstream tasks") 8 | parser.add_argument("--config", type=str, default=None, help="path to config file") 9 | 10 | # model setting 11 | parser.add_argument("--model_name_or_path", type=str, default=None) 12 | parser.add_argument("--use_vllm", action="store_true", help="whether to use vllm engine") 13 | parser.add_argument("--attn_implementation", type=str, default=None, help="Implementation of self-attention. None means using the default (flash_attention_2 for most models).") 14 | 15 | # data paths 16 | parser.add_argument("--datasets", type=str, default=None) 17 | parser.add_argument("--test_file_root", type=str, default=None) 18 | parser.add_argument("--image_file_root", type=str, default=None) 19 | parser.add_argument("--test_files", type=str, default=None) 20 | parser.add_argument("--output_dir", type=str, default=None, help="path to save the predictions") 21 | parser.add_argument("--overwrite", action="store_true", help="whether to the saved file") 22 | parser.add_argument("--max_test_samples", type=int, default=None) 23 | parser.add_argument("--num_workers", type=int, default=32) 24 | parser.add_argument("--preprocessing_num_workers", type=int, default=8) 25 | 26 | # evaluation settings 27 | parser.add_argument("--input_max_length", type=str, default='8192', help="the maximum number of tokens of the input, we truncate the end of the context; can be separated by comma to match the specified datasets") 28 | parser.add_argument("--test_length", type=str, default="4,8,16,32,64,128", help="list the length to be tested.") 29 | 30 | # generation settings 31 | parser.add_argument("--do_sample", type=ast.literal_eval, choices=[True, False], default=False, help="whether to use sampling (false is greedy), overwrites temperature") 32 | parser.add_argument("--generation_max_length", type=str, default='10', help="max number of tokens to generate, can be separated by comma to match the specified datasets") 33 | parser.add_argument("--generation_min_length", type=int, default=0, help="min number of tokens to generate") 34 | parser.add_argument("--temperature", type=float, default=1.0, help="generation temperature") 35 | parser.add_argument("--top_p", type=float, default=1.0, help="top-p parameter for nucleus sampling") 36 | parser.add_argument("--stop_newline", type=ast.literal_eval, choices=[True, False], default=False, help="whether to stop generation at newline") 37 | parser.add_argument("--do_prefill", action="store_true", help="prefill the context to save memory") 38 | 39 | # model specific settings 40 | parser.add_argument("--seed", type=int, default=42, help="random seed") 41 | parser.add_argument("--no_cuda", action="store_true", help="disable cuda") 42 | parser.add_argument("--no_bf16", action="store_true", help="disable bf16 and use fp32") 43 | parser.add_argument("--load_in_8bit", action="store_true", help="int8 mode") 44 | parser.add_argument("--no_torch_compile", action="store_true", help="disable cuda") 45 | parser.add_argument("--use_chat_template", type=ast.literal_eval, choices=[True, False], default=True, help="whether to use chat template") 46 | parser.add_argument("--rope_theta", type=int, default=None, help="override rope theta") 47 | parser.add_argument("--use_yarn", action="store_true", help="yarn extension") 48 | parser.add_argument("--do_image_splitting", type=str, choices=["True", "False", "None"], default="None", help="whether to use image splitting for Idefics2 and Mantis (True, False, or None to use model default)") 49 | parser.add_argument("--offload_state_dict", action="store_true", help="model with offload") 50 | parser.add_argument("--image_resize", type=float, default=None, help="Image scaling factor, where 1.0 means original size and 0.5 means half the original size") 51 | parser.add_argument("--max_image_num", type=int, default=None, help="the max image number for models with dynamic cropping (e.g., internvl1.5/2/2.5, phi3/3.5)") 52 | parser.add_argument("--vision_batch_size", type=int, default=None, help="the batch size for Pixtral's and Ovis2's vision tower since its implementation has O(N^2) memory cost (N is the image number)") 53 | parser.add_argument("--api_sleep", type=int, default=None, help="the sleep time for API models after each call") 54 | parser.add_argument("--max_image_size", type=int, default=None, help="Max image size for Gemini to prevent over resizing and splitting") 55 | parser.add_argument("--image_detail", type=str, choices=["high", "low", "auto"], default="auto", help="Image detail for OpenAI models") 56 | parser.add_argument("--batch_size", type=int, default=4, help="inference batch size. This is only effective for OpenAI models now!") 57 | parser.add_argument("--v2pe_step", type=int, default=64, help="the increment size for visual tokens in V2PE") 58 | 59 | # misc 60 | parser.add_argument("--debug", action="store_true", help="for debugging") 61 | parser.add_argument("--count_tokens", action="store_true", help="instead of running generation, just count the number of tokens (only for HF models not API)") 62 | parser.add_argument("--dry_run", action="store_true", help="Test the data loading speed.") 63 | 64 | args = parser.parse_args() 65 | config = yaml.safe_load(open(args.config)) if args.config is not None else {} 66 | parser.set_defaults(**config) 67 | args = parser.parse_args() 68 | 69 | if args.output_dir is None: 70 | args.output_dir = f"output/{os.path.basename(args.model_name_or_path)}" 71 | 72 | if args.rope_theta is not None: 73 | args.output_dir = args.output_dir + f"-override-rope{args.rope_theta}" 74 | 75 | return args -------------------------------------------------------------------------------- /figure_scripts/12_metrics_for_summ.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # plot specific ones in a row, for formatting in the paper 4 | lf_df = process_df(all_df) 5 | dataset_names = ["GovReport", "Multi-LexSum"] 6 | length_datasets = {"gov-report rougeLsum_f1": "ROUGE-L", 7 | "GovReport": "GPT-4o Eval", 8 | "multi-lexsum rougeLsum_f1": "ROUGE-L", 9 | "Multi-LexSum": "GPT-4o Eval"} 10 | 11 | summ_dataset_rouge = {"gov-report": "rougeLsum_f1", 12 | "multi-lexsum": "rougeLsum_f1"} 13 | 14 | dataset_to_metrics["gov-report"] = ["rougeLsum_f1"] 15 | dataset_to_metrics["multi-lexsum"] = ["rougeLsum_f1"] 16 | 17 | rouge_dfs = [] 18 | 19 | for m_idx, model in enumerate(models_configs): 20 | args = arguments() 21 | for dataset in dataset_configs: 22 | if dataset["dataset"] not in summ_dataset_rouge: 23 | continue 24 | args.update(dataset) 25 | args.update(model) 26 | 27 | # parse the metrics 28 | metric = args.get_averaged_metric() 29 | dsimple, mnames = args.get_metric_name() 30 | 31 | if metric is None: 32 | print("failed:", args.get_path()) # will be np.nan when using DataFrame 33 | continue 34 | for k, m in metric.items(): 35 | rouge_dfs.append({**asdict(args), **model, 36 | "metric name": k, "metric": m, 37 | "dataset_simple": dsimple + " " + k, 38 | "test_data": f"{args.dataset}-{args.test_name}-{args.input_max_length}" 39 | }) 40 | 41 | rouge_df = pd.DataFrame(rouge_dfs) 42 | import utils 43 | utils.custom_avgs = {} 44 | rouge_df = process_df(rouge_df) 45 | 46 | lf_df = pd.merge( 47 | lf_df, 48 | rouge_df, 49 | on=["input_max_length", "Model"], 50 | how="outer", 51 | suffixes=('', '_rouge') 52 | ) 53 | # duplicated_cols = lf_df.columns.duplicated() 54 | # lf_df = lf_df.loc[:, ~duplicated_cols] 55 | 56 | new_index = [ 57 | 'GPT-4o', 58 | 'InternVL2.5-4B', 'InternVL2.5-8B', 59 | 'InternVL3-2B', 'InternVL3-14B', 60 | 'Gemma3-4B', 'Gemma3-12B', 'Gemma3-27B', 61 | ] 62 | 63 | ncols = 4 64 | nrows = (len(length_datasets) - 1) // ncols + 1 65 | 66 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 67 | fig.set_size_inches((ncols * 6, nrows * 6)) 68 | plt.subplots_adjust(left=0.15, top=0.83, right=0.99, wspace=0.07) 69 | 70 | title_y_pos = 0.95 71 | left_pos = (ax[0].get_position().x0 + ax[1].get_position().x1) / 2 72 | right_pos = (ax[2].get_position().x0 + ax[3].get_position().x1) / 2 73 | fig.text(left_pos, title_y_pos, 'GovReport', fontsize=30, ha='center', va='center', fontweight='bold') 74 | fig.text(right_pos, title_y_pos, 'Multi-LexSum', fontsize=30, ha='center', va='center', fontweight='bold') 75 | 76 | line_y = title_y_pos - 0.03 77 | axes_row = ax if nrows == 1 else ax[0, :] 78 | transform = fig.transFigure 79 | fig.add_artist(plt.Line2D([axes_row[0].get_position().x0, axes_row[1].get_position().x1], 80 | [line_y, line_y], color='black', linewidth=1.0, transform=transform)) 81 | fig.add_artist(plt.Line2D([axes_row[2].get_position().x0, axes_row[3].get_position().x1], 82 | [line_y, line_y], color='black', linewidth=1.0, transform=transform)) 83 | 84 | 85 | 86 | 87 | cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#ed4d6e", '#DD9380', '#DEA683', '#CFCC86', "#0CD79F"]) 88 | 89 | group_min_list = [float('inf')] * len(length_datasets) 90 | group_max_list = [float('-inf')] * len(length_datasets) 91 | 92 | for d_idx, dataset_col in enumerate(length_datasets.keys()): 93 | tdf = lf_df[lf_df.input_max_length > 4096] 94 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset_col) 95 | tdf = tdf.reindex(new_index) 96 | 97 | current_min = tdf.min().min() 98 | current_max = tdf.max().max() 99 | 100 | group_min_list[d_idx] = min(group_min_list[d_idx], current_min) 101 | group_max_list[d_idx] = max(group_max_list[d_idx], current_max) 102 | 103 | group_size = 2 104 | for g_idx in range(0, len(length_datasets), group_size): 105 | group_min = min(group_min_list[g_idx: g_idx + group_size]) 106 | group_max = max(group_max_list[g_idx: g_idx + group_size]) 107 | for _ in range(g_idx, g_idx + group_size): 108 | group_min_list[_] = group_min 109 | group_max_list[_] = group_max 110 | 111 | for i, (dataset_col, dataset_title) in enumerate(length_datasets.items()): 112 | if nrows > 1: 113 | a = ax[i // ncols][i % ncols] 114 | else: 115 | a = ax[i] 116 | 117 | tdf = lf_df[lf_df.input_max_length > 4096] 118 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset_col) 119 | tdf = tdf.reindex(new_index) 120 | 121 | import matplotlib.colors as mcolors 122 | sns_g = sns.heatmap( 123 | tdf, annot=True, cmap=cmap, fmt=".1f", yticklabels=True, 124 | ax=a, annot_kws={"fontsize": 21}, vmax=group_max_list[i], vmin=group_min_list[i], 125 | cbar=False, norm=mcolors.PowerNorm(gamma=0.5, vmin=group_min_list[i], vmax=group_max_list[i]) 126 | ) 127 | sns_g.set_title(dataset_title, fontsize=30) 128 | 129 | sns_g.set_ylabel("") 130 | sns_g.set_xlabel("") 131 | 132 | written_index = [x.replace("-Inst", '') for x in tdf.index.tolist()] 133 | 134 | sns_g.set_yticklabels(written_index, size=29) 135 | xticks = ['8k', '16k', '32k', '64k', '128k'] 136 | sns_g.set_xticklabels(xticks, size=29) 137 | for idx in [1, 3, 5]: 138 | a.axhline(y=idx, color="white", linestyle="-", linewidth=4) 139 | 140 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 141 | 142 | # plt.tight_layout() 143 | figure_path = os.path.join(project_root, f"figures/12_results_length_model_eval.pdf") 144 | plt.savefig(figure_path, dpi=500, format="pdf") 145 | plt.show() -------------------------------------------------------------------------------- /vlm_model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def load_LLM(args): 5 | kwargs = {} 6 | if "gpt" in args.model_name_or_path or "claude" in args.model_name_or_path: # we use gpt compatible api 7 | from .openai_model import OpenAIModel 8 | model_cls = OpenAIModel 9 | elif "gemini" in args.model_name_or_path: 10 | from .gemini import GeminiModel 11 | model_cls = GeminiModel 12 | else: 13 | # HF models 14 | lower_model_name = args.model_name_or_path.lower().split("/")[-1] 15 | if "qwen2-vl" in lower_model_name or "qwen2.5-vl" in lower_model_name or "qvq-" in lower_model_name: 16 | from .qwen2_vl import Qwen2VLModel 17 | model_cls = Qwen2VLModel 18 | elif "qwen-vl" in lower_model_name: # too many tokens for image 19 | from .qwen_vl import QwenVLModel 20 | model_cls = QwenVLModel 21 | elif "llama-3.2" in lower_model_name: # too many tokens for image 22 | from .mllama import MLlamaModel 23 | model_cls = MLlamaModel 24 | elif "deepseek-vl2" in lower_model_name: # hard code integration 25 | from .deepseek_vl2 import DeepseekVL2Model 26 | model_cls = DeepseekVL2Model 27 | elif "idefics" in lower_model_name: 28 | from .idefics import IdeficsModel 29 | model_cls = IdeficsModel 30 | if args.do_image_splitting != "None": 31 | # assert "idefics2" in lower_model_name 32 | kwargs["do_image_splitting"] = args.do_image_splitting == "True" 33 | if args.vision_batch_size is not None: 34 | kwargs["vision_batch_size"] = args.vision_batch_size 35 | elif "phi-4" in lower_model_name: 36 | from .phi4 import Phi4Model 37 | model_cls = Phi4Model 38 | elif "phi-3" in lower_model_name: 39 | from .phi3 import Phi3Model 40 | model_cls = Phi3Model 41 | elif "v2pe" in lower_model_name: 42 | from .internv2pe import InternV2PEModel 43 | if args.vision_batch_size is not None: 44 | kwargs["vision_batch_size"] = args.vision_batch_size 45 | if args.v2pe_step is not None: 46 | kwargs["v2pe_step"] = args.v2pe_step 47 | model_cls = InternV2PEModel 48 | elif "internvl" in lower_model_name: 49 | from .internvl import InternVLModel 50 | if args.vision_batch_size is not None: 51 | kwargs["vision_batch_size"] = args.vision_batch_size 52 | model_cls = InternVLModel 53 | elif "pixtral" in lower_model_name: 54 | from .pixtral import PixtralModel 55 | if args.vision_batch_size is not None: 56 | kwargs["vision_batch_size"] = args.vision_batch_size 57 | model_cls = PixtralModel 58 | elif "nvila" in lower_model_name: 59 | from .nvila import NVILAModel 60 | model_cls = NVILAModel 61 | elif "llava-onevision" in lower_model_name: # too many tokens for image 62 | from .llava_onevision import LlavaOneVisionModel 63 | model_cls = LlavaOneVisionModel 64 | elif "gemma-3" in lower_model_name: 65 | from .gemma3 import Gemma3VLModel 66 | model_cls = Gemma3VLModel 67 | elif "minicpm" in lower_model_name: # only support one GPU inference 68 | from .minicpm import MiniCPMModel 69 | model_cls = MiniCPMModel 70 | elif "ovis2" in lower_model_name: 71 | from .ovis2 import Ovis2Model 72 | if args.vision_batch_size is not None: 73 | kwargs["vision_batch_size"] = args.vision_batch_size 74 | model_cls = Ovis2Model 75 | elif "qwen2.5" in lower_model_name: 76 | # we already checked the qwen2.5-vl and qwen2-vl above, this is for qwen2.5 LLMs 77 | # Not that TextOnlyModel support most text-only LLMs, not just Qwen2.5 LLMs 78 | from .text_only_model import TextOnlyModel 79 | model_cls = TextOnlyModel 80 | 81 | elif "mplug-owl3" in lower_model_name: # too many tokens for images 82 | from .mplug_owl3 import mPLUGOwl3Model 83 | model_cls = mPLUGOwl3Model 84 | 85 | if args.no_torch_compile: 86 | kwargs["torch_compile"] = False 87 | if args.no_bf16: 88 | kwargs["torch_dtype"] = torch.float16 89 | if args.load_in_8bit: 90 | kwargs["load_in_8bit"] = True 91 | if args.rope_theta is not None: 92 | kwargs["rope_theta"] = args.rope_theta 93 | if args.use_yarn: 94 | kwargs["use_yarn"] = True 95 | if args.offload_state_dict: 96 | kwargs["offload_state_dict"] = True 97 | if args.do_prefill: 98 | kwargs["do_prefill"] = args.do_prefill 99 | if args.attn_implementation is not None: 100 | kwargs["attn_implementation"] = args.attn_implementation 101 | if args.image_resize: 102 | kwargs["image_resize"] = args.image_resize 103 | if args.max_image_num is not None: 104 | kwargs["max_image_num"] = args.max_image_num 105 | if args.max_image_size is not None: 106 | kwargs["max_image_size"] = args.max_image_size 107 | if args.api_sleep is not None: 108 | kwargs["api_sleep"] = args.api_sleep 109 | if args.image_detail != "auto": 110 | kwargs["image_detail"] = args.image_detail 111 | 112 | model = model_cls( 113 | args.model_name_or_path, 114 | temperature=args.temperature, 115 | top_p=args.top_p, 116 | max_length=args.input_max_length, 117 | generation_max_length=args.generation_max_length, 118 | generation_min_length=args.generation_min_length, 119 | do_sample=args.do_sample, 120 | stop_newline=args.stop_newline, 121 | use_chat_template=args.use_chat_template, 122 | **kwargs, 123 | ) 124 | 125 | return model -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/model/internvl_chat/configuration_intern_vit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2023 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | import os 7 | from typing import Union 8 | 9 | from transformers.configuration_utils import PretrainedConfig 10 | from transformers.utils import logging 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | class InternVisionConfig(PretrainedConfig): 16 | r""" 17 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to 18 | instantiate a vision encoder according to the specified arguments, defining the model architecture. 19 | 20 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 21 | documentation from [`PretrainedConfig`] for more information. 22 | 23 | Args: 24 | num_channels (`int`, *optional*, defaults to 3): 25 | Number of color channels in the input images (e.g., 3 for RGB). 26 | patch_size (`int`, *optional*, defaults to 14): 27 | The size (resolution) of each patch. 28 | image_size (`int`, *optional*, defaults to 224): 29 | The size (resolution) of each image. 30 | qkv_bias (`bool`, *optional*, defaults to `False`): 31 | Whether to add a bias to the queries and values in the self-attention layers. 32 | hidden_size (`int`, *optional*, defaults to 3200): 33 | Dimensionality of the encoder layers and the pooler layer. 34 | num_attention_heads (`int`, *optional*, defaults to 25): 35 | Number of attention heads for each attention layer in the Transformer encoder. 36 | intermediate_size (`int`, *optional*, defaults to 12800): 37 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 38 | qk_normalization (`bool`, *optional*, defaults to `True`): 39 | Whether to normalize the queries and keys in the self-attention layers. 40 | num_hidden_layers (`int`, *optional*, defaults to 48): 41 | Number of hidden layers in the Transformer encoder. 42 | use_flash_attn (`bool`, *optional*, defaults to `True`): 43 | Whether to use flash attention mechanism. 44 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 45 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 46 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. 47 | layer_norm_eps (`float`, *optional*, defaults to 1e-6): 48 | The epsilon used by the layer normalization layers. 49 | dropout (`float`, *optional*, defaults to 0.0): 50 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 51 | drop_path_rate (`float`, *optional*, defaults to 0.0): 52 | Dropout rate for stochastic depth. 53 | attention_dropout (`float`, *optional*, defaults to 0.0): 54 | The dropout ratio for the attention probabilities. 55 | initializer_range (`float`, *optional*, defaults to 0.02): 56 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 57 | initializer_factor (`float`, *optional*, defaults to 0.1): 58 | A factor for layer scale. 59 | """ 60 | 61 | model_type = 'intern_vit_6b' 62 | 63 | def __init__( 64 | self, 65 | num_channels=3, 66 | patch_size=14, 67 | image_size=224, 68 | qkv_bias=False, 69 | hidden_size=3200, 70 | num_attention_heads=25, 71 | intermediate_size=12800, 72 | qk_normalization=True, 73 | num_hidden_layers=48, 74 | use_flash_attn=True, 75 | hidden_act='gelu', 76 | norm_type='rms_norm', 77 | layer_norm_eps=1e-6, 78 | dropout=0.0, 79 | drop_path_rate=0.0, 80 | attention_dropout=0.0, 81 | initializer_range=0.02, 82 | initializer_factor=0.1, 83 | **kwargs, 84 | ): 85 | super().__init__(**kwargs) 86 | 87 | self.hidden_size = hidden_size 88 | self.intermediate_size = intermediate_size 89 | self.dropout = dropout 90 | self.drop_path_rate = drop_path_rate 91 | self.num_hidden_layers = num_hidden_layers 92 | self.num_attention_heads = num_attention_heads 93 | self.num_channels = num_channels 94 | self.patch_size = patch_size 95 | self.image_size = image_size 96 | self.initializer_range = initializer_range 97 | self.initializer_factor = initializer_factor 98 | self.attention_dropout = attention_dropout 99 | self.layer_norm_eps = layer_norm_eps 100 | self.hidden_act = hidden_act 101 | self.norm_type = norm_type 102 | self.qkv_bias = qkv_bias 103 | self.qk_normalization = qk_normalization 104 | self.use_flash_attn = use_flash_attn 105 | 106 | @classmethod 107 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': 108 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 109 | 110 | if 'vision_config' in config_dict: 111 | config_dict = config_dict['vision_config'] 112 | 113 | if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: 114 | logger.warning( 115 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 116 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' 117 | ) 118 | 119 | return cls.from_dict(config_dict, **kwargs) 120 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/model/internvl_chat/configuration_internvl_chat.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2023 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import copy 8 | 9 | from internvl.model.internlm2.configuration_internlm2 import InternLM2Config 10 | from transformers import LlamaConfig, Qwen2Config 11 | from transformers.configuration_utils import PretrainedConfig 12 | from transformers.utils import logging 13 | 14 | from .configuration_intern_vit import InternVisionConfig 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | class InternVLChatConfig(PretrainedConfig): 20 | model_type = 'internvl_chat' 21 | is_composition = True 22 | 23 | def __init__( 24 | self, 25 | vision_config=None, 26 | llm_config=None, 27 | use_backbone_lora=0, 28 | use_llm_lora=0, 29 | pad2square=False, 30 | select_layer=-1, 31 | force_image_size=None, 32 | downsample_ratio=0.5, 33 | template=None, 34 | dynamic_image_size=False, 35 | use_thumbnail=False, 36 | ps_version='v1', 37 | dynamic_max_patch=False, 38 | min_dynamic_patch=1, 39 | max_dynamic_patch=6, 40 | min_num_frame=4, 41 | max_num_frame=20, 42 | compress_seq=False, 43 | attn_type=None, 44 | group_list=None, 45 | chunk_num=1, 46 | interaction=True, 47 | rope_pos_id_version='default', 48 | rope_pos_id_stride=None, 49 | img_emb_down_sample_ratio=None, 50 | **kwargs): 51 | super().__init__(**kwargs) 52 | 53 | if vision_config is None: 54 | vision_config = {} 55 | logger.info('vision_config is None. Initializing the InternVisionConfig with default values.') 56 | 57 | if llm_config is None: 58 | llm_config = {} 59 | logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') 60 | 61 | self.vision_config = InternVisionConfig(**vision_config) 62 | if llm_config['architectures'][0] == 'LlamaForCausalLM': 63 | self.llm_config = LlamaConfig(**llm_config) 64 | elif llm_config['architectures'][0] == 'InternLM2ForCausalLM': 65 | self.llm_config = InternLM2Config(**llm_config) 66 | elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': 67 | self.llm_config = Qwen2Config(**llm_config) 68 | else: 69 | raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0])) 70 | 71 | self.use_backbone_lora = use_backbone_lora 72 | self.use_llm_lora = use_llm_lora 73 | self.pad2square = pad2square 74 | self.select_layer = select_layer 75 | self.force_image_size = force_image_size 76 | self.downsample_ratio = downsample_ratio 77 | self.template = template 78 | self.dynamic_image_size = dynamic_image_size 79 | self.use_thumbnail = use_thumbnail 80 | self.ps_version = ps_version # pixel shuffle version 81 | self.min_dynamic_patch = min_dynamic_patch 82 | self.max_dynamic_patch = max_dynamic_patch 83 | self.min_num_frame = min_num_frame 84 | self.max_num_frame = max_num_frame 85 | self.compress_seq = compress_seq 86 | self.attn_type=attn_type 87 | self.group_list = group_list 88 | self.chunk_num = chunk_num 89 | self.interaction = interaction 90 | self.rope_pos_id_version = rope_pos_id_version 91 | self.rope_pos_id_stride = rope_pos_id_stride 92 | self.img_emb_down_sample_ratio = img_emb_down_sample_ratio 93 | self.dynamic_max_patch = dynamic_max_patch 94 | logger.info(f'vision_select_layer: {self.select_layer}') 95 | logger.info(f'ps_version: {self.ps_version}') 96 | logger.info(f'dynamic_max_patch: {self.dynamic_max_patch}') 97 | logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') 98 | logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') 99 | logger.info(f'img_emb_down_sample_ratio: {self.img_emb_down_sample_ratio}') 100 | 101 | def to_dict(self): 102 | """ 103 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. 104 | 105 | Returns: 106 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 107 | """ 108 | output = copy.deepcopy(self.__dict__) 109 | output['vision_config'] = self.vision_config.to_dict() 110 | output['llm_config'] = self.llm_config.to_dict() 111 | output['model_type'] = self.__class__.model_type 112 | output['use_backbone_lora'] = self.use_backbone_lora 113 | output['use_llm_lora'] = self.use_llm_lora 114 | output['pad2square'] = self.pad2square 115 | output['select_layer'] = self.select_layer 116 | output['force_image_size'] = self.force_image_size 117 | output['downsample_ratio'] = self.downsample_ratio 118 | output['template'] = self.template 119 | output['dynamic_image_size'] = self.dynamic_image_size 120 | output['use_thumbnail'] = self.use_thumbnail 121 | output['ps_version'] = self.ps_version 122 | output['min_dynamic_patch'] = self.min_dynamic_patch 123 | output['max_dynamic_patch'] = self.max_dynamic_patch 124 | output['dynamic_max_patch'] = self.dynamic_max_patch 125 | output['rope_pos_id_version'] = self.rope_pos_id_version 126 | output['rope_pos_id_stride'] = self.rope_pos_id_stride 127 | output['img_emb_down_sample_ratio'] = self.img_emb_down_sample_ratio 128 | output['min_num_frame'] = self.min_num_frame 129 | output['max_num_frame'] = self.max_num_frame 130 | 131 | return output 132 | -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/pad_data_collator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | IGNORE_INDEX = -100 5 | 6 | 7 | def pad_data_collator(features, pad_id=0): 8 | 9 | first = features[0] 10 | batch = {} 11 | 12 | batch_lens = [feat['input_ids'].shape for feat in features] 13 | max_item_length = max(batch_lens)[0] 14 | for idx in range(len(features)): 15 | feat = features[idx] 16 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 17 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] 18 | feat['input_ids'] = temp_input_ids 19 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 20 | temp_labels[:feat['labels'].shape[0]] = feat['labels'] 21 | feat['labels'] = temp_labels 22 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 23 | 24 | # Special handling for labels. 25 | # Ensure that tensor is created with the correct type 26 | # (it should be automatically the case, but let's make sure of it.) 27 | if 'label' in first and first['label'] is not None: 28 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] 29 | dtype = torch.long if isinstance(label, int) else torch.float 30 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 31 | elif 'label_ids' in first and first['label_ids'] is not None: 32 | if isinstance(first['label_ids'], torch.Tensor): 33 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 34 | else: 35 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float 36 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) 37 | 38 | # Handling of all other possible keys. 39 | # Again, we will use the first element to figure out which key/values are not None for this model. 40 | for k, v in first.items(): 41 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): 42 | if isinstance(v, torch.Tensor): 43 | batch[k] = torch.stack([f[k] for f in features]) 44 | elif isinstance(v, np.ndarray): 45 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 46 | else: 47 | batch[k] = torch.tensor([f[k] for f in features]) 48 | return batch 49 | 50 | 51 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0): 52 | 53 | first = features[0] 54 | batch = {} 55 | batch_lens = [feat['input_ids'].shape for feat in features] 56 | max_item_length = max_item_length or max(batch_lens)[0] 57 | for idx in range(len(features)): 58 | feat = features[idx] 59 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 60 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] 61 | feat['input_ids'] = temp_input_ids 62 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 63 | temp_labels[:feat['labels'].shape[0]] = feat['labels'] 64 | feat['labels'] = temp_labels 65 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 66 | 67 | if 'position_ids' in feat: 68 | 69 | if isinstance(feat['position_ids'],list): 70 | temp_position_ids = [pad_id] * max_item_length 71 | temp_position_ids[:len(feat['position_ids'])] = feat['position_ids'] 72 | 73 | else: 74 | temp_position_ids = [pad_id] * max_item_length 75 | temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] 76 | feat['position_ids'] = temp_position_ids 77 | 78 | if 'loss_weight' in feat: 79 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length) 80 | temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight'] 81 | feat['loss_weight'] = temp_loss_weight 82 | 83 | # Special handling for labels. 84 | # Ensure that tensor is created with the correct type 85 | # (it should be automatically the case, but let's make sure of it.) 86 | if 'label' in first and first['label'] is not None: 87 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] 88 | dtype = torch.long if isinstance(label, int) else torch.float 89 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 90 | elif 'label_ids' in first and first['label_ids'] is not None: 91 | if isinstance(first['label_ids'], torch.Tensor): 92 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 93 | else: 94 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float 95 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) 96 | 97 | # Handling of all other possible keys. 98 | # Again, we will use the first element to figure out which key/values are not None for this model. 99 | for k, v in first.items(): 100 | if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ 101 | v is not None and not isinstance(v, str): 102 | if isinstance(v, torch.Tensor): 103 | batch[k] = torch.stack([f[k] for f in features]) 104 | elif isinstance(v, np.ndarray): 105 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 106 | else: 107 | batch[k] = torch.tensor([f[k] for f in features]) 108 | if k in ('pixel_values', 'image_flags'): 109 | if isinstance(v, torch.Tensor): 110 | batch[k] = torch.concat([f[k] for f in features]) 111 | elif isinstance(v, np.ndarray): 112 | batch[k] = torch.concat(np.stack([f[k] for f in features])) 113 | else: 114 | batch[k] = torch.concat([f[k] for f in features]) 115 | if k=='position_ids': 116 | batch[k]=list(batch[k].numpy()) 117 | return batch 118 | -------------------------------------------------------------------------------- /figure_scripts/15_rag_modal_difficulty.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib 3 | 4 | # 5 | config_files = ["configs/text_rag_all.yaml"] 6 | 7 | config_files = [os.path.join(project_root, config) for config in config_files] 8 | 9 | print(os.getcwd()) 10 | dataset_configs = [] 11 | for file in config_files: 12 | c = yaml.safe_load(open(file)) 13 | 14 | if isinstance(c["generation_max_length"], int): 15 | c["generation_max_length"] = ",".join([str(c["generation_max_length"])] * len(c["datasets"].split(","))) 16 | for d, t, l, g in zip( 17 | c['datasets'].split(','), c['test_files'].split(','), c['input_max_length'].split(','), 18 | c['generation_max_length'].split(',')): 19 | dataset_configs.append( 20 | {"dataset": d, "test_name": os.path.basename(os.path.splitext(t)[0]), "input_max_length": int(l), 21 | "generation_max_length": int(g), "max_test_samples": c['max_test_samples'], 22 | 'use_chat_template': c['use_chat_template']}) 23 | print(dataset_configs) 24 | 25 | models_configs = [ 26 | {"model": "Qwen2.5-VL-3B-Instruct", "use_chat_template": True, "training_length": 32768}, 27 | {"model": "Qwen2.5-VL-7B-Instruct", "use_chat_template": True, "training_length": 32768}, 28 | {"model": 'Qwen2.5-VL-32B-Instruct', "use_chat_template": True, "training_length": 32768}, 29 | 30 | {"model": "Qwen2.5-3B-Instruct_yarn", "use_chat_template": True, "training_length": 131072}, 31 | {"model": "Qwen2.5-7B-Instruct_yarn", "use_chat_template": True, "training_length": 131072}, 32 | {"model": 'Qwen2.5-32B-Instruct_yarn', "use_chat_template": True, "training_length": 131072}, 33 | 34 | {"model": "gemma-3-4b-it", "use_chat_template": True, "training_length": 131072}, 35 | {"model": "gemma-3-12b-it", "use_chat_template": True, "training_length": 131072}, 36 | {"model": "gemma-3-27b-it", "use_chat_template": True, "training_length": 131072}, 37 | ] 38 | result_dir = "/home/zhaowei.wang/data_dir/mmlb_result_backup/analysis_models" 39 | for model in models_configs: 40 | model["output_dir"] = f"{result_dir}/{model['model']}" 41 | 42 | new_dfs = [] 43 | dataset_to_metrics["triviaqa"] = ["sub_em"] 44 | 45 | for m_idx, model in enumerate(models_configs): 46 | args = arguments() 47 | for dataset in dataset_configs: 48 | args.update(dataset) 49 | args.update(model) 50 | 51 | # parse the metrics 52 | metric = args.get_averaged_metric() 53 | dsimple, mnames = args.get_metric_name() 54 | 55 | if metric is None: 56 | print("failed:", args.get_path()) # will be np.nan when using DataFrame 57 | continue 58 | for k, m in metric.items(): 59 | new_dfs.append({**asdict(args), **model, 60 | "metric name": k, "metric": m, 61 | "dataset_simple": dsimple + " " + k, 62 | "test_data": f"{args.dataset}-{args.test_name}-{args.input_max_length}" 63 | }) 64 | 65 | new_df = pd.DataFrame(new_dfs) 66 | 67 | import utils 68 | chosen_models = [mc["model"] for mc in models_configs] 69 | utils.custom_avgs = {} 70 | triviaqa_df = process_df(new_df, chosen_models=chosen_models) 71 | viquae_df = process_df(all_df, chosen_models=[m for m in chosen_models if "yarn" not in m]) 72 | 73 | # align the df(s) 74 | triviaqa_df.rename(columns={'triviaqa sub_em': 'score'}, inplace=True) 75 | triviaqa_df['Model'] = triviaqa_df['Model'] + '_triviaqa' 76 | viquae_df = viquae_df[['Model', 'input_max_length', 'ViQuAE']] 77 | viquae_df.rename(columns={'ViQuAE': 'score'}, inplace=True) 78 | lf_df = pd.concat([triviaqa_df, viquae_df]) 79 | lf_df.reset_index(drop=True, inplace=True) 80 | 81 | # plot specific ones in a row, for formatting in the paper 82 | length_datasets = ['score'] 83 | 84 | ncols = 2 85 | nrows = (len(length_datasets) - 1) // ncols + 1 86 | 87 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 88 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 89 | fig.set_size_inches((ncols * 6, nrows * 5)) 90 | plt.rc('axes', unicode_minus=False) 91 | plt.rcParams.update({'axes.unicode_minus': False}) 92 | 93 | base_index_order = [ 94 | 'Qwen2.5-VL-7B-Inst', 'Qwen2.5-VL-7B-Inst_triviaqa', 'Qwen2.5-7B-Instruct_yarn_triviaqa', 95 | 'Qwen2.5-VL-32B-Inst', 'Qwen2.5-VL-32B-Inst_triviaqa', 'Qwen2.5-32B-Instruct_yarn_triviaqa', 96 | "Gemma3-27B", "Gemma3-27B_triviaqa", 97 | ] # "Gemma3-12B", "Gemma3-12B_triviaqa", 98 | 99 | for i, dataset in enumerate(length_datasets): 100 | if nrows > 1: 101 | a = ax[i // ncols][i % ncols] 102 | elif ncols > 1: 103 | a = ax[i] 104 | else: 105 | a = ax 106 | 107 | tdf = lf_df[lf_df.input_max_length > 4096] 108 | tdf = tdf.pivot_table(index="Model", columns="input_max_length", values=dataset) 109 | tdf = tdf.reindex(base_index_order) 110 | 111 | sns_g = sns.heatmap( 112 | tdf, annot=True, cmap=custom_cmap, fmt=".1f", yticklabels=True, 113 | ax=a, annot_kws={"fontsize": 22}, 114 | cbar=False 115 | ) 116 | sns_g.set_title("ViQuAE", fontsize=22) 117 | 118 | sns_g.set_ylabel("") 119 | sns_g.set_xlabel("") 120 | # "Gemma3-12B", '$\\diamond$ w/ name', 121 | written_index = ['Qwen2.5-VL-7B', '$\\diamond$ w/ name', '$\\diamond$ w/ LLM', 122 | 'Qwen2.5-VL-32B', '$\\diamond$ w/ name', '$\\diamond$ w/ LLM', 123 | "Gemma3-27B", '$\\diamond$ w/ name',] 124 | sns_g.set_yticklabels(written_index, size=18, fontweight='bold') 125 | xticks_map = {"8192": '8k', "16384": '16k', "32768": '32k', "65536":'64k', "131072":'128k'} 126 | sns_g.set_xticklabels([xticks_map[st.get_text()] for st in sns_g.get_xticklabels()], size=22) 127 | 128 | # idx, start, end 129 | a.hlines([3, 6, 8], 0, 6, color="0.95", linestyle="-", linewidth=3) 130 | 131 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 132 | 133 | plt.tight_layout() 134 | file_path = os.path.join(project_root, f"figures/15_text_rag_difficulty.pdf") 135 | plt.savefig(file_path, dpi=500, format="pdf") 136 | plt.show() -------------------------------------------------------------------------------- /vlm_model/model_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import io 3 | import time 4 | import base64 5 | from PIL import Image 6 | 7 | import logging 8 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 9 | datefmt='%m/%d/%Y %H:%M:%S') 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | def resize_image(image_list, image_resize): 15 | new_image_list = [] 16 | for img in image_list: 17 | width, height = img.size 18 | new_width = int(width * image_resize) 19 | new_height = int(height * image_resize) 20 | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) 21 | 22 | new_image_list.append(img) 23 | return new_image_list 24 | 25 | 26 | def resize_image_max_size(image_list, max_image_size): 27 | new_image_list = [] 28 | for img in image_list: 29 | width, height = img.size 30 | if width <= max_image_size and height <= max_image_size: 31 | new_image_list.append(img) 32 | continue 33 | 34 | if width > height: 35 | new_width = max_image_size 36 | new_height = min(int(max_image_size / width * height), max_image_size) 37 | else: 38 | new_height = max_image_size 39 | new_width = min(int(max_image_size / height * width), max_image_size) 40 | 41 | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) 42 | 43 | new_image_list.append(img) 44 | 45 | return new_image_list 46 | 47 | 48 | def image_to_io(image: Image.Image, format: str = 'PNG') -> io.BytesIO: 49 | img_io = io.BytesIO() 50 | image.save(img_io, format=format) 51 | img_io.seek(0) 52 | return img_io 53 | 54 | 55 | def encode_image_base64(pil_image, format="PNG"): 56 | buffer = io.BytesIO() 57 | pil_image.save(buffer, format=format) 58 | img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") 59 | return img_str 60 | 61 | 62 | def truncate_images(text, image_list, max_image_num=None): 63 | """ 64 | keep the last max_image_num images in the example. Truncate image_list and remove beginning marker in text. 65 | 66 | Args: 67 | text (str): query with 68 | image_list (list): list of image path (or PIL.Image) 69 | max_image_num (int, optional): Max number of kept images 70 | 71 | Returns: 72 | tuple: revised text and image_list 73 | """ 74 | if max_image_num is None or len(image_list) <= max_image_num: 75 | return text, image_list 76 | 77 | segments = re.split(r'()', text) 78 | 79 | # compute remove number 80 | keep_count = max_image_num 81 | remove_count = len(image_list) - keep_count 82 | 83 | # compute marker numbers 84 | image_tags_count = segments.count('') 85 | 86 | # safe check 87 | assert image_tags_count == len(image_list), f"Warning: Number of tags ({image_tags_count}) doesn't match image_list length ({len(image_list)})" 88 | 89 | # build new text 90 | new_segments = [] 91 | removed = 0 92 | for segment in segments: 93 | if segment == '' and removed < remove_count: 94 | # replace with "" 95 | new_segments.append('') 96 | removed += 1 97 | else: 98 | new_segments.append(segment) 99 | 100 | # join all segments 101 | new_text = ''.join(new_segments) 102 | 103 | # only keep last kepp_count images 104 | new_image_list = image_list[-keep_count:] 105 | 106 | return new_text, new_image_list 107 | 108 | 109 | def format_chat(text, image_list, system_prompt): 110 | content = re.split(r'()', text) 111 | image_idx, new_content = 0, [] 112 | for c in content: 113 | if c == "": 114 | new_content.append({ 115 | "type": "image", 116 | "image": image_list[image_idx] 117 | }) 118 | image_idx += 1 119 | else: 120 | new_content.append({ 121 | "type": "text", 122 | "text": c 123 | }) 124 | assert image_idx == len(image_list) 125 | messages = [{"role": "user", "content": new_content}, 126 | {"role": "assistant", "content": system_prompt}] 127 | return messages 128 | 129 | 130 | def call_api(func, limit: int=5, pause: int=10): 131 | """ 132 | Call the API function with retries and rate limit handling. 133 | TODO: more error handling? 134 | """ 135 | count = 0 136 | while True: 137 | try: 138 | output = func() 139 | break 140 | except Exception as e: 141 | logger.info(f"Exception while using api: {e}") 142 | msg = str(e).lower() 143 | 144 | if "rate limit" in msg or "rate_limit" in msg or "quota" in msg or "429" in msg or ("overloaded" in msg and count >= limit): 145 | logger.info(f"Rate limit exceeded, waiting {pause} secs and retrying...") 146 | count += 1 147 | if count < limit: 148 | logger.info(f"Encountered error {e}, retrying...") 149 | time.sleep(pause) 150 | else: 151 | logger.info("Skipping generation due to unknown error") 152 | raise e 153 | return output 154 | 155 | 156 | class LLM: 157 | def __init__( 158 | self, 159 | model_name, 160 | temperature=0.9, 161 | top_p=0.9, 162 | max_length=32768, 163 | generation_max_length=2048, 164 | generation_min_length=0, 165 | do_sample=True, 166 | stop_newline=False, 167 | use_chat_template=False, 168 | ): 169 | self.model_name = model_name 170 | self.temperature = temperature 171 | self.top_p = top_p 172 | self.max_length = max_length 173 | self.generation_max_length = generation_max_length 174 | self.generation_min_length = generation_min_length 175 | self.do_sample = do_sample 176 | self.use_chat_template = use_chat_template 177 | self.stops = None 178 | if stop_newline: 179 | self.stops = ["\n", "\n\n"] 180 | 181 | def prepare_inputs(self, test_item, data): 182 | raise NotImplementedError("prepare_inputs not implemented for LLM") 183 | 184 | def generate(self, inputs=None, prompt=None, **kwargs): 185 | raise NotImplementedError("generate not implemented for LLM") 186 | -------------------------------------------------------------------------------- /figure_scripts/15_docqa_modal_difficulty.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import matplotlib 3 | 4 | # 5 | config_files = ["configs/text_docqa_all.yaml", "configs/docqa_all.yaml"] 6 | 7 | config_files = [os.path.join(project_root, config) for config in config_files] 8 | 9 | print(os.getcwd()) 10 | dataset_configs = [] 11 | for file in config_files: 12 | c = yaml.safe_load(open(file)) 13 | 14 | if isinstance(c["generation_max_length"], int): 15 | c["generation_max_length"] = ",".join([str(c["generation_max_length"])] * len(c["datasets"].split(","))) 16 | for d, t, l, g in zip( 17 | c['datasets'].split(','), c['test_files'].split(','), c['input_max_length'].split(','), 18 | c['generation_max_length'].split(',')): 19 | if "mmlongdoc" in t: 20 | dataset_configs.append( 21 | {"dataset": d, "test_name": os.path.basename(os.path.splitext(t)[0]), "input_max_length": int(l), 22 | "generation_max_length": int(g), "max_test_samples": c['max_test_samples'], 23 | 'use_chat_template': c['use_chat_template']}) 24 | print(dataset_configs) 25 | 26 | models_configs = [ 27 | {"model": "Qwen2.5-VL-3B-Instruct", "use_chat_template": True, "training_length": 32768}, 28 | {"model": "Qwen2.5-VL-7B-Instruct", "use_chat_template": True, "training_length": 32768}, 29 | {"model": 'Qwen2.5-VL-32B-Instruct', "use_chat_template": True, "training_length": 32768}, 30 | 31 | {"model": "Qwen2.5-3B-Instruct_yarn", "use_chat_template": True, "training_length": 131072}, 32 | {"model": "Qwen2.5-7B-Instruct_yarn", "use_chat_template": True, "training_length": 131072}, 33 | {"model": 'Qwen2.5-32B-Instruct_yarn', "use_chat_template": True, "training_length": 131072}, 34 | 35 | {"model": "gemma-3-4b-it", "use_chat_template": True, "training_length": 131072}, 36 | {"model": "gemma-3-12b-it", "use_chat_template": True, "training_length": 131072}, 37 | {"model": "gemma-3-27b-it", "use_chat_template": True, "training_length": 131072}, 38 | ] 39 | result_dir = "/home/zhaowei.wang/data_dir/mmlb_result_backup/analysis_models" 40 | for model in models_configs: 41 | model["output_dir"] = f"{result_dir}/{model['model']}" 42 | 43 | new_dfs = [] 44 | metrics = ["doc_qa", "text_score", "mm_score"] 45 | dataset_to_metrics["text_doc"] = metrics 46 | dataset_to_metrics["mmlongdoc"] = metrics 47 | 48 | for m_idx, model in enumerate(models_configs): 49 | args = arguments() 50 | for dataset in dataset_configs: 51 | args.update(dataset) 52 | args.update(model) 53 | 54 | # parse the metrics 55 | metric = args.get_averaged_metric() 56 | dsimple, mnames = args.get_metric_name() 57 | 58 | if metric is None: 59 | print("failed:", args.get_path()) # will be np.nan when using DataFrame 60 | continue 61 | for k, m in metric.items(): 62 | new_dfs.append({**asdict(args), **model, 63 | "metric name": k, "metric": m, 64 | "dataset_simple": dsimple + " " + k, 65 | "test_data": f"{args.dataset}-{args.test_name}-{args.input_max_length}" 66 | }) 67 | 68 | new_df = pd.DataFrame(new_dfs) 69 | 70 | import utils 71 | chosen_models = [mc["model"] for mc in models_configs] 72 | utils.custom_avgs = {} 73 | utils.dataset_name_replace = {} 74 | all_doc_df = process_df(new_df, chosen_models=chosen_models) 75 | 76 | 77 | text_doc_metric_list = [f"text_doc {m}" for m in metrics] 78 | text_doc_df = all_doc_df[['model', 'input_max_length'] + text_doc_metric_list].copy() 79 | text_doc_df['model'] = text_doc_df['model'] + '_text' 80 | for curr_metric in metrics: 81 | text_doc_df.rename(columns={f'text_doc {curr_metric}': f'mmlongdoc {curr_metric}'}, inplace=True) 82 | 83 | doc_metric_list = [f"mmlongdoc {m}" for m in metrics] 84 | doc_df = all_doc_df[['model', 'input_max_length'] + doc_metric_list].copy() 85 | doc_df = doc_df[~doc_df['model'].str.contains('yarn')] 86 | 87 | lf_df = pd.concat([text_doc_df, doc_df]) 88 | lf_df.reset_index(drop=True, inplace=True) 89 | 90 | # plot specific ones in a row, for formatting in the paper 91 | length_datasets = doc_metric_list 92 | title_list = ["MMLB-Doc (All)", "Text-Pure Cases", "Vision-Needed Cases"] # "Layout/Table/Chart/Image"] 93 | 94 | ncols = 3 95 | nrows = (len(length_datasets) - 1) // ncols + 1 96 | 97 | custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', ['#FFFFFF', '#4A895B']) 98 | fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharey=True, sharex=False) 99 | fig.set_size_inches((20, nrows * 5)) 100 | plt.rc('axes', unicode_minus=False) 101 | plt.rcParams.update({'axes.unicode_minus': False}) 102 | 103 | base_index_order = [ 104 | 'Qwen2.5-VL-7B-Inst', 'Qwen2.5-VL-7B-Inst_text', 'Qwen2.5-7B-Instruct_yarn_text', 105 | 'Qwen2.5-VL-32B-Inst', 'Qwen2.5-VL-32B-Inst_text', 'Qwen2.5-32B-Instruct_yarn_text', 106 | "Gemma3-27B", "Gemma3-27B_text", 107 | ] # "Gemma3-12B", "Gemma3-12B_triviaqa", 108 | 109 | for i, dataset in enumerate(length_datasets): 110 | if nrows > 1: 111 | a = ax[i // ncols][i % ncols] 112 | elif ncols > 1: 113 | a = ax[i] 114 | else: 115 | a = ax 116 | 117 | tdf = lf_df[lf_df.input_max_length > 4096] 118 | tdf = tdf.pivot_table(index="model", columns="input_max_length", values=dataset) 119 | tdf = tdf.reindex(base_index_order) 120 | 121 | sns_g = sns.heatmap( 122 | tdf, annot=True, cmap=custom_cmap, fmt=".1f", yticklabels=True, 123 | ax=a, annot_kws={"fontsize": 22}, 124 | cbar=False 125 | ) 126 | sns_g.set_title(title_list[i], fontsize=28) 127 | 128 | sns_g.set_ylabel("") 129 | sns_g.set_xlabel("") 130 | # "Gemma3-12B", '$\\diamond$ w/ name', 131 | written_index = ['Qwen2.5-VL-7B', '$\\diamond$ w/ OCR', '$\\diamond$ w/ LLM', 132 | 'Qwen2.5-VL-32B', '$\\diamond$ w/ OCR', '$\\diamond$ w/ LLM', 133 | "Gemma3-27B", '$\\diamond$ w/ OCR',] 134 | sns_g.set_yticklabels(written_index, size=26) 135 | xticks_map = {"8192": '8k', "16384": '16k', "32768": '32k', "65536":'64k', "131072":'128k'} 136 | sns_g.set_xticklabels([xticks_map[st.get_text()] for st in sns_g.get_xticklabels()], size=26) 137 | 138 | # idx, start, end 139 | a.hlines([3, 6, 8], 0, 6, color="0.95", linestyle="-", linewidth=3) 140 | 141 | [fig.delaxes(a) for a in ax.flatten() if not a.has_data()] 142 | 143 | plt.tight_layout() 144 | plt.subplots_adjust(wspace=0.20) 145 | file_path = os.path.join(project_root, f"figures/15_text_doc_difficulty.pdf") 146 | plt.savefig(file_path, dpi=500, format="pdf") 147 | plt.show() -------------------------------------------------------------------------------- /vlm_model/internvl_v2pe/internvl/patch/internlm2_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Tuple, Union 3 | from einops import rearrange 4 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 5 | import internvl.model 6 | import transformers 7 | import torch.distributed as dist 8 | from internvl.model.internlm2.modeling_internlm2 import ( 9 | INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2, 10 | apply_rotary_pos_emb) 11 | from internvl.model.internlm2.configuration_internlm2 import InternLM2Config 12 | from torch import nn 13 | # from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func 14 | from ring_flash_attn.zigzag_ring_flash_attn_varlen import zigzag_ring_flash_attn_varlen_func 15 | 16 | 17 | 18 | # Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2 19 | class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2): 20 | 21 | def _flash_attention_forward( 22 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None 23 | ): 24 | """ 25 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 26 | first unpad the input, then computes the attention scores and pad the final attention scores. 27 | 28 | Args: 29 | query_states (`torch.Tensor`): 30 | Input query states to be passed to Flash Attention API 31 | key_states (`torch.Tensor`): 32 | Input key states to be passed to Flash Attention API 33 | value_states (`torch.Tensor`): 34 | Input value states to be passed to Flash Attention API 35 | attention_mask (`torch.Tensor`): 36 | rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 37 | of the sequences in the batch. 38 | dropout (`int`, *optional*): 39 | Attention dropout 40 | softmax_scale (`float`, *optional*): 41 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 42 | """ 43 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 44 | query_states = query_states.squeeze(0) 45 | key_states = key_states.squeeze(0) 46 | value_states = value_states.squeeze(0) 47 | cu_seqlens = attention_mask.squeeze(0) 48 | with torch.no_grad(): 49 | max_seqlen = max([ 50 | cu_seqlens[idx+1] - cu_seqlens[idx] 51 | for idx in range(cu_seqlens.size(0) - 1) 52 | ]).item() 53 | 54 | # Contains at least one padding token in the sequence 55 | causal = self.is_causal and query_length != 1 56 | attn_output = flash_attn_varlen_func( 57 | q=query_states, 58 | k=key_states, 59 | v=value_states, 60 | cu_seqlens_q=cu_seqlens, 61 | cu_seqlens_k=cu_seqlens, 62 | max_seqlen_q=max_seqlen, 63 | max_seqlen_k=max_seqlen, 64 | dropout_p=dropout, 65 | softmax_scale=softmax_scale, 66 | causal=causal, 67 | ) 68 | if torch.isnan(attn_output).any(): 69 | print("causal", causal) 70 | input() 71 | raise ValueError('Attention output contains NaN values') 72 | query_states = query_states.unsqueeze(0) 73 | key_states = key_states.unsqueeze(0) 74 | value_states = value_states.unsqueeze(0) 75 | return attn_output 76 | class InternLM2RingAttention2ForPackedTraining(InternLM2FlashAttention2): 77 | def _flash_attention_forward( 78 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None,group=None 79 | ): 80 | """ 81 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 82 | first unpad the input, then computes the attention scores and pad the final attention scores. 83 | 84 | Args: 85 | query_states (`torch.Tensor`): 86 | Input query states to be passed to Flash Attention API 87 | key_states (`torch.Tensor`): 88 | Input key states to be passed to Flash Attention API 89 | value_states (`torch.Tensor`): 90 | Input value states to be passed to Flash Attention API 91 | attention_mask (`torch.Tensor`): 92 | rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 93 | of the sequences in the batch. 94 | dropout (`int`, *optional*): 95 | Attention dropout 96 | softmax_scale (`float`, *optional*): 97 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 98 | """ 99 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 100 | query_states = query_states.squeeze(0) 101 | key_states = key_states.squeeze(0) 102 | value_states = value_states.squeeze(0) 103 | cu_seqlens = attention_mask.squeeze(0) 104 | with torch.no_grad(): 105 | max_seqlen = max([ 106 | cu_seqlens[idx+1] - cu_seqlens[idx] 107 | for idx in range(cu_seqlens.size(0) - 1) 108 | ]).item() 109 | # Contains at least one padding token in the sequence 110 | causal = self.is_causal and query_length != 1 111 | attn_output = zigzag_ring_flash_attn_varlen_func( 112 | q=query_states, 113 | k=key_states, 114 | v=value_states, 115 | cu_seqlens=cu_seqlens, 116 | max_seqlen=max_seqlen, 117 | dropout_p=dropout, 118 | softmax_scale=softmax_scale, 119 | causal=causal, 120 | group=group 121 | ) 122 | if torch.isnan(attn_output).any(): 123 | print("causal", causal) 124 | raise ValueError('Attention output contains NaN values') 125 | query_states = query_states.unsqueeze(0) 126 | key_states = key_states.unsqueeze(0) 127 | value_states = value_states.unsqueeze(0) 128 | return attn_output 129 | 130 | 131 | def replace_internlm2_attention_class(attn_type='packed'): 132 | 133 | if attn_type=='packed': 134 | INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2FlashAttention2ForPackedTraining 135 | elif attn_type=='ring': 136 | print("replacing to ring attn") 137 | INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2RingAttention2ForPackedTraining 138 | else: 139 | raise NotImplementedError() 140 | print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!') 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /vlm_model/nvila_2b_ef8fa9c8/conversation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # This file is modified from https://github.com/haotian-liu/LLaVA/ 17 | 18 | import dataclasses 19 | from enum import Enum, auto 20 | from typing import List 21 | 22 | # from llava.utils.logging import logger 23 | 24 | 25 | class SeparatorStyle(Enum): 26 | """Different separator style.""" 27 | 28 | AUTO = auto() 29 | TWO = auto() 30 | MPT = auto() 31 | PLAIN = auto() 32 | LLAMA_3 = auto() 33 | 34 | 35 | @dataclasses.dataclass 36 | class Conversation: 37 | """A class that keeps all conversation history.""" 38 | 39 | system: str 40 | roles: List[str] 41 | messages: List[List[str]] 42 | sep_style: SeparatorStyle = SeparatorStyle.AUTO 43 | sep: str = "###" 44 | sep2: str = None 45 | version: str = "Unknown" 46 | 47 | def get_prompt(self): 48 | messages = self.messages 49 | if len(messages) > 0 and type(messages[0][1]) is tuple: 50 | messages = self.messages.copy() 51 | init_role, init_msg = messages[0].copy() 52 | init_msg = init_msg[0].replace("", "").strip() 53 | messages[0] = (init_role, "\n" + init_msg) 54 | 55 | if self.sep_style == SeparatorStyle.TWO: 56 | seps = [self.sep, self.sep2] 57 | ret = self.system + seps[0] 58 | for i, (role, message) in enumerate(messages): 59 | if message: 60 | if type(message) is tuple: 61 | message, _, _ = message 62 | ret += role + ": " + message + seps[i % 2] 63 | else: 64 | ret += role + ":" 65 | elif self.sep_style == SeparatorStyle.LLAMA_3: 66 | ret = self.system + self.sep 67 | for rid, (role, message) in enumerate(messages): 68 | if message: 69 | if type(message) is tuple: 70 | message = message[0] 71 | sep = self.sep if rid < len(messages) - 1 else self.sep2 72 | ret += role + message + sep 73 | else: 74 | ret += role 75 | elif self.sep_style == SeparatorStyle.MPT: 76 | ret = self.system + self.sep 77 | for role, message in messages: 78 | if message: 79 | if type(message) is tuple: 80 | message, _, _ = message 81 | ret += role + message + self.sep 82 | else: 83 | ret += role 84 | elif self.sep_style == SeparatorStyle.PLAIN: 85 | seps = [self.sep, self.sep2] 86 | ret = self.system 87 | for i, (role, message) in enumerate(messages): 88 | if message: 89 | if type(message) is tuple: 90 | message, _, _ = message 91 | ret += message + seps[i % 2] 92 | else: 93 | ret += "" 94 | else: 95 | raise ValueError(f"Invalid style: {self.sep_style}") 96 | 97 | return ret 98 | 99 | def append_message(self, role, message): 100 | self.messages.append([role, message]) 101 | 102 | def copy(self): 103 | return Conversation( 104 | system=self.system, 105 | roles=self.roles, 106 | messages=[[x, y] for x, y in self.messages], 107 | sep_style=self.sep_style, 108 | sep=self.sep, 109 | sep2=self.sep2, 110 | version=self.version, 111 | ) 112 | 113 | 114 | conv_auto = Conversation( 115 | system="", 116 | roles=("", ""), 117 | messages=(), 118 | sep_style=SeparatorStyle.AUTO, 119 | sep="\n", 120 | ) 121 | 122 | conv_vicuna_v1 = Conversation( 123 | system="A chat between a curious user and an artificial intelligence assistant. " 124 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 125 | roles=("USER", "ASSISTANT"), 126 | version="v1", 127 | messages=(), 128 | sep_style=SeparatorStyle.TWO, 129 | sep=" ", 130 | sep2="", 131 | ) 132 | 133 | conv_llava_plain = Conversation( 134 | system="", 135 | roles=("", ""), 136 | messages=(), 137 | sep_style=SeparatorStyle.PLAIN, 138 | sep="\n", 139 | ) 140 | 141 | hermes_2 = Conversation( 142 | system="<|im_start|>system\nAnswer the questions.", 143 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 144 | sep_style=SeparatorStyle.MPT, 145 | sep="<|im_end|>", 146 | messages=(), 147 | version="hermes-2", 148 | ) 149 | 150 | # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template. 151 | llama_3_chat = Conversation( 152 | system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. " 153 | "You are able to understand the visual content that the user provides, " 154 | "and assist the user with a variety of tasks using natural language.", 155 | roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"), 156 | version="llama_v3", 157 | messages=(), 158 | sep_style=SeparatorStyle.LLAMA_3, 159 | sep="<|eot_id|>", 160 | sep2="<|end_of_text|>", 161 | ) 162 | 163 | 164 | default_conversation = conv_auto 165 | conv_templates = { 166 | "auto": conv_auto, 167 | "hermes-2": hermes_2, 168 | "llama_3": llama_3_chat, 169 | "v1": conv_vicuna_v1, 170 | "vicuna_v1": conv_vicuna_v1, 171 | "plain": conv_llava_plain, 172 | } 173 | 174 | 175 | CONVERSATION_MODE_MAPPING = { 176 | "vila1.5-3b": "vicuna_v1", 177 | "vila1.5-8b": "llama_3", 178 | "vila1.5-13b": "vicuna_v1", 179 | "vila1.5-40b": "hermes-2", 180 | "llama-3": "llama_3", 181 | "llama3": "llama_3", 182 | } 183 | 184 | 185 | def auto_set_conversation_mode(model_name_or_path: str) -> str: 186 | global default_conversation 187 | for k, v in CONVERSATION_MODE_MAPPING.items(): 188 | if k in model_name_or_path.lower(): 189 | print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.") 190 | default_conversation = conv_templates[v] 191 | return 192 | --------------------------------------------------------------------------------