├── .gitignore
├── .gitmodules
├── ARC_Hunyuan_Video_7B.pdf
├── ARIAL.TTF
├── License.txt
├── README.md
├── config
├── zero_stage1_config.json
├── zero_stage2_config.json
└── zero_stage3_config.json
├── examples
├── demo1.mp4
├── demo2.mp4
├── demo3.mov
└── temp.mp3
├── figures
├── README.md
├── method.jpg
├── shortvid-bench.jpg
└── teaser.jpg
├── model_train
├── dist_utils.py
├── model
│ └── modeling_arc_hunyuan_video.py
├── patch
│ ├── __init__.py
│ ├── pad_data_collator.py
│ ├── train_dataloader_patch.py
│ └── train_sampler_patch.py
└── train
│ ├── __init__.py
│ ├── arc_hunyuan_video_finetune.py
│ ├── constants.py
│ ├── dataset.py
│ └── dataset_packed.py
├── model_vllm
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── __init__.cpython-311.pyc
│ ├── hunyuan.cpython-310.pyc
│ ├── hunyuan.cpython-311.pyc
│ ├── hunyuan_video.cpython-310.pyc
│ ├── hunyuan_video.cpython-311.pyc
│ ├── monkey_patch_mrope.cpython-310.pyc
│ ├── monkey_patch_mrope.cpython-311.pyc
│ ├── video_audio_encoder.cpython-310.pyc
│ ├── video_audio_encoder.cpython-311.pyc
│ ├── video_audio_llm.cpython-310.pyc
│ └── video_audio_llm.cpython-311.pyc
├── hunyuan.py
├── hunyuan_video.py
├── monkey_patch_mrope.py
├── setup_vllm_env.sh
├── video_audio_encoder.py
└── video_audio_llm.py
├── requirements.txt
├── scripts
└── arc_hunyuan_video_full_finetune.sh
├── sft_data
├── audios_mp3
│ ├── a3545skvqbz.mp3
│ ├── c3522vbgwaw.mp3
│ ├── e3556d48uo0.mp3
│ ├── q35134y59x8.mp3
│ ├── u3519j52lb4.mp3
│ ├── v3524wr6l4l.mp3
│ ├── w33698kgs05.mp3
│ ├── x3551nmkn8o.mp3
│ ├── x3555e2g3t8.mp3
│ └── z1468vawe14.mp3
├── sft_jb_sp_abs_10.jsonl
├── sft_jb_sp_kd_10.json
└── videos
│ ├── a3545skvqbz.mp4
│ ├── c3522vbgwaw.mp4
│ ├── e3556d48uo0.mp4
│ ├── q35134y59x8.mp4
│ ├── u3519j52lb4.mp4
│ ├── v3524wr6l4l.mp4
│ ├── w33698kgs05.mp4
│ ├── x3551nmkn8o.mp4
│ ├── x3555e2g3t8.mp4
│ └── z1468vawe14.mp4
├── video_inference.py
├── video_inference_sft.py
└── video_inference_vllm.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "model_vllm/vllm"]
2 | path = model_vllm/vllm
3 | url = https://github.com/TencentARC/vllm.git/
4 |
--------------------------------------------------------------------------------
/ARC_Hunyuan_Video_7B.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/ARC_Hunyuan_Video_7B.pdf
--------------------------------------------------------------------------------
/ARIAL.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/ARIAL.TTF
--------------------------------------------------------------------------------
/License.txt:
--------------------------------------------------------------------------------
1 | Tencent is pleased to support the open source community by making ARC-Hunyuan-Video-7B available.
2 |
3 | Copyright (C) 2025 Tencent. All rights reserved.
4 |
5 | ARC-Hunyuan-Video-7B is licensed under the License Terms of ARC-Hunyuan-Video-7B.
6 |
7 | For avoidance of doubts, ARC-Hunyuan-Video-7B refers to the inference code made publicly available by Tencent in accordance with ARC-Hunyuan-Video-7B in this repository.
8 |
9 | License Terms of ARC-Hunyuan-Video-7B:
10 | --------------------------------------------------------------------
11 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
12 |
13 | - You agree to use the ARC-Hunyuan-Video-7B only for academic purposes, and refrain from using it for any commercial or production purposes under any circumstances.
14 |
15 | - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ARC-Hunyuan-Video-7B
2 |
3 | [](https://arxiv.org/abs/2507.20939)
4 | [](https://arc.tencent.com/en/ai-demos/multimodal)
5 | [](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B)
6 | [](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B)
7 | [](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B-Narrator)
8 | [](https://tencentarc.github.io/posts/arc-video-announcement/)
9 | [](https://huggingface.co/datasets/TencentARC/ShortVid-Bench)
10 |
11 |
12 | Please note that in our Demo, ARC-Hunyuan-Video-7B is the model consistent with the model checkpoint and the one described in the paper, while ARC-Hunyuan-Video-7B-V0 only supports video description and summarization in Chinese.
13 | Due to API file size limits, our demo uses compressed input video resolutions, which may cause slight performance differences from the paper. For original results, please run locally.
14 |
15 |
16 | ## Introduction
17 |
18 | We introduce **ARC-Hunyuan-Video-7B**, a powerful multimodal model designed for _understanding real-world short videos_.
19 | Understanding user-generated videos is actually challenging due to their complex visual elements, high
20 | information density in both visuals and audio, and fast pacing that focuses on emotional expression and viewpoint delivery.
21 | To address this challenge, ARC-Hunyuan-Video-7B
22 | processes visual, audio, and textual signals end-to-end for a deep, structured understanding of video through integrating and reasoning over multimodal cues.
23 | Stress test reports show an inference time of just 10 seconds for a one-minute video on H20 GPU, yielding an average of 500 tokens, with
24 | inference accelerated by the vLLM framework.
25 |
26 | Compared to prior arts, we introduces a new paradigm of **Structured Video Comprehension**, with capabilities including:
27 |
28 | - **Deep Understanding of Real-World Short Videos:** ARC-Hunyuan-Video-7B excels at analyzing user-generated content from platforms like WeChat Channels and TikTok. It goes beyond surface-level descriptions to grasp the creator's intent, emotional expression, and core message by processing complex visual elements, dense audio cues, and rapid pacing.
29 | - **Synchronized Audio-Visual Reasoning:** The synchronization of raw visual and audio signals allows our model to answer complex questions that are impossible to solve with only one modality, such as understanding humor in a skit or details in a product review.
30 | - **Precise Temporal Awareness:** ARC-Hunyuan-Video-7B knows not just _what_ happens, but _when_ it happens. It supports multi-granularity timestamped captioning, temporal video grounding, and detailed event summarization, making it perfect for applications like video search, highlight generation, and content analysis.
31 | - **Advanced Reasoning and Application Versatility:** Leveraging a comprehensive multi-stage training regimen including Reinforcement Learning (RL), ARC-Hunyuan-Video-7B demonstrates strong reasoning capabilities. It supports zero-shot or few-shot fine-tuning for diverse downstream applications like video tagging, recommendation, and retrieval.
32 |
33 | The model is capable of multi-granularity timestamped video captioning and summarization, open-ended video question answering, temporal video grounding, and
34 | video reasoning as below,
35 |
36 |
37 |
38 |
39 |
40 | Specifically, ARC-Hunyuan-Video-7B is built on top of the Hunyuan-7B vision-language model with the following key designs to meet the requirements of effective structured video comprehension:
41 |
42 | - An extra audio encoder with fine-grained visual-audio synchronization for temporally aligned visual-audio inputs
43 | - A timestamp overlay mechanism on visual frames that explicitly provides the model with temporal awareness
44 | - Millions of real-world videos with a totally automated bootstrapped annotation pipeline
45 | - A comprehensive training regimen based on the finding that grounding the model in objective
46 | tasks with RL is key to unlocking high-quality, subjective understanding
47 |
48 |
49 |
50 |
51 |
52 | ## ARC-Qwen-Video-7B
53 | In this version, we have switched the base model from hunyuan VLM to [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) and introduce [ARC-Qwen-Video-7B](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B). We used the same training data and training stages. Please refere to the `arc-qwen-video` branch for details.
54 |
55 | We are also introducing a new model, [ARC-Qwen-Video-7B-Narrator](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B-Narrator). It can output **timestamped video descriptions, speaker identities, and the specific ASR (Automatic Speech Recognition) content**. By processing its output with an external LLM, you can obtain more comprehensive structured information as follows (Click to watch the video):
56 |
57 | [
](https://www.youtube.com/watch?v=Bz1T4wCuWc8)
58 |
59 | > ### 视频概述
60 | >
61 | > 这是一个喜剧短片,讲述了一位丈夫藏在棉衣里的私房钱被妻子意外发现,并误以为是丈夫准备的“惊喜”礼物。视频通过夫妻二人的一通电话,生动展现了丈夫从悠闲自得,到震惊错愕,再到崩溃无奈的全过程,充满了戏剧性的反转和幽默感。
62 | >
63 | > ### 情节发展分解
64 | >
65 | > 视频情节围绕一通电话展开,以下是详细的时间线、场景、说话人和对话内容:
66 | >
67 | >
68 | >
69 | >
70 | > 时间戳 |
71 | > 场景描述 |
72 | > 说话人 |
73 | > 对话内容 (ASR) |
74 | >
75 | >
76 | >
77 | >
78 | > 0:00 - 0:05 |
79 | > 丈夫头戴浴帽,围着浴巾,在室内泳池边悠闲地自拍。 |
80 | > 无 |
81 | > (无对话) |
82 | >
83 | >
84 | > 0:05 - 0:10 |
85 | > 镜头切换:妻子在服装店里,满脸幸福地给丈夫打电话。 |
86 | > 妻子 |
87 | > “哎,老公,老公,我爱你爱你,爱死你了,么么么。” |
88 | >
89 | >
90 | > 0:10 - 0:18 |
91 | > 丈夫接起电话,对妻子的热情感到好奇,妻子则兴奋地揭晓了“惊喜”。 |
92 | > 丈夫 |
93 | > “哎,怎么了你这是,这么高兴啊?” |
94 | >
95 | >
96 | > 妻子 |
97 | > “今天我在我的棉衣兜里,发现了你给我的惊喜,一万元哟。” |
98 | >
99 | >
100 | > 0:18 - 0:27 |
101 | > 听到“一万元”,丈夫表情瞬间凝固,从疑惑变为震惊和懊悔,但仍强装镇定。 |
102 | > 丈夫 |
103 | > “啊?好啊,你你你你开心高兴就行。” |
104 | >
105 | >
106 | > 0:27 - 0:34 |
107 | > 妻子开心地告知钱的用途,丈夫的表情彻底僵住,震惊加剧。 |
108 | > 妻子 |
109 | > “我当然高兴啊,我用它买了一件新衣裳,等晚上回去穿给你看啊。” |
110 | >
111 | >
112 | > 0:34 - 0:46 |
113 | > 丈夫确认钱已被花掉,情绪崩溃。妻子则认为是丈夫授权的,丈夫忍不住骂了一句。 |
114 | > 丈夫 |
115 | > “你已经给买成衣服了?” |
116 | >
117 | >
118 | > 妻子 |
119 | > “当然啦,不是你说的吗?说买我自己喜欢的东西。老公,你真是太好了。” |
120 | >
121 | >
122 | > 丈夫 |
123 | > “你真是败家娘们儿啊你。” |
124 | >
125 | >
126 | > 0:46 - 0:59 |
127 | > 妻子察觉丈夫语气不对,丈夫立刻改口掩饰,并催促妻子早点回家。 |
128 | > 妻子 |
129 | > “什么,老公,你说什么?” |
130 | >
131 | >
132 | > 丈夫 |
133 | > “啊?我说好啊,你漂亮我高兴。” |
134 | >
135 | >
136 | > 妻子 |
137 | > “你说的,老公。你今天呀,一定要早点回来哟,我等你哟。” |
138 | >
139 | >
140 | > 丈夫 |
141 | > “行行行行行。” |
142 | >
143 | >
144 | >
145 | >
146 | > ### 人物与核心冲突
147 | >
148 | > #### 1. 人物分析
149 | >
150 | > 丈夫:
151 | > 行为: 藏私房钱,事发后极力掩饰自己的真实情绪(心痛、懊悔)。
152 | > 心理变化: 悠闲 -> 疑惑 -> 震惊 -> 崩溃 -> 无奈接受。
153 | > 特点: 爱面子,对妻子既有爱意也有无奈,典型的“妻管严”形象。
154 | >
155 | > 妻子:
156 | > 行为: 发现钱后,认为是丈夫的爱意表达,并迅速将其消费。
157 | > 心理变化: 全程处于发现“惊喜”的幸福和喜悦中。
158 | > 特点: 天真、消费果断,对丈夫充满信任和爱意。
159 | >
160 | > #### 2. 核心冲突
161 | >
162 | > 视频的核心冲突在于 “信息的严重不对等” 所造成的戏剧性误会:
163 | >
164 | > * 丈夫视角: 辛苦攒下的 $10,000$ 元私房钱被意外发现并花掉,是一场“惊吓”。
165 | > * 妻子视角: 丈夫精心准备的 $10,000$ 元浪漫基金,是一份巨大的“惊喜”。
166 | >
167 | > 这个误会推动了整个故事的发展,丈夫的“打碎牙往肚里咽”和妻子的“理所当然的幸福”形成了强烈的喜剧反差,制造了密集的笑点。
168 | >
169 | > ### 总结
170 | >
171 | > 该视频通过一个关于“私房钱”的常见家庭情景,巧妙地构建了一个充满反转和幽默的故事。它利用戏剧性讽刺(观众和丈夫知道真相,而妻子蒙在鼓里)的手法,精准捕捉了丈夫在突发状况下的复杂心理活动。整个过程不仅笑料百出,也含蓄地探讨了夫妻间的沟通、信任和金钱观等话题,容易引发观众的共鸣和讨论。
172 |
173 |
174 | ## News
175 | - 2025.09.19: We release [ARC-Qwen-Video-7B](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B), which switched the base model from hunyuan VLM to [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). We also release [ARC-Qwen-Video-7B-Narrator](https://huggingface.co/TencentARC/ARC-Qwen-Video-7B-Narrator), which can output timestamped video descriptions, speaker identities, and the specific ASR (Automatic Speech Recognition) content. Please refere to the `arc-qwen-video` branch for details.
176 | - 2025.08.05: We release [ShortVid-Bench](https://huggingface.co/datasets/TencentARC/ShortVid-Bench), a specialized, human-annotated benchmark with multiple-choice questions for evaluating short-video understanding.
177 | - 2025.07.29: We release the training code for instruction tuning.
178 | - 2025.07.25: We release the [model checkpoint](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) and inference code of ARC-Hunyuan-Video-7B including [vLLM](https://github.com/vllm-project/vllm) version.
179 | - 2025.07.25: We release the [API service](https://arc.tencent.com/zh/document/ARC-Hunyuan-Video-7B) of ARC-Hunyuan-Video-7B, which is supported by [vLLM](https://github.com/vllm-project/vllm). We release two versions: one is V0, which only supports video description and summarization in Chinese; the other is the version consistent with the model checkpoint and the one described in the paper.
180 |
181 | ## TODOs
182 |
183 | - [x] Relase ShortVid-Bench, a specialized, human-annotated benchmark with multiple-choice questions
184 | - [x] Release training code for instruction tuning
185 |
186 | ## Usage
187 |
188 | ### Dependencies
189 | Our inference can be performed on a single NVIDIA A100 40GB GPU.
190 |
191 | ### Installation
192 |
193 | Clone the repo and install dependent packages
194 |
195 | ```bash
196 | git clone https://github.com/TencentARC/ARC-Hunyuan-Video-7B.git
197 | cd ARC-Hunyuan-Video-7B
198 | # Install torch 2.6.0 based on your CUDA version
199 | # CUDA 11.8
200 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118
201 | # CUDA 12.4
202 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
203 | # CUDA 12.6
204 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
205 |
206 | pip install -r requirements.txt
207 | pip install git+https://github.com/liyz15/transformers.git@arc_hunyuan_video
208 |
209 | # Install flash-attention based on your python version
210 | # If you are unable to install flash-attention, you can modify attn_implementation to "sdpa" in video_inference.py
211 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
212 |
213 |
214 | # (Optional) For vllm, please follow the instructions below,
215 | git submodule update --init --recursive
216 | cd model_vllm/vllm/
217 | export SETUPTOOLS_SCM_PRETEND_VERSION="0.8.5"
218 | wget https://wheels.vllm.ai/ed2462030f2ccc84be13d8bb2c7476c84930fb71/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
219 | export VLLM_PRECOMPILED_WHEEL_LOCATION=$(pwd)/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
220 | pip install --editable .
221 | # Install flash-attention if you haven't installed it
222 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
223 | ```
224 |
225 | ### Model Weights
226 |
227 | - Download [ARC-Hunyuan-Video-7B](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) including ViT and LLM and the original [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) .
228 |
229 | ### Inference
230 |
231 | ```bash
232 | # Our model currently excels at processing short videos of up to 5 minutes.
233 | # If your video is longer, we recommend following the approach used in our demo and API:
234 | # split the video into segments for inference, and then use an LLM to integrate the results.
235 | ```
236 |
237 | #### Inference without vllm
238 |
239 | ```bash
240 | cd ARC-Hunyuan-Video-7B
241 | python3 video_inference.py
242 | ```
243 |
244 | #### Inference with vllm
245 |
246 | ```bash
247 | cd ARC-Hunyuan-Video-7B
248 | python3 video_inference_vllm.py
249 | ```
250 |
251 | ## Training
252 |
253 | ### Installation
254 |
255 | Clone the repo and install dependent packages
256 |
257 | ```bash
258 | git clone https://github.com/TencentARC/ARC-Hunyuan-Video-7B.git
259 | cd ARC-Hunyuan-Video-7B
260 | # Install torch 2.6.0
261 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
262 | pip install -r requirements.txt
263 | pip install git+https://github.com/liyz15/transformers.git@arc_hunyuan_video
264 |
265 | # For training
266 | pip install accelerate==1.9.0
267 | # Upgrade the GCC version to 9.0 or above
268 | sudo dnf install gcc-toolset-9
269 | scl enable gcc-toolset-9 bash
270 | source /opt/rh/gcc-toolset-9/enable
271 | gcc -v
272 | ```
273 |
274 | ### Model Weights
275 |
276 | - Download [ARC-Hunyuan-Video-7B](https://huggingface.co/TencentARC/ARC-Hunyuan-Video-7B) including ViT and LLM and the original [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) .
277 |
278 | ### Data Preparation
279 |
280 | Please follow the format of "sft_data/sft_jb_sp_kd_10.json".
281 |
282 | - "root" specifies the path of training videos (supports .mp4; videos shorter than 5 minutes yield better results).
283 | - "audio_root" specifies the path of corresponding audios (Please use the .mp3 format). You can use the code below to extract audio from a video and save it.
284 |
285 | ```bash
286 | from moviepy.editor import VideoFileClip
287 | from pydub import AudioSegment
288 |
289 | video = VideoFileClip(video_path)
290 | if video.audio is not None:
291 | video.audio.write_audiofile(audio_path, logger=None)
292 | video.audio.close()
293 | else:
294 | duration_ms = int(video.duration * 1000)
295 | silent_audio = AudioSegment.silent(duration=duration_ms)
296 | silent_audio.export(audio_path, format="mp3")
297 | video.close()
298 | ```
299 |
300 | - "annotation" specifies the path of the annotation in the format of ".jsonl".
301 |
302 | ### Model Fully-finetune
303 |
304 | ```bash
305 | # We use DeepSpeed Zero-3 with two 98G-H20 GPUs.
306 | bash scripts/arc_hunyuan_video_full_finetune.sh
307 | ```
308 |
309 | ### Model Inference
310 |
311 | After finishing training, the model will be saved in ${OUTPUT_DIR}.
312 |
313 | ```bash
314 | # Copy the model-related config files to the directory.
315 | cd path of the downloaded ARC-Hunyuan-Video-7B
316 | cp generation_config.json preprocessor_config.json ${OUTPUT_DIR}/checkpoint-500/.
317 |
318 | cd ARC-Hunyuan-Video-7B
319 | # Modify the prompt based on your fine-tuning data, and specify the path of the fine-tuned model.
320 | python3 video_inference_sft.py
321 | ```
322 |
323 | ## API service
324 |
325 | We also provide access to the model via API, which is supported by [vLLM](https://github.com/vllm-project/vllm). For details, please refer to the [documentation](https://arc.tencent.com/zh/document/ARC-Hunyuan-Video-7B).
326 |
327 | We release two versions: one is V0, which only supports video description and summarization in Chinese; the other is the version consistent with the model checkpoint and the one described in the paper, which is capable of multi-granularity timestamped video captioning and summarization, open-ended video question answering, temporal video grounding, and video reasoning (It supports Chinese and English videos and particularly excels at Chinese).
328 | For videos longer than 5 minutes, we only support structured descriptions. We process these videos in 5-minute segments and use an LLM to integrate the inference results.
329 |
330 | If you only need to understand and summarize short Chinese videos, we recommend using the V0 version.
331 |
332 | Due to video file size limitations imposed by the deployment API, we compressed input video resolutions for our online demo and API services. Consequently, model performance in these interfaces may slightly deviate from the results reported in the paper. To reproduce the original performance, we recommend local inference.
333 |
334 |
335 | ## ShortVid-Bench
336 |
337 | Existing benchmarks often fall short in capturing the nuanced complexities
338 | of user-generated content. To rigorously evaluate model’s ability to **understand real-world short videos**,
339 | we construct a specialized benchmark named **ShortVid-Bench**. Specifically, we develop an automated pipeline
340 | to generate multi-dimensional questions for each video, targeting capabilities that signify a deep, holistic
341 | comprehension through integrating both visual and audio cues. These dimensions include:
342 | - Temporal Reasoning and Localization
343 | - Affective Intent Classification
344 | - Creator Intent Taxonomy
345 | - Narrative Comprehension
346 | - Humor & Meme Deconstruction
347 | - Creative Innovation Analysis
348 |
349 | For objective assessment, we employ a multiple-choice question (MCQ) format following previous work. Each question is carefully curated by human annotators who
350 | provide the ground-truth answer and design challenging, plausible distractors. Collectively, these dimensions with a total of 1,000 multiple-choice questions
351 | push the evaluation beyond mere descriptive captioning, demanding a genuine comprehension of the video’s
352 | context, intent, and narrative.
353 |
354 |
355 |
356 |
357 |
358 | ### Model Performance
359 | | Model | fps | #frames | think | ShortVid-Bench |
360 | | :--- | :--- | :--- | :--- | :--- |
361 | | Qwen2.5-VL-7B-Instruct | 1.0 | 150 | × | 69.3 |
362 | | Qwen2.5-Omni-7B | 1.0 | 150 | × | 69.7 |
363 | | Keye-VL-8B | 1.0 | 150 | ✓ | 56.3 |
364 | | ARC-Hunyuan-Video-7B | 1.0 | 150 | ✓ | **73.0** |
365 |
366 |
367 | Please note that the results in the table above are different from those in
368 | ARC-Hunyuan-Video-7B.
369 | This is because, after releasing the technical report, we expanded the benchmark dataset to 1,000 samples, whereas the results in the paper were based on 400 samples.
370 |
371 |
372 | ## Future Work
373 |
374 | We observe that incorporating generic video datasets during training may inadvertently compromise the model's capacity for real-world video understanding, potentially due to domain shift or noise introduced by non-real-world samples. To address this limitation, we plan to develop a dedicated model trained exclusively on rigorously curated real-world video data.
375 |
376 | ## Citation
377 |
378 | If you find the work helpful, please consider citing:
379 |
380 | ```bash
381 | @article{ge2025arc,
382 | title={ARC-Hunyuan-Video-7B: Structured Video Comprehension of Real-World Shorts},
383 | author={Ge, Yuying and Ge, Yixiao and Li, Chen and Wang, Teng and Pu, Junfu and Li, Yizhuo and Qiu, Lu and Ma, Jin and Duan, Lisheng and Zuo, Xinyu and others},
384 | journal={arXiv preprint arXiv:2507.20939},
385 | year={2025}
386 | }
387 | ```
388 |
389 | ## Acknowledge
390 |
391 | Our training code is built upon [InternVL](https://github.com/OpenGVLab/InternVL). Thanks for their excellent work!
392 |
--------------------------------------------------------------------------------
/config/zero_stage1_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 1e9,
6 | "overlap_comm": true,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 1e9,
9 | "contiguous_gradients": true
10 | },
11 | "fp16": {
12 | "enabled": "auto",
13 | "auto_cast": true,
14 | "loss_scale": 0,
15 | "initial_scale_power": 32,
16 | "loss_scale_window": 1000,
17 | "hysteresis": 2,
18 | "min_loss_scale": 1
19 | },
20 | "bf16": {
21 | "enabled": "auto"
22 | },
23 | "optimizer": {
24 | "type": "AdamW",
25 | "params": {
26 | "lr": "auto",
27 | "betas": [
28 | 0.9,
29 | 0.999
30 | ],
31 | "eps": 1e-8,
32 | "weight_decay": "auto"
33 | }
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 2000,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": true
41 | }
42 |
--------------------------------------------------------------------------------
/config/zero_stage2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 1e8,
6 | "overlap_comm": true,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 1e8,
9 | "contiguous_gradients": true
10 | },
11 | "bf16": {
12 | "enabled": "auto"
13 | },
14 | "optimizer": {
15 | "type": "AdamW",
16 | "params": {
17 | "lr": "auto",
18 | "betas": [
19 | 0.9,
20 | 0.999
21 | ],
22 | "eps": 1e-8,
23 | "weight_decay": "auto"
24 | }
25 | },
26 | "gradient_accumulation_steps": "auto",
27 | "gradient_clipping": "auto",
28 | "steps_per_print": 2000,
29 | "train_batch_size": "auto",
30 | "train_micro_batch_size_per_gpu": "auto",
31 | "wall_clock_breakdown": false
32 | }
33 |
--------------------------------------------------------------------------------
/config/zero_stage3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 1e9,
7 | "reduce_bucket_size": 1e9,
8 | "stage3_prefetch_bucket_size": 1e9,
9 | "stage3_param_persistence_threshold": 1e7,
10 | "stage3_max_live_parameters": 1e9,
11 | "stage3_max_reuse_distance": 1e9,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "fp16": {
15 | "enabled": "auto",
16 | "auto_cast": true,
17 | "loss_scale": 0,
18 | "initial_scale_power": 32,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "bf16": {
24 | "enabled": "auto"
25 | },
26 | "optimizer": {
27 | "type": "AdamW",
28 | "params": {
29 | "lr": "auto",
30 | "betas": [
31 | 0.9,
32 | 0.999
33 | ],
34 | "eps": 1e-8,
35 | "weight_decay": "auto"
36 | }
37 | },
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "steps_per_print": 2000,
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": "auto",
43 | "wall_clock_breakdown": true
44 | }
45 |
--------------------------------------------------------------------------------
/examples/demo1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/demo1.mp4
--------------------------------------------------------------------------------
/examples/demo2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/demo2.mp4
--------------------------------------------------------------------------------
/examples/demo3.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/demo3.mov
--------------------------------------------------------------------------------
/examples/temp.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/examples/temp.mp3
--------------------------------------------------------------------------------
/figures/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figures/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/figures/method.jpg
--------------------------------------------------------------------------------
/figures/shortvid-bench.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/figures/shortvid-bench.jpg
--------------------------------------------------------------------------------
/figures/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/figures/teaser.jpg
--------------------------------------------------------------------------------
/model_train/dist_utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | import os
12 | import socket
13 | import subprocess
14 | from datetime import timedelta
15 |
16 | import deepspeed
17 | import torch
18 | import torch.multiprocessing as mp
19 | from torch import distributed as dist
20 |
21 | timeout = timedelta(minutes=60)
22 |
23 |
24 | def _find_free_port():
25 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
26 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27 | # Binding to port 0 will cause the OS to find an available port for us
28 | sock.bind(('', 0))
29 | port = sock.getsockname()[1]
30 | sock.close()
31 | # NOTE: there is still a chance the port could be taken by other processes.
32 | return port
33 |
34 |
35 | def _is_free_port(port):
36 | ips = socket.gethostbyname_ex(socket.gethostname())[-1]
37 | ips.append('localhost')
38 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
39 | return all(s.connect_ex((ip, port)) != 0 for ip in ips)
40 |
41 |
42 | def init_dist(launcher, backend='nccl', **kwargs):
43 | if mp.get_start_method(allow_none=True) is None:
44 | mp.set_start_method('spawn')
45 | if launcher == 'pytorch':
46 | _init_dist_pytorch(backend, **kwargs)
47 | elif launcher == 'mpi':
48 | _init_dist_mpi(backend, **kwargs)
49 | elif launcher == 'slurm':
50 | _init_dist_slurm(backend, **kwargs)
51 | else:
52 | raise ValueError(f'Invalid launcher type: {launcher}')
53 |
54 |
55 | def _init_dist_pytorch(backend, **kwargs):
56 | # TODO: use local_rank instead of rank % num_gpus
57 | rank = int(os.environ['RANK'])
58 | num_gpus = torch.cuda.device_count()
59 | torch.cuda.set_device(rank % num_gpus)
60 | # dist.init_process_group(backend=backend, **kwargs)
61 | deepspeed.init_distributed(dist_backend=backend)
62 |
63 |
64 | def _init_dist_mpi(backend, **kwargs):
65 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
66 | torch.cuda.set_device(local_rank)
67 | if 'MASTER_PORT' not in os.environ:
68 | # 29500 is torch.distributed default port
69 | os.environ['MASTER_PORT'] = '29500'
70 | if 'MASTER_ADDR' not in os.environ:
71 | raise KeyError('The environment variable MASTER_ADDR is not set')
72 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
73 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
74 | dist.init_process_group(backend=backend, **kwargs)
75 |
76 |
77 | def _init_dist_slurm(backend, port=None):
78 | """Initialize slurm distributed training environment.
79 |
80 | If argument ``port`` is not specified, then the master port will be system
81 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
82 | environment variable, then a default port ``29500`` will be used.
83 |
84 | Args:
85 | backend (str): Backend of torch.distributed.
86 | port (int, optional): Master port. Defaults to None.
87 | """
88 | proc_id = int(os.environ['SLURM_PROCID'])
89 | ntasks = int(os.environ['SLURM_NTASKS'])
90 | node_list = os.environ['SLURM_NODELIST']
91 | num_gpus = torch.cuda.device_count()
92 | torch.cuda.set_device(proc_id % num_gpus)
93 | addr = subprocess.getoutput(
94 | f'scontrol show hostname {node_list} | head -n1')
95 | # specify master port
96 | if port is not None:
97 | os.environ['MASTER_PORT'] = str(port)
98 | elif 'MASTER_PORT' in os.environ:
99 | pass # use MASTER_PORT in the environment variable
100 | else:
101 | # if torch.distributed default port(29500) is available
102 | # then use it, else find a free port
103 | if _is_free_port(29500):
104 | os.environ['MASTER_PORT'] = '29500'
105 | else:
106 | os.environ['MASTER_PORT'] = str(_find_free_port())
107 | # use MASTER_ADDR in the environment variable if it already exists
108 | if 'MASTER_ADDR' not in os.environ:
109 | os.environ['MASTER_ADDR'] = addr
110 | os.environ['WORLD_SIZE'] = str(ntasks)
111 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
112 | os.environ['RANK'] = str(proc_id)
113 | # dist.init_process_group(backend=backend, timeout=timeout)
114 | deepspeed.init_distributed(dist_backend=backend)
115 |
--------------------------------------------------------------------------------
/model_train/patch/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from .pad_data_collator import (concat_pad_data_collator,
8 | dpo_concat_pad_data_collator,
9 | pad_data_collator)
10 | from .train_dataloader_patch import replace_train_dataloader
11 | from .train_sampler_patch import replace_train_sampler
12 |
13 | __all__ = ['replace_llama_attn_with_flash_attn',
14 | 'replace_llama_rmsnorm_with_fused_rmsnorm',
15 | 'replace_llama2_attn_with_flash_attn',
16 | 'replace_train_sampler',
17 | 'replace_train_dataloader',
18 | 'replace_internlm2_attention_class',
19 | 'replace_qwen2_attention_class',
20 | 'replace_phi3_attention_class',
21 | 'replace_llama_attention_class',
22 | 'pad_data_collator',
23 | 'dpo_concat_pad_data_collator',
24 | 'concat_pad_data_collator',
25 | 'apply_liger_kernel_to_internvit']
26 |
--------------------------------------------------------------------------------
/model_train/patch/pad_data_collator.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | import numpy as np
12 | import torch
13 |
14 | IGNORE_INDEX = -100
15 |
16 | def pad_data_collator(features, pad_id=0):
17 |
18 | first = features[0]
19 | batch = {}
20 |
21 | batch_lens = [feat['input_ids'].shape for feat in features]
22 | max_item_length = max(batch_lens)[0]
23 | for idx in range(len(features)):
24 | feat = features[idx]
25 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
26 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
27 | feat['input_ids'] = temp_input_ids
28 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
29 | temp_labels[:feat['labels'].shape[0]] = feat['labels']
30 | feat['labels'] = temp_labels
31 | feat['attention_mask'] = feat['input_ids'].ne(pad_id)
32 |
33 | # Special handling for labels.
34 | # Ensure that tensor is created with the correct type
35 | # (it should be automatically the case, but let's make sure of it.)
36 | if 'label' in first and first['label'] is not None:
37 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
38 | dtype = torch.long if isinstance(label, int) else torch.float
39 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
40 | elif 'label_ids' in first and first['label_ids'] is not None:
41 | if isinstance(first['label_ids'], torch.Tensor):
42 | batch['labels'] = torch.stack([f['label_ids'] for f in features])
43 | else:
44 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
45 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
46 |
47 | # Handling of all other possible keys.
48 | # Again, we will use the first element to figure out which key/values are not None for this model.
49 | for k, v in first.items():
50 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
51 | if isinstance(v, torch.Tensor):
52 | batch[k] = torch.stack([f[k] for f in features])
53 | elif isinstance(v, np.ndarray):
54 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
55 | else:
56 | batch[k] = torch.tensor([f[k] for f in features])
57 | return batch
58 |
59 |
60 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
61 |
62 | first = features[0]
63 | batch = {}
64 |
65 | batch_lens = [feat['input_ids'].shape for feat in features]
66 | max_item_length = max_item_length or max(batch_lens)[0]
67 | for idx in range(len(features)):
68 | feat = features[idx]
69 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
70 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
71 | feat['input_ids'] = temp_input_ids
72 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
73 | temp_labels[:feat['labels'].shape[0]] = feat['labels']
74 | feat['labels'] = temp_labels
75 | feat['attention_mask'] = feat['input_ids'].ne(pad_id)
76 |
77 | # if 'position_ids' in feat:
78 | # temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
79 | # temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
80 | # feat['position_ids'] = temp_position_ids
81 |
82 | if 'loss_weight' in feat:
83 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length)
84 | temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight']
85 | feat['loss_weight'] = temp_loss_weight
86 |
87 | # Special handling for labels.
88 | # Ensure that tensor is created with the correct type
89 | # (it should be automatically the case, but let's make sure of it.)
90 | if 'label' in first and first['label'] is not None:
91 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
92 | dtype = torch.long if isinstance(label, int) else torch.float
93 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
94 | elif 'label_ids' in first and first['label_ids'] is not None:
95 | if isinstance(first['label_ids'], torch.Tensor):
96 | batch['labels'] = torch.stack([f['label_ids'] for f in features])
97 | else:
98 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
99 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
100 |
101 | # Handling of all other possible keys.
102 | # Again, we will use the first element to figure out which key/values are not None for this model.
103 | for k, v in first.items():
104 | if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
105 | v is not None and not isinstance(v, str):
106 | if isinstance(v, torch.Tensor):
107 | batch[k] = torch.stack([f[k] for f in features])
108 | elif isinstance(v, np.ndarray):
109 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
110 | else:
111 | batch[k] = torch.tensor([f[k] for f in features])
112 | if k in ('pixel_values', 'image_flags'):
113 | if isinstance(v, torch.Tensor):
114 | batch[k] = torch.concat([f[k] for f in features])
115 | elif isinstance(v, np.ndarray):
116 | batch[k] = torch.concat(np.stack([f[k] for f in features]))
117 | else:
118 | batch[k] = torch.concat([f[k] for f in features])
119 | return batch
120 |
121 |
122 | def dpo_concat_pad_data_collator(features, pad_id=0):
123 |
124 | first = features[0]
125 | batch = {}
126 |
127 | for prefix in ['chosen_', 'rejected_']:
128 | batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features]
129 | max_item_length = max(batch_lens)
130 | for idx in range(len(features)):
131 | feat = features[idx]
132 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
133 | temp_input_ids[:feat[f'{prefix}input_ids'].shape[0]] = feat[f'{prefix}input_ids']
134 | feat[f'{prefix}input_ids'] = temp_input_ids
135 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
136 | temp_labels[:feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels']
137 | feat[f'{prefix}labels'] = temp_labels
138 | feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id)
139 |
140 | # Handling of all other possible keys.
141 | # Again, we will use the first element to figure out which key/values are not None for this model.
142 | for k, v in first.items():
143 | if k not in ('pixel_values', 'image_flags') and \
144 | v is not None and not isinstance(v, str):
145 | if isinstance(v, torch.Tensor):
146 | batch[k] = torch.stack([f[k] for f in features])
147 | elif isinstance(v, np.ndarray):
148 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
149 | else:
150 | batch[k] = torch.tensor([f[k] for f in features])
151 | if k in ('pixel_values', 'image_flags'):
152 | if isinstance(v, torch.Tensor):
153 | batch[k] = torch.concat([f[k] for f in features])
154 | elif isinstance(v, np.ndarray):
155 | batch[k] = torch.concat(np.stack([f[k] for f in features]))
156 | else:
157 | batch[k] = torch.concat([f[k] for f in features])
158 | if '_logps' in k:
159 | batch[k] = [f[k] for f in features]
160 | return batch
161 |
--------------------------------------------------------------------------------
/model_train/patch/train_dataloader_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | import datasets
12 | import torch
13 | import transformers
14 | from functools import partial
15 | import torch.distributed as dist
16 | from torch.utils.data import DataLoader
17 | from transformers.trainer import is_datasets_available, seed_worker
18 |
19 |
20 | def get_train_dataloader(self) -> DataLoader:
21 | """
22 | Returns the training [`~torch.utils.data.DataLoader`].
23 |
24 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
25 | training if necessary) otherwise.
26 |
27 | Subclass and override this method if you want to inject some custom behavior.
28 | """
29 | if self.train_dataset is None:
30 | raise ValueError('Trainer: training requires a train_dataset.')
31 |
32 | train_dataset = self.train_dataset
33 | data_collator = self.data_collator
34 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
35 | train_dataset = self._remove_unused_columns(train_dataset, description='training')
36 | else:
37 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
38 |
39 | dataloader_params = {
40 | 'batch_size': self._train_batch_size,
41 | 'collate_fn': data_collator,
42 | 'num_workers': self.args.dataloader_num_workers,
43 | 'pin_memory': self.args.dataloader_pin_memory,
44 | 'persistent_workers': self.args.dataloader_persistent_workers,
45 | }
46 |
47 | if not isinstance(train_dataset, torch.utils.data.IterableDataset):
48 | dataloader_params['sampler'] = self._get_train_sampler()
49 | dataloader_params['drop_last'] = self.args.dataloader_drop_last
50 |
51 | num_workers = self.args.dataloader_num_workers
52 | rank = dist.get_rank() if dist.is_initialized() else 0
53 |
54 | dataloader_params['worker_init_fn'] = partial(seed_worker, num_workers=num_workers, rank=rank)
55 | #dataloader_params['worker_init_fn'] = seed_worker
56 |
57 | if self.args.use_packed_ds:
58 | return DataLoader(train_dataset, **dataloader_params)
59 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
60 |
61 |
62 | def replace_train_dataloader():
63 | transformers.Trainer.get_train_dataloader = get_train_dataloader
64 | # print('Replace train dataloader!!')
65 |
--------------------------------------------------------------------------------
/model_train/patch/train_sampler_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | from typing import List, Optional
12 |
13 | import torch
14 | import transformers
15 | from torch.utils.data import Dataset, Sampler
16 | from transformers.tokenization_utils_base import BatchEncoding
17 | from transformers.trainer import (LengthGroupedSampler, RandomSampler,
18 | has_length)
19 | from transformers.trainer_pt_utils import logger
20 |
21 |
22 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38
23 | def split_to_even_chunks(indices, lengths, num_chunks):
24 | """
25 | Split a list of indices into `chunks` chunks of roughly equal lengths.
26 | """
27 |
28 | if len(indices) % num_chunks != 0:
29 | return [indices[i::num_chunks] for i in range(num_chunks)]
30 |
31 | num_indices_per_chunk = len(indices) // num_chunks
32 |
33 | chunks = [[] for _ in range(num_chunks)]
34 | chunks_lengths = [0 for _ in range(num_chunks)]
35 | for index in indices:
36 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
37 | chunks[shortest_chunk].append(index)
38 | chunks_lengths[shortest_chunk] += lengths[index]
39 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
40 | chunks_lengths[shortest_chunk] = float('inf')
41 |
42 | return chunks
43 |
44 |
45 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88
46 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
47 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
48 | indices = torch.randperm(len(lengths), generator=generator)
49 | megabatch_size = world_size * batch_size
50 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
51 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
52 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
53 |
54 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
55 |
56 |
57 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99
58 | class LengthGroupedSampler(Sampler):
59 | r"""
60 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
61 | keeping a bit of randomness.
62 | """
63 |
64 | def __init__(
65 | self,
66 | batch_size: int,
67 | world_size: int,
68 | dataset: Optional[Dataset] = None,
69 | lengths: Optional[List[int]] = None,
70 | model_input_name: Optional[str] = None,
71 | generator=None,
72 | ):
73 | if dataset is None and lengths is None:
74 | raise ValueError('One of dataset and lengths must be provided.')
75 |
76 | self.batch_size = batch_size
77 | if lengths is None:
78 | model_input_name = model_input_name if model_input_name is not None else 'input_ids'
79 | if (
80 | not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
81 | or model_input_name not in dataset[0]
82 | ):
83 | raise ValueError(
84 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an '
85 | f"'{model_input_name}' key."
86 | )
87 | lengths = [len(feature[model_input_name]) for feature in dataset]
88 | elif isinstance(lengths, torch.Tensor):
89 | logger.info(
90 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...'
91 | )
92 | lengths = lengths.tolist()
93 | self.world_size = world_size
94 | self.lengths = lengths
95 | self.generator = generator
96 |
97 | def __len__(self):
98 | return len(self.lengths)
99 |
100 | def __iter__(self):
101 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
102 | return iter(indices)
103 |
104 |
105 | # patch trainer
106 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
107 | if self.train_dataset is None or not has_length(self.train_dataset):
108 | return None
109 | # Build the sampler.
110 | if self.args.group_by_length:
111 | lengths = []
112 | for dataset in self.train_dataset.datasets:
113 | lengths = lengths + dataset.length
114 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
115 | return LengthGroupedSampler(
116 | self.args.train_batch_size,
117 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
118 | # self.args.train_batch_size * self.args.gradient_accumulation_steps,
119 | dataset=self.train_dataset,
120 | lengths=lengths,
121 | model_input_name=model_input_name,
122 | )
123 | else:
124 | return RandomSampler(self.train_dataset)
125 |
126 |
127 | def replace_train_sampler():
128 | transformers.Trainer._get_train_sampler = _get_train_sampler
129 | # print('Replace train sampler!!')
130 |
--------------------------------------------------------------------------------
/model_train/train/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
--------------------------------------------------------------------------------
/model_train/train/constants.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | IMG_CONTEXT_TOKEN = ''
12 | IMG_START_TOKEN = '
'
13 | IMG_END_TOKEN = ''
14 | VID_START_TOKEN = ''
15 | VID_END_TOKEN = ''
16 | LINE_TOKEN = ''
17 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
18 | IMAGENET_STD = (0.229, 0.224, 0.225)
19 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
20 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
21 | SIGLIP_MEAN = (0.5, 0.5, 0.5)
22 | SIGLIP_STD = (0.5, 0.5, 0.5)
23 | HUNYUAN_MEAN = (0.48145466, 0.4578275, 0.40821073)
24 | HUNYUAN_STD = (0.26862954, 0.26130258, 0.27577711)
25 |
--------------------------------------------------------------------------------
/model_train/train/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | import io
12 |
13 | from transformers.trainer_pt_utils import LabelSmoother
14 |
15 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
16 | import os
17 | import random
18 | import re
19 | from collections import Counter
20 | from typing import Dict
21 | import math
22 | import cv2
23 | import imageio
24 | import numpy as np
25 | import torch
26 | import warnings
27 | import torch.nn.functional as F
28 | import torchvision.transforms as T
29 | import transformers
30 | from decord import VideoReader
31 | from PIL import Image
32 | from PIL import ImageDraw, ImageFont
33 | from torch.utils.data import ConcatDataset, WeightedRandomSampler
34 | from torchvision.transforms.functional import InterpolationMode
35 |
36 | from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
37 | IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
38 | VID_END_TOKEN, VID_START_TOKEN, LINE_TOKEN,
39 | SIGLIP_MEAN, SIGLIP_STD, HUNYUAN_MEAN, HUNYUAN_STD)
40 |
41 | xdrope_section = [
42 | 0.25,
43 | 0.25,
44 | 0.25,
45 | 0.25
46 | ]
47 |
48 | def calculate_ngram_repetition(text, n):
49 | words = text.split()
50 | ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
51 | ngram_counts = Counter(ngrams)
52 | total_ngrams = len(ngrams)
53 | repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1)
54 | return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0
55 |
56 | def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10):
57 | for conversation in conversations:
58 | if conversation['from'] == 'gpt':
59 | model_answer = conversation['value']
60 | repeat_ratio = calculate_ngram_repetition(model_answer, ngram)
61 | if repeat_ratio > repeat_threshold:
62 | raise Exception
63 |
64 | def get_frame_indices(vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
65 | duration = vlen / input_fps
66 |
67 | frames_per_second = input_fps
68 |
69 | ## current support videos > 300s
70 | if duration > 300:
71 | warnings.warn("The video is longer than 5 minutes. Due to sampling, some audio information may be lost!")
72 |
73 | if duration <= max_num_frames:
74 | interval = 1
75 | intervals = [(int(i * interval * frames_per_second), int((i + 1) * interval * frames_per_second)) for i in range(math.ceil(duration))]
76 | intervals_sec = [(int(i * interval), int((i + 1) * interval)) for i in range(math.ceil(duration))]
77 | else:
78 | num_segments = max_num_frames
79 | segment_duration = duration / num_segments
80 | intervals = [(int(i * segment_duration * frames_per_second), int((i + 1) * segment_duration * frames_per_second)) for i in range(num_segments)]
81 | intervals_sec = [(round(i * segment_duration), round((i + 1) * segment_duration)) for i in range(num_segments)]
82 |
83 | frame_indices = []
84 |
85 | if sample == 'rand':
86 | for start, end in intervals:
87 | if end > vlen:
88 | end = vlen
89 | frame_indices.append(random.choice(range(start, end)))
90 | elif sample == 'middle':
91 | for start, end in intervals:
92 | if end > vlen:
93 | end = vlen
94 | frame_indices.append((start + end) // 2)
95 | else:
96 | raise NotImplementedError
97 |
98 | return frame_indices, intervals_sec
99 |
100 | def seconds_to_mmss(seconds):
101 | m = int(seconds // 60)
102 | s = int(seconds % 60)
103 | return f"{m:02d}:{s:02d}"
104 |
105 | def sec2hms(seconds):
106 | seconds = int(round(seconds))
107 | h = seconds // 3600
108 | m = (seconds % 3600) // 60
109 | s = seconds % 60
110 | return f"{h:02d}:{m:02d}:{s:02d}"
111 |
112 | def add_timestamp_to_frame(frame, start_sec, end_sec, font_size=40):
113 | draw = ImageDraw.Draw(frame)
114 | font_size = int(frame.height * 0.05)
115 | font = ImageFont.truetype("ARIAL.TTF", font_size)
116 | text = f"{sec2hms(start_sec)}-{sec2hms(end_sec)}"
117 | bbox = draw.textbbox((0, 0), text, font=font)
118 | text_w = bbox[2] - bbox[0]
119 | text_h = bbox[3] - bbox[1]
120 | x = frame.width - text_w - 20
121 | y = 20
122 | draw.rectangle([x-10, y-10, x+text_w+10, y+text_h+10], fill=(0,0,0,180))
123 | draw.text((x, y), text, fill=(255,255,255), font=font)
124 | return frame
125 |
126 | def read_frames_decord(
127 | video_path, num_frames, sample='rand', fix_start=None,
128 | client=None, clip=None, use_time=False, min_num_frames=4
129 | ):
130 | video_reader = VideoReader(video_path, num_threads=1)
131 |
132 | vlen = len(video_reader)
133 | fps = video_reader.get_avg_fps()
134 | duration = vlen / float(fps)
135 |
136 | frame_indices, intervals_sec = get_frame_indices(
137 | vlen, sample=sample, fix_start=fix_start,
138 | input_fps=fps, max_num_frames=num_frames
139 | )
140 |
141 | frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
142 | frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
143 |
144 | if use_time:
145 | frames_with_ts = []
146 | for i, frame in enumerate(frames):
147 | start_sec, end_sec = intervals_sec[i]
148 | frame_with_ts = add_timestamp_to_frame(frame, start_sec, end_sec)
149 | frames_with_ts.append(frame_with_ts)
150 | frames = frames_with_ts
151 |
152 | save_dir = 'output_frames'
153 | if not os.path.exists(save_dir):
154 | os.makedirs(save_dir, exist_ok=True)
155 |
156 | for idx, frame in enumerate(frames):
157 | video_name = video_path.split('/')[-1].replace('.mp4', '')
158 | frame.save(os.path.join(save_dir, f'{video_name}_frame_{idx:03d}.jpg'))
159 |
160 | return frames
161 |
162 | def extract_frame_number(filename):
163 | # Extract the numeric part from the filename using regular expressions
164 | match = re.search(r'_(\d+).jpg$', filename)
165 | return int(match.group(1)) if match else -1
166 |
167 | def sort_frames(frame_paths):
168 | # Extract filenames from each path and sort by their numeric part
169 | return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
170 |
171 | def read_frames_folder(
172 | video_path, num_frames, sample='rand', fix_start=None,
173 | client=None, clip=None, min_num_frames=4
174 | ):
175 | image_list = sort_frames(list(os.listdir(video_path)))
176 | frames = []
177 | for image in image_list:
178 | fp = os.path.join(video_path, image)
179 | frame = Image.open(fp).convert('RGB')
180 | frames.append(frame)
181 |
182 | return frames
183 |
184 | class WeightedConcatDataset(ConcatDataset):
185 | def __init__(self, datasets, weights):
186 | super().__init__(datasets)
187 | self.weights = torch.DoubleTensor(weights)
188 | self.total_size = sum(len(d) for d in datasets)
189 | self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
190 |
191 | def __iter__(self):
192 | return iter(self.sampler)
193 |
194 | def __len__(self):
195 | return self.total_size
196 |
197 | def pil_loader(img_str):
198 | buff = io.BytesIO(img_str)
199 | img = Image.open(buff)
200 | return img.convert('RGB')
201 |
202 | class TCSLoader(object):
203 |
204 | def __init__(self, conf_path, sc_config_key='sensecore'):
205 | print(f'[TCSLoader] config_path: {conf_path}')
206 | # print('--> before Client(conf_path)')
207 | # self.client = Client(conf_path)
208 | # self.sc_config_key = sc_config_key
209 | # print('--> after Client(conf_path)')
210 | self.client = None
211 |
212 | def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', use_time=False, clip=None):
213 | #print(image_type, max_num_frames, min_num_frames, clip)
214 | if image_type == 'image':
215 | img_value_str = self.client.get(fn)
216 | img = pil_loader(img_value_str)
217 | return img
218 |
219 | elif image_type == 'video':
220 | if fn.endswith('/'):
221 | frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
222 | client=self.client, sample=sample)
223 | else:
224 | frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
225 | client=self.client, sample=sample, use_time=use_time, clip=clip)
226 | return frames
227 |
228 |
229 | def expand2square(pil_img, background_color):
230 | width, height = pil_img.size
231 | if width == height:
232 | return pil_img
233 | elif width > height:
234 | result = Image.new(pil_img.mode, (width, width), background_color)
235 | result.paste(pil_img, (0, (width - height) // 2))
236 | return result
237 | else:
238 | result = Image.new(pil_img.mode, (height, height), background_color)
239 | result.paste(pil_img, ((height - width) // 2, 0))
240 | return result
241 |
242 |
243 | def simulate_jpeg_degradation(quality):
244 | def jpeg_degrade(img):
245 | with io.BytesIO() as output:
246 | img.convert('RGB').save(output, format='JPEG', quality=quality)
247 | output.seek(0) # Move the reading cursor to the start of the stream
248 | img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
249 | return img_jpeg
250 | return jpeg_degrade
251 |
252 |
253 | # Define the JPEG compression quality range, pre-create all JPEG compression functions
254 | qualities = list(range(75, 101))
255 | jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
256 |
257 |
258 | def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
259 | if normalize_type == 'imagenet':
260 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
261 | elif normalize_type == 'clip':
262 | MEAN, STD = CLIP_MEAN, CLIP_STD
263 | elif normalize_type == 'siglip':
264 | MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
265 | elif normalize_type == 'hunyuan':
266 | MEAN, STD = HUNYUAN_MEAN, HUNYUAN_STD
267 | else:
268 | raise NotImplementedError
269 | if is_train: # use data augumentation
270 | transform = T.Compose([
271 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
272 | T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
273 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
274 | T.ToTensor(),
275 | T.Normalize(mean=MEAN, std=STD)
276 | ])
277 | else:
278 | if pad2square is False: # now we use this transform function by default
279 | transform = T.Compose([
280 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
281 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
282 | T.ToTensor(),
283 | T.Normalize(mean=MEAN, std=STD)
284 | ])
285 | else:
286 | transform = T.Compose([
287 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
288 | T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
289 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
290 | T.ToTensor(),
291 | T.Normalize(mean=MEAN, std=STD)
292 | ])
293 |
294 | return transform
295 |
296 | def generate_tokens(w=448, h=448, use_xrope=True):
297 | total_patch_size = 16 * 2 * 2
298 | tokens = ''
299 | tokens += IMG_START_TOKEN
300 | tokens += IMG_CONTEXT_TOKEN
301 | for i in range(h // total_patch_size):
302 | for j in range(w // total_patch_size):
303 | tokens += IMG_CONTEXT_TOKEN
304 | if use_xrope:
305 | tokens += LINE_TOKEN
306 | else:
307 | tokens += IMG_CONTEXT_TOKEN
308 | tokens += IMG_CONTEXT_TOKEN
309 | tokens += IMG_END_TOKEN
310 | return tokens
311 |
312 | def get_xdrope_position_ids(
313 | position_ids_t,
314 | position_ids_x,
315 | position_ids_y,
316 | b,
317 | i,
318 | prev_index,
319 | boi_index,
320 | eoi_index,
321 | eol_index,
322 | ):
323 |
324 | position_ids_t[b, (i + 1):] -= (i + 1 - prev_index)
325 | position_ids_x[b, (i + 1):] -= (i + 1 - prev_index)
326 | position_ids_y[b, (i + 1):] -= (i + 1 - prev_index)
327 |
328 | idx_cur = 0
329 | for x in range(boi_index.size()[0]):
330 | m = boi_index[x]
331 | n = eoi_index[x]
332 | assert m < n
333 | # Reset image token position ids.
334 | if m >= prev_index and m < i:
335 | assert n < i
336 | position_ids_t[b, m+1+1:n-1] = idx_cur
337 | idx_cur += 1
338 |
339 | eol_idx_list = []
340 | for y in range(eol_index.size()[0]):
341 | eol_idx = eol_index[y]
342 | # Reset image token position ids.
343 | if eol_idx > m and eol_idx < n:
344 | eol_idx_list.append(eol_idx)
345 | row = len(eol_idx_list)
346 | assert row > 0, 'the row of an image must be a positive integer'
347 | # -2 is for learnable img start and img end for each image
348 | # -1 is for getting rid of endpoint
349 | column = torch.round((n-m-2-1)/row).long().item()
350 |
351 | assert row * column == n-m-2-1, f"row:\t{row}, column:\t{column}, n:\t{n}, m:\t{m}, {int((n-m-2-1)/row)}"
352 |
353 | idx_xy = 0
354 | for rr in range(row):
355 | for cc in range(column):
356 | position_ids_x[b, m+1+1+idx_xy] = cc
357 | position_ids_y[b, m+1+1+idx_xy] = rr
358 | idx_xy += 1
359 |
360 | return position_ids_t, position_ids_x, position_ids_y
361 |
362 |
363 | def get_attention_masks_and_position_ids(data, eod_id, im_start_id, im_end_id, im_newline_id):
364 | position_embedding_xdrope = True
365 |
366 | micro_batch_size, seq_length = data.size()
367 | att_mask_batch = 1
368 | attention_mask = torch.tril(torch.ones(
369 | (att_mask_batch, seq_length, seq_length))).view(
370 | att_mask_batch, 1, seq_length, seq_length).int()
371 | position_ids = torch.arange(seq_length, dtype=torch.long)
372 | position_ids = position_ids.unsqueeze(0).expand_as(data)
373 | if position_embedding_xdrope:
374 | position_ids_t = position_ids.clone()
375 | position_ids_x = position_ids.clone()
376 | position_ids_y = position_ids.clone()
377 | for b in range(micro_batch_size):
378 | # Find indecies where EOD token is.
379 | eod_index = position_ids[b, data[b] == eod_id]
380 |
381 | eod_index = eod_index.clone()
382 | # Detach indecies from positions if going to modify positions.
383 | # Loop through EOD indecies:
384 | prev_index = 0
385 | if position_embedding_xdrope:
386 | boi_index = position_ids[b, data[b] == im_start_id]
387 | eoi_index = position_ids[b, data[b] == im_end_id]
388 | eol_index = position_ids[b, data[b] == im_newline_id]
389 |
390 | #print(boi_index, eoi_index, eol_index)
391 | for j in range(eod_index.size()[0]):
392 | i = eod_index[j]
393 | attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
394 | position_ids[b, (i + 1):] -= (i + 1 - prev_index)
395 |
396 | if position_embedding_xdrope:
397 | position_ids_t, position_ids_x, position_ids_y = get_xdrope_position_ids(position_ids_t, position_ids_x, position_ids_y,
398 | b, i, prev_index, boi_index,
399 | eoi_index, eol_index
400 | )
401 | prev_index = i + 1
402 | if position_embedding_xdrope:
403 | position_ids = torch.cat([position_ids.unsqueeze(1), position_ids_x.unsqueeze(1), position_ids_y.unsqueeze(1),position_ids_t.unsqueeze(1)], dim=1)
404 |
405 | return attention_mask, position_ids
406 |
407 |
408 | def preprocess_hunyuan(
409 | template_name,
410 | sources,
411 | tokenizer: transformers.PreTrainedTokenizer,
412 | text_only: bool = False,
413 | image_size: int = 448,
414 | num_image: int = 1
415 | ) -> Dict:
416 | assert len(sources) == 1, 'process only the first conversations'
417 | conversations = sources[0]
418 |
419 | if not text_only:
420 | new_conversations = []
421 | current_image_idx = 0
422 | img_tokens_per_frame = generate_tokens(w=image_size, h=image_size)
423 | for conversation in conversations:
424 | if conversation['from'] == 'human':
425 | image_cnt = conversation['value'].count('')
426 | for i in range(image_cnt):
427 | if current_image_idx == num_image:
428 | break
429 | if current_image_idx == 0:
430 | if num_image != 1:
431 | image_tokens = f'{VID_START_TOKEN}{img_tokens_per_frame}'
432 | else:
433 | image_tokens = f'{VID_START_TOKEN}{img_tokens_per_frame}{VID_END_TOKEN}'
434 | elif current_image_idx == num_image - 1:
435 | image_tokens = f'{img_tokens_per_frame}{VID_END_TOKEN}'
436 | else:
437 | image_tokens = f'{img_tokens_per_frame}'
438 | conversation['value'] = conversation['value'].replace('', image_tokens, 1)
439 | current_image_idx += 1
440 | new_conversations.append(conversation)
441 | conversations = new_conversations
442 | assert current_image_idx == num_image, f'{current_image_idx} != {num_image}'
443 |
444 | batches, roles = [], []
445 |
446 | for conversation in conversations:
447 | if conversation['from'] == 'human':
448 | batches.append(conversation["value"] + '')
449 | roles.append('human')
450 | elif conversation['from'] == 'gpt':
451 | batches.append(f'{conversation["value"]}<|endoftext|>')
452 | roles.append('gpt')
453 | else:
454 | raise NotImplementedError
455 |
456 | final_input_ids, final_targets = [tokenizer.bos_id], [IGNORE_TOKEN_ID]
457 | for role, batch in zip(roles, batches):
458 |
459 | input_ids = tokenizer.encode(
460 | batch,
461 | padding=False,
462 | max_length=tokenizer.model_max_length,
463 | truncation=True,
464 | )
465 | input_ids = np.array(input_ids)
466 |
467 | final_input_ids.extend(input_ids.tolist())
468 |
469 | if role == 'system' or role == 'human':
470 | final_targets.extend(np.full(input_ids.shape, IGNORE_TOKEN_ID).tolist()) # ignore
471 | elif role == 'gpt':
472 | target = input_ids.copy()
473 | final_targets.extend(target.tolist())
474 | else:
475 | raise NotImplementedError
476 |
477 | final_input_ids = np.array(final_input_ids)
478 | final_targets = np.array(final_targets)
479 | input_ids = torch.tensor(final_input_ids)[:tokenizer.model_max_length]
480 | targets = torch.tensor(final_targets)[:tokenizer.model_max_length]
481 |
482 | _, position_ids = get_attention_masks_and_position_ids(input_ids.unsqueeze(0), \
483 | tokenizer.eod_id, tokenizer.im_start_id, tokenizer.im_end_id, tokenizer.im_newline_id)
484 |
485 | input_ids[input_ids == tokenizer.im_newline_id] = tokenizer.image_token_id
486 |
487 | torch.set_printoptions(threshold=float('inf'))
488 |
489 | padding = False
490 | if padding:
491 | current_length = input_ids.size(0)
492 | padding_length = tokenizer.model_max_length - current_length
493 | input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_id)
494 | targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID)
495 |
496 | input_ids = input_ids.unsqueeze(0)
497 | targets = targets.unsqueeze(0)
498 |
499 | return dict(
500 | input_ids=input_ids,
501 | labels=targets,
502 | attention_mask=input_ids.ne(tokenizer.pad_id),
503 | position_ids=position_ids,
504 | )
505 |
--------------------------------------------------------------------------------
/model_train/train/dataset_packed.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | # Modified from InternVL
7 | # Copyright (c) 2025 ARC Lab
8 | # Licensed under LICENSE [see LICENSE for details]
9 | # --------------------------------------------------------
10 |
11 | import bisect
12 | import copy
13 | import logging
14 | from collections import defaultdict
15 | from typing import List, Union
16 |
17 | import numpy as np
18 | import torch
19 | import torch.distributed as dist
20 | from torch.utils.data import IterableDataset, get_worker_info
21 | from transformers.trainer_pt_utils import LabelSmoother
22 |
23 | from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN
24 |
25 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
26 | logger = logging.getLogger(__name__)
27 | logger.setLevel(logging.INFO)
28 |
29 |
30 | def is_dist_avail_and_initialized():
31 | if not dist.is_available():
32 | return False
33 | if not dist.is_initialized():
34 | return False
35 | return True
36 |
37 |
38 | def get_world_size():
39 | if not is_dist_avail_and_initialized():
40 | return 1
41 | return dist.get_world_size()
42 |
43 |
44 | def get_rank():
45 | if not is_dist_avail_and_initialized():
46 | return 0
47 | return dist.get_rank()
48 |
49 |
50 | class PackedDataset(IterableDataset):
51 | def __init__(
52 | self,
53 | tokenizer,
54 | data_rank,
55 | data_world_size,
56 | datasets: List,
57 | dataset_weight: List[int] = None,
58 | num_images_expected: int = 6,
59 | max_packed_tokens: int = 32768,
60 | max_buffer_size: int = 100,
61 | log_freq: int = 1000000,
62 | strict_mode: bool = False,
63 | debug_mode: bool = False,
64 | replacement: bool = True,
65 | allow_overflow: bool = True,
66 | allow_empty_data: bool = False,
67 | allow_deduplicated_ds_name: bool = False,
68 | ):
69 | super().__init__()
70 | self.tokenizer = tokenizer
71 | self.data_rank = data_rank
72 | self.data_world_size = data_world_size
73 | self.datasets = datasets
74 | self.num_images_expected = num_images_expected
75 | self.max_buffer_size = max_buffer_size
76 | self.log_freq = log_freq
77 | self.strict_mode = strict_mode
78 | self.debug_mode = debug_mode
79 | self.replacement = replacement
80 | self.allow_overflow = allow_overflow
81 | self.allow_empty_data = allow_empty_data
82 |
83 | self.max_packed_tokens = max_packed_tokens
84 |
85 | self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
86 | self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
87 | self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
88 |
89 | assert self.img_start_token_id != self.tokenizer.unk_token_id
90 | assert self.img_token_id != self.tokenizer.unk_token_id
91 | assert self.img_end_token_id != self.tokenizer.unk_token_id
92 |
93 | if dataset_weight is None:
94 | dataset_weight = [1] * len(datasets)
95 | self.dataset_type = [d.dataset_type for d in self.datasets]
96 |
97 | self.datasets_orig = datasets
98 | self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight]
99 |
100 | self.datasets = [ds for ds in self.datasets_orig]
101 | self.dataset_weight = [w for w in self.dataset_weight_orig]
102 |
103 | # lazy init
104 | self.worker_id = None
105 | self.worker_state_key = None
106 | self.dataset_iter_list = None
107 | self._state_dict = {
108 | 'sample_info': {d.ds_name:0 for d in self.datasets},
109 | }
110 |
111 | self.worker_custom_infos = None
112 |
113 | ds_name_list = [d.ds_name for d in self.datasets]
114 | if not allow_deduplicated_ds_name:
115 | assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}'
116 |
117 | for ds in self.datasets:
118 | if ds.max_num_images > self.num_images_expected:
119 | logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}')
120 | ds.max_num_images = num_images_expected
121 |
122 | if ds.max_tokens > self.max_packed_tokens:
123 | logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}')
124 | ds.max_tokens = self.max_packed_tokens
125 |
126 | self._state_dict[ds.ds_name] = {}
127 |
128 | if get_rank() == 0:
129 | logger.info(
130 | f'Loaded dataset to pack: {ds_name_list}, '
131 | f'{self.num_images_expected=}, {self.max_packed_tokens=}, '
132 | f'{self.replacement=}, {self.allow_overflow=}',
133 | )
134 |
135 | temp = []
136 | for ds, ds_w in zip(self.datasets, self.dataset_weight):
137 | temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%')
138 | temp = '\n'.join(temp)
139 | logger.info(
140 | f'Sampling prob for each dataset:\n{temp}'
141 | )
142 |
143 | if self.allow_empty_data:
144 | logger.warning('allow_empty_data is enabled, note that empty data may be generated!')
145 |
146 | def load_state_dict(self, state_dict, custom_infos=None):
147 |
148 | self.worker_custom_infos = custom_infos
149 |
150 | self._state_dict.update(state_dict)
151 | for ds in self.datasets:
152 | if ds.ds_name in self._state_dict:
153 | ds.load_state_dict(self._state_dict[ds.ds_name])
154 | logger.info(f'{ds.ds_name=} is resumed.')
155 | else:
156 | logger.warning(f'{ds.ds_name=} is not resumed.')
157 |
158 | def _should_log(self):
159 | worker_id = 0 if get_worker_info() is None else get_worker_info().id
160 | num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers
161 |
162 | worker_id = num_workers * get_rank() + worker_id
163 | num_workers = num_workers * get_world_size()
164 |
165 | return worker_id == 0
166 |
167 | def next_data(self, current_dataset_idx):
168 | while True:
169 | try:
170 | current_sample = next(self.dataset_iter_list[current_dataset_idx])
171 | break # Exit loop if successful
172 | except StopIteration:
173 | if self.replacement:
174 | # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.')
175 | try:
176 | self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx])
177 | current_sample = next(self.dataset_iter_list[current_dataset_idx])
178 | break
179 | except:
180 | # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
181 | self.datasets.pop(current_dataset_idx)
182 | self.dataset_iter_list.pop(current_dataset_idx)
183 | self.dataset_weight.pop(current_dataset_idx)
184 |
185 | if len(self.datasets) == 0:
186 | raise StopIteration
187 | current_dataset_idx = np.random.choice(len(self.datasets))
188 | else:
189 | # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
190 | self.datasets.pop(current_dataset_idx)
191 | self.dataset_iter_list.pop(current_dataset_idx)
192 | self.dataset_weight.pop(current_dataset_idx)
193 |
194 | if len(self.datasets) == 0:
195 | raise StopIteration
196 | current_dataset_idx = np.random.choice(len(self.datasets))
197 | except:
198 | logger.error('Unexpected error!')
199 | if len(self.datasets) == 0:
200 | raise StopIteration
201 | current_dataset_idx = np.random.choice(len(self.datasets))
202 |
203 | current_ds_name = self.datasets[current_dataset_idx].ds_name
204 | current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx
205 |
206 | if self.worker_state_key not in self._state_dict[current_ds_name]:
207 | self._state_dict[current_ds_name][self.worker_state_key] = {}
208 |
209 | meta_info = current_sample.pop('meta_info', {})
210 | self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info)
211 | self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1
212 | return current_sample
213 |
214 | def find_buffer(self, buffer_list, new_sample):
215 | # NOTE: use `bisect` to search might be faster
216 |
217 | find = False
218 | find_idx = -1
219 | num_images_current = new_sample['pixel_values'].size(0)
220 | for buffer_idx, buffer in enumerate(buffer_list):
221 | num_images_buffer = buffer['pixel_values'].size(0)
222 | if num_images_buffer + num_images_current <= self.num_images_expected:
223 | num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0)
224 |
225 | if num_merged_tokens <= self.max_packed_tokens:
226 | find = True
227 | find_idx = buffer_idx
228 | break
229 |
230 | if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2:
231 | find = True
232 | find_idx = buffer_idx
233 |
234 | if find:
235 | return buffer_list.pop(find_idx)
236 | return None
237 |
238 | def update_buffer(self, buffer, new_sample):
239 | if buffer is None:
240 | new_sample['data_index'] = torch.zeros_like(new_sample['input_ids'])
241 | return new_sample
242 |
243 | new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item()
244 |
245 | assert buffer.keys() == new_sample.keys()
246 | for k in buffer:
247 | buffer[k] = torch.cat([buffer[k], new_sample[k]])
248 | return buffer
249 |
250 | @staticmethod
251 | def check_valid(sample_to_check, min_active_tokens_ratio=1/256):
252 | num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum()
253 | num_tokens = sample_to_check['labels'].numel()
254 | return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio
255 |
256 | @staticmethod
257 | def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id):
258 | if buffer['input_ids'].size(0) <= max_tokens:
259 | return [buffer]
260 |
261 | def _image_is_splitted(input_ids, cut_idx):
262 | is_image_start = input_ids[cut_idx].item() == img_start_token_id
263 | is_image_token = input_ids[cut_idx].item() == img_token_id
264 | is_image_end = input_ids[cut_idx].item() == img_end_token_id
265 | return is_image_start or is_image_token or is_image_end
266 |
267 | def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx):
268 | assert (right_idx is None) == (right_img_idx is None)
269 |
270 | left_sample = {}
271 | right_sample = {} if right_idx is not None else None
272 | for k in sample_to_split:
273 | if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']:
274 | left_sample[k] = sample_to_split[k][:left_idx]
275 | if right_sample is not None:
276 | right_sample[k] = sample_to_split[k][right_idx:]
277 | elif k in ['pixel_values', 'image_flags']:
278 | left_sample[k] = sample_to_split[k][:left_img_idx]
279 | if right_sample is not None:
280 | right_sample[k] = sample_to_split[k][right_img_idx:]
281 | else:
282 | raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}')
283 | return left_sample, right_sample
284 |
285 | splitted_buffer = []
286 | while buffer['input_ids'].size(0) > max_tokens:
287 | img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist()
288 | img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist()
289 | assert len(img_start_idx_list) == len(img_end_idx_list)
290 |
291 | if _image_is_splitted(buffer['input_ids'], max_tokens):
292 | cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens)
293 | if buffer['input_ids'][max_tokens] == img_start_token_id:
294 | assert max_tokens == img_start_idx_list[cut_idx]
295 | cut_left_idx = img_start_idx_list[cut_idx]
296 | cut_left_img_idx = cut_idx
297 | else:
298 | cut_left_idx = img_start_idx_list[cut_idx - 1]
299 | cut_left_img_idx = cut_idx - 1
300 | cut_right_idx = cut_left_idx
301 | cut_right_img_idx = cut_left_img_idx
302 | else:
303 | cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens)
304 | if cut_img_idx < len(img_start_idx_list):
305 | cut_right_idx = img_start_idx_list[cut_img_idx]
306 | cut_right_img_idx = cut_img_idx
307 | else:
308 | cut_right_idx = None
309 | cut_right_img_idx = None
310 |
311 | cut_left_idx = max_tokens
312 | cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0)
313 |
314 | left, right = _split(
315 | sample_to_split=buffer,
316 | left_idx=cut_left_idx,
317 | left_img_idx=cut_left_img_idx,
318 | right_idx=cut_right_idx,
319 | right_img_idx=cut_right_img_idx,
320 | )
321 |
322 | assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0)
323 | if right is not None:
324 | assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0)
325 |
326 | if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left):
327 | splitted_buffer.append(left)
328 |
329 | if right is None or right['pixel_values'].size(0) == 0:
330 | break
331 |
332 | buffer = right
333 | if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer):
334 | splitted_buffer.append(buffer)
335 | break
336 |
337 | logger.debug(
338 | f'split a sample into {len(splitted_buffer)} samples, '
339 | f'current max_tokens={max_tokens}'
340 | )
341 | return splitted_buffer
342 |
343 | def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer):
344 | # NOTE: in-place operation
345 |
346 | splitted_buffer = PackedDataset.split_buffer(
347 | buffer=buffer,
348 | max_tokens=self.max_packed_tokens,
349 | img_start_token_id=self.img_start_token_id,
350 | img_token_id=self.img_token_id,
351 | img_end_token_id=self.img_end_token_id,
352 | )
353 |
354 | for each_buffer in splitted_buffer:
355 | if each_buffer['pixel_values'].size(0) > self.num_images_expected:
356 | logger.error(
357 | f"Find a sample with {each_buffer['pixel_values'].size(0)} images, "
358 | f'which exceeds {self.num_images_expected}'
359 | )
360 | continue
361 |
362 | if each_buffer['input_ids'].size(0) >= self.max_packed_tokens:
363 | assert each_buffer['input_ids'].size(0) == self.max_packed_tokens
364 | buffer_max_len_list.append(each_buffer)
365 | continue
366 |
367 | find_idx = len(buffer_list)
368 | num_images_new_sample = each_buffer['pixel_values'].size(0)
369 | for buffer_idx in range(len(buffer_list)):
370 | if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample:
371 | find_idx = buffer_idx
372 | break
373 | buffer_list.insert(find_idx, each_buffer)
374 |
375 | for i in range(1, len(buffer_list)):
376 | assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0)
377 |
378 | return buffer_list, buffer_max_len_list
379 |
380 | def pad_buffer(self, buffer):
381 | if buffer['pixel_values'].size(0) == self.num_images_expected:
382 | return buffer
383 |
384 | num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0)
385 | pad_images = torch.stack([
386 | torch.zeros_like(buffer['pixel_values'][0])
387 | for _ in range(num_pad_images)
388 | ])
389 | pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long)
390 |
391 | buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images])
392 | buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags])
393 |
394 | return buffer
395 |
396 | def postprocess_buffer(self, buffer, custom_infos=None):
397 | buffer['worker_state_key'] = self.worker_state_key
398 | buffer['worker_state_dict'] = self._state_dict
399 | if custom_infos is not None:
400 | buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)}
401 | return buffer
402 |
403 | def print_log(self, iter_idx, buffer_list):
404 | if iter_idx % self.log_freq != 0:
405 | return
406 |
407 | if self._should_log():
408 | logger.info(
409 | f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}"
410 | )
411 |
412 | def __iter__(self):
413 | iter_idx = 0
414 | buffer_list = []
415 | buffer_max_len_list = []
416 |
417 | if self._should_log():
418 | logger.info(f'Begin to iter, {len(buffer_list)=}')
419 |
420 | worker_id = 0 if get_worker_info() is None else get_worker_info().id
421 | num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers
422 |
423 | worker_id = num_workers * self.data_rank + worker_id
424 | num_workers = num_workers * self.data_world_size
425 |
426 | rng = np.random.default_rng(seed=worker_id)
427 |
428 | # reset states of each dataset
429 | self.worker_id = worker_id
430 | self.worker_state_key = f'work_state_{self.worker_id}'
431 | self.datasets = [d for d in self.datasets_orig]
432 | self.dataset_weight = [w for w in self.dataset_weight_orig]
433 | self.dataset_iter_list = [iter(d) for d in self.datasets]
434 |
435 | for ds in self.datasets:
436 | # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)):
437 | ds.worker_id = worker_id
438 | ds.worker_state_key = f'work_state_{self.worker_id}'
439 | ds.num_workers = num_workers
440 | if self._should_log() and worker_id == 0:
441 | logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}')
442 |
443 | if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos:
444 | custom_infos = self.worker_custom_infos[self.worker_state_key]
445 | # buffer list
446 | if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list):
447 | buffer_list = custom_infos['buffer_list']
448 | if self._should_log() and worker_id == 0:
449 | logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}')
450 | # other infos
451 |
452 | # reset
453 | self.worker_custom_infos = None
454 |
455 | logger.debug(
456 | f'{self.__class__.__name__} Rank {self.data_rank} '
457 | f'Worker {worker_id} begin to load data'
458 | )
459 |
460 | while True:
461 | self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight]
462 | current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight)
463 |
464 | try:
465 | current_sample = self.next_data(current_dataset_idx)
466 | except:
467 | logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})')
468 | while len(buffer_list) > 0:
469 | if self.strict_mode:
470 | yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)))
471 | else:
472 | yield self.postprocess_buffer(buffer_list.pop(0))
473 | logger.info(f'buffer_list is empty! ({len(buffer_list)=})')
474 | return
475 |
476 | buffer = self.find_buffer(buffer_list, current_sample)
477 | buffer = self.update_buffer(buffer, current_sample)
478 | buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer)
479 |
480 | while len(buffer_max_len_list) > 0:
481 | if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens:
482 | logger.debug(
483 | f'num tokens of a buffer exceed {self.max_packed_tokens=}, '
484 | f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images"
485 | )
486 | if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected:
487 | # buffer_max_len_list.pop(0)
488 | yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list})
489 | else:
490 | yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list})
491 |
492 | while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected:
493 | logger.error(
494 | f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) "
495 | f'is larger than num_images_expected({self.num_images_expected})'
496 | )
497 | buffer_list.pop(0)
498 |
499 | while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected:
500 | if self.debug_mode:
501 | debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
502 | while True:
503 | yield debug_data.copy()
504 |
505 | yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
506 |
507 | while len(buffer_list) > self.max_buffer_size:
508 | logger.debug(
509 | f'Failed to pack data to exactly {self.num_images_expected} images, '
510 | f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images."
511 | )
512 | if self.strict_mode:
513 | yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list})
514 | else:
515 | yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
516 |
517 | self.print_log(iter_idx=iter_idx, buffer_list=buffer_list)
518 | iter_idx += 1
519 |
520 | @staticmethod
521 | def get_cu_seqlens_and_indexes(
522 | data_index: torch.LongTensor, # (seq_len,)
523 | input_ids: torch.LongTensor, # (seq_len,)
524 | labels: torch.LongTensor, # (seq_len,)
525 | len2weight: callable,
526 | ):
527 | indexes = []
528 | cu_seqlens = [0]
529 | loss_weight = []
530 |
531 | start = data_index.min()
532 | end = data_index.max() + 1
533 | for i in range(start, end):
534 | num_tokens = (data_index == i).sum().item()
535 | indexes.extend(list(range(num_tokens)))
536 | cu_seqlens.append(cu_seqlens[-1] + num_tokens)
537 | assert num_tokens > 0
538 |
539 | curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
540 | assert (curr_data_index == i).all(), data_index
541 |
542 | curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
543 | num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item()
544 | loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens)
545 |
546 | assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}'
547 |
548 | loss_weight = torch.tensor(loss_weight, dtype=torch.float32)
549 | return cu_seqlens, indexes, loss_weight
550 |
551 |
552 | WARNING_CNT = defaultdict(int)
553 |
554 |
555 | def packed_collate_fn(
556 | features,
557 | data_collator,
558 | len2weight: callable,
559 | max_item_length: int,
560 | micro_num: int = 1,
561 | loss_reduction_all_gather: bool = False,
562 | pad_id: int = 0,
563 | ):
564 | if not isinstance(features, list):
565 | features = [features]
566 |
567 | if len(features) > micro_num:
568 | raise NotImplementedError(f'{len(features)=} > {micro_num=}')
569 |
570 | if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5:
571 | logger.warning(
572 | f'{len(features)=} > {micro_num=}, '
573 | f'the features will be padded to satisfy micro_num requirement'
574 | )
575 | WARNING_CNT['micro_num_warning'] += 1
576 |
577 | # ensure that the len(features) is equal to the required micro_num
578 | num_features = len(features)
579 | while len(features) < micro_num:
580 | features.append(copy.deepcopy(features[0]))
581 | features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID)
582 |
583 | indexes = []
584 | cu_seqlens = []
585 | cu_num_images_list = [0]
586 |
587 | worker_state_key_list = []
588 | worker_state_dict_list = []
589 | worker_state_custom_infos_list = []
590 |
591 | batch_lens = [feat['input_ids'].shape for feat in features]
592 | max_item_length = max_item_length or max(batch_lens)[0]
593 |
594 | num_samples = 0
595 | num_padding_tokens = 0
596 | for feat_idx, feat in enumerate(features):
597 | data_index = feat.pop('data_index')
598 | curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes(
599 | data_index=data_index,
600 | input_ids=feat['input_ids'],
601 | labels=feat['labels'],
602 | len2weight=len2weight,
603 | )
604 |
605 | feat['loss_weight'] = curr_loss_weight
606 |
607 | if feat_idx < num_features:
608 | num_samples += len(curr_cu_seqlens) - 1
609 |
610 | if curr_cu_seqlens[-1] < max_item_length:
611 | curr_cu_seqlens.append(max_item_length)
612 | curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2])))
613 |
614 | indexes.append(torch.tensor(curr_indexes, dtype=torch.long))
615 | cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32))
616 |
617 | worker_state_key_list.append(feat.pop('worker_state_key'))
618 | worker_state_dict_list.append(feat.pop('worker_state_dict'))
619 | worker_state_custom_infos_list.append(feat.pop('custom_infos', None))
620 |
621 | num_padding_tokens += (max_item_length - feat['input_ids'].size(0))
622 | cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0))
623 |
624 | batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id)
625 | # convert it to list in case it is converted into bf16
626 | batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist()
627 | batch['attention_mask'] = torch.stack(cu_seqlens)
628 | batch['loss_reduction_all_gather'] = loss_reduction_all_gather
629 | batch['statistics'] = torch.tensor(
630 | [
631 | num_samples,
632 | num_padding_tokens,
633 | batch['image_flags'].numel() - batch['image_flags'].sum().item(),
634 | ],
635 | dtype=torch.long,
636 | )
637 | batch.pop('type_ids')
638 | return batch
639 |
--------------------------------------------------------------------------------
/model_vllm/__init__.py:
--------------------------------------------------------------------------------
1 | from vllm import ModelRegistry
2 | from .hunyuan import HunYuanForCausalLM
3 | from .hunyuan_video import HunyuanVideoModel
4 | from .video_audio_encoder import VideoAudioEncoder
5 | from .video_audio_llm import VideoAudioLLM
6 |
7 | ModelRegistry.register_model("HunYuanForCausalLM", HunYuanForCausalLM)
8 | ModelRegistry.register_model("HunyuanVideoModel", HunyuanVideoModel)
9 |
--------------------------------------------------------------------------------
/model_vllm/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/hunyuan.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan.cpython-310.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/hunyuan.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan.cpython-311.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/hunyuan_video.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan_video.cpython-310.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/hunyuan_video.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/hunyuan_video.cpython-311.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/monkey_patch_mrope.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/monkey_patch_mrope.cpython-310.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/monkey_patch_mrope.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/monkey_patch_mrope.cpython-311.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/video_audio_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/video_audio_encoder.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_encoder.cpython-311.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/video_audio_llm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_llm.cpython-310.pyc
--------------------------------------------------------------------------------
/model_vllm/__pycache__/video_audio_llm.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/model_vllm/__pycache__/video_audio_llm.cpython-311.pyc
--------------------------------------------------------------------------------
/model_vllm/hunyuan.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union, List
3 |
4 | import torch
5 | from torch import nn
6 | from transformers import PretrainedConfig
7 |
8 | from vllm.attention.layer import Attention
9 | from vllm.compilation.decorators import support_torch_compile
10 | from vllm.config import CacheConfig, VllmConfig
11 | from vllm.distributed import (
12 | get_pp_group,
13 | get_tensor_model_parallel_rank,
14 | get_tensor_model_parallel_world_size,
15 | split_tensor_along_last_dim,
16 | tensor_model_parallel_all_gather,
17 | )
18 | from vllm.model_executor.layers.activation import SiluAndMul
19 | from vllm.model_executor.layers.layernorm import RMSNorm
20 | from vllm.model_executor.layers.linear import (
21 | MergedColumnParallelLinear,
22 | QKVParallelLinear,
23 | RowParallelLinear,
24 | ReplicatedLinear,
25 | )
26 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
27 | from vllm.model_executor.layers.pooler import Pooler, PoolingType
28 | from vllm.model_executor.layers.quantization import QuantizationConfig
29 | from vllm.model_executor.layers.rotary_embedding import (
30 | get_rope,
31 | RotaryEmbedding,
32 | MRotaryEmbedding,
33 | _apply_rotary_emb,
34 | )
35 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36 | from vllm.model_executor.layers.vocab_parallel_embedding import (
37 | ParallelLMHead,
38 | VocabParallelEmbedding,
39 | )
40 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41 |
42 | from vllm.model_executor.pooling_metadata import PoolingMetadata
43 | from vllm.model_executor.sampling_metadata import SamplingMetadata
44 | from vllm.sequence import IntermediateTensors, PoolerOutput
45 |
46 | from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
47 | from vllm.model_executor.models.utils import (
48 | is_pp_missing_parameter,
49 | make_empty_intermediate_tensors_factory,
50 | make_layers,
51 | maybe_prefix,
52 | AutoWeightsLoader,
53 | )
54 |
55 |
56 | class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
57 |
58 | def __init__(
59 | self,
60 | head_size: int,
61 | rotary_dim: int,
62 | max_position_embeddings: int,
63 | base: int,
64 | scaling_alpha: float,
65 | dtype: torch.dtype,
66 | is_neox_style: bool = True,
67 | ) -> None:
68 | self.scaling_alpha = scaling_alpha
69 | super().__init__(
70 | head_size,
71 | rotary_dim,
72 | max_position_embeddings,
73 | base,
74 | is_neox_style,
75 | dtype,
76 | )
77 |
78 | def _compute_cos_sin_cache(self) -> torch.Tensor:
79 | # NOTE(woosuk): self.max_position_embeddings is the original
80 | # maximum length before applying the rope scaling.
81 | # Thus, the maximum length after applying the rope scaling is
82 | # self.max_position_embeddings * self.scaling_alpha.
83 | max_len = self.max_position_embeddings * self.scaling_alpha
84 | base = self.base * self.scaling_alpha ** (
85 | self.rotary_dim / (self.rotary_dim - 2)
86 | )
87 | inv_freq = 1.0 / (
88 | base
89 | ** (torch.arange(0, self.rotary_dim, 2).float() / self.rotary_dim)
90 | )
91 | t = torch.arange(max_len, dtype=torch.float)
92 | freqs = torch.einsum("i,j -> ij", t, inv_freq)
93 | cos = freqs.cos()
94 | sin = freqs.sin()
95 | cache = torch.cat((cos, sin), dim=-1)
96 | return cache
97 |
98 |
99 | def rotate_half(x):
100 | """Rotates half the hidden dims of the input."""
101 | x1 = x[..., : x.shape[-1] // 2]
102 | x2 = x[..., x.shape[-1] // 2:]
103 | return torch.cat((-x2, x1), dim=-1)
104 |
105 |
106 | class DynamicNTKAlphaMRotaryEmbedding(MRotaryEmbedding):
107 |
108 | def __init__(
109 | self,
110 | head_size: int,
111 | rotary_dim: int,
112 | max_position_embeddings: int,
113 | base: int,
114 | scaling_alpha: float,
115 | dtype: torch.dtype,
116 | mrope_section: Optional[List[int]] = None,
117 | is_neox_style: bool = True,
118 | max_model_len: bool = None,
119 | ) -> None:
120 | self.scaling_alpha = scaling_alpha
121 | self.max_model_len = max_model_len
122 | assert len(mrope_section) == 4, "Currently only 4D is supported"
123 | mrope_section = [int(x * rotary_dim // 2) for x in mrope_section]
124 |
125 | # MRotaryEmbedding will enlarge the max_position_embeddings by 4
126 | # To keep consistent with the original max_position_embeddings,
127 | # we need to divide the max_position_embeddings by 4
128 | max_position_embeddings = max_position_embeddings // 4
129 |
130 | super().__init__(
131 | head_size,
132 | rotary_dim,
133 | max_position_embeddings,
134 | base,
135 | is_neox_style,
136 | dtype,
137 | mrope_section,
138 | )
139 |
140 | def _compute_cos_sin_cache(self) -> torch.Tensor:
141 | if self.max_model_len is not None:
142 | max_len = self.max_model_len
143 | else:
144 | max_len = self.max_position_embeddings * self.scaling_alpha
145 |
146 |
147 | base = self.base * self.scaling_alpha ** (
148 | self.rotary_dim / (self.rotary_dim - 2)
149 | )
150 | inv_freq = 1.0 / (
151 | base
152 | ** (torch.arange(0, self.rotary_dim, 2).float() / self.rotary_dim)
153 | )
154 | t = torch.arange(max_len, dtype=torch.float)
155 | freqs = torch.einsum("i,j -> ij", t, inv_freq)
156 | freqs = torch.cat((freqs, freqs), dim=-1)
157 | cos = freqs.cos()
158 | sin = freqs.sin()
159 | cache = torch.cat((cos, sin), dim=-1)
160 | return cache
161 |
162 |
163 | def forward(
164 | self,
165 | positions: torch.Tensor,
166 | query: torch.Tensor,
167 | key: torch.Tensor,
168 | ) -> Tuple[torch.Tensor, torch.Tensor]:
169 | """XDRope implementation following apply_rotary_pos_emb_xdrope pattern.
170 |
171 | Args:
172 | positions:
173 | [num_tokens,] (text only) or
174 | [4, num_tokens] (4D positions with multimodal inputs)
175 | query: [num_tokens, num_heads * head_size]
176 | key: [num_tokens, num_kv_heads * head_size]
177 | """
178 | assert positions.ndim == 2, f"positions must be 2D, but got {positions.shape}"
179 |
180 | num_tokens = positions.shape[-1]
181 | cos_sin = self.cos_sin_cache[positions]
182 | cos, sin = cos_sin.chunk(2, dim=-1)
183 |
184 | x_dim = len(self.mrope_section)
185 |
186 | cos = cos.permute(1, 0, 2).reshape(-1, x_dim, self.rotary_dim)
187 | sin = sin.permute(1, 0, 2).reshape(-1, x_dim, self.rotary_dim)
188 |
189 | xdrope_section = self.mrope_section * 2
190 | assert sum(xdrope_section) == self.rotary_dim
191 |
192 | cos = torch.cat([
193 | m[:, i % x_dim] for i, m in enumerate(cos.split(xdrope_section, dim=-1))
194 | ], dim=-1)
195 | sin = torch.cat([
196 | m[:, i % x_dim] for i, m in enumerate(sin.split(xdrope_section, dim=-1))
197 | ], dim=-1)
198 |
199 | cos = cos.view(1, -1, self.rotary_dim)
200 | sin = sin.view(1, -1, self.rotary_dim)
201 |
202 | cos = cos.permute(1, 0, 2)
203 | sin = sin.permute(1, 0, 2)
204 |
205 | query_shape = query.shape
206 | query = query.view(num_tokens, -1, self.head_size)
207 | query = (query * cos) + rotate_half(query) * sin
208 | query = query.reshape(query_shape)
209 |
210 | key_shape = key.shape
211 | key = key.view(num_tokens, -1, self.head_size)
212 | key = (key * cos) + rotate_half(key) * sin
213 | key = key.reshape(key_shape)
214 |
215 | return query, key
216 |
217 |
218 | @classmethod
219 | def get_input_positions(
220 | cls,
221 | input_tokens: List[int],
222 | hf_config: PretrainedConfig,
223 | image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
224 | video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
225 | second_per_grid_ts: Optional[List[float]],
226 | context_len: int = 0,
227 | seq_len: Optional[int] = None,
228 | audio_feature_lengths: Optional[torch.Tensor] = None,
229 | use_audio_in_video: bool = False,
230 | ) -> Tuple[List[List[int]], int]:
231 | """Get xdrope input positions and delta value."""
232 |
233 | image_grid_thw = [] if image_grid_thw is None else image_grid_thw
234 | video_grid_thw = [] if video_grid_thw is None else video_grid_thw
235 | second_per_grid_ts = [] if second_per_grid_ts is None else \
236 | second_per_grid_ts
237 |
238 | llm_positions, mrope_position_delta = \
239 | cls.get_input_positions_tensor(
240 | input_tokens=input_tokens,
241 | hf_config=hf_config,
242 | image_grid_thw=image_grid_thw,
243 | video_grid_thw=video_grid_thw,
244 | second_per_grid_ts=second_per_grid_ts,
245 | context_len=context_len,
246 | seq_len=seq_len,
247 | audio_feature_lengths=audio_feature_lengths,
248 | use_audio_in_video=use_audio_in_video,
249 | )
250 |
251 | return llm_positions.tolist(), mrope_position_delta
252 |
253 |
254 | @classmethod
255 | def get_input_positions_tensor(
256 | cls,
257 | input_tokens: List[int],
258 | hf_config: PretrainedConfig,
259 | image_grid_thw: Union[List[List[int]], torch.Tensor],
260 | video_grid_thw: Union[List[List[int]], torch.Tensor],
261 | second_per_grid_ts: List[float],
262 | context_len: int = 0,
263 | seq_len: Optional[int] = None,
264 | audio_feature_lengths: Optional[torch.Tensor] = None,
265 | use_audio_in_video: bool = False,
266 | ) -> Tuple[torch.Tensor, int]:
267 | return cls._vl_get_input_positions_tensor(
268 | input_tokens=input_tokens,
269 | hf_config=hf_config,
270 | image_grid_thw=image_grid_thw,
271 | video_grid_thw=video_grid_thw,
272 | second_per_grid_ts=second_per_grid_ts,
273 | context_len=context_len,
274 | seq_len=seq_len,
275 | )
276 |
277 |
278 | @classmethod
279 | def _vl_get_input_positions_tensor(
280 | cls,
281 | input_tokens: List[int],
282 | hf_config: PretrainedConfig,
283 | image_grid_thw: Union[List[List[int]], torch.Tensor],
284 | video_grid_thw: Union[List[List[int]], torch.Tensor],
285 | second_per_grid_ts: List[float],
286 | context_len: int = 0,
287 | seq_len: Optional[int] = None,
288 | ) -> Tuple[torch.Tensor, int]:
289 | """Get xdrope input positions following get_xdrope_position_ids pattern."""
290 |
291 |
292 | image_token_id = hf_config.image_token_id
293 | vision_start_token_id = hf_config.vision_start_token_id
294 |
295 | input_tokens_tensor = torch.tensor(input_tokens)
296 |
297 | # Initialize 4D position embeddings (following xdrope pattern)
298 | seq_length = len(input_tokens)
299 | position_ids_seq = torch.arange(seq_length) # Sequential positions
300 | position_ids_t = position_ids_seq.clone()
301 | position_ids_x = position_ids_seq.clone()
302 | position_ids_y = position_ids_seq.clone()
303 |
304 | vision_start_indices = torch.argwhere(
305 | input_tokens_tensor == vision_start_token_id).squeeze(1)
306 |
307 | if len(vision_start_indices) == 0:
308 | # No vision tokens, return 4D sequential positions
309 | llm_positions = torch.stack([position_ids_seq, position_ids_x, position_ids_y, position_ids_t])
310 | mrope_position_delta = 0
311 | llm_positions = llm_positions[:, context_len:seq_len]
312 | return llm_positions, mrope_position_delta
313 |
314 | # Process vision tokens using image_grid_thw information
315 | image_index, video_index = 0, 0
316 | current_pos = 0
317 |
318 | for start_idx in vision_start_indices:
319 | start_idx = start_idx.item()
320 |
321 | # Determine if this is image or video token
322 | if start_idx + 1 < len(input_tokens):
323 | next_token = input_tokens[start_idx + 1]
324 | is_image = (next_token == image_token_id)
325 |
326 | if is_image and image_index < len(image_grid_thw):
327 | t, h, w = image_grid_thw[image_index]
328 | image_index += 1
329 | else:
330 | continue
331 |
332 | # Calculate grid dimensions
333 | llm_grid_t, llm_grid_h, llm_grid_w = (
334 | t, h, w
335 | )
336 |
337 | # Find end of vision tokens (approximate)
338 | vision_token_count = llm_grid_t * llm_grid_h * llm_grid_w
339 | end_idx = min(start_idx + vision_token_count + 2, seq_length) # +2 for start/end tokens
340 |
341 | # Apply xdrope position assignment pattern
342 | if end_idx > start_idx + 2: # Ensure we have vision tokens
343 | # Reset time dimension for vision tokens (following get_xdrope_position_ids)
344 | position_ids_t[start_idx + 2:end_idx] = current_pos
345 | current_pos += 1
346 |
347 | # Calculate row and column for 2D layout
348 | vision_tokens_between = end_idx - start_idx - 2 # excluding start/end
349 | if llm_grid_h > 0:
350 | tokens_per_row = llm_grid_w
351 | num_rows = llm_grid_h
352 |
353 | # Assign x,y coordinates following the pattern
354 | idx_xy = 0
355 | for rr in range(num_rows):
356 | for cc in range(tokens_per_row):
357 | if start_idx + 2 + idx_xy < end_idx:
358 | position_ids_x[start_idx + 2 + idx_xy] = cc
359 | position_ids_y[start_idx + 2 + idx_xy] = rr
360 | idx_xy += 1
361 |
362 | # Stack into 4D positions
363 | llm_positions = torch.stack([position_ids_seq, position_ids_x, position_ids_y, position_ids_t])
364 | mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
365 | llm_positions = llm_positions[:, context_len:seq_len]
366 |
367 | return llm_positions, mrope_position_delta
368 |
369 |
370 | @staticmethod
371 | def get_next_input_positions(
372 | mrope_position_delta: int,
373 | context_len: int,
374 | seq_len: int,
375 | ) -> List[List[int]]:
376 | return [
377 | list(
378 | range(context_len + mrope_position_delta,
379 | seq_len + mrope_position_delta)) for _ in range(4) # Changed from 3 to 4
380 | ]
381 |
382 |
383 | @staticmethod
384 | def get_next_input_positions_tensor(
385 | mrope_position_delta: int,
386 | context_len: int,
387 | seq_len: int,
388 | ) -> torch.Tensor:
389 | return torch.arange(
390 | mrope_position_delta + context_len,
391 | mrope_position_delta + seq_len,
392 | ).expand(4, -1) # Changed from 3 to 4
393 |
394 |
395 | class HunyuanMLP(nn.Module):
396 | def __init__(
397 | self,
398 | config: PretrainedConfig,
399 | hidden_size: int,
400 | intermediat_size: int,
401 | hidden_act: str,
402 | quant_config: Optional[QuantizationConfig] = None,
403 | prefix: str = "",
404 | ) -> None:
405 | super().__init__()
406 | self.gate_and_up_proj = MergedColumnParallelLinear(
407 | hidden_size,
408 | [intermediat_size] * 2,
409 | bias=config.mlp_bias,
410 | quant_config=quant_config,
411 | prefix=f"{prefix}.gate_and_up_proj",
412 | )
413 | # self.down_proj = ReplicatedLinear(
414 | # intermediat_size,
415 | # hidden_size,
416 | # bias=config.mlp_bias,
417 | # quant_config=quant_config,
418 | # prefix=f"{prefix}.down_proj",
419 | # )
420 | self.down_proj = nn.Linear(
421 | intermediat_size,
422 | hidden_size,
423 | bias=config.mlp_bias,
424 | )
425 | self.act_fn = SiluAndMul()
426 |
427 | def forward(self, x):
428 | gate_up, _ = self.gate_and_up_proj(x)
429 | x = self.act_fn(gate_up)
430 | x = self.down_proj(x)
431 | return x
432 |
433 |
434 | class HunYuanAttention(nn.Module):
435 | def __init__(
436 | self,
437 | config: PretrainedConfig,
438 | hidden_size: int,
439 | num_heads: int,
440 | num_kv_heads: int,
441 | rope_theta: float = 10000,
442 | rope_scaling: Optional[Dict[str, Any]] = None,
443 | max_position_embeddings: int = 8192,
444 | attention_bias: bool = False,
445 | cache_config: Optional[CacheConfig] = None,
446 | quant_config: Optional[QuantizationConfig] = None,
447 | prefix: str = "",
448 | ):
449 | super().__init__()
450 | self.hidden_size = hidden_size
451 | self.tp_size = get_tensor_model_parallel_world_size()
452 | self.tp_rank = get_tensor_model_parallel_rank()
453 | self.total_num_heads = num_heads
454 | assert self.total_num_heads % self.tp_size == 0
455 | self.num_heads = self.total_num_heads // self.tp_size
456 | self.total_num_kv_heads = num_kv_heads
457 | if self.total_num_kv_heads >= self.tp_size:
458 | # Number of KV heads is greater than TP size, so we partition
459 | # the KV heads across multiple tensor parallel GPUs.
460 | assert self.total_num_kv_heads % self.tp_size == 0
461 | else:
462 | # Number of KV heads is less than TP size, so we replicate
463 | # the KV heads across multiple tensor parallel GPUs.
464 | assert self.tp_size % self.total_num_kv_heads == 0
465 | self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
466 | self.head_dim = hidden_size // self.total_num_heads
467 | self.q_size = self.num_heads * self.head_dim
468 | self.kv_size = self.num_kv_heads * self.head_dim
469 | self.key_value_groups = int(self.num_heads / self.num_kv_heads)
470 | self.scaling = self.head_dim**-0.5
471 | self.rope_theta = rope_theta
472 | self.max_position_embeddings = max_position_embeddings
473 |
474 | self.qkv_proj = QKVParallelLinear(
475 | hidden_size,
476 | self.head_dim,
477 | self.total_num_heads,
478 | self.total_num_kv_heads,
479 | bias=attention_bias,
480 | quant_config=quant_config,
481 | prefix=f"{prefix}.wqkv",
482 | )
483 |
484 | # self.o_proj = RowParallelLinear(
485 | # self.total_num_heads * self.head_dim,
486 | # hidden_size,
487 | # bias=attention_bias,
488 | # quant_config=quant_config,
489 | # prefix=f"{prefix}.wo",
490 | # )
491 | self.o_proj = nn.Linear(
492 | self.total_num_heads * self.head_dim,
493 | hidden_size,
494 | bias=attention_bias,
495 | )
496 |
497 | self.query_layernorm = (
498 | RMSNorm(self.head_dim, eps=config.rms_norm_eps)
499 | if config.use_qk_norm
500 | else None
501 | )
502 | self.key_layernorm = (
503 | RMSNorm(self.head_dim, eps=config.rms_norm_eps)
504 | if config.use_qk_norm
505 | else None
506 | )
507 |
508 | self.rotary_emb = (
509 | DynamicNTKAlphaMRotaryEmbedding(
510 | self.head_dim,
511 | self.head_dim,
512 | max_position_embeddings,
513 | int(rope_theta),
514 | scaling_alpha=rope_scaling["alpha"],
515 | dtype=torch.get_default_dtype(),
516 | mrope_section=rope_scaling["mrope_section"],
517 | max_model_len=config.max_model_len,
518 | )
519 | if config.use_rotary_pos_emb
520 | else None
521 | )
522 |
523 | self.attn = Attention(
524 | self.num_heads,
525 | self.head_dim,
526 | self.scaling,
527 | num_kv_heads=self.num_kv_heads,
528 | cache_config=cache_config,
529 | quant_config=quant_config,
530 | prefix=f"{prefix}.attn",
531 | )
532 |
533 | def split_qkv(self, qkv: torch.Tensor):
534 | seq_len = qkv.shape[0]
535 | if self.tp_size > 1:
536 | qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
537 | qkv = tensor_model_parallel_all_gather(qkv)
538 | qkv = torch.split(qkv, qkv_map, dim=-1)
539 | qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
540 | qkv = torch.cat(qkv, dim=-1)
541 |
542 | qkv = qkv.view(
543 | seq_len,
544 | self.total_num_kv_heads,
545 | self.key_value_groups + 2,
546 | self.head_dim,
547 | )
548 | q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
549 | q = q.reshape(seq_len, self.q_size * self.tp_size)
550 | k = k.reshape(seq_len, self.kv_size * self.tp_size)
551 | v = v.reshape(seq_len, self.kv_size * self.tp_size)
552 |
553 | if self.tp_size > 1:
554 | splitter = partial(
555 | split_tensor_along_last_dim, num_partitions=self.tp_size
556 | )
557 | q = splitter(q)[self.tp_rank]
558 | k = splitter(k)[self.tp_rank]
559 | v = splitter(v)[self.tp_rank]
560 | return q, k, v
561 |
562 | def forward(
563 | self,
564 | positions: torch.Tensor,
565 | hidden_states: torch.Tensor,
566 | ) -> torch.Tensor:
567 | qkv, _ = self.qkv_proj(hidden_states)
568 | q, k, v = self.split_qkv(qkv)
569 |
570 | if self.rotary_emb is not None:
571 | q, k = self.rotary_emb(positions, q, k)
572 |
573 |
574 | if self.query_layernorm is not None:
575 | q = q.reshape(-1, self.num_heads, self.head_dim)
576 | q = self.query_layernorm(q).reshape(-1, self.q_size)
577 |
578 | if self.key_layernorm is not None:
579 | k = k.reshape(-1, self.num_kv_heads, self.head_dim)
580 | k = self.key_layernorm(k).reshape(-1, self.kv_size)
581 |
582 | attn_output = self.attn(q, k, v)
583 | output = self.o_proj(attn_output)
584 | return output
585 |
586 |
587 | class HunYuanDecoderLayer(nn.Module):
588 | def __init__(
589 | self,
590 | config: PretrainedConfig,
591 | cache_config: CacheConfig,
592 | quant_config: QuantizationConfig,
593 | prefix: str = "",
594 | ):
595 | super().__init__()
596 | self.hidden_size = config.hidden_size
597 | self.self_attn = HunYuanAttention(
598 | config,
599 | config.hidden_size,
600 | config.num_attention_heads,
601 | config.num_key_value_heads,
602 | config.rope_theta,
603 | config.rope_scaling,
604 | config.max_position_embeddings,
605 | config.attention_bias,
606 | cache_config,
607 | quant_config,
608 | prefix=f"{prefix}.attention",
609 | )
610 | self.mlp = HunyuanMLP(
611 | config,
612 | config.hidden_size,
613 | config.intermediate_size,
614 | config.hidden_act,
615 | quant_config,
616 | prefix=f"{prefix}.mlp",
617 | )
618 |
619 | self.input_layernorm = RMSNorm(
620 | config.hidden_size, eps=config.rms_norm_eps
621 | )
622 | self.post_attention_layernorm = RMSNorm(
623 | config.hidden_size, eps=config.rms_norm_eps
624 | )
625 |
626 | def forward(
627 | self,
628 | positions: torch.Tensor,
629 | hidden_states: torch.Tensor,
630 | residual: Optional[torch.Tensor],
631 | ) -> Tuple[torch.Tensor, torch.Tensor]:
632 | if residual is None:
633 | residual = hidden_states
634 | hidden_states = self.input_layernorm(hidden_states)
635 | else:
636 | hidden_states, residual = self.input_layernorm(
637 | hidden_states, residual
638 | )
639 |
640 | hidden_states = self.self_attn(
641 | positions=positions,
642 | hidden_states=hidden_states,
643 | )
644 |
645 | hidden_states, residual = self.post_attention_layernorm(
646 | hidden_states, residual
647 | )
648 |
649 | hidden_states = self.mlp(hidden_states)
650 |
651 | return hidden_states, residual
652 |
653 |
654 | @support_torch_compile(
655 | dynamic_arg_dims={
656 | "input_ids": 0,
657 | # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
658 | # otherwise (seq_len, ).
659 | "positions": -1,
660 | "intermediate_tensors": 0,
661 | "inputs_embeds": 0,
662 | })
663 | class HunYuanModel(nn.Module):
664 | def __init__(
665 | self,
666 | *,
667 | vllm_config: VllmConfig,
668 | prefix: str = "",
669 | layer_type: Type[HunYuanDecoderLayer] = HunYuanDecoderLayer,
670 | ):
671 | super().__init__()
672 |
673 | config = vllm_config.model_config.hf_config
674 | cache_config = vllm_config.cache_config
675 | quant_config = vllm_config.quant_config
676 |
677 | self.config = config
678 | self.vocab_size = config.vocab_size
679 |
680 | self.embed_tokens = VocabParallelEmbedding(
681 | config.vocab_size,
682 | config.hidden_size,
683 | ) # TODO: This does not support padding_idx, check if this is an issue
684 |
685 | self.start_layer, self.end_layer, self.layers = make_layers(
686 | config.num_hidden_layers,
687 | lambda prefix: layer_type(
688 | config, cache_config, quant_config, prefix=prefix
689 | ),
690 | prefix=f"{prefix}.layers",
691 | )
692 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
693 | self.make_empty_intermediate_tensors = (
694 | make_empty_intermediate_tensors_factory(
695 | ["hidden_states", "residual"], config.hidden_size
696 | )
697 | )
698 |
699 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
700 | return self.embed_tokens(input_ids)
701 |
702 | def forward(
703 | self,
704 | input_ids: torch.Tensor,
705 | positions: torch.Tensor,
706 | intermediate_tensors: Optional[IntermediateTensors] = None,
707 | inputs_embeds: Optional[torch.Tensor] = None,
708 | **kwargs: object,
709 | ) -> Union[torch.Tensor, IntermediateTensors]:
710 | if get_pp_group().is_first_rank:
711 | if inputs_embeds is not None:
712 | hidden_states = inputs_embeds
713 | else:
714 | hidden_states = self.get_input_embeddings(input_ids)
715 | residual = None
716 | else:
717 | hidden_states = intermediate_tensors["hidden_states"]
718 | residual = intermediate_tensors["residual"]
719 | for layer in self.layers[self.start_layer : self.end_layer]:
720 | hidden_states, residual = layer(positions, hidden_states, residual)
721 | if not get_pp_group().is_first_rank:
722 | return IntermediateTensors(
723 | {
724 | "hidden_states": hidden_states,
725 | "residual": residual,
726 | }
727 | )
728 | hidden_states, _ = self.norm(hidden_states, residual)
729 |
730 | return hidden_states
731 |
732 |
733 | class HunYuanForCausalLM(nn.Module, SupportsPP):
734 |
735 | def __init__(
736 | self,
737 | *,
738 | vllm_config: VllmConfig,
739 | prefix: str = "",
740 | model_type: Type[HunYuanModel] = HunYuanModel,
741 | ):
742 | super().__init__()
743 |
744 | config = vllm_config.model_config.hf_config
745 | quant_config = vllm_config.quant_config
746 |
747 | self.config = config
748 | self.model = model_type(
749 | vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
750 | )
751 |
752 | self.lm_head = ParallelLMHead(
753 | config.vocab_size,
754 | config.hidden_size,
755 | bias=False,
756 | quant_config=quant_config,
757 | prefix=maybe_prefix(prefix, "lm_head"),
758 | )
759 | self.logits_processor = LogitsProcessor(config.vocab_size)
760 | self.sampler = get_sampler()
761 | self.make_empty_intermediate_tensors = (
762 | self.model.make_empty_intermediate_tensors
763 | )
764 |
765 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
766 | return self.model.embed_tokens(input_ids)
767 |
768 | def forward(
769 | self,
770 | input_ids: torch.Tensor,
771 | positions: torch.Tensor,
772 | intermediate_tensors: Optional[IntermediateTensors] = None,
773 | inputs_embeds: Optional[torch.Tensor] = None,
774 | **kwargs: object,
775 | ) -> torch.Tensor:
776 | assert positions.ndim == 2, f"positions must be 2D, but got {positions.shape}"
777 | hidden_states = self.model(
778 | input_ids, positions, intermediate_tensors, inputs_embeds
779 | )
780 | return hidden_states
781 |
782 | def compute_logits(
783 | self,
784 | hidden_states: torch.Tensor,
785 | sampling_metadata: SamplingMetadata,
786 | ) -> Optional[torch.Tensor]:
787 | logits = self.logits_processor(
788 | self.lm_head, hidden_states, sampling_metadata
789 | )
790 | return logits
791 |
792 | def sample(
793 | self,
794 | logits: torch.Tensor,
795 | sampling_metadata: SamplingMetadata,
796 | ) -> Optional[SamplerOutput]:
797 | next_tokens = self.sampler(logits, sampling_metadata)
798 | return next_tokens
799 |
800 | def load_weights(
801 | self, weights: Iterable[Tuple[str, torch.Tensor]]
802 | ) -> Set[str]:
803 | loader = AutoWeightsLoader(self)
804 | return loader.load_weights(weights)
805 |
--------------------------------------------------------------------------------
/model_vllm/hunyuan_video.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from collections.abc import Iterable, Mapping, Sequence
3 | from functools import cached_property
4 | from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union, Any
5 | from copy import deepcopy
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torchvision.transforms as T
10 | from PIL import Image
11 | from transformers import (
12 | BatchEncoding,
13 | PretrainedConfig,
14 | TensorType,
15 | WhisperFeatureExtractor,
16 | )
17 | import math
18 | import logging
19 |
20 | from vllm.config import VllmConfig
21 | from vllm.model_executor.layers.quantization import QuantizationConfig
22 | from vllm.model_executor.layers.quantization.awq import AWQConfig
23 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24 | from vllm.model_executor.sampling_metadata import SamplingMetadata
25 | from vllm.multimodal import MULTIMODAL_REGISTRY
26 | from vllm.multimodal.inputs import (
27 | MultiModalFieldConfig,
28 | MultiModalKwargs,
29 | NestedTensors,
30 | MultiModalDataDict,
31 | MultiModalInputs,
32 | )
33 | from vllm.multimodal.parse import (
34 | ImageEmbeddingItems,
35 | ImageProcessorItems,
36 | ImageSize,
37 | MultiModalDataItems,
38 | )
39 | from vllm.multimodal.processing import (
40 | BaseMultiModalProcessor,
41 | BaseProcessingInfo,
42 | PromptReplacement,
43 | PromptUpdate,
44 | PromptUpdateDetails,
45 | )
46 | from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
47 | from vllm.sequence import IntermediateTensors
48 | from vllm.transformers_utils.tokenizer import AnyTokenizer
49 |
50 | from vllm.model_executor.models.interfaces import (
51 | MultiModalEmbeddings,
52 | SupportsMultiModal,
53 | SupportsPP,
54 | )
55 | from vllm.model_executor.models.utils import (
56 | AutoWeightsLoader,
57 | flatten_bn,
58 | init_vllm_registered_model,
59 | maybe_prefix,
60 | merge_multimodal_embeddings,
61 | WeightsMapper,
62 | )
63 | from vllm.model_executor.models.whisper import WhisperEncoder
64 | from vllm.multimodal.parse import (
65 | MultiModalDataParser,
66 | ModalityData,
67 | ModalityDataItems,
68 | DictEmbeddingItems,
69 | ProcessorBatchItems,
70 | )
71 | from vllm.multimodal.inputs import ImageItem
72 | from vllm.transformers_utils.tokenizer import decode_tokens
73 | from vllm.multimodal.hasher import MultiModalHasher
74 |
75 |
76 | logger = logging.getLogger(__name__)
77 |
78 |
79 | IMG_START = "
"
80 | IMG_END = ""
81 | IMG_CONTEXT = ""
82 |
83 |
84 | def _hunyuan_field_config(hf_inputs: Mapping[str, torch.Tensor]):
85 |
86 | image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
87 | image_grid_sizes = image_grid_thw.prod(-1)
88 |
89 | return dict(
90 | pixel_values_flat=MultiModalFieldConfig.batched("image"),
91 | image_embeds=MultiModalFieldConfig.batched("image"),
92 | image_grid_thw=MultiModalFieldConfig.batched("image"),
93 | )
94 |
95 |
96 | class HunyuanMultiModalDataParser(MultiModalDataParser):
97 |
98 | def _parse_image_data(
99 | self,
100 | data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
101 | ) -> ModalityDataItems[Any, Any]:
102 | if isinstance(data, dict):
103 | return DictEmbeddingItems(
104 | data,
105 | modality="image",
106 | required_fields={"image_embeds", "image_grid_thw"},
107 | fields_factory=_hunyuan_field_config,
108 | )
109 |
110 | return super()._parse_image_data(data)
111 |
112 |
113 | class HunyuanImageEmbedInputs(TypedDict):
114 | type: Literal["image_embeds"]
115 | data: Union[torch.Tensor, list[torch.Tensor]]
116 | """
117 | A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
118 | or a list of tensors of shape `(total_image_feature_size, hidden_size)`
119 |
120 | `hidden_size` must match the hidden size of language model backbone.
121 | """
122 |
123 | image_grid_thw: torch.Tensor
124 |
125 |
126 | class BaseHunyuanProcessor(ABC):
127 |
128 | def __init__(
129 | self,
130 | config: PretrainedConfig,
131 | tokenizer: AnyTokenizer,
132 | ) -> None:
133 | self.config = config
134 | self.tokenizer = tokenizer
135 | self.num_image_token = config.num_image_token
136 |
137 | @property
138 | @abstractmethod
139 | def image_token_id(self) -> int:
140 | raise NotImplementedError
141 |
142 | @abstractmethod
143 | def get_image_replace(
144 | self,
145 | ) -> PromptUpdateDetails[str]:
146 | raise NotImplementedError
147 |
148 | def __call__(
149 | self,
150 | text: Optional[Union[str, list[str]]] = None,
151 | images: Optional[Union[Image.Image, list[Image.Image]]] = None,
152 | return_tensors: Optional[Union[str, TensorType]] = None,
153 | ) -> Mapping[str, NestedTensors]:
154 | if text is None:
155 | text = []
156 | if not isinstance(text, list):
157 | text = [text]
158 |
159 | if images is not None:
160 | raise NotImplementedError("Image processing not implemented")
161 |
162 | text_inputs = self.tokenizer(text)
163 |
164 | output = {
165 | **BatchEncoding(text_inputs, tensor_type=return_tensors),
166 | }
167 | return output
168 |
169 |
170 | class HunyuanProcessor(BaseHunyuanProcessor):
171 |
172 | @property
173 | def image_token_id(self) -> int:
174 | image_token_id = self.tokenizer.get_vocab()[IMG_CONTEXT]
175 | return image_token_id
176 |
177 | def get_image_replace(
178 | self,
179 | ) -> PromptUpdateDetails[str]:
180 | replace_features = IMG_CONTEXT * self.num_image_token
181 | replace_full = IMG_START + replace_features + IMG_END
182 |
183 | return PromptUpdateDetails.select_text(replace_full, IMG_CONTEXT)
184 | # return PromptUpdateDetails(full=replace_full, features=replace_features)
185 |
186 |
187 | class BaseHunyuanProcessingInfo(BaseProcessingInfo):
188 |
189 | @abstractmethod
190 | def get_hf_processor(
191 | self,
192 | **kwargs: object,
193 | ) -> BaseHunyuanProcessor:
194 | raise NotImplementedError
195 |
196 | def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
197 | return {"image": None}
198 |
199 | def get_mm_max_tokens_per_item(
200 | self,
201 | seq_len: int,
202 | mm_counts: Mapping[str, int],
203 | ) -> Mapping[str, int]:
204 | return {"image": self.get_max_image_tokens()}
205 |
206 | def get_max_image_tokens(self) -> int:
207 | processor = self.get_hf_processor()
208 | num_image_token = processor.num_image_token
209 | return num_image_token
210 |
211 |
212 | _I = TypeVar("_I", bound=BaseHunyuanProcessingInfo)
213 |
214 |
215 | class HunyuanDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
216 |
217 | def get_dummy_processor_inputs(
218 | self,
219 | seq_len: int,
220 | mm_counts: Mapping[str, int],
221 | ) -> ProcessorInputs:
222 | num_images = mm_counts.get("image", 0)
223 |
224 | num_image_token = self.info.get_hf_processor().num_image_token
225 | hidden_size = self.info.get_hf_processor().config.hidden_size
226 |
227 | grid_hw = int((math.sqrt(4 * num_image_token - 7) - 1) / 2)
228 |
229 | grid_thw = torch.tensor([[1, grid_hw, grid_hw]])
230 | grid_thw = grid_thw.repeat(num_images, 1)
231 |
232 | mm_data = {
233 | "image": {
234 | "image_embeds": torch.randn(
235 | num_images,
236 | num_image_token,
237 | hidden_size,
238 | dtype=torch.bfloat16,
239 | ),
240 | "image_grid_thw": grid_thw,
241 | }
242 | }
243 |
244 | return ProcessorInputs(
245 | prompt_text="" * num_images,
246 | mm_data=mm_data,
247 | )
248 |
249 |
250 | class HunyuanMultiModalProcessor(BaseMultiModalProcessor[_I]):
251 |
252 | def _call_hf_processor(
253 | self,
254 | prompt: str,
255 | mm_data: Mapping[str, object],
256 | mm_kwargs: Mapping[str, object],
257 | ) -> Mapping[str, NestedTensors]:
258 | processed_outputs = super()._call_hf_processor(
259 | prompt=prompt,
260 | mm_data=mm_data,
261 | mm_kwargs=mm_kwargs,
262 | )
263 | return processed_outputs
264 |
265 | def _get_data_parser(self) -> HunyuanMultiModalDataParser:
266 | return HunyuanMultiModalDataParser()
267 |
268 | def _get_mm_fields_config(
269 | self,
270 | hf_inputs: Mapping[str, NestedTensors],
271 | hf_processor_mm_kwargs: Mapping[str, object],
272 | ) -> Mapping[str, MultiModalFieldConfig]:
273 | return _hunyuan_field_config(hf_inputs)
274 |
275 | def _get_prompt_updates(
276 | self,
277 | mm_items: MultiModalDataItems,
278 | hf_processor_mm_kwargs: Mapping[str, object],
279 | out_mm_kwargs: MultiModalKwargs,
280 | ) -> Sequence[PromptUpdate]:
281 | hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
282 |
283 | image_replace = hf_processor.get_image_replace()
284 |
285 | return [
286 | PromptReplacement(
287 | modality="image",
288 | target="",
289 | replacement=image_replace,
290 | ),
291 | ]
292 |
293 |
294 | class HunyuanProcessingInfo(BaseHunyuanProcessingInfo):
295 |
296 | def get_hf_processor(
297 | self,
298 | **kwargs: object,
299 | ) -> HunyuanProcessor:
300 | return self.ctx.init_processor(
301 | HunyuanProcessor,
302 | config=self.get_hf_config(),
303 | tokenizer=self.get_tokenizer(),
304 | **kwargs,
305 | )
306 |
307 |
308 | @MULTIMODAL_REGISTRY.register_processor(
309 | HunyuanMultiModalProcessor,
310 | info=HunyuanProcessingInfo,
311 | dummy_inputs=HunyuanDummyInputsBuilder,
312 | )
313 | class HunyuanVideoModel(nn.Module, SupportsMultiModal, SupportsPP):
314 |
315 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
316 | super().__init__()
317 | config = vllm_config.model_config.hf_config
318 |
319 | self.config = config
320 |
321 | self.language_model = init_vllm_registered_model(
322 | vllm_config=vllm_config,
323 | prefix=maybe_prefix(prefix, "language_model"),
324 | architectures=["HunYuanForCausalLM"],
325 | )
326 |
327 | self.system_message = None
328 | self.num_samples = 0
329 |
330 | @cached_property
331 | def sampler(self):
332 | if hasattr(self.language_model, "sampler"):
333 | return self.language_model.sampler
334 | else:
335 | raise NotImplementedError
336 |
337 | def get_input_embeddings(
338 | self,
339 | input_ids: torch.Tensor,
340 | multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
341 | ) -> torch.Tensor:
342 |
343 | inputs_embeds = self.language_model.get_input_embeddings(input_ids)
344 | if multimodal_embeddings is not None:
345 | inputs_embeds = merge_multimodal_embeddings(
346 | input_ids,
347 | inputs_embeds,
348 | multimodal_embeddings,
349 | self.config.image_token_id,
350 | )
351 | return inputs_embeds
352 |
353 | def get_multimodal_embeddings(
354 | self, **kwargs: object
355 | ) -> Optional[MultiModalEmbeddings]:
356 | image_input = self._parse_and_validate_image_input(**kwargs)
357 |
358 | if image_input is None:
359 | return None
360 |
361 | return image_input["data"]
362 |
363 | def get_language_model(self) -> torch.nn.Module:
364 | return self.language_model
365 |
366 | def _parse_and_validate_image_input(
367 | self, **kwargs: object
368 | ) -> Optional[HunyuanImageEmbedInputs]:
369 | image_embeds = kwargs.pop("image_embeds", None)
370 | image_grid_thw = kwargs.pop("image_grid_thw", None)
371 |
372 | if image_embeds is None:
373 | return None
374 |
375 | if not isinstance(image_embeds, (torch.Tensor, list)):
376 | raise ValueError(
377 | "Incorrect type of image embeddings. "
378 | f"Got type: {type(image_embeds)}"
379 | )
380 |
381 | image_embeds = image_embeds.to(self.config.torch_dtype)
382 |
383 | return HunyuanImageEmbedInputs(
384 | type="image_embeds",
385 | data=flatten_bn(image_embeds),
386 | image_grid_thw=flatten_bn(image_grid_thw),
387 | )
388 |
389 | def _process_image_input(
390 | self, image_input: HunyuanImageEmbedInputs
391 | ) -> MultiModalEmbeddings:
392 | grid_thw = image_input["image_grid_thw"]
393 | assert grid_thw.ndim == 2
394 |
395 | if image_input["type"] == "image_embeds":
396 | image_embeds = image_input["data"]
397 |
398 | merge_size = 1 # TODO: Check this
399 | sizes = grid_thw.prod(-1) // merge_size // merge_size
400 |
401 | return image_embeds.split(sizes.tolist())
402 |
403 | def forward(
404 | self,
405 | input_ids: torch.Tensor,
406 | positions: torch.Tensor,
407 | intermediate_tensors: Optional[IntermediateTensors] = None,
408 | inputs_embeds: Optional[torch.Tensor] = None,
409 | **kwargs: object,
410 | ) -> Union[SamplerOutput, IntermediateTensors]:
411 |
412 | if intermediate_tensors is not None:
413 | input_ids = None
414 | inputs_embeds = None
415 | elif inputs_embeds is None:
416 | # raise ValueError(f"v0 not supported, {kwargs}")
417 | vision_embeddings = self.get_multimodal_embeddings(**kwargs)
418 | inputs_embeds = self.get_input_embeddings(input_ids,
419 | vision_embeddings)
420 | input_ids = None
421 |
422 | hidden_states = self.language_model.model(
423 | input_ids=input_ids,
424 | positions=positions,
425 | intermediate_tensors=intermediate_tensors,
426 | inputs_embeds=inputs_embeds,
427 | **kwargs,
428 | )
429 |
430 |
431 | return hidden_states
432 |
433 | def compute_logits(
434 | self,
435 | hidden_states: torch.Tensor,
436 | sampling_metadata: SamplingMetadata,
437 | ) -> Optional[torch.Tensor]:
438 | return self.language_model.compute_logits(
439 | hidden_states, sampling_metadata
440 | )
441 |
442 | def sample(
443 | self,
444 | logits: torch.Tensor,
445 | sampling_metadata: SamplingMetadata,
446 | ) -> Optional[SamplerOutput]:
447 | return self.language_model.sample(logits, sampling_metadata)
448 |
449 | def load_weights(
450 | self, weights: Iterable[Tuple[str, torch.Tensor]]
451 | ) -> Set[str]:
452 | loader = AutoWeightsLoader(self)
453 | weights = list(weights)
454 | for i, (k, v) in enumerate(weights):
455 | # The order of SiLU and Mul is different in VLLM
456 | if ".mlp.gate_and_up_proj.weight" in k:
457 | v1, v2 = v.chunk(2, dim=0)
458 | weights[i] = (k, torch.cat([v2, v1], dim=0))
459 |
460 | # Filter out weights that are not in the language model (vit, whisper, mlp2)
461 | weights = [(k, v) for k, v in weights if k.startswith("language_model")]
462 |
463 | if "language_model.lm_head.weight" not in weights:
464 | logger.warning(
465 | "langauge.lm_head.weight not found in weights, "
466 | "will try to load it from language_model.embed_tokens.weight"
467 | )
468 | weights.append(("language_model.lm_head.weight", self.language_model.model.embed_tokens.weight))
469 |
470 | return loader.load_weights(weights)
471 |
--------------------------------------------------------------------------------
/model_vllm/monkey_patch_mrope.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import importlib
3 |
4 | # Module where the original MRotaryEmbedding is defined
5 | VLLM_ROTARY_EMBEDDING_MODULE = "vllm.model_executor.layers.rotary_embedding"
6 |
7 | # Path to your custom class
8 | # Adjust the import path if your project structure is different
9 | YOUR_CUSTOM_MODULE = "model_vllm.hunyuan"
10 | YOUR_CUSTOM_CLASS_NAME = "DynamicNTKAlphaMRotaryEmbedding"
11 |
12 | try:
13 | # Import the vLLM module
14 | vllm_rotary_module = importlib.import_module(VLLM_ROTARY_EMBEDDING_MODULE)
15 |
16 | # Import your custom class
17 | custom_module = importlib.import_module(YOUR_CUSTOM_MODULE)
18 | CustomRotaryEmbeddingClass = getattr(custom_module, YOUR_CUSTOM_CLASS_NAME)
19 |
20 | # Perform the monkey patch:
21 | # Replace the MRotaryEmbedding in the vLLM module with your class
22 | setattr(vllm_rotary_module, "MRotaryEmbedding", CustomRotaryEmbeddingClass)
23 |
24 | print(f"Successfully monkey-patched 'MRotaryEmbedding' in '{VLLM_ROTARY_EMBEDDING_MODULE}' "
25 | f"with '{YOUR_CUSTOM_CLASS_NAME}' from '{YOUR_CUSTOM_MODULE}'.")
26 |
27 | except ImportError as e:
28 | print(f"Error during monkey patching: Could not import modules. {e}")
29 | print("Please ensure that vLLM is installed and your custom module path is correct.")
30 | except AttributeError as e:
31 | print(f"Error during monkey patching: Could not find class/attribute. {e}")
32 | print("Please ensure class names and module contents are correct.")
33 | except Exception as e:
34 | print(f"An unexpected error occurred during monkey patching: {e}")
35 |
36 |
--------------------------------------------------------------------------------
/model_vllm/setup_vllm_env.sh:
--------------------------------------------------------------------------------
1 | eval "$(conda shell.bash hook)"
2 | conda create -y -n arc python=3.10
3 | conda activate arc
4 |
5 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
6 | pip install -r model_vllm/requirements.txt
7 | conda install -y ffmpeg
8 | pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl
9 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
10 |
11 |
12 | export VLLM_PRECOMPILED_WHEEL_LOCATION=$(pwd)/model_vllm/vllm/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
13 | export VLLM_VERSION=v0.8.5.post1-1-gbed41f50d
14 | pip install --editable model_vllm/vllm/
15 |
--------------------------------------------------------------------------------
/model_vllm/video_audio_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 | from transformers.modeling_utils import no_init_weights
6 | from transformers import WhisperFeatureExtractor, WhisperModel, AutoConfig
7 | import sys
8 | import os
9 |
10 | from transformers import ARCHunyuanVideoVisionModel, ARCHunyuanVideoAudioEncoder
11 |
12 | class VideoAudioEncoder(nn.Module):
13 | def __init__(self, config, max_num_frames=150):
14 | super().__init__()
15 | self.max_num_frames = max_num_frames
16 |
17 | config.vision_config._attn_implementation = "flash_attention_2"
18 | config.audio_config._attn_implementation = "flash_attention_2"
19 |
20 | with no_init_weights():
21 | # Initialize vision model
22 | self.vision_model = ARCHunyuanVideoVisionModel(
23 | vision_config=config.vision_config,
24 | text_config=config.text_config,
25 | )
26 |
27 | self.speech_encoder = ARCHunyuanVideoAudioEncoder(
28 | config=config.audio_config,
29 | )
30 |
31 | self.speech_dim = config.audio_config.d_model
32 |
33 | llm_hidden_size = config.text_config.hidden_size
34 |
35 | self.mlp2 = nn.Sequential(
36 | nn.LayerNorm(self.speech_dim),
37 | nn.Linear(self.speech_dim, llm_hidden_size),
38 | nn.GELU(),
39 | nn.Linear(llm_hidden_size, llm_hidden_size),
40 | )
41 |
42 | @torch.no_grad()
43 | def extract_image_feature(self, pixel_values):
44 | """Extract features from image tensors using vision model"""
45 | vit_embeds = self.vision_model(pixel_values)
46 | return vit_embeds
47 |
48 | @torch.no_grad()
49 | def extract_audio_feature(self, audio_values):
50 | """Extract features from audio tensors using speech encoder"""
51 | audio_values = audio_values.squeeze(0).reshape(
52 | -1, 128, audio_values.shape[-1]
53 | )
54 | num_segments = audio_values.shape[0]
55 |
56 | speech_embeds = self.speech_encoder(
57 | audio_values, return_dict=True
58 | ).last_hidden_state
59 |
60 | speech_embeds = speech_embeds.reshape(1, -1, speech_embeds.shape[-1])
61 | speech_embeds = self.mlp2(speech_embeds)
62 | return num_segments, speech_embeds
63 |
64 | def create_mixed_embeddings(self, vit_embeds, audio_embeds, duration):
65 | """Create mixed embeddings from visual and audio features"""
66 | # Reshape audio embeddings to match video frames
67 | audio_embeds = audio_embeds.reshape(
68 | audio_embeds.shape[0], -1, 50, audio_embeds.shape[-1]
69 | )
70 | audio_embeds_no_pad = audio_embeds[:, :duration].squeeze(0)
71 |
72 | max_num_frame = self.max_num_frames
73 |
74 | # Handle case where audio duration exceeds max number of frames
75 | if duration > max_num_frame:
76 | per_audio_tokens = math.ceil(
77 | audio_embeds_no_pad.shape[0] / max_num_frame * 50
78 | )
79 | num_audio_tokens_sum = per_audio_tokens * max_num_frame
80 | audio_embeds_no_pad = audio_embeds_no_pad.reshape(
81 | -1, audio_embeds_no_pad.shape[-1]
82 | )
83 |
84 | if num_audio_tokens_sum != audio_embeds_no_pad.shape[0]:
85 | zero_padding = (
86 | torch.zeros(
87 | num_audio_tokens_sum - audio_embeds_no_pad.shape[0],
88 | audio_embeds_no_pad.shape[-1],
89 | )
90 | .to(audio_embeds_no_pad.dtype)
91 | .to(audio_embeds_no_pad.device)
92 | )
93 | audio_embeds_no_pad = torch.cat(
94 | (audio_embeds_no_pad, zero_padding), dim=0
95 | )
96 |
97 | audio_embeds_no_pad = audio_embeds_no_pad.reshape(
98 | max_num_frame, -1, audio_embeds_no_pad.shape[-1]
99 | )
100 |
101 | # Pad or trim to match the visual embedding shape
102 | padding_size = vit_embeds.shape[1] - audio_embeds_no_pad.shape[1]
103 | if padding_size != 0:
104 | zero_padding = (
105 | torch.zeros(
106 | vit_embeds.shape[0],
107 | padding_size,
108 | audio_embeds_no_pad.shape[-1],
109 | )
110 | .to(audio_embeds_no_pad.dtype)
111 | .to(audio_embeds_no_pad.device)
112 | )
113 | audio_embeds_pad = torch.cat(
114 | (audio_embeds_no_pad, zero_padding), dim=1
115 | )
116 | else:
117 | audio_embeds_pad = audio_embeds_no_pad
118 |
119 | mixed_embeds = vit_embeds + audio_embeds_pad
120 |
121 | return mixed_embeds
122 |
123 | def forward(self, pixel_values, audio_values, duration):
124 | """
125 | Encode images and audio to create mixed embeddings
126 |
127 | Args:
128 | pixel_values (torch.Tensor): Batch of images from video (processed frames)
129 | audio_values (torch.Tensor): Processed audio features
130 | duration (int): Duration of the video in frames or seconds
131 |
132 | Returns:
133 | mixed_embeds (torch.Tensor): Mixed embeddings combining vision and audio
134 | """
135 |
136 | # Extract features
137 | vit_embeds = self.extract_image_feature(pixel_values)
138 |
139 | _, audio_embeds = self.extract_audio_feature(audio_values)
140 |
141 | # Create mixed embeddings
142 | mixed_embeds = self.create_mixed_embeddings(
143 | vit_embeds, audio_embeds, duration
144 | )
145 |
146 | return mixed_embeds
147 |
148 |
--------------------------------------------------------------------------------
/model_vllm/video_audio_llm.py:
--------------------------------------------------------------------------------
1 | # This create a class to use the model in vllm with mm encoder + llm
2 | # The inputs should be preprocessed
3 | # The mm encoder will process input sequentially and llm will do batch inference
4 | # This may be less efficient, but much more clear and easy to use
5 |
6 | import os
7 | import json
8 | from pathlib import Path
9 | import tempfile
10 | import shutil
11 |
12 | import torch
13 | from huggingface_hub import snapshot_download
14 |
15 | from transformers import AutoTokenizer, AutoConfig, PretrainedConfig
16 | from vllm import LLM, SamplingParams
17 | from safetensors.torch import load_file as safetensors_load_file
18 |
19 | from model_vllm import VideoAudioEncoder
20 |
21 |
22 | def convert_config_to_legacy(config, max_model_len):
23 | legacy_config = PretrainedConfig()
24 |
25 | legacy_config.update(config.vision_config.to_dict())
26 | legacy_config.update(config.text_config.to_dict())
27 |
28 | force_image_size = config.vision_config.force_image_size
29 | num_image_token = int(
30 | (force_image_size / 64)
31 | * (force_image_size / 64 + 1)
32 | + 2
33 | )
34 |
35 | legacy_config.update({
36 | # Such that vllm can caculate the max_model_len correctly
37 | "architectures": ["HunyuanVideoModel"],
38 | "image_token_id": config.text_config.image_token_id,
39 | "vision_start_token_id": config.text_config.im_start_id,
40 | "num_image_token": num_image_token,
41 | "rope_scaling": {
42 | "alpha": 1000.0,
43 | "beta_fast": 32,
44 | "beta_slow": 1,
45 | "factor": 1000.0,
46 | "mscale": 1.0,
47 | "mscale_all_dim": 1.0,
48 | "rope_type": "dynamic",
49 | "mrope_section": [0.25, 0.25, 0.25, 0.25],
50 | },
51 | "max_model_len": max_model_len,
52 | })
53 |
54 | if hasattr(legacy_config, "torch_dtype") and isinstance(legacy_config.torch_dtype, str):
55 | # Convert string torch_dtype to torch.dtype
56 | legacy_config.torch_dtype = getattr(torch, legacy_config.torch_dtype)
57 |
58 | return legacy_config
59 |
60 |
61 | def load_state_dict_from_safetensors(path: str, prefixes: list[str]):
62 | def filter_dict_with_k_prefix(d, prefixes):
63 | return {
64 | k: v
65 | for k, v in d.items()
66 | if any(k.startswith(prefix) for prefix in prefixes)
67 | }
68 |
69 | index_path = os.path.join(path, "model.safetensors.index.json")
70 | if not os.path.exists(index_path):
71 | print(f"Index file {index_path} does not exist, loading all weights")
72 | pre_trained_dir = Path(path)
73 | weights_files = sorted(pre_trained_dir.glob("model-*.safetensors"))
74 | else:
75 | weight_map = json.load(open(index_path))["weight_map"]
76 | weights_files = set(
77 | filter_dict_with_k_prefix(weight_map, prefixes).values()
78 | )
79 | weights_files = [os.path.join(path, f) for f in weights_files]
80 |
81 | if len(weights_files) == 0:
82 | raise ValueError(
83 | f"No weights files found in {path} with prefixes {prefixes}"
84 | )
85 |
86 | state_dict = {}
87 | for file in weights_files:
88 | part_state_dict = safetensors_load_file(file)
89 | state_dict.update(part_state_dict)
90 |
91 | state_dict = filter_dict_with_k_prefix(state_dict, prefixes)
92 | return state_dict
93 |
94 |
95 | class VideoAudioLLM:
96 | def __init__(
97 | self,
98 | model_path,
99 | device_enc="cuda",
100 | device_llm="cuda",
101 | **kwargs,
102 | ):
103 | if not os.path.isdir(model_path):
104 | model_path = snapshot_download(repo_id=model_path)
105 |
106 | self.config = AutoConfig.from_pretrained(model_path)
107 | self.device_enc = device_enc
108 | self.device_llm = device_llm
109 |
110 | self.llm, self.sampling_params = self.init_llm(model_path, self.config, self.device_llm, **kwargs)
111 |
112 | self.mm_encoder = self.init_mm_encoder(model_path, self.config, self.device_enc)
113 |
114 |
115 | def init_mm_encoder(self, model_path, config, device):
116 | multi_modal_state_dict = load_state_dict_from_safetensors(
117 | model_path, ("vision_model.", "mlp2.", "speech_encoder.")
118 | )
119 |
120 | multi_modal_encoder = VideoAudioEncoder(
121 | config,
122 | max_num_frames=config.max_num_frame,
123 | )
124 |
125 | missing, unexpected = multi_modal_encoder.load_state_dict(
126 | multi_modal_state_dict, strict=False
127 | )
128 | assert len(missing) == 0, f"Missing keys in mm encoder: {missing}"
129 | assert (
130 | len(unexpected) == 0
131 | ), f"Unexpected keys in mm encoder: {unexpected}"
132 |
133 | multi_modal_encoder.eval()
134 | multi_modal_encoder.to(device)
135 |
136 | return multi_modal_encoder
137 |
138 | def init_llm(self, model_path, config, device, **kwargs):
139 |
140 | if self.device_enc != self.device_llm:
141 | gpu_memory_utilization = 0.9
142 | else: # Reserve memory for the encoder
143 | gpu_memory_utilization = 0.6
144 |
145 | max_model_len = 20480
146 |
147 | llm = LLM(
148 | model=model_path,
149 | tokenizer=model_path,
150 | trust_remote_code=True,
151 | max_model_len=max_model_len,
152 | max_seq_len_to_capture=max_model_len,
153 | dtype="bfloat16",
154 | hf_overrides=lambda x: convert_config_to_legacy(x, max_model_len),
155 | limit_mm_per_prompt={"image": 150},
156 | enforce_eager=False,
157 | disable_mm_preprocessor_cache=True,
158 | enable_prefix_caching=False,
159 | device=device,
160 | gpu_memory_utilization=gpu_memory_utilization,
161 | )
162 |
163 | sampling_params = SamplingParams(
164 | **kwargs,
165 | )
166 |
167 | return llm, sampling_params
168 |
169 | def forward_mm_encoder(self, batch):
170 | """
171 | This function will process the batch of data in the mm encoder
172 | Input:
173 | - batch: list of dicts, each dict contains the following keys:
174 | - pixel_values: torch.Tensor
175 | - audio_values: torch.Tensor
176 | - duration: float
177 |
178 | Output:
179 | - list of dicts, each dict contains the following keys:
180 | - embeddings: torch.Tensor
181 | - other keys in the original dict
182 | """
183 | device = self.device_enc
184 | ret = []
185 |
186 | for data in batch:
187 | pixel_values = data["pixel_values"]
188 | audio_values = data["audio_values"]
189 | duration = data["duration"]
190 |
191 | with torch.no_grad(), torch.autocast(device, torch.bfloat16):
192 | pixel_values = pixel_values.to(
193 | device=device, dtype=torch.bfloat16, non_blocking=True
194 | )
195 | audio_values = audio_values.to(
196 | device=device, dtype=torch.bfloat16, non_blocking=True
197 | )
198 |
199 | mixed_embeds = self.mm_encoder(
200 | pixel_values, audio_values, duration
201 | )
202 |
203 | mixed_embeds = mixed_embeds.to(device="cpu").float().share_memory_()
204 | ret.append({"embeddings": mixed_embeds, **data})
205 |
206 |
207 | return ret
208 |
209 | def forward_llm(self, batch):
210 | num_patches = (
211 | self.config.vision_config.force_image_size // 32 // 2
212 | )
213 | image_grid_thw = torch.tensor([[1, num_patches, num_patches + 1]])
214 | prompts = [
215 | {
216 | "prompt": "<|startoftext|>" + item["text_prompt"],
217 | "multi_modal_data": {
218 | "image": {
219 | "image_embeds": item["embeddings"],
220 | "image_grid_thw": image_grid_thw.repeat(
221 | item["embeddings"].shape[0], 1
222 | ),
223 | }
224 | },
225 | }
226 | for item in batch
227 | ]
228 |
229 | outputs = self.llm.generate(prompts, self.sampling_params, use_tqdm=False)
230 |
231 | ret = []
232 |
233 | for data, output in zip(batch, outputs):
234 | if "output" in data:
235 | raise ValueError("Check the batch, there is a key called output")
236 |
237 | ret.append({"output": output.outputs[0].text, **data})
238 |
239 | return ret
240 |
241 | def __call__(self, batch):
242 | if isinstance(batch, dict):
243 | batch = [batch]
244 |
245 | ret = self.forward_mm_encoder(batch)
246 | ret = self.forward_llm(ret)
247 | return ret
248 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tokenizers==0.21
2 | sentencepiece
3 | shortuuid
4 | accelerate
5 | peft>=0.4.0
6 | bitsandbytes==0.41.0
7 | pydantic
8 | markdown2[all]
9 | numpy<2
10 | scikit-learn>=1.2.2
11 | gradio==3.35.2
12 | gradio_client==0.2.9
13 | requests
14 | httpx==0.24.0
15 | uvicorn
16 | fastapi
17 | deepspeed==0.14.4
18 | einops
19 | einops-exts
20 | timm==0.9.12
21 | decord
22 | tiktoken
23 | librosa
24 | datasets
25 | opencv-python
26 | imageio[ffmpeg]
27 | tensorboardX
28 | av
29 | moviepy==1.0.3
30 |
--------------------------------------------------------------------------------
/scripts/arc_hunyuan_video_full_finetune.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | export NCCL_IB_GID_INDEX=3
4 | export NCCL_IB_SL=3
5 | export NCCL_CHECKS_DISABLE=1
6 | export NCCL_P2P_DISABLE=0
7 | export NCCL_IB_DISABLE=0
8 | export NCCL_LL_THRESHOLD=16384
9 | export NCCL_IB_CUDA_SUPPORT=1
10 | export NCCL_SOCKET_IFNAME=bond1
11 | export UCX_NET_DEVICES=bond1
12 | export NCCL_IB_HCA=mlx5
13 | export NCCL_COLLNET_ENABLE=0
14 | export SHARP_COLL_ENABLE_SAT=0
15 | export NCCL_NET_GDR_LEVEL=2
16 | export NCCL_IB_QPS_PER_CONNECTION=4
17 | export NCCL_IB_TC=160
18 | export NCCL_PXN_DISABLE=1
19 | export GLOO_SOCKET_IFNAME=bond1
20 | export NCCL_DEBUG=info
21 |
22 | export PYTHONPATH="${PYTHONPATH}:$(pwd)"
23 | export MASTER_PORT=8005
24 | export TF_CPP_MIN_LOG_LEVEL=3
25 | export LAUNCHER=pytorch
26 |
27 | OUTPUT_DIR='work_dirs/brief_summary_sft'
28 | if [ ! -d "$OUTPUT_DIR" ]; then
29 | mkdir -p "$OUTPUT_DIR"
30 | fi
31 |
32 | NODE_RANK=$1
33 | torchrun --nproc_per_node=2 \
34 | model_train/train/arc_hunyuan_video_finetune.py \
35 | --model_name_or_path "TencentARC/ARC-Hunyuan-Video-7B" \
36 | --conv_style "hunyuan" \
37 | --output_dir ${OUTPUT_DIR} \
38 | --meta_path "sft_data/sft_jb_sp_kd_10.json" \
39 | --overwrite_output_dir True \
40 | --force_image_size 640 \
41 | --num_image_token 112 \
42 | --freeze_llm False \
43 | --freeze_speech_encoder True \
44 | --freeze_backbone True \
45 | --dataloader_num_workers 4 \
46 | --bf16 True \
47 | --num_train_epochs 4 \
48 | --per_device_train_batch_size 1 \
49 | --gradient_accumulation_steps 1 \
50 | --save_strategy "steps" \
51 | --save_steps 500 \
52 | --save_total_limit 100 \
53 | --learning_rate 1e-5 \
54 | --warmup_steps 100 \
55 | --weight_decay 0.01 \
56 | --warmup_ratio 0.03 \
57 | --lr_scheduler_type "cosine" \
58 | --logging_steps 1 \
59 | --max_seq_length 20000 \
60 | --max_num_frame 150 \
61 | --do_train True \
62 | --grad_checkpoint True \
63 | --dynamic_image_size False \
64 | --normalize_type hunyuan \
65 | --seed 42 \
66 | --deepspeed "config/zero_stage3_config.json" \
67 | --report_to "tensorboard" \
68 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
69 |
--------------------------------------------------------------------------------
/sft_data/audios_mp3/a3545skvqbz.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/a3545skvqbz.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/c3522vbgwaw.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/c3522vbgwaw.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/e3556d48uo0.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/e3556d48uo0.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/q35134y59x8.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/q35134y59x8.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/u3519j52lb4.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/u3519j52lb4.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/v3524wr6l4l.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/v3524wr6l4l.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/w33698kgs05.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/w33698kgs05.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/x3551nmkn8o.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/x3551nmkn8o.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/x3555e2g3t8.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/x3555e2g3t8.mp3
--------------------------------------------------------------------------------
/sft_data/audios_mp3/z1468vawe14.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/ARC-Hunyuan-Video-7B/a061f1bf6a4d75f032759fec69c80c71077c3d0c/sft_data/audios_mp3/z1468vawe14.mp3
--------------------------------------------------------------------------------
/sft_data/sft_jb_sp_abs_10.jsonl:
--------------------------------------------------------------------------------
1 | {"id": 0, "video": "w33698kgs05.mp4", "conversations": [{"from": "human", "value": "你是一个视频内容总结助手,你需要按照如下规则根据视频及其标题生成简要描述:\n1. 请以不超过100字的文本简要描述视频的主要内容,请仅准确反映视频中的核心内容,不要引入任何视频中未出现的信息\n2. 请在简要描述中保留视频中的核心人物、事件、场景及可能引人关注的信息\n视频为: