├── .gitignore
├── LICENSE
├── README.md
├── configs
├── llava_video
│ ├── llava-video_lvbench.yaml
│ ├── llava-video_mlvu.yaml
│ ├── llava-video_videomme.yaml
│ ├── retake_llava-video_lvbench.yaml
│ ├── retake_llava-video_mlvu.yaml
│ └── retake_llava-video_videomme.yaml
├── qwen2_vl
│ ├── qwen2-vl_lvbench.yaml
│ ├── qwen2-vl_mlvu.yaml
│ ├── qwen2-vl_videomme.yaml
│ ├── retake_qwen2-vl_lvbench.yaml
│ ├── retake_qwen2-vl_mlvu.yaml
│ └── retake_qwen2-vl_videomme.yaml
├── retake_demo.yaml
└── retake_demo_npu.yaml
├── demo.py
├── docs
├── prepare_lvbench.md
├── prepare_mlvu.md
└── prepare_videomme.md
├── environment.yaml
├── environment_npu.yaml
├── misc
├── Q8AZ16uBhr8_resized_fps2_mute.mp4
└── overview.png
├── retake
├── dataset_utils.py
├── infer_eval.py
├── llava_onevision.py
├── longvideo_cache.py
├── monkeypatch.py
├── qwen2_vl.py
└── visual_compression.py
└── scripts
├── infer_eval_retake.sh
└── utils
├── build_lvbench_dataset.py
├── build_mlvu_dataset.py
├── build_mlvu_test_dataset.py
├── build_videomme_dataset.py
├── cal_flops.py
├── cal_ttft.py
├── convert_llava_video_weights_to_hf.py
└── frame_extraction.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /dataset
2 | /results
3 | */__pycache__
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 SCZwangxiao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ReTaKe: Reducing Temporal and Knowledge Redundancy for Long Video Understanding](https://arxiv.org/abs/2412.20504)
2 |
3 | ReTaKe is a novel approach for long video understanding that reduces temporal and knowledge redundancy, enabling MLLMs to process 8x longer video sequences (up to 2048 frames) under the same memory budget.
4 |
5 | ---
6 |
7 | ## 📢 Recent Updates
8 | - **2025/03/11**: Polish the paper, improve the readability of the methods section, and add more ablation studies and results for LongVideoBench.
9 | - **2025/02/01**: Support for the latest version of Transformers (v4.48).
10 | - **2025/01/29**: Added support for LLaVA-Video and LLaVA-OneVision.
11 |
12 | ---
13 |
14 | ## 🚀 Key Contributions
15 |
16 | - **Training-Free Framework**: ReTaKe is the first method to jointly model temporal and knowledge redundancy for long video understanding, reducing the model sequence length to 1/4 of the original with a relative performance loss within 1%.
17 |
18 | - **Novel Techniques**:
19 | - **DPSelect**: A keyframe selection method to reduce low-level temporal redundancy.
20 | - **PivotKV**: A KV cache compression method to reduce high-level knowledge redundancy in long videos.
21 |
22 |
23 |
24 |
25 |
26 | ---
27 |
28 | ## ⚙️ Environment Setup
29 |
30 | ### For GPU Users:
31 | ```bash
32 | conda env create -f environment.yaml
33 | ```
34 |
35 | ### For NPU Users:
36 | ```bash
37 | conda env create -f environment_npu.yaml
38 | ```
39 |
40 | ### Additional Dependencies:
41 | ```bash
42 | apt-get install ffmpeg # Required for full functionality; quick demo does not require ffmpeg.
43 | ```
44 |
45 | ---
46 |
47 | ## 🖥️ Quick Demo
48 |
49 | ### Step 1: Update Configuration
50 | Modify the `hf_qwen2vl7b_path` in `./demo.py` to point to your local path for `Qwen2-VL-7B-Instruct`.
51 | For NPU users, also update `config_path` to `'configs/retake_demo_npu.yaml'`.
52 |
53 | ### Step 2 (Optional for LLaVA-Video): Convert Model
54 | ```bash
55 | # Convert LLaVA-Video model into Hugging Face format
56 | # Ensure the following models are downloaded: Qwen2-7B-Instruct, siglip-so400m-patch14-384, and LLaVAVideoQwen2_7B.
57 | python scripts/utils/convert_llava_video_weights_to_hf.py \
58 | --text_model_id /path_to/Qwen2-7B-Instruct \
59 | --vision_model_id /path_to/siglip-so400m-patch14-384 \
60 | --output_hub_path /path_to/llava-video-qwen2-7b-hf \
61 | --old_state_dict_id /path_to/LLaVAVideoQwen2_7B
62 | ```
63 |
64 | ### Step 3: Run the Demo
65 | ```bash
66 | python demo.py
67 | ```
68 |
69 | ---
70 |
71 | ## 📊 Reproducing ReTaKe Results
72 |
73 | ### Step 1: Prepare Datasets
74 | Follow the documentation to prepare the required datasets:
75 | - [VideoMME](docs/prepare_videomme.md)
76 | - [MLVU](docs/prepare_mlvu.md)
77 | - [LVBench](docs/prepare_lvbench.md)
78 |
79 | ### Step 2: Run Inference and Evaluation
80 | Use the provided script to perform inference and evaluation:
81 | ```bash
82 | bash scripts/infer_eval_retake.sh ${YOUR_PATH_TO_Qwen2-VL-7B-Instruct} configs/qwen2_vl/retake_qwen2-vl_videomme.yaml 8
83 | bash scripts/infer_eval_retake.sh ${YOUR_PATH_TO_Qwen2-VL-7B-Instruct} configs/qwen2_vl/retake_qwen2-vl_mlvu.yaml 8
84 | bash scripts/infer_eval_retake.sh ${YOUR_PATH_TO_Qwen2-VL-7B-Instruct} configs/qwen2_vl/retake_qwen2-vl_lvbench.yaml 8
85 | ```
86 |
87 | - Results will be saved in the `./results` directory.
88 |
89 | ---
90 |
91 | ## 📚 Citation
92 | If you find this work helpful, please consider citing:
93 | ```bibtex
94 | @misc{xiao_retake_2024,
95 | author = {Xiao Wang and
96 | Qingyi Si and
97 | Jianlong Wu and
98 | Shiyu Zhu and
99 | Li Cao and
100 | Liqiang Nie},
101 | title = {{ReTaKe}: {Reducing} {Temporal} and {Knowledge} {Redundancy} for {Long} {Video} {Understanding}},
102 | year = {2024},
103 | note = {arXiv:2412.20504 [cs]}
104 | }
105 | ```
106 |
--------------------------------------------------------------------------------
/configs/llava_video/llava-video_lvbench.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: llava_video
3 | method: retake
4 | attn_implementation: "flash_attention_2"
5 |
6 | ### dataset
7 | dataset_name: lvbench
8 | anno_file: dataset/lvbench/lvbench.json
9 | dataloader_num_workers: 4
10 |
11 | ### data
12 | sample_fps: 2
13 | max_num_frames: 64
14 | longsize_resolution: 682 # short-side can be 384
15 |
16 | ### generate
17 | do_sample: false
18 |
19 | ### output
20 | output_dir: results/llava-video_lvbench_f64_2fps_r682/base
21 |
--------------------------------------------------------------------------------
/configs/llava_video/llava-video_mlvu.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: llava_video
3 | method: retake
4 | attn_implementation: "flash_attention_2"
5 |
6 | ### dataset
7 | dataset_name: mlvu
8 | anno_file: dataset/mlvu/mlvu.json
9 | dataloader_num_workers: 4
10 |
11 | ### data
12 | sample_fps: 2
13 | max_num_frames: 64
14 | longsize_resolution: 682 # short-side can be 384
15 |
16 | ### generate
17 | do_sample: false
18 |
19 | ### output
20 | output_dir: results/llava-video_mlvu_f64_2fps_r682/base
21 |
--------------------------------------------------------------------------------
/configs/llava_video/llava-video_videomme.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: llava_video
3 | method: retake
4 | attn_implementation: "flash_attention_2"
5 |
6 | ### dataset
7 | dataset_name: videomme
8 | anno_file: dataset/video_mme/video_mme.json
9 | dataloader_num_workers: 4
10 |
11 | ### data
12 | sample_fps: 2
13 | max_num_frames: 64
14 | longsize_resolution: 682 # short-side can be 384
15 |
16 | ### generate
17 | do_sample: false
18 |
19 | ### output
20 | output_dir: results/llava-video_video_mme_f64_2fps_r682/base
21 |
--------------------------------------------------------------------------------
/configs/llava_video/retake_llava-video_lvbench.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: llava_video
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 | longvideo_kwargs: {
7 | 'frame_chunk_size': 32,
8 | 'chunked_prefill_frames': 32,
9 | # Keyframe compression
10 | 'visual_compression': True,
11 | 'visual_compression_kwargs': {
12 | 'compression_ratio': 1.0,
13 | 'compression_method': 'Keyframe',
14 | 'patch_sync': False,
15 | 'return_keyframe_mask': True
16 | },
17 | # KVCache compression
18 | 'kvcache_compression': True,
19 | 'kvcache_compression_kwargs': {
20 | 'dynamic_compression_ratio': True,
21 | 'compression_method': 'pivotkv',
22 | 'pos_embed_reforge': True,
23 | 'max_input_length': 40000
24 | },
25 | }
26 |
27 | ### dataset
28 | dataset_name: lvbench
29 | anno_file: dataset/lvbench/lvbench.json
30 | dataloader_num_workers: 4
31 |
32 | ### data
33 | sample_fps: 2
34 | max_num_frames: 1024
35 | longsize_resolution: 682
36 |
37 | ### generate
38 | do_sample: false
39 |
40 | ### output
41 | output_dir: results/llava-video_f1024_2fps_r682/retake_dp1-async_pivot-40k
42 |
--------------------------------------------------------------------------------
/configs/llava_video/retake_llava-video_mlvu.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: llava_video
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 | longvideo_kwargs: {
7 | 'frame_chunk_size': 32,
8 | 'chunked_prefill_frames': 32,
9 | # Keyframe compression
10 | 'visual_compression': True,
11 | 'visual_compression_kwargs': {
12 | 'compression_ratio': 1.0,
13 | 'compression_method': 'Keyframe',
14 | 'patch_sync': False,
15 | 'return_keyframe_mask': True
16 | },
17 | # KVCache compression
18 | 'kvcache_compression': True,
19 | 'kvcache_compression_kwargs': {
20 | 'dynamic_compression_ratio': True,
21 | 'compression_method': 'pivotkv',
22 | 'pos_embed_reforge': True,
23 | 'max_input_length': 40000
24 | },
25 | }
26 |
27 | ### dataset
28 | dataset_name: mlvu
29 | anno_file: dataset/mlvu/mlvu.json
30 | dataloader_num_workers: 4
31 |
32 | ### data
33 | sample_fps: 2
34 | max_num_frames: 1024
35 | longsize_resolution: 682
36 |
37 | ### generate
38 | do_sample: false
39 |
40 | ### output
41 | output_dir: results/llava-video_rope4_mlvu_f1024_2fps_r682/retake_dp1-async_pivot-40k
42 |
--------------------------------------------------------------------------------
/configs/llava_video/retake_llava-video_videomme.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: llava_video
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 | longvideo_kwargs: {
7 | 'frame_chunk_size': 32,
8 | 'chunked_prefill_frames': 32,
9 | # Keyframe compression
10 | 'visual_compression': True,
11 | 'visual_compression_kwargs': {
12 | 'compression_ratio': 1.0,
13 | 'compression_method': 'Keyframe',
14 | 'patch_sync': False,
15 | 'return_keyframe_mask': True
16 | },
17 | # KVCache compression
18 | 'kvcache_compression': True,
19 | 'kvcache_compression_kwargs': {
20 | 'dynamic_compression_ratio': True,
21 | 'compression_method': 'pivotkv',
22 | 'pos_embed_reforge': True,
23 | 'max_input_length': 40000
24 | },
25 | }
26 |
27 | ### dataset
28 | dataset_name: videomme
29 | anno_file: dataset/video_mme/video_mme.json
30 | dataloader_num_workers: 4
31 |
32 | ### data
33 | sample_fps: 2
34 | max_num_frames: 1024
35 | longsize_resolution: 682
36 |
37 | ### generate
38 | do_sample: false
39 |
40 | ### output
41 | output_dir: results/llava-video_rope4_video_mme_f1024_2fps_r682/retake_dp1-async_pivot-40k
42 |
--------------------------------------------------------------------------------
/configs/qwen2_vl/qwen2-vl_lvbench.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: qwen2_vl
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 |
7 | ### dataset
8 | dataset_name: lvbench
9 | anno_file: dataset/lvbench/lvbench.json
10 | dataloader_num_workers: 2
11 |
12 | ### data
13 | sample_fps: 2
14 | max_num_frames: 256
15 | longsize_resolution: 448
16 |
17 | ### generate
18 | do_sample: false
19 |
20 | ### output
21 | output_dir: results/qwen2vl_7b_lvbench_f256_2fps_r448/base
22 |
--------------------------------------------------------------------------------
/configs/qwen2_vl/qwen2-vl_mlvu.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: qwen2_vl
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 |
7 | ### dataset
8 | dataset_name: mlvu
9 | anno_file: dataset/mlvu/mlvu.json
10 | dataloader_num_workers: 2
11 |
12 | ### data
13 | sample_fps: 4
14 | max_num_frames: 256
15 | longsize_resolution: 448
16 |
17 | ### generate
18 | do_sample: false
19 |
20 | ### output
21 | output_dir: results/qwen2vl_7b_mlvu_f256_4fps_r448/base
22 |
--------------------------------------------------------------------------------
/configs/qwen2_vl/qwen2-vl_videomme.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: qwen2_vl
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 |
7 | ### dataset
8 | dataset_name: videomme
9 | anno_file: dataset/video_mme/video_mme.json
10 | dataloader_num_workers: 2
11 |
12 | ### data
13 | sample_fps: 4
14 | max_num_frames: 256
15 | longsize_resolution: 448
16 |
17 | ### generate
18 | do_sample: false
19 |
20 | ### output
21 | output_dir: results/qwen2vl_7b_videomme_f256_4fps_r448/base
22 |
--------------------------------------------------------------------------------
/configs/qwen2_vl/retake_qwen2-vl_lvbench.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: qwen2_vl
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 | longvideo_kwargs: {
7 | 'frame_chunk_size': 128,
8 | 'chunked_prefill_frames': 32,
9 | # Keyframe compression
10 | 'visual_compression': True,
11 | 'visual_compression_kwargs': {
12 | 'compression_ratio': 1.0,
13 | 'compression_method': 'Keyframe',
14 | 'patch_sync': False,
15 | 'return_keyframe_mask': True
16 | },
17 | # KVCache compression
18 | 'kvcache_compression': True,
19 | 'kvcache_compression_kwargs': {
20 | 'dynamic_compression_ratio': True,
21 | 'compression_method': 'pivotkv',
22 | 'pos_embed_reforge': True,
23 | 'max_input_length': 32000
24 | },
25 | }
26 |
27 | ### dataset
28 | dataset_name: lvbench
29 | anno_file: dataset/lvbench/lvbench.json
30 | dataloader_num_workers: 2
31 |
32 | ### data
33 | sample_fps: 2
34 | max_num_frames: 2048
35 | longsize_resolution: 448
36 |
37 | ### generate
38 | do_sample: false
39 |
40 | ### output
41 | output_dir: results/qwen2vl_7b_lvbench_f2048_2fps_r448/retake_dp1-async_pivot-32k
42 |
--------------------------------------------------------------------------------
/configs/qwen2_vl/retake_qwen2-vl_mlvu.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: qwen2_vl
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 | longvideo_kwargs: {
7 | 'frame_chunk_size': 128,
8 | 'chunked_prefill_frames': 32,
9 | # Keyframe compression
10 | 'visual_compression': True,
11 | 'visual_compression_kwargs': {
12 | 'compression_ratio': 1.0,
13 | 'compression_method': 'Keyframe',
14 | 'patch_sync': False,
15 | 'return_keyframe_mask': True
16 | },
17 | # KVCache compression
18 | 'kvcache_compression': True,
19 | 'kvcache_compression_kwargs': {
20 | 'dynamic_compression_ratio': True,
21 | 'compression_method': 'pivotkv',
22 | 'pos_embed_reforge': True,
23 | 'max_input_length': 32000
24 | },
25 | }
26 |
27 | ### dataset
28 | dataset_name: mlvu
29 | anno_file: dataset/mlvu/mlvu.json
30 | dataloader_num_workers: 2
31 |
32 | ### data
33 | sample_fps: 4
34 | max_num_frames: 2048
35 | longsize_resolution: 448
36 |
37 | ### generate
38 | do_sample: false
39 |
40 | ### output
41 | output_dir: results/qwen2vl_7b_mlvu_f2048_4fps_r448/retake_dp1-async_pivot-32k
42 |
--------------------------------------------------------------------------------
/configs/qwen2_vl/retake_qwen2-vl_videomme.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name: qwen2_vl
3 | method: retake
4 | scaling_factor: 4
5 | attn_implementation: "flash_attention_2"
6 | longvideo_kwargs: {
7 | 'frame_chunk_size': 128,
8 | 'chunked_prefill_frames': 32,
9 | # KVCache compression
10 | 'kvcache_compression': True,
11 | 'kvcache_compression_kwargs': {
12 | 'dynamic_compression_ratio': True,
13 | 'compression_method': 'pivotkv',
14 | 'pos_embed_reforge': True,
15 | 'max_input_length': 32000
16 | },
17 | }
18 |
19 |
20 | ### dataset
21 | dataset_name: videomme
22 | anno_file: dataset/video_mme/video_mme.json
23 | dataloader_num_workers: 2
24 |
25 | ### data
26 | sample_fps: 4
27 | max_num_frames: 2048
28 | longsize_resolution: 448
29 |
30 | ### generate
31 | do_sample: false
32 |
33 | ### output
34 | output_dir: results/qwen2vl_7b_video_mme_f2048_4fps_r448/retake_pivot-32k
35 |
--------------------------------------------------------------------------------
/configs/retake_demo.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | method: retake
3 | scaling_factor: 4
4 | attn_implementation: "flash_attention_2"
5 | longvideo_kwargs: {
6 | 'frame_chunk_size': 128,
7 | 'chunked_prefill_frames': 32,
8 | # Keyframe compression
9 | 'visual_compression': True,
10 | 'visual_compression_kwargs': {
11 | 'compression_ratio': 1.0,
12 | 'compression_method': 'Keyframe',
13 | 'patch_sync': False,
14 | 'return_keyframe_mask': True
15 | },
16 | # KVCache compression
17 | 'kvcache_compression': True,
18 | 'kvcache_compression_kwargs': {
19 | 'dynamic_compression_ratio': True,
20 | 'compression_method': 'pivotkv',
21 | 'pos_embed_reforge': True,
22 | 'max_input_length': 32000
23 | },
24 | }
25 |
26 | ### data
27 | sample_fps: 4
28 | max_num_frames: 2048
29 | longsize_resolution: 448
30 |
31 | ### generate
32 | do_sample: false
--------------------------------------------------------------------------------
/configs/retake_demo_npu.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | method: retake
3 | scaling_factor: 4
4 | attn_implementation: "eager" # NPU does not support sdpa attention now
5 | longvideo_kwargs: {
6 | 'frame_chunk_size': 16, # Trade-off beteen peak memory and speed
7 | 'chunked_prefill_frames': 16, # Trade-off beteen peak memory and speed
8 | # Keyframe compression
9 | 'visual_compression': True,
10 | 'visual_compression_kwargs': {
11 | 'compression_ratio': 1.0,
12 | 'compression_method': 'Keyframe',
13 | 'patch_sync': False,
14 | 'return_keyframe_mask': True
15 | },
16 | # KVCache compression
17 | 'kvcache_compression': True,
18 | 'kvcache_compression_kwargs': {
19 | 'dynamic_compression_ratio': True,
20 | 'compression_method': 'pivotkv',
21 | 'pos_embed_reforge': True,
22 | 'max_input_length': 32000
23 | },
24 | }
25 |
26 | ### data
27 | sample_fps: 4
28 | max_num_frames: 2048
29 | longsize_resolution: 448
30 |
31 | ### generate
32 | do_sample: false
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import math
4 | import yaml
5 | from PIL import Image
6 | from typing import List, Union
7 |
8 | import torch
9 | import numpy as np
10 | from torchvision.transforms.functional import pil_to_tensor
11 | from transformers import AutoProcessor
12 |
13 | import retake
14 |
15 |
16 | def get_frame_indices(total_frames, max_num_frames, sample_fps, extraction_fps):
17 | # Get number of sampled frames
18 | sample_frames = float(total_frames / extraction_fps) * sample_fps
19 | sample_frames = min(total_frames, max_num_frames, sample_frames)
20 | sample_frames = math.floor(sample_frames)
21 | sample_frames = int(sample_frames / 2) * 2
22 | # Get sampled frame indices
23 | frame_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
24 | return frame_indices
25 |
26 |
27 | def load_specific_frames(cap, frame_indices):
28 | # List to store the frames
29 | frames = []
30 | # Read frames from the video
31 | for frame_index in frame_indices:
32 | # Set the video position to the desired frame index
33 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
34 | # Read the frame
35 | ret, frame = cap.read()
36 | # If the frame was read successfully, append it to the list
37 | if ret:
38 | # Convert the frame from BGR to RGB
39 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
40 | # Create a PIL Image from the frame
41 | frame = Image.fromarray(frame_rgb)
42 | frames.append(frame)
43 | else:
44 | ValueError(f"Warning: Could not read frame at index {frame_index}. It may be out of range.")
45 | return frames
46 |
47 |
48 | def load_video(video_path: str, max_num_frames: int, fps: Union[int, float]=None, frame_extraction_fps: Union[int, float]=None):
49 | """Load video frames at fps. If total frames larger than `max_num_frames`, do downsample.
50 | If 'fps' is `None`, load uniformly sample `max_num_frames` frames.
51 |
52 | video_path: Should either be a videofile or a directory of extracted frames.
53 |
54 | # NOTE: The extract frames must have name pattern of `%06d.(ext)`, or the loaded frame order will be wrong.
55 | """
56 | if video_path.startswith("file://"):
57 | video_path = video_path[7:]
58 | if os.path.isdir(video_path): # directory extracted frames
59 | assert frame_extraction_fps is not None
60 | pass
61 | else: # filename of a video
62 | # Open the video file
63 | cap = cv2.VideoCapture(video_path)
64 | if not cap.isOpened():
65 | raise ValueError("Error: Could not open video.")
66 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
67 | frame_extraction_fps = cap.get(cv2.CAP_PROP_FPS)
68 | # Get indices of sampled frame
69 | frame_indices = get_frame_indices(total_frames, max_num_frames, fps, frame_extraction_fps)
70 | # Get frames
71 | frames = load_specific_frames(cap, frame_indices)
72 | # Release the video capture object
73 | cap.release()
74 |
75 | # Convert into RGB format
76 | frames = [
77 | frame.convert("RGB") if frame.mode != "RGB" else frame
78 | for frame in frames
79 | ]
80 |
81 | return frames
82 |
83 |
84 | def resize_image_longside(image, image_resolution):
85 | r"""
86 | Pre-processes a single image.
87 | """
88 | if max(image.width, image.height) > image_resolution:
89 | resize_factor = image_resolution / max(image.width, image.height)
90 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
91 | image = image.resize((width, height), resample=Image.NEAREST)
92 |
93 | return image
94 |
95 |
96 | def resize_video_longside(frames: List, video_resolution):
97 | """
98 | frames: list of PIL images.
99 | """
100 | frames = [
101 | resize_image_longside(frame, video_resolution)
102 | for frame in frames
103 | ]
104 | return frames
105 |
106 |
107 | def load_yaml(file_path):
108 | with open(file_path, 'r') as file:
109 | data = yaml.safe_load(file)
110 | return data
111 |
112 |
113 | def fetch_video(video_info, max_num_frames, sample_fps, longsize_resolution):
114 | frames = load_video(video_info['video'], max_num_frames, sample_fps, video_info.get('frame_extraction_fps', None))
115 | frames = resize_video_longside(frames, longsize_resolution)
116 | frames = [pil_to_tensor(frame) for frame in frames]
117 | return frames
118 |
119 |
120 | def load_and_patch_model(model_name, hf_model_path, exp_configs, device):
121 | model_name = model_name if model_name is not None else exp_configs['model_name']
122 | model_name = model_name.lower().replace('-', '').replace('_', '')
123 | if model_name == 'qwen2vl': # QWen2VL
124 | from transformers import Qwen2VLConfig, Qwen2VLForConditionalGeneration
125 | from retake.monkeypatch import patch_qwen2vl, patch_qwen2vl_config
126 | retake.qwen2_vl.DEBUG_MODE = True
127 | patch_qwen2vl(exp_configs['method']) # Replace some functions of QWen2VL with those from ReTaKe
128 | qwen2vl_config = Qwen2VLConfig.from_pretrained(hf_model_path)
129 | qwen2vl_config = patch_qwen2vl_config(qwen2vl_config, exp_configs)
130 | model = Qwen2VLForConditionalGeneration.from_pretrained(
131 | hf_model_path,
132 | config=qwen2vl_config,
133 | torch_dtype=torch.bfloat16,
134 | attn_implementation=exp_configs.get('attn_implementation', None),
135 | device_map=device # "auto"
136 | ).eval()
137 | processor = AutoProcessor.from_pretrained(hf_model_path)
138 | elif model_name in ['llavaonevision', 'llavavideo']: # LLaVA-OneVision, LLaVA-Video
139 | from transformers import LlavaOnevisionConfig, LlavaOnevisionForConditionalGeneration
140 | from retake.monkeypatch import patch_llava_onevision, patch_llava_onevision_config
141 | retake.llava_onevision.DEBUG_MODE = True
142 | patch_llava_onevision(exp_configs['method']) # Replace some functions of LLaVA-Video with those from ReTaKe
143 | llava_onevision_config = LlavaOnevisionConfig.from_pretrained(hf_model_path)
144 | llava_onevision_config = patch_llava_onevision_config(llava_onevision_config, exp_configs)
145 | processor = AutoProcessor.from_pretrained(hf_model_path)
146 | model = LlavaOnevisionForConditionalGeneration.from_pretrained(
147 | hf_model_path,
148 | config=llava_onevision_config,
149 | torch_dtype=torch.bfloat16,
150 | attn_implementation=exp_configs.get('attn_implementation', None),
151 | device_map=device # "auto"
152 | )
153 | else:
154 | raise NotImplementedError
155 | return model, processor
156 |
157 |
158 | DEMO_VIDEO = 'misc/Q8AZ16uBhr8_resized_fps2_mute.mp4'
159 | DEMO_QUESTIONS = [
160 | "As depicted in the video, how is the relationship between the rabbit and human?\nOptions:\nA. Hostile.\nB. Friend.\nC. Cooperator.\nD. No one is correct above.\nAnswer with the option's letter from the given choices directly.",
161 | "What is the impression of the video?\nOptions:\nA. Sad.\nB. Funny.\nC. Horrible.\nD. Silent.\nAnswer with the option's letter from the given choices directly.",
162 | "What is the subject of the video?\nOptions:\nA. Rabbit likes to eat carrots.\nB. How to raise a rabbit.\nC. A rabbit gives people trouble.\nD. A rabbit performs for food.\nAnswer with the option's letter from the given choices directly.",
163 | ]
164 | EXPECTED_ANSWERS = ['A', 'B', 'C']
165 |
166 |
167 | if __name__ == "__main__":
168 | #------------------- Modify the following configs ------------------#
169 | hf_model_path = 'Qwen/Qwen2-VL-7B-Instruct' # TODO: replace to local path if you have trouble downloading huggingface models
170 | model_name = 'qwen2_vl'
171 | # hf_model_path = '/path_to/llava-video-qwen2-7b-hf'
172 | # model_name = 'llava_video'
173 | # hf_model_path = 'llava-hf/llava-onevision-qwen2-7b-ov-hf'
174 | # model_name = 'llava_onevision'
175 |
176 | # NOTE: for Nvidia GPUs
177 | config_path = 'configs/retake_demo.yaml'
178 | device = 'cuda:0'
179 |
180 | # NOTE: for NPUs or GPUs without support for FlashAttention
181 | # config_path = 'configs/retake_demo_npu.yaml'
182 | # device = 'npu:0'
183 |
184 | #------------------------ No need to change ------------------------#
185 | video_info = {"type": "video",
186 | "video": DEMO_VIDEO,
187 | "fps": 2.0}
188 |
189 | exp_configs = load_yaml(config_path)
190 |
191 | model, processor = load_and_patch_model(model_name, hf_model_path, exp_configs, device)
192 |
193 | # Video
194 | video = fetch_video(video_info, exp_configs['max_num_frames'], exp_configs['sample_fps'], exp_configs['longsize_resolution'])
195 | for question, expect_answer in zip(DEMO_QUESTIONS, EXPECTED_ANSWERS):
196 | conversation = [
197 | {
198 | "role": "user",
199 | "content": [
200 | {"type": "video"},
201 | {"type": "text", "text": question},
202 | ],
203 | }
204 | ]
205 |
206 | # Preprocess the inputs
207 | text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
208 | print('Input prompt:\n', text_prompt)
209 |
210 | inputs = processor(text=[text_prompt], videos=[video], padding=True, return_tensors="pt")
211 | inputs = inputs.to(device)
212 | inputs['pixel_values_videos'] = inputs['pixel_values_videos'].to(torch.bfloat16)
213 |
214 | # Inference: Generation of the output
215 | output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=128)
216 | generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
217 | output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
218 | output_text = output_text[0]
219 | print('Output text:\n', output_text)
220 | print('Expected answer:\n', expect_answer)
221 |
--------------------------------------------------------------------------------
/docs/prepare_lvbench.md:
--------------------------------------------------------------------------------
1 | ## Prepare LVBench Dataset
2 |
3 |
4 | ### Step 1: download LVBench data from [huggingface](https://huggingface.co/datasets/THUDM/LVBench/tree/main)
5 | ```bash
6 | git clone https://huggingface.co/datasets/THUDM/LVBench # Contain annotations only
7 | git clone https://huggingface.co/datasets/AIWinter/LVBench # Contain videos only
8 | ```
9 | Move all_files in `AIWinter/LVBench` into `THUDM/LVBench`.
10 |
11 | Denote the root directory of download LVBench dataset as `lvbench_root`, it should has the following structure:
12 | ```
13 | ${lvbench_root}/
14 | ├── docs/
15 | ├── video_info.meta.jsonl
16 | ├── all_videos_split.zip.001
17 | ├── all_videos_split.zip.002
18 | ├── ...
19 | └── all_videos_split.zip.014
20 | ```
21 |
22 |
23 | ### Step 2: Unzip everything
24 | ```bash
25 | cd ${lvbench_root}
26 | cat all_videos_split.zip.* > all_videos.zip
27 | unzip all_videos.zip
28 | ```
29 |
30 |
31 | ### Step 3: Extract frames of all videos
32 | ```bash
33 | cd ${retake_repo_root}
34 | python scripts/utils/frame_extraction.py \
35 | --videofile_tpl ${lvbench_root}/all_videos/'*.mp4' \
36 | --results_dir ${lvbench_root}/video_25fps \
37 | --fps 25 \
38 | --num_workers 32
39 | ```
40 |
41 |
42 | ### Step 4: Build LVBench dataset
43 | ```bash
44 | cd ${retake_repo_root}
45 | python scripts/utils/build_lvbench_dataset.py --hf_root ${lvbench_root}
46 | ```
47 | Note that you can NOT modify folder `${lvbench_root}/video_25fps` after this step, since the absolute path of extracted frames are written into annotation files `lvbench.json`:
48 | ```
49 | retake_repo_root/
50 | ├── dataset/
51 | ├── lvbench/
52 | ├── lvbench.json
53 | ├── ...
54 | ```
--------------------------------------------------------------------------------
/docs/prepare_mlvu.md:
--------------------------------------------------------------------------------
1 | ## Prepare MLVU Dataset
2 |
3 |
4 | ### Step 1: download MLVU dataset from [huggingface](https://huggingface.co/datasets/MLVU/MVLU)
5 | ```bash
6 | git clone https://huggingface.co/datasets/MLVU/MVLU
7 | ```
8 |
9 | Denote the root directory of download MLVU dataset as `mlvu_root`, it should has the following structure:
10 | ```
11 | ${mlvu_root}/
12 | ├── MLVU/
13 | ├── json
14 | ...
15 | ├── video
16 | ...
17 | ├── figs/
18 | ```
19 |
20 |
21 | ### Step 2: Extract frames of all videos
22 | ```bash
23 | cd ${retake_repo_root}
24 | python scripts/utils/frame_extraction.py \
25 | --videofile_tpl ${mlvu_root}/MLVU/video/'*/*.mp4' \
26 | --results_dir ${mlvu_root}/MLVU/video_25fps \
27 | --fps 25 \
28 | --num_workers 32
29 | ```
30 |
31 |
32 | ### Step 3: Build MLVU dataset
33 | ```bash
34 | cd ${retake_repo_root}
35 | python scripts/utils/build_mlvu_dataset.py --hf_root ${mlvu_root}
36 | ```
37 | Note that you can NOT modify folder `${mlvu_root}/MLVU/video_25fps` after this step, since the absolute path of extracted frames are written into annotation files `mlvu.json`:
38 | ```
39 | retake_repo_root/
40 | ├── dataset/
41 | ├── mlvu/
42 | ├── mlvu.json
43 | ├── ...
44 | ```
--------------------------------------------------------------------------------
/docs/prepare_videomme.md:
--------------------------------------------------------------------------------
1 | ## Prepare VideoMME Dataset
2 |
3 |
4 | ### Step 1: download VideoMME dataset from [huggingface](https://huggingface.co/datasets/lmms-lab/Video-MME)
5 | ```bash
6 | git clone https://huggingface.co/datasets/lmms-lab/Video-MME
7 | ```
8 |
9 | Denote the root directory of download VideoMME dataset as `videomme_root`, it should has the following structure:
10 | ```
11 | ${videomme_root}/
12 | ├── videomme/
13 | ├── subtitle.zip
14 | ├── videos_chunked_01.zip
15 | ├── videos_chunked_02.zip
16 | ├── ...
17 | └── videos_chunked_20.zip
18 | ```
19 |
20 |
21 | ### Step 2: Unzip everything
22 | ```bash
23 | cd ${videomme_root}
24 | unzip subtitle.zip
25 | cat videos_chunked_*.zip > videos.zip
26 | unzip videos.zip
27 | ```
28 |
29 |
30 | ### Step 3: Extract frames of all videos
31 | ```bash
32 | cd ${retake_repo_root}
33 | python scripts/utils/frame_extraction.py \
34 | --videofile_tpl ${videomme_root}/data/'*.mp4' \
35 | --results_dir ${videomme_root}/data_25fps \
36 | --fps 25 \
37 | --num_workers 32
38 | ```
39 |
40 |
41 | ### Step 4: Build VideoMME dataset
42 | ```bash
43 | cd ${retake_repo_root}
44 | python scripts/utils/build_videomme_dataset.py \
45 | --hf_qwen2vl7b_path ${PATH_TO_Qwen2-VL-7B-Instruct} \
46 | --hf_root ${videomme_root}
47 | ```
48 | Note that you can NOT modify folder `${videomme_root}/data_25fps` after this step, since the absolute path of extracted frames are written into annotation files `video_mme.json` and `video_mme_subtitle.json`:
49 | ```
50 | retake_repo_root/
51 | ├── dataset/
52 | ├── video_mme/
53 | ├── video_mme_subtitle.json
54 | ├── video_mme.json
55 | ├── ...
56 | ```
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: retake
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python==3.11
6 | - pip:
7 | - torch==2.4.0
8 | - torchvision==0.19.0
9 | - transformers==4.48
10 | - accelerate==0.34.2
11 | - flash-attn==2.6.3
12 | - av==13.1.0
13 | - pyyaml==6.0.2
14 | - opencv-python-headless==4.10.0.84
15 | - pandas==2.2.3
16 | - pysubs2==1.7.3
17 | - pyarrow==17.0.0
18 | - openai==1.56.0
--------------------------------------------------------------------------------
/environment_npu.yaml:
--------------------------------------------------------------------------------
1 | name: retake
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python==3.11
6 | - pip:
7 | - numpy==1.26.4
8 | - scipy==1.14.1
9 | - torch==2.4.0
10 | - torch-npu==2.4.0
11 | - torchvision==0.19.0
12 | - transformers==4.48
13 | - accelerate==0.34.2
14 | - av==13.1.0
15 | - pyyaml==6.0.2
16 | - opencv-python-headless==4.10.0.84
17 | - pandas==2.2.3
18 | - pysubs2==1.7.3
19 | - pyarrow==17.0.0
20 | - openai==1.56.0
--------------------------------------------------------------------------------
/misc/Q8AZ16uBhr8_resized_fps2_mute.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCZwangxiao/video-ReTaKe/e268c32c242061093d694f0cc219794857e71dd8/misc/Q8AZ16uBhr8_resized_fps2_mute.mp4
--------------------------------------------------------------------------------
/misc/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCZwangxiao/video-ReTaKe/e268c32c242061093d694f0cc219794857e71dd8/misc/overview.png
--------------------------------------------------------------------------------
/retake/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import io
4 | import re
5 | import json
6 | import math
7 | import base64
8 | from PIL import Image
9 | import pandas as pd
10 | from tqdm import tqdm
11 | from typing import Optional, List
12 |
13 | try:
14 | os.environ['OPENAI_BASE_URL'] = None
15 | os.environ['OPENAI_API_KEY'] = None
16 | import openai
17 | except:
18 | print("Warning! openai not installed for MLVU evalutation")
19 | import numpy as np
20 |
21 |
22 | class BaseDataset:
23 | def __init__(self,
24 | anno_file: str,
25 | processor_kwargs: str
26 | ) -> None:
27 | self.processor_kwargs = processor_kwargs
28 | # Load annotations
29 | with open(anno_file, 'r') as F:
30 | self.annos = json.load(F)
31 | # Preprocess meta
32 | for anno in self.annos:
33 | # NOTE: Pyarrow caching in LLaMA-Factory will raise error
34 | # for some complicate json data. So dump to jsons.
35 | if type(anno['meta']) == str:
36 | anno['meta'] = json.loads(anno['meta'])
37 |
38 | @staticmethod
39 | def _get_video_sample_extracted_frames(frame_files: List[str], **kwargs) -> int:
40 | video_fps = kwargs.get("video_fps")
41 | video_maxlen = kwargs.get("video_maxlen")
42 | extraction_fps = kwargs.get("video_frame_extraction_fps")
43 | total_frames = len(frame_files)
44 | sample_frames = float(total_frames / extraction_fps) * video_fps
45 | sample_frames = min(total_frames, video_maxlen, sample_frames)
46 | sample_frames = math.floor(sample_frames)
47 | return int(sample_frames / 2) * 2
48 |
49 | @staticmethod
50 | def _preprocess_image(image, **kwargs):
51 | r"""
52 | Pre-processes a single image.
53 | """
54 | image_resolution: int = kwargs.get("image_resolution")
55 | if max(image.width, image.height) > image_resolution:
56 | resize_factor = image_resolution / max(image.width, image.height)
57 | width, height = int(image.width * resize_factor), int(image.height * resize_factor)
58 | image = image.resize((width, height), resample=Image.NEAREST)
59 |
60 | if image.mode != "RGB":
61 | image = image.convert("RGB")
62 |
63 | return image
64 |
65 | def __len__(self):
66 | return len(self.annos)
67 |
68 | def get_video_message(self, video_root: str):
69 | frames = []
70 | frame_files = [
71 | os.path.join(video_root, file) for file in list(sorted(os.listdir(video_root)))
72 | ]
73 | total_frames = len(frame_files)
74 | sample_frames = self._get_video_sample_extracted_frames(frame_files, **self.processor_kwargs)
75 | sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
76 | for frame_idx, frame_file in enumerate(frame_files):
77 | if frame_idx in sample_indices:
78 | # NOTE: Load and resize on the fly can creatly RAM cost of dataloader
79 | image = Image.open(frame_file)
80 | resized_image = self._preprocess_image(image, **self.processor_kwargs)
81 | frames.append(resized_image)
82 |
83 | return frames
84 |
85 | def __getitem__(self, idx):
86 | anno = self.annos[idx]
87 |
88 | question = anno["messages"][0]["content"].replace('