├── .python-version ├── lvu ├── __init__.py ├── scripts │ ├── activate.sh │ ├── create_env.sh │ ├── timing_quickvideo_interleaved.sh │ ├── timing_baseline.sh │ └── timing_quickvideo.sh ├── models │ ├── __init__.py │ ├── qwen25_vl.py │ ├── qwen25_lvu.py │ └── qwen25_lvu_interleaved.py ├── lvu_config.py ├── lvu_cache.py ├── lvu.py └── utils.py ├── assets ├── logo.png └── imgs │ ├── teaser.png │ ├── interleaving_time.png │ ├── video_processing_times.png │ └── kv_pruning_avg_performance.png ├── main.py ├── .gitmodules ├── pyproject.toml ├── LICENSE ├── .vscode └── launch.json ├── .gitignore ├── video_length_timings.py ├── sparsity_timing.py ├── timing.py └── README.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /lvu/__init__.py: -------------------------------------------------------------------------------- 1 | from .lvu import LVU 2 | from .lvu_config import LVUConfig -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/QuickVideo/HEAD/assets/logo.png -------------------------------------------------------------------------------- /assets/imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/QuickVideo/HEAD/assets/imgs/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | print("Hello from lvu!") 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /assets/imgs/interleaving_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/QuickVideo/HEAD/assets/imgs/interleaving_time.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lmms-eval"] 2 | path = lmms-eval 3 | url = https://github.com/jdf-prog/lmms-eval.git 4 | branch = dev/lvu 5 | -------------------------------------------------------------------------------- /assets/imgs/video_processing_times.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/QuickVideo/HEAD/assets/imgs/video_processing_times.png -------------------------------------------------------------------------------- /assets/imgs/kv_pruning_avg_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/QuickVideo/HEAD/assets/imgs/kv_pruning_avg_performance.png -------------------------------------------------------------------------------- /lvu/scripts/activate.sh: -------------------------------------------------------------------------------- 1 | 2 | if [[ "$0" == "${BASH_SOURCE[0]}" ]]; then 3 | echo This must be sourced. 4 | exit 1 5 | fi 6 | 7 | source ./.venv/bin/activate -------------------------------------------------------------------------------- /lvu/scripts/create_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -d ".venv" ]; then 4 | yes | rm -r .venv 5 | fi 6 | 7 | uv sync 8 | uv pip install flash-attn --no-build-isolation 9 | -------------------------------------------------------------------------------- /lvu/scripts/timing_quickvideo_interleaved.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES="0" 4 | export QUICKCODEC_CORES="16" 5 | #export DEEPCODEC_DISABLED="TRUE" 6 | 7 | for i in {1..10}; do 8 | echo "Run #$i" 9 | python -m lvu.lvu "qwen25_lvu_interleaved" "32" "/scratch/b3schnei/movie1080p.BluRay.1hour_30min.mp4" 10 | done 11 | -------------------------------------------------------------------------------- /lvu/scripts/timing_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES="0" 4 | export QUICKCODEC_CORES="16" 5 | export DEEPCODEC_DISABLED="TRUE" 6 | 7 | for i in {1..10}; do 8 | echo "Run #$i" 9 | python -m lvu.lvu --model_type "qwen25_lvu" --video_group_size 0 --video_path "/scratch/b3schnei/movie1080p.BluRay.1hour_30min.mp4" 10 | done 11 | -------------------------------------------------------------------------------- /lvu/scripts/timing_quickvideo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES="0" 4 | export QUICKCODEC_CORES="16" 5 | #export DEEPCODEC_DISABLED="TRUE" 6 | 7 | for i in {1..10}; do 8 | echo "Run #$i" 9 | python -m lvu.lvu --model_type "qwen25_lvu" --video_group_size 32 --video_path "/scratch/b3schnei/movie1080p.BluRay.1hour_30min.mp4" 10 | done 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lvu" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "accelerate>=1.5.2", 9 | "datasets>=3.4.1", 10 | "decord>=0.6.0", 11 | "fire>=0.7.0", 12 | "pillow>=11.1.0", 13 | "qwen-vl-utils>=0.0.10", 14 | "sentencepiece>=0.2.0", 15 | "torch>=2.6.0", 16 | "torchvision>=0.21.0", 17 | "transformers==4.50.0", 18 | "deepcodec==0.0.8", 19 | "setuptools", 20 | "torchcodec==0.2", 21 | "scipy", 22 | ] 23 | 24 | [tool.setuptools.packages.find] 25 | include = ["lvu"] 26 | exclude = ["assets"] -------------------------------------------------------------------------------- /lvu/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from pathlib import Path 3 | cur_dir = Path(__file__).parent 4 | 5 | lvu_init_model_map = {} 6 | lvu_run_model_map = {} 7 | lvu_chat_model_map = {} 8 | 9 | for file in cur_dir.glob("*.py"): 10 | if file.name == "__init__.py": 11 | continue 12 | module_name = file.stem 13 | module = importlib.import_module(f".{module_name}", package=__package__) 14 | assert hasattr(module, "init_lvu_model"), f"Module {module_name} does not have init_lvu_model function." 15 | assert hasattr(module, "run_lvu_model"), f"Module {module_name} does not have run_lvu_model function." 16 | lvu_init_model_map[module_name] = module.init_lvu_model 17 | lvu_run_model_map[module_name] = module.run_lvu_model 18 | if hasattr(module, "chat_lvu_model"): 19 | lvu_chat_model_map[module_name] = module.chat_lvu_model 20 | 21 | __all__ = [] 22 | for module_name in lvu_init_model_map.keys(): 23 | __all__.append(module_name) 24 | 25 | __all__.append("lvu_init_model_map") 26 | __all__.append("lvu_run_model_map") 27 | __all__.append("lvu_chat_model_map") 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Dongfu Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Debug", 6 | "type": "debugpy", 7 | "request": "launch", 8 | "module": "lvu.lvu", 9 | "console": "integratedTerminal", 10 | "args": [ 11 | "qwen25_lvu" 12 | ], 13 | "env": { 14 | "CUDA_VISIBLE_DEVICES": "0", 15 | "QUICKCODEC_CORES": "16", 16 | "DEEPCODEC_DISABLED": "TRUE" 17 | }, 18 | "justMyCode": false 19 | }, 20 | { 21 | "name": "Debug_Interleaved", 22 | "type": "debugpy", 23 | "request": "launch", 24 | "module": "lvu.lvu", 25 | "console": "integratedTerminal", 26 | "env": { 27 | "CUDA_VISIBLE_DEVICES": "0", 28 | "QUICKCODEC_CORES": "8" 29 | }, 30 | "args": [ 31 | "qwen25_lvu_interleaved" 32 | ], 33 | "justMyCode": false 34 | }, 35 | { 36 | "name": "Debug (deepcodec off)", 37 | "type": "debugpy", 38 | "request": "launch", 39 | "module": "lvu.lvu", 40 | "console": "integratedTerminal", 41 | "args": [ 42 | "qwen25_lvu" 43 | ], 44 | "env": { 45 | "CUDA_VISIBLE_DEVICES": "0", 46 | "QUICKCODEC_CORES": "8", 47 | "DEEPCODEC_DISABLED": "TRUE" 48 | }, 49 | "justMyCode": false 50 | }, 51 | ] 52 | } -------------------------------------------------------------------------------- /lvu/lvu_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class LVUConfig: 5 | model_name_or_path: str 6 | model_type: str = "qwen25_vl" 7 | top_k_predict_type: str = "key_norms_small" 8 | top_k: int = None 9 | top_p: float = None 10 | top_k_starting_layer: int = None 11 | do_top_k_for_query: bool = False 12 | adaptive_local_attention: bool = True 13 | video_group_size: int = None # per frame 14 | prefill_prune_starting_layer: int = None 15 | fps: int = None 16 | num_frames: int = 32 17 | use_tqdm: bool = False 18 | extra_kwargs: dict = None 19 | enable: bool = True 20 | cache_dir: str = None 21 | save_video_cache: bool = False 22 | top_k_decay_factor: float = None 23 | top_k_decay_type: str = None 24 | query_based: bool = False 25 | 26 | def __post_init__(self): 27 | # check and auto set default values 28 | if self.top_k_decay_type == "linear" and self.top_k_decay_factor is None: 29 | print(f"Warning: top_k_decay_type is set to {self.top_k_decay_type} but top_k_decay_factor is None. Setting it to 0.5.") 30 | self.top_k_decay_factor = 0.5 31 | if "query" in self.top_k_predict_type: 32 | # this is a query based predict type 33 | self.query_based = True 34 | @dataclass 35 | class LVULayerConfig: 36 | layer_idx: int 37 | total_layers: int 38 | lvu_config: LVUConfig 39 | is_last_layer: bool = False 40 | prune_for_next_layer: bool = False 41 | 42 | def __post_init__(self): 43 | self.is_last_layer = (self.layer_idx == self.total_layers - 1) 44 | if self.lvu_config is None: 45 | self.lvu_config = LVUConfig() 46 | if self.layer_idx is None: 47 | raise ValueError("layer_idx cannot be None") 48 | if self.is_last_layer is None: 49 | raise ValueError("is_last_layer cannot be None") 50 | if isinstance(self.lvu_config.prefill_prune_starting_layer, int) and \ 51 | self.lvu_config.prefill_prune_starting_layer >= 0 and \ 52 | self.layer_idx >= self.lvu_config.prefill_prune_starting_layer: 53 | self.prune_for_next_layer = True 54 | else: 55 | self.prune_for_next_layer = False 56 | 57 | 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | /test* 176 | /logs 177 | /test* 178 | /*.mp4 -------------------------------------------------------------------------------- /lvu/lvu_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.cache_utils import ( 3 | DynamicCache, 4 | Iterable, 5 | List, 6 | Dict, 7 | Optional, 8 | Any, 9 | Tuple, 10 | ) 11 | from .lvu_config import LVUConfig 12 | from PIL import Image 13 | import numpy as np 14 | 15 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 16 | """ 17 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 18 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 19 | """ 20 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 21 | if n_rep == 1: 22 | return hidden_states 23 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 24 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 25 | 26 | 27 | # Assuming your ndarray is called 'array' with shape (C, H, W) 28 | def save_ndarray_as_image(array, filename): 29 | # Convert from (C, H, W) to (H, W, C) 30 | array_hwc = np.transpose(array, (1, 2, 0)) 31 | 32 | # Ensure the array is uint8 type 33 | array_hwc = array_hwc.astype(np.uint8) 34 | 35 | # Handle different numbers of channels 36 | if array.shape[0] == 1: 37 | # Grayscale image - squeeze to remove channel dimension 38 | image = Image.fromarray(array_hwc.squeeze(), mode='L') 39 | elif array.shape[0] == 3: 40 | # RGB image 41 | image = Image.fromarray(array_hwc, mode='RGB') 42 | elif array.shape[0] == 4: 43 | # RGBA image 44 | image = Image.fromarray(array_hwc, mode='RGBA') 45 | else: 46 | raise ValueError(f"Unsupported number of channels: {array.shape[0]}") 47 | 48 | # Save the image 49 | image.save(filename) 50 | 51 | def load_image_as_ndarray(filename, channels_first=True): 52 | # Load the image 53 | image = Image.open(filename) 54 | 55 | # Convert to numpy array 56 | array = np.array(image) 57 | 58 | # If the image is grayscale and we want it in (C, H, W) format 59 | if len(array.shape) == 2 and channels_first: 60 | # Add channel dimension (1, H, W) 61 | array = np.expand_dims(array, axis=0) 62 | elif len(array.shape) == 3 and channels_first: 63 | # Convert from (H, W, C) to (C, H, W) 64 | array = np.transpose(array, (2, 0, 1)) 65 | 66 | return array 67 | 68 | class LVUCache(DynamicCache): 69 | """ 70 | A class to manage caching for LVU models. 71 | Inherits from DynamicCache to provide caching functionality. 72 | """ 73 | 74 | def __init__(self, _distributed_cache_data: Iterable = None, lvu_config: LVUConfig = None): 75 | super().__init__(_distributed_cache_data) 76 | self.lvu_config = lvu_config 77 | # self.key_cache: List[torch.Tensor] = [] 78 | # self.value_cache: List[torch.Tensor] = [] 79 | self.accum_attn_scores: Dict[int, List[torch.Tensor]] = {} 80 | self.prompt_length: int = 0 81 | 82 | def set_prompt_length(self, prompt_length: int=0): 83 | """ 84 | Set the prompt length for the cache. 85 | Args: 86 | prompt_length (int): The length of the prompt. 87 | """ 88 | self.prompt_length = prompt_length 89 | 90 | def update( 91 | self, 92 | key_states: torch.Tensor, 93 | value_states: torch.Tensor, 94 | layer_idx: int, 95 | cache_kwargs: Optional[Dict[str, Any]] = None, 96 | ) -> Tuple[torch.Tensor, torch.Tensor]: 97 | if not self.prompt_length: 98 | return super().update(key_states, value_states, layer_idx, cache_kwargs) 99 | else: 100 | query_states = cache_kwargs["query_states"] # (bz, num_heads, Q, head_dim) 101 | query_states = query_states[:, :, -self.prompt_length:, :] 102 | key_states = key_states[:, :, :-self.prompt_length, :] 103 | value_states = value_states[:, :, :-self.prompt_length, :] 104 | super_result = super().update(key_states, value_states, layer_idx, cache_kwargs) 105 | # postprocess 106 | bsz, num_heads, q_len, head_dim = query_states.shape 107 | num_key_value_heads, k_len = key_states.shape[1:3] 108 | # attention scores of query to key 109 | key_states_repeated = repeat_kv(key_states, num_heads // num_key_value_heads) 110 | attn_scores = torch.einsum("bhqd,bhkd->bhqk", query_states, key_states_repeated) / (head_dim ** 0.5) 111 | attn_scores = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to( 112 | query_states.dtype 113 | ).detach() # # (bz, num_heads, Q, K) 114 | attn_scores = attn_scores.sum(-2).mean(1) # average over num_key_value_heads (bz, k_len) 115 | self.accum_attn_scores[layer_idx] = self.accum_attn_scores.get(layer_idx, []) 116 | self.accum_attn_scores[layer_idx].append(attn_scores) 117 | return super_result -------------------------------------------------------------------------------- /video_length_timings.py: -------------------------------------------------------------------------------- 1 | import time 2 | from deepcodec import VideoReader as DCVR 3 | import sys 4 | import os 5 | import numpy as np 6 | from scipy import stats 7 | import glob 8 | from torchvision import transforms as T 9 | 10 | NUM_RUNS = 5 # Number of runs for averaging and confidence intervals 11 | FIXED_THREADS = 16 # Fixed number of threads for all tests 12 | SAMPLE_FPS = 1 # Sample at 1 frame per second 13 | height = 448 14 | width = 448 15 | 16 | def calculate_ci(times): 17 | mean_time = np.mean(times) 18 | n = len(times) 19 | if n >= 2: 20 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 21 | return f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 22 | return f"{mean_time:.3f} ± 0.000" 23 | 24 | def main(): 25 | if len(sys.argv) < 2: 26 | print("Usage: python script.py ") 27 | sys.exit(1) 28 | 29 | directory_path = sys.argv[1] 30 | video_files = sorted(glob.glob(os.path.join(directory_path, "movie1080p.BluRay.1hour*.mp4"))) 31 | 32 | if not video_files: 33 | print(f"No matching video files found in {directory_path}") 34 | sys.exit(1) 35 | 36 | for video_path in video_files: 37 | print(f"\n===== Testing file: {os.path.basename(video_path)} =====") 38 | 39 | # Get video metadata 40 | temp = DCVR(video_path, num_threads=1) 41 | num_frames = len(temp) 42 | video_fps = round(temp.get_fps()) 43 | duration = num_frames / video_fps 44 | del temp 45 | 46 | minutes, seconds = divmod(duration, 60) 47 | print(f"Duration: {int(minutes)}m {int(seconds)}s, FPS: {video_fps}, Total frames: {num_frames}") 48 | 49 | # Calculate frame indices for 1 FPS sampling 50 | frame_step = video_fps // SAMPLE_FPS 51 | indices = list(range(0, num_frames, frame_step)) 52 | print(f"Sampling {len(indices)} frames at {SAMPLE_FPS} FPS") 53 | 54 | # TorchCodec 55 | try: 56 | import torch 57 | from torchcodec.decoders import VideoDecoder 58 | 59 | resize_transform = T.Resize((height, width), 60 | interpolation=T.InterpolationMode.BICUBIC, 61 | antialias=True) 62 | 63 | decode_times = [] 64 | combined_times = [] 65 | 66 | for _ in range(NUM_RUNS): 67 | # Decoding phase 68 | start_decode = time.time() 69 | device = "cpu" 70 | decoder = VideoDecoder(video_path, device=device, num_ffmpeg_threads=FIXED_THREADS) 71 | frames = decoder.get_frames_at(indices=indices).data 72 | elapsed_decode = time.time() - start_decode 73 | decode_times.append(elapsed_decode) 74 | del decoder # Release decoder immediately 75 | 76 | # Resizing phase 77 | start_resize = time.time() 78 | resized_frames = resize_transform(frames) 79 | elapsed_resize = time.time() - start_resize 80 | combined_times.append(elapsed_decode + elapsed_resize) 81 | del frames, resized_frames # Clean up memory 82 | 83 | print(f"TorchCodec Decode: {calculate_ci(decode_times)} sec") 84 | print(f"TorchCodec Decode+Resize: {calculate_ci(combined_times)} sec") 85 | 86 | except Exception as e: 87 | print(f"TorchCodec failed: {str(e)}") 88 | 89 | # DeepCodec 90 | try: 91 | from deepcodec import VideoReader 92 | 93 | times = [] 94 | for _ in range(NUM_RUNS): 95 | start = time.time() 96 | vr = VideoReader(video_path, num_threads=FIXED_THREADS, 97 | height=height, width=width) 98 | vr.interpolation = "LANCZOS" 99 | _ = vr.get_batch(indices) 100 | elapsed = time.time() - start 101 | times.append(elapsed) 102 | del vr 103 | 104 | print(f"DeepCodec: {calculate_ci(times)} sec (includes resize)") 105 | 106 | except Exception as e: 107 | print(f"DeepCodec failed: {str(e)}") 108 | 109 | # Decord 110 | try: 111 | import decord 112 | from decord import VideoReader as DecordVideoReader 113 | from decord import cpu 114 | 115 | times = [] 116 | for _ in range(NUM_RUNS): 117 | start = time.time() 118 | vr = DecordVideoReader(video_path, ctx=cpu(0), 119 | num_threads=FIXED_THREADS, 120 | height=height, width=width) 121 | _ = vr.get_batch(indices) 122 | elapsed = time.time() - start 123 | times.append(elapsed) 124 | del vr 125 | 126 | print(f"Decord: {calculate_ci(times)} sec (includes resize)") 127 | 128 | except Exception as e: 129 | print(f"Decord failed: {str(e)}") 130 | 131 | if __name__ == "__main__": 132 | main() -------------------------------------------------------------------------------- /sparsity_timing.py: -------------------------------------------------------------------------------- 1 | import time 2 | from deepcodec import VideoReader as DCVR 3 | import sys 4 | import numpy as np 5 | from scipy import stats 6 | from torchvision import transforms as T 7 | 8 | NUM_RUNS = 5 # Number of runs for averaging and confidence intervals 9 | height = 448 10 | width = 448 11 | 12 | def main(): 13 | video_path = sys.argv[1] 14 | num_threads = 16 # Fixed number of threads 15 | temp = DCVR(video_path, num_threads=1) 16 | num_frames = len(temp) 17 | fps = round(temp.get_fps()) 18 | print(f"FPS is {fps}") 19 | del temp 20 | 21 | seconds_between_frames = [1, 2, 4, 8, 16] 22 | for seconds in seconds_between_frames: 23 | frame_interval = seconds * fps 24 | indices = list(range(0, num_frames, frame_interval)) 25 | print(f"\n===== Sampling every {seconds} seconds (interval = {frame_interval} frames) =====") 26 | 27 | # TorchCodec 28 | try: 29 | import torch 30 | from torchcodec.decoders import VideoDecoder 31 | 32 | resize_transform = T.Resize((height, width), interpolation=T.InterpolationMode.BICUBIC, antialias=True) 33 | 34 | decode_times = [] 35 | combined_times = [] 36 | 37 | for _ in range(NUM_RUNS): 38 | # Measure decoding time 39 | start_decode = time.time() 40 | device = "cpu" 41 | decoder = VideoDecoder(video_path, device=device, num_ffmpeg_threads=num_threads) 42 | b = decoder.get_frames_at(indices=indices).data 43 | elapsed_decode = time.time() - start_decode 44 | decode_times.append(elapsed_decode) 45 | del decoder # Release decoder resources 46 | 47 | # Measure resizing time 48 | start_resize = time.time() 49 | resized_frames = resize_transform(b) 50 | elapsed_resize = time.time() - start_resize 51 | combined_time = elapsed_decode + elapsed_resize 52 | combined_times.append(combined_time) 53 | del b, resized_frames # Free memory 54 | 55 | def calculate_ci(times): 56 | mean_time = np.mean(times) 57 | n = len(times) 58 | if n >= 2: 59 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 60 | return f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 61 | return f"{mean_time:.3f} ± 0.000" 62 | 63 | decode_ci = calculate_ci(decode_times) 64 | combined_ci = calculate_ci(combined_times) 65 | 66 | print(f"TorchCodec Decode Only: {decode_ci} sec (95% CI over {len(decode_times)} runs)") 67 | print(f"TorchCodec Combined (Decode + Resize): {combined_ci} sec (95% CI over {len(combined_times)} runs)") 68 | 69 | except Exception as e: 70 | print(f"TorchCodec failed: {str(e)}") 71 | 72 | # DeepCodec 73 | try: 74 | from deepcodec import VideoReader 75 | 76 | times = [] 77 | for _ in range(NUM_RUNS): 78 | start = time.time() 79 | vr = VideoReader(video_path, num_threads=num_threads, height=height, width=width) 80 | vr.interpolation = "LANCZOS" 81 | _ = vr.get_batch(indices) 82 | elapsed = time.time() - start 83 | times.append(elapsed) 84 | del vr 85 | 86 | mean_time = np.mean(times) 87 | std_dev = np.std(times, ddof=1) 88 | n = len(times) 89 | if n >= 2: 90 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 91 | ci_str = f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 92 | else: 93 | ci_str = f"{mean_time:.3f} ± 0.000" 94 | print(f"DeepCodec: {ci_str} sec (95% CI over {n} runs)") 95 | except Exception as e: 96 | print(f"DeepCodec failed: {str(e)}") 97 | 98 | # Decord 99 | try: 100 | import decord 101 | from decord import VideoReader as DecordVideoReader 102 | from decord import cpu 103 | 104 | times = [] 105 | for _ in range(NUM_RUNS): 106 | start = time.time() 107 | vr = DecordVideoReader(video_path, ctx=cpu(0), num_threads=num_threads, height=height, width=width) 108 | _ = vr.get_batch(indices) 109 | elapsed = time.time() - start 110 | times.append(elapsed) 111 | del vr 112 | 113 | mean_time = np.mean(times) 114 | std_dev = np.std(times, ddof=1) 115 | n = len(times) 116 | if n >= 2: 117 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 118 | ci_str = f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 119 | else: 120 | ci_str = f"{mean_time:.3f} ± 0.000" 121 | print(f"Decord: {ci_str} sec (95% CI over {n} runs)") 122 | except Exception as e: 123 | print(f"Decord failed: {str(e)}") 124 | 125 | if __name__ == "__main__": 126 | main() -------------------------------------------------------------------------------- /timing.py: -------------------------------------------------------------------------------- 1 | import time 2 | from deepcodec import VideoReader as DCVR 3 | import sys 4 | import numpy as np 5 | from scipy import stats 6 | from torchvision import transforms as T 7 | 8 | 9 | NUM_RUNS = 5 # Number of runs for averaging and confidence intervals 10 | 11 | height = 448 12 | width = 448 13 | 14 | def main(): 15 | video_path = sys.argv[1] 16 | max_num_threads = [2, 4, 8, 16, 32] 17 | temp = DCVR(video_path, num_threads=1) 18 | num_frames = len(temp) 19 | fps = round(temp.get_fps()) 20 | print(f"FPS is {fps}") 21 | indices = list(range(0, num_frames, fps)) 22 | del temp 23 | 24 | for thread in max_num_threads: 25 | print(f"\n===== Testing with {thread} threads =====") 26 | 27 | # TorchCodec 28 | try: 29 | import torch 30 | from torchcodec.decoders import VideoDecoder 31 | resize_transform = T.Resize((448, 448), interpolation=T.InterpolationMode.BICUBIC, antialias=True) 32 | 33 | times = [] 34 | decode_times = [] 35 | combined_times = [] 36 | 37 | for _ in range(NUM_RUNS): 38 | # Measure decoding time 39 | start_decode = time.time() 40 | device = "cpu" 41 | decoder = VideoDecoder(video_path, device=device, num_ffmpeg_threads=thread) 42 | b = decoder.get_frames_at(indices=indices).data 43 | elapsed_decode = time.time() - start_decode 44 | decode_times.append(elapsed_decode) 45 | del decoder # Ensure decoder is released 46 | 47 | #print(f"Shape before resize: {b.shape}") 48 | # Measure resizing time 49 | start_resize = time.time() 50 | r = resize_transform(b) 51 | elapsed_resize = time.time() - start_resize 52 | #print(f"Shape after resize: {r.shape}") 53 | del b 54 | del r 55 | # Track combined time (decode + resize) 56 | combined_time = elapsed_decode + elapsed_resize 57 | combined_times.append(combined_time) 58 | 59 | def calculate_ci(times): 60 | mean_time = np.mean(times) 61 | n = len(times) 62 | if n >= 2: 63 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 64 | return f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 65 | return f"{mean_time:.3f} ± 0.000" 66 | 67 | decode_ci = calculate_ci(decode_times) 68 | combined_ci = calculate_ci(combined_times) 69 | 70 | print(f"Decode Only: {decode_ci} sec (95% CI over {len(decode_times)} runs)") 71 | print(f"Combined (Decode + Resize): {combined_ci} sec (95% CI over {len(combined_times)} runs)") 72 | 73 | except Exception as e: 74 | print(f"TorchCodec failed: {str(e)}") 75 | 76 | # DeepCodec 77 | try: 78 | from deepcodec import VideoReader 79 | 80 | times = [] 81 | for _ in range(NUM_RUNS): 82 | start = time.time() 83 | vr = VideoReader(video_path, num_threads=thread, height=height, width=width) 84 | vr.interpolation = "LANCZOS" 85 | _ = vr.get_batch(indices) 86 | elapsed = time.time() - start 87 | times.append(elapsed) 88 | del vr 89 | 90 | mean_time = np.mean(times) 91 | std_dev = np.std(times, ddof=1) 92 | n = len(times) 93 | if n >= 2: 94 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 95 | ci_str = f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 96 | else: 97 | ci_str = f"{mean_time:.3f} ± 0.000" 98 | 99 | print(f"DeepCodec: {ci_str} sec (95% CI over {n} runs)") 100 | 101 | except Exception as e: 102 | print(f"DeepCodec failed: {str(e)}") 103 | 104 | # Decord 105 | try: 106 | import decord 107 | from decord import VideoReader as DecordVideoReader 108 | from decord import cpu 109 | 110 | times = [] 111 | for _ in range(NUM_RUNS): 112 | start = time.time() 113 | vr = DecordVideoReader(video_path, ctx=cpu(0), num_threads=thread, height=height, width=width) 114 | _ = vr.get_batch(indices) 115 | elapsed = time.time() - start 116 | times.append(elapsed) 117 | del vr 118 | 119 | mean_time = np.mean(times) 120 | std_dev = np.std(times, ddof=1) 121 | n = len(times) 122 | if n >= 2: 123 | ci = stats.t.interval(0.95, n-1, loc=mean_time, scale=stats.sem(times)) 124 | ci_str = f"{mean_time:.3f} ± {(ci[1] - ci[0])/2:.3f}" 125 | else: 126 | ci_str = f"{mean_time:.3f} ± 0.000" 127 | 128 | print(f"Decord: {ci_str} sec (95% CI over {n} runs)") 129 | 130 | except Exception as e: 131 | print(f"Decord failed: {str(e)}") 132 | 133 | if __name__ == "__main__": 134 | main() -------------------------------------------------------------------------------- /lvu/lvu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import AutoProcessor, AutoModelForImageTextToText 4 | from .models import lvu_init_model_map, lvu_run_model_map, lvu_chat_model_map 5 | from .lvu_config import LVUConfig 6 | 7 | class LVU: 8 | def __init__(self, config, model=None, processor=None, model_init_kwargs={}): 9 | self.config = config 10 | if model is None: 11 | model_init_kwargs = { 12 | "torch_dtype": torch.bfloat16, 13 | "device_map": "auto", 14 | "attn_implementation": "flash_attention_2", 15 | } 16 | model = AutoModelForImageTextToText.from_pretrained(config.model_name_or_path, **model_init_kwargs) 17 | 18 | # time processing 19 | if processor is None: 20 | processor = AutoProcessor.from_pretrained(config.model_name_or_path) 21 | 22 | self.model = model 23 | self.processor = processor 24 | self.model = self.init_lvu() 25 | 26 | def run_model_func(self, question, video_path, **generation_kwargs): 27 | raise NotImplementedError("run_model_func not implemented.") 28 | 29 | def chat_model_func(self, messages, **generation_kwargs): 30 | raise NotImplementedError("chat_model_func not implemented.") 31 | 32 | def init_lvu(self): 33 | if self.config.model_type not in lvu_init_model_map: 34 | raise ValueError(f"Model type {self.config.model_type} not supported.") 35 | 36 | init_model_func = lvu_init_model_map[self.config.model_type] 37 | run_model_func = lvu_run_model_map[self.config.model_type] 38 | model = init_model_func(self.model, self.config) 39 | self.run_model_func = run_model_func.__get__(self) 40 | if self.config.model_type in lvu_chat_model_map: 41 | self.chat_model_func = lvu_chat_model_map[self.config.model_type].__get__(self) 42 | 43 | return model 44 | 45 | def generate(self, question, video_path, **generation_kwargs): 46 | if self.config.model_type not in lvu_run_model_map: 47 | raise ValueError(f"Model type {self.config.model_type} not supported.") 48 | 49 | output = self.run_model_func(question, video_path, **generation_kwargs) 50 | 51 | return output 52 | 53 | def chat(self, messages:dict, **generation_kwargs): 54 | if self.config.model_type not in lvu_run_model_map: 55 | raise ValueError(f"Model type {self.config.model_type} not supported.") 56 | output = self.chat_model_func(messages, **generation_kwargs) 57 | return output 58 | 59 | def main( 60 | model_name_or_path: str = "Qwen/Qwen2.5-VL-7B-Instruct", 61 | model_type: str = "qwen25_lvu", 62 | video_group_size: int = 1, 63 | video_path: str = "Q8AZ16uBhr8_resized_fps2_mute.mp4", 64 | top_k_predict_type: str = "key_norms_small", 65 | num_frames=64, 66 | ): 67 | assert isinstance(video_path, str), "video_path should be a string." 68 | assert os.path.exists(video_path), f"video_path {video_path} does not exist." 69 | config = LVUConfig( 70 | model_name_or_path=model_name_or_path, 71 | model_type=model_type, 72 | # top_k_predict_type="query_attention_weights", 73 | # top_k_predict_type="query_attention_weights_by_value_norm", 74 | top_k_predict_type=top_k_predict_type, 75 | video_group_size=video_group_size, 76 | top_k=None, 77 | top_p=0.2, 78 | prefill_prune_starting_layer=None, 79 | adaptive_local_attention=True, 80 | num_frames=num_frames, 81 | save_video_cache=True, 82 | # fps=1, 83 | use_tqdm=True, 84 | # top_k_decay_type="linear", 85 | # top_k_decay_factor=0.33, 86 | ) 87 | lvu = LVU(config) 88 | 89 | # question = "Describe this video." 90 | # video_path = "Q8AZ16uBhr8_resized_fps2_mute.mp4" 91 | # generation_kwargs = { 92 | # "max_new_tokens": 512, 93 | # "do_sample": False, 94 | # "top_p": 1.0, 95 | # } 96 | # output = lvu.generate(question, video_path, **generation_kwargs) 97 | # print(output) 98 | 99 | DEMO_QUESTIONS = [ 100 | "As depicted in the video, how is the relationship between the rabbit and human?\nOptions:\nA. Hostile.\nB. Friend.\nC. Cooperator.\nD. No one is correct above.\nAnswer with the option's letter from the given choices directly.", 101 | # "What is the impression of the video?\nOptions:\nA. Sad.\nB. Funny.\nC. Horrible.\nD. Silent.\nAnswer with the option's letter from the given choices directly.", 102 | # "What is the subject of the video?\nOptions:\nA. Rabbit likes to eat carrots.\nB. How to raise a rabbit.\nC. A rabbit gives people trouble.\nD. A rabbit performs for food.\nAnswer with the option's letter from the given choices directly.", 103 | ] 104 | EXPECTED_ANSWERS = ['A', 'B', 'C'] 105 | 106 | for question, expected_answer in zip(DEMO_QUESTIONS, EXPECTED_ANSWERS): 107 | generation_kwargs = { 108 | "max_new_tokens": 512, 109 | "do_sample": False, 110 | "top_p": 1.0, 111 | } 112 | output = lvu.generate(question, video_path, **generation_kwargs) 113 | print(f"Question: {question}") 114 | print(f"Expected Answer: {expected_answer}") 115 | print(f"Model Output: {output}") 116 | 117 | if __name__ == "__main__": 118 | import fire 119 | fire.Fire(main) -------------------------------------------------------------------------------- /lvu/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torch.nn.functional as F 4 | from typing import Tuple, Union, Optional 5 | from .lvu_config import LVUConfig, LVULayerConfig 6 | from .lvu_cache import DynamicCache, LVUCache 7 | import math 8 | import time 9 | import threading 10 | from queue import Queue 11 | import numpy as np 12 | from PIL import Image 13 | import os 14 | 15 | def get_top_k_mask_to_predict(attn_weights, keys, values, outputs, top_k=100, predict_type="attention_weights"): 16 | """ 17 | Args: 18 | attn_weights: (bz, 1, Q_len, K_len) or (bz, K_len) 19 | keys: (bz, num_heads, Q_len, C) 20 | values: (bz, num_heads, K_len, C) 21 | outputs: (bz, Q_len, C) 22 | Returns: 23 | top_k_mask: (bz, K_len) 24 | """ 25 | if top_k <= 0: 26 | return None 27 | # random.seed(0) 28 | bz, _, k_len, _ = values.shape 29 | bz_top_k_idxs = [] 30 | for bz_i in range(bz): 31 | if attn_weights is not None: 32 | if attn_weights.dim() == 4: 33 | attn_weights_i = attn_weights[bz_i].mean(0)[:, -k_len:] # (K_len, K_len) 34 | elif attn_weights.dim() == 2: 35 | attn_weights_i = attn_weights[bz_i] # (k_len) 36 | else: 37 | raise ValueError(f"Unknown attn_weights shape: {attn_weights.shape}") 38 | else: 39 | attn_weights_i = None 40 | keys_i = keys[bz_i] 41 | values_i = values[bz_i] 42 | outputs_i = outputs[bz_i] 43 | if predict_type == "salient_tokens": 44 | slident_value = [] 45 | for i in range(len(attn_weights_i)): 46 | weights = attn_weights_i[i:, i] 47 | slident_value.append(weights.std().item() + weights.mean().item()) 48 | top_k_idxs = sorted(range(len(slident_value)), key=lambda x: slident_value[x], reverse=True)[:top_k] 49 | elif predict_type == "attention_weights": 50 | mean_weights = [] 51 | for i in range(len(attn_weights_i)): 52 | weights = attn_weights_i[i:, i] 53 | mean_weights.append(weights.mean().item()) 54 | top_k_idxs = sorted(range(len(mean_weights)), key=lambda x: mean_weights[x], reverse=True)[:top_k] 55 | elif predict_type == "query_attention_weights": 56 | assert attn_weights_i is not None and attn_weights_i.dim() == 1, f"attn_weights_i should be 1D, but got {attn_weights_i.shape}" 57 | top_k_idxs = attn_weights_i.argsort(descending=True)[:top_k].tolist() 58 | elif predict_type == "query_attention_weights_by_value_norm": 59 | assert attn_weights_i is not None and attn_weights_i.dim() == 1, f"attn_weights should be 1D, but got {attn_weights_i.shape}" 60 | cur_layer_value_vectors = values_i.transpose(0, 1).flatten(1, 2) 61 | vector_norms = cur_layer_value_vectors.norm(2, dim=-1) 62 | weighted_vector_norms = attn_weights_i * vector_norms 63 | top_k_idxs = weighted_vector_norms.argsort(descending=True)[:top_k].tolist() 64 | elif predict_type == "attention_weights_sum": 65 | sum_weights = [] 66 | for i in range(len(attn_weights_i)): 67 | weights = attn_weights_i[i:, i] 68 | sum_weights.append(weights.sum().item()) 69 | top_k_idxs = sorted(range(len(sum_weights)), key=lambda x: sum_weights[x], reverse=True)[:top_k] 70 | elif predict_type == "attention_weights_sum_head_tail": 71 | sum_weights = [] 72 | for i in range(len(attn_weights_i)): 73 | weights = attn_weights_i[i:, i] 74 | sum_weights.append(weights.sum().item()) 75 | top_k_idxs = sorted(range(len(sum_weights)), key=lambda x: sum_weights[x], reverse=True) 76 | top_k_idxs = top_k_idxs[:top_k//2] + top_k_idxs[-top_k//2:] 77 | elif predict_type == "attention_weights_sum_per_image": 78 | sum_weights = [] 79 | for i in range(len(attn_weights_i)): 80 | weights = attn_weights_i[i:i+258, i] # 258 is the number of tokens in an image 81 | sum_weights.append(weights.sum().item()) 82 | top_k_idxs = sorted(range(len(sum_weights)), key=lambda x: sum_weights[x], reverse=True)[:top_k] 83 | elif predict_type == "attention_weights_sum_with_random": 84 | sum_weights = [] 85 | for i in range(len(attn_weights_i)): 86 | weights = attn_weights_i[i:, i] 87 | sum_weights.append(weights.sum().item()) 88 | top_k_idxs = sorted(range(len(sum_weights)), key=lambda x: sum_weights[x], reverse=True) 89 | top_k_idxs = top_k_idxs[:top_k//2] 90 | random_top_k_idxs = list(set(list(range(len(sum_weights)))) - set(top_k_idxs)) 91 | random_top_k_idxs = random.sample(random_top_k_idxs, min(top_k//2, len(random_top_k_idxs))) 92 | top_k_idxs.extend(random_top_k_idxs) 93 | elif predict_type == "attention_weights_deduplication": 94 | # pivot:retained tokens = 1:32 95 | num_pivot_tokens = (top_k - 1) // 2 + 1 96 | sum_weights = [] 97 | for i in range(len(attn_weights_i)): 98 | weights = attn_weights_i[i:, i] 99 | sum_weights.append(weights.sum().item()) 100 | top_k_idxs = sorted(range(len(sum_weights)), key=lambda x: sum_weights[x], reverse=True) 101 | top_k_idxs, other_top_k_idxs = top_k_idxs[:num_pivot_tokens], top_k_idxs[num_pivot_tokens:] 102 | # select num_other_tokens from other_top_k_idxs by the lowest cosine similarity 103 | cur_layer_value_vectors = values_i.transpose(0, 1).flatten(1, 2) 104 | local_self_attn_value_vectors = cur_layer_value_vectors[:attn_weights_i.shape[0]] 105 | pivot_tokens_values = local_self_attn_value_vectors[top_k_idxs] # (P, C) 106 | other_tokens_values = local_self_attn_value_vectors[other_top_k_idxs] # (O, C) 107 | # Step 1: Normalize both sets of vectors 108 | pivot_tokens_normalized = F.normalize(pivot_tokens_values, p=2, dim=1) # Normalize along embedding dimension 109 | other_tokens_normalized = F.normalize(other_tokens_values, p=2, dim=1) # Normalize along embedding dimension 110 | 111 | # Step 2: Compute the cosine similarity matrix 112 | # This performs a matrix multiplication: (P, C) × (C, O) = (P, O) 113 | cosine_similarity_matrix = torch.matmul(pivot_tokens_normalized, other_tokens_normalized.transpose(0, 1)) 114 | top_k_idxs.extend([other_top_k_idxs[j] for j in cosine_similarity_matrix.mean(dim=0).argsort()[:top_k - num_pivot_tokens]]) 115 | 116 | # # select the num_pick_tokens from other_top_k_idxs for each pivot token 117 | # for i in range(len(top_k_idxs)): 118 | # pivot_cosine_similarity = cosine_similarity_matrix[i] 119 | # top_k_idxs.extend([other_top_k_idxs[j] for j in pivot_cosine_similarity.argsort()[:num_pivot_tokens]]) 120 | top_k_idxs = list(set(top_k_idxs)) 121 | elif predict_type == "vector_norms": 122 | cur_layer_value_vectors = values_i.transpose(0, 1).flatten(1, 2) 123 | vector_norms = cur_layer_value_vectors.norm(2, dim=-1) 124 | top_k_idxs = vector_norms.argsort(descending=True)[:top_k].tolist() 125 | elif predict_type == "vector_norms_small": 126 | cur_layer_value_vectors = values_i.transpose(0, 1).flatten(1, 2) 127 | vector_norms = cur_layer_value_vectors.norm(2, dim=-1) 128 | top_k_idxs = vector_norms.argsort(descending=False)[:top_k].tolist() 129 | elif predict_type == "key_norms": 130 | cur_layer_key_vectors = keys_i.transpose(0, 1).flatten(1, 2) 131 | key_norms = cur_layer_key_vectors.norm(2, dim=-1) 132 | top_k_idxs = key_norms.argsort(descending=True)[:top_k].tolist() 133 | elif predict_type == "key_norms_small": 134 | cur_layer_key_vectors = keys_i.transpose(0, 1).flatten(1, 2) 135 | key_norms = cur_layer_key_vectors.norm(2, dim=-1) 136 | top_k_idxs = key_norms.argsort(descending=False)[:top_k].tolist() 137 | elif predict_type == "key_norms_small_random": 138 | # half of the top_k tokens are selected by the highest key norms, and the other half are randomly selected 139 | cur_layer_key_vectors = keys_i.transpose(0, 1).flatten(1, 2) 140 | key_norms = cur_layer_key_vectors.norm(2, dim=-1) 141 | sorted_idxs = key_norms.argsort(descending=False) 142 | top_k_idxs = sorted_idxs[:top_k//2].tolist() 143 | random_top_k_idxs = sorted_idxs[top_k//2:].tolist() 144 | random_top_k_idxs = random.sample(random_top_k_idxs, min(top_k//2, len(random_top_k_idxs))) 145 | top_k_idxs.extend(random_top_k_idxs) 146 | elif predict_type == "random": 147 | top_k_idxs = random.sample(range(k_len), top_k) 148 | if 0 not in top_k_idxs: 149 | top_k_idxs.append(0) 150 | elif predict_type == "key_norms_small_deduplication": 151 | num_pivot_tokens = (top_k - 1) // 16 + 1 152 | key_vectors = keys_i.transpose(0, 1).flatten(1, 2) 153 | key_norms = key_vectors.norm(2, dim=-1) 154 | sorted_idxs = key_norms.argsort(descending=False) 155 | top_k_idxs = sorted_idxs[:num_pivot_tokens].tolist() 156 | other_top_k_idxs = sorted_idxs[num_pivot_tokens:].tolist() 157 | # select num_other_tokens from other_top_k_idxs by the lowest cosine similarity 158 | # keys_i: (num_heads, Q_len, C) 159 | normalized_key_vectors = F.normalize(key_vectors, p=2, dim=-1) 160 | pivot_key_vectors = normalized_key_vectors[top_k_idxs] # (P, C) 161 | other_key_vectors = normalized_key_vectors[other_top_k_idxs] 162 | cosine_similarity_matrix = torch.matmul(pivot_key_vectors, other_key_vectors.transpose(0, 1)) 163 | top_k_idxs.extend([other_top_k_idxs[j] for j in cosine_similarity_matrix.mean(dim=0).argsort()[:top_k - num_pivot_tokens]]) 164 | top_k_idxs = list(set(top_k_idxs)) 165 | elif predict_type == "key_weighted_vector_norms": 166 | cur_layer_key_vectors = keys_i.transpose(0, 1).flatten(1, 2) 167 | key_norms = cur_layer_key_vectors.norm(2, dim=-1) 168 | # softmax the key norms 169 | key_norms = F.softmax(key_norms, dim=-1) 170 | cur_layer_value_vectors = values_i.transpose(0, 1).flatten(1, 2) 171 | value_norms = cur_layer_value_vectors.norm(2, dim=-1) 172 | weighted_norms = key_norms * value_norms 173 | top_k_idxs = weighted_norms.argsort(descending=True)[:top_k].tolist() 174 | elif predict_type == "output_norms": 175 | outputs_norms = outputs_i.norm(2, dim=-1) 176 | top_k_idxs = outputs_norms.argsort(descending=True)[:top_k].tolist() 177 | elif predict_type == "weighted_norms": 178 | weights = attn_weights_i # (Q_len, K_len) 179 | cur_layer_value_vectors = values_i.transpose(0, 1).flatten(1, 2) # (K_len, C) 180 | all_weighted_norms = [] 181 | for q_i in range(len(weights)): 182 | cur_weights = weights[q_i] 183 | weighted_vectors = cur_weights.unsqueeze(-1) * cur_layer_value_vectors 184 | weighted_norms = weighted_vectors.norm(2, dim=-1) 185 | all_weighted_norms.append(weighted_norms) 186 | all_weighted_norms = torch.stack(all_weighted_norms, dim=0).mean(dim=0) 187 | top_k_idxs = all_weighted_norms.argsort(descending=True)[:top_k].tolist() 188 | else: 189 | raise ValueError(f"Unknown predict type: {predict_type}") 190 | bz_top_k_idxs.append(top_k_idxs) 191 | bz_top_k_idxs = torch.tensor(bz_top_k_idxs, device=values.device) 192 | top_k_select_mask = torch.zeros(bz, k_len, dtype=torch.bool, device=values.device) 193 | top_k_select_mask.scatter_(1, bz_top_k_idxs, 1) 194 | return top_k_select_mask 195 | 196 | 197 | def post_process_kv_cache( 198 | hidden_states: torch.Tensor, 199 | attention_mask: torch.Tensor, 200 | position_ids: torch.LongTensor, 201 | cache_position: torch.Tensor, 202 | position_embeddings: torch.Tensor, 203 | attn_weights: torch.Tensor, 204 | present_key_value: Union[Tuple[torch.Tensor, torch.Tensor], DynamicCache, LVUCache], 205 | lvu_layer_config: LVULayerConfig, 206 | ): 207 | """ 208 | Args: 209 | hidden_states: (bz, Q_len, C) 210 | attn_weights: (bz, 1, Q_len, K_len) 211 | position_ids: (bz, Q_len) 212 | cache_position: (bz, Q_len) 213 | position_embeddings: (bz, Q_len, C) 214 | present_key_value: keys and values: ((bz, num_heads, K_len, C), (bz, num_heads, K_len, C)) 215 | values: 216 | top_k: int 217 | predict_type: str 218 | Returns: 219 | hidden_states: (bz, top_k, C) 220 | attention_mask: (bz, top_k) or None 221 | position_ids: (bz, top_k) or None 222 | cache_position: (bz, top_k) or None 223 | position_embeddings: (bz, top_k, C) or None 224 | present_key_value: ((bz, num_heads, top_k, C), (bz, num_heads, top_k, C)) 225 | 226 | """ 227 | if lvu_layer_config is None: 228 | return hidden_states, attention_mask, position_ids, cache_position, position_embeddings, present_key_value 229 | lvu_config = lvu_layer_config.lvu_config 230 | 231 | top_k = lvu_config.top_k 232 | top_p = lvu_config.top_p 233 | predict_type = lvu_config.top_k_predict_type 234 | layer_idx = lvu_layer_config.layer_idx 235 | prune_for_next_layer = lvu_layer_config.prune_for_next_layer 236 | q_len = hidden_states.shape[1] 237 | if isinstance(present_key_value, LVUCache) and present_key_value.prompt_length > 0: 238 | q_len -= present_key_value.prompt_length 239 | attn_weights = present_key_value.accum_attn_scores[layer_idx][-1] 240 | 241 | if top_p is not None and top_p >= 0: 242 | top_k = min((top_k or q_len), int(q_len * top_p)) 243 | 244 | if not lvu_config.top_k_decay_type: 245 | top_k = top_k 246 | elif lvu_config.top_k_decay_type == "linear": 247 | top_k = top_k - int(top_k * (layer_idx / lvu_layer_config.total_layers)) 248 | elif lvu_config.top_k_decay_type == "exponential": 249 | top_k = int(top_k * (lvu_config.top_k_decay_factor ** layer_idx)) 250 | else: 251 | raise ValueError(f"Unknown top_k_decay_type: {lvu_config.top_k_decay_type}") 252 | if not lvu_config.enable or not top_k or top_k <= 0 or q_len <= top_k or \ 253 | (isinstance(lvu_config.top_k_starting_layer, int) and lvu_config.top_k_starting_layer > 0 and lvu_config.layer_idx < lvu_config.top_k_starting_layer): 254 | # no need to prune 255 | return hidden_states, attention_mask, position_ids, cache_position, position_embeddings, present_key_value 256 | 257 | if isinstance(present_key_value, DynamicCache): 258 | keys, values = present_key_value[layer_idx] 259 | elif isinstance(present_key_value, tuple): 260 | keys, values = present_key_value 261 | else: 262 | raise ValueError(f"Unknown present_key_value type: {type(present_key_value)}") 263 | bz = keys.shape[0] 264 | assert bz == 1, f"Only support batch size 1 for now, but got {bz}" 265 | 266 | # only process the current new k 267 | past_keys = keys[:, :, :-q_len] 268 | past_values = values[:, :, :-q_len] 269 | keys = keys[:, :, -q_len:] 270 | values = values[:, :, -q_len:] 271 | old_k_shape = keys.shape 272 | 273 | top_k_select_mask = get_top_k_mask_to_predict(attn_weights, keys, values, hidden_states[:, :q_len], top_k=top_k, predict_type=predict_type) 274 | 275 | top_k_keys_list = [] 276 | top_k_values_list = [] 277 | top_k_hidden_states_list = [] if prune_for_next_layer else None 278 | top_k_attention_mask_list = [] if prune_for_next_layer else None 279 | top_k_position_ids_list = [] if prune_for_next_layer else None 280 | top_k_cache_position_list = [] if prune_for_next_layer else None 281 | top_k_position_embeddings_list = [] if prune_for_next_layer else None 282 | for bz_i in range(bz): 283 | top_k_select_mask_i = top_k_select_mask[bz_i] 284 | indices = torch.nonzero(top_k_select_mask_i, as_tuple=True)[0].cpu() 285 | assert len(indices) == top_k, f"top_k_select_mask_i: {top_k_select_mask_i}, indices: {indices}" 286 | 287 | bz_top_k_keys = keys[bz_i][:, indices] 288 | bz_top_k_values = values[bz_i][:, indices] 289 | top_k_keys_list.append(bz_top_k_keys) 290 | top_k_values_list.append(bz_top_k_values) 291 | 292 | if prune_for_next_layer: 293 | bz_top_k_hidden_states = hidden_states[bz_i][indices] 294 | bz_top_k_cache_position = cache_position[indices] 295 | if position_ids.dim() == 3: 296 | # (constant, bz, q_len) 297 | bz_top_k_position_ids = position_ids[:, bz_i][:, indices] 298 | elif position_ids.dim() == 2: 299 | # (bz, q_len) 300 | bz_top_k_position_ids = position_ids[bz_i][indices] 301 | if isinstance(position_embeddings, tuple): 302 | bz_top_k_position_embeddings = [] 303 | for x in position_embeddings: 304 | if x.dim() == 4: 305 | # (constant, bz, q_len, c) 306 | bz_top_k_position_embeddings.append(x[:, bz_i][:, indices]) 307 | elif x.dim() == 3: 308 | # (bz, q_len, c) 309 | bz_top_k_position_embeddings.append(x[bz_i][indices]) 310 | else: 311 | raise ValueError(f"Unknown position_embeddings shape: {x.shape}") 312 | bz_top_k_position_embeddings = tuple(bz_top_k_position_embeddings) 313 | elif position_embeddings.dim() == 3: 314 | # (bz, q_len, c) 315 | bz_top_k_position_embeddings = position_embeddings[bz_i][indices] 316 | else: 317 | raise ValueError(f"Unknown position_embeddings type: {type(position_embeddings)}") 318 | if attention_mask is not None: 319 | if attention_mask.dim() == 2: 320 | bz_top_k_attention_mask = attention_mask[bz_i][indices] 321 | elif attention_mask.dim() == 4: 322 | bz_top_k_attention_mask = attention_mask[bz_i][:, indices, indices] 323 | else: 324 | raise ValueError(f"Unknown attention_mask shape: {attention_mask.shape}") 325 | else: 326 | bz_top_k_attention_mask = None 327 | top_k_hidden_states_list.append(bz_top_k_hidden_states) 328 | top_k_attention_mask_list.append(bz_top_k_attention_mask) 329 | top_k_position_ids_list.append(bz_top_k_position_ids) 330 | top_k_cache_position_list.append(bz_top_k_cache_position) 331 | top_k_position_embeddings_list.append(bz_top_k_position_embeddings) 332 | 333 | top_k_keys = torch.stack(top_k_keys_list, dim=0) 334 | top_k_values = torch.stack(top_k_values_list, dim=0) 335 | keys = torch.cat([past_keys, top_k_keys], dim=2) 336 | values = torch.cat([past_values, top_k_values], dim=2) 337 | if isinstance(present_key_value, DynamicCache): 338 | # present_key_value.update(keys, values, layer_idx) 339 | present_key_value.key_cache[layer_idx] = keys 340 | present_key_value.value_cache[layer_idx] = values 341 | else: 342 | present_key_value = (keys, values) 343 | 344 | if prune_for_next_layer: 345 | hidden_states = torch.stack(top_k_hidden_states_list, dim=0) 346 | if not top_k_attention_mask_list or None in top_k_attention_mask_list: 347 | attention_mask = None 348 | else: 349 | attention_mask = torch.stack(top_k_attention_mask_list, dim=0) 350 | if position_ids.dim() == 3: 351 | position_ids = torch.stack(top_k_position_ids_list, dim=1) 352 | elif position_ids.dim() == 2: 353 | position_ids = torch.stack(top_k_position_ids_list, dim=0) 354 | cache_position = top_k_cache_position_list[0] 355 | 356 | if isinstance(position_embeddings, tuple): 357 | new_position_embeddings = [] 358 | for i in range(len(position_embeddings)): 359 | if position_embeddings[i].dim() == 4: 360 | # (constant, bz, q_len, c), stack in the batch dim 361 | new_position_embeddings.append(torch.stack([x[i] for x in top_k_position_embeddings_list], dim=1)) 362 | elif position_embeddings[i].dim() == 3: 363 | # (bz, q_len, c), stack in the batch dim 364 | new_position_embeddings.append(torch.stack([x[i] for x in top_k_position_embeddings_list], dim=0)) 365 | else: 366 | raise ValueError(f"Unknown position_embeddings shape: {position_embeddings[i].shape}") 367 | position_embeddings = tuple(new_position_embeddings) 368 | elif position_embeddings.dim() == 3: 369 | # (bz, q_len, c), stack in the batch dim 370 | position_embeddings = torch.stack(top_k_position_embeddings_list, dim=0) 371 | else: 372 | raise ValueError(f"Unknown position_embeddings type: {type(position_embeddings)}") 373 | 374 | # print(f"Reduced keys and values from {old_k_shape} to {top_k_keys.shape} for layer {layer_idx}") 375 | 376 | return hidden_states, attention_mask, position_ids, cache_position, position_embeddings, present_key_value -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuickVideo 2 |

