├── LICENSE ├── README.md ├── annotation_pipeline ├── 1_scenedetect_and_keyframes.py ├── 2_caption_keyframe_llava.py ├── 3_dense_video_description.py ├── 4_generate_qa.py └── run_pipeline.sh ├── docs └── images │ ├── IVAL_logo.png │ ├── MBZUAI_logo.png │ ├── MVBench_quantitative.png │ ├── Oryx_logo.png │ ├── VCGBench_quantitative.png │ ├── VCGDiverse_quantitative.png │ ├── block_diagram.png │ ├── demo_vcg+_full_part1.jpg │ ├── demo_vcg+_full_part2.jpg │ ├── demo_vcg+_main.png │ ├── intro_radar_plot.png │ ├── vcg120k_block_diagram.png │ ├── vcgbench_block_diag.png │ ├── videogpt_plus_face.jpeg │ └── zero_shot_quantitative.png ├── eval ├── README.md ├── merge.py ├── mvbench │ ├── evaluation │ │ └── evaluate_mvbench.py │ └── inference │ │ ├── ddp.py │ │ └── infer.py ├── vcgbench │ ├── gpt_evaluation │ │ ├── evaluate_benchmark_1_correctness.py │ │ ├── evaluate_benchmark_2_detailed_orientation.py │ │ ├── evaluate_benchmark_3_context.py │ │ ├── evaluate_benchmark_4_temporal.py │ │ ├── evaluate_benchmark_5_consistency.py │ │ └── vcgbench_evaluate.sh │ └── inference │ │ ├── ddp.py │ │ ├── infer_consistency.py │ │ ├── infer_general.py │ │ └── run_ddp_inference.sh ├── vcgbench_diverse │ ├── gpt_evaluation │ │ ├── 1_correctness_of_information.py │ │ ├── 2_detailed_orientation.py │ │ ├── 3_contextual_information.py │ │ ├── 4_temporal_information.py │ │ ├── 5_consistency.py │ │ ├── dense_captioning_spatial_and_reasoning_scores.py │ │ └── vcgbench_diverse_evaluate.sh │ ├── inference │ │ ├── ddp.py │ │ ├── infer.py │ │ └── run_ddp_inference.sh │ └── qa_generation │ │ └── generate_vcgbench_diverse_qa.py └── video_encoding.py ├── requirements.txt ├── scripts ├── README.md ├── finetune_dual_encoder.sh ├── pretrain_projector_image_encoder.sh ├── pretrain_projector_video_encoder.sh ├── zero.json ├── zero2.json ├── zero3.json └── zero3_offload.json └── videogpt_plus ├── __init__.py ├── config ├── __init__.py └── dataset_config.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model ├── __init__.py ├── arch.py ├── builder.py ├── dataloader.py ├── internvideo │ ├── build_internvideo.py │ ├── config.py │ ├── easydict.py │ ├── flash_attention_class.py │ ├── internvideo2.py │ ├── internvideo2_stage2_config_vision.py │ ├── pos_embed.py │ └── utils.py ├── language_model │ └── phi3.py ├── multimodal_encoder │ ├── builder.py │ ├── clip_encoder.py │ └── processor.py └── multimodal_projector │ └── builder.py └── train ├── pretrain.py ├── train.py └── trainer.py /annotation_pipeline/1_scenedetect_and_keyframes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Semi-automatic Video Annotation Pipeline - Step # 1: Detect scenes and extract keyframes 3 | 4 | Copyright 2024 MBZUAI ORYX 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import argparse 20 | from Katna.video import Video 21 | from Katna.writer import KeyFrameDiskWriter 22 | import os 23 | from scenedetect import detect, ContentDetector, split_video_ffmpeg, open_video, SceneManager 24 | import warnings 25 | import json 26 | from tqdm import tqdm 27 | import sys 28 | import contextlib 29 | 30 | # Suppress FutureWarnings 31 | warnings.simplefilter(action='ignore', category=FutureWarning) 32 | 33 | 34 | def parse_args(): 35 | """ 36 | Command-line argument parser. 37 | """ 38 | parser = argparse.ArgumentParser(description="Detect scenes and extract keyframes.") 39 | 40 | parser.add_argument("--video_dir", required=True, help="Directory containing ActivityNet videos.") 41 | 42 | parser.add_argument("--ann_video_ids_file", required=True, 43 | help="Path to the unique video ids JSON file (e.g. path to unique_video_ids.json).") 44 | parser.add_argument("--gt_caption_file", required=True, 45 | help="Path to the ground truth captions file (e.g. path to activitynet_gt_captions_train.json).") 46 | 47 | parser.add_argument("--scene_output_dir", required=False, help="Path to save the scene files.", default="scenes") 48 | parser.add_argument("--frames_output_dir", required=False, help="Path to save the keyframes.", default="key_frames") 49 | parser.add_argument("--num_keyframes", type=int, default=1, help="Number of keyframes to extract per scene.") 50 | 51 | return parser.parse_args() 52 | 53 | 54 | @contextlib.contextmanager 55 | def suppress_output(): 56 | with open(os.devnull, "w") as devnull: 57 | old_stdout = sys.stdout 58 | sys.stdout = devnull 59 | try: 60 | yield 61 | finally: 62 | sys.stdout = old_stdout 63 | 64 | 65 | def get_keyframes(video_path, num_keyframes, output_dir): 66 | """ 67 | Extracts keyframes using Katna from the video and returns their file paths, 68 | operating within a temporary directory. 69 | """ 70 | # Create a temporary directory for extracted frames 71 | # Initialize video module and disk writer 72 | vd = Video() 73 | diskwriter = KeyFrameDiskWriter(location=output_dir) 74 | 75 | # Suppress print output during keyframe extraction 76 | with suppress_output(): 77 | vd.extract_video_keyframes(no_of_frames=num_keyframes, file_path=video_path, writer=diskwriter) 78 | 79 | return None 80 | 81 | 82 | def get_scenes(video_path, output_dir): 83 | video = open_video(video_path) 84 | scene_manager = SceneManager() 85 | scene_manager.add_detector(ContentDetector()) 86 | scene_manager.detect_scenes(video) 87 | # If `start_in_scene` is True, len(scene_list) will always be >= 1 88 | scene_list = scene_manager.get_scene_list(start_in_scene=True) 89 | split_video_ffmpeg(video_path, scene_list, output_dir) 90 | 91 | return scene_list 92 | 93 | 94 | def main(): 95 | args = parse_args() 96 | os.makedirs(args.scene_output_dir, exist_ok=True) 97 | os.makedirs(args.frames_output_dir, exist_ok=True) 98 | with open(args.ann_video_ids_file, 'r') as file: 99 | data = json.load(file) 100 | video_ids_to_annotate = data['v2_videos'] 101 | 102 | # Read ground truth captions file 103 | gt_file = args.gt_caption_file 104 | with open(gt_file) as file: 105 | gt_json_data = json.load(file) 106 | 107 | video_ids_to_annotate = [id for id in video_ids_to_annotate if id in gt_json_data] 108 | 109 | files_to_annotate = [file for file in os.listdir(args.video_dir) if file.split('.')[0] in video_ids_to_annotate] 110 | 111 | for file in tqdm(files_to_annotate): 112 | try: 113 | video_id = file.split('.')[0] 114 | video_path = os.path.join(args.video_dir, file) 115 | curr_scene_dir = f'{args.scene_output_dir}/{video_id}' 116 | _ = get_scenes(video_path, curr_scene_dir) # Extract the scenes and save in the curr_scene_dir 117 | scenes_to_annotate = os.listdir(curr_scene_dir) 118 | for scene in tqdm(scenes_to_annotate): 119 | sce_video_path = os.path.join(curr_scene_dir, scene) 120 | get_keyframes(sce_video_path, num_keyframes=args.num_keyframes, output_dir=args.frames_output_dir) 121 | except Exception as e: 122 | print(f"Error processing video {file}: {e}") 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /annotation_pipeline/2_caption_keyframe_llava.py: -------------------------------------------------------------------------------- 1 | """ 2 | Semi-automatic Video Annotation Pipeline - Step # 2: Frame level detailed captioning using LLaVA-v1.6-34b 3 | 4 | Copyright 2024 MBZUAI ORYX 5 | Copyright 2024 LLaVA https://github.com/haotian-liu/LLaVA 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | 20 | import argparse 21 | import torch 22 | from llava.constants import (IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, 23 | DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER, ) 24 | from llava.conversation import conv_templates, SeparatorStyle 25 | from llava.model.builder import load_pretrained_model 26 | from llava.utils import disable_torch_init 27 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 28 | from tqdm import tqdm 29 | from PIL import Image 30 | from io import BytesIO 31 | import requests 32 | import json 33 | import re 34 | import os 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser() 39 | 40 | parser.add_argument("--key_frame_dir", type=str, required=False, help="Directory containing extracted keyframes.", 41 | default="key_frames") 42 | parser.add_argument("--output_dir", type=str, required=False, default='llava_captions_keyframes', 43 | help="Directory to save output files.") 44 | parser.add_argument("--question", type=str, default="Describe the image in detail.", 45 | help="Question to ask about the image.") 46 | 47 | parser.add_argument("--model-path", type=str, required=False, help="Path to the pretrained model.", 48 | default="liuhaotian/llava-v1.6-34b") 49 | parser.add_argument("--model-base", type=str, default=None, help="Base model to use.") 50 | parser.add_argument("--conv-mode", type=str, default=None, help="Conversation mode.") 51 | parser.add_argument("--sep", type=str, default=",", help="Separator.") 52 | parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.") 53 | parser.add_argument("--top_p", type=float, default=None, help="Top-p sampling parameter.") 54 | parser.add_argument("--num_beams", type=int, default=1, help="Number of beams for beam search.") 55 | parser.add_argument("--max_new_tokens", type=int, default=512, help="Maximum number of new tokens to generate.") 56 | 57 | return parser.parse_args() 58 | 59 | 60 | def load_image(image_file): 61 | if image_file.startswith("http") or image_file.startswith("https"): 62 | response = requests.get(image_file) 63 | image = Image.open(BytesIO(response.content)).convert("RGB") 64 | else: 65 | image = Image.open(image_file).convert("RGB") 66 | return image 67 | 68 | 69 | def load_images(image_files): 70 | out = [] 71 | for image_file in image_files: 72 | image = load_image(image_file) 73 | out.append(image) 74 | return out 75 | 76 | 77 | def load_model(args): 78 | # Model 79 | disable_torch_init() 80 | 81 | model_name = get_model_name_from_path(args.model_path) 82 | tokenizer, model, image_processor, context_len = load_pretrained_model( 83 | args.model_path, args.model_base, model_name 84 | ) 85 | 86 | if "v1.6-34b" in model_name.lower(): 87 | conv_mode = "chatml_direct" 88 | elif "v1" in model_name.lower(): 89 | conv_mode = "llava_v1" 90 | 91 | return model, image_processor, tokenizer, conv_mode 92 | 93 | 94 | def prepare_conv(qs, model, tokenizer, conv_mode): 95 | conv = conv_templates[conv_mode].copy() 96 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 97 | if IMAGE_PLACEHOLDER in qs: 98 | if model.config.mm_use_im_start_end: 99 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 100 | else: 101 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 102 | else: 103 | if model.config.mm_use_im_start_end: 104 | qs = image_token_se + "\n" + qs 105 | else: 106 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 107 | conv.append_message(conv.roles[0], qs) 108 | conv.append_message(conv.roles[1], None) 109 | prompt = conv.get_prompt() 110 | input_ids = (tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()) 111 | return input_ids 112 | 113 | 114 | def inference(image_files, input_ids, model, image_processor, tokenizer, args): 115 | images = load_images(image_files) 116 | image_sizes = [x.size for x in images] 117 | images_tensor = process_images( 118 | images, 119 | image_processor, 120 | model.config 121 | ).to(model.device, dtype=torch.float16) 122 | 123 | with torch.inference_mode(): 124 | output_ids = model.generate( 125 | input_ids, 126 | images=images_tensor, 127 | image_sizes=image_sizes, 128 | do_sample=True if args.temperature > 0 else False, 129 | temperature=args.temperature, 130 | top_p=args.top_p, 131 | num_beams=args.num_beams, 132 | max_new_tokens=args.max_new_tokens, 133 | use_cache=True, 134 | ) 135 | 136 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 137 | return outputs 138 | 139 | 140 | def main(args): 141 | key_frame_dir = args.key_frame_dir 142 | key_frame_files = os.listdir(key_frame_dir) 143 | output_dir = args.output_dir 144 | os.makedirs(output_dir, exist_ok=True) 145 | model, image_processor, tokenizer, conv_mode = load_model(args) 146 | 147 | question = args.question 148 | 149 | input_ids = prepare_conv(question, model, tokenizer, conv_mode) 150 | 151 | for file in tqdm(key_frame_files): 152 | file_name = file.split('.')[0] 153 | output_path = os.path.join(output_dir, f'{file_name}.json') 154 | if not os.path.exists(output_path): 155 | image_path = os.path.join(key_frame_dir, file) 156 | image_files = [image_path] 157 | result = inference(image_files, input_ids, model, image_processor, tokenizer, args) 158 | 159 | result_dict = {'result': result} 160 | with open(output_path, 'w') as f: 161 | json.dump(result_dict, f, indent=2) 162 | 163 | 164 | if __name__ == "__main__": 165 | args = parse_args() 166 | main(args) 167 | -------------------------------------------------------------------------------- /annotation_pipeline/3_dense_video_description.py: -------------------------------------------------------------------------------- 1 | """ 2 | Semi-automatic Video Annotation Pipeline - Step # 3: Use short ground truth caption along with the frame-level detailed captions to generate a detailed video caption using GPT4-Turbo. 3 | 4 | Copyright 2024 MBZUAI ORYX 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import openai 20 | import os 21 | import json 22 | import time 23 | import argparse 24 | import warnings 25 | from tqdm import tqdm 26 | from multiprocessing.pool import Pool 27 | 28 | # Suppressing all warnings 29 | warnings.filterwarnings('ignore') 30 | 31 | 32 | def parse_args(): 33 | """ 34 | Command-line argument parser. 35 | """ 36 | parser = argparse.ArgumentParser(description="Detailed video caption using GPT4-Turbo.") 37 | 38 | parser.add_argument("--ann_video_ids_file", required=True, 39 | help="Path to the JSON file with unique video IDs (e.g. path to unique_video_ids.json).") 40 | parser.add_argument("--output_dir", required=False, help="Directory to save the annotation JSON files.", 41 | default="video_descriptions") 42 | parser.add_argument("--captions_dir", required=False, help="Directory path containing generated video captions.", 43 | default="llava_captions_keyframes") 44 | parser.add_argument("--gt_caption_file", required=True, 45 | help="Path to the ground truth captions file (e.g. path to activitynet_gt_captions_train.json).") 46 | parser.add_argument("--api_keys", required=True, nargs='+', help="List of OpenAI API keys.") 47 | parser.add_argument("--num_tasks", type=int, default=16, help="Number of splits.") 48 | 49 | return parser.parse_args() 50 | 51 | 52 | def get_caption_summary_prompt(gt_caption, predicted_captions): 53 | prompt_prefix_1 = "Generate a detailed and accurate description of a video based on the given ground-truth video caption and multiple frame-level captions. " \ 54 | "Use the following details to create a clear and complete narrative:\n" 55 | prompt_prefix_2 = "\nGround-truth Video Caption: " 56 | prompt_prefix_3 = "\nFrame-level Captions: " 57 | prompt_suffix = """\n\nInstructions for writing the detailed description: 58 | 1. Focus on describing key visual details such as appearance, motion, sequence of actions, objects involved, and interactions between elements in the video. 59 | 2. Check for consistency between the ground-truth caption and frame-level captions, and prioritize details that match the ground-truth caption. Ignore any conflicting or irrelevant details from the frame-level captions. 60 | 3. Leave out any descriptions about the atmosphere, mood, style, aesthetics, proficiency, or emotional tone of the video. 61 | 4. Make sure the description is no more than 20 sentences. 62 | 5. Combine and organize information from all captions into one clear and detailed description, removing any repeated or conflicting details. 63 | 6. Emphasize important points like the order of events, appearance and actions of people or objects, and any significant changes or movements. 64 | 7. Do not mention that the information comes from ground-truth captions or frame-level captions. 65 | 8. Give a brief yet thorough description, highlighting the key visual and temporal details while keeping it clear and easy to understand. 66 | Use your intelligence to combine and refine the captions into a brief yet informative description of the entire video.""" 67 | 68 | # Create the prompt by iterating over the list_of_elements and formatting the template 69 | prompt = prompt_prefix_1 70 | prompt += f"{prompt_prefix_2}{gt_caption}{prompt_prefix_3}{'; '.join(predicted_captions)}" 71 | prompt += prompt_suffix 72 | 73 | return prompt 74 | 75 | 76 | def annotate(gt_file, caption_files, output_dir, captions_dir, api_key): 77 | """ 78 | Generate question-answer pairs using caption and 79 | dense-captions summarized from off-the-shelf models using OpenAI GPT-3. 80 | """ 81 | openai.api_key = api_key # Set the OpenAI API key for this process 82 | 83 | for file in tqdm(caption_files): 84 | annotated_dit = {} 85 | key = file.split('.')[0] 86 | gt_caption = get_gt_caption(gt_file, key) 87 | 88 | # Get pre-computed off-the-shelf predictions 89 | prediction_captions = get_pseudo_caption(captions_dir, key) 90 | 91 | # Summarize pre-computed off-the-shelf predictions into dense caption 92 | summary_prompt = get_caption_summary_prompt(gt_caption, prediction_captions) 93 | 94 | dense_caption_summary = openai.ChatCompletion.create( 95 | model="gpt-4-turbo", messages=[{"role": "user", "content": summary_prompt}] 96 | ) 97 | dense_caption = '' 98 | for choice in dense_caption_summary.choices: 99 | dense_caption += choice.message.content 100 | 101 | annotated_dit['dense_caption'] = dense_caption 102 | 103 | # Save the response dictionary into a JSON file 104 | json_file_path = os.path.join(output_dir, f"{key}.json") 105 | with open(json_file_path, "w", encoding='utf-8') as f: 106 | json.dump(annotated_dit, f, ensure_ascii=False, indent=2) 107 | 108 | print(f"Completed, Annotations saved in {output_dir}") 109 | 110 | 111 | def get_gt_caption(json_data, video_id): 112 | video_data = json_data[video_id] 113 | gt_captions = video_data['sentences'] 114 | gt_caption = ''.join(gt_captions) 115 | return gt_caption 116 | 117 | 118 | def get_pseudo_caption(pseudo_data_dir, video_id): 119 | curr_files = [file for file in os.listdir(pseudo_data_dir) if file.startswith(video_id)] 120 | pred_captions = [] 121 | for file in curr_files: 122 | pred_caption = json.load(open(f'{pseudo_data_dir}/{file}'))['result'] 123 | pred_captions.append(pred_caption) 124 | return pred_captions 125 | 126 | 127 | def main(): 128 | """ 129 | Main function to control the flow of the program. 130 | """ 131 | # Parse arguments 132 | args = parse_args() 133 | os.makedirs(args.output_dir, exist_ok=True) 134 | 135 | with open(args.ann_video_ids_file, 'r') as file: 136 | data = json.load(file) 137 | video_ids_to_annotate = data['v2_videos'] 138 | 139 | # Read ground truth captions file 140 | gt_file = args.gt_caption_file 141 | with open(gt_file) as file: 142 | gt_json_data = json.load(file) 143 | 144 | video_ids_to_annotate = [id for id in video_ids_to_annotate if id in gt_json_data] 145 | 146 | # Prepare list of caption files 147 | caption_files = [f'{video_id}.json' for video_id in video_ids_to_annotate] 148 | 149 | # List of OpenAI API keys 150 | api_keys = args.api_keys 151 | 152 | num_tasks = args.num_tasks 153 | 154 | # Main loop: Continues until all question-answer pairs are generated for all captions 155 | while True: 156 | try: 157 | # Files that have already been completed. 158 | completed_files = os.listdir(args.output_dir) 159 | print(f"completed_files: {len(completed_files)}") 160 | 161 | # Files that have not been processed yet. 162 | incomplete_files = [f for f in caption_files if f not in completed_files] 163 | print(f"incomplete_files: {len(incomplete_files)}") 164 | 165 | if len(incomplete_files) == 0: 166 | print("All tasks completed!") 167 | break 168 | 169 | if len(incomplete_files) <= num_tasks: 170 | num_tasks = 1 171 | 172 | # Split tasks into parts. 173 | num_tasks = min(len(incomplete_files), num_tasks) 174 | part_len = len(incomplete_files) // num_tasks 175 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 176 | 177 | # Distribute API keys to tasks 178 | task_args = [(gt_json_data, part, args.output_dir, args.captions_dir, api_keys[i % len(api_keys)]) for 179 | i, part 180 | in enumerate(all_parts)] 181 | 182 | # Use a pool of workers to process the files in parallel. 183 | with Pool() as pool: 184 | pool.starmap(annotate, task_args) 185 | 186 | except Exception as e: 187 | print(f"Error: {e}") 188 | print("Sleeping for 1 minute...") 189 | time.sleep(60) # wait for 1 minute before trying again 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /annotation_pipeline/run_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ## Path containing the videos 4 | VIDEO_DIR_PATH=$1 5 | ## Path to unique_video_ids.json file 6 | ANN_VIDEO_IDS_FILE=$2 7 | ## Path to ActivityNet GT captions 8 | GT_CAPTION_FILE=$3 9 | ## Output directory path to store the intermediate and final outputs 10 | OUTPUT_DIR_PATH=$4 11 | 12 | 13 | ## Step # 1: Detect scenes and extract keyframes 14 | python 1_scenedetect_and_keyframes.py --video_dir "$VIDEO_DIR_PATH" --ann_video_ids_file "$ANN_VIDEO_IDS_FILE" --gt_caption_file "$GT_CAPTION_FILE" --scene_output_dir "$OUTPUT_DIR_PATH/scenes" --frames_output_dir "$OUTPUT_DIR_PATH/key_frames" 15 | 16 | 17 | ## Step # 2: Frame level detailed captioning using LLaVA-v1.6-34b 18 | python 2_caption_keyframe_llava.py --key_frame_dir "$OUTPUT_DIR_PATH/key_frames" --output_dir "$OUTPUT_DIR_PATH/llava_captions_keyframes" 19 | 20 | 21 | ## Step # 3: Use short ground truth caption along with the frame-level detailed captions to generate a detailed video caption using GPT4-Turbo. 22 | python 3_dense_video_description.py --ann_video_ids_file "$ANN_VIDEO_IDS_FILE" --gt_caption_file "$GT_CAPTION_FILE" --captions_dir "$OUTPUT_DIR_PATH/llava_captions_keyframes" --output_dir "$OUTPUT_DIR_PATH/video_descriptions" 23 | 24 | 25 | ## Step # 4: Generate QA pairs using video descriptions generated in Step # 3 using GPT-3.5-Turbo. 26 | python 4_generate_qa.py --ann_video_ids_file "$ANN_VIDEO_IDS_FILE" --gt_caption_file "$GT_CAPTION_FILE" --video_descriptions_path "$OUTPUT_DIR_PATH/video_descriptions" --output_dir "$OUTPUT_DIR_PATH/video_qa" -------------------------------------------------------------------------------- /docs/images/IVAL_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/IVAL_logo.png -------------------------------------------------------------------------------- /docs/images/MBZUAI_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/MBZUAI_logo.png -------------------------------------------------------------------------------- /docs/images/MVBench_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/MVBench_quantitative.png -------------------------------------------------------------------------------- /docs/images/Oryx_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/Oryx_logo.png -------------------------------------------------------------------------------- /docs/images/VCGBench_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/VCGBench_quantitative.png -------------------------------------------------------------------------------- /docs/images/VCGDiverse_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/VCGDiverse_quantitative.png -------------------------------------------------------------------------------- /docs/images/block_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/block_diagram.png -------------------------------------------------------------------------------- /docs/images/demo_vcg+_full_part1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/demo_vcg+_full_part1.jpg -------------------------------------------------------------------------------- /docs/images/demo_vcg+_full_part2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/demo_vcg+_full_part2.jpg -------------------------------------------------------------------------------- /docs/images/demo_vcg+_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/demo_vcg+_main.png -------------------------------------------------------------------------------- /docs/images/intro_radar_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/intro_radar_plot.png -------------------------------------------------------------------------------- /docs/images/vcg120k_block_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/vcg120k_block_diagram.png -------------------------------------------------------------------------------- /docs/images/vcgbench_block_diag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/vcgbench_block_diag.png -------------------------------------------------------------------------------- /docs/images/videogpt_plus_face.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/videogpt_plus_face.jpeg -------------------------------------------------------------------------------- /docs/images/zero_shot_quantitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbzuai-oryx/VideoGPT-plus/7eac0aba1042c110cbc1645d2d5ed8e98ca5107b/docs/images/zero_shot_quantitative.png -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Quantitative Evaluation 📊 2 | 3 | We provide instructions to evaluate VideoGPT+ model on VCGBench, VCGBench-Diverse and MVBench. Please follow the instructions below, 4 | 5 | ## VCGBench 6 | VCGBench is a commonly used benchmark for video-conversation models, proposed in `Video-ChatGPT` work. It uses GPT-3.5-Turbo to evaluate Correctness of Information (CI), Detail Orientation (DO), 7 | Contextual Understanding (CU), Temporal Understanding (TU) and Consistency (CO) of a video conversation model. Please follow the steps below to evaluate VideoGPT+ on VCGBench. 8 | 9 | ### Download VCGBench Dataset 10 | You can download the videos and annotations following the instructions on the official page [https://mbzuai-oryx.github.io/Video-ChatGPT](https://mbzuai-oryx.github.io/Video-ChatGPT). 11 | 12 | ### Download the VideoGPT+ Model 13 | All the VideoGPT+ models are available on [HuggingFace](https://huggingface.co/collections/MBZUAI/videogpt-665c8643221dda4987a67d8d). Please follow the instructions below to download, 14 | 15 | Save the downloaded dataset under `MBZUAI` directory. 16 | 17 | ```bash 18 | 19 | mkdir MBZUAI 20 | 21 | git lfs install 22 | git clone https://huggingface.co/MBZUAI/VideoGPT-plus_Phi3-mini-4k 23 | ``` 24 | 25 | ### Run Inference 26 | We provide [eval/vcgbench/run_ddp_inference.sh](eval/vcgbench/run_ddp_inference.sh) script to run inference on multiple GPUs, 27 | 28 | ```bash 29 | 30 | bash eval/vcgbench/run_ddp_inference.sh MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench microsoft/Phi-3-mini-4k-instruct MBZUAI/VCGBench 31 | 32 | ``` 33 | 34 | Where `MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench` is the path to VideoGPT+ pretrained checkpoints, `microsoft/Phi-3-mini-4k-instruct` is the base model path and `MBZUAI/VCGBench` is the VCGBench dataset path. 35 | 36 | ### Evaluation 37 | We provide evaluation scripts using GPT-3.5-Turbo. Please use the script [eval/vcgbench/gpt_evaluation/vcgbench_evaluate.sh](eval/vcgbench/gpt_evaluation/vcgbench_evaluate.sh) for evaluation. 38 | 39 | 40 | ## VCGBench-Diverse 41 | VCGBench-Diverse is our proposed benchmarks which effectively addresses the limitations of VCGBench by including videos from 18 broad video categories. We use GPT-3.5-Turbo for the evaluation and report results for 42 | Correctness of Information (CI), Detail Orientation (DO), 43 | Contextual Understanding (CU), Temporal Understanding (TU), Consistency (CO), 44 | Dense Captioning, Spatial Understanding and Reasoning Abilities of video conversation models. Please follow the steps below to evaluate VideoGPT+ on VCGBench-Diverse. 45 | 46 | 47 | ```bash 48 | # Download and extract the VCGBench-Diverse dataset 49 | mkdir MBZUAI 50 | cd MBZUAI 51 | git lfs install 52 | git clone https://huggingface.co/datasets/MBZUAI/VCGBench-Diverse 53 | cd VCGBench-Diverse 54 | tar -xvf videos.tar.gz 55 | 56 | # Run inference 57 | bash eval/vcgbench_diverse/inference/run_ddp_inference.sh MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench microsoft/Phi-3-mini-4k-instruct MBZUAI/VCGBench-Diverse 58 | 59 | # Run GPT-3.5-Turbo evaluation (replace with your OpenAI API Key) 60 | bash eval/vcgbench_diverse/gpt_evaluation/vcgbench_diverse_evaluate.sh MBZUAI/VCGBench-Diverse/vcgbench_diverse_qa.json MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench/vcgbench_diverse_eval/answer-vcgbench-diverse.json MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench/vcgbench_diverse_eval/results 61 | ``` 62 | 63 | ## MVBench 64 | MVBench is a comprehensive video understanding benchmark which covers 20 challenging video tasks that cannot be effectively solved with a single frame. It is introduced in the `MVBench: A Comprehensive Multi-modal Video Understanding Benchmark` paper. 65 | Pleae follow the following steps for evaluation, 66 | 67 | ```bash 68 | # Download and extract MVBench dataset following the official huggingface link 69 | mkdir OpenGVLab 70 | git lfs install 71 | git clone https://huggingface.co/datasets/OpenGVLab/MVBench 72 | 73 | # Extract all the videos in OpenGVLab/MVBench/video 74 | 75 | # Run inference 76 | python eval/mvbench/inference/infer.py --model-path MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench --model-base microsoft/Phi-3-mini-4k-instruct 77 | 78 | # Evaluate 79 | python eval/mvbench/evaluation/evaluate_mvbench.py 80 | ``` -------------------------------------------------------------------------------- /eval/merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | 6 | 7 | def parse_args(): 8 | """ 9 | Parse command-line arguments. 10 | """ 11 | parser = argparse.ArgumentParser() 12 | 13 | # Define the command-line arguments 14 | parser.add_argument('--input_dir', help='Directory containing json files.', required=True) 15 | 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | all_json_names = os.listdir(args.input_dir) 22 | 23 | all_contents_list = [] 24 | for json_name in tqdm(all_json_names): 25 | contents = json.load(open(os.path.join(args.input_dir, json_name), 'r')) 26 | all_contents_list.append(contents) 27 | 28 | with open(f"{args.input_dir}.json", 'w') as f: 29 | json.dump(all_contents_list, f, indent=2) 30 | -------------------------------------------------------------------------------- /eval/mvbench/evaluation/evaluate_mvbench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | 7 | def check_ans(pred, gt): 8 | flag = False 9 | 10 | pred_list = pred.lower().split(' ') 11 | pred_option, pred_content = pred_list[1], ' '.join(pred_list[1:]) 12 | gt_list = gt.lower().split(' ') 13 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:]) 14 | if gt_content[-1] == '.': 15 | gt_content = gt_content[:-1] 16 | print(gt_option, pred_option) 17 | if pred_option.replace('.', '') in gt_option: 18 | flag = True 19 | elif gt_option in pred_option: 20 | flag = True 21 | 22 | return flag 23 | 24 | 25 | def main(args): 26 | result_files = os.listdir(args.output_dir) 27 | 28 | correct = 0 29 | total = 0 30 | acc_dict = {} 31 | 32 | for file in tqdm(result_files): 33 | if file.endswith('.json'): 34 | json_file = os.path.join(args.output_dir, file) 35 | json_data = json.load(open(json_file)) 36 | video_name = json_data['video_name'] 37 | task_type = json_data['task_type'] 38 | pred = json_data['pred'] 39 | gt_answer = json_data['A'] 40 | question = json_data['Q'] 41 | 42 | if task_type not in acc_dict: 43 | acc_dict[task_type] = [0, 0] # correct, total 44 | acc_dict[task_type][1] += 1 45 | total += 1 46 | if check_ans(pred=pred, gt=gt_answer): 47 | acc_dict[task_type][0] += 1 48 | correct += 1 49 | 50 | types = {'Action Sequence': 0, 'Action Prediction': 0, 'Action Antonym': 0, 'Fine-grained Action': 0, 51 | 'Unexpected Action': 0, 'Object Existence': 0, 'Object Interaction': 0, 'Object Shuffle': 0, 52 | 'Moving Direction': 0, 'Action Localization': 0, 'Scene Transition': 0, 'Action Count': 0, 53 | 'Moving Count': 0, 'Moving Attribute': 0, 'State Change': 0, 'Fine-grained Pose': 0, 'Character Order': 0, 54 | 'Egocentric Navigation': 0, 'Episodic Reasoning': 0, 'Counterfactual Inference': 0} 55 | 56 | result_list = [] 57 | for task_type, v in types.items(): 58 | print('-' * 30, task_type, '-' * 30) 59 | Acc = acc_dict[task_type][0] / acc_dict[task_type][1] * 100 60 | print(f"{task_type} Acc: {Acc :.2f}%") 61 | result_list.append(Acc) 62 | print(f"All Acc: {result_list}%") 63 | print(f"Total Acc: {correct / total * 100 :.2f}%") 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("--output_dir", type=str, default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench_eval") 69 | args = parser.parse_args() 70 | main(args) 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /eval/mvbench/inference/ddp.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | import torch 4 | import subprocess 5 | from videogpt_plus.constants import * 6 | from eval.video_encoding import _get_rawvideo_dec, read_frame_mod, read_gif_mod 7 | 8 | 9 | class EvalDatasetMvBench(Dataset): 10 | def __init__(self, gt_dir, video_dir, image_processor, video_processor, mvbench_data_list): 11 | self.gt_contents = [] 12 | for k, v in mvbench_data_list.items(): 13 | with open(os.path.join(gt_dir, v[0]), 'r') as f: 14 | json_data = json.load(f) 15 | for data in json_data: 16 | self.gt_contents.append( 17 | {'task_type': k, 'prefix': v[1], 'data_type': v[2], 'bound': v[3], 'data': data} 18 | ) 19 | self.video_dir = video_dir 20 | self.image_processor = image_processor 21 | self.video_processor = video_processor 22 | 23 | def __len__(self): 24 | return len(self.gt_contents) 25 | 26 | def __getitem__(self, idx): 27 | sample = self.gt_contents[idx] 28 | 29 | task_type = sample['task_type'] 30 | 31 | if sample['bound']: 32 | bound = (sample['data']['start'], sample['data']['end'],) 33 | else: 34 | bound = None 35 | data_type = sample['data_type'] 36 | prefix = sample['prefix'].replace('your_data_path/', '') 37 | video_name = sample['data']['video'] 38 | video_path = os.path.join(self.video_dir, prefix, video_name) 39 | if os.path.exists(video_path): 40 | if data_type == 'video': 41 | if bound: 42 | video_frames, context_frames, slice_len = ( 43 | _get_rawvideo_dec(video_path, self.image_processor, self.video_processor, s=bound[0], 44 | e=bound[1], max_frames=NUM_FRAMES, image_resolution=224, 45 | num_video_frames=NUM_FRAMES, num_context_images=NUM_CONTEXT_IMAGES)) 46 | else: 47 | video_frames, context_frames, slice_len = ( 48 | _get_rawvideo_dec(video_path, self.image_processor, self.video_processor, 49 | max_frames=NUM_FRAMES, image_resolution=224, 50 | num_video_frames=NUM_FRAMES, num_context_images=NUM_CONTEXT_IMAGES)) 51 | elif data_type == 'gif': 52 | if bound: 53 | video_frames, slice_len = read_gif_mod( 54 | video_path, self.image_processor, s=bound[0], e=bound[1], max_frames=NUM_FRAMES 55 | ) 56 | else: 57 | video_frames, slice_len = read_gif_mod( 58 | video_path, self.image_processor, max_frames=NUM_FRAMES 59 | ) 60 | elif data_type == 'frame': 61 | if bound: 62 | video_frames, context_frames, slice_len = read_frame_mod( 63 | video_path, self.image_processor, self.video_processor, s=bound[0], e=bound[1], 64 | max_frames=NUM_FRAMES, image_resolution=224, 65 | num_video_frames=NUM_FRAMES, num_context_images=NUM_CONTEXT_IMAGES 66 | ) 67 | else: 68 | video_frames, context_frames, slice_len = read_frame_mod( 69 | video_path, self.image_processor, self.image_processor, max_frames=NUM_FRAMES, 70 | image_resolution=224, 71 | num_video_frames=NUM_FRAMES, num_context_images=NUM_CONTEXT_IMAGES 72 | ) 73 | else: 74 | video_frames, slice_len = "None", 0 75 | print('Video not found:', video_path) 76 | 77 | sample_set = {} 78 | question, answer = qa_template(sample['data']) 79 | sample_set['video_name'] = f'{prefix}_{video_name}' 80 | sample_set['Q'] = question 81 | sample_set['A'] = answer 82 | sample_set['task_type'] = task_type 83 | 84 | return idx, [sample_set], video_frames, context_frames, slice_len 85 | 86 | 87 | def qa_template(data): 88 | question = f"Question: {data['question']}\n" 89 | question += "Options:\n" 90 | answer = data['answer'] 91 | answer_idx = -1 92 | for idx, c in enumerate(data['candidates']): 93 | question += f"({chr(ord('A') + idx)}) {c}\n" 94 | if c == answer: 95 | answer_idx = idx 96 | question = question.rstrip() 97 | answer = f"({chr(ord('A') + answer_idx)}) {answer}" 98 | 99 | # Add the instruction to question 100 | question_prompt = "\nOnly give the best option." # to change 101 | question += question_prompt 102 | 103 | return question, answer 104 | 105 | 106 | def setup_for_distributed(is_master): 107 | """ 108 | This function disables printing when not in master process 109 | """ 110 | import builtins as __builtin__ 111 | builtin_print = __builtin__.print 112 | 113 | def print(*args, **kwargs): 114 | force = kwargs.pop('force', False) 115 | if is_master or force: 116 | builtin_print(*args, **kwargs) 117 | 118 | __builtin__.print = print 119 | 120 | 121 | def init_distributed_mode(args): 122 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 123 | args.rank = int(os.environ["RANK"]) 124 | args.world_size = int(os.environ['WORLD_SIZE']) 125 | args.gpu = int(os.environ['LOCAL_RANK']) 126 | args.dist_url = 'env://' 127 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 128 | print('Using distributed mode: 1') 129 | elif 'SLURM_PROCID' in os.environ: 130 | proc_id = int(os.environ['SLURM_PROCID']) 131 | ntasks = int(os.environ['SLURM_NTASKS']) 132 | node_list = os.environ['SLURM_NODELIST'] 133 | num_gpus = torch.cuda.device_count() 134 | addr = subprocess.getoutput( 135 | 'scontrol show hostname {} | head -n1'.format(node_list)) 136 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '3460') 137 | os.environ['MASTER_ADDR'] = addr 138 | os.environ['WORLD_SIZE'] = str(ntasks) 139 | os.environ['RANK'] = str(proc_id) 140 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 141 | os.environ['LOCAL_SIZE'] = str(num_gpus) 142 | args.dist_url = 'env://' 143 | args.world_size = ntasks 144 | args.rank = proc_id 145 | args.gpu = proc_id % num_gpus 146 | print('Using distributed mode: slurm') 147 | print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']}," 148 | f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}") 149 | else: 150 | print('Not using distributed mode') 151 | args.distributed = False 152 | return 153 | 154 | args.distributed = True 155 | 156 | torch.cuda.set_device(args.gpu) 157 | args.dist_backend = 'nccl' 158 | print('| distributed init (rank {}): {}'.format( 159 | args.rank, args.dist_url), flush=True) 160 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 161 | world_size=args.world_size, rank=args.rank) 162 | torch.distributed.barrier() 163 | setup_for_distributed(args.rank == 0) 164 | -------------------------------------------------------------------------------- /eval/mvbench/inference/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import shortuuid 4 | from videogpt_plus.conversation import conv_templates 5 | from videogpt_plus.model.builder import load_pretrained_model 6 | from videogpt_plus.mm_utils import tokenizer_image_token, get_model_name_from_path 7 | from eval.mvbench.inference.ddp import * 8 | from torch.utils.data import DataLoader, DistributedSampler 9 | import traceback 10 | 11 | 12 | def disable_torch_init(): 13 | """ 14 | Disable the redundant torch default initialization to accelerate model creation. 15 | """ 16 | import torch 17 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 18 | 19 | 20 | mvbench_data_list = { 21 | "Episodic Reasoning": ("episodic_reasoning.json", "your_data_path/tvqa/frames_fps3_hq/", "frame", True), 22 | "Action Sequence": ("action_sequence.json", "your_data_path/star/Charades_v1_480/", "video", True), 23 | "Action Prediction": ("action_prediction.json", "your_data_path/star/Charades_v1_480/", "video", True), 24 | "Action Antonym": ("action_antonym.json", "your_data_path/ssv2_video/", "video", False), 25 | "Fine-grained Action": ("fine_grained_action.json", "your_data_path/Moments_in_Time_Raw/videos/", "video", False), 26 | "Unexpected Action": ("unexpected_action.json", "your_data_path/FunQA_test/test/", "video", False), 27 | "Object Existence": ("object_existence.json", "your_data_path/clevrer/video_validation/", "video", False), 28 | "Object Interaction": ("object_interaction.json", "your_data_path/star/Charades_v1_480/", "video", True), 29 | "Object Shuffle": ("object_shuffle.json", "your_data_path/perception/videos/", "video", False), 30 | "Moving Direction": ("moving_direction.json", "your_data_path/clevrer/video_validation/", "video", False), 31 | "Action Localization": ("action_localization.json", "your_data_path/sta/sta_video/", "video", True), 32 | "Scene Transition": ("scene_transition.json", "your_data_path/scene_qa/video/", "video", False), 33 | "Action Count": ("action_count.json", "your_data_path/perception/videos/", "video", False), 34 | "Moving Count": ("moving_count.json", "your_data_path/clevrer/video_validation/", "video", False), 35 | "Moving Attribute": ("moving_attribute.json", "your_data_path/clevrer/video_validation/", "video", False), 36 | "State Change": ("state_change.json", "your_data_path/perception/videos/", "video", False), 37 | "Fine-grained Pose": ("fine_grained_pose.json", "your_data_path/nturgbd/", "video", False), 38 | "Character Order": ("character_order.json", "your_data_path/perception/videos/", "video", False), 39 | "Egocentric Navigation": ("egocentric_navigation.json", "your_data_path/vlnqa/", "video", False), 40 | "Counterfactual Inference": ( 41 | "counterfactual_inference.json", "your_data_path/clevrer/video_validation/", "video", False), 42 | } 43 | 44 | 45 | def eval_model(args): 46 | # Model 47 | disable_torch_init() 48 | model_path = os.path.expanduser(args.model_path) 49 | model_name = get_model_name_from_path(args.model_path) 50 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 51 | 52 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 53 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 54 | if mm_use_im_patch_token: 55 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 56 | if mm_use_im_start_end: 57 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 58 | model.resize_token_embeddings(len(tokenizer)) 59 | 60 | vision_tower = model.get_vision_tower() 61 | vision_tower.load_model(model.config.mm_vision_tower) 62 | video_processor = vision_tower.image_processor 63 | 64 | image_vision_tower = model.get_image_vision_tower() 65 | image_vision_tower.load_model() 66 | image_processor = image_vision_tower.image_processor 67 | 68 | model = model.to("cuda") 69 | 70 | dataset = EvalDatasetMvBench(args.question_dir, args.video_folder, image_processor, 71 | video_processor, mvbench_data_list) 72 | distributed_sampler = DistributedSampler(dataset, rank=args.rank, shuffle=False) 73 | dataloader = DataLoader(dataset, batch_size=args.batch_size_per_gpu, num_workers=4, sampler=distributed_sampler) 74 | 75 | for (idx, sample_set, video_frames, context_frames, slice_len) in tqdm(dataloader): 76 | idx, sample_set, video_frames, context_frames, slice_len = int(idx[0]), sample_set[ 77 | 0], video_frames, context_frames, int(slice_len[0]) 78 | 79 | sample = sample_set 80 | qs = sample['Q'][0] 81 | 82 | try: 83 | cur_prompt = qs 84 | if model.config.mm_use_im_start_end: 85 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs 86 | else: 87 | qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs 88 | 89 | conv = conv_templates[args.conv_mode].copy() 90 | conv.append_message(conv.roles[0], qs) 91 | conv.append_message(conv.roles[1], None) 92 | prompt = conv.get_prompt() 93 | 94 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, 95 | return_tensors='pt').unsqueeze(0).cuda() 96 | 97 | # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 98 | stop_str = "<|end|>" 99 | 100 | with torch.inference_mode(): 101 | output_ids = model.generate( 102 | input_ids, 103 | images=torch.cat(video_frames, dim=0).half().cuda(), 104 | context_images=torch.cat(context_frames, dim=0).half().cuda(), 105 | do_sample=True if args.temperature > 0 else False, 106 | temperature=args.temperature, 107 | top_p=args.top_p, 108 | num_beams=args.num_beams, 109 | max_new_tokens=1024, 110 | use_cache=True) 111 | 112 | input_token_len = input_ids.shape[1] 113 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 114 | if n_diff_input_output > 0: 115 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 116 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 117 | outputs = outputs.strip() 118 | if outputs.endswith(stop_str): 119 | outputs = outputs[:-len(stop_str)] 120 | outputs = outputs.strip() 121 | outputs = outputs.replace("<|end|>", '') 122 | outputs = outputs.strip() 123 | 124 | ans_id = shortuuid.uuid() 125 | video_json_name = sample['video_name'][0].replace('/', '_') 126 | if len(video_json_name) > 100: 127 | video_json_name = video_json_name[50:] 128 | 129 | results = {'video_name': sample['video_name'][0], 130 | "prompt": cur_prompt, 131 | "pred": outputs, 132 | "answer_id": ans_id, 133 | "Q": sample_set['Q'][0], 134 | "task_type": sample['task_type'][0], 135 | "A": sample['A'][0]} 136 | with open(f"{args.output_dir}/{video_json_name}_{idx}.json", "w") as f: 137 | json.dump(results, f) 138 | except Exception as e: 139 | trace = traceback.format_exc() 140 | print(f"Error processing video file '{sample['video_name'][0]}': {e}") 141 | print("Detailed traceback:") 142 | print(trace) 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument("--model-path", type=str, default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench") 148 | parser.add_argument("--model-base", type=str, default="microsoft/Phi-3-mini-4k-instruct") 149 | parser.add_argument("--video-folder", type=str, default="OpenGVLab/MVBench/video") 150 | parser.add_argument("--question-dir", type=str, default="OpenGVLab/MVBench/json") 151 | parser.add_argument("--output-dir", type=str, default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/mvbench_eval") 152 | parser.add_argument("--conv-mode", type=str, default="phi3_instruct") 153 | parser.add_argument("--temperature", type=float, default=0.0) 154 | parser.add_argument("--top_p", type=float, default=None) 155 | parser.add_argument("--num_beams", type=int, default=1) 156 | 157 | parser.add_argument("--batch_size_per_gpu", required=False, default=1) 158 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 159 | parser.add_argument('--local_rank', default=-1, type=int) 160 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 161 | 162 | args = parser.parse_args() 163 | 164 | init_distributed_mode(args) 165 | 166 | os.makedirs(args.output_dir, exist_ok=True) 167 | 168 | eval_model(args) 169 | -------------------------------------------------------------------------------- /eval/vcgbench/gpt_evaluation/evaluate_benchmark_1_correctness.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench - Evaluation Script for Correctness of Information (CI) using gpt-3.5-turbo-0613 3 | 4 | Exactly the same prompts are used as proposed in https://github.com/mbzuai-oryx/Video-ChatGPT for a fair comparison with previous methods. 5 | """ 6 | 7 | import openai 8 | import os 9 | import argparse 10 | import json 11 | import ast 12 | from multiprocessing.pool import Pool 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="VCGBench - Evaluation Script for Correctness of Information (CI).") 18 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 19 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 20 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 21 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 22 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def annotate(prediction_set, caption_files, output_dir): 28 | """ 29 | Evaluates question and answer pairs using GPT-3 30 | Returns a score for correctness. 31 | """ 32 | for file in tqdm(caption_files): 33 | key = file[:-5] # Strip file extension 34 | qa_set = prediction_set[key] 35 | question = qa_set['q'] 36 | answer = qa_set['a'] 37 | pred = qa_set['pred'] 38 | try: 39 | # Compute the correctness score 40 | completion = openai.ChatCompletion.create( 41 | model="gpt-3.5-turbo-0613", 42 | messages=[ 43 | { 44 | "role": "system", 45 | "content": 46 | "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. " 47 | "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:" 48 | "------" 49 | "##INSTRUCTIONS: " 50 | "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n" 51 | "- The predicted answer must be factually accurate and align with the video content.\n" 52 | "- Consider synonyms or paraphrases as valid matches.\n" 53 | "- Evaluate the factual accuracy of the prediction compared to the answer." 54 | }, 55 | { 56 | "role": "user", 57 | "content": 58 | "Please evaluate the following video-based question-answer pair:\n\n" 59 | f"Question: {question}\n" 60 | f"Correct Answer: {answer}\n" 61 | f"Predicted Answer: {pred}\n\n" 62 | "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. " 63 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING." 64 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 65 | "For example, your response should look like this: {''score': 4.8}." 66 | } 67 | ] 68 | ) 69 | # Convert response to a Python dictionary. 70 | response_message = completion["choices"][0]["message"]["content"] 71 | response_dict = ast.literal_eval(response_message) 72 | result_qa_pair = [response_dict, qa_set] 73 | 74 | # Save the question-answer pairs to a json file. 75 | with open(f"{output_dir}/{key}.json", "w") as f: 76 | json.dump(result_qa_pair, f) 77 | 78 | except Exception as e: 79 | print(f"Error processing file '{key}': {e}") 80 | 81 | 82 | def main(): 83 | """ 84 | Main function to control the flow of the program. 85 | """ 86 | # Parse arguments. 87 | args = parse_args() 88 | 89 | file = args.pred_path 90 | pred_contents = json.load(open(file, 'r')) 91 | 92 | # Dictionary to store the count of occurrences for each video_id 93 | video_id_counts = {} 94 | new_pred_contents = [] 95 | 96 | # Iterate through each sample in pred_contents 97 | for sample in pred_contents: 98 | sample['video_name'] = 1 99 | video_id = sample['video_name'] 100 | if video_id in video_id_counts: 101 | video_id_counts[video_id] += 1 102 | else: 103 | video_id_counts[video_id] = 0 104 | 105 | # Create a new sample with the modified key 106 | new_sample = sample 107 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 108 | new_pred_contents.append(new_sample) 109 | 110 | # Generating list of id's and corresponding files 111 | id_list = [x['video_name'] for x in new_pred_contents] 112 | caption_files = [f"{id}.json" for id in id_list] 113 | 114 | output_dir = args.output_dir 115 | # Generate output directory if not exists. 116 | if not os.path.exists(output_dir): 117 | os.makedirs(output_dir) 118 | 119 | # Preparing dictionary of question-answer sets 120 | prediction_set = {} 121 | for sample in new_pred_contents: 122 | id = sample['video_name'] 123 | question = sample['prompt'] 124 | answer = sample['answer'] 125 | pred = sample['text'] 126 | qa_set = {"q": question, "a": answer, "pred": pred} 127 | prediction_set[id] = qa_set 128 | 129 | # Set the OpenAI API key. 130 | openai.api_key = args.api_key 131 | num_tasks = args.num_tasks 132 | 133 | # While loop to ensure that all captions are processed. 134 | while True: 135 | try: 136 | # Files that have not been processed yet. 137 | completed_files = os.listdir(output_dir) 138 | print(f"completed_files: {len(completed_files)}") 139 | 140 | # Files that have not been processed yet. 141 | incomplete_files = [f for f in caption_files if f not in completed_files] 142 | print(f"incomplete_files: {len(incomplete_files)}") 143 | 144 | # Break the loop when there are no incomplete files 145 | if len(incomplete_files) == 0: 146 | break 147 | if len(incomplete_files) <= num_tasks: 148 | num_tasks = 1 149 | 150 | # Split tasks into parts. 151 | part_len = len(incomplete_files) // num_tasks 152 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 153 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 154 | 155 | # Use a pool of workers to process the files in parallel. 156 | with Pool() as pool: 157 | pool.starmap(annotate, task_args) 158 | 159 | except Exception as e: 160 | print(f"Error: {e}") 161 | 162 | # Combine all the processed files into one 163 | combined_contents = {} 164 | json_path = args.output_json 165 | 166 | # Iterate through json files 167 | for file_name in os.listdir(output_dir): 168 | if file_name.endswith(".json"): 169 | file_path = os.path.join(output_dir, file_name) 170 | with open(file_path, "r") as json_file: 171 | content = json.load(json_file) 172 | combined_contents[file_name[:-5]] = content 173 | 174 | # Write combined content to a json file 175 | with open(json_path, "w") as json_file: 176 | json.dump(combined_contents, json_file) 177 | print("All evaluation completed!") 178 | 179 | # Calculate average score 180 | score_sum = 0 181 | count = 0 182 | for key, result in combined_contents.items(): 183 | count += 1 184 | score_match = result[0]['score'] 185 | score = int(score_match) 186 | score_sum += score 187 | average_score = score_sum / count 188 | 189 | print("Average score for correctness:", average_score) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /eval/vcgbench/gpt_evaluation/evaluate_benchmark_2_detailed_orientation.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench - Evaluation Script for Detailed Orientation (DO) using gpt-3.5-turbo-0613 3 | 4 | Exactly the same prompts are used as proposed in https://github.com/mbzuai-oryx/Video-ChatGPT for a fair comparison with previous methods. 5 | """ 6 | 7 | import openai 8 | import os 9 | import argparse 10 | import json 11 | import ast 12 | from multiprocessing.pool import Pool 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="VCGBench - Evaluation Script for Detailed Orientation (DO).") 18 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 19 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 20 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 21 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 22 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def annotate(prediction_set, caption_files, output_dir): 28 | """ 29 | Evaluates question and answer pairs using GPT-3 and 30 | returns a score for detailed orientation. 31 | """ 32 | for file in tqdm(caption_files): 33 | key = file[:-5] # Strip file extension 34 | qa_set = prediction_set[key] 35 | question = qa_set['q'] 36 | answer = qa_set['a'] 37 | pred = qa_set['pred'] 38 | try: 39 | # Compute the detailed-orientation score 40 | completion = openai.ChatCompletion.create( 41 | model="gpt-3.5-turbo-0613", 42 | messages=[ 43 | { 44 | "role": "system", 45 | "content": 46 | "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. " 47 | "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:" 48 | "------" 49 | "##INSTRUCTIONS: " 50 | "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n" 51 | "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n" 52 | "- Consider synonyms or paraphrases as valid matches.\n" 53 | "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity." 54 | }, 55 | { 56 | "role": "user", 57 | "content": 58 | "Please evaluate the following video-based question-answer pair:\n\n" 59 | f"Question: {question}\n" 60 | f"Correct Answer: {answer}\n" 61 | f"Predicted Answer: {pred}\n\n" 62 | "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. " 63 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING." 64 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 65 | "For example, your response should look like this: {''score': 4.8}." 66 | } 67 | ] 68 | ) 69 | # Convert response to a Python dictionary. 70 | response_message = completion["choices"][0]["message"]["content"] 71 | response_dict = ast.literal_eval(response_message) 72 | result_qa_pair = [response_dict, qa_set] 73 | 74 | # Save the question-answer pairs to a json file. 75 | with open(f"{output_dir}/{key}.json", "w") as f: 76 | json.dump(result_qa_pair, f) 77 | 78 | except Exception as e: 79 | print(f"Error processing file '{key}': {e}") 80 | 81 | 82 | def main(): 83 | """ 84 | Main function to control the flow of the program. 85 | """ 86 | # Parse arguments. 87 | args = parse_args() 88 | 89 | file = args.pred_path 90 | pred_contents = json.load(open(file, 'r')) 91 | 92 | # Dictionary to store the count of occurrences for each video_id 93 | video_id_counts = {} 94 | new_pred_contents = [] 95 | 96 | # Iterate through each sample in pred_contents 97 | for sample in pred_contents: 98 | sample['video_name'] = 1 99 | video_id = sample['video_name'] 100 | if video_id in video_id_counts: 101 | video_id_counts[video_id] += 1 102 | else: 103 | video_id_counts[video_id] = 0 104 | 105 | # Create a new sample with the modified key 106 | new_sample = sample 107 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 108 | new_pred_contents.append(new_sample) 109 | 110 | # Generating list of id's and corresponding files 111 | id_list = [x['video_name'] for x in new_pred_contents] 112 | caption_files = [f"{id}.json" for id in id_list] 113 | 114 | output_dir = args.output_dir 115 | # Generate output directory if not exists. 116 | if not os.path.exists(output_dir): 117 | os.makedirs(output_dir) 118 | 119 | # Preparing dictionary of question-answer sets 120 | prediction_set = {} 121 | for sample in new_pred_contents: 122 | id = sample['video_name'] 123 | question = sample['prompt'] 124 | answer = sample['answer'] 125 | pred = sample['text'] 126 | qa_set = {"q": question, "a": answer, "pred": pred} 127 | prediction_set[id] = qa_set 128 | 129 | # Set the OpenAI API key. 130 | openai.api_key = args.api_key 131 | num_tasks = args.num_tasks 132 | 133 | # While loop to ensure that all captions are processed. 134 | while True: 135 | try: 136 | # Files that have not been processed yet. 137 | completed_files = os.listdir(output_dir) 138 | print(f"completed_files: {len(completed_files)}") 139 | 140 | # Files that have not been processed yet. 141 | incomplete_files = [f for f in caption_files if f not in completed_files] 142 | print(f"incomplete_files: {len(incomplete_files)}") 143 | 144 | # Break the loop when there are no incomplete files 145 | if len(incomplete_files) == 0: 146 | break 147 | if len(incomplete_files) <= num_tasks: 148 | num_tasks = 1 149 | 150 | # Split tasks into parts. 151 | part_len = len(incomplete_files) // num_tasks 152 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 153 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 154 | 155 | # Use a pool of workers to process the files in parallel. 156 | with Pool() as pool: 157 | pool.starmap(annotate, task_args) 158 | 159 | except Exception as e: 160 | print(f"Error: {e}") 161 | 162 | # Combine all the processed files into one 163 | combined_contents = {} 164 | json_path = args.output_json 165 | 166 | # Iterate through json files 167 | for file_name in os.listdir(output_dir): 168 | if file_name.endswith(".json"): 169 | file_path = os.path.join(output_dir, file_name) 170 | with open(file_path, "r") as json_file: 171 | content = json.load(json_file) 172 | combined_contents[file_name[:-5]] = content 173 | 174 | # Write combined content to a json file 175 | with open(json_path, "w") as json_file: 176 | json.dump(combined_contents, json_file) 177 | print("All evaluation completed!") 178 | 179 | # Calculate average score 180 | score_sum = 0 181 | count = 0 182 | for key, result in combined_contents.items(): 183 | count += 1 184 | score_match = result[0]['score'] 185 | score = int(score_match) 186 | score_sum += score 187 | average_score = score_sum / count 188 | 189 | print("Average score for detailed orientation:", average_score) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /eval/vcgbench/gpt_evaluation/evaluate_benchmark_3_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench - Evaluation Script for Contextual Understanding (CU) using gpt-3.5-turbo-0613 3 | 4 | Exactly the same prompts are used as proposed in https://github.com/mbzuai-oryx/Video-ChatGPT for a fair comparison with previous methods. 5 | """ 6 | 7 | import openai 8 | import os 9 | import argparse 10 | import json 11 | import ast 12 | from multiprocessing.pool import Pool 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="VCGBench - Evaluation Script for Contextual Understanding (CU).") 18 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 19 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 20 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 21 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 22 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def annotate(prediction_set, caption_files, output_dir): 28 | """ 29 | Evaluates question and answer pairs using GPT-3 and 30 | returns a score for contextual understanding. 31 | """ 32 | for file in tqdm(caption_files): 33 | key = file[:-5] # Strip file extension 34 | qa_set = prediction_set[key] 35 | question = qa_set['q'] 36 | answer = qa_set['a'] 37 | pred = qa_set['pred'] 38 | try: 39 | # Compute the contextual understanding score 40 | completion = openai.ChatCompletion.create( 41 | model="gpt-3.5-turbo-0613", 42 | messages=[ 43 | { 44 | "role": "system", 45 | "content": 46 | "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. " 47 | "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:" 48 | "------" 49 | "##INSTRUCTIONS: " 50 | "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n" 51 | "- The predicted answer must capture the main themes and sentiments of the video.\n" 52 | "- Consider synonyms or paraphrases as valid matches.\n" 53 | "- Provide your evaluation of the contextual understanding of the prediction compared to the answer." 54 | }, 55 | { 56 | "role": "user", 57 | "content": 58 | "Please evaluate the following video-based question-answer pair:\n\n" 59 | f"Question: {question}\n" 60 | f"Correct Answer: {answer}\n" 61 | f"Predicted Answer: {pred}\n\n" 62 | "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. " 63 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING." 64 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 65 | "For example, your response should look like this: {''score': 4.8}." 66 | } 67 | ] 68 | ) 69 | # Convert response to a Python dictionary. 70 | response_message = completion["choices"][0]["message"]["content"] 71 | response_dict = ast.literal_eval(response_message) 72 | result_qa_pair = [response_dict, qa_set] 73 | 74 | # Save the question-answer pairs to a json file. 75 | with open(f"{output_dir}/{key}.json", "w") as f: 76 | json.dump(result_qa_pair, f) 77 | 78 | except Exception as e: 79 | print(f"Error processing file '{key}': {e}") 80 | 81 | 82 | def main(): 83 | """ 84 | Main function to control the flow of the program. 85 | """ 86 | # Parse arguments. 87 | args = parse_args() 88 | 89 | file = args.pred_path 90 | pred_contents = json.load(open(file, 'r')) 91 | 92 | # Dictionary to store the count of occurrences for each video_id 93 | video_id_counts = {} 94 | new_pred_contents = [] 95 | 96 | # Iterate through each sample in pred_contents 97 | for sample in pred_contents: 98 | sample['video_name'] = 1 99 | video_id = sample['video_name'] 100 | if video_id in video_id_counts: 101 | video_id_counts[video_id] += 1 102 | else: 103 | video_id_counts[video_id] = 0 104 | 105 | # Create a new sample with the modified key 106 | new_sample = sample 107 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 108 | new_pred_contents.append(new_sample) 109 | 110 | # Generating list of id's and corresponding files 111 | id_list = [x['video_name'] for x in new_pred_contents] 112 | caption_files = [f"{id}.json" for id in id_list] 113 | 114 | output_dir = args.output_dir 115 | # Generate output directory if not exists. 116 | if not os.path.exists(output_dir): 117 | os.makedirs(output_dir) 118 | 119 | # Preparing dictionary of question-answer sets 120 | prediction_set = {} 121 | for sample in new_pred_contents: 122 | id = sample['video_name'] 123 | question = sample['prompt'] 124 | answer = sample['answer'] 125 | pred = sample['text'] 126 | qa_set = {"q": question, "a": answer, "pred": pred} 127 | prediction_set[id] = qa_set 128 | 129 | # Set the OpenAI API key. 130 | openai.api_key = args.api_key 131 | num_tasks = args.num_tasks 132 | 133 | # While loop to ensure that all captions are processed. 134 | while True: 135 | try: 136 | # Files that have not been processed yet. 137 | completed_files = os.listdir(output_dir) 138 | print(f"completed_files: {len(completed_files)}") 139 | 140 | # Files that have not been processed yet. 141 | incomplete_files = [f for f in caption_files if f not in completed_files] 142 | print(f"incomplete_files: {len(incomplete_files)}") 143 | 144 | # Break the loop when there are no incomplete files 145 | if len(incomplete_files) == 0: 146 | break 147 | if len(incomplete_files) <= num_tasks: 148 | num_tasks = 1 149 | 150 | # Split tasks into parts. 151 | part_len = len(incomplete_files) // num_tasks 152 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 153 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 154 | 155 | # Use a pool of workers to process the files in parallel. 156 | with Pool() as pool: 157 | pool.starmap(annotate, task_args) 158 | 159 | except Exception as e: 160 | print(f"Error: {e}") 161 | 162 | # Combine all the processed files into one 163 | combined_contents = {} 164 | json_path = args.output_json 165 | 166 | # Iterate through json files 167 | for file_name in os.listdir(output_dir): 168 | if file_name.endswith(".json"): 169 | file_path = os.path.join(output_dir, file_name) 170 | with open(file_path, "r") as json_file: 171 | content = json.load(json_file) 172 | combined_contents[file_name[:-5]] = content 173 | 174 | # Write combined content to a json file 175 | with open(json_path, "w") as json_file: 176 | json.dump(combined_contents, json_file) 177 | print("All evaluation completed!") 178 | 179 | # Calculate average score 180 | score_sum = 0 181 | count = 0 182 | for key, result in combined_contents.items(): 183 | count += 1 184 | score_match = result[0]['score'] 185 | score = int(score_match) 186 | score_sum += score 187 | average_score = score_sum / count 188 | 189 | print("Average score for contextual understanding:", average_score * 20) 190 | print("Average score for contextual understanding:", average_score) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /eval/vcgbench/gpt_evaluation/evaluate_benchmark_4_temporal.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench - Evaluation Script for Temporal Understanding (TU) using gpt-3.5-turbo-0613 3 | 4 | Exactly the same prompts are used as proposed in https://github.com/mbzuai-oryx/Video-ChatGPT for a fair comparison with previous methods. 5 | """ 6 | 7 | import openai 8 | import os 9 | import argparse 10 | import json 11 | import ast 12 | from multiprocessing.pool import Pool 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="VCGBench - Evaluation Script for Temporal Understanding (TU).") 18 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 19 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 20 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 21 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 22 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def annotate(prediction_set, caption_files, output_dir): 28 | """ 29 | Evaluates question and answer pairs using GPT-3 and 30 | returns a score for temporal understanding. 31 | """ 32 | for file in tqdm(caption_files): 33 | key = file[:-5] # Strip file extension 34 | qa_set = prediction_set[key] 35 | question = qa_set['q'] 36 | answer = qa_set['a'] 37 | pred = qa_set['pred'] 38 | try: 39 | # Compute the temporal understanding score 40 | completion = openai.ChatCompletion.create( 41 | model="gpt-3.5-turbo-0613", 42 | messages=[ 43 | { 44 | "role": "system", 45 | "content": 46 | "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. " 47 | "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:" 48 | "------" 49 | "##INSTRUCTIONS: " 50 | "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n" 51 | "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n" 52 | "- Evaluate the temporal accuracy of the prediction compared to the answer." 53 | }, 54 | { 55 | "role": "user", 56 | "content": 57 | "Please evaluate the following video-based question-answer pair:\n\n" 58 | f"Question: {question}\n" 59 | f"Correct Answer: {answer}\n" 60 | f"Predicted Answer: {pred}\n\n" 61 | "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. " 62 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING." 63 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 64 | "For example, your response should look like this: {''score': 4.8}." 65 | } 66 | ] 67 | ) 68 | # Convert response to a Python dictionary. 69 | response_message = completion["choices"][0]["message"]["content"] 70 | response_dict = ast.literal_eval(response_message) 71 | result_qa_pair = [response_dict, qa_set] 72 | 73 | # Save the question-answer pairs to a json file. 74 | with open(f"{output_dir}/{key}.json", "w") as f: 75 | json.dump(result_qa_pair, f) 76 | 77 | except Exception as e: 78 | print(f"Error processing file '{key}': {e}") 79 | 80 | 81 | def main(): 82 | """ 83 | Main function to control the flow of the program. 84 | """ 85 | # Parse arguments. 86 | args = parse_args() 87 | 88 | file = args.pred_path 89 | pred_contents = json.load(open(file, 'r')) 90 | 91 | # Dictionary to store the count of occurrences for each video_id 92 | video_id_counts = {} 93 | new_pred_contents = [] 94 | 95 | # Iterate through each sample in pred_contents 96 | for sample in pred_contents: 97 | sample['video_name'] = 1 98 | video_id = sample['video_name'] 99 | if video_id in video_id_counts: 100 | video_id_counts[video_id] += 1 101 | else: 102 | video_id_counts[video_id] = 0 103 | 104 | # Create a new sample with the modified key 105 | new_sample = sample 106 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 107 | new_pred_contents.append(new_sample) 108 | 109 | # Generating list of id's and corresponding files 110 | id_list = [x['video_name'] for x in new_pred_contents] 111 | caption_files = [f"{id}.json" for id in id_list] 112 | 113 | output_dir = args.output_dir 114 | # Generate output directory if not exists. 115 | if not os.path.exists(output_dir): 116 | os.makedirs(output_dir) 117 | 118 | # Preparing dictionary of question-answer sets 119 | prediction_set = {} 120 | for sample in new_pred_contents: 121 | id = sample['video_name'] 122 | question = sample['prompt'] 123 | answer = sample['answer'] 124 | pred = sample['text'] 125 | qa_set = {"q": question, "a": answer, "pred": pred} 126 | prediction_set[id] = qa_set 127 | 128 | # Set the OpenAI API key. 129 | openai.api_key = args.api_key 130 | num_tasks = args.num_tasks 131 | 132 | # While loop to ensure that all captions are processed. 133 | while True: 134 | try: 135 | # Files that have not been processed yet. 136 | completed_files = os.listdir(output_dir) 137 | print(f"completed_files: {len(completed_files)}") 138 | 139 | # Files that have not been processed yet. 140 | incomplete_files = [f for f in caption_files if f not in completed_files] 141 | print(f"incomplete_files: {len(incomplete_files)}") 142 | 143 | # Break the loop when there are no incomplete files 144 | if len(incomplete_files) == 0: 145 | break 146 | if len(incomplete_files) <= num_tasks: 147 | num_tasks = 1 148 | 149 | # Split tasks into parts. 150 | part_len = len(incomplete_files) // num_tasks 151 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 152 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 153 | 154 | # Use a pool of workers to process the files in parallel. 155 | with Pool() as pool: 156 | pool.starmap(annotate, task_args) 157 | 158 | except Exception as e: 159 | print(f"Error: {e}") 160 | 161 | # Combine all the processed files into one 162 | combined_contents = {} 163 | json_path = args.output_json 164 | 165 | # Iterate through json files 166 | for file_name in os.listdir(output_dir): 167 | if file_name.endswith(".json"): 168 | file_path = os.path.join(output_dir, file_name) 169 | with open(file_path, "r") as json_file: 170 | content = json.load(json_file) 171 | combined_contents[file_name[:-5]] = content 172 | 173 | # Write combined content to a json file 174 | with open(json_path, "w") as json_file: 175 | json.dump(combined_contents, json_file) 176 | print("All evaluation completed!") 177 | 178 | # Calculate average score 179 | score_sum = 0 180 | count = 0 181 | for key, result in combined_contents.items(): 182 | count += 1 183 | score_match = result[0]['score'] 184 | score = int(score_match) 185 | score_sum += score 186 | average_score = score_sum / count 187 | 188 | print("Average score temporal understanding:", average_score) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /eval/vcgbench/gpt_evaluation/evaluate_benchmark_5_consistency.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench - Evaluation Script for Consistency (CO) using gpt-3.5-turbo-0613 3 | 4 | Exactly the same prompts are used as proposed in https://github.com/mbzuai-oryx/Video-ChatGPT for a fair comparison with previous methods. 5 | """ 6 | 7 | import openai 8 | import os 9 | import argparse 10 | import json 11 | import ast 12 | from multiprocessing.pool import Pool 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="VCGBench - Evaluation Script for Consistency (CO).") 18 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 19 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 20 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 21 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 22 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def annotate(prediction_set, caption_files, output_dir): 28 | """ 29 | Evaluates question and answer pairs using GPT-3 and 30 | returns a score for consistency. 31 | """ 32 | for file in tqdm(caption_files): 33 | key = file[:-5] # Strip file extension 34 | qa_set = prediction_set[key] 35 | question1 = qa_set['q1'] 36 | question2 = qa_set['q2'] 37 | answer = qa_set['a'] 38 | pred1 = qa_set['pred1'] 39 | pred2 = qa_set['pred2'] 40 | try: 41 | # Compute the consistency score 42 | completion = openai.ChatCompletion.create( 43 | model="gpt-3.5-turbo-0613", 44 | messages=[ 45 | { 46 | "role": "system", 47 | "content": 48 | "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. " 49 | "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ." 50 | "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:" 51 | "------" 52 | "##INSTRUCTIONS: " 53 | "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n" 54 | "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n" 55 | "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n" 56 | "- Evaluate the consistency of the two predicted answers compared to the correct answer." 57 | }, 58 | { 59 | "role": "user", 60 | "content": 61 | "Please evaluate the following video-based question-answer pair:\n\n" 62 | f"Question 1: {question1}\n" 63 | f"Question 2: {question2}\n" 64 | f"Correct Answer: {answer}\n" 65 | f"Predicted Answer to Question 1: {pred1}\n" 66 | f"Predicted Answer to Question 2: {pred2}\n\n" 67 | "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. " 68 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING." 69 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 70 | "For example, your response should look like this: {''score': 4.8}." 71 | } 72 | ], 73 | ) 74 | # Convert response to a Python dictionary. 75 | response_message = completion["choices"][0]["message"]["content"] 76 | response_dict = ast.literal_eval(response_message) 77 | result_qa_pair = [response_dict, qa_set] 78 | 79 | # Save the question-answer pairs to a json file. 80 | with open(f"{output_dir}/{key}.json", "w") as f: 81 | json.dump(result_qa_pair, f) 82 | 83 | except Exception as e: 84 | print(f"Error processing file '{key}': {e}") 85 | 86 | 87 | def main(): 88 | """ 89 | Main function to control the flow of the program. 90 | """ 91 | # Parse arguments. 92 | args = parse_args() 93 | 94 | file = args.pred_path 95 | pred_contents = json.load(open(file, 'r')) 96 | 97 | # Dictionary to store the count of occurrences for each video_id 98 | video_id_counts = {} 99 | new_pred_contents = [] 100 | 101 | # Iterate through each sample in pred_contents 102 | for sample in pred_contents: 103 | # video_id = sample['video_name'] 104 | video_id = 1 105 | if video_id in video_id_counts: 106 | video_id_counts[video_id] += 1 107 | else: 108 | video_id_counts[video_id] = 0 109 | 110 | # Create a new sample with the modified key 111 | new_sample = sample 112 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 113 | new_pred_contents.append(new_sample) 114 | 115 | # Generating list of id's and corresponding files 116 | id_list = [x['video_name'] for x in new_pred_contents] 117 | caption_files = [f"{id}.json" for id in id_list] 118 | 119 | output_dir = args.output_dir 120 | # Generate output directory if not exists. 121 | if not os.path.exists(output_dir): 122 | os.makedirs(output_dir) 123 | 124 | # Preparing dictionary of question-answer sets 125 | prediction_set = {} 126 | for sample in new_pred_contents: 127 | id = sample['video_name'] 128 | question1 = sample['prompt_1'] 129 | question2 = sample['prompt_2'] 130 | answer = sample['answer'] 131 | pred1 = sample['text_1'] 132 | pred2 = sample['text_2'] 133 | qa_set = {"q1": question1, "q2": question2, "a": answer, "pred1": pred1, "pred2": pred2} 134 | prediction_set[id] = qa_set 135 | 136 | # Set the OpenAI API key. 137 | openai.api_key = args.api_key 138 | num_tasks = args.num_tasks 139 | 140 | # While loop to ensure that all captions are processed. 141 | while True: 142 | try: 143 | # Files that have not been processed yet. 144 | completed_files = os.listdir(output_dir) 145 | print(f"completed_files: {len(completed_files)}") 146 | 147 | # Files that have not been processed yet. 148 | incomplete_files = [f for f in caption_files if f not in completed_files] 149 | print(f"incomplete_files: {len(incomplete_files)}") 150 | 151 | # Break the loop when there are no incomplete files 152 | if len(incomplete_files) == 0: 153 | break 154 | if len(incomplete_files) <= num_tasks: 155 | num_tasks = 1 156 | 157 | # Split tasks into parts. 158 | part_len = len(incomplete_files) // num_tasks 159 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 160 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 161 | 162 | # Use a pool of workers to process the files in parallel. 163 | with Pool() as pool: 164 | pool.starmap(annotate, task_args) 165 | 166 | except Exception as e: 167 | print(f"Error: {e}") 168 | 169 | # Combine all the processed files into one 170 | combined_contents = {} 171 | json_path = args.output_json 172 | 173 | # Iterate through json files 174 | for file_name in os.listdir(output_dir): 175 | if file_name.endswith(".json"): 176 | file_path = os.path.join(output_dir, file_name) 177 | with open(file_path, "r") as json_file: 178 | content = json.load(json_file) 179 | combined_contents[file_name[:-5]] = content 180 | 181 | # Write combined content to a json file 182 | with open(json_path, "w") as json_file: 183 | json.dump(combined_contents, json_file) 184 | print("All evaluation completed!") 185 | 186 | # Calculate average score 187 | score_sum = 0 188 | count = 0 189 | for key, result in combined_contents.items(): 190 | count += 1 191 | score_match = result[0]['score'] 192 | score = int(score_match) 193 | score_sum += score 194 | average_score = score_sum / count 195 | 196 | print("Average score for consistency:", average_score) 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | -------------------------------------------------------------------------------- /eval/vcgbench/gpt_evaluation/vcgbench_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ## Path to directory containing predictions (answer-video-generic.json, answer-video-temporal.json, answer-video-consistency.json ) 4 | PRED_PATH=$1 5 | ## Path to save the results 6 | OUTPUT_DIR_PATH=$2 7 | ## OpenAI API Key 8 | OPENAI_API_KEY=$3 9 | 10 | python evaluate_benchmark_1_correctness.py --pred_path "$PRED_PATH/answer-video-generic.json" --output_dir "$OUTPUT_DIR_PATH/correctness" --output_json "$OUTPUT_DIR_PATH/correctness.json " --api_key "$OPENAI_API_KEY" --num_tasks 16 11 | 12 | 13 | python evaluate_benchmark_2_detailed_orientation.py --pred_path "$PRED_PATH/answer-video-generic.json" --output_dir "$OUTPUT_DIR_PATH/detail" --output_json "$OUTPUT_DIR_PATH/detail.json" --api_key "$OPENAI_API_KEY" --num_tasks 16 14 | 15 | 16 | python evaluate_benchmark_3_context.py --pred_path "$PRED_PATH/answer-video-generic.json" --output_dir "$OUTPUT_DIR_PATH/context" --output_json "$OUTPUT_DIR_PATH/context.json" --api_key "$OPENAI_API_KEY" --num_tasks 16 17 | 18 | 19 | python evaluate_benchmark_4_temporal.py --pred_path "$PRED_PATH/answer-video-temporal.json" --output_dir "$OUTPUT_DIR_PATH/temporal" --output_json "$OUTPUT_DIR_PATH/temporal.json" --api_key "$OPENAI_API_KEY" --num_tasks 16 20 | 21 | 22 | python evaluate_benchmark_5_consistency.py --pred_path "$PRED_PATH/answer-video-consistency.json" --output_dir "$OUTPUT_DIR_PATH/consistency" --output_json "$OUTPUT_DIR_PATH/consistency.json" --api_key "$OPENAI_API_KEY" --num_tasks 16 23 | -------------------------------------------------------------------------------- /eval/vcgbench/inference/ddp.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | import torch 4 | import subprocess 5 | from videogpt_plus.constants import * 6 | from eval.video_encoding import _get_rawvideo_dec 7 | 8 | 9 | class EvalDatasetGeneric(Dataset): 10 | def __init__(self, qa_path, video_dir, image_processor, video_processor): 11 | with open(qa_path) as file: 12 | self.gt_contents = json.load(file) 13 | self.video_dir = video_dir 14 | self.image_processor = image_processor 15 | self.video_processor = video_processor 16 | 17 | self.video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 18 | 19 | def __len__(self): 20 | return len(self.gt_contents) 21 | 22 | def __getitem__(self, idx): 23 | sample = self.gt_contents[idx] 24 | video_name = sample['video_name'] 25 | sample_set = sample 26 | 27 | # Load the video file 28 | for fmt in self.video_formats: # Added this line 29 | temp_path = os.path.join(self.video_dir, f"{video_name}{fmt}") 30 | if os.path.exists(temp_path): 31 | video_path = temp_path 32 | break 33 | 34 | # Check if the video exists 35 | if os.path.exists(video_path): # Modified this line 36 | video_frames, context_frames, slice_len = _get_rawvideo_dec(video_path, self.image_processor, 37 | self.video_processor, 38 | max_frames=NUM_FRAMES, 39 | image_resolution=224, 40 | num_video_frames=NUM_FRAMES, 41 | num_context_images=NUM_CONTEXT_IMAGES) 42 | else: 43 | print(f'Video {video_path} not found') 44 | video_frames, context_frames, slice_len = "None", "None", 0 45 | 46 | return idx, [sample_set], video_frames, context_frames, slice_len 47 | 48 | 49 | def setup_for_distributed(is_master): 50 | """ 51 | This function disables printing when not in master process 52 | """ 53 | import builtins as __builtin__ 54 | builtin_print = __builtin__.print 55 | 56 | def print(*args, **kwargs): 57 | force = kwargs.pop('force', False) 58 | if is_master or force: 59 | builtin_print(*args, **kwargs) 60 | 61 | __builtin__.print = print 62 | 63 | 64 | def init_distributed_mode(args): 65 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 66 | args.rank = int(os.environ["RANK"]) 67 | args.world_size = int(os.environ['WORLD_SIZE']) 68 | args.gpu = int(os.environ['LOCAL_RANK']) 69 | args.dist_url = 'env://' 70 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 71 | print('Using distributed mode: 1') 72 | elif 'SLURM_PROCID' in os.environ: 73 | proc_id = int(os.environ['SLURM_PROCID']) 74 | ntasks = int(os.environ['SLURM_NTASKS']) 75 | node_list = os.environ['SLURM_NODELIST'] 76 | num_gpus = torch.cuda.device_count() 77 | addr = subprocess.getoutput( 78 | 'scontrol show hostname {} | head -n1'.format(node_list)) 79 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '3460') 80 | os.environ['MASTER_ADDR'] = addr 81 | os.environ['WORLD_SIZE'] = str(ntasks) 82 | os.environ['RANK'] = str(proc_id) 83 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 84 | os.environ['LOCAL_SIZE'] = str(num_gpus) 85 | args.dist_url = 'env://' 86 | args.world_size = ntasks 87 | args.rank = proc_id 88 | args.gpu = proc_id % num_gpus 89 | print('Using distributed mode: slurm') 90 | print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']}," 91 | f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}") 92 | else: 93 | print('Not using distributed mode') 94 | args.distributed = False 95 | return 96 | 97 | args.distributed = True 98 | 99 | torch.cuda.set_device(args.gpu) 100 | args.dist_backend = 'nccl' 101 | print('| distributed init (rank {}): {}'.format( 102 | args.rank, args.dist_url), flush=True) 103 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 104 | world_size=args.world_size, rank=args.rank) 105 | torch.distributed.barrier() 106 | setup_for_distributed(args.rank == 0) 107 | -------------------------------------------------------------------------------- /eval/vcgbench/inference/infer_consistency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import shortuuid 4 | from videogpt_plus.conversation import conv_templates 5 | from videogpt_plus.model.builder import load_pretrained_model 6 | from videogpt_plus.mm_utils import tokenizer_image_token, get_model_name_from_path 7 | from eval.vcgbench.inference.ddp import * 8 | from torch.utils.data import DataLoader, DistributedSampler 9 | import traceback 10 | 11 | 12 | def disable_torch_init(): 13 | """ 14 | Disable the redundant torch default initialization to accelerate model creation. 15 | """ 16 | import torch 17 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 18 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 19 | 20 | 21 | def eval_model(args): 22 | # Model 23 | disable_torch_init() 24 | model_path = os.path.expanduser(args.model_path) 25 | model_name = get_model_name_from_path(args.model_path) 26 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 27 | 28 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 29 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 30 | if mm_use_im_patch_token: 31 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 32 | if mm_use_im_start_end: 33 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 34 | model.resize_token_embeddings(len(tokenizer)) 35 | 36 | vision_tower = model.get_vision_tower() 37 | vision_tower.load_model(model.config.mm_vision_tower) 38 | video_processor = vision_tower.image_processor 39 | 40 | image_vision_tower = model.get_image_vision_tower() 41 | image_vision_tower.load_model() 42 | image_processor = image_vision_tower.image_processor 43 | 44 | model = model.to("cuda") 45 | 46 | dataset = EvalDatasetGeneric(args.question_file, args.video_folder, image_processor, video_processor) 47 | distributed_sampler = DistributedSampler(dataset, rank=args.rank, shuffle=False) 48 | dataloader = DataLoader(dataset, batch_size=args.batch_size_per_gpu, num_workers=4, sampler=distributed_sampler) 49 | 50 | for (idx, sample_set, video_frames, context_frames, slice_len) in tqdm(dataloader): 51 | idx, sample_set, video_frames, context_frames, slice_len = int(idx[0]), sample_set[ 52 | 0], video_frames, context_frames, int(slice_len[0]) 53 | 54 | sample = sample_set 55 | question_1 = sample['Q1'][0] 56 | question_2 = sample['Q2'][0] 57 | 58 | try: 59 | qs = question_1 60 | if model.config.mm_use_im_start_end: 61 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs 62 | else: 63 | qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs 64 | 65 | conv = conv_templates[args.conv_mode].copy() 66 | conv.append_message(conv.roles[0], qs) 67 | conv.append_message(conv.roles[1], None) 68 | prompt = conv.get_prompt() 69 | 70 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, 71 | return_tensors='pt').unsqueeze(0).cuda() 72 | 73 | stop_str = "<|end|>" 74 | 75 | with torch.inference_mode(): 76 | output_ids = model.generate( 77 | input_ids, 78 | images=torch.cat(video_frames, dim=0).half().cuda(), 79 | context_images=torch.cat(context_frames, dim=0).half().cuda(), 80 | do_sample=True if args.temperature > 0 else False, 81 | temperature=args.temperature, 82 | top_p=args.top_p, 83 | num_beams=args.num_beams, 84 | max_new_tokens=1024, 85 | use_cache=True) 86 | 87 | input_token_len = input_ids.shape[1] 88 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 89 | if n_diff_input_output > 0: 90 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 91 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 92 | outputs = outputs.strip() 93 | if outputs.endswith(stop_str): 94 | outputs = outputs[:-len(stop_str)] 95 | outputs_1 = outputs.strip() 96 | outputs_1 = outputs_1.replace("<|end|>", '') 97 | outputs_1 = outputs_1.strip() 98 | 99 | qs = question_2 100 | if model.config.mm_use_im_start_end: 101 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs 102 | else: 103 | qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs 104 | 105 | conv = conv_templates[args.conv_mode].copy() 106 | conv.append_message(conv.roles[0], qs) 107 | conv.append_message(conv.roles[1], None) 108 | prompt = conv.get_prompt() 109 | 110 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze( 111 | 0).cuda() 112 | 113 | stop_str = "<|end|>" 114 | 115 | with torch.inference_mode(): 116 | output_ids = model.generate( 117 | input_ids, 118 | images=torch.cat(video_frames, dim=0).half().cuda(), 119 | context_images=torch.cat(context_frames, dim=0).half().cuda(), 120 | do_sample=True if args.temperature > 0 else False, 121 | temperature=args.temperature, 122 | top_p=args.top_p, 123 | num_beams=args.num_beams, 124 | max_new_tokens=1024, 125 | use_cache=True) 126 | 127 | input_token_len = input_ids.shape[1] 128 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 129 | if n_diff_input_output > 0: 130 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 131 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 132 | outputs = outputs.strip() 133 | if outputs.endswith(stop_str): 134 | outputs = outputs[:-len(stop_str)] 135 | outputs_2 = outputs.strip() 136 | outputs_2 = outputs_2.replace("<|end|>", '') 137 | outputs_2 = outputs_2.strip() 138 | 139 | ans_id = shortuuid.uuid() 140 | results = {'video_name': sample['video_name'][0], 141 | "prompt_1": question_1, 142 | "text_1": outputs_1, 143 | "prompt_2": question_2, 144 | "text_2": outputs_2, 145 | "answer_id": ans_id, 146 | "model_id": model_name, 147 | "answer": sample['A'][0], 148 | "metadata": {}} 149 | with open(f"{args.output_dir}/{sample['video_name'][0]}_{idx}.json", "w") as f: 150 | json.dump(results, f) 151 | except Exception as e: 152 | trace = traceback.format_exc() 153 | print(f"Error processing video file '{sample['video_name'][0]}': {e}") 154 | print("Detailed traceback:") 155 | print(trace) 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("--model-path", type=str, default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench") 161 | parser.add_argument("--model-base", type=str, default="microsoft/Phi-3-mini-4k-instruct") 162 | parser.add_argument("--video-folder", type=str, default="VCGBench/Test_Videos") 163 | parser.add_argument("--question-file", type=str, default="VCGBench/Benchmarking_QA/consistency_qa.json") 164 | parser.add_argument("--output-dir", type=str, 165 | default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench/vcgbench_eval/answer-vcgbench-consistency") 166 | parser.add_argument("--conv-mode", type=str, default="phi3_instruct") 167 | parser.add_argument("--temperature", type=float, default=0.0) 168 | parser.add_argument("--top_p", type=float, default=None) 169 | parser.add_argument("--num_beams", type=int, default=1) 170 | 171 | parser.add_argument("--batch_size_per_gpu", required=False, default=1) 172 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 173 | parser.add_argument('--local_rank', default=-1, type=int) 174 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 175 | 176 | args = parser.parse_args() 177 | 178 | init_distributed_mode(args) 179 | 180 | os.makedirs(args.output_dir, exist_ok=True) 181 | 182 | eval_model(args) 183 | -------------------------------------------------------------------------------- /eval/vcgbench/inference/infer_general.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import shortuuid 4 | from videogpt_plus.conversation import conv_templates 5 | from videogpt_plus.model.builder import load_pretrained_model 6 | from videogpt_plus.mm_utils import tokenizer_image_token, get_model_name_from_path 7 | from eval.vcgbench.inference.ddp import * 8 | from torch.utils.data import DataLoader, DistributedSampler 9 | import traceback 10 | 11 | 12 | def disable_torch_init(): 13 | """ 14 | Disable the redundant torch default initialization to accelerate model creation. 15 | """ 16 | import torch 17 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 18 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 19 | 20 | 21 | def eval_model(args): 22 | # Model 23 | disable_torch_init() 24 | model_path = os.path.expanduser(args.model_path) 25 | model_name = get_model_name_from_path(args.model_path) 26 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 27 | 28 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 29 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 30 | if mm_use_im_patch_token: 31 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 32 | if mm_use_im_start_end: 33 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 34 | model.resize_token_embeddings(len(tokenizer)) 35 | 36 | vision_tower = model.get_vision_tower() 37 | vision_tower.load_model(model.config.mm_vision_tower) 38 | video_processor = vision_tower.image_processor 39 | 40 | image_vision_tower = model.get_image_vision_tower() 41 | image_vision_tower.load_model() 42 | image_processor = image_vision_tower.image_processor 43 | 44 | model = model.to("cuda") 45 | 46 | dataset = EvalDatasetGeneric(args.question_file, args.video_folder, image_processor, video_processor) 47 | distributed_sampler = DistributedSampler(dataset, rank=args.rank, shuffle=False) 48 | dataloader = DataLoader(dataset, batch_size=args.batch_size_per_gpu, num_workers=4, sampler=distributed_sampler) 49 | 50 | for (idx, sample_set, video_frames, context_frames, slice_len) in tqdm(dataloader): 51 | idx, sample_set, video_frames, context_frames, slice_len = int(idx[0]), sample_set[ 52 | 0], video_frames, context_frames, int(slice_len[0]) 53 | 54 | sample = sample_set 55 | qs = sample['Q'][0] 56 | 57 | try: 58 | cur_prompt = qs 59 | if model.config.mm_use_im_start_end: 60 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs 61 | else: 62 | qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs 63 | 64 | conv = conv_templates[args.conv_mode].copy() 65 | conv.append_message(conv.roles[0], qs) 66 | conv.append_message(conv.roles[1], None) 67 | prompt = conv.get_prompt() 68 | 69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, 70 | return_tensors='pt').unsqueeze(0).cuda() 71 | 72 | stop_str = "<|end|>" 73 | 74 | with torch.inference_mode(): 75 | output_ids = model.generate( 76 | input_ids, 77 | images=torch.cat(video_frames, dim=0).half().cuda(), 78 | context_images=torch.cat(context_frames, dim=0).half().cuda(), 79 | do_sample=True if args.temperature > 0 else False, 80 | temperature=args.temperature, 81 | top_p=args.top_p, 82 | num_beams=args.num_beams, 83 | max_new_tokens=1024, 84 | use_cache=True) 85 | 86 | input_token_len = input_ids.shape[1] 87 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 88 | if n_diff_input_output > 0: 89 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 90 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 91 | outputs = outputs.strip() 92 | if outputs.endswith(stop_str): 93 | outputs = outputs[:-len(stop_str)] 94 | outputs = outputs.strip() 95 | outputs = outputs.replace("<|end|>", '') 96 | outputs = outputs.strip() 97 | 98 | ans_id = shortuuid.uuid() 99 | results = {'video_name': sample['video_name'][0], 100 | "prompt": cur_prompt, 101 | "text": outputs, 102 | "answer_id": ans_id, 103 | "model_id": model_name, 104 | "answer": sample['A'][0], 105 | "metadata": {}} 106 | with open(f"{args.output_dir}/{sample['video_name'][0]}_{idx}.json", "w") as f: 107 | json.dump(results, f) 108 | except Exception as e: 109 | trace = traceback.format_exc() 110 | print(f"Error processing video file '{sample['video_name'][0]}': {e}") 111 | print("Detailed traceback:") 112 | print(trace) 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--model-path", type=str, default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench") 118 | parser.add_argument("--model-base", type=str, default="microsoft/Phi-3-mini-4k-instruct") 119 | parser.add_argument("--video-folder", type=str, default="VCGBench/Test_Videos") 120 | parser.add_argument("--question-file", type=str, default="VCGBench/Benchmarking_QA/generic_qa.json") 121 | parser.add_argument("--output-dir", type=str, default="MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench/vcgbench_eval/answer-vcgbench-general") 122 | parser.add_argument("--conv-mode", type=str, default="phi3_instruct") 123 | parser.add_argument("--temperature", type=float, default=0.0) 124 | parser.add_argument("--top_p", type=float, default=None) 125 | parser.add_argument("--num_beams", type=int, default=1) 126 | 127 | parser.add_argument("--batch_size_per_gpu", required=False, default=1) 128 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 129 | parser.add_argument('--local_rank', default=-1, type=int) 130 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 131 | 132 | args = parser.parse_args() 133 | 134 | init_distributed_mode(args) 135 | 136 | os.makedirs(args.output_dir, exist_ok=True) 137 | 138 | eval_model(args) 139 | -------------------------------------------------------------------------------- /eval/vcgbench/inference/run_ddp_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Update the number of gpus as per your configuration 4 | NUM_GPUS=8 5 | MODEL_PATH=MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench 6 | MODEL_BASE=microsoft/Phi-3-mini-4k-instruct 7 | VCGBench_PATH=MBZUAI/VCGBench 8 | 9 | export PYTHONPATH="./:$PYTHONPATH" 10 | 11 | # General 12 | torchrun --nproc_per_node="$NUM_GPUS" eval/vcgbench/inference/infer_general.py --model-path "$MODEL_PATH" --model-base "$MODEL_BASE" --video-folder "$VCGBench_PATH/Test_Videos" --question-file "$VCGBench_PATH/Benchmarking_QA/generic_qa.json" --output-dir "$MODEL_PATH/vcgbench_eval/answer-video-generic" --conv-mode "phi3_instruct" 13 | python eval/merge.py --input_dir "$MODEL_PATH/vcgbench_eval/answer-video-generic" 14 | 15 | 16 | # Temporal 17 | torchrun --nproc_per_node="$NUM_GPUS" eval/vcgbench/inference/infer_general.py --model-path "$MODEL_PATH" --model-base "$MODEL_BASE" --video-folder "$VCGBench_PATH/Test_Videos" --question-file "$VCGBench_PATH/Benchmarking_QA/temporal_qa.json" --output-dir "$MODEL_PATH/vcgbench_eval/answer-video-temporal" --conv-mode "phi3_instruct" 18 | python eval/merge.py --input_dir "$MODEL_PATH/vcgbench_eval/answer-video-temporal" 19 | 20 | 21 | # Consistency 22 | torchrun --nproc_per_node="$NUM_GPUS" eval/vcgbench/inference/infer_consistency.py --model-path "$MODEL_PATH" --model-base "$MODEL_BASE" --video-folder "$VCGBench_PATH/Test_Videos" --question-file "$VCGBench_PATH/Benchmarking_QA/consistency_qa.json" --output-dir "$MODEL_PATH/vcgbench_eval/answer-video-consistency" --conv-mode "phi3_instruct" 23 | python eval/merge.py --input_dir "$MODEL_PATH/vcgbench_eval/answer-video-consistency" 24 | -------------------------------------------------------------------------------- /eval/vcgbench_diverse/gpt_evaluation/1_correctness_of_information.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench-Diverse - Evaluation Script for Correctness of Information (CI) using gpt-3.5-turbo-0125 3 | 4 | Copyright 2024 MBZUAI ORYX 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import openai 20 | import os 21 | import argparse 22 | import json 23 | import ast 24 | from multiprocessing.pool import Pool 25 | from tqdm import tqdm 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="VCGBench-Diverse - Evaluation Script for Correctness of Information (CI)") 30 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 31 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 32 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 33 | parser.add_argument("--gt_json_path", required=True, help="The path to file containing ground_truths.") 34 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 35 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | def annotate(prediction_set, caption_files, output_dir): 41 | """ 42 | Evaluates question and answer pairs using GPT-3 43 | Returns a score for correctness. 44 | """ 45 | for file in tqdm(caption_files): 46 | key = file.split('.')[0] # Strip file extension 47 | qa_set = prediction_set[int(key)] 48 | question = qa_set['q'] 49 | answer = qa_set['a'] 50 | pred = qa_set['pred'] 51 | try: 52 | # Compute the correctness score 53 | completion = openai.ChatCompletion.create( 54 | model="gpt-3.5-turbo-0125", 55 | temperature=0.0, 56 | messages=[ 57 | { 58 | "role": "system", 59 | "content": 60 | "You are an AI assistant tasked with evaluating the factual accuracy of generative outputs for video-based question-answer pairs. " 61 | "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent." 62 | "------" 63 | "##INSTRUCTIONS: " 64 | "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the factual information presented in the video and should not contain any misinterpretations or misinformation.\n" 65 | "- Consider synonyms or paraphrases as valid matches, but only if the response is factually accurate and align with the video content.\n" 66 | "- Evaluate the factual accuracy of the prediction compared to the answer, do not assume anything from the world knowledge.\n" 67 | "- Assign a factual accuracy score between 0 and 5, where 5 indicates the highest level of factual consistency.\n" 68 | "- Base your evaluation on the following scale:\n" 69 | " 5: PERFECT match in terms of correctness with no factual errors.\n" 70 | " 4: Very little discrepancies in details, but the information generated is mostly correct and aligns with the video content.\n" 71 | " 3: Mostly correct information with minor discrepancies.\n" 72 | " 2: Very little correct information, though some parts are correct.\n" 73 | " 1: Mostly incorrect or irrelevant details, though some parts are correct\n" 74 | " 0: COMPLETELY incorrect response with no factual consistency.\n" 75 | }, 76 | { 77 | "role": "user", 78 | "content": 79 | "Please evaluate the following video-based question-answer pair:\n\n" 80 | f"Question: {question}\n" 81 | f"Correct Answer: {answer}\n" 82 | f"Predicted Answer: {pred}\n\n" 83 | "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. " 84 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING." 85 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 86 | "For example, your response should look like this: {'score': 2}." 87 | } 88 | ] 89 | ) 90 | # Convert response to a Python dictionary. 91 | response_message = completion["choices"][0]["message"]["content"] 92 | response_dict = ast.literal_eval(response_message) 93 | result_qa_pair = [response_dict, qa_set] 94 | 95 | # Save the question-answer pairs to a json file. 96 | with open(f"{output_dir}/{key}.json", "w") as f: 97 | json.dump(result_qa_pair, f) 98 | 99 | except Exception as e: 100 | print(f"Error processing file '{key}': {e}") 101 | 102 | 103 | def main(): 104 | """ 105 | Main function to control the flow of the program. 106 | """ 107 | # Parse arguments. 108 | args = parse_args() 109 | 110 | file = args.pred_path 111 | pred_contents = json.load(open(file, 'r')) 112 | 113 | # Read GT file 114 | gt_contents = json.load(open(args.gt_json_path, 'r')) 115 | types = ['summary', 'spatial', 'reasoning'] 116 | generic_ids = [x['id'] for x in gt_contents if x['type'] in types] 117 | # Generating list of id's and corresponding files 118 | id_list = [x['ann_id'] for x in pred_contents if x['ann_id'] in generic_ids] 119 | caption_files = [f"{id}.json" for id in id_list] 120 | 121 | output_dir = args.output_dir 122 | # Generate output directory if not exists. 123 | if not os.path.exists(output_dir): 124 | os.makedirs(output_dir) 125 | 126 | # Preparing dictionary of question-answer sets 127 | prediction_set = {} 128 | for sample in pred_contents: 129 | id = sample['ann_id'] 130 | if id in id_list: 131 | question = sample['prompt'] 132 | answer = sample['answer'] 133 | pred = sample['text'] 134 | qa_set = {"ann_id": id, "q": question, "a": answer, "pred": pred} 135 | prediction_set[id] = qa_set 136 | 137 | # Set the OpenAI API key. 138 | openai.api_key = args.api_key 139 | num_tasks = args.num_tasks 140 | 141 | # While loop to ensure that all captions are processed. 142 | while True: 143 | try: 144 | # Files that have not been processed yet. 145 | completed_files = os.listdir(output_dir) 146 | print(f"completed_files: {len(completed_files)}") 147 | 148 | # Files that have not been processed yet. 149 | incomplete_files = [f for f in caption_files if f not in completed_files] 150 | print(f"incomplete_files: {len(incomplete_files)}") 151 | 152 | # Break the loop when there are no incomplete files 153 | if len(incomplete_files) == 0: 154 | break 155 | if len(incomplete_files) <= num_tasks: 156 | num_tasks = 1 157 | 158 | # Split tasks into parts. 159 | part_len = len(incomplete_files) // num_tasks 160 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 161 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 162 | 163 | # Use a pool of workers to process the files in parallel. 164 | with Pool() as pool: 165 | pool.starmap(annotate, task_args) 166 | 167 | except Exception as e: 168 | print(f"Error: {e}") 169 | 170 | # Combine all the processed files into one 171 | combined_contents = {} 172 | json_path = args.output_json 173 | 174 | # Iterate through json files 175 | for file_name in os.listdir(output_dir): 176 | if file_name.endswith(".json"): 177 | file_path = os.path.join(output_dir, file_name) 178 | with open(file_path, "r") as json_file: 179 | content = json.load(json_file) 180 | key = file_name.split(".")[0] 181 | combined_contents[key] = content 182 | 183 | # Write combined content to a json file 184 | with open(json_path, "w") as json_file: 185 | json.dump(combined_contents, json_file) 186 | print("All evaluation completed!") 187 | 188 | # Calculate average score 189 | score_sum = 0 190 | count = 0 191 | for key, result in combined_contents.items(): 192 | count += 1 193 | score_match = result[0]['score'] 194 | score = int(score_match) 195 | score_sum += score 196 | average_score = score_sum / count 197 | 198 | print("Average score for correctness:", average_score) 199 | 200 | 201 | if __name__ == "__main__": 202 | main() 203 | -------------------------------------------------------------------------------- /eval/vcgbench_diverse/gpt_evaluation/3_contextual_information.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench-Diverse - Evaluation Script for Contextual Understanding (CU) using gpt-3.5-turbo-0125 3 | 4 | Copyright 2024 MBZUAI ORYX 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import openai 20 | import os 21 | import argparse 22 | import json 23 | import ast 24 | from multiprocessing.pool import Pool 25 | from tqdm import tqdm 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="VCGBench-Diverse - Evaluation Script for Contextual Understanding (CU)") 30 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 31 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 32 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 33 | parser.add_argument("--gt_json_path", required=True, help="The path to file containing ground_truths.") 34 | parser.add_argument("--api_key", required=True, help="OpenAI API key.") 35 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | def annotate(prediction_set, caption_files, output_dir): 41 | """ 42 | Evaluates question and answer pairs using GPT-3 and 43 | returns a score for contextual understanding. 44 | """ 45 | for file in tqdm(caption_files): 46 | key = file.split('.')[0] # Strip file extension 47 | qa_set = prediction_set[int(key)] 48 | question = qa_set['q'] 49 | answer = qa_set['a'] 50 | pred = qa_set['pred'] 51 | try: 52 | # Compute the contextual understanding score 53 | completion = openai.ChatCompletion.create( 54 | model="gpt-3.5-turbo-0125", 55 | temperature=0.0, 56 | messages=[ 57 | { 58 | "role": "system", 59 | "content": 60 | "You are an AI assistant tasked with evaluating the contextual understanding in results for video-based question-answer pairs. " 61 | "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content." 62 | "------" 63 | "##INSTRUCTIONS: " 64 | "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n" 65 | "- The predicted answer must capture the main themes and sentiments of the video.\n" 66 | "- Consider synonyms or paraphrases as valid matches.\n" 67 | "- Provide a single evaluation score that reflects the level of contextual understanding of the prediction compared to the answer.\n" 68 | "- Assign a contextual understanding score between 0 and 5, where 5 indicates the highest level of contextual understanding.\n" 69 | "- Base your evaluation on the following scale:\n" 70 | " 5: PERFECT match in terms of context, themes, and sentiments.\n" 71 | " 4: Very little misalignments in context or themes, but mostly correct.\n" 72 | " 3: Mostly correct themes or sentiments, but minor misalignments.\n" 73 | " 2: Very little correct elements, though parts are relevant.\n" 74 | " 1: Mostly incorrect context or themes, though some correct elements.\n" 75 | " 0: COMPLETELY incorrect context or themes with no correct elements." 76 | }, 77 | { 78 | "role": "user", 79 | "content": 80 | "Please evaluate the following video-based question-answer pair:\n\n" 81 | f"Question: {question}\n" 82 | f"Correct Answer: {answer}\n" 83 | f"Predicted Answer: {pred}\n\n" 84 | "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. " 85 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the contextual understanding score in INTEGER, not STRING." 86 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 87 | "For example, your response should look like this: {'score': 2}." 88 | } 89 | ] 90 | ) 91 | # Convert response to a Python dictionary. 92 | response_message = completion["choices"][0]["message"]["content"] 93 | response_dict = ast.literal_eval(response_message) 94 | result_qa_pair = [response_dict, qa_set] 95 | 96 | # Save the question-answer pairs to a json file. 97 | with open(f"{output_dir}/{key}.json", "w") as f: 98 | json.dump(result_qa_pair, f) 99 | 100 | except Exception as e: 101 | print(f"Error processing file '{key}': {e}") 102 | 103 | 104 | def main(): 105 | """ 106 | Main function to control the flow of the program. 107 | """ 108 | # Parse arguments. 109 | args = parse_args() 110 | 111 | file = args.pred_path 112 | pred_contents = json.load(open(file, 'r')) 113 | 114 | # Read GT file 115 | gt_contents = json.load(open(args.gt_json_path, 'r')) 116 | types = ['summary', 'spatial', 'reasoning'] 117 | generic_ids = [x['id'] for x in gt_contents if x['type'] in types] 118 | # Generating list of id's and corresponding files 119 | id_list = [x['ann_id'] for x in pred_contents if x['ann_id'] in generic_ids] 120 | caption_files = [f"{id}.json" for id in id_list] 121 | 122 | output_dir = args.output_dir 123 | # Generate output directory if not exists. 124 | if not os.path.exists(output_dir): 125 | os.makedirs(output_dir) 126 | 127 | # Preparing dictionary of question-answer sets 128 | prediction_set = {} 129 | for sample in pred_contents: 130 | id = sample['ann_id'] 131 | if id in id_list: 132 | question = sample['prompt'] 133 | answer = sample['answer'] 134 | pred = sample['text'] 135 | qa_set = {"ann_id": id, "q": question, "a": answer, "pred": pred} 136 | prediction_set[id] = qa_set 137 | 138 | # Set the OpenAI API key. 139 | openai.api_key = args.api_key 140 | num_tasks = args.num_tasks 141 | 142 | # While loop to ensure that all captions are processed. 143 | while True: 144 | try: 145 | # Files that have not been processed yet. 146 | completed_files = os.listdir(output_dir) 147 | print(f"completed_files: {len(completed_files)}") 148 | 149 | # Files that have not been processed yet. 150 | incomplete_files = [f for f in caption_files if f not in completed_files] 151 | print(f"incomplete_files: {len(incomplete_files)}") 152 | 153 | # Break the loop when there are no incomplete files 154 | if len(incomplete_files) == 0: 155 | break 156 | if len(incomplete_files) <= num_tasks: 157 | num_tasks = 1 158 | 159 | # Split tasks into parts. 160 | part_len = len(incomplete_files) // num_tasks 161 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 162 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 163 | 164 | # Use a pool of workers to process the files in parallel. 165 | with Pool() as pool: 166 | pool.starmap(annotate, task_args) 167 | 168 | except Exception as e: 169 | print(f"Error: {e}") 170 | 171 | # Combine all the processed files into one 172 | combined_contents = {} 173 | json_path = args.output_json 174 | 175 | # Iterate through json files 176 | for file_name in os.listdir(output_dir): 177 | if file_name.endswith(".json"): 178 | file_path = os.path.join(output_dir, file_name) 179 | with open(file_path, "r") as json_file: 180 | content = json.load(json_file) 181 | key = file_name.split(".")[0] 182 | combined_contents[key] = content 183 | 184 | # Write combined content to a json file 185 | with open(json_path, "w") as json_file: 186 | json.dump(combined_contents, json_file) 187 | print("All evaluation completed!") 188 | 189 | # Calculate average score 190 | score_sum = 0 191 | count = 0 192 | for key, result in combined_contents.items(): 193 | count += 1 194 | score_match = result[0]['score'] 195 | score = int(score_match) 196 | score_sum += score 197 | average_score = score_sum / count 198 | 199 | print("Average score for contextual understanding:", average_score) 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | -------------------------------------------------------------------------------- /eval/vcgbench_diverse/gpt_evaluation/dense_captioning_spatial_and_reasoning_scores.py: -------------------------------------------------------------------------------- 1 | """ 2 | VCGBench-Diverse - Evaluation Script for Dense Captioning, Spatial Understanding and Reasoning 3 | 4 | Copyright 2024 MBZUAI ORYX 5 | 6 | Licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use this file except in compliance with the License. 8 | You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, software 13 | distributed under the License is distributed on an "AS IS" BASIS, 14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | See the License for the specific language governing permissions and 16 | limitations under the License. 17 | """ 18 | 19 | import argparse 20 | import json 21 | import os 22 | from tqdm import tqdm 23 | 24 | 25 | def parse_args(): 26 | """ 27 | Command-line argument parser. 28 | """ 29 | parser = argparse.ArgumentParser(description="") 30 | 31 | parser.add_argument('--gt_json_path', type=str, required=True, help="The path to file containing ground_truths.") 32 | parser.add_argument('--results_dir_path', type=str, required=True, 33 | help="The path containing correctness and detail evaluation results (i.e. correctness.json and detail.json files).") 34 | 35 | return parser.parse_args() 36 | 37 | 38 | def read_jsonl(file_path): 39 | all_data = [] 40 | with open(file_path, 'r', encoding='utf-8') as file: 41 | for line in file: 42 | data = json.loads(line) 43 | all_data.append(data) 44 | return all_data 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | 50 | gt_json_contents = json.load(open(args.gt_json_path)) 51 | id_to_type_dict = {} 52 | for content in gt_json_contents: 53 | id_to_type_dict[content['id']] = content['type'] 54 | 55 | type_to_score_dict = {"summary": [], "spatial": [], "reasoning": []} 56 | target_jsonl_names = ["correctness.json", "detail.json"] 57 | for target_jsonl_name in target_jsonl_names: 58 | target_json_path = os.path.join(args.results_dir_path, target_jsonl_name) 59 | target_json_data = json.load(open(target_json_path)) 60 | for id_key in tqdm(target_json_data.keys()): 61 | ann_type = id_to_type_dict[int(id_key)] 62 | if ann_type in type_to_score_dict.keys(): 63 | type_to_score_dict[ann_type].append(target_json_data[id_key][0]['score']) 64 | 65 | for key in type_to_score_dict.keys(): 66 | type_to_score_dict[key] = sum(type_to_score_dict[key]) / len(type_to_score_dict[key]) 67 | 68 | print(f"Dense Caption: {type_to_score_dict['summary']}\n" 69 | f"Spatial: {type_to_score_dict['spatial']}\n" 70 | f"Reasoning: {type_to_score_dict['reasoning']}") 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /eval/vcgbench_diverse/gpt_evaluation/vcgbench_diverse_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ## Path to the VCGBench ground truth path (vcgbench_diverse_qa.json) 4 | GT_PATH=$1 5 | ## Path to directory containing predictions (answer-vcgbench-diverse.json) 6 | PRED_PATH=$2 7 | ## Path to save the results 8 | OUTPUT_DIR_PATH=$3 9 | ## OpenAI API Key 10 | OPENAI_API_KEY=$4 11 | 12 | 13 | ## FORMAT of the PREDICTION FILE (answer-vcgbench-diverse.json) should be as follows. 14 | 15 | ## List of dictionaries where each dictionary represents one sample. 16 | ## For consistency questions, the dictionary must have the keys ann_id, video_name, prompt_1, text_1, prompt_2, text_2, and answer. 17 | ## Here ann_id represents the unique annotation id from the ground truth (vcgbench_diverse_qa.json). 18 | 19 | ## An example of the consistency prediction is, 20 | ## {"ann_id": 1715, "video_name": "Mwn9ir0CkF4.mp4", "prompt_1": question_1, "text_1": answer_1, "prompt_2": question_2, "text_2": answer_2, "answer": gt_answer} 21 | 22 | ## For all other types of question, the prediction will have only one question and answer as follows, 23 | ## {"ann_id": 1071, "video_name": "7A3n_hJJjgg.mp4", "prompt": question, "text": answer, "answer": gt_answer} 24 | 25 | 26 | python 1_correctness_of_information.py --pred_path "$PRED_PATH/answer-vcgbench-diverse.json" --output_dir "$OUTPUT_DIR_PATH/correctness" --output_json "$OUTPUT_DIR_PATH/correctness.json" --gt_json_path "$GT_PATH" --api_key "$OPENAI_API_KEY" --num_tasks 16 27 | 28 | 29 | python 2_detailed_orientation.py --pred_path "$PRED_PATH/answer-vcgbench-diverse.json" --output_dir "$OUTPUT_DIR_PATH/detail" --output_json "$OUTPUT_DIR_PATH/detail.json" --gt_json_path "$GT_PATH" --api_key "$OPENAI_API_KEY" --num_tasks 16 30 | 31 | 32 | python 3_contextual_information.py --pred_path "$PRED_PATH/answer-vcgbench-diverse.json" --output_dir "$OUTPUT_DIR_PATH/context" --output_json "$OUTPUT_DIR_PATH/context.json" --gt_json_path "$GT_PATH" --api_key "$OPENAI_API_KEY" --num_tasks 16 33 | 34 | 35 | python 4_temporal_information.py --pred_path "$PRED_PATH/answer-vcgbench-diverse.json" --output_dir "$OUTPUT_DIR_PATH/temporal" --output_json "$OUTPUT_DIR_PATH/temporal.json" --gt_json_path "$GT_PATH" --api_key "$OPENAI_API_KEY" --num_tasks 16 36 | 37 | 38 | python 5_consistency.py --pred_path "$PRED_PATH/answer-vcgbench-diverse.json" --output_dir "$OUTPUT_DIR_PATH/consistency" --output_json "$OUTPUT_DIR_PATH/consistency.json" --gt_json_path "$GT_PATH" --api_key "$OPENAI_API_KEY" --num_tasks 16 39 | 40 | 41 | python dense_captioning_spatial_and_reasoning_scores.py --gt_json_path "$GT_PATH" --results_dir_path "$OUTPUT_DIR_PATH" 42 | -------------------------------------------------------------------------------- /eval/vcgbench_diverse/inference/ddp.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | import torch 4 | import subprocess 5 | from videogpt_plus.constants import * 6 | from eval.video_encoding import _get_rawvideo_dec 7 | 8 | 9 | class EvalDatasetGeneric(Dataset): 10 | def __init__(self, qa_path, video_dir, image_processor, video_processor): 11 | with open(qa_path) as file: 12 | self.gt_contents = json.load(file) 13 | self.video_dir = video_dir 14 | self.image_processor = image_processor 15 | self.video_processor = video_processor 16 | 17 | self.video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 18 | 19 | def __len__(self): 20 | return len(self.gt_contents) 21 | 22 | def __getitem__(self, idx): 23 | sample = self.gt_contents[idx] 24 | video_name = sample['video_name'] 25 | sample_set = sample 26 | 27 | # Load the video file 28 | video_path = os.path.join(self.video_dir, video_name) 29 | 30 | # Check if the video exists 31 | if os.path.exists(video_path): # Modified this line 32 | video_frames, context_frames, slice_len = _get_rawvideo_dec(video_path, self.image_processor, 33 | self.video_processor, 34 | max_frames=NUM_FRAMES, 35 | image_resolution=224, 36 | num_video_frames=NUM_FRAMES, 37 | num_context_images=NUM_CONTEXT_IMAGES) 38 | else: 39 | print(f'Video {video_path} not found') 40 | video_frames, context_frames, slice_len = "None", "None", 0 41 | 42 | return idx, [sample_set], video_frames, context_frames, slice_len 43 | 44 | 45 | def setup_for_distributed(is_master): 46 | """ 47 | This function disables printing when not in master process 48 | """ 49 | import builtins as __builtin__ 50 | builtin_print = __builtin__.print 51 | 52 | def print(*args, **kwargs): 53 | force = kwargs.pop('force', False) 54 | if is_master or force: 55 | builtin_print(*args, **kwargs) 56 | 57 | __builtin__.print = print 58 | 59 | 60 | def init_distributed_mode(args): 61 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 62 | args.rank = int(os.environ["RANK"]) 63 | args.world_size = int(os.environ['WORLD_SIZE']) 64 | args.gpu = int(os.environ['LOCAL_RANK']) 65 | args.dist_url = 'env://' 66 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 67 | print('Using distributed mode: 1') 68 | elif 'SLURM_PROCID' in os.environ: 69 | proc_id = int(os.environ['SLURM_PROCID']) 70 | ntasks = int(os.environ['SLURM_NTASKS']) 71 | node_list = os.environ['SLURM_NODELIST'] 72 | num_gpus = torch.cuda.device_count() 73 | addr = subprocess.getoutput( 74 | 'scontrol show hostname {} | head -n1'.format(node_list)) 75 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '3460') 76 | os.environ['MASTER_ADDR'] = addr 77 | os.environ['WORLD_SIZE'] = str(ntasks) 78 | os.environ['RANK'] = str(proc_id) 79 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 80 | os.environ['LOCAL_SIZE'] = str(num_gpus) 81 | args.dist_url = 'env://' 82 | args.world_size = ntasks 83 | args.rank = proc_id 84 | args.gpu = proc_id % num_gpus 85 | print('Using distributed mode: slurm') 86 | print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']}," 87 | f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}") 88 | else: 89 | print('Not using distributed mode') 90 | args.distributed = False 91 | return 92 | 93 | args.distributed = True 94 | 95 | torch.cuda.set_device(args.gpu) 96 | args.dist_backend = 'nccl' 97 | print('| distributed init (rank {}): {}'.format( 98 | args.rank, args.dist_url), flush=True) 99 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 100 | world_size=args.world_size, rank=args.rank) 101 | torch.distributed.barrier() 102 | setup_for_distributed(args.rank == 0) 103 | -------------------------------------------------------------------------------- /eval/vcgbench_diverse/inference/run_ddp_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Update the number of gpus as per your configuration 4 | NUM_GPUS=8 5 | MODEL_PATH=MBZUAI/VideoGPT-plus_Phi3-mini-4k/vcgbench 6 | MODEL_BASE=microsoft/Phi-3-mini-4k-instruct 7 | VCGBench_Diverse_PATH=MBZUAI/VCGBench-Diverse 8 | 9 | export PYTHONPATH="./:$PYTHONPATH" 10 | 11 | torchrun --nproc_per_node="$NUM_GPUS" eval/vcgbench_diverse/inference/infer.py --model-path "$MODEL_PATH" --model-base "$MODEL_BASE" --video-folder "$VCGBench_Diverse_PATH/videos" --question-file "$VCGBench_Diverse_PATH/vcgbench_diverse_qa.json" --output-dir "$MODEL_PATH/vcgbench_diverse_eval/answer-vcgbench-diverse" --conv-mode "phi3_instruct" 12 | 13 | python eval/merge.py --input_dir "$MODEL_PATH/vcgbench_diverse_eval/answer-vcgbench-diverse" 14 | -------------------------------------------------------------------------------- /eval/video_encoding.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import io 3 | import imageio 4 | import torch 5 | import numpy as np 6 | from decord import VideoReader 7 | from PIL import Image 8 | from videogpt_plus.constants import * 9 | from mmengine import fileio 10 | from mmengine.fileio import FileClient 11 | 12 | client = FileClient('disk') 13 | 14 | 15 | def uniform_sample(lst, n): 16 | assert n <= len(lst) 17 | m = len(lst) 18 | step = m // n # Calculate the step size 19 | return [lst[i * step] for i in range(n)] 20 | 21 | 22 | def _get_rawvideo_dec(video_path, image_processor, video_processor, max_frames=16, min_frames=16, image_resolution=224, 23 | video_framerate=1, s=None, e=None, num_video_frames=NUM_FRAMES, num_context_images=16): 24 | # Speed up video decode via decord. 25 | video_mask = np.zeros(max_frames, dtype=np.int64) 26 | max_video_length = 0 27 | 28 | if s is None: 29 | start_time, end_time = None, None 30 | else: 31 | start_time = int(s) 32 | end_time = int(e) 33 | start_time = start_time if start_time >= 0. else 0. 34 | end_time = end_time if end_time >= 0. else 0. 35 | if start_time > end_time: 36 | start_time, end_time = end_time, start_time 37 | elif start_time == end_time: 38 | end_time = start_time + 1 39 | 40 | if os.path.exists(video_path): 41 | try: 42 | vreader = VideoReader(video_path, num_threads=1) 43 | except Exception as e: 44 | try: 45 | video_bytes = fileio.get(video_path) 46 | vreader = VideoReader(io.BytesIO(video_bytes), num_threads=1) 47 | except Exception as e: 48 | print("Both options failed for video path:", video_path) 49 | else: 50 | raise FileNotFoundError(video_path) 51 | 52 | fps = vreader.get_avg_fps() 53 | f_start = 0 if start_time is None else int(start_time * fps) 54 | f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) 55 | num_frames = f_end - f_start + 1 56 | 57 | # T x 3 x H x W 58 | sample_fps = int(video_framerate) 59 | t_stride = int(round(float(fps) / sample_fps)) 60 | 61 | all_pos = list(range(f_start, f_end + 1, t_stride)) 62 | if len(all_pos) > max_frames: 63 | sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)] 64 | elif len(all_pos) < min_frames: 65 | if num_frames < min_frames: 66 | min_frames = num_frames 67 | t_stride = max(1, (f_end - f_start) // (min_frames - 1)) 68 | adjusted_f_end = f_start + t_stride * (min_frames - 1) 69 | sample_pos = list(range(f_start, adjusted_f_end + 1, t_stride)) 70 | else: 71 | sample_pos = all_pos 72 | 73 | all_images = [f for f in vreader.get_batch(sample_pos).asnumpy()] 74 | # In case if we can't sample MAX_IMAGE_LENGTH frames 75 | num_video_frames_sampled = min(num_video_frames, len(all_images)) 76 | num_context_images_sampled = min(num_context_images, len(all_images)) 77 | 78 | patch_images = uniform_sample(all_images, num_video_frames_sampled) 79 | context_images = uniform_sample(all_images, num_context_images_sampled) 80 | 81 | patch_images = video_processor.preprocess(patch_images)['pixel_values'] 82 | context_images = [image_processor.preprocess(i, return_tensors='pt')['pixel_values'][0] for i in context_images] 83 | 84 | if len(context_images) < num_context_images: # Pad 85 | while len(context_images) < num_context_images: 86 | context_images.append( 87 | torch.zeros((3, image_processor.crop_size['height'], image_processor.crop_size['width']))) 88 | 89 | slice_len = len(patch_images) 90 | if slice_len < 1: 91 | pass 92 | else: 93 | while len(patch_images) < num_video_frames: 94 | patch_images.append(torch.zeros((3, image_resolution, image_resolution))) 95 | 96 | return patch_images, context_images, slice_len 97 | 98 | 99 | def read_gif_mod(video_path, image_processor, max_frames=16, image_resolution=224, video_framerate=25, 100 | s=None, e=None, sample_fps=1): 101 | # Initialize data structures 102 | video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64) 103 | 104 | # Load GIF file 105 | video_bytes = client.get(video_path) 106 | gif_reader = imageio.get_reader(io.BytesIO(video_bytes)) 107 | num_frames = len(gif_reader) 108 | 109 | # Calculate frame indices 110 | fps = video_framerate 111 | f_start = 0 if s is None else max(int(s * fps), 0) 112 | f_end = min(num_frames - 1, int(e * fps)) if e is not None else num_frames - 1 113 | 114 | t_stride = max(int(round(float(fps) / sample_fps)), 1) 115 | frame_indices = range(f_start, f_end + 1, t_stride) 116 | 117 | # Process frames 118 | processed_frames = [] 119 | for i, frame in enumerate(gif_reader): 120 | if i in frame_indices: 121 | img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 122 | img_pil = Image.fromarray(img).resize((image_resolution, image_resolution)) 123 | processed_frames.append(img_pil) 124 | 125 | if len(processed_frames) >= max_frames: 126 | break 127 | # Transform images 128 | patch_images = processed_frames 129 | patch_images = image_processor.preprocess(patch_images)['pixel_values'] 130 | 131 | slice_len = patch_images.shape[0] 132 | 133 | # Store video data 134 | video[:slice_len, ...] = patch_images 135 | 136 | return video, slice_len 137 | 138 | 139 | def read_frame_mod(video_path, image_processor, video_processor, max_frames=16, image_resolution=224, video_framerate=3, 140 | s=None, e=None, sample_fps=1, num_video_frames=16, num_context_images=16): 141 | # Initialize data structures 142 | video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64) 143 | max_video_length = 0 144 | 145 | # Check if video path exists 146 | if not os.path.exists(video_path): 147 | raise FileNotFoundError(f"Video path {video_path} not found.") 148 | 149 | # Determine frame range 150 | frame_files = sorted(os.listdir(video_path)) 151 | num_frames = len(frame_files) 152 | 153 | # Calculate frame indices 154 | fps = video_framerate 155 | f_start = 0 if s is None else max(int(s * fps), 0) 156 | f_end = min(num_frames - 1, int(e * fps)) if e is not None else num_frames - 1 157 | 158 | t_stride = max(int(round(float(fps) / sample_fps)), 1) 159 | frame_indices = range(f_start, f_end + 1, t_stride) 160 | 161 | # Process frames 162 | all_frames = [] 163 | for idx in frame_indices: 164 | img_path = os.path.join(video_path, frame_files[idx]) 165 | img = np.array(Image.open(img_path)) 166 | all_frames.append(img) 167 | 168 | if len(all_frames) >= max_frames: 169 | break 170 | 171 | num_video_frames_sampled = min(num_video_frames, len(all_frames)) 172 | num_context_images_sampled = min(num_context_images, len(all_frames)) 173 | 174 | patch_images = uniform_sample(all_frames, num_video_frames_sampled) 175 | context_images = uniform_sample(all_frames, num_context_images_sampled) 176 | 177 | patch_images = video_processor.preprocess(patch_images)['pixel_values'] 178 | context_images = [image_processor.preprocess(i, return_tensors='pt')['pixel_values'][0] for i in context_images] 179 | 180 | if len(context_images) < num_context_images: # Pad 181 | while len(context_images) < num_context_images: 182 | context_images.append( 183 | torch.zeros((3, image_processor.crop_size['height'], image_processor.crop_size['width']))) 184 | 185 | slice_len = len(patch_images) 186 | 187 | if slice_len < 1: 188 | pass 189 | else: 190 | while len(patch_images) < num_video_frames: 191 | patch_images.append(torch.zeros((3, image_resolution, image_resolution))) 192 | 193 | return patch_images, context_images, slice_len 194 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tokenizers 2 | sentencepiece 3 | shortuuid 4 | accelerate 5 | peft 6 | bitsandbytes 7 | pydantic 8 | markdown2[all] 9 | numpy 10 | scikit-learn 11 | gradio 12 | gradio_client 13 | requests 14 | httpx 15 | uvicorn 16 | fastapi 17 | einops 18 | einops-exts 19 | timm -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Training VideotGPT+ :train: 2 | We provide scripts for projector pretraining and video fine-tuning of VideoGPT+. Please follow the instructions below. 3 | 4 | ## Download Training Dataset 5 | You can download all the pretraining and fine-tuning datasets from HuggingFace follow the instructions below, 6 | 7 | ```bash 8 | mkdir playground 9 | mkdir playground/data 10 | cd playground/data 11 | git lfs install 12 | git clone https://huggingface.co/datasets/MBZUAI/VideoGPT-plus_Training_Dataset 13 | ``` 14 | 15 | ## Projector pretraining with CLIP Image Encoder 16 | Use the script [scripts/pretrain_projector_image_encoder.sh](scripts/pretrain_projector_image_encoder.sh) for running MLP projector pretraining with CLIP Image Encoder. 17 | 18 | ## Projector pretraining with InternVideo2 Video Encoder 19 | Please use the script [scripts/pretrain_projector_video_encoder.sh](scripts/pretrain_projector_video_encoder.sh) for running MLP projector pretraining with InternVideo2 video encoder. 20 | 21 | ALTERNATIVELY, you can download the pretrained projector weights provided by us from the HuggingFace, 22 | 23 | ```bash 24 | git lfs install 25 | git clone https://huggingface.co/MBZUAI/VideoGPT-plus_Phi3-mini-4k_Pretrain 26 | ``` 27 | 28 | ## Video Instruction Fine-tuning 29 | Please use the script [scripts/finetune_dual_encoder.sh](finetune_dual_encoder.sh) for video instruction fine-tuning. 30 | -------------------------------------------------------------------------------- /scripts/finetune_dual_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | export DATASET_DIR=playground/data 5 | 6 | BASE_LLM_PATH=microsoft/Phi-3-mini-4k-instruct 7 | VISION_TOWER=OpenGVLab/InternVideo2-Stage2_1B-224p-f4 8 | IMAGE_VISION_TOWER=openai/clip-vit-large-patch14-336 9 | PROJECTOR_TYPE=mlp2x_gelu 10 | PRETRAIN_VIDEO_MLP_PATH=MBZUAI/VideoGPT-plus_Phi3-mini-4k_Pretrain/mlp2x_gelu_internvideo2/mm_projector.bin 11 | PRETRAIN_IMAGE_MLP_PATH=MBZUAI/VideoGPT-plus_Phi3-mini-4k_Pretrain/mlp2x_gelu_clip_l14_336px/mm_projector.bin 12 | OUTPUT_DIR_PATH=results/videogpt_plus_finetune 13 | 14 | deepspeed videogpt_plus/train/train.py \ 15 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 16 | --deepspeed scripts/zero3.json \ 17 | --model_name_or_path "$BASE_LLM_PATH" \ 18 | --version phi3_instruct \ 19 | --dataset_use FINETUNING \ 20 | --vision_tower "$VISION_TOWER" \ 21 | --image_vision_tower "$IMAGE_VISION_TOWER" \ 22 | --mm_projector_type "$PROJECTOR_TYPE" \ 23 | --image_mm_projector_type "$PROJECTOR_TYPE" \ 24 | --pretrain_mm_mlp_adapter "$PRETRAIN_VIDEO_MLP_PATH" \ 25 | --pretrain_image_mm_mlp_adapter "$PRETRAIN_IMAGE_MLP_PATH" \ 26 | --mm_vision_select_layer -2 \ 27 | --mm_use_im_start_end False \ 28 | --mm_use_im_patch_token False \ 29 | --image_aspect_ratio pad \ 30 | --group_by_modality_length True \ 31 | --bf16 True \ 32 | --output_dir $OUTPUT_DIR_PATH \ 33 | --num_train_epochs 1 \ 34 | --per_device_train_batch_size 8 \ 35 | --per_device_eval_batch_size 4 \ 36 | --gradient_accumulation_steps 2 \ 37 | --evaluation_strategy "no" \ 38 | --save_strategy "steps" \ 39 | --save_steps 50000 \ 40 | --save_total_limit 1 \ 41 | --learning_rate 2e-4 \ 42 | --weight_decay 0. \ 43 | --warmup_ratio 0.03 \ 44 | --lr_scheduler_type "cosine" \ 45 | --logging_steps 1 \ 46 | --tf32 True \ 47 | --model_max_length 4096 \ 48 | --gradient_checkpointing True \ 49 | --dataloader_num_workers 4 \ 50 | --lazy_preprocess True \ 51 | --report_to none 52 | -------------------------------------------------------------------------------- /scripts/pretrain_projector_image_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | export DATASET_DIR=playground/data 5 | 6 | BASE_LLM_PATH=microsoft/Phi-3-mini-4k-instruct 7 | IMAGE_VISION_TOWER=openai/clip-vit-large-patch14-336 8 | PROJECTOR_TYPE=mlp2x_gelu 9 | OUTPUT_DIR_PATH=results/mlp2x_gelu_clip_l14_336px 10 | 11 | deepspeed videogpt_plus/train/pretrain.py \ 12 | --deepspeed scripts/zero2.json \ 13 | --tune_image_mm_mlp_adapter True \ 14 | --model_name_or_path "$BASE_LLM_PATH" \ 15 | --version plain \ 16 | --dataset_use PRETRAINING \ 17 | --image_vision_tower "$IMAGE_VISION_TOWER" \ 18 | --image_mm_projector_type "$PROJECTOR_TYPE" \ 19 | --mm_vision_select_layer -2 \ 20 | --mm_use_im_start_end False \ 21 | --mm_use_im_patch_token False \ 22 | --bf16 True \ 23 | --output_dir $OUTPUT_DIR_PATH \ 24 | --num_train_epochs 1 \ 25 | --per_device_train_batch_size 16 \ 26 | --per_device_eval_batch_size 4 \ 27 | --gradient_accumulation_steps 2 \ 28 | --evaluation_strategy "no" \ 29 | --save_strategy "steps" \ 30 | --save_steps 50000 \ 31 | --save_total_limit 1 \ 32 | --learning_rate 1e-3 \ 33 | --weight_decay 0. \ 34 | --warmup_ratio 0.03 \ 35 | --lr_scheduler_type "cosine" \ 36 | --logging_steps 1 \ 37 | --tf32 True \ 38 | --model_max_length 4096 \ 39 | --gradient_checkpointing True \ 40 | --dataloader_num_workers 4 \ 41 | --lazy_preprocess True \ 42 | --report_to none 43 | -------------------------------------------------------------------------------- /scripts/pretrain_projector_video_encoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | export DATASET_DIR=playground/data 5 | 6 | BASE_LLM_PATH=microsoft/Phi-3-mini-4k-instruct 7 | VISION_TOWER=OpenGVLab/InternVideo2-Stage2_1B-224p-f4 8 | PROJECTOR_TYPE=mlp2x_gelu 9 | OUTPUT_DIR_PATH=results/mlp2x_gelu_internvideo2 10 | 11 | deepspeed videogpt_plus/train/pretrain.py \ 12 | --deepspeed scripts/zero2.json \ 13 | --tune_mm_mlp_adapter True \ 14 | --model_name_or_path "$BASE_LLM_PATH" \ 15 | --version plain \ 16 | --dataset_use PRETRAINING \ 17 | --vision_tower "$VISION_TOWER" \ 18 | --mm_projector_type "$PROJECTOR_TYPE" \ 19 | --mm_vision_select_layer -2 \ 20 | --mm_use_im_start_end False \ 21 | --mm_use_im_patch_token False \ 22 | --bf16 True \ 23 | --output_dir $OUTPUT_DIR_PATH \ 24 | --num_train_epochs 1 \ 25 | --per_device_train_batch_size 16 \ 26 | --per_device_eval_batch_size 4 \ 27 | --gradient_accumulation_steps 2 \ 28 | --evaluation_strategy "no" \ 29 | --save_strategy "steps" \ 30 | --save_steps 50000 \ 31 | --save_total_limit 1 \ 32 | --learning_rate 1e-3 \ 33 | --weight_decay 0. \ 34 | --warmup_ratio 0.03 \ 35 | --lr_scheduler_type "cosine" \ 36 | --logging_steps 1 \ 37 | --tf32 True \ 38 | --model_max_length 4096 \ 39 | --gradient_checkpointing True \ 40 | --dataloader_num_workers 4 \ 41 | --lazy_preprocess True \ 42 | --report_to none 43 | -------------------------------------------------------------------------------- /scripts/zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 1, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /videogpt_plus/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import VideoGPTPlusPhi3ForCausalLM 2 | -------------------------------------------------------------------------------- /videogpt_plus/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_config import * 2 | 3 | DataConfig = { 4 | "PRETRAINING": [CC3M_595K, COCO_CAP, COCO_REG, COCO_REC], 5 | 6 | "FINETUNING": [CONV_VideoChatGPT, VCG_HUMAN, VCG_PLUS_112K, CAPTION_VIDEOCHAT, CLASSIFICATION_K710, CLASSIFICATION_SSV2, CONV_VideoChat1, REASONING_NExTQA, REASONING_CLEVRER_QA, REASONING_CLEVRER_MC, VQA_WEBVID_QA], 7 | 8 | "VCGBench_FINETUNING": [CONV_VideoChatGPT, VCG_HUMAN, VCG_PLUS_112K, CAPTION_VIDEOCHAT, CONV_VideoChat1, VQA_WEBVID_QA], 9 | "MVBench_FINETUNING": [CLASSIFICATION_K710, CLASSIFICATION_SSV2, CONV_VideoChatGPT, REASONING_NExTQA, REASONING_CLEVRER_QA, REASONING_CLEVRER_MC, VQA_WEBVID_QA], 10 | 11 | } 12 | -------------------------------------------------------------------------------- /videogpt_plus/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATASET_DIR = os.environ.get("DATASET_DIR", "playground/data") 4 | 5 | CC3M_595K = { 6 | "annotation_path": f"{DATASET_DIR}/pretraining/CC3M-595K/chat.json", 7 | "data_path": f"{DATASET_DIR}/pretraining/CC3M-595K", 8 | } 9 | 10 | COCO_CAP = { 11 | "annotation_path": f"{DATASET_DIR}/pretraining/COCO/coco_cap_chat.json", 12 | "data_path": f"{DATASET_DIR}/pretraining/COCO/train2014", 13 | } 14 | 15 | COCO_REG = { 16 | "annotation_path": f"{DATASET_DIR}/pretraining/COCO/coco_reg_chat.json", 17 | "data_path": f"{DATASET_DIR}/pretraining/COCO/train2014", 18 | } 19 | 20 | COCO_REC = { 21 | "annotation_path": f"{DATASET_DIR}/pretraining/COCO/coco_rec_chat.json", 22 | "data_path": f"{DATASET_DIR}/pretraining/COCO/train2014", 23 | } 24 | 25 | CONV_VideoChatGPT = { 26 | "annotation_path": f"{DATASET_DIR}/annotations/conversation_videochatgpt.json", 27 | "data_path": f"{DATASET_DIR}/instruction_tuning/Activity_Videos", 28 | } 29 | 30 | VCG_HUMAN = { 31 | "annotation_path": f"{DATASET_DIR}/annotations/vcg_human_annotated.json", 32 | "data_path": f"{DATASET_DIR}/instruction_tuning/Activity_Videos", 33 | } 34 | 35 | VCG_PLUS_112K = { 36 | "annotation_path": f"{DATASET_DIR}/annotations/vcg-plus_112K.json", 37 | "data_path": f"{DATASET_DIR}/instruction_tuning/Activity_Videos", 38 | } 39 | 40 | CAPTION_VIDEOCHAT = { 41 | "annotation_path": f"{DATASET_DIR}/annotations/caption_videochat.json", 42 | "data_path": f"{DATASET_DIR}/instruction_tuning/webvid", 43 | } 44 | 45 | CLASSIFICATION_K710 = { 46 | "annotation_path": f"{DATASET_DIR}/annotations/classification_k710.json", 47 | "data_path": f"{DATASET_DIR}/instruction_tuning/k710", 48 | } 49 | 50 | CLASSIFICATION_SSV2 = { 51 | "annotation_path": f"{DATASET_DIR}/annotations/classification_ssv2.json", 52 | "data_path": f"{DATASET_DIR}/instruction_tuning/ssv2", 53 | } 54 | 55 | CONV_VideoChat1 = { 56 | "annotation_path": f"{DATASET_DIR}/annotations/conversation_videochat1.json", 57 | "data_path": f"{DATASET_DIR}/instruction_tuning/videochat_it", 58 | } 59 | 60 | REASONING_NExTQA = { 61 | "annotation_path": f"{DATASET_DIR}/annotations/reasoning_next_qa.json", 62 | "data_path": f"{DATASET_DIR}/instruction_tuning/NExTQA", 63 | } 64 | 65 | REASONING_CLEVRER_QA = { 66 | "annotation_path": f"{DATASET_DIR}/annotations/reasoning_clevrer_qa.json", 67 | "data_path": f"{DATASET_DIR}/instruction_tuning/clevrer", 68 | } 69 | 70 | REASONING_CLEVRER_MC = { 71 | "annotation_path": f"{DATASET_DIR}/annotations/reasoning_clevrer_mc.json", 72 | "data_path": f"{DATASET_DIR}/instruction_tuning/clevrer", 73 | } 74 | 75 | VQA_WEBVID_QA = { 76 | "annotation_path": f"{DATASET_DIR}/annotations/vqa_webvid_qa.json", 77 | "data_path": f"{DATASET_DIR}/instruction_tuning/webvid", 78 | } 79 | -------------------------------------------------------------------------------- /videogpt_plus/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from distutils.util import strtobool 3 | 4 | # Configuration Constants 5 | # TODO: Change the chunk size if you use any other video encoder accordingly 6 | CHUNK_SIZE = 4 # Video chunk size for InternVideo2-Stage2_1B-224p-f4 which is trained using 4 frames per video 7 | NUM_FRAMES = int(os.environ.get("NUM_FRAMES", 16)) # Number of video frames (if using video) 8 | NUM_CONTEXT_IMAGES = int(os.environ.get("NUM_CONTEXT_IMAGES", 16)) # Number of context images for video 9 | 10 | # Model Constants 11 | IGNORE_INDEX = -100 12 | IMAGE_TOKEN_INDEX = -200 13 | DEFAULT_IMAGE_TOKEN = "" 14 | DEFAULT_VIDEO_TOKEN = "