├── models ├── __init__.py ├── internvl │ ├── __init__.py │ ├── internvl.py │ └── conversation.py ├── builder.py ├── qwenvl.py └── llava_ov.py ├── tasks ├── __init__.py ├── builder.py ├── egoschema.py ├── mlvu.py ├── lvb.py ├── slidevqa.py ├── base.py ├── mpdocvqa.py ├── utils.py ├── videomme.py └── mmlbdoc.py ├── .gitignore ├── scripts ├── document │ ├── path.sh │ ├── annot_scores_internvl-8b_frames.sh │ └── eval_internvl-8b-max16_topk_frames.sh └── video │ ├── annot_scores_internvl-8b_frames.sh │ ├── path.sh │ └── eval_internvl-8b-max1_top32_frames.sh ├── utils.py ├── collect_results.py ├── LICENSE ├── run_model.py ├── README.md ├── eval_model.py └── frame_selection.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/internvl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | model_map = { 5 | 'internvl': ('.internvl.internvl', 'InternVL'), 6 | 'llava_ov': ('.llava_ov', 'LLaVAOneVision'), 7 | 'qwenvl': ('.qwenvl', 'QwenVL') 8 | } 9 | 10 | 11 | def build_model(model_name, model_path, generation_args, image_aspect_ratio=None, **kwargs): 12 | module_name, func_name = model_map[model_name] 13 | module = importlib.import_module(module_name, package=__package__) 14 | model_init = getattr(module, func_name) 15 | 16 | return model_init(model_path, generation_args, image_aspect_ratio=image_aspect_ratio, **kwargs) 17 | 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | logs 9 | test_scripts 10 | *.log 11 | *.log.* 12 | # *.json 13 | # *.jsonl 14 | 15 | # Data 16 | !**/alpaca-data-conversation.json 17 | # Editor 18 | .idea 19 | *.swp 20 | .vscode 21 | 22 | # Other 23 | .DS_Store 24 | wandb 25 | output 26 | 27 | checkpoints 28 | project_checkpoints 29 | debug_checkpoints 30 | playground/data 31 | playground/cc3m_llava34b_cap 32 | ckpts* 33 | 34 | .ipynb_checkpoints 35 | chunyl_scripts 36 | *.ipynb 37 | 38 | # DevContainer 39 | !.devcontainer/* 40 | 41 | # Demo 42 | serve_images/ 43 | notebooks/ 44 | logs 45 | scripts/dist_* 46 | logs/ 47 | submissions/ 48 | # work_dirs 49 | -------------------------------------------------------------------------------- /scripts/document/path.sh: -------------------------------------------------------------------------------- 1 | root= 2 | output_root= 3 | 4 | if [[ "$DATASET" == "slidevqa" ]]; then 5 | # SPLIT="dev" 6 | SPLIT="test" 7 | doc_path=${root}/datasets/SlideVQA/${SPLIT}_doc.json 8 | visual_path=${root}/datasets/SlideVQA/images 9 | elif [[ "$DATASET" == "mmlbdoc" ]]; then 10 | SPLIT="old" 11 | doc_path=${root}/datasets/MMLongBench-Doc/data/${SPLIT}_doc.json 12 | visual_path=${root}/datasets/MMLongBench-Doc/data/images 13 | elif [[ "$DATASET" == "mpdocvqa" ]]; then 14 | # SPLIT="val" 15 | SPLIT="test" 16 | doc_path=${root}/datasets/MP-DocVQA/${SPLIT}_doc.json 17 | visual_path=${root}/datasets/MP-DocVQA/images 18 | fi 19 | 20 | -------------------------------------------------------------------------------- /tasks/builder.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | task_map = { 5 | 'nextqa': ('.base', 'VideoTask'), 6 | 'egoschema': ('.egoschema', 'EgoSchema'), 7 | 'lvb': ('.lvb', 'LongVideoBench'), 8 | 'videomme': ('.videomme', 'VideoMME'), 9 | 'mlvu': ('.mlvu', 'MLVU'), 10 | 'slidevqa': ('.slidevqa', 'SlideVQA'), 11 | 'mmlbdoc': ('.mmlbdoc', 'MMLongBenchDoc'), 12 | 'mpdocvqa': ('.mpdocvqa', 'MPDocVQA') 13 | } 14 | 15 | 16 | def build_task(dataset, split, **kwargs): 17 | if dataset is None: 18 | module_name = '.base' 19 | fund_name = 'VideoTask' 20 | else: 21 | module_name, func_name = task_map[dataset] 22 | module = importlib.import_module(module_name, package=__package__) 23 | task_init = getattr(module, func_name) 24 | 25 | return task_init(dataset, split, **kwargs) -------------------------------------------------------------------------------- /scripts/video/annot_scores_internvl-8b_frames.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=$1 4 | num_frames=$2 5 | CHUNKS=$3 6 | IDX=$4 7 | 8 | source "$(dirname "${BASH_SOURCE[0]}")"/path.sh 9 | 10 | filename=$(basename "$0") 11 | filename="${filename%.*}" 12 | output_dir=${output_root}/${DATASET}/$SPLIT/${filename}_${num_frames} 13 | 14 | python eval_model.py \ 15 | --num-chunks $CHUNKS \ 16 | --chunk-idx $IDX \ 17 | --doc-path $doc_path \ 18 | --visual-folder $visual_path \ 19 | --output-dir $output_dir \ 20 | --model "internvl" \ 21 | --model-path "${root}/ckpts/InternVL2-8B" \ 22 | --sample_frames $num_frames \ 23 | --input_frames 1 \ 24 | --selector_method "annot_scores_frames" \ 25 | --image-aspect-ratio 12 \ 26 | --dataset $DATASET \ 27 | --split $SPLIT \ 28 | --main-process 29 | 30 | 31 | -------------------------------------------------------------------------------- /scripts/document/annot_scores_internvl-8b_frames.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=$1 4 | num_frames=$2 5 | CHUNKS=$3 6 | IDX=$4 7 | 8 | source "$(dirname "${BASH_SOURCE[0]}")"/path.sh 9 | 10 | filename=$(basename "$0") 11 | filename="${filename%.*}" 12 | output_dir=${output_root}/${DATASET}/$SPLIT/${filename}_${num_frames} 13 | 14 | python eval_model.py \ 15 | --num-chunks $CHUNKS \ 16 | --chunk-idx $IDX \ 17 | --doc-path $doc_path \ 18 | --visual-folder $visual_path \ 19 | --output-dir $output_dir \ 20 | --model "internvl" \ 21 | --model-path "${root}/ckpts/InternVL2-8B" \ 22 | --sample_frames $num_frames \ 23 | --input_frames 1 \ 24 | --selector_method "annot_scores_frames" \ 25 | --image-aspect-ratio 12 \ 26 | --dataset $DATASET \ 27 | --split $SPLIT \ 28 | --main-process 29 | 30 | 31 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import hashlib 3 | import numpy as np 4 | from decord import VideoReader, cpu 5 | 6 | 7 | def split_list(lst, n): 8 | """Split a list into n (roughly) equal-sized chunks""" 9 | chunk_size = math.ceil(len(lst) / n) # integer division 10 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 11 | 12 | 13 | def get_chunk(lst, n, k): 14 | chunks = split_list(lst, n) 15 | return chunks[k] 16 | 17 | 18 | def hash_string(input_string): 19 | hash_object = hashlib.sha1() 20 | hash_object.update(input_string.encode('utf-8')) 21 | hashed_string = hash_object.hexdigest() 22 | 23 | return hashed_string 24 | 25 | 26 | def frame_to_sec(frame_indices, visual_path): 27 | vr = VideoReader(visual_path, ctx=cpu(0), num_threads=1) 28 | 29 | # Get the frame rate (FPS) 30 | fps = vr.get_avg_fps() 31 | 32 | return [float(x)/fps for x in frame_indices] 33 | 34 | -------------------------------------------------------------------------------- /scripts/video/path.sh: -------------------------------------------------------------------------------- 1 | root= 2 | output_root= 3 | 4 | if [[ "$DATASET" == "egoschema" ]]; then 5 | SPLIT="full" 6 | doc_path=${root}/datasets/EgoSchema/${SPLIT}_doc_list.json 7 | visual_path=${root}/datasets/EgoSchema/Egochema_videos 8 | 9 | elif [[ "$DATASET" == "lvb" ]]; then 10 | SPLIT="val" 11 | doc_path=${root}/datasets/LongVideoBench/lvb_${SPLIT}_doc_list.json 12 | visual_path=${root}/datasets/LongVideoBench/videos 13 | 14 | elif [[ "$DATASET" == "videomme" ]]; then 15 | SPLIT="test" 16 | doc_path=${root}/datasets/Video-MME/${SPLIT}_doc_list.json 17 | visual_path=${root}/datasets/Video-MME/data 18 | 19 | elif [[ "$DATASET" == "nextqa" ]]; then 20 | SPLIT="val" 21 | doc_path=${root}/datasets/nextqa/mc_${SPLIT}_doc_list.json 22 | visual_path=${root}/datasets/nextqa/NExTVideo 23 | 24 | elif [[ "$DATASET" == "mlvu" ]]; then 25 | SPLIT="dev_mc" 26 | doc_path=${root}/datasets/MLVU/${SPLIT}_doc_list.json 27 | visual_path=${root}/datasets/MLVU/MLVU/video 28 | fi 29 | 30 | -------------------------------------------------------------------------------- /scripts/document/eval_internvl-8b-max16_topk_frames.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=$1 4 | num_frames=$2 5 | 6 | source "$(dirname "${BASH_SOURCE[0]}")"/path.sh 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | filename=$(basename "$0") 14 | filename="${filename%.*}" 15 | output_dir=${output_root}/${DATASET}/$SPLIT/${filename}_${num_frames} 16 | 17 | 18 | for IDX in $(seq 0 $((CHUNKS-1))); do 19 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python eval_model.py \ 20 | --doc-path $doc_path \ 21 | --visual-folder $visual_path \ 22 | --output-dir $output_dir \ 23 | --model "internvl" \ 24 | --model-path "${root}/ckpts/InternVL2-8B" \ 25 | --sample_frames -1 \ 26 | --input_frames $num_frames \ 27 | --selector_method "topk_frames" \ 28 | --image-aspect-ratio 16 \ 29 | --score-docs "${output_root}/${DATASET}/$SPLIT/annot_scores_internvl-8b_frames_-1.json" \ 30 | --dataset $DATASET \ 31 | --split $SPLIT \ 32 | --num-chunks $CHUNKS \ 33 | --chunk-idx $IDX & 34 | done 35 | 36 | wait 37 | 38 | python collect_results.py \ 39 | --result-path $output_dir \ 40 | --doc-path $doc_path \ 41 | --dataset $DATASET \ 42 | --split $SPLIT \ 43 | --eval 44 | -------------------------------------------------------------------------------- /scripts/video/eval_internvl-8b-max1_top32_frames.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATASET=$1 4 | num_frames=$2 5 | 6 | source "$(dirname "${BASH_SOURCE[0]}")"/path.sh 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | filename=$(basename "$0") 14 | filename="${filename%.*}" 15 | output_dir=${output_root}/${DATASET}/$SPLIT/${filename}_${num_frames} 16 | 17 | 18 | for IDX in $(seq 0 $((CHUNKS-1))); do 19 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python eval_model.py \ 20 | --doc-path $doc_path \ 21 | --visual-folder $visual_path \ 22 | --output-dir $output_dir \ 23 | --model "internvl" \ 24 | --model-path "${root}/ckpts/InternVL2-8B" \ 25 | --sample_frames $num_frames \ 26 | --input_frames 32 \ 27 | --selector_method "topk_frames" \ 28 | --score-docs "${output_root}/${DATASET}/$SPLIT/annot_scores_internvl-8b_frames_256.json" \ 29 | --image-aspect-ratio 1 \ 30 | --dataset $DATASET \ 31 | --split $SPLIT \ 32 | --num-chunks $CHUNKS \ 33 | --chunk-idx $IDX & 34 | done 35 | 36 | wait 37 | 38 | python collect_results.py \ 39 | --result-path $output_dir \ 40 | --doc-path $doc_path \ 41 | --dataset $DATASET \ 42 | --split $SPLIT \ 43 | --eval 44 | -------------------------------------------------------------------------------- /tasks/egoschema.py: -------------------------------------------------------------------------------- 1 | from .base import VideoTask 2 | from .utils import get_multi_choice_info, parse_multi_choice_response 3 | import requests 4 | 5 | 6 | class EgoSchema(VideoTask): 7 | def __init__(self, dataset, split, **kwargs): 8 | super().__init__(dataset, split, **kwargs) 9 | 10 | assert self.split in ['subset', 'full'] 11 | 12 | def aggregate_results(self, docs, out_root): 13 | if self.split == "subset": 14 | return super().aggregate_results(docs, out_root) 15 | elif self.split == "full": 16 | out_dict = {} 17 | for doc in docs: 18 | pred = doc["pred"][0] 19 | index2ans, all_choices = get_multi_choice_info(doc) 20 | parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) 21 | if parsed_pred not in all_choices: 22 | res = -1 23 | else: 24 | res = all_choices.index(parsed_pred) 25 | out_dict[doc["id"]] = res 26 | 27 | url = "https://validation-server.onrender.com/api/upload/" 28 | headers = { 29 | "Content-Type": "application/json" 30 | } 31 | response = requests.post(url, headers=headers, json=out_dict) 32 | 33 | out_file = out_root + '.log' 34 | with open(out_file, 'a') as file: 35 | file.write(response.text) 36 | file.write('\n') 37 | 38 | -------------------------------------------------------------------------------- /collect_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | 6 | def main(args): 7 | docs = json.load(open(args.doc_path, "r")) 8 | output_path = args.result_path + '.json' 9 | if not os.path.exists(output_path): 10 | out_docs = [] 11 | for i, doc in enumerate(docs): 12 | if 'id' in doc: 13 | doc_id = doc['id'] 14 | else: 15 | doc_id = "%06d" % i 16 | res_path = os.path.join(args.result_path, "%s.json" % doc_id) 17 | try: 18 | doc = json.load(open(res_path, "r")) 19 | except Exception as error: 20 | print(res_path) 21 | raise error 22 | out_docs.append(doc) 23 | 24 | with open(output_path, "w") as f: 25 | json.dump(out_docs, f) 26 | else: 27 | out_docs = json.load(open(output_path, 'r')) 28 | 29 | if args.eval: 30 | from tasks.builder import build_task 31 | task = build_task(args.dataset, args.split) 32 | task.aggregate_results(out_docs, args.result_path) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--result-path", type=str, required=True) 38 | parser.add_argument("--doc-path", type=str, required=True) 39 | parser.add_argument("--eval", action="store_true") 40 | parser.add_argument("--dataset", type=str, default=None) 41 | parser.add_argument("--split", type=str, default=None) 42 | args = parser.parse_args() 43 | 44 | main(args) 45 | -------------------------------------------------------------------------------- /tasks/mlvu.py: -------------------------------------------------------------------------------- 1 | from .base import VideoTask 2 | from .utils import OPTIONS 3 | from collections import defaultdict 4 | import logging 5 | 6 | 7 | def extract_characters_regex(s): 8 | s = s.strip() 9 | if ")" in s: 10 | index = s.index(")") 11 | pred = s[index - 1 : index] 12 | return pred 13 | else: 14 | return s 15 | 16 | 17 | class MLVU(VideoTask): 18 | def __init__(self, dataset, split, **kwargs): 19 | super().__init__(dataset, split, **kwargs) 20 | 21 | self.post_prompt = "\nOnly give the best option.\nBest Option: (" 22 | 23 | assert self.split in ['dev_mc', 'dev_gen'], "dev only now" 24 | 25 | def doc_to_prompt(self, doc): 26 | if self.split.endswith('_mc'): 27 | return self.doc_to_prompt_mc(doc) 28 | elif self.split.endswith('_gen'): 29 | return self.doc_to_prompt_gen(doc) 30 | 31 | def doc_to_prompt_mc(self, doc): 32 | prompt_assistant = self.prompt_assistant 33 | 34 | question = f"Question: {doc['question']}\n" 35 | question += "Options:\n" 36 | for idx, c in enumerate(doc['options']): 37 | question += f"({chr(ord('A') + idx)}) {c}\n" 38 | question = question.rstrip() 39 | 40 | prompt_user = question + self.post_prompt 41 | 42 | return (prompt_user, prompt_assistant) 43 | 44 | def doc_to_prompt_gen(self, doc): 45 | pass # TODO 46 | 47 | def aggregate_results(self, docs, out_root): 48 | if self.split in ['dev_gen', 'test_mc', 'test_gen']: 49 | raise NotImplementedError('') 50 | elif self.split == 'dev_mc': 51 | out_file = out_root + '.log' 52 | logging.basicConfig(filename=out_file, 53 | level=logging.INFO, 54 | format='%(asctime)s - %(levelname)s - %(message)s') 55 | acc_dict = defaultdict(list) 56 | for doc in docs: 57 | pred = doc["pred"][0] 58 | answer = OPTIONS[doc["answer"]] 59 | pred_ans = extract_characters_regex(pred) 60 | correct = int(answer == pred_ans) 61 | acc_dict[doc["type"]].append(correct) 62 | 63 | final_res = dict() 64 | total=0 65 | idx=0 66 | for k, v in acc_dict.items(): 67 | idx+=1 68 | final_res[k] = 100 * float(sum(v)) / len(v) 69 | total+=final_res[k] 70 | final_res['Avg'] = total /idx 71 | for k, v in final_res.items(): 72 | logging.info(f"{k} Acc: {v :.2f}%") 73 | 74 | -------------------------------------------------------------------------------- /tasks/lvb.py: -------------------------------------------------------------------------------- 1 | from .base import VideoTask 2 | from .utils import get_multi_choice_info, parse_multi_choice_response, mc_process_results 3 | import logging 4 | from collections import defaultdict 5 | 6 | 7 | CATEGORIES = [ 8 | 'E2O', 9 | 'E3E', 10 | 'O2E', 11 | 'O3O', 12 | 'S2A', 13 | 'S2E', 14 | 'S2O', 15 | 'SAA', 16 | 'SOS', 17 | 'SSS', 18 | 'T2A', 19 | 'T2E', 20 | 'T2O', 21 | 'T3E', 22 | 'T3O', 23 | 'TAA', 24 | 'TOS' 25 | ] 26 | 27 | 28 | class LongVideoBench(VideoTask): 29 | def __init__(self, dataset, split, subtitles=False, **kwargs): 30 | super().__init__(dataset, split, **kwargs) 31 | 32 | assert self.split in ['val', 'test'] 33 | 34 | self.subtitles = subtitles 35 | if self.subtitles: 36 | raise NotImplementedError('subtitles not implemented yet') 37 | # TODO: change prompt for subtitles 38 | 39 | def aggregate_results(self, docs, out_root): 40 | if self.split == "test": 41 | raise NotImplementedError('test set not implemented') 42 | # pred = results 43 | # index2ans, all_choices = get_multi_choice_info(doc) 44 | # parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) 45 | elif self.split == "val": 46 | out_file = out_root + '.log' 47 | logging.basicConfig(filename=out_file, 48 | level=logging.INFO, 49 | format='%(asctime)s - %(levelname)s - %(message)s') 50 | subset_to_eval_samples = defaultdict(list) 51 | for doc in docs: 52 | pred = doc["pred"][0] 53 | res = mc_process_results(doc, pred) 54 | correct = int(res["exact_match"]) 55 | 56 | subset_to_eval_samples[doc["question_category"]].append(correct) 57 | subset_to_eval_samples[doc["duration_group"]].append(correct) 58 | subset_to_eval_samples["overall"].append(correct) 59 | for subset in CATEGORIES: 60 | sub_eval_samples = subset_to_eval_samples[subset] 61 | total_correct = float(sum(sub_eval_samples)) 62 | total_answered = len(sub_eval_samples) 63 | logging.info(f"Evaluation on Question Categories: {subset}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 64 | for subset in [15, 60, 600, 3600]: 65 | sub_eval_samples = subset_to_eval_samples[subset] 66 | total_correct = float(sum(sub_eval_samples)) 67 | total_answered = len(sub_eval_samples) 68 | logging.info(f"Evaluation on Duration Categories: {subset}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 69 | for subset in ["overall"]: 70 | sub_eval_samples = subset_to_eval_samples[subset] 71 | total_correct = float(sum(sub_eval_samples)) 72 | total_answered = len(sub_eval_samples) 73 | logging.info(f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 74 | -------------------------------------------------------------------------------- /tasks/slidevqa.py: -------------------------------------------------------------------------------- 1 | from .base import DocumentTask 2 | import string 3 | from collections import Counter 4 | import re 5 | 6 | 7 | WORD_NUMBER_MAP = {"zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, 8 | "five": 5, "six": 6, "seven": 7, "eight": 8, 9 | "nine": 9, "ten": 10, "eleven": 11, "twelve": 12, 10 | "thirteen": 13, "fourteen": 14, "fifteen": 15, 11 | "sixteen": 16, "seventeen": 17, "eighteen": 18, "nineteen": 19} 12 | 13 | def normalize_answer(s, question): 14 | def remove_articles(text): 15 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 16 | return re.sub(regex, ' ', text) 17 | def white_space_fix(text): 18 | return ' '.join(text.split()) 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | def lower(text): 23 | return text.lower() 24 | def yesno(text): 25 | if 'yes' == text[:3] or 'no' == text[:2]: 26 | text = text.split()[0] 27 | return text 28 | def replace_text(text): 29 | return text.replace('this is ', '').replace('it is ', '').replace('&', ',').replace('and', ',').replace('percent', '').replace('organisation', 'organization').replace('because of', '').replace('because', '').replace('due to', '').replace('hours', 'hrs').replace('minites', 'min') 30 | def word2number(text): 31 | words = text.split() 32 | return ' '.join([str(WORD_NUMBER_MAP[word]) if word in WORD_NUMBER_MAP else word for word in words]) 33 | def remove_unit(text, question): 34 | if 'how many' in question: 35 | idx = question.find('how many') 36 | unit = question[idx+len('how many'):].split()[0] 37 | text = text.replace(unit, '') 38 | if 'which' in question: 39 | idx = question.find('which') 40 | unit = question[idx+len('which'):].split()[0] 41 | text = text.replace(unit, '') 42 | return text 43 | return word2number(white_space_fix(yesno(remove_articles(remove_punc(remove_unit(replace_text(lower(s)), question)))))) 44 | 45 | 46 | class SlideVQA(DocumentTask): 47 | def __init__(self, dataset, split, **kwargs): 48 | super().__init__(dataset, split, **kwargs) 49 | 50 | assert self.split in ['dev', 'test', 'train'] 51 | 52 | def aggregate_results(self, docs, out_root): 53 | f1 = exact_match = 0 54 | precisions = {} 55 | recalls = {} 56 | ems = {} 57 | for doc in docs: 58 | qa_id = doc['id'] 59 | question = doc['question'] 60 | prediction = doc['pred'][0].strip() 61 | ground_truth = doc['answer'] 62 | prediction_tokens = normalize_answer(prediction, question).split() 63 | ground_truth_tokens = normalize_answer(ground_truth, question).split() 64 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 65 | num_same = sum(common.values()) 66 | if num_same == 0: 67 | precisions[qa_id] = recalls[qa_id] = ems[qa_id] = 0 68 | continue 69 | precision = 1.0 * num_same / len(prediction_tokens) 70 | recall = 1.0 * num_same / len(ground_truth_tokens) 71 | f1 += (2 * precision * recall) / (precision + recall) 72 | exact_match += (prediction_tokens == ground_truth_tokens) 73 | precisions[qa_id] = precision 74 | recalls[qa_id] = recall 75 | ems[qa_id] = (prediction_tokens == ground_truth_tokens) 76 | exact_match = exact_match / len(docs) 77 | f1 = f1 / len(docs) 78 | 79 | out_file = out_root + '.log' 80 | with open(out_file, 'a') as file: 81 | file.write(f"EM: {exact_match*100}\nF1: {f1*100}") 82 | 83 | -------------------------------------------------------------------------------- /tasks/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/FRAG/blob/main/LICENSE 6 | 7 | import numpy as np 8 | from decord import VideoReader, cpu 9 | from PIL import Image 10 | 11 | from .utils import doc_to_text_mc, mc_process_results 12 | 13 | 14 | class VideoTask: 15 | def __init__(self, dataset, split, **kwargs): 16 | self.dataset = dataset 17 | self.split = split 18 | 19 | self.post_prompt = "\nAnswer with the option's letter from the given choices directly." 20 | self.prompt_assistant = None 21 | 22 | self.result_key = "pred" 23 | 24 | def load_visual(self, visual_path, num_frames): 25 | vr = VideoReader(visual_path, ctx=cpu(0), num_threads=1) 26 | total_frames = len(vr) 27 | num_frames = min(num_frames, total_frames) 28 | if num_frames == 1: 29 | # get middle instead of first for single frame 30 | frame_indices = np.arange(0, total_frames, total_frames / 2).astype(int).tolist() 31 | frame_indices = frame_indices[1:] 32 | else: 33 | frame_indices = np.arange(0, total_frames, total_frames / num_frames).astype(int).tolist() 34 | 35 | frame_indices = [x for x in frame_indices if x < total_frames] 36 | 37 | frames = vr.get_batch(frame_indices).asnumpy() 38 | images = [Image.fromarray(frames[i]).convert('RGB') for i in range(len(frames))] 39 | 40 | return images, frame_indices 41 | 42 | def doc_to_prompt(self, doc): 43 | prompt_user = doc_to_text_mc(doc, {"post_prompt": self.post_prompt}) 44 | prompt_assistant = self.prompt_assistant 45 | 46 | return (prompt_user, prompt_assistant) 47 | 48 | def doc_to_visual_name(self, doc): 49 | return doc["video"] 50 | 51 | def process_results(self, doc, results): 52 | return results 53 | 54 | def aggregate_results(self, docs, out_root): 55 | out_file = out_root + '.log' 56 | 57 | cnt = 0 58 | correct = 0 59 | for doc in docs: 60 | pred = doc["pred"][0] 61 | res = mc_process_results(doc, pred) 62 | cnt += 1 63 | correct += int(res["exact_match"]) 64 | accuracy = float(correct) / cnt 65 | print(f"Accuracy: {accuracy}") 66 | with open(out_file, 'a') as file: 67 | file.write(f"Accuracy: {accuracy}\n") 68 | 69 | 70 | class DocumentTask: 71 | def __init__(self, dataset, split, **kwargs): 72 | self.dataset = dataset 73 | self.split = split 74 | 75 | # self.post_prompt = "\nAnswer the question concisely based on the provided images." 76 | self.post_prompt = "\nAnswer the question based on the provided images. Please make your response as concise as possible." 77 | self.prompt_assistant = None 78 | 79 | self.result_key = "pred" 80 | 81 | def load_visual(self, visual_paths, num_frames): 82 | assert isinstance(visual_paths, list) 83 | total_frames = len(visual_paths) 84 | if num_frames > 0: 85 | num_frames = min(num_frames, total_frames) 86 | frame_indices = np.arange(0, total_frames, total_frames / num_frames).astype(int).tolist() 87 | else: 88 | frame_indices = list(range(len(visual_paths))) 89 | frame_indices = [x for x in frame_indices if x < total_frames] 90 | images = [Image.open(visual_paths[i]).convert("RGB") for i in frame_indices] 91 | 92 | return images, frame_indices 93 | 94 | def doc_to_prompt(self, doc): 95 | prompt_assistant = self.prompt_assistant 96 | 97 | prompt_user = doc["question"].strip() + self.post_prompt 98 | 99 | return (prompt_user, prompt_assistant) 100 | 101 | def doc_to_visual_name(self, doc): 102 | return doc["deck_name"] 103 | 104 | def process_results(self, doc, results): 105 | return results 106 | 107 | def aggregate_results(self, docs, out_root): 108 | pass 109 | 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025, NVIDIA Corporation & Affiliates. All rights reserved. 2 | 3 | Nvidia Source Code License-NC 4 | 5 | 6 | 1. Definition 7 | 8 | “Licensor” means any person or entity that distributes its Work. 9 | 10 | “Work” means (a) the original work of authorship made available under this license, which may include software, 11 | documentation, or other files, and (b) any additions to or derivative works thereof that are made available 12 | under this license. 13 | 14 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 15 | U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include 16 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 17 | 18 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice 19 | referencing the applicability of this license to the Work, or (b) a copy of this license. 20 | 21 | 2. License Grant 22 | 23 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, 24 | worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly 25 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 26 | 27 | 3. Limitations 28 | 29 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you 30 | include a complete copy of this license with your distribution, and (c) you retain without modification any 31 | copyright, patent, trademark, or attribution notices that are present in the Work. 32 | 33 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 34 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 35 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 36 | that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements 37 | in Section 3.1) will continue to apply to the Work itself. 38 | 39 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. 40 | Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works 41 | commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 42 | 43 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 44 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 45 | your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 46 | 47 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or 48 | trademarks, except as necessary to reproduce the notices described in this license. 49 | 50 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant 51 | in Section 2.1) will terminate immediately. 52 | 53 | 4. Disclaimer of Warranty. 54 | 55 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 56 | WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 57 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 58 | 59 | 5. Limitation of Liability. 60 | 61 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 62 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, 63 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE 64 | THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER 65 | FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY 66 | OF SUCH DAMAGES. 67 | -------------------------------------------------------------------------------- /models/qwenvl.py: -------------------------------------------------------------------------------- 1 | # Adopted from Qwen2-VL from https://github.com/QwenLM/Qwen2-VL. Below is the original copyright: 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 16 | from qwen_vl_utils import process_vision_info 17 | import torch 18 | 19 | 20 | class QwenVL: 21 | def __init__(self, model_path, generation_args, image_aspect_ratio=None): 22 | # load model and processor 23 | self.model= Qwen2VLForConditionalGeneration.from_pretrained( 24 | model_path, 25 | torch_dtype=torch.bfloat16, 26 | attn_implementation="flash_attention_2", 27 | device_map="auto" 28 | ) 29 | self.processor = AutoProcessor.from_pretrained(model_path) 30 | 31 | # set dynamic resolution param 32 | if image_aspect_ratio is not None: 33 | self.max_pixels = int(image_aspect_ratio) * 28 * 28 34 | else: 35 | self.max_pixels = None 36 | 37 | # get tokenizer from processor. needed to get option's id in frame selector 38 | self.tokenizer = self.processor.tokenizer 39 | 40 | # save generation_args to use in run_model 41 | self.generation_args = generation_args 42 | 43 | def run_model(self, images, message, output_scores=False, post_proc_func=None): 44 | message, prompt_assistant = message 45 | 46 | messages = [] 47 | 48 | user_content = [] 49 | for image in images: 50 | if self.max_pixels is not None: 51 | user_content.append({"type": "image", "image": image, "max_pixels": self.max_pixels}) 52 | else: 53 | user_content.append({"type": "image", "image": image}) 54 | user_content.append({"type": "text", "text": message}) 55 | user_message = { 56 | "role": "user", 57 | "content": user_content 58 | } 59 | messages.append(user_message) 60 | 61 | if prompt_assistant is not None: 62 | assistant_message = { 63 | "role": "assistant", 64 | "content": [{"type": "text", "text": prompt_assistant}] 65 | } 66 | messages.append(assistant_message) 67 | 68 | # Preparation for inference 69 | text = self.processor.apply_chat_template( 70 | messages, tokenize=False, add_generation_prompt=True 71 | ) 72 | image_inputs, video_inputs = process_vision_info(messages) 73 | inputs = self.processor( 74 | text=[text], 75 | images=image_inputs, 76 | videos=video_inputs, 77 | padding=True, 78 | return_tensors="pt", 79 | ) 80 | inputs = inputs.to("cuda") 81 | 82 | # Inference 83 | with torch.inference_mode(): 84 | generation_output = self.model.generate( 85 | **inputs, 86 | **self.generation_args, 87 | output_scores=output_scores, 88 | return_dict_in_generate=output_scores, 89 | ) 90 | 91 | if post_proc_func is not None: 92 | outputs = post_proc_func(generation_output) 93 | else: 94 | generated_ids_trimmed = [ 95 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generation_output) 96 | ] 97 | outputs = self.processor.batch_decode( 98 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 99 | ) 100 | outputs = outputs[0].strip() 101 | return outputs 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/FRAG/blob/main/LICENSE 6 | 7 | import argparse 8 | import os 9 | import json 10 | import numpy as np 11 | 12 | from tasks.utils import load_video, load_images 13 | 14 | from models.builder import build_model 15 | from frame_selection import FrameSelection 16 | from collections import defaultdict 17 | 18 | 19 | def main(args): 20 | generation_args = {"max_new_tokens": args.max_new_tokens, 21 | "temperature": args.temperature, 22 | "do_sample": args.do_sample} 23 | 24 | # answering LMM 25 | model_path = os.path.expanduser(args.model_path) 26 | model = build_model(args.model, 27 | model_path, 28 | generation_args, 29 | image_aspect_ratio=args.image_aspect_ratio) 30 | 31 | # scroing LMM 32 | if args.selector_model is None: 33 | selector_model = model 34 | else: 35 | selector_model_path = os.path.expanduser(args.selector_model_path) 36 | selector_model = build_model(args.selector_model, 37 | args.selector_model_path, 38 | generation_args, 39 | image_aspect_ratio=args.selector_image_aspect_ratio) 40 | 41 | selector = FrameSelection(selector_model, args.selector_method, args.input_frames, args.sample_frames) 42 | 43 | # load input 44 | if args.input_type == 'video': 45 | images, frame_indices = load_video(args.input_path, args.sample_frames) 46 | elif args.input_type == 'images': 47 | image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp') 48 | image_files = [file for file in os.listdir(args.input_path) if file.lower().endswith(image_extensions)] 49 | image_files = sorted(image_files) 50 | image_files = [os.path.join(args.input_path, file) for file in image_files] 51 | images, frame_indices = load_images(image_files, args.sample_frames) 52 | else: 53 | raise NotImplementedError(f"unknown input type: {args.input_type}") 54 | 55 | doc = {} 56 | doc['visual_path'] = args.input_path 57 | doc['question'] = args.query 58 | 59 | prompt = doc["question"].strip() 60 | if args.pre_prompt is not None: 61 | prompt = f"{args.pre_prompt}{prompt}" 62 | if args.post_prompt is not None: 63 | prompt = f"{prompt}{args.post_prompt}" 64 | doc['prompt'] = prompt 65 | 66 | selected_images_list, selected_indices_list, scores_dict = selector.select_frames(doc, images, frame_indices, return_scores=True) 67 | selected_images = selected_images_list[0] 68 | selected_indices = selected_indices_list[0] 69 | doc["selected_frames"] = selected_indices 70 | doc["scores"] = scores_dict 71 | 72 | outputs = model.run_model(selected_images, (prompt, None)) 73 | doc["pred"] = outputs 74 | print(f"FRAG: {outputs}") 75 | 76 | if isinstance(next(iter(doc["scores"])), tuple): 77 | doc["scores"] = {str(k): v for k, v in doc["scores"].items()} 78 | 79 | out_path = os.path.join(args.output_dir, "%s.json" % os.path.basename(doc['visual_path'])) 80 | with open(out_path, "w") as f: 81 | json.dump(doc, f) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--output-dir", type=str, default="outputs") 87 | 88 | parser.add_argument("--max_new_tokens", type=int, default=128) 89 | parser.add_argument("--temperature", type=float, default=0.0) 90 | 91 | parser.add_argument("--model", type=str, default="llava_ov") 92 | parser.add_argument("--model-path", type=str, default="lmms-lab/llava-onevision-qwen2-7b-ov") 93 | parser.add_argument("--image-aspect-ratio", type=str, default="anyres_max_9") 94 | parser.add_argument("--selector-model", type=str, default=None) 95 | parser.add_argument("--selector-model-path", type=str, default=None) 96 | parser.add_argument("--selector-image-aspect-ratio", type=str, default="anyres_max_9") 97 | 98 | parser.add_argument("--sample_frames", type=int, default=64) 99 | parser.add_argument("--input_frames", type=int, default=1) 100 | parser.add_argument("--selector_method", type=str, default="topk_frames") 101 | 102 | parser.add_argument("--input-type", type=str, default="video") 103 | parser.add_argument("--input-path", type=str, default="") 104 | 105 | parser.add_argument("--query", type=str, default="") 106 | parser.add_argument("--pre-prompt", type=str, default=None) 107 | parser.add_argument("--post-prompt", type=str, default=None) 108 | 109 | args = parser.parse_args() 110 | 111 | os.makedirs(args.output_dir, exist_ok=True) 112 | args.do_sample = args.temperature > 0 113 | 114 | main(args) 115 | 116 | -------------------------------------------------------------------------------- /tasks/mpdocvqa.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/rubenpt91/MP-DocVQA-Framework 2 | # Licensed under The MIT License. 3 | # See https://github.com/rubenpt91/MP-DocVQA-Framework/blob/master/LICENSE for details 4 | 5 | from .base import DocumentTask 6 | 7 | import editdistance 8 | import json 9 | 10 | 11 | class Evaluator: 12 | def __init__(self, case_sensitive=False): 13 | 14 | self.case_sensitive = case_sensitive 15 | self.get_edit_distance = editdistance.eval 16 | self.anls_threshold = 0.5 17 | 18 | self.total_accuracies = [] 19 | self.total_anls = [] 20 | 21 | self.best_accuracy = 0 22 | # self.best_anls = 0 23 | self.best_epoch = 0 24 | 25 | def get_metrics(self, gt_answers, preds, answer_types=None, update_global_metrics=True): 26 | answer_types = answer_types if answer_types is not None else ['string' for batch_idx in range(len(gt_answers))] 27 | batch_accuracy = [] 28 | batch_anls = [] 29 | for batch_idx in range(len(preds)): 30 | gt = [self._preprocess_str(gt_elm) for gt_elm in gt_answers[batch_idx]] 31 | pred = self._preprocess_str(preds[batch_idx]) 32 | 33 | batch_accuracy.append(self._calculate_accuracy(gt, pred, answer_types[batch_idx])) 34 | batch_anls.append(self._calculate_anls(gt, pred, answer_types[batch_idx])) 35 | 36 | # if accumulate_metrics: 37 | # self.total_accuracies.extend(batch_accuracy) 38 | # self.total_anls.extend(batch_anls) 39 | 40 | return {'accuracy': batch_accuracy, 'anls': batch_anls} 41 | 42 | def get_retrieval_metric(self, gt_answer_page, pred_answer_page): 43 | retrieval_precision = [1 if gt == pred else 0 for gt, pred in zip(gt_answer_page, pred_answer_page)] 44 | return retrieval_precision 45 | 46 | def update_global_metrics(self, accuracy, anls, current_epoch): 47 | if accuracy > self.best_accuracy: 48 | self.best_accuracy = accuracy 49 | self.best_epoch = current_epoch 50 | return True 51 | 52 | else: 53 | return False 54 | 55 | def _preprocess_str(self, string): 56 | if not self.case_sensitive: 57 | string = string.lower() 58 | 59 | return string.strip() 60 | 61 | def _calculate_accuracy(self, gt, pred, answer_type): 62 | 63 | if answer_type == 'not-answerable': 64 | return 1 if pred in ['', 'none', 'NA', None, []] else 0 65 | 66 | if pred == 'none' and answer_type != 'not-answerable': 67 | return 0 68 | 69 | for gt_elm in gt: 70 | if gt_elm == pred: 71 | return 1 72 | 73 | return 0 74 | 75 | def _calculate_anls(self, gt, pred, answer_type): 76 | if len(pred) == 0: 77 | return 0 78 | 79 | if answer_type == 'not-answerable': 80 | return 1 if pred in ['', 'none', 'NA', None, []] else 0 81 | 82 | if pred == 'none' and answer_type != 'not-answerable': 83 | return 0 84 | 85 | answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt] 86 | max_similarity = max(answers_similarity) 87 | 88 | anls = max_similarity if max_similarity >= self.anls_threshold else 0 89 | return anls 90 | 91 | 92 | class MPDocVQA(DocumentTask): 93 | def __init__(self, dataset, split, **kwargs): 94 | super().__init__(dataset, split, **kwargs) 95 | 96 | assert self.split in ['val', 'test', 'train'] 97 | 98 | def doc_to_visual_name(self, doc): 99 | return str(doc["id"]) # same doc_id could have different pages, so only same id gives same visual 100 | 101 | def aggregate_results(self, docs, out_root): 102 | if self.split == "val": 103 | evaluator = Evaluator(case_sensitive=False) 104 | 105 | total_accuracies = 0 106 | total_anls = 0 107 | 108 | for doc in docs: 109 | metric = evaluator.get_metrics([doc['answer']], [doc['pred'][0].strip()], doc.get('answer_type', None)) 110 | total_accuracies += sum(metric['accuracy']) 111 | total_anls += sum(metric['anls']) 112 | 113 | total_accuracies = total_accuracies/len(docs) 114 | total_anls = total_anls/len(docs) 115 | 116 | print(f"Acc: {total_accuracies*100}\nANLS: {total_anls*100}") 117 | out_file = out_root + '.log' 118 | with open(out_file, 'a') as file: 119 | file.write(f"Acc: {total_accuracies*100}\nANLS: {total_anls*100}\n") 120 | elif self.split == "test": 121 | out_all = [] 122 | for doc in docs: 123 | out = { 124 | "questionId": int(doc['id']), 125 | "answer": doc['pred'][0].strip(), 126 | "answer_page": "", 127 | } 128 | out_all.append(out) 129 | out_file = out_root + '_sub.json' 130 | with open(out_file, "w") as f: 131 | json.dump(out_all, f) 132 | 133 | -------------------------------------------------------------------------------- /models/llava_ov.py: -------------------------------------------------------------------------------- 1 | # Adopted from LLaVA-OneVision from https://github.com/LLaVA-VL/LLaVA-NeXT. Below is the original copyright: 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from llava.model.builder import load_pretrained_model 16 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 17 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 18 | from llava.conversation import conv_templates, SeparatorStyle 19 | 20 | import torch 21 | import copy 22 | import math 23 | 24 | 25 | def split_model(num_layers, gpu0_load=0.5): 26 | device_map = {} 27 | world_size = torch.cuda.device_count() 28 | 29 | # Since the first GPU will be used for ViT, treat it as half a GPU. 30 | num_layers_per_gpu = math.ceil(num_layers / (world_size - (1 - gpu0_load))) 31 | num_layers_per_gpu = [num_layers_per_gpu] * world_size 32 | num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * gpu0_load) 33 | layer_cnt = 0 34 | for i, num_layer in enumerate(num_layers_per_gpu): 35 | for j in range(num_layer): 36 | device_map[f'model.layers.{layer_cnt}'] = i 37 | layer_cnt += 1 38 | device_map['model.embed_tokens'] = 0 39 | device_map['model.norm'] = 0 40 | device_map['model.image_newline'] = 0 41 | device_map['model.vision_tower'] = 0 42 | device_map['model.vision_resampler'] = 0 43 | device_map['model.mm_projector'] = 0 44 | device_map['lm_head'] = 0 45 | device_map[f'model.layers.{num_layers - 1}'] = 0 46 | 47 | return device_map 48 | 49 | 50 | class LLaVAOneVision: 51 | def __init__(self, model_path, generation_args, image_aspect_ratio="anyres_max_9"): 52 | model_name = "llava_qwen" 53 | llava_model_args = { 54 | "multimodal": True, 55 | } 56 | overwrite_config = {} 57 | overwrite_config["image_aspect_ratio"] = image_aspect_ratio 58 | llava_model_args["overwrite_config"] = overwrite_config 59 | 60 | self.device = 'cuda' 61 | self.conv_template = "qwen_1_5" 62 | 63 | if "llava-onevision-qwen2-72b" in model_path: 64 | world_size = torch.cuda.device_count() 65 | if world_size < 4: 66 | gpu0_load = 0.5 67 | else: 68 | gpu0_load = 0.2 69 | 70 | device_map = split_model(80, gpu0_load=gpu0_load) 71 | else: 72 | device_map = 'auto' 73 | 74 | tokenizer, model, image_processor, max_length = load_pretrained_model(model_path, None, model_name, device_map=device_map, **llava_model_args) 75 | self.model = model 76 | self.tokenizer = tokenizer 77 | self.image_processor = image_processor 78 | self.generation_args = generation_args 79 | 80 | def run_model(self, images, message, output_scores=False, post_proc_func=None): 81 | message, prompt_assistant = message 82 | 83 | image_tensors = process_images(images, self.image_processor, self.model.config) 84 | image_tensors = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensors] 85 | 86 | image_sizes = [image.size for image in images] 87 | 88 | num_images = len(images) 89 | image_tokens = [DEFAULT_IMAGE_TOKEN] * num_images 90 | image_tokens = '\n'.join(image_tokens) 91 | 92 | question = f"{image_tokens}\n{message}" 93 | 94 | conv = copy.deepcopy(conv_templates[self.conv_template]) 95 | conv.append_message(conv.roles[0], question) 96 | conv.append_message(conv.roles[1], prompt_assistant) 97 | prompt_question = conv.get_prompt() 98 | 99 | input_ids = tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) 100 | 101 | # Generate response 102 | with torch.inference_mode(): 103 | generation_output = self.model.generate( 104 | input_ids, 105 | images=image_tensors, 106 | image_sizes=image_sizes, 107 | output_scores=output_scores, 108 | return_dict_in_generate=output_scores, 109 | **self.generation_args, 110 | ) 111 | 112 | if post_proc_func is not None: 113 | outputs = post_proc_func(generation_output) 114 | else: 115 | outputs = self.tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] 116 | return outputs 117 | -------------------------------------------------------------------------------- /tasks/utils.py: -------------------------------------------------------------------------------- 1 | # Adopted from lmms-eval from https://github.com/EvolvingLMMs-Lab/lmms-eval. Below is the original copyright: 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import string 17 | 18 | import numpy as np 19 | from decord import VideoReader, cpu 20 | from PIL import Image 21 | 22 | 23 | def load_video(visual_path, num_frames): 24 | vr = VideoReader(visual_path, ctx=cpu(0), num_threads=1) 25 | total_frames = len(vr) 26 | num_frames = min(num_frames, total_frames) 27 | if num_frames == 1: 28 | # get middle instead of first for single frame 29 | frame_indices = np.arange(0, total_frames, total_frames / 2).astype(int).tolist() 30 | frame_indices = frame_indices[1:] 31 | else: 32 | frame_indices = np.arange(0, total_frames, total_frames / num_frames).astype(int).tolist() 33 | 34 | frame_indices = [x for x in frame_indices if x < total_frames] 35 | 36 | frames = vr.get_batch(frame_indices).asnumpy() 37 | images = [Image.fromarray(frames[i]).convert('RGB') for i in range(len(frames))] 38 | 39 | return images, frame_indices 40 | 41 | 42 | def load_images(visual_paths, num_frames): 43 | assert isinstance(visual_paths, list) 44 | total_frames = len(visual_paths) 45 | if num_frames > 0: 46 | num_frames = min(num_frames, total_frames) 47 | frame_indices = np.arange(0, total_frames, total_frames / num_frames).astype(int).tolist() 48 | else: 49 | frame_indices = list(range(len(visual_paths))) 50 | frame_indices = [x for x in frame_indices if x < total_frames] 51 | images = [Image.open(visual_paths[i]).convert("RGB") for i in frame_indices] 52 | 53 | return images, frame_indices 54 | 55 | 56 | ################ Multi Choice ################ 57 | OPTIONS = string.ascii_uppercase 58 | 59 | 60 | def doc_to_text_mc(doc, model_specific_prompt_kwargs=None): 61 | if model_specific_prompt_kwargs is None: 62 | model_specific_prompt_kwargs = {} 63 | question = [doc["question"].strip()] 64 | options = doc["options"] 65 | for i, option in enumerate(options): 66 | question.append(f"{OPTIONS[i]}. {option.strip()}") 67 | question = "\n".join(question) 68 | if "pre_prompt" in model_specific_prompt_kwargs and model_specific_prompt_kwargs["pre_prompt"] != "": 69 | question = f"{model_specific_prompt_kwargs['pre_prompt']}{question}" 70 | if "post_prompt" in model_specific_prompt_kwargs and model_specific_prompt_kwargs["post_prompt"] != "": 71 | question = f"{question}{model_specific_prompt_kwargs['post_prompt']}" 72 | return question 73 | 74 | 75 | def mc_process_results(doc, results): 76 | pred = results 77 | index2ans, all_choices = get_multi_choice_info(doc) 78 | parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) 79 | return { 80 | "exact_match": parsed_pred == OPTIONS[doc["answer"]], 81 | } 82 | 83 | 84 | def parse_multi_choice_response(response, all_choices, index2ans): 85 | """ 86 | Parse the prediction from the generated response. 87 | Return the predicted index e.g., A, B, C, D. 88 | https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10 89 | """ 90 | for char in [",", ".", "!", "?", ";", ":", "'"]: 91 | response = response.strip(char) 92 | response = " " + response + " " # add space to avoid partial match 93 | 94 | index_ans = True 95 | ans_with_brack = False 96 | candidates = [] 97 | for choice in all_choices: # e.g., (A) (B) (C) (D) 98 | if f"({choice})" in response: 99 | candidates.append(choice) 100 | ans_with_brack = True 101 | 102 | if len(candidates) == 0: 103 | for choice in all_choices: # e.g., A B C D 104 | if f"{choice} " in response: 105 | candidates.append(choice) 106 | 107 | if len(candidates) == 0: 108 | for choice in all_choices: # e.g., A. B. C. D. 109 | if f"{choice}." in response: 110 | candidates.append(choice) 111 | 112 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example 113 | if len(candidates) == 0 and len(response.split()) > 5: 114 | for index, ans in index2ans.items(): 115 | if ans.lower() in response.lower(): 116 | candidates.append(index) 117 | index_ans = False # it's content ans. 118 | 119 | if len(candidates) == 0: # still not get answer, randomly choose one. 120 | pred_index = random.choice(all_choices) 121 | elif len(candidates) > 1: 122 | start_indexes = [] 123 | if index_ans: 124 | if ans_with_brack: 125 | for can in candidates: 126 | index = response.rfind(f"({can})") 127 | start_indexes.append(index) # -1 will be ignored anyway 128 | # start_indexes = [generated_response.index(f'({can})') for can in candidates] 129 | else: 130 | for can in candidates: 131 | index = response.rfind(f" {can} ") 132 | start_indexes.append(index) 133 | else: 134 | for can in candidates: 135 | index = response.lower().rfind(index2ans[can].lower()) 136 | start_indexes.append(index) 137 | # get the last one 138 | pred_index = candidates[np.argmax(start_indexes)] 139 | else: # if only one candidate, use it. 140 | pred_index = candidates[0] 141 | 142 | return pred_index 143 | 144 | 145 | def get_multi_choice_info(doc): 146 | all_choices = [] 147 | index2ans = {} 148 | options = doc["options"] 149 | for i, option in enumerate(options): 150 | index2ans[OPTIONS[i]] = option.strip() 151 | all_choices.append(OPTIONS[i]) 152 | 153 | return index2ans, all_choices 154 | ################ Multi Choice ################ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FRAG: Frame Selection Augmented Generation 2 | 3 | [De-An Huang](https://ai.stanford.edu/~dahuang/), [Subhashree Radhakrishnan](), [Zhiding Yu](https://chrisding.github.io/), [Jan Kautz](https://jankautz.com/) 4 | 5 | [[`arXiv`](https://arxiv.org/abs/2504.17447)] [[`Project`]()] [[`BibTeX`](#Citation)] 6 | 7 | 8 | ## Contents 9 | - [Install](#install) 10 | - [Inference](#inference) 11 | - [Evaluation](#evaluation) 12 | 13 | 14 | ## Install 15 | 16 | The core of FRAG is zero-shot and has minimal dependencies. Follow the instructions below to install the base models and benchmarks. 17 | 18 | ### Data Loading 19 | ```Shell 20 | pip install decord 21 | ``` 22 | 23 | ### Models 24 | 25 | 1. **LLaVA-OneVision**: Follow the instructions [here](https://github.com/LLaVA-VL/LLaVA-NeXT) to install LLaVA-OneVision. Please make sure that you can run the examples [here](https://huggingface.co/lmms-lab/llava-onevision-qwen2-7b-ov#generation). 26 | 27 | 2. **InternVL-2**: If you have already installed LLaVA-OneVision, the dependencies should already work for InternVL-2. If you only want to use InternVL-2, follow the instructions [here](https://huggingface.co/OpenGVLab/InternVL2-8B#quick-start). Please make sure that you can run the examples [here](https://huggingface.co/OpenGVLab/InternVL2-8B#inference-with-transformers). 28 | 29 | 3. **Qwen2-VL**: Follow the instructions [here](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct#quickstart). Please make sure that you can run the multi image inference example. We recommend using `transformers==4.45.0` for Qwen2-VL. 30 | 31 | Note that you only need to install the models you would like to use. If you want to quickly try out FRAG, we recommend starting with InternVL-2 first, as it has fewer dependencies. 32 | 33 | 34 | ### Benchmarks 35 | 36 | 1. **MP-DocVQA**: 37 | ```Shell 38 | pip install editdistance 39 | ``` 40 | 41 | 2. **MMLongBench-Doc**: 42 | ```Shell 43 | pip install openai 44 | ``` 45 | 46 | ## Inference 47 | 48 | ### Video 49 | 50 | Video inference example: 51 | ```Shell 52 | python run_model.py \ 53 | --output-dir . \ 54 | --model "internvl" \ 55 | --model-path "OpenGVLab/InternVL2-8B" \ 56 | --image-aspect-ratio "1" \ 57 | --selector-model "internvl" \ 58 | --selector-model-path "OpenGVLab/InternVL2-8B" \ 59 | --selector-image-aspect-ratio "12" \ 60 | --sample_frames 64 \ 61 | --input_frames 8 \ 62 | --selector_method "topk_frames" \ 63 | --input-type "video" \ 64 | --input-path \ 65 | --query "Are there any fish in the video?" 66 | ``` 67 | This uses InternVL2-8B for both answering and scoring. Scoring uses a maximum of 12 tiles for dynamic resolution, while answering disables dynamic resolution (`--image-aspect-ratio "1"`). The example first uniformly samples 64 frames, and selects the top 8 frames for answering (FRAG-64-Top8). We use `--sample_frames 256` in the paper. For longer or shorter videos, around 1 fps for `--sample_frames` is a good starting point. The example also generates `.json`, which contains the inputs and outputs of the model. 68 | 69 | 70 | ### Document 71 | 72 | Document inference example: 73 | ```Shell 74 | python run_model.py \ 75 | --output-dir . \ 76 | --model "internvl" \ 77 | --model-path "OpenGVLab/InternVL2-8B" \ 78 | --image-aspect-ratio "16" \ 79 | --selector-model "internvl" \ 80 | --selector-model-path "OpenGVLab/InternVL2-8B" \ 81 | --selector-image-aspect-ratio "12" \ 82 | --sample_frames -1 \ 83 | --input_frames 1 \ 84 | --selector_method "topk_frames" \ 85 | --input-type "images" \ 86 | --input-path \ 87 | --query "What is the title of the paper?" 88 | ``` 89 | The main difference from the video example is `--input-type "images"`, which suggests that `--input-path` points to a folder containing images (from pages of a document). Our data loading function assumes that the image file names are sorted by the page order. Other differences include: `--image-aspect-ratio "16"` to use higher resolution for answering, `--sample_frames -1` to sample all the pages, and `--input_frames 1` to only select the Top-1 page for answering. 90 | 91 | 92 | ## Evaluation 93 | 94 | We provide example scripts for benchmark evaluation using InternVL2-8B. 95 | 96 | ### Video 97 | 98 | 0. Update Paths 99 | 100 | Update the dataset and output paths in `scripts/video/path.sh`. JSON files pointed by `$doc_path` can be downloaded [here](https://huggingface.co/datasets/deahuang/FRAG-Datasets). Follow the official download instruction for each dataset, and `$visual_path` would point to the root folder for videos. 101 | 102 | 103 | 1. Precompute FRAG Scores 104 | 105 | ```Shell 106 | bash scripts/video/annot_scores_internvl-8b_frames.sh $dataset $num_frames $CHUNKS $IDX 107 | ``` 108 | `$dataset` is the dataset name to evaluate. `$num_frames` is the number of frames to uniformly sample from the video for FRAG scoring. `$CHUNKS` and `$IDX` would split the samples in the dataset in to `$CHUNKS` splits and only compute scores for the `$IDX` split. For example, to evaluate LongVideoBench with 256 sampled frames (as in the paper) with a single job: 109 | ```Shell 110 | bash scripts/video/annot_scores_internvl-8b_frames.sh lvb 256 1 0 111 | ``` 112 | Here, there is only 1 chunk, and the only `$IDX` is 0. Set `$CHUNKS` to `N` and `$IDX` in `[0, N)` to run `N` jobs for score computation. 113 | 114 | 2. Collect FRAG Scores 115 | 116 | The previous step computes FRAG scores for videos in the dataset, which are saved in separate files for easier parallelization. Now we collect all the FRAG scores into a single JSON file. Following the previous example, collect the FRAG scores using: 117 | ```Shell 118 | python collect_results.py \ 119 | --doc-path $root/datasets/LongVideoBench/lvb_val_doc_list.json \ 120 | --result-path $output_root/lvb/val/annot_scores_internvl-8b_frames_256 121 | ``` 122 | `$root` and `$output_root` are specified in `scripts/video/path.sh` in step 0. This should generates `$output_root/lvb/val/annot_scores_internvl-8b_frames_256.json`, which will be used in the next step. 123 | 124 | 3. Evaluate FRAG 125 | 126 | ```Shell 127 | bash scripts/video/eval_internvl-8b-max1_top32_frames.sh $dataset $num_frames 128 | ``` 129 | This script evaluates FRAG-Top32-N, where N is `$num_frames`. For LongVideoBench and 256 sampled frames: 130 | ```Shell 131 | bash scripts/video/eval_internvl-8b-max1_top32_frames.sh lvb 256 132 | ``` 133 | 134 | 135 | ### Document 136 | 137 | We use SlideVQA and InternVL2-8B as an example. The scripts are similar to the ones for videos. 138 | 139 | 140 | 0. Update Paths 141 | 142 | Update the dataset and output paths in `scripts/document/path.sh`. 143 | 144 | 1. Precompute FRAG Scores 145 | 146 | ```Shell 147 | bash scripts/document/annot_scores_internvl-8b_frames.sh slidevqa -1 1 0 148 | ``` 149 | The arguments are the same as video's step 1. Here, -1 means that all the pages are sampled, and the pages are not uniformly sampled. 150 | 151 | 2. Collect FRAG Scores 152 | 153 | ```Shell 154 | python collect_results.py \ 155 | --doc-path ${root}/datasets/SlideVQA/test_doc.json \ 156 | --result-path $output_root/slidevqa/test/annot_scores_internvl-8b_frames_-1 157 | ``` 158 | This should generates `$output_root/slidevqa/test/annot_scores_internvl-8b_frames_-1.json`, which will be used in the next step. 159 | 160 | 3. Evaluate FRAG 161 | 162 | ```Shell 163 | bash scripts/document/eval_internvl-8b-max16_topk_frames.sh $dataset $num_frames 164 | ``` 165 | This script evaluates FRAG by selecting the top `$num_frames` frames. Here, `$num_frames` is K instead of N for FRAG-TopK-N because for documents we go through all the pages. For SlideVQA and Top 2 frames: 166 | ```Shell 167 | bash scripts/document/eval_internvl-8b-max16_topk_frames.sh slidevqa 2 168 | ``` 169 | 170 | ## License 171 | 172 | Copyright © 2025, NVIDIA Corporation. All rights reserved. 173 | 174 | This work is made available under the Nvidia Source Code License-NC. Click [here](LICENSE) to view a copy of this license. 175 | 176 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). 177 | 178 | 179 | ## Citation 180 | 181 | If you find FRAG useful for your research and applications, please cite using this BibTeX: 182 | ```bibtex 183 | @article{huang2025frag, 184 | title={FRAG: Frame Selection Augmented Generation for Long Video and Long Document Understanding}, 185 | author={De-An Huang and Subhashree Radhakrishnan and Zhiding Yu and Jan Kautz}, 186 | journal={arXiv preprint arXiv:2504.17447}, 187 | year={2025} 188 | } 189 | ``` 190 | 191 | 192 | 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /models/internvl/internvl.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see https://github.com/OpenGVLab/InternVL/blob/main/LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | import numpy as np 9 | import math 10 | import torch 11 | import torchvision.transforms as T 12 | from torchvision.transforms.functional import InterpolationMode 13 | from transformers import AutoModel, AutoTokenizer 14 | from .conversation import get_conv_template 15 | 16 | 17 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 18 | IMAGENET_STD = (0.229, 0.224, 0.225) 19 | 20 | IMG_START_TOKEN = '' 21 | IMG_END_TOKEN = '' 22 | IMG_CONTEXT_TOKEN = '' 23 | 24 | 25 | def build_transform(input_size): 26 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 27 | transform = T.Compose([ 28 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 29 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 30 | T.ToTensor(), 31 | T.Normalize(mean=MEAN, std=STD) 32 | ]) 33 | return transform 34 | 35 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 36 | best_ratio_diff = float('inf') 37 | best_ratio = (1, 1) 38 | area = width * height 39 | for ratio in target_ratios: 40 | target_aspect_ratio = ratio[0] / ratio[1] 41 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 42 | if ratio_diff < best_ratio_diff: 43 | best_ratio_diff = ratio_diff 44 | best_ratio = ratio 45 | elif ratio_diff == best_ratio_diff: 46 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 47 | best_ratio = ratio 48 | return best_ratio 49 | 50 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): 51 | orig_width, orig_height = image.size 52 | aspect_ratio = orig_width / orig_height 53 | 54 | # calculate the existing image aspect ratio 55 | target_ratios = set( 56 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 57 | i * j <= max_num and i * j >= min_num) 58 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 59 | 60 | # find the closest aspect ratio to the target 61 | target_aspect_ratio = find_closest_aspect_ratio( 62 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 63 | 64 | # calculate the target width and height 65 | target_width = image_size * target_aspect_ratio[0] 66 | target_height = image_size * target_aspect_ratio[1] 67 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 68 | 69 | # resize the image 70 | resized_img = image.resize((target_width, target_height)) 71 | processed_images = [] 72 | for i in range(blocks): 73 | box = ( 74 | (i % (target_width // image_size)) * image_size, 75 | (i // (target_width // image_size)) * image_size, 76 | ((i % (target_width // image_size)) + 1) * image_size, 77 | ((i // (target_width // image_size)) + 1) * image_size 78 | ) 79 | # split the image 80 | split_img = resized_img.crop(box) 81 | processed_images.append(split_img) 82 | assert len(processed_images) == blocks 83 | if use_thumbnail and len(processed_images) != 1: 84 | thumbnail_img = image.resize((image_size, image_size)) 85 | processed_images.append(thumbnail_img) 86 | return processed_images 87 | 88 | def process_image(image, input_size=448, max_num=12): 89 | transform = build_transform(input_size=input_size) 90 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 91 | pixel_values = [transform(image) for image in images] 92 | pixel_values = torch.stack(pixel_values) 93 | return pixel_values 94 | 95 | def split_model(model_name, gpu0_load=0.5): 96 | device_map = {} 97 | world_size = torch.cuda.device_count() 98 | num_layers = { 99 | 'InternVL2-1B': 24, 'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32, 100 | 'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name] 101 | # Since the first GPU will be used for ViT, treat it as half a GPU. 102 | num_layers_per_gpu = math.ceil(num_layers / (world_size - (1 - gpu0_load))) 103 | num_layers_per_gpu = [num_layers_per_gpu] * world_size 104 | num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * gpu0_load) 105 | layer_cnt = 0 106 | for i, num_layer in enumerate(num_layers_per_gpu): 107 | for j in range(num_layer): 108 | device_map[f'language_model.model.layers.{layer_cnt}'] = i 109 | layer_cnt += 1 110 | device_map['vision_model'] = 0 111 | device_map['mlp1'] = 0 112 | device_map['language_model.model.tok_embeddings'] = 0 113 | device_map['language_model.model.embed_tokens'] = 0 114 | device_map['language_model.output'] = 0 115 | device_map['language_model.model.norm'] = 0 116 | device_map['language_model.lm_head'] = 0 117 | 118 | return device_map 119 | 120 | 121 | class InternVL: 122 | def __init__(self, model_path, generation_args, image_aspect_ratio=12): 123 | model_name = os.path.basename(model_path) 124 | max_num = int(image_aspect_ratio) 125 | 126 | # reduce gpu0 load for 76B, 8 gpus 127 | if torch.cuda.device_count() == 8 and model_name == 'InternVL2-Llama3-76B': 128 | gpu0_load = 0.0 129 | elif torch.cuda.device_count() == 8 and model_name == 'InternVL2-40B' and max_num < 6: 130 | gpu0_load = 0.1 131 | else: 132 | gpu0_load = 0.5 133 | 134 | device_map = split_model(model_name, gpu0_load=gpu0_load) 135 | self.model = AutoModel.from_pretrained( 136 | model_path, 137 | torch_dtype=torch.bfloat16, 138 | low_cpu_mem_usage=True, 139 | trust_remote_code=True, 140 | device_map=device_map) 141 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) 142 | self.generation_args = generation_args 143 | 144 | self.max_num = max_num 145 | img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) 146 | self.model.img_context_token_id = img_context_token_id 147 | 148 | def run_model(self, images, message, output_scores=False, post_proc_func=None): 149 | message, prompt_assistant = message 150 | 151 | pixel_values = [process_image(image, max_num=self.max_num) for image in images] 152 | num_patches_list = [x.size(0) for x in pixel_values] 153 | pixel_values = torch.cat(pixel_values, dim=0).to(torch.bfloat16).cuda() 154 | 155 | num_images = len(images) 156 | image_tokens = [] 157 | for i in range(num_images): 158 | image_tokens.append(f"Image-{i+1}: ") 159 | image_tokens = '\n'.join(image_tokens) 160 | question = f"{image_tokens}\n{message}" 161 | 162 | # from InternVLChatModel.chat 163 | assert pixel_values is None or len(pixel_values) == sum(num_patches_list) 164 | 165 | template = get_conv_template(self.model.template) 166 | template.system_message = self.model.system_message 167 | eos_token_id = self.tokenizer.convert_tokens_to_ids(template.sep) 168 | 169 | template.append_message(template.roles[0], question) 170 | template.append_message(template.roles[1], prompt_assistant) 171 | query = template.get_prompt() 172 | 173 | for num_patches in num_patches_list: 174 | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.model.num_image_token * num_patches + IMG_END_TOKEN 175 | query = query.replace('', image_tokens, 1) 176 | 177 | model_inputs = self.tokenizer(query, return_tensors='pt') 178 | input_ids = model_inputs['input_ids'].cuda() 179 | attention_mask = model_inputs['attention_mask'].cuda() 180 | 181 | generation_config = self.generation_args.copy() 182 | generation_config['eos_token_id'] = eos_token_id 183 | generation_config['output_scores'] = output_scores 184 | generation_config['return_dict_in_generate'] = output_scores 185 | 186 | with torch.inference_mode(): 187 | generation_output = self.model.generate( 188 | pixel_values=pixel_values, 189 | input_ids=input_ids, 190 | attention_mask=attention_mask, 191 | **generation_config 192 | ) 193 | 194 | if post_proc_func is not None: 195 | outputs = post_proc_func(generation_output) 196 | else: 197 | outputs = self.tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] 198 | outputs = outputs.split(template.sep)[0].strip() 199 | 200 | return outputs 201 | 202 | 203 | -------------------------------------------------------------------------------- /eval_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/FRAG/blob/main/LICENSE 6 | 7 | import argparse 8 | import os 9 | import json 10 | from tqdm import tqdm 11 | from collections import defaultdict 12 | import numpy as np 13 | import copy 14 | 15 | import torch 16 | from torch.utils.data import Dataset, DataLoader 17 | 18 | from PIL import Image 19 | from decord import VideoReader, cpu 20 | 21 | from utils import split_list, get_chunk 22 | from models.builder import build_model 23 | from frame_selection import FrameSelection 24 | 25 | from tasks.builder import build_task 26 | 27 | 28 | # Custom dataset class 29 | class CustomDataset(Dataset): 30 | def __init__(self, task, video_docs, visual_folder, num_frames): 31 | self.task = task 32 | self.video_docs = video_docs 33 | self.video_names = list(video_docs.keys()) 34 | self.visual_folder = visual_folder 35 | self.num_frames = num_frames 36 | 37 | def __getitem__(self, index): 38 | video_name = self.video_names[index] 39 | 40 | # load video 41 | video_path = os.path.join(self.visual_folder, video_name) 42 | doc = self.video_docs[video_name][0] 43 | if 'images' in doc: 44 | video_path = [os.path.join(self.visual_folder, x) for x in doc["images"]] 45 | images, frame_indices = self.task.load_visual(video_path, self.num_frames) 46 | 47 | docs = self.video_docs[video_name] 48 | prompts = [] 49 | for doc in docs: 50 | prompt = self.task.doc_to_prompt(doc) 51 | prompts.append(prompt) 52 | 53 | return docs, prompts, images, frame_indices 54 | 55 | def __len__(self): 56 | return len(self.video_names) 57 | 58 | 59 | def collate_fn(batch): 60 | docs, prompts, images, frame_indices = zip(*batch) 61 | docs = list(docs) 62 | prompts = list(prompts) 63 | images = list(images) 64 | frame_indices = list(frame_indices) 65 | return docs, prompts, images, frame_indices 66 | 67 | 68 | # DataLoader 69 | def create_data_loader(task, video_docs, visual_folder, num_frames, batch_size=1, num_workers=4): 70 | assert batch_size == 1, "batch_size must be 1" 71 | dataset = CustomDataset(task, video_docs, visual_folder, num_frames) 72 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) 73 | return data_loader 74 | 75 | 76 | def eval_model(args): 77 | # task 78 | task = build_task(args.dataset, args.split, 79 | subtitles=args.subtitles, 80 | visual_folder=args.visual_folder) 81 | 82 | generation_args = {"max_new_tokens": args.max_new_tokens, 83 | "temperature": args.temperature, 84 | "do_sample": args.do_sample} 85 | 86 | # answering LMM 87 | model_path = os.path.expanduser(args.model_path) 88 | model = build_model(args.model, 89 | model_path, 90 | generation_args, 91 | image_aspect_ratio=args.image_aspect_ratio) 92 | 93 | # scroing LMM 94 | if args.selector_model is None: 95 | selector_model = model 96 | else: 97 | selector_model_path = os.path.expanduser(args.selector_model_path) 98 | selector_model = build_model(args.selector_model, 99 | args.selector_model_path, 100 | generation_args, 101 | image_aspect_ratio=args.selector_image_aspect_ratio) 102 | 103 | selector = FrameSelection(selector_model, args.selector_method, args.input_frames, args.sample_frames, 104 | score_docs=args.score_docs, sub_vid_cache=args.sub_vid_cache) 105 | 106 | docs = json.load(open(args.doc_path, "r")) 107 | 108 | # set id if doesn't exist 109 | for i, doc in enumerate(docs): 110 | if 'id' not in doc: 111 | doc["id"] = "%06d" % i 112 | 113 | # split by video 114 | video_docs = defaultdict(list) 115 | for doc in docs: 116 | video_name = task.doc_to_visual_name(doc) 117 | video_docs[video_name].append(doc) 118 | 119 | # chunk by videos 120 | video_names = sorted(list(video_docs.keys())) 121 | video_names = get_chunk(video_names, args.num_chunks, args.chunk_idx) 122 | video_docs_chunk = {} 123 | for video_name in video_names: 124 | video_docs_chunk[video_name] = video_docs[video_name] 125 | video_docs = video_docs_chunk 126 | 127 | # remove ones that are already done 128 | remaining_docs = defaultdict(list) 129 | for video_name in video_docs.keys(): 130 | vid_docs = video_docs[video_name] 131 | for doc in vid_docs: 132 | out_name = os.path.join(args.output_dir, "%s.json" % doc["id"]) 133 | if not os.path.exists(out_name): 134 | remaining_docs[video_name].append(doc) 135 | video_docs = remaining_docs 136 | 137 | data_loader = create_data_loader(task, video_docs, args.visual_folder, args.sample_frames) 138 | for docs, prompts, images, frame_indices in tqdm(data_loader, total=len(data_loader)): 139 | docs = docs[0] 140 | prompts = prompts[0] 141 | images = images[0] 142 | frame_indices = frame_indices[0] 143 | for doc, prompt in zip(docs, prompts): 144 | out_name = os.path.join(args.output_dir, "%s.json" % doc["id"]) 145 | if os.path.exists(out_name): 146 | continue 147 | 148 | if args.annot_scores: 149 | scores, proposal_indices = selector.annotate_scores(doc, images, frame_indices) 150 | 151 | if scores is not None: 152 | out = copy.deepcopy(doc) 153 | out["frames"] = proposal_indices 154 | out["scores"] = [float(x) for x in scores] 155 | assert len(out["frames"]) == len(out["scores"]) 156 | 157 | with open(out_name, "w") as f: 158 | json.dump(out, f) 159 | else: 160 | selected_images_list, selected_indices_list = selector.select_frames(doc, images, frame_indices) 161 | 162 | results = [] 163 | for selected_images, selected_indices in zip(selected_images_list, selected_indices_list): 164 | outputs = model.run_model(selected_images, prompt) 165 | results.append(task.process_results(doc, outputs)) 166 | 167 | out = copy.deepcopy(doc) 168 | out["frames"] = selected_indices_list 169 | out[task.result_key] = results 170 | assert len(out["frames"]) == len(out[task.result_key]) 171 | 172 | with open(out_name, "w") as f: 173 | json.dump(out, f) 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument("--doc-path", type=str, default="facebook/opt-350m") 179 | parser.add_argument("--visual-folder", type=str, default="") 180 | parser.add_argument("--num-chunks", type=int, default=1) 181 | parser.add_argument("--chunk-idx", type=int, default=0) 182 | parser.add_argument("--output-dir", type=str, default="outputs") 183 | 184 | parser.add_argument("--max_new_tokens", type=int, default=128) 185 | parser.add_argument("--temperature", type=float, default=0.0) 186 | 187 | parser.add_argument("--model", type=str, default="llava_ov") 188 | parser.add_argument("--model-path", type=str, default="lmms-lab/llava-onevision-qwen2-7b-ov") 189 | parser.add_argument("--image-aspect-ratio", type=str, default="anyres_max_9") 190 | parser.add_argument("--selector-model", type=str, default=None) 191 | parser.add_argument("--selector-model-path", type=str, default=None) 192 | parser.add_argument("--selector-image-aspect-ratio", type=str, default="anyres_max_9") 193 | 194 | # selection 195 | parser.add_argument("--sample_frames", type=int, default=64) 196 | parser.add_argument("--input_frames", type=int, default=1) 197 | parser.add_argument("--selector_method", type=str, default="topk") 198 | parser.add_argument("--score-docs", type=str, default=None) 199 | 200 | # resume 201 | parser.add_argument("--main-process", action='store_true') 202 | parser.add_argument("--sub-vid-cache", type=str, default=None) 203 | 204 | # task 205 | parser.add_argument("--dataset", type=str, default=None) 206 | parser.add_argument("--split", type=str, default=None) 207 | parser.add_argument("--subtitles", action='store_true') 208 | 209 | args = parser.parse_args() 210 | 211 | os.makedirs(args.output_dir, exist_ok=True) 212 | if args.sub_vid_cache is not None: 213 | os.makedirs(args.sub_vid_cache, exist_ok=True) 214 | args.do_sample = args.temperature > 0 215 | 216 | args.annot_scores = args.selector_method.startswith('annot_scores') 217 | 218 | eval_model(args) 219 | 220 | -------------------------------------------------------------------------------- /tasks/videomme.py: -------------------------------------------------------------------------------- 1 | from .base import VideoTask 2 | from .utils import doc_to_text_mc, OPTIONS 3 | 4 | import os 5 | import re 6 | import numpy as np 7 | from decord import VideoReader, cpu 8 | 9 | import logging 10 | 11 | 12 | VIDEO_TYPE = ["short", "medium", "long"] 13 | CATEGORIES = ["Knowledge", "Film & Television", "Sports Competition", "Artistic Performance", "Life Record", "Multilingual"] 14 | 15 | SUB_CATEGORIES = [ 16 | "Humanity & History", 17 | "Literature & Art", 18 | "Biology & Medicine", 19 | "Finance & Commerce", 20 | "Astronomy", 21 | "Geography", 22 | "Law", 23 | "Life Tip", 24 | "Technology", 25 | "Animation", 26 | "Movie & TV Show", 27 | "Documentary", 28 | "News Report", 29 | "Esports", 30 | "Basketball", 31 | "Football", 32 | "Athletics", 33 | "Other Sports", 34 | "Stage Play", 35 | "Magic Show", 36 | "Variety Show", 37 | "Acrobatics", 38 | "Handicraft", 39 | "Food", 40 | "Fashion", 41 | "Daily Life", 42 | "Travel", 43 | "Pet & Animal", 44 | "Exercise", 45 | "Multilingual", 46 | ] 47 | 48 | TASK_CATEGORIES = [ 49 | "Temporal Perception", 50 | "Spatial Perception", 51 | "Attribute Perception", 52 | "Action Recognition", 53 | "Object Recognition", 54 | "OCR Problems", 55 | "Counting Problem", 56 | "Temporal Reasoning", 57 | "Spatial Reasoning", 58 | "Action Reasoning", 59 | "Object Reasoning", 60 | "Information Synopsis", 61 | ] 62 | 63 | 64 | def parse_subtitle_time(time_str): 65 | h, m, s_ms = time_str.split(":") 66 | s, ms = s_ms.split(",") 67 | return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 68 | 69 | 70 | def load_subtitles(subtitle_path): 71 | subtitles = {} 72 | with open(subtitle_path, "r", encoding="utf-8") as file: 73 | content = file.read().split("\n\n") 74 | for section in content: 75 | if section.strip(): 76 | lines = section.split("\n") 77 | if len(lines) >= 3: 78 | time_range = lines[1].split(" --> ") 79 | start_time = parse_subtitle_time(time_range[0]) 80 | end_time = parse_subtitle_time(time_range[1]) 81 | text = " ".join(line for line in lines[2:]) 82 | subtitles[(start_time, end_time)] = text 83 | return subtitles 84 | 85 | 86 | def convert_time_to_frame(time_in_seconds, fps): 87 | return int(time_in_seconds * fps) 88 | 89 | 90 | def extract_subtitles(video_path, subtitle_path): 91 | vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) 92 | total_frame = len(vr) 93 | fps = vr.get_avg_fps() 94 | subtitles = load_subtitles(subtitle_path) 95 | 96 | subtitle_frames = [] 97 | for (start_time, end_time), text in subtitles.items(): 98 | start_frame = convert_time_to_frame(start_time, fps) 99 | end_frame = convert_time_to_frame(end_time, fps) 100 | subtitle_frames.append((start_frame, end_frame, text)) 101 | 102 | return subtitle_frames, total_frame 103 | 104 | 105 | def extract_characters_regex(s): 106 | s = s.strip() 107 | answer_prefixes = [ 108 | "The best answer is", 109 | "The correct answer is", 110 | "The answer is", 111 | "The answer", 112 | "The best option is" "The correct option is", 113 | "Best answer:" "Best option:", 114 | ] 115 | for answer_prefix in answer_prefixes: 116 | s = s.replace(answer_prefix, "") 117 | 118 | if len(s.split()) > 10 and not re.search("[ABCD]", s): 119 | return "" 120 | 121 | matches = re.search(r"[ABCD]", s) 122 | if matches is None: 123 | return "" 124 | return matches[0] 125 | 126 | 127 | class VideoMME(VideoTask): 128 | def __init__(self, dataset, split, subtitles=False, visual_folder=None, **kwargs): 129 | super().__init__(dataset, split, **kwargs) 130 | 131 | self.post_prompt = "\nThe best answer is:" 132 | 133 | assert self.split in ['test'] 134 | 135 | self.subtitles = subtitles 136 | self.visual_folder = visual_folder 137 | 138 | def doc_to_prompt(self, doc): 139 | prompt_user = doc_to_text_mc(doc, {"post_prompt": self.post_prompt}) 140 | prompt_assistant = self.prompt_assistant 141 | 142 | option_prompt = "Select the best answer to the following multiple-choice question based on the video and the subtitles. Respond with only the letter (A, B, C, or D) of the correct option." 143 | prompt_user = option_prompt + '\n' + prompt_user 144 | 145 | if self.subtitles: 146 | video_path = os.path.join(self.visual_folder, doc["video"]) 147 | subtitle_folder = os.path.join(os.path.dirname(self.visual_folder), 'subtitle') 148 | video_name = os.path.splitext(os.path.basename(doc["video"]))[0] 149 | subtitle_path = os.path.join(subtitle_folder, video_name + '.srt') 150 | if os.path.exists(subtitle_path): 151 | subtitle = open(subtitle_path).readlines() 152 | 153 | frame_num = 32 # lmms-eval, videomme_w_subtitle.yaml 154 | subtitle_by_frame, total_frame = extract_subtitles(video_path, subtitle_path) 155 | uniform_sampled_frames = np.linspace(0, total_frame - 1, frame_num, dtype=int).tolist() 156 | 157 | subtitle_by_frame_idx = [] 158 | for frame_idx in uniform_sampled_frames: 159 | for idx, title in enumerate(subtitle_by_frame): 160 | if frame_idx < title[1] and frame_idx >= title[0]: 161 | subtitle_by_frame_idx.append(idx) 162 | subtitle_by_frame_idx = list(set(subtitle_by_frame_idx)) 163 | 164 | textlist = [] 165 | for idx in subtitle_by_frame_idx: 166 | pattern = r'(.*?)' 167 | raw_text = re.findall(pattern, subtitle_by_frame[idx][2]) 168 | try: 169 | textlist.append(raw_text[0]) 170 | except: 171 | continue 172 | subtitle_text = "\n".join(textlist) 173 | subtitle = subtitle_text 174 | else: 175 | subtitle = "No subtitles available" 176 | subtitles_prompt = "This video's subtitles are listed below:" 177 | 178 | prompt_user = subtitles_prompt + '\n' + subtitle + '\n' + prompt_user 179 | 180 | return (prompt_user, prompt_assistant) 181 | 182 | def aggregate_results(self, docs, out_root): 183 | out_file = out_root + '.log' 184 | logging.basicConfig(filename=out_file, 185 | level=logging.INFO, 186 | format='%(asctime)s - %(levelname)s - %(message)s') 187 | 188 | category2score = {} 189 | 190 | for video_type in VIDEO_TYPE: 191 | for category in CATEGORIES: 192 | for sub_category in SUB_CATEGORIES: 193 | for task_category in TASK_CATEGORIES: 194 | key = f"{video_type}_{category}_{sub_category}_{task_category}" 195 | category2score[key] = {"correct": 0, "answered": 0} 196 | 197 | for doc in docs: 198 | video_type = doc["duration"] 199 | category = doc["domain"] 200 | sub_category = doc["sub_category"] 201 | task_category = doc["task_type"] 202 | key = f"{video_type}_{category}_{sub_category}_{task_category}" 203 | category2score[key]["answered"] += 1 204 | 205 | pred = doc["pred"][0] 206 | pred_ans = extract_characters_regex(pred) 207 | answer = OPTIONS[doc["answer"]] 208 | correct = int(pred_ans == answer) 209 | category2score[key]["correct"] += correct 210 | 211 | for video_type in VIDEO_TYPE: 212 | total_correct = 0 213 | total_answered = 0 214 | for k, v in category2score.items(): 215 | if video_type in k: 216 | total_correct += v["correct"] 217 | total_answered += v["answered"] 218 | logging.info(f"Evaluation on video Type: {video_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 219 | 220 | for category in CATEGORIES: 221 | total_correct = 0 222 | total_answered = 0 223 | for k, v in category2score.items(): 224 | if category in k: 225 | total_correct += v["correct"] 226 | total_answered += v["answered"] 227 | logging.info(f"Evaluation on Categories: {category}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 228 | 229 | for sub_cate in SUB_CATEGORIES: 230 | total_correct = 0 231 | total_answered = 0 232 | for k, v in category2score.items(): 233 | if sub_cate in k: 234 | total_correct += v["correct"] 235 | total_answered += v["answered"] 236 | logging.info(f"Evaluation on Video Sub Categories: {sub_cate}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 237 | 238 | for task_cate in TASK_CATEGORIES: 239 | total_correct = 0 240 | total_answered = 0 241 | for k, v in category2score.items(): 242 | if task_cate in k: 243 | total_correct += v["correct"] 244 | total_answered += v["answered"] 245 | logging.info(f"Evaluation on Task Categories: {task_cate}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 246 | 247 | total_correct = 0 248 | total_answered = 0 249 | for k, v in category2score.items(): 250 | total_correct += v["correct"] 251 | total_answered += v["answered"] 252 | logging.info(f"Overall Performance: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%") 253 | return 100 * total_correct / total_answered if total_answered > 0 else 0 254 | 255 | -------------------------------------------------------------------------------- /frame_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/FRAG/blob/main/LICENSE 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import json 11 | import einops 12 | 13 | from utils import split_list 14 | 15 | from utils import hash_string 16 | import os 17 | import math 18 | 19 | 20 | OPTIONS = ["A", "B"] 21 | 22 | 23 | class FrameSelection: 24 | def __init__(self, model, selector_method, output_frames, sample_frames, score_docs=None, sub_vid_cache=None): 25 | self.model = model 26 | self.output_frames = output_frames 27 | self.sample_frames = sample_frames 28 | self.selector_method = selector_method 29 | 30 | # TODO: score batch_size in options 31 | if hasattr(self.model, 'model_type') and self.model.model_type == 'clip': 32 | self.score_batch_size = 128 33 | else: 34 | self.score_batch_size = 1 35 | 36 | if score_docs is not None: 37 | score_docs = json.load(open(score_docs, 'r')) 38 | score_method = 'doc' 39 | self.score_docs = {} 40 | for doc in score_docs: 41 | self.score_docs[doc['id']] = doc 42 | else: 43 | score_method = 'model' 44 | 45 | if selector_method == 'uniform': 46 | assert output_frames == sample_frames 47 | self.proposal = None 48 | self.score_method = None 49 | self.process_score = None 50 | elif selector_method == 'topk_frames': 51 | self.proposal = 'frames' 52 | self.score_method = score_method 53 | self.process_score = 'topk_flatten' 54 | elif selector_method == 'topk_pairs': 55 | self.proposal = 'pairs' 56 | self.score_method = score_method 57 | self.process_score = 'topk_flatten' 58 | elif selector_method.startswith('annot_scores'): 59 | self.proposal = selector_method.replace('annot_scores_', '') 60 | self.score_method = score_method 61 | self.process_score = None 62 | 63 | self.select_id = self.model.tokenizer.convert_tokens_to_ids(OPTIONS[0]) 64 | 65 | self.sub_vid_cache = sub_vid_cache 66 | 67 | def generate_proposals(self, images, frame_indices): 68 | if self.proposal is None: 69 | return [images], [frame_indices] 70 | 71 | proposal_func = getattr(self, 'proposal_' + self.proposal) 72 | images = proposal_func(images) 73 | frame_indices = proposal_func(frame_indices) 74 | 75 | return images, frame_indices 76 | 77 | def proposal_frames(self, lst): 78 | return split_list(lst, len(lst)) 79 | 80 | def proposal_pairs(self, lst): 81 | return [[lst[i], lst[i + 1]] for i in range(len(lst) - 1)] 82 | 83 | def score_prompt(self, doc): 84 | 85 | if hasattr(self.model, 'model_type') and self.model.model_type == 'clip': 86 | return doc['question'].strip() 87 | 88 | prompt = f"Question: {doc['question'].strip()}\n" 89 | 90 | if self.proposal == 'frames': 91 | task_prompt = "Does the information within the image provide the necessary details to accurately answer the given question?\n" 92 | elif self.proposal == 'pairs': 93 | task_prompt = "Does the information within the images provide the necessary details to accurately answer the given question?\n" 94 | else: 95 | raise NotImplementedError 96 | 97 | post_prompt = f"{OPTIONS[0]}. yes\n{OPTIONS[1]}. no\n" 98 | post_prompt += "Answer with the option's letter from the given choices directly." 99 | 100 | return prompt + task_prompt + post_prompt 101 | 102 | def compute_scores_model(self, doc, images): 103 | prompt = self.score_prompt(doc) 104 | prompt = (prompt, None) 105 | 106 | def post_proc_func(x): 107 | logits = x['scores'][0] 108 | scores = F.softmax(logits, dim=-1)[:, self.select_id].detach() 109 | return scores 110 | 111 | scores = [] 112 | for proposal_images in images: 113 | score = self.model.run_model(proposal_images, prompt, output_scores=True, post_proc_func=post_proc_func) 114 | scores.append(score) 115 | scores = torch.cat(scores, dim=0) 116 | 117 | return scores 118 | 119 | def compute_scores_model_batch(self, doc, images): 120 | prompt = self.score_prompt(doc) 121 | prompt = (prompt, None) 122 | 123 | def post_proc_func(x): 124 | logits = x['scores'][0] 125 | scores = F.softmax(logits, dim=-1)[:, self.select_id].detach() 126 | return scores 127 | 128 | scores = [] 129 | iters = math.ceil(len(images) / self.score_batch_size) 130 | images_list = split_list(images, iters) 131 | for proposal_images in images_list: 132 | assert len(proposal_images[0]) == 1 133 | proposal_images = [img for imgs in proposal_images for img in imgs] 134 | score = self.model.run_model_batch(proposal_images, prompt, output_scores=True, post_proc_func=post_proc_func) 135 | scores.append(score) 136 | scores = torch.cat(scores, dim=0) 137 | 138 | return scores 139 | 140 | def compute_scores_model_cache(self, doc, images, frame_indices, cache_dir): 141 | input_string = "" 142 | input_string += json.dumps(doc, sort_keys=True) + '\n' 143 | input_string += str(self.model.model) 144 | input_string += f"\n{self.selector_method}_{self.output_frames}_{self.sample_frames}" 145 | 146 | cache_file = hash_string(input_string) 147 | cache_file = os.path.join(cache_dir, cache_file) 148 | 149 | prompt = self.score_prompt(doc) 150 | prompt = (prompt, None) 151 | 152 | def post_proc_func(x): 153 | logits = x['scores'][0] 154 | scores = F.softmax(logits, dim=-1)[:, self.select_id].detach() 155 | return scores 156 | 157 | # load from cache 158 | cache_dict = {} 159 | if os.path.exists(cache_file): 160 | with open(cache_file, "r") as file: 161 | for line in file: 162 | idx, score = line.strip().split(',') 163 | cache_dict[idx] = torch.Tensor([float(score)]).to(self.model.model.device) 164 | 165 | scores = [] 166 | 167 | # for proposal_images in images: 168 | for i, proposal_images in enumerate(images): 169 | idx = frame_indices[i] 170 | idx = ','.join(map(str, idx)) 171 | if idx in cache_dict: 172 | score = cache_dict[idx] 173 | else: 174 | score = self.model.run_model(proposal_images, prompt, output_scores=True, post_proc_func=post_proc_func) 175 | with open(cache_file, "a") as file: 176 | file.write(f"{idx},{float(score.cpu().numpy()[0])}\n") 177 | scores.append(score) 178 | 179 | assert len(scores) == len(images) 180 | scores = torch.cat(scores, dim=0) 181 | 182 | return scores 183 | 184 | def compute_scores_doc(self, doc, frame_indices): 185 | score_doc = self.score_docs[doc['id']] 186 | 187 | score_dict = {} 188 | for frames, scores in zip(score_doc['frames'], score_doc['scores']): 189 | score_dict[tuple(frames)] = scores 190 | 191 | scores = [] 192 | for proposal_indices in frame_indices: 193 | score = score_dict[tuple(proposal_indices)] 194 | scores.append(score) 195 | scores = torch.Tensor(scores).to(self.model.model.device) 196 | 197 | return scores 198 | 199 | def annotate_scores(self, doc, images, frame_indices): 200 | images, frame_indices = self.generate_proposals(images, frame_indices) 201 | 202 | assert self.score_method == 'model' 203 | if self.sub_vid_cache is not None: 204 | scores = self.compute_scores_model_cache(doc, images, frame_indices, self.sub_vid_cache) 205 | elif self.score_batch_size > 1: 206 | scores = self.compute_scores_model_batch(doc, images) 207 | else: 208 | scores = self.compute_scores_model(doc, images) 209 | 210 | if scores is not None: 211 | scores = list(scores.cpu().numpy()) 212 | 213 | return scores, frame_indices 214 | 215 | def process_flatten(self, selected, images, frame_indices): 216 | selected_images = [] 217 | selected_indices = [] 218 | for s in selected: 219 | selected_images.append(images[s]) 220 | selected_indices.append(frame_indices[s]) 221 | 222 | # flatten 223 | selected_images = [img for imgs in selected_images for img in imgs] 224 | selected_indices = [idx for indices in selected_indices for idx in indices] 225 | 226 | # sort by index 227 | argsort = np.argsort(selected_indices) 228 | selected_images = [[selected_images[i] for i in argsort]] 229 | selected_indices = [[selected_indices[i] for i in argsort]] 230 | 231 | return selected_images, selected_indices 232 | 233 | def select_frames(self, doc, images, frame_indices, return_scores=False): 234 | images, frame_indices = self.generate_proposals(images, frame_indices) 235 | 236 | if self.score_method is None: 237 | return images, frame_indices 238 | elif self.score_method == 'model': 239 | if self.score_batch_size > 1: 240 | scores = self.compute_scores_model_batch(doc, images) 241 | else: 242 | scores = self.compute_scores_model(doc, images) 243 | elif self.score_method == 'doc': 244 | scores = self.compute_scores_doc(doc, frame_indices) 245 | 246 | if self.process_score == 'topk_flatten': 247 | assert self.output_frames % len(images[0]) == 0 248 | k = self.output_frames // len(images[0]) 249 | k = min(k, scores.shape[-1]) 250 | _, selected = torch.topk(scores, k=k, dim=-1) 251 | selected = selected.cpu().numpy() 252 | 253 | selected_images, selected_indices = self.process_flatten(selected, images, frame_indices) 254 | 255 | if return_scores: 256 | scores = list(scores.cpu().numpy()) 257 | scores = [float(x) for x in scores] 258 | frame_indices = [tuple(x) for x in frame_indices] 259 | if len(frame_indices[0]) == 1: 260 | frame_indices = [x[0] for x in frame_indices] 261 | scores_dict = dict(zip(frame_indices, scores)) 262 | 263 | return selected_images, selected_indices, scores_dict 264 | else: 265 | return selected_images, selected_indices 266 | 267 | -------------------------------------------------------------------------------- /tasks/mmlbdoc.py: -------------------------------------------------------------------------------- 1 | # Adopted from MMLongBench-Doc from https://github.com/mayubo2333/MMLongBench-Doc. Below is the original copyright: 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .base import DocumentTask 16 | import os 17 | 18 | import re 19 | from math import isclose 20 | from collections import defaultdict 21 | 22 | from openai import OpenAI 23 | client = OpenAI() 24 | 25 | 26 | extract_prompt = """ 27 | Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis. 28 | - Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer". 29 | - Please make your response as concise as possible. Also note that your response should be formatted as below: 30 | ``` 31 | Extracted answer: [answer] 32 | Answer format: [answer format] 33 | ``` 34 | 35 | Please read the following example, then extract the answer from the model response and type it at the end of the prompt. 36 | 37 | --- 38 | Question: List the primary questions asked about the services in this report. 39 | Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led? 40 | Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?'] 41 | Answer format: List 42 | 43 | --- 44 | Question: How many regulations of the HSCA 2008 are breached in all according to this report? 45 | Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents. 46 | Extracted answer: 10 47 | Answer format: Integer 48 | 49 | --- 50 | Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election? 51 | Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that! 52 | Extracted answer: Not answerable 53 | Answer format: String 54 | 55 | --- 56 | Question: How many quotations from male respondent over 50 years old are included in this report? 57 | Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question. 58 | Extracted answer: Fail to answer 59 | Answer format: String 60 | 61 | --- 62 | """ 63 | 64 | def levenshtein_distance(s1, s2): 65 | if len(s1) > len(s2): 66 | s1, s2 = s2, s1 67 | 68 | distances = range(len(s1) + 1) 69 | for i2, c2 in enumerate(s2): 70 | distances_ = [i2 + 1] 71 | for i1, c1 in enumerate(s1): 72 | if c1 == c2: 73 | distances_.append(distances[i1]) 74 | else: 75 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 76 | distances = distances_ 77 | return distances[-1] 78 | 79 | 80 | def anls_compute(groundtruth, prediction, threshold=0.5): 81 | dist = levenshtein_distance(groundtruth, prediction) 82 | length = max(len(groundtruth.upper()), len(prediction.upper())) 83 | value = 0.0 if length == 0 else float(dist) / float(length) 84 | anls = 1.0 - value 85 | if anls<=threshold: 86 | anls = 0.0 87 | return anls 88 | 89 | 90 | def is_float_equal(reference, prediction, include_percentage: bool = False, is_close: float = False) -> bool: 91 | def get_precision(gt_ans: float) -> int: 92 | precision = 3 93 | if '.' in str(gt_ans): 94 | precision = len(str(gt_ans).split('.')[-1]) 95 | return precision 96 | 97 | reference = float(str(reference).strip().rstrip("%").strip()) 98 | try: 99 | prediction = float(str(prediction).strip().rstrip("%").strip()) 100 | except: 101 | return False 102 | 103 | if include_percentage: 104 | gt_result = [reference / 100, reference, reference * 100] 105 | else: 106 | gt_result = [reference] 107 | for item in gt_result: 108 | try: 109 | if is_close: 110 | if isclose(item, prediction, rel_tol=0.01): 111 | return True 112 | precision = max(min(get_precision(prediction), get_precision(item)), 2) 113 | if round(prediction, precision) == round(item, precision): 114 | return True 115 | except Exception: 116 | continue 117 | return False 118 | 119 | 120 | def get_clean_string(s): 121 | s = str(s).lower().strip() 122 | if s.endswith("mile"): 123 | s.rstrip("mile").strip() 124 | if s.endswith("miles"): 125 | s.rstrip("miles").strip() 126 | if s.endswith("million"): 127 | s.rstrip("million").strip() 128 | # remove parenthesis 129 | s = re.sub(r'\s*\([^)]*\)', '', s).strip() 130 | # remove quotes 131 | s = re.sub(r"^['\"]|['\"]$", "", s).strip() 132 | s = s.strip().lstrip("$").strip() 133 | s = s.strip().rstrip("%").strip() 134 | return s 135 | 136 | 137 | def is_exact_match(s): 138 | flag = False 139 | # Website 140 | if "https://" in s: 141 | flag = True 142 | # code file 143 | if s.endswith(".py") or s.endswith("ipynb"): 144 | flag = True 145 | if s.startswith("page"): 146 | flag = True 147 | # telephone number 148 | if re.fullmatch(r'\b\d+(-\d+|\s\d+)?\b', s): 149 | flag = True 150 | # time 151 | if "a.m." in s or "p.m." in s: 152 | flag = True 153 | # YYYY-MM-DD 154 | if re.fullmatch(r'\b\d{4}[-\s]\d{2}[-\s]\d{2}\b', s): 155 | flag = True 156 | # YYYY-MM 157 | if re.fullmatch(r'\b\d{4}[-\s]\d{2}\b', s): 158 | flag = True 159 | # Email address 160 | if re.fullmatch(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', s): 161 | flag = True 162 | return flag 163 | 164 | 165 | def isfloat(num): 166 | try: 167 | float(num) 168 | return True 169 | except ValueError: 170 | return False 171 | 172 | 173 | def eval_score(gt, pred, answer_type): 174 | if answer_type=="Int": 175 | try: 176 | gt, pred = int(gt), int(float(pred)) 177 | except: 178 | pred = "" 179 | score = (gt==pred) 180 | elif answer_type=="Float": 181 | try: 182 | gt = float(get_clean_string(str(gt))) 183 | pred = float(get_clean_string(str(pred))) 184 | except: 185 | pred = "" 186 | score = is_float_equal(gt, pred, include_percentage=True, is_close=True) 187 | elif answer_type in ["Str", "None"]: 188 | gt = get_clean_string(gt) 189 | pred = get_clean_string(pred) 190 | if is_exact_match(gt): 191 | score = (gt==pred) 192 | else: 193 | score = anls_compute(gt, pred) 194 | else: 195 | if isinstance(gt, str) and gt.startswith("["): 196 | gt = eval(gt) 197 | if not isinstance(gt, list): 198 | gt = [gt] 199 | if isinstance(pred, str) and pred.startswith("["): 200 | pred = eval(pred) 201 | if not isinstance(pred, list): 202 | pred = [pred] 203 | print(len(gt), len(pred)) 204 | if len(gt)!=len(pred): 205 | score = 0.0 206 | else: 207 | gt = sorted([get_clean_string(a) for a in gt]) 208 | pred = sorted([get_clean_string(a) for a in pred]) 209 | print(gt, pred) 210 | if isfloat(gt[0]) or is_exact_match(gt[0]): 211 | score = ("-".join(gt)=="-".join(pred)) 212 | else: 213 | score = min([anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred)]) 214 | 215 | return float(score) 216 | 217 | 218 | def eval_acc_and_f1(samples): 219 | evaluated_samples = [sample for sample in samples if "score" in sample] 220 | if not evaluated_samples: 221 | return 0.0, 0.0 222 | 223 | acc = sum([sample["score"] for sample in evaluated_samples])/len(evaluated_samples) 224 | try: 225 | recall = sum([sample["score"] for sample in evaluated_samples if sample["answer"]!="Not answerable"])/len([sample for sample in evaluated_samples if sample["answer"]!="Not answerable"]) 226 | precision = sum([sample["score"] for sample in evaluated_samples if sample["answer"]!="Not answerable"])/len([sample for sample in evaluated_samples if sample["pred"]!="Not answerable"]) 227 | f1 = 2*recall*precision/(recall+precision) if (recall+precision)>0.0 else 0.0 228 | except: 229 | f1 = 0.0 230 | 231 | return acc, f1 232 | 233 | 234 | def extract_answer(question, output, prompt, model_name="gpt-4o"): 235 | try: 236 | response = client.chat.completions.create( 237 | model=model_name, 238 | messages=[ 239 | { 240 | "role": "user", 241 | "content": prompt, 242 | }, 243 | { 244 | "role": "assistant", 245 | "content": "\n\nQuestion:{}\nAnalysis:{}\n".format(question, output) 246 | } 247 | ], 248 | temperature=0.0, 249 | max_tokens=256, 250 | top_p=1, 251 | frequency_penalty=0, 252 | presence_penalty=0 253 | ) 254 | response = response.choices[0].message.content 255 | except: 256 | response = "Failed" 257 | 258 | return response 259 | 260 | 261 | class MMLongBenchDoc(DocumentTask): 262 | def __init__(self, dataset, split, **kwargs): 263 | super().__init__(dataset, split, **kwargs) 264 | 265 | assert self.split in ['old', 'test'] 266 | self.post_prompt = "" 267 | self.thres = 0.4 268 | self.model_name = "gpt-4o" 269 | 270 | def doc_to_visual_name(self, doc): 271 | return doc["doc_id"] 272 | 273 | def aggregate_results(self, docs, out_root): 274 | os.makedirs(out_root + '_' + self.model_name, exist_ok=True) 275 | 276 | id_to_scores = {} 277 | score_path = os.getenv("SCORE_PATH") 278 | if score_path is not None: 279 | score_docs = json.load(open(score_path)) 280 | for doc in score_docs: 281 | id_to_scores[doc["id"]] = doc["scores"] 282 | 283 | for doc in docs: 284 | if doc["id"] in id_to_scores: 285 | scores = id_to_scores[doc["id"]] 286 | frames = doc["frames"][0] 287 | selected_scores = [scores[x] for x in frames] 288 | else: 289 | selected_scores = [1.0] 290 | 291 | if max(selected_scores) > self.thres: 292 | response = doc['pred'][0].strip() 293 | extracted_res_file = os.path.join(out_root + '_' + self.model_name, doc['id'] + '.txt') 294 | if os.path.exists(extracted_res_file): 295 | with open(extracted_res_file, 'r') as file: 296 | extracted_res = file.read() 297 | else: 298 | extracted_res = extract_answer(doc["question"], response, extract_prompt, model_name=self.model_name) 299 | with open(extracted_res_file, 'w') as file: 300 | file.write(extracted_res) 301 | try: 302 | pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() 303 | except: 304 | pred_ans = "Failed to extract" 305 | else: 306 | pred_ans = "Not answerable" 307 | doc["pred"] = pred_ans 308 | 309 | try: 310 | score = eval_score(doc['answer'], pred_ans, doc["answer_format"]) 311 | except: 312 | score = 0.0 313 | assert "score" not in doc 314 | doc["score"] = score 315 | 316 | acc, f1 = eval_acc_and_f1(docs) 317 | print(f"Acc: {acc*100}\nF1: {f1*100}") 318 | out_file = out_root + '.log' 319 | with open(out_file, 'a') as file: 320 | file.write(f"{self.model_name}\nAcc: {acc*100}\nF1: {f1*100}\n") 321 | -------------------------------------------------------------------------------- /models/internvl/conversation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see https://github.com/OpenGVLab/InternVL/blob/main/LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | """ 8 | Conversation prompt templates. 9 | 10 | We kindly request that you import fastchat instead of copying this file if you wish to use it. 11 | If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. 12 | """ 13 | 14 | import dataclasses 15 | from enum import IntEnum, auto 16 | from typing import Any, Dict, List, Tuple, Union 17 | 18 | 19 | class SeparatorStyle(IntEnum): 20 | """Separator styles.""" 21 | 22 | ADD_COLON_SINGLE = auto() 23 | ADD_COLON_TWO = auto() 24 | ADD_COLON_SPACE_SINGLE = auto() 25 | NO_COLON_SINGLE = auto() 26 | NO_COLON_TWO = auto() 27 | ADD_NEW_LINE_SINGLE = auto() 28 | LLAMA2 = auto() 29 | CHATGLM = auto() 30 | CHATML = auto() 31 | CHATINTERN = auto() 32 | DOLLY = auto() 33 | RWKV = auto() 34 | PHOENIX = auto() 35 | ROBIN = auto() 36 | FALCON_CHAT = auto() 37 | CHATGLM3 = auto() 38 | INTERNVL_ZH = auto() 39 | MPT = auto() 40 | 41 | 42 | @dataclasses.dataclass 43 | class Conversation: 44 | """A class that manages prompt templates and keeps all conversation history.""" 45 | 46 | # The name of this template 47 | name: str 48 | # The template of the system prompt 49 | system_template: str = '{system_message}' 50 | # The system message 51 | system_message: str = '' 52 | # The names of two roles 53 | roles: Tuple[str] = ('USER', 'ASSISTANT') 54 | # All messages. Each item is (role, message). 55 | messages: List[List[str]] = () 56 | # The number of few shot examples 57 | offset: int = 0 58 | # The separator style and configurations 59 | sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE 60 | sep: str = '\n' 61 | sep2: str = None 62 | # Stop criteria (the default one is EOS token) 63 | stop_str: Union[str, List[str]] = None 64 | # Stops generation if meeting any token in this list 65 | stop_token_ids: List[int] = None 66 | 67 | def get_prompt(self) -> str: 68 | """Get the prompt for generation.""" 69 | system_prompt = self.system_template.format(system_message=self.system_message) 70 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 71 | ret = system_prompt + self.sep 72 | for role, message in self.messages: 73 | if message: 74 | ret += role + ': ' + message + self.sep 75 | else: 76 | ret += role + ':' 77 | return ret 78 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 79 | seps = [self.sep, self.sep2] 80 | ret = system_prompt + seps[0] 81 | for i, (role, message) in enumerate(self.messages): 82 | if message: 83 | ret += role + ': ' + message + seps[i % 2] 84 | else: 85 | ret += role + ':' 86 | return ret 87 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 88 | ret = system_prompt + self.sep 89 | for role, message in self.messages: 90 | if message: 91 | ret += role + ': ' + message + self.sep 92 | else: 93 | ret += role + ': ' # must be end with a space 94 | return ret 95 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 96 | ret = '' if system_prompt == '' else system_prompt + self.sep 97 | for role, message in self.messages: 98 | if message: 99 | ret += role + '\n' + message + self.sep 100 | else: 101 | ret += role + '\n' 102 | return ret 103 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 104 | ret = system_prompt 105 | for role, message in self.messages: 106 | if message: 107 | ret += role + message + self.sep 108 | else: 109 | ret += role 110 | return ret 111 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: 112 | seps = [self.sep, self.sep2] 113 | ret = system_prompt 114 | for i, (role, message) in enumerate(self.messages): 115 | if message: 116 | ret += role + message + seps[i % 2] 117 | else: 118 | ret += role 119 | return ret 120 | elif self.sep_style == SeparatorStyle.RWKV: 121 | ret = system_prompt 122 | for i, (role, message) in enumerate(self.messages): 123 | if message: 124 | ret += ( 125 | role 126 | + ': ' 127 | + message.replace('\r\n', '\n').replace('\n\n', '\n') 128 | ) 129 | ret += '\n\n' 130 | else: 131 | ret += role + ':' 132 | return ret 133 | elif self.sep_style == SeparatorStyle.LLAMA2: 134 | seps = [self.sep, self.sep2] 135 | if self.system_message: 136 | ret = system_prompt 137 | else: 138 | ret = '[INST] ' 139 | for i, (role, message) in enumerate(self.messages): 140 | tag = self.roles[i % 2] 141 | if message: 142 | if i == 0: 143 | ret += message + ' ' 144 | else: 145 | ret += tag + ' ' + message + seps[i % 2] 146 | else: 147 | ret += tag 148 | return ret 149 | elif self.sep_style == SeparatorStyle.CHATGLM: 150 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 151 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 152 | round_add_n = 1 if self.name == 'chatglm2' else 0 153 | if system_prompt: 154 | ret = system_prompt + self.sep 155 | else: 156 | ret = '' 157 | 158 | for i, (role, message) in enumerate(self.messages): 159 | if i % 2 == 0: 160 | ret += f'[Round {i//2 + round_add_n}]{self.sep}' 161 | 162 | if message: 163 | ret += f'{role}:{message}{self.sep}' 164 | else: 165 | ret += f'{role}:' 166 | return ret 167 | elif self.sep_style == SeparatorStyle.CHATML: 168 | ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' 169 | for role, message in self.messages: 170 | if message: 171 | ret += role + '\n' + message + self.sep + '\n' 172 | else: 173 | ret += role + '\n' 174 | return ret 175 | elif self.sep_style == SeparatorStyle.CHATGLM3: 176 | ret = '' 177 | if self.system_message: 178 | ret += system_prompt 179 | for role, message in self.messages: 180 | if message: 181 | ret += role + '\n' + ' ' + message 182 | else: 183 | ret += role 184 | return ret 185 | elif self.sep_style == SeparatorStyle.CHATINTERN: 186 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 187 | seps = [self.sep, self.sep2] 188 | ret = system_prompt 189 | for i, (role, message) in enumerate(self.messages): 190 | # if i % 2 == 0: 191 | # ret += "" 192 | if message: 193 | ret += role + ':' + message + seps[i % 2] + '\n' 194 | else: 195 | ret += role + ':' 196 | return ret 197 | elif self.sep_style == SeparatorStyle.DOLLY: 198 | seps = [self.sep, self.sep2] 199 | ret = system_prompt 200 | for i, (role, message) in enumerate(self.messages): 201 | if message: 202 | ret += role + ':\n' + message + seps[i % 2] 203 | if i % 2 == 1: 204 | ret += '\n\n' 205 | else: 206 | ret += role + ':\n' 207 | return ret 208 | elif self.sep_style == SeparatorStyle.PHOENIX: 209 | ret = system_prompt 210 | for role, message in self.messages: 211 | if message: 212 | ret += role + ': ' + '' + message + '' 213 | else: 214 | ret += role + ': ' + '' 215 | return ret 216 | elif self.sep_style == SeparatorStyle.ROBIN: 217 | ret = system_prompt + self.sep 218 | for role, message in self.messages: 219 | if message: 220 | ret += role + ':\n' + message + self.sep 221 | else: 222 | ret += role + ':\n' 223 | return ret 224 | elif self.sep_style == SeparatorStyle.FALCON_CHAT: 225 | ret = '' 226 | if self.system_message: 227 | ret += system_prompt + self.sep 228 | for role, message in self.messages: 229 | if message: 230 | ret += role + ': ' + message + self.sep 231 | else: 232 | ret += role + ':' 233 | 234 | return ret 235 | elif self.sep_style == SeparatorStyle.INTERNVL_ZH: 236 | seps = [self.sep, self.sep2] 237 | ret = self.system_message + seps[0] 238 | for i, (role, message) in enumerate(self.messages): 239 | if message: 240 | ret += role + ': ' + message + seps[i % 2] 241 | else: 242 | ret += role + ':' 243 | return ret 244 | elif self.sep_style == SeparatorStyle.MPT: 245 | ret = system_prompt + self.sep 246 | for role, message in self.messages: 247 | if message: 248 | if type(message) is tuple: 249 | message, _, _ = message 250 | ret += role + message + self.sep 251 | else: 252 | ret += role 253 | return ret 254 | else: 255 | raise ValueError(f'Invalid style: {self.sep_style}') 256 | 257 | def set_system_message(self, system_message: str): 258 | """Set the system message.""" 259 | self.system_message = system_message 260 | 261 | def append_message(self, role: str, message: str): 262 | """Append a new message.""" 263 | self.messages.append([role, message]) 264 | 265 | def update_last_message(self, message: str): 266 | """Update the last output. 267 | 268 | The last message is typically set to be None when constructing the prompt, 269 | so we need to update it in-place after getting the response from a model. 270 | """ 271 | self.messages[-1][1] = message 272 | 273 | def to_gradio_chatbot(self): 274 | """Convert the conversation to gradio chatbot format.""" 275 | ret = [] 276 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 277 | if i % 2 == 0: 278 | ret.append([msg, None]) 279 | else: 280 | ret[-1][-1] = msg 281 | return ret 282 | 283 | def to_openai_api_messages(self): 284 | """Convert the conversation to OpenAI chat completion format.""" 285 | ret = [{'role': 'system', 'content': self.system_message}] 286 | 287 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 288 | if i % 2 == 0: 289 | ret.append({'role': 'user', 'content': msg}) 290 | else: 291 | if msg is not None: 292 | ret.append({'role': 'assistant', 'content': msg}) 293 | return ret 294 | 295 | def copy(self): 296 | return Conversation( 297 | name=self.name, 298 | system_template=self.system_template, 299 | system_message=self.system_message, 300 | roles=self.roles, 301 | messages=[[x, y] for x, y in self.messages], 302 | offset=self.offset, 303 | sep_style=self.sep_style, 304 | sep=self.sep, 305 | sep2=self.sep2, 306 | stop_str=self.stop_str, 307 | stop_token_ids=self.stop_token_ids, 308 | ) 309 | 310 | def dict(self): 311 | return { 312 | 'template_name': self.name, 313 | 'system_message': self.system_message, 314 | 'roles': self.roles, 315 | 'messages': self.messages, 316 | 'offset': self.offset, 317 | } 318 | 319 | 320 | # A global registry for all conversation templates 321 | conv_templates: Dict[str, Conversation] = {} 322 | 323 | 324 | def register_conv_template(template: Conversation, override: bool = False): 325 | """Register a new conversation template.""" 326 | if not override: 327 | assert ( 328 | template.name not in conv_templates 329 | ), f'{template.name} has been registered.' 330 | 331 | conv_templates[template.name] = template 332 | 333 | 334 | def get_conv_template(name: str) -> Conversation: 335 | """Get a conversation template.""" 336 | return conv_templates[name].copy() 337 | 338 | 339 | # Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference 340 | # is that during training, the preprocessing function for the Hermes-2 template doesn't add 341 | # at the beginning of the tokenized sequence, while the internlm2-chat template does. 342 | # Therefore, they are completely equivalent during inference. 343 | register_conv_template( 344 | Conversation( 345 | name='Hermes-2', 346 | system_template='<|im_start|>system\n{system_message}', 347 | # note: The new system prompt was not used here to avoid changes in benchmark performance. 348 | # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 349 | system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', 350 | roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), 351 | sep_style=SeparatorStyle.MPT, 352 | sep='<|im_end|>', 353 | stop_token_ids=[ 354 | 2, 355 | 6, 356 | 7, 357 | 8, 358 | ], 359 | stop_str='<|endoftext|>', 360 | ) 361 | ) 362 | 363 | 364 | register_conv_template( 365 | Conversation( 366 | name='internlm2-chat', 367 | system_template='<|im_start|>system\n{system_message}', 368 | # note: The new system prompt was not used here to avoid changes in benchmark performance. 369 | # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 370 | system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', 371 | roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), 372 | sep_style=SeparatorStyle.MPT, 373 | sep='<|im_end|>', 374 | stop_token_ids=[ 375 | 2, 376 | 92543, 377 | 92542 378 | ] 379 | ) 380 | ) 381 | 382 | 383 | register_conv_template( 384 | Conversation( 385 | name='phi3-chat', 386 | system_template='<|system|>\n{system_message}', 387 | # note: The new system prompt was not used here to avoid changes in benchmark performance. 388 | # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 389 | system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', 390 | roles=('<|user|>\n', '<|assistant|>\n'), 391 | sep_style=SeparatorStyle.MPT, 392 | sep='<|end|>', 393 | stop_token_ids=[ 394 | 2, 395 | 32000, 396 | 32007 397 | ] 398 | ) 399 | ) 400 | --------------------------------------------------------------------------------