3 | QuickVideo Logo 4 |

5 | 6 |

7 | Efficient video loading and context prefill for hour-long video understanding 8 |

9 | 10 |

11 | Benjamin Schneider*Dongfu Jiang*Chao DuTianyu PangWenhu Chen 12 |

13 | 14 |

15 | University of Waterloo • SeaAI Lab 16 |

17 | 18 |

19 | *Equal contribution 20 |

21 | 22 |

23 | | 24 | Quick Start | 25 | Paper | 26 | QuickCodec | 27 | QuickPrefill 28 | | 29 |

30 | 31 | --- 32 | 33 | ## 🎯 Overview 34 | 35 | Long video understanding has emerged as a crucial capability for real-world applications such as meeting summarization, video surveillance, educational lecture analysis, and content moderation. However, it remains computationally prohibitive for VideoLLMs due to two critical bottlenecks: 36 | 37 | 1. **Sequential video decoding** - Converting raw bit streams to RGB frames can take up to a minute for hour-long videos 38 | 2. **Costly prefilling** - Processing millions of tokens for LLM inference results in high latency and memory usage 39 | 40 |

41 | QuickVideo System Overview 42 |

43 | 44 | **QuickVideo** is a system-algorithm co-design that achieves **3.5× speedup** (from 70s to 20s for 1-hour videos) while maintaining **97% performance** with **50% less memory**. 45 | 46 | ## 🚀 Key Innovations 47 | 48 | ### 🔧 QuickDecoder 49 | - **Parallelized CPU-based decoder** that splits videos into keyframe-aligned intervals 50 | - **2-3× faster** than sequential processing through concurrent execution 51 | 52 | ### ⚡ QuickPrefill 53 | - **Group-based prefilling** for memory-efficient activation handling 54 | - **KV-cache pruning** using key norm selection (L2) to retain only essential tokens 55 | - **50% memory reduction** while preserving 97% of original performance 56 | 57 | ### 🔄 Overlapping Pipeline 58 | - **Concurrent CPU decoding and GPU inference** to minimize end-to-end latency 59 | - Intelligent scheduling reduces total processing time significantly 60 | 61 |

