├── LICENSE ├── README.md ├── cfgs ├── free_video_llm_34b.yaml └── free_video_llm_7b.yaml ├── dataset.py ├── eval ├── eval_prompt.py └── eval_video_qa.py ├── free_video_llm └── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ └── llava_llama.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py │ └── utils.py ├── prompt.py ├── pyproject.toml ├── run_demo.py ├── run_inference.py ├── run_inference_video_qa.py ├── scripts ├── data │ ├── prepare_activitynet_qa_file.py │ ├── prepare_msrvtt_qa_file.py │ ├── prepare_msvd_qa_file.py │ └── prepare_tgif_qa_file.py ├── run_eval_activitynet.sh ├── run_eval_msrvtt.sh ├── run_eval_msvd.sh ├── run_eval_tgif.sh └── run_eval_videoqabench.sh ├── setup_env.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 The Free Video-LLM Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FreeVideoLLM 2 | 3 | Free Video-LLM: Prompt-guided Visual Perception for Efficient Training-free Video LLM [![arXiv](https://img.shields.io/badge/Arxiv-2410.10441-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.10441) [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/contrastive/FreeVideoLLM) 4 | 5 | by [Kai Han](https://iamhankai.github.io/), [Jianyuan Guo](https://ggjy.github.io/), [Yehui Tang](https://scholar.google.com/citations?user=TkSZQ6gAAAAJ&hl=zh-CN), [Wei He](https://github.com/contrastive/FreeVideoLLM), [Enhua Wu](https://www.fst.um.edu.mo/people/ehwu/), [Yunhe Wang](https://www.wangyunhe.site/) 6 | 7 | ## Getting Started 8 | 9 | ### Installation 10 | 11 | The code is developed with CUDA 11.7, Python >= 3.10.12, PyTorch >= 2.1.0 12 | 13 | 1. Install the requirements. 14 | ``` 15 | bash setup_env.sh 16 | ``` 17 | 18 | 2. Add OpenAI key and organization to the system environment to use GPT-3.5-turbo for model evaluation. 19 | ``` 20 | export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY 21 | export OPENAI_ORG=$YOUR_OPENAI_ORG # optional 22 | ``` 23 | 24 | 3. Download pre-trained LLaVA-v1.6 weights from [`HuggingFace`](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2), and put them under the [`FreeVideoLLM`](./) folder. 25 | ``` 26 | git lfs clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b liuhaotian/llava-v1.6-vicuna-7b 27 | git lfs clone https://huggingface.co/liuhaotian/llava-v1.6-34b liuhaotian/llava-v1.6-34b 28 | ``` 29 | 30 | ### Data Preparation 31 | 32 | 1. We prepare the ground-truth question and answer files based on [`IG-VLM`](https://github.com/imagegridworth/IG-VLM/tree/main), and put them under [playground/gt_qa_files](playground/gt_qa_files). 33 | 34 | - MSVD-QA 35 | - Download the `MSVD_QA.csv` from the [`here`](https://github.com/imagegridworth/IG-VLM/blob/main/data/open_ended_qa/MSVD_QA.csv) 36 | - Reformat the files by running 37 | ``` 38 | python scripts/data/prepare_msvd_qa_file.py --qa_file $PATH_TO_CSV_FILE 39 | ``` 40 | - MSRVTT-QA 41 | - Download the `MSRVTT_QA.csv` from the [`here`](https://github.com/imagegridworth/IG-VLM/blob/main/data/open_ended_qa/MSRVTT_QA.csv) 42 | - Reformat the files by running 43 | ``` 44 | python scripts/data/prepare_msrvtt_qa_file.py --qa_file $PATH_TO_CSV_FILE 45 | ``` 46 | - TGIF-QA 47 | - Download the `TGIF_FrameQA.csv` from the [`here`](https://github.com/imagegridworth/IG-VLM/blob/main/data/open_ended_qa/TGIF_FrameQA.csv) 48 | - Reformat the files by running 49 | ``` 50 | python scripts/data/prepare_tgif_qa_file.py --qa_file $PATH_TO_CSV_FILE 51 | ``` 52 | - Activitynet-QA 53 | - Download the `Activitynet_QA.csv` from the [`here`](https://github.com/imagegridworth/IG-VLM/blob/main/data/open_ended_qa/ActivityNet_QA.csv) 54 | - Reformat the files by running 55 | ``` 56 | python scripts/data/prepare_activitynet_qa_file.py --qa_file $PATH_TO_CSV_FILE 57 | ``` 58 | 59 | 2. Download the raw videos from the official websites. 60 | 61 | - [Recomanded] Option 1: Follow the instruction in [`Video-LLaVA`](https://github.com/PKU-YuanGroup/Video-LLaVA/blob/main/TRAIN_AND_VALIDATE.md) to download raw videos. 62 | - Option 2: Download videos from the data owners. 63 | - [`MSVD-QA`](https://github.com/xudejing/video-question-answering?tab=readme-ov-file) 64 | - [`MSRVTT-QA`](https://github.com/xudejing/video-question-answering?tab=readme-ov-file) 65 | - [`TGIF-QA`](https://github.com/YunseokJANG/tgif-qa?tab=readme-ov-file) 66 | - [`ActivityNet-QA`](https://github.com/MILVLG/activitynet-qa) 67 | 68 | 69 | 3. Organize the raw videos under [playground/data](playground/data). 70 | 71 | - To directly use our data loaders without changing paths, please organize your datasets as follows 72 | 73 | ``` 74 | $ FreeVideoLLM/playground/data 75 | ├── video_qa 76 | ├── MSVD_Zero_Shot_QA 77 | ├── videos 78 | ├── ... 79 | ├── MSRVTT_Zero_Shot_QA 80 | ├── videos 81 | ├── all 82 | ├── ... 83 | ├── TGIF_Zero_Shot_QA 84 | ├── mp4 85 | ├── ... 86 | ├── Activitynet_Zero_Shot_QA 87 | ├── all_test 88 | ├── ... 89 | ``` 90 | 91 | ## Configuration 92 | 93 | We use yaml config to control the design choice. You can refer to the code https://github.com/contrastive/FreeVideoLLM/blob/e973c8840306f60773b0d9058b222287c45c5f97/free_video_llm/llava/model/llava_arch.py#L275 to understand the config. 94 | 95 | ## Inference and Evaluation 96 | 97 | FreeVideoLLM is a training-free method, so we can directly do the inference and evaluation without model training. 98 | 99 | By default, we use 8 GPUs for the model inference. We can modify the `CUDA_VISIBLE_DEVICES` in the config file to accommodate your own settings. Please note that the model inference of FreeVideoLLM-34B requires GPUs with at least 80G memory. 100 | 101 | ``` 102 | cd FreeVideoLLM 103 | python run_inference.py --exp_config $PATH_TO_CONFIG_FILE 104 | ``` 105 | 106 | - This is optional, but use `export PYTHONWARNINGS="ignore"` if you want to suppress the warnings. 107 | 108 | ### Output Structures 109 | 110 | - The inference outputs will be stored under [`outputs/artifacts`](outputs/artifacts). 111 | - The intermediate outputs of GPT-3.5-turbo will be stored under [`outputs/eval_save_dir`](outputs/eval_save_dir). 112 | - The evaluation results will be stored under [`outputs/logs`](outputs/logs). 113 | - All of these can be changed in the config file. 114 | 115 | ## Acknowledgement 116 | 117 | The project is developed based on [LLaVA-v1.6](https://github.com/haotian-liu/LLaVA), [SlowFast-LLaVA](https://github.com/apple/ml-slowfast-llava), [IG-VLM](https://github.com/imagegridworth/IG-VLM), [CLIP](https://github.com/openai/CLIP) and [transformers](https://github.com/huggingface/transformers). 118 | 119 | ## Citation 120 | ``` 121 | @misc{han2024freevideollmpromptguidedvisual, 122 | title={Free Video-LLM: Prompt-guided Visual Perception for Efficient Training-free Video LLMs}, 123 | author={Kai Han and Jianyuan Guo and Yehui Tang and Wei He and Enhua Wu and Yunhe Wang}, 124 | year={2024}, 125 | eprint={2410.10441}, 126 | archivePrefix={arXiv}, 127 | primaryClass={cs.CV}, 128 | url={https://arxiv.org/abs/2410.10441}, 129 | } 130 | ``` 131 | -------------------------------------------------------------------------------- /cfgs/free_video_llm_34b.yaml: -------------------------------------------------------------------------------- 1 | SCRIPT: [ 2 | "bash scripts/run_eval_videoqabench.sh", # Openset VideoQA tasks 3 | ] 4 | 5 | CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7" 6 | CONFIG_NAME: "auto" 7 | DATA_DIR: [ 8 | "playground/data/video_qa", 9 | ] 10 | GT_QA_DIR: "playground/gt_qa_files" 11 | MODEL_PATH: "liuhaotian/llava-v1.6-34b/" 12 | OUTPUT_DIR: "outputs/artifacts" 13 | TEMP_DIR: "outputs/eval_save_dir" 14 | CONV_MODE: [ 15 | "image_seq_34b_v3", 16 | ] 17 | NUM_FRAMES: "50" 18 | INPUT_STRUCTURE: "image_seq" 19 | IMAGE_ASPECT_RATIO: "resize" 20 | TEMPORAL_AGGREGATION: "slow_3frms_spatial_1d_max_pool_roi8-middle_5frms_spatial_1d_max_pool_24x12-fast_50frms_4x4" 21 | ROPE_SCALING_FACTOR: 2 22 | SAVE_DIR: "outputs/artifacts/logs" 23 | -------------------------------------------------------------------------------- /cfgs/free_video_llm_7b.yaml: -------------------------------------------------------------------------------- 1 | SCRIPT: [ 2 | "bash scripts/run_eval_videoqabench.sh", # Openset VideoQA tasks 3 | ] 4 | 5 | CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7" 6 | CONFIG_NAME: "auto" 7 | DATA_DIR: [ 8 | "playground/data/video_qa", 9 | ] 10 | GT_QA_DIR: "playground/gt_qa_files" 11 | MODEL_PATH: "liuhaotian/llava-v1.6-vicuna-7b/" 12 | OUTPUT_DIR: "outputs/artifacts" 13 | TEMP_DIR: "outputs/eval_save_dir" 14 | CONV_MODE: [ 15 | "image_seq_v3", 16 | ] 17 | NUM_FRAMES: "50" 18 | INPUT_STRUCTURE: "image_seq" 19 | IMAGE_ASPECT_RATIO: "resize" 20 | TEMPORAL_AGGREGATION: "slow_3frms_spatial_1d_max_pool_roi8-middle_5frms_spatial_1d_max_pool_24x12-fast_50frms_4x4" 21 | ROPE_SCALING_FACTOR: 2 22 | SAVE_DIR: "outputs/artifacts/logs" 23 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | from decord import VideoReader, cpu 7 | import numpy as np 8 | from PIL import Image 9 | 10 | 11 | def load_frame(video_path, num_clips=1, num_frms=4): 12 | # Currently, this function supports only 1 clip 13 | assert num_clips == 1 14 | 15 | frame_names = sorted(os.listdir(video_path)) 16 | total_num_frames = len(frame_names) 17 | 18 | # Calculate desired number of frames to extract 19 | desired_num_frames = min(total_num_frames, num_frms) 20 | 21 | # Get indices of frames to extract 22 | frame_idx = get_seq_frames(total_num_frames, desired_num_frames) 23 | 24 | # Extract frames and get original sizes 25 | clip_imgs = [] 26 | original_sizes = [] 27 | for i in frame_idx: 28 | img = Image.open(os.path.join(video_path, frame_names[i])) 29 | clip_imgs.append(img) 30 | original_sizes.append(img.size) 31 | original_sizes = tuple(original_sizes) 32 | 33 | return clip_imgs, original_sizes 34 | 35 | 36 | def load_video(video_path, num_clips=1, num_frms=4): 37 | """ 38 | Load video frames from a video file. 39 | 40 | Parameters: 41 | video_path (str): Path to the video file. 42 | num_clips (int): Number of clips to extract from the video. Defaults to 1. 43 | num_frms (int): Number of frames to extract from each clip. Defaults to 4. 44 | 45 | Returns: 46 | list: List of PIL.Image.Image objects representing video frames. 47 | """ 48 | 49 | # Load video frame from a directory 50 | if os.path.isdir(video_path): 51 | return load_frame(video_path, num_clips, num_frms) 52 | 53 | # Load video with VideoReader 54 | vr = VideoReader(video_path, ctx=cpu(0)) 55 | total_num_frames = len(vr) 56 | 57 | # Currently, this function supports only 1 clip 58 | assert num_clips == 1 59 | 60 | # Calculate desired number of frames to extract 61 | desired_num_frames = min(total_num_frames, num_frms) 62 | 63 | # Get indices of frames to extract 64 | frame_idx = get_seq_frames(total_num_frames, desired_num_frames) 65 | 66 | # Extract frames as numpy array 67 | img_array = vr.get_batch(frame_idx).asnumpy() # (T H W C) 68 | clip_imgs = [Image.fromarray(img_array[i]) for i in range(desired_num_frames)] 69 | 70 | # Get original sizes of video frame 71 | original_size = (img_array.shape[-2], img_array.shape[-3]) # (W, H) 72 | original_sizes = (original_size,) * desired_num_frames 73 | 74 | return clip_imgs, original_sizes 75 | 76 | 77 | def get_seq_frames(total_num_frames, desired_num_frames): 78 | """ 79 | Calculate the indices of frames to extract from a video. 80 | 81 | Parameters: 82 | total_num_frames (int): Total number of frames in the video. 83 | desired_num_frames (int): Desired number of frames to extract. 84 | 85 | Returns: 86 | list: List of indices of frames to extract. 87 | """ 88 | 89 | # Calculate the size of each segment from which a frame will be extracted 90 | seg_size = float(total_num_frames - 1) / desired_num_frames 91 | 92 | seq = [] 93 | for i in range(desired_num_frames): 94 | # Calculate the start and end indices of each segment 95 | start = int(np.round(seg_size * i)) 96 | end = int(np.round(seg_size * (i + 1))) 97 | 98 | # Append the middle index of the segment to the list 99 | seq.append((start + end) // 2) 100 | return seq 101 | -------------------------------------------------------------------------------- /eval/eval_prompt.py: -------------------------------------------------------------------------------- 1 | def get_eval_prompt(prompt_mode="default"): 2 | """ 3 | prompt_mode: default, 'correctness', 'detailed_orientation', 'context', 'temporal', 'consistency' 4 | """ 5 | system_prompt = { 6 | "default": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " 7 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" 8 | "------" 9 | "##INSTRUCTIONS: " 10 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n" 11 | "- Consider synonyms or paraphrases as valid matches.\n" 12 | "- Evaluate the correctness of the prediction compared to the answer.", 13 | "consistency": "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. " 14 | "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ." 15 | "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:" 16 | "------" 17 | "##INSTRUCTIONS: " 18 | "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n" 19 | "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n" 20 | "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n" 21 | "- Evaluate the consistency of the two predicted answers compared to the correct answer.", 22 | "correctness": "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. " 23 | "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:" 24 | "------" 25 | "##INSTRUCTIONS: " 26 | "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n" 27 | "- The predicted answer must be factually accurate and align with the video content.\n" 28 | "- Consider synonyms or paraphrases as valid matches.\n" 29 | "- Evaluate the factual accuracy of the prediction compared to the answer.", 30 | "detailed_orientation": "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. " 31 | "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:" 32 | "------" 33 | "##INSTRUCTIONS: " 34 | "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n" 35 | "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n" 36 | "- Consider synonyms or paraphrases as valid matches.\n" 37 | "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity.", 38 | "context": "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. " 39 | "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:" 40 | "------" 41 | "##INSTRUCTIONS: " 42 | "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n" 43 | "- The predicted answer must capture the main themes and sentiments of the video.\n" 44 | "- Consider synonyms or paraphrases as valid matches.\n" 45 | "- Provide your evaluation of the contextual understanding of the prediction compared to the answer.", 46 | "temporal": "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. " 47 | "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:" 48 | "------" 49 | "##INSTRUCTIONS: " 50 | "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n" 51 | "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n" 52 | "- Evaluate the temporal accuracy of the prediction compared to the answer.", 53 | } 54 | user_prompt = { 55 | "default": "Please evaluate the following video-based question-answer pair:\n\n" 56 | "Question: %s\n" 57 | "Correct Answer: %s\n" 58 | "Predicted Answer: %s\n\n" 59 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " 60 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 61 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 62 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.", 63 | "consistency": "Please evaluate the following video-based question-answer pair:\n\n" 64 | "Question 1: %s\n" 65 | "Question 2: %s\n" 66 | "Correct Answer: %s\n" 67 | "Predicted Answer to Question 1: %s\n" 68 | "Predicted Answer to Question 2: %s\n\n" 69 | "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. " 70 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING." 71 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 72 | "For example, your response should look like this: {''score': 4.8}.", 73 | "correctness": "Please evaluate the following video-based question-answer pair:\n\n" 74 | "Question: %s\n" 75 | "Correct Answer: %s\n" 76 | "Predicted Answer: %s\n\n" 77 | "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. " 78 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING." 79 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 80 | "For example, your response should look like this: {''score': 4.8}.", 81 | "detailed_orientation": "Please evaluate the following video-based question-answer pair:\n\n" 82 | "Question: %s\n" 83 | "Correct Answer: %s\n" 84 | "Predicted Answer: %s\n\n" 85 | "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. " 86 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING." 87 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 88 | "For example, your response should look like this: {''score': 4.8}.", 89 | "context": "Please evaluate the following video-based question-answer pair:\n\n" 90 | "Question: %s\n" 91 | "Correct Answer: %s\n" 92 | "Predicted Answer: %s\n\n" 93 | "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. " 94 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING." 95 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 96 | "For example, your response should look like this: {''score': 4.8}.", 97 | "temporal": "Please evaluate the following video-based question-answer pair:\n\n" 98 | "Question: %s\n" 99 | "Correct Answer: %s\n" 100 | "Predicted Answer: %s\n\n" 101 | "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. " 102 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING." 103 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 104 | "For example, your response should look like this: {''score': 4.8}.", 105 | } 106 | return system_prompt[prompt_mode], user_prompt[prompt_mode] 107 | -------------------------------------------------------------------------------- /eval/eval_video_qa.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | client = OpenAI(organization=os.environ.get("OPENAI_ORG", None)) 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 14 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 15 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 16 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 17 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 18 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 19 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def annotate(prediction_set, caption_files, output_dir, args): 25 | """ 26 | Evaluates question and answer pairs using GPT-3 27 | Returns a score for correctness. 28 | """ 29 | for file in caption_files: 30 | key = file[:-5] # Strip file extension 31 | qa_set = prediction_set[key] 32 | question = qa_set['q'] 33 | answer = qa_set['a'] 34 | pred = qa_set['pred'] 35 | try: 36 | # Compute the correctness score 37 | completion = client.chat.completions.create( 38 | model=args.gpt_version, 39 | messages=[ 40 | { 41 | "role": "system", 42 | "content": 43 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " 44 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" 45 | "------" 46 | "##INSTRUCTIONS: " 47 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n" 48 | "- Consider synonyms or paraphrases as valid matches.\n" 49 | "- Evaluate the correctness of the prediction compared to the answer." 50 | }, 51 | { 52 | "role": "user", 53 | "content": 54 | "Please evaluate the following video-based question-answer pair:\n\n" 55 | f"Question: {question}\n" 56 | f"Correct Answer: {answer}\n" 57 | f"Predicted Answer: {pred}\n\n" 58 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " 59 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 60 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 61 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." 62 | } 63 | ] 64 | ) 65 | # Convert response to a Python dictionary. 66 | response_message = completion.choices[0].message.content 67 | response_dict = ast.literal_eval(response_message) 68 | result_qa_pair = [response_dict, qa_set] 69 | 70 | # Save the question-answer pairs to a json file. 71 | with open(f"{output_dir}/{key}.json", "w") as f: 72 | json.dump(result_qa_pair, f) 73 | 74 | except Exception as e: 75 | print(f"Error processing file '{key}': {e}") 76 | 77 | 78 | def main(): 79 | """ 80 | Main function to control the flow of the program. 81 | """ 82 | # Parse arguments. 83 | args = parse_args() 84 | 85 | file = open(args.pred_path) 86 | new_pred_contents = [eval(i.strip()) for i in file.readlines()] 87 | 88 | ''' 89 | # Dictionary to store the count of occurrences for each video_id 90 | video_id_counts = {} 91 | new_pred_contents = [] 92 | 93 | # Iterate through each sample in pred_contents 94 | for sample in pred_contents: 95 | video_id = sample['video_name'] 96 | if video_id in video_id_counts: 97 | video_id_counts[video_id] += 1 98 | else: 99 | video_id_counts[video_id] = 0 100 | 101 | # Create a new sample with the modified key 102 | new_sample = sample 103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 104 | new_pred_contents.append(new_sample) 105 | ''' 106 | # Generating list of id's and corresponding files 107 | id_list = [x['id'] for x in new_pred_contents] 108 | caption_files = [f"{id}.json" for id in id_list] 109 | caption_set = set(caption_files) 110 | 111 | output_dir = args.output_dir 112 | # Generate output directory if not exists. 113 | if not os.path.exists(output_dir): 114 | os.makedirs(output_dir) 115 | 116 | # Preparing dictionary of question-answer sets 117 | prediction_set = {} 118 | for sample in new_pred_contents: 119 | id = str(sample['id']) 120 | question = sample['question'] 121 | answer = sample['answer'] 122 | pred = sample['pred'] 123 | qa_set = {"q": question, "a": answer, "pred": pred} 124 | prediction_set[id] = qa_set 125 | 126 | num_tasks = args.num_tasks 127 | 128 | # While loop to ensure that all captions are processed. 129 | # Change `while loop` to `for loop`` to avoid endless loop. 130 | num_retries = 100 131 | for _ in range(num_retries + 1): 132 | try: 133 | # Files that have not been processed yet. 134 | completed_files = os.listdir(output_dir) 135 | completed_set = set(completed_files) 136 | print(f"completed_files: {len(completed_files)}") 137 | 138 | # Files that have not been processed yet. 139 | incomplete_files = list(caption_set - completed_set) 140 | print(f"incomplete_files: {len(incomplete_files)}") 141 | 142 | # Break the loop when there are no incomplete files 143 | if len(incomplete_files) == 0: 144 | break 145 | if len(incomplete_files) <= num_tasks: 146 | num_tasks = 1 147 | 148 | # Split tasks into parts. 149 | part_len = len(incomplete_files) // num_tasks 150 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 151 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 152 | 153 | # Use a pool of workers to process the files in parallel. 154 | with Pool() as pool: 155 | pool.starmap(annotate, task_args) 156 | 157 | except Exception as e: 158 | print(f"Error: {e}") 159 | else: 160 | print( 161 | f"Run into endless loop over {num_retries} times, " \ 162 | f"skip {len(incomplete_files)} incomplete files for now ..." 163 | ) 164 | 165 | # Combine all the processed files into one 166 | combined_contents = {} 167 | json_path = args.output_json 168 | 169 | # Iterate through json files 170 | for file_name in os.listdir(output_dir): 171 | if file_name.endswith(".json"): 172 | file_path = os.path.join(output_dir, file_name) 173 | with open(file_path, "r") as json_file: 174 | content = json.load(json_file) 175 | combined_contents[file_name[:-5]] = content 176 | 177 | # Write combined content to a json file 178 | with open(json_path, "w") as json_file: 179 | json.dump(combined_contents, json_file) 180 | print("All evaluation completed!") 181 | 182 | # Calculate average score and accuracy 183 | score_sum = 0 184 | count = 0 185 | yes_count = 0 186 | no_count = 0 187 | for key, result in tqdm(combined_contents.items()): 188 | try: 189 | # Computing score 190 | count += 1 191 | score_match = result[0]['score'] 192 | score = int(score_match) 193 | score_sum += score 194 | 195 | # Computing accuracy 196 | pred = result[0]['pred'] 197 | if "yes" in pred.lower(): 198 | yes_count += 1 199 | elif "no" in pred.lower(): 200 | no_count += 1 201 | except: 202 | print(result) 203 | 204 | average_score = score_sum / count 205 | accuracy = yes_count / (yes_count + no_count) 206 | print("Yes count:", yes_count) 207 | print("No count:", no_count) 208 | print("Accuracy:", accuracy) 209 | print("Average score:", average_score) 210 | 211 | 212 | if __name__ == "__main__": 213 | main() 214 | -------------------------------------------------------------------------------- /free_video_llm/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /free_video_llm/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /free_video_llm/llava/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | SINGLE = auto() 12 | TWO = auto() 13 | MPT = auto() 14 | PLAIN = auto() 15 | LLAMA_2 = auto() 16 | 17 | 18 | @dataclasses.dataclass 19 | class Conversation: 20 | """A class that keeps all conversation history.""" 21 | system: str 22 | roles: List[str] 23 | messages: List[List[str]] 24 | offset: int 25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 26 | sep: str = "###" 27 | sep2: str = None 28 | version: str = "Unknown" 29 | 30 | skip_next: bool = False 31 | 32 | def get_prompt(self): 33 | messages = self.messages 34 | if len(messages) > 0 and type(messages[0][1]) is tuple: 35 | messages = self.messages.copy() 36 | init_role, init_msg = messages[0].copy() 37 | init_msg = init_msg[0].replace("", "").strip() 38 | if 'mmtag' in self.version: 39 | messages[0] = (init_role, init_msg) 40 | messages.insert(0, (self.roles[0], "")) 41 | messages.insert(1, (self.roles[1], "Received.")) 42 | else: 43 | messages[0] = (init_role, "\n" + init_msg) 44 | 45 | if self.sep_style == SeparatorStyle.SINGLE: 46 | ret = self.system + self.sep 47 | for role, message in messages: 48 | if message: 49 | if type(message) is tuple: 50 | message, _, _ = message 51 | ret += role + ": " + message + self.sep 52 | else: 53 | ret += role + ":" 54 | elif self.sep_style == SeparatorStyle.TWO: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system + seps[0] 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | elif self.sep_style == SeparatorStyle.MPT: 65 | ret = self.system + self.sep 66 | for role, message in messages: 67 | if message: 68 | if type(message) is tuple: 69 | message, _, _ = message 70 | ret += role + message + self.sep 71 | else: 72 | ret += role 73 | elif self.sep_style == SeparatorStyle.LLAMA_2: 74 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 75 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 76 | ret = "" 77 | 78 | for i, (role, message) in enumerate(messages): 79 | if i == 0: 80 | assert message, "first message should not be none" 81 | assert role == self.roles[0], "first message should come from user" 82 | if message: 83 | if type(message) is tuple: 84 | message, _, _ = message 85 | if i == 0: message = wrap_sys(self.system) + message 86 | if i % 2 == 0: 87 | message = wrap_inst(message) 88 | ret += self.sep + message 89 | else: 90 | ret += " " + message + " " + self.sep2 91 | else: 92 | ret += "" 93 | ret = ret.lstrip(self.sep) 94 | elif self.sep_style == SeparatorStyle.PLAIN: 95 | seps = [self.sep, self.sep2] 96 | ret = self.system 97 | for i, (role, message) in enumerate(messages): 98 | if message: 99 | if type(message) is tuple: 100 | message, _, _ = message 101 | ret += message + seps[i % 2] 102 | else: 103 | ret += "" 104 | else: 105 | raise ValueError(f"Invalid style: {self.sep_style}") 106 | 107 | return ret 108 | 109 | def append_message(self, role, message): 110 | self.messages.append([role, message]) 111 | 112 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): 113 | if image_process_mode == "Pad": 114 | def expand2square(pil_img, background_color=(122, 116, 104)): 115 | width, height = pil_img.size 116 | if width == height: 117 | return pil_img 118 | elif width > height: 119 | result = Image.new(pil_img.mode, (width, width), background_color) 120 | result.paste(pil_img, (0, (width - height) // 2)) 121 | return result 122 | else: 123 | result = Image.new(pil_img.mode, (height, height), background_color) 124 | result.paste(pil_img, ((height - width) // 2, 0)) 125 | return result 126 | image = expand2square(image) 127 | elif image_process_mode in ["Default", "Crop"]: 128 | pass 129 | elif image_process_mode == "Resize": 130 | image = image.resize((336, 336)) 131 | else: 132 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 133 | if max(image.size) > max_len: 134 | max_hw, min_hw = max(image.size), min(image.size) 135 | aspect_ratio = max_hw / min_hw 136 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 137 | longest_edge = int(shortest_edge * aspect_ratio) 138 | W, H = image.size 139 | if H > W: 140 | H, W = longest_edge, shortest_edge 141 | else: 142 | H, W = shortest_edge, longest_edge 143 | image = image.resize((W, H)) 144 | if return_pil: 145 | return image 146 | else: 147 | buffered = BytesIO() 148 | image.save(buffered, format=image_format) 149 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 150 | return img_b64_str 151 | 152 | def get_images(self, return_pil=False): 153 | images = [] 154 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 155 | if i % 2 == 0: 156 | if type(msg) is tuple: 157 | msg, image, image_process_mode = msg 158 | image = self.process_image(image, image_process_mode, return_pil=return_pil) 159 | images.append(image) 160 | return images 161 | 162 | def to_gradio_chatbot(self): 163 | ret = [] 164 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 165 | if i % 2 == 0: 166 | if type(msg) is tuple: 167 | msg, image, image_process_mode = msg 168 | img_b64_str = self.process_image( 169 | image, "Default", return_pil=False, 170 | image_format='JPEG') 171 | img_str = f'user upload image' 172 | msg = img_str + msg.replace('', '').strip() 173 | ret.append([msg, None]) 174 | else: 175 | ret.append([msg, None]) 176 | else: 177 | ret[-1][-1] = msg 178 | return ret 179 | 180 | def copy(self): 181 | return Conversation( 182 | system=self.system, 183 | roles=self.roles, 184 | messages=[[x, y] for x, y in self.messages], 185 | offset=self.offset, 186 | sep_style=self.sep_style, 187 | sep=self.sep, 188 | sep2=self.sep2, 189 | version=self.version) 190 | 191 | def dict(self): 192 | if len(self.get_images()) > 0: 193 | return { 194 | "system": self.system, 195 | "roles": self.roles, 196 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 197 | "offset": self.offset, 198 | "sep": self.sep, 199 | "sep2": self.sep2, 200 | } 201 | return { 202 | "system": self.system, 203 | "roles": self.roles, 204 | "messages": self.messages, 205 | "offset": self.offset, 206 | "sep": self.sep, 207 | "sep2": self.sep2, 208 | } 209 | 210 | 211 | conv_vicuna_v0 = Conversation( 212 | system="A chat between a curious human and an artificial intelligence assistant. " 213 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 214 | roles=("Human", "Assistant"), 215 | messages=( 216 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 217 | ("Assistant", 218 | "Renewable energy sources are those that can be replenished naturally in a relatively " 219 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 220 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 221 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 222 | "renewable and non-renewable energy sources:\n" 223 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 224 | "energy sources are finite and will eventually run out.\n" 225 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 226 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 227 | "and other negative effects.\n" 228 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 229 | "have lower operational costs than non-renewable sources.\n" 230 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 231 | "locations than non-renewable sources.\n" 232 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 233 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 234 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 235 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 236 | ), 237 | offset=2, 238 | sep_style=SeparatorStyle.SINGLE, 239 | sep="###", 240 | ) 241 | 242 | conv_vicuna_v1 = Conversation( 243 | system="A chat between a curious user and an artificial intelligence assistant. " 244 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 245 | roles=("USER", "ASSISTANT"), 246 | version="v1", 247 | messages=(), 248 | offset=0, 249 | sep_style=SeparatorStyle.TWO, 250 | sep=" ", 251 | sep2="", 252 | ) 253 | 254 | conv_llama_2 = Conversation( 255 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 256 | 257 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 258 | roles=("USER", "ASSISTANT"), 259 | version="llama_v2", 260 | messages=(), 261 | offset=0, 262 | sep_style=SeparatorStyle.LLAMA_2, 263 | sep="", 264 | sep2="", 265 | ) 266 | 267 | conv_llava_llama_2 = Conversation( 268 | system="You are a helpful language and vision assistant. " 269 | "You are able to understand the visual content that the user provides, " 270 | "and assist the user with a variety of tasks using natural language.", 271 | roles=("USER", "ASSISTANT"), 272 | version="llama_v2", 273 | messages=(), 274 | offset=0, 275 | sep_style=SeparatorStyle.LLAMA_2, 276 | sep="", 277 | sep2="", 278 | ) 279 | 280 | conv_mpt = Conversation( 281 | system="""<|im_start|>system 282 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 283 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 284 | version="mpt", 285 | messages=(), 286 | offset=0, 287 | sep_style=SeparatorStyle.MPT, 288 | sep="<|im_end|>", 289 | ) 290 | 291 | conv_llava_plain = Conversation( 292 | system="", 293 | roles=("", ""), 294 | messages=( 295 | ), 296 | offset=0, 297 | sep_style=SeparatorStyle.PLAIN, 298 | sep="\n", 299 | ) 300 | 301 | conv_llava_v0 = Conversation( 302 | system="A chat between a curious human and an artificial intelligence assistant. " 303 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 304 | roles=("Human", "Assistant"), 305 | messages=( 306 | ), 307 | offset=0, 308 | sep_style=SeparatorStyle.SINGLE, 309 | sep="###", 310 | ) 311 | 312 | conv_llava_v0_mmtag = Conversation( 313 | system="A chat between a curious user and an artificial intelligence assistant. " 314 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 315 | "The visual content will be provided with the following format: visual content.", 316 | roles=("Human", "Assistant"), 317 | messages=( 318 | ), 319 | offset=0, 320 | sep_style=SeparatorStyle.SINGLE, 321 | sep="###", 322 | version="v0_mmtag", 323 | ) 324 | 325 | conv_llava_v1 = Conversation( 326 | system="A chat between a curious human and an artificial intelligence assistant. " 327 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 328 | roles=("USER", "ASSISTANT"), 329 | version="v1", 330 | messages=(), 331 | offset=0, 332 | sep_style=SeparatorStyle.TWO, 333 | sep=" ", 334 | sep2="", 335 | ) 336 | 337 | conv_llava_v1_mmtag = Conversation( 338 | system="A chat between a curious user and an artificial intelligence assistant. " 339 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 340 | "The visual content will be provided with the following format: visual content.", 341 | roles=("USER", "ASSISTANT"), 342 | messages=(), 343 | offset=0, 344 | sep_style=SeparatorStyle.TWO, 345 | sep=" ", 346 | sep2="", 347 | version="v1_mmtag", 348 | ) 349 | 350 | conv_mistral_instruct = Conversation( 351 | system="", 352 | roles=("USER", "ASSISTANT"), 353 | version="llama_v2", 354 | messages=(), 355 | offset=0, 356 | sep_style=SeparatorStyle.LLAMA_2, 357 | sep="", 358 | sep2="", 359 | ) 360 | 361 | conv_chatml_direct = Conversation( 362 | system="""<|im_start|>system 363 | Answer the questions.""", 364 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 365 | version="mpt", 366 | messages=(), 367 | offset=0, 368 | sep_style=SeparatorStyle.MPT, 369 | sep="<|im_end|>", 370 | ) 371 | 372 | default_conversation = conv_vicuna_v1 373 | conv_templates = { 374 | "default": conv_vicuna_v0, 375 | "v0": conv_vicuna_v0, 376 | "v1": conv_vicuna_v1, 377 | "vicuna_v1": conv_vicuna_v1, 378 | "llama_2": conv_llama_2, 379 | "mistral_instruct": conv_mistral_instruct, 380 | "chatml_direct": conv_chatml_direct, 381 | "mistral_direct": conv_chatml_direct, 382 | 383 | "plain": conv_llava_plain, 384 | "v0_plain": conv_llava_plain, 385 | "llava_v0": conv_llava_v0, 386 | "v0_mmtag": conv_llava_v0_mmtag, 387 | "llava_v1": conv_llava_v1, 388 | "v1_mmtag": conv_llava_v1_mmtag, 389 | "llava_llama_2": conv_llava_llama_2, 390 | 391 | "mpt": conv_mpt, 392 | } 393 | 394 | 395 | if __name__ == "__main__": 396 | print(default_conversation.get_prompt()) 397 | -------------------------------------------------------------------------------- /free_video_llm/llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | elif image_aspect_ratio == "resize": 179 | for image in images: 180 | image = image.resize(( 181 | image_processor.size['shortest_edge'], 182 | image_processor.size['shortest_edge'], 183 | )) 184 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 185 | new_images.append(image) 186 | else: 187 | return image_processor(images, return_tensors='pt')['pixel_values'] 188 | if all(x.shape == new_images[0].shape for x in new_images): 189 | new_images = torch.stack(new_images, dim=0) 190 | return new_images 191 | 192 | 193 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 194 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 195 | 196 | def insert_separator(X, sep): 197 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 198 | 199 | input_ids = [] 200 | offset = 0 201 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 202 | offset = 1 203 | input_ids.append(prompt_chunks[0][0]) 204 | 205 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 206 | input_ids.extend(x[offset:]) 207 | 208 | if return_tensors is not None: 209 | if return_tensors == 'pt': 210 | return torch.tensor(input_ids, dtype=torch.long) 211 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 212 | return input_ids 213 | 214 | 215 | def get_model_name_from_path(model_path): 216 | model_path = model_path.strip("/") 217 | model_paths = model_path.split("/") 218 | if model_paths[-1].startswith('checkpoint-'): 219 | return model_paths[-2] + "_" + model_paths[-1] 220 | else: 221 | return model_paths[-1] 222 | 223 | class KeywordsStoppingCriteria(StoppingCriteria): 224 | def __init__(self, keywords, tokenizer, input_ids): 225 | self.keywords = keywords 226 | self.keyword_ids = [] 227 | self.max_keyword_len = 0 228 | for keyword in keywords: 229 | cur_keyword_ids = tokenizer(keyword).input_ids 230 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 231 | cur_keyword_ids = cur_keyword_ids[1:] 232 | if len(cur_keyword_ids) > self.max_keyword_len: 233 | self.max_keyword_len = len(cur_keyword_ids) 234 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 235 | self.tokenizer = tokenizer 236 | self.start_len = input_ids.shape[1] 237 | 238 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 239 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 240 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 241 | for keyword_id in self.keyword_ids: 242 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 243 | if torch.equal(truncated_output_ids, keyword_id): 244 | return True 245 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 246 | for keyword in self.keywords: 247 | if keyword in outputs: 248 | return True 249 | return False 250 | 251 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 252 | outputs = [] 253 | for i in range(output_ids.shape[0]): 254 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 255 | return all(outputs) 256 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | if 'llava' in model_name.lower(): 49 | # Load LLaVA model 50 | if 'lora' in model_name.lower() and model_base is None: 51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 52 | if 'lora' in model_name.lower() and model_base is not None: 53 | from llava.model.language_model.llava_llama import LlavaConfig 54 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) 55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 56 | print('Loading LLaVA from base model...') 57 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 59 | if model.lm_head.weight.shape[0] != token_num: 60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 62 | 63 | print('Loading additional LLaVA weights...') 64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 66 | else: 67 | # this is probably from HF Hub 68 | from huggingface_hub import hf_hub_download 69 | def load_from_hf(repo_id, filename, subfolder=None): 70 | cache_file = hf_hub_download( 71 | repo_id=repo_id, 72 | filename=filename, 73 | subfolder=subfolder) 74 | return torch.load(cache_file, map_location='cpu') 75 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 76 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 77 | if any(k.startswith('model.model.') for k in non_lora_trainables): 78 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 79 | model.load_state_dict(non_lora_trainables, strict=False) 80 | 81 | from peft import PeftModel 82 | print('Loading LoRA weights...') 83 | model = PeftModel.from_pretrained(model, model_path) 84 | print('Merging LoRA weights...') 85 | model = model.merge_and_unload() 86 | print('Model is loaded...') 87 | elif model_base is not None: 88 | # this may be mm projector only 89 | print('Loading LLaVA from base model...') 90 | if 'mpt' in model_name.lower(): 91 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 92 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 93 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 94 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 95 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 96 | else: 97 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 98 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 99 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 100 | 101 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 102 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 103 | model.load_state_dict(mm_projector_weights, strict=False) 104 | else: 105 | if 'mpt' in model_name.lower(): 106 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 107 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 108 | elif 'mistral' in model_name.lower(): 109 | tokenizer = AutoTokenizer.from_pretrained(model_path) 110 | model = LlavaMistralForCausalLM.from_pretrained( 111 | model_path, 112 | low_cpu_mem_usage=True, 113 | **kwargs 114 | ) 115 | else: 116 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 117 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 118 | rope_scaling_factor = int(kwargs.pop("rope_scaling_factor", 1)) 119 | if rope_scaling_factor >= 2: 120 | setattr(cfg_pretrained, "rope_scaling", {"factor": float(rope_scaling_factor), "type": "dynamic"}) 121 | setattr(cfg_pretrained, "max_sequence_length", 4096 * rope_scaling_factor) 122 | setattr(cfg_pretrained, "tokenizer_model_max_length", 4096 * rope_scaling_factor) 123 | model = LlavaLlamaForCausalLM.from_pretrained( 124 | model_path, 125 | low_cpu_mem_usage=True, 126 | config=cfg_pretrained, 127 | **kwargs 128 | ) 129 | else: 130 | # Load language model 131 | if model_base is not None: 132 | # PEFT model 133 | from peft import PeftModel 134 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 135 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 136 | print(f"Loading LoRA weights from {model_path}") 137 | model = PeftModel.from_pretrained(model, model_path) 138 | print(f"Merging weights") 139 | model = model.merge_and_unload() 140 | print('Convert to FP16...') 141 | model.to(torch.float16) 142 | else: 143 | use_fast = False 144 | if 'mpt' in model_name.lower(): 145 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 146 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 147 | else: 148 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 149 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 150 | 151 | image_processor = None 152 | 153 | if 'llava' in model_name.lower(): 154 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 155 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 156 | if mm_use_im_patch_token: 157 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 158 | if mm_use_im_start_end: 159 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 160 | model.resize_token_embeddings(len(tokenizer)) 161 | 162 | vision_tower = model.get_vision_tower() 163 | if not vision_tower.is_loaded: 164 | vision_tower.load_model(device_map=device_map) 165 | if device_map != 'auto': 166 | vision_tower.to(device=device_map, dtype=torch.float16) 167 | image_processor = vision_tower.image_processor 168 | 169 | clip_double_tower = model.get_clip_double_tower() 170 | if not clip_double_tower.is_loaded: 171 | clip_double_tower.load_model(device_map=device_map) 172 | if device_map != 'auto': 173 | clip_double_tower.to(device=device_map, dtype=torch.float16) 174 | 175 | if hasattr(model.config, "max_sequence_length"): 176 | context_len = model.config.max_sequence_length 177 | else: 178 | context_len = 2048 179 | 180 | return tokenizer, model, image_processor, context_len 181 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | question = kwargs.pop("question", None) 115 | if "inputs_embeds" in kwargs: 116 | raise NotImplementedError("`inputs_embeds` is not supported") 117 | 118 | if images is not None: 119 | if question is not None: 120 | ( 121 | inputs, 122 | position_ids, 123 | attention_mask, 124 | _, 125 | inputs_embeds, 126 | _ 127 | ) = self.prepare_inputs_labels_for_multimodal( 128 | inputs, 129 | position_ids, 130 | attention_mask, 131 | None, 132 | None, 133 | images, 134 | image_sizes=image_sizes, 135 | temporal_aggregation=kwargs.pop("temporal_aggregation", None), 136 | question=question, 137 | ) 138 | else: 139 | ( 140 | inputs, 141 | position_ids, 142 | attention_mask, 143 | _, 144 | inputs_embeds, 145 | _ 146 | ) = self.prepare_inputs_labels_for_multimodal( 147 | inputs, 148 | position_ids, 149 | attention_mask, 150 | None, 151 | None, 152 | images, 153 | image_sizes=image_sizes, 154 | temporal_aggregation=kwargs.pop("temporal_aggregation", None), 155 | ) 156 | else: 157 | inputs_embeds = self.get_model().embed_tokens(inputs) 158 | 159 | return super().generate( 160 | position_ids=position_ids, 161 | attention_mask=attention_mask, 162 | inputs_embeds=inputs_embeds, 163 | **kwargs 164 | ) 165 | 166 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 167 | inputs_embeds=None, **kwargs): 168 | images = kwargs.pop("images", None) 169 | image_sizes = kwargs.pop("image_sizes", None) 170 | inputs = super().prepare_inputs_for_generation( 171 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 172 | ) 173 | # Fix the bug according to "https://github.com/haotian-liu/LLaVA/issues/1448" 174 | inputs.pop("cache_position") 175 | if images is not None: 176 | inputs['images'] = images 177 | if image_sizes is not None: 178 | inputs['image_sizes'] = image_sizes 179 | return inputs 180 | 181 | AutoConfig.register("llava_llama", LlavaConfig) 182 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 183 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/llava_arch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------- 2 | # Based on the code by Haotian liu, 2020 and Apple 2024 3 | # -------------------------------------------------- 4 | from abc import ABC, abstractmethod 5 | import math 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .multimodal_encoder.builder import build_vision_tower 13 | from .multimodal_projector.builder import build_vision_projector 14 | from .multimodal_encoder.clip_encoder import CLIPDoubleTower 15 | 16 | from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 17 | 18 | from llava.mm_utils import get_anyres_image_grid_shape 19 | 20 | from einops import rearrange 21 | 22 | 23 | class LlavaMetaModel: 24 | 25 | def __init__(self, config): 26 | super(LlavaMetaModel, self).__init__(config) 27 | 28 | if hasattr(config, "mm_vision_tower"): 29 | self.vision_tower = build_vision_tower(config, delay_load=True) 30 | self.mm_projector = build_vision_projector(config) 31 | model_name = getattr(config, 'mm_vision_tower', getattr(config, 'vision_tower', None)) 32 | self.clip_double_tower = CLIPDoubleTower(model_name, delay_load=True) 33 | 34 | if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): 35 | self.image_newline = nn.Parameter( 36 | torch.empty(config.hidden_size, dtype=self.dtype) 37 | ) 38 | 39 | def get_clip_double_tower(self): 40 | clip_double_tower = getattr(self, 'clip_double_tower', None) 41 | return clip_double_tower 42 | 43 | def get_vision_tower(self): 44 | vision_tower = getattr(self, 'vision_tower', None) 45 | if type(vision_tower) is list: 46 | vision_tower = vision_tower[0] 47 | return vision_tower 48 | 49 | def initialize_vision_modules(self, model_args, fsdp=None): 50 | vision_tower = model_args.vision_tower 51 | mm_vision_select_layer = model_args.mm_vision_select_layer 52 | mm_vision_select_feature = model_args.mm_vision_select_feature 53 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 54 | mm_patch_merge_type = model_args.mm_patch_merge_type 55 | 56 | self.config.mm_vision_tower = vision_tower 57 | 58 | if self.get_vision_tower() is None: 59 | vision_tower = build_vision_tower(model_args) 60 | 61 | if fsdp is not None and len(fsdp) > 0: 62 | self.vision_tower = [vision_tower] 63 | else: 64 | self.vision_tower = vision_tower 65 | else: 66 | if fsdp is not None and len(fsdp) > 0: 67 | vision_tower = self.vision_tower[0] 68 | else: 69 | vision_tower = self.vision_tower 70 | vision_tower.load_model() 71 | 72 | self.config.use_mm_proj = True 73 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 74 | self.config.mm_hidden_size = vision_tower.hidden_size 75 | self.config.mm_vision_select_layer = mm_vision_select_layer 76 | self.config.mm_vision_select_feature = mm_vision_select_feature 77 | self.config.mm_patch_merge_type = mm_patch_merge_type 78 | 79 | if getattr(self, 'mm_projector', None) is None: 80 | self.mm_projector = build_vision_projector(self.config) 81 | 82 | if 'unpad' in mm_patch_merge_type: 83 | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) 84 | self.image_newline = nn.Parameter( 85 | torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std 86 | ) 87 | else: 88 | # In case it is frozen by LoRA 89 | for p in self.mm_projector.parameters(): 90 | p.requires_grad = True 91 | 92 | if pretrain_mm_mlp_adapter is not None: 93 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 94 | def get_w(weights, keyword): 95 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 96 | 97 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 98 | 99 | 100 | def unpad_image(tensor, original_size): 101 | """ 102 | Unpads a PyTorch tensor of a padded and resized image. 103 | 104 | Args: 105 | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. 106 | original_size (tuple): The original size of PIL image (width, height). 107 | 108 | Returns: 109 | torch.Tensor: The unpadded image tensor. 110 | """ 111 | original_width, original_height = original_size 112 | current_height, current_width = tensor.shape[1:] 113 | 114 | original_aspect_ratio = original_width / original_height 115 | current_aspect_ratio = current_width / current_height 116 | 117 | if original_aspect_ratio > current_aspect_ratio: 118 | scale_factor = current_width / original_width 119 | new_height = int(original_height * scale_factor) 120 | padding = (current_height - new_height) // 2 121 | unpadded_tensor = tensor[:, padding:current_height - padding, :] 122 | else: 123 | scale_factor = current_height / original_height 124 | new_width = int(original_width * scale_factor) 125 | padding = (current_width - new_width) // 2 126 | unpadded_tensor = tensor[:, :, padding:current_width - padding] 127 | 128 | return unpadded_tensor 129 | 130 | 131 | class LlavaMetaForCausalLM(ABC): 132 | 133 | @abstractmethod 134 | def get_model(self): 135 | pass 136 | 137 | def get_vision_tower(self): 138 | return self.get_model().get_vision_tower() 139 | 140 | def get_clip_double_tower(self): 141 | return self.get_model().get_clip_double_tower() 142 | 143 | def encode_images(self, images): 144 | image_features = self.get_model().get_vision_tower()(images) 145 | image_features = self.get_model().mm_projector(image_features) 146 | return image_features 147 | 148 | def roi_box(self, points, h, w, ratio=0.5): 149 | x_list = points % w 150 | y_list = points // w 151 | x_mean = x_list.float().mean(dim=-1).to(torch.int32) 152 | y_mean = y_list.float().mean(dim=-1).to(torch.int32) 153 | h_new, w_new = round(h * math.sqrt(ratio)), round(w * math.sqrt(ratio)) 154 | left, right = x_mean - w_new // 2, x_mean + w_new - w_new // 2 155 | top, bottom = y_mean - h_new // 2, y_mean + h_new - h_new // 2 156 | for i in range(left.shape[0]): 157 | if left[i] < 0: 158 | left[i], right[i] = 0, w_new 159 | if right[i] > w: 160 | left[i], right[i] = w - w_new, w 161 | if top[i] < 0: 162 | top[i], bottom[i] = 0, h_new 163 | if bottom[i] > h: 164 | top[i], bottom[i] = h - h_new, h 165 | return left, right, top, bottom 166 | 167 | def temporal_aggregation(self, image_features, temporal_aggregation, adaptive_size=None): 168 | T, N, D = image_features.shape 169 | 170 | if temporal_aggregation == "concat": 171 | ## temporal cat 172 | image_features = image_features.view(T * N, D) 173 | elif temporal_aggregation == "spatial_1d_max_pool": 174 | ## horizontal max pool + temporal cat 175 | pool2 = nn.MaxPool1d(kernel_size=2, stride=2) 176 | image_features = rearrange(image_features, 't n d -> t d n') 177 | image_features = pool2(image_features) 178 | image_features = rearrange(image_features, 't d n -> t n d', t=T) 179 | image_features = image_features.view(-1, D) 180 | elif temporal_aggregation == "vertical_1d_max_pool": 181 | ## spatial max pool + temporal cat 182 | n0 = n1 = int(math.sqrt(N)) 183 | pool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) 184 | image_features = rearrange(image_features, 't (n0 n1) d -> d t n0 n1', n0=n0, n1=n1) 185 | image_features = pool2(image_features) 186 | image_features = rearrange(image_features, 'd t n0 n1 -> (t n0 n1) d') 187 | elif temporal_aggregation == "spatial_1d_avg_pool": 188 | ## horizontal avg pool + temporal cat 189 | pool2 = nn.AvgPool1d(kernel_size=2, stride=2) 190 | image_features = rearrange(image_features, 't n d -> t d n') 191 | image_features = pool2(image_features) 192 | image_features = rearrange(image_features, 't d n -> t n d', t=T) 193 | image_features = image_features.view(-1, D) 194 | elif temporal_aggregation == "spatial_2d_max_pool": 195 | ## spatial max pool + temporal cat 196 | n0 = n1 = int(math.sqrt(N)) 197 | pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 198 | image_features = rearrange(image_features, 't (n0 n1) d -> d t n0 n1', n0=n0, n1=n1) 199 | image_features = pool2(image_features) 200 | image_features = rearrange(image_features, 'd t n0 n1 -> (t n0 n1) d') 201 | elif temporal_aggregation == "spatial_2d_avg_pool": 202 | ## spatial avg pool + temporal cat 203 | n0 = n1 = int(math.sqrt(N)) 204 | pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 205 | image_features = rearrange(image_features, 't (n0 n1) d -> d t n0 n1', n0=n0, n1=n1) 206 | image_features = pool2(image_features) 207 | image_features = rearrange(image_features, 'd t n0 n1 -> (t n0 n1) d') 208 | elif temporal_aggregation == "adaptive_2d_pool": 209 | ## Adaptive avg pool + temporal cat 210 | n0 = n1 = int(math.sqrt(N)) 211 | pool2 = nn.AdaptiveAvgPool2d(adaptive_size) 212 | image_features = rearrange(image_features, 't (n0 n1) d -> d t n0 n1', n0=n0, n1=n1) # [T N D] -> [D T N0 N1] 213 | image_features = pool2(image_features) 214 | image_features = rearrange(image_features, 'd t n0 n1 -> (t n0 n1) d') # [D T N0/2 N1/2] -> [-1 D] 215 | elif temporal_aggregation == "spatial_temporal_pool": 216 | ## spatial pool + temporal pool 217 | pooling_size = (16, 12, 12) 218 | n0 = n1 = int(math.sqrt(N)) 219 | pool3 = nn.AdaptiveAvgPool3d(pooling_size) 220 | image_features = rearrange(image_features, 't (n0 n1) d -> d t n0 n1', n0=n0, n1=n1) 221 | image_features = pool3(image_features) 222 | image_features = rearrange(image_features, 'd t n0 n1 -> (t n0 n1) d') 223 | elif temporal_aggregation == "temporal_global_pool": 224 | ## temporal pool 225 | image_features = torch.mean(image_features, dim=0) 226 | else: 227 | raise ValueError(f'Unknown temporal aggregation method: {temporal_aggregation}') 228 | 229 | image_features = image_features.unsqueeze(0) 230 | return image_features 231 | 232 | def text_guided_frames(self, question, images, num_textguided): 233 | text_pooled, image_pooled = self.get_clip_double_tower().get_pooled_features(question, images) 234 | cos = F.cosine_similarity(text_pooled, image_pooled, dim=-1) 235 | _, sorted_idxs = torch.sort(cos, dim=-1) 236 | select_idxs = sorted_idxs[-num_textguided:] 237 | textguided_idx = select_idxs.to(torch.int32).sort()[0].tolist() 238 | return textguided_idx 239 | 240 | def top_change_frames(self, image_features, num_selected): 241 | global_features = torch.mean(image_features, dim=1) # [T,D] 242 | cos = F.cosine_similarity(global_features[:-1,:], global_features[1:,:], dim=-1) 243 | _, selected_idxs = torch.topk(cos, num_selected-1, largest=False, sorted=False) 244 | change_idxs = [0] 245 | change_idxs += (selected_idxs+1).to(torch.int32).sort()[0].tolist() 246 | return change_idxs 247 | 248 | def roi_crop(self, question, images, image_features, temporal_aggregation, output_size, roi_ratio): 249 | h, w = output_size 250 | text_pooled, _, _, image_finegrained = self.get_clip_double_tower().get_finegrained_features(question, images) 251 | image_finegrained = self.get_clip_double_tower().clip_model.visual_projection(image_finegrained[:,1:]) 252 | image_finegrained = self.temporal_aggregation( 253 | image_finegrained, 254 | temporal_aggregation, 255 | output_size, 256 | ) 257 | image_finegrained = rearrange(image_finegrained, 'b (t n) d -> (b t) n d', t=images.shape[0]) 258 | 259 | cos = F.cosine_similarity(text_pooled, image_finegrained, dim=-1) 260 | _, top_idxs = torch.topk(cos, int(h*w*roi_ratio), dim=1, largest=True, sorted=False) 261 | top_idxs = top_idxs.to(torch.int32).sort(dim=1)[0] 262 | left, right, top, bottom = self.roi_box(top_idxs, h, w, roi_ratio) 263 | 264 | image_features = rearrange(image_features, 'b (t h w) d -> b t h w d', t=images.shape[0], h=h, w=w) 265 | select_features = [] 266 | for i in range(images.shape[0]): 267 | select_features += [image_features[:, i, top[i]:bottom[i], left[i]:right[i]]] 268 | select_features = torch.stack(select_features, dim=1) 269 | select_features = rearrange(select_features, 'b t h w d -> b (t h w) d') 270 | return select_features 271 | 272 | def prepare_visual_aggregation(self, image_features, temporal_aggregation, question, images): 273 | T, N, D = image_features.shape 274 | 275 | # Example: temporal_aggregation = "slow_3frms_spatial_1d_max_pool_roi6-middle_3frms_spatial_1d_max_pool_24x12-fast_50frms_4x4" 276 | visual_aggregation_match = re.match(r'^slow_(\d+)frms_(\w+)_roi(\d+)-middle_(\d+)frms_(\w+)_(\d+)x(\d+)-fast_(\d+)frms_(\d+)x(\d+)$', temporal_aggregation) 277 | 278 | if not visual_aggregation_match: 279 | raise ValueError(f'Failed to parse the temporal aggregation: {temporal_aggregation}') 280 | num_slowpath = int(visual_aggregation_match.group(1)) 281 | slowpath_temporal_aggregation = visual_aggregation_match.group(2) 282 | roi_ratio = int(visual_aggregation_match.group(3)) / 10.0 283 | middle_frames = int(visual_aggregation_match.group(4)) 284 | middle_temporal_aggregation = visual_aggregation_match.group(5) 285 | middle_output_size = ( 286 | int(visual_aggregation_match.group(6)), 287 | int(visual_aggregation_match.group(7)), 288 | ) 289 | fast_frames = int(visual_aggregation_match.group(8)) 290 | fastpath_output_size = ( 291 | int(visual_aggregation_match.group(9)), 292 | int(visual_aggregation_match.group(10)), 293 | ) 294 | 295 | ##### text-guided images 296 | if num_slowpath < 1: 297 | slowpath_features = None 298 | else: 299 | textguided_idx = self.text_guided_frames(question, images, num_slowpath) 300 | slowpath_features = self.temporal_aggregation( 301 | image_features[textguided_idx], 302 | slowpath_temporal_aggregation, 303 | ) 304 | 305 | ###### text-guided roi crop 306 | if roi_ratio < 1: 307 | h, w = int(math.sqrt(N))//2, int(math.sqrt(N)) 308 | slowpath_features = self.roi_crop(question, images[textguided_idx], slowpath_features, slowpath_temporal_aggregation, (h, w), roi_ratio) 309 | 310 | # select top-k change frames 311 | if middle_frames < 1: 312 | middle_features = None 313 | else: 314 | middle_idx = torch.linspace(0, T, middle_frames + 1) 315 | middle_idx = middle_idx.to(torch.int32).tolist() 316 | middle_idx.pop() 317 | middle_features = self.temporal_aggregation( 318 | image_features[middle_idx], 319 | middle_temporal_aggregation, 320 | ) 321 | ###### text-guided roi crop 322 | if roi_ratio < 1: 323 | h, w = int(math.sqrt(N))//2, int(math.sqrt(N)) 324 | middle_features = self.roi_crop(question, images[middle_idx], middle_features, middle_temporal_aggregation, (h, w), roi_ratio) 325 | 326 | # Prepare fast pathway 327 | if fast_frames < 1: 328 | fastpath_features = None 329 | else: 330 | if fast_frames < T: 331 | fastpath_idx = self.text_guided_frames(question, images, fast_frames) 332 | fastpath_features = image_features[fastpath_idx] # [T N D] 333 | else: 334 | fastpath_features = image_features 335 | pool2 = nn.AdaptiveAvgPool2d(fastpath_output_size) 336 | n0 = n1 = int(math.sqrt(N)) 337 | fastpath_features = rearrange(fastpath_features, 't (n0 n1) d -> d t n0 n1', n0=n0, n1=n1) # [T N D] -> [D T N0 N1] 338 | fastpath_features = pool2(fastpath_features) 339 | fastpath_features = rearrange(fastpath_features, 'd t n0 n1 -> (t n0 n1) d') # [D T N0/2 N1/2] -> [-1 D] 340 | fastpath_features = fastpath_features.unsqueeze(0) 341 | 342 | feature_list = [x for x in [slowpath_features, middle_features, fastpath_features] if x is not None] 343 | visual_aggregation_features = torch.cat(feature_list, dim=1) 344 | return visual_aggregation_features 345 | 346 | def prepare_inputs_labels_for_multimodal( 347 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 348 | images, image_sizes=None, temporal_aggregation=None, question=None, 349 | ): 350 | vision_tower = self.get_vision_tower() 351 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 352 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 353 | 354 | if type(images) is list or images.ndim == 5: 355 | if type(images) is list: 356 | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] 357 | concat_images = torch.cat([image for image in images], dim=0) 358 | image_features = self.encode_images(concat_images) 359 | split_sizes = [image.shape[0] for image in images] 360 | image_features = torch.split(image_features, split_sizes, dim=0) 361 | mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') 362 | image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') 363 | if mm_patch_merge_type == 'flat': 364 | image_features = [x.flatten(0, 1) for x in image_features] 365 | elif mm_patch_merge_type.startswith('spatial'): 366 | new_image_features = [] 367 | for image_idx, image_feature in enumerate(image_features): 368 | if image_feature.shape[0] > 1: 369 | base_image_feature = image_feature[0] 370 | image_feature = image_feature[1:] 371 | height = width = self.get_vision_tower().num_patches_per_side 372 | assert height * width == base_image_feature.shape[0] 373 | if image_aspect_ratio == 'anyres': 374 | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) 375 | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) 376 | else: 377 | raise NotImplementedError 378 | if 'unpad' in mm_patch_merge_type: 379 | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() 380 | image_feature = image_feature.flatten(1, 2).flatten(2, 3) 381 | image_feature = unpad_image(image_feature, image_sizes[image_idx]) 382 | image_feature = torch.cat(( 383 | image_feature, 384 | self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) 385 | ), dim=-1) 386 | image_feature = image_feature.flatten(1, 2).transpose(0, 1) 387 | else: 388 | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() 389 | image_feature = image_feature.flatten(0, 3) 390 | image_feature = torch.cat((base_image_feature, image_feature), dim=0) 391 | else: 392 | image_feature = image_feature[0] 393 | if 'unpad' in mm_patch_merge_type: 394 | image_feature = torch.cat(( 395 | image_feature, 396 | self.model.image_newline[None].to(image_feature.device) 397 | ), dim=0) 398 | new_image_features.append(image_feature) 399 | image_features = new_image_features 400 | else: 401 | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") 402 | else: 403 | image_features = self.encode_images(images) 404 | 405 | if temporal_aggregation and \ 406 | temporal_aggregation.lower() != 'none' and \ 407 | temporal_aggregation.lower() != 'false': 408 | if temporal_aggregation.startswith('slow'): 409 | image_features = self.prepare_visual_aggregation(image_features, temporal_aggregation, question, images) 410 | else: 411 | image_features = self.temporal_aggregation(image_features, temporal_aggregation, question, images) 412 | 413 | # TODO: image start / end is not implemented here to support pretraining. 414 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 415 | raise NotImplementedError 416 | 417 | # Let's just add dummy tensors if they do not exist, 418 | # it is a headache to deal with None all the time. 419 | # But it is not ideal, and if you have a better idea, 420 | # please open an issue / submit a PR, thanks. 421 | _labels = labels 422 | _position_ids = position_ids 423 | _attention_mask = attention_mask 424 | if attention_mask is None: 425 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 426 | else: 427 | attention_mask = attention_mask.bool() 428 | if position_ids is None: 429 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 430 | if labels is None: 431 | labels = torch.full_like(input_ids, IGNORE_INDEX) 432 | 433 | # remove the padding using attention_mask -- FIXME 434 | _input_ids = input_ids 435 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 436 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 437 | 438 | new_input_embeds = [] 439 | new_labels = [] 440 | cur_image_idx = 0 441 | for batch_idx, cur_input_ids in enumerate(input_ids): 442 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 443 | if num_images == 0: 444 | cur_image_features = image_features[cur_image_idx] 445 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 446 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) 447 | new_input_embeds.append(cur_input_embeds) 448 | new_labels.append(labels[batch_idx]) 449 | cur_image_idx += 1 450 | continue 451 | 452 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 453 | cur_input_ids_noim = [] 454 | cur_labels = labels[batch_idx] 455 | cur_labels_noim = [] 456 | for i in range(len(image_token_indices) - 1): 457 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) 458 | cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) 459 | split_sizes = [x.shape[0] for x in cur_labels_noim] 460 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) 461 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) 462 | cur_new_input_embeds = [] 463 | cur_new_labels = [] 464 | 465 | for i in range(num_images + 1): 466 | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) 467 | cur_new_labels.append(cur_labels_noim[i]) 468 | if i < num_images: 469 | cur_image_features = image_features[cur_image_idx] 470 | cur_image_idx += 1 471 | cur_new_input_embeds.append(cur_image_features) 472 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 473 | 474 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 475 | 476 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) 477 | cur_new_labels = torch.cat(cur_new_labels) 478 | 479 | new_input_embeds.append(cur_new_input_embeds) 480 | new_labels.append(cur_new_labels) 481 | 482 | # Truncate sequences to max length as image embeddings can make the sequence longer 483 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 484 | if tokenizer_model_max_length is not None: 485 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 486 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 487 | 488 | # Combine them 489 | max_len = max(x.shape[0] for x in new_input_embeds) 490 | batch_size = len(new_input_embeds) 491 | 492 | new_input_embeds_padded = [] 493 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 494 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 495 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 496 | 497 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 498 | cur_len = cur_new_embed.shape[0] 499 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 500 | new_input_embeds_padded.append(torch.cat(( 501 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 502 | cur_new_embed 503 | ), dim=0)) 504 | if cur_len > 0: 505 | new_labels_padded[i, -cur_len:] = cur_new_labels 506 | attention_mask[i, -cur_len:] = True 507 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 508 | else: 509 | new_input_embeds_padded.append(torch.cat(( 510 | cur_new_embed, 511 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 512 | ), dim=0)) 513 | if cur_len > 0: 514 | new_labels_padded[i, :cur_len] = cur_new_labels 515 | attention_mask[i, :cur_len] = True 516 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 517 | 518 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 519 | 520 | if _labels is None: 521 | new_labels = None 522 | else: 523 | new_labels = new_labels_padded 524 | 525 | if _attention_mask is None: 526 | attention_mask = None 527 | else: 528 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 529 | 530 | if _position_ids is None: 531 | position_ids = None 532 | 533 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 534 | 535 | def initialize_vision_tokenizer(self, model_args, tokenizer): 536 | if model_args.mm_use_im_patch_token: 537 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 538 | self.resize_token_embeddings(len(tokenizer)) 539 | 540 | if model_args.mm_use_im_start_end: 541 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 542 | self.resize_token_embeddings(len(tokenizer)) 543 | 544 | if num_new_tokens > 0: 545 | input_embeddings = self.get_input_embeddings().weight.data 546 | output_embeddings = self.get_output_embeddings().weight.data 547 | 548 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 549 | dim=0, keepdim=True) 550 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 551 | dim=0, keepdim=True) 552 | 553 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 554 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 555 | 556 | if model_args.tune_mm_mlp_adapter: 557 | for p in self.get_input_embeddings().parameters(): 558 | p.requires_grad = True 559 | for p in self.get_output_embeddings().parameters(): 560 | p.requires_grad = False 561 | 562 | if model_args.pretrain_mm_mlp_adapter: 563 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 564 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 565 | assert num_new_tokens == 2 566 | if input_embeddings.shape == embed_tokens_weight.shape: 567 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 568 | elif embed_tokens_weight.shape[0] == num_new_tokens: 569 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 570 | else: 571 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 572 | elif model_args.mm_use_im_patch_token: 573 | if model_args.tune_mm_mlp_adapter: 574 | for p in self.get_input_embeddings().parameters(): 575 | p.requires_grad = False 576 | for p in self.get_output_embeddings().parameters(): 577 | p.requires_grad = False 578 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | use_s2 = getattr(vision_tower_cfg, 's2', False) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 10 | if use_s2: 11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 12 | else: 13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | 15 | raise ValueError(f'Unknown vision tower: {vision_tower}') 16 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | from transformers import AutoTokenizer, CLIPModel 6 | 7 | 8 | class CLIPDoubleTower(nn.Module): 9 | def __init__(self, model_name, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | self.model_name = model_name 14 | 15 | if not delay_load: 16 | self.load_model() 17 | 18 | def load_model(self, device_map=None): 19 | if self.is_loaded: 20 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.model_name)) 21 | return 22 | 23 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, device_map=device_map) 24 | self.clip_model = CLIPModel.from_pretrained(self.model_name, device_map=device_map) 25 | self.clip_model.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | @torch.no_grad() 30 | def get_pooled_features(self, question, images): 31 | inputs = self.tokenizer(question, padding=True, return_tensors="pt") 32 | input_ids = inputs['input_ids'].to(device=self.device) 33 | text_features = self.clip_model.get_text_features(input_ids) 34 | image_features = self.clip_model.get_image_features(images) 35 | return text_features, image_features 36 | 37 | @torch.no_grad() 38 | def get_finegrained_features(self, question, images): 39 | inputs = self.tokenizer(question, padding=True, return_tensors="pt") 40 | input_ids = inputs['input_ids'].to(device=self.device) 41 | outputs = self.clip_model(input_ids, images) 42 | text_pooled, image_pooled = outputs.text_embeds, outputs.image_embeds 43 | text_finegrained, image_finegrained = outputs.text_model_output.last_hidden_state, outputs.vision_model_output.last_hidden_state 44 | return text_pooled, image_pooled, text_finegrained, image_finegrained 45 | 46 | @property 47 | def dtype(self): 48 | return self.clip_model.dtype 49 | 50 | @property 51 | def device(self): 52 | return self.clip_model.device 53 | 54 | 55 | class CLIPVisionTower(nn.Module): 56 | def __init__(self, vision_tower, args, delay_load=False): 57 | super().__init__() 58 | 59 | self.is_loaded = False 60 | 61 | self.vision_tower_name = vision_tower 62 | self.select_layer = args.mm_vision_select_layer 63 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 64 | 65 | if not delay_load: 66 | self.load_model() 67 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 68 | self.load_model() 69 | else: 70 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 71 | 72 | def load_model(self, device_map=None): 73 | if self.is_loaded: 74 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 75 | return 76 | 77 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 78 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 79 | self.vision_tower.requires_grad_(False) 80 | 81 | self.is_loaded = True 82 | 83 | def feature_select(self, image_forward_outs): 84 | image_features = image_forward_outs.hidden_states[self.select_layer] 85 | if self.select_feature == 'patch': 86 | image_features = image_features[:, 1:] 87 | elif self.select_feature == 'cls_patch': 88 | image_features = image_features 89 | else: 90 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 91 | return image_features 92 | 93 | @torch.no_grad() 94 | def forward(self, images): 95 | if type(images) is list: 96 | image_features = [] 97 | for image in images: 98 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 99 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 100 | image_features.append(image_feature) 101 | else: 102 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 103 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 104 | 105 | return image_features 106 | 107 | @property 108 | def dummy_feature(self): 109 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 110 | 111 | @property 112 | def dtype(self): 113 | return self.vision_tower.dtype 114 | 115 | @property 116 | def device(self): 117 | return self.vision_tower.device 118 | 119 | @property 120 | def config(self): 121 | if self.is_loaded: 122 | return self.vision_tower.config 123 | else: 124 | return self.cfg_only 125 | 126 | @property 127 | def hidden_size(self): 128 | return self.config.hidden_size 129 | 130 | @property 131 | def num_patches_per_side(self): 132 | return self.config.image_size // self.config.patch_size 133 | 134 | @property 135 | def num_patches(self): 136 | return (self.config.image_size // self.config.patch_size) ** 2 137 | 138 | 139 | 140 | class CLIPVisionTowerS2(CLIPVisionTower): 141 | def __init__(self, vision_tower, args, delay_load=False): 142 | super().__init__(vision_tower, args, delay_load) 143 | 144 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008') 145 | self.s2_scales = list(map(int, self.s2_scales.split(','))) 146 | self.s2_scales.sort() 147 | self.s2_split_size = self.s2_scales[0] 148 | self.s2_image_size = self.s2_scales[-1] 149 | 150 | try: 151 | from s2wrapper import forward as multiscale_forward 152 | except ImportError: 153 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') 154 | self.multiscale_forward = multiscale_forward 155 | 156 | # change resize/crop size in preprocessing to the largest image size in s2_scale 157 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): 158 | self.image_processor.size['shortest_edge'] = self.s2_image_size 159 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 160 | 161 | def load_model(self, device_map=None): 162 | if self.is_loaded: 163 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 164 | return 165 | 166 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 167 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 168 | self.vision_tower.requires_grad_(False) 169 | 170 | self.image_processor.size['shortest_edge'] = self.s2_image_size 171 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 172 | 173 | self.is_loaded = True 174 | 175 | @torch.no_grad() 176 | def forward_feature(self, images): 177 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 178 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 179 | return image_features 180 | 181 | @torch.no_grad() 182 | def forward(self, images): 183 | if type(images) is list: 184 | image_features = [] 185 | for image in images: 186 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 187 | image_features.append(image_feature) 188 | else: 189 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 190 | 191 | return image_features 192 | 193 | @property 194 | def hidden_size(self): 195 | return self.config.hidden_size * len(self.s2_scales) 196 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /free_video_llm/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /free_video_llm/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 6 | from llava.conversation import conv_templates, SeparatorStyle 7 | 8 | 9 | def get_option_prompt(candidates, version="default"): 10 | option_prompt = "" 11 | options = [] 12 | for idx, candidate in enumerate(candidates): 13 | choice = chr(ord("A") + idx) 14 | if version == "v4": 15 | option_prompt += f"({choice}) {candidate}\n" 16 | else: 17 | option_prompt += f"({choice}):{candidate} " 18 | options.append(choice) 19 | options = "(" + ",".join(options) + ")" 20 | return option_prompt, options 21 | 22 | 23 | def get_multiple_choice_prompt(model, conv_mode, question, candidates): 24 | if conv_mode == "multiple_choice_allvideo_v4": 25 | prompt = "You are a helpful expert in video analysis. Select the best option to answer the question. USER: \nThe input consists of a sequence of key frames from a video.\nQuestion: %s\nOptions:\n%sOnly give the best option. \nASSISTANT:\nAnswer: Best option:(" 26 | option_prompt, options = get_option_prompt(candidates, version="v4") 27 | prompt = prompt % (question, option_prompt) 28 | elif conv_mode == "multiple_choice_allvideo_34b_v4": 29 | prompt = "<|im_start|>system\n You are a helpful expert in video analysis. Select the best option to answer the question. <|im_end|>\n<|im_start|>user\n \nThe input consists of a sequence of key frames from a video. Question: %s\nOptions:\n%sOnly give the best option. <|im_end|>\n<|im_start|>assistant\nAnswer: Best option:(" 30 | option_prompt, options = get_option_prompt(candidates, version="v4") 31 | prompt = prompt % (question, option_prompt) 32 | else: 33 | raise ValueError(f"Unknown conv_mode: {conv_mode}") 34 | return prompt 35 | 36 | 37 | def get_prompt(model, conv_mode, question): 38 | if conv_mode == "image_seq_v3": 39 | prompt = "USER: \nThe input consists of a sequence of key frames from a video. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the video. Question: %s \nASSISTANT:\nAnswer: In the video," 40 | prompt = prompt % question 41 | elif conv_mode == "image_seq_34b_v3": 42 | prompt = "<|im_start|>system\n Answer the question. <|im_end|>\n<|im_start|>user\n \nThe input consists of a sequence of key frames from a video. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the video. Question: %s <|im_end|>\n<|im_start|>assistant\nAnswer: In the video," 43 | prompt = prompt % question 44 | else: 45 | if model.config.mm_use_im_start_end: 46 | ques = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + question 47 | else: 48 | ques = DEFAULT_IMAGE_TOKEN + "\n" + question 49 | conv = conv_templates[conv_mode].copy() 50 | conv.append_message(conv.roles[0], ques) 51 | conv.append_message(conv.roles[1], None) 52 | prompt = conv.get_prompt() 53 | return prompt 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sf_llava" 7 | version = "1.2.2.post1" 8 | description = "A Strong Training-Free Video LLM" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.2.0", "torchvision==0.17.0", 17 | "transformers==4.38.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid==1.0.13", 18 | "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", 19 | "pydantic", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.5.0", 20 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 21 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 22 | "openai==1.14.3", "peft==0.4.0", "safetensors==0.4.3", 23 | "decord", "opencv-python", "pytorchvideo==0.1.5", 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.12.6", "ninja", "wandb"] 28 | build = ["build", "twine"] 29 | 30 | [tool.setuptools.packages.find] 31 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 32 | 33 | [tool.wheel] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import os 7 | import sys 8 | from pathlib import Path 9 | sys.path.insert(0, Path(__file__).parent.as_posix()) 10 | sys.path.insert(0, os.path.join(Path(__file__).parent.as_posix(), "free_video_llm")) 11 | import torch 12 | 13 | from llava.constants import IMAGE_TOKEN_INDEX 14 | from llava.model.builder import load_pretrained_model 15 | from llava.utils import disable_torch_init 16 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 17 | 18 | from dataset import load_video 19 | from prompt import get_prompt 20 | 21 | 22 | def llava_inference( 23 | video_frames, 24 | question, 25 | conv_mode, 26 | model, 27 | tokenizer, 28 | image_processor, 29 | image_sizes, 30 | temporal_aggregation, 31 | ): 32 | # Get prompt 33 | prompt = get_prompt(model, conv_mode, question) 34 | 35 | # Get text inputs 36 | input_ids = tokenizer_image_token( 37 | prompt, 38 | tokenizer, 39 | IMAGE_TOKEN_INDEX, 40 | return_tensors="pt", 41 | ).unsqueeze(0).cuda() 42 | 43 | # Get image inputs 44 | image_tensor = process_images(video_frames, image_processor, model.config) 45 | 46 | with torch.inference_mode(): 47 | output_ids = model.generate( 48 | input_ids, 49 | images=image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True), 50 | image_sizes=image_sizes, 51 | do_sample=False, 52 | temperature=0, 53 | top_p=None, 54 | num_beams=1, 55 | max_new_tokens=256, 56 | use_cache=True, 57 | temporal_aggregation=temporal_aggregation, 58 | question=question, 59 | ) 60 | 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 62 | return outputs 63 | 64 | 65 | def run_inference(args): 66 | """ 67 | Run inference 68 | 69 | Args: 70 | args: Command-line arguments. 71 | """ 72 | 73 | disable_torch_init() 74 | 75 | # Load tokenizer, model and image processor 76 | model_path = os.path.expanduser(args.model_path) 77 | model_name = get_model_name_from_path(model_path) 78 | tokenizer, model, image_processor, context_len = load_pretrained_model( 79 | model_path, None, model_name, 80 | device=torch.cuda.current_device(), 81 | device_map="cuda", 82 | rope_scaling_factor=args.rope_scaling_factor, 83 | ) 84 | 85 | # Override image aspect ratio if needed 86 | if args.image_aspect_ratio: 87 | model.config.image_aspect_ratio = args.image_aspect_ratio 88 | 89 | # Load video 90 | video_frames, sizes = load_video(args.video_path, num_frms=args.num_frames) 91 | 92 | try: 93 | # Run inference on the video 94 | output = llava_inference( 95 | video_frames, 96 | args.question, 97 | args.conv_mode, 98 | model, 99 | tokenizer, 100 | image_processor, 101 | sizes, 102 | args.temporal_aggregation, 103 | ) 104 | print(f"Question: {args.question}") 105 | print(f"\nAnswer: In this video, {output}") 106 | except Exception as e: 107 | print(f"Error processing video file '{args.video_path}': {e}") 108 | 109 | 110 | def parse_args(): 111 | """ 112 | Parse command-line arguments. 113 | """ 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--video_path", help="input video path", required=True) 116 | parser.add_argument("--model_path", help="LLaVA model path", type=str, required=True) 117 | parser.add_argument("--question", help="Input question and prompt", type=str, required=True) 118 | parser.add_argument("--conv_mode", type=str, required=False, default="image_seq_v3") 119 | parser.add_argument("--num_frames", type=int, default=50) 120 | parser.add_argument("--input_structure", type=str, default="image_seq") 121 | parser.add_argument("--image_aspect_ratio", type=str, default="resize") 122 | parser.add_argument("--temporal_aggregation", type=str, default="slowfast-slow_10frms_spatial_1d_max_pool-fast_4x4") 123 | parser.add_argument("--rope_scaling_factor", type=int, default=2) 124 | return parser.parse_args() 125 | 126 | 127 | if __name__ == "__main__": 128 | args = parse_args() 129 | run_inference(args) 130 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | import yaml 5 | 6 | 7 | def get_args(): 8 | parser = ArgumentParser(description="Free-VideoLLM") 9 | parser.add_argument( 10 | "--exp_config", 11 | type=str, 12 | required=True, 13 | help="path to exp config file", 14 | ) 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | # Load exp config 20 | args = get_args() 21 | with open(args.exp_config, "r") as f: 22 | exp_config = yaml.safe_load(f) 23 | if exp_config["CONFIG_NAME"] == "auto": 24 | exp_config["CONFIG_NAME"] = Path(args.exp_config).stem 25 | 26 | # Get commands 27 | commands = exp_config.pop("SCRIPT", None) 28 | if commands is None: 29 | raise RuntimeError("Script was not found in the config") 30 | if type(commands) is not list: 31 | commands = [commands] 32 | 33 | # Get parameters 34 | parameters = [] 35 | for k, v in exp_config.items(): 36 | if type(v) is not list: 37 | v = [v] * len(commands) 38 | else: 39 | assert len(v) == len(commands), \ 40 | f"The number of parameters in {k} must match the number of SCRIPT" 41 | parameters.append((k, v)) 42 | 43 | print(f":::: Start Inference ::::") 44 | 45 | # Iterate all scripts 46 | for idx, cmd in enumerate(commands): 47 | params = "" 48 | for k, v in parameters: 49 | params += f"{k}={v[idx]} " 50 | cmd = params + cmd 51 | 52 | # Run command 53 | subprocess.check_call(["bash", "-c", cmd]) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() -------------------------------------------------------------------------------- /run_inference_video_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | sys.path.insert(0, Path(__file__).parent.as_posix()) 6 | sys.path.insert(0, os.path.join(Path(__file__).parent.as_posix(), "free_video_llm")) 7 | import json 8 | from tqdm import tqdm 9 | import torch 10 | 11 | from llava.constants import IMAGE_TOKEN_INDEX 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 15 | 16 | from dataset import load_video 17 | from prompt import get_prompt 18 | from utils import get_chunk 19 | 20 | 21 | VIDEO_FORMATS = [".mp4", ".avi", ".mov", ".mkv"] 22 | 23 | 24 | def llava_inference( 25 | video_frames, 26 | question, 27 | conv_mode, 28 | model, 29 | tokenizer, 30 | image_processor, 31 | image_sizes, 32 | temperature, 33 | top_p, 34 | num_beams, 35 | temporal_aggregation, 36 | ): 37 | # Get prompt 38 | prompt = get_prompt(model, conv_mode, question) 39 | 40 | # Get text inputs 41 | input_ids = tokenizer_image_token( 42 | prompt, 43 | tokenizer, 44 | IMAGE_TOKEN_INDEX, 45 | return_tensors="pt", 46 | ).unsqueeze(0).cuda() 47 | 48 | # Get image inputs 49 | image_tensor = process_images(video_frames, image_processor, model.config) 50 | 51 | with torch.inference_mode(): 52 | output_ids = model.generate( 53 | input_ids, 54 | images=image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True), 55 | image_sizes=image_sizes, 56 | do_sample=True if temperature > 0 else False, 57 | temperature=temperature, 58 | top_p=top_p, 59 | num_beams=num_beams, 60 | max_new_tokens=128, 61 | use_cache=True, 62 | temporal_aggregation=temporal_aggregation, 63 | question=question, 64 | ) 65 | 66 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 67 | return outputs 68 | 69 | 70 | def run_inference(args): 71 | """ 72 | Run inference on Video QA Dataset. 73 | 74 | Args: 75 | args: Command-line arguments. 76 | """ 77 | 78 | disable_torch_init() 79 | 80 | # Load tokenizer, model and image processor 81 | model_path = os.path.expanduser(args.model_path) 82 | model_name = get_model_name_from_path(model_path) 83 | tokenizer, model, image_processor, context_len = load_pretrained_model( 84 | model_path, args.model_base, model_name, 85 | device=torch.cuda.current_device(), 86 | device_map="cuda", 87 | rope_scaling_factor=args.rope_scaling_factor, 88 | ) 89 | 90 | # Override image aspect ratio if needed 91 | if args.image_aspect_ratio: 92 | model.config.image_aspect_ratio = args.image_aspect_ratio 93 | 94 | # Load questions and answers 95 | gt_questions = json.load(open(args.gt_file_question, "r")) 96 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx) 97 | gt_answers = json.load(open(args.gt_file_answers, "r")) 98 | gt_answers = get_chunk(gt_answers, args.num_chunks, args.chunk_idx) 99 | 100 | os.makedirs(args.output_dir, exist_ok=True) 101 | ans_file = open( 102 | os.path.join(args.output_dir, f"{args.output_name}.json"), "w") 103 | 104 | # Iterate over each sample in the ground truth file 105 | for index, sample in enumerate(tqdm(gt_questions)): 106 | video_name = sample["video_name"] 107 | question = sample["question"] 108 | question_id = sample["question_id"] 109 | answer = gt_answers[index]["answer"] 110 | 111 | sample_set = { 112 | "question": question, 113 | "id": question_id, 114 | "answer": answer, 115 | } 116 | 117 | for fmt in VIDEO_FORMATS: 118 | # Load video 119 | updated_video_name = f"v_{video_name}" if "Activitynet" in args.video_dir else video_name 120 | video_path = os.path.join(args.video_dir, f"{updated_video_name}{fmt}") 121 | 122 | if os.path.exists(video_path): 123 | try: 124 | video_frames, sizes = load_video(video_path, num_frms=args.num_frames) 125 | except Exception as e: 126 | print(f"Failed to load {video_path}, continue...") 127 | continue 128 | 129 | # Run inference on the video 130 | output = llava_inference( 131 | video_frames, 132 | question, 133 | args.conv_mode, 134 | model, 135 | tokenizer, 136 | image_processor, 137 | sizes, 138 | args.temperature, 139 | args.top_p, 140 | args.num_beams, 141 | args.temporal_aggregation, 142 | ) 143 | # print(output) 144 | sample_set["pred"] = output 145 | ans_file.write(json.dumps(sample_set) + "\n") 146 | break 147 | 148 | ans_file.close() 149 | 150 | 151 | def parse_args(): 152 | """ 153 | Parse command-line arguments. 154 | """ 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument("--video_dir", help="Directory containing video files.", required=True) 157 | parser.add_argument("--gt_file_question", help="Path to the ground truth file containing question.", required=True) 158 | parser.add_argument("--gt_file_answers", help="Path to the ground truth file containing answers.", required=True) 159 | parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True) 160 | parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True) 161 | parser.add_argument("--model_path", type=str, required=True) 162 | parser.add_argument("--model_base", type=str, default=None) 163 | parser.add_argument("--conv_mode", type=str, default="vicuna_v1") 164 | parser.add_argument("--num_chunks", type=int, default=1) 165 | parser.add_argument("--chunk_idx", type=int, default=0) 166 | parser.add_argument("--num_frames", type=int, default=100) 167 | parser.add_argument("--temperature", type=float, default=0.2) 168 | parser.add_argument("--top_p", type=float, default=None) 169 | parser.add_argument("--num_beams", type=int, default=1) 170 | parser.add_argument("--input_structure", type=str, default="image_seq") 171 | parser.add_argument("--image_aspect_ratio", type=str, default=None) 172 | parser.add_argument("--temporal_aggregation", type=str, default=None) 173 | parser.add_argument("--rope_scaling_factor", type=int, default=1) 174 | return parser.parse_args() 175 | 176 | 177 | if __name__ == "__main__": 178 | args = parse_args() 179 | run_inference(args) 180 | -------------------------------------------------------------------------------- /scripts/data/prepare_activitynet_qa_file.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import argparse 7 | import csv 8 | import json 9 | 10 | 11 | def main(args, task_name="Activitynet_Zero_Shot_QA"): 12 | data_q = [] 13 | data_a = [] 14 | with open(args.qa_file, newline="") as csvfile: 15 | spamreader = csv.reader(csvfile, delimiter=",") 16 | # ,video_id,answer,question,video_name,question_id,question_type 17 | for idx, row in enumerate(spamreader): 18 | if idx == 0: 19 | continue 20 | _, video_id, answer, question, video_name, question_id, question_type = row 21 | data_q.append({ 22 | "video_name": video_name[2:], 23 | "question_id": question_id, 24 | "question": question, 25 | }) 26 | data_a.append({ 27 | "answer": answer, 28 | "type": int(question_type), 29 | "question_id": question_id, 30 | }) 31 | 32 | folder = f"playground/gt_qa_files/{task_name}" 33 | os.makedirs(folder, exist_ok=True) 34 | with open(f"{folder}/test_q.json", "w") as f: 35 | json.dump(data_q, f, indent=4) 36 | with open(f"{folder}/test_a.json", "w") as f: 37 | json.dump(data_a, f, indent=4) 38 | 39 | 40 | def parse_args(): 41 | """ 42 | Parse command-line arguments. 43 | """ 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--qa_file", help="Path to Activitynet_QA.csv", required=True) 46 | return parser.parse_args() 47 | 48 | 49 | if __name__ == "__main__": 50 | main(parse_args()) 51 | -------------------------------------------------------------------------------- /scripts/data/prepare_msrvtt_qa_file.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import argparse 7 | import csv 8 | import json 9 | 10 | 11 | def main(args, task_name="MSRVTT_Zero_Shot_QA"): 12 | data_q = [] 13 | data_a = [] 14 | with open(args.qa_file, newline="") as csvfile: 15 | spamreader = csv.reader(csvfile, delimiter=",") 16 | # ,video_id,answer,question,video_name,question_id,question_type 17 | for idx, row in enumerate(spamreader): 18 | if idx == 0: 19 | continue 20 | _, video_id, answer, question, video_name, question_id, question_type = row 21 | data_q.append({ 22 | "video_name": video_name, 23 | "question_id": question_id, 24 | "question": question, 25 | }) 26 | data_a.append({ 27 | "answer": answer, 28 | "type": int(question_type), 29 | "question_id": question_id, 30 | }) 31 | 32 | folder = f"playground/gt_qa_files/{task_name}" 33 | os.makedirs(folder, exist_ok=True) 34 | with open(f"{folder}/val_q.json", "w") as f: 35 | json.dump(data_q, f, indent=4) 36 | with open(f"{folder}/val_a.json", "w") as f: 37 | json.dump(data_a, f, indent=4) 38 | 39 | 40 | def parse_args(): 41 | """ 42 | Parse command-line arguments. 43 | """ 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--qa_file", help="Path to MSRVTT_QA.csv", required=True) 46 | return parser.parse_args() 47 | 48 | 49 | if __name__ == "__main__": 50 | main(parse_args()) 51 | -------------------------------------------------------------------------------- /scripts/data/prepare_msvd_qa_file.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import argparse 7 | import csv 8 | import json 9 | 10 | 11 | def main(args, task_name="MSVD_Zero_Shot_QA"): 12 | data_q = [] 13 | data_a = [] 14 | with open(args.qa_file, newline="") as csvfile: 15 | spamreader = csv.reader(csvfile, delimiter=",") 16 | # ,video_id,answer,question,video_name,question_id,question_type 17 | for idx, row in enumerate(spamreader): 18 | if idx == 0: 19 | continue 20 | _, video_id, answer, question, video_name, question_id, question_type = row 21 | data_q.append({ 22 | "video_name": video_name, 23 | "question_id": question_id, 24 | "question": question, 25 | }) 26 | data_a.append({ 27 | "answer": answer, 28 | "type": int(question_type), 29 | "question_id": question_id, 30 | }) 31 | 32 | folder = f"playground/gt_qa_files/{task_name}" 33 | os.makedirs(folder, exist_ok=True) 34 | with open(f"playground/gt_qa_files/{task_name}/val_q.json", "w") as f: 35 | json.dump(data_q, f, indent=4) 36 | with open(f"playground/gt_qa_files/{task_name}/val_a.json", "w") as f: 37 | json.dump(data_a, f, indent=4) 38 | 39 | 40 | def parse_args(): 41 | """ 42 | Parse command-line arguments. 43 | """ 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--qa_file", help="Path to MSVD_QA.csv", required=True) 46 | return parser.parse_args() 47 | 48 | 49 | if __name__ == "__main__": 50 | main(parse_args()) 51 | -------------------------------------------------------------------------------- /scripts/data/prepare_tgif_qa_file.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import argparse 7 | import csv 8 | import json 9 | 10 | 11 | def main(args, task_name="TGIF_Zero_Shot_QA"): 12 | data_q = [] 13 | data_a = [] 14 | with open(args.qa_file, newline="") as csvfile: 15 | spamreader = csv.reader(csvfile, delimiter=",") 16 | # ,video_id,answer,question,video_name,question_id,question_type 17 | for idx, row in enumerate(spamreader): 18 | if idx == 0: 19 | continue 20 | _, video_id, answer, question, video_name, question_id, question_type = row 21 | data_q.append({ 22 | "video_name": video_name, 23 | "question_id": question_id, 24 | "question": question, 25 | }) 26 | data_a.append({ 27 | "answer": answer, 28 | "type": int(question_type), 29 | "question_id": question_id, 30 | }) 31 | 32 | folder = f"playground/gt_qa_files/{task_name}" 33 | os.makedirs(folder, exist_ok=True) 34 | with open(f"{folder}/val_q.json", "w") as f: 35 | json.dump(data_q, f, indent=4) 36 | with open(f"{folder}/val_a.json", "w") as f: 37 | json.dump(data_a, f, indent=4) 38 | 39 | 40 | def parse_args(): 41 | """ 42 | Parse command-line arguments. 43 | """ 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--qa_file", help="Path to TGIF_FrameQA.csv", required=True) 46 | return parser.parse_args() 47 | 48 | 49 | if __name__ == "__main__": 50 | main(parse_args()) 51 | -------------------------------------------------------------------------------- /scripts/run_eval_activitynet.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | ROOT_DIR=${ROOT_DIR:-"$(dirname "$(dirname "$(readlink -f "$0")")")"} 6 | CONFIG_NAME=${CONFIG_NAME:-"slowfast_llava_7b-resize-slow_10frms_spatial_1d_max_pool_fast_4x4-50_frms"} 7 | DATA_DIR=${DATA_DIR:-"${ROOT_DIR}/playground/data/video_qa"} 8 | GT_QA_DIR=${GT_QA_DIR:-"${ROOT_DIR}/playground/gt_qa_files"} 9 | MODEL_PATH=${MODEL_PATH:-"${ROOT_DIR}/liuhaotian/llava-v1.6-vicuna-7b/"} 10 | OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/artifacts"} 11 | TEMP_DIR=${TEMP_DIR:-"${ROOT_DIR}/outputs/eval_save_dir"} 12 | CONV_MODE=${CONV_MODE:-"image_seq_v3"} 13 | NUM_FRAMES=${NUM_FRAMES:-"50"} 14 | INPUT_STRUCTURE=${INPUT_STRUCTURE:-"image_seq"} 15 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION:-"slowfast-slow_10frms_spatial_1d_max_pool-fast_4x4"} 16 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO:-"resize"} 17 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR:-"2"} 18 | 19 | ################################# Run ################################## 20 | 21 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 22 | IFS=',' read -ra GPULIST <<< "$gpu_list" 23 | 24 | CHUNKS=${#GPULIST[@]} 25 | 26 | 27 | for IDX in $(seq 0 $((CHUNKS-1))); do 28 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 ${ROOT_DIR}/run_inference_video_qa.py \ 29 | --video_dir ${DATA_DIR}/Activitynet_Zero_Shot_QA/all_test \ 30 | --gt_file_question ${GT_QA_DIR}/Activitynet_Zero_Shot_QA/test_q.json \ 31 | --gt_file_answers ${GT_QA_DIR}/Activitynet_Zero_Shot_QA/test_a.json \ 32 | --output_dir ${OUTPUT_DIR}/Activitynet_Zero_Shot_QA/${CONFIG_NAME} \ 33 | --output_name ${CHUNKS}_${IDX} \ 34 | --model_path ${MODEL_PATH} \ 35 | --conv_mode ${CONV_MODE} \ 36 | --num_chunks ${CHUNKS} \ 37 | --chunk_idx ${IDX} \ 38 | --num_frames ${NUM_FRAMES} \ 39 | --temperature 0 \ 40 | --input_structure ${INPUT_STRUCTURE} \ 41 | --temporal_aggregation ${TEMPORAL_AGGREGATION} \ 42 | --image_aspect_ratio ${IMAGE_ASPECT_RATIO} \ 43 | --rope_scaling_factor ${ROPE_SCALING_FACTOR} & 44 | done 45 | 46 | wait 47 | 48 | output_dir=${OUTPUT_DIR}/Activitynet_Zero_Shot_QA/${CONFIG_NAME} 49 | output_file=${output_dir}/merge.jsonl 50 | temp_dir=${TEMP_DIR}/Activitynet_Zero_Shot_QA/${CONFIG_NAME} 51 | 52 | # Clear out the output file if it exists. 53 | > "${output_file}" 54 | 55 | # Loop through the indices and concatenate each file. 56 | for IDX in $(seq 0 $((CHUNKS-1))); do 57 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "${output_file}" 58 | done 59 | 60 | ################################# Eval ################################## 61 | 62 | gpt_version="gpt-3.5-turbo-0125" 63 | num_tasks=25 64 | 65 | python3 ${ROOT_DIR}/eval/eval_video_qa.py \ 66 | --pred_path ${output_file} \ 67 | --output_dir ${temp_dir}/${gpt_version} \ 68 | --output_json ${output_dir}/results.json \ 69 | --gpt_version ${gpt_version} \ 70 | --num_tasks ${num_tasks} 71 | -------------------------------------------------------------------------------- /scripts/run_eval_msrvtt.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | ROOT_DIR=${ROOT_DIR:-"$(dirname "$(dirname "$(readlink -f "$0")")")"} 6 | CONFIG_NAME=${CONFIG_NAME:-"slowfast_llava_7b-resize-slow_10frms_spatial_1d_max_pool_fast_4x4-50_frms"} 7 | DATA_DIR=${DATA_DIR:-"${ROOT_DIR}/playground/data/video_qa"} 8 | GT_QA_DIR=${GT_QA_DIR:-"${ROOT_DIR}/playground/gt_qa_files"} 9 | MODEL_PATH=${MODEL_PATH:-"${ROOT_DIR}/liuhaotian/llava-v1.6-vicuna-7b/"} 10 | OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/artifacts"} 11 | TEMP_DIR=${TEMP_DIR:-"${ROOT_DIR}/outputs/eval_save_dir"} 12 | CONV_MODE=${CONV_MODE:-"image_seq_v3"} 13 | NUM_FRAMES=${NUM_FRAMES:-"50"} 14 | INPUT_STRUCTURE=${INPUT_STRUCTURE:-"image_seq"} 15 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION:-"slowfast-slow_10frms_spatial_1d_max_pool-fast_4x4"} 16 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO:-"resize"} 17 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR:-"2"} 18 | 19 | ################################# Run ################################## 20 | 21 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 22 | IFS=',' read -ra GPULIST <<< "$gpu_list" 23 | 24 | CHUNKS=${#GPULIST[@]} 25 | 26 | 27 | for IDX in $(seq 0 $((CHUNKS-1))); do 28 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 ${ROOT_DIR}/run_inference_video_qa.py \ 29 | --video_dir ${DATA_DIR}/MSRVTT_Zero_Shot_QA/videos/all \ 30 | --gt_file_question ${GT_QA_DIR}/MSRVTT_Zero_Shot_QA/val_q.json \ 31 | --gt_file_answers ${GT_QA_DIR}/MSRVTT_Zero_Shot_QA/val_a.json \ 32 | --output_dir ${OUTPUT_DIR}/MSRVTT_Zero_Shot_QA/${CONFIG_NAME} \ 33 | --output_name ${CHUNKS}_${IDX} \ 34 | --model_path ${MODEL_PATH} \ 35 | --conv_mode ${CONV_MODE} \ 36 | --num_chunks ${CHUNKS} \ 37 | --chunk_idx ${IDX} \ 38 | --num_frames ${NUM_FRAMES} \ 39 | --temperature 0 \ 40 | --input_structure ${INPUT_STRUCTURE} \ 41 | --temporal_aggregation ${TEMPORAL_AGGREGATION} \ 42 | --image_aspect_ratio ${IMAGE_ASPECT_RATIO} \ 43 | --rope_scaling_factor ${ROPE_SCALING_FACTOR} & 44 | done 45 | 46 | wait 47 | 48 | output_dir=${OUTPUT_DIR}/MSRVTT_Zero_Shot_QA/${CONFIG_NAME} 49 | output_file=${output_dir}/merge.jsonl 50 | temp_dir=${TEMP_DIR}/MSRVTT_Zero_Shot_QA/${CONFIG_NAME} 51 | 52 | # Clear out the output file if it exists. 53 | > "${output_file}" 54 | 55 | # Loop through the indices and concatenate each file. 56 | for IDX in $(seq 0 $((CHUNKS-1))); do 57 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "${output_file}" 58 | done 59 | 60 | ################################# Eval ################################## 61 | 62 | gpt_version="gpt-3.5-turbo-0125" 63 | num_tasks=25 64 | 65 | python3 ${ROOT_DIR}/eval/eval_video_qa.py \ 66 | --pred_path ${output_file} \ 67 | --output_dir ${temp_dir}/${gpt_version} \ 68 | --output_json ${output_dir}/results.json \ 69 | --gpt_version ${gpt_version} \ 70 | --num_tasks ${num_tasks} 71 | -------------------------------------------------------------------------------- /scripts/run_eval_msvd.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | ROOT_DIR=${ROOT_DIR:-"$(dirname "$(dirname "$(readlink -f "$0")")")"} 6 | CONFIG_NAME=${CONFIG_NAME:-"slowfast_llava_7b-resize-slow_10frms_spatial_1d_max_pool_fast_4x4-50_frms"} 7 | DATA_DIR=${DATA_DIR:-"${ROOT_DIR}/playground/data/video_qa"} 8 | GT_QA_DIR=${GT_QA_DIR:-"${ROOT_DIR}/playground/gt_qa_files"} 9 | MODEL_PATH=${MODEL_PATH:-"${ROOT_DIR}/liuhaotian/llava-v1.6-vicuna-7b/"} 10 | OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/artifacts"} 11 | TEMP_DIR=${TEMP_DIR:-"${ROOT_DIR}/outputs/eval_save_dir"} 12 | CONV_MODE=${CONV_MODE:-"image_seq_v3"} 13 | NUM_FRAMES=${NUM_FRAMES:-"50"} 14 | INPUT_STRUCTURE=${INPUT_STRUCTURE:-"image_seq"} 15 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION:-"slowfast-slow_10frms_spatial_1d_max_pool-fast_4x4"} 16 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO:-"resize"} 17 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR:-"2"} 18 | 19 | ################################# Run ################################## 20 | 21 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 22 | IFS=',' read -ra GPULIST <<< "$gpu_list" 23 | 24 | CHUNKS=${#GPULIST[@]} 25 | 26 | 27 | for IDX in $(seq 0 $((CHUNKS-1))); do 28 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 ${ROOT_DIR}/run_inference_video_qa.py \ 29 | --video_dir ${DATA_DIR}/MSVD_Zero_Shot_QA/videos \ 30 | --gt_file_question ${GT_QA_DIR}/MSVD_Zero_Shot_QA/val_q.json \ 31 | --gt_file_answers ${GT_QA_DIR}/MSVD_Zero_Shot_QA/val_a.json \ 32 | --output_dir ${OUTPUT_DIR}/MSVD_Zero_Shot_QA/${CONFIG_NAME} \ 33 | --output_name ${CHUNKS}_${IDX} \ 34 | --model_path ${MODEL_PATH} \ 35 | --conv_mode ${CONV_MODE} \ 36 | --num_chunks ${CHUNKS} \ 37 | --chunk_idx ${IDX} \ 38 | --num_frames ${NUM_FRAMES} \ 39 | --temperature 0 \ 40 | --input_structure ${INPUT_STRUCTURE} \ 41 | --temporal_aggregation ${TEMPORAL_AGGREGATION} \ 42 | --image_aspect_ratio ${IMAGE_ASPECT_RATIO} \ 43 | --rope_scaling_factor ${ROPE_SCALING_FACTOR} & 44 | done 45 | 46 | wait 47 | 48 | output_dir=${OUTPUT_DIR}/MSVD_Zero_Shot_QA/${CONFIG_NAME} 49 | output_file=${output_dir}/merge.jsonl 50 | temp_dir=${TEMP_DIR}/MSVD_Zero_Shot_QA/${CONFIG_NAME} 51 | 52 | # Clear out the output file if it exists. 53 | > "${output_file}" 54 | 55 | # Loop through the indices and concatenate each file. 56 | for IDX in $(seq 0 $((CHUNKS-1))); do 57 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "${output_file}" 58 | done 59 | 60 | ################################# Eval ################################## 61 | 62 | gpt_version="gpt-3.5-turbo-0125" 63 | num_tasks=25 64 | 65 | python3 ${ROOT_DIR}/eval/eval_video_qa.py \ 66 | --pred_path ${output_file} \ 67 | --output_dir ${temp_dir}/${gpt_version} \ 68 | --output_json ${output_dir}/results.json \ 69 | --gpt_version ${gpt_version} \ 70 | --num_tasks ${num_tasks} 71 | -------------------------------------------------------------------------------- /scripts/run_eval_tgif.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | ROOT_DIR=${ROOT_DIR:-"$(dirname "$(dirname "$(readlink -f "$0")")")"} 6 | CONFIG_NAME=${CONFIG_NAME:-"slowfast_llava_7b-resize-slow_10frms_spatial_1d_max_pool_fast_4x4-50_frms"} 7 | DATA_DIR=${DATA_DIR:-"${ROOT_DIR}/playground/data/video_qa"} 8 | GT_QA_DIR=${GT_QA_DIR:-"${ROOT_DIR}/playground/gt_qa_files"} 9 | MODEL_PATH=${MODEL_PATH:-"${ROOT_DIR}/liuhaotian/llava-v1.6-vicuna-7b/"} 10 | OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/artifacts"} 11 | TEMP_DIR=${TEMP_DIR:-"${ROOT_DIR}/outputs/eval_save_dir"} 12 | CONV_MODE=${CONV_MODE:-"image_seq_v3"} 13 | NUM_FRAMES=${NUM_FRAMES:-"50"} 14 | INPUT_STRUCTURE=${INPUT_STRUCTURE:-"image_seq"} 15 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION:-"slowfast-slow_10frms_spatial_1d_max_pool-fast_4x4"} 16 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO:-"resize"} 17 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR:-"2"} 18 | 19 | ################################# Run ################################## 20 | 21 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 22 | IFS=',' read -ra GPULIST <<< "$gpu_list" 23 | 24 | CHUNKS=${#GPULIST[@]} 25 | 26 | 27 | for IDX in $(seq 0 $((CHUNKS-1))); do 28 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 ${ROOT_DIR}/run_inference_video_qa.py \ 29 | --video_dir ${DATA_DIR}/TGIF_Zero_Shot_QA/mp4 \ 30 | --gt_file_question ${GT_QA_DIR}/TGIF_Zero_Shot_QA/val_q.json \ 31 | --gt_file_answers ${GT_QA_DIR}/TGIF_Zero_Shot_QA/val_a.json \ 32 | --output_dir ${OUTPUT_DIR}/TGIF_Zero_Shot_QA/${CONFIG_NAME} \ 33 | --output_name ${CHUNKS}_${IDX} \ 34 | --model_path ${MODEL_PATH} \ 35 | --conv_mode ${CONV_MODE} \ 36 | --num_chunks ${CHUNKS} \ 37 | --chunk_idx ${IDX} \ 38 | --num_frames ${NUM_FRAMES} \ 39 | --temperature 0 \ 40 | --input_structure ${INPUT_STRUCTURE} \ 41 | --temporal_aggregation ${TEMPORAL_AGGREGATION} \ 42 | --image_aspect_ratio ${IMAGE_ASPECT_RATIO} \ 43 | --rope_scaling_factor ${ROPE_SCALING_FACTOR} & 44 | done 45 | 46 | wait 47 | 48 | output_dir=${OUTPUT_DIR}/TGIF_Zero_Shot_QA/${CONFIG_NAME} 49 | output_file=${output_dir}/merge.jsonl 50 | temp_dir=${TEMP_DIR}/TGIF_Zero_Shot_QA/${CONFIG_NAME} 51 | 52 | # Clear out the output file if it exists. 53 | > "${output_file}" 54 | 55 | # Loop through the indices and concatenate each file. 56 | for IDX in $(seq 0 $((CHUNKS-1))); do 57 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "${output_file}" 58 | done 59 | 60 | ################################# Eval ################################## 61 | 62 | gpt_version="gpt-3.5-turbo-0125" 63 | num_tasks=25 64 | 65 | python3 ${ROOT_DIR}/eval/eval_video_qa.py \ 66 | --pred_path ${output_file} \ 67 | --output_dir ${temp_dir}/${gpt_version} \ 68 | --output_json ${output_dir}/results.json \ 69 | --gpt_version ${gpt_version} \ 70 | --num_tasks ${num_tasks} 71 | -------------------------------------------------------------------------------- /scripts/run_eval_videoqabench.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} 6 | ROOT_DIR=${ROOT_DIR:-"$(dirname "$(dirname "$(readlink -f "$0")")")"} 7 | CONFIG_NAME=${CONFIG_NAME:-"slowfast_llava_7b-resize-slow_10frms_spatial_1d_max_pool_fast_4x4-50_frms"} 8 | DATA_DIR=${DATA_DIR:-"${ROOT_DIR}/playground/data/video_qa"} 9 | GT_QA_DIR=${GT_QA_DIR:-"${ROOT_DIR}/playground/gt_qa_files"} 10 | MODEL_PATH=${MODEL_PATH:-"${ROOT_DIR}/liuhaotian/llava-v1.6-vicuna-7b/"} 11 | OUTPUT_DIR=${OUTPUT_DIR:-"${ROOT_DIR}/outputs/artifacts"} 12 | TEMP_DIR=${TEMP_DIR:-"${ROOT_DIR}/outputs/eval_save_dir"} 13 | CONV_MODE=${CONV_MODE:-"image_seq_v3"} 14 | NUM_FRAMES=${NUM_FRAMES:-"50"} 15 | INPUT_STRUCTURE=${INPUT_STRUCTURE:-"image_seq"} 16 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION:-"slowfast-slow_10frms_spatial_1d_max_pool-fast_4x4"} 17 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO:-"resize"} 18 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR:-"2"} 19 | SAVE_DIR=${SAVE_DIR:-"${ROOT_DIR}/outputs/artifacts/logs"} 20 | 21 | mkdir -p ${TEMP_DIR} 22 | mkdir -p ${SAVE_DIR} 23 | 24 | ################################# Run ################################## 25 | 26 | echo "evaluating msvd ..." 27 | 28 | CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \ 29 | CONFIG_NAME=${CONFIG_NAME} \ 30 | DATA_DIR=${DATA_DIR} \ 31 | GT_QA_DIR=${GT_QA_DIR} \ 32 | MODEL_PATH=${MODEL_PATH} \ 33 | OUTPUT_DIR=${OUTPUT_DIR} \ 34 | TEMP_DIR=${TEMP_DIR} \ 35 | CONV_MODE=${CONV_MODE} \ 36 | NUM_FRAMES=${NUM_FRAMES} \ 37 | INPUT_STRUCTURE=${INPUT_STRUCTURE} \ 38 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION} \ 39 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO} \ 40 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR} \ 41 | bash scripts/run_eval_msvd.sh >> ${SAVE_DIR}/${CONFIG_NAME}_msvd.log 42 | 43 | wait 44 | 45 | echo "evaluating msrvtt ..." 46 | 47 | CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \ 48 | CONFIG_NAME=${CONFIG_NAME} \ 49 | DATA_DIR=${DATA_DIR} \ 50 | GT_QA_DIR=${GT_QA_DIR} \ 51 | MODEL_PATH=${MODEL_PATH} \ 52 | OUTPUT_DIR=${OUTPUT_DIR} \ 53 | TEMP_DIR=${TEMP_DIR} \ 54 | CONV_MODE=${CONV_MODE} \ 55 | NUM_FRAMES=${NUM_FRAMES} \ 56 | INPUT_STRUCTURE=${INPUT_STRUCTURE} \ 57 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION} \ 58 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO} \ 59 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR} \ 60 | bash scripts/run_eval_msrvtt.sh >> ${SAVE_DIR}/${CONFIG_NAME}_msrvtt.log 61 | 62 | wait 63 | 64 | echo "evaluating tgif ..." 65 | 66 | CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \ 67 | CONFIG_NAME=${CONFIG_NAME} \ 68 | DATA_DIR=${DATA_DIR} \ 69 | GT_QA_DIR=${GT_QA_DIR} \ 70 | MODEL_PATH=${MODEL_PATH} \ 71 | OUTPUT_DIR=${OUTPUT_DIR} \ 72 | TEMP_DIR=${TEMP_DIR} \ 73 | CONV_MODE=${CONV_MODE} \ 74 | NUM_FRAMES=${NUM_FRAMES} \ 75 | INPUT_STRUCTURE=${INPUT_STRUCTURE} \ 76 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION} \ 77 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO} \ 78 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR} \ 79 | bash scripts/run_eval_tgif.sh >> ${SAVE_DIR}/${CONFIG_NAME}_tgif.log 80 | 81 | wait 82 | 83 | echo "evaluating activitynet ..." 84 | 85 | CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \ 86 | CONFIG_NAME=${CONFIG_NAME} \ 87 | DATA_DIR=${DATA_DIR} \ 88 | GT_QA_DIR=${GT_QA_DIR} \ 89 | MODEL_PATH=${MODEL_PATH} \ 90 | OUTPUT_DIR=${OUTPUT_DIR} \ 91 | TEMP_DIR=${TEMP_DIR} \ 92 | CONV_MODE=${CONV_MODE} \ 93 | NUM_FRAMES=${NUM_FRAMES} \ 94 | INPUT_STRUCTURE=${INPUT_STRUCTURE} \ 95 | TEMPORAL_AGGREGATION=${TEMPORAL_AGGREGATION} \ 96 | IMAGE_ASPECT_RATIO=${IMAGE_ASPECT_RATIO} \ 97 | ROPE_SCALING_FACTOR=${ROPE_SCALING_FACTOR} \ 98 | bash scripts/run_eval_activitynet.sh >> ${SAVE_DIR}/${CONFIG_NAME}_activitynet.log 99 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | #!/bin/bash 6 | 7 | pip install -e ".[train]" 8 | pip install flash-attn --no-build-isolation --no-cache-dir 9 | 10 | apt-get update 11 | apt-get install git-lfs 12 | git-lfs install 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import math 6 | 7 | 8 | def split_list(lst, n): 9 | """Split a list into n (roughly) equal-sized chunks""" 10 | chunk_size = math.ceil(len(lst) / n) # integer division 11 | return [lst[i: i + chunk_size] for i in range(0, len(lst), chunk_size)] 12 | 13 | 14 | def get_chunk(lst, n, k): 15 | chunks = split_list(lst, n) 16 | return chunks[k] 17 | --------------------------------------------------------------------------------