├── README.md
├── asset
├── data.png
├── icon.png
├── main_fig.png
└── performance.png
├── evaluation
├── eval.py
├── example_result.jsonl
└── gpt_judge.py
└── llava_examples
├── README.md
├── evaluation_utils.py
├── llava-ov-7b-videoniah-result.jsonl
├── model_video_niah.py
└── niah_eval.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Needle In A Video Haystack: A Scalable Synthetic Framework for Benchmarking Video MLLMs (VideoNIAH)
2 |
3 | 
4 | 
5 | 
6 | 
7 | 
8 | 
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | [[🍎 Project Page](https://videoniah.github.io/)] [[📖 arXiv Paper](https://arxiv.org/abs/2406.09367)] [[📊 Dataset](https://huggingface.co/datasets/Joez1717/VNBench)]
17 |
18 | ---
19 |
20 | ## 🔥 News
21 | * **`2024.06.13`** 🌟 We are very proud to launch VideoNIAH, a scalable synthetic method for benchmarking video MLLMs, and VNBench, a comprehensive video synthetic benchmark!
22 | * **`2025.03.07`** 🌟 VideoNIAH has been accepted as a poster presentation at ICLR 2025.
23 |
24 |
25 | ## 👀 Overview
26 |
27 | We propose **VideoNIAH (Video Needle In A Haystack)**, a benchmark construction framework through synthetic video generation.
28 | **VideoNIAH** decouples test video content from their query-responses by inserting unrelated image/text 'needles' into original videos. It generates annotations solely from these needles, ensuring diversity in video sources and a variety of query-responses.
29 | Additionally, by inserting multiple needles, **VideoNIAH** rigorously evaluates the temporal understanding capabilities of models.
30 |
31 |
32 |
33 |
34 | We utilize VideoNIAH to compile a video benchmark **VNBench**, including tasks such as **retrieval**, **ordering**, and **counting**. **VNBench** contains 1350 samples in total.
35 | **VNBench** can efficiently evaluate the fine-grained understanding ability and spatio-temporal modeling ability of a video model, while also supporting the long-context evaluation.
36 |
37 |
38 |
39 |
40 |
41 | **VideoNIAH** is a **simple** yet highly **scalable** benchmark construction framework, and we believe it will inspire future video benchmark works!
42 |
43 |
44 | ## 🔍 Dataset
45 | Download the raw videos in VNBench from the [google drive link](https://drive.google.com/file/d/1KOUzy07viQzpmpcBqydUA043VQZ4nmRv/view?usp=sharing).
46 | Download the annotation of VNBench from the [huggingface link](https://huggingface.co/datasets/Joez1717/VNBench/tree/main)
47 | **License**:
48 | ```
49 | VNBench is only used for academic research. Commercial use in any form is prohibited.
50 | The copyright of all videos belongs to the video owners.
51 | ```
52 |
53 |
54 | ## 🔮 Evaluation Pipeline
55 | **Prompt**:
56 |
57 | The common prompt used in our evaluation follows this format:
58 |
59 | ```
60 |
61 | A.
62 | B.
63 | C.
64 | D.
65 | Answer with the option's letter from the given choices directly.
66 | ```
67 |
68 |
69 |
70 | **Evaluation**:
71 |
72 | We recommend you to save the inference result in the format as [example_result.jsonl](./evaluation/example_result.jsonl). Once you have prepared the model responses in this format, please execute our evaluation script [eval.py](./evaluation/eval_your_results.py), and you will get the accuracy scores.
73 |
74 |
75 | ```bash
76 | python eval.py \
77 | --path $RESULTS_FILE
78 | ```
79 |
80 | If you want to use GPT-3.5 for evaluation, please use the script wo provided [gpt_judge.py](./evaluation/gpt_judge.py).
81 | ```bash
82 | python gpt_judge.py \
83 | --input_file $INPUT_FILE \
84 | --output_file $OUTPUT_FILE
85 | ```
86 |
87 | For convenience, we provide [evaluation code examples](./llava_examples) for the [llavanext](https://github.com/LLaVA-VL/LLaVA-NeXT) series models.
88 |
89 | ## 📈 Experimental Results
90 | - **Evaluation results of different Video MLLMs.**
91 | Please visit the [leaderboard](https://videoniah.github.io/) for more details.
92 |
93 |
94 |
95 |
96 | ## Citation
97 |
98 | If you find our work helpful for your research, please consider citing our work.
99 |
100 | ```bibtex
101 | @article{zhao2024videoniah,
102 | title={Needle In A Video Haystack: A Scalable Synthetic Framework for Benchmarking Video MLLMs},
103 | author={Zhao, Zijia and Lu, Haoyu and Huo, Yuqi and Du, Yifan and Yue, Tongtian and Guo, Longteng and Wang, Bingning and Chen, Weipeng and Liu, Jing},
104 | journal={arXiv preprint},
105 | year={2024}
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/asset/data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joez17/VideoNIAH/8193a3b20ac1c2fcac9c3e8c1550095628f7628f/asset/data.png
--------------------------------------------------------------------------------
/asset/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joez17/VideoNIAH/8193a3b20ac1c2fcac9c3e8c1550095628f7628f/asset/icon.png
--------------------------------------------------------------------------------
/asset/main_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joez17/VideoNIAH/8193a3b20ac1c2fcac9c3e8c1550095628f7628f/asset/main_fig.png
--------------------------------------------------------------------------------
/asset/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joez17/VideoNIAH/8193a3b20ac1c2fcac9c3e8c1550095628f7628f/asset/performance.png
--------------------------------------------------------------------------------
/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--path", type=str, default="evaluation/example_result.jsonl")
6 | args = parser.parse_args()
7 | annos = [json.loads(q) for q in open(os.path.expanduser(args.path), "r")]
8 | res = {}
9 | for anno in annos:
10 | name = anno['question_id'][:-2]
11 | label = anno['type']
12 | if anno['pred'] is None:
13 | continue
14 | if not label in res:
15 | res[label] = []
16 | if anno['gt'] in [0, 1, 2, 3]:
17 | anno['gt'] = chr(ord('A') + anno['gt'])
18 | anno['pred'] = anno['pred'].split('.')[0]
19 | dic = {
20 | 'name': name,
21 | 'gt': anno['gt'],
22 | 'pred': anno['pred'],
23 | }
24 | if "gpt_judge" in anno:
25 | dic['judge'] = anno['gpt_judge'][0]
26 | res[label].append(dic)
27 |
28 | RES = {}
29 | result = {}
30 | sorted_items = sorted(res.items(), key=lambda x: x[0])
31 | for k, vv in sorted_items:
32 | acc = {}
33 | for v in vv:
34 | name = v['name']
35 | if not name in acc:
36 | acc[name] = 0
37 | if 'judge' in v:
38 | acc[name] += (v['judge']=='1')
39 | else:
40 | pred = v['pred']
41 | if 'A' in pred:
42 | pred = 'A'
43 | elif 'B' in pred:
44 | pred = 'B'
45 | elif 'C' in pred:
46 | pred = 'C'
47 | elif 'D' in pred:
48 | pred = 'D'
49 | acc[name] += (v['gt']==pred)
50 | accuracy = 0
51 | for n, ac in acc.items():
52 | if ac==4:
53 | accuracy += 1
54 | st = f'true: {accuracy}, total: {len(acc)}, acc: {accuracy/len(acc)}'
55 | RES[k] = st
56 | result[k] = accuracy/len(acc)
57 | RES_list = []
58 | for k, v in result.items():
59 | print(k)
60 | print(RES[k])
61 | RES_list.append(result[k])
62 | print('Overall: ', sum(RES_list)/len(RES_list))
63 |
--------------------------------------------------------------------------------
/evaluation/example_result.jsonl:
--------------------------------------------------------------------------------
1 | {"question_id": "3261112327_ret_insert1_0", "pred": "D", "gt": "D", "type": "ret_insert1", "try": 0, "prompt": 'INFERENCE PROMPT'}
2 | {"question_id": "3261112327_ret_insert1_1", "pred": "B", "gt": "B", "type": "ret_insert1", "try": 1, "prompt": 'INFERENCE PROMPT'}
3 | {"question_id": "3261112327_ret_insert1_2", "pred": "B", "gt": "B", "type": "ret_insert1", "try": 2, "prompt": 'INFERENCE PROMPT'}
4 | {"question_id": "3261112327_ret_insert1_3", "pred": "A", "gt": "A", "type": "ret_insert1", "try": 3, "prompt": 'INFERENCE PROMPT'}
5 | {"question_id": "3261112327_ret_insert2_0", "pred": "D", "gt": "D", "type": "ret_insert2", "try": 0, "prompt": 'INFERENCE PROMPT'}
6 | {"question_id": "3261112327_ret_insert2_1", "pred": "B", "gt": "B", "type": "ret_insert2", "try": 1, "prompt": 'INFERENCE PROMPT'}
7 | {"question_id": "3261112327_ret_insert2_2", "pred": "B", "gt": "B", "type": "ret_insert2", "try": 2, "prompt": 'INFERENCE PROMPT'}
8 | {"question_id": "3261112327_ret_insert2_3", "pred": "A", "gt": "A", "type": "ret_insert2", "try": 3, "prompt": 'INFERENCE PROMPT'}
9 | {"question_id": "3261112327_ret_edit1_0", "pred": "D", "gt": "B", "type": "ret_edit1", "try": 0, "prompt": 'INFERENCE PROMPT'}
10 | {"question_id": "3261112327_ret_edit1_1", "pred": "A", "gt": "B", "type": "ret_edit1", "try": 1, "prompt": 'INFERENCE PROMPT'}
11 | {"question_id": "3261112327_ret_edit1_2", "pred": "C", "gt": "B", "type": "ret_edit1", "try": 2, "prompt": 'INFERENCE PROMPT'}
12 | {"question_id": "3261112327_ret_edit1_3", "pred": "A", "gt": "A", "type": "ret_edit1", "try": 3, "prompt": 'INFERENCE PROMPT'}
--------------------------------------------------------------------------------
/evaluation/gpt_judge.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 |
4 | import openai
5 | openai.api_base = ""
6 | openai.api_key = ''
7 | import time
8 | import json
9 | import tqdm
10 | from multiprocessing import Pool
11 | openai.api_base = ""
12 | openai.api_key = ''
13 | gpt_model = 'gpt-3.5-turbo'
14 | import argparse
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--input_file", type=str)
18 | parser.add_argument("--output_file", type=str)
19 | args = parser.parse_args()
20 |
21 | system_prompt = '''
22 | You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:
23 | ------
24 | ##INSTRUCTIONS:
25 | - Focus on the meaningful match between the predicted answer and the correct answer.
26 | - Consider synonyms or paraphrases as valid matches.
27 | - Evaluate the correctness of the prediction compared to the answer.
28 | '''
29 |
30 | def judge(ele):
31 | template = '''Please evaluate the following video-based question-answer pair:
32 | Question: {}
33 | Correct Answer: {}
34 | Predicted Answer: {}
35 | If the predicted answer expresses the same meaning as the correct answer, please output 1; otherwise, output 0.
36 | DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide 0 or 1.
37 | '''
38 | gpt_judge = []
39 |
40 | prompt = template.format(ele['prompt'].replace("Answer with the option's letter from the given choices directly.", ""), ele['gt'], ele['pred'])
41 | max_retries = 20
42 | retry_delay = 5
43 | retries = 0
44 | output = None
45 | while output is None and retries < max_retries:
46 | try:
47 | messages = [
48 | {"role": "system", "content": system_prompt},
49 | {"role": "user", "content": prompt},
50 | ]
51 | output = openai.ChatCompletion.create(
52 | model=gpt_model,
53 | max_tokens=10,
54 | temperature=0,
55 | messages=messages)
56 | if output is not None:
57 | output = output['choices'][0]['message']['content']
58 | else:
59 | retries += 1
60 | print(f"Attempt {retries}: Failed to get response, retrying after {retry_delay} seconds...")
61 | time.sleep(retry_delay)
62 | print(f"An error occurred: {e}")
63 | retries += 1
64 | print(f"Attempt {retries}: Exception encountered, retrying after {retry_delay} seconds...")
65 | time.sleep(retry_delay)
66 | if output is None:
67 | print("Failed to get a valid response from the API after maximum retries.")
68 | gpt_judge.append("No response")
69 | else:
70 | gpt_judge.append(output)
71 | print(output)
72 | ele['gpt_judge'] = gpt_judge
73 | return ele
74 | import os
75 | if __name__ == "__main__":
76 | output_file_path = args.output_file
77 | output_file = open(output_file_path, 'a')
78 | gpt_input = [json.loads(q) for q in open(os.path.expanduser(args.input_file), "r")]
79 | with Pool(150) as p:
80 | result = list(tqdm.tqdm(p.imap(judge, gpt_input), total=len(gpt_input)))
81 | for ele in result:
82 | output_file.write(json.dumps(ele)+"\n")
83 |
84 |
--------------------------------------------------------------------------------
/llava_examples/README.md:
--------------------------------------------------------------------------------
1 | # Instructions for Adding and Running Evaluation Scripts in LLaVA-NeXT
2 |
3 | ## File Placement
4 |
5 | You need to place three code files in the correct locations within the **LLaVA-NeXT** repository:
6 |
7 | | File | Destination Path |
8 | |------|-----------------|
9 | | `model_video_niah.py` | `LLaVA-NeXT/llava/eval/model_video_niah.py` |
10 | | `evaluation_utils.py` | `LLaVA-NeXT/scripts/video/eval/evaluation_utils.py` |
11 | | `niah_eval.sh` | `LLaVA-NeXT/scripts/video/eval/niah_eval.sh` |
12 |
13 | ## Modify `niah_eval.sh`
14 |
15 | After placing the files, modify `niah_eval.sh` to include the correct paths for:
16 |
17 | - Your **code directory**
18 | - The **model checkpoint**
19 | - The **video data folder**
20 | - The **annotation file path**
21 |
22 | ## Run the Evaluation Script
23 |
24 | Execute the following command to run the evaluation:
25 |
26 | ```bash
27 | bash scripts/video/eval/niah_eval.sh
28 |
--------------------------------------------------------------------------------
/llava_examples/evaluation_utils.py:
--------------------------------------------------------------------------------
1 | # prepare your result as a list named data
2 | # each element data is a dict containing video_path, task_type, judge_result
3 | # judge result is 1 for correct and 0 for incorrect
4 | # for official VNbench, len(data)==5400
5 | def get_detail_result(data):
6 | task_result = {}
7 | res = []
8 | processed_data = {}
9 | for d in data:
10 | if not 'judge_result' in d:
11 | d['judge_result'] = 1 if d['pred'][0]==d['gt'] else 0
12 | if d['video_path'] in task_result:
13 | task_result[d['video_path']] += d['judge_result']
14 | else:
15 | task_result[d['video_path']] = d['judge_result']
16 | processed_data[d['video_path']] = {'video_path': d['video_path'],
17 | 'task_type': d['task_type']}
18 | for k, v in task_result.items():
19 | if v==4:
20 | processed_data[k]['result'] = 1
21 | else:
22 | processed_data[k]['result'] = 0
23 | result = list(processed_data.values())
24 | assert 4*len(result) == len(data)
25 | acc = {}
26 | sample_num = {}
27 | for res in result:
28 | if res['task_type'] in acc:
29 | acc[res['task_type']] += res['result']
30 | else:
31 | acc[res['task_type']] = res['result']
32 | if res['task_type'] in sample_num:
33 | sample_num[res['task_type']] += 1
34 | else:
35 | sample_num[res['task_type']] = 1
36 | res = {}
37 | res['ret'] = 0.0
38 | res['ord'] = 0.0
39 | res['cnt'] = 0.0
40 | for k, v in acc.items():
41 | acc[k] = v/sample_num[k]
42 | if 'ord' in k:
43 | res['ord'] += acc[k]
44 | if 'cnt' in k:
45 | res['cnt'] += acc[k]
46 | if 'ret' in k:
47 | res['ret'] += acc[k]
48 | tmp = acc['ret_insert1']
49 | acc['ret_insert1'] = acc['ret_insert2']
50 | acc['ret_insert2'] = tmp
51 | acc = {key: acc[key] for key in sorted(acc.keys())}
52 |
53 | res['ord'] = res['ord']/3
54 | res['cnt'] = res['cnt']/3
55 | res['ret'] = res['ret']/3
56 | res = {key: res[key] for key in sorted(res.keys())}
57 | res['Overall'] = (res['ord']+res['cnt']+res['ret'])/3
58 | acc.update(res)
59 |
60 |
61 | return acc
62 |
63 | import argparse
64 | import json
65 | import os
66 | def parse_args():
67 | """
68 | Parse command-line arguments.
69 | """
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument("--result_path", help="result path, jsonl format", required=True)
72 | return parser.parse_args()
73 | if __name__ == "__main__":
74 | args = parse_args()
75 | # if '.jsonl' in args.result_path:
76 | # data = [json.loads(q) for q in open(os.path.expanduser(args.result_path), "r")]
77 | # else:
78 | # data = json.load(open(args.result_path))
79 | data = [json.loads(q) for q in open(os.path.expanduser(args.result_path), "r")]
80 | print(get_detail_result(data))
81 |
--------------------------------------------------------------------------------
/llava_examples/model_video_niah.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from operator import attrgetter
5 | from llava.model.builder import load_pretrained_model
6 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
7 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
8 | from llava.conversation import conv_templates, SeparatorStyle
9 |
10 | import torch
11 | import cv2
12 | import numpy as np
13 | from PIL import Image
14 | import requests
15 | import copy
16 | import warnings
17 | from decord import VideoReader, cpu
18 | from transformers import AutoConfig
19 | import json
20 | import os
21 |
22 | import math
23 | from tqdm import tqdm
24 | from decord import VideoReader, cpu
25 |
26 | import numpy as np
27 |
28 |
29 | def split_list(lst, n):
30 | """Split a list into n (roughly) equal-sized chunks"""
31 | chunk_size = math.ceil(len(lst) / n) # integer division
32 | return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
33 |
34 |
35 | def get_chunk(lst, n, k):
36 | chunks = split_list(lst, n)
37 | return chunks[k]
38 |
39 |
40 | def parse_args():
41 | """
42 | Parse command-line arguments.
43 | """
44 | parser = argparse.ArgumentParser()
45 |
46 | # Define the command-line arguments
47 | parser.add_argument("--video_dir", help="Directory containing video files.", required=True)
48 | parser.add_argument('--question_fp', help='Path to the question file.', required=True)
49 | # parser.add_argument("--gt_file_question", help="Path to the ground truth file containing question.", required=True)
50 | # parser.add_argument("--gt_file_answers", help="Path to the ground truth file containing answers.", required=True)
51 | parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True)
52 | parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True)
53 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
54 | # parser.add_argument("--model-base", type=str, default=None)
55 | # parser.add_argument("--conv-mode", type=str, default=None)
56 | parser.add_argument("--num-chunks", type=int, default=1)
57 | parser.add_argument("--chunk-idx", type=int, default=0)
58 | # parser.add_argument("--model-max-length", type=int, default=None)
59 | # parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool")
60 | # parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
61 | # parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
62 | # parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
63 | # parser.add_argument("--image_aspect_ratio", type=str, default="anyres")
64 | # parser.add_argument("--image_grid_pinpoints", type=str, default="[(224, 448), (224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]")
65 | # parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad")
66 | # parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == 'true'), default=True)
67 | parser.add_argument("--frames_num", type=int, default=4)
68 | return parser.parse_args()
69 |
70 |
71 | # Function to extract frames from video
72 | def load_video(video_path, max_frames_num):
73 | if type(video_path) == str:
74 | vr = VideoReader(video_path, ctx=cpu(0))
75 | else:
76 | vr = VideoReader(video_path[0], ctx=cpu(0))
77 | total_frame_num = len(vr)
78 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
79 | frame_idx = uniform_sampled_frames.tolist()
80 | spare_frames = vr.get_batch(frame_idx).asnumpy()
81 | return spare_frames # (frames, height, width, channels)
82 |
83 |
84 | def run_inference(args):
85 | """
86 | Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
87 |
88 | Args:
89 | args: Command-line arguments.
90 | """
91 | warnings.filterwarnings("ignore")
92 | # Load the OneVision model
93 | pretrained = args.model_path
94 | model_name = "llava_qwen"
95 | device = "cuda"
96 | device_map = "auto"
97 | tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
98 |
99 | model.eval()
100 |
101 | if '.jsonl' in args.question_fp:
102 | question_dict = [json.loads(q) for q in open(os.path.expanduser(args.question_fp), "r")]
103 | else:
104 | question_dict = json.load(open(args.question_fp))
105 | question_dict = get_chunk(question_dict, args.num_chunks, args.chunk_idx)
106 |
107 |
108 |
109 | # Create the output directory if it doesn't exist
110 | if not os.path.exists(args.output_dir):
111 | os.makedirs(args.output_dir)
112 |
113 | video_formats = [".mp4", ".avi", ".mov", ".mkv"]
114 | if args.num_chunks > 1:
115 | output_name = f"{args.num_chunks}_{args.chunk_idx}"
116 | else:
117 | output_name = args.output_name
118 | answers_file = os.path.join(args.output_dir, f"{output_name}.json")
119 | ans_file = open(answers_file, "w")
120 |
121 | index = 0
122 | for q_dict in tqdm(question_dict):
123 | # breakpoint()
124 | q_uid = q_dict['video'].split('/')[-1].replace('.mp4', '')
125 | if not os.path.exists(q_dict["video"]):
126 | video_path = os.path.join(args.video_dir, q_dict["video"])
127 | else:
128 | video_path = q_dict["video"]
129 |
130 | # Check if the video exists
131 | if os.path.exists(video_path):
132 | video_frames = load_video(video_path, args.frames_num)
133 | # print(video_frames.shape) # (16, 1024, 576, 3)
134 | image_tensors = []
135 | frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()
136 | image_tensors.append(frames)
137 |
138 | question0 = q_dict['question']
139 | options = q_dict['options']
140 | question = f"{question0}\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nAnswer with the option's letter from the given choices directly."
141 | # Process prompt.
142 | qs = question
143 |
144 | # Prepare conversation input
145 | conv_template = "qwen_1_5"
146 | question = f"{DEFAULT_IMAGE_TOKEN}\n{question}"
147 |
148 | conv = copy.deepcopy(conv_templates[conv_template])
149 | conv.append_message(conv.roles[0], question)
150 | conv.append_message(conv.roles[1], None)
151 | prompt = conv.get_prompt()
152 |
153 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
154 | image_sizes = [frame.size for frame in video_frames]
155 |
156 | with torch.inference_mode():
157 | output_ids = model.generate(
158 | input_ids,
159 | images=image_tensors,
160 | image_sizes=image_sizes,
161 | do_sample=False,
162 | temperature=0,
163 | max_new_tokens=5,
164 | modalities=["video"],
165 | )
166 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
167 | outputs = outputs.strip()
168 |
169 | outputs = outputs.strip()
170 | gt = chr(options.index(q_dict['gt']) + ord('A'))
171 | inf_res = {"video_path": q_dict['video'],
172 | "prompt": prompt,
173 | "pred": outputs,
174 | "gt": gt,
175 | "task_type": q_dict['type'],
176 | "try": q_dict['try'],
177 | "model_id": model_name}
178 | # print(inf_res)
179 | ans_file.write(json.dumps(inf_res) + "\n")
180 |
181 | ans_file.flush()
182 |
183 | ans_file.close()
184 |
185 |
186 | if __name__ == "__main__":
187 | args = parse_args()
188 | run_inference(args)
--------------------------------------------------------------------------------
/llava_examples/niah_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ROOT_DIR="YOURPATH/LLaVA-NeXT"
3 | CKPT="YOURPATH/ckpt/llava-onevision-qwen2-0.5b-ov"
4 | ANNOFILE="YOURPATH/VNBench/VNBench-main-4try.json"
5 | VIDEODIR="YOURPATH/VNBench"
6 | OUTPUT="YOURPATH/output"
7 | if [ ! -e $ROOT_DIR ]; then
8 | echo "The root dir does not exist. Exiting the script."
9 | exit 1
10 | fi
11 |
12 | cd $ROOT_DIR
13 | export PYTHONPATH="./:$PYTHONPATH"
14 | export PYTHONWARNINGS=ignore
15 | export TOKENIZERS_PARALLELISM=false
16 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
17 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
18 | IFS=',' read -ra GPULIST <<< "$gpu_list"
19 |
20 | CHUNKS=${#GPULIST[@]}
21 | echo "Using $CHUNKS GPUs"
22 | FRAMES=64
23 | mkdir -p $OUTPUT/video_niah_${FRAMES}
24 |
25 | for IDX in $(seq 0 $((CHUNKS-1))); do
26 | GPU_ID=${GPULIST[$IDX]} # Note: Zsh arrays are 1-indexed by default
27 | echo "Running on GPU $GPU_ID"
28 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 llava/eval/model_video_niah.py \
29 | --model-path $CKPT \
30 | --video_dir $VIDEODIR \
31 | --question_fp $ANNOFILE \
32 | --output_dir $OUTPUT/video_niah_${FRAMES} \
33 | --output_name pred \
34 | --num-chunks $CHUNKS \
35 | --chunk-idx $(($IDX - 1)) \
36 | --frames_num $FRAMES &
37 |
38 | done
39 | wait
40 | output_file=$OUTPUT/video_niah_${FRAMES}/merge.jsonl
41 |
42 | # Clear out the output file if it exists.
43 | > "$output_file"
44 |
45 | # Loop through the indices and concatenate each file.
46 | for IDX in $(seq -1 $((CHUNKS-2))); do
47 | cat $OUTPUT/video_niah_${FRAMES}/${CHUNKS}_${IDX}.json >> "$output_file"
48 | done
49 |
50 | python ./scripts/video/eval/evaluation_utils.py --result_path $output_file
--------------------------------------------------------------------------------