├── README.md
├── figures
├── arch.png
├── result-dfec.png
├── result-emotion.png
└── result-mvbench.png
├── humanomni
├── __init__.py
├── constants.py
├── conversation.py
├── conversation_llava.py
├── eval
│ ├── eval_mafw_dfew.py
│ ├── eval_video_mcqa_mvbench.py
│ ├── inference_dfec.py
│ └── inference_video_mcqa_mvbench.py
├── humanomni_trainer.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── encoder.py
│ ├── humanomni_arch.py
│ ├── humanomni_model.py
│ └── projector.py
├── train_flash_attn.py
├── train_humanomni.py
└── utils.py
├── inference.py
├── requirements.txt
└── scripts
├── eval
└── eval_video_mcqa_mvbench.sh
├── train
└── finetune_omni.sh
└── zero3.json
/README.md:
--------------------------------------------------------------------------------
1 | # HumanOmni: A Large Vision-Speech Language Model for Human-Centric Video Understanding
2 |
3 | [](https://modelscope.cn/models/iic/HumanOmni-7B)
4 | [](https://huggingface.co/StarJiaxing/HumanOmni-7B)
5 | [](https://arxiv.org/abs/2501.15111)
6 |
7 |
8 |
9 |
10 |
11 | ## [](#news)News
12 | 1) Building upon HumanOmni, we are the **first to combine RLVR (Reinforcement Learning for Vision and Reasoning) with an Omni model**, introducing [R1-Omni](https://github.com/HumanMLLM/R1-Omni), a reasoning-based large multimodal model.
13 |
14 |
15 | ## 📖 Introduction
16 | **HumanOmni** is the industry’s first human-centric Omni-multimodal large language model for comprehensive understanding in human-centric scenes.
17 | 1) **2.4M human-centric video clips with over 14M double-check instructions**: We have constructed a dataset containing over 2.4M human-centric video clips, providing rich and detailed information about individuals. We provide over 14M instruction data for visual pretraining.
18 | 2) **50K video clips with more than 100K manually annotated instrcutions**: We have manually annotated 50K video clips with more than 100K instructions related to emotion recognition, facial description, and speaker-specific speech recognition for visual fine-tuning and cross-modal interaction integration.
19 | 3) **Three human-specific branch**: We use three branches to handle face-related, body-related, and interaction-related scenes separately in HumanOmni. HumanOmni dynamically adjusts its fusion weights based on input instructions, ensuring accurate responses across various scenes.
20 | 4) **Audio-visual synergy**: HumanOmni can simultaneously understand vision and speech, allowing for a more comprehensive understanding of complex scenes.
21 |
22 |
23 |
24 | ## 📦 Model Download
25 |
26 |
27 | | **Model** | **Stage** | **#Params** | **HuggingFace** | **ModelScope** |
28 | |------------------------|------------------------------------|-------------|---------------------------------------------------------------------------------|-------------------------------------------------------------------------|
29 | | `HumanOmni-Video` | Visual Capability Construction | 7B | [](https://hf.co/StarJiaxing/HumanOmni-7B-Video) | [](https://modelscope.cn/models/iic/HumanOmni-7B-Video) |
30 | | `HumanOmni-Audio` | Auditory Capability Development | 7B | [](https://hf.co/StarJiaxing/HumanOmni-7B-Audio) | [](https://modelscope.cn/models/iic/HumanOmni-7B-Audio) |
31 | | `HumanOmni-Omni` | Cross-Modal Interaction Integration | 7B | [](https://hf.co/StarJiaxing/HumanOmni-7B) | [](https://modelscope.cn/models/iic/HumanOmni-7B) |
32 |
33 |
34 |
35 | Our training pipeline consists of three progressive stages to establish multimodal understanding capabilities:
36 |
37 | 📹 Visual Capability Construction
38 |
39 | - Model: HumanOmni-Video
40 | - Objective: Learn spatio-temporal feature representations to analyze human actions and scene dynamics in videos.
41 |
42 | 🎧 Auditory Capability Development
43 |
44 | - Model: HumanOmni-Audio
45 | - Objective: Develop robust speech comprehension and audio interpretation through large-scale acoustic modeling.
46 |
47 | 🌐 Cross-Modal Interaction Integration
48 |
49 | - Model: HumanOmni-Omni (also referred to as HumanOmni)
50 | - Objective: Enable synergistic vision-audio reasoning by fine-tuning parameters from both HumanOmni-Video and HumanOmni-Audio.
51 |
52 |
53 | ## 🏆 Performance
54 |
55 | - Emotion Understanding:
56 |
57 | | Method | Modalities | DFEW (UAR) | DFEW (WAR) | MAFW (UAR) | MAFW (WAR) |
58 | |----------------------------------|------------|-------------|-------------|-------------|-------------|
59 | | **Specialized models for emotion-related tasks** | | | | | |
60 | | Wav2Vec2.0 | A | 36.15 | 43.05 | 21.59 | 29.69 |
61 | | HuBERT | A | 35.98 | 43.24 | 25.00 | 32.60 |
62 | | DFER-CLIP | V | 59.61 | 71.25 | 38.89 | 52.55 |
63 | | MAE-DFER | V | 63.41 | 74.43 | 41.62 | 54.31 |
64 | | HiCMAE | AV | 63.76 | 75.01 | 42.65 | 56.17 |
65 | | Emotion-LLaMA | AV | 64.21 | 77.06 | - | - |
66 | | MMA-DFER | AV | 66.85 | 77.43 | 44.25 | 58.45 |
67 | | **Other models** | | | | | |
68 | | Qwen2-VL-7B | V | 43.08 | 52.83 | 31.67 | 45.89 |
69 | | Qwen2-VL-72B | V | 39.24 | 45.12 | 42.61 | 46.07 |
70 | | VITA | AV | 21.36 | 32.07 | 14.05 | 33.38 |
71 | | InternLM-XComposer-2.5-OL | AV | 44.23 | 51.29 | 33.78 | 46.81 |
72 | | GPT4-O | AV | 50.57 | 57.19 | 38.29 | 48.82 |
73 | | **HumanOmni** | AV | **74.86** | **82.46** | **52.94** | **68.40** |
74 |
75 | - [Dynamic Facial Expression Caption](https://modelscope.cn/datasets/iic/DFEC):
76 |
77 | | Method | Correctness | Detail | Context | Temporal | CIDEr | Rouge-L | AutoDQ |
78 | |----------------------------------|-------------|--------|---------|----------|--------|---------|---------|
79 | | **Vision large language model** | | | | | | | |
80 | | VideoLLaMA | 3.60 | 3.67 | 3.84 | 3.50 | 0.189 | 0.196 | 0.303 |
81 | | VideoChat | 3.47 | 3.52 | 3.92 | 3.38 | 0.251 | 0.192 | 0.344 |
82 | | VideoChat2 | 3.70 | 3.56 | 4.16 | 3.52 | 0.202 | 0.229 | 0.311 |
83 | | Chat-UniVI | 3.64 | 3.63 | 4.21 | 3.61 | 0.189 | 0.231 | 0.396 |
84 | | LLaVA-Next-Video | 4.19 | 4.07 | 4.39 | 4.04 | 0.250 | 0.249 | 0.395 |
85 | | ShareGPT4Video | 4.24 | 4.13 | 4.35 | 4.09 | 0.192 | 0.205 | 0.394 |
86 | | LLaMA-VID | 3.95 | 4.01 | 4.22 | 3.71 | 0.195 | 0.231 | 0.339 |
87 | | VideoLLaMA2 | 4.17 | 4.02 | 4.47 | 3.93 | 0.253 | 0.266 | 0.344 |
88 | | PLLaVA | 4.21 | 4.15 | 4.37 | 4.08 | 0.268 | 0.250 | 0.393 |
89 | | ST-LLM | 4.00 | 3.98 | 4.31 | 3.94 | 0.213 | 0.238 | 0.321 |
90 | | Tarsier | 3.59 | 3.50 | 4.07 | 3.41 | 0.143 | 0.185 | 0.415 |
91 | | LLaVA-OneVision | 3.68 | 3.47 | 4.10 | 3.42 | 0.115 | 0.165 | 0.379 |
92 | | FaceTrack-MM | 4.42 | 4.30 | 4.60 | 4.26 | 0.418 | 0.473 | 0.483 |
93 | | Qwen2-VL-72B | 4.28 | 4.14 | 4.55 | 4.08 | 0.241 | 0.314 | 0.449 |
94 | | Qwen2-VL-7B | 4.23 | 4.16 | 4.52 | 4.02 | 0.204 | 0.233 | 0.422 |
95 | | Qwen2-VL-2B | 4.01 | 3.98 | 4.37 | 3.88 | 0.202 | 0.221 | 0.406 |
96 | | Claude3.5-Sonnet | 4.13 | 4.01 | 4.49 | 4.05 | 0.243 | 0.228 | 0.442 |
97 | | **Omni-modality large language model** | | | | | | | |
98 | | GPT4-O | 4.22 | 3.97 | 4.48 | 3.90 | 0.264 | 0.213 | 0.432 |
99 | | VITA | 3.98 | 3.74 | 4.11 | 3.59 | 0.191 | 0.224 | 0.366 |
100 | | InternLM-XComposer-2.5-OL | 3.91 | 3.70 | 4.12 | 3.54 | 0.113 | 0.164 | 0.382 |
101 | | **HumanOmni** | **4.58** | **4.41**| **4.70**| **4.41** | 0.412 | 0.468 | **0.523**|
102 |
103 | - Action and Pose Understanding:
104 |
105 | | Method | Action Sequence | Unexpected Action | Action Antonym | Object Interaction | Action Count | Fine-grained Action | Avg |
106 | |----------------------------------|-----|-----|-----|-----|-----|-----|------|
107 | | **Vision large language model** | | | | | | | |
108 | | Otter-V | 23.0| 29.5| 27.5| 28.0| 26.0| 27.0| 26.8 |
109 | | mPLUG-Owl-V | 22.0| 29.0| 34.0| 27.0| 31.5| 29.0| 28.8 |
110 | | Video-LLaMA | 27.5| 39.0| 51.0| 40.5| 34.0| 29.0| 36.8 |
111 | | LLaMA-Adapter | 23.0| 33.0| 51.0| 32.5| 29.0| 30.0| 33.1 |
112 | | Video-ChatGPT | 23.5| 26.5| 62.0| 28.0| 30.5| 22.5| 32.2 |
113 | | VideoChat | 33.5| 40.5| 56.0| 40.5| 35.0| 33.5| 39.8 |
114 | | VideoChat2 | 75.5| 60.5| 83.5| 74.5| 37.0| 50.5| 63.6 |
115 | | ST-LLM | 66.0| 58.5| 84.0| 73.5| 36.5| 44.0| 60.4 |
116 | | PLLaVA | 58.0| 61.0| 55.5| 61.0| 39.5| 41.0| 52.6 |
117 | | VideoLLaMB | 54.5| 52.0| 86.5| 58.5| 40.5| 44.5| 56.1 |
118 | | Qwen2-VL-72B* | 51.5| 82.0| 93.5| 81.5| 48.5| 49.0| 67.7 |
119 | | Qwen2-VL-7B* | 73.5| 80.0| 79.0| 78.5| 46.0| 49.0| 67.7 |
120 | | Qwen2-VL-2B* | 77.5| 76.5| 76.5| 77.5| 50.0| 47.5| 67.6 |
121 | | GPT-4V | 55.5| 63.5| 72.0| 59.0| 39.0| 47.5| 56.1 |
122 | | **Omni-modality large language model** | | | | | | | |
123 | | VITA | 58.0| 81.5| 73.5| 61.5| 45.5| 42.0| 60.3 |
124 | | InternLM-XComposer-2.5-OL | 84.5| 81.0| 75.0| 79.5| 60.5| 46.0| 71.1 |
125 | | **HumanOmni** | 70.0| 78.0| 92.5| 80.5| 65.5| 49.0| **72.6** |
126 |
127 | ## 📺 Demo
128 | https://github.com/user-attachments/assets/ecf07e98-f448-4cb6-a124-6e2a2885e581
129 | ```
130 | User:What's the major emotion?
131 | HumanOmni:angry.
132 |
133 | User:What's the major facial expression?
134 | HumanOmni:frown.
135 |
136 | User:Desribe the woman.
137 | HumanOmni:The woman has long brown hair and is wearing a pink top with a dark cardigan.
138 |
139 | User:Please describe the major action in detail.
140 | HumanOmni:The woman is frowning and shaking her head, indicating that she is angry. She is also gesturing with her hands as if she is explaining something or emphasizing a point.
141 |
142 | User:What did the woman say?
143 | HumanOmni:The woman said "Of course we object! I object, Steve!"
144 | ```
145 |
146 | https://github.com/user-attachments/assets/0f0d9c5c-6e1d-473f-8c6b-819ebdc8aae8
147 | ```
148 | User:What's the major emotion?
149 | HumanOmni:happy.
150 |
151 | User:What's the major facial expression?
152 | HumanOmni:smile.
153 |
154 | User:Desribe the woman in detail.
155 | HumanOmni:The woman in the video has blonde hair, is wearing a black top, and is smiling.
156 |
157 | User:What did the woman say?
158 | HumanOmni:The woman said "Dogs are sitting by the door."
159 | ```
160 |
161 | [https://humanmllm.github.io/HumanOmni/chinese_video.mp4](https://github.com/user-attachments/assets/65b49d37-7213-452d-a492-cfddbbc8b41d)
162 | ```
163 | User:描述一下视频中男子的外貌穿着.
164 | HumanOmni:视频中的男性有短的黑发,配戴圆框眼镜。他穿了一件棕色的毛大衣,内里是白色的.
165 |
166 | User:视频中男子说了什么内容?
167 | HumanOmni:男子说他穿什么颜色的衣服,现在处于什么环境,今天的天气怎么样.
168 |
169 | User:视频中男子有什么情绪吗?
170 | HumanOmni:没有情绪。
171 | ```
172 |
173 | ## 🛠️ Environment Setup
174 |
175 | To set up the recommended environment for HumanOmni, follow these instructions:
176 |
177 | ### Recommended Environment
178 | - **Python**: >=3.10
179 | - **CUDA**: >=12.1
180 | - **PyTorch**: >=2.2 (with CUDA support)
181 | - **Transformers**: >=4.45
182 | - **Accelerate**: >=0.30.1
183 |
184 | Or you can quickly set up the environment as follows:
185 |
186 | ```
187 | git clone https://github.com/HumanMLLM/HumanOmni
188 | cd HumanOmni
189 | conda create -n humanOmni python=3.10 -y
190 | conda activate humanOmni
191 | pip install --upgrade pip
192 | pip install -r requirements.txt
193 | pip install flash-attn --no-build-isolation
194 | ```
195 | ## 🧠 Training on Custom Dataset
196 | ### Data Preparation
197 | An example json file of the training data:
198 | ```
199 | [
200 | {
201 | "video": "human/DFEW/videos/1.mp4",
202 | "conversations": [
203 | {
204 | "from": "human",
205 | "value": "\n\nAs an emotional recognition expert; throughout the video, which emotion conveyed by the characters is the most obvious to you?\nfear ,angry ,surprise ,happy ,neutral ,sad ,disgust"
206 | },
207 | {
208 | "from": "gpt",
209 | "value": "sad"
210 | }
211 | ],
212 | },
213 | {
214 | "video": "human/DFEW/videos/1.mp4",
215 | "conversations": [
216 | {
217 | "from": "human",
218 | "value": "\n\nAs an emotional recognition expert, in the video, when the characters display their emotions, which predominant feeling is most clearly expressed?\nfear ,disgust ,happy ,sad ,surprise"
219 | },
220 | {
221 | "from": "gpt",
222 | "value": "sad"
223 | }
224 | ],
225 | },
226 | ...
227 | ]
228 | ```
229 |
230 | ### Multi-Modal SFT
231 | - Download the required weights: (1) [HumanOmni-7B-Video](https://modelscope.cn/models/iic/HumanOmni-7B-Video) (2) [HumanOmni-7B-Audio](https://modelscope.cn/models/iic/HumanOmni-7B-Audio)
232 | - scripts/train/finetune_humanomni.sh Loading the weights and the prepared dataset.
233 | - bash scripts/train/finetune_humanomni.sh
234 |
235 | ## 🔍 Inference
236 | We provide inference.py for singe video inference.
237 | - video + audio
238 | ```
239 | python inference.py --modal video_audio \
240 | --model_path ./HumanOmni_7B \
241 | --video_path video.mp4 \
242 | --instruct "Describe this video."
243 | ```
244 | - only video
245 | ```
246 | python inference.py --modal video \
247 | --model_path ./HumanOmni_7B \
248 | --video_path video.mp4 \
249 | --instruct "Describe this video."
250 | ```
251 | - only audio
252 | ```
253 | python inference.py --modal audio \
254 | --model_path ./HumanOmni_7B \
255 | --video_path video.mp4 \
256 | --instruct "Describe this video."
257 | ```
258 |
259 | ## 🤝 Related Work
260 | - [LLaVA-Octopus: Unlocking Instruction-Driven Adaptive Projector Fusion for Video Understanding](https://arxiv.org/abs/2501.05067)
261 | - [Omni-Emotion: Extending Video MLLM with Detailed Face and Audio Modeling for Multimodal Emotion Analysis](https://arxiv.org/abs/2501.09502)
262 | - [Qwen2.5](https://github.com/QwenLM/Qwen2.5)
263 |
264 | ## 📚 Citation
265 | If you find our work helpful, feel free to give us a cite.
266 | ```
267 | @article{zhao2025humanomni,
268 | title={HumanOmni: A Large Vision-Speech Language Model for Human-Centric Video Understanding},
269 | author={Zhao, Jiaxing and Yang, Qize and Peng, Yixing and Bai, Detao and Yao, Shimin and Sun, Boyuan and Chen, Xiang and Fu, Shenghao and Wei, Xihan and Bo, Liefeng and others},
270 | journal={arXiv preprint arXiv:2501.15111},
271 | year={2025}
272 | }
273 | ```
274 |
--------------------------------------------------------------------------------
/figures/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanMLLM/HumanOmni/26fa491492d39a66eef0d9e805c7bf33bf2cb0ee/figures/arch.png
--------------------------------------------------------------------------------
/figures/result-dfec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanMLLM/HumanOmni/26fa491492d39a66eef0d9e805c7bf33bf2cb0ee/figures/result-dfec.png
--------------------------------------------------------------------------------
/figures/result-emotion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanMLLM/HumanOmni/26fa491492d39a66eef0d9e805c7bf33bf2cb0ee/figures/result-emotion.png
--------------------------------------------------------------------------------
/figures/result-mvbench.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanMLLM/HumanOmni/26fa491492d39a66eef0d9e805c7bf33bf2cb0ee/figures/result-mvbench.png
--------------------------------------------------------------------------------
/humanomni/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import warnings
4 | import shutil
5 | from functools import partial
6 |
7 | import torch
8 |
9 | from .model import load_pretrained_model
10 | from .mm_utils import process_image, process_video, process_audio,tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria,process_image_npary
11 | from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, DEFAULT_AUDIO_TOKEN
12 | import transformers
13 |
14 | def model_init(model_path=None, **kwargs):
15 | # with_face = kwargs.get('with_face', False)
16 | model_path = "HumanOmni_7B" if model_path is None else model_path
17 | model_name = get_model_name_from_path(model_path)
18 |
19 | tokenizer, model, processor, context_len, audio_processor = load_pretrained_model(model_path, None, model_name, **kwargs)
20 |
21 | if tokenizer.pad_token is None and tokenizer.unk_token is not None:
22 | tokenizer.pad_token = tokenizer.unk_token
23 |
24 | num_frames = model.config.num_frames if hasattr(model.config, "num_frames") else NUM_FRAMES
25 | if "qwen2vit" in model_path:
26 | from .mm_utils import process_image_qwen, process_video_qwen
27 | processor = {
28 | 'image': partial(process_image_qwen, processor=processor, aspect_ratio=None),
29 | 'video': partial(process_video_qwen, processor=processor, aspect_ratio=None, num_frames=num_frames),
30 | }
31 | else:
32 | processor = {
33 | 'image': partial(process_image, processor=processor, aspect_ratio=None),
34 | 'video': partial(process_video, processor=processor, aspect_ratio=None, num_frames=num_frames),
35 | 'face': partial(process_image_npary, processor=processor, aspect_ratio=None),
36 | 'audio': partial(process_audio, processor=audio_processor),
37 | }
38 | return model, processor, tokenizer
39 |
40 |
41 | def mm_infer(image_or_video, instruct, model, tokenizer, audio=None, modal='video', question=None, bert_tokeni=None, **kwargs):
42 | """inference api of HumanOmni for video understanding.
43 |
44 | Args:
45 | model: HumanOmni model.
46 | image_or_video (torch.Tensor): image tensor (1, C, H, W) / video tensor (T, C, H, W).
47 | instruct (str): text instruction for understanding video.
48 | tokenizer: tokenizer.
49 | do_sample (bool): whether to sample.
50 | modal (str): inference modality.
51 | Returns:
52 | str: response of the model.
53 | """
54 | question_prompt = None
55 | if question is not None:
56 | question = [question]
57 | question_prompt = bert_tokeni(question, return_tensors='pt', padding=True, truncation=True,add_special_tokens=True)
58 | question_prompt = {key: value.to('cuda') for key, value in question_prompt.items()}
59 |
60 | if modal == 'image':
61 | modal_token = DEFAULT_IMAGE_TOKEN
62 | elif modal == 'video':
63 | modal_token = DEFAULT_VIDEO_TOKEN
64 | elif modal == 'audio':
65 | modal_token = DEFAULT_AUDIO_TOKEN
66 | elif modal == 'video_audio':
67 | modal_token = DEFAULT_VIDEO_TOKEN + '\n' +DEFAULT_AUDIO_TOKEN
68 | elif modal == 'text':
69 | modal_token = ''
70 | else:
71 | raise ValueError(f"Unsupported modal: {modal}")
72 |
73 |
74 | # 1. vision preprocess (load & transform image or video).
75 |
76 | if modal == 'text' or modal == 'audio':
77 | tensor = [(torch.zeros(32, 3, 384, 384).cuda().half(), "video")]
78 | else:
79 | if "video" in modal:
80 | vi_modal = "video"
81 | else:
82 | vi_modal = "image"
83 |
84 | if isinstance(image_or_video, transformers.image_processing_base.BatchFeature):
85 | # 处理 BatchFeature 中的所有 tensor
86 | processed_data = transformers.image_processing_base.BatchFeature({
87 | 'pixel_values_videos': image_or_video['pixel_values_videos'][0].half().cuda(),
88 | 'video_grid_thw': image_or_video['video_grid_thw'][0].cuda()
89 | })
90 | else:
91 | # 处理普通 tensor
92 | processed_data = image_or_video.half().cuda()
93 | tensor = [(processed_data, vi_modal)]
94 |
95 |
96 | if audio is not None:
97 | audio = audio.half().cuda()
98 |
99 | # 2. text preprocess (tag process & generate prompt).
100 | if isinstance(instruct, str):
101 | message = [{'role': 'user', 'content': modal_token + '\n' + instruct}]
102 | elif isinstance(instruct, list):
103 | message = copy.deepcopy(instruct)
104 | message[0]['content'] = modal_token + '\n' + message[0]['content']
105 | else:
106 | raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
107 |
108 |
109 |
110 | if model.config.model_type in ['HumanOmni', 'HumanOmni_mistral', 'HumanOmni_mixtral']:
111 | system_message = [
112 | {'role': 'system', 'content': (
113 | """<>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."""
114 | """\n"""
115 | """If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n< >""")
116 | }
117 | ]
118 | else:
119 | system_message = []
120 |
121 | message = system_message + message
122 | prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
123 |
124 | # add modal warpper tokken
125 | if model.config.mm_use_x_start_end:
126 | prompt = prompt.replace("", "").replace("", "").replace("", "")
127 |
128 |
129 | input_ids = tokenizer_multimodal_token(prompt, tokenizer, modal_token, return_tensors='pt').unsqueeze(0).long().cuda()
130 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
131 |
132 | # 3. generate response according to visual signals and prompts.
133 | keywords = [tokenizer.eos_token]
134 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
135 |
136 | do_sample = kwargs.get('do_sample', False)
137 | temperature = kwargs.get('temperature', 0.2 if do_sample else 0.0)
138 | top_p = kwargs.get('top_p', 0.9)
139 | max_new_tokens = kwargs.get('max_new_tokens', 2048)
140 |
141 |
142 | with torch.inference_mode():
143 | output_ids = model.generate(
144 | input_ids,
145 | attention_mask=attention_masks,
146 | images=tensor,
147 | do_sample=do_sample,
148 | temperature=temperature,
149 | max_new_tokens=max_new_tokens,
150 | top_p=top_p,
151 | use_cache=True,
152 | stopping_criteria=[stopping_criteria],
153 | pad_token_id=tokenizer.eos_token_id,
154 | prompts=question_prompt,
155 | audios=audio
156 | )
157 |
158 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
159 |
160 | return outputs
161 |
--------------------------------------------------------------------------------
/humanomni/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 |
9 | # Image arguments
10 | IMAGE_TOKEN_INDEX = -200
11 | IMAGE_TOKEN_PATCH = -300
12 | DEFAULT_IMAGE_TOKEN = ""
13 | DEFAULT_IMAGE_PATCH_TOKEN = ""
14 | DEFAULT_IM_START_TOKEN = ""
15 | DEFAULT_IM_END_TOKEN = ""
16 | IMAGE_PLACEHOLDER = ""
17 |
18 | # Video arguments
19 | VIDEO_TOKEN_INDEX = -201
20 | DEFAULT_VIDEO_TOKEN = ""
21 | NUM_FRAMES = 8
22 | MAX_FRAMES = 32
23 | NUM_FRAMES_PER_SECOND = 1
24 |
25 | # Audio arguments
26 | AUDIO_TOKEN_INDEX = -202
27 | DEFAULT_AUDIO_TOKEN = ""
28 |
29 | MODAL_INDEX_MAP = {
30 | "": -202,
31 | "": -201,
32 | "": -200,
33 | }
34 |
35 | MODAL_INDEX_REMAP = {v: k for k, v in MODAL_INDEX_MAP.items()}
36 | DEFAULT_X_START_TOKEN = {'IMAGE': "", 'VIDEO': "", 'AUDIO': "", 'THERMAL': "", 'DEPTH': ""}
37 | DEFAULT_X_END_TOKEN = {'IMAGE': "", 'VIDEO': "", 'AUDIO': "", 'THERMAL': "", 'DEPTH': ""}
--------------------------------------------------------------------------------
/humanomni/conversation.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import dataclasses
3 | from io import BytesIO
4 | from enum import auto, Enum
5 | from typing import List, Tuple
6 |
7 | from PIL import Image
8 | from .constants import LOGDIR, NUM_FRAMES
9 |
10 |
11 | class SeparatorStyle(Enum):
12 | """Different separator style."""
13 | SINGLE = auto()
14 | TWO = auto()
15 | PLAIN = auto()
16 | LLAMA2 = auto()
17 | QWEN = auto()
18 |
19 | @dataclasses.dataclass
20 | class Conversation:
21 | """A class that keeps all conversation history."""
22 | system: str
23 | roles: List[str]
24 | messages: List[List[str]]
25 | offset: int
26 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27 | sep: str = "###"
28 | sep2: str = None
29 | version: str = "Unknown"
30 |
31 | skip_next: bool = False
32 | modality: str = "image"
33 |
34 | def get_prompt(self):
35 | messages = self.messages
36 | modality_token = f"<{self.modality}>"
37 | if len(messages) > 0 and type(messages[0][1]) is tuple:
38 | messages = self.messages.copy()
39 | init_role, init_msg = messages[0].copy()
40 | init_msg = init_msg[0].replace(modality_token, "").strip()
41 | if 'mmtag' in self.version:
42 | messages[0] = (init_role, init_msg)
43 | messages.insert(0, (self.roles[0], " "))
44 | messages.insert(1, (self.roles[1], "Received."))
45 | else:
46 | messages[0] = (init_role, f"{modality_token}\n" + init_msg)
47 |
48 | if self.sep_style == SeparatorStyle.SINGLE:
49 | ret = self.system + self.sep
50 | for role, message in messages:
51 | if message:
52 | if type(message) is tuple:
53 | message, _, _ = message
54 | ret += role + ": " + message + self.sep
55 | else:
56 | ret += role + ":"
57 | elif self.sep_style == SeparatorStyle.TWO:
58 | seps = [self.sep, self.sep2]
59 | ret = self.system + seps[0]
60 | for i, (role, message) in enumerate(messages):
61 | if message:
62 | if type(message) is tuple:
63 | message, _, _ = message
64 | ret += role + ": " + message + seps[i % 2]
65 | else:
66 | ret += role + ":"
67 | elif self.sep_style == SeparatorStyle.LLAMA2:
68 | wrap_sys = lambda msg: f"<>\n{msg}\n< >\n\n"
69 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
70 | ret = ""
71 |
72 | for i, (role, message) in enumerate(messages):
73 | if i == 0:
74 | assert message, "first message should not be none"
75 | assert role == self.roles[0], "first message should come from user"
76 | if message:
77 | if type(message) is tuple:
78 | message, _, _ = message
79 | if i == 0: message = wrap_sys(self.system) + message
80 | if i % 2 == 0:
81 | message = wrap_inst(message)
82 | ret += self.sep + message
83 | else:
84 | ret += " " + message + " " + self.sep2
85 | else:
86 | ret += ""
87 | ret = ret.lstrip(self.sep)
88 | elif self.sep_style == SeparatorStyle.QWEN:
89 | ret = ""
90 | # 1. Add system prompt
91 | ret += self.system + self.sep + "\n"
92 | # 2. Iterate message
93 | for i, (role, message) in enumerate(messages):
94 | if i == 0:
95 | assert message, "first message should not be none"
96 | assert role == self.roles[0], "first message should come from user"
97 | if message:
98 | if type(message) is tuple:
99 | message, _, _ = message
100 | # 2.1 Add role and message
101 | ret += role + message + self.sep + "\n"
102 | else:
103 | # 2.2 Add generation prompt
104 | ret += role
105 | elif self.sep_style == SeparatorStyle.PLAIN:
106 | seps = [self.sep, self.sep2]
107 | ret = self.system
108 | for i, (role, message) in enumerate(messages):
109 | if message:
110 | if type(message) is tuple:
111 | message, _, _ = message
112 | ret += role + message + seps[i % 2]
113 | else:
114 | ret += role
115 | else:
116 | raise ValueError(f"Invalid style: {self.sep_style}")
117 |
118 | return ret
119 |
120 | def append_message(self, role, message):
121 | self.messages.append([role, message])
122 |
123 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=800, min_len=400):
124 | if image_process_mode == "Pad":
125 | def expand2square(pil_img, background_color=(122, 116, 104)):
126 | width, height = pil_img.size
127 | if width == height:
128 | return pil_img
129 | elif width > height:
130 | result = Image.new(pil_img.mode, (width, width), background_color)
131 | result.paste(pil_img, (0, (width - height) // 2))
132 | return result
133 | else:
134 | result = Image.new(pil_img.mode, (height, height), background_color)
135 | result.paste(pil_img, ((height - width) // 2, 0))
136 | return result
137 | image = expand2square(image)
138 | elif image_process_mode in ["Default", "Crop"]:
139 | pass
140 | elif image_process_mode == "Resize":
141 | image = image.resize((336, 336))
142 | else:
143 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
144 | if max(image.size) > max_len:
145 | max_hw, min_hw = max(image.size), min(image.size)
146 | aspect_ratio = max_hw / min_hw
147 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148 | longest_edge = int(shortest_edge * aspect_ratio)
149 | W, H = image.size
150 | if H > W:
151 | H, W = longest_edge, shortest_edge
152 | else:
153 | H, W = shortest_edge, longest_edge
154 | image = image.resize((W, H))
155 | if return_pil:
156 | return image
157 | else:
158 | buffered = BytesIO()
159 | image.save(buffered, format=image_format)
160 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
161 | return img_b64_str
162 |
163 |
164 | def get_videos(self, return_pil=False):
165 | video_frames = []
166 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
167 | if i % 2 == 0:
168 | if type(msg) is tuple:
169 | from decord import VideoReader, cpu
170 | import numpy as np
171 | # here video is the file path of input video
172 | msg, video, image_process_mode = msg
173 | if not return_pil:
174 | # return filepath
175 | video_frames.append(video)
176 | else:
177 | # read video using decord.VideoReader
178 | decord_vr = VideoReader(uri=video, ctx=cpu(0))
179 | duration = len(decord_vr)
180 | frame_id_list = np.linspace(0, duration-1, NUM_FRAMES, dtype=int)
181 | # convert the extracted image frames into PIL objects
182 | all_images = [Image.fromarray(f) for f in decord_vr.get_batch(frame_id_list).asnumpy()]
183 | video_frames.extend([self.process_image(image, image_process_mode, return_pil=return_pil) for image in all_images])
184 | return video_frames
185 |
186 |
187 | def get_images(self, return_pil=False):
188 | images = []
189 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
190 | if i % 2 == 0:
191 | if type(msg) is tuple:
192 | msg, image, image_process_mode = msg
193 | image = self.process_image(image, image_process_mode, return_pil=return_pil)
194 | images.append(image)
195 |
196 | return images
197 |
198 | def to_gradio_chatbot(self):
199 | ret = []
200 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
201 | if i % 2 == 0:
202 | if type(msg) is tuple:
203 |
204 | msg, image_or_video, image_process_mode = msg
205 | ##print("imagebox:", image)
206 | if isinstance(image_or_video, Image.Image):
207 | # image is PIL object
208 | img_b64_str = self.process_image(image_or_video, "Default", return_pil=False, image_format='JPEG')
209 | img_str = f' '
210 | msg = img_str + msg.replace('', '').strip()
211 | else:
212 | # video is file path
213 | vid_str = f' '
214 | msg = vid_str + msg.replace('', '').strip()
215 | ret.append([msg, None])
216 | else:
217 | ret.append([msg, None])
218 | else:
219 | ret[-1][-1] = msg
220 | return ret
221 |
222 | def copy(self):
223 | return Conversation(
224 | system=self.system,
225 | roles=self.roles,
226 | messages=[[x, y] for x, y in self.messages],
227 | offset=self.offset,
228 | sep_style=self.sep_style,
229 | sep=self.sep,
230 | sep2=self.sep2,
231 | version=self.version)
232 |
233 | def dict(self):
234 | if (self.modality == "image" and len(self.get_images()) > 0) or \
235 | (self.modality == "video" and len(self.get_videos()) > 0):
236 | return {
237 | "system": self.system,
238 | "roles": self.roles,
239 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
240 | "offset": self.offset,
241 | "sep": self.sep,
242 | "sep2": self.sep2,
243 | "modality": self.modality
244 | }
245 | return {
246 | "system": self.system,
247 | "roles": self.roles,
248 | "messages": self.messages,
249 | "offset": self.offset,
250 | "sep": self.sep,
251 | "sep2": self.sep2,
252 | }
253 |
254 |
255 | conv_vicuna_v0 = Conversation(
256 | system="A chat between a curious human and an artificial intelligence assistant. "
257 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
258 | roles=("Human", "Assistant"),
259 | messages=(
260 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
261 | ("Assistant",
262 | "Renewable energy sources are those that can be replenished naturally in a relatively "
263 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
264 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
265 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
266 | "renewable and non-renewable energy sources:\n"
267 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
268 | "energy sources are finite and will eventually run out.\n"
269 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
270 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
271 | "and other negative effects.\n"
272 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
273 | "have lower operational costs than non-renewable sources.\n"
274 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
275 | "locations than non-renewable sources.\n"
276 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
277 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
278 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
279 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
280 | ),
281 | offset=2,
282 | sep_style=SeparatorStyle.SINGLE,
283 | sep="###",
284 | )
285 |
286 | conv_llava_plain = Conversation(
287 | system="",
288 | roles=("", ""),
289 | messages=(),
290 | offset=0,
291 | sep_style=SeparatorStyle.PLAIN,
292 | sep="",
293 | sep2="\n"
294 | )
295 |
296 | conv_llava_v0_mmtag = Conversation(
297 | system="A chat between a curious user and an artificial intelligence assistant. "
298 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
299 | "The visual content will be provided with the following format: visual content .",
300 | roles=("Human", "Assistant"),
301 | messages=(
302 | ),
303 | offset=0,
304 | sep_style=SeparatorStyle.SINGLE,
305 | sep="###",
306 | version="v0_mmtag",
307 | )
308 |
309 | conv_llava_v0 = Conversation(
310 | system="A chat between a curious human and an artificial intelligence assistant. "
311 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
312 | roles=("Human", "Assistant"),
313 | messages=(
314 | ),
315 | offset=0,
316 | sep_style=SeparatorStyle.SINGLE,
317 | sep="###",
318 | )
319 |
320 | conv_vicuna_v1 = Conversation(
321 | system="A chat between a curious user and an artificial intelligence assistant. "
322 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
323 | roles=("USER", "ASSISTANT"),
324 | version="v1",
325 | messages=(),
326 | offset=0,
327 | sep_style=SeparatorStyle.TWO,
328 | sep=" ",
329 | sep2="",
330 | )
331 |
332 | conv_llava_v1_mmtag = Conversation(
333 | system="A chat between a curious user and an artificial intelligence assistant. "
334 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
335 | "The visual content will be provided with the following format: visual content .",
336 | roles=("USER", "ASSISTANT"),
337 | messages=(),
338 | offset=0,
339 | sep_style=SeparatorStyle.TWO,
340 | sep=" ",
341 | sep2="",
342 | version="v1_mmtag",
343 | )
344 |
345 | conv_llava_v1 = Conversation(
346 | system="A chat between a curious human and an artificial intelligence assistant. "
347 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
348 | roles=("USER", "ASSISTANT"),
349 | version="v1",
350 | messages=(),
351 | offset=0,
352 | sep_style=SeparatorStyle.TWO,
353 | sep=" ",
354 | sep2="",
355 | )
356 |
357 | conv_llava_llama2 = Conversation(
358 | system="You are a helpful language and vision assistant. "
359 | "You are able to understand the visual content that the user provides, "
360 | "and assist the user with a variety of tasks using natural language.",
361 | roles=("USER", "ASSISTANT"),
362 | version="llama2",
363 | messages=(),
364 | offset=0,
365 | sep_style=SeparatorStyle.LLAMA2,
366 | sep="",
367 | sep2=" ",
368 | )
369 |
370 | conv_llama2 = Conversation(
371 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
372 |
373 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
374 | roles=("USER", "ASSISTANT"),
375 | version="llama2",
376 | messages=(),
377 | offset=0,
378 | sep_style=SeparatorStyle.LLAMA2,
379 | sep="",
380 | sep2=" ",
381 | )
382 |
383 | conv_mistral = Conversation(
384 | system="A chat between a curious user and an artificial intelligence assistant. "
385 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
386 | roles=("USER", "ASSISTANT"),
387 | version="llama2",
388 | messages=(),
389 | offset=0,
390 | sep_style=SeparatorStyle.LLAMA2,
391 | sep="",
392 | sep2="",
393 | )
394 |
395 | conv_qwen = Conversation(
396 | system="<|im_start|>system\nYou are a helpful assistant.",
397 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
398 | messages=(),
399 | offset=0,
400 | sep_style=SeparatorStyle.QWEN,
401 | sep="<|im_end|>",
402 | version="qwen",
403 | )
404 |
405 | conv_qwen_plain = Conversation(
406 | system="",
407 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
408 | messages=(),
409 | offset=0,
410 | sep_style=SeparatorStyle.PLAIN,
411 | sep="<|im_end|>",
412 | sep2="<|im_end|>",
413 | version="qwen_plain",
414 | )
415 |
416 | default_conversation = conv_mistral
417 | conv_templates = {
418 | "default": conv_vicuna_v0,
419 | # pretrain template
420 | "plain": conv_llava_plain,
421 | # llava v0
422 | "v0": conv_vicuna_v0,
423 | "v0_plain": conv_llava_plain,
424 | "v0_mmtag": conv_llava_v0_mmtag,
425 | "llava_v0": conv_llava_v0,
426 | # llava v1
427 | "v1": conv_vicuna_v1,
428 | "v1_mmtag": conv_llava_v1_mmtag,
429 | "llava_v1": conv_llava_v1,
430 | "vicuna_v1": conv_vicuna_v1,
431 | # llava v1.5
432 | "llava_llama2": conv_llava_llama2,
433 | # llama2
434 | "llama2": conv_llama2,
435 | # mistral
436 | "mistral": conv_mistral,
437 | # qwen
438 | "qwen": conv_qwen,
439 | "qwen_plain": conv_qwen_plain,
440 | }
441 |
442 |
443 | if __name__ == "__main__":
444 | print(default_conversation.get_prompt())
445 |
--------------------------------------------------------------------------------
/humanomni/conversation_llava.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Any, Dict, Union, Tuple
4 | import re
5 | import base64
6 | from io import BytesIO
7 | from PIL import Image
8 | from transformers import AutoTokenizer
9 |
10 |
11 | class SeparatorStyle(Enum):
12 | """Different separator style."""
13 |
14 | SINGLE = auto()
15 | TWO = auto()
16 | MPT = auto()
17 | PLAIN = auto()
18 | CHATML = auto()
19 | LLAMA_2 = auto()
20 | LLAMA_3 = auto()
21 | QWEN = auto()
22 | GEMMA = auto()
23 |
24 |
25 | @dataclasses.dataclass
26 | class Conversation:
27 | """A class that keeps all conversation history."""
28 |
29 | system: str
30 | roles: List[str]
31 | messages: List[List[str]]
32 | offset: int
33 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
34 | sep: str = "###"
35 | sep2: str = None
36 | version: str = "Unknown"
37 |
38 | tokenizer_id: str = ""
39 | tokenizer: Any = None
40 | # Stop criteria (the default one is EOS token)
41 | stop_str: Union[str, List[str]] = None
42 | # Stops generation if meeting any token in this list
43 | stop_token_ids: List[int] = None
44 |
45 | skip_next: bool = False
46 |
47 | def get_prompt(self):
48 | messages = self.messages
49 | if len(messages) > 0 and type(messages[0][1]) is tuple:
50 | messages = self.messages.copy()
51 | init_role, init_msg = messages[0].copy()
52 | init_msg = init_msg[0]
53 | if "mmtag" in self.version:
54 | init_msg = init_msg.replace("", "").strip()
55 | messages[0] = (init_role, init_msg)
56 | messages.insert(0, (self.roles[0], " "))
57 | messages.insert(1, (self.roles[1], "Received."))
58 | elif not init_msg.startswith(""):
59 | init_msg = init_msg.replace("", "").strip()
60 | messages[0] = (init_role, "\n" + init_msg)
61 | else:
62 | messages[0] = (init_role, init_msg)
63 |
64 | if self.sep_style == SeparatorStyle.SINGLE:
65 | ret = self.system + self.sep
66 | for role, message in messages:
67 | if message:
68 | if type(message) is tuple:
69 | message, _, _ = message
70 | ret += role + ": " + message + self.sep
71 | else:
72 | ret += role + ":"
73 |
74 | elif self.sep_style == SeparatorStyle.TWO:
75 | seps = [self.sep, self.sep2]
76 | ret = self.system + seps[0]
77 | for i, (role, message) in enumerate(messages):
78 | if message:
79 | if type(message) is tuple:
80 | message, _, _ = message
81 | ret += role + ": " + message + seps[i % 2]
82 | else:
83 | ret += role + ":"
84 |
85 | elif self.sep_style == SeparatorStyle.CHATML:
86 | ret = "" if self.system == "" else self.system + self.sep + "\n"
87 | for role, message in messages:
88 | if message:
89 | if type(message) is tuple:
90 | message, images, _ = message
91 | message = "" * len(images) + message
92 | ret += role + "\n" + message + self.sep + "\n"
93 | else:
94 | ret += role + "\n"
95 | return ret
96 |
97 | elif self.sep_style == SeparatorStyle.LLAMA_3:
98 | if self.tokenizer is None:
99 | raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
100 | chat_template_messages = [{"role": "system", "content": self.system}]
101 | for role, message in messages:
102 | if message:
103 | if type(message) is tuple:
104 | message, images = message
105 | message = "" * len(images) + message
106 | chat_template_messages.append({"role": role, "content": message})
107 |
108 | # print(chat_template_messages)
109 | return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
110 |
111 | elif self.sep_style == SeparatorStyle.MPT:
112 | ret = self.system + self.sep
113 | for role, message in messages:
114 | if message:
115 | if type(message) is tuple:
116 | message, _, _ = message
117 | ret += role + message + self.sep
118 | else:
119 | ret += role
120 |
121 | elif self.sep_style == SeparatorStyle.GEMMA:
122 | ret = ""
123 | for i, (role, message) in enumerate(messages):
124 | assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
125 | if message:
126 | if type(message) is tuple:
127 | message, _, _ = message
128 | ret += role + message + self.sep
129 | else:
130 | ret += role
131 |
132 | elif self.sep_style == SeparatorStyle.LLAMA_2:
133 | wrap_sys = lambda msg: f"<>\n{msg}\n< >\n\n" if len(msg) > 0 else msg
134 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
135 | ret = ""
136 |
137 | for i, (role, message) in enumerate(messages):
138 | if i == 0:
139 | assert message, "first message should not be none"
140 | assert role == self.roles[0], "first message should come from user"
141 | if message:
142 | if type(message) is tuple:
143 | message, _, _ = message
144 | if i == 0:
145 | message = wrap_sys(self.system) + message
146 | if i % 2 == 0:
147 | message = wrap_inst(message)
148 | ret += self.sep + message
149 | else:
150 | ret += " " + message + " " + self.sep2
151 | else:
152 | ret += ""
153 | ret = ret.lstrip(self.sep)
154 |
155 | elif self.sep_style == SeparatorStyle.PLAIN:
156 | seps = [self.sep, self.sep2]
157 | ret = self.system
158 | for i, (role, message) in enumerate(messages):
159 | if message:
160 | if type(message) is tuple:
161 | message, _, _ = message
162 | ret += message + seps[i % 2]
163 | else:
164 | ret += ""
165 | else:
166 | raise ValueError(f"Invalid style: {self.sep_style}")
167 |
168 | return ret
169 |
170 | def append_message(self, role, message):
171 | self.messages.append([role, message])
172 |
173 | def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
174 | if image_process_mode == "Pad":
175 |
176 | def expand2square(pil_img, background_color=(122, 116, 104)):
177 | width, height = pil_img.size
178 | if width == height:
179 | return pil_img
180 | elif width > height:
181 | result = Image.new(pil_img.mode, (width, width), background_color)
182 | result.paste(pil_img, (0, (width - height) // 2))
183 | return result
184 | else:
185 | result = Image.new(pil_img.mode, (height, height), background_color)
186 | result.paste(pil_img, ((height - width) // 2, 0))
187 | return result
188 |
189 | image = expand2square(image)
190 | elif image_process_mode in ["Default", "Crop"]:
191 | pass
192 | elif image_process_mode == "Resize":
193 | image = image.resize((336, 336))
194 | else:
195 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
196 |
197 | if type(image) is not Image.Image:
198 | image = Image.open(image).convert("RGB")
199 |
200 | max_hw, min_hw = max(image.size), min(image.size)
201 | aspect_ratio = max_hw / min_hw
202 | max_len, min_len = 672, 448
203 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
204 | longest_edge = int(shortest_edge * aspect_ratio)
205 | W, H = image.size
206 | if H > W:
207 | H, W = longest_edge, shortest_edge
208 | else:
209 | H, W = shortest_edge, longest_edge
210 | image = image.resize((W, H))
211 | if return_pil:
212 | return image
213 | else:
214 | buffered = BytesIO()
215 | image.save(buffered, format=image_format)
216 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
217 | return img_b64_str
218 |
219 | def get_images(self, return_pil=False, return_path=False):
220 | images = []
221 | for i, (role, msg) in enumerate(self.messages[self.offset :]):
222 | if i % 2 == 0:
223 | if type(msg) is tuple:
224 | msg, image, image_process_mode = msg
225 | if type(image) != list:
226 | image = [image]
227 | for img in image:
228 | if not return_path and self.is_image_file(img):
229 | img = self.process_image(img, image_process_mode, return_pil=return_pil)
230 | else:
231 | images.append(img)
232 | return images
233 |
234 | def is_image_file(self, filename):
235 | image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
236 | return any(filename.lower().endswith(ext) for ext in image_extensions)
237 |
238 | def is_video_file(self, filename):
239 | video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
240 | return any(filename.lower().endswith(ext) for ext in video_extensions)
241 |
242 | def to_gradio_chatbot(self):
243 | ret = []
244 | for i, (role, msg) in enumerate(self.messages[self.offset :]):
245 | if i % 2 == 0:
246 | if type(msg) is tuple:
247 | msg, image, image_process_mode = msg
248 | if type(image) != list:
249 | image = [image]
250 | if len(image) == 1:
251 | msg = "\n" + msg.replace("", "").strip()
252 | else:
253 | msg = re.sub(r"()\n(?=)", r"\1 ", msg)
254 |
255 | img_str_list = []
256 | for img in image:
257 | if self.is_image_file(img):
258 | img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
259 | img_str = f' '
260 | img_str_list.append(img_str)
261 | elif self.is_video_file(img):
262 | ret.append(((img,), None))
263 |
264 | msg = msg.strip()
265 | img_place_holder = ""
266 | for img_str in img_str_list:
267 | img_place_holder += f"{img_str}\n\n"
268 |
269 | if len(img_str_list) > 0:
270 | msg = f"{img_place_holder}\n\n{msg}"
271 |
272 | if len(msg) > 0:
273 | ret.append([msg, None])
274 | else:
275 | ret.append([msg, None])
276 | else:
277 | ret[-1][-1] = msg
278 | return ret
279 |
280 | def copy(self):
281 | return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
282 |
283 | def dict(self):
284 | if len(self.get_images()) > 0:
285 | return {
286 | "system": self.system,
287 | "roles": self.roles,
288 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
289 | "offset": self.offset,
290 | "sep": self.sep,
291 | "sep2": self.sep2,
292 | }
293 | return {
294 | "system": self.system,
295 | "roles": self.roles,
296 | "messages": self.messages,
297 | "offset": self.offset,
298 | "sep": self.sep,
299 | "sep2": self.sep2,
300 | }
301 |
302 |
303 | conv_vicuna_v0 = Conversation(
304 | system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
305 | roles=("Human", "Assistant"),
306 | messages=[
307 | ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
308 | [
309 | "Assistant",
310 | "Renewable energy sources are those that can be replenished naturally in a relatively "
311 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
312 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
313 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
314 | "renewable and non-renewable energy sources:\n"
315 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
316 | "energy sources are finite and will eventually run out.\n"
317 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
318 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
319 | "and other negative effects.\n"
320 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
321 | "have lower operational costs than non-renewable sources.\n"
322 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
323 | "locations than non-renewable sources.\n"
324 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
325 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
326 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
327 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
328 | ],
329 | ],
330 | offset=2,
331 | sep_style=SeparatorStyle.SINGLE,
332 | sep="###",
333 | )
334 |
335 | conv_vicuna_v1 = Conversation(
336 | system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
337 | roles=("USER", "ASSISTANT"),
338 | version="v1",
339 | messages=[],
340 | offset=0,
341 | sep_style=SeparatorStyle.TWO,
342 | sep=" ",
343 | sep2="",
344 | )
345 |
346 | conv_llama_2 = Conversation(
347 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
348 |
349 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
350 | roles=("USER", "ASSISTANT"),
351 | version="llama_v2",
352 | messages=[],
353 | offset=0,
354 | sep_style=SeparatorStyle.LLAMA_2,
355 | sep="",
356 | sep2=" ",
357 | )
358 |
359 | conv_llava_llama_2 = Conversation(
360 | system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
361 | roles=("USER", "ASSISTANT"),
362 | version="llama_v2",
363 | messages=[],
364 | offset=0,
365 | sep_style=SeparatorStyle.LLAMA_2,
366 | sep="",
367 | sep2=" ",
368 | )
369 |
370 | def safe_load_tokenizer(tokenizer_id):
371 | try:
372 | return AutoTokenizer.from_pretrained(tokenizer_id)
373 | except Exception:
374 | return None
375 |
376 | conv_llava_llama_3 = Conversation(
377 | system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
378 | roles=("user", "assistant"),
379 | version="llama_v3",
380 | messages=[],
381 | offset=0,
382 | sep="<|eot_id|>",
383 | sep_style=SeparatorStyle.LLAMA_3,
384 | tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
385 | tokenizer=safe_load_tokenizer("meta-llama/Meta-Llama-3-8B-Instruct"),
386 | stop_token_ids=[128009],
387 | )
388 |
389 |
390 | conv_llava_llama_31 = Conversation(
391 | system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
392 | roles=("user", "assistant"),
393 | version="llama3",
394 | messages=[],
395 | offset=0,
396 | sep_style=SeparatorStyle.LLAMA_3,
397 | tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
398 | stop_token_ids=[128009],
399 | )
400 |
401 |
402 |
403 | conv_mistral_instruct = Conversation(
404 | system="",
405 | roles=("USER", "ASSISTANT"),
406 | version="llama_v2",
407 | messages=[],
408 | offset=0,
409 | sep_style=SeparatorStyle.LLAMA_2,
410 | sep="",
411 | sep2="",
412 | )
413 |
414 | conv_llava_llama_2_simple = Conversation(
415 | system="Answer the questions about the visual content that the user provides.",
416 | roles=("USER", "ASSISTANT"),
417 | version="llama_v2",
418 | messages=[],
419 | offset=0,
420 | sep_style=SeparatorStyle.LLAMA_2,
421 | sep="",
422 | sep2=" ",
423 | )
424 |
425 | conv_llava_llama_2_mmtag = Conversation(
426 | system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: visual content .",
427 | roles=("USER", "ASSISTANT"),
428 | version="llama_v2_mmtag",
429 | messages=[],
430 | offset=0,
431 | sep_style=SeparatorStyle.LLAMA_2,
432 | sep="",
433 | sep2=" ",
434 | )
435 |
436 | conv_mpt = Conversation(
437 | system="""<|im_start|>system
438 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
439 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
440 | version="mpt",
441 | messages=[],
442 | offset=0,
443 | sep_style=SeparatorStyle.MPT,
444 | sep="<|im_end|>",
445 | )
446 |
447 | conv_qwen = Conversation(
448 | system="""<|im_start|>system
449 | You are a helpful assistant.""",
450 | roles=("<|im_start|>user", "<|im_start|>assistant"),
451 | version="qwen",
452 | messages=[],
453 | offset=0,
454 | sep_style=SeparatorStyle.CHATML,
455 | sep="<|im_end|>",
456 | )
457 |
458 | conv_gemma_instruct = Conversation(system="", roles=("user\n", "model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="\n")
459 |
460 | conv_llava_plain = Conversation(
461 | system="",
462 | roles=("", ""),
463 | messages=[],
464 | offset=0,
465 | sep_style=SeparatorStyle.PLAIN,
466 | sep="\n",
467 | )
468 |
469 | conv_llava_v0 = Conversation(
470 | system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
471 | roles=("Human", "Assistant"),
472 | messages=[],
473 | offset=0,
474 | sep_style=SeparatorStyle.SINGLE,
475 | sep="###",
476 | )
477 |
478 | conv_llava_v0_mmtag = Conversation(
479 | system="A chat between a curious user and an artificial intelligence assistant. "
480 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
481 | "The visual content will be provided with the following format: visual content .",
482 | roles=("Human", "Assistant"),
483 | messages=[],
484 | offset=0,
485 | sep_style=SeparatorStyle.SINGLE,
486 | sep="###",
487 | version="v0_mmtag",
488 | )
489 |
490 | conv_llava_v1 = Conversation(
491 | system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
492 | roles=("USER", "ASSISTANT"),
493 | version="v1",
494 | messages=[],
495 | offset=0,
496 | sep_style=SeparatorStyle.TWO,
497 | sep=" ",
498 | sep2="",
499 | )
500 |
501 | conv_llava_v1_mmtag = Conversation(
502 | system="A chat between a curious user and an artificial intelligence assistant. "
503 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
504 | "The visual content will be provided with the following format: visual content .",
505 | roles=("USER", "ASSISTANT"),
506 | messages=[],
507 | offset=0,
508 | sep_style=SeparatorStyle.TWO,
509 | sep=" ",
510 | sep2="",
511 | version="v1_mmtag",
512 | )
513 |
514 | conv_mistral_orca = Conversation(
515 | system="""<|im_start|>system
516 | You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
517 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
518 | version="mpt",
519 | messages=[],
520 | offset=0,
521 | sep_style=SeparatorStyle.MPT,
522 | sep="<|im_end|>",
523 | )
524 |
525 | conv_mistral_zephyr = Conversation(
526 | system="""<|system|>
527 | You are a helpful AI assistant.""",
528 | roles=("<|user|>\n", "<|assistant|>\n"),
529 | version="mpt",
530 | messages=[],
531 | offset=0,
532 | sep_style=SeparatorStyle.MPT,
533 | sep="",
534 | )
535 |
536 | conv_mistral_direct = Conversation(
537 | system="""<|im_start|>system
538 | Answer the questions.""",
539 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
540 | version="mpt",
541 | messages=[],
542 | offset=0,
543 | sep_style=SeparatorStyle.MPT,
544 | sep="<|im_end|>",
545 | )
546 |
547 | conv_chatml_direct = Conversation(
548 | system="""<|im_start|>system
549 | Answer the questions.""",
550 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
551 | version="mpt",
552 | messages=[],
553 | offset=0,
554 | sep_style=SeparatorStyle.MPT,
555 | sep="<|im_end|>",
556 | )
557 |
558 | default_conversation = conv_vicuna_v0
559 | conv_templates = {
560 | "default": conv_vicuna_v0,
561 | "v0": conv_vicuna_v0,
562 | "v1": conv_vicuna_v1,
563 | "vicuna_v1": conv_vicuna_v1,
564 | "llama_2": conv_llama_2,
565 | "mistral_instruct": conv_mistral_instruct,
566 | "mistral_orca": conv_mistral_orca,
567 | "mistral_zephyr": conv_mistral_zephyr,
568 | "mistral_direct": conv_mistral_direct,
569 | "plain": conv_llava_plain,
570 | "v0_plain": conv_llava_plain,
571 | "chatml_direct": conv_chatml_direct,
572 | "llava_v0": conv_llava_v0,
573 | "llava_v0_mmtag": conv_llava_v0_mmtag,
574 | "llava_v1": conv_llava_v1,
575 | "llava_v1_mmtag": conv_llava_v1_mmtag,
576 | "llava_llama_2": conv_llava_llama_2,
577 | "llava_llama_3": conv_llava_llama_3,
578 | "llava_llama_31": conv_llava_llama_31,
579 | "llava_llama_2_simple": conv_llava_llama_2_simple,
580 | "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
581 | "llava_mistral_instruct": conv_mistral_instruct,
582 | "mpt": conv_mpt,
583 | "qwen_1_5": conv_qwen,
584 | "qwen_2": conv_qwen,
585 | "gemma_instruct": conv_gemma_instruct,
586 | }
587 |
588 |
589 | if __name__ == "__main__":
590 | print(default_conversation.get_prompt())
591 |
--------------------------------------------------------------------------------
/humanomni/eval/eval_mafw_dfew.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 | import torch
9 | import requests
10 |
11 | from tqdm import tqdm
12 | from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
13 | from transformers.pipelines.audio_utils import ffmpeg_read
14 | from sklearn.metrics import accuracy_score
15 | from sklearn.metrics import confusion_matrix, recall_score
16 | import numpy as np
17 |
18 | from humanomni import model_init, mm_infer
19 |
20 | ds_collections = {
21 | 'emotion': {'path': '/mnt/data/qize.yqz/datasets/human/annos/1021_val_MAFW_DFEW_it_without_tag.json'}
22 | }
23 |
24 | from transformers import BertModel, BertTokenizer
25 | bert_model = "bert-base-uncased"
26 | bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
27 |
28 | def weighted_average_recall(y_true, y_pred):
29 | unique_classes = np.unique(y_true)
30 | recalls = recall_score(y_true, y_pred, average=None, labels=unique_classes)
31 | print(f"{'cls':<12} | {'recall':<20}")
32 | for cls,recall in zip(unique_classes,recalls):
33 | print(f"{cls:<12} | {recall:<20.15f}")
34 | weights = [np.sum(np.array(y_true) == cls) for cls in unique_classes]
35 | total_samples = len(y_true)
36 | weights = np.array(weights) / total_samples
37 | mm=confusion_matrix(y_true, y_pred)
38 | print(mm)
39 |
40 | war = np.sum(weights * recalls)
41 | return war*100
42 |
43 | def unweighted_average_recall(y_true, y_pred):
44 | recalls = recall_score(y_true, y_pred, average=None)
45 | uar = np.mean(recalls)
46 | return uar*100
47 |
48 | class AudioDataset(torch.utils.data.Dataset):
49 |
50 | def __init__(self, ds):
51 | path = ds['path']
52 | self.datas = json.load(open(path))
53 |
54 | def __len__(self):
55 | return len(self.datas)
56 |
57 | def __getitem__(self, idx):
58 | data = self.datas[idx]
59 | video = data['video']
60 | clip_meta_path = data['clip_meta_path']
61 |
62 | # prompt = data['conversations'][0]['value'].replace("\n\n", "")
63 | gt = data['conversations'][1]['value']
64 | source = 'mafw' if "MAFW" in video else 'dfew'
65 | base_prompt = "As an emotional recognition expert, in the video, when the characters display their emotions, which predominant feeling is most clearly expressed?\n"
66 | options_dfew = "happy ,surprise ,neutral ,angry ,disgust ,sad ,fear"
67 | options_mafw = "happy ,surprise ,neutral ,angry ,disgust ,sad ,fear ,contemptuous, disappointed, helpless, anxious"
68 | prompt = base_prompt+options_mafw if "MAFW" in video else base_prompt+options_dfew
69 | return {
70 | 'video': video,
71 | 'prompt': prompt,
72 | 'gt': gt,
73 | 'clip_meta_path': clip_meta_path,
74 | 'source': source
75 | }
76 |
77 |
78 | def collate_fn(inputs, processor):
79 | input_texts = [_['prompt'] for _ in inputs]
80 | source = [_['source'] for _ in inputs]
81 | gt = [_['gt'] for _ in inputs]
82 | input_videos = [_['video'] for _ in inputs]
83 | input_allinone = [ _['clip_meta_path'] for _ in inputs]
84 |
85 | return input_texts, input_videos, input_allinone, gt, source
86 |
87 |
88 | class InferenceSampler(torch.utils.data.sampler.Sampler):
89 |
90 | def __init__(self, size):
91 | self._size = int(size)
92 | assert size > 0
93 | self._rank = torch.distributed.get_rank()
94 | self._world_size = torch.distributed.get_world_size()
95 | self._local_indices = self._get_local_indices(size, self._world_size,
96 | self._rank)
97 | @staticmethod
98 | def _get_local_indices(total_size, world_size, rank):
99 | shard_size = total_size // world_size
100 | left = total_size % world_size
101 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
102 |
103 | begin = sum(shard_sizes[:rank])
104 | end = min(sum(shard_sizes[:rank + 1]), total_size)
105 | return range(begin, end)
106 |
107 | def __iter__(self):
108 | yield from self._local_indices
109 |
110 | def __len__(self):
111 | return len(self._local_indices)
112 |
113 |
114 | if __name__ == '__main__':
115 |
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio-7B')
118 | parser.add_argument('--dataset', type=str, default='')
119 | parser.add_argument('--batch-size', type=int, default=1)
120 | parser.add_argument('--num-workers', type=int, default=1)
121 | parser.add_argument('--seed', type=int, default=0)
122 | args = parser.parse_args()
123 |
124 | torch.distributed.init_process_group(
125 | backend='nccl',
126 | world_size=int(os.getenv('WORLD_SIZE', '1')),
127 | rank=int(os.getenv('RANK', '0')),
128 | )
129 |
130 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
131 |
132 | model, processor, tokenizer = model_init(args.checkpoint, device_map='cuda')
133 |
134 |
135 |
136 |
137 | random.seed(args.seed)
138 | dataset = AudioDataset(
139 | ds=ds_collections[args.dataset],
140 | )
141 | data_loader = torch.utils.data.DataLoader(
142 | dataset=dataset,
143 | sampler=InferenceSampler(len(dataset)),
144 | batch_size=args.batch_size,
145 | num_workers=args.num_workers,
146 | pin_memory=True,
147 | drop_last=False,
148 | collate_fn=partial(collate_fn, processor=processor),
149 | )
150 |
151 | gts = []
152 | sources = []
153 | rets = []
154 | video_paths = []
155 |
156 |
157 |
158 | for _, (inputs, video_path, allinones, gt, source) in tqdm(enumerate(data_loader)):
159 | audio_tensor = processor["audio"](video_path[0])[0]
160 | video_tensor = processor["video"](video_path[0])
161 | # print(audio_tensor.size(), video_tensor.size())
162 | output = mm_infer(
163 | image_or_video=video_tensor,
164 | instruct=inputs[0],
165 | model=model,
166 | tokenizer=tokenizer,
167 | audio=audio_tensor,
168 | modal='video_audio',
169 | do_sample=False,
170 | question=inputs[0],
171 | bert_tokeni=bert_tokenizer
172 | )
173 | print(inputs[0], video_path[0], output, gt[0])
174 | gts.extend(gt)
175 | rets.append(output)
176 | sources.extend(source)
177 | video_paths.extend(video_path)
178 |
179 | torch.distributed.barrier()
180 |
181 | world_size = torch.distributed.get_world_size()
182 | merged_gts = [None for _ in range(world_size)]
183 | merged_sources = [None for _ in range(world_size)]
184 | merged_responses = [None for _ in range(world_size)]
185 | merged_video_paths = [None for _ in range(world_size)]
186 |
187 | torch.distributed.all_gather_object(merged_gts, gts)
188 | torch.distributed.all_gather_object(merged_sources, sources)
189 | torch.distributed.all_gather_object(merged_responses, rets)
190 | torch.distributed.all_gather_object(merged_video_paths, video_paths)
191 |
192 | merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)]
193 | merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)]
194 | merged_video_paths = [_ for _ in itertools.chain.from_iterable(merged_video_paths)]
195 | merged_responses = [
196 | _ for _ in itertools.chain.from_iterable(merged_responses)
197 | ]
198 |
199 | if torch.distributed.get_rank() == 0:
200 | print(f"Evaluating {args.dataset} ...")
201 |
202 | results = []
203 | for gt, response, source, video_path in zip(merged_gts, merged_responses, merged_sources, merged_video_paths):
204 | results.append({
205 | 'gt': gt,
206 | 'response': response,
207 | 'source': source,
208 | 'video_path': video_path
209 | })
210 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
211 | results_file = f'{args.dataset}_{time_prefix}.json'
212 | json.dump(results, open(results_file, 'w'))
213 | results_dict = {}
214 | for item in tqdm(results):
215 | source = item["source"]
216 | results_dict.setdefault(source, []).append(item)
217 |
218 | for source in results_dict:
219 | refs, hyps = [], []
220 | results_list = results_dict[source]
221 | for result in results_list:
222 | gt = result["gt"]
223 | response = result["response"].lstrip()
224 | refs.append(gt)
225 | hyps.append(response)
226 | score = accuracy_score(refs, hyps)
227 | war = weighted_average_recall(refs, hyps)
228 | uar = unweighted_average_recall(refs, hyps)
229 | print(f"{source} acc: {score:.2f}%\t war: {war:.2f}% \t uar: {uar:.2f}% len:{len(hyps)}")
230 |
231 |
232 | torch.distributed.barrier()
233 |
234 | """
235 |
236 | python -m torch.distributed.launch --use_env --master_port=29501 --nproc_per_node 8 --nnodes 1 \
237 | humanomni/eval/eval_mafw_dfew.py \
238 | --checkpoint HumanOmni_7B/ \
239 | --dataset emotion
240 | """
--------------------------------------------------------------------------------
/humanomni/eval/eval_video_mcqa_mvbench.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from tabulate import tabulate
4 | import ipdb
5 |
6 | tasks = {
7 | "Action Sequence": ("action_sequence.json", "star/Charades_v1_480/", "video", True), # has start & end
8 | "Action Prediction": ("action_prediction.json", "star/Charades_v1_480/", "video", True), # has start & end
9 | "Action Antonym": ("action_antonym.json", "ssv2_video/", "video", False),
10 | "Fine-grained Action": ("fine_grained_action.json", "Moments_in_Time_Raw/videos/", "video", False),
11 | "Unexpected Action": ("unexpected_action.json", "FunQA_test/test/", "video", False),
12 | "Object Existence": ("object_existence.json", "clevrer/video_validation/", "video", False),
13 | "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
14 | "Object Shuffle": ("object_shuffle.json", "perception/videos/", "video", False),
15 | "Moving Direction": ("moving_direction.json", "clevrer/video_validation/", "video", False),
16 | "Action Localization": ("action_localization.json", "sta/sta_video/", "video", True), # has start & end
17 | "Scene Transition": ("scene_transition.json", "scene_qa/video/", "video", False),
18 | "Action Count": ("action_count.json", "perception/videos/", "video", False),
19 | "Moving Count": ("moving_count.json", "clevrer/video_validation/", "video", False),
20 | "Moving Attribute": ("moving_attribute.json", "clevrer/video_validation/", "video", False),
21 | "State Change": ("state_change.json", "perception/videos/", "video", False),
22 | "Fine-grained Pose": ("fine_grained_pose.json", "nturgbd/", "video", False),
23 | "Character Order": ("character_order.json", "perception/videos/", "video", False),
24 | "Egocentric Navigation": ("egocentric_navigation.json", "vlnqa/", "video", False),
25 | "Episodic Reasoning": ("episodic_reasoning.json", "tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame
26 | "Counterfactual Inference": ("counterfactual_inference.json", "clevrer/video_validation/", "video", False),
27 | }
28 |
29 |
30 | def main():
31 | args = parse_args()
32 | res = [eval(x.strip()) for x in open(args.pred_path, 'r').readlines()]
33 | task_types = tasks.keys()
34 | task_acc = {x: [] for x in task_types}
35 | acc = []
36 | for i, x in enumerate(res):
37 | value = 1
38 | # print(x)
39 | if x['pred'] != x['gt']:
40 | value = 0
41 | acc.append(value)
42 | task_acc[x['task_type']].append(value)
43 | acc = sum(acc) * 100 / len(acc)
44 | print(task_acc)
45 | task_acc = {x: sum(task_acc[x]) * 100 / len(task_acc[x]) for x in task_acc}
46 | print(f"{args.pred_path}:", acc)
47 | task_names = list(task_acc.keys())
48 |
49 | table_data = []
50 | for i in range(len(task_names) // 4):
51 | row_task_names = task_names[i * 4: (i + 1) * 4]
52 | row_task_acc = [task_acc[x] for x in row_task_names]
53 | table_data.append(row_task_names)
54 | table_data.append(row_task_acc)
55 | print(tabulate(table_data, floatfmt=".1f"), '\n')
56 |
57 |
58 |
59 | def parse_args():
60 | parser = argparse.ArgumentParser(description="Evaluate video captioning.")
61 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
62 | args = parser.parse_args()
63 | return args
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/humanomni/eval/inference_dfec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import math
4 | import json
5 | import argparse
6 | import warnings
7 | import traceback
8 |
9 | import torch
10 | import numpy as np
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from decord import VideoReader, cpu
14 | from torch.utils.data import Dataset, DataLoader
15 | import random
16 | import sys
17 | sys.path.append('./')
18 | from videollama2 import model_init, mm_infer
19 | from videollama2.utils import disable_torch_init
20 | withaudio= True
21 | # NOTE: Ignore TypedStorage warning, which refers to this link~(https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560)
22 | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
23 |
24 | from transformers import BertModel, BertTokenizer
25 | bert_model = "bert-base-uncased"
26 | bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
27 | def split_list(lst, n):
28 | """Split a list into n (roughly) equal-sized chunks"""
29 | chunk_size = math.ceil(len(lst) / n) # integer division
30 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
31 |
32 |
33 | def get_chunk(lst, n, k):
34 | chunks = split_list(lst, n)
35 | return chunks[k]
36 |
37 |
38 | class MVBenchDataset(Dataset):
39 |
40 | def __init__(self, data_list, question_list, processor, processor_audio=None):
41 | self.data_list = data_list
42 | self.processor = processor
43 | self.question_list = question_list
44 | self.processor_audio = processor_audio
45 |
46 | def __len__(self):
47 | return len(self.data_list)
48 |
49 | def __getitem__(self, idx):
50 | bound = (None, None)
51 | video_path = self.data_list[idx]
52 | video_path = os.path.join('/mnt/data/jiaxing.zjx/datasets/DFEC_CVPR/test/', video_path)
53 | try:
54 | torch_imgs = self.processor(video_path, s=bound[0], e=bound[1])
55 | if self.processor_audio is not None:
56 | audio = self.processor_audio(video_path, s=bound[0], e=bound[1])
57 | except Exception as e:
58 | backup_idx = random.randint(0, len(self.data_list)-1)
59 | print(f"Encounted error when reading video {video_path}, use {backup_idx}-th example instead!!!")
60 | return self.__getitem__(backup_idx)
61 | torch_imgs = self.processor(video_path, s=bound[0], e=bound[1])
62 | question = self.question_list[idx]
63 | if self.processor_audio is not None:
64 | audio = self.processor_audio(video_path, s=bound[0], e=bound[1])
65 | else:
66 | audio = None
67 | return {
68 | 'video': torch_imgs,
69 | 'video_path': video_path,
70 | 'instruction': question,
71 | 'question': question,
72 | 'audio': audio
73 | }
74 |
75 | def load_file(test_file):
76 | with open(test_file, 'r') as f:
77 | datas = json.load(f)
78 | video_paths = [data['video'] for data in datas]
79 | # question = "Please provide a detailed description of the facial appearance attributes and expression changes of the character in the video, including their expression state at the beginning and end of the video."
80 | question = "Please provide a detailed description of the facial appearance attributes and expression changes of the character in the video"
81 | questions = [question] * len(video_paths)
82 | return video_paths, questions
83 |
84 | def build_mvbench_eval(args, processor, processor_audio=None):
85 | video_paths, questions = load_file(args.question_file)
86 | dataset = MVBenchDataset(video_paths, questions, processor, processor_audio)
87 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=1)
88 |
89 | return dataloader
90 |
91 |
92 | def mvbench_dump(vid, instruct, letters, options, output):
93 |
94 | output = output.replace('answer', '')
95 | output = output.replace('Answer', '')
96 | pred_answer = re.findall(f'[\(,\ ]*[{letters[0]}-{letters[-1]}][\),\ ]*', output)
97 | try:
98 | find_flag = False
99 | if len(pred_answer) == 0:
100 | for idx, opt in enumerate(options):
101 | # Arabic numerals -> English words
102 | if opt.lower() in output.lower():
103 | pred_idx = idx
104 | find_flag = True
105 | break
106 | else:
107 | pred_answer = pred_answer[0].strip()
108 | pred_answer = pred_answer.strip('()')
109 | pred_idx = letters.index(pred_answer)
110 | find_flag = True
111 |
112 | assert find_flag, 'The video \"{}\" instruct: \n\"{}\"\n output: \n\"{}\"\n is not in the expected format'.format(vid, instruct, output)
113 | except:
114 | traceback.print_exc()
115 | pred_idx = 2
116 |
117 | return pred_idx
118 |
119 |
120 | def run_inference(args):
121 | disable_torch_init()
122 |
123 | model, processor, tokenizer = model_init(args.model_path)
124 |
125 |
126 | if withaudio:
127 | val_loader = build_mvbench_eval(args, processor['video'], processor['audio'])
128 | else:
129 | val_loader = build_mvbench_eval(args, processor['video'])
130 | results = []
131 | # NOTE: only support batch size 1 for now
132 | for i, line in enumerate(tqdm(val_loader)):
133 | video_tensor = line['video'][0].to(args.device)
134 | question = line['question'][0]
135 | instruct = line['instruction'][0]
136 | video_path = line['video_path'][0]
137 |
138 | if withaudio:
139 | audio = line['audio'][0]
140 | else:
141 | audio = None
142 |
143 | if withaudio:
144 | output = mm_infer(video_tensor, instruct, model=model, tokenizer=tokenizer, modal='video_audio', question=question,bert_tokeni=bert_tokenizer,do_sample=False, audio=audio)
145 | else:
146 | output = mm_infer(video_tensor, instruct, model=model, tokenizer=tokenizer, modal='video', question=question,bert_tokeni=bert_tokenizer,do_sample=False, audio=audio)
147 |
148 |
149 | result = {"video_path": video_path.replace('/mnt/data/jiaxing.zjx/datasets/DFEC_CVPR/test/', ''), "instruction": question, "output": output}
150 | results.append(result)
151 |
152 | with open(args.answer_file, 'w', encoding='utf-8') as f:
153 | json.dump(results, f, ensure_ascii=False, indent=4)
154 |
155 | if __name__ == "__main__":
156 |
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument('--model-path', required=True)
159 | parser.add_argument('--question-file', required=True)
160 | parser.add_argument('--answer-file', required=True)
161 | parser.add_argument("--device", type=str, default='cuda:0')
162 | parser.add_argument("--batch-size", type=int, default=1)
163 |
164 | args = parser.parse_args()
165 | run_inference(args)
166 |
--------------------------------------------------------------------------------
/humanomni/eval/inference_video_mcqa_mvbench.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import math
4 | import json
5 | import argparse
6 | import warnings
7 | import traceback
8 |
9 | import torch
10 | import numpy as np
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from decord import VideoReader, cpu
14 | from torch.utils.data import Dataset, DataLoader
15 | import random
16 | import sys
17 | sys.path.append('./')
18 | from humanomni import model_init, mm_infer
19 | from humanomni.utils import disable_torch_init
20 |
21 | withaudio= True
22 |
23 | # NOTE: Ignore TypedStorage warning, which refers to this link~(https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560)
24 | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
25 |
26 | from transformers import BertModel, BertTokenizer
27 | bert_model = "bert-base-uncased"
28 | bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
29 | def split_list(lst, n):
30 | """Split a list into n (roughly) equal-sized chunks"""
31 | chunk_size = math.ceil(len(lst) / n) # integer division
32 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
33 |
34 |
35 | def get_chunk(lst, n, k):
36 | chunks = split_list(lst, n)
37 | return chunks[k]
38 |
39 |
40 | class MVBenchDataset(Dataset):
41 |
42 | def __init__(self, data_list, processor, processor_audio=None):
43 | self.data_list = data_list
44 | self.processor = processor
45 |
46 | self.processor_audio = processor_audio
47 |
48 | def __len__(self):
49 | return len(self.data_list)
50 |
51 | def __getitem__(self, idx):
52 |
53 | bound = (None, None)
54 | if self.data_list[idx]['bound']:
55 | bound = (self.data_list[idx]['data']['start'], self.data_list[idx]['data']['end'])
56 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
57 | try:
58 | torch_imgs = self.processor(video_path, s=bound[0], e=bound[1])
59 | if self.processor_audio is not None:
60 | audio = self.processor_audio(video_path, s=bound[0], e=bound[1])
61 | except Exception as e:
62 | backup_idx = random.randint(0, len(self.data_list)-1)
63 | print(f"Encounted error when reading video {video_path}, use {backup_idx}-th example instead!!!")
64 | return self.__getitem__(backup_idx)
65 | torch_imgs = self.processor(video_path, s=bound[0], e=bound[1])
66 | question = self.data_list[idx]['data']['question']
67 | options = self.data_list[idx]['data']['candidates']
68 | answer = self.data_list[idx]['data']['answer']
69 | task_type = self.data_list[idx]['task_type']
70 | if self.processor_audio is not None:
71 | audio = self.processor_audio(video_path, s=bound[0], e=bound[1])
72 | else:
73 | audio = None
74 |
75 |
76 | # 原有的多选题逻辑
77 | answer_idx = -1
78 | letters = []
79 | options_string = ''
80 | for option_idx, c in enumerate(options):
81 | letters.append(f"{chr(ord('A') + option_idx)}")
82 | options_string += f"({chr(ord('A') + option_idx)}) {c}\n"
83 | if c == answer:
84 | answer_idx = option_idx
85 |
86 | instruct = f'Question: {question}\nOptions:\n{options_string}Answer with the option\'s letter from the given choices directly and only give the best option.'
87 | # instruct = "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.\n" + instruct
88 | if audio is not None:
89 | return {
90 | 'video': torch_imgs,
91 | 'video_path': video_path,
92 | 'instruct': instruct,
93 | 'letters': letters,
94 | 'options': options,
95 | 'answer_idx': answer_idx,
96 | 'task_type': task_type,
97 | 'question': question,
98 | 'audio': audio
99 | }
100 | else:
101 | return {
102 | 'video': torch_imgs,
103 | 'video_path': video_path,
104 | 'instruct': instruct,
105 | 'letters': letters,
106 | 'options': options,
107 | 'answer_idx': answer_idx,
108 | 'task_type': task_type,
109 | 'question': question
110 | }
111 |
112 |
113 |
114 | tasks = {
115 | "Action Sequence": ("action_sequence.json", "star/Charades_v1_480/", "video", True), # has start & end
116 | "Action Prediction": ("action_prediction.json", "star/Charades_v1_480/", "video", True), # has start & end
117 | "Action Antonym": ("action_antonym.json", "ssv2_video/", "video", False),
118 | "Fine-grained Action": ("fine_grained_action.json", "Moments_in_Time_Raw/videos/", "video", False),
119 | "Unexpected Action": ("unexpected_action.json", "FunQA_test/test/", "video", False),
120 | "Object Existence": ("object_existence.json", "clevrer/video_validation/", "video", False),
121 | "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
122 | "Object Shuffle": ("object_shuffle.json", "perception/videos/", "video", False),
123 | "Moving Direction": ("moving_direction.json", "clevrer/video_validation/", "video", False),
124 | "Action Localization": ("action_localization.json", "sta/sta_video/", "video", True), # has start & end
125 | "Scene Transition": ("scene_transition.json", "scene_qa/video/", "video", False),
126 | "Action Count": ("action_count.json", "perception/videos/", "video", False),
127 | "Moving Count": ("moving_count.json", "clevrer/video_validation/", "video", False),
128 | "Moving Attribute": ("moving_attribute.json", "clevrer/video_validation/", "video", False),
129 | "State Change": ("state_change.json", "perception/videos/", "video", False),
130 | "Fine-grained Pose": ("fine_grained_pose.json", "nturgbd/", "video", False),
131 | "Character Order": ("character_order.json", "perception/videos/", "video", False),
132 | "Egocentric Navigation": ("egocentric_navigation.json", "vlnqa/", "video", False),
133 | "Episodic Reasoning": ("episodic_reasoning.json", "tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame
134 | "Counterfactual Inference": ("counterfactual_inference.json", "clevrer/video_validation/", "video", False),
135 | }
136 | # tasks = {
137 | # "Action Sequence": ("action_sequence.json", "star/Charades_v1_480/", "video", True), # has start & end
138 | # "Action Prediction": ("action_prediction.json", "star/Charades_v1_480/", "video", True), # has start & end
139 | # "Action Antonym": ("action_antonym.json", "ssv2_video/", "video", False),
140 | # "Fine-grained Action": ("fine_grained_action.json", "Moments_in_Time_Raw/videos/", "video", False),
141 | # "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
142 | # "Action Localization": ("action_localization.json", "sta/sta_video/", "video", True), # has start & end
143 | # "Action Count": ("action_count.json", "perception/videos/", "video", False),
144 | # "Fine-grained Pose": ("fine_grained_pose.json", "nturgbd/", "video", False),
145 | # }
146 | # tasks = {
147 | # "Moving Direction": ("moving_direction.json", "clevrer/video_validation/", "video", False),
148 | # "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
149 | # "Object Interaction_only_replace_candidates_withfix": ("object_interaction_only_replace_candidates_withfix.json", "star/Charades_v1_480/", "video", True), # has start & end
150 | # "Object Interaction_only_shuffle_order": ("object_interaction_only_shuffle_order.json", "star/Charades_v1_480/", "video", True), # has start & end
151 | # "Object Interaction_shuffle_and_replace": ("object_interaction_shuffle_and_replace.json", "star/Charades_v1_480/", "video", True), # has start & end
152 | #}
153 |
154 |
155 |
156 | def build_mvbench_eval(args, processor, processor_audio=None):
157 | data_list = []
158 | for task_name, task in tasks.items():
159 | json_file = os.path.join(args.question_file, task[0])
160 | vis_folder = os.path.join(args.video_folder, task[1])
161 | with open(json_file, 'r') as f:
162 | json_data = json.load(f)
163 | for data in json_data:
164 | data_list.append({
165 | 'task_type': task_name,
166 | 'prefix': vis_folder,
167 | 'data_type': task[2],
168 | 'bound': task[3],
169 | 'data': data
170 | })
171 | data_list = get_chunk(data_list, args.num_chunks, args.chunk_idx)
172 | dataset = MVBenchDataset(data_list, processor, processor_audio)
173 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
174 |
175 | return dataloader
176 |
177 |
178 | def mvbench_dump(vid, instruct, letters, options, output):
179 |
180 | output = output.replace('answer', '')
181 | output = output.replace('Answer', '')
182 | pred_answer = re.findall(f'[\(,\ ]*[{letters[0]}-{letters[-1]}][\),\ ]*', output)
183 | try:
184 | find_flag = False
185 | if len(pred_answer) == 0:
186 | for idx, opt in enumerate(options):
187 | # Arabic numerals -> English words
188 | if opt.lower() in output.lower():
189 | pred_idx = idx
190 | find_flag = True
191 | break
192 | else:
193 | pred_answer = pred_answer[0].strip()
194 | pred_answer = pred_answer.strip('()')
195 | pred_idx = letters.index(pred_answer)
196 | find_flag = True
197 |
198 | assert find_flag, 'The video \"{}\" instruct: \n\"{}\"\n output: \n\"{}\"\n is not in the expected format'.format(vid, instruct, output)
199 | except:
200 | traceback.print_exc()
201 | pred_idx = 2
202 |
203 | return pred_idx
204 |
205 |
206 | def run_inference(args):
207 |
208 | disable_torch_init()
209 |
210 | model, processor, tokenizer = model_init(args.model_path)
211 |
212 | answer_file = os.path.expanduser(args.answer_file)
213 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
214 | ans_file = open(answer_file, "w")
215 |
216 | if withaudio:
217 | val_loader = build_mvbench_eval(args, processor['video'], processor['audio'])
218 | else:
219 | val_loader = build_mvbench_eval(args, processor['video'])
220 |
221 | # NOTE: only support batch size 1 for now
222 | for i, line in enumerate(tqdm(val_loader)):
223 | vid = line['video_path'][0]
224 | if "qwen2vit" in args.model_path:
225 | video_tensor = line['video']
226 | else:
227 | video_tensor = line['video'][0]
228 |
229 | if withaudio:
230 | audio = line['audio'][0]
231 | else:
232 | audio = None
233 | task_type = line['task_type'][0]
234 | instruct = line['instruct'][0]
235 | question = line['question'][0]
236 |
237 | # 原有的多选题逻辑
238 | letters = list(zip(*line['letters']))[0]
239 | options = list(zip(*line['options']))[0]
240 | answer_idx = line['answer_idx'][0].item()
241 |
242 | if withaudio:
243 | output = mm_infer(video_tensor, instruct, model=model, tokenizer=tokenizer, modal='video_audio', question=question,bert_tokeni=bert_tokenizer,do_sample=False, audio=audio)
244 | else:
245 | output = mm_infer(video_tensor, instruct, model=model, tokenizer=tokenizer, modal='video', question=question,bert_tokeni=bert_tokenizer,do_sample=False, audio=audio)
246 |
247 |
248 | pred_idx = mvbench_dump(vid, instruct, letters, options, output)
249 |
250 | ans_file.write(json.dumps({"vid": vid, "task_type": task_type, "pred": pred_idx, "gt": answer_idx}) + '\n')
251 |
252 | ans_file.close()
253 |
254 |
255 | if __name__ == "__main__":
256 | parser = argparse.ArgumentParser()
257 |
258 | parser.add_argument('--model-path', help='', required=True)
259 | parser.add_argument('--video-folder', help='Directory containing video files.', required=True)
260 | parser.add_argument('--question-file', help='Path to the ground truth file containing question.', required=True)
261 | parser.add_argument('--answer-file', help='Path to the ground truth file containing answers.', required=True)
262 | parser.add_argument("--num-chunks", type=int, default=1)
263 | parser.add_argument("--chunk-idx", type=int, default=0)
264 | parser.add_argument("--device", type=str, required=False, default='cuda:0')
265 | parser.add_argument("--batch-size", type=int, default=1)
266 | parser.add_argument("--num-workers", type=int, default=8)
267 | args = parser.parse_args()
268 |
269 | run_inference(args)
270 |
--------------------------------------------------------------------------------
/humanomni/humanomni_trainer.py:
--------------------------------------------------------------------------------
1 | # Adopted from: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py
2 | import os
3 | import logging
4 | from typing import List, Optional
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.utils.data import Sampler
9 |
10 | from transformers import Trainer
11 | from transformers.trainer import (
12 | is_sagemaker_mp_enabled,
13 | get_parameter_names,
14 | has_length,
15 | ALL_LAYERNORM_LAYERS,
16 | logger,
17 | TRAINER_STATE_NAME,
18 | )
19 |
20 | # from trl_our.trainer import DPOTrainer
21 | # from trl_our.trainer.utils import DPODataCollatorWithPadding
22 |
23 | def maybe_zero_3(param, ignore_status=False, name=None):
24 | from deepspeed import zero
25 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
26 | if hasattr(param, "ds_id"):
27 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
28 | if not ignore_status:
29 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
30 | with zero.GatheredParameters([param]):
31 | param = param.data.detach().cpu().clone()
32 | else:
33 | param = param.detach().cpu().clone()
34 | return param
35 |
36 |
37 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
38 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
39 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
40 | return to_return
41 |
42 |
43 | # Borrowed from peft.utils.get_peft_model_state_dict
44 | def get_peft_state_maybe_zero_3(named_params, bias):
45 | if bias == "none":
46 | to_return = {k: t for k, t in named_params if "lora_" in k}
47 | elif bias == "all":
48 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
49 | elif bias == "lora_only":
50 | to_return = {}
51 | maybe_lora_bias = {}
52 | lora_bias_names = set()
53 | for k, t in named_params:
54 | if "lora_" in k:
55 | to_return[k] = t
56 | bias_name = k.split("lora_")[0] + "bias"
57 | lora_bias_names.add(bias_name)
58 | elif "bias" in k:
59 | maybe_lora_bias[k] = t
60 | for k, t in maybe_lora_bias:
61 | if bias_name in lora_bias_names:
62 | to_return[bias_name] = t
63 | else:
64 | raise NotImplementedError
65 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
66 | return to_return
67 |
68 |
69 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
70 | to_return = {k: t for k, t in named_params if "lora_" not in k}
71 | if require_grad_only:
72 | to_return = {k: t for k, t in to_return.items() if t.requires_grad}
73 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
74 | return to_return
75 |
76 |
77 | def find_all_linear_names(model):
78 | cls = torch.nn.Linear
79 | lora_module_names = set()
80 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler', 'vlm_att']
81 | for name, module in model.named_modules():
82 | if any(mm_keyword in name for mm_keyword in multimodal_keywords):
83 | continue
84 | if isinstance(module, cls):
85 | names = name.split('.')
86 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
87 |
88 | if 'lm_head' in lora_module_names: # needed for 16-bit
89 | lora_module_names.remove('lm_head')
90 | return list(lora_module_names)
91 |
92 |
93 | def safe_save_model_for_hf_trainer(trainer: Trainer,
94 | output_dir: str):
95 | """Collects the state dict and dump to disk."""
96 |
97 | if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
98 | check_only_save_mm_adapter_tunnable = True
99 | # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts
100 | elif hasattr(trainer.args, "mm_tunable_parts") and "language" not in trainer.args.mm_tunable_parts:
101 | check_only_save_mm_adapter_tunnable = True
102 | else:
103 | check_only_save_mm_adapter_tunnable = False
104 | trainer.accelerator.wait_for_everyone()
105 | torch.cuda.synchronize()
106 | # rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}")
107 | if check_only_save_mm_adapter_tunnable:
108 | # Only save Adapter
109 | keys_to_match = ["mm_projector", "vision_resampler", "image_newline"]
110 | # if trainer.model.get_input_embeddings().weight.requires_grad and (getattr(trainer.args, "use_im_start_end", False)
111 | # or getattr(trainer.args, "use_x_start_end", False)):
112 | # keys_to_match.extend(["embed_tokens", "embed_in"])
113 |
114 | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
115 |
116 | if trainer.model.get_audio_tower() is not None:
117 | keys_to_match = ["audio_projector"]
118 | audio_weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
119 |
120 |
121 | trainer.model.config.save_pretrained(output_dir)
122 |
123 | current_folder = output_dir.split("/")[-1]
124 | parent_folder = os.path.dirname(output_dir)
125 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
126 | if current_folder.startswith("checkpoint-"):
127 | mm_projector_folder = os.path.join(parent_folder, "mm_projector")
128 | os.makedirs(mm_projector_folder, exist_ok=True)
129 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
130 | if trainer.model.get_audio_tower() is not None:
131 | torch.save(audio_weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin").replace("mm_projector", "audio_projector"))
132 |
133 | else:
134 | torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
135 | if trainer.model.get_audio_tower() is not None:
136 | torch.save(audio_weight_to_save, os.path.join(output_dir, f"audio_projector.bin"))
137 | # return
138 |
139 | if trainer.deepspeed:
140 | torch.cuda.synchronize()
141 | trainer.save_model(output_dir)
142 | return
143 |
144 | state_dict = trainer.model.state_dict()
145 | if trainer.args.should_save:
146 | cpu_state_dict = {
147 | key: value.cpu()
148 | for key, value in state_dict.items()
149 | }
150 | del state_dict
151 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
152 |
153 |
154 | def split_to_even_chunks(indices, lengths, num_chunks):
155 | """
156 | Split a list of indices into `chunks` chunks of roughly equal lengths.
157 | """
158 |
159 | if len(indices) % num_chunks != 0:
160 | return [indices[i::num_chunks] for i in range(num_chunks)]
161 |
162 | num_indices_per_chunk = len(indices) // num_chunks
163 |
164 | chunks = [[] for _ in range(num_chunks)]
165 | chunks_lengths = [0 for _ in range(num_chunks)]
166 | for index in indices:
167 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
168 | chunks[shortest_chunk].append(index)
169 | chunks_lengths[shortest_chunk] += lengths[index]
170 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
171 | chunks_lengths[shortest_chunk] = float("inf")
172 |
173 | return chunks
174 |
175 |
176 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
177 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
178 | assert all(l != 0 for l in lengths), "Should not have zero length."
179 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
180 | # all samples are in the same modality
181 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
182 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
183 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
184 |
185 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
186 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
187 | megabatch_size = world_size * batch_size
188 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
189 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
190 |
191 | last_mm = mm_megabatches[-1]
192 | last_lang = lang_megabatches[-1]
193 | additional_batch = last_mm + last_lang
194 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
195 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
196 | megabatches = [megabatches[i] for i in megabatch_indices]
197 |
198 | if len(additional_batch) > 0:
199 | megabatches.append(sorted(additional_batch))
200 |
201 | return [i for megabatch in megabatches for i in megabatch]
202 |
203 |
204 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
205 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
206 | indices = torch.randperm(len(lengths), generator=generator)
207 | megabatch_size = world_size * batch_size
208 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
209 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
210 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
211 |
212 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
213 |
214 |
215 | class LengthGroupedSampler(Sampler):
216 | r"""
217 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
218 | keeping a bit of randomness.
219 | """
220 |
221 | def __init__(
222 | self,
223 | batch_size: int,
224 | world_size: int,
225 | lengths: Optional[List[int]] = None,
226 | generator=None,
227 | group_by_modality: bool = False,
228 | ):
229 | if lengths is None:
230 | raise ValueError("Lengths must be provided.")
231 |
232 | self.batch_size = batch_size
233 | self.world_size = world_size
234 | self.lengths = lengths
235 | self.generator = generator
236 | self.group_by_modality = group_by_modality
237 |
238 | def __len__(self):
239 | return len(self.lengths)
240 |
241 | def __iter__(self):
242 | if self.group_by_modality:
243 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
244 | else:
245 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
246 | return iter(indices)
247 |
248 |
249 | class HumanOmniTrainer(Trainer):
250 |
251 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
252 | if self.train_dataset is None or not has_length(self.train_dataset):
253 | return None
254 |
255 | if self.args.group_by_modality_length:
256 | lengths = self.train_dataset.modality_lengths
257 | return LengthGroupedSampler(
258 | self.args.train_batch_size,
259 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
260 | lengths=lengths,
261 | group_by_modality=True,
262 | )
263 | else:
264 | return super()._get_train_sampler()
265 |
266 | def create_optimizer(self):
267 | """
268 | Setup the optimizer.
269 |
270 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
271 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
272 | """
273 | if is_sagemaker_mp_enabled():
274 | return super().create_optimizer()
275 |
276 | opt_model = self.model
277 |
278 | if self.optimizer is None:
279 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
280 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
281 | lr_mapper = {}
282 |
283 | if self.args.mm_projector_lr is not None:
284 | lr_mapper["mm_projector"] = self.args.mm_projector_lr
285 | lr_mapper["audio_projector"] = self.args.mm_projector_lr
286 | lr_mapper["vision_resampler"] = self.args.mm_projector_lr
287 | if len(lr_mapper) > 0:
288 | special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
289 | optimizer_grouped_parameters = [
290 | {
291 | "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
292 | "weight_decay": self.args.weight_decay,
293 | },
294 | {
295 | "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
296 | "weight_decay": 0.0,
297 | },
298 | ]
299 | for module_keyword, lr in lr_mapper.items():
300 | module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
301 | optimizer_grouped_parameters.extend(
302 | [
303 | {
304 | "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)],
305 | "weight_decay": self.args.weight_decay,
306 | "lr": lr,
307 | },
308 | {
309 | "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)],
310 | "weight_decay": 0.0,
311 | "lr": lr,
312 | },
313 | ]
314 | )
315 | else:
316 | optimizer_grouped_parameters = [
317 | {
318 | "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
319 | "weight_decay": self.args.weight_decay,
320 | },
321 | {
322 | "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
323 | "weight_decay": 0.0,
324 | },
325 | ]
326 |
327 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
328 |
329 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
330 | if optimizer_cls.__name__ == "Adam8bit":
331 | import bitsandbytes
332 |
333 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
334 |
335 | skipped = 0
336 | for module in opt_model.modules():
337 | if isinstance(module, nn.Embedding):
338 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
339 | logger.info(f"skipped {module}: {skipped/2**20}M params")
340 | manager.register_module_override(module, "weight", {"optim_bits": 32})
341 | logger.debug(f"bitsandbytes: will optimize {module} in fp32")
342 | logger.info(f"skipped: {skipped/2**20}M params")
343 |
344 | return self.optimizer
345 |
346 | def _save_checkpoint(self, model, trial, metrics=None):
347 | if getattr(self.args, 'tune_mm_mlp_adapter', False) or (
348 | hasattr(self.args, "mm_tunable_parts") and "language" not in self.args.mm_tunable_parts):
349 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
350 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
351 |
352 | run_dir = self._get_output_dir(trial=trial)
353 | output_dir = os.path.join(run_dir, checkpoint_folder)
354 |
355 | # Only save Adapter
356 | keys_to_match = ['mm_projector', 'vision_resampler']
357 |
358 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
359 |
360 | if self.args.local_rank == 0 or self.args.local_rank == -1:
361 | self.model.config.save_pretrained(output_dir)
362 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
363 |
364 | if model.get_audio_tower() is not None:
365 |
366 | keys_to_match = ["audio_projector"]
367 |
368 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
369 |
370 | if self.args.local_rank == 0 or self.args.local_rank == -1:
371 | self.model.config.save_pretrained(output_dir)
372 | torch.save(weight_to_save, os.path.join(output_dir, f"audio_projector.bin"))
373 |
374 |
375 | # Save optimizer and scheduler
376 | self._save_optimizer_and_scheduler(output_dir)
377 | # Save RNG state
378 | self._save_rng_state(output_dir)
379 | self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
380 | self.args.distributed_state.wait_for_everyone()
381 | else:
382 | # NOTE: Supporting save complete lora checkpoint during training.
383 | if self.args.lora_enable:
384 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
385 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
386 |
387 | run_dir = self._get_output_dir(trial=trial)
388 | output_dir = os.path.join(run_dir, checkpoint_folder)
389 |
390 | state_dict = get_peft_state_maybe_zero_3(self.model.named_parameters(), self.args.lora_bias)
391 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters())
392 | if self.args.local_rank == 0 or self.args.local_rank == -1:
393 | # save for acquring `config.json`
394 | self.model.config.save_pretrained(output_dir)
395 | # save for acquring `adapter_config.json`, `adapter_model.bin`
396 | # self.model.save_pretrained(output_dir, state_dict=state_dict)
397 | torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
398 |
399 | # save for acquring lora adapter parameters & trainer states: `adapter_config.json`, `adapter_model.safetensors`
400 | super(HumanOmniTrainer, self)._save_checkpoint(model, trial, metrics)
401 | else:
402 | super(HumanOmniTrainer, self)._save_checkpoint(model, trial, metrics)
403 |
404 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
405 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
406 | pass
407 | else:
408 | super(HumanOmniTrainer, self)._save(output_dir, state_dict)
409 |
410 |
411 |
--------------------------------------------------------------------------------
/humanomni/mm_utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | import math
4 | import base64
5 | import traceback
6 | from io import BytesIO
7 | import io
8 | import cv2
9 | import torch
10 | import imageio
11 | import numpy as np
12 | from PIL import Image
13 | from decord import VideoReader, cpu, AudioReader
14 | from moviepy.editor import VideoFileClip
15 | from transformers import StoppingCriteria
16 | import random
17 | from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
18 | import concurrent.futures
19 | import ipdb
20 |
21 | def chunk_list(input_list, chunk_size):
22 | return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
23 |
24 |
25 | def load_image_from_base64(image):
26 | return Image.open(BytesIO(base64.b64decode(image)))
27 |
28 |
29 | def expand2square(pil_img, background_color):
30 | width, height = pil_img.size
31 | if width == height:
32 | return pil_img
33 | elif width > height:
34 | result = Image.new(pil_img.mode, (width, width), background_color)
35 | result.paste(pil_img, (0, (width - height) // 2))
36 | return result
37 | else:
38 | result = Image.new(pil_img.mode, (height, height), background_color)
39 | result.paste(pil_img, ((height - width) // 2, 0))
40 | return result
41 |
42 |
43 | def create_photo_grid(arr, rows=None, cols=None):
44 | """
45 | Create a photo grid from a 4D numpy array with shape [t, h, w, c].
46 |
47 | Parameters:
48 | arr (numpy.ndarray): Input array with shape [t, h, w, c].
49 | rows (int): Optional. Number of rows in the grid. If not set, it will be determined based on `cols` or the square root of `t`.
50 | cols (int): Optional. Number of columns in the grid. If not set, it will be determined based on `rows` or the square root of `t`.
51 |
52 | Returns:
53 | numpy.ndarray: A 3D numpy array representing the photo grid.
54 | """
55 |
56 | if isinstance(arr, list):
57 | if isinstance(arr[0], Image.Image):
58 | arr = np.stack([np.array(img) for img in arr])
59 | elif isinstance(arr[0], np.ndarray):
60 | arr = np.stack(arr)
61 | else:
62 | raise ValueError("Invalid input type. Expected list of Images or numpy arrays.")
63 |
64 | t, h, w, c = arr.shape
65 |
66 | # Calculate the number of rows and columns if not provided
67 | if rows is None and cols is None:
68 | rows = math.ceil(math.sqrt(t))
69 | cols = math.ceil(t / rows)
70 | elif rows is None:
71 | rows = math.ceil(t / cols)
72 | elif cols is None:
73 | cols = math.ceil(t / rows)
74 |
75 | # Check if the grid can hold all the images
76 | if rows * cols < t:
77 | raise ValueError(f"Not enough grid cells ({rows}x{cols}) to hold all images ({t}).")
78 |
79 | # Create the grid array with appropriate height and width
80 | grid_height = h * rows
81 | grid_width = w * cols
82 | grid = np.zeros((grid_height, grid_width, c), dtype=arr.dtype)
83 |
84 | # Fill the grid with images
85 | for i in range(t):
86 | row_idx = i // cols
87 | col_idx = i % cols
88 | grid[row_idx*h:(row_idx+1)*h, col_idx*w:(col_idx+1)*w, :] = arr[i]
89 |
90 | return grid
91 |
92 | def select_best_resolution(original_size, possible_resolutions):
93 | """
94 | Selects the best resolution from a list of possible resolutions based on the original size.
95 |
96 | Args:
97 | original_size (tuple): The original size of the image in the format (width, height).
98 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
99 |
100 | Returns:
101 | tuple: The best fit resolution in the format (width, height).
102 | """
103 | original_width, original_height = original_size
104 | best_fit = None
105 | max_effective_resolution = 0
106 | min_wasted_resolution = float("inf")
107 |
108 | for width, height in possible_resolutions:
109 | # Calculate the downscaled size to keep the aspect ratio
110 | scale = min(width / original_width, height / original_height)
111 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
112 |
113 | # Calculate effective and wasted resolutions
114 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
115 | wasted_resolution = (width * height) - effective_resolution
116 |
117 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
118 | max_effective_resolution = effective_resolution
119 | min_wasted_resolution = wasted_resolution
120 | best_fit = (width, height)
121 |
122 | return best_fit
123 | def resize_and_pad_image(image, target_resolution):
124 | """
125 | Resize and pad an image to a target resolution while maintaining aspect ratio.
126 |
127 | Args:
128 | image (PIL.Image.Image): The input image.
129 | target_resolution (tuple): The target resolution (width, height) of the image.
130 |
131 | Returns:
132 | PIL.Image.Image: The resized and padded image.
133 | """
134 | original_width, original_height = image.size
135 | target_width, target_height = target_resolution
136 |
137 | # Determine which dimension (width or height) to fill
138 | scale_w = target_width / original_width
139 | scale_h = target_height / original_height
140 |
141 | if scale_w < scale_h:
142 | # Width will be filled completely
143 | new_width = target_width
144 | new_height = min(math.ceil(original_height * scale_w), target_height)
145 | else:
146 | # Height will be filled completely
147 | new_height = target_height
148 | new_width = min(math.ceil(original_width * scale_h), target_width)
149 |
150 | # Resize the image
151 | resized_image = image.resize((new_width, new_height))
152 |
153 | # Create a new image with the target size and paste the resized image onto it
154 | new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
155 | paste_x = (target_width - new_width) // 2
156 | paste_y = (target_height - new_height) // 2
157 | new_image.paste(resized_image, (paste_x, paste_y))
158 |
159 | return new_image
160 | def divide_to_patches(image, patch_size):
161 | """
162 | Divides an image into patches of a specified size.
163 |
164 | Args:
165 | image (PIL.Image.Image): The input image.
166 | patch_size (int): The size of each patch.
167 |
168 | Returns:
169 | list: A list of PIL.Image.Image objects representing the patches.
170 | """
171 | patches = []
172 | width, height = image.size
173 | for i in range(0, height, patch_size):
174 | for j in range(0, width, patch_size):
175 | box = (j, i, j + patch_size, i + patch_size)
176 | patch = image.crop(box)
177 | patches.append(patch)
178 |
179 | return patches
180 | def process_anyres_image(image, processor, grid_pinpoints):
181 | """
182 | Process an image with variable resolutions.
183 |
184 | Args:
185 | image (PIL.Image.Image): The input image to be processed.
186 | processor: The image processor object.
187 | grid_pinpoints (str): A string representation of a list of possible resolutions.
188 |
189 | Returns:
190 | torch.Tensor: A tensor containing the processed image patches.
191 | """
192 | # Convert grid_pinpoints from string to list
193 | if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
194 | try:
195 | patch_size = processor.size[0]
196 | except Exception as e:
197 | patch_size = processor.size["shortest_edge"] if "shortest_edge" in processor.size else processor.size["height"]
198 | assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
199 | # Use regex to extract the range from the input string
200 | matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
201 | range_start = tuple(map(int, matches[0]))
202 | range_end = tuple(map(int, matches[-1]))
203 | # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
204 | grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
205 | # Multiply all elements by patch_size
206 | grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
207 |
208 | if type(grid_pinpoints) is list:
209 | possible_resolutions = grid_pinpoints
210 | else:
211 | possible_resolutions = ast.literal_eval(grid_pinpoints)
212 | # print("@@@@@@@", image.size)
213 | best_resolution = select_best_resolution(image.size, possible_resolutions)
214 |
215 | image_padded = resize_and_pad_image(image, best_resolution)
216 | # print("@@@@@", processor.size)
217 | patches = divide_to_patches(image_padded, processor.size["height"])
218 | print("image.size:", image.size, "possible_resolutions:", possible_resolutions, "best_resolution:", best_resolution, len(patches))
219 | # FIXME: this seems to be a bug that it resizes instead of pad.
220 | # but to keep it consistent with previous, i will keep it as it is
221 | # TODO: uncomment below to ablate with the padding
222 | if isinstance(processor.size, dict):
223 | shortest_edge = processor.size["shortest_edge"] if "shortest_edge" in processor.size else processor.size["height"]
224 | else:
225 | shortest_edge = min(processor.size)
226 | image_original_resize = image.resize((shortest_edge, shortest_edge))
227 |
228 | image_patches = [image_original_resize] + patches
229 |
230 | image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
231 | return torch.stack(image_patches, dim=0)
232 |
233 |
234 | def read_video_patch(patch_info, data_folder="/mnt/data/yixing.pyx/checkpoints/Oryx-SFT-DATA"):
235 | # import ipdb;ipdb.set_trace()
236 | is_image = False
237 | if 'img_path' in patch_info.keys():
238 | image = Image.open(patch_info['img_path']).convert('RGB')
239 | is_image = True
240 | else:
241 | image_file_name = os.path.join(data_folder, patch_info['patch'])
242 | start_bytes = int(patch_info['start_num'])
243 | file_size = patch_info['size'] # list of int
244 | if len(file_size) == 1:
245 | is_image = True
246 | else:
247 | is_image = False
248 | total_file_size = 0
249 | images_all = []
250 | with open(image_file_name, 'rb') as f:
251 | for idx in range(len(file_size)):
252 | f.seek(start_bytes + total_file_size)
253 | if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
254 | image = Image.open(io.BytesIO(base64.b64decode(f.read(int(file_size[idx])).decode()))).convert("RGB")
255 | else:
256 | if 'sharegpt4o' in image_file_name or 'ShareGPT4Video/new_patch' in image_file_name or 'cinepile' in image_file_name or 'nextqa' in image_file_name or 'perceptiontest' in image_file_name:
257 | byte_str = io.BytesIO(f.read(int(file_size[idx])))
258 | array = np.frombuffer(byte_str.getvalue(), dtype=np.uint8)
259 | image = cv2.imdecode(array, cv2.IMREAD_COLOR)
260 | image = Image.fromarray(image)
261 | else:
262 | image = Image.open(io.BytesIO(f.read(int(file_size[idx])))).convert("RGB")
263 | images_all.append(image)
264 | total_file_size += int(file_size[idx])
265 | # import ipdb;ipdb.set_trace()
266 | return images_all, is_image
267 |
268 | def resize_with_limit(image, max_size=512):
269 | """辅助函数:限制图像最长边"""
270 | width, height = image.size
271 | if max(width, height) > max_size:
272 | scale = max_size / max(width, height)
273 | new_width = int(width * scale)
274 | new_height = int(height * scale)
275 | return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
276 | return image
277 |
278 | def process_image(image_path, processor, aspect_ratio='pad'):
279 | image = Image.open(image_path).convert('RGB')
280 |
281 | images = [np.array(image)]
282 |
283 | if aspect_ratio == 'pad':
284 | images = [Image.fromarray(f) for f in images]
285 | images = [expand2square(image, tuple(int(x*255) for x in processor.image_mean)) for image in images]
286 | else:
287 | images = [Image.fromarray(f) for f in images]
288 |
289 | images = processor.preprocess(images=images, videos=None, return_tensors='pt')['pixel_values']
290 | return images
291 |
292 | def process_image_qwen(image_path, processor, aspect_ratio='pad'):
293 | image = Image.open(image_path).convert('RGB')
294 |
295 | # 使用resize_with_limit函数限制图像大小
296 | image = resize_with_limit(image)
297 |
298 | images = [np.array(image)]
299 |
300 | if aspect_ratio == 'pad':
301 | images = [Image.fromarray(f) for f in images]
302 | images = [expand2square(image, tuple(int(x*255) for x in processor.image_mean)) for image in images]
303 | else:
304 | images = [Image.fromarray(f) for f in images]
305 |
306 | images = processor(images=images, return_tensors='pt')
307 | return images
308 |
309 | def process_image_npary(images, processor, aspect_ratio='pad'):
310 | if images is None:
311 | return None
312 | if aspect_ratio == 'pad':
313 | images = [Image.fromarray(f) for f in images]
314 | images = [expand2square(image, tuple(int(x*255) for x in processor.image_mean)) for image in images]
315 | else:
316 | images = [Image.fromarray(f) for f in images]
317 |
318 | images = processor.preprocess(images, return_tensors='pt')['pixel_values']
319 | return images
320 |
321 | def frame_sample(duration, mode='uniform', num_frames=None, fps=None):
322 | if mode == 'uniform':
323 | assert num_frames is not None, "Number of frames must be provided for uniform sampling."
324 | # NOTE: v1 version
325 | # Calculate the size of each segment from which a frame will be extracted
326 | seg_size = float(duration - 1) / num_frames
327 |
328 | frame_ids = []
329 | for i in range(num_frames):
330 | # Calculate the start and end indices of each segment
331 | start = seg_size * i
332 | end = seg_size * (i + 1)
333 | # Append the middle index of the segment to the list
334 | frame_ids.append((start + end) / 2)
335 |
336 | return np.round(np.array(frame_ids) + 1e-6).astype(int)
337 | # NOTE: v0 version
338 | # return np.linspace(0, duration-1, num_frames, dtype=int)
339 | elif mode == 'fps':
340 | assert fps is not None, "FPS must be provided for FPS sampling."
341 | segment_len = min(fps // NUM_FRAMES_PER_SECOND, duration)
342 | return np.arange(segment_len // 2, duration, segment_len, dtype=int)
343 | else:
344 | raise ImportError(f'Unsupported frame sampling mode: {mode}')
345 |
346 |
347 |
348 | def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=NUM_FRAMES):
349 | if isinstance(video_path, str):
350 | if s is not None and e is not None:
351 | s = s if s >= 0. else 0.
352 | e = e if e >= 0. else 0.
353 | if s > e:
354 | s, e = e, s
355 | elif s == e:
356 | e = s + 1
357 |
358 | # 1. Loading Video
359 | if os.path.isdir(video_path):
360 | frame_files = sorted(os.listdir(video_path))
361 |
362 | fps = 3
363 | num_frames_of_video = len(frame_files)
364 | elif video_path.endswith('.gif'):
365 | gif_reader = imageio.get_reader(video_path)
366 |
367 | fps = 25
368 | num_frames_of_video = len(gif_reader)
369 | else:
370 | vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
371 |
372 | fps = vreader.get_avg_fps()
373 | num_frames_of_video = len(vreader)
374 |
375 | if num_frames > 10000:
376 | num_frames = num_frames_of_video
377 | # 2. Determine frame range & Calculate frame indices
378 | f_start = 0 if s is None else max(int(s * fps) - 1, 0)
379 | f_end = num_frames_of_video - 1 if e is None else min(int(e * fps) - 1, num_frames_of_video - 1)
380 | frame_indices = list(range(f_start, f_end + 1))
381 |
382 | duration = len(frame_indices)
383 | # 3. Sampling frame indices
384 | if num_frames is None:
385 | sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', fps=fps)]
386 | else:
387 | sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=num_frames)]
388 |
389 |
390 | if os.path.isdir(video_path):
391 | video_data = [Image.open(os.path.join(video_path, frame_files[f_idx])) for f_idx in sampled_frame_indices]
392 | elif video_path.endswith('.gif'):
393 | video_data = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
394 | else:
395 | video_data = [Image.fromarray(frame) for frame in vreader.get_batch(sampled_frame_indices).asnumpy()]
396 |
397 | elif isinstance(video_path, np.ndarray):
398 | video_data = [Image.fromarray(f) for f in video_path]
399 | elif isinstance(video_path, list) and isinstance(video_path[0], np.ndarray):
400 | video_data = [Image.fromarray(f) for f in video_path]
401 | elif isinstance(video_path, list) and isinstance(video_path[0], str):
402 | video_data = [Image.open(f) for f in video_path]
403 | elif isinstance(video_path, list) and isinstance(video_path[0], Image.Image):
404 | video_data = video_path
405 | else:
406 | raise ValueError(f"Unsupported video path type: {type(video_path)}")
407 | while num_frames is not None and len(video_data) < num_frames:
408 | video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
409 | if aspect_ratio == 'pad':
410 | images = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in video_data]
411 | video = processor.preprocess(images, return_tensors='pt')['pixel_values']
412 | else:
413 | images = [f for f in video_data]
414 | video = processor.preprocess(images, return_tensors='pt')['pixel_values']
415 | return video
416 |
417 |
418 |
419 | def process_video_qwen(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=NUM_FRAMES):
420 | if isinstance(video_path, str):
421 | if s is not None and e is not None:
422 | s = s if s >= 0. else 0.
423 | e = e if e >= 0. else 0.
424 | if s > e:
425 | s, e = e, s
426 | elif s == e:
427 | e = s + 1
428 |
429 | # 1. Loading Video
430 | if os.path.isdir(video_path):
431 | frame_files = sorted(os.listdir(video_path))
432 | fps = 3
433 | num_frames_of_video = len(frame_files)
434 | elif video_path.endswith('.gif'):
435 | gif_reader = imageio.get_reader(video_path)
436 | fps = 25
437 | num_frames_of_video = len(gif_reader)
438 | else:
439 | vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
440 | fps = vreader.get_avg_fps()
441 | num_frames_of_video = len(vreader)
442 |
443 | if num_frames > 10000:
444 | num_frames = num_frames_of_video
445 | # 2. Determine frame range & Calculate frame indices
446 | f_start = 0 if s is None else max(int(s * fps) - 1, 0)
447 | f_end = num_frames_of_video - 1 if e is None else min(int(e * fps) - 1, num_frames_of_video - 1)
448 | frame_indices = list(range(f_start, f_end + 1))
449 |
450 | duration = len(frame_indices)
451 | # 3. Sampling frame indices
452 | if num_frames is None:
453 | sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', fps=fps)]
454 | else:
455 | sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=num_frames)]
456 |
457 | # 加载帧并应用大小限制
458 | if os.path.isdir(video_path):
459 | video_data = [resize_with_limit(Image.open(os.path.join(video_path, frame_files[f_idx]))) for f_idx in sampled_frame_indices]
460 | elif video_path.endswith('.gif'):
461 | video_data = [resize_with_limit(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB))) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
462 | else:
463 | video_data = [resize_with_limit(Image.fromarray(frame)) for frame in vreader.get_batch(sampled_frame_indices).asnumpy()]
464 |
465 | elif isinstance(video_path, np.ndarray):
466 | video_data = [resize_with_limit(Image.fromarray(f)) for f in video_path]
467 | elif isinstance(video_path, list) and isinstance(video_path[0], np.ndarray):
468 | video_data = [resize_with_limit(Image.fromarray(f)) for f in video_path]
469 | elif isinstance(video_path, list) and isinstance(video_path[0], str):
470 | video_data = [resize_with_limit(Image.open(f)) for f in video_path]
471 | elif isinstance(video_path, list) and isinstance(video_path[0], Image.Image):
472 | video_data = [resize_with_limit(f) for f in video_path]
473 | else:
474 | raise ValueError(f"Unsupported video path type: {type(video_path)}")
475 |
476 | while num_frames is not None and len(video_data) < num_frames:
477 | video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
478 |
479 | if aspect_ratio == 'pad':
480 | images = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in video_data]
481 | video = processor(images=None, videos=images, return_tensors='pt')
482 | else:
483 | images = [f for f in video_data]
484 | video = processor(images=None, videos=images, return_tensors='pt')
485 | return video
486 |
487 |
488 |
489 | def process_audio(audio_path, processor=None, sample_rate=16000, duration=10, s=None, e=None, return_empty=False):
490 | if return_empty:
491 | num_samples = int(duration * sample_rate)
492 | audio_data = torch.zeros(num_samples, dtype=torch.float32)
493 | if processor is not None:
494 | audio_data = processor(audio_data, sampling_rate=sample_rate, return_tensors='pt')['input_features']
495 | if torch.isnan(audio_data).any():
496 | audio_data = torch.nan_to_num(audio_data, nan=-1.5)
497 | return audio_data, processor.sampling_rate
498 | return audio_data, sample_rate
499 |
500 | try:
501 | audio_reader = AudioReader(audio_path, ctx=cpu(0), sample_rate=sample_rate)
502 | audio_data = torch.from_numpy(audio_reader._array)
503 | audio_sample_rate = audio_reader.sample_rate
504 |
505 | if torch.isnan(audio_data).any():
506 | audio_data = torch.nan_to_num(audio_data, nan=-1.5)
507 |
508 | if s is not None and e is not None:
509 | s = s if s >= 0. else 0.
510 | e = e if e >= 0. else 0.
511 | if s > e:
512 | s, e = e, s
513 | elif s == e:
514 | e = s + 1
515 |
516 | start_idx = int(s * audio_sample_rate)
517 | end_idx = int(e * audio_sample_rate)
518 | start_idx = max(0, start_idx)
519 | end_idx = min(len(audio_data), end_idx)
520 | audio_data = audio_data[start_idx:end_idx]
521 |
522 | if len(audio_data.shape) > 1:
523 | audio_data = audio_data.mean(dim=0)
524 |
525 | except Exception as e:
526 | num_samples = int(duration * sample_rate)
527 | audio_data = torch.zeros(num_samples, dtype=torch.float32)
528 | audio_sample_rate = sample_rate
529 |
530 | if processor is not None:
531 | audio_data = processor(audio_data, sampling_rate=audio_sample_rate, return_tensors='pt')['input_features']
532 | if torch.isnan(audio_data).any():
533 | audio_data = torch.nan_to_num(audio_data, nan=-1.5)
534 | audio_sample_rate = processor.sampling_rate
535 |
536 | return audio_data, audio_sample_rate
537 |
538 |
539 | def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
540 | """Tokenize text and multimodal tag to input_ids.
541 | Args:
542 | prompt (str): Text prompt (w/ multimodal tag), e.g., '\nDescribe the video.'
543 | tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
544 | multimodal_token (int): Token index corresponding to the multimodal tag.
545 | """
546 | # multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
547 | if multimodal_token is None or multimodal_token=="":
548 | input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
549 | else:
550 | prompt_chunks = [prompt]
551 | separators = []
552 | # Split prompt by each token type
553 | for token_type, token_index in MODAL_INDEX_MAP.items():
554 | lower_token = token_type
555 | if lower_token in prompt:
556 | split_chunks = []
557 | for chunk in prompt_chunks:
558 | if isinstance(chunk, str):
559 | parts = chunk.split(lower_token)
560 | split_chunks.extend([part for sublist in zip(parts, [token_index] * len(parts)) for part in sublist][:-1])
561 | else:
562 | split_chunks.append(chunk)
563 | # split_chunks.append(parts[-1])
564 | prompt_chunks = split_chunks
565 | # Log the token index for insertion
566 | # sep_positions = [token_index for _ in range(len(parts)-1)]
567 | # separators.extend(sep_positions)
568 | encoded_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids if isinstance(chunk, str) else [chunk] for chunk in prompt_chunks]
569 | # Insert tokens into encoded chunks
570 | input_ids = []
571 | for chunk in encoded_chunks:
572 | input_ids.extend(chunk)
573 |
574 | if return_tensors is not None:
575 | if return_tensors == 'pt':
576 | return torch.tensor(input_ids, dtype=torch.long)
577 | else:
578 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
579 |
580 | return input_ids
581 |
582 |
583 |
584 | def get_model_name_from_path(model_path):
585 | model_path = model_path.strip("/")
586 | model_paths = model_path.split("/")
587 | if model_paths[-1].startswith('checkpoint-'):
588 | return model_paths[-2] + "_" + model_paths[-1]
589 | else:
590 | return model_paths[-1]
591 |
592 |
593 | class KeywordsStoppingCriteria(StoppingCriteria):
594 | def __init__(self, keywords, tokenizer, input_ids):
595 | self.keywords = keywords
596 | self.keyword_ids = []
597 | self.max_keyword_len = 0
598 | for keyword in keywords:
599 | cur_keyword_ids = tokenizer(keyword).input_ids
600 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
601 | cur_keyword_ids = cur_keyword_ids[1:]
602 | if len(cur_keyword_ids) > self.max_keyword_len:
603 | self.max_keyword_len = len(cur_keyword_ids)
604 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
605 | self.tokenizer = tokenizer
606 | self.start_len = input_ids.shape[1]
607 |
608 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
609 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
610 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
611 | for keyword_id in self.keyword_ids:
612 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
613 | return True
614 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
615 | for keyword in self.keywords:
616 | if keyword in outputs:
617 | return True
618 | return False
619 |
620 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
621 | outputs = []
622 | for i in range(output_ids.shape[0]):
623 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
624 | return all(outputs)
625 |
--------------------------------------------------------------------------------
/humanomni/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2 | # Copyright 2023 Haotian Liu
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import warnings
19 | import shutil
20 |
21 | import torch
22 | from transformers import PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
23 |
24 | from .projector import load_mm_projector
25 | from .humanomni_model import HumanOmniQwen2ForCausalLM, HumanOmniQwen2Config
26 |
27 |
28 |
29 | VLLMs = {
30 | "HumanOmni_qwen2": HumanOmniQwen2ForCausalLM,
31 | }
32 |
33 | VLLMConfigs = {
34 | "HumanOmni_qwen2": HumanOmniQwen2Config,
35 | }
36 |
37 |
38 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
39 | if 'token' in kwargs:
40 | token = kwargs['token']
41 | else:
42 | token = None
43 |
44 | kwargs = {"device_map": device_map, **kwargs}
45 |
46 | if device != "cuda":
47 | kwargs['device_map'] = {"": device}
48 |
49 | if load_8bit:
50 | kwargs['load_in_8bit'] = True
51 | elif load_4bit:
52 | # NOTE: High-version Transformers will report: """ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time."""
53 | # kwargs['load_in_4bit'] = True
54 | kwargs['quantization_config'] = BitsAndBytesConfig(
55 | load_in_4bit=True,
56 | bnb_4bit_compute_dtype=torch.float16,
57 | bnb_4bit_use_double_quant=True,
58 | bnb_4bit_quant_type='nf4'
59 | )
60 | else:
61 | kwargs['torch_dtype'] = torch.float16
62 |
63 | if use_flash_attn:
64 | kwargs['attn_implementation'] = 'flash_attention_2'
65 |
66 | config = AutoConfig.from_pretrained(model_path)
67 |
68 | # judge model type
69 | model_type = config.model_type
70 |
71 | # judge pretrain/finetune
72 | try:
73 | is_pretraining = config.tune_mm_mlp_adapter
74 | except:
75 | is_pretraining = False
76 |
77 | # NOTE: SFT model loading
78 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
79 | model = HumanOmniQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
80 | processor = None
81 |
82 | if "HumanOmni" in model_type:
83 | vision_tower = model.get_vision_tower()
84 | if not vision_tower.is_loaded:
85 | vision_tower.load_model()
86 | vision_tower.to(device=device, dtype=torch.float16)
87 | # NOTE: HuanOmni adopts the same processor for processing image and video.
88 |
89 | processor = vision_tower.image_processor
90 |
91 | if hasattr(model.config, "max_sequence_length"):
92 | context_len = model.config.max_sequence_length
93 | else:
94 | context_len = 2048
95 |
96 | if getattr(model.config, "mm_audio_tower", None):
97 | audio_tower = model.get_audio_tower()
98 | if not audio_tower.is_loaded:
99 | audio_tower.load_model()
100 | audio_tower.to(device=device, dtype=torch.float16)
101 |
102 | audio_processor = audio_tower.audio_processor
103 | return tokenizer, model, processor, context_len, audio_processor
104 | else:
105 | return tokenizer, model, processor, context_len, None
106 | # return tokenizer, model, processor, context_len
107 |
--------------------------------------------------------------------------------
/humanomni/model/encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers import (
7 | CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig,
8 | SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig,
9 | WhisperFeatureExtractor, WhisperProcessor, WhisperConfig, WhisperForAudioClassification
10 | )
11 |
12 | class CLIPVisionTower(nn.Module):
13 |
14 | def __init__(self, vision_tower, args, delay_load=False):
15 | super().__init__()
16 |
17 | self.is_loaded = False
18 |
19 | self.vision_tower_name = vision_tower
20 | self.select_layer = args.mm_vision_select_layer
21 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
22 |
23 | if not delay_load:
24 | self.load_model()
25 | else:
26 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
27 |
28 | def load_model(self):
29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30 |
31 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
32 | self.vision_tower.requires_grad_(False)
33 |
34 | self.is_loaded = True
35 |
36 | def feature_select(self, image_forward_outs):
37 | image_features = image_forward_outs.hidden_states[self.select_layer]
38 | if self.select_feature == 'patch':
39 | image_features = image_features[:, 1:]
40 | elif self.select_feature == 'cls_patch':
41 | image_features = image_features
42 | else:
43 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
44 | return image_features
45 |
46 | @torch.no_grad()
47 | def forward(self, images):
48 | if type(images) is list:
49 | image_features = []
50 | for image in images:
51 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
52 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
53 | image_features.append(image_feature)
54 | else:
55 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
56 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
57 |
58 | return image_features
59 |
60 | @property
61 | def dummy_feature(self):
62 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
63 |
64 | @property
65 | def dtype(self):
66 | return self.vision_tower.dtype
67 |
68 | @property
69 | def device(self):
70 | return self.vision_tower.device
71 |
72 | @property
73 | def config(self):
74 | if self.is_loaded:
75 | return self.vision_tower.config
76 | else:
77 | return self.cfg_only
78 |
79 | @property
80 | def hidden_size(self):
81 | return self.config.hidden_size
82 |
83 | @property
84 | def num_patches(self):
85 | return (self.config.image_size // self.config.patch_size) ** 2
86 |
87 | @property
88 | def num_patches_per_side(self):
89 | return self.config.image_size // self.config.patch_size
90 |
91 | @property
92 | def image_size(self):
93 | return self.config.image_size
94 |
95 |
96 | class SiglipVisionTower(nn.Module):
97 |
98 | def __init__(self, vision_tower, args, delay_load=False):
99 | super().__init__()
100 |
101 | self.is_loaded = False
102 |
103 | self.vision_tower_name = vision_tower
104 | self.select_layer = args.mm_vision_select_layer
105 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
106 | if not delay_load:
107 | self.load_model()
108 | else:
109 | self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
110 |
111 | def load_model(self):
112 | self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
113 |
114 | self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
115 | self.vision_tower.requires_grad_(False)
116 |
117 | self.is_loaded = True
118 |
119 | def feature_select(self, image_forward_outs):
120 | image_features = image_forward_outs.hidden_states[self.select_layer]
121 | if self.select_feature == 'patch':
122 | image_features = image_features
123 | else:
124 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
125 | return image_features
126 |
127 | @torch.no_grad()
128 | def forward(self, images, raw_datas=None):
129 | if type(images) is list:
130 | image_features = []
131 | for image in images:
132 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
133 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
134 | image_features.append(image_feature)
135 | else:
136 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
137 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
138 |
139 | return image_features
140 |
141 | @property
142 | def dummy_feature(self):
143 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
144 |
145 | @property
146 | def dtype(self):
147 | return self.vision_tower.dtype
148 |
149 | @property
150 | def device(self):
151 | return self.vision_tower.device
152 |
153 | @property
154 | def config(self):
155 | if self.is_loaded:
156 | return self.vision_tower.config
157 | else:
158 | return self.cfg_only
159 |
160 | @property
161 | def hidden_size(self):
162 | return self.config.hidden_size
163 |
164 | @property
165 | def num_patches(self):
166 | return (self.config.image_size // self.config.patch_size) ** 2
167 |
168 | @property
169 | def num_patches_per_side(self):
170 | return self.config.image_size // self.config.patch_size
171 |
172 | @property
173 | def image_size(self):
174 | return self.config.image_size
175 |
176 | class WhisperAudioTower(nn.Module):
177 | def __init__(self, audio_tower, args, delay_load=False):
178 | super().__init__()
179 | self.is_loaded = False
180 | self.audio_tower_name = audio_tower
181 | self.select_layer = args.mm_vision_select_layer
182 |
183 | if not delay_load:
184 | self.load_model()
185 | elif getattr(args, "unfreeze_mm_audio_tower", False):
186 | # TODO: better detector is needed.
187 | print(f"The checkpoint seems to contain `audio_tower` weights: `unfreeze_mm_audio_tower`: True.")
188 | self.load_model()
189 | else:
190 | self.cfg_only = WhisperConfig.from_pretrained(self.audio_tower_name)
191 | def load_model(self, device_map=None):
192 | if self.is_loaded:
193 | print("{} is already loaded, `load_model` called again, skipping.".format(self.audio_tower_name))
194 | return
195 | self.audio_processor = WhisperFeatureExtractor.from_pretrained(self.audio_tower_name)
196 | self.audio_tower = WhisperForAudioClassification.from_pretrained(self.audio_tower_name)
197 | self.audio_tower.requires_grad_(False)
198 | self.is_loaded = True
199 | def feature_select(self, audio_forward_outs):
200 | audio_features = audio_forward_outs.hidden_states[self.select_layer]
201 | return audio_features
202 | def forward(self, samples):
203 | if isinstance(samples, list):
204 | audio_features = []
205 | for sample in samples:
206 |
207 | audio_forward_outs = self.audio_tower.encoder(sample, output_hidden_states=True)
208 | # audio_feature = self.feature_select(audio_forward_outs).to(audio_features.dtype)
209 | audio_features.append(audio_forward_outs.last_hidden_state)
210 | else:
211 |
212 | audio_forward_outs = self.audio_tower.encoder(samples, return_dict=True)
213 | audio_features = audio_forward_outs.last_hidden_state
214 | return audio_features
215 | @property
216 | def dummy_feature(self):
217 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
218 | @property
219 | def dtype(self):
220 | return self.audio_tower.dtype
221 | @property
222 | def device(self):
223 | return self.audio_tower.device
224 | @property
225 | def config(self):
226 | if self.is_loaded:
227 | return self.audio_tower.config
228 | else:
229 | return self.cfg_only
230 | @property
231 | def hidden_size(self):
232 | return self.config.hidden_size
233 |
234 |
235 | def build_vision_tower(vision_tower_cfg, **kwargs):
236 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
237 |
238 | if 'clip' in vision_tower:
239 | vision_tower = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
240 | elif 'siglip' in vision_tower:
241 | vision_tower = SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
242 | else:
243 | raise ValueError(f'Unknown vision tower: {vision_tower}')
244 |
245 | return vision_tower
246 |
247 | def build_audio_tower(audio_tower_cfg, **kwargs):
248 | audio_tower = getattr(audio_tower_cfg, 'mm_audio_tower', getattr(audio_tower_cfg, 'audio_tower', None))
249 |
250 | if "whisper" in audio_tower:
251 | return WhisperAudioTower(audio_tower, args=audio_tower_cfg, **kwargs)
252 |
--------------------------------------------------------------------------------
/humanomni/model/humanomni_arch.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2 | # Copyright 2023 Haotian Liu
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import time
16 | import os
17 | from abc import ABC, abstractmethod
18 | import math
19 | import re
20 | import einops
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | from .projector import load_mm_projector, build_vision_projector, build_audio_projector
25 | from .encoder import build_vision_tower, build_audio_tower
26 | from ..constants import IGNORE_INDEX, NUM_FRAMES, MODAL_INDEX_MAP, IMAGE_TOKEN_PATCH, MODAL_INDEX_REMAP
27 | from humanomni.mm_utils import frame_sample
28 | from transformers import BertModel, BertTokenizer
29 | import h5py
30 | import torch.distributed as dist
31 | import ipdb
32 |
33 | class SFDynamicCompressor(nn.Module):
34 | def __init__(self, model_args, vision_tower):
35 | super().__init__()
36 |
37 | self.out_channels = vision_tower.hidden_size
38 | self.mid_channel = 256
39 |
40 | self.vlm_query_projector = nn.Linear(self.out_channels, self.mid_channel)
41 | self.vlm_key_projector = nn.Linear(self.out_channels, self.mid_channel)
42 |
43 | def downsample(self, x):
44 | return F.avg_pool2d(x, 2, 2)
45 |
46 | def downsample_4(self, x):
47 | return F.avg_pool2d(x, 4, 4)
48 |
49 | def forward(self, image_features, image_size=None):
50 | if image_size is None:
51 | W = int(math.sqrt(image_features.shape[1]))
52 | H = int(W)
53 | else:
54 | H, W = image_size
55 | image_features = einops.rearrange(image_features, 't (r w) h -> t r w h', r = H)
56 | T, H, W, C = image_features.shape
57 | image_features = image_features.unsqueeze(0)
58 | B = 1
59 | fast_feature = F.avg_pool2d(image_features.permute(0, 1, 4, 2, 3).view(B*T, C, H, W), 2, 2) # B * T, C, H // 2, W //2
60 | fast_feature = fast_feature.view(B*T, C, -1)
61 | fast_feature = fast_feature.permute(0, 2, 1).view(B, T, -1, C).view(B, -1, C)
62 |
63 | index = torch.arange(1, T, 4)
64 | if len(index) == 0:
65 | index = torch.tensor([0])
66 | slow_feature = image_features[:, index, :, :, :].view(B, -1, C)
67 |
68 | final_feature = torch.cat([fast_feature, slow_feature], dim=1)
69 | return final_feature
70 |
71 |
72 | class HumanOmniMetaModel:
73 |
74 | def __init__(self, config):
75 | super(HumanOmniMetaModel, self).__init__(config)
76 | if hasattr(config, "mm_vision_tower"):
77 | self.vision_tower = build_vision_tower(config, delay_load=True)
78 | self.mm_projector = build_vision_projector(config)
79 |
80 |
81 | # Comment out this part of the code during training to avoid repeated initialization.
82 | num_branches = 3
83 | bert_model = "bert-base-uncased"
84 | self.bert_model = BertModel.from_pretrained(bert_model)
85 | self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
86 | modules = [nn.Linear(self.bert_model.config.hidden_size, 3584)]
87 | modules.append(nn.GELU())
88 | modules.append(nn.Linear(3584, num_branches))
89 | self.bert_gate = nn.Sequential(*modules)
90 | self.bert_softmax = nn.Softmax(dim=1)
91 | self.feature_compressor = SFDynamicCompressor(config, self.vision_tower)
92 | #####
93 |
94 | if hasattr(config, "mm_audio_tower"):
95 | self.audio_tower = build_audio_tower(config, delay_load=True)
96 | self.config.audio_hidden_size = getattr(self.audio_tower, "hidden_size", 1280)
97 | self.audio_projector = build_audio_projector(config, vision_cfg=self.audio_tower.config)
98 |
99 | def get_vision_tower(self):
100 | vision_tower = getattr(self, 'vision_tower', None)
101 | if type(vision_tower) is list:
102 | vision_tower = vision_tower[0]
103 | return vision_tower
104 |
105 | def get_audio_tower(self):
106 | audio_tower = getattr(self, "audio_tower", None)
107 | return audio_tower
108 |
109 | def initialize_vision_modules(self, model_args, fsdp=None):
110 | vision_tower = model_args.vision_tower
111 | mm_vision_select_layer = model_args.mm_vision_select_layer
112 | mm_vision_select_feature = model_args.mm_vision_select_feature
113 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
114 |
115 | self.config.mm_vision_tower = vision_tower
116 |
117 | if self.get_vision_tower() is None:
118 | vision_tower = build_vision_tower(model_args)
119 |
120 | if fsdp is not None and len(fsdp) > 0:
121 | self.vision_tower = [vision_tower]
122 | else:
123 | self.vision_tower = vision_tower
124 | else:
125 | if fsdp is not None and len(fsdp) > 0:
126 | vision_tower = self.vision_tower[0]
127 | else:
128 | vision_tower = self.vision_tower
129 | vision_tower.load_model()
130 |
131 | self.config.use_mm_proj = True
132 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
133 | self.config.mm_hidden_size = vision_tower.hidden_size
134 | self.config.mm_vision_select_layer = mm_vision_select_layer
135 | self.config.mm_vision_select_feature = mm_vision_select_feature
136 |
137 | if getattr(self, 'mm_projector', None) is None:
138 | self.mm_projector = build_vision_projector(self.config)
139 | else:
140 | # In case it is frozen by LoRA
141 | for p in self.mm_projector.parameters():
142 | p.requires_grad = True
143 |
144 | if model_args.audio_tower:
145 | self.initialize_audio_modules(model_args, fsdp)
146 |
147 | if pretrain_mm_mlp_adapter is not None:
148 | if os.path.exists(pretrain_mm_mlp_adapter):
149 | is_local = True
150 | if os.path.isdir(pretrain_mm_mlp_adapter):
151 | mm_projector_weights = load_mm_projector(pretrain_mm_mlp_adapter)
152 | else:
153 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
154 | else:
155 | # Support loading projector weights from remote HuggingFace model hub
156 | is_local = False
157 | pretrain_mm_mlp_adapter = pretrain_mm_mlp_adapter.replace('mm_projector.bin', '')
158 | pretrain_mm_mlp_adapter = pretrain_mm_mlp_adapter.strip('/').strip('\\').strip()
159 | mm_projector_weights = load_mm_projector(pretrain_mm_mlp_adapter)
160 |
161 | def get_w(weights, keyword):
162 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
163 |
164 | # self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
165 | # set strict=False to avoid missing key error regarding bert.embeddings.position_ids
166 |
167 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=True)
168 |
169 |
170 | self.feature_compressor = SFDynamicCompressor(model_args, vision_tower)
171 | num_branches = 3
172 | bert_model = "bert-base-uncased"
173 | self.bert_model = BertModel.from_pretrained(bert_model)
174 | self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
175 | # self.bert_gate = nn.Linear(self.bert_model.config.hidden_size, num_branches)
176 | modules = [nn.Linear(self.bert_model.config.hidden_size, 3584)]
177 | modules.append(nn.GELU())
178 | modules.append(nn.Linear(3584, num_branches))
179 | self.bert_gate = nn.Sequential(*modules)
180 | self.bert_softmax = nn.Softmax(dim=1)
181 |
182 |
183 | def initialize_audio_modules(self, model_args, fsdp=None):
184 | audio_tower = model_args.audio_tower
185 | pretrain_audio_mlp_adapter = model_args.pretrain_audio_mlp_adapter
186 | self.config.mm_audio_tower = audio_tower
187 | self.config.mm_audio_projector_type = getattr(model_args, "mm_audio_projector_type", "mlp2x_gelu")
188 | if self.get_audio_tower() is None:
189 | audio_tower = build_audio_tower(model_args)
190 |
191 | if fsdp is not None and len(fsdp) > 0:
192 | self.audio_tower = [audio_tower]
193 | else:
194 | self.audio_tower = audio_tower
195 | else:
196 | if fsdp is not None and len(fsdp) > 0:
197 | audio_tower = self.audio_tower[0]
198 | else:
199 | audio_tower = self.audio_tower
200 | audio_tower.load_model()
201 | self.config.audio_hidden_size = getattr(audio_tower, "hidden_size", 1280)
202 | if getattr(self, "audio_projector", None) is None:
203 | self.audio_projector = build_audio_projector(self.config, vision_cfg=audio_tower.config)
204 | else:
205 | # In case it is frozen by LoRA
206 | for p in self.audio_projector.parameters():
207 | p.requires_grad = True
208 | def get_w(weights, keyword):
209 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
210 | if pretrain_audio_mlp_adapter is not None:
211 | audio_projector_weights = torch.load(pretrain_audio_mlp_adapter, map_location="cpu")
212 | # # import pdb; pdb.set_trace()
213 | incompatible_keys = self.audio_projector.load_state_dict(get_w(audio_projector_weights, "audio_projector"))
214 | print(f"load audio projector: {incompatible_keys}")
215 | num_trainable_parameters = sum(p.numel() for p in self.audio_projector.parameters() if p.requires_grad) / 1e6
216 | print(f"Number of trainable parameters in audio projector: {num_trainable_parameters}M")
217 |
218 |
219 | class HumanOmniMetaForCausalLM(ABC):
220 |
221 | @abstractmethod
222 | def get_model(self):
223 | pass
224 |
225 | def num_frames(self):
226 | if hasattr(self.config, 'num_frames'):
227 | return self.config.num_frames
228 | else:
229 | return NUM_FRAMES
230 |
231 | def get_vision_tower(self):
232 | return self.get_model().get_vision_tower()
233 |
234 | def get_audio_tower(self):
235 | return self.get_model().get_audio_tower()
236 |
237 | def get_2dPool(self, image_feature, stride=2):
238 | height = width = self.get_vision_tower().num_patches_per_side
239 | num_frames, num_tokens, num_dim = image_feature.shape
240 | image_feature = image_feature.view(num_frames, height, width, -1)
241 | image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
242 |
243 | height, weight = image_feature.shape[2:]
244 | scaled_shape = [math.ceil(height / stride), math.ceil(weight / stride)]
245 | image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
246 |
247 | image_feature = image_feature.permute(0, 2, 3, 1)
248 | image_feature = image_feature.view(num_frames, -1, num_dim)
249 | return image_feature
250 |
251 | def encode_images_or_videos(self, images, device=None,prompts=None):
252 |
253 | num_frames = self.config.num_frames if hasattr(self.config, 'num_frames') else NUM_FRAMES
254 | current_device = torch.cuda.current_device()
255 |
256 | data_batch = []
257 | video_idx_in_batch = []
258 | for i, (data, modal) in enumerate(images):
259 | data = data
260 | video_idx_in_batch.append(i)
261 | data_batch.append(data)
262 | batch_size = len(data_batch)
263 | split_sizes = [image.shape[0] for image in data_batch]
264 | frames = torch.cat([image for image in data_batch], dim=0)
265 | # ddd
266 | frames_features = self.get_model().get_vision_tower()(frames)
267 | video_features = einops.rearrange(frames_features, '(b t) n h -> b t n h', b = batch_size)
268 | body_features = video_features
269 | face_features = frames_features
270 | video_features, body_features, face_features = self.get_model().mm_projector(video_features, body_features, face_features)
271 | face_features = einops.rearrange(face_features, '(b t) n h -> b t n h', b = batch_size)
272 |
273 | inputs_bert = prompts
274 | # Get BERT features
275 | outputs_bert = self.get_model().bert_model(**inputs_bert)
276 | last_hidden_state_bert = outputs_bert.last_hidden_state
277 | # Use [CLS] token representation
278 | cls_token_embedding_bert = last_hidden_state_bert[:, 0, :]
279 | # Calculate branch probabilities
280 | logits = self.get_model().bert_gate(cls_token_embedding_bert)
281 | branch_probs = self.get_model().bert_softmax(logits)
282 |
283 | image_features = []
284 | for idx, image_feat in enumerate(face_features):
285 | if idx in video_idx_in_batch:
286 | image_features.append(self.get_2dPool(image_feat))
287 | else:
288 | image_features.append(image_feat)
289 | face_features = image_features
290 |
291 | new_image_features = []
292 | for image_idx, face_feature in enumerate(face_features):
293 | video_feature = video_features[image_idx]
294 | body_feature = body_features[image_idx]
295 | if image_idx in video_idx_in_batch: # video operations
296 | face_feature = face_feature.flatten(0, 1)
297 | image_feature = video_feature * branch_probs[image_idx][0] + body_feature * branch_probs[image_idx][1] + face_feature * branch_probs[image_idx][2]
298 | ###如果有slow fast分支,取消注释
299 | image_feature = einops.rearrange(image_feature, '(t n) h -> t n h', t = num_frames)
300 | image_feature = self.get_model().feature_compressor(image_feature)
301 | new_image_features.append(image_feature)
302 |
303 | return new_image_features
304 |
305 |
306 | def encode_audios(self, audios):
307 | audio_features = self.get_model().get_audio_tower()(audios).permute(0, 2, 1).contiguous() #b, t, c -> b, c, t # torch.Size([1, 1280, 1500])
308 | audio_features = torch.nn.functional.avg_pool1d(audio_features, kernel_size=3, stride=3).permute(0, 2, 1).contiguous() # torch.Size([1, 1280, 500])
309 | audio_features = self.get_model().audio_projector(audio_features)
310 | return audio_features
311 |
312 |
313 | def prepare_inputs_labels_for_multimodal(
314 | self, input_ids, attention_mask, past_key_values, labels, images, prompts=None,audios=None
315 | ):
316 |
317 | if audios is not None:
318 | if len(audios.shape) == 4 and audios.shape[0] == 1:
319 | audios = audios.squeeze(0) # 移除第一维
320 | vision_tower = self.get_vision_tower()
321 | audio_tower = self.get_audio_tower()
322 | # NOTE: text-only situation
323 | if vision_tower is None or images is None or input_ids.shape[1] == 1:
324 | return input_ids, attention_mask, past_key_values, None, labels
325 | device_ = input_ids.device
326 | mm_features = self.encode_images_or_videos(images ,device_,prompts)
327 |
328 | if audios is not None and audio_tower is not None:
329 | audio_features = self.encode_audios(audios)
330 | new_input_embeds = []
331 | new_labels = [] if labels is not None else None
332 | cur_mm_idx = 0
333 | for batch_idx, cur_input_ids in enumerate(input_ids):
334 | num_multimodals = sum((cur_input_ids == mm_token_idx).sum() for mm_token_idx in MODAL_INDEX_MAP.values())
335 | # pure text input
336 | if num_multimodals == 0:
337 | half_len = cur_input_ids.shape[0] // 2
338 | cur_mm_features = mm_features[cur_mm_idx]
339 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
340 | cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
341 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_mm_features[0:0], cur_input_embeds_2], dim=0)
342 | new_input_embeds.append(cur_input_embeds)
343 | if labels is not None:
344 | new_labels.append(labels[batch_idx])
345 | cur_mm_idx += 1
346 | continue
347 |
348 | cur_new_input_embeds = []
349 | if labels is not None:
350 | cur_labels = labels[batch_idx]
351 | cur_new_labels = []
352 | assert cur_labels.shape == cur_input_ids.shape
353 |
354 | mm_token_indices = torch.where(sum([cur_input_ids == mm_token_idx for mm_token_idx in MODAL_INDEX_MAP.values()]))[0]
355 | while mm_token_indices.numel() > 0:
356 | mm_token_start = mm_token_indices[0]
357 | cur_modal = MODAL_INDEX_REMAP[cur_input_ids[mm_token_start].item()]
358 | if cur_modal in ["", ""]:
359 | cur_mm_idx += 1
360 | cur_mm_features = mm_features[batch_idx]
361 | if len(cur_mm_features.size())==3:
362 | cur_mm_features=cur_mm_features.flatten(0,1)
363 | elif cur_modal in [""] and audio_tower is not None:
364 | cur_mm_features = audio_features[batch_idx]
365 |
366 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:mm_token_start]))
367 | cur_new_input_embeds.append(cur_mm_features)
368 | if labels is not None:
369 | cur_new_labels.append(cur_labels[:mm_token_start])
370 | cur_new_labels.append(torch.full((cur_mm_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
371 | cur_labels = cur_labels[mm_token_start+1:]
372 |
373 | cur_input_ids = cur_input_ids[mm_token_start+1:]
374 | mm_token_indices = torch.where(sum([cur_input_ids == mm_token_idx for mm_token_idx in MODAL_INDEX_MAP.values()]))[0]
375 |
376 | if cur_input_ids.numel() > 0:
377 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
378 | if labels is not None:
379 | cur_new_labels.append(cur_labels)
380 | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
381 | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
382 | new_input_embeds.append(cur_new_input_embeds)
383 | if labels is not None:
384 | cur_new_labels = torch.cat(cur_new_labels, dim=0)
385 | new_labels.append(cur_new_labels)
386 |
387 | # padding
388 | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
389 | max_len = max(x.shape[0] for x in new_input_embeds)
390 |
391 | new_input_embeds_align = []
392 | for cur_new_embed in new_input_embeds:
393 | cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
394 | new_input_embeds_align.append(cur_new_embed)
395 | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
396 |
397 | if labels is not None:
398 | new_labels_align = []
399 | _new_labels = new_labels
400 | for cur_new_label in new_labels:
401 | cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
402 | new_labels_align.append(cur_new_label)
403 | new_labels = torch.stack(new_labels_align, dim=0)
404 |
405 | if attention_mask is not None:
406 | new_attention_mask = []
407 | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
408 | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
409 | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
410 | cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
411 | new_attention_mask.append(cur_new_attention_mask)
412 | attention_mask = torch.stack(new_attention_mask, dim=0)
413 | assert attention_mask.shape == new_labels.shape
414 | else:
415 | new_input_embeds = torch.stack(new_input_embeds, dim=0)
416 | if labels is not None:
417 | new_labels = torch.stack(new_labels, dim=0)
418 |
419 | if attention_mask is not None:
420 | new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
421 | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
422 | assert attention_mask.shape == new_input_embeds.shape[:2]
423 | return None, attention_mask, past_key_values, new_input_embeds, new_labels
424 |
--------------------------------------------------------------------------------
/humanomni/model/humanomni_model.py:
--------------------------------------------------------------------------------
1 | # Adopted from: https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2 | # Copyright 2023 Haotian Liu
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from typing import List, Optional, Tuple, Union
18 |
19 | import torch
20 | import torch.nn as nn
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | Qwen2Config, Qwen2Model, Qwen2ForCausalLM
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from transformers.generation.utils import GenerateOutput
26 |
27 | from .humanomni_arch import HumanOmniMetaModel, HumanOmniMetaForCausalLM
28 | from torch.nn import CrossEntropyLoss
29 |
30 | class HumanOmniQwen2Config(Qwen2Config):
31 | model_type = "HumanOmni_qwen2"
32 |
33 | def __init__(self, **kwargs):
34 | super().__init__(**kwargs)
35 | self.model_type = "HumanOmni_qwen2"
36 |
37 |
38 | class HumanOmniQwen2Model(HumanOmniMetaModel, Qwen2Model):
39 | config_class = HumanOmniQwen2Config
40 |
41 | def __init__(self, config: HumanOmniQwen2Config):
42 | super(HumanOmniQwen2Model, self).__init__(config)
43 |
44 |
45 | class HumanOmniQwen2ForCausalLM(Qwen2ForCausalLM,HumanOmniMetaForCausalLM):
46 | config_class = HumanOmniQwen2Config
47 |
48 | def __init__(self, config, **kwargs):
49 | super(Qwen2ForCausalLM, self).__init__(config)
50 | self.model = HumanOmniQwen2Model(config)
51 | self.vocab_size = config.vocab_size
52 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53 |
54 | # Initialize weights and apply final processing
55 | self.post_init()
56 |
57 | def get_model(self):
58 | return self.model
59 |
60 | def forward(
61 | self,
62 | input_ids: torch.LongTensor = None,
63 | attention_mask: Optional[torch.Tensor] = None,
64 | position_ids: Optional[torch.LongTensor] = None,
65 | past_key_values: Optional[List[torch.FloatTensor]] = None,
66 | inputs_embeds: Optional[torch.FloatTensor] = None,
67 | labels: Optional[torch.LongTensor] = None,
68 | use_cache: Optional[bool] = None,
69 | output_attentions: Optional[bool] = None,
70 | output_hidden_states: Optional[bool] = None,
71 | images: Optional[torch.FloatTensor] = None,
72 | return_dict: Optional[bool] = None,
73 | cache_position: Optional[int] = None,
74 | prompts: Optional[List[str]] = None,
75 | audios: Optional[torch.FloatTensor] = None,
76 | **kwargs
77 | ) -> Union[Tuple, CausalLMOutputWithPast]:
78 | # audios=kwargs.get('audios', None)
79 | if inputs_embeds is None:
80 |
81 | (
82 | input_ids,
83 | attention_mask,
84 | past_key_values,
85 | inputs_embeds,
86 | labels
87 | ) = self.prepare_inputs_labels_for_multimodal(
88 | input_ids,
89 | attention_mask,
90 | past_key_values,
91 | labels,
92 | images,
93 | prompts=prompts,
94 | audios=audios
95 | )
96 |
97 |
98 | outputs = super().forward(
99 | input_ids=input_ids,
100 | attention_mask=attention_mask,
101 | past_key_values=past_key_values,
102 | inputs_embeds=inputs_embeds,
103 | labels=labels,
104 | use_cache=use_cache,
105 | output_attentions=output_attentions,
106 | output_hidden_states=output_hidden_states,
107 | return_dict=return_dict,
108 | cache_position=cache_position,
109 | )
110 |
111 | outputs.labels = labels
112 | return outputs
113 |
114 |
115 |
116 |
117 | @torch.no_grad()
118 | def generate(
119 | self,
120 | inputs: Optional[torch.Tensor] = None,
121 | images: Optional[torch.Tensor] = None,
122 | audios: Optional[torch.Tensor] = None,
123 | **kwargs,
124 | ) -> Union[GenerateOutput, torch.LongTensor]:
125 | position_ids = kwargs.pop("position_ids", None)
126 | attention_mask = kwargs.pop("attention_mask", None)
127 | prompts = kwargs.pop("prompts", None)
128 | face_videos = kwargs.pop("face_videos", None)
129 | body_videos = kwargs.pop("body_videos", None)
130 | if "inputs_embeds" in kwargs:
131 | raise NotImplementedError("`inputs_embeds` is not supported")
132 |
133 | if images is not None:
134 | if face_videos is None:
135 | (
136 | input_ids,
137 | attention_mask,
138 | past_key_values,
139 | inputs_embeds,
140 | _
141 | ) = self.prepare_inputs_labels_for_multimodal(
142 | input_ids=inputs,
143 | attention_mask=attention_mask,
144 | past_key_values=None,
145 | labels=None,
146 | images=images,
147 | prompts=prompts,
148 | audios=audios
149 | )
150 | else:
151 | (
152 | input_ids,
153 | attention_mask,
154 | past_key_values,
155 | inputs_embeds,
156 | _
157 | ) = self.prepare_inputs_labels_for_multimodal(
158 | input_ids=inputs,
159 | attention_mask=attention_mask,
160 | past_key_values=None,
161 | labels=None,
162 | images=images,
163 | prompts=prompts,
164 | face_videos=face_videos,
165 | body_videos=body_videos,
166 | audios=audios
167 | )
168 | else:
169 | inputs_embeds = self.get_model().embed_tokens(inputs)
170 |
171 | return super().generate(
172 | position_ids=position_ids,
173 | attention_mask=attention_mask,
174 | inputs_embeds=inputs_embeds,
175 | **kwargs
176 | )
177 |
178 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
179 | images = kwargs.pop("images", None)
180 | _inputs = super().prepare_inputs_for_generation(
181 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
182 | )
183 | if images is not None:
184 | _inputs['images'] = images
185 | return _inputs
186 |
187 |
188 | AutoConfig.register("HumanOmni_qwen2", HumanOmniQwen2Config)
189 | AutoModelForCausalLM.register(HumanOmniQwen2Config, HumanOmniQwen2ForCausalLM)
190 |
--------------------------------------------------------------------------------
/humanomni/train_flash_attn.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
3 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
4 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
5 |
6 | import sys
7 | sys.path.append('./')
8 |
9 | from humanomni.train_humanomni import train
10 |
11 |
12 | if __name__ == "__main__":
13 | train(attn_implementation="flash_attention_2")
14 |
--------------------------------------------------------------------------------
/humanomni/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from .constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from humanomni import model_init, mm_infer
4 | from humanomni.utils import disable_torch_init
5 | from transformers import BertTokenizer
6 |
7 | # 设置环境变量
8 | os.environ['TRANSFORMERS_OFFLINE'] = '1'
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | def main():
12 | parser = argparse.ArgumentParser(description="HumanOmni Inference Script")
13 | parser.add_argument('--modal', type=str, default='video_audio', help='Modal type (video or video_audio)')
14 | parser.add_argument('--model_path', type=str, required=True, help='Path to the model')
15 | parser.add_argument('--video_path', type=str, required=True, help='Path to the video file')
16 | parser.add_argument('--instruct', type=str, required=True, help='Instruction for the model')
17 |
18 | args = parser.parse_args()
19 |
20 | # 初始化BERT分词器
21 | bert_model = "bert-base-uncased"
22 | bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
23 |
24 | # 禁用Torch初始化
25 | disable_torch_init()
26 |
27 | # 初始化模型、处理器和分词器
28 | model, processor, tokenizer = model_init(args.model_path)
29 |
30 | # 处理视频输入
31 | video_tensor = processor['video'](args.video_path)
32 |
33 | # 根据modal类型决定是否处理音频
34 | if args.modal == 'video_audio' or args.modal == 'audio':
35 | audio = processor['audio'](args.video_path)[0]
36 | else:
37 | audio = None
38 |
39 | # 执行推理
40 | output = mm_infer(video_tensor, args.instruct, model=model, tokenizer=tokenizer, modal=args.modal, question=args.instruct, bert_tokeni=bert_tokenizer, do_sample=False, audio=audio)
41 | print(output)
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.1
2 | torchvision==0.17.1
3 | transformers==4.45.0
4 | tokenizers==0.19.1
5 | deepspeed==0.14.0
6 | accelerate==0.30.1
7 | peft==0.11.0
8 | timm==1.0.3
9 | numpy==1.22.2
10 | decord==0.6.0
11 | imageio==2.34.2
12 | imageio-ffmpeg==0.6.0
13 | moviepy==1.0.3
14 | scenedetect==0.6.3
15 | opencv-python==4.6.0.66
16 | pysubs2
17 | scikit-learn==1.3.0
18 | huggingface_hub==0.24.5
19 | einops==0.6.1
20 | bitsandbytes==0.43.1
21 | pydantic==1.10.7
22 | requests
23 | openai
24 | uvicorn
25 | fastapi
26 | tensorboard
27 | wandb
28 | tabulate
--------------------------------------------------------------------------------
/scripts/eval/eval_video_mcqa_mvbench.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | export HF_HOME=/mnt/data/jiaxing.zjx/cache/huggingface/
3 | export HF_ENDPOINT=http://hf-mirror.com
4 | EVAL_DATA_DIR=/mnt/data/qize.yqz/datasets/video_eval
5 | OUTPUT_DIR=eval_output_qwen_noaudio
6 | CKPT=./HumanOmni_7B
7 |
8 | CKPT_NAME=$(echo $CKPT | rev | cut -d'/' -f1 | rev)
9 |
10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
11 | IFS=',' read -ra GPULIST <<< "$gpu_list"
12 |
13 | # divide data via the number of GPUs per task
14 | GPUS_PER_TASK=1
15 | CHUNKS=$((${#GPULIST[@]}/$GPUS_PER_TASK))
16 |
17 | output_file=${OUTPUT_DIR}/mvbench/answers/${CKPT_NAME}/merge.json
18 |
19 | # judge if the number of json lines is 0
20 | if [ ! -f "$output_file" ] || [ $(cat "$output_file" | wc -l) -eq 0 ]; then
21 | rm -f ${OUTPUT_DIR}/mvbench/answers/${CKPT_NAME}/*.json
22 | fi
23 |
24 | if [ ! -f "$output_file" ]; then
25 | for IDX in $(seq 0 $((CHUNKS-1))); do
26 | gpu_devices=$(IFS=,; echo "${GPULIST[*]:$(($IDX*$GPUS_PER_TASK)):$GPUS_PER_TASK}")
27 | TRANSFORMERS_OFFLINE=1 CUDA_VISIBLE_DEVICES=${gpu_devices} python3 humanomni/eval/inference_video_mcqa_mvbench.py \
28 | --model-path ${CKPT} \
29 | --video-folder ${EVAL_DATA_DIR}/MVBench/video \
30 | --question-file ${EVAL_DATA_DIR}/MVBench/json \
31 | --answer-file ${OUTPUT_DIR}/mvbench/answers/${CKPT_NAME}/${CHUNKS}_${IDX}.json \
32 | --num-chunks $CHUNKS \
33 | --chunk-idx $IDX &
34 | done
35 |
36 | wait
37 |
38 | # Clear out the output file if it exists.
39 | > "$output_file"
40 |
41 | # Loop through the indices and concatenate each file.
42 | for IDX in $(seq 0 $((CHUNKS-1))); do
43 | cat ${OUTPUT_DIR}/mvbench/answers/${CKPT_NAME}/${CHUNKS}_${IDX}.json >> "$output_file"
44 | done
45 | fi
46 |
47 | python3 humanomni/eval/eval_video_mcqa_mvbench.py \
48 | --pred_path ${output_file} \
49 |
--------------------------------------------------------------------------------
/scripts/train/finetune_omni.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Environment Variables
4 | ARG_WORLD_SIZE=${1:-1}
5 | ARG_NPROC_PER_NODE=${2:-8}
6 | ARG_MASTER_ADDR="127.0.0.1"
7 | ARG_MASTER_PORT=16666
8 | ARG_RANK=0
9 |
10 | # Multiple conditions
11 | if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
12 | WORLD_SIZE=$ARG_WORLD_SIZE
13 | NPROC_PER_NODE=$ARG_NPROC_PER_NODE
14 | fi
15 | if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
16 | MASTER_ADDR=$ARG_MASTER_ADDR
17 | MASTER_PORT=$ARG_MASTER_PORT
18 | RANK=$ARG_RANK
19 | fi
20 |
21 | echo "WORLD_SIZE: $WORLD_SIZE"
22 | echo "NPROC_PER_NODE: $NPROC_PER_NODE"
23 |
24 | # Training Arguments
25 | GLOBAL_BATCH_SIZE=128
26 | LOCAL_BATCH_SIZE=1
27 | GRADIENT_ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
28 | echo $GRADIENT_ACCUMULATION_STEPS
29 |
30 | # Log Arguments
31 | export TRANSFORMERS_OFFLINE=1
32 | export WANDB_PROJECT=humanomniqwen2_siglip
33 | export HF_HOME=/mnt/data/jiaxing.zjx/cache/huggingface/
34 | export HF_ENDPOINT=http://hf-mirror.com
35 | RUN_NAME=HumanOmni
36 | DATA_DIR=/mnt/data/jiaxing.zjx/datasets/Video-LLaVA/
37 | OUTP_DIR=work_dirs
38 |
39 | torchrun --nnodes $WORLD_SIZE \
40 | --nproc_per_node $NPROC_PER_NODE \
41 | --master_addr=$MASTER_ADDR \
42 | --master_port=$MASTER_PORT \
43 | --node_rank $RANK \
44 | humanomni/train_flash_attn.py \
45 | --deepspeed scripts/zero3.json \
46 | --model_type HumanOmni_qwen2 \
47 | --model_path /mnt/data/jiaxing.zjx/code/HumanOmni/HumanOmni_7B_Video/ \
48 | --vision_tower google/siglip-so400m-patch14-384 \
49 | --audio_tower openai/whisper-large-v3 \
50 | --mm_projector_type all_in_one \
51 | --mm_tunable_parts "mm_mlp_adapter,audio_projector,mm_language_model" \
52 | --pretrain_audio_mlp_adapter /mnt/data/jiaxing.zjx/code/HumanOmni/HumanOmni_7B_Audio/audio_projector.bin \
53 | --data_path ./yamls/oryx_audio.yaml \
54 | --data_folder / \
55 | --mm_vision_select_layer -2 \
56 | --image_aspect_ratio pad \
57 | --num_frames 32 \
58 | --bf16 True \
59 | --tf32 True \
60 | --fp16 False \
61 | --output_dir ${OUTP_DIR}/${WANDB_PROJECT}/finetune_${RUN_NAME} \
62 | --num_train_epochs 1 \
63 | --per_device_train_batch_size $LOCAL_BATCH_SIZE \
64 | --per_device_eval_batch_size 4 \
65 | --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
66 | --evaluation_strategy "no" \
67 | --save_strategy "steps" \
68 | --save_steps 500 \
69 | --save_total_limit 99 \
70 | --learning_rate 2e-5 \
71 | --weight_decay 0. \
72 | --warmup_ratio 0.03 \
73 | --lr_scheduler_type "cosine" \
74 | --logging_steps 1 \
75 | --model_max_length 2048 \
76 | --gradient_checkpointing True \
77 | --mm_use_x_start_end True \
78 | --dataloader_num_workers 4 \
79 | --report_to tensorboard \
80 | --run_name $RUN_NAME \
81 |
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------