├── .gitignore ├── .gitmodules ├── ARC_Hunyuan_Video_7B.pdf ├── ARIAL.TTF ├── License.txt ├── README.md ├── config ├── zero_stage1_config.json ├── zero_stage2_config.json └── zero_stage3_config.json ├── examples ├── demo1.mp4 ├── demo2.mp4 ├── demo3.mov └── temp.mp3 ├── figures ├── README.md ├── method.jpg ├── shortvid-bench.jpg └── teaser.jpg ├── model_train ├── dist_utils.py ├── model │ └── modeling_arc_hunyuan_video.py ├── patch │ ├── __init__.py │ ├── pad_data_collator.py │ ├── train_dataloader_patch.py │ └── train_sampler_patch.py └── train │ ├── __init__.py │ ├── arc_hunyuan_video_finetune.py │ ├── constants.py │ ├── dataset.py │ └── dataset_packed.py ├── model_vllm ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── hunyuan.cpython-310.pyc │ ├── hunyuan.cpython-311.pyc │ ├── hunyuan_video.cpython-310.pyc │ ├── hunyuan_video.cpython-311.pyc │ ├── monkey_patch_mrope.cpython-310.pyc │ ├── monkey_patch_mrope.cpython-311.pyc │ ├── video_audio_encoder.cpython-310.pyc │ ├── video_audio_encoder.cpython-311.pyc │ ├── video_audio_llm.cpython-310.pyc │ └── video_audio_llm.cpython-311.pyc ├── hunyuan.py ├── hunyuan_video.py ├── monkey_patch_mrope.py ├── setup_vllm_env.sh ├── video_audio_encoder.py └── video_audio_llm.py ├── requirements.txt ├── scripts └── arc_hunyuan_video_full_finetune.sh ├── sft_data ├── audios_mp3 │ ├── a3545skvqbz.mp3 │ ├── c3522vbgwaw.mp3 │ ├── e3556d48uo0.mp3 │ ├── q35134y59x8.mp3 │ ├── u3519j52lb4.mp3 │ ├── v3524wr6l4l.mp3 │ ├── w33698kgs05.mp3 │ ├── x3551nmkn8o.mp3 │ ├── x3555e2g3t8.mp3 │ └── z1468vawe14.mp3 ├── sft_jb_sp_abs_10.jsonl ├── sft_jb_sp_kd_10.json └── videos │ ├── a3545skvqbz.mp4 │ ├── c3522vbgwaw.mp4 │ ├── e3556d48uo0.mp4 │ ├── q35134y59x8.mp4 │ ├── u3519j52lb4.mp4 │ ├── v3524wr6l4l.mp4 │ ├── w33698kgs05.mp4 │ ├── x3551nmkn8o.mp4 │ ├── x3555e2g3t8.mp4 │ └── z1468vawe14.mp4 ├── video_inference.py ├── video_inference_sft.py └── video_inference_vllm.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "model_vllm/vllm"] 2 | path = model_vllm/vllm 3 | url = https://github.com/TencentARC/vllm.git/ 4 | -------------------------------------------------------------------------------- /ARC_Hunyuan_Video_7B.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/ARC_Hunyuan_Video_7B.pdf -------------------------------------------------------------------------------- /ARIAL.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/ARIAL.TTF -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | Tencent is pleased to support the open source community by making ARC-Hunyuan-Video-7B available. 2 | 3 | Copyright (C) 2025 Tencent. All rights reserved. 4 | 5 | ARC-Hunyuan-Video-7B is licensed under the License Terms of ARC-Hunyuan-Video-7B. 6 | 7 | For avoidance of doubts, ARC-Hunyuan-Video-7B refers to the inference code made publicly available by Tencent in accordance with ARC-Hunyuan-Video-7B in this repository. 8 | 9 | License Terms of ARC-Hunyuan-Video-7B: 10 | -------------------------------------------------------------------- 11 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 12 | 13 | - You agree to use the ARC-Hunyuan-Video-7B only for academic purposes, and refrain from using it for any commercial or production purposes under any circumstances. 14 | 15 | - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARC-Hunyuan-Video-7B 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2507.20939-b31b1b.svg)](https://arxiv.org/abs/2507.20939) 4 | [![Demo](https://img.shields.io/badge/ARC-Demo-blue)](https://arc.tencent.com/en/ai-demos/multimodal) 5 | [![Static Badge](https://img.shields.io/badge/Model-Huggingface-yellow)](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) 6 | [![Static Badge](https://img.shields.io/badge/Model-Huggingface-yellow)](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B) 7 | [![Static Badge](https://img.shields.io/badge/Model-Huggingface-yellow)](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B-Narrator) 8 | [![Blog](https://img.shields.io/badge/ARC-Blog-green)](https://tencentarc.github.io/posts/arc-video-announcement/) 9 | [![Benchmark](https://img.shields.io/badge/ShortVid-Bench-orange)](https://huggingface.co/datasets/TencentARC/ShortVid-Bench) 10 | 11 | 12 | Please note that in our Demo, ARC-Hunyuan-Video-7B is the model consistent with the model checkpoint and the one described in the paper, while ARC-Hunyuan-Video-7B-V0 only supports video description and summarization in Chinese. 13 | Due to API file size limits, our demo uses compressed input video resolutions, which may cause slight performance differences from the paper. For original results, please run locally. 14 | 15 | 16 | ## Introduction 17 | 18 | We introduce **ARC-Hunyuan-Video-7B**, a powerful multimodal model designed for _understanding real-world short videos_. 19 | Understanding user-generated videos is actually challenging due to their complex visual elements, high 20 | information density in both visuals and audio, and fast pacing that focuses on emotional expression and viewpoint delivery. 21 | To address this challenge, ARC-Hunyuan-Video-7B 22 | processes visual, audio, and textual signals end-to-end for a deep, structured understanding of video through integrating and reasoning over multimodal cues. 23 | Stress test reports show an inference time of just 10 seconds for a one-minute video on H20 GPU, yielding an average of 500 tokens, with 24 | inference accelerated by the vLLM framework. 25 | 26 | Compared to prior arts, we introduces a new paradigm of **Structured Video Comprehension**, with capabilities including: 27 | 28 | - **Deep Understanding of Real-World Short Videos:** ARC-Hunyuan-Video-7B excels at analyzing user-generated content from platforms like WeChat Channels and TikTok. It goes beyond surface-level descriptions to grasp the creator's intent, emotional expression, and core message by processing complex visual elements, dense audio cues, and rapid pacing. 29 | - **Synchronized Audio-Visual Reasoning:** The synchronization of raw visual and audio signals allows our model to answer complex questions that are impossible to solve with only one modality, such as understanding humor in a skit or details in a product review. 30 | - **Precise Temporal Awareness:** ARC-Hunyuan-Video-7B knows not just _what_ happens, but _when_ it happens. It supports multi-granularity timestamped captioning, temporal video grounding, and detailed event summarization, making it perfect for applications like video search, highlight generation, and content analysis. 31 | - **Advanced Reasoning and Application Versatility:** Leveraging a comprehensive multi-stage training regimen including Reinforcement Learning (RL), ARC-Hunyuan-Video-7B demonstrates strong reasoning capabilities. It supports zero-shot or few-shot fine-tuning for diverse downstream applications like video tagging, recommendation, and retrieval. 32 | 33 | The model is capable of multi-granularity timestamped video captioning and summarization, open-ended video question answering, temporal video grounding, and 34 | video reasoning as below, 35 | 36 |

37 | 38 |

39 | 40 | Specifically, ARC-Hunyuan-Video-7B is built on top of the Hunyuan-7B vision-language model with the following key designs to meet the requirements of effective structured video comprehension: 41 | 42 | - An extra audio encoder with fine-grained visual-audio synchronization for temporally aligned visual-audio inputs 43 | - A timestamp overlay mechanism on visual frames that explicitly provides the model with temporal awareness 44 | - Millions of real-world videos with a totally automated bootstrapped annotation pipeline 45 | - A comprehensive training regimen based on the finding that grounding the model in objective 46 | tasks with RL is key to unlocking high-quality, subjective understanding 47 | 48 |

49 | 50 |

51 | 52 | ## ARC-Qwen-Video-7B 53 | In this version, we have switched the base model from hunyuan VLM to [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) and introduce [ARC-Qwen-Video-7B](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B). We used the same training data and training stages. Please refere to the `arc-qwen-video` branch for details. 54 | 55 | We are also introducing a new model, [ARC-Qwen-Video-7B-Narrator](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B-Narrator). It can output **timestamped video descriptions, speaker identities, and the specific ASR (Automatic Speech Recognition) content**. By processing its output with an external LLM, you can obtain more comprehensive structured information as follows (Click to watch the video): 56 | 57 | [视频](https://www.youtube.com/watch?v=Bz1T4wCuWc8) 58 | 59 | > ### 视频概述 60 | > 61 | > 这是一个喜剧短片,讲述了一位丈夫藏在棉衣里的私房钱被妻子意外发现,并误以为是丈夫准备的“惊喜”礼物。视频通过夫妻二人的一通电话,生动展现了丈夫从悠闲自得,到震惊错愕,再到崩溃无奈的全过程,充满了戏剧性的反转和幽默感。 62 | > 63 | > ### 情节发展分解 64 | > 65 | > 视频情节围绕一通电话展开,以下是详细的时间线、场景、说话人和对话内容: 66 | > 67 | > 68 | > 69 | > 70 | > 71 | > 72 | > 73 | > 74 | > 75 | > 76 | > 77 | > 78 | > 79 | > 80 | > 81 | > 82 | > 83 | > 84 | > 85 | > 86 | > 87 | > 88 | > 89 | > 90 | > 91 | > 92 | > 93 | > 94 | > 95 | > 96 | > 97 | > 98 | > 99 | > 100 | > 101 | > 102 | > 103 | > 104 | > 105 | > 106 | > 107 | > 108 | > 109 | > 110 | > 111 | > 112 | > 113 | > 114 | > 115 | > 116 | > 117 | > 118 | > 119 | > 120 | > 121 | > 122 | > 123 | > 124 | > 125 | > 126 | > 127 | > 128 | > 129 | > 130 | > 131 | > 132 | > 133 | > 134 | > 135 | > 136 | > 137 | > 138 | > 139 | > 140 | > 141 | > 142 | > 143 | > 144 | >
时间戳场景描述说话人对话内容 (ASR)
0:00 - 0:05丈夫头戴浴帽,围着浴巾,在室内泳池边悠闲地自拍。(无对话)
0:05 - 0:10镜头切换:妻子在服装店里,满脸幸福地给丈夫打电话。妻子“哎,老公,老公,我爱你爱你,爱死你了,么么么。”
0:10 - 0:18丈夫接起电话,对妻子的热情感到好奇,妻子则兴奋地揭晓了“惊喜”。丈夫“哎,怎么了你这是,这么高兴啊?”
妻子“今天我在我的棉衣兜里,发现了你给我的惊喜,一万元哟。”
0:18 - 0:27听到“一万元”,丈夫表情瞬间凝固,从疑惑变为震惊和懊悔,但仍强装镇定。丈夫“啊?好啊,你你你你开心高兴就行。”
0:27 - 0:34妻子开心地告知钱的用途,丈夫的表情彻底僵住,震惊加剧。妻子“我当然高兴啊,我用它买了一件新衣裳,等晚上回去穿给你看啊。”
0:34 - 0:46丈夫确认钱已被花掉,情绪崩溃。妻子则认为是丈夫授权的,丈夫忍不住骂了一句。丈夫“你已经给买成衣服了?”
妻子“当然啦,不是你说的吗?说买我自己喜欢的东西。老公,你真是太好了。”
丈夫“你真是败家娘们儿啊你。”
0:46 - 0:59妻子察觉丈夫语气不对,丈夫立刻改口掩饰,并催促妻子早点回家。妻子“什么,老公,你说什么?”
丈夫“啊?我说好啊,你漂亮我高兴。”
妻子“你说的,老公。你今天呀,一定要早点回来哟,我等你哟。”
丈夫“行行行行行。”
145 | > 146 | > ### 人物与核心冲突 147 | > 148 | > #### 1. 人物分析 149 | > 150 | > 丈夫: 151 | > 行为: 藏私房钱,事发后极力掩饰自己的真实情绪(心痛、懊悔)。 152 | > 心理变化: 悠闲 -> 疑惑 -> 震惊 -> 崩溃 -> 无奈接受。 153 | > 特点: 爱面子,对妻子既有爱意也有无奈,典型的“妻管严”形象。 154 | > 155 | > 妻子: 156 | > 行为: 发现钱后,认为是丈夫的爱意表达,并迅速将其消费。 157 | > 心理变化: 全程处于发现“惊喜”的幸福和喜悦中。 158 | > 特点: 天真、消费果断,对丈夫充满信任和爱意。 159 | > 160 | > #### 2. 核心冲突 161 | > 162 | > 视频的核心冲突在于 “信息的严重不对等” 所造成的戏剧性误会: 163 | > 164 | > * 丈夫视角: 辛苦攒下的 $10,000$ 元私房钱被意外发现并花掉,是一场“惊吓”。 165 | > * 妻子视角: 丈夫精心准备的 $10,000$ 元浪漫基金,是一份巨大的“惊喜”。 166 | > 167 | > 这个误会推动了整个故事的发展,丈夫的“打碎牙往肚里咽”和妻子的“理所当然的幸福”形成了强烈的喜剧反差,制造了密集的笑点。 168 | > 169 | > ### 总结 170 | > 171 | > 该视频通过一个关于“私房钱”的常见家庭情景,巧妙地构建了一个充满反转和幽默的故事。它利用戏剧性讽刺(观众和丈夫知道真相,而妻子蒙在鼓里)的手法,精准捕捉了丈夫在突发状况下的复杂心理活动。整个过程不仅笑料百出,也含蓄地探讨了夫妻间的沟通、信任和金钱观等话题,容易引发观众的共鸣和讨论。 172 | 173 | 174 | ## News 175 | - 2025.09.19: We release [ARC-Qwen-Video-7B](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B), which switched the base model from hunyuan VLM to [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). We also release [ARC-Qwen-Video-7B-Narrator](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B-Narrator), which can output timestamped video descriptions, speaker identities, and the specific ASR (Automatic Speech Recognition) content. Please refere to the `arc-qwen-video` branch for details. 176 | - 2025.08.05: We release [ShortVid-Bench](https://huggingface.co/datasets/TencentARC/ShortVid-Bench), a specialized, human-annotated benchmark with multiple-choice questions for evaluating short-video understanding. 177 | - 2025.07.29: We release the training code for instruction tuning. 178 | - 2025.07.25: We release the [model checkpoint](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) and inference code of ARC-Hunyuan-Video-7B including [vLLM](https://github.com/vllm-project/vllm) version. 179 | - 2025.07.25: We release the [API service](https://arc.tencent.com/zh/document/ARC-Hunyuan-Video-7B) of ARC-Hunyuan-Video-7B, which is supported by [vLLM](https://github.com/vllm-project/vllm). We release two versions: one is V0, which only supports video description and summarization in Chinese; the other is the version consistent with the model checkpoint and the one described in the paper. 180 | 181 | ## TODOs 182 | 183 | - [x] Relase ShortVid-Bench, a specialized, human-annotated benchmark with multiple-choice questions 184 | - [x] Release training code for instruction tuning 185 | 186 | ## Usage 187 | 188 | ### Dependencies 189 | Our inference can be performed on a single NVIDIA A100 40GB GPU. 190 | 191 | ### Installation 192 | 193 | Clone the repo and install dependent packages 194 | 195 | ```bash 196 | git clone https://github.com/TencentARC/ARC-Hunyuan-Video-7B.git 197 | cd ARC-Hunyuan-Video-7B 198 | # Install torch 2.6.0 based on your CUDA version 199 | # CUDA 11.8 200 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118 201 | # CUDA 12.4 202 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 203 | # CUDA 12.6 204 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126 205 | 206 | pip install -r requirements.txt 207 | pip install git+https://github.com/liyz15/transformers.git@arc_hunyuan_video 208 | 209 | # Install flash-attention based on your python version 210 | # If you are unable to install flash-attention, you can modify attn_implementation to "sdpa" in video_inference.py 211 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl 212 | 213 | 214 | # (Optional) For vllm, please follow the instructions below, 215 | git submodule update --init --recursive 216 | cd model_vllm/vllm/ 217 | export SETUPTOOLS_SCM_PRETEND_VERSION="0.8.5" 218 | wget https://wheels.vllm.ai/ed2462030f2ccc84be13d8bb2c7476c84930fb71/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl 219 | export VLLM_PRECOMPILED_WHEEL_LOCATION=$(pwd)/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl 220 | pip install --editable . 221 | # Install flash-attention if you haven't installed it 222 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl 223 | ``` 224 | 225 | ### Model Weights 226 | 227 | - Download [ARC-Hunyuan-Video-7B](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) including ViT and LLM and the original [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) . 228 | 229 | ### Inference 230 | 231 | ```bash 232 | # Our model currently excels at processing short videos of up to 5 minutes. 233 | # If your video is longer, we recommend following the approach used in our demo and API: 234 | # split the video into segments for inference, and then use an LLM to integrate the results. 235 | ``` 236 | 237 | #### Inference without vllm 238 | 239 | ```bash 240 | cd ARC-Hunyuan-Video-7B 241 | python3 video_inference.py 242 | ``` 243 | 244 | #### Inference with vllm 245 | 246 | ```bash 247 | cd ARC-Hunyuan-Video-7B 248 | python3 video_inference_vllm.py 249 | ``` 250 | 251 | ## Training 252 | 253 | ### Installation 254 | 255 | Clone the repo and install dependent packages 256 | 257 | ```bash 258 | git clone https://github.com/TencentARC/ARC-Hunyuan-Video-7B.git 259 | cd ARC-Hunyuan-Video-7B 260 | # Install torch 2.6.0 261 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 262 | pip install -r requirements.txt 263 | pip install git+https://github.com/liyz15/transformers.git@arc_hunyuan_video 264 | 265 | # For training 266 | pip install accelerate==1.9.0 267 | # Upgrade the GCC version to 9.0 or above 268 | sudo dnf install gcc-toolset-9 269 | scl enable gcc-toolset-9 bash 270 | source /opt/rh/gcc-toolset-9/enable 271 | gcc -v 272 | ``` 273 | 274 | ### Model Weights 275 | 276 | - Download [ARC-Hunyuan-Video-7B](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) including ViT and LLM and the original [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) . 277 | 278 | ### Data Preparation 279 | 280 | Please follow the format of "sft_data/sft_jb_sp_kd_10.json". 281 | 282 | - "root" specifies the path of training videos (supports .mp4; videos shorter than 5 minutes yield better results). 283 | - "audio_root" specifies the path of corresponding audios (Please use the .mp3 format). You can use the code below to extract audio from a video and save it. 284 | 285 | ```bash 286 | from moviepy.editor import VideoFileClip 287 | from pydub import AudioSegment 288 | 289 | video = VideoFileClip(video_path) 290 | if video.audio is not None: 291 | video.audio.write_audiofile(audio_path, logger=None) 292 | video.audio.close() 293 | else: 294 | duration_ms = int(video.duration * 1000) 295 | silent_audio = AudioSegment.silent(duration=duration_ms) 296 | silent_audio.export(audio_path, format="mp3") 297 | video.close() 298 | ``` 299 | 300 | - "annotation" specifies the path of the annotation in the format of ".jsonl". 301 | 302 | ### Model Fully-finetune 303 | 304 | ```bash 305 | # We use DeepSpeed Zero-3 with two 98G-H20 GPUs. 306 | bash scripts/arc_hunyuan_video_full_finetune.sh 307 | ``` 308 | 309 | ### Model Inference 310 | 311 | After finishing training, the model will be saved in ${OUTPUT_DIR}. 312 | 313 | ```bash 314 | # Copy the model-related config files to the directory. 315 | cd path of the downloaded ARC-Hunyuan-Video-7B 316 | cp generation_config.json preprocessor_config.json ${OUTPUT_DIR}/checkpoint-500/. 317 | 318 | cd ARC-Hunyuan-Video-7B 319 | # Modify the prompt based on your fine-tuning data, and specify the path of the fine-tuned model. 320 | python3 video_inference_sft.py 321 | ``` 322 | 323 | ## API service 324 | 325 | We also provide access to the model via API, which is supported by [vLLM](https://github.com/vllm-project/vllm). For details, please refer to the [documentation](https://arc.tencent.com/zh/document/ARC-Hunyuan-Video-7B). 326 | 327 | We release two versions: one is V0, which only supports video description and summarization in Chinese; the other is the version consistent with the model checkpoint and the one described in the paper, which is capable of multi-granularity timestamped video captioning and summarization, open-ended video question answering, temporal video grounding, and video reasoning (It supports Chinese and English videos and particularly excels at Chinese). 328 | For videos longer than 5 minutes, we only support structured descriptions. We process these videos in 5-minute segments and use an LLM to integrate the inference results. 329 | 330 | If you only need to understand and summarize short Chinese videos, we recommend using the V0 version. 331 | 332 | Due to video file size limitations imposed by the deployment API, we compressed input video resolutions for our online demo and API services. Consequently, model performance in these interfaces may slightly deviate from the results reported in the paper. To reproduce the original performance, we recommend local inference. 333 | 334 | 335 | ## ShortVid-Bench 336 | 337 | Existing benchmarks often fall short in capturing the nuanced complexities 338 | of user-generated content. To rigorously evaluate model’s ability to **understand real-world short videos**, 339 | we construct a specialized benchmark named **ShortVid-Bench**. Specifically, we develop an automated pipeline 340 | to generate multi-dimensional questions for each video, targeting capabilities that signify a deep, holistic 341 | comprehension through integrating both visual and audio cues. These dimensions include: 342 | - Temporal Reasoning and Localization 343 | - Affective Intent Classification 344 | - Creator Intent Taxonomy 345 | - Narrative Comprehension 346 | - Humor & Meme Deconstruction 347 | - Creative Innovation Analysis 348 | 349 | For objective assessment, we employ a multiple-choice question (MCQ) format following previous work. Each question is carefully curated by human annotators who 350 | provide the ground-truth answer and design challenging, plausible distractors. Collectively, these dimensions with a total of 1,000 multiple-choice questions 351 | push the evaluation beyond mere descriptive captioning, demanding a genuine comprehension of the video’s 352 | context, intent, and narrative. 353 | 354 |