62 | Pipeline Optimization 63 |

64 | 65 | ## 📊 Performance Results 66 | 67 | We evaluate both QuickCodec on video decoding efficiency (left figure) and QuickPrefill on avg QA accuracy results on 4 long video understanding benchmarks: VideoMME, LongVideoBench, LVBench, MLVU (right figure and hidden table). Results show significant speedup and memory saving while preserving 97% of the original performance. 68 | 69 | 70 | 71 | 74 | 77 | 78 |
72 | Video Processing Times 73 | 75 | KV Pruning Average Performance 76 |
79 | 80 |
81 | Performance Table 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 |
Group SizeKV Pruning methodρVideoMMELongVideoBench (val)LVBenchMLVU (dev)AvgPerformance
64 Frames
--162.4159.6940.0963.8656.51100.00%
16Value Norms0.547.6335.9830.9231.3836.4864.55%
16Attention Scores0.558.6352.9537.8359.8752.3292.58%
16Key Norms (↓)0.560.5656.1737.7062.3454.1995.90%
128 Frames
--166.4160.9642.8766.8659.27100.00%
16Value Norms0.548.5637.3230.7338.5138.7865.42%
16Attention Scores0.560.9655.2039.7064.3655.0692.89%
16Key Norms (↓)0.563.4158.1939.5764.9956.5495.39%
256 Frames
--165.7861.5643.9068.6559.97100.00%
16Value Norms0.548.3338.8931.3837.7439.0865.17%
16Attention Scores0.562.5257.2241.9667.2757.2495.45%
16Key Norms (↓)0.564.0460.2141.9066.7358.2297.08%
1024 Frames
--162.0060.4342.2963.4857.05100.00%
16Value Norms0.547.3733.6629.1832.6535.7162.60%
16Attention Scores0.562.2258.4942.0364.4556.8099.56%
16Key Norms0.559.9961.5940.8064.7656.7899.53%
299 | 300 | 301 |
302 | 303 | ## 🛠️ Installation 304 | 305 | ```bash 306 | # Clone and setup environment 307 | uv sync 308 | source .venv/bin/activate 309 | uv pip install -e . 310 | uv pip install flash-attn==2.7.3 --no-build-isolation 311 | ``` 312 | 313 | **Important** 314 | Please use `transformers==4.50.0` to run and it has been tested. Higher version's transformers library may not work because they have updated the source code of Qwen VL models at some versions after it (e.g. `transformers==4.52.4`). We will try to make it compatible with the latest version in the future. 315 | 316 | ## 🎮 Quick Start 317 | 318 | ### 1. Download Example Video 319 | ```bash 320 | wget https://github.com/TIGER-AI-Lab/QuickVideo/raw/refs/heads/dev/video/Q8AZ16uBhr8_resized_fps2_mute.mp4 321 | video_path="Q8AZ16uBhr8_resized_fps2_mute.mp4" 322 | ``` 323 | 324 | ### 2. Run QuickVideo (Recommended) 325 | **With interleaved processing + KV pruning** - ⚡ **Fastest configuration** 326 | 327 | ```python 328 | from lvu import LVU, LVUConfig 329 | 330 | # Configure QuickVideo with all optimizations 331 | config = LVUConfig( 332 | model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct", 333 | model_type="qwen25_lvu_interleaved", # Enable interleaved processing 334 | top_k_predict_type="key_norms_small", # Use key norm pruning 335 | video_group_size=16, # Process 16 frames per group 336 | top_k=64, # Keep 64 most important tokens per group 337 | num_frames=1024, # Process up to 1024 frames 338 | use_tqdm=True, 339 | ) 340 | 341 | lvu = LVU(config) 342 | question = "Describe this video." 343 | video_path = "Q8AZ16uBhr8_resized_fps2_mute.mp4" 344 | 345 | # Generate response 346 | output = lvu.generate(question, video_path, max_new_tokens=128, do_sample=False) 347 | print(output) 348 | ``` 349 | 350 | **Expected Output:** 351 | ``` 352 | ⏱️ Performance Metrics: 353 | • Frame fetching: 0.33s 354 | • Processing: 10.44s 355 | • Prefill: 22.95s 356 | • End-to-end: 27.65s (vs 57.86s baseline) 357 | • Time saved: 10.57s ⚡ 358 | 359 | 🎬 Generated Response: 360 | ['The video is a compilation of classic animated shorts featuring iconic characters from the 1940s and 1950s, showcasing slapstick humor and vibrant animation styles typical of that era. The clips include:\n\n1. **"A Bug\'s Life"**: A rabbit character is seen in a desert setting, engaging in a comedic chase sequence with a carrot. The rabbit exhibits exaggerated expressions and movements, typical of the cartoon\'s slapstick style.\n\n2. **"The Wabbit Who Could"**: Bugs Bunny appears in a whimsical scene where he is performing a magic trick involving a carrot. The animation is colorful and lively'] 361 | "The video is a compilation of classic animated shorts featuring iconic 362 | characters from the 1940s and 1950s, showcasing slapstick humor and 363 | vibrant animation styles typical of that era..." 364 | ``` 365 | 366 | **Important**: We recommend to run the interleaved version on **at least 2 cpu cores**, otherwise the interleaving strategy will do no better than the standard sequential processing. If you find no improvement using interleaved processing, then please check the number of CPU cores available on your machine. 367 | 368 | ### 3. Baseline Comparison 369 | **Without interleaved processing** - 🐌 **Slower but still optimized** 370 | 371 | ```python 372 | config = LVUConfig( 373 | model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct", 374 | model_type="qwen25_lvu", # Standard processing 375 | video_group_size=16, 376 | top_k=64, 377 | num_frames=1024, 378 | use_tqdm=True, 379 | ) 380 | # Same usage as above - notice the 2x slower processing time 381 | ``` 382 | 383 | ## 🔬 Benchmark Evaluation 384 | 385 | Evaluate QuickVideo performance on standard video understanding benchmarks: 386 | 387 | ```bash 388 | # Setup evaluation environment 389 | git submodule update --init --recursive 390 | cd lmms-eval 391 | uv pip install -e . 392 | 393 | # Configure environment 394 | export QUICKCODEC_CORES=8 395 | export FORCE_QWENVL_VIDEO_READER='deepcodec' 396 | ``` 397 | 398 | **Run comprehensive evaluation:** 399 | 400 | ```bash 401 | # Example evaluation script 402 | num_frame=1024 403 | benchmark_name="videomme,longvideobench_val_v,lvbench,mlvu_dev" 404 | 405 | accelerate launch --num_processes 8 --main_process_port 12351 -m lmms_eval \ 406 | --model qwen2_5_vl \ 407 | --model_args "pretrained=Qwen/Qwen2.5-VL-7B-Instruct,max_num_frames=$num_frame,use_flash_attention_2=True,adaptive_local_attention=True,local_attention_group_size=16,top_k=64,predict_type=key_norms_small" \ 408 | --tasks $benchmark_name \ 409 | --batch_size 1 \ 410 | --log_samples \ 411 | --output_path ./logs/quickvideo_evaluation 412 | ``` 413 | 414 | ## 🧪 Advanced Configuration 415 | 416 |
417 | QuickCodec Configuration 418 | 419 | | Environment Variable | Description | Default | Options | 420 | |-----------|-------------|---------|---------| 421 | | `QUICKCODEC_CORES` | CPU cores used for video decoding. | `8`| `2-128` | 422 | | `QUICKCODEC_INTERVALS` | Number of video segments to queue for loading. | `64`| `Any` | 423 | 424 | - Environment variables can be changed during execution to suport didferent settings for different videos. 425 | - The more cores you can use the better! Ideally several cores should be reserved for video decoding. 426 | - `QUICKCODEC_INTERVALS` is used for our overlapped prefill (see paper for details). Each intervals should be *at least* a keyframe apart. 427 | 428 |
429 | 430 | 431 |
432 | QuickPrefill Configuration 433 | 434 | | Parameter | Description | Default | Options | 435 | |-----------|-------------|---------|---------| 436 | | `model_type` | Processing mode | `qwen25_lvu` | `qwen25_lvu`, `qwen25_lvu_interleaved` | 437 | | `video_group_size` | Frames per processing group | `16` | `8`, `16`, `32`, ... | 438 | | `top_k` | Tokens to keep per group | `64` | Any positive integer | 439 | | `top_k_predict_type` | Pruning strategy | `key_norms_small` | `key_norms_small`, `attention_scores`, `value_norms` | 440 | | `num_frames` | Maximum frames to process | `1024` | `64`, `128`, `256`, `1024`, ... | 441 | | `top_p` | Percentage-based pruning | `None` | `0.0` to `1.0` | 442 | 443 |
444 | 445 | 446 | ## 🤝 Contributing 447 | 448 | We welcome contributions! To add new models or KV pruning methods: 449 | 450 | 1. **Fork the repository** 451 | 2. **Create a feature branch**: `git checkout -b feature/new-model` 452 | 3. **Implement your changes** following our coding standards 453 | 4. **Add tests** and documentation 454 | 5. **Submit a pull request** 455 | 456 | See our [contribution guidelines](CONTRIBUTING.md) for detailed instructions. (under construction) 457 | 458 | ## 📜 Citation 459 | 460 | If you find QuickVideo useful in your research, please cite our paper: 461 | 462 | ```bibtex 463 | @inproceedings{Schneider2025QuickVideoRL, 464 | title={QuickVideo: Real-Time Long Video Understanding with System Algorithm Co-Design}, 465 | author={Benjamin Schneider and Dongfu Jiang and Chao Du and Tianyu Pang and Wenhu Chen}, 466 | year={2025}, 467 | url={https://api.semanticscholar.org/CorpusID:278789043} 468 | } 469 | ``` 470 | 471 | ## 📄 License 472 | 473 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 474 | 475 | ## Star History 476 | 477 | [![Star History Chart](https://api.star-history.com/svg?repos=TIGER-AI-Lab/QuickVideo&type=Date)](https://www.star-history.com/#TIGER-AI-Lab/QuickVideo&Date) 478 | 479 | --- 480 | 481 |

