├── ape_tools ├── ape_service.py └── ape_api.py ├── vidrag_pipeline ├── tools │ ├── filter_keywords.py │ ├── rag_retriever_dynamic.py │ └── scene_graph.py └── vidrag_pipeline.py ├── readme.md └── evals ├── generate_videomme.py ├── generate_longvideobench.py └── generate_mlvu.py /ape_tools/ape_service.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import pickle 3 | from ape_api import setup_cfg, ape_inference, VisualizationDemo 4 | 5 | def main(): 6 | # Initialize demo once 7 | cfg = setup_cfg() 8 | demo = VisualizationDemo(cfg, args=None) 9 | 10 | # Set up a socket server 11 | server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 12 | server_socket.bind(('0.0.0.0', 9999)) 13 | server_socket.listen() 14 | 15 | print("Server is listening...") 16 | 17 | while True: 18 | client_socket, addr = server_socket.accept() 19 | print(f"Connection from {addr}") 20 | 21 | data = client_socket.recv(8192) 22 | if not data: 23 | break 24 | 25 | # Deserialize datas 26 | input_files, text_prompt = pickle.loads(data) 27 | 28 | # Run inference 29 | result = ape_inference(input_files, text_prompt, demo) 30 | 31 | # Send back result 32 | client_socket.send(pickle.dumps(result)) 33 | client_socket.close() 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /vidrag_pipeline/tools/filter_keywords.py: -------------------------------------------------------------------------------- 1 | import io 2 | import spacy 3 | 4 | nlp = spacy.load("en_core_web_sm") 5 | 6 | # pip install spacy 7 | # pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz 8 | 9 | def filter_keywords(keywords): 10 | filtered_keywords = [] 11 | 12 | for phrase in keywords: 13 | doc = nlp(phrase) 14 | if len(doc) == 1: 15 | if doc[0].pos_ in ["NOUN", "ADJ", "VERB"] and phrase != 'video': 16 | filtered_keywords.append(phrase) 17 | else: 18 | is_valid = False 19 | if len(doc) == 2 and ( 20 | (doc[0].pos_ == "ADJ" and doc[1].pos_ in ["NOUN", "PROPN"]) or 21 | (doc[0].pos_ in ["NOUN", "PROPN"] and doc[1].pos_ in ["NOUN", "PROPN"]) 22 | ): 23 | is_valid = True 24 | elif len(doc) == 3 and doc[0].pos_ == "ADJ" and doc[1].pos_ in ["NOUN", "PROPN"] and doc[2].pos_ in ["NOUN", "PROPN"]: 25 | is_valid = True 26 | elif len(doc) == 2 and doc[0].pos_ == "VERB" and doc[1].pos_ in ["NOUN", "PROPN"]: 27 | is_valid = True 28 | 29 | if is_valid and phrase != 'video': 30 | filtered_keywords.append(phrase) 31 | 32 | return filtered_keywords -------------------------------------------------------------------------------- /vidrag_pipeline/tools/rag_retriever_dynamic.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModel 2 | import torch 3 | import numpy as np 4 | import faiss 5 | 6 | tokenizer = AutoTokenizer.from_pretrained('facebook/contriever') 7 | model = AutoModel.from_pretrained('facebook/contriever') 8 | 9 | def text_to_vector(text, max_length=512): 10 | inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=max_length) 11 | with torch.no_grad(): 12 | outputs = model(**inputs) 13 | return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() 14 | 15 | def retrieve_documents_with_dynamic(documents, queries, threshold=0.4): 16 | if isinstance(queries, list): 17 | query_vectors = np.array([text_to_vector(query) for query in queries]) 18 | average_query_vector = np.mean(query_vectors, axis=0) 19 | query_vector = average_query_vector / np.linalg.norm(average_query_vector) 20 | query_vector = query_vector.reshape(1, -1) 21 | else: 22 | query_vector = text_to_vector(queries) 23 | query_vector = query_vector / np.linalg.norm(query_vector) 24 | query_vector = query_vector.reshape(1, -1) 25 | 26 | document_vectors = np.array([text_to_vector(doc) for doc in documents]) 27 | document_vectors = document_vectors / np.linalg.norm(document_vectors, axis=1, keepdims=True) 28 | dimension = document_vectors.shape[1] 29 | 30 | index = faiss.IndexFlatIP(dimension) 31 | index.add(document_vectors) 32 | lims, D, I = index.range_search(query_vector, threshold) 33 | start = lims[0] 34 | end = lims[1] 35 | I = I[start:end] 36 | 37 | if len(I) == 0: 38 | top_documents = [] 39 | idx = [] 40 | else: 41 | idx = I.tolist() 42 | top_documents = [documents[i] for i in idx] 43 | 44 | return top_documents, idx 45 | -------------------------------------------------------------------------------- /vidrag_pipeline/tools/scene_graph.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | def calculate_xmax_ymax(bbox): 4 | xmin, ymin, width, height = bbox 5 | xmax = xmin + width 6 | ymax = ymin + height 7 | return xmax, ymax 8 | 9 | def calculate_spatial_relations(bbox1, bbox2): 10 | xmin1, ymin1, width1, height1 = bbox1 11 | xmin2, ymin2, width2, height2 = bbox2 12 | 13 | xmax1, ymax1 = calculate_xmax_ymax(bbox1) 14 | xmax2, ymax2 = calculate_xmax_ymax(bbox2) 15 | 16 | relations = [] 17 | 18 | if xmin1 < xmax2 and xmax1 > xmin2 and ymin1 < ymax2 and ymax1 > ymin2: 19 | relations.append("overlaps") 20 | 21 | if xmax1 < xmin2: 22 | relations.append("left_of") 23 | 24 | if xmin1 > xmax2: 25 | relations.append("right_of") 26 | 27 | if ymax1 < ymin2: 28 | relations.append("above") 29 | 30 | if ymin1 > ymax2: 31 | relations.append("below") 32 | 33 | return relations 34 | 35 | 36 | def relation_to_text(source_id, source_label, relation, target_id, target_label): 37 | if relation == "overlaps": 38 | return f"Object {source_id} ({source_label}) overlaps with Object {target_id} ({target_label})." 39 | elif relation == "left_of": 40 | return f"Object {source_id} ({source_label}) is to the left of Object {target_id} ({target_label})." 41 | elif relation == "right_of": 42 | return f"Object {source_id} ({source_label}) is to the right of Object {target_id} ({target_label})." 43 | elif relation == "above": 44 | return f"Object {source_id} ({source_label}) is above Object {target_id} ({target_label})." 45 | elif relation == "below": 46 | return f"Object {source_id} ({source_label}) is below Object {target_id} ({target_label})." 47 | elif relation == "same_object_type": 48 | return f"Object {source_id} ({source_label}) is of the same type as Object {target_id} ({target_label})." 49 | else: 50 | return f"Object {source_id} ({source_label}) is related to Object {target_id} ({target_label})." 51 | 52 | def generate_scene_graph_description(objects, location_des, relation_des, number_des): 53 | 54 | scene_graph = nx.DiGraph() 55 | object_count = {} 56 | 57 | for obj in objects: 58 | scene_graph.add_node(obj['id'], label=obj['label'], bbox=obj['bbox']) 59 | 60 | label = obj['label'] 61 | if label in object_count: 62 | object_count[label] += 1 63 | else: 64 | object_count[label] = 1 65 | 66 | for node1, data1 in scene_graph.nodes(data=True): 67 | for node2, data2 in scene_graph.nodes(data=True): 68 | if node1 < node2: 69 | bbox1 = data1['bbox'] 70 | bbox2 = data2['bbox'] 71 | relations = calculate_spatial_relations(bbox1, bbox2) 72 | 73 | for relation in relations: 74 | scene_graph.add_edge(node1, node2, relation=relation) 75 | 76 | descriptions = [] 77 | 78 | if location_des: 79 | for node, data in scene_graph.nodes(data=True): 80 | label = data.get('label', 'unknown object') 81 | bbox = data.get('bbox', []) 82 | description = f"Object {node} is a {label} located at coordinates [{bbox[0]}, {bbox[1]}] with dimensions {bbox[2]}x{bbox[3]}." 83 | descriptions.append(description) 84 | 85 | if relation_des: 86 | for source, target, data in scene_graph.edges(data=True): 87 | relation = data.get('relation', 'related to') 88 | source_label = scene_graph.nodes[source]['label'] 89 | target_label = scene_graph.nodes[target]['label'] 90 | description = relation_to_text(source, source_label, relation, target, target_label) 91 | descriptions.append(description) 92 | 93 | if number_des: 94 | count_description = "Object counting:\n" 95 | for label, count in object_count.items(): 96 | count_description += f"- {label}: {count}\n" 97 | descriptions.append(count_description) 98 | 99 | return "\n".join(descriptions) -------------------------------------------------------------------------------- /ape_tools/ape_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | from collections import abc 4 | import cv2 5 | import numpy as np 6 | import tqdm 7 | from PIL import Image 8 | from detectron2.config import LazyConfig, get_cfg 9 | from detectron2.data.detection_utils import read_image 10 | from detectron2.evaluation.coco_evaluation import instances_to_coco_json 11 | from detectron2.utils.logger import setup_logger 12 | from predictor_lazy import VisualizationDemo 13 | from decord import VideoReader, cpu 14 | 15 | import logging 16 | logging.getLogger().setLevel(logging.ERROR) 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | def setup_cfg(): 21 | # load config from file and command-line arguments 22 | config_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k.py" 23 | opts = ['train.init_checkpoint=/checkpoints/model_final.pth', 'model.model_language.cache_dir=', 'model.model_vision.select_box_nums_for_evaluation=500', 'model.model_vision.text_feature_bank_reset=True', 'model.model_vision.backbone.net.xattn=False'] 24 | cfg = LazyConfig.load(config_file) 25 | cfg = LazyConfig.apply_overrides(cfg, opts) 26 | confidence_threshold = 0.1 27 | 28 | if "output_dir" in cfg.model: 29 | cfg.model.output_dir = cfg.train.output_dir 30 | if "model_vision" in cfg.model and "output_dir" in cfg.model.model_vision: 31 | cfg.model.model_vision.output_dir = cfg.train.output_dir 32 | if "train" in cfg.dataloader: 33 | if isinstance(cfg.dataloader.train, abc.MutableSequence): 34 | for i in range(len(cfg.dataloader.train)): 35 | if "output_dir" in cfg.dataloader.train[i].mapper: 36 | cfg.dataloader.train[i].mapper.output_dir = cfg.train.output_dir 37 | else: 38 | if "output_dir" in cfg.dataloader.train.mapper: 39 | cfg.dataloader.train.mapper.output_dir = cfg.train.output_dir 40 | 41 | if "model_vision" in cfg.model: 42 | cfg.model.model_vision.test_score_thresh = confidence_threshold 43 | else: 44 | cfg.model.test_score_thresh = confidence_threshold 45 | 46 | setup_logger(name="ape") 47 | setup_logger(name="timm") 48 | 49 | return cfg 50 | 51 | def ape_inference(input, text_prompt, demo): 52 | 53 | res_list = [] 54 | 55 | for path in tqdm.tqdm(input): 56 | # use PIL, to be consistent with evaluation 57 | try: 58 | img = read_image(path, format="BGR") 59 | except Exception as e: 60 | continue 61 | 62 | predictions, visualized_output, visualized_outputs, metadata = demo.run_on_image( 63 | img, 64 | text_prompt=text_prompt, 65 | with_box=True, 66 | with_mask=False, 67 | with_sseg=False, 68 | ) 69 | 70 | res = "" 71 | with_box = True 72 | if "instances" in predictions: 73 | results = instances_to_coco_json( 74 | predictions["instances"].to(demo.cpu_device), path 75 | ) 76 | if with_box: 77 | for result in results: 78 | res += metadata.thing_classes[result["category_id"]] + ": [" 79 | for idx, box in enumerate(result['bbox']): 80 | if idx != 3: 81 | res += str(int(box)) + ", " 82 | else: 83 | res += str(int(box)) 84 | res += "]; " 85 | else: 86 | for result in results: 87 | res += metadata.thing_classes[result["category_id"]] + ", " 88 | 89 | if len(res) > 0: 90 | if with_box: 91 | res_list.append(res[:-2]) 92 | else: 93 | res_list.append(res) 94 | else: 95 | res_list.append("") 96 | 97 | return res_list 98 | 99 | if __name__ == "__main__": 100 | 101 | text_prompt = "Apples,Candles,Berries" 102 | res = ape_inference(input, text_prompt) 103 | print(res) 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Video-RAG: Visually-aligned Retrieval-Augmented Long Video Comprehension 2 | 3 | [![Arxiv](https://img.shields.io/badge/Arxiv-2411.13093-red)](https://arxiv.org/abs/2411.13093) 4 | ![](https://img.shields.io/badge/Task-VideoQA-blue) [![Arxiv](https://img.shields.io/badge/Web-Project_Page-yellow)](https://video-rag.github.io/) 5 | [![YouTube](https://img.shields.io/badge/-YouTube-000000?logo=youtube&logoColor=FF0000)](https://www.youtube.com/watch?v=WTs3xHicR_0) 6 | [![Blog](https://img.shields.io/badge/Blog-LearnOpenCV-blue?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAJYAAACWCAMAAAAL34HQAAAAilBMVEVHcEwuLi4qKio3NzdQUFBjY2NoaGhiYmJ8fHx4eHiGhoabm5t1dXW2tranp6eOjo7Pz8/////9/f35+fn29vbz8/Pv7+/t7e3q6urn5+fj4+Pg4OA4svIyrfAsp+0noesinekemOcalOUchMVjZGQXaZxOT08OT3oMPFwvMTIONEoKIS8OFx0DAwPBWB/1AAAAEXRSTlMACxw5ZXqQmqG1vdfv8ff7/XwvPHUAAAnaSURBVHja3ZyJkqI6GIUbXHpcwHSUbtuMdtuiLcu8/+vdCIRDjElYxLLuz701W03NV+ccAmT5XzqV4w6Go/FkPl/4RS3m88l4NBy4zsuDC0Sv05nnEcLYhvH/ebHsp4R43mz6+ng2ZzD8M+NAOcaGI6Hy38ngZn+Gg4eRuYPR1LsQVWk+iwtVoE1HA/cROo1nORPLcW4WBwSaNxv3rJk7nGRMGwlpfV2CrSDbEG8ydPsTajQjFZXWJc/H+gMFuAteqdls1I9k3D1Sca4AQr3zyn4EXa6bAOvBS2fw6mdQYCpheAWVKyuwZaIVYP7rfcHcUqkKkwC6UQWeICsl44q595NqOCUQCkwFwuq6gJaTQbINBxs6dwrVxMuUUphAtFwti1ot+a8A914hK7z0JoN7+Dea5VLlUIJpFeQ8uip0WwkyDgbBRm53qWSokglItKg3KqpkA5kM1lEwZzgjhX+AqjDREgklACtkAMuc7Jowd+wJqQAlmMBzq3I2QQYwIZjX/pYcTAkT/mVQEMqIBDReQrIMDE4yMm1pZGkgoCDUW80CGcAuXGzDjWwVK5+xMlWZfRDKjqOC5VZCMMb8odOcymOsTBWkAlMzslIwJIwxrymXM/JgoALVFQxcI6cR1VihAlQXMImLlzd2GlMVBopUAaotmJQwnlpwNXDQKFV3wYRe8NGadpUKUF3qysiMq27uVSpIdRfBwFXcj7Wo/AqVRqqugkl6Mb8G12CWj1f9UMFI5Iux2cBG5U4J02jVHxeZurahgfSmFbjeJK5PzjV2LHHXUPWslzc0B0s8BzlVAKoe9BK5z8d7Y7zc7A25ORWly+DyD/AnQu2/UXCJYYJM9PEaIVh1qehqTfxFGJ6irE5huPDJekXr6gWukcVCBItSGxPxwyhN0yQ+FxUn/JdR6BMbGaWIl9FGBxbWolqyxSlKk/Pv8XD4+fnZ/+z3/IfD4fh7TtLotGDLmlzCRsdwF66FhRaqFZlzJo70c6M4Giebk5WZCzbq70Z3imBZqQISRskZTLfIzkkUksDOVdo4dTUDqWShwT4ScqHApCPjkoXEYOW1jWNHm3d7sOjHIkoBZQZLo8UHtcZLm3rnNRfLbiElpzQGlKWOcXoi1G5jLtercy2WX9PCwI/gX536TSI/qGmjP1CTVSvv6/lFqv1PgzrE6XxdJ/VIF5IliaWjoixMz6pUe7kUrnMaMqrjglxqukYQy2QhO6W/GiQ9GNf2Nz0xk42QaySNWRDLYCHlVMfbUN9l3SY7ci6qt1HIxdjMrQ7wuA0NYilUQEIBTeEyyIWbcVh9GmLM0ou1DmWqK6av7JLIZK5wbZQrd5FNHATew5ilFSuYy7kC1NdVAUwaKNJ5oJULY5eH0IvRwXQbUj89a6DAZQQ7pz413YxijFACbxCLkig+qFSCYycKZArXIY4INcglxghX8VAv1scp0VJdaPiV/8eropjMlZzWerngIgYtW+CXi2qwAAWdpJIEq8ZrsbSEHkOXM7V7KCxUqQTXtrgA9q1ycRvtLk6duh4GYSpTyVJtqyUpJrgwSgR1XRzKHt7MOyxUqS4ogKqgcSrBBRs9nVxwcZh5+Mfq4Sqs5D2nunBdKfWXX5cSZn5Br0rqw5XVxT8OhgeThyRSxZKp/qK2kl6KXBGxujhzRbSMT+nlXIgFKkBlTIASYLutFC/INV/q70WEyx4tyiKM79AKVGqVgsFGyMVqheuVWKJFF+nxplgyFMfYKlyqjcd0QW/KhXCRV4xa+mitTvBQthBUh2Oc8IqPh62GCy6eVtpwYeRC4hEtNfAQq7QQWu2OcfqvqDQ+7gQXbIRcCL0hXDM3T7wxWj48hFg80YLqEP+TKj6AC3LBxZsvEteZH3oi8StgqYOWilVSJf+uKgFXRS5p6DJn3hviOa1L/Do6yx5CLFDJlR5kuWQXz9H6Ta1sQMXTekwsiceDRxYLVGol4FJd/E2JJfNk/DJhFiwpWhArp9oiV1LFWe4hVwVrz8NlwWKTl7nu0YNR66B6KCw8/tPUUWCp9+IBI5fu8TN/mW/M4wMNEwULHsY6rHgLFwUWMn8Ta1libeYvCwnrRhTDWI1WmfcUIJrU3whXHC7NI8Rm8eIzM1ZwOktYiod2F3fXWKfgJhZM9DmWeTR9j4Bl9xAVq5nHCPGuVas9FkzcJnqspBqub+lWPEcfdixupWmQX19hfVdM3JmwdsiWota6o1oq1ldzLFUtYKEUtYxYH7fU2tYx0aDWR2e1Ak227JHXZ8se+f4HCHXcOp8CG9YzDacYtxZ9P3x2moePeZSf9/uohloNH9WtXmyAtbO+2Kg3Yq0XmzFhbV4Dd+1fA481XgNbvzRvDVyJ8aU51gzy0kszPjGCOp8Y+KQ2f2KoHlpmR+RPjBYfZMCyfZDtWn+Qtf983dk/X7cNP18xZdPxY/+v6WN/1/5jXz81grrn1MgeUyNKtDA1UmuOUjuRZJ5H6jSRhMxrw0XnycE0GajHajvtdqdJShkJWrWdpKw3pbtUp3QFlzynK5hA1XxKl2VTuoZwocwT4CBTqNpNgNdcLqDm5YILh4Sko9ofk3rLBd0XV+TaVplaL67UXoqi+qWoL4WJU+06LkWpLi7rLdxh3RVkQLrHwp3TeZlzl7Hxy7DMua+9zNl1URhg+kXhfeNF4WZL6Cv9Evp3myV0lLKEfr8NB9hvAChIVXvDwYZNnIbbM+iDt2d038zyvecsF572m1kQeJRj2foDLuQLZGope7iglVLK1p+eNkr9dN4ohaq5rYzm28rU2udX921lLTfhrRpuwju02YSHqr1lcdloy+Kh5ZbF/jd4rhts8Oy0HZY+YDssxq4mm4e9MMKGZt225ij0mmwenrlPu9W6+8Z0io3pYOu6Mb37Nn7+h5238cPCzWzwDIceqHLo4RmOiLypR0Se4EANrXGgBuU86PjRm3r86AkOa1GJCgOpvtxp71yqVmTqPsVBwJWgWuMgYKNjk8FDjk0+3SFTxP1JjuSCyml6BL3/A8xNDqI7Y3DdTzBACSoc936Gw/FBcTj+c8NA9RStBODgczZeePY2Fd2betCuTT1Ed5b/SwsUNIxhbRvGUEPDGIaGMQ9qr0PL0rfX2XRt4OQMJkQIBjCQAa1aS6nlD6CEVGQy6Kl1U6C0bspYUCswaVo39d/oSuHpt9EV2oIVYCBDozK5JVi1LxiEQr8ypOrxTdT4b9uaqPXfck5uOlcUWs4BCi3nHtKgL4MoWN4/QASmDRr03b2kdoY5GuBQ6LX4iHaGaP4IMgH3WW2tiOaPn4IJzR/7AuNeVttSantlPqhVJsqRGosKPFwbILFHNBYFGNqwSnQA0rRhfXDTWtTmPk1rn7TF739gu3see8j9YQAAAABJRU5ErkJggg==)](https://learnopencv.com/video-rag-for-long-videos/) 7 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=Leon1207.Video-RAG-masterX&left_color=gray&right_color=orange) 8 | 9 |
[[🍎 Project Page](https://video-rag.github.io/)] [[📖 arXiv Paper](https://arxiv.org/pdf/2411.13093)]
10 | 11 | ## 😮 Highlights 12 | ![radar](https://github.com/user-attachments/assets/3ee6d1a7-24d7-42ff-a592-081491862ce8) 13 | 14 | - **We integrate RAG into open-source LVLMs:** Video-RAG incorporates three types of visually-aligned auxiliary texts (OCR, ASR, and object detection) processed by external tools and retrieved via RAG, enhancing the LVLM. It’s implemented using completely open-source tools, without the need for any commercial APIs. 15 | - **We design a versatile plug-and-play RAG-based pipeline for any LVLM:** Video-RAG offers a training-free solution for a wide range of LVLMs, delivering performance improvements with minimal additional resource requirements. 16 | - **We achieve proprietary-level performance with open-source models:** Applying Video-RAG to a 72B open-source model yields state-of-the-art performance in Video-MME, surpassing models such as Gemini-1.5-Pro. 17 | ![framework](https://github.com/user-attachments/assets/9c9b176c-10a8-483e-be6b-de72b2b68191) 18 | ![results](https://github.com/user-attachments/assets/657454fe-5252-4043-a83d-a486283dce96) 19 | 20 | 21 | 22 | ## 🔨 Usage 23 | 24 | This repo is built upon LLaVA-NeXT: 25 | 26 | - Step 1: Clone and build LLaVA-NeXT conda environment: 27 | 28 | ``` 29 | git clone https://github.com/LLaVA-VL/LLaVA-NeXT 30 | cd LLaVA-NeXT 31 | conda create -n llava python=3.10 -y 32 | conda activate llava 33 | pip install --upgrade pip # Enable PEP 660 support. 34 | pip install -e ".[train]" 35 | ``` 36 | Then install the following packages in llava environment: 37 | ``` 38 | pip install spacy faiss-cpu easyocr ffmpeg-python 39 | pip install torch==2.1.2 torchaudio numpy 40 | python -m spacy download en_core_web_sm 41 | # Optional: pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz 42 | ``` 43 | 44 | - Step 2: Clone and build another conda environment for APE by: 45 | 46 | ``` 47 | git clone https://github.com/shenyunhang/APE 48 | cd APE 49 | pip3 install -r requirements.txt 50 | python3 -m pip install -e . 51 | ``` 52 | 53 | - Step 3: Copy all the files in `vidrag_pipeline` under the root dir of LLaVA-NeXT; 54 | 55 | - Step 4: Copy all the files in `ape_tools` under the `demo` dir of APE; 56 | 57 | - Step 5: Opening a service of APE by running the code under `APE/demo`: 58 | 59 | ``` 60 | python demo/ape_service.py 61 | ``` 62 | 63 | - Step 6: You can now run our pipeline build upon LLaVA-Video-7B by: 64 | 65 | ``` 66 | python vidrag_pipeline.py 67 | ``` 68 | 69 | > [!NOTE] 70 | > You can also use our pipeline in any LVLMs by implementing some modifications in `vidrag_pipeline.py`: 71 | ``` 72 | 1. The video-language model you load (line #161). 73 | 2. The llava_inference() function, make sure your model supports both inputs with/without video (line #175). 74 | 3. The process_video() function may suit your model (line #34). 75 | 4. The final prompt may suit your model (line #366). 76 | ``` 77 | 78 | ## ✏️ Citation 79 | 80 | If you find our paper and code useful in your research, please consider giving a star ⭐ and citation 📝: 81 | 82 | ``` 83 | @misc{luo2024videoragvisuallyalignedretrievalaugmentedlong, 84 | title={Video-RAG: Visually-aligned Retrieval-Augmented Long Video Comprehension}, 85 | author={Yongdong Luo and Xiawu Zheng and Xiao Yang and Guilin Li and Haojia Lin and Jinfa Huang and Jiayi Ji and Fei Chao and Jiebo Luo and Rongrong Ji}, 86 | year={2024}, 87 | eprint={2411.13093}, 88 | archivePrefix={arXiv}, 89 | primaryClass={cs.CV}, 90 | url={https://arxiv.org/abs/2411.13093}, 91 | } 92 | ``` 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /vidrag_pipeline/vidrag_pipeline.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | import torch 7 | from transformers import AutoProcessor, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel 8 | import copy 9 | from decord import VideoReader, cpu 10 | import numpy as np 11 | import json 12 | from tqdm import tqdm 13 | import os 14 | import easyocr 15 | from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic 16 | import re 17 | import ast 18 | import socket 19 | import pickle 20 | from tools.filter_keywords import filter_keywords 21 | from tools.scene_graph import generate_scene_graph_description 22 | import ffmpeg, torchaudio 23 | 24 | max_frames_num = 32 25 | clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto") 26 | clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336") 27 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 28 | "openai/whisper-large", 29 | torch_dtype=torch.float16, 30 | device_map="auto" 31 | ) 32 | whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large") 33 | 34 | def process_video(video_path, max_frames_num, fps=1, force_sample=False): 35 | if max_frames_num == 0: 36 | return np.zeros((1, 336, 336, 3)) 37 | vr = VideoReader(video_path, ctx=cpu(),num_threads=1) 38 | total_frame_num = len(vr) 39 | video_time = total_frame_num / vr.get_avg_fps() 40 | fps = round(vr.get_avg_fps()/fps) 41 | frame_idx = [i for i in range(0, len(vr), fps)] 42 | frame_time = [i/fps for i in frame_idx] 43 | if len(frame_idx) > max_frames_num or force_sample: 44 | sample_fps = max_frames_num 45 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) 46 | frame_idx = uniform_sampled_frames.tolist() 47 | frame_time = [i/vr.get_avg_fps() for i in frame_idx] 48 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) 49 | spare_frames = vr.get_batch(frame_idx).asnumpy() 50 | 51 | return spare_frames, frame_time, video_time 52 | 53 | def extract_audio(video_path, audio_path): 54 | if not os.path.exists(audio_path): 55 | ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run() 56 | 57 | def chunk_audio(audio_path, chunk_length_s=30): 58 | speech, sr = torchaudio.load(audio_path) 59 | speech = speech.mean(dim=0) 60 | speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech) 61 | num_samples_per_chunk = chunk_length_s * 16000 62 | chunks = [] 63 | for i in range(0, len(speech), num_samples_per_chunk): 64 | chunks.append(speech[i:i + num_samples_per_chunk]) 65 | return chunks 66 | 67 | def transcribe_chunk(chunk): 68 | 69 | inputs = whisper_processor(chunk, return_tensors="pt") 70 | inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16) 71 | with torch.no_grad(): 72 | predicted_ids = whisper_model.generate( 73 | inputs["input_features"], 74 | no_repeat_ngram_size=2, 75 | early_stopping=True 76 | ) 77 | transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] 78 | return transcription 79 | 80 | def get_asr_docs(video_path, audio_path): 81 | 82 | full_transcription = [] 83 | 84 | try: 85 | extract_audio(video_path, audio_path) 86 | except: 87 | return full_transcription 88 | audio_chunks = chunk_audio(audio_path, chunk_length_s=30) 89 | 90 | for chunk in audio_chunks: 91 | transcription = transcribe_chunk(chunk) 92 | full_transcription.append(transcription) 93 | 94 | return full_transcription 95 | 96 | def get_ocr_docs(frames): 97 | reader = easyocr.Reader(['en']) 98 | text_set = [] 99 | ocr_docs = [] 100 | for img in frames: 101 | ocr_results = reader.readtext(img) 102 | det_info = "" 103 | for result in ocr_results: 104 | text = result[1] 105 | confidence = result[2] 106 | if confidence > 0.5 and text not in text_set: 107 | det_info += f"{text}; " 108 | text_set.append(text) 109 | if len(det_info) > 0: 110 | ocr_docs.append(det_info) 111 | 112 | return ocr_docs 113 | 114 | 115 | def save_frames(frames): 116 | file_paths = [] 117 | for i, frame in enumerate(frames): 118 | img = Image.fromarray(frame) 119 | file_path = f'restore/frame_{i}.png' 120 | img.save(file_path) 121 | file_paths.append(file_path) 122 | return file_paths 123 | 124 | def get_det_docs(frames, prompt): 125 | prompt = ",".join(prompt) 126 | frames_path = save_frames(frames) 127 | res = [] 128 | if len(frames) > 0: 129 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 130 | client_socket.connect(('0.0.0.0', 9999)) 131 | data = (frames_path, prompt) 132 | client_socket.send(pickle.dumps(data)) 133 | result_data = client_socket.recv(4096) 134 | try: 135 | res = pickle.loads(result_data) 136 | except: 137 | res = [] 138 | return res 139 | 140 | def det_preprocess(det_docs, location, relation, number): 141 | 142 | scene_descriptions = [] 143 | 144 | for det_doc_per_frame in det_docs: 145 | objects = [] 146 | scene_description = "" 147 | if len(det_doc_per_frame) > 0: 148 | for obj_id, objs in enumerate(det_doc_per_frame.split(";")): 149 | obj_name = objs.split(":")[0].strip() 150 | obj_bbox = objs.split(":")[1].strip() 151 | obj_bbox = ast.literal_eval(obj_bbox) 152 | objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox}) 153 | 154 | scene_description = generate_scene_graph_description(objects, location, relation, number) 155 | scene_descriptions.append(scene_description) 156 | 157 | return scene_descriptions 158 | 159 | 160 | # load your VLM 161 | device = "cuda" 162 | overwrite_config = {} 163 | tokenizer, model, image_processor, max_length = load_pretrained_model( 164 | "LLaVA-Video-7B-Qwen2", 165 | None, 166 | "llava_qwen", 167 | torch_dtype="bfloat16", 168 | device_map="auto", 169 | overwrite_config=overwrite_config) # Add any other thing you want to pass in llava_model_args 170 | model.eval() 171 | conv_template = "qwen_1_5" # Make sure you use correct chat template for different models 172 | 173 | 174 | # The inference function of your VLM 175 | def llava_inference(qs, video): 176 | if video is not None: 177 | question = DEFAULT_IMAGE_TOKEN + qs 178 | else: 179 | question = qs 180 | conv = copy.deepcopy(conv_templates[conv_template]) 181 | conv.append_message(conv.roles[0], question) 182 | conv.append_message(conv.roles[1], None) 183 | prompt_question = conv.get_prompt() 184 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 185 | cont = model.generate( 186 | input_ids, 187 | images=video, 188 | modalities= ["video"], 189 | do_sample=False, 190 | temperature=0, 191 | max_new_tokens=4096, 192 | ) 193 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() 194 | return text_outputs 195 | 196 | 197 | # super-parameters setting 198 | rag_threshold = 0.3 199 | clip_threshold = 0.3 200 | beta = 3.0 201 | 202 | # Choose the auxiliary texts you want 203 | USE_OCR = True 204 | USE_ASR = True 205 | USE_DET = True 206 | print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------") 207 | print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------") 208 | print(f"---------------DET{beta}-{clip_threshold}: {USE_DET}-----------------") 209 | print(f"---------------Frames: {max_frames_num}-----------------") 210 | 211 | 212 | video_path = "/path/to/your/video.mp4" # your video path 213 | question = "How many people appear in the video? A. 1. B. 2. C. 3. D. 4." # your question 214 | 215 | 216 | frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True) 217 | raw_video = [f for f in frames] 218 | 219 | video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16() 220 | video = [video] 221 | 222 | if USE_DET: 223 | video_tensor = [] 224 | for frame in raw_video: 225 | processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16) 226 | video_tensor.append(processed.squeeze(0)) 227 | video_tensor = torch.stack(video_tensor, dim=0) 228 | 229 | if USE_OCR: 230 | ocr_docs_total = get_ocr_docs(frames) 231 | 232 | if USE_ASR: 233 | if os.path.exists(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt")): 234 | with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f: 235 | asr_docs_total = f.readlines() 236 | else: 237 | audio_path = os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".wav") 238 | asr_docs_total = get_asr_docs(video_path, audio_path) 239 | with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f: 240 | for doc in asr_docs_total: 241 | f.write(doc + '\n') 242 | 243 | # step 0: get cot information 244 | retrieve_pmt_0 = "Question: " + question 245 | # you can change this decouple prompt to fit your requirements 246 | retrieve_pmt_0 += "\nTo answer the question step by step, you can provide your retrieve request to assist you by the following json format:" 247 | retrieve_pmt_0 += '''{ 248 | "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null. 249 | "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null. 250 | "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation". 251 | } 252 | ## Example 1: 253 | Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4. 254 | Your retrieve can be: 255 | { 256 | "ASR": "The location and the color of balloons, the number of the blue balloons.", 257 | "DET": ["blue ballons", "long table"], 258 | "TYPE": ["relation", "number"] 259 | } 260 | ## Example 2: 261 | Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow. 262 | Your retrieve can be: 263 | { 264 | "ASR": null, 265 | "DET": ["the man in black", "woman"], 266 | "TYPE": ["location", "relation"] 267 | } 268 | ## Example 3: 269 | Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States. 270 | Your retrieve can be: 271 | { 272 | "ASR": "The country recognized worldwide for its comedy.", 273 | "DET": null, 274 | "TYPE": null 275 | } 276 | Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.''' 277 | 278 | json_request, _ = llava_inference(retrieve_pmt_0, None) 279 | 280 | # step 1: get docs information 281 | query = [question] 282 | 283 | # APE fetch 284 | if USE_DET: 285 | det_docs = [] 286 | try: 287 | request_det = json.loads(json_request)["DET"] 288 | request_det = filter_keywords(request_det) 289 | clip_text = ["A picture of " + txt for txt in request_det] 290 | if len(clip_text) == 0: 291 | clip_text = ["A picture of object"] 292 | except: 293 | request_det = None 294 | clip_text = ["A picture of object"] 295 | 296 | clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device) 297 | clip_img_feats = clip_model.get_image_features(video_tensor) 298 | with torch.no_grad(): 299 | text_features = clip_model.get_text_features(**clip_inputs) 300 | similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu() 301 | similarities = np.array(similarities, dtype=np.float64) 302 | alpha = beta * (len(similarities) / 16) 303 | similarities = similarities * alpha / np.sum(similarities) 304 | 305 | del clip_inputs, clip_img_feats, text_features 306 | torch.cuda.empty_cache() 307 | 308 | det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold] 309 | 310 | if request_det is not None and len(request_det) > 0: 311 | det_docs = get_det_docs(frames[det_top_idx], request_det) 312 | 313 | L, R, N = False, False, False 314 | try: 315 | det_retrieve_info = json.loads(json_request)["TYPE"] 316 | except: 317 | det_retrieve_info = None 318 | if det_retrieve_info is not None: 319 | if "location" in det_retrieve_info: 320 | L = True 321 | if "relation" in det_retrieve_info: 322 | R = True 323 | if "number" in det_retrieve_info: 324 | N = True 325 | det_docs = det_preprocess(det_docs, location=L, relation=R, number=N) # pre-process of APE information 326 | 327 | 328 | # OCR fetch 329 | if USE_OCR: 330 | try: 331 | request_det = json.loads(json_request)["DET"] 332 | request_det = filter_keywords(request_det) 333 | except: 334 | request_det = None 335 | ocr_docs = [] 336 | if len(ocr_docs_total) > 0: 337 | ocr_query = query.copy() 338 | if request_det is not None and len(request_det) > 0: 339 | ocr_query.extend(request_det) 340 | ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold) 341 | 342 | # ASR fetch 343 | if USE_ASR: 344 | asr_docs = [] 345 | try: 346 | request_asr = json.loads(json_request)["ASR"] 347 | except: 348 | request_asr = None 349 | if len(asr_docs_total) > 0: 350 | asr_query = query.copy() 351 | if request_asr is not None: 352 | asr_query.append(request_asr) 353 | asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold) 354 | 355 | qs = "" 356 | if USE_DET and len(det_docs) > 0: 357 | for i, info in enumerate(det_docs): 358 | if len(info) > 0: 359 | qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n" 360 | if len(qs) > 0: 361 | qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs 362 | if USE_ASR and len(asr_docs) > 0: 363 | qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs) 364 | if USE_OCR and len(ocr_docs) > 0: 365 | qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs) 366 | qs += "Select the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, or D) of the correct option. Question: " + question # you can change this prompt 367 | 368 | res = llava_inference(qs, video) 369 | print(res) 370 | -------------------------------------------------------------------------------- /evals/generate_videomme.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | from llava.model.builder import load_pretrained_model 4 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 6 | from llava.conversation import conv_templates, SeparatorStyle 7 | import torch 8 | from transformers import AutoProcessor, LlavaForConditionalGeneration, CLIPProcessor, CLIPModel, WhisperForConditionalGeneration, WhisperProcessor 9 | import copy 10 | from decord import VideoReader, cpu 11 | import numpy as np 12 | import json 13 | from tqdm import tqdm 14 | import os 15 | import easyocr 16 | from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic 17 | import re 18 | import ast 19 | import socket 20 | import pickle 21 | from tools.filter_keywords import filter_keywords 22 | from tools.scene_graph import generate_scene_graph_description 23 | import torchaudio, ffmpeg 24 | 25 | max_frames_num = 64 26 | clip_model = CLIPModel.from_pretrained("clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto") 27 | clip_processor = CLIPProcessor.from_pretrained("clip-vit-large-patch14-336") 28 | whisper_model = WhisperForConditionalGeneration.from_pretrained("whisper-large", torch_dtype=torch.float16, device_map="auto") 29 | whisper_processor = WhisperProcessor.from_pretrained("whisper-large") 30 | 31 | def process_video(video_path, max_frames_num, fps=1, force_sample=False): 32 | if max_frames_num == 0: 33 | return np.zeros((1, 336, 336, 3)) 34 | vr = VideoReader(video_path, ctx=cpu(),num_threads=1) 35 | total_frame_num = len(vr) 36 | video_time = total_frame_num / vr.get_avg_fps() 37 | fps = round(vr.get_avg_fps()/fps) 38 | frame_idx = [i for i in range(0, len(vr), fps)] 39 | frame_time = [i/fps for i in frame_idx] 40 | if len(frame_idx) > max_frames_num or force_sample: 41 | sample_fps = max_frames_num 42 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) 43 | frame_idx = uniform_sampled_frames.tolist() 44 | frame_time = [i/vr.get_avg_fps() for i in frame_idx] 45 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) 46 | spare_frames = vr.get_batch(frame_idx).asnumpy() 47 | return spare_frames, frame_time, video_time 48 | 49 | def extract_audio(video_path, audio_path): 50 | if not os.path.exists(audio_path): 51 | ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run() 52 | 53 | def chunk_audio(audio_path, chunk_length_s=30): 54 | speech, sr = torchaudio.load(audio_path) 55 | speech = speech.mean(dim=0) 56 | speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech) 57 | num_samples_per_chunk = chunk_length_s * 16000 58 | chunks = [] 59 | for i in range(0, len(speech), num_samples_per_chunk): 60 | chunks.append(speech[i:i + num_samples_per_chunk]) 61 | return chunks 62 | 63 | def transcribe_chunk(chunk): 64 | 65 | inputs = whisper_processor(chunk, return_tensors="pt") 66 | inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16) 67 | with torch.no_grad(): 68 | predicted_ids = whisper_model.generate( 69 | inputs["input_features"], 70 | no_repeat_ngram_size=2, 71 | early_stopping=True 72 | ) 73 | transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] 74 | return transcription 75 | 76 | def get_asr_docs(video_path, audio_path): 77 | 78 | full_transcription = [] 79 | try: 80 | extract_audio(video_path, audio_path) 81 | except: 82 | return full_transcription 83 | audio_chunks = chunk_audio(audio_path, chunk_length_s=30) 84 | 85 | for chunk in audio_chunks: 86 | transcription = transcribe_chunk(chunk) 87 | full_transcription.append(transcription) 88 | 89 | return full_transcription 90 | 91 | def get_ocr_docs(frames): 92 | reader = easyocr.Reader(['en']) 93 | text_set = [] 94 | ocr_docs = [] 95 | for img in frames: 96 | ocr_results = reader.readtext(img) 97 | det_info = "" 98 | for result in ocr_results: 99 | text = result[1] 100 | confidence = result[2] 101 | if confidence > 0.5 and text not in text_set: 102 | det_info += f"{text}; " 103 | text_set.append(text) 104 | if len(det_info) > 0: 105 | ocr_docs.append(det_info) 106 | 107 | return ocr_docs 108 | 109 | 110 | def save_frames(frames, file_name): 111 | file_paths = [] 112 | for i, frame in enumerate(frames): 113 | img = Image.fromarray(frame) 114 | file_path = f'restore/{file_name}/frame_{i}.png' 115 | img.save(file_path) 116 | file_paths.append(file_path) 117 | return file_paths 118 | 119 | def get_det_docs(frames, prompt, file_name): 120 | prompt = ",".join(prompt) 121 | frames_path = save_frames(frames, file_name) 122 | res = [] 123 | if len(frames) > 0: 124 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 125 | client_socket.connect(('0.0.0.0', 9999)) 126 | data = (frames_path, prompt) 127 | client_socket.send(pickle.dumps(data)) 128 | result_data = client_socket.recv(4096) 129 | try: 130 | res = pickle.loads(result_data) 131 | except: 132 | res = [] 133 | return res 134 | 135 | def det_preprocess(det_docs, location, relation, number): 136 | 137 | scene_descriptions = [] 138 | 139 | for det_doc_per_frame in det_docs: 140 | objects = [] 141 | scene_description = "" 142 | if len(det_doc_per_frame) > 0: 143 | for obj_id, objs in enumerate(det_doc_per_frame.split(";")): 144 | obj_name = objs.split(":")[0].strip() 145 | obj_bbox = objs.split(":")[1].strip() 146 | obj_bbox = ast.literal_eval(obj_bbox) 147 | objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox}) 148 | 149 | scene_description = generate_scene_graph_description(objects, location, relation, number) 150 | scene_descriptions.append(scene_description) 151 | 152 | return scene_descriptions 153 | 154 | device = "cuda" 155 | tokenizer, model, image_processor, max_length = load_pretrained_model( 156 | "LLaVA-Video-7B-Qwen2", 157 | None, 158 | "llava_qwen", 159 | torch_dtype="bfloat16", 160 | device_map="auto") # Add any other thing you want to pass in llava_model_args 161 | model.eval() 162 | conv_template = "qwen_1_5" # Make sure you use correct chat template for different models 163 | 164 | def llava_inference(qs, video): 165 | if video is not None: 166 | time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." 167 | question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\n" + qs 168 | else: 169 | question = qs 170 | conv = copy.deepcopy(conv_templates[conv_template]) 171 | conv.append_message(conv.roles[0], question) 172 | conv.append_message(conv.roles[1], None) 173 | prompt_question = conv.get_prompt() 174 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 175 | 176 | if video is not None: 177 | cont = model.generate( 178 | input_ids, 179 | images=video, 180 | modalities= ["video"], 181 | do_sample=False, 182 | temperature=0, 183 | max_new_tokens=16, 184 | top_p=1.0, 185 | num_beams=1 186 | ) 187 | else: 188 | cont = model.generate( 189 | input_ids, 190 | images=video, 191 | modalities= ["video"], 192 | do_sample=False, 193 | temperature=0, 194 | max_new_tokens=4096, 195 | ) 196 | 197 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() 198 | return text_outputs, 199 | 200 | rep_list = [] 201 | rag_threshold = 0.3 202 | clip_threshold = 0.3 203 | beta = 3.0 204 | USE_OCR = True 205 | USE_ASR = True 206 | USE_DET = True 207 | print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------") 208 | print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------") 209 | print(f"---------------DET{beta}-{clip_threshold}: {USE_DET}-----------------") 210 | print(f"---------------Frames: {max_frames_num}-----------------") 211 | 212 | file_name = f"generate_videomme" 213 | file_path = os.path.join("restore", file_name) 214 | if not os.path.exists(file_path): 215 | os.mkdir(file_path) 216 | data_path = "/path/to/Video-MME/data" 217 | json_file = f"results/{file_name}.json" 218 | 219 | with open("videomme_json_file.json", 'r', encoding='utf-8') as file: 220 | mme_data = json.load(file) 221 | 222 | if os.path.exists(json_file): 223 | with open(json_file, 'r', encoding='utf-8') as file: 224 | rep_list = json.load(file) 225 | 226 | index = len(rep_list) 227 | for item in tqdm(mme_data[index:], desc="Processing items"): 228 | 229 | video_path = os.path.join(data_path, item['url'] + ".mp4") 230 | content = item.copy() 231 | frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True) 232 | raw_video = [f for f in frames] 233 | 234 | video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16() 235 | video = [video] 236 | 237 | if USE_DET: 238 | video_tensor = [] 239 | for frame in raw_video: 240 | processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16) 241 | video_tensor.append(processed.squeeze(0)) 242 | video_tensor = torch.stack(video_tensor, dim=0) 243 | 244 | if USE_OCR: 245 | ocr_docs_total = get_ocr_docs(frames) 246 | 247 | if USE_ASR: 248 | if os.path.exists(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt")): 249 | with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f: 250 | asr_docs_total = f.readlines() 251 | else: 252 | audio_path = os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".wav") 253 | asr_docs_total = get_asr_docs(video_path, audio_path) 254 | with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f: 255 | for doc in asr_docs_total: 256 | f.write(doc + '\n') 257 | 258 | for q_num, question in enumerate(content['questions']): 259 | 260 | # step 0: get cot information 261 | retrieve_pmt_0 = "Question: " + question['question'] + '\n' + " ".join(question['options']) 262 | retrieve_pmt_0 += "\nTo answer the question step by step, you can provide your retrieve request to assist you by the following json format:" 263 | retrieve_pmt_0 += '''{ 264 | "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null. 265 | "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null. 266 | "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation". 267 | } 268 | ## Example 1: 269 | Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4. 270 | Your retrieve can be: 271 | { 272 | "ASR": "The location and the color of balloons, the number of the blue balloons.", 273 | "DET": ["blue ballons", "long table"], 274 | "TYPE": ["relation", "number"] 275 | } 276 | ## Example 2: 277 | Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow. 278 | Your retrieve can be: 279 | { 280 | "ASR": null, 281 | "DET": ["the man in black", "woman"], 282 | "TYPE": ["location", "relation"] 283 | } 284 | ## Example 3: 285 | Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States. 286 | Your retrieve can be: 287 | { 288 | "ASR": "The country recognized worldwide for its comedy.", 289 | "DET": null, 290 | "TYPE": null 291 | } 292 | Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.''' 293 | 294 | qs = "" 295 | 296 | if USE_ASR or USE_DET or USE_OCR: 297 | json_request = llava_inference(retrieve_pmt_0, None) 298 | 299 | # step 1: get docs information 300 | query = [question['question']] 301 | for o in question['options']: 302 | query.append(o) 303 | 304 | torch.cuda.empty_cache() 305 | 306 | # APE fetch 307 | if USE_DET: 308 | det_docs = [] 309 | try: 310 | request_det = json.loads(json_request)["DET"] 311 | request_det = filter_keywords(request_det) 312 | clip_text = ["A picture of " + txt for txt in request_det] 313 | if len(clip_text) == 0: 314 | clip_text = ["A picture of object"] 315 | except: 316 | request_det = None 317 | clip_text = ["A picture of object"] 318 | 319 | clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device) 320 | clip_img_feats = clip_model.get_image_features(video_tensor) 321 | with torch.no_grad(): 322 | text_features = clip_model.get_text_features(**clip_inputs) 323 | similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu() 324 | similarities = np.array(similarities, dtype=np.float64) 325 | alpha = beta * (len(similarities) / 16) 326 | similarities = similarities * alpha / np.sum(similarities) 327 | 328 | del clip_inputs, clip_img_feats, text_features 329 | torch.cuda.empty_cache() 330 | 331 | det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold] 332 | 333 | if request_det is not None and len(request_det) > 0: 334 | # process directly 335 | det_docs = get_det_docs(frames[det_top_idx], request_det, file_name) 336 | 337 | L, R, N = False, False, False 338 | try: 339 | det_retrieve_info = json.loads(json_request)["TYPE"] 340 | except: 341 | det_retrieve_info = None 342 | if det_retrieve_info is not None: 343 | if "location" in det_retrieve_info: 344 | L = True 345 | if "relation" in det_retrieve_info: 346 | R = True 347 | if "number" in det_retrieve_info: 348 | N = True 349 | det_docs = det_preprocess(det_docs, location=L, relation=R, number=N) # pre-process of APE information 350 | 351 | 352 | # OCR fetch 353 | if USE_OCR: 354 | try: 355 | request_det = json.loads(json_request)["DET"] 356 | request_det = filter_keywords(request_det) 357 | except: 358 | request_det = None 359 | ocr_docs = [] 360 | if len(ocr_docs_total) > 0: 361 | ocr_query = query.copy() 362 | if request_det is not None and len(request_det) > 0: 363 | ocr_query.extend(request_det) 364 | ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold) 365 | 366 | # ASR fetch 367 | if USE_ASR: 368 | asr_docs = [] 369 | try: 370 | request_asr = json.loads(json_request)["ASR"] 371 | except: 372 | request_asr = None 373 | if len(asr_docs_total) > 0: 374 | asr_query = query.copy() 375 | if request_asr is not None: 376 | asr_query.append(request_asr) 377 | asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold) 378 | 379 | 380 | if USE_DET and len(det_docs) > 0: 381 | for i, info in enumerate(det_docs): 382 | if len(info) > 0: 383 | qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n" 384 | if len(qs) > 0: 385 | qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs 386 | if USE_ASR and len(asr_docs) > 0: 387 | qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs) 388 | if USE_OCR and len(ocr_docs) > 0: 389 | qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs) 390 | 391 | qs += "Select the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, or D) of the correct option. Question: " + question['question'] + '\n' + " ".join(question['options']) + '\nThe best answer is:' 392 | 393 | res, token_num = llava_inference(qs, video) 394 | question['response'] = res 395 | 396 | rep_list.append(content) 397 | 398 | with open(json_file, "w", encoding='utf-8') as file: 399 | json.dump(rep_list, file, ensure_ascii=False, indent=4) 400 | -------------------------------------------------------------------------------- /evals/generate_longvideobench.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | from llava.model.builder import load_pretrained_model 4 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 6 | from llava.conversation import conv_templates, SeparatorStyle 7 | import torch 8 | from transformers import AutoProcessor, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel 9 | import copy 10 | from decord import VideoReader, cpu 11 | import numpy as np 12 | import json 13 | from tqdm import tqdm 14 | import os 15 | import easyocr 16 | from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic 17 | import re 18 | import ast 19 | import socket 20 | import pickle 21 | import torchaudio, ffmpeg 22 | from tools.scene_graph import generate_scene_graph_description 23 | from longvideobench import LongVideoBenchDataset 24 | 25 | max_frames_num = 64 26 | clip_model = CLIPModel.from_pretrained("clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto") 27 | clip_processor = CLIPProcessor.from_pretrained("clip-vit-large-patch14-336") 28 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 29 | "whisper-large", 30 | torch_dtype=torch.float16, 31 | device_map="auto" 32 | ) 33 | whisper_processor = WhisperProcessor.from_pretrained("whisper-large") 34 | 35 | def process_video(video_path, max_frames_num, fps=1, force_sample=False): 36 | if max_frames_num == 0: 37 | return np.zeros((1, 336, 336, 3)) 38 | vr = VideoReader(video_path, ctx=cpu(),num_threads=1) 39 | total_frame_num = len(vr) 40 | video_time = total_frame_num / vr.get_avg_fps() 41 | fps = round(vr.get_avg_fps()/fps) 42 | frame_idx = [i for i in range(0, len(vr), fps)] 43 | frame_time = [i/fps for i in frame_idx] 44 | if len(frame_idx) > max_frames_num or force_sample: 45 | sample_fps = max_frames_num 46 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) 47 | frame_idx = uniform_sampled_frames.tolist() 48 | frame_time = [i/vr.get_avg_fps() for i in frame_idx] 49 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) 50 | spare_frames = vr.get_batch(frame_idx).asnumpy() 51 | 52 | return spare_frames, frame_time, video_time 53 | 54 | def extract_audio(video_path, audio_path): 55 | if not os.path.exists(audio_path): 56 | ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run() 57 | 58 | def chunk_audio(audio_path, chunk_length_s=30): 59 | speech, sr = torchaudio.load(audio_path) 60 | speech = speech.mean(dim=0) 61 | speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech) 62 | num_samples_per_chunk = chunk_length_s * 16000 63 | chunks = [] 64 | for i in range(0, len(speech), num_samples_per_chunk): 65 | chunks.append(speech[i:i + num_samples_per_chunk]) 66 | return chunks 67 | 68 | def transcribe_chunk(chunk): 69 | 70 | inputs = whisper_processor(chunk, return_tensors="pt") 71 | inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16) 72 | with torch.no_grad(): 73 | predicted_ids = whisper_model.generate( 74 | inputs["input_features"], 75 | no_repeat_ngram_size=2, 76 | early_stopping=True 77 | ) 78 | transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] 79 | return transcription 80 | 81 | def get_asr_docs(video_path, audio_path): 82 | 83 | full_transcription = [] 84 | 85 | try: 86 | extract_audio(video_path, audio_path) 87 | except: 88 | return full_transcription 89 | audio_chunks = chunk_audio(audio_path, chunk_length_s=30) 90 | 91 | for chunk in audio_chunks: 92 | transcription = transcribe_chunk(chunk) 93 | full_transcription.append(transcription) 94 | 95 | return full_transcription 96 | 97 | def chunk_subtitles(text_list, chunk_size=20): 98 | result = [] 99 | current_chunk = [] 100 | word_count = 0 101 | 102 | for text in text_list: 103 | words = text.split() 104 | 105 | for word in words: 106 | current_chunk.append(word) 107 | word_count += 1 108 | 109 | if word_count >= chunk_size: 110 | result.append(' '.join(current_chunk)) 111 | current_chunk = [] 112 | word_count = 0 113 | 114 | if current_chunk: 115 | result.append(' '.join(current_chunk)) 116 | 117 | return result 118 | 119 | 120 | def get_ocr_docs(frames): 121 | reader = easyocr.Reader(['en']) 122 | text_set = [] 123 | ocr_docs = [] 124 | for img in frames: 125 | ocr_results = reader.readtext(img) 126 | det_info = "" 127 | for result in ocr_results: 128 | text = result[1] 129 | confidence = result[2] 130 | if confidence > 0.5 and text not in text_set: 131 | det_info += f"{text}; " 132 | text_set.append(text) 133 | if len(det_info) > 0: 134 | ocr_docs.append(det_info) 135 | 136 | return ocr_docs 137 | 138 | 139 | def save_frames(frames, file_name): 140 | file_paths = [] 141 | for i, frame in enumerate(frames): 142 | img = Image.fromarray(frame) 143 | file_path = f'restore/{file_name}/frame_{i}.png' 144 | img.save(file_path) 145 | file_paths.append(file_path) 146 | return file_paths 147 | 148 | def get_det_docs(frames, prompt, file_name): 149 | prompt = ",".join(prompt) 150 | frames_path = save_frames(frames, file_name) 151 | res = [] 152 | if len(frames) > 0: 153 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 154 | client_socket.connect(('0.0.0.0', 9999)) 155 | data = (frames_path, prompt) 156 | client_socket.send(pickle.dumps(data)) 157 | result_data = client_socket.recv(4096) 158 | try: 159 | res = pickle.loads(result_data) 160 | except: 161 | res = [] 162 | return res 163 | 164 | def det_preprocess(det_docs, location, relation, number): 165 | 166 | scene_descriptions = [] 167 | 168 | for det_doc_per_frame in det_docs: 169 | objects = [] 170 | scene_description = "" 171 | if len(det_doc_per_frame) > 0: 172 | for obj_id, objs in enumerate(det_doc_per_frame.split(";")): 173 | obj_name = objs.split(":")[0].strip() 174 | obj_bbox = objs.split(":")[1].strip() 175 | obj_bbox = ast.literal_eval(obj_bbox) 176 | objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox}) 177 | 178 | scene_description = generate_scene_graph_description(objects, location, relation, number) 179 | scene_descriptions.append(scene_description) 180 | 181 | return scene_descriptions 182 | 183 | device = "cuda" 184 | overwrite_config = {} 185 | overwrite_config['mm_vision_tower'] = "siglip-so400m-patch14-384" 186 | tokenizer, model, image_processor, max_length = load_pretrained_model( 187 | "LLaVA-Video-7B-Qwen2", 188 | None, 189 | "llava_qwen", 190 | torch_dtype="bfloat16", 191 | device_map="auto", 192 | overwrite_config=overwrite_config) # Add any other thing you want to pass in llava_model_args 193 | model.eval() 194 | conv_template = "qwen_1_5" # Make sure you use correct chat template for different models 195 | 196 | def llava_inference(qs, video): 197 | if video is not None: 198 | time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." 199 | question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\n" + qs 200 | else: 201 | question = qs 202 | conv = copy.deepcopy(conv_templates[conv_template]) 203 | conv.append_message(conv.roles[0], question) 204 | conv.append_message(conv.roles[1], None) 205 | prompt_question = conv.get_prompt() 206 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 207 | if video is not None: 208 | cont = model.generate( 209 | input_ids, 210 | images=video, 211 | modalities= ["video"], 212 | do_sample=False, 213 | temperature=0, 214 | max_new_tokens=16, 215 | top_p=1.0, 216 | num_beams=1 217 | ) 218 | else: 219 | cont = model.generate( 220 | input_ids, 221 | images=video, 222 | modalities= ["video"], 223 | do_sample=False, 224 | temperature=0, 225 | max_new_tokens=4096, 226 | ) 227 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() 228 | return text_outputs, input_ids.size(1) 229 | 230 | rep_list = [] 231 | rag_threshold = 0.3 232 | asr_chunk_size = 5 233 | clip_threshold = 0.3 234 | beta = 3.0 235 | USE_OCR = True 236 | USE_ASR = True 237 | USE_DET = True 238 | print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------") 239 | print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------") 240 | print(f"---------------DET{beta}-{clip_threshold}: {USE_DET}-----------------") 241 | print(f"---------------Frames: {max_frames_num}-----------------") 242 | 243 | file_name = f"LVB_Video-RAG" 244 | file_path = os.path.join("restore", file_name) 245 | if not os.path.exists(file_path): 246 | os.mkdir(file_path) 247 | data_path = "/path/to/LongVideoBenchData" 248 | json_file = f"results/{file_name}.json" 249 | letter = ['A. ', 'B. ', 'C. ', 'D. ', 'E. '] 250 | 251 | # test set 252 | # mme_data = LongVideoBenchDataset(data_path, "lvb_test_wo_gt.json", max_num_frames=max_frames_num).data 253 | 254 | # val set 255 | mme_data = LongVideoBenchDataset(data_path, "lvb_val.json", max_num_frames=max_frames_num).data 256 | 257 | if os.path.exists(json_file): 258 | with open(json_file, 'r', encoding='utf-8') as file: 259 | rep_list = json.load(file) 260 | 261 | index = len(rep_list) 262 | for item in tqdm(mme_data[index:], desc="Processing items"): 263 | 264 | video_path = os.path.join(data_path, "videos", item['video_path']) 265 | content = {} 266 | frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True) 267 | raw_video = [f for f in frames] 268 | 269 | video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16() 270 | video = [video] 271 | 272 | if USE_DET: 273 | video_tensor = [] 274 | for frame in raw_video: 275 | processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16) 276 | video_tensor.append(processed.squeeze(0)) 277 | video_tensor = torch.stack(video_tensor, dim=0) 278 | 279 | if USE_OCR: 280 | ocr_docs_total = get_ocr_docs(frames) 281 | 282 | if USE_ASR: 283 | if os.path.exists(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt")): 284 | with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f: 285 | asr_docs_total = f.readlines() 286 | else: 287 | audio_path = os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".wav") 288 | asr_docs_total = get_asr_docs(video_path, audio_path) 289 | with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f: 290 | for doc in asr_docs_total: 291 | f.write(doc + '\n') 292 | 293 | 294 | # step 0: get cot information 295 | retrieve_pmt_0 = "Question: " + item['question'] 296 | for i in range(len(item['candidates'])): 297 | retrieve_pmt_0 += letter[i] + str(item['candidates'][i]) + " " 298 | retrieve_pmt_0 += "\nTo answer the question step by step, list all the physical entities related to the question you want to retrieve, you can provide your retrieve request to assist you by the following json format:" 299 | retrieve_pmt_0 += '''{ 300 | "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null. 301 | "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null. 302 | "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation". 303 | } 304 | ## Example 1: 305 | Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4. 306 | Your retrieve can be: 307 | { 308 | "ASR": "The location and the color of balloons, the number of the blue balloons.", 309 | "DET": ["blue ballons", "long table"], 310 | "TYPE": ["relation", "number"] 311 | } 312 | ## Example 2: 313 | Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow. 314 | Your retrieve can be: 315 | { 316 | "ASR": null, 317 | "DET": ["the man in black", "woman"], 318 | "TYPE": ["location", "relation"] 319 | } 320 | ## Example 3: 321 | Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States. 322 | Your retrieve can be: 323 | { 324 | "ASR": "The country recognized worldwide for its comedy.", 325 | "DET": null, 326 | "TYPE": null 327 | } 328 | Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.''' 329 | 330 | qs = "" 331 | if USE_OCR or USE_DET or USE_ASR: 332 | 333 | json_request, _ = llava_inference(retrieve_pmt_0, None) 334 | 335 | # step 1: get docs information 336 | query = [item['question']] 337 | for o in item['candidates']: 338 | query.append(str(o)) 339 | 340 | # APE fetch 341 | if USE_DET: 342 | det_docs = [] 343 | try: 344 | request_det = json.loads(json_request)["DET"] 345 | # request_det = filter_keywords(request_det) 346 | clip_text = ["A picture of " + txt for txt in request_det] 347 | if len(clip_text) == 0: 348 | clip_text = ["A picture of object"] 349 | except: 350 | request_det = None 351 | clip_text = ["A picture of object"] 352 | 353 | clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device) 354 | clip_img_feats = clip_model.get_image_features(video_tensor) 355 | with torch.no_grad(): 356 | text_features = clip_model.get_text_features(**clip_inputs) 357 | similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu() 358 | similarities = np.array(similarities, dtype=np.float64) 359 | alpha = beta * (len(similarities) / 16) 360 | similarities = similarities * alpha / np.sum(similarities) 361 | 362 | del clip_inputs, clip_img_feats, text_features 363 | torch.cuda.empty_cache() 364 | 365 | det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold] 366 | 367 | if request_det is not None and len(request_det) > 0: 368 | # process directly 369 | det_docs = get_det_docs(frames[det_top_idx], request_det, file_name) 370 | 371 | L, R, N = False, False, False 372 | try: 373 | det_retrieve_info = json.loads(json_request)["TYPE"] 374 | except: 375 | det_retrieve_info = None 376 | if det_retrieve_info is not None: 377 | if "location" in det_retrieve_info: 378 | L = True 379 | if "relation" in det_retrieve_info: 380 | R = True 381 | if "number" in det_retrieve_info: 382 | N = True 383 | det_docs = det_preprocess(det_docs, location=L, relation=R, number=N) # pre-process of APE information 384 | 385 | 386 | # OCR fetch 387 | if USE_OCR: 388 | try: 389 | request_det = json.loads(json_request)["DET"] 390 | # request_det = filter_keywords(request_det) 391 | except: 392 | request_det = None 393 | ocr_docs = [] 394 | if len(ocr_docs_total) > 0: 395 | ocr_query = query.copy() 396 | if request_det is not None and len(request_det) > 0: 397 | ocr_query.extend(request_det) 398 | ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold) 399 | 400 | # ASR fetch 401 | if USE_ASR: 402 | asr_docs = [] 403 | try: 404 | request_asr = json.loads(json_request)["ASR"] 405 | except: 406 | request_asr = None 407 | if len(asr_docs_total) > 0: 408 | asr_query = query.copy() 409 | if request_asr is not None: 410 | asr_query.append(request_asr) 411 | asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold) 412 | 413 | if USE_DET and len(det_docs) > 0: 414 | for i, info in enumerate(det_docs): 415 | if len(info) > 0: 416 | qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n" 417 | if len(qs) > 0: 418 | qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs 419 | if USE_ASR and len(asr_docs) > 0: 420 | qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs) 421 | if USE_OCR and len(ocr_docs) > 0: 422 | qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs) 423 | 424 | qs += "\nSelect the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, D or E) of the correct option. Question: " + item['question'] + '\n' 425 | for i in range(len(item['candidates'])): 426 | qs += letter[i] + str(item['candidates'][i]) + " " 427 | qs += "\nThe best answer is:" 428 | 429 | res, _ = llava_inference(qs, video) 430 | 431 | # val set 432 | start_chr = 'A' 433 | content = { 434 | "video": item['video_id'], 435 | "pred": res[0], 436 | "gt": chr(ord(start_chr) + item['correct_choice']) 437 | } 438 | 439 | # test set 440 | # content[item['video_id']] = res 441 | 442 | rep_list.append(content) 443 | index += 1 444 | 445 | with open(json_file, "w", encoding='utf-8') as file: 446 | json.dump(rep_list, file, ensure_ascii=False, indent=4) 447 | -------------------------------------------------------------------------------- /evals/generate_mlvu.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | from llava.model.builder import load_pretrained_model 4 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 6 | from llava.conversation import conv_templates, SeparatorStyle 7 | import copy 8 | import torch 9 | from transformers import AutoProcessor, CLIPProcessor, CLIPModel, WhisperForConditionalGeneration, WhisperProcessor 10 | from decord import VideoReader, cpu 11 | import numpy as np 12 | import json 13 | from tqdm import tqdm 14 | import os 15 | import easyocr 16 | from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic 17 | import re 18 | import ast 19 | import socket 20 | import pickle 21 | from tools.filter_keywords import filter_keywords 22 | from tools.scene_graph import generate_scene_graph_description 23 | import ffmpeg 24 | import torchaudio 25 | from torch.utils.data import Dataset 26 | from openai import OpenAI 27 | import time 28 | import openai 29 | 30 | max_frames_num = 64 31 | clip_model = CLIPModel.from_pretrained("clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto") 32 | clip_processor = CLIPProcessor.from_pretrained("clip-vit-large-patch14-336") 33 | whisper_model = WhisperForConditionalGeneration.from_pretrained( 34 | "whisper-large", 35 | torch_dtype=torch.float16, 36 | device_map="auto" 37 | ) 38 | whisper_processor = WhisperProcessor.from_pretrained("whisper-large") 39 | 40 | def process_video(video_path, max_frames_num, fps=1, force_sample=False): 41 | if max_frames_num == 0: 42 | return np.zeros((1, 336, 336, 3)) 43 | vr = VideoReader(video_path, ctx=cpu(),num_threads=1) 44 | total_frame_num = len(vr) 45 | video_time = total_frame_num / vr.get_avg_fps() 46 | fps = round(vr.get_avg_fps()/fps) 47 | frame_idx = [i for i in range(0, len(vr), fps)] 48 | frame_time = [i/fps for i in frame_idx] 49 | if len(frame_idx) > max_frames_num or force_sample: 50 | sample_fps = max_frames_num 51 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) 52 | frame_idx = uniform_sampled_frames.tolist() 53 | frame_time = [i/vr.get_avg_fps() for i in frame_idx] 54 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) 55 | spare_frames = vr.get_batch(frame_idx).asnumpy() 56 | return spare_frames, frame_time, video_time 57 | 58 | def extract_audio(video_path, audio_path): 59 | if not os.path.exists(audio_path): 60 | ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run() 61 | 62 | def chunk_audio(audio_path, chunk_length_s=30): 63 | speech, sr = torchaudio.load(audio_path) 64 | speech = speech.mean(dim=0) 65 | speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech) 66 | num_samples_per_chunk = chunk_length_s * 16000 67 | chunks = [] 68 | for i in range(0, len(speech), num_samples_per_chunk): 69 | chunks.append(speech[i:i + num_samples_per_chunk]) 70 | return chunks 71 | 72 | def transcribe_chunk(chunk): 73 | 74 | inputs = whisper_processor(chunk, return_tensors="pt") 75 | inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16) 76 | with torch.no_grad(): 77 | predicted_ids = whisper_model.generate( 78 | inputs["input_features"], 79 | no_repeat_ngram_size=2, 80 | early_stopping=True 81 | ) 82 | transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] 83 | return transcription 84 | 85 | def get_asr_docs(video_path, audio_path): 86 | 87 | full_transcription = [] 88 | 89 | try: 90 | extract_audio(video_path, audio_path) 91 | except: 92 | return full_transcription 93 | audio_chunks = chunk_audio(audio_path, chunk_length_s=30) 94 | 95 | for chunk in audio_chunks: 96 | transcription = transcribe_chunk(chunk) 97 | full_transcription.append(transcription) 98 | 99 | return full_transcription 100 | 101 | def get_ocr_docs(frames): 102 | reader = easyocr.Reader(['en']) 103 | text_set = [] 104 | ocr_docs = [] 105 | for img in frames: 106 | ocr_results = reader.readtext(img) 107 | det_info = "" 108 | for result in ocr_results: 109 | text = result[1] 110 | confidence = result[2] 111 | if confidence > 0.5 and text not in text_set: 112 | det_info += f"{text}; " 113 | text_set.append(text) 114 | if len(det_info) > 0: 115 | ocr_docs.append(det_info) 116 | 117 | return ocr_docs 118 | 119 | 120 | def save_frames(frames, file_name): 121 | file_paths = [] 122 | for i, frame in enumerate(frames): 123 | img = Image.fromarray(frame) 124 | file_path = f'{file_name}/frame_{i}.png' 125 | img.save(file_path) 126 | file_paths.append(file_path) 127 | return file_paths 128 | 129 | def get_det_docs(frames, prompt, file_name): 130 | prompt = ",".join(prompt) 131 | frames_path = save_frames(frames, file_name) 132 | res = [] 133 | if len(frames) > 0: 134 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 135 | client_socket.connect(('10.24.82.203', 9999)) 136 | data = (frames_path, prompt) 137 | client_socket.send(pickle.dumps(data)) 138 | result_data = client_socket.recv(4096) 139 | try: 140 | res = pickle.loads(result_data) 141 | except: 142 | res = [] 143 | return res 144 | 145 | def det_preprocess(det_docs, location, relation, number): 146 | 147 | scene_descriptions = [] 148 | 149 | for det_doc_per_frame in det_docs: 150 | objects = [] 151 | scene_description = "" 152 | if len(det_doc_per_frame) > 0: 153 | for obj_id, objs in enumerate(det_doc_per_frame.split(";")): 154 | obj_name = objs.split(":")[0].strip() 155 | obj_bbox = objs.split(":")[1].strip() 156 | obj_bbox = ast.literal_eval(obj_bbox) 157 | objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox}) 158 | 159 | scene_description = generate_scene_graph_description(objects, location, relation, number) 160 | scene_descriptions.append(scene_description) 161 | 162 | return scene_descriptions 163 | 164 | device = "cuda" 165 | overwrite_config = {} 166 | overwrite_config['mm_vision_tower'] = "/mnt/82_store/LLM-weights/siglip-so400m-patch14-384" 167 | tokenizer, model, image_processor, max_length = load_pretrained_model( 168 | "/mnt/82_store/LLM-weights/LLaVA-Video-7B-Qwen2", 169 | None, 170 | "llava_qwen", 171 | torch_dtype="bfloat16", 172 | device_map="auto", 173 | overwrite_config=overwrite_config) # Add any other thing you want to pass in llava_model_args 174 | model.eval() 175 | conv_template = "qwen_1_5" # Make sure you use correct chat template for different models 176 | 177 | def llava_inference(qs, video): 178 | if video is not None: 179 | time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." 180 | question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\n" + qs 181 | else: 182 | question = qs 183 | conv = copy.deepcopy(conv_templates[conv_template]) 184 | conv.append_message(conv.roles[0], question) 185 | conv.append_message(conv.roles[1], None) 186 | prompt_question = conv.get_prompt() 187 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 188 | 189 | if video is not None: 190 | cont = model.generate( 191 | input_ids, 192 | images=video, 193 | modalities= ["video"], 194 | do_sample=False, 195 | temperature=0, 196 | max_new_tokens=16, 197 | top_p=1.0, 198 | num_beams=1 199 | ) 200 | else: 201 | cont = model.generate( 202 | input_ids, 203 | images=video, 204 | modalities= ["video"], 205 | do_sample=False, 206 | temperature=0, 207 | max_new_tokens=4096, 208 | ) 209 | 210 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() 211 | return text_outputs 212 | 213 | 214 | def get_prompt2(conv): 215 | ret = conv.system + conv.sep 216 | count = 0 217 | for role, message in conv.messages: 218 | count += 1 219 | if count == len(conv.messages): 220 | ret += role + ": " + message 221 | else: 222 | if message: 223 | ret += role + ": " + message + conv.sep 224 | else: 225 | ret += role + ":" 226 | return ret 227 | 228 | class MLVU(Dataset): 229 | def __init__(self, data_dir, video_folder): 230 | self.video_folder=video_folder 231 | self.data_list = [] 232 | with open(os.path.join(data_dir), 'r') as f: 233 | json_data = json.load(f) 234 | for data in json_data: 235 | self.data_list.append({ 236 | 'task_type': data["question_type"], 237 | 'data': data, 238 | 'question_id': data["video"] + "_" + data["question"], 239 | 'candidates': data["candidates"], 240 | 'answer': data["answer"] 241 | }) 242 | 243 | def __str__(self): 244 | len_list = {} 245 | option_list = {} 246 | for data in self.data_list: 247 | if data['task_type'] not in len_list: 248 | len_list[data['task_type']] = 0 249 | len_list[data['task_type']] += 1 250 | if data['task_type'] not in option_list: 251 | option_list[data['task_type']] = 0 252 | option_list[data['task_type']] += len(data['data']['candidates']) 253 | 254 | correct = 0 255 | total = 0 256 | res = f"There are {len(self.data_list)} videos as follow:\n" 257 | for k, v in len_list.items(): 258 | correct += len_list[k] 259 | total += option_list[k] 260 | res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n" 261 | correct = correct + 1 / option_list[k] 262 | res += f"Total random accuracy: {correct/total*100:.2f}%" 263 | return res.rstrip() 264 | 265 | def __len__(self): 266 | return len(self.data_list) 267 | 268 | def get_index(self, bound, fps, max_frame, first_idx=0): 269 | start, end = -100000, 100000 270 | start_idx = max(first_idx, round(start * fps)) 271 | end_idx = min(round(end * fps), max_frame) 272 | seg_size = float(end_idx - start_idx) / self.num_segments 273 | frame_indices = np.array([ 274 | int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) 275 | for idx in range(self.num_segments) 276 | ]) 277 | return frame_indices 278 | 279 | 280 | def qa_template(self, data): 281 | question = f"Question: {data['question']}\n" 282 | question += "Options:\n" 283 | 284 | for idx, c in enumerate(data['candidates']): 285 | question += f"({chr(ord('A') + idx)}) {c}\n" 286 | 287 | question = question.rstrip() 288 | 289 | return question 290 | 291 | def __getitem__(self, idx): 292 | bound = None 293 | video_path = os.path.join(self.video_folder, self.data_list[idx]['data']['video']) 294 | question = self.qa_template(self.data_list[idx]['data']) 295 | start_chr = 'A' 296 | return { 297 | 'video': video_path, 298 | 'question': question, 299 | 'task_type': self.data_list[idx]['task_type'], 300 | "question_id": self.data_list[idx]["question_id"], 301 | "answer": chr(ord(start_chr) + self.data_list[idx]["candidates"].index(self.data_list[idx]['answer'])) 302 | } 303 | 304 | def chunk_string_by_words(string, words_per_chunk): 305 | words = string.split() 306 | return [' '.join(words[i:i+words_per_chunk]) for i in range(0, len(words), words_per_chunk)] 307 | 308 | 309 | total = 0 310 | res_list = [] 311 | acc_dict = {} 312 | process_list = [] 313 | rag_threshold = 0.0 314 | asr_chunk_size = 5 315 | clip_threshold = 0.3 316 | beta = 3.0 317 | USE_OCR = True 318 | USE_ASR = True 319 | USE_DET = True 320 | print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------") 321 | print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------") 322 | print(f"---------------DET{clip_threshold}-{beta}: {USE_DET}-----------------") 323 | print(f"---------------Frames: {max_frames_num}-----------------") 324 | 325 | file_name = f"7B_DEV_MC_asr{USE_ASR}_ocr{USE_OCR}_ape{beta}{USE_DET}_{max_frames_num}frames_th{rag_threshold}_dep" 326 | file_path = os.path.join("restore", file_name) 327 | if not os.path.exists(file_path): 328 | os.mkdir(file_path) 329 | video_folder="MLVU/video" 330 | data_dir = f"MLVU/MLVU_Dev.json" 331 | dataset = MLVU(data_dir, video_folder) 332 | json_file = f"results/{file_name}.json" 333 | 334 | if os.path.exists(json_file): 335 | with open(json_file, 'r', encoding='utf-8') as file: 336 | res_list = json.load(file) 337 | total = len(res_list) 338 | for i in res_list: 339 | process_list.append(i['question_id']) 340 | 341 | for example in tqdm(dataset): 342 | 343 | if example["question_id"] in process_list: 344 | continue 345 | 346 | video_path = example["video"] 347 | frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True) 348 | raw_video = [f for f in frames] 349 | 350 | video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16() 351 | video = [video] 352 | 353 | if USE_DET: 354 | video_tensor = [] 355 | for frame in raw_video: 356 | processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16) 357 | video_tensor.append(processed.squeeze(0)) 358 | video_tensor = torch.stack(video_tensor, dim=0) 359 | 360 | if USE_OCR: 361 | ocr_docs_total = get_ocr_docs(frames) 362 | 363 | if USE_ASR: 364 | if os.path.exists(os.path.join("audio", os.path.basename(video_path).split(".")[0] + ".txt")): 365 | with open(os.path.join("audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f: 366 | asr_docs_total = f.readlines() 367 | else: 368 | audio_path = os.path.join("audio", os.path.basename(video_path).split(".")[0] + ".wav") 369 | asr_docs_total = get_asr_docs(video_path, audio_path) 370 | with open(os.path.join("audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f: 371 | for doc in asr_docs_total: 372 | f.write(doc + '\n') 373 | 374 | # step 0: get cot information 375 | retrieve_pmt_0 = example["question"] 376 | retrieve_pmt_0 += "\nTo answer the question step by step, list all the physical entities related to the question you want to retrieve, you can provide your retrieve request to assist you by the following json format:" 377 | retrieve_pmt_0 += '''{ 378 | "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null. 379 | "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null. 380 | "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation". 381 | } 382 | ## Example 1: 383 | Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4. 384 | Your retrieve can be: 385 | { 386 | "ASR": "The location and the color of balloons, the number of the blue balloons.", 387 | "DET": ["blue ballons", "long table"], 388 | "TYPE": ["relation", "number"] 389 | } 390 | ## Example 2: 391 | Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow. 392 | Your retrieve can be: 393 | { 394 | "ASR": null, 395 | "DET": ["the man in black", "woman"], 396 | "TYPE": ["location", "relation"] 397 | } 398 | ## Example 3: 399 | Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States. 400 | Your retrieve can be: 401 | { 402 | "ASR": "The country recognized worldwide for its comedy.", 403 | "DET": null, 404 | "TYPE": null 405 | } 406 | Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.''' 407 | 408 | qs = "" 409 | if USE_OCR or USE_DET or USE_ASR: 410 | 411 | json_request = llava_inference(retrieve_pmt_0, None) 412 | 413 | # step 1: get docs information 414 | query = [example["question"]] 415 | 416 | # APE fetch 417 | if USE_DET: 418 | det_docs = [] 419 | try: 420 | request_det = json.loads(json_request)["DET"] 421 | request_det = filter_keywords(request_det) 422 | clip_text = ["A picture of " + txt for txt in request_det] 423 | if len(clip_text) == 0: 424 | clip_text = ["A picture of object"] 425 | except: 426 | request_det = None 427 | clip_text = ["A picture of object"] 428 | 429 | clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device) 430 | clip_img_feats = clip_model.get_image_features(video_tensor) 431 | with torch.no_grad(): 432 | text_features = clip_model.get_text_features(**clip_inputs) 433 | similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu() 434 | similarities = np.array(similarities, dtype=np.float64) 435 | alpha = beta * (len(similarities) / 16) 436 | similarities = similarities * alpha / np.sum(similarities) 437 | 438 | del clip_inputs, clip_img_feats, text_features 439 | torch.cuda.empty_cache() 440 | 441 | det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold] 442 | 443 | if request_det is not None and len(request_det) > 0: 444 | # process directly 445 | det_docs = get_det_docs(frames[det_top_idx], request_det, file_name) 446 | 447 | L, R, N = False, False, False 448 | try: 449 | det_retrieve_info = json.loads(json_request)["TYPE"] 450 | except: 451 | det_retrieve_info = None 452 | if det_retrieve_info is not None: 453 | if "location" in det_retrieve_info: 454 | L = True 455 | if "relation" in det_retrieve_info: 456 | R = True 457 | if "number" in det_retrieve_info: 458 | N = True 459 | det_docs = det_preprocess(det_docs, location=L, relation=R, number=N) # pre-process of APE information 460 | 461 | 462 | # OCR fetch 463 | if USE_OCR: 464 | try: 465 | request_det = json.loads(json_request)["DET"] 466 | request_det = filter_keywords(request_det) 467 | except: 468 | request_det = None 469 | ocr_docs = [] 470 | if len(ocr_docs_total) > 0: 471 | ocr_query = query.copy() 472 | if request_det is not None and len(request_det) > 0: 473 | ocr_query.extend(request_det) 474 | ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold) 475 | 476 | # ASR fetch 477 | if USE_ASR: 478 | asr_docs = [] 479 | try: 480 | request_asr = json.loads(json_request)["ASR"] 481 | except: 482 | request_asr = None 483 | if len(asr_docs_total) > 0: 484 | asr_query = query.copy() 485 | if request_asr is not None: 486 | asr_query.append(request_asr) 487 | asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold) 488 | 489 | if USE_DET and len(det_docs) > 0: 490 | for i, info in enumerate(det_docs): 491 | if len(info) > 0: 492 | qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n" 493 | if len(qs) > 0: 494 | qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs 495 | if USE_ASR and len(asr_docs) > 0: 496 | qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs) 497 | if USE_OCR and len(ocr_docs) > 0: 498 | qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs) 499 | 500 | qs += "Select the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, D, E or F) of the correct option." + example["question"] + '\nThe best answer is:' 501 | 502 | res = llava_inference(qs, video) 503 | start_chr = 'A' 504 | res_list.append({ 505 | 'question_id': example["question_id"], 506 | 'question_type': example['task_type'], 507 | 'option': res[0], 508 | 'answer': example["answer"] 509 | }) 510 | 511 | with open(json_file, "w", encoding='utf-8') as file: 512 | json.dump(res_list, file, ensure_ascii=False, indent=4) 513 | --------------------------------------------------------------------------------