355 | 356 |

357 | 358 | ### Model Performance 359 | | Model | fps | #frames | think | ShortVid-Bench | 360 | | :--- | :--- | :--- | :--- | :--- | 361 | | Qwen2.5-VL-7B-Instruct | 1.0 | 150 | × | 69.3 | 362 | | Qwen2.5-Omni-7B | 1.0 | 150 | × | 69.7 | 363 | | Keye-VL-8B | 1.0 | 150 | ✓ | 56.3 | 364 | | ARC-Hunyuan-Video-7B | 1.0 | 150 | ✓ | **73.0** | 365 | 366 | 367 | Please note that the results in the table above are different from those in 368 | ARC-Hunyuan-Video-7B. 369 | This is because, after releasing the technical report, we expanded the benchmark dataset to 1,000 samples, whereas the results in the paper were based on 400 samples. 370 | 371 | 372 | ## Future Work 373 | 374 | We observe that incorporating generic video datasets during training may inadvertently compromise the model's capacity for real-world video understanding, potentially due to domain shift or noise introduced by non-real-world samples. To address this limitation, we plan to develop a dedicated model trained exclusively on rigorously curated real-world video data. 375 | 376 | ## Citation 377 | 378 | If you find the work helpful, please consider citing: 379 | 380 | ```bash 381 | @article{ge2025arc, 382 | title={ARC-Hunyuan-Video-7B: Structured Video Comprehension of Real-World Shorts}, 383 | author={Ge, Yuying and Ge, Yixiao and Li, Chen and Wang, Teng and Pu, Junfu and Li, Yizhuo and Qiu, Lu and Ma, Jin and Duan, Lisheng and Zuo, Xinyu and others}, 384 | journal={arXiv preprint arXiv:2507.20939}, 385 | year={2025} 386 | } 387 | ``` 388 | 389 | ## Acknowledge 390 | 391 | Our training code is built upon [InternVL](https://github.com/OpenGVLab/InternVL). Thanks for their excellent work! 392 | -------------------------------------------------------------------------------- /config/zero_stage1_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 1e9, 6 | "overlap_comm": true, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 1e9, 9 | "contiguous_gradients": true 10 | }, 11 | "fp16": { 12 | "enabled": "auto", 13 | "auto_cast": true, 14 | "loss_scale": 0, 15 | "initial_scale_power": 32, 16 | "loss_scale_window": 1000, 17 | "hysteresis": 2, 18 | "min_loss_scale": 1 19 | }, 20 | "bf16": { 21 | "enabled": "auto" 22 | }, 23 | "optimizer": { 24 | "type": "AdamW", 25 | "params": { 26 | "lr": "auto", 27 | "betas": [ 28 | 0.9, 29 | 0.999 30 | ], 31 | "eps": 1e-8, 32 | "weight_decay": "auto" 33 | } 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 2000, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": true 41 | } 42 | -------------------------------------------------------------------------------- /config/zero_stage2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 1e8, 6 | "overlap_comm": true, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 1e8, 9 | "contiguous_gradients": true 10 | }, 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | "optimizer": { 15 | "type": "AdamW", 16 | "params": { 17 | "lr": "auto", 18 | "betas": [ 19 | 0.9, 20 | 0.999 21 | ], 22 | "eps": 1e-8, 23 | "weight_decay": "auto" 24 | } 25 | }, 26 | "gradient_accumulation_steps": "auto", 27 | "gradient_clipping": "auto", 28 | "steps_per_print": 2000, 29 | "train_batch_size": "auto", 30 | "train_micro_batch_size_per_gpu": "auto", 31 | "wall_clock_breakdown": false 32 | } 33 | -------------------------------------------------------------------------------- /config/zero_stage3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "stage3_prefetch_bucket_size": 1e9, 9 | "stage3_param_persistence_threshold": 1e7, 10 | "stage3_max_live_parameters": 1e9, 11 | "stage3_max_reuse_distance": 1e9, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "fp16": { 15 | "enabled": "auto", 16 | "auto_cast": true, 17 | "loss_scale": 0, 18 | "initial_scale_power": 32, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "bf16": { 24 | "enabled": "auto" 25 | }, 26 | "optimizer": { 27 | "type": "AdamW", 28 | "params": { 29 | "lr": "auto", 30 | "betas": [ 31 | 0.9, 32 | 0.999 33 | ], 34 | "eps": 1e-8, 35 | "weight_decay": "auto" 36 | } 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": true 44 | } 45 | -------------------------------------------------------------------------------- /examples/demo1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/demo1.mp4 -------------------------------------------------------------------------------- /examples/demo2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/demo2.mp4 -------------------------------------------------------------------------------- /examples/demo3.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/demo3.mov -------------------------------------------------------------------------------- /examples/temp.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/temp.mp3 -------------------------------------------------------------------------------- /figures/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figures/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/figures/method.jpg -------------------------------------------------------------------------------- /figures/shortvid-bench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/figures/shortvid-bench.jpg -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/figures/teaser.jpg -------------------------------------------------------------------------------- /model_train/dist_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import socket 13 | import subprocess 14 | from datetime import timedelta 15 | 16 | import deepspeed 17 | import torch 18 | import torch.multiprocessing as mp 19 | from torch import distributed as dist 20 | 21 | timeout = timedelta(minutes=60) 22 | 23 | 24 | def _find_free_port(): 25 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 26 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 27 | # Binding to port 0 will cause the OS to find an available port for us 28 | sock.bind(('', 0)) 29 | port = sock.getsockname()[1] 30 | sock.close() 31 | # NOTE: there is still a chance the port could be taken by other processes. 32 | return port 33 | 34 | 35 | def _is_free_port(port): 36 | ips = socket.gethostbyname_ex(socket.gethostname())[-1] 37 | ips.append('localhost') 38 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 39 | return all(s.connect_ex((ip, port)) != 0 for ip in ips) 40 | 41 | 42 | def init_dist(launcher, backend='nccl', **kwargs): 43 | if mp.get_start_method(allow_none=True) is None: 44 | mp.set_start_method('spawn') 45 | if launcher == 'pytorch': 46 | _init_dist_pytorch(backend, **kwargs) 47 | elif launcher == 'mpi': 48 | _init_dist_mpi(backend, **kwargs) 49 | elif launcher == 'slurm': 50 | _init_dist_slurm(backend, **kwargs) 51 | else: 52 | raise ValueError(f'Invalid launcher type: {launcher}') 53 | 54 | 55 | def _init_dist_pytorch(backend, **kwargs): 56 | # TODO: use local_rank instead of rank % num_gpus 57 | rank = int(os.environ['RANK']) 58 | num_gpus = torch.cuda.device_count() 59 | torch.cuda.set_device(rank % num_gpus) 60 | # dist.init_process_group(backend=backend, **kwargs) 61 | deepspeed.init_distributed(dist_backend=backend) 62 | 63 | 64 | def _init_dist_mpi(backend, **kwargs): 65 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 66 | torch.cuda.set_device(local_rank) 67 | if 'MASTER_PORT' not in os.environ: 68 | # 29500 is torch.distributed default port 69 | os.environ['MASTER_PORT'] = '29500' 70 | if 'MASTER_ADDR' not in os.environ: 71 | raise KeyError('The environment variable MASTER_ADDR is not set') 72 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] 73 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] 74 | dist.init_process_group(backend=backend, **kwargs) 75 | 76 | 77 | def _init_dist_slurm(backend, port=None): 78 | """Initialize slurm distributed training environment. 79 | 80 | If argument ``port`` is not specified, then the master port will be system 81 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 82 | environment variable, then a default port ``29500`` will be used. 83 | 84 | Args: 85 | backend (str): Backend of torch.distributed. 86 | port (int, optional): Master port. Defaults to None. 87 | """ 88 | proc_id = int(os.environ['SLURM_PROCID']) 89 | ntasks = int(os.environ['SLURM_NTASKS']) 90 | node_list = os.environ['SLURM_NODELIST'] 91 | num_gpus = torch.cuda.device_count() 92 | torch.cuda.set_device(proc_id % num_gpus) 93 | addr = subprocess.getoutput( 94 | f'scontrol show hostname {node_list} | head -n1') 95 | # specify master port 96 | if port is not None: 97 | os.environ['MASTER_PORT'] = str(port) 98 | elif 'MASTER_PORT' in os.environ: 99 | pass # use MASTER_PORT in the environment variable 100 | else: 101 | # if torch.distributed default port(29500) is available 102 | # then use it, else find a free port 103 | if _is_free_port(29500): 104 | os.environ['MASTER_PORT'] = '29500' 105 | else: 106 | os.environ['MASTER_PORT'] = str(_find_free_port()) 107 | # use MASTER_ADDR in the environment variable if it already exists 108 | if 'MASTER_ADDR' not in os.environ: 109 | os.environ['MASTER_ADDR'] = addr 110 | os.environ['WORLD_SIZE'] = str(ntasks) 111 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 112 | os.environ['RANK'] = str(proc_id) 113 | # dist.init_process_group(backend=backend, timeout=timeout) 114 | deepspeed.init_distributed(dist_backend=backend) 115 | -------------------------------------------------------------------------------- /model_train/patch/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .pad_data_collator import (concat_pad_data_collator, 8 | dpo_concat_pad_data_collator, 9 | pad_data_collator) 10 | from .train_dataloader_patch import replace_train_dataloader 11 | from .train_sampler_patch import replace_train_sampler 12 | 13 | __all__ = ['replace_llama_attn_with_flash_attn', 14 | 'replace_llama_rmsnorm_with_fused_rmsnorm', 15 | 'replace_llama2_attn_with_flash_attn', 16 | 'replace_train_sampler', 17 | 'replace_train_dataloader', 18 | 'replace_internlm2_attention_class', 19 | 'replace_qwen2_attention_class', 20 | 'replace_phi3_attention_class', 21 | 'replace_llama_attention_class', 22 | 'pad_data_collator', 23 | 'dpo_concat_pad_data_collator', 24 | 'concat_pad_data_collator', 25 | 'apply_liger_kernel_to_internvit'] 26 | -------------------------------------------------------------------------------- /model_train/patch/pad_data_collator.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | import numpy as np 12 | import torch 13 | 14 | IGNORE_INDEX = -100 15 | 16 | def pad_data_collator(features, pad_id=0): 17 | 18 | first = features[0] 19 | batch = {} 20 | 21 | batch_lens = [feat['input_ids'].shape for feat in features] 22 | max_item_length = max(batch_lens)[0] 23 | for idx in range(len(features)): 24 | feat = features[idx] 25 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 26 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] 27 | feat['input_ids'] = temp_input_ids 28 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 29 | temp_labels[:feat['labels'].shape[0]] = feat['labels'] 30 | feat['labels'] = temp_labels 31 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 32 | 33 | # Special handling for labels. 34 | # Ensure that tensor is created with the correct type 35 | # (it should be automatically the case, but let's make sure of it.) 36 | if 'label' in first and first['label'] is not None: 37 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] 38 | dtype = torch.long if isinstance(label, int) else torch.float 39 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 40 | elif 'label_ids' in first and first['label_ids'] is not None: 41 | if isinstance(first['label_ids'], torch.Tensor): 42 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 43 | else: 44 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float 45 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) 46 | 47 | # Handling of all other possible keys. 48 | # Again, we will use the first element to figure out which key/values are not None for this model. 49 | for k, v in first.items(): 50 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): 51 | if isinstance(v, torch.Tensor): 52 | batch[k] = torch.stack([f[k] for f in features]) 53 | elif isinstance(v, np.ndarray): 54 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 55 | else: 56 | batch[k] = torch.tensor([f[k] for f in features]) 57 | return batch 58 | 59 | 60 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0): 61 | 62 | first = features[0] 63 | batch = {} 64 | 65 | batch_lens = [feat['input_ids'].shape for feat in features] 66 | max_item_length = max_item_length or max(batch_lens)[0] 67 | for idx in range(len(features)): 68 | feat = features[idx] 69 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 70 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] 71 | feat['input_ids'] = temp_input_ids 72 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 73 | temp_labels[:feat['labels'].shape[0]] = feat['labels'] 74 | feat['labels'] = temp_labels 75 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 76 | 77 | # if 'position_ids' in feat: 78 | # temp_position_ids = torch.LongTensor([pad_id] * max_item_length) 79 | # temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] 80 | # feat['position_ids'] = temp_position_ids 81 | 82 | if 'loss_weight' in feat: 83 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length) 84 | temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight'] 85 | feat['loss_weight'] = temp_loss_weight 86 | 87 | # Special handling for labels. 88 | # Ensure that tensor is created with the correct type 89 | # (it should be automatically the case, but let's make sure of it.) 90 | if 'label' in first and first['label'] is not None: 91 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] 92 | dtype = torch.long if isinstance(label, int) else torch.float 93 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 94 | elif 'label_ids' in first and first['label_ids'] is not None: 95 | if isinstance(first['label_ids'], torch.Tensor): 96 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 97 | else: 98 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float 99 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) 100 | 101 | # Handling of all other possible keys. 102 | # Again, we will use the first element to figure out which key/values are not None for this model. 103 | for k, v in first.items(): 104 | if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ 105 | v is not None and not isinstance(v, str): 106 | if isinstance(v, torch.Tensor): 107 | batch[k] = torch.stack([f[k] for f in features]) 108 | elif isinstance(v, np.ndarray): 109 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 110 | else: 111 | batch[k] = torch.tensor([f[k] for f in features]) 112 | if k in ('pixel_values', 'image_flags'): 113 | if isinstance(v, torch.Tensor): 114 | batch[k] = torch.concat([f[k] for f in features]) 115 | elif isinstance(v, np.ndarray): 116 | batch[k] = torch.concat(np.stack([f[k] for f in features])) 117 | else: 118 | batch[k] = torch.concat([f[k] for f in features]) 119 | return batch 120 | 121 | 122 | def dpo_concat_pad_data_collator(features, pad_id=0): 123 | 124 | first = features[0] 125 | batch = {} 126 | 127 | for prefix in ['chosen_', 'rejected_']: 128 | batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features] 129 | max_item_length = max(batch_lens) 130 | for idx in range(len(features)): 131 | feat = features[idx] 132 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 133 | temp_input_ids[:feat[f'{prefix}input_ids'].shape[0]] = feat[f'{prefix}input_ids'] 134 | feat[f'{prefix}input_ids'] = temp_input_ids 135 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 136 | temp_labels[:feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels'] 137 | feat[f'{prefix}labels'] = temp_labels 138 | feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id) 139 | 140 | # Handling of all other possible keys. 141 | # Again, we will use the first element to figure out which key/values are not None for this model. 142 | for k, v in first.items(): 143 | if k not in ('pixel_values', 'image_flags') and \ 144 | v is not None and not isinstance(v, str): 145 | if isinstance(v, torch.Tensor): 146 | batch[k] = torch.stack([f[k] for f in features]) 147 | elif isinstance(v, np.ndarray): 148 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 149 | else: 150 | batch[k] = torch.tensor([f[k] for f in features]) 151 | if k in ('pixel_values', 'image_flags'): 152 | if isinstance(v, torch.Tensor): 153 | batch[k] = torch.concat([f[k] for f in features]) 154 | elif isinstance(v, np.ndarray): 155 | batch[k] = torch.concat(np.stack([f[k] for f in features])) 156 | else: 157 | batch[k] = torch.concat([f[k] for f in features]) 158 | if '_logps' in k: 159 | batch[k] = [f[k] for f in features] 160 | return batch 161 | -------------------------------------------------------------------------------- /model_train/patch/train_dataloader_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | import datasets 12 | import torch 13 | import transformers 14 | from functools import partial 15 | import torch.distributed as dist 16 | from torch.utils.data import DataLoader 17 | from transformers.trainer import is_datasets_available, seed_worker 18 | 19 | 20 | def get_train_dataloader(self) -> DataLoader: 21 | """ 22 | Returns the training [`~torch.utils.data.DataLoader`]. 23 | 24 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 25 | training if necessary) otherwise. 26 | 27 | Subclass and override this method if you want to inject some custom behavior. 28 | """ 29 | if self.train_dataset is None: 30 | raise ValueError('Trainer: training requires a train_dataset.') 31 | 32 | train_dataset = self.train_dataset 33 | data_collator = self.data_collator 34 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 35 | train_dataset = self._remove_unused_columns(train_dataset, description='training') 36 | else: 37 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training') 38 | 39 | dataloader_params = { 40 | 'batch_size': self._train_batch_size, 41 | 'collate_fn': data_collator, 42 | 'num_workers': self.args.dataloader_num_workers, 43 | 'pin_memory': self.args.dataloader_pin_memory, 44 | 'persistent_workers': self.args.dataloader_persistent_workers, 45 | } 46 | 47 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 48 | dataloader_params['sampler'] = self._get_train_sampler() 49 | dataloader_params['drop_last'] = self.args.dataloader_drop_last 50 | 51 | num_workers = self.args.dataloader_num_workers 52 | rank = dist.get_rank() if dist.is_initialized() else 0 53 | 54 | dataloader_params['worker_init_fn'] = partial(seed_worker, num_workers=num_workers, rank=rank) 55 | #dataloader_params['worker_init_fn'] = seed_worker 56 | 57 | if self.args.use_packed_ds: 58 | return DataLoader(train_dataset, **dataloader_params) 59 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 60 | 61 | 62 | def replace_train_dataloader(): 63 | transformers.Trainer.get_train_dataloader = get_train_dataloader 64 | # print('Replace train dataloader!!') 65 | -------------------------------------------------------------------------------- /model_train/patch/train_sampler_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | from typing import List, Optional 12 | 13 | import torch 14 | import transformers 15 | from torch.utils.data import Dataset, Sampler 16 | from transformers.tokenization_utils_base import BatchEncoding 17 | from transformers.trainer import (LengthGroupedSampler, RandomSampler, 18 | has_length) 19 | from transformers.trainer_pt_utils import logger 20 | 21 | 22 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38 23 | def split_to_even_chunks(indices, lengths, num_chunks): 24 | """ 25 | Split a list of indices into `chunks` chunks of roughly equal lengths. 26 | """ 27 | 28 | if len(indices) % num_chunks != 0: 29 | return [indices[i::num_chunks] for i in range(num_chunks)] 30 | 31 | num_indices_per_chunk = len(indices) // num_chunks 32 | 33 | chunks = [[] for _ in range(num_chunks)] 34 | chunks_lengths = [0 for _ in range(num_chunks)] 35 | for index in indices: 36 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 37 | chunks[shortest_chunk].append(index) 38 | chunks_lengths[shortest_chunk] += lengths[index] 39 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 40 | chunks_lengths[shortest_chunk] = float('inf') 41 | 42 | return chunks 43 | 44 | 45 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88 46 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 47 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 48 | indices = torch.randperm(len(lengths), generator=generator) 49 | megabatch_size = world_size * batch_size 50 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 51 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 52 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 53 | 54 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 55 | 56 | 57 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99 58 | class LengthGroupedSampler(Sampler): 59 | r""" 60 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 61 | keeping a bit of randomness. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | batch_size: int, 67 | world_size: int, 68 | dataset: Optional[Dataset] = None, 69 | lengths: Optional[List[int]] = None, 70 | model_input_name: Optional[str] = None, 71 | generator=None, 72 | ): 73 | if dataset is None and lengths is None: 74 | raise ValueError('One of dataset and lengths must be provided.') 75 | 76 | self.batch_size = batch_size 77 | if lengths is None: 78 | model_input_name = model_input_name if model_input_name is not None else 'input_ids' 79 | if ( 80 | not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) 81 | or model_input_name not in dataset[0] 82 | ): 83 | raise ValueError( 84 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an ' 85 | f"'{model_input_name}' key." 86 | ) 87 | lengths = [len(feature[model_input_name]) for feature in dataset] 88 | elif isinstance(lengths, torch.Tensor): 89 | logger.info( 90 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...' 91 | ) 92 | lengths = lengths.tolist() 93 | self.world_size = world_size 94 | self.lengths = lengths 95 | self.generator = generator 96 | 97 | def __len__(self): 98 | return len(self.lengths) 99 | 100 | def __iter__(self): 101 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 102 | return iter(indices) 103 | 104 | 105 | # patch trainer 106 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 107 | if self.train_dataset is None or not has_length(self.train_dataset): 108 | return None 109 | # Build the sampler. 110 | if self.args.group_by_length: 111 | lengths = [] 112 | for dataset in self.train_dataset.datasets: 113 | lengths = lengths + dataset.length 114 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None 115 | return LengthGroupedSampler( 116 | self.args.train_batch_size, 117 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 118 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, 119 | dataset=self.train_dataset, 120 | lengths=lengths, 121 | model_input_name=model_input_name, 122 | ) 123 | else: 124 | return RandomSampler(self.train_dataset) 125 | 126 | 127 | def replace_train_sampler(): 128 | transformers.Trainer._get_train_sampler = _get_train_sampler 129 | # print('Replace train sampler!!') 130 | -------------------------------------------------------------------------------- /model_train/train/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | -------------------------------------------------------------------------------- /model_train/train/constants.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | IMG_CONTEXT_TOKEN = '' 12 | IMG_START_TOKEN = '' 13 | IMG_END_TOKEN = '' 14 | VID_START_TOKEN = '' 15 | VID_END_TOKEN = '' 16 | LINE_TOKEN = '' 17 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 18 | IMAGENET_STD = (0.229, 0.224, 0.225) 19 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) 20 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711) 21 | SIGLIP_MEAN = (0.5, 0.5, 0.5) 22 | SIGLIP_STD = (0.5, 0.5, 0.5) 23 | HUNYUAN_MEAN = (0.48145466, 0.4578275, 0.40821073) 24 | HUNYUAN_STD = (0.26862954, 0.26130258, 0.27577711) 25 | -------------------------------------------------------------------------------- /model_train/train/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | import io 12 | 13 | from transformers.trainer_pt_utils import LabelSmoother 14 | 15 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 16 | import os 17 | import random 18 | import re 19 | from collections import Counter 20 | from typing import Dict 21 | import math 22 | import cv2 23 | import imageio 24 | import numpy as np 25 | import torch 26 | import warnings 27 | import torch.nn.functional as F 28 | import torchvision.transforms as T 29 | import transformers 30 | from decord import VideoReader 31 | from PIL import Image 32 | from PIL import ImageDraw, ImageFont 33 | from torch.utils.data import ConcatDataset, WeightedRandomSampler 34 | from torchvision.transforms.functional import InterpolationMode 35 | 36 | from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, 37 | IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, 38 | VID_END_TOKEN, VID_START_TOKEN, LINE_TOKEN, 39 | SIGLIP_MEAN, SIGLIP_STD, HUNYUAN_MEAN, HUNYUAN_STD) 40 | 41 | xdrope_section = [ 42 | 0.25, 43 | 0.25, 44 | 0.25, 45 | 0.25 46 | ] 47 | 48 | def calculate_ngram_repetition(text, n): 49 | words = text.split() 50 | ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] 51 | ngram_counts = Counter(ngrams) 52 | total_ngrams = len(ngrams) 53 | repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1) 54 | return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0 55 | 56 | def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10): 57 | for conversation in conversations: 58 | if conversation['from'] == 'gpt': 59 | model_answer = conversation['value'] 60 | repeat_ratio = calculate_ngram_repetition(model_answer, ngram) 61 | if repeat_ratio > repeat_threshold: 62 | raise Exception 63 | 64 | def get_frame_indices(vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): 65 | duration = vlen / input_fps 66 | 67 | frames_per_second = input_fps 68 | 69 | ## current support videos > 300s 70 | if duration > 300: 71 | warnings.warn("The video is longer than 5 minutes. Due to sampling, some audio information may be lost!") 72 | 73 | if duration <= max_num_frames: 74 | interval = 1 75 | intervals = [(int(i * interval * frames_per_second), int((i + 1) * interval * frames_per_second)) for i in range(math.ceil(duration))] 76 | intervals_sec = [(int(i * interval), int((i + 1) * interval)) for i in range(math.ceil(duration))] 77 | else: 78 | num_segments = max_num_frames 79 | segment_duration = duration / num_segments 80 | intervals = [(int(i * segment_duration * frames_per_second), int((i + 1) * segment_duration * frames_per_second)) for i in range(num_segments)] 81 | intervals_sec = [(round(i * segment_duration), round((i + 1) * segment_duration)) for i in range(num_segments)] 82 | 83 | frame_indices = [] 84 | 85 | if sample == 'rand': 86 | for start, end in intervals: 87 | if end > vlen: 88 | end = vlen 89 | frame_indices.append(random.choice(range(start, end))) 90 | elif sample == 'middle': 91 | for start, end in intervals: 92 | if end > vlen: 93 | end = vlen 94 | frame_indices.append((start + end) // 2) 95 | else: 96 | raise NotImplementedError 97 | 98 | return frame_indices, intervals_sec 99 | 100 | def seconds_to_mmss(seconds): 101 | m = int(seconds // 60) 102 | s = int(seconds % 60) 103 | return f"{m:02d}:{s:02d}" 104 | 105 | def sec2hms(seconds): 106 | seconds = int(round(seconds)) 107 | h = seconds // 3600 108 | m = (seconds % 3600) // 60 109 | s = seconds % 60 110 | return f"{h:02d}:{m:02d}:{s:02d}" 111 | 112 | def add_timestamp_to_frame(frame, start_sec, end_sec, font_size=40): 113 | draw = ImageDraw.Draw(frame) 114 | font_size = int(frame.height * 0.05) 115 | font = ImageFont.truetype("ARIAL.TTF", font_size) 116 | text = f"{sec2hms(start_sec)}-{sec2hms(end_sec)}" 117 | bbox = draw.textbbox((0, 0), text, font=font) 118 | text_w = bbox[2] - bbox[0] 119 | text_h = bbox[3] - bbox[1] 120 | x = frame.width - text_w - 20 121 | y = 20 122 | draw.rectangle([x-10, y-10, x+text_w+10, y+text_h+10], fill=(0,0,0,180)) 123 | draw.text((x, y), text, fill=(255,255,255), font=font) 124 | return frame 125 | 126 | def read_frames_decord( 127 | video_path, num_frames, sample='rand', fix_start=None, 128 | client=None, clip=None, use_time=False, min_num_frames=4 129 | ): 130 | video_reader = VideoReader(video_path, num_threads=1) 131 | 132 | vlen = len(video_reader) 133 | fps = video_reader.get_avg_fps() 134 | duration = vlen / float(fps) 135 | 136 | frame_indices, intervals_sec = get_frame_indices( 137 | vlen, sample=sample, fix_start=fix_start, 138 | input_fps=fps, max_num_frames=num_frames 139 | ) 140 | 141 | frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8 142 | frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] 143 | 144 | if use_time: 145 | frames_with_ts = [] 146 | for i, frame in enumerate(frames): 147 | start_sec, end_sec = intervals_sec[i] 148 | frame_with_ts = add_timestamp_to_frame(frame, start_sec, end_sec) 149 | frames_with_ts.append(frame_with_ts) 150 | frames = frames_with_ts 151 | 152 | save_dir = 'output_frames' 153 | if not os.path.exists(save_dir): 154 | os.makedirs(save_dir, exist_ok=True) 155 | 156 | for idx, frame in enumerate(frames): 157 | video_name = video_path.split('/')[-1].replace('.mp4', '') 158 | frame.save(os.path.join(save_dir, f'{video_name}_frame_{idx:03d}.jpg')) 159 | 160 | return frames 161 | 162 | def extract_frame_number(filename): 163 | # Extract the numeric part from the filename using regular expressions 164 | match = re.search(r'_(\d+).jpg$', filename) 165 | return int(match.group(1)) if match else -1 166 | 167 | def sort_frames(frame_paths): 168 | # Extract filenames from each path and sort by their numeric part 169 | return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) 170 | 171 | def read_frames_folder( 172 | video_path, num_frames, sample='rand', fix_start=None, 173 | client=None, clip=None, min_num_frames=4 174 | ): 175 | image_list = sort_frames(list(os.listdir(video_path))) 176 | frames = [] 177 | for image in image_list: 178 | fp = os.path.join(video_path, image) 179 | frame = Image.open(fp).convert('RGB') 180 | frames.append(frame) 181 | 182 | return frames 183 | 184 | class WeightedConcatDataset(ConcatDataset): 185 | def __init__(self, datasets, weights): 186 | super().__init__(datasets) 187 | self.weights = torch.DoubleTensor(weights) 188 | self.total_size = sum(len(d) for d in datasets) 189 | self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) 190 | 191 | def __iter__(self): 192 | return iter(self.sampler) 193 | 194 | def __len__(self): 195 | return self.total_size 196 | 197 | def pil_loader(img_str): 198 | buff = io.BytesIO(img_str) 199 | img = Image.open(buff) 200 | return img.convert('RGB') 201 | 202 | class TCSLoader(object): 203 | 204 | def __init__(self, conf_path, sc_config_key='sensecore'): 205 | print(f'[TCSLoader] config_path: {conf_path}') 206 | # print('--> before Client(conf_path)') 207 | # self.client = Client(conf_path) 208 | # self.sc_config_key = sc_config_key 209 | # print('--> after Client(conf_path)') 210 | self.client = None 211 | 212 | def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', use_time=False, clip=None): 213 | #print(image_type, max_num_frames, min_num_frames, clip) 214 | if image_type == 'image': 215 | img_value_str = self.client.get(fn) 216 | img = pil_loader(img_value_str) 217 | return img 218 | 219 | elif image_type == 'video': 220 | if fn.endswith('/'): 221 | frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, 222 | client=self.client, sample=sample) 223 | else: 224 | frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, 225 | client=self.client, sample=sample, use_time=use_time, clip=clip) 226 | return frames 227 | 228 | 229 | def expand2square(pil_img, background_color): 230 | width, height = pil_img.size 231 | if width == height: 232 | return pil_img 233 | elif width > height: 234 | result = Image.new(pil_img.mode, (width, width), background_color) 235 | result.paste(pil_img, (0, (width - height) // 2)) 236 | return result 237 | else: 238 | result = Image.new(pil_img.mode, (height, height), background_color) 239 | result.paste(pil_img, ((height - width) // 2, 0)) 240 | return result 241 | 242 | 243 | def simulate_jpeg_degradation(quality): 244 | def jpeg_degrade(img): 245 | with io.BytesIO() as output: 246 | img.convert('RGB').save(output, format='JPEG', quality=quality) 247 | output.seek(0) # Move the reading cursor to the start of the stream 248 | img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory 249 | return img_jpeg 250 | return jpeg_degrade 251 | 252 | 253 | # Define the JPEG compression quality range, pre-create all JPEG compression functions 254 | qualities = list(range(75, 101)) 255 | jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} 256 | 257 | 258 | def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): 259 | if normalize_type == 'imagenet': 260 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 261 | elif normalize_type == 'clip': 262 | MEAN, STD = CLIP_MEAN, CLIP_STD 263 | elif normalize_type == 'siglip': 264 | MEAN, STD = SIGLIP_MEAN, SIGLIP_STD 265 | elif normalize_type == 'hunyuan': 266 | MEAN, STD = HUNYUAN_MEAN, HUNYUAN_STD 267 | else: 268 | raise NotImplementedError 269 | if is_train: # use data augumentation 270 | transform = T.Compose([ 271 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 272 | T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), 273 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 274 | T.ToTensor(), 275 | T.Normalize(mean=MEAN, std=STD) 276 | ]) 277 | else: 278 | if pad2square is False: # now we use this transform function by default 279 | transform = T.Compose([ 280 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 281 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 282 | T.ToTensor(), 283 | T.Normalize(mean=MEAN, std=STD) 284 | ]) 285 | else: 286 | transform = T.Compose([ 287 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 288 | T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), 289 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 290 | T.ToTensor(), 291 | T.Normalize(mean=MEAN, std=STD) 292 | ]) 293 | 294 | return transform 295 | 296 | def generate_tokens(w=448, h=448, use_xrope=True): 297 | total_patch_size = 16 * 2 * 2 298 | tokens = '' 299 | tokens += IMG_START_TOKEN 300 | tokens += IMG_CONTEXT_TOKEN 301 | for i in range(h // total_patch_size): 302 | for j in range(w // total_patch_size): 303 | tokens += IMG_CONTEXT_TOKEN 304 | if use_xrope: 305 | tokens += LINE_TOKEN 306 | else: 307 | tokens += IMG_CONTEXT_TOKEN 308 | tokens += IMG_CONTEXT_TOKEN 309 | tokens += IMG_END_TOKEN 310 | return tokens 311 | 312 | def get_xdrope_position_ids( 313 | position_ids_t, 314 | position_ids_x, 315 | position_ids_y, 316 | b, 317 | i, 318 | prev_index, 319 | boi_index, 320 | eoi_index, 321 | eol_index, 322 | ): 323 | 324 | position_ids_t[b, (i + 1):] -= (i + 1 - prev_index) 325 | position_ids_x[b, (i + 1):] -= (i + 1 - prev_index) 326 | position_ids_y[b, (i + 1):] -= (i + 1 - prev_index) 327 | 328 | idx_cur = 0 329 | for x in range(boi_index.size()[0]): 330 | m = boi_index[x] 331 | n = eoi_index[x] 332 | assert m < n 333 | # Reset image token position ids. 334 | if m >= prev_index and m < i: 335 | assert n < i 336 | position_ids_t[b, m+1+1:n-1] = idx_cur 337 | idx_cur += 1 338 | 339 | eol_idx_list = [] 340 | for y in range(eol_index.size()[0]): 341 | eol_idx = eol_index[y] 342 | # Reset image token position ids. 343 | if eol_idx > m and eol_idx < n: 344 | eol_idx_list.append(eol_idx) 345 | row = len(eol_idx_list) 346 | assert row > 0, 'the row of an image must be a positive integer' 347 | # -2 is for learnable img start and img end for each image 348 | # -1 is for getting rid of endpoint 349 | column = torch.round((n-m-2-1)/row).long().item() 350 | 351 | assert row * column == n-m-2-1, f"row:\t{row}, column:\t{column}, n:\t{n}, m:\t{m}, {int((n-m-2-1)/row)}" 352 | 353 | idx_xy = 0 354 | for rr in range(row): 355 | for cc in range(column): 356 | position_ids_x[b, m+1+1+idx_xy] = cc 357 | position_ids_y[b, m+1+1+idx_xy] = rr 358 | idx_xy += 1 359 | 360 | return position_ids_t, position_ids_x, position_ids_y 361 | 362 | 363 | def get_attention_masks_and_position_ids(data, eod_id, im_start_id, im_end_id, im_newline_id): 364 | position_embedding_xdrope = True 365 | 366 | micro_batch_size, seq_length = data.size() 367 | att_mask_batch = 1 368 | attention_mask = torch.tril(torch.ones( 369 | (att_mask_batch, seq_length, seq_length))).view( 370 | att_mask_batch, 1, seq_length, seq_length).int() 371 | position_ids = torch.arange(seq_length, dtype=torch.long) 372 | position_ids = position_ids.unsqueeze(0).expand_as(data) 373 | if position_embedding_xdrope: 374 | position_ids_t = position_ids.clone() 375 | position_ids_x = position_ids.clone() 376 | position_ids_y = position_ids.clone() 377 | for b in range(micro_batch_size): 378 | # Find indecies where EOD token is. 379 | eod_index = position_ids[b, data[b] == eod_id] 380 | 381 | eod_index = eod_index.clone() 382 | # Detach indecies from positions if going to modify positions. 383 | # Loop through EOD indecies: 384 | prev_index = 0 385 | if position_embedding_xdrope: 386 | boi_index = position_ids[b, data[b] == im_start_id] 387 | eoi_index = position_ids[b, data[b] == im_end_id] 388 | eol_index = position_ids[b, data[b] == im_newline_id] 389 | 390 | #print(boi_index, eoi_index, eol_index) 391 | for j in range(eod_index.size()[0]): 392 | i = eod_index[j] 393 | attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 394 | position_ids[b, (i + 1):] -= (i + 1 - prev_index) 395 | 396 | if position_embedding_xdrope: 397 | position_ids_t, position_ids_x, position_ids_y = get_xdrope_position_ids(position_ids_t, position_ids_x, position_ids_y, 398 | b, i, prev_index, boi_index, 399 | eoi_index, eol_index 400 | ) 401 | prev_index = i + 1 402 | if position_embedding_xdrope: 403 | position_ids = torch.cat([position_ids.unsqueeze(1), position_ids_x.unsqueeze(1), position_ids_y.unsqueeze(1),position_ids_t.unsqueeze(1)], dim=1) 404 | 405 | return attention_mask, position_ids 406 | 407 | 408 | def preprocess_hunyuan( 409 | template_name, 410 | sources, 411 | tokenizer: transformers.PreTrainedTokenizer, 412 | text_only: bool = False, 413 | image_size: int = 448, 414 | num_image: int = 1 415 | ) -> Dict: 416 | assert len(sources) == 1, 'process only the first conversations' 417 | conversations = sources[0] 418 | 419 | if not text_only: 420 | new_conversations = [] 421 | current_image_idx = 0 422 | img_tokens_per_frame = generate_tokens(w=image_size, h=image_size) 423 | for conversation in conversations: 424 | if conversation['from'] == 'human': 425 | image_cnt = conversation['value'].count('') 426 | for i in range(image_cnt): 427 | if current_image_idx == num_image: 428 | break 429 | if current_image_idx == 0: 430 | if num_image != 1: 431 | image_tokens = f'{VID_START_TOKEN}{img_tokens_per_frame}' 432 | else: 433 | image_tokens = f'{VID_START_TOKEN}{img_tokens_per_frame}{VID_END_TOKEN}' 434 | elif current_image_idx == num_image - 1: 435 | image_tokens = f'{img_tokens_per_frame}{VID_END_TOKEN}' 436 | else: 437 | image_tokens = f'{img_tokens_per_frame}' 438 | conversation['value'] = conversation['value'].replace('', image_tokens, 1) 439 | current_image_idx += 1 440 | new_conversations.append(conversation) 441 | conversations = new_conversations 442 | assert current_image_idx == num_image, f'{current_image_idx} != {num_image}' 443 | 444 | batches, roles = [], [] 445 | 446 | for conversation in conversations: 447 | if conversation['from'] == 'human': 448 | batches.append(conversation["value"] + '') 449 | roles.append('human') 450 | elif conversation['from'] == 'gpt': 451 | batches.append(f'{conversation["value"]}<|endoftext|>') 452 | roles.append('gpt') 453 | else: 454 | raise NotImplementedError 455 | 456 | final_input_ids, final_targets = [tokenizer.bos_id], [IGNORE_TOKEN_ID] 457 | for role, batch in zip(roles, batches): 458 | 459 | input_ids = tokenizer.encode( 460 | batch, 461 | padding=False, 462 | max_length=tokenizer.model_max_length, 463 | truncation=True, 464 | ) 465 | input_ids = np.array(input_ids) 466 | 467 | final_input_ids.extend(input_ids.tolist()) 468 | 469 | if role == 'system' or role == 'human': 470 | final_targets.extend(np.full(input_ids.shape, IGNORE_TOKEN_ID).tolist()) # ignore 471 | elif role == 'gpt': 472 | target = input_ids.copy() 473 | final_targets.extend(target.tolist()) 474 | else: 475 | raise NotImplementedError 476 | 477 | final_input_ids = np.array(final_input_ids) 478 | final_targets = np.array(final_targets) 479 | input_ids = torch.tensor(final_input_ids)[:tokenizer.model_max_length] 480 | targets = torch.tensor(final_targets)[:tokenizer.model_max_length] 481 | 482 | _, position_ids = get_attention_masks_and_position_ids(input_ids.unsqueeze(0), \ 483 | tokenizer.eod_id, tokenizer.im_start_id, tokenizer.im_end_id, tokenizer.im_newline_id) 484 | 485 | input_ids[input_ids == tokenizer.im_newline_id] = tokenizer.image_token_id 486 | 487 | torch.set_printoptions(threshold=float('inf')) 488 | 489 | padding = False 490 | if padding: 491 | current_length = input_ids.size(0) 492 | padding_length = tokenizer.model_max_length - current_length 493 | input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_id) 494 | targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID) 495 | 496 | input_ids = input_ids.unsqueeze(0) 497 | targets = targets.unsqueeze(0) 498 | 499 | return dict( 500 | input_ids=input_ids, 501 | labels=targets, 502 | attention_mask=input_ids.ne(tokenizer.pad_id), 503 | position_ids=position_ids, 504 | ) 505 | -------------------------------------------------------------------------------- /model_train/train/dataset_packed.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | # Modified from InternVL 7 | # Copyright (c) 2025 ARC Lab 8 | # Licensed under LICENSE [see LICENSE for details] 9 | # -------------------------------------------------------- 10 | 11 | import bisect 12 | import copy 13 | import logging 14 | from collections import defaultdict 15 | from typing import List, Union 16 | 17 | import numpy as np 18 | import torch 19 | import torch.distributed as dist 20 | from torch.utils.data import IterableDataset, get_worker_info 21 | from transformers.trainer_pt_utils import LabelSmoother 22 | 23 | from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN 24 | 25 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 26 | logger = logging.getLogger(__name__) 27 | logger.setLevel(logging.INFO) 28 | 29 | 30 | def is_dist_avail_and_initialized(): 31 | if not dist.is_available(): 32 | return False 33 | if not dist.is_initialized(): 34 | return False 35 | return True 36 | 37 | 38 | def get_world_size(): 39 | if not is_dist_avail_and_initialized(): 40 | return 1 41 | return dist.get_world_size() 42 | 43 | 44 | def get_rank(): 45 | if not is_dist_avail_and_initialized(): 46 | return 0 47 | return dist.get_rank() 48 | 49 | 50 | class PackedDataset(IterableDataset): 51 | def __init__( 52 | self, 53 | tokenizer, 54 | data_rank, 55 | data_world_size, 56 | datasets: List, 57 | dataset_weight: List[int] = None, 58 | num_images_expected: int = 6, 59 | max_packed_tokens: int = 32768, 60 | max_buffer_size: int = 100, 61 | log_freq: int = 1000000, 62 | strict_mode: bool = False, 63 | debug_mode: bool = False, 64 | replacement: bool = True, 65 | allow_overflow: bool = True, 66 | allow_empty_data: bool = False, 67 | allow_deduplicated_ds_name: bool = False, 68 | ): 69 | super().__init__() 70 | self.tokenizer = tokenizer 71 | self.data_rank = data_rank 72 | self.data_world_size = data_world_size 73 | self.datasets = datasets 74 | self.num_images_expected = num_images_expected 75 | self.max_buffer_size = max_buffer_size 76 | self.log_freq = log_freq 77 | self.strict_mode = strict_mode 78 | self.debug_mode = debug_mode 79 | self.replacement = replacement 80 | self.allow_overflow = allow_overflow 81 | self.allow_empty_data = allow_empty_data 82 | 83 | self.max_packed_tokens = max_packed_tokens 84 | 85 | self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) 86 | self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) 87 | self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) 88 | 89 | assert self.img_start_token_id != self.tokenizer.unk_token_id 90 | assert self.img_token_id != self.tokenizer.unk_token_id 91 | assert self.img_end_token_id != self.tokenizer.unk_token_id 92 | 93 | if dataset_weight is None: 94 | dataset_weight = [1] * len(datasets) 95 | self.dataset_type = [d.dataset_type for d in self.datasets] 96 | 97 | self.datasets_orig = datasets 98 | self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight] 99 | 100 | self.datasets = [ds for ds in self.datasets_orig] 101 | self.dataset_weight = [w for w in self.dataset_weight_orig] 102 | 103 | # lazy init 104 | self.worker_id = None 105 | self.worker_state_key = None 106 | self.dataset_iter_list = None 107 | self._state_dict = { 108 | 'sample_info': {d.ds_name:0 for d in self.datasets}, 109 | } 110 | 111 | self.worker_custom_infos = None 112 | 113 | ds_name_list = [d.ds_name for d in self.datasets] 114 | if not allow_deduplicated_ds_name: 115 | assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}' 116 | 117 | for ds in self.datasets: 118 | if ds.max_num_images > self.num_images_expected: 119 | logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}') 120 | ds.max_num_images = num_images_expected 121 | 122 | if ds.max_tokens > self.max_packed_tokens: 123 | logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}') 124 | ds.max_tokens = self.max_packed_tokens 125 | 126 | self._state_dict[ds.ds_name] = {} 127 | 128 | if get_rank() == 0: 129 | logger.info( 130 | f'Loaded dataset to pack: {ds_name_list}, ' 131 | f'{self.num_images_expected=}, {self.max_packed_tokens=}, ' 132 | f'{self.replacement=}, {self.allow_overflow=}', 133 | ) 134 | 135 | temp = [] 136 | for ds, ds_w in zip(self.datasets, self.dataset_weight): 137 | temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%') 138 | temp = '\n'.join(temp) 139 | logger.info( 140 | f'Sampling prob for each dataset:\n{temp}' 141 | ) 142 | 143 | if self.allow_empty_data: 144 | logger.warning('allow_empty_data is enabled, note that empty data may be generated!') 145 | 146 | def load_state_dict(self, state_dict, custom_infos=None): 147 | 148 | self.worker_custom_infos = custom_infos 149 | 150 | self._state_dict.update(state_dict) 151 | for ds in self.datasets: 152 | if ds.ds_name in self._state_dict: 153 | ds.load_state_dict(self._state_dict[ds.ds_name]) 154 | logger.info(f'{ds.ds_name=} is resumed.') 155 | else: 156 | logger.warning(f'{ds.ds_name=} is not resumed.') 157 | 158 | def _should_log(self): 159 | worker_id = 0 if get_worker_info() is None else get_worker_info().id 160 | num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers 161 | 162 | worker_id = num_workers * get_rank() + worker_id 163 | num_workers = num_workers * get_world_size() 164 | 165 | return worker_id == 0 166 | 167 | def next_data(self, current_dataset_idx): 168 | while True: 169 | try: 170 | current_sample = next(self.dataset_iter_list[current_dataset_idx]) 171 | break # Exit loop if successful 172 | except StopIteration: 173 | if self.replacement: 174 | # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.') 175 | try: 176 | self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx]) 177 | current_sample = next(self.dataset_iter_list[current_dataset_idx]) 178 | break 179 | except: 180 | # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') 181 | self.datasets.pop(current_dataset_idx) 182 | self.dataset_iter_list.pop(current_dataset_idx) 183 | self.dataset_weight.pop(current_dataset_idx) 184 | 185 | if len(self.datasets) == 0: 186 | raise StopIteration 187 | current_dataset_idx = np.random.choice(len(self.datasets)) 188 | else: 189 | # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') 190 | self.datasets.pop(current_dataset_idx) 191 | self.dataset_iter_list.pop(current_dataset_idx) 192 | self.dataset_weight.pop(current_dataset_idx) 193 | 194 | if len(self.datasets) == 0: 195 | raise StopIteration 196 | current_dataset_idx = np.random.choice(len(self.datasets)) 197 | except: 198 | logger.error('Unexpected error!') 199 | if len(self.datasets) == 0: 200 | raise StopIteration 201 | current_dataset_idx = np.random.choice(len(self.datasets)) 202 | 203 | current_ds_name = self.datasets[current_dataset_idx].ds_name 204 | current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx 205 | 206 | if self.worker_state_key not in self._state_dict[current_ds_name]: 207 | self._state_dict[current_ds_name][self.worker_state_key] = {} 208 | 209 | meta_info = current_sample.pop('meta_info', {}) 210 | self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info) 211 | self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1 212 | return current_sample 213 | 214 | def find_buffer(self, buffer_list, new_sample): 215 | # NOTE: use `bisect` to search might be faster 216 | 217 | find = False 218 | find_idx = -1 219 | num_images_current = new_sample['pixel_values'].size(0) 220 | for buffer_idx, buffer in enumerate(buffer_list): 221 | num_images_buffer = buffer['pixel_values'].size(0) 222 | if num_images_buffer + num_images_current <= self.num_images_expected: 223 | num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0) 224 | 225 | if num_merged_tokens <= self.max_packed_tokens: 226 | find = True 227 | find_idx = buffer_idx 228 | break 229 | 230 | if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2: 231 | find = True 232 | find_idx = buffer_idx 233 | 234 | if find: 235 | return buffer_list.pop(find_idx) 236 | return None 237 | 238 | def update_buffer(self, buffer, new_sample): 239 | if buffer is None: 240 | new_sample['data_index'] = torch.zeros_like(new_sample['input_ids']) 241 | return new_sample 242 | 243 | new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item() 244 | 245 | assert buffer.keys() == new_sample.keys() 246 | for k in buffer: 247 | buffer[k] = torch.cat([buffer[k], new_sample[k]]) 248 | return buffer 249 | 250 | @staticmethod 251 | def check_valid(sample_to_check, min_active_tokens_ratio=1/256): 252 | num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum() 253 | num_tokens = sample_to_check['labels'].numel() 254 | return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio 255 | 256 | @staticmethod 257 | def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id): 258 | if buffer['input_ids'].size(0) <= max_tokens: 259 | return [buffer] 260 | 261 | def _image_is_splitted(input_ids, cut_idx): 262 | is_image_start = input_ids[cut_idx].item() == img_start_token_id 263 | is_image_token = input_ids[cut_idx].item() == img_token_id 264 | is_image_end = input_ids[cut_idx].item() == img_end_token_id 265 | return is_image_start or is_image_token or is_image_end 266 | 267 | def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx): 268 | assert (right_idx is None) == (right_img_idx is None) 269 | 270 | left_sample = {} 271 | right_sample = {} if right_idx is not None else None 272 | for k in sample_to_split: 273 | if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']: 274 | left_sample[k] = sample_to_split[k][:left_idx] 275 | if right_sample is not None: 276 | right_sample[k] = sample_to_split[k][right_idx:] 277 | elif k in ['pixel_values', 'image_flags']: 278 | left_sample[k] = sample_to_split[k][:left_img_idx] 279 | if right_sample is not None: 280 | right_sample[k] = sample_to_split[k][right_img_idx:] 281 | else: 282 | raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}') 283 | return left_sample, right_sample 284 | 285 | splitted_buffer = [] 286 | while buffer['input_ids'].size(0) > max_tokens: 287 | img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist() 288 | img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist() 289 | assert len(img_start_idx_list) == len(img_end_idx_list) 290 | 291 | if _image_is_splitted(buffer['input_ids'], max_tokens): 292 | cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens) 293 | if buffer['input_ids'][max_tokens] == img_start_token_id: 294 | assert max_tokens == img_start_idx_list[cut_idx] 295 | cut_left_idx = img_start_idx_list[cut_idx] 296 | cut_left_img_idx = cut_idx 297 | else: 298 | cut_left_idx = img_start_idx_list[cut_idx - 1] 299 | cut_left_img_idx = cut_idx - 1 300 | cut_right_idx = cut_left_idx 301 | cut_right_img_idx = cut_left_img_idx 302 | else: 303 | cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens) 304 | if cut_img_idx < len(img_start_idx_list): 305 | cut_right_idx = img_start_idx_list[cut_img_idx] 306 | cut_right_img_idx = cut_img_idx 307 | else: 308 | cut_right_idx = None 309 | cut_right_img_idx = None 310 | 311 | cut_left_idx = max_tokens 312 | cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0) 313 | 314 | left, right = _split( 315 | sample_to_split=buffer, 316 | left_idx=cut_left_idx, 317 | left_img_idx=cut_left_img_idx, 318 | right_idx=cut_right_idx, 319 | right_img_idx=cut_right_img_idx, 320 | ) 321 | 322 | assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0) 323 | if right is not None: 324 | assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0) 325 | 326 | if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left): 327 | splitted_buffer.append(left) 328 | 329 | if right is None or right['pixel_values'].size(0) == 0: 330 | break 331 | 332 | buffer = right 333 | if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer): 334 | splitted_buffer.append(buffer) 335 | break 336 | 337 | logger.debug( 338 | f'split a sample into {len(splitted_buffer)} samples, ' 339 | f'current max_tokens={max_tokens}' 340 | ) 341 | return splitted_buffer 342 | 343 | def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer): 344 | # NOTE: in-place operation 345 | 346 | splitted_buffer = PackedDataset.split_buffer( 347 | buffer=buffer, 348 | max_tokens=self.max_packed_tokens, 349 | img_start_token_id=self.img_start_token_id, 350 | img_token_id=self.img_token_id, 351 | img_end_token_id=self.img_end_token_id, 352 | ) 353 | 354 | for each_buffer in splitted_buffer: 355 | if each_buffer['pixel_values'].size(0) > self.num_images_expected: 356 | logger.error( 357 | f"Find a sample with {each_buffer['pixel_values'].size(0)} images, " 358 | f'which exceeds {self.num_images_expected}' 359 | ) 360 | continue 361 | 362 | if each_buffer['input_ids'].size(0) >= self.max_packed_tokens: 363 | assert each_buffer['input_ids'].size(0) == self.max_packed_tokens 364 | buffer_max_len_list.append(each_buffer) 365 | continue 366 | 367 | find_idx = len(buffer_list) 368 | num_images_new_sample = each_buffer['pixel_values'].size(0) 369 | for buffer_idx in range(len(buffer_list)): 370 | if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample: 371 | find_idx = buffer_idx 372 | break 373 | buffer_list.insert(find_idx, each_buffer) 374 | 375 | for i in range(1, len(buffer_list)): 376 | assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0) 377 | 378 | return buffer_list, buffer_max_len_list 379 | 380 | def pad_buffer(self, buffer): 381 | if buffer['pixel_values'].size(0) == self.num_images_expected: 382 | return buffer 383 | 384 | num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0) 385 | pad_images = torch.stack([ 386 | torch.zeros_like(buffer['pixel_values'][0]) 387 | for _ in range(num_pad_images) 388 | ]) 389 | pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long) 390 | 391 | buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images]) 392 | buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags]) 393 | 394 | return buffer 395 | 396 | def postprocess_buffer(self, buffer, custom_infos=None): 397 | buffer['worker_state_key'] = self.worker_state_key 398 | buffer['worker_state_dict'] = self._state_dict 399 | if custom_infos is not None: 400 | buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)} 401 | return buffer 402 | 403 | def print_log(self, iter_idx, buffer_list): 404 | if iter_idx % self.log_freq != 0: 405 | return 406 | 407 | if self._should_log(): 408 | logger.info( 409 | f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}" 410 | ) 411 | 412 | def __iter__(self): 413 | iter_idx = 0 414 | buffer_list = [] 415 | buffer_max_len_list = [] 416 | 417 | if self._should_log(): 418 | logger.info(f'Begin to iter, {len(buffer_list)=}') 419 | 420 | worker_id = 0 if get_worker_info() is None else get_worker_info().id 421 | num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers 422 | 423 | worker_id = num_workers * self.data_rank + worker_id 424 | num_workers = num_workers * self.data_world_size 425 | 426 | rng = np.random.default_rng(seed=worker_id) 427 | 428 | # reset states of each dataset 429 | self.worker_id = worker_id 430 | self.worker_state_key = f'work_state_{self.worker_id}' 431 | self.datasets = [d for d in self.datasets_orig] 432 | self.dataset_weight = [w for w in self.dataset_weight_orig] 433 | self.dataset_iter_list = [iter(d) for d in self.datasets] 434 | 435 | for ds in self.datasets: 436 | # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)): 437 | ds.worker_id = worker_id 438 | ds.worker_state_key = f'work_state_{self.worker_id}' 439 | ds.num_workers = num_workers 440 | if self._should_log() and worker_id == 0: 441 | logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}') 442 | 443 | if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos: 444 | custom_infos = self.worker_custom_infos[self.worker_state_key] 445 | # buffer list 446 | if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list): 447 | buffer_list = custom_infos['buffer_list'] 448 | if self._should_log() and worker_id == 0: 449 | logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}') 450 | # other infos 451 | 452 | # reset 453 | self.worker_custom_infos = None 454 | 455 | logger.debug( 456 | f'{self.__class__.__name__} Rank {self.data_rank} ' 457 | f'Worker {worker_id} begin to load data' 458 | ) 459 | 460 | while True: 461 | self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight] 462 | current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight) 463 | 464 | try: 465 | current_sample = self.next_data(current_dataset_idx) 466 | except: 467 | logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})') 468 | while len(buffer_list) > 0: 469 | if self.strict_mode: 470 | yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0))) 471 | else: 472 | yield self.postprocess_buffer(buffer_list.pop(0)) 473 | logger.info(f'buffer_list is empty! ({len(buffer_list)=})') 474 | return 475 | 476 | buffer = self.find_buffer(buffer_list, current_sample) 477 | buffer = self.update_buffer(buffer, current_sample) 478 | buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer) 479 | 480 | while len(buffer_max_len_list) > 0: 481 | if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens: 482 | logger.debug( 483 | f'num tokens of a buffer exceed {self.max_packed_tokens=}, ' 484 | f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images" 485 | ) 486 | if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected: 487 | # buffer_max_len_list.pop(0) 488 | yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list}) 489 | else: 490 | yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list}) 491 | 492 | while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected: 493 | logger.error( 494 | f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) " 495 | f'is larger than num_images_expected({self.num_images_expected})' 496 | ) 497 | buffer_list.pop(0) 498 | 499 | while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected: 500 | if self.debug_mode: 501 | debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) 502 | while True: 503 | yield debug_data.copy() 504 | 505 | yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) 506 | 507 | while len(buffer_list) > self.max_buffer_size: 508 | logger.debug( 509 | f'Failed to pack data to exactly {self.num_images_expected} images, ' 510 | f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images." 511 | ) 512 | if self.strict_mode: 513 | yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list}) 514 | else: 515 | yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) 516 | 517 | self.print_log(iter_idx=iter_idx, buffer_list=buffer_list) 518 | iter_idx += 1 519 | 520 | @staticmethod 521 | def get_cu_seqlens_and_indexes( 522 | data_index: torch.LongTensor, # (seq_len,) 523 | input_ids: torch.LongTensor, # (seq_len,) 524 | labels: torch.LongTensor, # (seq_len,) 525 | len2weight: callable, 526 | ): 527 | indexes = [] 528 | cu_seqlens = [0] 529 | loss_weight = [] 530 | 531 | start = data_index.min() 532 | end = data_index.max() + 1 533 | for i in range(start, end): 534 | num_tokens = (data_index == i).sum().item() 535 | indexes.extend(list(range(num_tokens))) 536 | cu_seqlens.append(cu_seqlens[-1] + num_tokens) 537 | assert num_tokens > 0 538 | 539 | curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] 540 | assert (curr_data_index == i).all(), data_index 541 | 542 | curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] 543 | num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item() 544 | loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens) 545 | 546 | assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}' 547 | 548 | loss_weight = torch.tensor(loss_weight, dtype=torch.float32) 549 | return cu_seqlens, indexes, loss_weight 550 | 551 | 552 | WARNING_CNT = defaultdict(int) 553 | 554 | 555 | def packed_collate_fn( 556 | features, 557 | data_collator, 558 | len2weight: callable, 559 | max_item_length: int, 560 | micro_num: int = 1, 561 | loss_reduction_all_gather: bool = False, 562 | pad_id: int = 0, 563 | ): 564 | if not isinstance(features, list): 565 | features = [features] 566 | 567 | if len(features) > micro_num: 568 | raise NotImplementedError(f'{len(features)=} > {micro_num=}') 569 | 570 | if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5: 571 | logger.warning( 572 | f'{len(features)=} > {micro_num=}, ' 573 | f'the features will be padded to satisfy micro_num requirement' 574 | ) 575 | WARNING_CNT['micro_num_warning'] += 1 576 | 577 | # ensure that the len(features) is equal to the required micro_num 578 | num_features = len(features) 579 | while len(features) < micro_num: 580 | features.append(copy.deepcopy(features[0])) 581 | features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID) 582 | 583 | indexes = [] 584 | cu_seqlens = [] 585 | cu_num_images_list = [0] 586 | 587 | worker_state_key_list = [] 588 | worker_state_dict_list = [] 589 | worker_state_custom_infos_list = [] 590 | 591 | batch_lens = [feat['input_ids'].shape for feat in features] 592 | max_item_length = max_item_length or max(batch_lens)[0] 593 | 594 | num_samples = 0 595 | num_padding_tokens = 0 596 | for feat_idx, feat in enumerate(features): 597 | data_index = feat.pop('data_index') 598 | curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes( 599 | data_index=data_index, 600 | input_ids=feat['input_ids'], 601 | labels=feat['labels'], 602 | len2weight=len2weight, 603 | ) 604 | 605 | feat['loss_weight'] = curr_loss_weight 606 | 607 | if feat_idx < num_features: 608 | num_samples += len(curr_cu_seqlens) - 1 609 | 610 | if curr_cu_seqlens[-1] < max_item_length: 611 | curr_cu_seqlens.append(max_item_length) 612 | curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2]))) 613 | 614 | indexes.append(torch.tensor(curr_indexes, dtype=torch.long)) 615 | cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32)) 616 | 617 | worker_state_key_list.append(feat.pop('worker_state_key')) 618 | worker_state_dict_list.append(feat.pop('worker_state_dict')) 619 | worker_state_custom_infos_list.append(feat.pop('custom_infos', None)) 620 | 621 | num_padding_tokens += (max_item_length - feat['input_ids'].size(0)) 622 | cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0)) 623 | 624 | batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id) 625 | # convert it to list in case it is converted into bf16 626 | batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist() 627 | batch['attention_mask'] = torch.stack(cu_seqlens) 628 | batch['loss_reduction_all_gather'] = loss_reduction_all_gather 629 | batch['statistics'] = torch.tensor( 630 | [ 631 | num_samples, 632 | num_padding_tokens, 633 | batch['image_flags'].numel() - batch['image_flags'].sum().item(), 634 | ], 635 | dtype=torch.long, 636 | ) 637 | batch.pop('type_ids') 638 | return batch 639 | -------------------------------------------------------------------------------- /model_vllm/__init__.py: -------------------------------------------------------------------------------- 1 | from vllm import ModelRegistry 2 | from .hunyuan import HunYuanForCausalLM 3 | from .hunyuan_video import HunyuanVideoModel 4 | from .video_audio_encoder import VideoAudioEncoder 5 | from .video_audio_llm import VideoAudioLLM 6 | 7 | ModelRegistry.register_model("HunYuanForCausalLM", HunYuanForCausalLM) 8 | ModelRegistry.register_model("HunyuanVideoModel", HunyuanVideoModel) 9 | -------------------------------------------------------------------------------- /model_vllm/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/hunyuan.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan.cpython-310.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/hunyuan.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan.cpython-311.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/hunyuan_video.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan_video.cpython-310.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/hunyuan_video.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan_video.cpython-311.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/monkey_patch_mrope.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/monkey_patch_mrope.cpython-310.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/monkey_patch_mrope.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/monkey_patch_mrope.cpython-311.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/video_audio_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/video_audio_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/video_audio_llm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_llm.cpython-310.pyc -------------------------------------------------------------------------------- /model_vllm/__pycache__/video_audio_llm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_llm.cpython-311.pyc -------------------------------------------------------------------------------- /model_vllm/hunyuan.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union, List 3 | 4 | import torch 5 | from torch import nn 6 | from transformers import PretrainedConfig 7 | 8 | from vllm.attention.layer import Attention 9 | from vllm.compilation.decorators import support_torch_compile 10 | from vllm.config import CacheConfig, VllmConfig 11 | from vllm.distributed import ( 12 | get_pp_group, 13 | get_tensor_model_parallel_rank, 14 | get_tensor_model_parallel_world_size, 15 | split_tensor_along_last_dim, 16 | tensor_model_parallel_all_gather, 17 | ) 18 | from vllm.model_executor.layers.activation import SiluAndMul 19 | from vllm.model_executor.layers.layernorm import RMSNorm 20 | from vllm.model_executor.layers.linear import ( 21 | MergedColumnParallelLinear, 22 | QKVParallelLinear, 23 | RowParallelLinear, 24 | ReplicatedLinear, 25 | ) 26 | from vllm.model_executor.layers.logits_processor import LogitsProcessor 27 | from vllm.model_executor.layers.pooler import Pooler, PoolingType 28 | from vllm.model_executor.layers.quantization import QuantizationConfig 29 | from vllm.model_executor.layers.rotary_embedding import ( 30 | get_rope, 31 | RotaryEmbedding, 32 | MRotaryEmbedding, 33 | _apply_rotary_emb, 34 | ) 35 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler 36 | from vllm.model_executor.layers.vocab_parallel_embedding import ( 37 | ParallelLMHead, 38 | VocabParallelEmbedding, 39 | ) 40 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader 41 | 42 | from vllm.model_executor.pooling_metadata import PoolingMetadata 43 | from vllm.model_executor.sampling_metadata import SamplingMetadata 44 | from vllm.sequence import IntermediateTensors, PoolerOutput 45 | 46 | from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP 47 | from vllm.model_executor.models.utils import ( 48 | is_pp_missing_parameter, 49 | make_empty_intermediate_tensors_factory, 50 | make_layers, 51 | maybe_prefix, 52 | AutoWeightsLoader, 53 | ) 54 | 55 | 56 | class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): 57 | 58 | def __init__( 59 | self, 60 | head_size: int, 61 | rotary_dim: int, 62 | max_position_embeddings: int, 63 | base: int, 64 | scaling_alpha: float, 65 | dtype: torch.dtype, 66 | is_neox_style: bool = True, 67 | ) -> None: 68 | self.scaling_alpha = scaling_alpha 69 | super().__init__( 70 | head_size, 71 | rotary_dim, 72 | max_position_embeddings, 73 | base, 74 | is_neox_style, 75 | dtype, 76 | ) 77 | 78 | def _compute_cos_sin_cache(self) -> torch.Tensor: 79 | # NOTE(woosuk): self.max_position_embeddings is the original 80 | # maximum length before applying the rope scaling. 81 | # Thus, the maximum length after applying the rope scaling is 82 | # self.max_position_embeddings * self.scaling_alpha. 83 | max_len = self.max_position_embeddings * self.scaling_alpha 84 | base = self.base * self.scaling_alpha ** ( 85 | self.rotary_dim / (self.rotary_dim - 2) 86 | ) 87 | inv_freq = 1.0 / ( 88 | base 89 | ** (torch.arange(0, self.rotary_dim, 2).float() / self.rotary_dim) 90 | ) 91 | t = torch.arange(max_len, dtype=torch.float) 92 | freqs = torch.einsum("i,j -> ij", t, inv_freq) 93 | cos = freqs.cos() 94 | sin = freqs.sin() 95 | cache = torch.cat((cos, sin), dim=-1) 96 | return cache 97 | 98 | 99 | def rotate_half(x): 100 | """Rotates half the hidden dims of the input.""" 101 | x1 = x[..., : x.shape[-1] // 2] 102 | x2 = x[..., x.shape[-1] // 2:] 103 | return torch.cat((-x2, x1), dim=-1) 104 | 105 | 106 | class DynamicNTKAlphaMRotaryEmbedding(MRotaryEmbedding): 107 | 108 | def __init__( 109 | self, 110 | head_size: int, 111 | rotary_dim: int, 112 | max_position_embeddings: int, 113 | base: int, 114 | scaling_alpha: float, 115 | dtype: torch.dtype, 116 | mrope_section: Optional[List[int]] = None, 117 | is_neox_style: bool = True, 118 | max_model_len: bool = None, 119 | ) -> None: 120 | self.scaling_alpha = scaling_alpha 121 | self.max_model_len = max_model_len 122 | assert len(mrope_section) == 4, "Currently only 4D is supported" 123 | mrope_section = [int(x * rotary_dim // 2) for x in mrope_section] 124 | 125 | # MRotaryEmbedding will enlarge the max_position_embeddings by 4 126 | # To keep consistent with the original max_position_embeddings, 127 | # we need to divide the max_position_embeddings by 4 128 | max_position_embeddings = max_position_embeddings // 4 129 | 130 | super().__init__( 131 | head_size, 132 | rotary_dim, 133 | max_position_embeddings, 134 | base, 135 | is_neox_style, 136 | dtype, 137 | mrope_section, 138 | ) 139 | 140 | def _compute_cos_sin_cache(self) -> torch.Tensor: 141 | if self.max_model_len is not None: 142 | max_len = self.max_model_len 143 | else: 144 | max_len = self.max_position_embeddings * self.scaling_alpha 145 | 146 | 147 | base = self.base * self.scaling_alpha ** ( 148 | self.rotary_dim / (self.rotary_dim - 2) 149 | ) 150 | inv_freq = 1.0 / ( 151 | base 152 | ** (torch.arange(0, self.rotary_dim, 2).float() / self.rotary_dim) 153 | ) 154 | t = torch.arange(max_len, dtype=torch.float) 155 | freqs = torch.einsum("i,j -> ij", t, inv_freq) 156 | freqs = torch.cat((freqs, freqs), dim=-1) 157 | cos = freqs.cos() 158 | sin = freqs.sin() 159 | cache = torch.cat((cos, sin), dim=-1) 160 | return cache 161 | 162 | 163 | def forward( 164 | self, 165 | positions: torch.Tensor, 166 | query: torch.Tensor, 167 | key: torch.Tensor, 168 | ) -> Tuple[torch.Tensor, torch.Tensor]: 169 | """XDRope implementation following apply_rotary_pos_emb_xdrope pattern. 170 | 171 | Args: 172 | positions: 173 | [num_tokens,] (text only) or 174 | [4, num_tokens] (4D positions with multimodal inputs) 175 | query: [num_tokens, num_heads * head_size] 176 | key: [num_tokens, num_kv_heads * head_size] 177 | """ 178 | assert positions.ndim == 2, f"positions must be 2D, but got {positions.shape}" 179 | 180 | num_tokens = positions.shape[-1] 181 | cos_sin = self.cos_sin_cache[positions] 182 | cos, sin = cos_sin.chunk(2, dim=-1) 183 | 184 | x_dim = len(self.mrope_section) 185 | 186 | cos = cos.permute(1, 0, 2).reshape(-1, x_dim, self.rotary_dim) 187 | sin = sin.permute(1, 0, 2).reshape(-1, x_dim, self.rotary_dim) 188 | 189 | xdrope_section = self.mrope_section * 2 190 | assert sum(xdrope_section) == self.rotary_dim 191 | 192 | cos = torch.cat([ 193 | m[:, i % x_dim] for i, m in enumerate(cos.split(xdrope_section, dim=-1)) 194 | ], dim=-1) 195 | sin = torch.cat([ 196 | m[:, i % x_dim] for i, m in enumerate(sin.split(xdrope_section, dim=-1)) 197 | ], dim=-1) 198 | 199 | cos = cos.view(1, -1, self.rotary_dim) 200 | sin = sin.view(1, -1, self.rotary_dim) 201 | 202 | cos = cos.permute(1, 0, 2) 203 | sin = sin.permute(1, 0, 2) 204 | 205 | query_shape = query.shape 206 | query = query.view(num_tokens, -1, self.head_size) 207 | query = (query * cos) + rotate_half(query) * sin 208 | query = query.reshape(query_shape) 209 | 210 | key_shape = key.shape 211 | key = key.view(num_tokens, -1, self.head_size) 212 | key = (key * cos) + rotate_half(key) * sin 213 | key = key.reshape(key_shape) 214 | 215 | return query, key 216 | 217 | 218 | @classmethod 219 | def get_input_positions( 220 | cls, 221 | input_tokens: List[int], 222 | hf_config: PretrainedConfig, 223 | image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], 224 | video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], 225 | second_per_grid_ts: Optional[List[float]], 226 | context_len: int = 0, 227 | seq_len: Optional[int] = None, 228 | audio_feature_lengths: Optional[torch.Tensor] = None, 229 | use_audio_in_video: bool = False, 230 | ) -> Tuple[List[List[int]], int]: 231 | """Get xdrope input positions and delta value.""" 232 | 233 | image_grid_thw = [] if image_grid_thw is None else image_grid_thw 234 | video_grid_thw = [] if video_grid_thw is None else video_grid_thw 235 | second_per_grid_ts = [] if second_per_grid_ts is None else \ 236 | second_per_grid_ts 237 | 238 | llm_positions, mrope_position_delta = \ 239 | cls.get_input_positions_tensor( 240 | input_tokens=input_tokens, 241 | hf_config=hf_config, 242 | image_grid_thw=image_grid_thw, 243 | video_grid_thw=video_grid_thw, 244 | second_per_grid_ts=second_per_grid_ts, 245 | context_len=context_len, 246 | seq_len=seq_len, 247 | audio_feature_lengths=audio_feature_lengths, 248 | use_audio_in_video=use_audio_in_video, 249 | ) 250 | 251 | return llm_positions.tolist(), mrope_position_delta 252 | 253 | 254 | @classmethod 255 | def get_input_positions_tensor( 256 | cls, 257 | input_tokens: List[int], 258 | hf_config: PretrainedConfig, 259 | image_grid_thw: Union[List[List[int]], torch.Tensor], 260 | video_grid_thw: Union[List[List[int]], torch.Tensor], 261 | second_per_grid_ts: List[float], 262 | context_len: int = 0, 263 | seq_len: Optional[int] = None, 264 | audio_feature_lengths: Optional[torch.Tensor] = None, 265 | use_audio_in_video: bool = False, 266 | ) -> Tuple[torch.Tensor, int]: 267 | return cls._vl_get_input_positions_tensor( 268 | input_tokens=input_tokens, 269 | hf_config=hf_config, 270 | image_grid_thw=image_grid_thw, 271 | video_grid_thw=video_grid_thw, 272 | second_per_grid_ts=second_per_grid_ts, 273 | context_len=context_len, 274 | seq_len=seq_len, 275 | ) 276 | 277 | 278 | @classmethod 279 | def _vl_get_input_positions_tensor( 280 | cls, 281 | input_tokens: List[int], 282 | hf_config: PretrainedConfig, 283 | image_grid_thw: Union[List[List[int]], torch.Tensor], 284 | video_grid_thw: Union[List[List[int]], torch.Tensor], 285 | second_per_grid_ts: List[float], 286 | context_len: int = 0, 287 | seq_len: Optional[int] = None, 288 | ) -> Tuple[torch.Tensor, int]: 289 | """Get xdrope input positions following get_xdrope_position_ids pattern.""" 290 | 291 | 292 | image_token_id = hf_config.image_token_id 293 | vision_start_token_id = hf_config.vision_start_token_id 294 | 295 | input_tokens_tensor = torch.tensor(input_tokens) 296 | 297 | # Initialize 4D position embeddings (following xdrope pattern) 298 | seq_length = len(input_tokens) 299 | position_ids_seq = torch.arange(seq_length) # Sequential positions 300 | position_ids_t = position_ids_seq.clone() 301 | position_ids_x = position_ids_seq.clone() 302 | position_ids_y = position_ids_seq.clone() 303 | 304 | vision_start_indices = torch.argwhere( 305 | input_tokens_tensor == vision_start_token_id).squeeze(1) 306 | 307 | if len(vision_start_indices) == 0: 308 | # No vision tokens, return 4D sequential positions 309 | llm_positions = torch.stack([position_ids_seq, position_ids_x, position_ids_y, position_ids_t]) 310 | mrope_position_delta = 0 311 | llm_positions = llm_positions[:, context_len:seq_len] 312 | return llm_positions, mrope_position_delta 313 | 314 | # Process vision tokens using image_grid_thw information 315 | image_index, video_index = 0, 0 316 | current_pos = 0 317 | 318 | for start_idx in vision_start_indices: 319 | start_idx = start_idx.item() 320 | 321 | # Determine if this is image or video token 322 | if start_idx + 1 < len(input_tokens): 323 | next_token = input_tokens[start_idx + 1] 324 | is_image = (next_token == image_token_id) 325 | 326 | if is_image and image_index < len(image_grid_thw): 327 | t, h, w = image_grid_thw[image_index] 328 | image_index += 1 329 | else: 330 | continue 331 | 332 | # Calculate grid dimensions 333 | llm_grid_t, llm_grid_h, llm_grid_w = ( 334 | t, h, w 335 | ) 336 | 337 | # Find end of vision tokens (approximate) 338 | vision_token_count = llm_grid_t * llm_grid_h * llm_grid_w 339 | end_idx = min(start_idx + vision_token_count + 2, seq_length) # +2 for start/end tokens 340 | 341 | # Apply xdrope position assignment pattern 342 | if end_idx > start_idx + 2: # Ensure we have vision tokens 343 | # Reset time dimension for vision tokens (following get_xdrope_position_ids) 344 | position_ids_t[start_idx + 2:end_idx] = current_pos 345 | current_pos += 1 346 | 347 | # Calculate row and column for 2D layout 348 | vision_tokens_between = end_idx - start_idx - 2 # excluding start/end 349 | if llm_grid_h > 0: 350 | tokens_per_row = llm_grid_w 351 | num_rows = llm_grid_h 352 | 353 | # Assign x,y coordinates following the pattern 354 | idx_xy = 0 355 | for rr in range(num_rows): 356 | for cc in range(tokens_per_row): 357 | if start_idx + 2 + idx_xy < end_idx: 358 | position_ids_x[start_idx + 2 + idx_xy] = cc 359 | position_ids_y[start_idx + 2 + idx_xy] = rr 360 | idx_xy += 1 361 | 362 | # Stack into 4D positions 363 | llm_positions = torch.stack([position_ids_seq, position_ids_x, position_ids_y, position_ids_t]) 364 | mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() 365 | llm_positions = llm_positions[:, context_len:seq_len] 366 | 367 | return llm_positions, mrope_position_delta 368 | 369 | 370 | @staticmethod 371 | def get_next_input_positions( 372 | mrope_position_delta: int, 373 | context_len: int, 374 | seq_len: int, 375 | ) -> List[List[int]]: 376 | return [ 377 | list( 378 | range(context_len + mrope_position_delta, 379 | seq_len + mrope_position_delta)) for _ in range(4) # Changed from 3 to 4 380 | ] 381 | 382 | 383 | @staticmethod 384 | def get_next_input_positions_tensor( 385 | mrope_position_delta: int, 386 | context_len: int, 387 | seq_len: int, 388 | ) -> torch.Tensor: 389 | return torch.arange( 390 | mrope_position_delta + context_len, 391 | mrope_position_delta + seq_len, 392 | ).expand(4, -1) # Changed from 3 to 4 393 | 394 | 395 | class HunyuanMLP(nn.Module): 396 | def __init__( 397 | self, 398 | config: PretrainedConfig, 399 | hidden_size: int, 400 | intermediat_size: int, 401 | hidden_act: str, 402 | quant_config: Optional[QuantizationConfig] = None, 403 | prefix: str = "", 404 | ) -> None: 405 | super().__init__() 406 | self.gate_and_up_proj = MergedColumnParallelLinear( 407 | hidden_size, 408 | [intermediat_size] * 2, 409 | bias=config.mlp_bias, 410 | quant_config=quant_config, 411 | prefix=f"{prefix}.gate_and_up_proj", 412 | ) 413 | # self.down_proj = ReplicatedLinear( 414 | # intermediat_size, 415 | # hidden_size, 416 | # bias=config.mlp_bias, 417 | # quant_config=quant_config, 418 | # prefix=f"{prefix}.down_proj", 419 | # ) 420 | self.down_proj = nn.Linear( 421 | intermediat_size, 422 | hidden_size, 423 | bias=config.mlp_bias, 424 | ) 425 | self.act_fn = SiluAndMul() 426 | 427 | def forward(self, x): 428 | gate_up, _ = self.gate_and_up_proj(x) 429 | x = self.act_fn(gate_up) 430 | x = self.down_proj(x) 431 | return x 432 | 433 | 434 | class HunYuanAttention(nn.Module): 435 | def __init__( 436 | self, 437 | config: PretrainedConfig, 438 | hidden_size: int, 439 | num_heads: int, 440 | num_kv_heads: int, 441 | rope_theta: float = 10000, 442 | rope_scaling: Optional[Dict[str, Any]] = None, 443 | max_position_embeddings: int = 8192, 444 | attention_bias: bool = False, 445 | cache_config: Optional[CacheConfig] = None, 446 | quant_config: Optional[QuantizationConfig] = None, 447 | prefix: str = "", 448 | ): 449 | super().__init__() 450 | self.hidden_size = hidden_size 451 | self.tp_size = get_tensor_model_parallel_world_size() 452 | self.tp_rank = get_tensor_model_parallel_rank() 453 | self.total_num_heads = num_heads 454 | assert self.total_num_heads % self.tp_size == 0 455 | self.num_heads = self.total_num_heads // self.tp_size 456 | self.total_num_kv_heads = num_kv_heads 457 | if self.total_num_kv_heads >= self.tp_size: 458 | # Number of KV heads is greater than TP size, so we partition 459 | # the KV heads across multiple tensor parallel GPUs. 460 | assert self.total_num_kv_heads % self.tp_size == 0 461 | else: 462 | # Number of KV heads is less than TP size, so we replicate 463 | # the KV heads across multiple tensor parallel GPUs. 464 | assert self.tp_size % self.total_num_kv_heads == 0 465 | self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) 466 | self.head_dim = hidden_size // self.total_num_heads 467 | self.q_size = self.num_heads * self.head_dim 468 | self.kv_size = self.num_kv_heads * self.head_dim 469 | self.key_value_groups = int(self.num_heads / self.num_kv_heads) 470 | self.scaling = self.head_dim**-0.5 471 | self.rope_theta = rope_theta 472 | self.max_position_embeddings = max_position_embeddings 473 | 474 | self.qkv_proj = QKVParallelLinear( 475 | hidden_size, 476 | self.head_dim, 477 | self.total_num_heads, 478 | self.total_num_kv_heads, 479 | bias=attention_bias, 480 | quant_config=quant_config, 481 | prefix=f"{prefix}.wqkv", 482 | ) 483 | 484 | # self.o_proj = RowParallelLinear( 485 | # self.total_num_heads * self.head_dim, 486 | # hidden_size, 487 | # bias=attention_bias, 488 | # quant_config=quant_config, 489 | # prefix=f"{prefix}.wo", 490 | # ) 491 | self.o_proj = nn.Linear( 492 | self.total_num_heads * self.head_dim, 493 | hidden_size, 494 | bias=attention_bias, 495 | ) 496 | 497 | self.query_layernorm = ( 498 | RMSNorm(self.head_dim, eps=config.rms_norm_eps) 499 | if config.use_qk_norm 500 | else None 501 | ) 502 | self.key_layernorm = ( 503 | RMSNorm(self.head_dim, eps=config.rms_norm_eps) 504 | if config.use_qk_norm 505 | else None 506 | ) 507 | 508 | self.rotary_emb = ( 509 | DynamicNTKAlphaMRotaryEmbedding( 510 | self.head_dim, 511 | self.head_dim, 512 | max_position_embeddings, 513 | int(rope_theta), 514 | scaling_alpha=rope_scaling["alpha"], 515 | dtype=torch.get_default_dtype(), 516 | mrope_section=rope_scaling["mrope_section"], 517 | max_model_len=config.max_model_len, 518 | ) 519 | if config.use_rotary_pos_emb 520 | else None 521 | ) 522 | 523 | self.attn = Attention( 524 | self.num_heads, 525 | self.head_dim, 526 | self.scaling, 527 | num_kv_heads=self.num_kv_heads, 528 | cache_config=cache_config, 529 | quant_config=quant_config, 530 | prefix=f"{prefix}.attn", 531 | ) 532 | 533 | def split_qkv(self, qkv: torch.Tensor): 534 | seq_len = qkv.shape[0] 535 | if self.tp_size > 1: 536 | qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size 537 | qkv = tensor_model_parallel_all_gather(qkv) 538 | qkv = torch.split(qkv, qkv_map, dim=-1) 539 | qkv = qkv[::3] + qkv[1::3] + qkv[2::3] 540 | qkv = torch.cat(qkv, dim=-1) 541 | 542 | qkv = qkv.view( 543 | seq_len, 544 | self.total_num_kv_heads, 545 | self.key_value_groups + 2, 546 | self.head_dim, 547 | ) 548 | q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) 549 | q = q.reshape(seq_len, self.q_size * self.tp_size) 550 | k = k.reshape(seq_len, self.kv_size * self.tp_size) 551 | v = v.reshape(seq_len, self.kv_size * self.tp_size) 552 | 553 | if self.tp_size > 1: 554 | splitter = partial( 555 | split_tensor_along_last_dim, num_partitions=self.tp_size 556 | ) 557 | q = splitter(q)[self.tp_rank] 558 | k = splitter(k)[self.tp_rank] 559 | v = splitter(v)[self.tp_rank] 560 | return q, k, v 561 | 562 | def forward( 563 | self, 564 | positions: torch.Tensor, 565 | hidden_states: torch.Tensor, 566 | ) -> torch.Tensor: 567 | qkv, _ = self.qkv_proj(hidden_states) 568 | q, k, v = self.split_qkv(qkv) 569 | 570 | if self.rotary_emb is not None: 571 | q, k = self.rotary_emb(positions, q, k) 572 | 573 | 574 | if self.query_layernorm is not None: 575 | q = q.reshape(-1, self.num_heads, self.head_dim) 576 | q = self.query_layernorm(q).reshape(-1, self.q_size) 577 | 578 | if self.key_layernorm is not None: 579 | k = k.reshape(-1, self.num_kv_heads, self.head_dim) 580 | k = self.key_layernorm(k).reshape(-1, self.kv_size) 581 | 582 | attn_output = self.attn(q, k, v) 583 | output = self.o_proj(attn_output) 584 | return output 585 | 586 | 587 | class HunYuanDecoderLayer(nn.Module): 588 | def __init__( 589 | self, 590 | config: PretrainedConfig, 591 | cache_config: CacheConfig, 592 | quant_config: QuantizationConfig, 593 | prefix: str = "", 594 | ): 595 | super().__init__() 596 | self.hidden_size = config.hidden_size 597 | self.self_attn = HunYuanAttention( 598 | config, 599 | config.hidden_size, 600 | config.num_attention_heads, 601 | config.num_key_value_heads, 602 | config.rope_theta, 603 | config.rope_scaling, 604 | config.max_position_embeddings, 605 | config.attention_bias, 606 | cache_config, 607 | quant_config, 608 | prefix=f"{prefix}.attention", 609 | ) 610 | self.mlp = HunyuanMLP( 611 | config, 612 | config.hidden_size, 613 | config.intermediate_size, 614 | config.hidden_act, 615 | quant_config, 616 | prefix=f"{prefix}.mlp", 617 | ) 618 | 619 | self.input_layernorm = RMSNorm( 620 | config.hidden_size, eps=config.rms_norm_eps 621 | ) 622 | self.post_attention_layernorm = RMSNorm( 623 | config.hidden_size, eps=config.rms_norm_eps 624 | ) 625 | 626 | def forward( 627 | self, 628 | positions: torch.Tensor, 629 | hidden_states: torch.Tensor, 630 | residual: Optional[torch.Tensor], 631 | ) -> Tuple[torch.Tensor, torch.Tensor]: 632 | if residual is None: 633 | residual = hidden_states 634 | hidden_states = self.input_layernorm(hidden_states) 635 | else: 636 | hidden_states, residual = self.input_layernorm( 637 | hidden_states, residual 638 | ) 639 | 640 | hidden_states = self.self_attn( 641 | positions=positions, 642 | hidden_states=hidden_states, 643 | ) 644 | 645 | hidden_states, residual = self.post_attention_layernorm( 646 | hidden_states, residual 647 | ) 648 | 649 | hidden_states = self.mlp(hidden_states) 650 | 651 | return hidden_states, residual 652 | 653 | 654 | @support_torch_compile( 655 | dynamic_arg_dims={ 656 | "input_ids": 0, 657 | # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, 658 | # otherwise (seq_len, ). 659 | "positions": -1, 660 | "intermediate_tensors": 0, 661 | "inputs_embeds": 0, 662 | }) 663 | class HunYuanModel(nn.Module): 664 | def __init__( 665 | self, 666 | *, 667 | vllm_config: VllmConfig, 668 | prefix: str = "", 669 | layer_type: Type[HunYuanDecoderLayer] = HunYuanDecoderLayer, 670 | ): 671 | super().__init__() 672 | 673 | config = vllm_config.model_config.hf_config 674 | cache_config = vllm_config.cache_config 675 | quant_config = vllm_config.quant_config 676 | 677 | self.config = config 678 | self.vocab_size = config.vocab_size 679 | 680 | self.embed_tokens = VocabParallelEmbedding( 681 | config.vocab_size, 682 | config.hidden_size, 683 | ) # TODO: This does not support padding_idx, check if this is an issue 684 | 685 | self.start_layer, self.end_layer, self.layers = make_layers( 686 | config.num_hidden_layers, 687 | lambda prefix: layer_type( 688 | config, cache_config, quant_config, prefix=prefix 689 | ), 690 | prefix=f"{prefix}.layers", 691 | ) 692 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 693 | self.make_empty_intermediate_tensors = ( 694 | make_empty_intermediate_tensors_factory( 695 | ["hidden_states", "residual"], config.hidden_size 696 | ) 697 | ) 698 | 699 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 700 | return self.embed_tokens(input_ids) 701 | 702 | def forward( 703 | self, 704 | input_ids: torch.Tensor, 705 | positions: torch.Tensor, 706 | intermediate_tensors: Optional[IntermediateTensors] = None, 707 | inputs_embeds: Optional[torch.Tensor] = None, 708 | **kwargs: object, 709 | ) -> Union[torch.Tensor, IntermediateTensors]: 710 | if get_pp_group().is_first_rank: 711 | if inputs_embeds is not None: 712 | hidden_states = inputs_embeds 713 | else: 714 | hidden_states = self.get_input_embeddings(input_ids) 715 | residual = None 716 | else: 717 | hidden_states = intermediate_tensors["hidden_states"] 718 | residual = intermediate_tensors["residual"] 719 | for layer in self.layers[self.start_layer : self.end_layer]: 720 | hidden_states, residual = layer(positions, hidden_states, residual) 721 | if not get_pp_group().is_first_rank: 722 | return IntermediateTensors( 723 | { 724 | "hidden_states": hidden_states, 725 | "residual": residual, 726 | } 727 | ) 728 | hidden_states, _ = self.norm(hidden_states, residual) 729 | 730 | return hidden_states 731 | 732 | 733 | class HunYuanForCausalLM(nn.Module, SupportsPP): 734 | 735 | def __init__( 736 | self, 737 | *, 738 | vllm_config: VllmConfig, 739 | prefix: str = "", 740 | model_type: Type[HunYuanModel] = HunYuanModel, 741 | ): 742 | super().__init__() 743 | 744 | config = vllm_config.model_config.hf_config 745 | quant_config = vllm_config.quant_config 746 | 747 | self.config = config 748 | self.model = model_type( 749 | vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") 750 | ) 751 | 752 | self.lm_head = ParallelLMHead( 753 | config.vocab_size, 754 | config.hidden_size, 755 | bias=False, 756 | quant_config=quant_config, 757 | prefix=maybe_prefix(prefix, "lm_head"), 758 | ) 759 | self.logits_processor = LogitsProcessor(config.vocab_size) 760 | self.sampler = get_sampler() 761 | self.make_empty_intermediate_tensors = ( 762 | self.model.make_empty_intermediate_tensors 763 | ) 764 | 765 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 766 | return self.model.embed_tokens(input_ids) 767 | 768 | def forward( 769 | self, 770 | input_ids: torch.Tensor, 771 | positions: torch.Tensor, 772 | intermediate_tensors: Optional[IntermediateTensors] = None, 773 | inputs_embeds: Optional[torch.Tensor] = None, 774 | **kwargs: object, 775 | ) -> torch.Tensor: 776 | assert positions.ndim == 2, f"positions must be 2D, but got {positions.shape}" 777 | hidden_states = self.model( 778 | input_ids, positions, intermediate_tensors, inputs_embeds 779 | ) 780 | return hidden_states 781 | 782 | def compute_logits( 783 | self, 784 | hidden_states: torch.Tensor, 785 | sampling_metadata: SamplingMetadata, 786 | ) -> Optional[torch.Tensor]: 787 | logits = self.logits_processor( 788 | self.lm_head, hidden_states, sampling_metadata 789 | ) 790 | return logits 791 | 792 | def sample( 793 | self, 794 | logits: torch.Tensor, 795 | sampling_metadata: SamplingMetadata, 796 | ) -> Optional[SamplerOutput]: 797 | next_tokens = self.sampler(logits, sampling_metadata) 798 | return next_tokens 799 | 800 | def load_weights( 801 | self, weights: Iterable[Tuple[str, torch.Tensor]] 802 | ) -> Set[str]: 803 | loader = AutoWeightsLoader(self) 804 | return loader.load_weights(weights) 805 | -------------------------------------------------------------------------------- /model_vllm/hunyuan_video.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Iterable, Mapping, Sequence 3 | from functools import cached_property 4 | from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union, Any 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms as T 10 | from PIL import Image 11 | from transformers import ( 12 | BatchEncoding, 13 | PretrainedConfig, 14 | TensorType, 15 | WhisperFeatureExtractor, 16 | ) 17 | import math 18 | import logging 19 | 20 | from vllm.config import VllmConfig 21 | from vllm.model_executor.layers.quantization import QuantizationConfig 22 | from vllm.model_executor.layers.quantization.awq import AWQConfig 23 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler 24 | from vllm.model_executor.sampling_metadata import SamplingMetadata 25 | from vllm.multimodal import MULTIMODAL_REGISTRY 26 | from vllm.multimodal.inputs import ( 27 | MultiModalFieldConfig, 28 | MultiModalKwargs, 29 | NestedTensors, 30 | MultiModalDataDict, 31 | MultiModalInputs, 32 | ) 33 | from vllm.multimodal.parse import ( 34 | ImageEmbeddingItems, 35 | ImageProcessorItems, 36 | ImageSize, 37 | MultiModalDataItems, 38 | ) 39 | from vllm.multimodal.processing import ( 40 | BaseMultiModalProcessor, 41 | BaseProcessingInfo, 42 | PromptReplacement, 43 | PromptUpdate, 44 | PromptUpdateDetails, 45 | ) 46 | from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs 47 | from vllm.sequence import IntermediateTensors 48 | from vllm.transformers_utils.tokenizer import AnyTokenizer 49 | 50 | from vllm.model_executor.models.interfaces import ( 51 | MultiModalEmbeddings, 52 | SupportsMultiModal, 53 | SupportsPP, 54 | ) 55 | from vllm.model_executor.models.utils import ( 56 | AutoWeightsLoader, 57 | flatten_bn, 58 | init_vllm_registered_model, 59 | maybe_prefix, 60 | merge_multimodal_embeddings, 61 | WeightsMapper, 62 | ) 63 | from vllm.model_executor.models.whisper import WhisperEncoder 64 | from vllm.multimodal.parse import ( 65 | MultiModalDataParser, 66 | ModalityData, 67 | ModalityDataItems, 68 | DictEmbeddingItems, 69 | ProcessorBatchItems, 70 | ) 71 | from vllm.multimodal.inputs import ImageItem 72 | from vllm.transformers_utils.tokenizer import decode_tokens 73 | from vllm.multimodal.hasher import MultiModalHasher 74 | 75 | 76 | logger = logging.getLogger(__name__) 77 | 78 | 79 | IMG_START = "" 80 | IMG_END = "" 81 | IMG_CONTEXT = "" 82 | 83 | 84 | def _hunyuan_field_config(hf_inputs: Mapping[str, torch.Tensor]): 85 | 86 | image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) 87 | image_grid_sizes = image_grid_thw.prod(-1) 88 | 89 | return dict( 90 | pixel_values_flat=MultiModalFieldConfig.batched("image"), 91 | image_embeds=MultiModalFieldConfig.batched("image"), 92 | image_grid_thw=MultiModalFieldConfig.batched("image"), 93 | ) 94 | 95 | 96 | class HunyuanMultiModalDataParser(MultiModalDataParser): 97 | 98 | def _parse_image_data( 99 | self, 100 | data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], 101 | ) -> ModalityDataItems[Any, Any]: 102 | if isinstance(data, dict): 103 | return DictEmbeddingItems( 104 | data, 105 | modality="image", 106 | required_fields={"image_embeds", "image_grid_thw"}, 107 | fields_factory=_hunyuan_field_config, 108 | ) 109 | 110 | return super()._parse_image_data(data) 111 | 112 | 113 | class HunyuanImageEmbedInputs(TypedDict): 114 | type: Literal["image_embeds"] 115 | data: Union[torch.Tensor, list[torch.Tensor]] 116 | """ 117 | A tensor of shape `(num_images, total_image_feature_size, hidden_size)` 118 | or a list of tensors of shape `(total_image_feature_size, hidden_size)` 119 | 120 | `hidden_size` must match the hidden size of language model backbone. 121 | """ 122 | 123 | image_grid_thw: torch.Tensor 124 | 125 | 126 | class BaseHunyuanProcessor(ABC): 127 | 128 | def __init__( 129 | self, 130 | config: PretrainedConfig, 131 | tokenizer: AnyTokenizer, 132 | ) -> None: 133 | self.config = config 134 | self.tokenizer = tokenizer 135 | self.num_image_token = config.num_image_token 136 | 137 | @property 138 | @abstractmethod 139 | def image_token_id(self) -> int: 140 | raise NotImplementedError 141 | 142 | @abstractmethod 143 | def get_image_replace( 144 | self, 145 | ) -> PromptUpdateDetails[str]: 146 | raise NotImplementedError 147 | 148 | def __call__( 149 | self, 150 | text: Optional[Union[str, list[str]]] = None, 151 | images: Optional[Union[Image.Image, list[Image.Image]]] = None, 152 | return_tensors: Optional[Union[str, TensorType]] = None, 153 | ) -> Mapping[str, NestedTensors]: 154 | if text is None: 155 | text = [] 156 | if not isinstance(text, list): 157 | text = [text] 158 | 159 | if images is not None: 160 | raise NotImplementedError("Image processing not implemented") 161 | 162 | text_inputs = self.tokenizer(text) 163 | 164 | output = { 165 | **BatchEncoding(text_inputs, tensor_type=return_tensors), 166 | } 167 | return output 168 | 169 | 170 | class HunyuanProcessor(BaseHunyuanProcessor): 171 | 172 | @property 173 | def image_token_id(self) -> int: 174 | image_token_id = self.tokenizer.get_vocab()[IMG_CONTEXT] 175 | return image_token_id 176 | 177 | def get_image_replace( 178 | self, 179 | ) -> PromptUpdateDetails[str]: 180 | replace_features = IMG_CONTEXT * self.num_image_token 181 | replace_full = IMG_START + replace_features + IMG_END 182 | 183 | return PromptUpdateDetails.select_text(replace_full, IMG_CONTEXT) 184 | # return PromptUpdateDetails(full=replace_full, features=replace_features) 185 | 186 | 187 | class BaseHunyuanProcessingInfo(BaseProcessingInfo): 188 | 189 | @abstractmethod 190 | def get_hf_processor( 191 | self, 192 | **kwargs: object, 193 | ) -> BaseHunyuanProcessor: 194 | raise NotImplementedError 195 | 196 | def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: 197 | return {"image": None} 198 | 199 | def get_mm_max_tokens_per_item( 200 | self, 201 | seq_len: int, 202 | mm_counts: Mapping[str, int], 203 | ) -> Mapping[str, int]: 204 | return {"image": self.get_max_image_tokens()} 205 | 206 | def get_max_image_tokens(self) -> int: 207 | processor = self.get_hf_processor() 208 | num_image_token = processor.num_image_token 209 | return num_image_token 210 | 211 | 212 | _I = TypeVar("_I", bound=BaseHunyuanProcessingInfo) 213 | 214 | 215 | class HunyuanDummyInputsBuilder(BaseDummyInputsBuilder[_I]): 216 | 217 | def get_dummy_processor_inputs( 218 | self, 219 | seq_len: int, 220 | mm_counts: Mapping[str, int], 221 | ) -> ProcessorInputs: 222 | num_images = mm_counts.get("image", 0) 223 | 224 | num_image_token = self.info.get_hf_processor().num_image_token 225 | hidden_size = self.info.get_hf_processor().config.hidden_size 226 | 227 | grid_hw = int((math.sqrt(4 * num_image_token - 7) - 1) / 2) 228 | 229 | grid_thw = torch.tensor([[1, grid_hw, grid_hw]]) 230 | grid_thw = grid_thw.repeat(num_images, 1) 231 | 232 | mm_data = { 233 | "image": { 234 | "image_embeds": torch.randn( 235 | num_images, 236 | num_image_token, 237 | hidden_size, 238 | dtype=torch.bfloat16, 239 | ), 240 | "image_grid_thw": grid_thw, 241 | } 242 | } 243 | 244 | return ProcessorInputs( 245 | prompt_text="" * num_images, 246 | mm_data=mm_data, 247 | ) 248 | 249 | 250 | class HunyuanMultiModalProcessor(BaseMultiModalProcessor[_I]): 251 | 252 | def _call_hf_processor( 253 | self, 254 | prompt: str, 255 | mm_data: Mapping[str, object], 256 | mm_kwargs: Mapping[str, object], 257 | ) -> Mapping[str, NestedTensors]: 258 | processed_outputs = super()._call_hf_processor( 259 | prompt=prompt, 260 | mm_data=mm_data, 261 | mm_kwargs=mm_kwargs, 262 | ) 263 | return processed_outputs 264 | 265 | def _get_data_parser(self) -> HunyuanMultiModalDataParser: 266 | return HunyuanMultiModalDataParser() 267 | 268 | def _get_mm_fields_config( 269 | self, 270 | hf_inputs: Mapping[str, NestedTensors], 271 | hf_processor_mm_kwargs: Mapping[str, object], 272 | ) -> Mapping[str, MultiModalFieldConfig]: 273 | return _hunyuan_field_config(hf_inputs) 274 | 275 | def _get_prompt_updates( 276 | self, 277 | mm_items: MultiModalDataItems, 278 | hf_processor_mm_kwargs: Mapping[str, object], 279 | out_mm_kwargs: MultiModalKwargs, 280 | ) -> Sequence[PromptUpdate]: 281 | hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) 282 | 283 | image_replace = hf_processor.get_image_replace() 284 | 285 | return [ 286 | PromptReplacement( 287 | modality="image", 288 | target="", 289 | replacement=image_replace, 290 | ), 291 | ] 292 | 293 | 294 | class HunyuanProcessingInfo(BaseHunyuanProcessingInfo): 295 | 296 | def get_hf_processor( 297 | self, 298 | **kwargs: object, 299 | ) -> HunyuanProcessor: 300 | return self.ctx.init_processor( 301 | HunyuanProcessor, 302 | config=self.get_hf_config(), 303 | tokenizer=self.get_tokenizer(), 304 | **kwargs, 305 | ) 306 | 307 | 308 | @MULTIMODAL_REGISTRY.register_processor( 309 | HunyuanMultiModalProcessor, 310 | info=HunyuanProcessingInfo, 311 | dummy_inputs=HunyuanDummyInputsBuilder, 312 | ) 313 | class HunyuanVideoModel(nn.Module, SupportsMultiModal, SupportsPP): 314 | 315 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: 316 | super().__init__() 317 | config = vllm_config.model_config.hf_config 318 | 319 | self.config = config 320 | 321 | self.language_model = init_vllm_registered_model( 322 | vllm_config=vllm_config, 323 | prefix=maybe_prefix(prefix, "language_model"), 324 | architectures=["HunYuanForCausalLM"], 325 | ) 326 | 327 | self.system_message = None 328 | self.num_samples = 0 329 | 330 | @cached_property 331 | def sampler(self): 332 | if hasattr(self.language_model, "sampler"): 333 | return self.language_model.sampler 334 | else: 335 | raise NotImplementedError 336 | 337 | def get_input_embeddings( 338 | self, 339 | input_ids: torch.Tensor, 340 | multimodal_embeddings: Optional[MultiModalEmbeddings] = None, 341 | ) -> torch.Tensor: 342 | 343 | inputs_embeds = self.language_model.get_input_embeddings(input_ids) 344 | if multimodal_embeddings is not None: 345 | inputs_embeds = merge_multimodal_embeddings( 346 | input_ids, 347 | inputs_embeds, 348 | multimodal_embeddings, 349 | self.config.image_token_id, 350 | ) 351 | return inputs_embeds 352 | 353 | def get_multimodal_embeddings( 354 | self, **kwargs: object 355 | ) -> Optional[MultiModalEmbeddings]: 356 | image_input = self._parse_and_validate_image_input(**kwargs) 357 | 358 | if image_input is None: 359 | return None 360 | 361 | return image_input["data"] 362 | 363 | def get_language_model(self) -> torch.nn.Module: 364 | return self.language_model 365 | 366 | def _parse_and_validate_image_input( 367 | self, **kwargs: object 368 | ) -> Optional[HunyuanImageEmbedInputs]: 369 | image_embeds = kwargs.pop("image_embeds", None) 370 | image_grid_thw = kwargs.pop("image_grid_thw", None) 371 | 372 | if image_embeds is None: 373 | return None 374 | 375 | if not isinstance(image_embeds, (torch.Tensor, list)): 376 | raise ValueError( 377 | "Incorrect type of image embeddings. " 378 | f"Got type: {type(image_embeds)}" 379 | ) 380 | 381 | image_embeds = image_embeds.to(self.config.torch_dtype) 382 | 383 | return HunyuanImageEmbedInputs( 384 | type="image_embeds", 385 | data=flatten_bn(image_embeds), 386 | image_grid_thw=flatten_bn(image_grid_thw), 387 | ) 388 | 389 | def _process_image_input( 390 | self, image_input: HunyuanImageEmbedInputs 391 | ) -> MultiModalEmbeddings: 392 | grid_thw = image_input["image_grid_thw"] 393 | assert grid_thw.ndim == 2 394 | 395 | if image_input["type"] == "image_embeds": 396 | image_embeds = image_input["data"] 397 | 398 | merge_size = 1 # TODO: Check this 399 | sizes = grid_thw.prod(-1) // merge_size // merge_size 400 | 401 | return image_embeds.split(sizes.tolist()) 402 | 403 | def forward( 404 | self, 405 | input_ids: torch.Tensor, 406 | positions: torch.Tensor, 407 | intermediate_tensors: Optional[IntermediateTensors] = None, 408 | inputs_embeds: Optional[torch.Tensor] = None, 409 | **kwargs: object, 410 | ) -> Union[SamplerOutput, IntermediateTensors]: 411 | 412 | if intermediate_tensors is not None: 413 | input_ids = None 414 | inputs_embeds = None 415 | elif inputs_embeds is None: 416 | # raise ValueError(f"v0 not supported, {kwargs}") 417 | vision_embeddings = self.get_multimodal_embeddings(**kwargs) 418 | inputs_embeds = self.get_input_embeddings(input_ids, 419 | vision_embeddings) 420 | input_ids = None 421 | 422 | hidden_states = self.language_model.model( 423 | input_ids=input_ids, 424 | positions=positions, 425 | intermediate_tensors=intermediate_tensors, 426 | inputs_embeds=inputs_embeds, 427 | **kwargs, 428 | ) 429 | 430 | 431 | return hidden_states 432 | 433 | def compute_logits( 434 | self, 435 | hidden_states: torch.Tensor, 436 | sampling_metadata: SamplingMetadata, 437 | ) -> Optional[torch.Tensor]: 438 | return self.language_model.compute_logits( 439 | hidden_states, sampling_metadata 440 | ) 441 | 442 | def sample( 443 | self, 444 | logits: torch.Tensor, 445 | sampling_metadata: SamplingMetadata, 446 | ) -> Optional[SamplerOutput]: 447 | return self.language_model.sample(logits, sampling_metadata) 448 | 449 | def load_weights( 450 | self, weights: Iterable[Tuple[str, torch.Tensor]] 451 | ) -> Set[str]: 452 | loader = AutoWeightsLoader(self) 453 | weights = list(weights) 454 | for i, (k, v) in enumerate(weights): 455 | # The order of SiLU and Mul is different in VLLM 456 | if ".mlp.gate_and_up_proj.weight" in k: 457 | v1, v2 = v.chunk(2, dim=0) 458 | weights[i] = (k, torch.cat([v2, v1], dim=0)) 459 | 460 | # Filter out weights that are not in the language model (vit, whisper, mlp2) 461 | weights = [(k, v) for k, v in weights if k.startswith("language_model")] 462 | 463 | if "language_model.lm_head.weight" not in weights: 464 | logger.warning( 465 | "langauge.lm_head.weight not found in weights, " 466 | "will try to load it from language_model.embed_tokens.weight" 467 | ) 468 | weights.append(("language_model.lm_head.weight", self.language_model.model.embed_tokens.weight)) 469 | 470 | return loader.load_weights(weights) 471 | -------------------------------------------------------------------------------- /model_vllm/monkey_patch_mrope.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import importlib 3 | 4 | # Module where the original MRotaryEmbedding is defined 5 | VLLM_ROTARY_EMBEDDING_MODULE = "vllm.model_executor.layers.rotary_embedding" 6 | 7 | # Path to your custom class 8 | # Adjust the import path if your project structure is different 9 | YOUR_CUSTOM_MODULE = "model_vllm.hunyuan" 10 | YOUR_CUSTOM_CLASS_NAME = "DynamicNTKAlphaMRotaryEmbedding" 11 | 12 | try: 13 | # Import the vLLM module 14 | vllm_rotary_module = importlib.import_module(VLLM_ROTARY_EMBEDDING_MODULE) 15 | 16 | # Import your custom class 17 | custom_module = importlib.import_module(YOUR_CUSTOM_MODULE) 18 | CustomRotaryEmbeddingClass = getattr(custom_module, YOUR_CUSTOM_CLASS_NAME) 19 | 20 | # Perform the monkey patch: 21 | # Replace the MRotaryEmbedding in the vLLM module with your class 22 | setattr(vllm_rotary_module, "MRotaryEmbedding", CustomRotaryEmbeddingClass) 23 | 24 | print(f"Successfully monkey-patched 'MRotaryEmbedding' in '{VLLM_ROTARY_EMBEDDING_MODULE}' " 25 | f"with '{YOUR_CUSTOM_CLASS_NAME}' from '{YOUR_CUSTOM_MODULE}'.") 26 | 27 | except ImportError as e: 28 | print(f"Error during monkey patching: Could not import modules. {e}") 29 | print("Please ensure that vLLM is installed and your custom module path is correct.") 30 | except AttributeError as e: 31 | print(f"Error during monkey patching: Could not find class/attribute. {e}") 32 | print("Please ensure class names and module contents are correct.") 33 | except Exception as e: 34 | print(f"An unexpected error occurred during monkey patching: {e}") 35 | 36 | -------------------------------------------------------------------------------- /model_vllm/setup_vllm_env.sh: -------------------------------------------------------------------------------- 1 | eval "$(conda shell.bash hook)" 2 | conda create -y -n arc python=3.10 3 | conda activate arc 4 | 5 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 6 | pip install -r model_vllm/requirements.txt 7 | conda install -y ffmpeg 8 | pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl 9 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 10 | 11 | 12 | export VLLM_PRECOMPILED_WHEEL_LOCATION=$(pwd)/model_vllm/vllm/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl 13 | export VLLM_VERSION=v0.8.5.post1-1-gbed41f50d 14 | pip install --editable model_vllm/vllm/ 15 | -------------------------------------------------------------------------------- /model_vllm/video_audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from transformers.modeling_utils import no_init_weights 6 | from transformers import WhisperFeatureExtractor, WhisperModel, AutoConfig 7 | import sys 8 | import os 9 | 10 | from transformers import ARCHunyuanVideoVisionModel, ARCHunyuanVideoAudioEncoder 11 | 12 | class VideoAudioEncoder(nn.Module): 13 | def __init__(self, config, max_num_frames=150): 14 | super().__init__() 15 | self.max_num_frames = max_num_frames 16 | 17 | config.vision_config._attn_implementation = "flash_attention_2" 18 | config.audio_config._attn_implementation = "flash_attention_2" 19 | 20 | with no_init_weights(): 21 | # Initialize vision model 22 | self.vision_model = ARCHunyuanVideoVisionModel( 23 | vision_config=config.vision_config, 24 | text_config=config.text_config, 25 | ) 26 | 27 | self.speech_encoder = ARCHunyuanVideoAudioEncoder( 28 | config=config.audio_config, 29 | ) 30 | 31 | self.speech_dim = config.audio_config.d_model 32 | 33 | llm_hidden_size = config.text_config.hidden_size 34 | 35 | self.mlp2 = nn.Sequential( 36 | nn.LayerNorm(self.speech_dim), 37 | nn.Linear(self.speech_dim, llm_hidden_size), 38 | nn.GELU(), 39 | nn.Linear(llm_hidden_size, llm_hidden_size), 40 | ) 41 | 42 | @torch.no_grad() 43 | def extract_image_feature(self, pixel_values): 44 | """Extract features from image tensors using vision model""" 45 | vit_embeds = self.vision_model(pixel_values) 46 | return vit_embeds 47 | 48 | @torch.no_grad() 49 | def extract_audio_feature(self, audio_values): 50 | """Extract features from audio tensors using speech encoder""" 51 | audio_values = audio_values.squeeze(0).reshape( 52 | -1, 128, audio_values.shape[-1] 53 | ) 54 | num_segments = audio_values.shape[0] 55 | 56 | speech_embeds = self.speech_encoder( 57 | audio_values, return_dict=True 58 | ).last_hidden_state 59 | 60 | speech_embeds = speech_embeds.reshape(1, -1, speech_embeds.shape[-1]) 61 | speech_embeds = self.mlp2(speech_embeds) 62 | return num_segments, speech_embeds 63 | 64 | def create_mixed_embeddings(self, vit_embeds, audio_embeds, duration): 65 | """Create mixed embeddings from visual and audio features""" 66 | # Reshape audio embeddings to match video frames 67 | audio_embeds = audio_embeds.reshape( 68 | audio_embeds.shape[0], -1, 50, audio_embeds.shape[-1] 69 | ) 70 | audio_embeds_no_pad = audio_embeds[:, :duration].squeeze(0) 71 | 72 | max_num_frame = self.max_num_frames 73 | 74 | # Handle case where audio duration exceeds max number of frames 75 | if duration > max_num_frame: 76 | per_audio_tokens = math.ceil( 77 | audio_embeds_no_pad.shape[0] / max_num_frame * 50 78 | ) 79 | num_audio_tokens_sum = per_audio_tokens * max_num_frame 80 | audio_embeds_no_pad = audio_embeds_no_pad.reshape( 81 | -1, audio_embeds_no_pad.shape[-1] 82 | ) 83 | 84 | if num_audio_tokens_sum != audio_embeds_no_pad.shape[0]: 85 | zero_padding = ( 86 | torch.zeros( 87 | num_audio_tokens_sum - audio_embeds_no_pad.shape[0], 88 | audio_embeds_no_pad.shape[-1], 89 | ) 90 | .to(audio_embeds_no_pad.dtype) 91 | .to(audio_embeds_no_pad.device) 92 | ) 93 | audio_embeds_no_pad = torch.cat( 94 | (audio_embeds_no_pad, zero_padding), dim=0 95 | ) 96 | 97 | audio_embeds_no_pad = audio_embeds_no_pad.reshape( 98 | max_num_frame, -1, audio_embeds_no_pad.shape[-1] 99 | ) 100 | 101 | # Pad or trim to match the visual embedding shape 102 | padding_size = vit_embeds.shape[1] - audio_embeds_no_pad.shape[1] 103 | if padding_size != 0: 104 | zero_padding = ( 105 | torch.zeros( 106 | vit_embeds.shape[0], 107 | padding_size, 108 | audio_embeds_no_pad.shape[-1], 109 | ) 110 | .to(audio_embeds_no_pad.dtype) 111 | .to(audio_embeds_no_pad.device) 112 | ) 113 | audio_embeds_pad = torch.cat( 114 | (audio_embeds_no_pad, zero_padding), dim=1 115 | ) 116 | else: 117 | audio_embeds_pad = audio_embeds_no_pad 118 | 119 | mixed_embeds = vit_embeds + audio_embeds_pad 120 | 121 | return mixed_embeds 122 | 123 | def forward(self, pixel_values, audio_values, duration): 124 | """ 125 | Encode images and audio to create mixed embeddings 126 | 127 | Args: 128 | pixel_values (torch.Tensor): Batch of images from video (processed frames) 129 | audio_values (torch.Tensor): Processed audio features 130 | duration (int): Duration of the video in frames or seconds 131 | 132 | Returns: 133 | mixed_embeds (torch.Tensor): Mixed embeddings combining vision and audio 134 | """ 135 | 136 | # Extract features 137 | vit_embeds = self.extract_image_feature(pixel_values) 138 | 139 | _, audio_embeds = self.extract_audio_feature(audio_values) 140 | 141 | # Create mixed embeddings 142 | mixed_embeds = self.create_mixed_embeddings( 143 | vit_embeds, audio_embeds, duration 144 | ) 145 | 146 | return mixed_embeds 147 | 148 | -------------------------------------------------------------------------------- /model_vllm/video_audio_llm.py: -------------------------------------------------------------------------------- 1 | # This create a class to use the model in vllm with mm encoder + llm 2 | # The inputs should be preprocessed 3 | # The mm encoder will process input sequentially and llm will do batch inference 4 | # This may be less efficient, but much more clear and easy to use 5 | 6 | import os 7 | import json 8 | from pathlib import Path 9 | import tempfile 10 | import shutil 11 | 12 | import torch 13 | from huggingface_hub import snapshot_download 14 | 15 | from transformers import AutoTokenizer, AutoConfig, PretrainedConfig 16 | from vllm import LLM, SamplingParams 17 | from safetensors.torch import load_file as safetensors_load_file 18 | 19 | from model_vllm import VideoAudioEncoder 20 | 21 | 22 | def convert_config_to_legacy(config, max_model_len): 23 | legacy_config = PretrainedConfig() 24 | 25 | legacy_config.update(config.vision_config.to_dict()) 26 | legacy_config.update(config.text_config.to_dict()) 27 | 28 | force_image_size = config.vision_config.force_image_size 29 | num_image_token = int( 30 | (force_image_size / 64) 31 | * (force_image_size / 64 + 1) 32 | + 2 33 | ) 34 | 35 | legacy_config.update({ 36 | # Such that vllm can caculate the max_model_len correctly 37 | "architectures": ["HunyuanVideoModel"], 38 | "image_token_id": config.text_config.image_token_id, 39 | "vision_start_token_id": config.text_config.im_start_id, 40 | "num_image_token": num_image_token, 41 | "rope_scaling": { 42 | "alpha": 1000.0, 43 | "beta_fast": 32, 44 | "beta_slow": 1, 45 | "factor": 1000.0, 46 | "mscale": 1.0, 47 | "mscale_all_dim": 1.0, 48 | "rope_type": "dynamic", 49 | "mrope_section": [0.25, 0.25, 0.25, 0.25], 50 | }, 51 | "max_model_len": max_model_len, 52 | }) 53 | 54 | if hasattr(legacy_config, "torch_dtype") and isinstance(legacy_config.torch_dtype, str): 55 | # Convert string torch_dtype to torch.dtype 56 | legacy_config.torch_dtype = getattr(torch, legacy_config.torch_dtype) 57 | 58 | return legacy_config 59 | 60 | 61 | def load_state_dict_from_safetensors(path: str, prefixes: list[str]): 62 | def filter_dict_with_k_prefix(d, prefixes): 63 | return { 64 | k: v 65 | for k, v in d.items() 66 | if any(k.startswith(prefix) for prefix in prefixes) 67 | } 68 | 69 | index_path = os.path.join(path, "model.safetensors.index.json") 70 | if not os.path.exists(index_path): 71 | print(f"Index file {index_path} does not exist, loading all weights") 72 | pre_trained_dir = Path(path) 73 | weights_files = sorted(pre_trained_dir.glob("model-*.safetensors")) 74 | else: 75 | weight_map = json.load(open(index_path))["weight_map"] 76 | weights_files = set( 77 | filter_dict_with_k_prefix(weight_map, prefixes).values() 78 | ) 79 | weights_files = [os.path.join(path, f) for f in weights_files] 80 | 81 | if len(weights_files) == 0: 82 | raise ValueError( 83 | f"No weights files found in {path} with prefixes {prefixes}" 84 | ) 85 | 86 | state_dict = {} 87 | for file in weights_files: 88 | part_state_dict = safetensors_load_file(file) 89 | state_dict.update(part_state_dict) 90 | 91 | state_dict = filter_dict_with_k_prefix(state_dict, prefixes) 92 | return state_dict 93 | 94 | 95 | class VideoAudioLLM: 96 | def __init__( 97 | self, 98 | model_path, 99 | device_enc="cuda", 100 | device_llm="cuda", 101 | **kwargs, 102 | ): 103 | if not os.path.isdir(model_path): 104 | model_path = snapshot_download(repo_id=model_path) 105 | 106 | self.config = AutoConfig.from_pretrained(model_path) 107 | self.device_enc = device_enc 108 | self.device_llm = device_llm 109 | 110 | self.llm, self.sampling_params = self.init_llm(model_path, self.config, self.device_llm, **kwargs) 111 | 112 | self.mm_encoder = self.init_mm_encoder(model_path, self.config, self.device_enc) 113 | 114 | 115 | def init_mm_encoder(self, model_path, config, device): 116 | multi_modal_state_dict = load_state_dict_from_safetensors( 117 | model_path, ("vision_model.", "mlp2.", "speech_encoder.") 118 | ) 119 | 120 | multi_modal_encoder = VideoAudioEncoder( 121 | config, 122 | max_num_frames=config.max_num_frame, 123 | ) 124 | 125 | missing, unexpected = multi_modal_encoder.load_state_dict( 126 | multi_modal_state_dict, strict=False 127 | ) 128 | assert len(missing) == 0, f"Missing keys in mm encoder: {missing}" 129 | assert ( 130 | len(unexpected) == 0 131 | ), f"Unexpected keys in mm encoder: {unexpected}" 132 | 133 | multi_modal_encoder.eval() 134 | multi_modal_encoder.to(device) 135 | 136 | return multi_modal_encoder 137 | 138 | def init_llm(self, model_path, config, device, **kwargs): 139 | 140 | if self.device_enc != self.device_llm: 141 | gpu_memory_utilization = 0.9 142 | else: # Reserve memory for the encoder 143 | gpu_memory_utilization = 0.6 144 | 145 | max_model_len = 20480 146 | 147 | llm = LLM( 148 | model=model_path, 149 | tokenizer=model_path, 150 | trust_remote_code=True, 151 | max_model_len=max_model_len, 152 | max_seq_len_to_capture=max_model_len, 153 | dtype="bfloat16", 154 | hf_overrides=lambda x: convert_config_to_legacy(x, max_model_len), 155 | limit_mm_per_prompt={"image": 150}, 156 | enforce_eager=False, 157 | disable_mm_preprocessor_cache=True, 158 | enable_prefix_caching=False, 159 | device=device, 160 | gpu_memory_utilization=gpu_memory_utilization, 161 | ) 162 | 163 | sampling_params = SamplingParams( 164 | **kwargs, 165 | ) 166 | 167 | return llm, sampling_params 168 | 169 | def forward_mm_encoder(self, batch): 170 | """ 171 | This function will process the batch of data in the mm encoder 172 | Input: 173 | - batch: list of dicts, each dict contains the following keys: 174 | - pixel_values: torch.Tensor 175 | - audio_values: torch.Tensor 176 | - duration: float 177 | 178 | Output: 179 | - list of dicts, each dict contains the following keys: 180 | - embeddings: torch.Tensor 181 | - other keys in the original dict 182 | """ 183 | device = self.device_enc 184 | ret = [] 185 | 186 | for data in batch: 187 | pixel_values = data["pixel_values"] 188 | audio_values = data["audio_values"] 189 | duration = data["duration"] 190 | 191 | with torch.no_grad(), torch.autocast(device, torch.bfloat16): 192 | pixel_values = pixel_values.to( 193 | device=device, dtype=torch.bfloat16, non_blocking=True 194 | ) 195 | audio_values = audio_values.to( 196 | device=device, dtype=torch.bfloat16, non_blocking=True 197 | ) 198 | 199 | mixed_embeds = self.mm_encoder( 200 | pixel_values, audio_values, duration 201 | ) 202 | 203 | mixed_embeds = mixed_embeds.to(device="cpu").float().share_memory_() 204 | ret.append({"embeddings": mixed_embeds, **data}) 205 | 206 | 207 | return ret 208 | 209 | def forward_llm(self, batch): 210 | num_patches = ( 211 | self.config.vision_config.force_image_size // 32 // 2 212 | ) 213 | image_grid_thw = torch.tensor([[1, num_patches, num_patches + 1]]) 214 | prompts = [ 215 | { 216 | "prompt": "<|startoftext|>" + item["text_prompt"], 217 | "multi_modal_data": { 218 | "image": { 219 | "image_embeds": item["embeddings"], 220 | "image_grid_thw": image_grid_thw.repeat( 221 | item["embeddings"].shape[0], 1 222 | ), 223 | } 224 | }, 225 | } 226 | for item in batch 227 | ] 228 | 229 | outputs = self.llm.generate(prompts, self.sampling_params, use_tqdm=False) 230 | 231 | ret = [] 232 | 233 | for data, output in zip(batch, outputs): 234 | if "output" in data: 235 | raise ValueError("Check the batch, there is a key called output") 236 | 237 | ret.append({"output": output.outputs[0].text, **data}) 238 | 239 | return ret 240 | 241 | def __call__(self, batch): 242 | if isinstance(batch, dict): 243 | batch = [batch] 244 | 245 | ret = self.forward_mm_encoder(batch) 246 | ret = self.forward_llm(ret) 247 | return ret 248 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tokenizers==0.21 2 | sentencepiece 3 | shortuuid 4 | accelerate 5 | peft>=0.4.0 6 | bitsandbytes==0.41.0 7 | pydantic 8 | markdown2[all] 9 | numpy<2 10 | scikit-learn>=1.2.2 11 | gradio==3.35.2 12 | gradio_client==0.2.9 13 | requests 14 | httpx==0.24.0 15 | uvicorn 16 | fastapi 17 | deepspeed==0.14.4 18 | einops 19 | einops-exts 20 | timm==0.9.12 21 | decord 22 | tiktoken 23 | librosa 24 | datasets 25 | opencv-python 26 | imageio[ffmpeg] 27 | tensorboardX 28 | av 29 | moviepy==1.0.3 30 | -------------------------------------------------------------------------------- /scripts/arc_hunyuan_video_full_finetune.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_IB_SL=3 5 | export NCCL_CHECKS_DISABLE=1 6 | export NCCL_P2P_DISABLE=0 7 | export NCCL_IB_DISABLE=0 8 | export NCCL_LL_THRESHOLD=16384 9 | export NCCL_IB_CUDA_SUPPORT=1 10 | export NCCL_SOCKET_IFNAME=bond1 11 | export UCX_NET_DEVICES=bond1 12 | export NCCL_IB_HCA=mlx5 13 | export NCCL_COLLNET_ENABLE=0 14 | export SHARP_COLL_ENABLE_SAT=0 15 | export NCCL_NET_GDR_LEVEL=2 16 | export NCCL_IB_QPS_PER_CONNECTION=4 17 | export NCCL_IB_TC=160 18 | export NCCL_PXN_DISABLE=1 19 | export GLOO_SOCKET_IFNAME=bond1 20 | export NCCL_DEBUG=info 21 | 22 | export PYTHONPATH="${PYTHONPATH}:$(pwd)" 23 | export MASTER_PORT=8005 24 | export TF_CPP_MIN_LOG_LEVEL=3 25 | export LAUNCHER=pytorch 26 | 27 | OUTPUT_DIR='work_dirs/brief_summary_sft' 28 | if [ ! -d "$OUTPUT_DIR" ]; then 29 | mkdir -p "$OUTPUT_DIR" 30 | fi 31 | 32 | NODE_RANK=$1 33 | torchrun --nproc_per_node=2 \ 34 | model_train/train/arc_hunyuan_video_finetune.py \ 35 | --model_name_or_path "TencentARC/ARC-Hunyuan-Video-7B" \ 36 | --conv_style "hunyuan" \ 37 | --output_dir ${OUTPUT_DIR} \ 38 | --meta_path "sft_data/sft_jb_sp_kd_10.json" \ 39 | --overwrite_output_dir True \ 40 | --force_image_size 640 \ 41 | --num_image_token 112 \ 42 | --freeze_llm False \ 43 | --freeze_speech_encoder True \ 44 | --freeze_backbone True \ 45 | --dataloader_num_workers 4 \ 46 | --bf16 True \ 47 | --num_train_epochs 4 \ 48 | --per_device_train_batch_size 1 \ 49 | --gradient_accumulation_steps 1 \ 50 | --save_strategy "steps" \ 51 | --save_steps 500 \ 52 | --save_total_limit 100 \ 53 | --learning_rate 1e-5 \ 54 | --warmup_steps 100 \ 55 | --weight_decay 0.01 \ 56 | --warmup_ratio 0.03 \ 57 | --lr_scheduler_type "cosine" \ 58 | --logging_steps 1 \ 59 | --max_seq_length 20000 \ 60 | --max_num_frame 150 \ 61 | --do_train True \ 62 | --grad_checkpoint True \ 63 | --dynamic_image_size False \ 64 | --normalize_type hunyuan \ 65 | --seed 42 \ 66 | --deepspeed "config/zero_stage3_config.json" \ 67 | --report_to "tensorboard" \ 68 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" 69 | -------------------------------------------------------------------------------- /sft_data/audios_mp3/a3545skvqbz.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/a3545skvqbz.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/c3522vbgwaw.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/c3522vbgwaw.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/e3556d48uo0.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/e3556d48uo0.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/q35134y59x8.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/q35134y59x8.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/u3519j52lb4.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/u3519j52lb4.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/v3524wr6l4l.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/v3524wr6l4l.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/w33698kgs05.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/w33698kgs05.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/x3551nmkn8o.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/x3551nmkn8o.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/x3555e2g3t8.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/x3555e2g3t8.mp3 -------------------------------------------------------------------------------- /sft_data/audios_mp3/z1468vawe14.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/z1468vawe14.mp3 -------------------------------------------------------------------------------- /sft_data/sft_jb_sp_abs_10.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 0, "video": "w33698kgs05.mp4", "conversations": [{"from": "human", "value": "你是一个视频内容总结助手,你需要按照如下规则根据视频及其标题生成简要描述:\n1. 请以不超过100字的文本简要描述视频的主要内容,请仅准确反映视频中的核心内容,不要引入任何视频中未出现的信息\n2. 请在简要描述中保留视频中的核心人物、事件、场景及可能引人关注的信息\n视频为: