├── .gitignore ├── LICENSE ├── README.md ├── configs ├── accelerate │ └── default_config.yaml ├── datasets │ └── mmduetit.json └── deepspeed │ ├── pipeline.json │ ├── zero1.json │ ├── zero2.json │ ├── zero2offload.json │ └── zero3.json ├── data ├── __init__.py ├── data_collator.py ├── dvc.py ├── grounding.py ├── magqa.py ├── stream.py └── utils.py ├── demo ├── app.py ├── assets │ ├── assistant_avatar.png │ ├── cooking.mp4 │ ├── drive.mp4 │ ├── office.mp4 │ └── user_avatar.png └── liveinfer.py ├── models ├── __init__.py ├── arguments_live.py ├── configuration_live.py ├── live_llava │ └── video_head_live_llava_qwen.py ├── modeling_live.py ├── tokenization_live.py └── vision_live.py ├── requirements.txt ├── scripts ├── inference │ ├── charades.sh │ ├── magqa.sh │ ├── qvh.sh │ └── youcook2.sh └── train.sh ├── test ├── __init__.py ├── analyze_magqa_results.py ├── datasets.py ├── dvc │ ├── eval_dvc.py │ └── metrics │ │ ├── README.md │ │ ├── cider.py │ │ ├── cider_scorer.py │ │ ├── meteor-1.5.jar │ │ ├── meteor.py │ │ ├── ptbtokenizer.py │ │ └── stanford-corenlp-3.4.1.jar ├── evaluate.py ├── inference.py ├── openai_batch.py └── qvh │ ├── eval.py │ └── utils.py ├── train.py └── utils └── dist_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | /outputs 3 | /datasets 4 | /dataset_upload 5 | /ffmpeg 6 | 7 | **/debug.ipynb 8 | **/debug.py 9 | **/nohup.out 10 | /test/dvc/metrics/data/paraphrase-en.gz 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yueqian Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MMDuet 2 | [![Static Badge](https://img.shields.io/badge/🤗Model-MMDuet-yellow)](https://huggingface.co/wangyueqian/MMDuet) 3 | [![Static Badge](https://img.shields.io/badge/🤗Dataset-MMDuetIT-yellow)](https://huggingface.co/datasets/wangyueqian/MMDuetIT) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2411.17991-b31b1b.svg)](https://arxiv.org/abs/2411.17991) 5 | 6 | 7 | Official implementation of paper *VideoLLM Knows When to Speak: Enhancing Time-Sensitive Video Comprehension with Video-Text Duet Interaction Format* 8 | 9 | # Introduction 10 | 11 | [![Watch video on Youtube](http://img.youtube.com/vi/n1OybwhQvtk/0.jpg)](https://www.youtube.com/watch?v=n1OybwhQvtk) 12 | 13 | Video also available on [Bilibili (゜-゜)つロ干杯~](https://www.bilibili.com/video/BV1nwzGYBEPE) 14 | 15 | MMDuet is a VideoLLM implemented in the *video-text duet interaction format*, which treats the video stream as a role in the conversation akin to the user and the assistant. Under this interaction format, the video is continuously played and input to the model frame-by-frame. Both the user and model can insert their text messages right after any frame during the video play. When a text message ends, the video continues to play, akin to the show of two performers in a duet. 16 | 17 | **This not only ensures a timely response for video comprehension, but also improves the performance on many time-sensitive video-text multimodal tasks, such as temporal video grounding, highlight detection, and dense video captioning.** 18 | 19 | # Installation 20 | 1. Create conda environment and use pip to install some packages 21 | ```shell 22 | pip clone https://github.com/yellow-binary-tree/MMDuet 23 | cd MMDuet 24 | 25 | conda create -n mmduet python=3.10 26 | conda activate mmduet 27 | pip install --upgrade pip 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | 2. Install llava following the instructions in [https://github.com/LLaVA-VL/LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT) 32 | ```bash 33 | git clone https://github.com/LLaVA-VL/LLaVA-NeXT 34 | cd LLaVA-NeXT 35 | pip install -e ".[train]" 36 | ``` 37 | 38 | 3. Install flash-attention following the instructions in [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). If you have difficulties installing it, add `--attn_implementation sdpa` in every command to use the sdpa implementation of transformer attention for train or inference. 39 | 40 | 4. Download MMDuet checkpoints from HuggingFace: [https://huggingface.co/wangyueqian/MMDuet](https://huggingface.co/wangyueqian/MMDuet) and put the files under folder `./outputs/mmduet`. 41 | 42 | # Demo 43 | To launch a Gradio demo: `python -m demo.app --lora_pretrained outputs/mmduet` 44 | 45 | # Inference 46 | ## Download model and data 47 | 48 | - Download our data annotation for training (MMDuetIT) and evaluation from [wangyueqian/MMDuetIT](https://huggingface.co/datasets/wangyueqian/MMDuetIT) and put them in `datasets` folder. 49 | 50 | - Download the videos, and link each video folder to `datasets/${DATASET_NAME}/videos`. Here we list recommended video download links, while you can also download from other sources: 51 | - YouCook2: [https://opendatalab.com/OpenDataLab/YouCook2](https://opendatalab.com/OpenDataLab/YouCook2) 52 | - Shot2Story: [https://huggingface.co/mhan/shot2story-videos](https://huggingface.co/mhan/shot2story-videos) 53 | - Charades: [https://prior.allenai.org/projects/charades](https://prior.allenai.org/projects/charades) 54 | - QVHighlights: [https://github.com/jayleicn/moment_detr/blob/main/data/README.md](https://github.com/jayleicn/moment_detr/blob/main/data/README.md) 55 | 56 | - Download [paraphrase-en.gz](https://github.com/lichengunc/refer/raw/refs/heads/master/evaluation/meteor/data/paraphrase-en.gz) (59MB) which is used for dense video captioning evaluation. Put this file at `test/dvc/metrics/data/paraphrase-en.gz` 57 | 58 | ## Inference and evaluation 59 | Scripts to inference on all benchmarks are listed in `./scripts/inference/`. 60 | 61 | **WARNING**: Each script file contains many steps for inference and evaluation. DO NOT directly run these script files. Instead, read the contents of these files carefully and run them step by step. 62 | 63 | - YouCook2 dense video captioning: `./scripts/inference/youcook2.sh` 64 | - Shot2Story-MAGQA-39k multi-answer grounded video question answering (MAGQA): `./scripts/inference/magqa.sh` 65 | - **Note**: To save compute, we do not calculate the similarity score between the pred answer and the gold answer if the pred time is not in the gold timespan. We simply set this score to 1 in the score matrix of evaluator_output. These scores are not used in calculating and do not affect the final metric (in-span score). 66 | - Charades-STA temporal video grounding: `./scripts/inference/charades.sh` 67 | - QVHighlights highlight detection: `./scripts/inference/qvh.sh` 68 | 69 | 70 | # Training 71 | 72 | - If you want to reproduce the training process, you also need to download the training data. Download the videos, and link each video folder to `datasets/${DATASET_NAME}/videos`. Here we list recommended video download links, while you can also download from other sources: 73 | - COIN: [https://huggingface.co/datasets/WHB139426/coin](https://huggingface.co/datasets/WHB139426/coin) 74 | - HiREST: [https://github.com/j-min/HiREST](https://github.com/j-min/HiREST) 75 | - DiDeMo: [https://github.com/LisaAnne/TemporalLanguageRelease](https://github.com/LisaAnne/TemporalLanguageRelease) 76 | - QueryD: [https://www.robots.ox.ac.uk/~vgg/data/queryd/](https://www.robots.ox.ac.uk/~vgg/data/queryd/) 77 | 78 | Run `./scripts/train.sh`. 79 | 80 | When running training code for the first time, the dataset code will traverse all videos of the training dataset and stat the frame rate, duration and number of frames of the videos, and store this information in `datasets/${dataset_name}/videos_metadata.json`. This can take quite a long time. 81 | Considering that videos downloaded from different sources may be slightly different, in order to ensure that the videos are correctly loaded, we do not include this metadata information in our data release. 82 | 83 | # Acknowledgment 84 | The following projects has been of great help to this work: 85 | - [VideoLLM-online](https://github.com/showlab/VideoLLM-online) for providing codebase we built upon, 86 | - [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT) for providing awesome multi-modal foundation models, 87 | - [Shot2Story](https://github.com/bytedance/Shot2Story) for providing high-quality clip-level video captions. 88 | 89 | # Citation 90 | If you find this work useful in your research, please consider citing: 91 | ```bibtex 92 | @misc{wang2024mmduet, 93 | title={VideoLLM Knows When to Speak: Enhancing Time-Sensitive Video Comprehension with Video-Text Duet Interaction Format}, 94 | author={Yueqian Wang and Xiaojun Meng and Yuxuan Wang and Jianxin Liang and Jiansheng Wei and Huishuai Zhang and Dongyan Zhao}, 95 | year={2024}, 96 | eprint={2411.17991}, 97 | archivePrefix={arXiv}, 98 | primaryClass={cs.CV}, 99 | url={https://arxiv.org/abs/2411.17991}, 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /configs/accelerate/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: configs/deepspeed/zero2.json 5 | zero3_init_flag: false 6 | downcast_bf16: 'no' 7 | enable_cpu_affinity: false 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 2 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | 19 | distributed_type: FSDP 20 | fsdp_config: 21 | use_fsdp: true 22 | fsdp_sharding_strategy: 1 23 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 24 | fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer 25 | fsdp_ignored_modules: [vision_tower, mm_projector] -------------------------------------------------------------------------------- /configs/datasets/mmduetit.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dataset_cls": "DenseVideoCaptioningStreamDataset", 4 | "video_root": "datasets/shot2story/videos", 5 | "anno_file": "datasets/shot2story/annotations/dvc_train-human_anno-0.25_0.5_earlier.json", 6 | "metadata_path": "datasets/shot2story/videos_metadata.json" 7 | }, 8 | { 9 | "dataset_cls": "MAGQAStreamDataset", 10 | "video_root": "datasets/shot2story/videos", 11 | "anno_file": "datasets/shot2story/annotations/magqa_train-0.25_0.5-earlier.json", 12 | "metadata_path": "datasets/shot2story/videos_metadata.json" 13 | }, 14 | { 15 | "dataset_cls": "GroundingStreamDataset", 16 | "video_root": "datasets/didemo/videos", 17 | "anno_file": "datasets/didemo/annotations/train.json", 18 | "metadata_path": "datasets/didemo/videos_metadata.json" 19 | }, 20 | { 21 | "dataset_cls": "GroundingStreamDataset", 22 | "video_root": "datasets/hirest_grounding/videos", 23 | "anno_file": "datasets/hirest_grounding/annotations/train.json", 24 | "metadata_path": "datasets/hirest_grounding/videos_metadata.json", 25 | "frame_fps": 0.333, "max_num_frames": 120 26 | }, 27 | { 28 | "dataset_cls": "GroundingStreamDataset", 29 | "video_root": "datasets/queryd/videos", 30 | "anno_file": "datasets/queryd/annotations/train.json", 31 | "metadata_path": "datasets/queryd/videos_metadata.json", 32 | "frame_fps": 0.5, "max_num_frames": 120 33 | }, 34 | { 35 | "dataset_cls": "DenseVideoCaptioningStreamDataset", 36 | "video_root": "datasets/coin/videos", 37 | "anno_file": "datasets/coin/annotations/train-0.25_0.5_earlier-120s_240s.json", 38 | "metadata_path": "datasets/coin/videos_metadata.json", 39 | "frame_fps": 0.5, "max_num_frames": 120 40 | } 41 | ] -------------------------------------------------------------------------------- /configs/deepspeed/pipeline.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "zero_optimization": { 6 | "stage": 1 7 | }, 8 | "optimizer": { 9 | "type": "AdamW", 10 | "params": { 11 | "lr": 2e-5, 12 | "betas": [0.9,0.99], 13 | "eps": 1e-7, 14 | "weight_decay": 0 15 | } 16 | }, 17 | "scheduler": { 18 | "type": "WarmupLR", 19 | "params": { 20 | "warmup_min_lr": 0, 21 | "warmup_max_lr": 2e-5, 22 | "warmup_num_steps": 100 23 | } 24 | }, 25 | "gradient_accumulation_steps": 32, 26 | "train_micro_batch_size_per_gpu": 1, 27 | "stage3_gather_16bit_weights_on_model_save": false 28 | } -------------------------------------------------------------------------------- /configs/deepspeed/zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": {"enabled": "auto"}, 3 | "zero_optimization": { 4 | "stage": 1 5 | }, 6 | "gradient_accumulation_steps": "auto", 7 | "train_batch_size": "auto", 8 | "train_micro_batch_size_per_gpu": "auto", 9 | "stage3_gather_16bit_weights_on_model_save": false 10 | } -------------------------------------------------------------------------------- /configs/deepspeed/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": {"enabled": "auto"}, 3 | "zero_optimization": { 4 | "stage": 2, 5 | "offload_optimizer": { 6 | "device": "none", 7 | "pin_memory": true 8 | }, 9 | "allgather_partitions": true, 10 | "allgather_bucket_size": 5e8, 11 | "overlap_comm": true, 12 | "reduce_scatter": true, 13 | "reduce_bucket_size": 5e8, 14 | "contiguous_gradients": true 15 | }, 16 | "gradient_accumulation_steps": "auto", 17 | "train_batch_size": "auto", 18 | "train_micro_batch_size_per_gpu": "auto" 19 | } -------------------------------------------------------------------------------- /configs/deepspeed/zero2offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": {"enabled": "auto"}, 3 | "bf16": {"enabled": "auto"}, 4 | "zero_optimization": { 5 | "stage": 2, 6 | "offload_optimizer": { 7 | "device": "cpu", 8 | "pin_memory": true 9 | }, 10 | "allgather_partitions": true, 11 | "allgather_bucket_size": 2e8, 12 | "overlap_comm": true, 13 | "reduce_scatter": true, 14 | "reduce_bucket_size": 2e8, 15 | "contiguous_gradients": true 16 | }, 17 | "gradient_accumulation_steps": "auto", 18 | "train_batch_size": "auto", 19 | "train_micro_batch_size_per_gpu": "auto" 20 | } -------------------------------------------------------------------------------- /configs/deepspeed/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": {"enabled": "auto"}, 3 | "bf16": {"enabled": "auto"}, 4 | "zero_optimization": { 5 | "stage": 3, 6 | "offload_optimizer": { 7 | "device": "none", 8 | "pin_memory": true 9 | }, 10 | "offload_param": { 11 | "device": "none", 12 | "pin_memory": true 13 | }, 14 | "overlap_comm": true, 15 | "contiguous_gradients": true, 16 | "sub_group_size": 1e9, 17 | "stage3_max_live_parameters": 1e9, 18 | "stage3_max_reuse_distance": 1e9, 19 | "stage3_gather_16bit_weights_on_model_save": true 20 | }, 21 | "gradient_accumulation_steps": "auto", 22 | "train_batch_size": "auto", 23 | "train_micro_batch_size_per_gpu": "auto", 24 | "wall_clock_breakdown": false 25 | } -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import ConcatDataset, Dataset 2 | from functools import partial 3 | 4 | # all datasets loaded here 5 | from .data_collator import get_data_collator 6 | from .dvc import DenseVideoCaptioningStreamDataset 7 | from .magqa import MAGQAStreamDataset 8 | from .grounding import GroundingStreamDataset 9 | 10 | __all__ = [ 11 | 'build_concat_train_dataset', 12 | 'build_eval_dataset_dict', 13 | 'get_data_collator', 14 | 'get_compute_metrics_dict' 15 | ] 16 | 17 | def build_concat_train_dataset_from_config(tokenizer, config): 18 | datasets = list() 19 | for dataset_config in config: 20 | dataset_cls = dataset_config.pop('dataset_cls') 21 | datasets.append(globals()[dataset_cls](tokenizer=tokenizer, **dataset_config)) 22 | return ConcatDataset(datasets) 23 | -------------------------------------------------------------------------------- /data/data_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from transformers import PreTrainedTokenizer 4 | from transformers.trainer_pt_utils import LabelSmoother 5 | 6 | 7 | def data_collator_with_video_labels( 8 | batch: list[list], *, 9 | tokenizer: PreTrainedTokenizer = None, image_processor = None, 10 | model_config=None, **kwargs 11 | ): 12 | v_placeholder_id = model_config.v_placeholder_id 13 | frame_num_tokens = model_config.frame_num_tokens 14 | 15 | batch = list(zip(*batch)) 16 | batch_text, batch_frames, batch_learn_ranges, src_batch_response_labels, src_batch_related_labels, \ 17 | batch_sample_idx = batch 18 | batch = tokenizer(batch_text, return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt", padding=True) 19 | 20 | batch_labels = torch.full_like(batch.input_ids, LabelSmoother.ignore_index, dtype=torch.long) 21 | batch_response_labels = torch.full_like(batch.input_ids, LabelSmoother.ignore_index, dtype=torch.long) 22 | batch_related_labels = torch.full_like(batch.input_ids, LabelSmoother.ignore_index, dtype=torch.long) 23 | 24 | for text, labels, response_labels, related_labels, src_response_labels, src_related_labels, \ 25 | input_ids, offset_mapping, learn_range in zip( 26 | batch_text, batch_labels, batch_response_labels, batch_related_labels, src_batch_response_labels, src_batch_related_labels, 27 | batch.input_ids, batch.offset_mapping, batch_learn_ranges 28 | ): 29 | for learn_r in learn_range: 30 | start = torch.nonzero(offset_mapping[:,0] == learn_r.start).item() 31 | if offset_mapping[:,0][-1] >= learn_r.stop: 32 | stop = torch.nonzero(offset_mapping[:,0] == learn_r.stop).item() 33 | else: # the last eos token 34 | stop = len(input_ids) 35 | labels[start-1:stop-1] = input_ids[start:stop] 36 | 37 | v_placeholder_indices = torch.nonzero(input_ids == v_placeholder_id).squeeze() 38 | indices_to_learn = v_placeholder_indices[frame_num_tokens-1::frame_num_tokens] 39 | if src_response_labels is not None: 40 | response_labels[indices_to_learn] = torch.tensor(src_response_labels, dtype=torch.long) 41 | if src_related_labels is not None: 42 | related_labels[indices_to_learn] = torch.tensor(src_related_labels, dtype=torch.long) 43 | 44 | batch['labels'] = batch_labels 45 | batch['response_labels'] = batch_response_labels 46 | batch['related_labels'] = batch_related_labels 47 | batch.pop('offset_mapping') 48 | batch['frames'] = torch.cat(batch_frames) 49 | if image_processor is not None: 50 | batch['frames'] = image_processor.preprocess(batch['frames'], return_tensors="pt")['pixel_values'] 51 | batch['sample_idxs'] = torch.tensor(batch_sample_idx) 52 | return batch 53 | 54 | def get_data_collator(**kwargs): 55 | return partial(data_collator_with_video_labels, **kwargs) 56 | -------------------------------------------------------------------------------- /data/dvc.py: -------------------------------------------------------------------------------- 1 | import tqdm, random 2 | import numpy as np 3 | 4 | from .stream import StreamMixIn 5 | from .utils import ceil_time_by_fps, DictWithTo, reformat_example_for_debug 6 | from transformers.utils import logging 7 | 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class DenseVideoCaptioningStreamDataset(StreamMixIn): 12 | instructions = [ 13 | {"role": "user", "content": "Please concisely narrate the video in real time."}, 14 | {"role": "user", "content": "Help me to illustrate my view in short."}, 15 | {"role": "user", "content": "Please simply describe what do you see."}, 16 | {"role": "user", "content": "Continuously answer what you observed with simple text."}, 17 | {"role": "user", "content": "Do concise real-time narration."}, 18 | {"role": "user", "content": "Hey assistant, do you know the current video content? Reply me concisely."}, 19 | {"role": "user", "content": "Simply interpret the scene for me."}, 20 | {"role": "user", "content": "What can you tell me about? Be concise."}, 21 | {"role": "user", "content": "Use simple text to explain what is shown in front of me."}, 22 | {"role": "user", "content": "What is the action now? Please response in short."}, 23 | ] 24 | 25 | def __init__(self, **kwargs): 26 | super().__init__(**kwargs) 27 | annos, self.annos = self.annos, [] 28 | for video_uid, _annotation_uid_narrations in tqdm.tqdm(annos.items(), desc=self.anno_file): 29 | if video_uid not in self.metadata: 30 | continue 31 | duration = self.metadata[video_uid]['duration'] 32 | for narrations in _annotation_uid_narrations.values(): 33 | if not narrations: 34 | continue 35 | start_time = ceil_time_by_fps(0, self.frame_fps, min_time=0, max_time=duration) 36 | conversation = [] 37 | last_time = start_time 38 | last_text = None 39 | for narration in narrations: 40 | if last_time >= duration: 41 | break 42 | text = narration['text'] 43 | learn = narration.get('learn', True) 44 | if text == last_text: 45 | continue 46 | time = ceil_time_by_fps(narration['time'], self.frame_fps, min_time=0, max_time=duration) 47 | if time == last_time: # since we have sorted and ceiled, so directly replace, this time is more close 48 | conversation[-1]['content'] = text 49 | else: # time > last_time 50 | num_frames = int((time - last_time) * self.frame_fps) 51 | # here we set informative_label = 1 for the frames that are after the middle point of the frame, but before the assistant turn point 52 | # as once a response is already generated, we should not generate another one at once. 53 | response_start_time = ceil_time_by_fps(np.mean([narration['timespan'][0], narration['timespan'][1]]), self.frame_fps, min_time=0, max_time=duration) 54 | response_frame_num = int((time - response_start_time) * self.frame_fps) + 1 55 | conversation.extend([ 56 | {"role": "stream", 'num_frames': num_frames, 'learn': True}, 57 | {"role": "assistant", "content": text, 'learn': learn, 'response_frame_num': response_frame_num}, 58 | ]) 59 | last_time = time 60 | last_text = text 61 | if not conversation: 62 | continue 63 | self.annos.append({ 64 | 'conversation': conversation, 65 | 'load_ranges': {video_uid: range(int(start_time*self.frame_fps), int(last_time*self.frame_fps))} 66 | }) 67 | print(f'Dataset {self.__class__.__name__} has {len(self)} examples. Example data: {reformat_example_for_debug(self[0])}') 68 | 69 | def preprocess_conversation(self, conversation): 70 | return [random.choice(self.instructions)] + conversation 71 | 72 | def get_relevance_labels(self, conversation): 73 | # this label is for grounding task, no need to learn here 74 | return None 75 | 76 | def __getitem__(self, index): 77 | try: 78 | anno = self.annos[index] 79 | return *super().__getitem__( 80 | conversation=self.preprocess_conversation(anno['conversation']), 81 | load_ranges=anno['load_ranges'], 82 | ), index 83 | except Exception as e: 84 | logger.warning(f'Error in dataset {self.anno_file} when getting index {index}: {e}') 85 | logger.warning(f'Using a random data instead.') 86 | return self.__getitem__(random.choice(list(range(len(self))))) 87 | 88 | 89 | if __name__ == '__main__': 90 | from models.configuration_live import LiveConfigMixin 91 | from models.tokenization_live import build_live_tokenizer_and_update_config 92 | llava_config = LiveConfigMixin(frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10) 93 | llava_tokenizer = build_live_tokenizer_and_update_config('lmms-lab/llava-onevision-qwen2-7b-ov', llava_config) 94 | 95 | dataset = DenseVideoCaptioningStreamDataset( 96 | video_root='datasets/shot2story/videos_2fps_max384', 97 | anno_file='datasets/shot2story/annotations/narration_stream_train-human_anno-0.25_0.5_earlier.json', 98 | metadata_path='datasets/shot2story/videos_2fps_max384_metadata.json', 99 | system_prompt='This is a system prompt.', 100 | tokenizer=llava_tokenizer, 101 | frame_fps=2, max_num_frames=100 102 | ) 103 | 104 | print('length of the dataset:', len(dataset)) 105 | for i in range(0, min(1000, len(dataset)), 20): 106 | example = dataset[i] 107 | print(reformat_example_for_debug(example)) 108 | -------------------------------------------------------------------------------- /data/grounding.py: -------------------------------------------------------------------------------- 1 | import json, os, shutil 2 | import cv2 3 | from tqdm import tqdm, trange 4 | import math 5 | import random 6 | 7 | from transformers.utils import logging 8 | from .stream import StreamMixIn 9 | from .utils import reformat_example_for_debug, DictWithTo 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class GroundingStreamDataset(StreamMixIn): 14 | query_templates = [ 15 | "%s", 16 | "%s", 17 | "What segment of the video addresses the topic '%s'?", 18 | "At what timestamp can I find information about '%s' in the video?", 19 | "Can you highlight the section of the video that pertains to '%s'?", 20 | "Which moments in the video discuss '%s' in detail?", 21 | "Identify the parts that mention '%s'.", 22 | "Where in the video is '%s' demonstrated or explained?", 23 | "What parts are relevant to the concept of '%s'?", 24 | "Which clips in the video relate to the query '%s'?", 25 | "Can you point out the video segments that cover '%s'?", 26 | "What are the key timestamps in the video for the topic '%s'?" 27 | ] 28 | 29 | def __init__(self, **kwargs): 30 | super().__init__(**kwargs) 31 | annos, self.annos = self.annos, list() 32 | for anno in tqdm(annos): 33 | video_uid = anno['video_uid'] 34 | if video_uid not in self.metadata: 35 | continue 36 | duration = self.metadata[video_uid]['duration'] 37 | conversation, current_frame = list(), 0 38 | conversation.append({'role': 'user', 'content': random.choice(self.query_templates) % anno['query'], 'learn': False}) 39 | related_info = list() 40 | for start_time, end_time in anno['timestamps']: 41 | start_frame = math.floor(start_time * self.frame_fps) 42 | if start_frame > current_frame: 43 | related_info.append({'related': False, 'num_frames': start_frame - current_frame}) 44 | end_frame = math.floor(end_time * self.frame_fps) 45 | related_info.append({'related': True, 'num_frames': end_frame - start_frame}) 46 | current_frame = end_frame 47 | last_frame = math.floor(duration * self.frame_fps) 48 | if last_frame > current_frame: 49 | related_info.append({'related': False, 'num_frames': last_frame - current_frame}) 50 | conversation.append({'role': 'stream', 'num_frames': last_frame, 'learn': True, 'related': related_info}) 51 | self.annos.append({ 52 | 'conversation': conversation, 53 | 'load_ranges': {video_uid: range(0, last_frame)} 54 | }) 55 | print(f'Dataset {self.__class__.__name__} has {len(self)} examples. Example data: {reformat_example_for_debug(self[0])}') 56 | 57 | def get_informative_labels(self, conversation): 58 | # this label is for captioning and qa task, no need to learn here 59 | return None 60 | 61 | def __getitem__(self, index): 62 | try: 63 | anno = self.annos[index] 64 | res = *super().__getitem__( 65 | conversation=anno['conversation'], 66 | load_ranges=anno['load_ranges'], 67 | ), index 68 | except Exception as e: 69 | logger.warning(f'Error in dataset {self.anno_file} when getting index {index}: {e}') 70 | logger.warning(f'Using a random data instead.') 71 | res = self.__getitem__(random.choice(list(range(len(self))))) 72 | return res 73 | 74 | if __name__ == '__main__': 75 | from models.configuration_live import LiveConfigMixin 76 | from models.tokenization_live import build_live_tokenizer_and_update_config 77 | llava_config = LiveConfigMixin(frame_token_cls=False, frame_token_pooled=[1,1], frame_num_tokens=1) 78 | llava_tokenizer = build_live_tokenizer_and_update_config('lmms-lab/llava-onevision-qwen2-7b-ov', llava_config) 79 | 80 | dataset = GroundingStreamDataset( 81 | video_root='datasets/queryd/videos', 82 | anno_file='datasets/queryd/annotations/train.json', 83 | metadata_path='datasets/queryd/videos_metadata.json', 84 | system_prompt='This is a system prompt.', tokenizer=llava_tokenizer, 85 | frame_fps=0.5, max_num_frames=120 86 | ) 87 | 88 | print('length of the dataset:', len(dataset)) 89 | for i in range(0, min(1000, len(dataset)), 20): 90 | example = dataset[i] 91 | print(reformat_example_for_debug(example)) 92 | -------------------------------------------------------------------------------- /data/magqa.py: -------------------------------------------------------------------------------- 1 | import random, json, tqdm 2 | import numpy as np 3 | import torch 4 | from .stream import StreamMixIn 5 | from .utils import ceil_time_by_fps, floor_time_by_fps, rand_bool, DictWithTo, reformat_example_for_debug 6 | 7 | from transformers.utils import logging 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class MAGQAStreamDataset(StreamMixIn): 12 | def __init__(self, **kwargs): 13 | super().__init__(**kwargs) 14 | annos, self.annos = self.annos, [] 15 | for anno in tqdm.tqdm(annos): 16 | video_uid = anno['video_uid'] 17 | if video_uid not in self.metadata: 18 | continue 19 | duration = self.metadata[video_uid]['duration'] 20 | if not anno['conversation']: 21 | continue 22 | role = anno['conversation'][0]['role'] 23 | time = anno['conversation'][0]['time'] 24 | video_start_time = anno.get('video_start_time', 100000000) # video starting from here should be used as input 25 | content = anno['conversation'][0]['content'] 26 | if not (role == 'user' and time > 0 and time <= duration and content): 27 | continue 28 | 29 | # 1. add random frames before the user 30 | fps_time = ceil_time_by_fps(time, self.frame_fps, 0, duration) 31 | waiting_frames = random.randint(int((fps_time - video_start_time) * self.frame_fps), int(fps_time * self.frame_fps)) 32 | waiting_frames = max(0, min(20, waiting_frames)) 33 | conversation = [] 34 | if waiting_frames: 35 | conversation.append({'role': 'stream', 'num_frames': waiting_frames, 'learn': waiting_frames - 1}) 36 | conversation.append({'role': 'user', 'content': content, 'time': time, 'fps_time': fps_time}) 37 | start_fps_time = fps_time - waiting_frames / self.frame_fps 38 | 39 | # 2. for loop to add message 40 | for message in anno['conversation'][1:]: 41 | role, content, time, learn, timespan = message['role'], message['content'], message['time'], message.get('learn', True), message.get('timespan', None) 42 | if time > duration: 43 | break 44 | 45 | if role == 'user': 46 | fps_time = ceil_time_by_fps(time, self.frame_fps, conversation[-1]['fps_time'], duration) 47 | if fps_time > duration: 48 | break 49 | if fps_time > conversation[-1]['fps_time']: 50 | conversation.append({'role': 'stream', 'num_frames': int((fps_time - conversation[-1]['fps_time']) * self.frame_fps), 'learn': True}) 51 | conversation.append({'role': 'user', 'content': content, 'time': time, 'fps_time': fps_time}) 52 | else: 53 | fps_time = ceil_time_by_fps(time, self.frame_fps, conversation[-1]['fps_time'], duration) 54 | if fps_time > duration: 55 | break 56 | if fps_time > conversation[-1]['fps_time']: 57 | num_frames = int((fps_time - conversation[-1]['fps_time']) * self.frame_fps) 58 | conversation.append({'role': 'stream', 'num_frames': num_frames, 'learn': True}) 59 | # here we set informative_label = 1 for the frames that are after the middle point of the frame, but before the assistant turn point 60 | # as once a response is already generated, we should not generate another one at once. 61 | response_start_time = ceil_time_by_fps(np.mean([timespan[0], timespan[1]]), self.frame_fps, min_time=0, max_time=duration) 62 | response_frame_num = int((time - response_start_time) * self.frame_fps) + 1 63 | response_frame_num = min(response_frame_num, num_frames) 64 | conversation.append({'role': 'assistant', 'content': content, 'time': time, 'fps_time': fps_time, 'learn': learn, 'response_frame_num': response_frame_num}) 65 | if not conversation: 66 | continue 67 | self.annos.append({ 68 | 'conversation': conversation, 69 | 'load_ranges': {video_uid: range(int(start_fps_time*self.frame_fps), int(conversation[-1]['fps_time']*self.frame_fps))} 70 | }) 71 | 72 | print(f'Dataset {self.__class__.__name__} has {len(self)} examples. Example data: {reformat_example_for_debug(self[0])}') 73 | 74 | # DEPRECATED 75 | def preprocess_conversation(self, conversation): 76 | if self.augmentation and self.is_training and len(conversation) >= 4: # 2 round 77 | i = random.randint(0, len(conversation) - 1) # stream, assistant, stream, ... 78 | if i > len(conversation) - 3: 79 | return [random.choice(self.user_instructions)] + conversation 80 | if conversation[i]['role'] == 'stream': 81 | i += 1 # assistant 82 | assert conversation[i]['role'] == 'assistant' 83 | correct_assistant = conversation[i] 84 | wrong_texts = set([turn['content'] for turn in conversation if 'assistant' == turn['role']]) - set(correct_assistant['content']) 85 | wrong_texts = list(wrong_texts) + [''] 86 | wrong_assistant = {'role': 'assistant', 'content': random.choice(wrong_texts)} 87 | augmented = [wrong_assistant] 88 | num_next_frames = conversation[i+1]['intervals'].numel() 89 | if num_next_frames > 1: 90 | if rand_bool(): # promptly fix behavior 91 | frame_placeholder_with_interval = self.v_placeholders_per_frame + self.frame_interval 92 | next_stream_placeholder = frame_placeholder_with_interval * (num_next_frames - 1) 93 | next_intervals = torch.arange(len(frame_placeholder_with_interval), len(next_stream_placeholder)+1, len(frame_placeholder_with_interval)) - len(self.frame_interval) 94 | if self.frame_interval: # last frame does not have frame interval 95 | next_stream_placeholder = next_stream_placeholder[:-len(self.frame_interval)] 96 | augmented += [ 97 | {'role': 'stream', 'content': self.v_placeholders_per_frame, 'intervals': torch.tensor([len(self.v_placeholders_per_frame)])}, 98 | correct_assistant, 99 | {'role': 'stream', 'content': next_stream_placeholder, 'intervals': next_intervals} 100 | ] 101 | else: # condition on video behavior 102 | augmented += [ 103 | {'role': 'stream', 'content': conversation[i+1]['content']} 104 | ] 105 | else: 106 | augmented += [conversation[i+1]] 107 | conversation = conversation[:i] + augmented + conversation[i+2:] 108 | return [random.choice(self.user_instructions)] + conversation 109 | 110 | def get_relevance_labels(self, conversation): 111 | # this label is for grounding task, no need to learn here 112 | return None 113 | 114 | def __getitem__(self, index): 115 | try: 116 | anno = self.annos[index] 117 | res = *super().__getitem__( 118 | conversation=anno['conversation'], 119 | load_ranges=anno['load_ranges'], 120 | ), index 121 | except Exception as e: 122 | logger.warning(f'Error in dataset {self.anno_file} when getting index {index}: {e}') 123 | logger.warning(f'Using a random data instead.') 124 | res = self.__getitem__(random.choice(list(range(len(self))))) 125 | return res 126 | 127 | 128 | if __name__ == '__main__': 129 | from models.configuration_live import LiveConfigMixin 130 | from models.tokenization_live import build_live_tokenizer_and_update_config 131 | llava_config = LiveConfigMixin(frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10) 132 | llava_tokenizer = build_live_tokenizer_and_update_config('lmms-lab/llava-onevision-qwen2-7b-ov', llava_config) 133 | 134 | dataset = MAGQAStreamDataset( 135 | video_root='datasets/shot2story/videos_2fps_max384', 136 | anno_file='datasets/shot2story/annotations/livechat_train-multiturn-gpt4o-0.25_0.5-earlier.json', 137 | metadata_path='datasets/shot2story/videos_2fps_max384_metadata.json', 138 | system_prompt='This is a system prompt.', 139 | tokenizer=llava_tokenizer 140 | ) 141 | print(len(dataset)) 142 | print(reformat_example_for_debug(dataset[0])) 143 | print(reformat_example_for_debug(dataset[1])) 144 | -------------------------------------------------------------------------------- /data/stream.py: -------------------------------------------------------------------------------- 1 | import torch, os, json, tqdm, math, random, cv2 2 | import numpy as np 3 | from transformers import PreTrainedTokenizer 4 | import torch.distributed as dist 5 | import multiprocessing as mp 6 | 7 | from .utils import rand_bool, resize_and_pad_frame 8 | 9 | 10 | def get_all_files(directory): 11 | relative_file_list = [] 12 | for root, dirs, files in os.walk(directory): 13 | for file in files: 14 | # Get the relative path by removing the directory part from the absolute path 15 | relative_path = os.path.relpath(os.path.join(root, file), directory) 16 | relative_file_list.append(relative_path) 17 | return relative_file_list 18 | 19 | 20 | def get_video_duration_and_fps(args): 21 | file, video_root = args 22 | path = os.path.join(video_root, file) 23 | cap = cv2.VideoCapture(path) 24 | fps = cap.get(cv2.CAP_PROP_FPS) 25 | frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) 26 | duration = frame_count / fps if fps > 0 else 0 27 | return file, {'duration': duration, 'fps': fps, 'path': path, 'frame_count': frame_count} 28 | 29 | 30 | class StreamMixIn(torch.utils.data.Dataset): 31 | def __init__(self, 32 | video_root: str = None, anno_file: str = None, metadata_path: str = None, frame_fps: float = 2, frame_size: int = 384, 33 | system_prompt: str = None, augmentation: bool = False, 34 | max_num_frames: int = 128, tokenizer: PreTrainedTokenizer = None, skip_video=False, **kwargs): 35 | super().__init__() 36 | self.video_root = video_root 37 | self.anno_file = anno_file 38 | self.metadata_path = metadata_path 39 | self.frame_fps = frame_fps 40 | self.frame_size = frame_size 41 | self.system_prompt = system_prompt if system_prompt is not None else "A multimodal AI assistant is helping users with some activities. Below is their conversation, interleaved with the list of video frames received by the assistant." 42 | self.augmentation = augmentation 43 | self.tokenizer = tokenizer 44 | self.max_num_frames = max_num_frames 45 | self.skip_video = skip_video # used in text-only scenarios 46 | self.metadata = self.get_metadata() 47 | self.annos = self.get_annos() 48 | 49 | def __len__(self): 50 | return len(self.annos) 51 | 52 | def get_annos(self) -> dict: 53 | anno_path = os.path.join(self.anno_file) 54 | assert os.path.exists(anno_path) 55 | return json.load(open(anno_path)) 56 | 57 | def max_frames_clip(self, conversation: list[dict], load_ranges: dict[str, range], max_num_frames: int): 58 | cum_num_frames = 0 59 | for i, message in enumerate(conversation): 60 | if message['role'] == 'stream': 61 | if cum_num_frames + message['num_frames'] >= max_num_frames: 62 | if cum_num_frames < max_num_frames: 63 | # crop this video stream to fewer frames 64 | conversation[i]['num_frames'] = max_num_frames - cum_num_frames 65 | conversation = conversation[:i+1] 66 | else: 67 | conversation = conversation[:i] 68 | load_ranges = {path: range(ranger.start, ranger.start + max_num_frames) for path, ranger in load_ranges.items()} 69 | break 70 | cum_num_frames += message['num_frames'] 71 | return conversation, load_ranges 72 | 73 | def get_metadata(self): 74 | if os.path.exists(self.metadata_path): 75 | print(f'load {self.metadata_path}...') 76 | metadata = json.load(open(self.metadata_path)) 77 | else: 78 | metadata = {} 79 | if not dist.is_initialized() or dist.get_rank() == 0: 80 | # only the main process needs to prepare metadata 81 | files = get_all_files(self.video_root) 82 | with mp.Pool(20) as pool: 83 | results = list(tqdm.tqdm(pool.imap( 84 | get_video_duration_and_fps, [(file, self.video_root) for file in files]), 85 | total=len(files), desc=f'prepare {self.metadata_path}...')) 86 | for key, value in results: 87 | metadata[key] = value 88 | with open(self.metadata_path, 'w') as f: 89 | json.dump(metadata, f, indent=4) 90 | if dist.is_initialized(): 91 | dist.barrier() 92 | else: 93 | dist.barrier() 94 | metadata = json.load(open(self.metadata_path)) 95 | return metadata 96 | 97 | def load_video(self, file): 98 | video_metadata = self.metadata[file] 99 | # load the frames, and downsample to self.frame_fps 100 | cap = cv2.VideoCapture(video_metadata['path']) 101 | num_frames_total = math.floor(video_metadata['duration'] * self.frame_fps) 102 | frame_sec = [i / self.frame_fps for i in range(num_frames_total)] 103 | frames, cur_time, frame_index = [], 0, 0 104 | while True: 105 | ret, frame = cap.read() 106 | if not ret: 107 | break 108 | if frame_index < len(frame_sec) and cur_time >= frame_sec[frame_index]: 109 | frame = resize_and_pad_frame(frame, self.frame_size) 110 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 111 | frames.append(frame) 112 | frame_index += 1 113 | cur_time += 1 / video_metadata['fps'] 114 | cap.release() 115 | frames = np.array(frames) # shape will be (T, H, W, C) 116 | frames = np.transpose(frames, (0, 3, 1, 2)) # Change to (T, C, H, W) 117 | return torch.tensor(frames) 118 | 119 | def get_informative_labels(self, conversation): 120 | informative_labels = list() 121 | for i, turn in enumerate(conversation): 122 | if turn['role'] == 'stream' and turn['num_frames'] > 0: 123 | if turn['learn']: 124 | if i != len(conversation) - 1: 125 | next_turn = conversation[i + 1] 126 | response_frame_num = next_turn.get('response_frame_num', 1) 127 | next_role = next_turn['role'] 128 | else: 129 | response_frame_num = 1 130 | next_role = None 131 | informative_labels += [0] * (turn['num_frames'] - response_frame_num) 132 | informative_labels += [int(next_role == 'assistant')] * response_frame_num 133 | else: 134 | informative_labels += [-100] * turn['num_frames'] 135 | return informative_labels 136 | 137 | def get_relevance_labels(self, conversation): 138 | relevance_labels = list() 139 | for turn in conversation: 140 | if turn['role'] == 'stream' and turn['num_frames'] > 0: 141 | if turn['learn']: 142 | for related_info in turn['related']: 143 | relevance_labels += [int(related_info['related'])] * related_info['num_frames'] 144 | else: 145 | relevance_labels += [-100] * turn['num_frames'] 146 | return relevance_labels 147 | 148 | def __getitem__(self, *, conversation: list[dict], load_ranges: dict[str, range] | torch.Tensor = None, add_generation_prompt=False, **kwargs): 149 | # 1. load videos 150 | if self.skip_video: 151 | frames = torch.tensor([]) 152 | elif isinstance(load_ranges, torch.Tensor): 153 | frames = load_ranges 154 | elif load_ranges is not None: 155 | conversation, load_ranges = self.max_frames_clip(conversation, load_ranges, self.max_num_frames) 156 | # after max_frames_clip, sometimes there may be no conversation left due to the conversations are too late. 157 | # we also need to keep this kind of data, as no conversation can also be a real-time situation 158 | ranges = [self.load_video(path)[ranger] for path, ranger in load_ranges.items()] 159 | frames = torch.cat(ranges) 160 | else: 161 | frames = torch.tensor([]) 162 | 163 | # 2. prepare texts 164 | if self.augmentation: 165 | conversation = self.augment(conversation) 166 | conversation = [{"role": "system", "content": self.system_prompt}] + conversation 167 | text = self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=add_generation_prompt) 168 | 169 | # 3. learn ranges 170 | learn_ranges = self.tokenizer.get_learn_ranges(conversation) if not add_generation_prompt else [] 171 | # check if the number of frames in video and text is equal 172 | if not self.skip_video: 173 | num_frames_in_video = len(frames) 174 | num_frames_in_text = sum([turn['num_frames'] for turn in conversation if turn['role'] == 'stream']) 175 | assert num_frames_in_video == num_frames_in_text, f"num_frames_in_video: {num_frames_in_video}, num_frames_in_text: {num_frames_in_text}" 176 | 177 | # 4. get response labels or related labels according to subclass 178 | # the default logic is written in this class. if do not want to learn with this label, you can override in subclass with `return None` 179 | informative_labels, relevance_labels = self.get_informative_labels(conversation), self.get_relevance_labels(conversation) 180 | if not self.skip_video and informative_labels is not None: 181 | assert len(informative_labels) >= len(frames), f"len(informative_labels): {len(informative_labels)}, len(frames): {len(frames)}" 182 | informative_labels = informative_labels[:len(frames)] 183 | if not self.skip_video and relevance_labels is not None: 184 | assert len(relevance_labels) >= len(frames), f"len(relevance_labels): {len(relevance_labels)}, len(frames): {len(frames)}" 185 | relevance_labels = relevance_labels[:len(frames)] 186 | 187 | return text, frames, learn_ranges, informative_labels, relevance_labels 188 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import random, torch, tqdm, os, subprocess, torchvision, pathlib, submitit, math 2 | import cv2 3 | from itertools import takewhile 4 | try: 5 | torchvision.set_video_backend('video_reader') 6 | except: 7 | pass 8 | from transformers import AutoModel 9 | from torchvision.transforms.functional import to_pil_image, normalize 10 | 11 | 12 | def reformat_example_for_debug(data): 13 | if isinstance(data, torch.Tensor) and len(data.size()) > 1: 14 | return data.size() 15 | if isinstance(data, dict): 16 | return {k: reformat_example_for_debug(v) for k, v in data.items()} 17 | if isinstance(data, (list, tuple)): 18 | return [reformat_example_for_debug(v) for v in data] 19 | return data 20 | 21 | 22 | class DictWithTo(dict): 23 | def to(self, *args, **kwargs): 24 | return self 25 | 26 | def inverse_preprocess_to_pil_images(frames: torch.Tensor, mean: list, std: list): 27 | frames = normalize(frames, mean=tuple(-m / s for m, s in zip(mean, std)), std=tuple(1.0 / s for s in std)) 28 | frames = (frames * 255).to(torch.uint8) 29 | return list(map(to_pil_image, frames)) 30 | 31 | def rand_bool(): 32 | return bool(random.getrandbits(1)) 33 | 34 | def case_connect(prefix: str, suffix: str): 35 | if not prefix: 36 | return suffix[0].upper() + suffix[1:] 37 | if not suffix: 38 | return prefix 39 | if prefix[-1] == ',' or prefix[-1] == ':': 40 | return prefix + ' ' + suffix[0].lower() + suffix[1:] 41 | return prefix + ' ' + suffix[0].upper() + suffix[1:] 42 | 43 | def batch_temporal_iou(sequences1: torch.Tensor, sequences2: torch.Tensor): 44 | area1 = sequences1[:, 1] - sequences1[:, 0] 45 | area2 = sequences2[:, 1] - sequences2[:, 0] 46 | l = torch.maximum(sequences1[:,None,0], sequences2[:,0]) 47 | r = torch.minimum(sequences1[:,None,1], sequences2[:,1]) 48 | inter = (r - l).clamp(min=0) 49 | union = area1[:, None] + area2 - inter 50 | iou = inter / union 51 | return iou 52 | 53 | def temporal_iou(region1, region2): 54 | area1 = region1[1] - region1[0] 55 | area2 = region2[1] - region2[0] 56 | l = max(region1[0], region2[0]) 57 | r = min(region1[1], region2[1]) 58 | inter = max(0, (r - l)) 59 | union = area1 + area2 - inter 60 | iou = inter / union 61 | return iou 62 | 63 | def ffmpeg_once(src_path: str, dst_path: str, *, fps: int = None, resolution: int = None, pad: str = '#000000', mode='bicubic'): 64 | os.makedirs(os.path.dirname(dst_path), exist_ok=True) 65 | command = [ 66 | './ffmpeg/ffmpeg', 67 | # '-y', 68 | '-n', # skip if target file exists 69 | '-sws_flags', mode, 70 | '-i', src_path, 71 | '-an', 72 | '-threads', '10', 73 | ] 74 | if fps is not None: 75 | command += ['-r', str(fps)] 76 | if resolution is not None: 77 | command += ['-vf', f"scale='if(gt(iw\\,ih)\\,{resolution}\\,-2)':'if(gt(iw\\,ih)\\,-2\\,{resolution})',pad={resolution}:{resolution}:(ow-iw)/2:(oh-ih)/2:color='{pad}'"] 78 | command += [dst_path] 79 | subprocess.run(command, check=True) 80 | 81 | def distributed_ffmpeg(*, src_root: str, fps: int = None, resolution: int = None, pad: str = '#000000', mode='bicubic'): 82 | import submitit 83 | env = submitit.JobEnvironment() 84 | src_root = src_root.rstrip('/') 85 | pather = pathlib.Path(src_root) 86 | src_paths = [str(path) for path in pather.rglob('*') if path.is_file() and str(path).endswith('.mp4')] 87 | dst_root = src_root 88 | if fps is not None: 89 | dst_root += f'_{fps}fps' 90 | if resolution is not None: 91 | assert (pad is not None) 92 | dst_root += f'_max{resolution}' 93 | for i, src_path in tqdm.tqdm(enumerate(src_paths), desc=f'{src_root} -> {dst_root}'): 94 | if i % env.num_tasks != env.global_rank: 95 | continue 96 | dst_path = src_path.replace(src_root, dst_root) 97 | ffmpeg_once(src_path, dst_path, fps=fps, resolution=resolution, pad=pad, mode=mode) 98 | 99 | def distributed_encode(*, src_root: str, vision_pretrained: str, vision_encode: callable, batch_size: int, embed_mark: str, save_bf16: bool = False, **kwargs): 100 | env = submitit.JobEnvironment() 101 | src_root = src_root.rstrip('/') 102 | model = AutoModel.from_pretrained(vision_pretrained, device_map=f'cuda:{env.local_rank}').vision_model 103 | model.eval() 104 | dst_root = f"{src_root}_{embed_mark.split('_')[-1]}_{vision_pretrained.replace('/', '--')}" 105 | os.makedirs(dst_root, exist_ok=True) 106 | for i, file in tqdm.tqdm(enumerate(os.listdir(src_root)), desc=f'{src_root} -> {dst_root}'): 107 | if i % env.num_tasks != env.global_rank: 108 | continue 109 | frame_path = os.path.join(src_root, file) 110 | save_path = os.path.splitext(frame_path)[0] + '.pt' 111 | save_path = save_path.replace(src_root, dst_root) 112 | frames = torchvision.io.read_video(frame_path, pts_unit='sec', output_format='TCHW')[0] 113 | with torch.no_grad(): 114 | frames = torch.cat([vision_encode(model, batch.to(f'cuda:{env.local_rank}')).cpu() for batch in frames.split(batch_size)]) 115 | if save_bf16: 116 | frames = frames.to(torch.bfloat16) 117 | torch.save(frames, save_path) 118 | 119 | def round_time_by_fps(time: float, fps: int, min_time: float, max_time: float): 120 | return min(max(round(time * fps) / fps, min_time), max_time) 121 | 122 | def ceil_time_by_fps(time: float, fps: int, min_time: float, max_time: float): 123 | return min(max(math.ceil(time * fps) / fps, min_time), max_time) 124 | 125 | def floor_time_by_fps(time: float, fps: int, min_time: float, max_time: float): 126 | return min(max(math.floor(time * fps) / fps, min_time), max_time) 127 | 128 | def resize_and_pad_frame(frame, output_size, pad_color=(0, 0, 0)): 129 | input_height, input_width = frame.shape[:2] 130 | if input_height == output_size and input_width == output_size: 131 | return frame 132 | if input_width > input_height: 133 | # Landscape video: scale width to the resolution, adjust height 134 | new_width = output_size 135 | new_height = int((input_height / input_width) * output_size) 136 | else: 137 | # Portrait video: scale height to the resolution, adjust width 138 | new_height = output_size 139 | new_width = int((input_width / input_height) * output_size) 140 | resized_frame = cv2.resize(frame, (new_width, new_height)) 141 | # pad the frame 142 | canvas = cv2.copyMakeBorder( 143 | resized_frame, 144 | top=(output_size - new_height) // 2, 145 | bottom=(output_size - new_height + 1) // 2, 146 | left=(output_size - new_width) // 2, 147 | right=(output_size - new_width + 1) // 2, 148 | borderType=cv2.BORDER_CONSTANT, 149 | value=pad_color 150 | ) 151 | return canvas 152 | -------------------------------------------------------------------------------- /demo/app.py: -------------------------------------------------------------------------------- 1 | import os, torchvision, transformers, time 2 | import gradio as gr 3 | from threading import Event 4 | 5 | from models import parse_args 6 | from demo.liveinfer import LiveInferForDemo, load_video 7 | logger = transformers.logging.get_logger('liveinfer') 8 | 9 | args = parse_args('test') 10 | args.stream_end_prob_threshold = 0.3 11 | liveinfer = LiveInferForDemo(args) 12 | 13 | pause_event = Event() # Event for pausing/resuming 14 | pause_event.set() # Initially, processing is allowed (not paused) 15 | 16 | css = """ 17 | #gr_title {text-align: center;} 18 | #gr_video {max-height: 480px;} 19 | #gr_chatbot {max-height: 480px;} 20 | """ 21 | 22 | 23 | class HistorySynchronizer: 24 | def __init__(self): 25 | self.history = [] 26 | 27 | def set_history(self, history): 28 | self.history = history 29 | 30 | def get_history(self): 31 | return self.history 32 | 33 | def reset(self): 34 | self.history = [] 35 | 36 | history_synchronizer = HistorySynchronizer() 37 | 38 | 39 | class ChatInterfaceWithUserMsgTime(gr.ChatInterface): 40 | async def _display_input( 41 | self, message: str, history 42 | ): 43 | message = f"[time={liveinfer.video_time:.1f}s] {message}" 44 | history = history_synchronizer.get_history() 45 | if isinstance(message, str) and self.type == "tuples": 46 | history.append([message, None]) # type: ignore 47 | elif isinstance(message, str) and self.type == "messages": 48 | history.append({"role": "user", "content": message}) # type: ignore 49 | history_synchronizer.set_history(history) 50 | return history # type: ignore 51 | 52 | 53 | with gr.Blocks(title="MMDuet", css=css) as demo: 54 | gr.Markdown("# VideoLLM Knows When to Speak: Enhancing Time-Sensitive Video Comprehension with Video-Text Duet Interaction Format", elem_id='gr_title') 55 | with gr.Row(): # row for instructions 56 | with gr.Column(): 57 | gr.Markdown(( 58 | 'This demo demonstrates **MMDuet**, a VideoLLM you can interact with in a real-time manner while the video plays.\n' 59 | '## Usage\n' 60 | '1. Upload the video, and set the "Threshold Mode", "Scores Used", "Remove Previous Model Turns in Context" and "Threshold" hyperparameters. After this click "Start Chat", the frames will be sampled from the video and encoded by the ViT of MMDuet,' 61 | 'then the video will start to progress in the bottom left corner.\n' 62 | '1. When the video is progressing, if you want to send a message, type in the message box and click "Submit". Your message will be inserted at the current position of the video.\n' 63 | 'You can also pause the video before typing & submitting your message.\n' 64 | '1. The conversation will be listed in the chatbot in the bottom right corner.\n' 65 | '1. If you want to change the video or hyperparameters, click the red "Stop Video" button.\n' 66 | '## Hyparameters\n' 67 | '1. **Threshold Mode**: When "single-frame score" is selected, MMDuet responses when the video score of current frame exceeds the "Score Threshold".' 68 | 'When "sum score" is selected, MMDuet responses when the sum of video scores of the several latest frames exceeds the "Score Threshold".\n' 69 | '1. **Scores Used**: The video score is set to the sum of the scores selected here.\n' 70 | '1. **Remove Previous Model Turns in Context**: Whether to remove the previous model-generated responses from the dialogue context.' 71 | 'If set to "yes", The model will not be affected by previous content when generating subsequent responses, which will significantly reduce the occurrence of duplicate responses.' 72 | 'However, this is at the cost of losing those context information.\n' 73 | '1. **Threshold**: The threshold for the video score to trigger a response from MMDuet.\n' 74 | '1. **Frame Interval**: The time interval between frames sampled from the video.\n' 75 | '## Examples\n' 76 | 'In the upper right part we list 3 examples. The first 2 examples are about answering questions about videos in a real-time manner, and the third question is about dense video captioning for a **10-minutes long** video.\n' 77 | 'Choose the example, click "Start Chat", and you will see an example user message filled in the message input box. Click "Submit" to submit it as early in the video as possible.\n' 78 | '## Notes\n' 79 | '- For a clear demonstration, we add a short `time.sleep()` after each frame. Therefore, the video progresses much slower than the actual inference speed of MMDuet.\n' 80 | '- If the app stuck for some reason, refreshing the page usually doesn\'t help. Restart the app and try again.' 81 | )) 82 | 83 | with gr.Row(visible=False): 84 | def handle_user_input(message, history): 85 | liveinfer.encode_given_query(message) 86 | gr_chat_interface = ChatInterfaceWithUserMsgTime( 87 | fn=handle_user_input, 88 | chatbot=gr.Chatbot( 89 | elem_id="gr_chatbot", 90 | label='chatbot', 91 | avatar_images=('demo/assets/user_avatar.png', 'demo/assets/assistant_avatar.png'), 92 | render=False 93 | ), 94 | retry_btn=None, undo_btn=None, 95 | examples=[], 96 | ) 97 | 98 | with gr.Row(), gr.Blocks() as hyperparam_block: 99 | gr_video = gr.Video(label="Input Video", visible=True, sources=['upload'], autoplay=False) 100 | 101 | with gr.Column(): 102 | # choose the threshold before starting inference 103 | gr_thres_mode = gr.Radio(choices=["single-frame score", "sum score"], value="single-frame score", label="Threshold Mode") 104 | gr_used_scores = gr.CheckboxGroup(choices=["informative score", "relevance score"], value=["informative score"], label="Scores Used") 105 | gr_rm_ass_turns = gr.Radio(choices=["yes", "no"], value="yes", label="Remove Previous Model Turns in Context") 106 | gr_threshold = gr.Slider(minimum=0, maximum=3, step=0.05, value=args.stream_end_prob_threshold, interactive=True, label="Score Threshold") 107 | gr_frame_interval = gr.Slider(minimum=0.1, maximum=10, step=0.1, value=(1/args.frame_fps), interactive=True, label="Frame Interval (sec)") 108 | gr_start_button = gr.Button("Start Chat", variant="primary") 109 | 110 | gr_examples = gr.Examples( 111 | examples=[ 112 | ["demo/assets/office.mp4", "single-frame score", ["informative score", "relevance score"], 0.4, 0.5, "What is happening in the office?", "no"], 113 | ["demo/assets/drive.mp4", "single-frame score", ["informative score", "relevance score"], 0.4, 0.5, "Who is driving the car?", "no"], 114 | ["demo/assets/cooking.mp4", "sum score", ["informative score"], 1.5, 2.0, "Please simply describe what do you see.", "yes"], 115 | ], 116 | inputs=[gr_video, gr_thres_mode, gr_used_scores, gr_threshold, gr_frame_interval, gr_chat_interface.textbox, gr_rm_ass_turns], 117 | # outputs=[gr_video, gr_thres_mode, gr_used_scores, gr_threshold, gr_frame_interval, gr_chat_interface.examples], 118 | label="Examples" 119 | ) 120 | 121 | with gr.Row() as chat: 122 | with gr.Column(): 123 | gr_frame_display = gr.Image(label="Current Model Input Frame", interactive=False) 124 | with gr.Row(): 125 | gr_time_display = gr.Number(label="Current Video Time", value=0) 126 | with gr.Row(): 127 | gr_inf_score_display = gr.Number(label="Informative Score", value=0) 128 | gr_rel_score_display = gr.Number(label="Relevance Score", value=0) 129 | with gr.Row(): 130 | gr_pause_button = gr.Button("Pause Video") 131 | gr_stop_button = gr.Button("Stop Video", variant='stop') 132 | 133 | with gr.Column(): 134 | gr_chat_interface.render() 135 | 136 | def start_chat(src_video_path, thres_mode, rm_ass_turns, scores, threshold, frame_interval, history): 137 | # clear the chatbot and frame display 138 | yield 0, 0, 0, None, [] 139 | 140 | # set the hyperparams 141 | liveinfer.reset() 142 | history_synchronizer.reset() 143 | liveinfer.score_heads = [s.replace(' ', '_') for s in scores] 144 | if thres_mode == 'single-frame score': 145 | liveinfer.stream_end_prob_threshold = threshold 146 | liveinfer.stream_end_score_sum_threshold = None 147 | elif thres_mode == 'sum score': 148 | liveinfer.stream_end_prob_threshold = None 149 | liveinfer.stream_end_score_sum_threshold = threshold 150 | liveinfer.remove_assistant_turns = rm_ass_turns == 'yes' 151 | 152 | # disable the hyperparams 153 | for component in hyperparam_block.children: 154 | component.interactive = False 155 | 156 | # upload the video 157 | frame_fps = 1 / frame_interval 158 | liveinfer.set_fps(frame_fps) 159 | video_input, original_frame_list = load_video(src_video_path, frame_fps) 160 | liveinfer.input_video_stream(video_input) 161 | 162 | while liveinfer.frame_embeds_queue: 163 | start_time = time.time() 164 | pause_event.wait() 165 | ret = liveinfer.input_one_frame() 166 | history = history_synchronizer.get_history() 167 | 168 | if ret['response'] is not None: 169 | history.append((None, f"[time={ret['time']}s] {ret['response']}")) 170 | history_synchronizer.set_history(history) 171 | 172 | elapsed_time = time.time() - start_time 173 | target_delay_time = min(frame_interval, 0.2) 174 | if elapsed_time < target_delay_time: # or the video plays too fast 175 | time.sleep(frame_interval - elapsed_time) 176 | yield ret['time'], ret['informative_score'], ret['relevance_score'], \ 177 | original_frame_list[ret['frame_idx'] - 1], history 178 | 179 | gr_start_button.click( 180 | fn=start_chat, 181 | inputs=[gr_video, gr_thres_mode, gr_rm_ass_turns, gr_used_scores, gr_threshold, gr_frame_interval, gr_chat_interface.chatbot], 182 | outputs=[gr_time_display, gr_inf_score_display, gr_rel_score_display, gr_frame_display, gr_chat_interface.chatbot] 183 | ) 184 | 185 | def toggle_pause(): 186 | if pause_event.is_set(): 187 | pause_event.clear() # Pause processing 188 | return "Resume Video" 189 | else: 190 | pause_event.set() # Resume processing 191 | return "Pause Video" 192 | 193 | gr_pause_button.click( 194 | toggle_pause, 195 | inputs=[], 196 | outputs=gr_pause_button 197 | ) 198 | 199 | def stop_chat(): 200 | liveinfer.reset() 201 | history_synchronizer.reset() 202 | 203 | # enable the hyperparams 204 | for component in hyperparam_block.children: 205 | component.interactive = True 206 | 207 | return 0, 0, 0, None, [] 208 | 209 | gr_stop_button.click( 210 | stop_chat, 211 | inputs=[], 212 | outputs=[gr_time_display, gr_inf_score_display, gr_rel_score_display, gr_frame_display, gr_chat_interface.chatbot] 213 | ) 214 | 215 | demo.queue() 216 | demo.launch(share=False) 217 | -------------------------------------------------------------------------------- /demo/assets/assistant_avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/demo/assets/assistant_avatar.png -------------------------------------------------------------------------------- /demo/assets/cooking.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/demo/assets/cooking.mp4 -------------------------------------------------------------------------------- /demo/assets/drive.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/demo/assets/drive.mp4 -------------------------------------------------------------------------------- /demo/assets/office.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/demo/assets/office.mp4 -------------------------------------------------------------------------------- /demo/assets/user_avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/demo/assets/user_avatar.png -------------------------------------------------------------------------------- /demo/liveinfer.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from test.inference import LiveInferForBenchmark 6 | 7 | 8 | def load_video(video_file, output_fps): 9 | pad_color = (0, 0, 0) 10 | output_resolution = 384 11 | max_num_frames = 400 12 | 13 | cap = cv2.VideoCapture(video_file) 14 | # Get original video properties 15 | input_fps = cap.get(cv2.CAP_PROP_FPS) 16 | frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) 17 | video_duration = frame_count / input_fps 18 | input_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 19 | input_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 20 | output_width = output_height = output_resolution 21 | 22 | output_fps = output_fps if output_fps > 0 else max_num_frames / video_duration 23 | num_frames_total = math.floor(video_duration * output_fps) 24 | frame_sec = [i / output_fps for i in range(num_frames_total)] 25 | frame_list, original_frame_list, cur_time, frame_index = [], [], 0, 0 26 | while cap.isOpened(): 27 | ret, frame = cap.read() 28 | if not ret: 29 | break 30 | if frame_index < len(frame_sec) and cur_time >= frame_sec[frame_index]: 31 | original_frame_list.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 32 | if input_width > input_height: 33 | # Landscape video: scale width to the resolution, adjust height 34 | new_width = output_resolution 35 | new_height = int((input_height / input_width) * output_resolution) 36 | else: 37 | # Portrait video: scale height to the resolution, adjust width 38 | new_height = output_resolution 39 | new_width = int((input_width / input_height) * output_resolution) 40 | resized_frame = cv2.resize(frame, (new_width, new_height)) 41 | # pad the frame 42 | canvas = cv2.copyMakeBorder( 43 | resized_frame, 44 | top=(output_height - new_height) // 2, 45 | bottom=(output_height - new_height + 1) // 2, 46 | left=(output_width - new_width) // 2, 47 | right=(output_width - new_width + 1) // 2, 48 | borderType=cv2.BORDER_CONSTANT, 49 | value=pad_color 50 | ) 51 | frame_list.append(np.transpose(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB), (2, 0, 1))) 52 | frame_index += 1 53 | if len(frame_list) >= max_num_frames: 54 | break 55 | cur_time += 1 / input_fps 56 | cap.release() 57 | return torch.tensor(np.stack(frame_list)), original_frame_list 58 | 59 | 60 | class LiveInferForDemo(LiveInferForBenchmark): 61 | def encode_given_query(self, query): 62 | self.last_ids = self.tokenizer.apply_chat_template([{'role': 'user', 'content': query}], add_stream_query_prompt=self.last_role == 'stream', add_stream_prompt=True, return_tensors='pt').to('cuda') 63 | inputs_embeds = self.model.get_input_embeddings()(self.last_ids) 64 | outputs = self.model(inputs_embeds=inputs_embeds, past_key_values=self.past_key_values, use_cache=True, return_dict=True) 65 | self.past_key_values = outputs.past_key_values 66 | self.last_ids = outputs.logits[:, -1:].argmax(dim=-1) 67 | self.last_role = 'user' 68 | 69 | def input_one_frame(self): 70 | """ 71 | in the interactive demo, we need to input 1 frame each time this function is called. 72 | to ensure that user can stop the video and input user messages. 73 | """ 74 | # 1. the check query step is skipped, as all user input is from the demo page 75 | 76 | # 2. input a frame, and update the scores list 77 | video_scores = self._encode_frame() 78 | ret = dict(frame_idx=self.frame_idx, time=round(self.video_time, 1), **video_scores) # the frame_idx here is after self.frame_idx += 1 79 | 80 | # 3. check the scores, if need to generate a response 81 | need_response = False 82 | stream_end_score = sum([v for k, v in video_scores.items() if k in self.score_heads]) 83 | self.stream_end_prob_list.append(stream_end_score) 84 | self.stream_end_score_sum += stream_end_score 85 | if isinstance(self.running_list_length, int) and self.running_list_length > 0: 86 | self.stream_end_prob_list = self.stream_end_prob_list[-self.running_list_length:] 87 | if self.stream_end_score_sum_threshold is not None and self.stream_end_score_sum > self.stream_end_score_sum_threshold: 88 | need_response = True 89 | self.stream_end_score_sum = 0 90 | if self.stream_end_prob_threshold is not None and stream_end_score > self.stream_end_prob_threshold: 91 | need_response = True 92 | 93 | # 4. record the responses 94 | if need_response: 95 | response = self._generate_response() 96 | self.num_frames_no_reply = 0 97 | self.consecutive_n_frames = 0 98 | else: 99 | response = None 100 | ret['response'] = response 101 | 102 | # 5. update the video time 103 | self.video_time += 1 / self.frame_fps 104 | 105 | return ret 106 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import HfArgumentParser 2 | 3 | from .arguments_live import LiveTrainingArguments, get_args_class 4 | from .live_llava.video_head_live_llava_qwen import build_video_head_live_llava_qwen 5 | from .modeling_live import fast_greedy_generate 6 | 7 | 8 | def build_model_and_tokenizer(is_training, **kwargs): 9 | llm_pretrained = kwargs.get('llm_pretrained', None) 10 | if 'llava' in llm_pretrained: 11 | return build_video_head_live_llava_qwen(is_training=is_training, **kwargs) 12 | else: 13 | raise NotImplementedError(f'Not support {llm_pretrained}') 14 | 15 | def parse_args(live_version=None) -> LiveTrainingArguments: 16 | if live_version is None: 17 | args, = HfArgumentParser(LiveTrainingArguments).parse_args_into_dataclasses() 18 | live_version = args.live_version 19 | args, = HfArgumentParser(get_args_class(live_version)).parse_args_into_dataclasses() 20 | return args -------------------------------------------------------------------------------- /models/arguments_live.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from transformers import TrainingArguments 3 | from typing import Union 4 | 5 | @dataclass 6 | class LiveTrainingArguments(TrainingArguments): 7 | live_version: str = 'live1+' 8 | dataset_config: str = None 9 | stream_loss_weight: float = 1.0 10 | llm_pretrained: str = 'lmms-lab/llava-onevision-qwen2-7b-ov' 11 | vision_pretrained: str = 'google/siglip-large-patch16-384' 12 | lora_pretrained: str = None 13 | lora_modules: str = "model\.layers.*(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)$" 14 | lora_r: int = 16 15 | lora_alpha: int = 32 16 | finetune_modules: list[str] = field(default_factory=lambda: ['connector', 'mm_projector', 'response_head', 'related_head']) 17 | frame_fps: float = 2 18 | frame_token_cls: bool = False 19 | frame_token_pooled: list[int] = field(default_factory=lambda: [7,7]) 20 | frame_num_tokens: int = 49 21 | video_pooling_stride: int = 4 22 | frame_resolution: int = 384 23 | embed_mark: str = '2fps_384_1+3x3' 24 | v_placeholder: str = '' 25 | max_num_frames: int = 100 26 | augmentation: bool = False 27 | attn_implementation: str = 'flash_attention_2' 28 | output_dir: str = 'outputs/debug' 29 | 30 | 31 | @dataclass 32 | class LiveTestArguments(LiveTrainingArguments): 33 | system_prompt: str = ( 34 | "A multimodal AI assistant is helping users with some activities." 35 | " Below is their conversation, interleaved with the list of video frames received by the assistant." 36 | ) 37 | live_version: str = 'test' 38 | is_online_model: bool = True 39 | grounding_mode: bool = False # if set, only output probs, never generate reply 40 | input_dir: str = 'datasets/shot2story/videos/' 41 | test_fname: str = '' 42 | output_fname: str = '' 43 | repetition_penalty: float = None 44 | stream_end_prob_threshold: float = None 45 | response_min_interval_frames: int = None 46 | threshold_z: float = None 47 | first_n_frames_no_generate: int = 0 48 | consecutive_n_frames_threshold: int = 1 49 | running_list_length: int = 20 50 | start_idx: int = 0 51 | end_idx: int = None 52 | time_instruction_format: str = None 53 | stream_end_score_sum_threshold: float = None 54 | remove_assistant_turns: bool = False # if True, do not add assistant-generated content to input context (kv_cache) 55 | score_heads: str = 'informative_score' # a list of score names, seperated with comma. e.g.: `relevance_score,informative_score` 56 | 57 | 58 | def get_args_class(args_version: str): 59 | if args_version == 'train': 60 | return LiveTrainingArguments 61 | elif args_version == 'test': 62 | return LiveTestArguments 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /models/configuration_live.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import PretrainedConfig 3 | 4 | class LiveConfigMixin(PretrainedConfig): 5 | def __init__(self, *, vision_pretrained: str = None, 6 | frame_resolution: int = None, frame_token_cls: bool = None, frame_token_pooled: list[int] = None, frame_num_tokens: int = None, 7 | v_placeholder: str = '', v_placeholder_id: int = None, 8 | stream_loss_weight: float = 1.0, vision_hidden_size=1024, **kwargs 9 | ): 10 | super().__init__(**kwargs) 11 | self.vision_pretrained = vision_pretrained 12 | self.frame_resolution = frame_resolution 13 | self.frame_token_cls = frame_token_cls 14 | self.frame_token_pooled = frame_token_pooled 15 | self.frame_num_tokens = frame_num_tokens 16 | self.vision_hidden_size = vision_hidden_size 17 | self.stream_loss_weight = stream_loss_weight 18 | self.v_placeholder = v_placeholder 19 | self.v_placeholder_id = v_placeholder_id 20 | 21 | 22 | class VideoHeadLiveConfigMixin(PretrainedConfig): 23 | def __init__(self, *, vision_pretrained: str = None, 24 | frame_resolution: int = None, frame_token_cls: bool = None, frame_token_pooled: list[int] = None, frame_num_tokens: int = None, 25 | v_placeholder: str = '', v_placeholder_id: int = None, vision_hidden_size=1024, **kwargs 26 | ): 27 | super().__init__(**kwargs) 28 | self.vision_pretrained = vision_pretrained 29 | self.frame_resolution = frame_resolution 30 | self.frame_token_cls = frame_token_cls 31 | self.frame_token_pooled = frame_token_pooled 32 | self.frame_num_tokens = frame_num_tokens 33 | self.vision_hidden_size = vision_hidden_size 34 | 35 | self.v_placeholder = v_placeholder 36 | self.v_placeholder_id = v_placeholder_id 37 | -------------------------------------------------------------------------------- /models/live_llava/video_head_live_llava_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Hao Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import math 17 | import copy 18 | import random 19 | from typing import List, Optional, Tuple, Union, Dict 20 | import torch 21 | import torch.nn as nn 22 | from torch.nn import CrossEntropyLoss 23 | from dataclasses import dataclass 24 | 25 | import transformers 26 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM 27 | 28 | from transformers.modeling_outputs import CausalLMOutputWithPast 29 | from transformers.generation.utils import GenerateOutput 30 | 31 | from llava.model.llava_arch import LlavaMetaModel 32 | from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM 33 | 34 | from ..modeling_live import build_live, LiveMixin 35 | from ..configuration_live import VideoHeadLiveConfigMixin 36 | 37 | from transformers.utils import logging 38 | logger = logging.get_logger(__name__) 39 | 40 | 41 | class VideoHeadLiveLlavaQwenConfig(Qwen2Config, VideoHeadLiveConfigMixin): 42 | def __init__(self, video_pooling_stride=4, video_head_stop_grad=False, **kwargs): 43 | super().__init__(**kwargs) 44 | self.video_pooling_stride = video_pooling_stride 45 | self.video_head_stop_grad = video_head_stop_grad 46 | 47 | 48 | @dataclass 49 | class VideoHeadCausalLMOutputWithPast(CausalLMOutputWithPast): 50 | loss: Optional[torch.FloatTensor] = None 51 | logits: torch.FloatTensor = None 52 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 53 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 54 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 55 | lm_loss: Optional[torch.FloatTensor] = None 56 | video_loss: Optional[torch.FloatTensor] = None 57 | informative_logits: Optional[torch.FloatTensor] = None 58 | relevance_logits: Optional[torch.FloatTensor] = None 59 | 60 | class VideoHeadLlavaQwenModel(LlavaMetaModel, Qwen2Model): 61 | config_class = VideoHeadLiveLlavaQwenConfig 62 | 63 | def __init__(self, config: Qwen2Config): 64 | super(VideoHeadLlavaQwenModel, self).__init__(config) 65 | 66 | 67 | class VideoHeadLiveLlavaQwenForCausalLM(Qwen2ForCausalLM, LiveMixin): 68 | config_class = VideoHeadLiveLlavaQwenConfig 69 | 70 | def __init__(self, config): 71 | Qwen2ForCausalLM.__init__(self, config) 72 | config.model_type = "llava_qwen" 73 | config.rope_scaling = None 74 | 75 | self.model = VideoHeadLlavaQwenModel(config) 76 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 77 | self.informative_head = nn.Linear(config.hidden_size, 2, bias=False) 78 | self.relevance_head = nn.Linear(config.hidden_size, 2, bias=False) 79 | 80 | # Initialize weights and apply final processing 81 | self.post_init() 82 | self.vision_encoder = self.get_vision_tower() 83 | self.lm_loss_weight = 1 84 | self.video_loss_weight = 1 85 | print(f"using lm_loss_weight: {self.lm_loss_weight}, video_loss_weight: {self.video_loss_weight} for training") 86 | 87 | def get_model(self): 88 | return self.model 89 | 90 | def connector(self, frames): 91 | return self.get_model().mm_projector(frames) 92 | 93 | def get_vision_tower(self): 94 | return self.get_model().get_vision_tower() 95 | 96 | def vision_encode(self, vision_tower, frames): 97 | frame_features = vision_tower(frames) 98 | return frame_features 99 | 100 | def post_projector_pooling(self, image_feature): 101 | stride = self.config.video_pooling_stride 102 | height = width = self.get_vision_tower().num_patches_per_side 103 | num_frames, num_tokens, num_dim = image_feature.shape 104 | image_feature = image_feature.view(num_frames, height, width, -1) 105 | image_feature = image_feature.permute(0, 3, 1, 2).contiguous() 106 | # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) 107 | if self.config.mm_spatial_pool_mode == "average": 108 | image_feature = nn.functional.avg_pool2d(image_feature, stride) 109 | elif self.config.mm_spatial_pool_mode == "max": 110 | image_feature = nn.functional.max_pool2d(image_feature, stride) 111 | elif self.config.mm_spatial_pool_mode == "bilinear": 112 | height, weight = image_feature.shape[2:] 113 | scaled_shape = [math.ceil(height / stride), math.ceil(weight / stride)] 114 | image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') 115 | else: 116 | raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") 117 | image_feature = image_feature.permute(0, 2, 3, 1) 118 | image_feature = image_feature.view(num_frames, -1, num_dim).contiguous() 119 | return image_feature 120 | 121 | def forward( 122 | self, 123 | input_ids: torch.LongTensor = None, 124 | attention_mask: Optional[torch.Tensor] = None, 125 | position_ids: Optional[torch.LongTensor] = None, 126 | past_key_values: Optional[List[torch.FloatTensor]] = None, 127 | inputs_embeds: Optional[torch.FloatTensor] = None, 128 | labels: Optional[torch.LongTensor] = None, 129 | informative_labels: Optional[torch.LongTensor] = None, 130 | relevance_labels: Optional[torch.LongTensor] = None, 131 | use_cache: Optional[bool] = None, 132 | output_attentions: Optional[bool] = None, 133 | output_hidden_states: Optional[bool] = None, 134 | frames: Optional[torch.FloatTensor] = None, 135 | return_dict: Optional[bool] = None, 136 | **kwargs, 137 | ) -> Union[Tuple, CausalLMOutputWithPast]: 138 | if inputs_embeds is None: 139 | inputs_embeds = self.joint_embed(input_ids, frames) 140 | 141 | outputs = self.model( 142 | attention_mask=attention_mask, 143 | position_ids=position_ids, 144 | past_key_values=past_key_values, 145 | inputs_embeds=inputs_embeds, 146 | use_cache=use_cache, 147 | output_attentions=output_attentions, 148 | output_hidden_states=output_hidden_states, 149 | return_dict=return_dict, 150 | ) 151 | 152 | model_outputs = copy.copy(outputs) 153 | hidden_states = outputs[0] 154 | outputs = outputs[1:] 155 | logits = self.lm_head(hidden_states).float() 156 | if self.config.video_head_stop_grad: 157 | hidden_states_no_grad = hidden_states.detach() 158 | else: 159 | hidden_states_no_grad = hidden_states 160 | informative_logits = self.informative_head(hidden_states_no_grad).float() 161 | relevance_logits = self.relevance_head(hidden_states_no_grad).float() 162 | 163 | # NOTE: all labels used here are already shifted in data collator 164 | loss_fct = CrossEntropyLoss() 165 | loss = 0. 166 | 167 | if labels is not None: 168 | if not(labels != -100).any(): 169 | labels[:, 0] = input_ids[:, 1] # make sure lm_loss is calculated for every example, or the deepspeed training process will hang 170 | lm_loss = loss_fct(logits.flatten(0, 1), labels.flatten()) 171 | if not return_dict: 172 | outputs = (logits,) + outputs + (loss,) 173 | else: 174 | lm_loss = 0. 175 | 176 | # merge the 2 labels together, so this loss must be calculated as all training examples contains either informative_label or relevance_label. 177 | # otherwise the deepspeed training process will hang 178 | if informative_labels is not None and relevance_labels is not None: 179 | video_labels = torch.cat([informative_labels, relevance_labels], dim=0) 180 | video_logits = torch.cat([informative_logits, relevance_logits], dim=0) 181 | if not (video_labels != -100).any(): 182 | video_labels[:, 0] = 0 # make sure video_loss is calculated for every example, or the deepspeed training process will hang 183 | video_loss = loss_fct(video_logits.flatten(0, 1), video_labels.flatten()) 184 | if not return_dict: 185 | outputs = outputs + (video_loss,) 186 | else: 187 | video_loss = 0. 188 | 189 | loss = lm_loss * self.lm_loss_weight + video_loss * self.video_loss_weight 190 | 191 | if not return_dict: 192 | outputs = (loss,) + outputs 193 | return outputs 194 | 195 | return VideoHeadCausalLMOutputWithPast( 196 | loss=loss, 197 | logits=logits, 198 | past_key_values=model_outputs.past_key_values, 199 | hidden_states=model_outputs.hidden_states, 200 | attentions=model_outputs.attentions, 201 | lm_loss=lm_loss, 202 | video_loss=video_loss, 203 | informative_logits=informative_logits, 204 | relevance_logits=relevance_logits, 205 | ) 206 | 207 | def generate_after_embed(self, input_ids, frames, **kwargs): 208 | return super().generate(inputs_embeds=self.joint_embed(input_ids, frames), **kwargs) 209 | 210 | @torch.no_grad() 211 | def generate( 212 | self, 213 | input_ids: torch.LongTensor = None, 214 | attention_mask: Optional[torch.Tensor] = None, 215 | position_ids: Optional[torch.LongTensor] = None, 216 | past_key_values: Optional[List[torch.FloatTensor]] = None, 217 | inputs_embeds: Optional[torch.FloatTensor] = None, 218 | use_cache: Optional[bool] = None, 219 | output_attentions: Optional[bool] = None, 220 | output_hidden_states: Optional[bool] = None, 221 | frames: Optional[torch.FloatTensor] = None, 222 | return_dict: Optional[bool] = None, 223 | **kwargs, 224 | ) -> Union[GenerateOutput, torch.LongTensor]: 225 | ''' 226 | The original generate function of LLaVA. 227 | ''' 228 | logger.warning('You are calling the generate function of LLaVA, which is deprecated for Live Video models. Please use a LiveInfer class for inference.') 229 | if inputs_embeds is None: 230 | inputs_embeds = self.joint_embed(input_ids, frames) 231 | outputs = super().generate( 232 | attention_mask=attention_mask, 233 | position_ids=position_ids, 234 | past_key_values=past_key_values, 235 | inputs_embeds=inputs_embeds, 236 | use_cache=use_cache, 237 | output_attentions=output_attentions, 238 | output_hidden_states=output_hidden_states, 239 | return_dict=return_dict, 240 | **kwargs 241 | ) 242 | return outputs 243 | 244 | 245 | def build_video_head_live_llava_qwen(**kwargs): 246 | model, tokenizer = build_live(config_class=VideoHeadLiveLlavaQwenConfig, model_class=VideoHeadLiveLlavaQwenForCausalLM, **kwargs) 247 | # freeze vit 248 | print('freezing ViT') 249 | for param in model.get_vision_tower().parameters(): 250 | param.requires_grad = False 251 | return model, tokenizer 252 | 253 | if __name__ == '__main__': 254 | from transformers import HfArgumentParser 255 | from models.arguments_live import LiveTrainingArguments 256 | args, = HfArgumentParser(LiveTrainingArguments).parse_args_into_dataclasses() 257 | args.llm_pretrained = "lmms-lab/llava-onevision-qwen2-7b-ov" 258 | print(args.to_dict()) 259 | model, tokenizer = build_video_head_live_llava_qwen(is_training=True, **args.to_dict()) 260 | print(model.config, tokenizer) 261 | -------------------------------------------------------------------------------- /models/modeling_live.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import torch.distributed as dist 3 | from peft import LoraConfig, get_peft_model, PeftModel 4 | from transformers import AutoModelForCausalLM, Cache 5 | from transformers.utils import logging 6 | from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor 7 | 8 | from .tokenization_live import build_live_tokenizer_and_update_config 9 | from .vision_live import build_live_vision 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | class LiveMixin(AutoModelForCausalLM): 14 | def set_vision_inside(self): 15 | logger.warning_once("!!! Set vision encoder in the model, only recommended for on in-the-wild inference. " 16 | "Please dont call this for efficient training & evaluation. Instead, do visual feature pre-extraction.") 17 | if not hasattr(self, 'vision_encoder'): 18 | self.vision_encoder, self.vision_encode = build_live_vision(self.config) 19 | else: 20 | logger.warning_once("Vision encoder already exists, skip setting vision encoder inside the model.") 21 | 22 | def unset_vision_inside(self): 23 | del self.vision_encoder 24 | del self.vision_encode 25 | 26 | def visual_embed(self, frames: torch.Tensor): 27 | if hasattr(self, 'vision_encode'): 28 | with torch.cuda.amp.autocast(): 29 | frames = self.vision_encode(self.vision_encoder, frames) 30 | frames = self.connector(frames) 31 | if hasattr(self, 'post_projector_pooling'): 32 | frames = self.post_projector_pooling(frames) 33 | return frames.view(-1, frames.shape[-1]) 34 | 35 | def joint_embed( 36 | self, 37 | input_ids: torch.Tensor = None, 38 | frames: torch.Tensor = None, 39 | ): 40 | if frames is None: 41 | return self.get_input_embeddings()(input_ids) 42 | if input_ids is None: 43 | return self.visual_embed(frames) 44 | inputs_embeds = self.get_input_embeddings()(input_ids.clamp(max=self.vocab_size-1)) 45 | v_mask = input_ids == self.config.v_placeholder_id 46 | if v_mask.any(): 47 | inputs_embeds[v_mask] = self.visual_embed(frames).to(inputs_embeds.dtype) 48 | return inputs_embeds 49 | 50 | 51 | def fast_greedy_generate(*, model: LiveMixin, inputs_embeds: torch.Tensor, past_key_values: Cache, eos_token_id: int, inplace_output_ids: torch.Tensor, 52 | repetition_penalty=None, generated_token_ids=list()): 53 | if repetition_penalty is not None: 54 | assert isinstance(repetition_penalty, float) 55 | logits_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) 56 | 57 | for i in range(inplace_output_ids.size(1)): 58 | outputs = model(inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=True, return_dict=True) 59 | past_key_values = outputs.past_key_values 60 | if repetition_penalty is not None: 61 | if len(generated_token_ids) > 0: 62 | outputs_logits = logits_processor( 63 | input_ids=torch.tensor(generated_token_ids).unsqueeze(0).to(device=inplace_output_ids.device, dtype=torch.long), scores=outputs.logits[:, -1, :]) 64 | outputs_logits = outputs_logits.unsqueeze(1) 65 | else: 66 | outputs_logits = outputs.logits[:, -1:] 67 | new_token_id = outputs_logits.argmax(dim=-1) 68 | if not new_token_id == eos_token_id: # special tokens should not be penalized 69 | generated_token_ids.append(new_token_id.item()) 70 | else: 71 | outputs_logits = outputs.logits 72 | new_token_id = outputs_logits[:, -1:].argmax(dim=-1) 73 | inplace_output_ids[:, i] = new_token_id 74 | if new_token_id == eos_token_id: 75 | break 76 | inputs_embeds = model.get_input_embeddings()(new_token_id) 77 | return inplace_output_ids[:, :i+1], past_key_values, generated_token_ids 78 | 79 | 80 | def build_live( 81 | *, 82 | is_training: bool, 83 | config_class: type, 84 | model_class: type, 85 | llm_pretrained: str = None, 86 | lora_pretrained: str = None, 87 | finetune_modules: list[str] = None, 88 | lora_modules: str = None, 89 | lora_r: int = None, 90 | lora_alpha: int = None, 91 | set_vision_inside: bool = False, 92 | attn_implementation: str = 'flash_attention_2', 93 | torch_dtype: str | torch.dtype = 'auto', 94 | **kwargs 95 | ): 96 | model = model_class.from_pretrained( 97 | llm_pretrained, config=config_class.from_pretrained(llm_pretrained, **kwargs), 98 | torch_dtype=torch_dtype, attn_implementation=attn_implementation, 99 | device_map='cuda' if torch.cuda.device_count() == 1 or dist.is_initialized() else 'auto') 100 | tokenizer = build_live_tokenizer_and_update_config(llm_pretrained, model.config) 101 | logger.warning(f"model config after update: {model.config}") 102 | if is_training: 103 | if lora_pretrained: 104 | print(f'loading lora from checkpoint: {lora_pretrained}') 105 | model = PeftModel.from_pretrained(model, lora_pretrained, is_trainable=False) 106 | else: 107 | lora_config = LoraConfig( 108 | r=lora_r, 109 | lora_alpha=lora_alpha, 110 | target_modules=lora_modules, 111 | lora_dropout=0.05, 112 | task_type="CAUSAL_LM", 113 | modules_to_save=finetune_modules, 114 | inference_mode=False, 115 | ) 116 | print(f'creating lora with config: {lora_config}') 117 | model = get_peft_model(model, lora_config, autocast_adapter_dtype=False) 118 | model.print_trainable_parameters() 119 | 120 | else: 121 | if lora_pretrained: 122 | logger.info(f'loading lora from checkpoint: {lora_pretrained}') 123 | model = PeftModel.from_pretrained(model, lora_pretrained, is_trainable=False) 124 | else: 125 | logger.warning(f'!!! Fail to load lora from checkpoint: {lora_pretrained}. Return a new initialized model.') 126 | if set_vision_inside: 127 | model.set_vision_inside() 128 | model.requires_grad_(False) 129 | return model, tokenizer 130 | -------------------------------------------------------------------------------- /models/tokenization_live.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, Qwen2Tokenizer 3 | from functools import partial 4 | from .configuration_live import LiveConfigMixin, VideoHeadLiveConfigMixin 5 | 6 | 7 | def get_stream_placeholder_len(num_frames: int, model_config: VideoHeadLiveConfigMixin) -> str: 8 | return num_frames * model_config.frame_num_tokens * len(model_config.v_placeholder) 9 | 10 | 11 | def get_stream_placeholder_jinja2(model_config: VideoHeadLiveConfigMixin) -> str: 12 | # * (frame_num_tokens * num_frames) 13 | return f"''.join([{model_config.frame_num_tokens} * '{model_config.v_placeholder}'] * message['num_frames'])" 14 | 15 | 16 | def get_stream_learn_ranges(num_frames: int, model_config: LiveConfigMixin, is_grounding_task) -> torch.Tensor: 17 | ''' 18 | the start/end idx of every frame_token_interval or stream_end_token after each frame 19 | ''' 20 | len_frame_placeholder_with_interval = model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval) 21 | intermediate_interval_idxs = torch.arange( 22 | len_frame_placeholder_with_interval, 23 | len_frame_placeholder_with_interval * num_frames + 1, 24 | len_frame_placeholder_with_interval 25 | ) - len(model_config.frame_token_interval) 26 | len_learn = torch.LongTensor([len(model_config.frame_token_interval)] * (num_frames - 1) + [len(model_config.frame_token_interval) if is_grounding_task else len(model_config.stream_end_token)]) 27 | learn_ranges = torch.stack([ 28 | intermediate_interval_idxs, 29 | intermediate_interval_idxs + len_learn 30 | ], dim=1) 31 | return learn_ranges 32 | 33 | 34 | def chat_template_llava(self, stream_placeholder): 35 | template = ( 36 | "{% if messages[0]['role'] == 'system' %}" 37 | "{{ bos_token + 'system\n' + messages[0]['content'] + eos_token}}" # system 38 | "{% set messages = messages[1:] %}" 39 | "{% endif %}" 40 | "{% for i in range(messages | length) %}" 41 | "{% set message = messages[i] %}" 42 | "{% if message['role'] == 'user' %}" 43 | "{% if add_stream_query_prompt %}" 44 | "{{ eos_token + '\n' + bos_token + 'user\n' + message['content'] + eos_token }}" 45 | "{% else %}" 46 | "{{ '\n' + bos_token + 'user\n' + message['content'] + eos_token }}" 47 | "{% endif %}" 48 | "{% elif message['role'] == 'assistant' %}" 49 | "{{ '\n' + bos_token + 'assistant\n' + message['content'] + eos_token }}" 50 | "{% elif message['role'] == 'stream' and message['num_frames'] > 0 %}" 51 | "{{ '\n' + bos_token + 'stream\n' + STREAM_PLACEHOLDER + eos_token }}" 52 | "{% endif %}" 53 | "{% endfor %}" 54 | "{% if add_generation_prompt %}" 55 | "{{ '\n' + bos_token + 'assistant\n' }}" 56 | "{% elif add_stream_prompt %}" 57 | "{{ '\n' + bos_token + 'stream\n' }}" 58 | "{% elif add_stream_generation_prompt %}" 59 | "{{ eos_token + '\n' + bos_token + 'assistant\n' }}" 60 | "{% endif %}" 61 | ) 62 | template = template.replace('STREAM_PLACEHOLDER', stream_placeholder) 63 | return template 64 | 65 | 66 | def chat_template_offsets_llava(tokenizer): 67 | # now the turn of all roles start with similar beginnings 68 | def chat_template_transition(): 69 | return { 70 | (None, 'system'): f'{tokenizer.bos_token}system\n', 71 | ('system', 'user'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}user\n', 72 | ('system', 'stream'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}stream\n', 73 | ('user', 'assistant'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}assistant\n', 74 | ('user', 'stream'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}stream\n', 75 | ('user', 'user'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}user\n', 76 | ('assistant', 'user'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}user\n', 77 | ('assistant', 'stream'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}stream\n', 78 | ('stream', 'user'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}user\n', 79 | ('stream', 'assistant'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}assistant\n', 80 | ('stream', 'stream'): f'{tokenizer.eos_token}\n{tokenizer.bos_token}stream\n', 81 | 'assistant': f'{tokenizer.bos_token}assistant\n', 82 | 'eos_token': tokenizer.eos_token, 83 | } 84 | return {k:len(v) for k, v in chat_template_transition().items()} 85 | 86 | 87 | CHAT_TEMPLATES = { 88 | 'llava': chat_template_llava 89 | } 90 | 91 | CHAT_TEMPLATE_OFFSETS = { 92 | 'llava': chat_template_offsets_llava 93 | } 94 | 95 | 96 | def get_learn_ranges(conversation: list[dict], *, chat_template_offsets: dict[tuple, int], model_config: VideoHeadLiveConfigMixin): 97 | offset = 0 98 | learn_ranges = [] 99 | last_role = None 100 | for message_i, message in enumerate(conversation): 101 | role = message['role'] 102 | offset += chat_template_offsets[(last_role, role)] 103 | last_role = role 104 | if role == 'stream': 105 | # we do not to use lm_loss to learn anything in the stream 106 | offset += get_stream_placeholder_len(message['num_frames'], model_config) 107 | else: 108 | if role == 'assistant': 109 | if message.get('learn', False): 110 | learn_ranges.append(range(offset, offset + len(message['content']) + chat_template_offsets['eos_token'])) 111 | offset += len(message['content']) 112 | return learn_ranges 113 | 114 | 115 | def build_live_tokenizer_and_update_config(llm_pretrained: str, model_config: LiveConfigMixin) -> AutoTokenizer: 116 | if 'llava' in llm_pretrained: 117 | tokenizer = AutoTokenizer.from_pretrained(llm_pretrained, use_fast=True, padding_side='left') 118 | tokenizer.add_special_tokens({'additional_special_tokens': [model_config.v_placeholder]}) 119 | v_placeholder_id = tokenizer.convert_tokens_to_ids(model_config.v_placeholder) 120 | tokenizer.bos_token, tokenizer.eos_token = "<|im_start|>", "<|im_end|>" 121 | 122 | model_config.update(dict( 123 | v_placeholder_id=v_placeholder_id, 124 | eos_token_id=tokenizer.eos_token_id)) 125 | 126 | tokenizer.chat_template = CHAT_TEMPLATES['llava']( 127 | tokenizer, 128 | get_stream_placeholder_jinja2(model_config), 129 | ) 130 | tokenizer.get_learn_ranges = partial(get_learn_ranges, chat_template_offsets=CHAT_TEMPLATE_OFFSETS['llava'](tokenizer), model_config=model_config) 131 | return tokenizer 132 | 133 | else: 134 | raise NotImplementedError 135 | 136 | 137 | if __name__ == '__main__': 138 | chat = [ 139 | {'role': 'system', 'content': 'System message 1.'}, 140 | {'role': 'stream', 'num_frames': 2, 'learn': 1}, 141 | {'role': 'user', 'content': 'User message 1?'}, 142 | {'role': 'assistant', 'content': 'Assistant message 1.', 'learn': True}, 143 | {'role': 'stream', 'num_frames': 3, 'learn': 3}, 144 | {'role': 'assistant', 'content': 'Assistant message 2.', 'learn': True}, 145 | {'role': 'user', 'content': 'User message 2?'}, 146 | {'role': 'stream', 'num_frames': 4, 'learn': 4}, 147 | {'role': 'assistant', 'content': 'Assistant message 3.', 'learn': True}, 148 | ] 149 | 150 | llava_config = VideoHeadLiveConfigMixin(v_placeholder='', 151 | frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10) 152 | llava_tokenizer = build_live_tokenizer_and_update_config('lmms-lab/llava-onevision-qwen2-7b-ov', llava_config) 153 | 154 | for model_name, tokenizer in [('llava', llava_tokenizer)]: 155 | print('model name:', model_name) 156 | prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) 157 | learn_ranges = tokenizer.get_learn_ranges(chat) 158 | batch = tokenizer([prompt], return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt", padding=True) 159 | print('prompt:', prompt) 160 | print('batch:', batch) 161 | print('learn_ranges:', learn_ranges) 162 | print('learend text:') 163 | for learn_r in learn_ranges: 164 | print(prompt[learn_r.start:learn_r.stop], end='\n----------\n') 165 | 166 | batch_labels = torch.full_like(batch.input_ids, -100, dtype=torch.long) 167 | for text, labels, input_ids, offset_mapping, learn_range in zip( 168 | [prompt], batch_labels, batch.input_ids, batch.offset_mapping, [learn_ranges] 169 | ): 170 | for learn_r in learn_range: 171 | start = torch.nonzero(offset_mapping[:,0] == learn_r.start).item() 172 | if offset_mapping[:,0][-1] >= learn_r.stop: 173 | stop = torch.nonzero(offset_mapping[:,0] == learn_r.stop).item() 174 | else: # the last eos token 175 | stop = len(input_ids) 176 | labels[start-1:stop-1] = input_ids[start:stop] 177 | labels[labels >= len(tokenizer) - 1] = tokenizer.eos_token_id 178 | print(batch.input_ids) 179 | print(batch_labels) -------------------------------------------------------------------------------- /models/vision_live.py: -------------------------------------------------------------------------------- 1 | import math, torch 2 | from functools import partial 3 | from torch import nn, Tensor 4 | from torchvision.transforms.functional import normalize 5 | from transformers import AutoModel 6 | from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD 7 | 8 | from .configuration_live import LiveConfigMixin 9 | 10 | 11 | def _siglip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple, 12 | mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], rescale_factor=0.00392156862745098, **kwargs): 13 | frames = normalize(frames * rescale_factor, mean=mean, std=std) 14 | with torch.cuda.amp.autocast(): 15 | vision_outputs = vision_model(frames) 16 | last_hidden_state = vision_outputs.last_hidden_state 17 | if frame_token_pooled: 18 | s = int(math.sqrt(last_hidden_state.shape[1])) 19 | spatial_tokens = torch.nn.functional.adaptive_avg_pool2d( 20 | last_hidden_state.reshape( 21 | last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1] 22 | ).permute(0, 3, 1, 2), 23 | frame_token_pooled 24 | ).flatten(2, 3).permute(0, 2, 1) 25 | if not frame_token_cls: 26 | return spatial_tokens 27 | if frame_token_cls: 28 | cls_token = vision_outputs.pooler_output[:, None] 29 | if not frame_token_pooled: 30 | return cls_token 31 | return torch.cat([cls_token, spatial_tokens], dim=1) 32 | 33 | 34 | def _clip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple, 35 | mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, rescale_factor=0.00392156862745098, **kwargs): 36 | frames = normalize(frames * rescale_factor, mean=mean, std=std) 37 | with torch.cuda.amp.autocast(): 38 | vision_outputs = vision_model(frames) 39 | last_hidden_state = vision_outputs.last_hidden_state 40 | if frame_token_pooled: 41 | s = int(math.sqrt(last_hidden_state.shape[1])) 42 | spatial_tokens = torch.nn.functional.adaptive_avg_pool2d( 43 | last_hidden_state[:,1:].reshape( 44 | last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1] 45 | ).permute(0, 3, 1, 2), 46 | frame_token_pooled 47 | ).flatten(2, 3).permute(0, 2, 1) 48 | if not frame_token_cls: 49 | return spatial_tokens 50 | if frame_token_cls: 51 | cls_token = last_hidden_state[:,0] 52 | if not frame_token_pooled: 53 | return cls_token 54 | return torch.cat([cls_token, spatial_tokens], dim=1) 55 | 56 | 57 | def build_live_vision(config: LiveConfigMixin): 58 | model = AutoModel.from_pretrained(config.vision_pretrained).vision_model 59 | if 'google/siglip-large-patch16-384' == config.vision_pretrained: 60 | return model, partial(_siglip_vision_encode, frame_token_cls=config.frame_token_cls, frame_token_pooled=config.frame_token_pooled) 61 | elif 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90k' == config.vision_pretrained or 'openai/clip-vit-large-patch14-336' == config.vision_pretrained: 62 | return model, partial(_clip_vision_encode, config) 63 | else: 64 | raise ValueError(f'Unverified vision_pretrained: {config.vision_pretrained}') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | av==13.0.0 3 | bitsandbytes==0.41.0 4 | datasets==2.16.1 5 | decord==0.6.0 6 | deepspeed==0.15.1 7 | einops==0.6.1 8 | fastapi==0.112.2 9 | ffmpeg-python==0.2.0 10 | Flask==3.0.3 11 | gradio==4.42.0 12 | gradio_client==1.3.0 13 | huggingface-hub==0.24.6 14 | Jinja2==3.1.4 15 | matplotlib==3.9.2 16 | moviepy==1.0.3 17 | ninja==1.11.1.1 18 | nltk==3.9.1 19 | numpy==1.26.1 20 | omegaconf==2.3.0 21 | open_clip_torch==2.26.1 22 | openai==1.51.2 23 | opencv-python==4.10.0.84 24 | pandas==2.2.2 25 | peft==0.12.0 26 | pillow==10.4.0 27 | pycocotools==2.0.8 28 | safetensors==0.4.4 29 | scikit-learn==1.2.2 30 | scipy==1.14.1 31 | submitit==1.5.1 32 | tensorboard==2.17.1 33 | tiktoken==0.8.0 34 | timm==1.0.9 35 | tokenizers==0.19.1 36 | tomlkit==0.12.0 37 | torch==2.1.2 38 | torchaudio==2.1.2 39 | torchvision==0.16.2 40 | tqdm==4.66.5 41 | transformers==4.44.2 42 | wandb==0.17.9 43 | -------------------------------------------------------------------------------- /scripts/inference/charades.sh: -------------------------------------------------------------------------------- 1 | output_dir=outputs/mmduet 2 | mkdir -vp ${output_dir}/eval 3 | 4 | 5 | # -------------------- 6 | # run inference 7 | # -------------------- 8 | python -u -m test.inference --grounding_mode true \ 9 | --llm_pretrained lmms-lab/llava-onevision-qwen2-7b-ov --bf16 true \ 10 | --lora_pretrained ${output_dir} \ 11 | --stream_end_prob_threshold 1 \ 12 | --input_dir datasets/charades/videos --frame_fps 2 --max_num_frames 400 \ 13 | --test_fname datasets/charades/annotations/test-random_prompt.json \ 14 | --output_fname ${output_dir}/eval/charades_test-random_prompt-pred.json \ 15 | > ${output_dir}/eval/charades_test-random_prompt-pred.log 2>&1 & 16 | wait 17 | 18 | 19 | # -------------------- 20 | # evaluate 21 | # -------------------- 22 | python -u -m test.evaluate --func grounding \ 23 | --pred_file ${output_dir}/eval/charades_test-random_prompt-pred.json \ 24 | --gold_file datasets/charades/annotations/test-random_prompt.json \ 25 | --output_file ${output_dir}/eval/charades_test-random_prompt-eval.json \ 26 | > ${output_dir}/eval/charades_test-random_prompt-eval.log 2>&1 & 27 | -------------------------------------------------------------------------------- /scripts/inference/magqa.sh: -------------------------------------------------------------------------------- 1 | output_dir=outputs/mmduet 2 | mkdir -vp ${output_dir}/eval 3 | 4 | thres=0.5 5 | 6 | # -------------------- 7 | # run inference 8 | # -------------------- 9 | python -u -m test.inference \ 10 | --llm_pretrained lmms-lab/llava-onevision-qwen2-7b-ov --bf16 true \ 11 | --lora_pretrained ${output_dir} \ 12 | --input_dir datasets/shot2story/videos --frame_fps 2 --max_num_frames 400 \ 13 | --test_fname datasets/shot2story/annotations/magqa_test.json \ 14 | --stream_end_prob_threshold ${thres} --score_heads "informative_score,relevance_score" \ 15 | --remove_assistant_turns true \ 16 | --output_fname ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-pred.json \ 17 | > ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-pred.log 2>&1 & 18 | wait 19 | 20 | # --stream_end_prob_threshold when the scores reach this theshold, stop the video stream and start a assistant turn. 21 | # --score_heads "informative_score,relevance_score" the two scores are added when comparing with stream_end_prob_threshold 22 | # --remove_assistant_turns is the rm. ass. turns trick that do not add previous assistant-generated responses in conversation context. 23 | 24 | 25 | # -------------------- 26 | # use LLaMA to evaluate and get the in-span score 27 | # -------------------- 28 | # 1. calculate similarities between pred and gold answers 29 | python -u -m test.evaluate --func magqa \ 30 | --llm_pretrained meta-llama/Meta-Llama-3.1-70B-Instruct \ 31 | --gold_file datasets/shot2story/annotations/magqa_test.json \ 32 | --pred_file ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-pred.json \ 33 | --output_file ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-llama_score-eval.json \ 34 | > ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-llama_score-eval.log 2>&1 & 35 | wait 36 | 37 | # 2. analyze the LLaMA-calculated scores to get the final in-span score 38 | python test/analyze_magqa_results.py \ 39 | --fname ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-llama_score-eval.json 40 | 41 | 42 | # -------------------- 43 | # use GPT-4o to evaluate and get the in-span score 44 | # -------------------- 45 | # we use OpenAI batch api to calculate the pred-gold answer similarity to save money and time 46 | 47 | # 0. set the openai key 48 | export OPENAI_API_KEY="your_openai_api_key" 49 | 50 | # 1. create batch input 51 | python test/openai_batch.py --func batch_input \ 52 | --pred_file ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-pred.json \ 53 | --gold_file datasets/shot2story/annotations/magqa_test.json \ 54 | --output_file ${output_dir}/eval/openai/magqa_test-thres_${thres}-rm_ass_turn-pred-batch_input.jsonl 55 | wc -l ${output_dir}/eval/openai/magqa_test-thres_${thres}-rm_ass_turn-pred-batch_input.jsonl 56 | 57 | # 2. submit this batch 58 | python test/openai_batch.py --func send_batch \ 59 | --pred_file ${output_dir}/eval/openai/magqa_test-thres_${thres}-rm_ass_turn-pred-batch_input.jsonl \ 60 | --description "xxx magqa test set evaluate" # you can change to other descriptions for this task 61 | 62 | # 3. wait until the Batch API service finish the evaluation process. You can check the progress by running the following command and get output_file_id 63 | python test/openai_batch.py --func check_batch 64 | 65 | # 4. download the similarity results 66 | python test/openai_batch.py --func get_batch \ 67 | --file_id OUTPUT_FILE_ID_YOU_GOT_IN_THE_LAST_STEP \ 68 | --output_file ${output_dir}/eval/openai/magqa_test-thres_${thres}-rm_ass_turn-pred-batch_output.jsonl 69 | wc -l ${output_dir}/eval/openai/magqa_test-thres_${thres}-rm_ass_turn-pred-batch_output.jsonl 70 | 71 | # 5. reformat the results to the format similar to LLaMA output: 72 | python test/openai_batch.py --func batch_output \ 73 | --pred_file ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-pred.json \ 74 | --gold_file datasets/shot2story/annotations/magqa_test.json \ 75 | --openai_file ${output_dir}/eval/openai/magqa_test-thres_${thres}-rm_ass_turn-pred-batch_output.jsonl \ 76 | --output_file ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-gpt4o_score-eval.json 77 | wc -l ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-gpt4o_score-eval.json 78 | 79 | # 6. analyze the GPT-4o-calculated scores to get the final in-span score 80 | python test/analyze_magqa_results.py \ 81 | --fname ${output_dir}/eval/magqa_test-thres_${thres}-rm_ass_turn-gpt4o_score-eval.json 82 | -------------------------------------------------------------------------------- /scripts/inference/qvh.sh: -------------------------------------------------------------------------------- 1 | output_dir=outputs/mmduet 2 | mkdir -vp ${output_dir}/eval 3 | 4 | 5 | # -------------------- 6 | # run inference 7 | # -------------------- 8 | python -u -m test.inference --grounding_mode true \ 9 | --llm_pretrained lmms-lab/llava-onevision-qwen2-7b-ov --bf16 true \ 10 | --lora_pretrained ${output_dir} \ 11 | --stream_end_prob_threshold 1 \ 12 | --input_dir datasets/qvh/videos --frame_fps 1 --max_num_frames 400 \ 13 | --test_fname datasets/qvh/annotations/highlight_val-random_prompt.json \ 14 | --output_fname ${output_dir}/eval/qvh_val-random_prompt-pred.json \ 15 | > ${output_dir}/eval/qvh_val-random_prompt-pred.log 2>&1 & 16 | wait 17 | 18 | 19 | # -------------------- 20 | # evaluate 21 | # -------------------- 22 | python -u -m test.evaluate --func qvh_highlight \ 23 | --pred_file ${output_dir}/eval/qvh_val-random_prompt-pred.json \ 24 | --gold_file datasets/qvh/annotations/highlight_val_release.jsonl \ 25 | --output_file ${output_dir}/eval/qvh_val-random_prompt-eval.json \ 26 | > ${output_dir}/eval/qvh_val-random_prompt-eval.log 2>&1 & 27 | -------------------------------------------------------------------------------- /scripts/inference/youcook2.sh: -------------------------------------------------------------------------------- 1 | output_dir=outputs/mmduet 2 | mkdir -vp ${output_dir}/eval 3 | 4 | thres_sum=2 5 | 6 | # -------------------- 7 | # run inference 8 | # -------------------- 9 | python -u -m test.inference \ 10 | --llm_pretrained lmms-lab/llava-onevision-qwen2-7b-ov --bf16 true \ 11 | --lora_pretrained ${output_dir} \ 12 | --input_dir datasets/youcook2/videos --frame_fps 0.5 --max_num_frames 200 \ 13 | --test_fname datasets/youcook2/annotations/val-random_prompt.json \ 14 | --stream_end_score_sum_threshold ${thres_sum} --remove_assistant_turns true \ 15 | --output_fname ${output_dir}/eval/youcook2_val-thres_sum_${thres_sum}-rm_ass_turns-pred.json \ 16 | > ${output_dir}/eval/youcook2_val-thres_sum_${thres_sum}-rm_ass_turns-pred.log 2>&1 & 17 | wait 18 | 19 | # --stream_end_score_sum_threshold is the theshold of the sum of informative score, 20 | # when the sum reaches this theshold, assistant generates a response 21 | # --remove_assistant_turns is the rm. ass. turns trick that do not add previous assistant-generated responses in conversation context. 22 | 23 | # -------------------- 24 | # evaluate the model generated results 25 | # -------------------- 26 | python -m test.evaluate --func dense_captioning \ 27 | --pred_file ${output_dir}/eval/youcook2_val-thres_sum_${thres_sum}-rm_ass_turns-pred.json \ 28 | --gold_file datasets/youcook2/annotations/val-random_prompt.json \ 29 | > ${output_dir}/eval/youcook2_val-thres_sum_${thres_sum}-rm_ass_turns-eval.log 2>&1 & 30 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | output_dir=outputs/mmduet-train_output 2 | mkdir -p $output_dir 3 | 4 | torchrun --nproc_per_node 8 --master_port 29506 train.py --deepspeed configs/deepspeed/zero2.json \ 5 | --bf16 true --tf32 true \ 6 | --dataset_config configs/datasets/mmduetit.json \ 7 | --llm_pretrained lmms-lab/llava-onevision-qwen2-7b-ov \ 8 | --num_train_epochs 1 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \ 9 | --gradient_accumulation_steps 16 --gradient_checkpointing true \ 10 | --evaluation_strategy no --prediction_loss_only false \ 11 | --save_strategy steps --save_steps 500 --save_total_limit 5 \ 12 | --learning_rate 0.00002 --optim adamw_torch --lr_scheduler_type cosine --warmup_ratio 0.05 \ 13 | --dataloader_num_workers 4 \ 14 | --logging_steps 10 \ 15 | --report_to tensorboard \ 16 | --output_dir $output_dir \ 17 | > $output_dir/train.log 2>&1 & 18 | 19 | # check `configs/datasets/mmduetit.json` for datasets used. 20 | # If you want to use your own dataset to train MMDuet, write your own data config file like this file. 21 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/test/__init__.py -------------------------------------------------------------------------------- /test/analyze_magqa_results.py: -------------------------------------------------------------------------------- 1 | 2 | # analyze shotstory livechat results evaluated by llama or gpt4o 3 | import json, argparse 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | 8 | def text_score_to_int(text): 9 | if not isinstance(text, str): return text 10 | return int(text[0]) if text[0] in '12345' else 1 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--fname", type=str) 16 | parser.add_argument("--model_type", type=str, default='online') 17 | parser.add_argument("--num_examples", type=int, default=2000) 18 | parser.add_argument("--baseline_all_match", type=int, default=1, 19 | help="if set to 1, when a baseline model does not provide time in the pred answer, match this pred answer with all turns of gold answers. if set to 0, skip this example.") 20 | parser.add_argument("--pad_with_one", type=int, default=1, 21 | help="if the number of examples stated is less than --num_examples, add 1 as in-span score for remaining examples") 22 | args = parser.parse_args() 23 | print(args) 24 | 25 | num_turns_list_dedup, num_turns_list = list(), list() 26 | max_acc_score_list = list() 27 | nearest_acc_score_list = list() 28 | in_span_acc_score_list = list() 29 | 30 | for line in tqdm(open(args.fname).readlines()[:args.num_examples]): 31 | eval_example = json.loads(line) 32 | 33 | if not args.baseline_all_match: 34 | if eval_example['model_response_list'][0]['time'] == -1: continue 35 | 36 | # stat length 37 | sentences = [turn['content'] for turn in eval_example['model_response_list'] if turn['role'] == 'assistant'] 38 | num_turns_list.append(len(sentences)) 39 | num_turns_list_dedup.append(len(set(sentences))) 40 | 41 | # stat scores 42 | max_acc_score_list.append(np.mean([max([text_score_to_int(score)] for score in turn_scores) for turn_scores in eval_example['evaluator_output']])) 43 | 44 | example_acc_score_list = list() 45 | turn_time_list = [turn['time'] for turn in eval_example['model_response_list'] if turn['role'] == 'assistant'] 46 | for score_list, answer_time in zip(eval_example['evaluator_output'], eval_example['answer_time']): 47 | if args.baseline_all_match: 48 | answer_in_span_idx = [turn_idx for turn_idx, turn_time in enumerate(turn_time_list) if (answer_time[0] <= turn_time <= answer_time[1] or turn_time == -1)] 49 | else: 50 | answer_in_span_idx = [turn_idx for turn_idx, turn_time in enumerate(turn_time_list) if answer_time[0] <= turn_time <= answer_time[1]] 51 | if not answer_in_span_idx: 52 | example_acc_score_list.append(1) 53 | # pass 54 | else: 55 | example_acc_score_list.append(np.mean([text_score_to_int(score_list[idx]) for idx in answer_in_span_idx])) 56 | if not example_acc_score_list: 57 | example_acc_score_list.append(1) 58 | # pass 59 | else: 60 | in_span_acc_score_list.append(np.mean(example_acc_score_list)) 61 | 62 | if len(num_turns_list) < args.num_examples and args.pad_with_one: 63 | num_turns_list = num_turns_list + [0] * (args.num_examples - len(num_turns_list)) 64 | num_turns_list_dedup = num_turns_list_dedup + [0] * (args.num_examples - len(num_turns_list_dedup)) 65 | max_acc_score_list = max_acc_score_list + [1] * (args.num_examples - len(max_acc_score_list)) 66 | in_span_acc_score_list = in_span_acc_score_list + [1] * (args.num_examples - len(in_span_acc_score_list)) 67 | 68 | print(args.fname, len(num_turns_list)) 69 | # latex table format output: score & turns / turns(dedup.) & \\ 70 | print(round(np.mean(in_span_acc_score_list), 2), end=' & ') 71 | print(round(np.mean(num_turns_list), 2), end='/') 72 | print(round(np.mean(num_turns_list_dedup), 2), end=' & ') 73 | print('\\') 74 | -------------------------------------------------------------------------------- /test/datasets.py: -------------------------------------------------------------------------------- 1 | import os, json, math, random 2 | import torch 3 | import cv2 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FastAndAccurateStreamingVideoQADataset(Dataset): 9 | """ 10 | Dataset class for Fast and Accurate Streaming Video Question Answering Benchmarks 11 | """ 12 | def __init__(self, data_file, video_base_folder, start_idx=0, end_idx=None, 13 | output_fps=2, output_resolution=384, max_num_frames=100, time_instruction_format=None, 14 | system_prompt="A multimodal AI assistant is helping users with some activities." 15 | " Below is their conversation, interleaved with the list of video frames received by the assistant."): 16 | """ 17 | set output_fps = 'auto' to always load "max_num_fraems" frames from the video.a 18 | this is used when the lengths of videos vary significantly in the test set. 19 | """ 20 | self.data = json.load(open(data_file))[start_idx: end_idx] 21 | self.video_base_folder = video_base_folder 22 | self.output_fps = output_fps 23 | self.output_resolution = output_resolution 24 | self.max_num_frames = max_num_frames 25 | self.pad_color = (0, 0, 0) 26 | self.system_prompt = system_prompt 27 | self.time_instruction_format = time_instruction_format # provide frame time for traditional video llms 28 | print(f'loaded {len(self)} samples from {data_file}. Example data:') 29 | print(self[0]) 30 | print(self[random.randint(0, len(self)-1)]) 31 | 32 | def load_video(self, video_file): 33 | video_file = os.path.join(self.video_base_folder, video_file) 34 | cap = cv2.VideoCapture(video_file) 35 | # Get original video properties 36 | input_fps = cap.get(cv2.CAP_PROP_FPS) 37 | frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) 38 | video_duration = frame_count / input_fps 39 | input_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 40 | input_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 41 | output_width = output_height = self.output_resolution 42 | 43 | output_fps = self.output_fps if self.output_fps > 0 else self.max_num_frames / video_duration 44 | num_frames_total = math.ceil(video_duration * output_fps) 45 | frame_sec = [i / output_fps for i in range(num_frames_total)] 46 | frame_list, cur_time, frame_index = [], 0, 0 47 | while cap.isOpened(): 48 | ret, frame = cap.read() 49 | if not ret: 50 | break 51 | if frame_index < len(frame_sec) and cur_time >= frame_sec[frame_index]: 52 | if input_width > input_height: 53 | # Landscape video: scale width to the resolution, adjust height 54 | new_width = self.output_resolution 55 | new_height = int((input_height / input_width) * self.output_resolution) 56 | else: 57 | # Portrait video: scale height to the resolution, adjust width 58 | new_height = self.output_resolution 59 | new_width = int((input_width / input_height) * self.output_resolution) 60 | resized_frame = cv2.resize(frame, (new_width, new_height)) 61 | # pad the frame 62 | canvas = cv2.copyMakeBorder( 63 | resized_frame, 64 | top=(output_height - new_height) // 2, 65 | bottom=(output_height - new_height + 1) // 2, 66 | left=(output_width - new_width) // 2, 67 | right=(output_width - new_width + 1) // 2, 68 | borderType=cv2.BORDER_CONSTANT, 69 | value=self.pad_color 70 | ) 71 | frame_list.append(np.transpose(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB), (2, 0, 1))) 72 | frame_index += 1 73 | if len(frame_list) >= self.max_num_frames: 74 | break 75 | cur_time += 1 / input_fps 76 | cap.release() 77 | 78 | if self.time_instruction_format == 'timechat': 79 | frame_sec_str = ",".join([f"{i:.2f}s" for i in frame_sec]) 80 | time_instruciton = f"The video lasts for {video_duration:.2f} seconds, and {len(frame_list)} frames are uniformly sampled from it. These frames are located at {frame_sec_str}.Please answer the following questions related to this video." 81 | return torch.tensor(np.stack(frame_list)), output_fps, video_duration, time_instruciton 82 | elif self.time_instruction_format == 'vtimellm': 83 | time_instruciton = f"This is a video with {len(frame_list)} frames." 84 | return torch.tensor(np.stack(frame_list)), output_fps, video_duration, time_instruciton 85 | return torch.tensor(np.stack(frame_list)), output_fps, video_duration 86 | 87 | def __len__(self): 88 | return len(self.data) 89 | 90 | def __getitem__(self, idx): 91 | example = self.data[idx] 92 | try: 93 | conversation = example['conversation'] 94 | question_id = example['question_id'] 95 | if self.time_instruction_format is None: 96 | video_frames, output_fps, video_duration = self.load_video(example['video']) 97 | else: 98 | video_frames, output_fps, video_duration, time_instruction = self.load_video(example['video']) 99 | conversation[0]['content'] = time_instruction + '\n' + conversation[0]['content'] 100 | conversation.insert(0, {"role": "system", "content": self.system_prompt}) 101 | return question_id, video_frames, conversation, output_fps, video_duration 102 | except Exception as e: 103 | print(f"error loading {example['question_id']} due to exception {e}, this example will be skipped") 104 | return None, None, None, None, None 105 | 106 | 107 | class StreamingVideoQADatasetWithGenTime(FastAndAccurateStreamingVideoQADataset): 108 | def __getitem__(self, idx): 109 | example = self.data[idx] 110 | try: 111 | conversation = example['conversation'] 112 | question_id = example['question_id'] 113 | video_frames, output_fps, video_duration = self.load_video(example['video']) 114 | conversation.insert(0, {"role": "system", "content": self.system_prompt}) 115 | gen_time_list = [i['time'][1] for i in example['answer']] 116 | return question_id, video_frames, conversation, output_fps, video_duration, gen_time_list 117 | except Exception as e: 118 | print(f"error loading {example['question_id']} due to exception {e}, this example will be skipped") 119 | return None, None, None, None, None 120 | -------------------------------------------------------------------------------- /test/dvc/metrics/README.md: -------------------------------------------------------------------------------- 1 | # captioning-metrics 2 | 3 | This is a fork from https://github.com/salaniz/pycocoevalcap, with several functionalities that make it easier to run for dense video captioning, e.g. not closing the METEOR jar at every call but only once per evaluation. 4 | 5 | To use it, you may download https://github.com/salaniz/pycocoevalcap/tree/master/meteor/data and put in under the data/ folder. 6 | -------------------------------------------------------------------------------- /test/dvc/metrics/cider.py: -------------------------------------------------------------------------------- 1 | """Computes the CIDEr (Consensus-Based Image Description Evaluation) Metric.""" 2 | 3 | # Filename: cider.py 4 | # 5 | # Description: Describes the class to compute the CIDEr 6 | # (Consensus-Based Image Description Evaluation) Metric 7 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 8 | # 9 | # Creation Date: Sun Feb 8 14:16:54 2015 10 | # 11 | # Authors: Ramakrishna Vedantam 12 | # and Tsung-Yi Lin 13 | 14 | from .cider_scorer import CiderScorer 15 | 16 | 17 | class Cider: 18 | """Main Class to compute the CIDEr metric.""" 19 | 20 | def __init__(self, n=4, sigma=6.0): 21 | # set cider to sum over 1 to 4-grams 22 | self._n = n 23 | # set the standard deviation parameter for gaussian penalty 24 | self._sigma = sigma 25 | 26 | def compute_score(self, gts, res): 27 | """Main function to compute CIDEr score. 28 | 29 | Args: 30 | gts: dictionary with key and value 32 | res: dictionary with key and value 33 | 34 | Returns: 35 | Computed CIDEr float score for the corpus. 36 | """ 37 | 38 | assert sorted(gts.keys()) == sorted(res.keys()) 39 | imgids = list(gts.keys()) 40 | 41 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 42 | 43 | # Sort the IDs to be able to have control over the order 44 | # of the individual scores. 45 | for iid in sorted(imgids): 46 | hypo = res[iid] 47 | ref = gts[iid] 48 | 49 | # Sanity check. 50 | assert isinstance(hypo, list) 51 | assert len(hypo) == 1 52 | assert isinstance(ref, list) 53 | assert ref 54 | 55 | cider_scorer += (hypo[0], ref) 56 | 57 | (score, scores) = cider_scorer.compute_score() 58 | 59 | return score, scores 60 | 61 | def method(self): 62 | return "CIDEr" 63 | -------------------------------------------------------------------------------- /test/dvc/metrics/cider_scorer.py: -------------------------------------------------------------------------------- 1 | """Computes the CIDEr (Consensus-Based Image Description Evaluation) Metric.""" 2 | 3 | # Tsung-Yi Lin 4 | # Ramakrishna Vedantam 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import collections 11 | import copy 12 | import math 13 | 14 | import numpy as np 15 | import six 16 | from six.moves import range 17 | from six.moves import zip 18 | 19 | 20 | def precook(s, n=4): 21 | """Takes a string as input. 22 | 23 | And returns an object that can be given to either cook_refs or cook_test. 24 | This is optional: cook_refs and cook_test can take string arguments as well. 25 | 26 | Args: 27 | s: string : sentence to be converted into ngrams. 28 | n: int : number of ngrams for which representation is calculated. 29 | 30 | Returns: 31 | Term frequency vector for occuring ngrams. 32 | """ 33 | words = s.split() 34 | counts = collections.defaultdict(int) 35 | for k in range(1, n + 1): 36 | for i in range(len(words) - k + 1): 37 | ngram = tuple(words[i:i + k]) 38 | counts[ngram] += 1 39 | return counts 40 | 41 | 42 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 43 | """Takes a list of reference sentences for a single segment. 44 | 45 | And returns an object that encapsulates everything that BLEU 46 | needs to know about them. 47 | 48 | Args: 49 | refs: list of string : reference sentences for some image. 50 | n: int : number of ngrams for which (ngram) representation is calculated. 51 | 52 | Returns: 53 | result (list of dict). 54 | """ 55 | return [precook(ref, n) for ref in refs] 56 | 57 | 58 | def cook_test(test, n=4): 59 | """Takes a test sentence. 60 | 61 | And returns an object that encapsulates everything 62 | that BLEU needs to know about it. 63 | 64 | Args: 65 | test: list of string : hypothesis sentence for some image. 66 | n: int : number of ngrams for which (ngram) representation is calculated. 67 | 68 | Returns: 69 | result (dict). 70 | """ 71 | return precook(test, n) 72 | 73 | 74 | class CiderScorer(object): 75 | """CIDEr scorer.""" 76 | 77 | def copy(self): 78 | """Copy the refs.""" 79 | new = CiderScorer(n=self.n) 80 | new.ctest = copy.copy(self.ctest) 81 | new.crefs = copy.copy(self.crefs) 82 | return new 83 | 84 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 85 | """Singular instance.""" 86 | self.n = n 87 | self.sigma = sigma 88 | self.crefs = [] 89 | self.ctest = [] 90 | self.document_frequency = collections.defaultdict(float) 91 | self.cook_append(test, refs) 92 | self.ref_len = None 93 | 94 | def cook_append(self, test, refs): 95 | """called by constructor and __iadd__ to avoid creating new instances.""" 96 | 97 | if refs is not None: 98 | self.crefs.append(cook_refs(refs)) 99 | if test is not None: 100 | self.ctest.append(cook_test(test)) ## N.B.: -1 101 | else: 102 | self.ctest.append(None) # lens of crefs and ctest have to match 103 | 104 | def size(self): 105 | assert len(self.crefs) == len( 106 | self.ctest), "refs/test mismatch! %d<>%d" % (len( 107 | self.crefs), len(self.ctest)) 108 | return len(self.crefs) 109 | 110 | def __iadd__(self, other): 111 | """add an instance (e.g., from another sentence).""" 112 | 113 | if isinstance(other, tuple): 114 | ## avoid creating new CiderScorer instances 115 | self.cook_append(other[0], other[1]) 116 | else: 117 | self.ctest.extend(other.ctest) 118 | self.crefs.extend(other.crefs) 119 | 120 | return self 121 | 122 | def compute_doc_freq(self): 123 | """Compute term frequency for reference data. 124 | 125 | This will be used to compute idf (inverse document frequency later). 126 | The term frequency is stored in the object. 127 | 128 | Returns: 129 | None 130 | """ 131 | for refs in self.crefs: 132 | # refs, k ref captions of one image 133 | for ngram in set( 134 | [ngram for ref in refs for (ngram, count) in six.iteritems(ref)]): # pylint: disable=g-complex-comprehension 135 | self.document_frequency[ngram] += 1 136 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 137 | 138 | def compute_cider(self, corpus_size=None): 139 | """Function computing CIDEr.""" 140 | 141 | def counts2vec(cnts): 142 | """Function maps counts of ngram to vector of tfidf weights. 143 | 144 | The function returns vec, an array of dictionary that store mapping 145 | of n-gram and tf-idf weights. 146 | The n-th entry of array denotes length of n-grams. 147 | Args: 148 | cnts: 149 | Returns: 150 | vec (array of dict), norm (array of float), length (int). 151 | """ 152 | vec = [collections.defaultdict(float) for _ in range(self.n)] 153 | length = 0 154 | norm = [0.0 for _ in range(self.n)] 155 | for (ngram, term_freq) in six.iteritems(cnts): 156 | # give word count 1 if it doesn't appear in reference corpus 157 | df = np.log(max(1.0, self.document_frequency[ngram])) 158 | # ngram index 159 | n = len(ngram) - 1 160 | # tf (term_freq) * idf (precomputed idf) for n-grams 161 | vec[n][ngram] = float(term_freq) * (self.ref_len - df) 162 | # compute norm for the vector. 163 | # the norm will be used for computing similarity. 164 | norm[n] += pow(vec[n][ngram], 2) 165 | 166 | if n == 1: 167 | length += term_freq 168 | norm = [np.sqrt(n) for n in norm] 169 | return vec, norm, length 170 | 171 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 172 | """Compute the cosine similarity of two vectors. 173 | 174 | Args: 175 | vec_hyp: array of dictionary for vector corresponding to hypothesis. 176 | vec_ref: array of dictionary for vector corresponding to reference. 177 | norm_hyp: array of float for vector corresponding to hypothesis. 178 | norm_ref: array of float for vector corresponding to reference. 179 | length_hyp: int containing length of hypothesis. 180 | length_ref: int containing length of reference. 181 | 182 | Returns: 183 | Array of score for each n-grams cosine similarity. 184 | """ 185 | delta = float(length_hyp - length_ref) 186 | # measure consine similarity 187 | val = np.array([0.0 for _ in range(self.n)]) 188 | for n in range(self.n): 189 | # ngram 190 | for (ngram, _) in six.iteritems(vec_hyp[n]): 191 | # vrama91 : added clipping 192 | val[n] += min(vec_hyp[n][ngram], 193 | vec_ref[n][ngram]) * vec_ref[n][ngram] 194 | 195 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 196 | val[n] /= (norm_hyp[n] * norm_ref[n]) 197 | 198 | assert not math.isnan(val[n]) 199 | # vrama91: added a length based gaussian penalty 200 | val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) 201 | return val 202 | 203 | corpus_size = float(corpus_size or len(self.crefs)) 204 | self.ref_len = np.log(corpus_size) 205 | 206 | scores = [] 207 | for test, refs in zip(self.ctest, self.crefs): 208 | # compute vector for test captions 209 | vec, norm, length = counts2vec(test) 210 | # compute vector for ref captions 211 | score = np.array([0.0 for _ in range(self.n)]) 212 | for ref in refs: 213 | vec_ref, norm_ref, length_ref = counts2vec(ref) 214 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 215 | # change by vrama91 - mean of ngram scores, instead of sum 216 | score_avg = np.mean(score) 217 | # divide by number of references 218 | score_avg /= len(refs) 219 | # multiply score by 10 220 | score_avg *= 10.0 221 | # append score of an image to the score list 222 | scores.append(score_avg) 223 | return scores 224 | 225 | def compute_score(self): 226 | # compute idf 227 | self.compute_doc_freq() 228 | # assert to check document frequency 229 | assert (len(self.ctest) >= max(self.document_frequency.values())) 230 | # compute cider score 231 | score = self.compute_cider() 232 | # debug 233 | # print score 234 | return np.mean(np.array(score)), np.array(score) 235 | -------------------------------------------------------------------------------- /test/dvc/metrics/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/test/dvc/metrics/meteor-1.5.jar -------------------------------------------------------------------------------- /test/dvc/metrics/meteor.py: -------------------------------------------------------------------------------- 1 | """Python wrapper for METEOR implementation.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import subprocess 9 | import threading 10 | 11 | import numpy as np 12 | import six 13 | 14 | 15 | class Meteor(object): 16 | """Meteor scorer.""" 17 | 18 | def __init__(self, 19 | meteor_jar_path=None, 20 | java_jre_path=None, 21 | jdk_java_options=None): 22 | if java_jre_path: 23 | self.java_bin = java_jre_path 24 | elif 'JRE_BIN_JAVA' in os.environ: 25 | self.java_bin = os.environ['JRE_BIN_JAVA'] 26 | else: 27 | self.java_bin = 'java' 28 | 29 | if meteor_jar_path: 30 | meteor_jar = meteor_jar_path 31 | else: 32 | meteor_jar = os.path.join( 33 | './metrics', 'meteor-1.5.jar' 34 | ) 35 | 36 | assert os.path.exists(meteor_jar), meteor_jar 37 | 38 | jdk_java_options = jdk_java_options or ['-Xmx2G'] 39 | meteor_cmd = [ 40 | self.java_bin, '-jar', '-Xmx2G', meteor_jar, '-', '-', '-stdio', 41 | '-l', 'en', '-norm' 42 | ] 43 | 44 | self.meteor_p = subprocess.Popen( 45 | meteor_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) 46 | self.lock = threading.Lock() 47 | 48 | def compute_score(self, gts, res): 49 | """Compute METEOR scores.""" 50 | with self.lock: 51 | assert sorted(gts.keys()) == sorted(res.keys()) 52 | img_ids = sorted(gts.keys()) 53 | scores = [] 54 | 55 | eval_line = 'EVAL ||| ' 56 | stats = self._stat(img_ids, res, gts) 57 | eval_line += ' ||| '.join(stats) 58 | self.meteor_p.stdin.write(six.ensure_binary(eval_line + '\n')) 59 | self.meteor_p.stdin.flush() 60 | scores = [float(six.ensure_str(self.meteor_p.stdout.readline())) 61 | for _ in img_ids] 62 | # get the aggregated value 63 | score = self.meteor_p.stdout.readline() 64 | # do not close the file inside this function to keep it open for full eval 65 | return float(score), np.asarray(scores) 66 | 67 | def method(self): 68 | return 'METEOR' 69 | 70 | def _stat(self, img_ids, hypothesis_str, reference_list): # pylint: disable=missing-function-docstring 71 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 72 | stat_lines = [] 73 | for i in img_ids: 74 | assert len(hypothesis_str[i]) == 1 75 | hypo = hypothesis_str[i][0].replace('|||', '').replace(' ', ' ') 76 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list[i]), 77 | hypo)) 78 | 79 | self.meteor_p.stdin.write(six.ensure_binary(score_line + '\n')) 80 | self.meteor_p.stdin.flush() 81 | stat_lines.append(six.ensure_str(self.meteor_p.stdout.readline()).strip()) 82 | return stat_lines 83 | -------------------------------------------------------------------------------- /test/dvc/metrics/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | """PTBTokenizer.""" 2 | 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | # pylint: disable=g-inconsistent-quotes 16 | 17 | # punctuations to be removed from the sentences 18 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", 19 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 20 | 21 | 22 | class PTBTokenizer: 23 | """Python wrapper of Stanford PTBTokenizer.""" 24 | 25 | def __init__(self, 26 | ptbtokenizer_jar_path=None, 27 | java_jre_path=None): 28 | if java_jre_path: 29 | self.java_bin = java_jre_path 30 | elif 'JRE_BIN_JAVA' in os.environ: 31 | self.java_bin = os.environ['JRE_BIN_JAVA'] 32 | else: 33 | self.java_bin = 'java' 34 | 35 | if ptbtokenizer_jar_path: 36 | self.ptbtokenizer_jar = ptbtokenizer_jar_path 37 | else: 38 | self.ptbtokenizer_jar = os.path.join( 39 | "./metrics", 40 | "stanford-corenlp-3.4.1.jar", 41 | ) 42 | 43 | assert os.path.exists(self.ptbtokenizer_jar), self.ptbtokenizer_jar 44 | 45 | def tokenize(self, captions_for_image): 46 | """Tokenization.""" 47 | 48 | cmd = [self.java_bin, '-cp', self.ptbtokenizer_jar, 49 | 'edu.stanford.nlp.process.PTBTokenizer', 50 | '-preserveLines', '-lowerCase'] 51 | 52 | # ====================================================== 53 | # prepare data for PTB Tokenizer 54 | # ====================================================== 55 | final_tokenized_captions_for_image = {} 56 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] # pylint: disable=g-complex-comprehension 57 | sentences = "\n".join( 58 | [ # pylint: disable=g-complex-comprehension 59 | c["caption"].replace("\n", " ") 60 | for k, v in captions_for_image.items() 61 | for c in v 62 | ] 63 | ) 64 | 65 | # ====================================================== 66 | # save sentences to temporary file 67 | # ====================================================== 68 | fd, tmpfname = tempfile.mkstemp() 69 | with os.fdopen(fd, 'w') as f: 70 | f.write(sentences) 71 | 72 | # ====================================================== 73 | # tokenize sentence 74 | # ====================================================== 75 | cmd.append(tmpfname) 76 | p_tokenizer = subprocess.Popen(cmd, stdout=subprocess.PIPE) 77 | token_lines = p_tokenizer.communicate(input=sentences.rstrip().encode())[0] 78 | token_lines = token_lines.decode() 79 | lines = token_lines.split('\n') 80 | # remove temp file 81 | os.remove(tmpfname) 82 | 83 | # ====================================================== 84 | # create dictionary for tokenized captions 85 | # ====================================================== 86 | for k, line in zip(image_id, lines): 87 | if k not in final_tokenized_captions_for_image: 88 | final_tokenized_captions_for_image[k] = [] 89 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') 90 | if w not in PUNCTUATIONS]) 91 | final_tokenized_captions_for_image[k].append(tokenized_caption) 92 | 93 | return final_tokenized_captions_for_image 94 | -------------------------------------------------------------------------------- /test/dvc/metrics/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yellow-binary-tree/MMDuet/33be1387cd18643614d3c414c3c237bdd4ff59cc/test/dvc/metrics/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /test/inference.py: -------------------------------------------------------------------------------- 1 | import collections, math, json, copy 2 | from dataclasses import asdict 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import transformers 7 | from torchvision.io import read_video 8 | from peft import PeftModel 9 | logger = transformers.logging.get_logger('inference') 10 | 11 | from llava.mm_utils import tokenizer_image_token 12 | from llava.model.builder import load_pretrained_model 13 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 14 | from llava.conversation import conv_templates 15 | 16 | from models import build_model_and_tokenizer, fast_greedy_generate, parse_args 17 | from .datasets import FastAndAccurateStreamingVideoQADataset 18 | 19 | 20 | class LiveInferForBenchmark: 21 | def __init__(self, args) -> None: 22 | assert not (args.bf16 and args.fp16), "only one of --bf16 true and --fp16 true can be set" 23 | self.torch_dtype = torch.bfloat16 if args.bf16 else torch.float16 if args.fp16 else torch.float32 24 | self.model, self.tokenizer = build_model_and_tokenizer(is_training=False, set_vision_inside=True, torch_dtype=self.torch_dtype, **asdict(args)) 25 | self.model.eval() 26 | if 'llava' in args.llm_pretrained: 27 | self.image_processor = self.model.get_vision_tower().image_processor 28 | else: 29 | self.image_processor = None 30 | # self.model.to('cuda') 31 | 32 | # visual 33 | self.hidden_size = self.model.config.hidden_size 34 | if args.frame_fps > 0: 35 | self.set_fps(args.frame_fps) 36 | self.frame_resolution = self.model.config.frame_resolution 37 | self.frame_num_tokens = self.model.config.frame_num_tokens 38 | self.frame_v_placeholder = self.model.config.v_placeholder * self.frame_num_tokens 39 | 40 | # generation 41 | self.system_prompt = args.system_prompt 42 | self.inplace_output_ids = torch.zeros(1, 200, device='cuda', dtype=torch.long) 43 | self.stream_end_prob_threshold = args.stream_end_prob_threshold 44 | self.response_min_interval_frames = args.response_min_interval_frames 45 | self.threshold_z = args.threshold_z 46 | self.first_n_frames_no_generate = args.first_n_frames_no_generate 47 | self.running_list_length = args.running_list_length 48 | self.stream_end_score_sum_threshold = args.stream_end_score_sum_threshold 49 | self.score_heads = args.score_heads.split(',') 50 | self.consecutive_n_frames_threshold = args.consecutive_n_frames_threshold 51 | print(f'score heads: {self.score_heads}') 52 | 53 | if int(self.threshold_z is not None) + int(self.stream_end_prob_threshold is not None) + int(self.stream_end_score_sum_threshold is not None) != 1: 54 | raise ValueError(f'only one of --stream_end_prob_threshold, --threshold_z and --stream_end_score_sum_threshold can be set. However, they are: {self.stream_end_prob_threshold}, {self.threshold_z}, {self.stream_end_score_sum_threshold}') 55 | if self.threshold_z is not None and self.first_n_frames_no_generate is None: 56 | raise ValueError('--first_n_frames_no_generate must be set when --threshold_z is set') 57 | 58 | self.remove_assistant_turns = args.remove_assistant_turns 59 | 60 | self.eos_token_id = self.model.config.eos_token_id 61 | self._start_ids = self.tokenizer.apply_chat_template([{'role': 'system', 'content': self.system_prompt}], return_tensors='pt').to('cuda') 62 | self._added_stream_prompt_ids = self.tokenizer.apply_chat_template([{}], add_stream_prompt=True, return_tensors='pt').to('cuda') 63 | self._added_stream_generation_ids = self.tokenizer.apply_chat_template([{}], add_stream_generation_prompt=True, return_tensors='pt').to('cuda') 64 | self.repetition_penalty = args.repetition_penalty 65 | 66 | self.reset() 67 | 68 | def set_fps(self, fps=None, frame_interval=None): 69 | assert fps is not None or frame_interval is not None 70 | assert not (fps is not None and frame_interval is not None) 71 | if fps is not None: 72 | self.frame_fps = fps 73 | self.frame_interval = 1 / self.frame_fps 74 | else: 75 | self.frame_interval = frame_interval 76 | self.frame_fps = 1 / self.frame_interval 77 | 78 | # DEPRECATED 79 | def _call_for_response(self, video_time, query): 80 | if query is not None: 81 | # encode the user query 82 | self.last_ids = self.tokenizcer.apply_chat_template([{'role': 'user', 'content': query}], add_stream_query_prompt=self.last_role == 'stream', add_stream_prompt=True, return_tensors='pt').to('cuda') 83 | inputs_embeds = self.model.get_input_embeddings()(self.last_ids) 84 | outputs = self.model(inputs_embeds=inputs_embeds, past_key_values=self.past_key_values, use_cache=True, return_dict=True) 85 | self.past_key_values = outputs.past_key_values 86 | self.last_ids = outputs.logits[:, -1:].argmax(dim=-1) 87 | self.last_role = 'user' 88 | return query, None 89 | 90 | self.last_ids = self._added_stream_generation_ids 91 | inputs_embeds = self.model.get_input_embeddings()(self.last_ids) 92 | output_ids, past_key_values, self.generated_token_ids = fast_greedy_generate( 93 | model=self.model, inputs_embeds=inputs_embeds, past_key_values=self.past_key_values, eos_token_id=self.eos_token_id, inplace_output_ids=self.inplace_output_ids, 94 | repetition_penalty=self.repetition_penalty, generated_token_ids=self.generated_token_ids 95 | ) 96 | 97 | if not self.remove_assistant_turns: 98 | self.past_key_values = past_key_values 99 | self.last_ids = output_ids[:, -1:] 100 | else: 101 | self.last_ids = torch.tensor([[]], device='cuda', dtype=torch.long) 102 | 103 | response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) 104 | self.num_frames_no_reply = 0 105 | self.last_role = 'assistant' 106 | return query, response 107 | 108 | # DEPRECATED 109 | def _call_for_streaming(self): 110 | while self.frame_embeds_queue: 111 | # 1. if query is before next frame, encode the user query first 112 | if self.query_queue and self.frame_embeds_queue[0][0] > self.query_queue[0][0]: 113 | video_time, query = self.query_queue.popleft() 114 | return video_time, query 115 | video_time, frame_embeds = self.frame_embeds_queue.popleft() 116 | if not self.past_key_values: 117 | self.last_ids = self._start_ids 118 | elif self.last_role == 'assistant' and not self.remove_assistant_turns: 119 | self.last_ids = torch.cat([self.last_ids, self._added_stream_prompt_ids], dim=1) 120 | else: # last_role is stream, now we just input another frame 121 | self.last_ids = torch.tensor([[]], device='cuda', dtype=torch.long) 122 | inputs_embeds = torch.cat([ 123 | self.model.get_input_embeddings()(self.last_ids).view(1, -1, self.hidden_size), 124 | frame_embeds.view(1, -1, self.hidden_size).to(self.last_ids.device), 125 | ], dim=1) 126 | outputs = self.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=self.past_key_values, return_dict=True) 127 | self.past_key_values = outputs.past_key_values 128 | 129 | self.frame_idx += 1 130 | self.num_frames_no_reply += 1 131 | 132 | informative_score = outputs.informative_logits[0,-1].softmax(dim=-1) 133 | relevance_score = outputs.relevance_logits[0,-1].softmax(dim=-1) 134 | 135 | # 3. if no user input, check what informative_heads returns 136 | video_heads_data = {'video_time': video_time, 'informative_score': informative_score.tolist(), 'relevance_score': relevance_score.tolist()} 137 | self.debug_data_list.append(video_heads_data) 138 | 139 | if self.stream_end_score_sum_threshold is not None: 140 | stream_end_prob_threshold = None 141 | elif self.stream_end_prob_threshold is not None: 142 | stream_end_prob_threshold = self.stream_end_prob_threshold 143 | else: 144 | if len(self.stream_end_prob_list) < self.first_n_frames_no_generate: 145 | # set the threshold to a very large number, to ensure the informative_head can not produce a score larger than this 146 | stream_end_prob_threshold = 1 147 | else: 148 | # set the threshold as mean + z * std 149 | stream_end_prob_threshold = np.mean(self.stream_end_prob_list) + self.threshold_z * np.std(self.stream_end_prob_list) 150 | 151 | stream_end_score = 0. 152 | for score_name in ['informative_score', 'relevance_score']: 153 | if score_name in self.score_heads: 154 | stream_end_score += video_heads_data[score_name][1] 155 | 156 | self.stream_end_prob_list.append(stream_end_score) 157 | self.stream_end_score_sum += stream_end_score 158 | 159 | if isinstance(self.running_list_length, int) and self.running_list_length > 0: 160 | self.stream_end_prob_list = self.stream_end_prob_list[-self.running_list_length:] 161 | self.last_role = 'stream' 162 | if stream_end_prob_threshold is not None and stream_end_score > stream_end_prob_threshold: 163 | return video_time, None 164 | if stream_end_prob_threshold is None and self.stream_end_score_sum > self.stream_end_score_sum_threshold: 165 | self.stream_end_score_sum = 0 166 | return video_time, None 167 | return None, None 168 | 169 | def reset(self, ): 170 | self.query_queue = collections.deque() 171 | self.frame_embeds_queue = collections.deque() 172 | self.video_time = 0 173 | self.frame_idx = 0 174 | self.last_role = 'system' 175 | self.video_tensor = None 176 | self.last_ids = torch.tensor([[]], device='cuda', dtype=torch.long) 177 | self.past_key_values = None 178 | self.debug_data_list = list() 179 | self.generated_token_ids = list() 180 | self.num_frames_no_reply = 0 181 | self.stream_end_prob_list = list() 182 | self.stream_end_score_sum = 0 183 | self.consecutive_n_frames = 0 184 | 185 | @torch.no_grad() 186 | def load_video(self, video_path): 187 | self.video_tensor = read_video(video_path, pts_unit='sec', output_format='TCHW')[0] 188 | if self.image_processor is not None: 189 | self.video_tensor = self.image_processor.preprocess(self.video_tensor, return_tensors='pt')['pixel_values'].to('cuda').to(self.torch_dtype) 190 | else: 191 | self.video_tensor = self.video_tensor.to('cuda').to(self.torch_dtype) 192 | self.num_video_frames = self.video_tensor.size(0) 193 | self.video_duration = self.video_tensor.size(0) / self.frame_fps 194 | logger.warning(f'{video_path} -> {self.video_tensor.shape}, {self.frame_fps} FPS') 195 | 196 | def input_video_stream(self, video_frames): 197 | """ 198 | input all video to video_frames at a time 199 | video_frames: input to visual encoder, after preprocessor 200 | """ 201 | torch.cuda.empty_cache() # prevent oov on small gpus 202 | if self.image_processor is not None: 203 | video_frames = self.image_processor.preprocess(video_frames, return_tensors='pt')['pixel_values'].to('cuda').to(self.torch_dtype) 204 | else: 205 | video_frames = video_frames.to('cuda').to(self.torch_dtype) 206 | 207 | # encode the video frames in batches to prevent oov 208 | batch_size = 32 209 | for batch_i in range(0, math.ceil(len(video_frames) / batch_size)): 210 | video_frames_batch = video_frames[batch_i*batch_size: batch_i*batch_size+batch_size] 211 | frame_embeds = self.model.visual_embed(video_frames_batch).split(self.frame_num_tokens) 212 | self.frame_embeds_queue.extend([((r + batch_i * batch_size) / self.frame_fps, f.to('cpu')) for r, f in enumerate(frame_embeds)]) 213 | del frame_embeds 214 | torch.cuda.empty_cache() # prevent oov on small gpus? 215 | 216 | def input_query_stream(self, conversation): 217 | for turn in conversation: 218 | if turn['role'] == 'user': 219 | self.query_queue.append((turn['time'], turn['content'])) 220 | 221 | def _encode_frame(self): 222 | """ 223 | returns: informative_score, relevance_score 224 | """ 225 | if not self.frame_embeds_queue: 226 | return None, None 227 | 228 | video_time, frame_embeds = self.frame_embeds_queue.popleft() 229 | if not self.past_key_values: 230 | self.last_ids = self._start_ids 231 | elif self.last_role == 'assistant' and not self.remove_assistant_turns: 232 | self.last_ids = torch.cat([self.last_ids, self._added_stream_prompt_ids], dim=1) 233 | else: # last_role is stream, now we just input another frame 234 | self.last_ids = torch.tensor([[]], device='cuda', dtype=torch.long) 235 | inputs_embeds = torch.cat([ 236 | self.model.get_input_embeddings()(self.last_ids).view(1, -1, self.hidden_size), 237 | frame_embeds.view(1, -1, self.hidden_size).to(self.last_ids.device), 238 | ], dim=1) 239 | outputs = self.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=self.past_key_values, return_dict=True) 240 | self.past_key_values = outputs.past_key_values 241 | self.frame_idx += 1 242 | self.num_frames_no_reply += 1 243 | informative_score = outputs.informative_logits[0,-1].softmax(dim=-1)[1].item() 244 | relevance_score = outputs.relevance_logits[0,-1].softmax(dim=-1)[1].item() 245 | self.last_role = 'stream' 246 | return {"informative_score": informative_score, "relevance_score": relevance_score} 247 | 248 | def _encode_query(self): 249 | query_time, query = self.query_queue.popleft() 250 | self.last_ids = self.tokenizer.apply_chat_template([{'role': 'user', 'content': query}], add_stream_query_prompt=self.last_role == 'stream', add_stream_prompt=True, return_tensors='pt').to('cuda') 251 | inputs_embeds = self.model.get_input_embeddings()(self.last_ids) 252 | outputs = self.model(inputs_embeds=inputs_embeds, past_key_values=self.past_key_values, use_cache=True, return_dict=True) 253 | self.past_key_values = outputs.past_key_values 254 | self.last_ids = outputs.logits[:, -1:].argmax(dim=-1) 255 | self.last_role = 'user' 256 | 257 | def _generate_response(self): 258 | self.last_ids = self._added_stream_generation_ids 259 | inputs_embeds = self.model.get_input_embeddings()(self.last_ids) 260 | output_ids, past_key_values, self.generated_token_ids = fast_greedy_generate( 261 | model=self.model, inputs_embeds=inputs_embeds, past_key_values=self.past_key_values, eos_token_id=self.eos_token_id, inplace_output_ids=self.inplace_output_ids, 262 | repetition_penalty=self.repetition_penalty, generated_token_ids=self.generated_token_ids 263 | ) 264 | 265 | if not self.remove_assistant_turns: 266 | self.past_key_values = past_key_values 267 | self.last_ids = output_ids[:, -1:] 268 | else: 269 | self.last_ids = torch.tensor([[]], device='cuda', dtype=torch.long) 270 | 271 | response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) 272 | self.num_frames_no_reply = 0 273 | self.last_role = 'assistant' 274 | return response 275 | 276 | @torch.no_grad() 277 | def inference(self): 278 | model_response_list = [{'time': q[0], 'content': q[1], 'role': 'user'} for q in self.query_queue] 279 | while self.frame_embeds_queue: 280 | # 1. check if a user query is at current time 281 | if self.query_queue and self.video_time >= self.query_queue[0][0]: 282 | self._encode_query() 283 | 284 | # 2. input a frame, and update the scores list 285 | video_scores = self._encode_frame() 286 | self.debug_data_list.append(dict(time=self.video_time, **video_scores)) 287 | 288 | # 3. check the scores, if need to generate a response 289 | need_response = False 290 | stream_end_score = sum([v for k, v in video_scores.items() if k in self.score_heads]) 291 | self.stream_end_prob_list.append(stream_end_score) 292 | self.stream_end_score_sum += stream_end_score 293 | if isinstance(self.running_list_length, int) and self.running_list_length > 0: 294 | self.stream_end_prob_list = self.stream_end_prob_list[-self.running_list_length:] 295 | if self.stream_end_score_sum_threshold is not None and self.stream_end_score_sum > self.stream_end_score_sum_threshold: 296 | need_response = True 297 | self.stream_end_score_sum = 0 298 | if self.stream_end_prob_threshold is not None and stream_end_score > self.stream_end_prob_threshold: 299 | need_response = True 300 | 301 | # 4. record the responses 302 | if need_response: 303 | response = self._generate_response() 304 | model_response_list.append({'time': self.video_time, 'content': response, 'role': 'assistant'}) 305 | self.num_frames_no_reply = 0 306 | self.consecutive_n_frames = 0 307 | else: 308 | response = None 309 | 310 | # 5. update the video time 311 | self.video_time += 1 / self.frame_fps 312 | 313 | return sorted(model_response_list, key=lambda x: x['time']) 314 | 315 | 316 | class DoNothingDataCollator: 317 | def __call__(self, batch): 318 | # Since batch size is 1, just return the first (and only) element 319 | return batch[0] 320 | 321 | 322 | def round_numbers(data, n): 323 | if isinstance(data, list): 324 | return [round_numbers(d, n) for d in data] 325 | elif isinstance(data, dict): 326 | return {k: round_numbers(v, n) for k, v in data.items()} 327 | elif isinstance(data, float): 328 | return round(data, n) 329 | return data 330 | 331 | 332 | if __name__ == '__main__': 333 | args = parse_args('test') 334 | print(args) 335 | dataset = FastAndAccurateStreamingVideoQADataset( 336 | data_file=args.test_fname, video_base_folder=args.input_dir, 337 | start_idx=args.start_idx, end_idx=args.end_idx, 338 | output_fps=args.frame_fps, output_resolution=args.frame_resolution, max_num_frames=args.max_num_frames, 339 | time_instruction_format=args.time_instruction_format, system_prompt=args.system_prompt 340 | ) 341 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4, collate_fn=DoNothingDataCollator()) 342 | 343 | infer = LiveInferForBenchmark(args) 344 | f_out = open(args.output_fname, 'w') 345 | 346 | if args.is_online_model: 347 | if not args.grounding_mode: 348 | for data_i, data in enumerate(tqdm(dataloader)): 349 | question_id, video_frames, conversation, fps, video_duration = data 350 | if question_id is None: continue 351 | infer.reset() 352 | print(f"num frames and fps for {question_id}: {len(video_frames)}, {fps}") 353 | infer.set_fps(fps=fps) 354 | infer.input_video_stream(video_frames) 355 | infer.input_query_stream(conversation) 356 | model_response_list = infer.inference() 357 | res = {'question_id': question_id, 'model_response_list': model_response_list, 'video_duration': video_duration} 358 | res['debug_data'] = round_numbers(infer.debug_data_list, 3) 359 | f_out.write(json.dumps(res) + '\n') 360 | if data_i % 5 == 0: 361 | f_out.flush() 362 | f_out.close() 363 | 364 | else: 365 | infer.first_n_frames_no_generate = 100000 # so the generation process is never called, we just want `relevance_score` results 366 | for data_i, data in enumerate(tqdm(dataloader)): 367 | question_id, video_frames, conversation, fps, video_duration = data 368 | if question_id is None: continue 369 | infer.reset() 370 | print(f"num frames and fps for {question_id}: {len(video_frames)}, {fps}") 371 | infer.set_fps(fps=fps) 372 | infer.input_video_stream(video_frames) 373 | infer.input_query_stream(conversation) 374 | model_response_list = infer.inference() 375 | res = {'question_id': question_id, 'model_response_list': model_response_list, 'video_duration': video_duration} 376 | res['debug_data'] = round_numbers(infer.debug_data_list, 3) 377 | f_out.write(json.dumps(res) + '\n') 378 | if data_i % 5 == 0: 379 | f_out.flush() 380 | f_out.close() 381 | 382 | else: 383 | # llava onevision baseline 384 | tokenizer, model, image_processor, max_length = load_pretrained_model(args.llm_pretrained, None, "llava_qwen", device_map="auto", attn_implementation=args.attn_implementation) # Add any other thing you want to pass in llava_model_args 385 | model.eval() 386 | 387 | if args.lora_pretrained is not None: 388 | print(f"loading lora ckpt from {args.lora_pretrained}, and setting mm_spatial_pool_stride to {args.video_pooling_stride}") 389 | model = PeftModel.from_pretrained(model, args.lora_pretrained, is_trainable=False) 390 | model.config.mm_spatial_pool_stride = args.video_pooling_stride 391 | 392 | f_out = open(args.output_fname, 'w') 393 | for data_i, data in enumerate(tqdm(dataloader)): 394 | question_id, video_frames, conversation, fps, video_duration = data 395 | if question_id is None: continue 396 | conv_template = "qwen_1_5" 397 | original_question = [e['content'] for e in conversation if e['role'] == 'user'][0] 398 | question = f"{DEFAULT_IMAGE_TOKEN}\n{original_question}" 399 | conv = copy.deepcopy(conv_templates[conv_template]) 400 | conv.append_message(conv.roles[0], question) 401 | conv.append_message(conv.roles[1], None) 402 | prompt_question = conv.get_prompt() 403 | if data_i < 5: 404 | print(f'model input at example {data_i}: {prompt_question}') 405 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device) 406 | image_sizes = [frame.size() for frame in video_frames] 407 | image_tensor = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().to(model.device) 408 | modalities = ["video"] * len(video_frames) 409 | cont = model.generate( 410 | input_ids, 411 | images=[image_tensor], 412 | image_sizes=image_sizes, 413 | do_sample=False, 414 | temperature=0, 415 | max_new_tokens=512, 416 | modalities=modalities, 417 | ) 418 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) 419 | res = {'question_id': question_id, 'model_response': text_outputs, 'question': original_question, 'video_duration': video_duration} 420 | f_out.write(json.dumps(res) + '\n') 421 | if data_i % 10 == 0: 422 | f_out.flush() 423 | f_out.close() 424 | -------------------------------------------------------------------------------- /test/openai_batch.py: -------------------------------------------------------------------------------- 1 | # reformat model output to opanai batch input, and reformat openai batch output to our sentence similarity matrix formatted eval results 2 | import argparse, json, re, os 3 | from tqdm import tqdm 4 | import numpy as np 5 | 6 | def convert_to_online_format(example): 7 | model_response_list = list() 8 | # vtimellm format 9 | pattern = r"From (\d+) to (\d+), (.*)" 10 | matches = re.findall(pattern, example['model_response'][0]) 11 | video_length = example['video_duration'] 12 | for match in matches: 13 | reply_time = (int(match[0]) / 100 * video_length + int(match[1]) / 100 * video_length) / 2 14 | caption = match[2] 15 | model_response_list.append({'time': reply_time, 'content': caption, 'role': 'assistant'}) 16 | 17 | # timechat format 18 | pattern = r"(\d+\.\d+) - (\d+\.\d+)\s*seconds,\s*(.*)" 19 | matches = re.findall(pattern, example['model_response'][0]) 20 | for match in matches: 21 | start_time, end_time, caption = float(match[0]), float(match[1]), match[2] 22 | model_response_list.append({'time': (start_time + end_time) / 2, 'content': caption, 'role': 'assistant'}) 23 | 24 | if len(model_response_list) == 0: 25 | # the answer is not generated as grounded format; we use the entire response as 1 turn of answer, and set time = -1 26 | model_response_list.append({'time': -1, 'content': example['model_response'][0], 'role': 'assistant'}) 27 | example['model_response_list'] = model_response_list 28 | 29 | 30 | def model_output_to_openai_batch_input( 31 | pred_file, gold_file, output_file, is_online_model=True, 32 | num_examples=None, last_question_id=None 33 | ): 34 | pred_examples = [json.loads(line) for line in open(pred_file)] 35 | gold_examples = json.load(open(gold_file)) 36 | gold_dict = {example['question_id']: example for example in gold_examples} 37 | print(f"{len(pred_examples)} pred examples to evaluate") 38 | 39 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 40 | f_out = open(output_file, 'w') 41 | 42 | for example_i, example in enumerate(tqdm(pred_examples)): 43 | if example_i == num_examples: break 44 | if is_online_model: 45 | pass 46 | else: # convert the timechat/vtimellm generated text to online model format 47 | convert_to_online_format(example) 48 | 49 | if 'model_response_list' in example: 50 | if 'debug_data' in example: del example['debug_data'] 51 | answers = [e for e in example['model_response_list'] if e['role'] == 'assistant'] 52 | if not len(answers): 53 | continue 54 | pred_list = [e['content'] for e in answers] 55 | pred_time_list = [e['time'] for e in answers] 56 | 57 | gold_list = [e['content'] for e in gold_dict[example['question_id']]['answer']] 58 | gold_timespan_list = [e['time'] for e in gold_dict[example['question_id']]['answer']] 59 | 60 | # in case that there may be some identical turns, we only need to evaluate once for them all 61 | pred_text_to_turn_i = dict() 62 | for turn_i, text in enumerate(pred_list): 63 | if text not in pred_text_to_turn_i: 64 | pred_text_to_turn_i[text] = list() 65 | pred_text_to_turn_i[text].append(turn_i) 66 | 67 | gold_text_to_turn_i = dict() 68 | for turn_i, text in enumerate(gold_list): 69 | if text not in gold_text_to_turn_i: 70 | gold_text_to_turn_i[text] = list() 71 | gold_text_to_turn_i[text].append(turn_i) 72 | 73 | question = gold_dict[example['question_id']]['conversation'][0]['content'] 74 | for gold_answer, gold_turn_ids in gold_text_to_turn_i.items(): 75 | for pred_answer, pred_turn_ids in pred_text_to_turn_i.items(): 76 | 77 | # we only need to evaluate the pred answer that is in the gold span to for the in-span metric 78 | gold_timespan = [gold_timespan_list[i] for i in gold_turn_ids] 79 | pred_time = [pred_time_list[i] for i in pred_turn_ids] 80 | pred_time_in_gold_timespan_list = [span[0] <= time <= span[1] or time == -1 for span in gold_timespan for time in pred_time] 81 | if not any(pred_time_in_gold_timespan_list): 82 | continue 83 | 84 | conversation = [ 85 | {"role": "system", "content": ( 86 | "You are an evaluator for a video question answering system. Your task is to rate the " 87 | "correctness of the predicted answers against the ground truth answers. Use the following scale to assign a score:\n" 88 | "- 5: Perfect match; the predicted answer is completely correct and contains all the relevant information.\n" 89 | "- 4: Mostly correct; the predicted answer is largely accurate but may have minor omissions or slight inaccuracies.\n" 90 | "- 3: Partially correct; the predicted answer has some correct information, but also contains significant inaccuracies or missing key points.\n" 91 | "- 2: Slightly correct; the predicted answer has only a few correct elements, but most of the information is incorrect or irrelevant, or the predicted answer conflicts with the ground truth answer.\n" 92 | "- 1: Incorrect; the predicted answer is entirely wrong or does not address the question at all.\n" 93 | "Only reply with a number from 1 to 5, and nothing else.") 94 | }, 95 | {"role": "user", "content": f"Question: {question}\nGround Truth Answer: {gold_answer}\nPredicted Answer: {pred_answer}"}, 96 | ] 97 | custom_id = f"{example['question_id']}*{','.join(map(str, gold_turn_ids))}*{','.join(map(str, pred_turn_ids))}" 98 | output_example = { 99 | "custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", 100 | "body": {"model": "gpt-4o-2024-08-06", "messages": conversation} 101 | } 102 | f_out.write(json.dumps(output_example) + '\n') 103 | 104 | if example['question_id'] == last_question_id: 105 | break 106 | 107 | 108 | def openai_batch_output_to_eval_results( 109 | pred_file, openai_file, gold_file, output_file, is_online_model=True, 110 | num_examples=None, last_question_id=None 111 | ): 112 | assert not os.path.exists(output_file), "check your filename, why do you want to create this file again?" 113 | 114 | openai_scores_dict = dict() 115 | for line in open(openai_file): 116 | openai_example = json.loads(line) 117 | question_id, gold_turn_ids, pred_turn_ids = openai_example['custom_id'].split('*') 118 | gold_turn_ids = list(map(int, gold_turn_ids.split(','))) 119 | pred_turn_ids = list(map(int, pred_turn_ids.split(','))) 120 | if question_id not in openai_scores_dict: 121 | openai_scores_dict[question_id] = dict() 122 | for gold_turn_id in gold_turn_ids: 123 | for pred_turn_id in pred_turn_ids: 124 | openai_scores_dict[question_id][(gold_turn_id, pred_turn_id)] = int(openai_example['response']['body']['choices'][0]['message']['content']) 125 | 126 | pred_examples = [json.loads(line) for line in open(pred_file)] 127 | gold_examples = json.load(open(gold_file)) 128 | gold_dict = {example['question_id']: example for example in gold_examples} 129 | 130 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 131 | f_out = open(output_file, 'w') 132 | for example_i, example in enumerate(pred_examples): 133 | if example_i == num_examples: break 134 | if not is_online_model: 135 | convert_to_online_format(example) 136 | 137 | if 'model_response_list' in example: 138 | if 'debug_data' in example: del example['debug_data'] 139 | answers = [e for e in example['model_response_list'] if e['role'] == 'assistant'] 140 | if not len(answers): 141 | continue 142 | pred_list = [e['content'] for e in answers] 143 | gold_list = [e['content'] for e in gold_dict[example['question_id']]['answer']] 144 | 145 | score_matrix = np.ones((len(gold_list), len(pred_list))) 146 | if example['question_id'] in openai_scores_dict: 147 | for (gold_turn_id, pred_turn_id), score in openai_scores_dict[example['question_id']].items(): 148 | score_matrix[gold_turn_id, pred_turn_id] = score 149 | example['evaluator_output'] = score_matrix.tolist() 150 | example['answer'] = gold_list 151 | example['answer_time'] = [turn['time'] for turn in gold_dict[example['question_id']]['answer']] 152 | f_out.write(json.dumps(example) + '\n') 153 | if example['question_id'] == last_question_id: 154 | break 155 | 156 | 157 | def openai_send_batch(batch_input_fname, description="debug"): 158 | from openai import OpenAI 159 | client = OpenAI() 160 | batch_input_file = client.files.create(file=open(batch_input_fname, "rb"), purpose="batch") 161 | batch_input_file_id = batch_input_file.id 162 | batch_metadata = client.batches.create( 163 | input_file_id=batch_input_file_id, 164 | endpoint="/v1/chat/completions", completion_window="24h", 165 | metadata={"description": description}) 166 | print(batch_input_fname) 167 | print(batch_metadata) 168 | 169 | 170 | def openai_get_batch(output_file_id, output_fname): 171 | from openai import OpenAI 172 | client = OpenAI() 173 | if output_file_id is not None: 174 | file_response = client.files.content(output_file_id) 175 | print(f'saving result file {output_file_id} to {output_fname}') 176 | os.makedirs(os.path.dirname(output_fname), exist_ok=True) 177 | with open(output_fname, 'w') as f_out: 178 | f_out.write(file_response.text) 179 | else: 180 | print('output_file_id is None, batch not completed') 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('--func', type=str, default='batch_input') 186 | parser.add_argument('--file_id', type=str) 187 | parser.add_argument('--description', type=str) 188 | parser.add_argument('--pred_file', type=str) 189 | parser.add_argument('--openai_file', type=str) 190 | parser.add_argument('--gold_file', type=str) 191 | parser.add_argument('--output_file', type=str) 192 | parser.add_argument('--is_online_model', type=int, default=1) 193 | # parser.add_argument('--num_examples', type=int, default=2000) 194 | # parser.add_argument('--last_question_id', type=str, default='d05eUOI81LA.35.mp4') 195 | args = parser.parse_args() 196 | 197 | if args.func == 'batch_input': 198 | model_output_to_openai_batch_input( 199 | pred_file=args.pred_file, gold_file=args.gold_file, output_file=args.output_file, 200 | is_online_model=bool(args.is_online_model) #, num_examples=args.num_examples, last_question_id=args.last_question_id, 201 | ) 202 | 203 | elif args.func == 'batch_output': 204 | openai_batch_output_to_eval_results( 205 | pred_file=args.pred_file, openai_file=args.openai_file, gold_file=args.gold_file, output_file=args.output_file, 206 | is_online_model=bool(args.is_online_model) #, num_examples=args.num_examples, last_question_id=args.last_question_id, 207 | ) 208 | 209 | elif args.func == 'send_batch': 210 | openai_send_batch(batch_input_fname=args.pred_file, description=args.description) 211 | 212 | elif args.func == 'get_batch': 213 | openai_get_batch(output_file_id=args.file_id, output_fname=args.output_file) 214 | 215 | elif args.func == 'check_batch': 216 | from openai import OpenAI 217 | client = OpenAI() 218 | for task in client.batches.list(limit=6).data: 219 | print(task, end='\n\n') 220 | 221 | -------------------------------------------------------------------------------- /test/qvh/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict, defaultdict 3 | import json 4 | import time 5 | import copy 6 | import multiprocessing as mp 7 | from .utils import compute_average_precision_detection, \ 8 | compute_temporal_iou_batch_cross, compute_temporal_iou_batch_paired, load_jsonl, get_ap 9 | 10 | 11 | def compute_average_precision_detection_wrapper( 12 | input_triple, tiou_thresholds=np.linspace(0.5, 0.95, 10)): 13 | qid, ground_truth, prediction = input_triple 14 | scores = compute_average_precision_detection( 15 | ground_truth, prediction, tiou_thresholds=tiou_thresholds) 16 | return qid, scores 17 | 18 | 19 | def compute_mr_ap(submission, ground_truth, iou_thds=np.linspace(0.5, 0.95, 10), 20 | max_gt_windows=None, max_pred_windows=10, num_workers=8, chunksize=50): 21 | iou_thds = [float(f"{e:.2f}") for e in iou_thds] 22 | pred_qid2data = defaultdict(list) 23 | for d in submission: 24 | pred_windows = d["pred_relevant_windows"][:max_pred_windows] \ 25 | if max_pred_windows is not None else d["pred_relevant_windows"] 26 | qid = d["qid"] 27 | for w in pred_windows: 28 | pred_qid2data[qid].append({ 29 | "video-id": d["qid"], # in order to use the API 30 | "t-start": w[0], 31 | "t-end": w[1], 32 | "score": w[2] 33 | }) 34 | 35 | gt_qid2data = defaultdict(list) 36 | for d in ground_truth: 37 | gt_windows = d["relevant_windows"][:max_gt_windows] \ 38 | if max_gt_windows is not None else d["relevant_windows"] 39 | qid = d["qid"] 40 | for w in gt_windows: 41 | gt_qid2data[qid].append({ 42 | "video-id": d["qid"], 43 | "t-start": w[0], 44 | "t-end": w[1] 45 | }) 46 | qid2ap_list = {} 47 | # start_time = time.time() 48 | data_triples = [[qid, gt_qid2data[qid], pred_qid2data[qid]] for qid in pred_qid2data] 49 | from functools import partial 50 | compute_ap_from_triple = partial( 51 | compute_average_precision_detection_wrapper, tiou_thresholds=iou_thds) 52 | 53 | if num_workers > 1: 54 | with mp.Pool(num_workers) as pool: 55 | for qid, scores in pool.imap_unordered(compute_ap_from_triple, data_triples, chunksize=chunksize): 56 | qid2ap_list[qid] = scores 57 | else: 58 | for data_triple in data_triples: 59 | qid, scores = compute_ap_from_triple(data_triple) 60 | qid2ap_list[qid] = scores 61 | 62 | # print(f"compute_average_precision_detection {time.time() - start_time:.2f} seconds.") 63 | ap_array = np.array(list(qid2ap_list.values())) # (#queries, #thd) 64 | ap_thds = ap_array.mean(0) # mAP at different IoU thresholds. 65 | iou_thd2ap = dict(zip([str(e) for e in iou_thds], ap_thds)) 66 | iou_thd2ap["average"] = np.mean(ap_thds) 67 | # formatting 68 | iou_thd2ap = {k: float(f"{100 * v:.2f}") for k, v in iou_thd2ap.items()} 69 | return iou_thd2ap 70 | 71 | 72 | def compute_mr_r1(submission, ground_truth, iou_thds=np.linspace(0.5, 0.95, 10)): 73 | """If a predicted segment has IoU >= iou_thd with one of the 1st GT segment, we define it positive""" 74 | iou_thds = [float(f"{e:.2f}") for e in iou_thds] 75 | pred_qid2window = {d["qid"]: d["pred_relevant_windows"][0][:2] for d in submission} # :2 rm scores 76 | # gt_qid2window = {d["qid"]: d["relevant_windows"][0] for d in ground_truth} 77 | gt_qid2window = {} 78 | for d in ground_truth: 79 | cur_gt_windows = d["relevant_windows"] 80 | cur_qid = d["qid"] 81 | cur_max_iou_idx = 0 82 | if len(cur_gt_windows) > 0: # select the GT window that has the highest IoU 83 | cur_ious = compute_temporal_iou_batch_cross( 84 | np.array([pred_qid2window[cur_qid]]), np.array(d["relevant_windows"]) 85 | )[0] 86 | cur_max_iou_idx = np.argmax(cur_ious) 87 | gt_qid2window[cur_qid] = cur_gt_windows[cur_max_iou_idx] 88 | 89 | qids = list(pred_qid2window.keys()) 90 | pred_windows = np.array([pred_qid2window[k] for k in qids]).astype(float) 91 | gt_windows = np.array([gt_qid2window[k] for k in qids]).astype(float) 92 | pred_gt_iou = compute_temporal_iou_batch_paired(pred_windows, gt_windows) 93 | iou_thd2recall_at_one = {} 94 | for thd in iou_thds: 95 | iou_thd2recall_at_one[str(thd)] = float(f"{np.mean(pred_gt_iou >= thd) * 100:.2f}") 96 | return iou_thd2recall_at_one 97 | 98 | 99 | def get_window_len(window): 100 | return window[1] - window[0] 101 | 102 | 103 | def get_data_by_range(submission, ground_truth, len_range): 104 | """ keep queries with ground truth window length in the specified length range. 105 | Args: 106 | submission: 107 | ground_truth: 108 | len_range: [min_l (int), max_l (int)]. the range is (min_l, max_l], i.e., min_l < l <= max_l 109 | """ 110 | min_l, max_l = len_range 111 | if min_l == 0 and max_l == 150: # min and max l in dataset 112 | return submission, ground_truth 113 | 114 | # only keep ground truth with windows in the specified length range 115 | # if multiple GT windows exists, we only keep the ones in the range 116 | ground_truth_in_range = [] 117 | gt_qids_in_range = set() 118 | for d in ground_truth: 119 | rel_windows_in_range = [ 120 | w for w in d["relevant_windows"] if min_l < get_window_len(w) <= max_l] 121 | if len(rel_windows_in_range) > 0: 122 | d = copy.deepcopy(d) 123 | d["relevant_windows"] = rel_windows_in_range 124 | ground_truth_in_range.append(d) 125 | gt_qids_in_range.add(d["qid"]) 126 | 127 | # keep only submissions for ground_truth_in_range 128 | submission_in_range = [] 129 | for d in submission: 130 | if d["qid"] in gt_qids_in_range: 131 | submission_in_range.append(copy.deepcopy(d)) 132 | 133 | return submission_in_range, ground_truth_in_range 134 | 135 | 136 | def eval_moment_retrieval(submission, ground_truth, verbose=True): 137 | length_ranges = [[0, 10], [10, 30], [30, 150], [0, 150], ] # 138 | range_names = ["short", "middle", "long", "full"] 139 | 140 | ret_metrics = {} 141 | for l_range, name in zip(length_ranges, range_names): 142 | if verbose: 143 | start_time = time.time() 144 | _submission, _ground_truth = get_data_by_range(submission, ground_truth, l_range) 145 | print(f"{name}: {l_range}, {len(_ground_truth)}/{len(ground_truth)}=" 146 | f"{100*len(_ground_truth)/len(ground_truth):.2f} examples.") 147 | iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50) 148 | iou_thd2recall_at_one = compute_mr_r1(_submission, _ground_truth) 149 | ret_metrics[name] = {"MR-mAP": iou_thd2average_precision, "MR-R1": iou_thd2recall_at_one} 150 | if verbose: 151 | print(f"[eval_moment_retrieval] [{name}] {time.time() - start_time:.2f} seconds") 152 | return ret_metrics 153 | 154 | 155 | def compute_hl_hit1(qid2preds, qid2gt_scores_binary): 156 | qid2max_scored_clip_idx = {k: np.argmax(v["pred_saliency_scores"]) for k, v in qid2preds.items()} 157 | hit_scores = np.zeros((len(qid2preds), 3)) 158 | qids = list(qid2preds.keys()) 159 | for idx, qid in enumerate(qids): 160 | pred_clip_idx = qid2max_scored_clip_idx[qid] 161 | gt_scores_binary = qid2gt_scores_binary[qid] # (#clips, 3) 162 | if pred_clip_idx < len(gt_scores_binary): 163 | hit_scores[idx] = gt_scores_binary[pred_clip_idx] 164 | # aggregate scores from 3 separate annotations (3 workers) by taking the max. 165 | # then average scores from all queries. 166 | hit_at_one = float(f"{100 * np.mean(np.max(hit_scores, 1)):.2f}") 167 | return hit_at_one 168 | 169 | 170 | def compute_hl_ap(qid2preds, qid2gt_scores_binary, num_workers=8, chunksize=50): 171 | qid2pred_scores = {k: v["pred_saliency_scores"] for k, v in qid2preds.items()} 172 | ap_scores = np.zeros((len(qid2preds), 3)) # (#preds, 3) 173 | qids = list(qid2preds.keys()) 174 | input_tuples = [] 175 | for idx, qid in enumerate(qids): 176 | for w_idx in range(3): # annotation score idx 177 | y_true = qid2gt_scores_binary[qid][:, w_idx] 178 | y_predict = np.array(qid2pred_scores[qid]) 179 | input_tuples.append((idx, w_idx, y_true, y_predict)) 180 | 181 | if num_workers > 1: 182 | with mp.Pool(num_workers) as pool: 183 | for idx, w_idx, score in pool.imap_unordered( 184 | compute_ap_from_tuple, input_tuples, chunksize=chunksize): 185 | ap_scores[idx, w_idx] = score 186 | else: 187 | for input_tuple in input_tuples: 188 | idx, w_idx, score = compute_ap_from_tuple(input_tuple) 189 | ap_scores[idx, w_idx] = score 190 | 191 | # it's the same if we first average across different annotations, then average across queries 192 | # since all queries have the same #annotations. 193 | mean_ap = float(f"{100 * np.mean(ap_scores):.2f}") 194 | return mean_ap 195 | 196 | 197 | def compute_ap_from_tuple(input_tuple): 198 | idx, w_idx, y_true, y_predict = input_tuple 199 | if len(y_true) < len(y_predict): 200 | # print(f"len(y_true) < len(y_predict) {len(y_true), len(y_predict)}") 201 | y_predict = y_predict[:len(y_true)] 202 | elif len(y_true) > len(y_predict): 203 | # print(f"len(y_true) > len(y_predict) {len(y_true), len(y_predict)}") 204 | _y_predict = np.zeros(len(y_true)) 205 | _y_predict[:len(y_predict)] = y_predict 206 | y_predict = _y_predict 207 | 208 | score = get_ap(y_true, y_predict) 209 | return idx, w_idx, score 210 | 211 | 212 | def mk_gt_scores(gt_data, clip_length=2): 213 | """gt_data, dict, """ 214 | num_clips = int(gt_data["duration"] / clip_length) 215 | saliency_scores_full_video = np.zeros((num_clips, 3)) 216 | relevant_clip_ids = np.array(gt_data["relevant_clip_ids"]) # (#relevant_clip_ids, ) 217 | saliency_scores_relevant_clips = np.array(gt_data["saliency_scores"]) # (#relevant_clip_ids, 3) 218 | saliency_scores_full_video[relevant_clip_ids] = saliency_scores_relevant_clips 219 | return saliency_scores_full_video # (#clips_in_video, 3) the scores are in range [0, 4] 220 | 221 | 222 | def eval_highlight(submission, ground_truth, verbose=True): 223 | """ 224 | Args: 225 | submission: 226 | ground_truth: 227 | verbose: 228 | """ 229 | qid2preds = {d["qid"]: d for d in submission} 230 | qid2gt_scores_full_range = {d["qid"]: mk_gt_scores(d) for d in ground_truth} # scores in range [0, 4] 231 | # gt_saliency_score_min: int, in [0, 1, 2, 3, 4]. The minimum score for a positive clip. 232 | gt_saliency_score_min_list = [2, 3, 4] 233 | saliency_score_names = ["Fair", "Good", "VeryGood"] 234 | highlight_det_metrics = {} 235 | for gt_saliency_score_min, score_name in zip(gt_saliency_score_min_list, saliency_score_names): 236 | start_time = time.time() 237 | qid2gt_scores_binary = { 238 | k: (v >= gt_saliency_score_min).astype(float) 239 | for k, v in qid2gt_scores_full_range.items()} # scores in [0, 1] 240 | hit_at_one = compute_hl_hit1(qid2preds, qid2gt_scores_binary) 241 | mean_ap = compute_hl_ap(qid2preds, qid2gt_scores_binary) 242 | highlight_det_metrics[f"HL-min-{score_name}"] = {"HL-mAP": mean_ap, "HL-Hit1": hit_at_one} 243 | if verbose: 244 | print(f"Calculating highlight scores with min score {gt_saliency_score_min} ({score_name})") 245 | print(f"Time cost {time.time() - start_time:.2f} seconds") 246 | return highlight_det_metrics 247 | 248 | 249 | def eval_submission(submission, ground_truth, verbose=True, match_number=True): 250 | """ 251 | Args: 252 | submission: list(dict), each dict is { 253 | qid: str, 254 | query: str, 255 | vid: str, 256 | pred_relevant_windows: list([st, ed]), 257 | pred_saliency_scores: list(float), len == #clips in video. 258 | i.e., each clip in the video will have a saliency score. 259 | } 260 | ground_truth: list(dict), each dict is { 261 | "qid": 7803, 262 | "query": "Man in gray top walks from outside to inside.", 263 | "duration": 150, 264 | "vid": "RoripwjYFp8_360.0_510.0", 265 | "relevant_clip_ids": [13, 14, 15, 16, 17] 266 | "saliency_scores": [[4, 4, 2], [3, 4, 2], [2, 2, 3], [2, 2, 2], [0, 1, 3]] 267 | each sublist corresponds to one clip in relevant_clip_ids. 268 | The 3 elements in the sublist are scores from 3 different workers. The 269 | scores are in [0, 1, 2, 3, 4], meaning [Very Bad, ..., Good, Very Good] 270 | } 271 | verbose: 272 | match_number: 273 | 274 | Returns: 275 | 276 | """ 277 | pred_qids = set([e["qid"] for e in submission]) 278 | gt_qids = set([e["qid"] for e in ground_truth]) 279 | if match_number: 280 | assert pred_qids == gt_qids, \ 281 | f"qids in ground_truth and submission must match. " \ 282 | f"use `match_number=False` if you wish to disable this check" 283 | else: # only leave the items that exists in both submission and ground_truth 284 | shared_qids = pred_qids.intersection(gt_qids) 285 | submission = [e for e in submission if e["qid"] in shared_qids] 286 | ground_truth = [e for e in ground_truth if e["qid"] in shared_qids] 287 | 288 | eval_metrics = {} 289 | eval_metrics_brief = OrderedDict() 290 | if "pred_relevant_windows" in submission[0]: 291 | moment_ret_scores = eval_moment_retrieval( 292 | submission, ground_truth, verbose=verbose) 293 | eval_metrics.update(moment_ret_scores) 294 | moment_ret_scores_brief = { 295 | "MR-full-mAP": moment_ret_scores["full"]["MR-mAP"]["average"], 296 | "MR-full-mAP@0.5": moment_ret_scores["full"]["MR-mAP"]["0.5"], 297 | "MR-full-mAP@0.75": moment_ret_scores["full"]["MR-mAP"]["0.75"], 298 | "MR-short-mAP": moment_ret_scores["short"]["MR-mAP"]["average"], 299 | "MR-middle-mAP": moment_ret_scores["middle"]["MR-mAP"]["average"], 300 | "MR-long-mAP": moment_ret_scores["long"]["MR-mAP"]["average"], 301 | "MR-full-R1@0.5": moment_ret_scores["full"]["MR-R1"]["0.5"], 302 | "MR-full-R1@0.7": moment_ret_scores["full"]["MR-R1"]["0.7"], 303 | } 304 | eval_metrics_brief.update( 305 | sorted([(k, v) for k, v in moment_ret_scores_brief.items()], key=lambda x: x[0])) 306 | 307 | if "pred_saliency_scores" in submission[0]: 308 | highlight_det_scores = eval_highlight( 309 | submission, ground_truth, verbose=verbose) 310 | eval_metrics.update(highlight_det_scores) 311 | highlight_det_scores_brief = dict([ 312 | (f"{k}-{sub_k.split('-')[1]}", v[sub_k]) 313 | for k, v in highlight_det_scores.items() for sub_k in v]) 314 | eval_metrics_brief.update(highlight_det_scores_brief) 315 | 316 | # sort by keys 317 | final_eval_metrics = OrderedDict() 318 | final_eval_metrics["brief"] = eval_metrics_brief 319 | final_eval_metrics.update(sorted([(k, v) for k, v in eval_metrics.items()], key=lambda x: x[0])) 320 | return final_eval_metrics 321 | 322 | 323 | def eval_main(): 324 | import argparse 325 | parser = argparse.ArgumentParser(description="Moments and Highlights Evaluation Script") 326 | parser.add_argument("--submission_path", type=str, help="path to generated prediction file") 327 | parser.add_argument("--gt_path", type=str, help="path to GT file") 328 | parser.add_argument("--save_path", type=str, help="path to save the results") 329 | parser.add_argument("--not_verbose", action="store_true") 330 | args = parser.parse_args() 331 | 332 | verbose = not args.not_verbose 333 | submission = load_jsonl(args.submission_path) 334 | gt = load_jsonl(args.gt_path) 335 | results = eval_submission(submission, gt, verbose=verbose) 336 | if verbose: 337 | print(json.dumps(results, indent=4)) 338 | 339 | with open(args.save_path, "w") as f: 340 | f.write(json.dumps(results, indent=4)) 341 | 342 | 343 | if __name__ == '__main__': 344 | eval_main() -------------------------------------------------------------------------------- /test/qvh/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from MMAction2 3 | https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_detection.py 4 | """ 5 | import json 6 | import numpy as np 7 | from sklearn.metrics import precision_recall_curve 8 | 9 | 10 | def load_jsonl(filename): 11 | with open(filename, "r") as f: 12 | return [json.loads(l.strip("\n")) for l in f.readlines()] 13 | 14 | 15 | def compute_temporal_iou_batch_paired(pred_windows, gt_windows): 16 | """ compute intersection-over-union along temporal axis for each pair of windows in pred_windows and gt_windows. 17 | Args: 18 | pred_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N 19 | gt_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N 20 | Returns: 21 | iou (float): np.ndarray, (N, ) 22 | 23 | References: 24 | for np.divide with zeros, see https://stackoverflow.com/a/37977222 25 | """ 26 | intersection = np.maximum( 27 | 0, np.minimum(pred_windows[:, 1], gt_windows[:, 1]) - np.maximum(pred_windows[:, 0], gt_windows[:, 0]) 28 | ) 29 | union = np.maximum(pred_windows[:, 1], gt_windows[:, 1]) \ 30 | - np.minimum(pred_windows[:, 0], gt_windows[:, 0]) # not the correct union though 31 | return np.divide(intersection, union, out=np.zeros_like(intersection), where=union != 0) 32 | 33 | 34 | def compute_temporal_iou_batch_cross(spans1, spans2): 35 | """ 36 | Args: 37 | spans1: (N, 2) np.ndarray, each row defines a span [st, ed] 38 | spans2: (M, 2) np.ndarray, ... 39 | 40 | Returns: 41 | iou: (N, M) np.ndarray 42 | union: (N, M) np.ndarray 43 | >>> spans1 = np.array([[0, 0.2, 0.9], [0.5, 1.0, 0.2]]) 44 | >>> spans2 = np.array([[0, 0.3], [0., 1.0]]) 45 | >>> compute_temporal_iou_batch_cross(spans1, spans2) 46 | (tensor([[0.6667, 0.2000], 47 | [0.0000, 0.5000]]), 48 | tensor([[0.3000, 1.0000], 49 | [0.8000, 1.0000]])) 50 | """ 51 | areas1 = spans1[:, 1] - spans1[:, 0] # (N, ) 52 | areas2 = spans2[:, 1] - spans2[:, 0] # (M, ) 53 | 54 | left = np.maximum(spans1[:, None, 0], spans2[None, :, 0]) # (N, M) 55 | right = np.minimum(spans1[:, None, 1], spans2[None, :, 1]) # (N, M) 56 | 57 | inter = np.clip(right - left, 0, None) # (N, M) 58 | union = areas1[:, None] + areas2[None, :] - inter # (N, M) 59 | 60 | iou = inter / union 61 | return iou, union 62 | 63 | 64 | def interpolated_precision_recall(precision, recall): 65 | """Interpolated AP - VOCdevkit from VOC 2011. 66 | 67 | Args: 68 | precision (np.ndarray): The precision of different thresholds. 69 | recall (np.ndarray): The recall of different thresholds. 70 | 71 | Returns: 72 | float: Average precision score. 73 | """ 74 | mprecision = np.hstack([[0], precision, [0]]) 75 | mrecall = np.hstack([[0], recall, [1]]) 76 | for i in range(len(mprecision) - 1)[::-1]: 77 | mprecision[i] = max(mprecision[i], mprecision[i + 1]) 78 | idx = np.where(mrecall[1::] != mrecall[0:-1])[0] + 1 79 | ap = np.sum((mrecall[idx] - mrecall[idx - 1]) * mprecision[idx]) 80 | return ap 81 | 82 | 83 | def compute_average_precision_detection(ground_truth, 84 | prediction, 85 | tiou_thresholds=np.linspace( 86 | 0.5, 0.95, 10)): 87 | """Compute average precision (detection task) between ground truth and 88 | predictions data frames. If multiple predictions occurs for the same 89 | predicted segment, only the one with highest score is matches as true 90 | positive. This code is greatly inspired by Pascal VOC devkit. 91 | 92 | Args: 93 | ground_truth (list[dict]): List containing the ground truth instances 94 | (dictionaries). Required keys are 'video-id', 't-start' and 95 | 't-end'. 96 | prediction (list[dict]): List containing the prediction instances 97 | (dictionaries). Required keys are: 'video-id', 't-start', 't-end' 98 | and 'score'. 99 | tiou_thresholds (np.ndarray): A 1darray indicates the temporal 100 | intersection over union threshold, which is optional. 101 | Default: ``np.linspace(0.5, 0.95, 10)``. 102 | 103 | Returns: 104 | Float: ap, Average precision score. 105 | """ 106 | num_thresholds = len(tiou_thresholds) 107 | num_gts = len(ground_truth) 108 | num_preds = len(prediction) 109 | ap = np.zeros(num_thresholds) 110 | if len(prediction) == 0: 111 | return ap 112 | 113 | num_positive = float(num_gts) 114 | lock_gt = np.ones((num_thresholds, num_gts)) * -1 115 | # Sort predictions by decreasing score order. 116 | prediction.sort(key=lambda x: -x['score']) 117 | # Initialize true positive and false positive vectors. 118 | tp = np.zeros((num_thresholds, num_preds)) 119 | fp = np.zeros((num_thresholds, num_preds)) 120 | 121 | # Adaptation to query faster 122 | ground_truth_by_videoid = {} 123 | for i, item in enumerate(ground_truth): 124 | item['index'] = i 125 | ground_truth_by_videoid.setdefault(item['video-id'], []).append(item) 126 | 127 | # Assigning true positive to truly grount truth instances. 128 | for idx, pred in enumerate(prediction): 129 | if pred['video-id'] in ground_truth_by_videoid: 130 | gts = ground_truth_by_videoid[pred['video-id']] 131 | else: 132 | fp[:, idx] = 1 133 | continue 134 | 135 | _pred = np.array([[pred['t-start'], pred['t-end']], ]) 136 | _gt = np.array([[gt['t-start'], gt['t-end']] for gt in gts]) 137 | tiou_arr = compute_temporal_iou_batch_cross(_pred, _gt)[0] 138 | 139 | tiou_arr = tiou_arr.reshape(-1) 140 | # We would like to retrieve the predictions with highest tiou score. 141 | tiou_sorted_idx = tiou_arr.argsort()[::-1] 142 | for t_idx, tiou_threshold in enumerate(tiou_thresholds): 143 | for j_idx in tiou_sorted_idx: 144 | if tiou_arr[j_idx] < tiou_threshold: 145 | fp[t_idx, idx] = 1 146 | break 147 | if lock_gt[t_idx, gts[j_idx]['index']] >= 0: 148 | continue 149 | # Assign as true positive after the filters above. 150 | tp[t_idx, idx] = 1 151 | lock_gt[t_idx, gts[j_idx]['index']] = idx 152 | break 153 | 154 | if fp[t_idx, idx] == 0 and tp[t_idx, idx] == 0: 155 | fp[t_idx, idx] = 1 156 | 157 | tp_cumsum = np.cumsum(tp, axis=1).astype(float) 158 | fp_cumsum = np.cumsum(fp, axis=1).astype(float) 159 | recall_cumsum = tp_cumsum / num_positive 160 | 161 | precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum) 162 | 163 | for t_idx in range(len(tiou_thresholds)): 164 | ap[t_idx] = interpolated_precision_recall(precision_cumsum[t_idx, :], 165 | recall_cumsum[t_idx, :]) 166 | return ap 167 | 168 | 169 | def get_ap(y_true, y_predict, interpolate=True, point_11=False): 170 | """ 171 | Average precision in different formats: (non-) interpolated and/or 11-point approximated 172 | point_11=True and interpolate=True corresponds to the 11-point interpolated AP used in 173 | the PASCAL VOC challenge up to the 2008 edition and has been verfied against the vlfeat implementation 174 | The exact average precision (interpolate=False, point_11=False) corresponds to the one of vl_feat 175 | 176 | :param y_true: list/ numpy vector of true labels in {0,1} for each element 177 | :param y_predict: predicted score for each element 178 | :param interpolate: Use interpolation? 179 | :param point_11: Use 11-point approximation to average precision? 180 | :return: average precision 181 | 182 | ref: https://github.com/gyglim/video2gif_dataset/blob/master/v2g_evaluation/__init__.py 183 | 184 | """ 185 | # Check inputs 186 | assert len(y_true) == len(y_predict), "Prediction and ground truth need to be of the same length" 187 | if len(set(y_true)) == 1: 188 | if y_true[0] == 0: 189 | return 0 # True labels are all zeros 190 | # raise ValueError('True labels cannot all be zero') 191 | else: 192 | return 1 193 | else: 194 | assert sorted(set(y_true)) == [0, 1], "Ground truth can only contain elements {0,1}" 195 | 196 | # Compute precision and recall 197 | precision, recall, _ = precision_recall_curve(y_true, y_predict) 198 | recall = recall.astype(np.float32) 199 | 200 | if interpolate: # Compute the interpolated precision 201 | for i in range(1, len(precision)): 202 | precision[i] = max(precision[i - 1], precision[i]) 203 | 204 | if point_11: # Compute the 11-point approximated AP 205 | precision_11 = [precision[np.where(recall >= t)[0][-1]] for t in np.arange(0, 1.01, 0.1)] 206 | return np.mean(precision_11) 207 | else: # Compute the AP using precision at every additionally recalled sample 208 | indices = np.where(np.diff(recall)) 209 | return np.mean(precision[indices]) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import asdict 3 | 4 | import torch 5 | from models import build_model_and_tokenizer, parse_args 6 | from data import ( 7 | build_concat_train_dataset_from_config, get_data_collator 8 | ) 9 | from transformers import Trainer 10 | 11 | 12 | class TrainerWithLossErrorCatch(Trainer): 13 | def training_step(self, model, inputs): 14 | try: 15 | loss = super().training_step(model, inputs) 16 | return loss 17 | except Exception as e: 18 | print(f"Error during training step: {e}, use a dummy loss = 0.0") 19 | return torch.tensor(0., device=self.args.device, 20 | dtype=torch.float16 if self.args.fp16 else torch.bfloat16 if self.args.bf16 else torch.float32) # dummy loss 21 | 22 | 23 | def rank0_print(*args): 24 | if torch.distributed.get_rank() == 0: 25 | print(*args) 26 | 27 | 28 | def train(): 29 | args = parse_args('train') 30 | rank0_print(args) 31 | model, tokenizer = build_model_and_tokenizer(is_training=True, **asdict(args)) 32 | if 'llava' in args.llm_pretrained: 33 | image_processor = model.get_vision_tower().image_processor 34 | else: 35 | image_processor = None 36 | 37 | for name, param in model.named_parameters(): 38 | rank0_print(name, param.shape, param.dtype, param.requires_grad) 39 | 40 | train_dataset_config = json.load(open(args.dataset_config)) 41 | train_dataset = build_concat_train_dataset_from_config( 42 | tokenizer=tokenizer, config=train_dataset_config 43 | ) 44 | data_collator = get_data_collator(tokenizer=tokenizer, image_processor=image_processor, model_config=model.config, **asdict(args)) 45 | 46 | args.gradient_checkpointing_kwargs = {'use_reentrant': False} 47 | trainer = TrainerWithLossErrorCatch( 48 | model=model, tokenizer=tokenizer, 49 | args=args, 50 | train_dataset=train_dataset, 51 | data_collator=data_collator, 52 | ) 53 | 54 | with torch.cuda.amp.autocast(): 55 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 56 | trainer.save_model() 57 | 58 | if __name__ == "__main__": 59 | train() 60 | -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import datetime 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | 9 | 10 | def setup_seeds(seed): 11 | seed = seed + get_rank() 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | cudnn.benchmark = False 16 | cudnn.deterministic = True 17 | 18 | # functions below inherited from lavis 19 | 20 | def get_rank(): 21 | if not (dist.is_available() and dist.is_initialized()): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_main_process(): 27 | return get_rank() == 0 28 | 29 | 30 | def setup_for_distributed(is_master): 31 | """ 32 | This function disables printing when not in master process 33 | """ 34 | import builtins as __builtin__ 35 | 36 | builtin_print = __builtin__.print 37 | 38 | def print(*args, **kwargs): 39 | force = kwargs.pop("force", False) 40 | if is_master or force: 41 | builtin_print(*args, **kwargs) 42 | 43 | __builtin__.print = print 44 | 45 | 46 | def init_deepspeed_distributed_mode(): 47 | import deepspeed 48 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 49 | rank = int(os.environ["RANK"]) 50 | world_size = int(os.environ["WORLD_SIZE"]) 51 | gpu = int(os.environ["LOCAL_RANK"]) 52 | elif "SLURM_PROCID" in os.environ: 53 | rank = int(os.environ["SLURM_PROCID"]) 54 | gpu = rank % torch.cuda.device_count() 55 | else: 56 | print("Not using distributed mode") 57 | return 58 | 59 | torch.cuda.set_device(gpu) 60 | dist_backend = "nccl" 61 | dist_url = 'env://' 62 | print( 63 | "| distributed init (rank {}, world {}): {}".format( 64 | rank, world_size, dist_url 65 | ), 66 | flush=True, 67 | ) 68 | deepspeed.init_distributed( 69 | dist_backend=dist_backend, 70 | init_method=dist_url, 71 | world_size=world_size, 72 | rank=rank, 73 | timeout=datetime.timedelta( 74 | days=365 75 | ), # allow auto-downloading and de-compressing 76 | ) 77 | torch.distributed.barrier() 78 | setup_for_distributed(rank == 0) 79 | --------------------------------------------------------------------------------