├── .gitignore ├── LICENSE ├── README.md ├── configs ├── llava_video │ ├── llava-video_lvbench.yaml │ ├── llava-video_mlvu.yaml │ ├── llava-video_videomme.yaml │ ├── retake_llava-video_lvbench.yaml │ ├── retake_llava-video_mlvu.yaml │ └── retake_llava-video_videomme.yaml ├── qwen2_vl │ ├── qwen2-vl_lvbench.yaml │ ├── qwen2-vl_mlvu.yaml │ ├── qwen2-vl_videomme.yaml │ ├── retake_qwen2-vl_lvbench.yaml │ ├── retake_qwen2-vl_mlvu.yaml │ └── retake_qwen2-vl_videomme.yaml ├── retake_demo.yaml └── retake_demo_npu.yaml ├── demo.py ├── docs ├── prepare_lvbench.md ├── prepare_mlvu.md └── prepare_videomme.md ├── environment.yaml ├── environment_npu.yaml ├── misc ├── Q8AZ16uBhr8_resized_fps2_mute.mp4 └── overview.png ├── retake ├── dataset_utils.py ├── infer_eval.py ├── llava_onevision.py ├── longvideo_cache.py ├── monkeypatch.py ├── qwen2_vl.py └── visual_compression.py └── scripts ├── infer_eval_retake.sh └── utils ├── build_lvbench_dataset.py ├── build_mlvu_dataset.py ├── build_mlvu_test_dataset.py ├── build_videomme_dataset.py ├── cal_flops.py ├── cal_ttft.py ├── convert_llava_video_weights_to_hf.py └── frame_extraction.py /.gitignore: -------------------------------------------------------------------------------- 1 | /dataset 2 | /results 3 | */__pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SCZwangxiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ReTaKe: Reducing Temporal and Knowledge Redundancy for Long Video Understanding](https://arxiv.org/abs/2412.20504) 2 | 3 | ReTaKe is a novel approach for long video understanding that reduces temporal and knowledge redundancy, enabling MLLMs to process 8x longer video sequences (up to 2048 frames) under the same memory budget. 4 | 5 | --- 6 | 7 | ## 📢 Recent Updates 8 | - **2025/03/11**: Polish the paper, improve the readability of the methods section, and add more ablation studies and results for LongVideoBench. 9 | - **2025/02/01**: Support for the latest version of Transformers (v4.48). 10 | - **2025/01/29**: Added support for LLaVA-Video and LLaVA-OneVision. 11 | 12 | --- 13 | 14 | ## 🚀 Key Contributions 15 | 16 | - **Training-Free Framework**: ReTaKe is the first method to jointly model temporal and knowledge redundancy for long video understanding, reducing the model sequence length to 1/4 of the original with a relative performance loss within 1%. 17 | 18 | - **Novel Techniques**: 19 | - **DPSelect**: A keyframe selection method to reduce low-level temporal redundancy. 20 | - **PivotKV**: A KV cache compression method to reduce high-level knowledge redundancy in long videos. 21 | 22 |

23 | Overview of ReTaKe 24 |

25 | 26 | --- 27 | 28 | ## ⚙️ Environment Setup 29 | 30 | ### For GPU Users: 31 | ```bash 32 | conda env create -f environment.yaml 33 | ``` 34 | 35 | ### For NPU Users: 36 | ```bash 37 | conda env create -f environment_npu.yaml 38 | ``` 39 | 40 | ### Additional Dependencies: 41 | ```bash 42 | apt-get install ffmpeg # Required for full functionality; quick demo does not require ffmpeg. 43 | ``` 44 | 45 | --- 46 | 47 | ## 🖥️ Quick Demo 48 | 49 | ### Step 1: Update Configuration 50 | Modify the `hf_qwen2vl7b_path` in `./demo.py` to point to your local path for `Qwen2-VL-7B-Instruct`. 51 | For NPU users, also update `config_path` to `'configs/retake_demo_npu.yaml'`. 52 | 53 | ### Step 2 (Optional for LLaVA-Video): Convert Model 54 | ```bash 55 | # Convert LLaVA-Video model into Hugging Face format 56 | # Ensure the following models are downloaded: Qwen2-7B-Instruct, siglip-so400m-patch14-384, and LLaVAVideoQwen2_7B. 57 | python scripts/utils/convert_llava_video_weights_to_hf.py \ 58 | --text_model_id /path_to/Qwen2-7B-Instruct \ 59 | --vision_model_id /path_to/siglip-so400m-patch14-384 \ 60 | --output_hub_path /path_to/llava-video-qwen2-7b-hf \ 61 | --old_state_dict_id /path_to/LLaVAVideoQwen2_7B 62 | ``` 63 | 64 | ### Step 3: Run the Demo 65 | ```bash 66 | python demo.py 67 | ``` 68 | 69 | --- 70 | 71 | ## 📊 Reproducing ReTaKe Results 72 | 73 | ### Step 1: Prepare Datasets 74 | Follow the documentation to prepare the required datasets: 75 | - [VideoMME](docs/prepare_videomme.md) 76 | - [MLVU](docs/prepare_mlvu.md) 77 | - [LVBench](docs/prepare_lvbench.md) 78 | 79 | ### Step 2: Run Inference and Evaluation 80 | Use the provided script to perform inference and evaluation: 81 | ```bash 82 | bash scripts/infer_eval_retake.sh ${YOUR_PATH_TO_Qwen2-VL-7B-Instruct} configs/qwen2_vl/retake_qwen2-vl_videomme.yaml 8 83 | bash scripts/infer_eval_retake.sh ${YOUR_PATH_TO_Qwen2-VL-7B-Instruct} configs/qwen2_vl/retake_qwen2-vl_mlvu.yaml 8 84 | bash scripts/infer_eval_retake.sh ${YOUR_PATH_TO_Qwen2-VL-7B-Instruct} configs/qwen2_vl/retake_qwen2-vl_lvbench.yaml 8 85 | ``` 86 | 87 | - Results will be saved in the `./results` directory. 88 | 89 | --- 90 | 91 | ## 📚 Citation 92 | If you find this work helpful, please consider citing: 93 | ```bibtex 94 | @misc{xiao_retake_2024, 95 | author = {Xiao Wang and 96 | Qingyi Si and 97 | Jianlong Wu and 98 | Shiyu Zhu and 99 | Li Cao and 100 | Liqiang Nie}, 101 | title = {{ReTaKe}: {Reducing} {Temporal} and {Knowledge} {Redundancy} for {Long} {Video} {Understanding}}, 102 | year = {2024}, 103 | note = {arXiv:2412.20504 [cs]} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /configs/llava_video/llava-video_lvbench.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: llava_video 3 | method: retake 4 | attn_implementation: "flash_attention_2" 5 | 6 | ### dataset 7 | dataset_name: lvbench 8 | anno_file: dataset/lvbench/lvbench.json 9 | dataloader_num_workers: 4 10 | 11 | ### data 12 | sample_fps: 2 13 | max_num_frames: 64 14 | longsize_resolution: 682 # short-side can be 384 15 | 16 | ### generate 17 | do_sample: false 18 | 19 | ### output 20 | output_dir: results/llava-video_lvbench_f64_2fps_r682/base 21 | -------------------------------------------------------------------------------- /configs/llava_video/llava-video_mlvu.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: llava_video 3 | method: retake 4 | attn_implementation: "flash_attention_2" 5 | 6 | ### dataset 7 | dataset_name: mlvu 8 | anno_file: dataset/mlvu/mlvu.json 9 | dataloader_num_workers: 4 10 | 11 | ### data 12 | sample_fps: 2 13 | max_num_frames: 64 14 | longsize_resolution: 682 # short-side can be 384 15 | 16 | ### generate 17 | do_sample: false 18 | 19 | ### output 20 | output_dir: results/llava-video_mlvu_f64_2fps_r682/base 21 | -------------------------------------------------------------------------------- /configs/llava_video/llava-video_videomme.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: llava_video 3 | method: retake 4 | attn_implementation: "flash_attention_2" 5 | 6 | ### dataset 7 | dataset_name: videomme 8 | anno_file: dataset/video_mme/video_mme.json 9 | dataloader_num_workers: 4 10 | 11 | ### data 12 | sample_fps: 2 13 | max_num_frames: 64 14 | longsize_resolution: 682 # short-side can be 384 15 | 16 | ### generate 17 | do_sample: false 18 | 19 | ### output 20 | output_dir: results/llava-video_video_mme_f64_2fps_r682/base 21 | -------------------------------------------------------------------------------- /configs/llava_video/retake_llava-video_lvbench.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: llava_video 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | longvideo_kwargs: { 7 | 'frame_chunk_size': 32, 8 | 'chunked_prefill_frames': 32, 9 | # Keyframe compression 10 | 'visual_compression': True, 11 | 'visual_compression_kwargs': { 12 | 'compression_ratio': 1.0, 13 | 'compression_method': 'Keyframe', 14 | 'patch_sync': False, 15 | 'return_keyframe_mask': True 16 | }, 17 | # KVCache compression 18 | 'kvcache_compression': True, 19 | 'kvcache_compression_kwargs': { 20 | 'dynamic_compression_ratio': True, 21 | 'compression_method': 'pivotkv', 22 | 'pos_embed_reforge': True, 23 | 'max_input_length': 40000 24 | }, 25 | } 26 | 27 | ### dataset 28 | dataset_name: lvbench 29 | anno_file: dataset/lvbench/lvbench.json 30 | dataloader_num_workers: 4 31 | 32 | ### data 33 | sample_fps: 2 34 | max_num_frames: 1024 35 | longsize_resolution: 682 36 | 37 | ### generate 38 | do_sample: false 39 | 40 | ### output 41 | output_dir: results/llava-video_f1024_2fps_r682/retake_dp1-async_pivot-40k 42 | -------------------------------------------------------------------------------- /configs/llava_video/retake_llava-video_mlvu.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: llava_video 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | longvideo_kwargs: { 7 | 'frame_chunk_size': 32, 8 | 'chunked_prefill_frames': 32, 9 | # Keyframe compression 10 | 'visual_compression': True, 11 | 'visual_compression_kwargs': { 12 | 'compression_ratio': 1.0, 13 | 'compression_method': 'Keyframe', 14 | 'patch_sync': False, 15 | 'return_keyframe_mask': True 16 | }, 17 | # KVCache compression 18 | 'kvcache_compression': True, 19 | 'kvcache_compression_kwargs': { 20 | 'dynamic_compression_ratio': True, 21 | 'compression_method': 'pivotkv', 22 | 'pos_embed_reforge': True, 23 | 'max_input_length': 40000 24 | }, 25 | } 26 | 27 | ### dataset 28 | dataset_name: mlvu 29 | anno_file: dataset/mlvu/mlvu.json 30 | dataloader_num_workers: 4 31 | 32 | ### data 33 | sample_fps: 2 34 | max_num_frames: 1024 35 | longsize_resolution: 682 36 | 37 | ### generate 38 | do_sample: false 39 | 40 | ### output 41 | output_dir: results/llava-video_rope4_mlvu_f1024_2fps_r682/retake_dp1-async_pivot-40k 42 | -------------------------------------------------------------------------------- /configs/llava_video/retake_llava-video_videomme.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: llava_video 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | longvideo_kwargs: { 7 | 'frame_chunk_size': 32, 8 | 'chunked_prefill_frames': 32, 9 | # Keyframe compression 10 | 'visual_compression': True, 11 | 'visual_compression_kwargs': { 12 | 'compression_ratio': 1.0, 13 | 'compression_method': 'Keyframe', 14 | 'patch_sync': False, 15 | 'return_keyframe_mask': True 16 | }, 17 | # KVCache compression 18 | 'kvcache_compression': True, 19 | 'kvcache_compression_kwargs': { 20 | 'dynamic_compression_ratio': True, 21 | 'compression_method': 'pivotkv', 22 | 'pos_embed_reforge': True, 23 | 'max_input_length': 40000 24 | }, 25 | } 26 | 27 | ### dataset 28 | dataset_name: videomme 29 | anno_file: dataset/video_mme/video_mme.json 30 | dataloader_num_workers: 4 31 | 32 | ### data 33 | sample_fps: 2 34 | max_num_frames: 1024 35 | longsize_resolution: 682 36 | 37 | ### generate 38 | do_sample: false 39 | 40 | ### output 41 | output_dir: results/llava-video_rope4_video_mme_f1024_2fps_r682/retake_dp1-async_pivot-40k 42 | -------------------------------------------------------------------------------- /configs/qwen2_vl/qwen2-vl_lvbench.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: qwen2_vl 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | 7 | ### dataset 8 | dataset_name: lvbench 9 | anno_file: dataset/lvbench/lvbench.json 10 | dataloader_num_workers: 2 11 | 12 | ### data 13 | sample_fps: 2 14 | max_num_frames: 256 15 | longsize_resolution: 448 16 | 17 | ### generate 18 | do_sample: false 19 | 20 | ### output 21 | output_dir: results/qwen2vl_7b_lvbench_f256_2fps_r448/base 22 | -------------------------------------------------------------------------------- /configs/qwen2_vl/qwen2-vl_mlvu.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: qwen2_vl 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | 7 | ### dataset 8 | dataset_name: mlvu 9 | anno_file: dataset/mlvu/mlvu.json 10 | dataloader_num_workers: 2 11 | 12 | ### data 13 | sample_fps: 4 14 | max_num_frames: 256 15 | longsize_resolution: 448 16 | 17 | ### generate 18 | do_sample: false 19 | 20 | ### output 21 | output_dir: results/qwen2vl_7b_mlvu_f256_4fps_r448/base 22 | -------------------------------------------------------------------------------- /configs/qwen2_vl/qwen2-vl_videomme.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: qwen2_vl 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | 7 | ### dataset 8 | dataset_name: videomme 9 | anno_file: dataset/video_mme/video_mme.json 10 | dataloader_num_workers: 2 11 | 12 | ### data 13 | sample_fps: 4 14 | max_num_frames: 256 15 | longsize_resolution: 448 16 | 17 | ### generate 18 | do_sample: false 19 | 20 | ### output 21 | output_dir: results/qwen2vl_7b_videomme_f256_4fps_r448/base 22 | -------------------------------------------------------------------------------- /configs/qwen2_vl/retake_qwen2-vl_lvbench.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: qwen2_vl 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | longvideo_kwargs: { 7 | 'frame_chunk_size': 128, 8 | 'chunked_prefill_frames': 32, 9 | # Keyframe compression 10 | 'visual_compression': True, 11 | 'visual_compression_kwargs': { 12 | 'compression_ratio': 1.0, 13 | 'compression_method': 'Keyframe', 14 | 'patch_sync': False, 15 | 'return_keyframe_mask': True 16 | }, 17 | # KVCache compression 18 | 'kvcache_compression': True, 19 | 'kvcache_compression_kwargs': { 20 | 'dynamic_compression_ratio': True, 21 | 'compression_method': 'pivotkv', 22 | 'pos_embed_reforge': True, 23 | 'max_input_length': 32000 24 | }, 25 | } 26 | 27 | ### dataset 28 | dataset_name: lvbench 29 | anno_file: dataset/lvbench/lvbench.json 30 | dataloader_num_workers: 2 31 | 32 | ### data 33 | sample_fps: 2 34 | max_num_frames: 2048 35 | longsize_resolution: 448 36 | 37 | ### generate 38 | do_sample: false 39 | 40 | ### output 41 | output_dir: results/qwen2vl_7b_lvbench_f2048_2fps_r448/retake_dp1-async_pivot-32k 42 | -------------------------------------------------------------------------------- /configs/qwen2_vl/retake_qwen2-vl_mlvu.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: qwen2_vl 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | longvideo_kwargs: { 7 | 'frame_chunk_size': 128, 8 | 'chunked_prefill_frames': 32, 9 | # Keyframe compression 10 | 'visual_compression': True, 11 | 'visual_compression_kwargs': { 12 | 'compression_ratio': 1.0, 13 | 'compression_method': 'Keyframe', 14 | 'patch_sync': False, 15 | 'return_keyframe_mask': True 16 | }, 17 | # KVCache compression 18 | 'kvcache_compression': True, 19 | 'kvcache_compression_kwargs': { 20 | 'dynamic_compression_ratio': True, 21 | 'compression_method': 'pivotkv', 22 | 'pos_embed_reforge': True, 23 | 'max_input_length': 32000 24 | }, 25 | } 26 | 27 | ### dataset 28 | dataset_name: mlvu 29 | anno_file: dataset/mlvu/mlvu.json 30 | dataloader_num_workers: 2 31 | 32 | ### data 33 | sample_fps: 4 34 | max_num_frames: 2048 35 | longsize_resolution: 448 36 | 37 | ### generate 38 | do_sample: false 39 | 40 | ### output 41 | output_dir: results/qwen2vl_7b_mlvu_f2048_4fps_r448/retake_dp1-async_pivot-32k 42 | -------------------------------------------------------------------------------- /configs/qwen2_vl/retake_qwen2-vl_videomme.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name: qwen2_vl 3 | method: retake 4 | scaling_factor: 4 5 | attn_implementation: "flash_attention_2" 6 | longvideo_kwargs: { 7 | 'frame_chunk_size': 128, 8 | 'chunked_prefill_frames': 32, 9 | # KVCache compression 10 | 'kvcache_compression': True, 11 | 'kvcache_compression_kwargs': { 12 | 'dynamic_compression_ratio': True, 13 | 'compression_method': 'pivotkv', 14 | 'pos_embed_reforge': True, 15 | 'max_input_length': 32000 16 | }, 17 | } 18 | 19 | 20 | ### dataset 21 | dataset_name: videomme 22 | anno_file: dataset/video_mme/video_mme.json 23 | dataloader_num_workers: 2 24 | 25 | ### data 26 | sample_fps: 4 27 | max_num_frames: 2048 28 | longsize_resolution: 448 29 | 30 | ### generate 31 | do_sample: false 32 | 33 | ### output 34 | output_dir: results/qwen2vl_7b_video_mme_f2048_4fps_r448/retake_pivot-32k 35 | -------------------------------------------------------------------------------- /configs/retake_demo.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | method: retake 3 | scaling_factor: 4 4 | attn_implementation: "flash_attention_2" 5 | longvideo_kwargs: { 6 | 'frame_chunk_size': 128, 7 | 'chunked_prefill_frames': 32, 8 | # Keyframe compression 9 | 'visual_compression': True, 10 | 'visual_compression_kwargs': { 11 | 'compression_ratio': 1.0, 12 | 'compression_method': 'Keyframe', 13 | 'patch_sync': False, 14 | 'return_keyframe_mask': True 15 | }, 16 | # KVCache compression 17 | 'kvcache_compression': True, 18 | 'kvcache_compression_kwargs': { 19 | 'dynamic_compression_ratio': True, 20 | 'compression_method': 'pivotkv', 21 | 'pos_embed_reforge': True, 22 | 'max_input_length': 32000 23 | }, 24 | } 25 | 26 | ### data 27 | sample_fps: 4 28 | max_num_frames: 2048 29 | longsize_resolution: 448 30 | 31 | ### generate 32 | do_sample: false -------------------------------------------------------------------------------- /configs/retake_demo_npu.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | method: retake 3 | scaling_factor: 4 4 | attn_implementation: "eager" # NPU does not support sdpa attention now 5 | longvideo_kwargs: { 6 | 'frame_chunk_size': 16, # Trade-off beteen peak memory and speed 7 | 'chunked_prefill_frames': 16, # Trade-off beteen peak memory and speed 8 | # Keyframe compression 9 | 'visual_compression': True, 10 | 'visual_compression_kwargs': { 11 | 'compression_ratio': 1.0, 12 | 'compression_method': 'Keyframe', 13 | 'patch_sync': False, 14 | 'return_keyframe_mask': True 15 | }, 16 | # KVCache compression 17 | 'kvcache_compression': True, 18 | 'kvcache_compression_kwargs': { 19 | 'dynamic_compression_ratio': True, 20 | 'compression_method': 'pivotkv', 21 | 'pos_embed_reforge': True, 22 | 'max_input_length': 32000 23 | }, 24 | } 25 | 26 | ### data 27 | sample_fps: 4 28 | max_num_frames: 2048 29 | longsize_resolution: 448 30 | 31 | ### generate 32 | do_sample: false -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import math 4 | import yaml 5 | from PIL import Image 6 | from typing import List, Union 7 | 8 | import torch 9 | import numpy as np 10 | from torchvision.transforms.functional import pil_to_tensor 11 | from transformers import AutoProcessor 12 | 13 | import retake 14 | 15 | 16 | def get_frame_indices(total_frames, max_num_frames, sample_fps, extraction_fps): 17 | # Get number of sampled frames 18 | sample_frames = float(total_frames / extraction_fps) * sample_fps 19 | sample_frames = min(total_frames, max_num_frames, sample_frames) 20 | sample_frames = math.floor(sample_frames) 21 | sample_frames = int(sample_frames / 2) * 2 22 | # Get sampled frame indices 23 | frame_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) 24 | return frame_indices 25 | 26 | 27 | def load_specific_frames(cap, frame_indices): 28 | # List to store the frames 29 | frames = [] 30 | # Read frames from the video 31 | for frame_index in frame_indices: 32 | # Set the video position to the desired frame index 33 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) 34 | # Read the frame 35 | ret, frame = cap.read() 36 | # If the frame was read successfully, append it to the list 37 | if ret: 38 | # Convert the frame from BGR to RGB 39 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 40 | # Create a PIL Image from the frame 41 | frame = Image.fromarray(frame_rgb) 42 | frames.append(frame) 43 | else: 44 | ValueError(f"Warning: Could not read frame at index {frame_index}. It may be out of range.") 45 | return frames 46 | 47 | 48 | def load_video(video_path: str, max_num_frames: int, fps: Union[int, float]=None, frame_extraction_fps: Union[int, float]=None): 49 | """Load video frames at fps. If total frames larger than `max_num_frames`, do downsample. 50 | If 'fps' is `None`, load uniformly sample `max_num_frames` frames. 51 | 52 | video_path: Should either be a videofile or a directory of extracted frames. 53 | 54 | # NOTE: The extract frames must have name pattern of `%06d.(ext)`, or the loaded frame order will be wrong. 55 | """ 56 | if video_path.startswith("file://"): 57 | video_path = video_path[7:] 58 | if os.path.isdir(video_path): # directory extracted frames 59 | assert frame_extraction_fps is not None 60 | pass 61 | else: # filename of a video 62 | # Open the video file 63 | cap = cv2.VideoCapture(video_path) 64 | if not cap.isOpened(): 65 | raise ValueError("Error: Could not open video.") 66 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 67 | frame_extraction_fps = cap.get(cv2.CAP_PROP_FPS) 68 | # Get indices of sampled frame 69 | frame_indices = get_frame_indices(total_frames, max_num_frames, fps, frame_extraction_fps) 70 | # Get frames 71 | frames = load_specific_frames(cap, frame_indices) 72 | # Release the video capture object 73 | cap.release() 74 | 75 | # Convert into RGB format 76 | frames = [ 77 | frame.convert("RGB") if frame.mode != "RGB" else frame 78 | for frame in frames 79 | ] 80 | 81 | return frames 82 | 83 | 84 | def resize_image_longside(image, image_resolution): 85 | r""" 86 | Pre-processes a single image. 87 | """ 88 | if max(image.width, image.height) > image_resolution: 89 | resize_factor = image_resolution / max(image.width, image.height) 90 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 91 | image = image.resize((width, height), resample=Image.NEAREST) 92 | 93 | return image 94 | 95 | 96 | def resize_video_longside(frames: List, video_resolution): 97 | """ 98 | frames: list of PIL images. 99 | """ 100 | frames = [ 101 | resize_image_longside(frame, video_resolution) 102 | for frame in frames 103 | ] 104 | return frames 105 | 106 | 107 | def load_yaml(file_path): 108 | with open(file_path, 'r') as file: 109 | data = yaml.safe_load(file) 110 | return data 111 | 112 | 113 | def fetch_video(video_info, max_num_frames, sample_fps, longsize_resolution): 114 | frames = load_video(video_info['video'], max_num_frames, sample_fps, video_info.get('frame_extraction_fps', None)) 115 | frames = resize_video_longside(frames, longsize_resolution) 116 | frames = [pil_to_tensor(frame) for frame in frames] 117 | return frames 118 | 119 | 120 | def load_and_patch_model(model_name, hf_model_path, exp_configs, device): 121 | model_name = model_name if model_name is not None else exp_configs['model_name'] 122 | model_name = model_name.lower().replace('-', '').replace('_', '') 123 | if model_name == 'qwen2vl': # QWen2VL 124 | from transformers import Qwen2VLConfig, Qwen2VLForConditionalGeneration 125 | from retake.monkeypatch import patch_qwen2vl, patch_qwen2vl_config 126 | retake.qwen2_vl.DEBUG_MODE = True 127 | patch_qwen2vl(exp_configs['method']) # Replace some functions of QWen2VL with those from ReTaKe 128 | qwen2vl_config = Qwen2VLConfig.from_pretrained(hf_model_path) 129 | qwen2vl_config = patch_qwen2vl_config(qwen2vl_config, exp_configs) 130 | model = Qwen2VLForConditionalGeneration.from_pretrained( 131 | hf_model_path, 132 | config=qwen2vl_config, 133 | torch_dtype=torch.bfloat16, 134 | attn_implementation=exp_configs.get('attn_implementation', None), 135 | device_map=device # "auto" 136 | ).eval() 137 | processor = AutoProcessor.from_pretrained(hf_model_path) 138 | elif model_name in ['llavaonevision', 'llavavideo']: # LLaVA-OneVision, LLaVA-Video 139 | from transformers import LlavaOnevisionConfig, LlavaOnevisionForConditionalGeneration 140 | from retake.monkeypatch import patch_llava_onevision, patch_llava_onevision_config 141 | retake.llava_onevision.DEBUG_MODE = True 142 | patch_llava_onevision(exp_configs['method']) # Replace some functions of LLaVA-Video with those from ReTaKe 143 | llava_onevision_config = LlavaOnevisionConfig.from_pretrained(hf_model_path) 144 | llava_onevision_config = patch_llava_onevision_config(llava_onevision_config, exp_configs) 145 | processor = AutoProcessor.from_pretrained(hf_model_path) 146 | model = LlavaOnevisionForConditionalGeneration.from_pretrained( 147 | hf_model_path, 148 | config=llava_onevision_config, 149 | torch_dtype=torch.bfloat16, 150 | attn_implementation=exp_configs.get('attn_implementation', None), 151 | device_map=device # "auto" 152 | ) 153 | else: 154 | raise NotImplementedError 155 | return model, processor 156 | 157 | 158 | DEMO_VIDEO = 'misc/Q8AZ16uBhr8_resized_fps2_mute.mp4' 159 | DEMO_QUESTIONS = [ 160 | "As depicted in the video, how is the relationship between the rabbit and human?\nOptions:\nA. Hostile.\nB. Friend.\nC. Cooperator.\nD. No one is correct above.\nAnswer with the option's letter from the given choices directly.", 161 | "What is the impression of the video?\nOptions:\nA. Sad.\nB. Funny.\nC. Horrible.\nD. Silent.\nAnswer with the option's letter from the given choices directly.", 162 | "What is the subject of the video?\nOptions:\nA. Rabbit likes to eat carrots.\nB. How to raise a rabbit.\nC. A rabbit gives people trouble.\nD. A rabbit performs for food.\nAnswer with the option's letter from the given choices directly.", 163 | ] 164 | EXPECTED_ANSWERS = ['A', 'B', 'C'] 165 | 166 | 167 | if __name__ == "__main__": 168 | #------------------- Modify the following configs ------------------# 169 | hf_model_path = 'Qwen/Qwen2-VL-7B-Instruct' # TODO: replace to local path if you have trouble downloading huggingface models 170 | model_name = 'qwen2_vl' 171 | # hf_model_path = '/path_to/llava-video-qwen2-7b-hf' 172 | # model_name = 'llava_video' 173 | # hf_model_path = 'llava-hf/llava-onevision-qwen2-7b-ov-hf' 174 | # model_name = 'llava_onevision' 175 | 176 | # NOTE: for Nvidia GPUs 177 | config_path = 'configs/retake_demo.yaml' 178 | device = 'cuda:0' 179 | 180 | # NOTE: for NPUs or GPUs without support for FlashAttention 181 | # config_path = 'configs/retake_demo_npu.yaml' 182 | # device = 'npu:0' 183 | 184 | #------------------------ No need to change ------------------------# 185 | video_info = {"type": "video", 186 | "video": DEMO_VIDEO, 187 | "fps": 2.0} 188 | 189 | exp_configs = load_yaml(config_path) 190 | 191 | model, processor = load_and_patch_model(model_name, hf_model_path, exp_configs, device) 192 | 193 | # Video 194 | video = fetch_video(video_info, exp_configs['max_num_frames'], exp_configs['sample_fps'], exp_configs['longsize_resolution']) 195 | for question, expect_answer in zip(DEMO_QUESTIONS, EXPECTED_ANSWERS): 196 | conversation = [ 197 | { 198 | "role": "user", 199 | "content": [ 200 | {"type": "video"}, 201 | {"type": "text", "text": question}, 202 | ], 203 | } 204 | ] 205 | 206 | # Preprocess the inputs 207 | text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) 208 | print('Input prompt:\n', text_prompt) 209 | 210 | inputs = processor(text=[text_prompt], videos=[video], padding=True, return_tensors="pt") 211 | inputs = inputs.to(device) 212 | inputs['pixel_values_videos'] = inputs['pixel_values_videos'].to(torch.bfloat16) 213 | 214 | # Inference: Generation of the output 215 | output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=128) 216 | generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] 217 | output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) 218 | output_text = output_text[0] 219 | print('Output text:\n', output_text) 220 | print('Expected answer:\n', expect_answer) 221 | -------------------------------------------------------------------------------- /docs/prepare_lvbench.md: -------------------------------------------------------------------------------- 1 | ## Prepare LVBench Dataset 2 | 3 | 4 | ### Step 1: download LVBench data from [huggingface](https://huggingface.co/datasets/THUDM/LVBench/tree/main) 5 | ```bash 6 | git clone https://huggingface.co/datasets/THUDM/LVBench # Contain annotations only 7 | git clone https://huggingface.co/datasets/AIWinter/LVBench # Contain videos only 8 | ``` 9 | Move all_files in `AIWinter/LVBench` into `THUDM/LVBench`. 10 | 11 | Denote the root directory of download LVBench dataset as `lvbench_root`, it should has the following structure: 12 | ``` 13 | ${lvbench_root}/ 14 | ├── docs/ 15 | ├── video_info.meta.jsonl 16 | ├── all_videos_split.zip.001 17 | ├── all_videos_split.zip.002 18 | ├── ... 19 | └── all_videos_split.zip.014 20 | ``` 21 | 22 | 23 | ### Step 2: Unzip everything 24 | ```bash 25 | cd ${lvbench_root} 26 | cat all_videos_split.zip.* > all_videos.zip 27 | unzip all_videos.zip 28 | ``` 29 | 30 | 31 | ### Step 3: Extract frames of all videos 32 | ```bash 33 | cd ${retake_repo_root} 34 | python scripts/utils/frame_extraction.py \ 35 | --videofile_tpl ${lvbench_root}/all_videos/'*.mp4' \ 36 | --results_dir ${lvbench_root}/video_25fps \ 37 | --fps 25 \ 38 | --num_workers 32 39 | ``` 40 | 41 | 42 | ### Step 4: Build LVBench dataset 43 | ```bash 44 | cd ${retake_repo_root} 45 | python scripts/utils/build_lvbench_dataset.py --hf_root ${lvbench_root} 46 | ``` 47 | Note that you can NOT modify folder `${lvbench_root}/video_25fps` after this step, since the absolute path of extracted frames are written into annotation files `lvbench.json`: 48 | ``` 49 | retake_repo_root/ 50 | ├── dataset/ 51 | ├── lvbench/ 52 | ├── lvbench.json 53 | ├── ... 54 | ``` -------------------------------------------------------------------------------- /docs/prepare_mlvu.md: -------------------------------------------------------------------------------- 1 | ## Prepare MLVU Dataset 2 | 3 | 4 | ### Step 1: download MLVU dataset from [huggingface](https://huggingface.co/datasets/MLVU/MVLU) 5 | ```bash 6 | git clone https://huggingface.co/datasets/MLVU/MVLU 7 | ``` 8 | 9 | Denote the root directory of download MLVU dataset as `mlvu_root`, it should has the following structure: 10 | ``` 11 | ${mlvu_root}/ 12 | ├── MLVU/ 13 | ├── json 14 | ... 15 | ├── video 16 | ... 17 | ├── figs/ 18 | ``` 19 | 20 | 21 | ### Step 2: Extract frames of all videos 22 | ```bash 23 | cd ${retake_repo_root} 24 | python scripts/utils/frame_extraction.py \ 25 | --videofile_tpl ${mlvu_root}/MLVU/video/'*/*.mp4' \ 26 | --results_dir ${mlvu_root}/MLVU/video_25fps \ 27 | --fps 25 \ 28 | --num_workers 32 29 | ``` 30 | 31 | 32 | ### Step 3: Build MLVU dataset 33 | ```bash 34 | cd ${retake_repo_root} 35 | python scripts/utils/build_mlvu_dataset.py --hf_root ${mlvu_root} 36 | ``` 37 | Note that you can NOT modify folder `${mlvu_root}/MLVU/video_25fps` after this step, since the absolute path of extracted frames are written into annotation files `mlvu.json`: 38 | ``` 39 | retake_repo_root/ 40 | ├── dataset/ 41 | ├── mlvu/ 42 | ├── mlvu.json 43 | ├── ... 44 | ``` -------------------------------------------------------------------------------- /docs/prepare_videomme.md: -------------------------------------------------------------------------------- 1 | ## Prepare VideoMME Dataset 2 | 3 | 4 | ### Step 1: download VideoMME dataset from [huggingface](https://huggingface.co/datasets/lmms-lab/Video-MME) 5 | ```bash 6 | git clone https://huggingface.co/datasets/lmms-lab/Video-MME 7 | ``` 8 | 9 | Denote the root directory of download VideoMME dataset as `videomme_root`, it should has the following structure: 10 | ``` 11 | ${videomme_root}/ 12 | ├── videomme/ 13 | ├── subtitle.zip 14 | ├── videos_chunked_01.zip 15 | ├── videos_chunked_02.zip 16 | ├── ... 17 | └── videos_chunked_20.zip 18 | ``` 19 | 20 | 21 | ### Step 2: Unzip everything 22 | ```bash 23 | cd ${videomme_root} 24 | unzip subtitle.zip 25 | cat videos_chunked_*.zip > videos.zip 26 | unzip videos.zip 27 | ``` 28 | 29 | 30 | ### Step 3: Extract frames of all videos 31 | ```bash 32 | cd ${retake_repo_root} 33 | python scripts/utils/frame_extraction.py \ 34 | --videofile_tpl ${videomme_root}/data/'*.mp4' \ 35 | --results_dir ${videomme_root}/data_25fps \ 36 | --fps 25 \ 37 | --num_workers 32 38 | ``` 39 | 40 | 41 | ### Step 4: Build VideoMME dataset 42 | ```bash 43 | cd ${retake_repo_root} 44 | python scripts/utils/build_videomme_dataset.py \ 45 | --hf_qwen2vl7b_path ${PATH_TO_Qwen2-VL-7B-Instruct} \ 46 | --hf_root ${videomme_root} 47 | ``` 48 | Note that you can NOT modify folder `${videomme_root}/data_25fps` after this step, since the absolute path of extracted frames are written into annotation files `video_mme.json` and `video_mme_subtitle.json`: 49 | ``` 50 | retake_repo_root/ 51 | ├── dataset/ 52 | ├── video_mme/ 53 | ├── video_mme_subtitle.json 54 | ├── video_mme.json 55 | ├── ... 56 | ``` -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: retake 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python==3.11 6 | - pip: 7 | - torch==2.4.0 8 | - torchvision==0.19.0 9 | - transformers==4.48 10 | - accelerate==0.34.2 11 | - flash-attn==2.6.3 12 | - av==13.1.0 13 | - pyyaml==6.0.2 14 | - opencv-python-headless==4.10.0.84 15 | - pandas==2.2.3 16 | - pysubs2==1.7.3 17 | - pyarrow==17.0.0 18 | - openai==1.56.0 -------------------------------------------------------------------------------- /environment_npu.yaml: -------------------------------------------------------------------------------- 1 | name: retake 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python==3.11 6 | - pip: 7 | - numpy==1.26.4 8 | - scipy==1.14.1 9 | - torch==2.4.0 10 | - torch-npu==2.4.0 11 | - torchvision==0.19.0 12 | - transformers==4.48 13 | - accelerate==0.34.2 14 | - av==13.1.0 15 | - pyyaml==6.0.2 16 | - opencv-python-headless==4.10.0.84 17 | - pandas==2.2.3 18 | - pysubs2==1.7.3 19 | - pyarrow==17.0.0 20 | - openai==1.56.0 -------------------------------------------------------------------------------- /misc/Q8AZ16uBhr8_resized_fps2_mute.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCZwangxiao/video-ReTaKe/e268c32c242061093d694f0cc219794857e71dd8/misc/Q8AZ16uBhr8_resized_fps2_mute.mp4 -------------------------------------------------------------------------------- /misc/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCZwangxiao/video-ReTaKe/e268c32c242061093d694f0cc219794857e71dd8/misc/overview.png -------------------------------------------------------------------------------- /retake/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import io 4 | import re 5 | import json 6 | import math 7 | import base64 8 | from PIL import Image 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from typing import Optional, List 12 | 13 | try: 14 | os.environ['OPENAI_BASE_URL'] = None 15 | os.environ['OPENAI_API_KEY'] = None 16 | import openai 17 | except: 18 | print("Warning! openai not installed for MLVU evalutation") 19 | import numpy as np 20 | 21 | 22 | class BaseDataset: 23 | def __init__(self, 24 | anno_file: str, 25 | processor_kwargs: str 26 | ) -> None: 27 | self.processor_kwargs = processor_kwargs 28 | # Load annotations 29 | with open(anno_file, 'r') as F: 30 | self.annos = json.load(F) 31 | # Preprocess meta 32 | for anno in self.annos: 33 | # NOTE: Pyarrow caching in LLaMA-Factory will raise error 34 | # for some complicate json data. So dump to jsons. 35 | if type(anno['meta']) == str: 36 | anno['meta'] = json.loads(anno['meta']) 37 | 38 | @staticmethod 39 | def _get_video_sample_extracted_frames(frame_files: List[str], **kwargs) -> int: 40 | video_fps = kwargs.get("video_fps") 41 | video_maxlen = kwargs.get("video_maxlen") 42 | extraction_fps = kwargs.get("video_frame_extraction_fps") 43 | total_frames = len(frame_files) 44 | sample_frames = float(total_frames / extraction_fps) * video_fps 45 | sample_frames = min(total_frames, video_maxlen, sample_frames) 46 | sample_frames = math.floor(sample_frames) 47 | return int(sample_frames / 2) * 2 48 | 49 | @staticmethod 50 | def _preprocess_image(image, **kwargs): 51 | r""" 52 | Pre-processes a single image. 53 | """ 54 | image_resolution: int = kwargs.get("image_resolution") 55 | if max(image.width, image.height) > image_resolution: 56 | resize_factor = image_resolution / max(image.width, image.height) 57 | width, height = int(image.width * resize_factor), int(image.height * resize_factor) 58 | image = image.resize((width, height), resample=Image.NEAREST) 59 | 60 | if image.mode != "RGB": 61 | image = image.convert("RGB") 62 | 63 | return image 64 | 65 | def __len__(self): 66 | return len(self.annos) 67 | 68 | def get_video_message(self, video_root: str): 69 | frames = [] 70 | frame_files = [ 71 | os.path.join(video_root, file) for file in list(sorted(os.listdir(video_root))) 72 | ] 73 | total_frames = len(frame_files) 74 | sample_frames = self._get_video_sample_extracted_frames(frame_files, **self.processor_kwargs) 75 | sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) 76 | for frame_idx, frame_file in enumerate(frame_files): 77 | if frame_idx in sample_indices: 78 | # NOTE: Load and resize on the fly can creatly RAM cost of dataloader 79 | image = Image.open(frame_file) 80 | resized_image = self._preprocess_image(image, **self.processor_kwargs) 81 | frames.append(resized_image) 82 | 83 | return frames 84 | 85 | def __getitem__(self, idx): 86 | anno = self.annos[idx] 87 | 88 | question = anno["messages"][0]["content"].replace('