482 | Made with ❤️ by the TIGER AI Lab team 483 |

484 | -------------------------------------------------------------------------------- /lvu/models/qwen25_vl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dataclasses 4 | import time 5 | import hashlib 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | from typing import Optional, Tuple 9 | from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel 10 | from transformers.feature_extraction_utils import BatchFeature 11 | from transformers.cache_utils import Cache 12 | from qwen_vl_utils import process_vision_info, extract_vision_info 13 | from ..utils import post_process_kv_cache 14 | from ..lvu_config import LVUConfig, LVULayerConfig 15 | from ..lvu_cache import LVUCache, save_ndarray_as_image, load_image_as_ndarray 16 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( 17 | apply_multimodal_rotary_pos_emb, 18 | repeat_kv, 19 | _flash_attention_forward, 20 | ) 21 | 22 | def lvu_qwen25_vl_flash_attention_2_forward( 23 | self, 24 | hidden_states: torch.Tensor, 25 | attention_mask: Optional[torch.Tensor] = None, 26 | position_ids: Optional[torch.LongTensor] = None, 27 | past_key_value: Optional[Cache] = None, 28 | output_attentions: bool = False, 29 | use_cache: bool = False, 30 | cache_position: Optional[torch.LongTensor] = None, 31 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 32 | ): 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = self.q_proj(hidden_states) 36 | key_states = self.k_proj(hidden_states) 37 | value_states = self.v_proj(hidden_states) 38 | 39 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 40 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 41 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 42 | 43 | # Because the input can be padded, the absolute sequence length depends on the max position id. 44 | cos, sin = position_embeddings 45 | query_states, key_states = apply_multimodal_rotary_pos_emb( 46 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] 47 | ) 48 | 49 | if past_key_value is not None: 50 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "query_states": query_states} # Specific to RoPE models 51 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 52 | 53 | # repeat k/v heads if n_kv_heads < n_heads 54 | key_states = repeat_kv(key_states, self.num_key_value_groups) 55 | value_states = repeat_kv(value_states, self.num_key_value_groups) 56 | dropout_rate = 0.0 if not self.training else self.attention_dropout 57 | 58 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 59 | # therefore the input hidden states gets silently casted in float32. Hence, we need 60 | # cast them back in float16 just to be sure everything works as expected. 61 | input_dtype = query_states.dtype 62 | if input_dtype == torch.float32: 63 | if torch.is_autocast_enabled(): 64 | target_dtype = torch.get_autocast_gpu_dtype() 65 | # Handle the case where the model is quantized 66 | elif hasattr(self.config, "_pre_quantization_dtype"): 67 | target_dtype = self.config._pre_quantization_dtype 68 | else: 69 | target_dtype = self.q_proj.weight.dtype 70 | 71 | logger.warning_once( 72 | f"The input hidden states seems to be silently casted in float32, this might be related to" 73 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 74 | f" {target_dtype}." 75 | ) 76 | 77 | query_states = query_states.to(target_dtype) 78 | key_states = key_states.to(target_dtype) 79 | value_states = value_states.to(target_dtype) 80 | 81 | # Reashape to the expected shape for Flash Attention 82 | query_states = query_states.transpose(1, 2) 83 | key_states = key_states.transpose(1, 2) 84 | value_states = value_states.transpose(1, 2) 85 | 86 | if ( 87 | self.config.use_sliding_window 88 | and getattr(self.config, "sliding_window", None) is not None 89 | and self.layer_idx >= self.config.max_window_layers 90 | ): 91 | sliding_window = self.config.sliding_window 92 | else: 93 | sliding_window = None 94 | 95 | attn_output = _flash_attention_forward( 96 | query_states, 97 | key_states, 98 | value_states, 99 | attention_mask, 100 | q_len, 101 | dropout=dropout_rate, 102 | sliding_window=sliding_window, 103 | is_causal=self.is_causal, 104 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 105 | ) 106 | 107 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 108 | attn_output = self.o_proj(attn_output) 109 | 110 | if not output_attentions: 111 | attn_weights = None 112 | 113 | return attn_output, attn_weights, past_key_value 114 | 115 | def lvu_qwen25_vl_decoder_layer_forward( 116 | self, 117 | hidden_states: torch.Tensor, 118 | attention_mask: Optional[torch.Tensor] = None, 119 | position_ids: Optional[torch.LongTensor] = None, 120 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 121 | output_attentions: Optional[bool] = False, 122 | use_cache: Optional[bool] = False, 123 | cache_position: Optional[torch.LongTensor] = None, 124 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 125 | **kwargs, 126 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 127 | """ 128 | Args: 129 | hidden_states (`torch.FloatTensor`): 130 | - input to the layer of shape `(batch, seq_len, embed_dim)` 131 | - or a tuple of `(hidden_states, attention_mask, position_ids, cache_position, position_embeddings)` 132 | meaning that the previous layer has prune the hidden states to topk 133 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 134 | `(batch, sequence_length)` where padding elements are indicated by 0. 135 | output_attentions (`bool`, *optional*): 136 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 137 | returned tensors for more detail. 138 | use_cache (`bool`, *optional*): 139 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 140 | (see `past_key_values`). 141 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 142 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 143 | Indices depicting the position of the input sequence tokens in the sequence. 144 | position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): 145 | Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, 146 | with `head_dim` being the embedding dimension of each attention head. 147 | kwargs (`dict`, *optional*): 148 | Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code 149 | into the model 150 | """ 151 | lvu_layer_config = getattr(self, "lvu_layer_config", None) 152 | lvu_config = getattr(lvu_layer_config, "lvu_config", None) 153 | if lvu_config is None: 154 | raise ValueError("LVUConfig is not set in the model. Please initialize the LVU model first.") 155 | 156 | if isinstance(hidden_states, tuple): 157 | # this means that previous layer has prune the hidden states to topk 158 | hidden_states, attention_mask, position_ids, cache_position, position_embeddings = hidden_states 159 | 160 | residual = hidden_states 161 | 162 | hidden_states = self.input_layernorm(hidden_states) 163 | 164 | # Self Attention 165 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 166 | hidden_states=hidden_states, 167 | attention_mask=attention_mask, 168 | position_ids=position_ids, 169 | past_key_value=past_key_value, 170 | output_attentions=output_attentions, 171 | use_cache=use_cache, 172 | cache_position=cache_position, 173 | position_embeddings=position_embeddings, 174 | ) 175 | hidden_states = residual.to(hidden_states.device) + hidden_states 176 | hidden_states, attention_mask, position_ids, cache_position, position_embeddings, present_key_value = post_process_kv_cache( 177 | hidden_states, 178 | attention_mask, 179 | position_ids=position_ids, 180 | cache_position=cache_position, 181 | position_embeddings=position_embeddings, 182 | attn_weights=self_attn_weights, 183 | present_key_value=present_key_value, 184 | lvu_layer_config=lvu_layer_config, 185 | ) 186 | 187 | # Fully Connected 188 | residual = hidden_states 189 | hidden_states = self.post_attention_layernorm(hidden_states) 190 | hidden_states = self.mlp(hidden_states) 191 | hidden_states = residual + hidden_states 192 | 193 | if lvu_config.enable and lvu_layer_config.prune_for_next_layer and not lvu_layer_config.is_last_layer: 194 | # pass all the pruned information to next layer. If the last layer, we don't need to save other information except hidden_states 195 | hidden_states = (hidden_states, attention_mask, position_ids, cache_position, position_embeddings) 196 | 197 | outputs = (hidden_states,) 198 | 199 | if output_attentions: 200 | outputs += (self_attn_weights,) 201 | 202 | if use_cache: 203 | outputs += (present_key_value,) 204 | 205 | return outputs 206 | 207 | import qwen_vl_utils.vision_process 208 | from qwen_vl_utils.vision_process import * 209 | import sys 210 | FPS_MAX_FRAMES = 100_000 # 768 = 256 * 3 211 | def smart_nframes( 212 | ele: dict, 213 | total_frames: int, 214 | video_fps: int | float, 215 | ) -> int: 216 | """calculate the number of frames for video used for model inputs. 217 | 218 | Args: 219 | ele (dict): a dict contains the configuration of video. 220 | support either `fps` or `nframes`: 221 | - nframes: the number of frames to extract for model inputs. 222 | - fps: the fps to extract frames for model inputs. 223 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 224 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 225 | total_frames (int): the original total number of frames of the video. 226 | video_fps (int | float): the original fps of the video. 227 | 228 | Raises: 229 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 230 | 231 | Returns: 232 | int: the number of frames for video used for model inputs. 233 | """ 234 | assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" 235 | if "nframes" in ele: 236 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 237 | nframes = min(nframes, total_frames) 238 | nframes -= (nframes % FRAME_FACTOR) 239 | else: 240 | fps = ele.get("fps", FPS) 241 | min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 242 | max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) 243 | nframes = total_frames / video_fps * fps 244 | if nframes > total_frames: 245 | logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") 246 | nframes = min(min(max(nframes, min_frames), max_frames), total_frames) 247 | nframes = floor_by_factor(nframes, FRAME_FACTOR) 248 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 249 | raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") 250 | return nframes 251 | #sys.modules["qwen_vl_utils.vision_process"].smart_nframes = smart_nframes 252 | 253 | import inspect 254 | def _get_initial_cache_position(self, *args, **kwargs): 255 | # Get the function signature 256 | sig = inspect.signature(self.old_get_initial_cache_position) 257 | 258 | # Get parameter names (excluding 'self') 259 | param_names = [param.name for param in sig.parameters.values() 260 | if param.name not in ('self', 'args', 'kwargs')] 261 | 262 | # Transform *args to **kwargs using parameter names 263 | args_as_kwargs = dict(zip(param_names, args)) 264 | 265 | # Combine with existing kwargs 266 | all_kwargs = {**args_as_kwargs, **kwargs} 267 | # Find model_kwargs in the mapped arguments 268 | model_kwargs = all_kwargs.get('model_kwargs') 269 | 270 | # Early return if cache_position already exists in model_kwargs 271 | if model_kwargs is not None and "cache_position" in model_kwargs: 272 | return model_kwargs 273 | return self.old_get_initial_cache_position(*args, **kwargs) 274 | 275 | 276 | def init_lvu_model(model, config: LVUConfig): 277 | """ 278 | Initialize the LVU model for Qwen 2.5 VL. 279 | - replace the decoder layer forward function with the LVU version 280 | Args: 281 | model: The model to be initialized. 282 | config: The configuration for the LVU model. 283 | """ 284 | _model = model 285 | if isinstance(_model, Qwen2_5_VLForConditionalGeneration): 286 | _model = _model.model 287 | if not hasattr(model, "get_rope_index"): 288 | # for transformers > 4.50.0 289 | model.get_rope_index = _model.get_rope_index 290 | if isinstance(_model, Qwen2_5_VLModel): 291 | if hasattr(_model, "layers"): 292 | _model = _model 293 | elif hasattr(_model, "language_model"): 294 | _model = _model.language_model 295 | else: 296 | raise ValueError("Qwen2_5_VLModel must have either `model` or `language_model` attribute.") 297 | try: 298 | decoder_layers = _model.layers 299 | except AttributeError: 300 | raise ValueError("Did not find `layers` attribute in the model. Please check your qwen2.5_vl source code and transformers version.") 301 | 302 | 303 | total_layers= len(decoder_layers) 304 | for i, layer in enumerate(decoder_layers): 305 | # Set the forward function for each decoder layer and filling the parameters in the config 306 | layer.forward = lvu_qwen25_vl_decoder_layer_forward.__get__(layer) 307 | layer.self_attn.forward = lvu_qwen25_vl_flash_attention_2_forward.__get__(layer.self_attn) 308 | layer.lvu_layer_config = LVULayerConfig(layer_idx=layer.self_attn.layer_idx, total_layers=total_layers, lvu_config=config) 309 | model.old_get_initial_cache_position = model._get_initial_cache_position 310 | model._get_initial_cache_position = _get_initial_cache_position.__get__(model) 311 | 312 | return model 313 | 314 | def run_lvu_model(self, question, video_path, **generation_kwargs): 315 | lvu_config = self.config 316 | fps = lvu_config.fps 317 | num_frames = lvu_config.num_frames 318 | extra_kwargs = lvu_config.extra_kwargs or {} 319 | max_pixels = extra_kwargs.get("max_pixels", 360 * 420) 320 | min_pixels = extra_kwargs.get("min_pixels", None) 321 | 322 | video_content = { 323 | "type": "video", 324 | "video": video_path, 325 | } 326 | if max_pixels is not None: 327 | video_content["max_pixels"] = max_pixels 328 | if min_pixels is not None: 329 | video_content["min_pixels"] = min_pixels 330 | if fps is not None: 331 | video_content["fps"] = fps 332 | elif num_frames is not None: 333 | video_content["nframes"] = num_frames 334 | else: 335 | raise ValueError("Either fps or num_frames should be set.") 336 | # Messages containing a local video path and a text query 337 | messages = [ 338 | { 339 | "role": "user", 340 | "content": [ 341 | video_content, 342 | {"type": "text", "text": question} 343 | ], 344 | } 345 | ] 346 | return chat_lvu_model(self, messages, **generation_kwargs) 347 | 348 | def chat_lvu_model(self, messages, **generation_kwargs): 349 | model = self.model 350 | processor = self.processor 351 | lvu_config = self.config 352 | 353 | # Process the messages 354 | #In Qwen 2.5 VL, frame rate information is also input into the model to align with absolute time. 355 | # Preparation for inference 356 | text = processor.apply_chat_template( 357 | messages, tokenize=False, add_generation_prompt=True 358 | ) 359 | 360 | start = time.time() 361 | cache_dir = lvu_config.cache_dir or "~/.cache/video_cache/qwen25_vl" 362 | vision_info = extract_vision_info(messages) 363 | assert len(vision_info) == 1, "Only one video is supported for now." 364 | video_path = Path(vision_info[0]["video"]) 365 | cache_key = video_path.stem 366 | for k, v in vision_info[0].items(): 367 | if k not in ['type', 'video']: 368 | cache_key += f"_{k}={v}" 369 | # cache_key = video_path.stem + "_" + hashlib.md5(str(vision_info).encode()).hexdigest() 370 | cache_dir = Path(cache_dir).expanduser() 371 | cache_file = cache_dir / f"{cache_key}.pt" 372 | cache_dir.mkdir(parents=True, exist_ok=True) 373 | if not cache_file.exists(): 374 | image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True) 375 | # save to cache 376 | if lvu_config.save_video_cache: 377 | cache_images_folder = cache_file.parent / f"{cache_key}_images" 378 | cache_images_folder.mkdir(parents=True, exist_ok=True) 379 | total_size = 0 380 | for i, image in enumerate(video_inputs[0]): 381 | # image is a numpy array of (C, H, W) 382 | save_ndarray_as_image(image.numpy(), cache_images_folder / f"{i:04d}.jpg") 383 | total_size += os.path.getsize(cache_images_folder / f"{i:04d}.jpg") 384 | torch.save({ 385 | "image_inputs": image_inputs, 386 | "video_kwargs": video_kwargs, 387 | }, cache_file) 388 | total_size_gb = total_size / (1024 ** 3) 389 | print(f"Saved video cache to {cache_file} ({total_size_gb:.2f} GB)") 390 | else: 391 | print(f"Cache file {cache_file} found. Loading video frames...") 392 | results = torch.load(cache_file) 393 | image_inputs = results["image_inputs"] 394 | video_kwargs = results["video_kwargs"] 395 | video_inputs = [] 396 | cache_images_folder = cache_file.parent / f"{cache_key}_images" 397 | all_images = sorted(cache_images_folder.glob("*.jpg")) 398 | for i in range(len(all_images)): 399 | image = torch.tensor(load_image_as_ndarray(cache_images_folder / f"{i:04d}.jpg")) 400 | video_inputs.append(image) 401 | video_inputs = [torch.stack(video_inputs)] 402 | end = time.time() 403 | print(f"Preprocessing time for video: {end - start:.2f}s") 404 | 405 | whole_inputs = processor( 406 | text=text, 407 | images=image_inputs, 408 | videos=video_inputs, 409 | padding=True, 410 | return_tensors="pt", 411 | **video_kwargs, 412 | ) 413 | whole_inputs = whole_inputs.to(model.device) 414 | n_video_tokens = (whole_inputs['input_ids'] == model.config.video_token_id).sum().item() 415 | video_token_idxs = (whole_inputs['input_ids'] == model.config.video_token_id).nonzero(as_tuple=True)[1] 416 | first_video_token_id_idx = video_token_idxs[0].item() 417 | last_video_token_id_idx = video_token_idxs[-1].item() 418 | position_ids, rope_deltas = model.get_rope_index( 419 | whole_inputs['input_ids'], 420 | whole_inputs.get('image_grid_thw', None), 421 | whole_inputs.get('video_grid_thw', None), 422 | whole_inputs.get('second_per_grid_ts', None), 423 | whole_inputs['attention_mask'], 424 | ) 425 | model.rope_deltas = rope_deltas 426 | 427 | assert len(video_inputs) <= 1, "Only one video is supported for now." 428 | video_group_size = lvu_config.video_group_size 429 | temporal_patch_size = processor.image_processor.temporal_patch_size 430 | if not video_group_size % temporal_patch_size == 0: 431 | video_group_size += temporal_patch_size - (video_group_size % temporal_patch_size) 432 | if video_group_size is not None and video_group_size > 0: 433 | video_groups = video_inputs[0].split(video_group_size) 434 | assert all(len(group) % 2 == 0 for group in video_groups), "The video group size should be even." 435 | video_groups_tokens = [int(n_video_tokens * (len(group) / len(video_inputs[0]))) for group in video_groups] 436 | video_grid_thw = whole_inputs['video_grid_thw'][0] 437 | video_groups_grid_thw = [] 438 | for group in video_groups: 439 | video_groups_grid_thw.append( 440 | torch.tensor( 441 | [(len(group) -1 ) // temporal_patch_size + 1, 442 | video_grid_thw[1], 443 | video_grid_thw[2]] 444 | ).unsqueeze(0) 445 | ) 446 | pixel_values_videos_group_size = round((video_group_size / len(video_inputs[0])) * whole_inputs['pixel_values_videos'].shape[0]) 447 | pixel_values_videos_groups = whole_inputs['pixel_values_videos'].split(pixel_values_videos_group_size) 448 | else: 449 | video_groups = [video_inputs[0]] 450 | video_groups_tokens = [n_video_tokens] 451 | video_groups_grid_thw = [whole_inputs['video_grid_thw']] 452 | pixel_values_videos_groups = [whole_inputs['pixel_values_videos']] 453 | 454 | # print("Sampled video frames: ", len(video_inputs[0])) 455 | # print("Video groups: ", [len(group) for group in video_groups]) 456 | # print("Video groups tokens: ", video_groups_tokens) 457 | print("Video groups grid thw: ", video_groups_grid_thw) 458 | # print("Pixel values videos groups: ", [group.shape for group in pixel_values_videos_groups]) 459 | 460 | # if any([group.shape[0] % 4 != 0 for group in pixel_values_videos_groups]): 461 | # print("Warning: The number of frames in each video group should be divisible by 4. Please check the video group size.") 462 | # return "" 463 | 464 | # preprepare the chunk processing 465 | past_key_values = LVUCache() 466 | past_len = 0 467 | video_token_idxs = (whole_inputs['input_ids'] == model.config.video_token_id).nonzero(as_tuple=True)[1] 468 | first_video_token_id_idx = video_token_idxs[0].item() 469 | last_video_token_id_idx = video_token_idxs[-1].item() 470 | prompt_input_ids = whole_inputs['input_ids'][:, last_video_token_id_idx + 1:] 471 | prompt_attention_mask = whole_inputs['attention_mask'][:, last_video_token_id_idx + 1:] 472 | if lvu_config.query_based: 473 | past_key_values.set_prompt_length(prompt_input_ids.shape[1]) 474 | video_groups_tokens[0] += first_video_token_id_idx # add the tokens before the first video group as well 475 | 476 | # start processing the video groups 477 | for i, pixel_values_videos_groups_i in tqdm(enumerate(pixel_values_videos_groups), 478 | desc="Processing video groups", total=len(pixel_values_videos_groups), disable=not lvu_config.use_tqdm): 479 | group_i_inputs = { 480 | "video_grid_thw": video_groups_grid_thw[i], 481 | "second_per_grid_ts": whole_inputs['second_per_grid_ts'], 482 | "pixel_values_videos": pixel_values_videos_groups_i, 483 | } 484 | group_i_inputs = BatchFeature(data=group_i_inputs) 485 | group_i_inputs['input_ids'] = whole_inputs['input_ids'][:, past_len:past_len + video_groups_tokens[i]] 486 | group_i_inputs['attention_mask'] = whole_inputs['attention_mask'][:, past_len:past_len + video_groups_tokens[i]] 487 | if lvu_config.query_based: 488 | group_i_inputs['input_ids'] = torch.cat((group_i_inputs['input_ids'], prompt_input_ids), dim=1) 489 | group_i_inputs['attention_mask'] = torch.cat((group_i_inputs['attention_mask'], prompt_attention_mask), dim=1) 490 | 491 | group_i_inputs['cache_position'] = torch.arange(group_i_inputs['input_ids'].shape[1], dtype=torch.int64, device=model.device) + past_len 492 | group_i_inputs['position_ids'] = position_ids[:, :, past_len:past_len + group_i_inputs['input_ids'].shape[1]] 493 | past_len += video_groups_tokens[i] # only the video group tokens are counted, prompt tokens are not counted 494 | group_i_inputs = group_i_inputs.to(model.device) 495 | group_i_inputs['use_cache'] = True 496 | if lvu_config.adaptive_local_attention: 497 | group_i_inputs['past_key_values'] = past_key_values 498 | with torch.no_grad(): 499 | outputs = model(**group_i_inputs) 500 | # later video groups will use the past key values 501 | past_key_values = outputs.past_key_values 502 | else: 503 | with torch.no_grad(): 504 | outputs = model(**group_i_inputs) 505 | if not past_key_values: 506 | # first time parsing, the video grid information is not correct 507 | past_key_values = outputs.past_key_values 508 | else: 509 | # update the past key values 510 | if isinstance(outputs.past_key_values, Cache): 511 | for i in range(len(outputs.past_key_values)): 512 | past_key_values.update(outputs.past_key_values[i][0], outputs.past_key_values[i][1], i) 513 | else: 514 | for i in range(len(outputs.past_key_values)): 515 | for j in range(len(outputs.past_key_values[i])): 516 | past_key_values[i][j] = torch.cat((past_key_values[i][j], outputs.past_key_values[i][j]), dim=2) 517 | # print(f"past_key_values shape: {past_key_values[0][0].shape}") 518 | assert past_len < whole_inputs['input_ids'].shape[1], "The past length should be less than the final input length." 519 | if lvu_config.query_based: 520 | # reset prompt length as all video groups are processed 521 | past_key_values.set_prompt_length(0) 522 | # end of processing the video groups 523 | 524 | final_inputs = { 525 | "input_ids": whole_inputs['input_ids'][:, past_len:], 526 | "attention_mask": whole_inputs['attention_mask'][:, past_len:], 527 | } 528 | final_inputs = BatchFeature(data=final_inputs) 529 | final_inputs['cache_position'] = torch.arange(final_inputs.input_ids.shape[1], dtype=torch.int64, device=model.device) + past_len 530 | final_inputs['position_ids'] = position_ids[:, :, past_len:] 531 | assert final_inputs['input_ids'].shape[1] == final_inputs['position_ids'].shape[2], "The input ids and position ids should have the same length, but got {} and {}".format( 532 | final_inputs['input_ids'].shape[1], final_inputs['position_ids'].shape[2]) 533 | final_inputs = final_inputs.to(model.device) 534 | final_inputs['past_key_values'] = past_key_values 535 | final_inputs['use_cache'] = True 536 | 537 | cache_enable = lvu_config.enable 538 | lvu_config.enable = lvu_config.do_top_k_for_query # determine whether to do topk or not 539 | generated_ids = model.generate(**final_inputs, **generation_kwargs) 540 | lvu_config.enable = cache_enable 541 | 542 | generated_ids_trimmed = [ 543 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(final_inputs.input_ids, generated_ids) 544 | ] 545 | output_text = processor.batch_decode( 546 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 547 | ) 548 | return output_text 549 | 550 | -------------------------------------------------------------------------------- /lvu/models/qwen25_lvu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dataclasses 4 | import time 5 | import sys 6 | import hashlib 7 | from PIL import Image 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from typing import Optional, Tuple 11 | from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel 12 | from transformers.feature_extraction_utils import BatchFeature 13 | from transformers.cache_utils import Cache 14 | from qwen_vl_utils import process_vision_info, extract_vision_info 15 | from ..utils import post_process_kv_cache 16 | from ..lvu_config import LVUConfig, LVULayerConfig 17 | from ..lvu_cache import LVUCache, save_ndarray_as_image, load_image_as_ndarray 18 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( 19 | apply_multimodal_rotary_pos_emb, 20 | repeat_kv, 21 | _flash_attention_forward, 22 | ) 23 | from torch.profiler import profile, record_function, ProfilerActivity 24 | 25 | import qwen_vl_utils.vision_process 26 | from qwen_vl_utils.vision_process import * 27 | FPS_MAX_FRAMES = 100_000 # originally: 768 = 256 * 3 28 | 29 | def lvu_qwen25_vl_flash_attention_2_forward( 30 | self, 31 | hidden_states: torch.Tensor, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | position_ids: Optional[torch.LongTensor] = None, 34 | past_key_value: Optional[Cache] = None, 35 | output_attentions: bool = False, 36 | use_cache: bool = False, 37 | cache_position: Optional[torch.LongTensor] = None, 38 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 39 | ): 40 | bsz, q_len, _ = hidden_states.size() 41 | 42 | query_states = self.q_proj(hidden_states) 43 | key_states = self.k_proj(hidden_states) 44 | value_states = self.v_proj(hidden_states) 45 | 46 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 47 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 48 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 49 | 50 | # Because the input can be padded, the absolute sequence length depends on the max position id. 51 | cos, sin = position_embeddings 52 | query_states, key_states = apply_multimodal_rotary_pos_emb( 53 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] 54 | ) 55 | 56 | if past_key_value is not None: 57 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "query_states": query_states} # Specific to RoPE models 58 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 59 | 60 | # repeat k/v heads if n_kv_heads < n_heads 61 | key_states = repeat_kv(key_states, self.num_key_value_groups) 62 | value_states = repeat_kv(value_states, self.num_key_value_groups) 63 | dropout_rate = 0.0 if not self.training else self.attention_dropout 64 | 65 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 66 | # therefore the input hidden states gets silently casted in float32. Hence, we need 67 | # cast them back in float16 just to be sure everything works as expected. 68 | input_dtype = query_states.dtype 69 | if input_dtype == torch.float32: 70 | if torch.is_autocast_enabled(): 71 | target_dtype = torch.get_autocast_gpu_dtype() 72 | # Handle the case where the model is quantized 73 | elif hasattr(self.config, "_pre_quantization_dtype"): 74 | target_dtype = self.config._pre_quantization_dtype 75 | else: 76 | target_dtype = self.q_proj.weight.dtype 77 | 78 | logger.warning_once( 79 | f"The input hidden states seems to be silently casted in float32, this might be related to" 80 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 81 | f" {target_dtype}." 82 | ) 83 | 84 | query_states = query_states.to(target_dtype) 85 | key_states = key_states.to(target_dtype) 86 | value_states = value_states.to(target_dtype) 87 | 88 | # Reashape to the expected shape for Flash Attention 89 | query_states = query_states.transpose(1, 2) 90 | key_states = key_states.transpose(1, 2) 91 | value_states = value_states.transpose(1, 2) 92 | 93 | if ( 94 | self.config.use_sliding_window 95 | and getattr(self.config, "sliding_window", None) is not None 96 | and self.layer_idx >= self.config.max_window_layers 97 | ): 98 | sliding_window = self.config.sliding_window 99 | else: 100 | sliding_window = None 101 | 102 | attn_output = _flash_attention_forward( 103 | query_states, 104 | key_states, 105 | value_states, 106 | attention_mask, 107 | q_len, 108 | dropout=dropout_rate, 109 | sliding_window=sliding_window, 110 | is_causal=self.is_causal, 111 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 112 | ) 113 | 114 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 115 | attn_output = self.o_proj(attn_output) 116 | 117 | if not output_attentions: 118 | attn_weights = None 119 | 120 | return attn_output, attn_weights, past_key_value 121 | 122 | def lvu_qwen25_vl_decoder_layer_forward( 123 | self, 124 | hidden_states: torch.Tensor, 125 | attention_mask: Optional[torch.Tensor] = None, 126 | position_ids: Optional[torch.LongTensor] = None, 127 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 128 | output_attentions: Optional[bool] = False, 129 | use_cache: Optional[bool] = False, 130 | cache_position: Optional[torch.LongTensor] = None, 131 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 132 | **kwargs, 133 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 134 | """ 135 | Args: 136 | hidden_states (`torch.FloatTensor`): 137 | - input to the layer of shape `(batch, seq_len, embed_dim)` 138 | - or a tuple of `(hidden_states, attention_mask, position_ids, cache_position, position_embeddings)` 139 | meaning that the previous layer has prune the hidden states to topk 140 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 141 | `(batch, sequence_length)` where padding elements are indicated by 0. 142 | output_attentions (`bool`, *optional*): 143 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 144 | returned tensors for more detail. 145 | use_cache (`bool`, *optional*): 146 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 147 | (see `past_key_values`). 148 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 149 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 150 | Indices depicting the position of the input sequence tokens in the sequence. 151 | position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): 152 | Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, 153 | with `head_dim` being the embedding dimension of each attention head. 154 | kwargs (`dict`, *optional*): 155 | Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code 156 | into the model 157 | """ 158 | lvu_layer_config = getattr(self, "lvu_layer_config", None) 159 | lvu_config = getattr(lvu_layer_config, "lvu_config", None) 160 | if lvu_config is None: 161 | raise ValueError("LVUConfig is not set in the model. Please initialize the LVU model first.") 162 | 163 | if isinstance(hidden_states, tuple): 164 | # this means that previous layer has prune the hidden states to topk 165 | hidden_states, attention_mask, position_ids, cache_position, position_embeddings = hidden_states 166 | 167 | residual = hidden_states 168 | 169 | hidden_states = self.input_layernorm(hidden_states) 170 | 171 | # Self Attention 172 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 173 | hidden_states=hidden_states, 174 | attention_mask=attention_mask, 175 | position_ids=position_ids, 176 | past_key_value=past_key_value, 177 | output_attentions=output_attentions, 178 | use_cache=use_cache, 179 | cache_position=cache_position, 180 | position_embeddings=position_embeddings, 181 | ) 182 | hidden_states = residual.to(hidden_states.device) + hidden_states 183 | hidden_states, attention_mask, position_ids, cache_position, position_embeddings, present_key_value = post_process_kv_cache( 184 | hidden_states, 185 | attention_mask, 186 | position_ids=position_ids, 187 | cache_position=cache_position, 188 | position_embeddings=position_embeddings, 189 | attn_weights=self_attn_weights, 190 | present_key_value=present_key_value, 191 | lvu_layer_config=lvu_layer_config, 192 | ) 193 | 194 | # Fully Connected 195 | residual = hidden_states 196 | hidden_states = self.post_attention_layernorm(hidden_states) 197 | hidden_states = self.mlp(hidden_states) 198 | hidden_states = residual + hidden_states 199 | 200 | if lvu_config.enable and lvu_layer_config.prune_for_next_layer and not lvu_layer_config.is_last_layer: 201 | # pass all the pruned information to next layer. If the last layer, we don't need to save other information except hidden_states 202 | hidden_states = (hidden_states, attention_mask, position_ids, cache_position, position_embeddings) 203 | 204 | outputs = (hidden_states,) 205 | 206 | if output_attentions: 207 | outputs += (self_attn_weights,) 208 | 209 | if use_cache: 210 | outputs += (present_key_value,) 211 | 212 | return outputs 213 | 214 | 215 | def _read_video_decord_cpu( 216 | ele: dict, 217 | ) -> (torch.Tensor, float): 218 | """read video using decord.VideoReader 219 | 220 | Args: 221 | ele (dict): a dict contains the configuration of video. 222 | support keys: 223 | - video: the path of video. support "file://", "http://", "https://" and local path. 224 | - video_start: the start time of video. 225 | - video_end: the end time of video. 226 | Returns: 227 | torch.Tensor: the video tensor with shape (T, C, H, W). 228 | """ 229 | import decord 230 | num_cores = int(os.environ.get("QUICKCODEC_CORES", "4")) 231 | video_path = ele["video"] 232 | st = time.time() 233 | vr = decord.VideoReader(video_path, num_threads=num_cores) 234 | # TODO: support start_pts and end_pts 235 | if 'video_start' in ele or 'video_end' in ele: 236 | raise NotImplementedError("not support start_pts and end_pts in decord for now.") 237 | total_frames, video_fps = len(vr), vr.get_avg_fps() 238 | logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") 239 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 240 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 241 | video = vr.get_batch(idx).asnumpy() 242 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format 243 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps 244 | return video, sample_fps 245 | 246 | 247 | def is_deepcodec_available() -> bool: 248 | import importlib.util 249 | if "DEEPCODEC_DISABLED" in os.environ: 250 | return False 251 | else: 252 | return importlib.util.find_spec("deepcodec") is not None 253 | 254 | @lru_cache(maxsize=1) 255 | def get_video_reader_backend() -> str: 256 | if FORCE_QWENVL_VIDEO_READER is not None: 257 | video_reader_backend = FORCE_QWENVL_VIDEO_READER 258 | elif is_deepcodec_available(): 259 | video_reader_backend = "deepcodec" 260 | elif is_decord_available(): 261 | video_reader_backend = "decord" 262 | else: 263 | video_reader_backend = "torchvision" 264 | print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) 265 | return video_reader_backend 266 | 267 | def _read_video_deepcodec( 268 | ele: dict, 269 | ) -> (torch.Tensor, float): 270 | """read video using decord.VideoReader 271 | 272 | Args: 273 | ele (dict): a dict contains the configuration of video. 274 | support keys: 275 | - video: the path of video. support "file://", "http://", "https://" and local path. 276 | - video_start: the start time of video. 277 | - video_end: the end time of video. 278 | Returns: 279 | torch.Tensor: the video tensor with shape (T, C, H, W). 280 | """ 281 | from deepcodec import VideoReader as DCVideoReader 282 | video_path = ele["video"] 283 | resize = ele.pop("resize") 284 | 285 | st = time.time() 286 | num_cores = int(os.environ.get("QUICKCODEC_CORES", "4")) 287 | vr = DCVideoReader(video_path, num_threads=num_cores) 288 | 289 | total_frames, video_fps = len(vr), vr.get_fps() 290 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 291 | 292 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 293 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 294 | max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) 295 | max_pixels_supposed = ele.get("max_pixels", max_pixels) 296 | if max_pixels_supposed > max_pixels: 297 | logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") 298 | max_pixels = min(max_pixels_supposed, max_pixels) 299 | 300 | height, width = vr.height, vr.width 301 | resized_height, resized_width = resize["fxn"]( 302 | height, 303 | width, 304 | factor=resize["image_factor"], 305 | min_pixels=min_pixels, 306 | max_pixels=max_pixels, 307 | ) 308 | vr.height = resized_height 309 | vr.width = resized_width 310 | vr.interpolation = "LANCZOS" 311 | 312 | # TODO: support start_pts and end_pts 313 | if 'video_start' in ele or 'video_end' in ele: 314 | raise NotImplementedError("not support start_pts and end_pts in deepcodec for now.") 315 | 316 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 317 | video = torch.from_numpy(vr.get_batch(idx)) 318 | batch_end_time = time.time() 319 | print(f"deepcodec: {video_path=}, {total_frames=}, {video_fps=}, time={batch_end_time-st:.3f}s") 320 | print(video.shape) 321 | # deepcodec already returns in TCHW format 322 | #video = torch.tensor(video).permute(0, 3, 1, 2) 323 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps 324 | 325 | return video, sample_fps 326 | 327 | VIDEO_READER_BACKENDS = { 328 | "deepcodec": _read_video_deepcodec, 329 | "decord": _read_video_decord_cpu, 330 | "torchvision": qwen_vl_utils.vision_process._read_video_torchvision, 331 | } 332 | 333 | def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]: 334 | if isinstance(ele["video"], str): 335 | 336 | video_reader_backend = get_video_reader_backend() 337 | 338 | if video_reader_backend == "deepcodec": 339 | ele["resize"] = { 340 | "fxn": smart_resize, 341 | "image_factor": image_factor, 342 | } 343 | 344 | try: 345 | video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele) 346 | except Exception as e: 347 | logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}") 348 | video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele) 349 | 350 | nframes, _, height, width = video.shape 351 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 352 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 353 | max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) 354 | max_pixels_supposed = ele.get("max_pixels", max_pixels) 355 | if max_pixels_supposed > max_pixels: 356 | logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") 357 | max_pixels = min(max_pixels_supposed, max_pixels) 358 | if "resized_height" in ele and "resized_width" in ele: 359 | resized_height, resized_width = smart_resize( 360 | ele["resized_height"], 361 | ele["resized_width"], 362 | factor=image_factor, 363 | ) 364 | elif video_reader_backend != "deepcodec": 365 | resized_height, resized_width = smart_resize( 366 | height, 367 | width, 368 | factor=image_factor, 369 | min_pixels=min_pixels, 370 | max_pixels=max_pixels, 371 | ) 372 | 373 | # deepcodec already handles resizing 374 | if video_reader_backend == "deepcodec": 375 | video = video.float() 376 | else: 377 | video = transforms.functional.resize( 378 | video, 379 | [resized_height, resized_width], 380 | interpolation=InterpolationMode.BICUBIC, 381 | antialias=True, 382 | ).float() 383 | if return_video_sample_fps: 384 | return video, sample_fps 385 | return video 386 | else: 387 | assert isinstance(ele["video"], (list, tuple)) 388 | process_info = ele.copy() 389 | process_info.pop("type", None) 390 | process_info.pop("video", None) 391 | images = [ 392 | fetch_image({"image": video_element, **process_info}, size_factor=image_factor) 393 | for video_element in ele["video"] 394 | ] 395 | nframes = ceil_by_factor(len(images), FRAME_FACTOR) 396 | if len(images) < nframes: 397 | images.extend([images[-1]] * (nframes - len(images))) 398 | if return_video_sample_fps: 399 | return images, process_info.pop("fps", 2.0) 400 | return images 401 | 402 | 403 | def smart_nframes( 404 | ele: dict, 405 | total_frames: int, 406 | video_fps: int | float, 407 | ) -> int: 408 | """calculate the number of frames for video used for model inputs. 409 | 410 | Args: 411 | ele (dict): a dict contains the configuration of video. 412 | support either `fps` or `nframes`: 413 | - nframes: the number of frames to extract for model inputs. 414 | - fps: the fps to extract frames for model inputs. 415 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 416 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 417 | total_frames (int): the original total number of frames of the video. 418 | video_fps (int | float): the original fps of the video. 419 | 420 | Raises: 421 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 422 | 423 | Returns: 424 | int: the number of frames for video used for model inputs. 425 | """ 426 | assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" 427 | if "nframes" in ele: 428 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 429 | nframes = min(nframes, total_frames) 430 | nframes -= (nframes % FRAME_FACTOR) 431 | else: 432 | fps = ele.get("fps", FPS) 433 | min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 434 | max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) 435 | nframes = total_frames / video_fps * fps 436 | if nframes > total_frames: 437 | logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") 438 | nframes = min(min(max(nframes, min_frames), max_frames), total_frames) 439 | nframes = floor_by_factor(nframes, FRAME_FACTOR) 440 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 441 | raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") 442 | return nframes 443 | 444 | import inspect 445 | def _get_initial_cache_position(self, *args, **kwargs): 446 | # Get the function signature 447 | sig = inspect.signature(self.old_get_initial_cache_position) 448 | 449 | # Get parameter names (excluding 'self') 450 | param_names = [param.name for param in sig.parameters.values() 451 | if param.name not in ('self', 'args', 'kwargs')] 452 | 453 | # Transform *args to **kwargs using parameter names 454 | args_as_kwargs = dict(zip(param_names, args)) 455 | 456 | # Combine with existing kwargs 457 | all_kwargs = {**args_as_kwargs, **kwargs} 458 | # Find model_kwargs in the mapped arguments 459 | model_kwargs = all_kwargs.get('model_kwargs') 460 | 461 | # Early return if cache_position already exists in model_kwargs 462 | if model_kwargs is not None and "cache_position" in model_kwargs: 463 | return model_kwargs 464 | return self.old_get_initial_cache_position(*args, **kwargs) 465 | 466 | 467 | def init_lvu_model(model, config: LVUConfig): 468 | """ 469 | Initialize the LVU model for Qwen 2.5 VL. 470 | - replace the decoder layer forward function with the LVU version 471 | Args: 472 | model: The model to be initialized. 473 | config: The configuration for the LVU model. 474 | """ 475 | _model = model 476 | if isinstance(_model, Qwen2_5_VLForConditionalGeneration): 477 | _model = _model.model 478 | if not hasattr(model, "get_rope_index"): 479 | # for transformers > 4.50.0 480 | model.get_rope_index = _model.get_rope_index 481 | if isinstance(_model, Qwen2_5_VLModel): 482 | if hasattr(_model, "layers"): 483 | _model = _model 484 | elif hasattr(_model, "language_model"): 485 | _model = _model.language_model 486 | else: 487 | raise ValueError("Qwen2_5_VLModel must have either `model` or `language_model` attribute.") 488 | try: 489 | decoder_layers = _model.layers 490 | except AttributeError: 491 | raise ValueError("Did not find `layers` attribute in the model. Please check your qwen2.5_vl source code and transformers version.") 492 | 493 | total_layers= len(decoder_layers) 494 | for i, layer in enumerate(decoder_layers): 495 | # Set the forward function for each decoder layer and filling the parameters in the config 496 | layer.forward = lvu_qwen25_vl_decoder_layer_forward.__get__(layer) 497 | layer.self_attn.forward = lvu_qwen25_vl_flash_attention_2_forward.__get__(layer.self_attn) 498 | layer.lvu_layer_config = LVULayerConfig(layer_idx=layer.self_attn.layer_idx, total_layers=total_layers, lvu_config=config) 499 | model.old_get_initial_cache_position = model._get_initial_cache_position 500 | model._get_initial_cache_position = _get_initial_cache_position.__get__(model) 501 | 502 | return model 503 | 504 | def run_lvu_model(self, question, video_path, **generation_kwargs): 505 | lvu_config = self.config 506 | fps = lvu_config.fps 507 | num_frames = lvu_config.num_frames 508 | extra_kwargs = lvu_config.extra_kwargs or {} 509 | max_pixels = extra_kwargs.get("max_pixels", None) 510 | min_pixels = extra_kwargs.get("min_pixels", None) 511 | 512 | video_content = { 513 | "type": "video", 514 | "video": video_path, 515 | } 516 | if max_pixels is not None: 517 | video_content["max_pixels"] = max_pixels 518 | if min_pixels is not None: 519 | video_content["min_pixels"] = min_pixels 520 | if fps is not None: 521 | video_content["fps"] = fps 522 | elif num_frames is not None: 523 | video_content["nframes"] = num_frames 524 | else: 525 | raise ValueError("Either fps or num_frames should be set.") 526 | # Messages containing a local video path and a text query 527 | messages = [ 528 | { 529 | "role": "user", 530 | "content": [ 531 | video_content, 532 | {"type": "text", "text": question} 533 | ], 534 | } 535 | ] 536 | return chat_lvu_model(self, messages, **generation_kwargs) 537 | 538 | def chat_lvu_model(self, messages, **generation_kwargs): 539 | model = self.model 540 | processor = self.processor 541 | lvu_config = self.config 542 | 543 | # Process the messages 544 | #In Qwen 2.5 VL, frame rate information is also input into the model to align with absolute time. 545 | # Preparation for inference 546 | text = processor.apply_chat_template( 547 | messages, tokenize=False, add_generation_prompt=True 548 | ) 549 | 550 | start = time.time() 551 | e2e_start = time.time() 552 | cache_dir = lvu_config.cache_dir or "~/.cache/video_cache/qwen25_vl" 553 | vision_info = extract_vision_info(messages) 554 | assert len(vision_info) == 1, "Only one video is supported for now." 555 | video_path = Path(vision_info[0]["video"]) 556 | cache_key = video_path.stem 557 | for k, v in vision_info[0].items(): 558 | if k not in ['type', 'video']: 559 | cache_key += f"_{k}={v}" 560 | # cache_key = video_path.stem + "_" + hashlib.md5(str(vision_info).encode()).hexdigest() 561 | cache_dir = Path(cache_dir).expanduser() 562 | cache_file = cache_dir / f"{cache_key}.pt" 563 | cache_dir.mkdir(parents=True, exist_ok=True) 564 | if not cache_file.exists(): 565 | image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True) 566 | # save to cache 567 | if lvu_config.save_video_cache: 568 | cache_images_folder = cache_file.parent / f"{cache_key}_images" 569 | cache_images_folder.mkdir(parents=True, exist_ok=True) 570 | total_size = 0 571 | for i, image in enumerate(video_inputs[0]): 572 | # image is a numpy array of (C, H, W) 573 | save_ndarray_as_image(image.numpy(), cache_images_folder / f"{i:04d}.jpg") 574 | total_size += os.path.getsize(cache_images_folder / f"{i:04d}.jpg") 575 | torch.save({ 576 | "image_inputs": image_inputs, 577 | "video_kwargs": video_kwargs, 578 | }, cache_file) 579 | total_size_gb = total_size / (1024 ** 3) 580 | print(f"Saved video cache to {cache_file} ({total_size_gb:.2f} GB)") 581 | else: 582 | print(f"Cache file {cache_file} found. Loading video frames...") 583 | results = torch.load(cache_file) 584 | image_inputs = results["image_inputs"] 585 | video_kwargs = results["video_kwargs"] 586 | video_inputs = [] 587 | cache_images_folder = cache_file.parent / f"{cache_key}_images" 588 | all_images = sorted(cache_images_folder.glob("*.jpg")) 589 | for i in range(len(all_images)): 590 | image = torch.tensor(load_image_as_ndarray(cache_images_folder / f"{i:04d}.jpg")) 591 | video_inputs.append(image) 592 | video_inputs = [torch.stack(video_inputs)] 593 | end = time.time() 594 | video_processing_time = end - start 595 | 596 | processor_start = time.time() 597 | whole_inputs = processor( 598 | text=text, 599 | images=image_inputs, 600 | videos=video_inputs, 601 | padding=True, 602 | return_tensors="pt", 603 | **video_kwargs, 604 | ) 605 | processor_end = time.time() 606 | processor_time = processor_end - processor_start 607 | 608 | whole_inputs = whole_inputs.to(model.device) 609 | n_video_tokens = (whole_inputs['input_ids'] == model.config.video_token_id).sum().item() 610 | video_token_idxs = (whole_inputs['input_ids'] == model.config.video_token_id).nonzero(as_tuple=True)[1] 611 | first_video_token_id_idx = video_token_idxs[0].item() 612 | last_video_token_id_idx = video_token_idxs[-1].item() 613 | position_ids, rope_deltas = model.get_rope_index( 614 | whole_inputs['input_ids'], 615 | whole_inputs.get('image_grid_thw', None), 616 | whole_inputs.get('video_grid_thw', None), 617 | whole_inputs.get('second_per_grid_ts', None), 618 | whole_inputs['attention_mask'], 619 | ) 620 | model.rope_deltas = rope_deltas 621 | 622 | assert len(video_inputs) <= 1, "Only one video is supported for now." 623 | video_group_size = lvu_config.video_group_size 624 | temporal_patch_size = processor.image_processor.temporal_patch_size 625 | if not video_group_size % temporal_patch_size == 0: 626 | video_group_size += temporal_patch_size - (video_group_size % temporal_patch_size) 627 | if video_group_size is not None and video_group_size > 0: 628 | video_groups = video_inputs[0].split(video_group_size) 629 | assert all(len(group) % 2 == 0 for group in video_groups), "The video group size should be even." 630 | video_groups_tokens = [int(n_video_tokens * (len(group) / len(video_inputs[0]))) for group in video_groups] 631 | video_grid_thw = whole_inputs['video_grid_thw'][0] 632 | video_groups_grid_thw = [] 633 | for group in video_groups: 634 | video_groups_grid_thw.append( 635 | torch.tensor( 636 | [(len(group) -1 ) // temporal_patch_size + 1, 637 | video_grid_thw[1], 638 | video_grid_thw[2]] 639 | ).unsqueeze(0) 640 | ) 641 | pixel_values_videos_group_size = round((video_group_size / len(video_inputs[0])) * whole_inputs['pixel_values_videos'].shape[0]) 642 | pixel_values_videos_groups = whole_inputs['pixel_values_videos'].split(pixel_values_videos_group_size) 643 | else: 644 | video_groups = [video_inputs[0]] 645 | video_groups_tokens = [n_video_tokens] 646 | video_groups_grid_thw = [whole_inputs['video_grid_thw']] 647 | pixel_values_videos_groups = [whole_inputs['pixel_values_videos']] 648 | 649 | # print("Sampled video frames: ", len(video_inputs[0])) 650 | # print("Video groups: ", [len(group) for group in video_groups]) 651 | # print("Video groups tokens: ", video_groups_tokens) 652 | # print("Video groups grid thw: ", video_groups_grid_thw) 653 | # print("Pixel values videos groups: ", [group.shape for group in pixel_values_videos_groups]) 654 | 655 | # preprepare the chunk processing 656 | past_key_values = LVUCache() 657 | past_len = 0 658 | video_token_idxs = (whole_inputs['input_ids'] == model.config.video_token_id).nonzero(as_tuple=True)[1] 659 | first_video_token_id_idx = video_token_idxs[0].item() 660 | last_video_token_id_idx = video_token_idxs[-1].item() 661 | prompt_input_ids = whole_inputs['input_ids'][:, last_video_token_id_idx + 1:] 662 | prompt_attention_mask = whole_inputs['attention_mask'][:, last_video_token_id_idx + 1:] 663 | if lvu_config.query_based: 664 | past_key_values.set_prompt_length(prompt_input_ids.shape[1]) 665 | video_groups_tokens[0] += first_video_token_id_idx # add the tokens before the first video group as well 666 | 667 | total_prefill = 0 668 | 669 | # start processing the video groups 670 | print(f"Processing total of {len(video_groups)} video groups, each with {video_group_size} frames.") 671 | for i, pixel_values_videos_groups_i in tqdm(enumerate(pixel_values_videos_groups), 672 | desc="Processing video groups", total=len(pixel_values_videos_groups), disable=not lvu_config.use_tqdm): 673 | 674 | prefill_start = time.time() 675 | 676 | group_i_inputs = { 677 | "video_grid_thw": video_groups_grid_thw[i], 678 | "second_per_grid_ts": whole_inputs['second_per_grid_ts'], 679 | "pixel_values_videos": pixel_values_videos_groups_i, 680 | } 681 | group_i_inputs = BatchFeature(data=group_i_inputs) 682 | group_i_inputs['input_ids'] = whole_inputs['input_ids'][:, past_len:past_len + video_groups_tokens[i]] 683 | group_i_inputs['attention_mask'] = whole_inputs['attention_mask'][:, past_len:past_len + video_groups_tokens[i]] 684 | if lvu_config.query_based: 685 | group_i_inputs['input_ids'] = torch.cat((group_i_inputs['input_ids'], prompt_input_ids), dim=1) 686 | group_i_inputs['attention_mask'] = torch.cat((group_i_inputs['attention_mask'], prompt_attention_mask), dim=1) 687 | 688 | group_i_inputs['cache_position'] = torch.arange(group_i_inputs['input_ids'].shape[1], dtype=torch.int64, device=model.device) + past_len 689 | group_i_inputs['position_ids'] = position_ids[:, :, past_len:past_len + group_i_inputs['input_ids'].shape[1]] 690 | past_len += video_groups_tokens[i] # only the video group tokens are counted, prompt tokens are not counted 691 | group_i_inputs = group_i_inputs.to(model.device) 692 | group_i_inputs['use_cache'] = True 693 | 694 | if lvu_config.adaptive_local_attention: 695 | group_i_inputs['past_key_values'] = past_key_values 696 | with torch.no_grad(): 697 | outputs = model(**group_i_inputs) 698 | # later video groups will use the past key values 699 | past_key_values = outputs.past_key_values 700 | else: 701 | with torch.no_grad(): 702 | outputs = model(**group_i_inputs) 703 | if not past_key_values: 704 | # first time parsing, the video grid information is not correct 705 | past_key_values = outputs.past_key_values 706 | else: 707 | # update the past key values 708 | if isinstance(outputs.past_key_values, Cache): 709 | for i in range(len(outputs.past_key_values)): 710 | past_key_values.update(outputs.past_key_values[i][0], outputs.past_key_values[i][1], i) 711 | else: 712 | for i in range(len(outputs.past_key_values)): 713 | for j in range(len(outputs.past_key_values[i])): 714 | past_key_values[i][j] = torch.cat((past_key_values[i][j], outputs.past_key_values[i][j]), dim=2) 715 | # print(f"past_key_values shape: {past_key_values[0][0].shape}") 716 | prefill_end = time.time() 717 | total_prefill += prefill_end-prefill_start 718 | assert past_len < whole_inputs['input_ids'].shape[1], "The past length should be less than the final input length." 719 | if lvu_config.query_based: 720 | # reset prompt length as all video groups are processed 721 | past_key_values.set_prompt_length(0) 722 | # end of processing the video groups 723 | 724 | final_inputs = { 725 | "input_ids": whole_inputs['input_ids'][:, past_len:], 726 | "attention_mask": whole_inputs['attention_mask'][:, past_len:], 727 | } 728 | final_inputs = BatchFeature(data=final_inputs) 729 | final_inputs['cache_position'] = torch.arange(final_inputs.input_ids.shape[1], dtype=torch.int64, device=model.device) + past_len 730 | final_inputs['position_ids'] = position_ids[:, :, past_len:] 731 | assert final_inputs['input_ids'].shape[1] == final_inputs['position_ids'].shape[2], "The input ids and position ids should have the same length, but got {} and {}".format( 732 | final_inputs['input_ids'].shape[1], final_inputs['position_ids'].shape[2]) 733 | final_inputs = final_inputs.to(model.device) 734 | final_inputs['past_key_values'] = past_key_values 735 | final_inputs['use_cache'] = True 736 | 737 | cache_enable = lvu_config.enable 738 | lvu_config.enable = lvu_config.do_top_k_for_query # determine whether to do topk or not 739 | decoding_start = time.time() 740 | generated_ids = model.generate(**final_inputs, **generation_kwargs) 741 | decoding_end = time.time() 742 | lvu_config.enable = cache_enable 743 | 744 | e2e_end = time.time() 745 | e2e_time = e2e_end - e2e_start 746 | decoding_time = decoding_end - decoding_start 747 | 748 | print(f"total time spent fetching frames was: {video_processing_time}") 749 | print(f"total time spent on processor was: {processor_time}") 750 | print(f"total time spent on prefill was: {total_prefill}") 751 | print(f"total time spent on decoding was: {decoding_time}") 752 | print(f"total time spent on e2e fetching and decoding was: {e2e_time}") 753 | print(f"Time saved by interleaved processing was: {video_processing_time + processor_time + total_prefill + decoding_time - e2e_time}") 754 | 755 | generated_ids_trimmed = [ 756 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(final_inputs.input_ids, generated_ids) 757 | ] 758 | output_text = processor.batch_decode( 759 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 760 | ) 761 | return output_text 762 | 763 | 764 | sys.modules["qwen_vl_utils.vision_process"].get_video_reader_backend = get_video_reader_backend 765 | sys.modules["qwen_vl_utils.vision_process"].VIDEO_READER_BACKENDS = VIDEO_READER_BACKENDS 766 | sys.modules["qwen_vl_utils.vision_process"].fetch_video = fetch_video 767 | sys.modules["qwen_vl_utils.vision_process"].smart_nframes = smart_nframes 768 | -------------------------------------------------------------------------------- /lvu/models/qwen25_lvu_interleaved.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | import time 5 | import sys 6 | import types 7 | import threading 8 | import numpy as np 9 | import os 10 | from PIL import Image 11 | from queue import Queue 12 | from pathlib import Path 13 | from tqdm import tqdm 14 | from typing import Optional, Tuple, Union 15 | from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel 16 | from transformers.feature_extraction_utils import BatchFeature 17 | from transformers.cache_utils import Cache 18 | from qwen_vl_utils import process_vision_info, extract_vision_info 19 | from ..utils import post_process_kv_cache 20 | from ..lvu_config import LVUConfig, LVULayerConfig 21 | from ..lvu_cache import LVUCache 22 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( 23 | apply_multimodal_rotary_pos_emb, 24 | repeat_kv, 25 | _flash_attention_forward, 26 | ) 27 | from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs, BatchFeature 28 | import warnings 29 | import qwen_vl_utils.vision_process 30 | from qwen_vl_utils.vision_process import * 31 | from deepcodec import InterleavedVideoReader 32 | FPS_MAX_FRAMES = 100_000 # originally: 768 = 256 * 3 33 | 34 | def lvu_qwen25_vl_flash_attention_2_forward( 35 | self, 36 | hidden_states: torch.Tensor, 37 | attention_mask: Optional[torch.Tensor] = None, 38 | position_ids: Optional[torch.LongTensor] = None, 39 | past_key_value: Optional[Cache] = None, 40 | output_attentions: bool = False, 41 | use_cache: bool = False, 42 | cache_position: Optional[torch.LongTensor] = None, 43 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 44 | ): 45 | bsz, q_len, _ = hidden_states.size() 46 | 47 | query_states = self.q_proj(hidden_states) 48 | key_states = self.k_proj(hidden_states) 49 | value_states = self.v_proj(hidden_states) 50 | 51 | query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 52 | key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 53 | value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) 54 | 55 | # Because the input can be padded, the absolute sequence length depends on the max position id. 56 | cos, sin = position_embeddings 57 | query_states, key_states = apply_multimodal_rotary_pos_emb( 58 | query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] 59 | ) 60 | 61 | if past_key_value is not None: 62 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "query_states": query_states} # Specific to RoPE models 63 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 64 | 65 | # repeat k/v heads if n_kv_heads < n_heads 66 | key_states = repeat_kv(key_states, self.num_key_value_groups) 67 | value_states = repeat_kv(value_states, self.num_key_value_groups) 68 | dropout_rate = 0.0 if not self.training else self.attention_dropout 69 | 70 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 71 | # therefore the input hidden states gets silently casted in float32. Hence, we need 72 | # cast them back in float16 just to be sure everything works as expected. 73 | input_dtype = query_states.dtype 74 | if input_dtype == torch.float32: 75 | if torch.is_autocast_enabled(): 76 | target_dtype = torch.get_autocast_gpu_dtype() 77 | # Handle the case where the model is quantized 78 | elif hasattr(self.config, "_pre_quantization_dtype"): 79 | target_dtype = self.config._pre_quantization_dtype 80 | else: 81 | target_dtype = self.q_proj.weight.dtype 82 | 83 | logger.warning_once( 84 | f"The input hidden states seems to be silently casted in float32, this might be related to" 85 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 86 | f" {target_dtype}." 87 | ) 88 | 89 | query_states = query_states.to(target_dtype) 90 | key_states = key_states.to(target_dtype) 91 | value_states = value_states.to(target_dtype) 92 | 93 | # Reashape to the expected shape for Flash Attention 94 | query_states = query_states.transpose(1, 2) 95 | key_states = key_states.transpose(1, 2) 96 | value_states = value_states.transpose(1, 2) 97 | 98 | if ( 99 | self.config.use_sliding_window 100 | and getattr(self.config, "sliding_window", None) is not None 101 | and self.layer_idx >= self.config.max_window_layers 102 | ): 103 | sliding_window = self.config.sliding_window 104 | else: 105 | sliding_window = None 106 | 107 | attn_output = _flash_attention_forward( 108 | query_states, 109 | key_states, 110 | value_states, 111 | attention_mask, 112 | q_len, 113 | dropout=dropout_rate, 114 | sliding_window=sliding_window, 115 | is_causal=self.is_causal, 116 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 117 | ) 118 | 119 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 120 | attn_output = self.o_proj(attn_output) 121 | 122 | if not output_attentions: 123 | attn_weights = None 124 | 125 | return attn_output, attn_weights, past_key_value 126 | 127 | def lvu_qwen25_vl_decoder_layer_forward( 128 | self, 129 | hidden_states: torch.Tensor, 130 | attention_mask: Optional[torch.Tensor] = None, 131 | position_ids: Optional[torch.LongTensor] = None, 132 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 133 | output_attentions: Optional[bool] = False, 134 | use_cache: Optional[bool] = False, 135 | cache_position: Optional[torch.LongTensor] = None, 136 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC 137 | **kwargs, 138 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 139 | """ 140 | Args: 141 | hidden_states (`torch.FloatTensor`): 142 | - input to the layer of shape `(batch, seq_len, embed_dim)` 143 | - or a tuple of `(hidden_states, attention_mask, position_ids, cache_position, position_embeddings)` 144 | meaning that the previous layer has prune the hidden states to topk 145 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 146 | `(batch, sequence_length)` where padding elements are indicated by 0. 147 | output_attentions (`bool`, *optional*): 148 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 149 | returned tensors for more detail. 150 | use_cache (`bool`, *optional*): 151 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 152 | (see `past_key_values`). 153 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 154 | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 155 | Indices depicting the position of the input sequence tokens in the sequence. 156 | position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): 157 | Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, 158 | with `head_dim` being the embedding dimension of each attention head. 159 | kwargs (`dict`, *optional*): 160 | Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code 161 | into the model 162 | """ 163 | lvu_layer_config = getattr(self, "lvu_layer_config", None) 164 | lvu_config = getattr(lvu_layer_config, "lvu_config", None) 165 | if lvu_config is None: 166 | raise ValueError("LVUConfig is not set in the model. Please initialize the LVU model first.") 167 | 168 | if isinstance(hidden_states, tuple): 169 | # this means that previous layer has prune the hidden states to topk 170 | hidden_states, attention_mask, position_ids, cache_position, position_embeddings = hidden_states 171 | 172 | residual = hidden_states 173 | 174 | hidden_states = self.input_layernorm(hidden_states) 175 | 176 | # Self Attention 177 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 178 | hidden_states=hidden_states, 179 | attention_mask=attention_mask, 180 | position_ids=position_ids, 181 | past_key_value=past_key_value, 182 | output_attentions=output_attentions, 183 | use_cache=use_cache, 184 | cache_position=cache_position, 185 | position_embeddings=position_embeddings, 186 | ) 187 | hidden_states = residual.to(hidden_states.device) + hidden_states 188 | hidden_states, attention_mask, position_ids, cache_position, position_embeddings, present_key_value = post_process_kv_cache( 189 | hidden_states, 190 | attention_mask, 191 | position_ids=position_ids, 192 | cache_position=cache_position, 193 | position_embeddings=position_embeddings, 194 | attn_weights=self_attn_weights, 195 | present_key_value=present_key_value, 196 | lvu_layer_config=lvu_layer_config, 197 | ) 198 | 199 | # Fully Connected 200 | residual = hidden_states 201 | hidden_states = self.post_attention_layernorm(hidden_states) 202 | hidden_states = self.mlp(hidden_states) 203 | hidden_states = residual + hidden_states 204 | 205 | if lvu_config.enable and lvu_layer_config.prune_for_next_layer and not lvu_layer_config.is_last_layer: 206 | # pass all the pruned information to next layer. If the last layer, we don't need to save other information except hidden_states 207 | hidden_states = (hidden_states, attention_mask, position_ids, cache_position, position_embeddings) 208 | 209 | outputs = (hidden_states,) 210 | 211 | if output_attentions: 212 | outputs += (self_attn_weights,) 213 | 214 | if use_cache: 215 | outputs += (present_key_value,) 216 | 217 | return outputs 218 | 219 | 220 | def save_image_to_home(img_array: torch.Tensor, filename: str = "img/output.png"): 221 | # Ensure shape is (H, W, 3) 222 | img = np.transpose(img_array.cpu().numpy(), (1, 2, 0)) 223 | 224 | if img.dtype != np.uint8: 225 | img = np.clip(img, 0, 1) if img.max() <= 1.0 else np.clip(img, 0, 255) 226 | img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8) 227 | 228 | pil_img = Image.fromarray(img) 229 | # Get home directory 230 | home_path = os.path.expanduser("~") 231 | # Save image 232 | save_path = os.path.join(home_path, filename) 233 | pil_img.save(save_path) 234 | print(f"Image saved to {save_path}") 235 | 236 | 237 | class PixelIterator: 238 | 239 | def __init__(self,qwen_vr, vr, frames_per_block,video_kwargs, processor): 240 | 241 | self.qwen_vr = qwen_vr 242 | self.iterations = 0 243 | self.frames_per_block = frames_per_block 244 | self.vr = vr 245 | self.processor = processor 246 | self.video_kwargs = video_kwargs 247 | self.processor_timing = 0 248 | 249 | def __iter__(self): 250 | return self 251 | 252 | def __next__(self): 253 | 254 | s = time.time() 255 | frames = torch.from_numpy(next(self.vr)).float() 256 | e = time.time() 257 | self.qwen_vr.total_timing += e-s 258 | s = time.time() 259 | #save_image_to_home(frames[8], f"img/{self.iterations}.png") 260 | pixels = self.processor( 261 | text="a", 262 | images=[], 263 | videos=[frames], 264 | padding=True, 265 | return_tensors="pt", 266 | **self.video_kwargs, 267 | )['pixel_values_videos'] 268 | self.iterations += 1 269 | e = time.time() 270 | self.processor_timing += e - s 271 | return pixels 272 | 273 | class AsyncPixelIterator(PixelIterator): 274 | def __init__(self, qwen_vr, vr, frames_per_block, video_kwargs, processor, buffer_size=3): 275 | super().__init__(qwen_vr, vr, frames_per_block, video_kwargs, processor) 276 | # Threading components 277 | self.buffer = Queue(maxsize=buffer_size) 278 | self.is_finished = False 279 | self.worker_thread = None 280 | self.exception = None 281 | 282 | def __iter__(self): 283 | # Start the background worker thread 284 | self.worker_thread = threading.Thread(target=self._background_worker, daemon=True) 285 | self.worker_thread.start() 286 | return self 287 | 288 | def __next__(self): 289 | # Get the next processed frame from the buffer 290 | while True: 291 | if self.exception: 292 | raise self.exception 293 | 294 | if not self.buffer.empty(): 295 | return self.buffer.get() 296 | 297 | if self.is_finished and self.buffer.empty(): 298 | raise StopIteration 299 | 300 | # Wait a bit before checking again 301 | time.sleep(0.01) 302 | 303 | def _background_worker(self): 304 | """Background thread that continuously processes frames""" 305 | try: 306 | while True: 307 | # Process the next frame 308 | pixels = self._process_frame() 309 | if pixels is None: 310 | break 311 | self.buffer.put(pixels) 312 | except StopIteration: 313 | self.is_finished = True 314 | except Exception as e: 315 | self.exception = e 316 | self.is_finished = True 317 | 318 | def _process_frame(self): 319 | """Process a single frame""" 320 | try: 321 | s = time.time() 322 | 323 | # Get the next frame 324 | frames = torch.from_numpy(next(self.vr)).float() 325 | e = time.time() 326 | self.qwen_vr.total_timing += e - s 327 | #save_image_to_home(frames[8], f"img/{self.iterations}.png") 328 | s = time.time() 329 | pixels = self.processor( 330 | text="a", 331 | images=[], 332 | videos=[frames], 333 | padding=True, 334 | return_tensors="pt", 335 | **self.video_kwargs, 336 | )['pixel_values_videos'] 337 | self.iterations += 1 338 | e = time.time() 339 | self.processor_timing += e - s 340 | return pixels 341 | except StopIteration: 342 | raise 343 | 344 | def smart_nframes( 345 | ele: dict, 346 | total_frames: int, 347 | video_fps: int | float, 348 | ) -> int: 349 | """calculate the number of frames for video used for model inputs. 350 | 351 | Args: 352 | ele (dict): a dict contains the configuration of video. 353 | support either `fps` or `nframes`: 354 | - nframes: the number of frames to extract for model inputs. 355 | - fps: the fps to extract frames for model inputs. 356 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 357 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 358 | total_frames (int): the original total number of frames of the video. 359 | video_fps (int | float): the original fps of the video. 360 | 361 | Raises: 362 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 363 | 364 | Returns: 365 | int: the number of frames for video used for model inputs. 366 | """ 367 | assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" 368 | if "nframes" in ele: 369 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 370 | nframes = min(nframes, total_frames) 371 | nframes -= (nframes % FRAME_FACTOR) 372 | else: 373 | fps = ele.get("fps", FPS) 374 | min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 375 | max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) 376 | nframes = total_frames / video_fps * fps 377 | if nframes > total_frames: 378 | logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") 379 | nframes = min(min(max(nframes, min_frames), max_frames), total_frames) 380 | nframes = floor_by_factor(nframes, FRAME_FACTOR) 381 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 382 | raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") 383 | return nframes 384 | 385 | def _read_video_interleaved( 386 | ele: dict, 387 | ): 388 | 389 | video_path = ele["video"] 390 | 391 | num_cores = int(os.environ.get("QUICKCODEC_CORES", "8")) 392 | num_intervals = int(os.environ.get("QUICKCODEC_INTERVALS", "64")) 393 | 394 | if os.cpu_count()-1 > num_cores: 395 | num_cores = os.cpu_count()-1 if os.cpu_count()-1 > 0 else 1 396 | warnings.warn(f"QuickCodec requested more cores than the system supports, num_cores was set to {num_cores}.") 397 | 398 | vr = InterleavedVideoReader(video_path, num_threads=num_cores, num_intervals=num_intervals) 399 | # TODO: support start_pts and end_pts 400 | if 'video_start' in ele or 'video_end' in ele: 401 | raise NotImplementedError("not support start_pts and end_pts in deepcodec for now.") 402 | total_frames, video_fps = len(vr), vr.get_fps() 403 | 404 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 405 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 406 | #video = torch.from_numpy(vr.get_batch(idx)) 407 | 408 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps 409 | 410 | return vr,idx, sample_fps 411 | 412 | def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]: 413 | if isinstance(ele["video"], str): 414 | vr, idx, sample_fps = _read_video_interleaved(ele) 415 | nframes, height, width = len(idx), vr.height, vr.width 416 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 417 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 418 | max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) 419 | max_pixels_supposed = ele.get("max_pixels", max_pixels) 420 | if max_pixels_supposed > max_pixels: 421 | logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") 422 | max_pixels = min(max_pixels_supposed, max_pixels) 423 | if "resized_height" in ele and "resized_width" in ele: 424 | resized_height, resized_width = smart_resize( 425 | ele["resized_height"], 426 | ele["resized_width"], 427 | factor=image_factor, 428 | ) 429 | else: 430 | resized_height, resized_width = smart_resize( 431 | height, 432 | width, 433 | factor=image_factor, 434 | min_pixels=min_pixels, 435 | max_pixels=max_pixels, 436 | ) 437 | 438 | vr.height = resized_height 439 | vr.width = resized_width 440 | vr.interpolation = "LANCZOS" 441 | #vr.interpolation = "BICUBIC" 442 | vr.process(idx) 443 | 444 | if return_video_sample_fps: 445 | return vr, sample_fps, nframes 446 | return vr 447 | else: 448 | raise NotImplementedError 449 | 450 | 451 | def process_vision_info( 452 | conversations: list[dict] | list[list[dict]], 453 | return_video_kwargs: bool = False, 454 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]: 455 | 456 | vision_infos = extract_vision_info(conversations) 457 | ## Read images or videos 458 | image_inputs = [] 459 | video_inputs = [] 460 | video_sample_fps_list = [] 461 | for vision_info in vision_infos: 462 | if "image" in vision_info or "image_url" in vision_info: 463 | raise Exception("NotImplementedError") 464 | elif "video" in vision_info: 465 | video_input, video_sample_fps, nframes = fetch_video(vision_info, return_video_sample_fps=True) 466 | video_sample_fps_list.append(video_sample_fps) 467 | video_inputs.append(video_input) 468 | else: 469 | raise ValueError("image, image_url or video should in content.") 470 | if len(image_inputs) == 0: 471 | image_inputs = None 472 | if len(video_inputs) == 0: 473 | video_inputs = None 474 | if return_video_kwargs: 475 | return image_inputs, video_inputs[0], {'fps': video_sample_fps_list}, nframes 476 | return image_inputs, video_inputs[0] 477 | 478 | class QwenVideoReaderInterleaved: 479 | def __init__(self, path, processor): 480 | self.total_timing = 0 481 | self.path = path 482 | self.frames_per_block = None 483 | self.batch = None 484 | self.processor = processor 485 | 486 | def process(self, conv): 487 | 488 | # conv = [{ 489 | # "role": "user", 490 | # "content": [{"type": "video", "video": str(self.path), "max_pixels": math.inf, "fps": 2}] 491 | # }] 492 | s = time.time() 493 | self.image_inputs, self.vr, self.video_kwargs, self.nframes = process_vision_info(conv, return_video_kwargs=True) 494 | e = time.time() 495 | self.total_timing += e-s 496 | 497 | def dummy_input(self): 498 | 499 | return { 500 | "video_grid_thw" : torch.tensor(( 501 | self.nframes / self.processor.image_processor.temporal_patch_size, 502 | self.vr.height / self.processor.image_processor.patch_size, 503 | self.vr.width / self.processor.image_processor.patch_size 504 | ), dtype=torch.int64).unsqueeze(dim=0), 505 | "second_per_grid_ts": -1, 506 | "pixel_values_videos": None, 507 | "fps" : self.video_kwargs.get("fps", 2.0), 508 | } 509 | 510 | def dummy_video_inputs(self): 511 | return [torch.empty((self.nframes, 3, self.vr.height, self.vr.width), dtype=torch.float32)] 512 | 513 | def set_frames_per_block(self, num_frames): 514 | self.vr.frame_iter = num_frames 515 | self.frames_per_block = num_frames 516 | 517 | 518 | def get_pixel_iterator(self): 519 | # return PixelIterator(self, self.vr, self.frames_per_block,self.video_kwargs, self.processor) 520 | return AsyncPixelIterator(self, self.vr, self.frames_per_block,self.video_kwargs, self.processor) 521 | 522 | def dummy_call( 523 | self, 524 | images = None, 525 | text = None, 526 | videos = None, 527 | **kwargs, 528 | ): 529 | """ 530 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 531 | and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode 532 | the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to 533 | Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. 534 | 535 | Args: 536 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 537 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 538 | tensor. Both channels-first and channels-last formats are supported. 539 | text (`str`, `List[str]`, `List[List[str]]`): 540 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 541 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 542 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 543 | videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): 544 | The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch 545 | tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. 546 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 547 | If set, will return tensors of a particular framework. Acceptable values are: 548 | - `'tf'`: Return TensorFlow `tf.constant` objects. 549 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 550 | - `'np'`: Return NumPy `np.ndarray` objects. 551 | - `'jax'`: Return JAX `jnp.ndarray` objects. 552 | 553 | Returns: 554 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 555 | 556 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 557 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 558 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 559 | `None`). 560 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 561 | - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. 562 | - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. 563 | - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. 564 | - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. 565 | """ 566 | video_in = kwargs.pop("video_inputs") 567 | 568 | 569 | output_kwargs = self._merge_kwargs( 570 | Qwen2_5_VLProcessorKwargs, 571 | tokenizer_init_kwargs=self.tokenizer.init_kwargs, 572 | **kwargs, 573 | ) 574 | if images is not None: 575 | image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) 576 | image_grid_thw = image_inputs["image_grid_thw"] 577 | else: 578 | image_inputs = {} 579 | image_grid_thw = None 580 | 581 | if videos is not None: 582 | 583 | videos_inputs = { 584 | "video_grid_thw" : video_in["video_grid_thw"], 585 | "second_per_grid_ts": -1, 586 | "pixel_values_videos": None 587 | } 588 | 589 | video_grid_thw = video_in["video_grid_thw"] 590 | 591 | #fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) 592 | fps = video_in["fps"] 593 | if isinstance(fps, (int, float)): 594 | second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) 595 | elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): 596 | second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] 597 | else: 598 | raise ValueError( 599 | f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." 600 | ) 601 | videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) 602 | 603 | else: 604 | videos_inputs = {} 605 | video_grid_thw = None 606 | 607 | if not isinstance(text, list): 608 | text = [text] 609 | 610 | if image_grid_thw is not None: 611 | merge_length = self.image_processor.merge_size**2 612 | index = 0 613 | for i in range(len(text)): 614 | while self.image_token in text[i]: 615 | text[i] = text[i].replace( 616 | self.image_token, 617 | "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 618 | 1, 619 | ) 620 | index += 1 621 | text[i] = text[i].replace("<|placeholder|>", self.image_token) 622 | 623 | if video_grid_thw is not None: 624 | merge_length = self.image_processor.merge_size**2 625 | index = 0 626 | for i in range(len(text)): 627 | while self.video_token in text[i]: 628 | text[i] = text[i].replace( 629 | self.video_token, 630 | "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 631 | 1, 632 | ) 633 | index += 1 634 | text[i] = text[i].replace("<|placeholder|>", self.video_token) 635 | 636 | text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) 637 | 638 | return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) 639 | 640 | import inspect 641 | def _get_initial_cache_position(self, *args, **kwargs): 642 | # Get the function signature 643 | sig = inspect.signature(self.old_get_initial_cache_position) 644 | 645 | # Get parameter names (excluding 'self') 646 | param_names = [param.name for param in sig.parameters.values() 647 | if param.name not in ('self', 'args', 'kwargs')] 648 | 649 | # Transform *args to **kwargs using parameter names 650 | args_as_kwargs = dict(zip(param_names, args)) 651 | 652 | # Combine with existing kwargs 653 | all_kwargs = {**args_as_kwargs, **kwargs} 654 | # Find model_kwargs in the mapped arguments 655 | model_kwargs = all_kwargs.get('model_kwargs') 656 | 657 | # Early return if cache_position already exists in model_kwargs 658 | if model_kwargs is not None and "cache_position" in model_kwargs: 659 | return model_kwargs 660 | return self.old_get_initial_cache_position(*args, **kwargs) 661 | 662 | def init_lvu_model(model, config: LVUConfig): 663 | """ 664 | Initialize the LVU model for Qwen 2.5 VL. 665 | - replace the decoder layer forward function with the LVU version 666 | Args: 667 | model: The model to be initialized. 668 | config: The configuration for the LVU model. 669 | """ 670 | _model = model 671 | if isinstance(_model, Qwen2_5_VLForConditionalGeneration): 672 | _model = _model.model 673 | if not hasattr(model, "get_rope_index"): 674 | # for transformers > 4.50.0 675 | model.get_rope_index = _model.get_rope_index 676 | if isinstance(_model, Qwen2_5_VLModel): 677 | if hasattr(_model, "layers"): 678 | _model = _model 679 | elif hasattr(_model, "language_model"): 680 | _model = _model.language_model 681 | else: 682 | raise ValueError("Qwen2_5_VLModel must have either `model` or `language_model` attribute.") 683 | try: 684 | decoder_layers = _model.layers 685 | except AttributeError: 686 | raise ValueError("Did not find `layers` attribute in the model. Please check your qwen2.5_vl source code and transformers version.") 687 | 688 | total_layers= len(decoder_layers) 689 | for i, layer in enumerate(decoder_layers): 690 | # Set the forward function for each decoder layer and filling the parameters in the config 691 | layer.forward = lvu_qwen25_vl_decoder_layer_forward.__get__(layer) 692 | layer.self_attn.forward = lvu_qwen25_vl_flash_attention_2_forward.__get__(layer.self_attn) 693 | layer.lvu_layer_config = LVULayerConfig(layer_idx=layer.self_attn.layer_idx, total_layers=total_layers, lvu_config=config) 694 | model.old_get_initial_cache_position = model._get_initial_cache_position 695 | model._get_initial_cache_position = _get_initial_cache_position.__get__(model) 696 | 697 | return model 698 | 699 | def run_lvu_model(self, question, video_path, **generation_kwargs): 700 | lvu_config = self.config 701 | fps = lvu_config.fps 702 | num_frames = lvu_config.num_frames 703 | extra_kwargs = lvu_config.extra_kwargs or {} 704 | max_pixels = extra_kwargs.get("max_pixels", None) 705 | min_pixels = extra_kwargs.get("min_pixels", None) 706 | 707 | video_content = { 708 | "type": "video", 709 | "video": video_path, 710 | } 711 | if max_pixels is not None: 712 | video_content["max_pixels"] = max_pixels 713 | if min_pixels is not None: 714 | video_content["min_pixels"] = min_pixels 715 | if fps is not None: 716 | video_content["fps"] = fps 717 | elif num_frames is not None: 718 | video_content["nframes"] = num_frames 719 | else: 720 | raise ValueError("Either fps or num_frames should be set.") 721 | # Messages containing a local video path and a text query 722 | messages = [ 723 | { 724 | "role": "user", 725 | "content": [ 726 | video_content, 727 | {"type": "text", "text": question} 728 | ], 729 | } 730 | ] 731 | return chat_lvu_model(self, messages, **generation_kwargs) 732 | 733 | def chat_lvu_model(self, messages, **generation_kwargs): 734 | model = self.model 735 | processor = self.processor 736 | lvu_config = self.config 737 | 738 | # Process the messages 739 | #In Qwen 2.5 VL, frame rate information is also input into the model to align with absolute time. 740 | # Preparation for inference 741 | text = processor.apply_chat_template( 742 | messages, tokenize=False, add_generation_prompt=True 743 | ) 744 | 745 | start = time.time() 746 | cache_dir = lvu_config.cache_dir or "~/.cache/video_cache/qwen25_vl" 747 | vision_info = extract_vision_info(messages) 748 | assert len(vision_info) == 1, "Only one video is supported for now." 749 | video_path = Path(vision_info[0]["video"]) 750 | cache_key = video_path.stem 751 | for k, v in vision_info[0].items(): 752 | if k not in ['type', 'video']: 753 | cache_key += f"_{k}={v}" 754 | # cache_key = video_path.stem + "_" + hashlib.md5(str(vision_info).encode()).hexdigest() 755 | cache_dir = Path(cache_dir).expanduser() 756 | cache_file = cache_dir / f"{cache_key}.pt" 757 | cache_dir.mkdir(parents=True, exist_ok=True) 758 | if not cache_file.exists() or False: 759 | # Interleaved processing 760 | vr = QwenVideoReaderInterleaved(video_path, processor) 761 | vr.process(messages) 762 | 763 | # used for finding correct shapes of blocks 764 | # I think only time dim is needed ~ can probably remove other dims 765 | video_inputs = vr.dummy_video_inputs() 766 | # save to cache 767 | if lvu_config.save_video_cache: 768 | torch.save({ 769 | "image_inputs": image_inputs, 770 | "video_inputs": video_inputs, 771 | "video_kwargs": video_kwargs, 772 | }, cache_file) 773 | cache_file_size_gb = cache_file.stat().st_size / (1024 ** 3) 774 | print(f"Saved video cache to {cache_file} ({cache_file_size_gb:.2f} GB)") 775 | else: 776 | print(f"Cache file {cache_file} found. Loading video frames...") 777 | results = torch.load(cache_file) 778 | image_inputs = results["image_inputs"] 779 | video_inputs = results["video_inputs"] 780 | video_kwargs = results["video_kwargs"] 781 | end = time.time() 782 | print(f"Preprocessing time for video: {end - start:.2f}s") 783 | 784 | s = time.time() 785 | 786 | processor.dummy_call = types.MethodType(dummy_call, processor) 787 | whole_inputs = processor.dummy_call( 788 | text=text, 789 | images=None, 790 | videos=[], 791 | padding=True, 792 | return_tensors="pt", 793 | **vr.video_kwargs, 794 | video_inputs = vr.dummy_input() 795 | ) 796 | 797 | e = time.time() 798 | print(f"Tokenizer time was: {e - s:.2f}s") 799 | whole_inputs = whole_inputs.to(model.device) 800 | n_video_tokens = (whole_inputs['input_ids'] == model.config.video_token_id).sum().item() 801 | video_token_idxs = (whole_inputs['input_ids'] == model.config.video_token_id).nonzero(as_tuple=True)[1] 802 | first_video_token_id_idx = video_token_idxs[0].item() 803 | last_video_token_id_idx = video_token_idxs[-1].item() 804 | position_ids, rope_deltas = model.get_rope_index( 805 | whole_inputs['input_ids'], 806 | whole_inputs.get('image_grid_thw', None), 807 | whole_inputs.get('video_grid_thw', None), 808 | whole_inputs.get('second_per_grid_ts', None), 809 | whole_inputs['attention_mask'], 810 | ) 811 | model.rope_deltas = rope_deltas 812 | 813 | assert len(video_inputs) <= 1, "Only one video is supported for now." 814 | video_group_size = lvu_config.video_group_size 815 | temporal_patch_size = processor.image_processor.temporal_patch_size 816 | if not video_group_size % temporal_patch_size == 0: 817 | video_group_size += temporal_patch_size - (video_group_size % temporal_patch_size) 818 | if video_group_size is not None and video_group_size > 0: 819 | video_groups = video_inputs[0].split(video_group_size) 820 | assert all(len(group) % 2 == 0 for group in video_groups), "The video group size should be even." 821 | video_groups_tokens = [int(n_video_tokens * (len(group) / len(video_inputs[0]))) for group in video_groups] 822 | video_grid_thw = whole_inputs['video_grid_thw'][0] 823 | video_groups_grid_thw = [] 824 | for group in video_groups: 825 | video_groups_grid_thw.append( 826 | torch.tensor( 827 | [(len(group) -1 ) // temporal_patch_size + 1, 828 | video_grid_thw[1], 829 | video_grid_thw[2]] 830 | ).unsqueeze(0) 831 | ) 832 | 833 | vr.set_frames_per_block(video_group_size) 834 | pixel_iter = vr.get_pixel_iterator() 835 | 836 | # preprepare the chunk processing 837 | past_key_values = LVUCache() 838 | past_len = 0 839 | video_token_idxs = (whole_inputs['input_ids'] == model.config.video_token_id).nonzero(as_tuple=True)[1] 840 | first_video_token_id_idx = video_token_idxs[0].item() 841 | last_video_token_id_idx = video_token_idxs[-1].item() 842 | prompt_input_ids = whole_inputs['input_ids'][:, last_video_token_id_idx + 1:] 843 | prompt_attention_mask = whole_inputs['attention_mask'][:, last_video_token_id_idx + 1:] 844 | if lvu_config.query_based: 845 | past_key_values.set_prompt_length(prompt_input_ids.shape[1]) 846 | video_groups_tokens[0] += first_video_token_id_idx # add the tokens before the first video group as well 847 | 848 | total_prefill_time = 0 849 | 850 | # start processing the video groups 851 | print(f"Processing total of {vr.nframes} frames of {video_group_size} frames each.") 852 | e2e_start = time.time() 853 | for i, (pixel_values_videos_groups_i) in tqdm(enumerate(pixel_iter), 854 | desc="Processing video groups", disable=not lvu_config.use_tqdm, 855 | total=vr.nframes // video_group_size, 856 | ): 857 | start_of_block_prefill = time.time() 858 | group_i_inputs = { 859 | "video_grid_thw": video_groups_grid_thw[i], 860 | "second_per_grid_ts": whole_inputs['second_per_grid_ts'], 861 | "pixel_values_videos": pixel_values_videos_groups_i, 862 | } 863 | group_i_inputs = BatchFeature(data=group_i_inputs) 864 | group_i_inputs['input_ids'] = whole_inputs['input_ids'][:, past_len:past_len + video_groups_tokens[i]] 865 | group_i_inputs['attention_mask'] = whole_inputs['attention_mask'][:, past_len:past_len + video_groups_tokens[i]] 866 | if lvu_config.query_based: 867 | group_i_inputs['input_ids'] = torch.cat((group_i_inputs['input_ids'], prompt_input_ids), dim=1) 868 | group_i_inputs['attention_mask'] = torch.cat((group_i_inputs['attention_mask'], prompt_attention_mask), dim=1) 869 | 870 | group_i_inputs['cache_position'] = torch.arange(group_i_inputs['input_ids'].shape[1], dtype=torch.int64, device=model.device) + past_len 871 | group_i_inputs['position_ids'] = position_ids[:, :, past_len:past_len + group_i_inputs['input_ids'].shape[1]] 872 | past_len += video_groups_tokens[i] # only the video group tokens are counted, prompt tokens are not counted 873 | group_i_inputs = group_i_inputs.to(model.device) 874 | group_i_inputs['use_cache'] = True 875 | if lvu_config.adaptive_local_attention: 876 | group_i_inputs['past_key_values'] = past_key_values 877 | with torch.no_grad(): 878 | outputs = model(**group_i_inputs) 879 | # later video groups will use the past key values 880 | past_key_values = outputs.past_key_values 881 | else: 882 | with torch.no_grad(): 883 | outputs = model(**group_i_inputs) 884 | if not past_key_values: 885 | # first time parsing, the video grid information is not correct 886 | past_key_values = outputs.past_key_values 887 | else: 888 | # update the past key values 889 | if isinstance(outputs.past_key_values, Cache): 890 | for i in range(len(outputs.past_key_values)): 891 | past_key_values.update(outputs.past_key_values[i][0], outputs.past_key_values[i][1], i) 892 | else: 893 | for i in range(len(outputs.past_key_values)): 894 | for j in range(len(outputs.past_key_values[i])): 895 | past_key_values[i][j] = torch.cat((past_key_values[i][j], outputs.past_key_values[i][j]), dim=2) 896 | end_of_block_prefill_time = time.time() 897 | total_prefill_time += end_of_block_prefill_time - start_of_block_prefill 898 | # print(f"past_key_values shape: {past_key_values[0][0].shape}") 899 | assert past_len < whole_inputs['input_ids'].shape[1], "The past length should be less than the final input length." 900 | if lvu_config.query_based: 901 | # reset prompt length as all video groups are processed 902 | past_key_values.set_prompt_length(0) 903 | # end of processing the video groups 904 | start_of_decoding = time.time() 905 | 906 | final_inputs = { 907 | "input_ids": whole_inputs['input_ids'][:, past_len:], 908 | "attention_mask": whole_inputs['attention_mask'][:, past_len:], 909 | } 910 | final_inputs = BatchFeature(data=final_inputs) 911 | final_inputs['cache_position'] = torch.arange(final_inputs.input_ids.shape[1], dtype=torch.int64, device=model.device) + past_len 912 | final_inputs['position_ids'] = position_ids[:, :, past_len:] 913 | assert final_inputs['input_ids'].shape[1] == final_inputs['position_ids'].shape[2], "The input ids and position ids should have the same length, but got {} and {}".format( 914 | final_inputs['input_ids'].shape[1], final_inputs['position_ids'].shape[2]) 915 | final_inputs = final_inputs.to(model.device) 916 | final_inputs['past_key_values'] = past_key_values 917 | final_inputs['use_cache'] = True 918 | 919 | cache_enable = lvu_config.enable 920 | lvu_config.enable = lvu_config.do_top_k_for_query # determine whether to do topk or not 921 | generated_ids = model.generate(**final_inputs, **generation_kwargs) 922 | lvu_config.enable = cache_enable 923 | end_of_decoding = time.time() 924 | decoding_time = end_of_decoding - start_of_decoding 925 | 926 | e2e_end = time.time() 927 | e2e_time = e2e_end - e2e_start 928 | 929 | print(f"total time spent fetching frames was: {vr.total_timing}") 930 | print(f"total time spent on processor was: {pixel_iter.processor_timing}") 931 | print(f"total time spent on prefill was: {total_prefill_time}") 932 | print(f"total time spent on decoding was: {decoding_time}") 933 | print(f"total time spent on e2e fetching and decoding was: {e2e_time}") 934 | print(f"Time saved by interleaved processing was: {vr.total_timing + pixel_iter.processor_timing + total_prefill_time + decoding_time - e2e_time}") 935 | 936 | generated_ids_trimmed = [ 937 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(final_inputs.input_ids, generated_ids) 938 | ] 939 | output_text = processor.batch_decode( 940 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 941 | ) 942 | return output_text 943 | 944 | sys.modules["qwen_vl_utils.vision_process"].smart_nframes = smart_nframes 945 | --------------------------------------------------------------------------------