├── LICENSE ├── README.md ├── assets ├── example.mp4 ├── radar_comparison.png └── teaser2.png ├── flash_vstream ├── __init__.py ├── constants.py ├── conversation.py ├── eval_video │ ├── eval_activitynet_qa.py │ ├── eval_any_dataset_features.py │ ├── model_msvd_qa.py │ └── model_msvd_qa_featuresloader.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── compress_functions.py │ ├── language_model │ │ └── vstream_llama.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── vstream_arch.py ├── serve │ └── cli_video_stream.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── train.py │ ├── train_mem.py │ ├── train_xformers.py │ └── vstream_trainer.py └── utils.py ├── pyproject.toml ├── scripts ├── eval.sh ├── merge_lora_weights.py ├── realtime_cli.sh ├── train_and_eval.sh ├── zero0.json ├── zero1.json ├── zero2.json ├── zero3.json └── zero3_offload.json └── vstream.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams 2 | 3 | 4 | Haoji Zhang\*, 5 | Yiqin Wang\*, 6 | Yansong Tang †, 7 | Yong Liu, 8 | Jiashi Feng, 9 | Jifeng Dai, 10 | Xiaojie Jin†‡ 11 | 12 | \* Equally contributing first authors, †Correspondence, ‡Project Lead 13 | 14 | **Work done when interning at Bytedance.** 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=flash-vstream-memory-based-real-time) 26 | 27 | 28 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=flash-vstream-memory-based-real-time) 29 | 30 | 31 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/question-answering-on-next-qa-open-ended)](https://paperswithcode.com/sota/question-answering-on-next-qa-open-ended?p=flash-vstream-memory-based-real-time) 32 | 33 | 34 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=flash-vstream-memory-based-real-time) 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | We presented **Flash-VStream**, a noval LMM able to process extremely long video streams in real-time and respond to user queries simultaneously. 45 | 46 | We also proposed **VStream-QA**, a novel question answering benchmark specifically designed for online video streaming understanding. 47 | 48 |

49 | 50 |

51 |

52 | 53 |

54 | 55 | ## News 56 | - [2024/6/15] 🏅 Our team won the 1st Place at [Long-Term Video Question Answering Challenge](https://sites.google.com/view/loveucvpr24/track1) of [LOVEU Workshop@CVPR'24](https://sites.google.com/view/loveucvpr24/home). Here is our [certification](https://github.com/bytedance/Flash-VStream/assets/37479394/e1496dec-52c8-4707-aabe-fd1970c8f874). 57 | We used a Hierarchical Memory model based on Flash-VStream-7b. 58 | 59 | - [2024/06/12] 🔥 Flash-VStream is coming! We release the 60 | [homepage](https://invinciblewyq.github.io/vstream-page), 61 | [paper](https://arxiv.org/abs/2406.08085v1), 62 | [code](https://github.com/IVGSZ/Flash-VStream) 63 | and [model](https://huggingface.co/IVGSZ/Flash-VStream-7b) 64 | for Flash-VStream. 65 | We release the [dataset](https://huggingface.co/datasets/IVGSZ/VStream-QA) for VStream-QA benchmark. 66 | 67 | ## Contents 68 | - [Install](#install) 69 | - [Model](#model) 70 | - [Preparation](#preparation) 71 | - [Train](#train) 72 | - [Evaluation](#evaluation) 73 | - [Real-time CLI Inference](#Real-time-CLI-Inference) 74 | - [VStream-QA Benchmark](#VStream-QA-Benchmark) 75 | - [Citation](#citation) 76 | - [Acknowledgement](#acknowledgement) 77 | - [License](#license) 78 | 79 | ## Install 80 | Please follow the instructions below to install the required packages. 81 | 1. Clone this repository 82 | 83 | 2. Install Package 84 | ```bash 85 | conda create -n vstream python=3.10 -y 86 | conda activate vstream 87 | cd Flash-VStream 88 | pip install --upgrade pip 89 | pip install -e . 90 | ``` 91 | 92 | 3. Install additional packages for training cases 93 | ```bash 94 | pip install ninja 95 | pip install flash-attn --no-build-isolation 96 | ``` 97 | 98 | ## Model 99 | 100 | We provide our Flash-VStream models after Stage 1 and 2 finetuning: 101 | 102 | | Model | Weight | Initialized from LLM | Initialized from ViT | 103 | | --- | --- | --- | --- | 104 | | Flash-VStream-7b | [Flash-VStream-7b](https://huggingface.co/IVGSZ/Flash-VStream-7b) | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) | 105 | 106 | 107 | ## Preparation 108 | 109 | ### Dataset 110 | 111 | **Image VQA Dataset.** 112 | Please organize the training Image VQA training data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md) and evaluation data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md). 113 | Please put the pretraining data, finetuning data, and evaluation data in `pretrain`, `finetune`, and `eval_video` folder following [Structure](#structure). 114 | 115 | **Video VQA Dataset.** 116 | please download the 2.5M subset from [WebVid](https://maxbain.com/webvid-dataset/) and ActivityNet dataset from [official website](http://activity-net.org/download.html) or [video-chatgpt](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/docs/train_video_chatgpt.md). 117 | 118 | If you want to perform evaluation, please also download corresponding files of 119 | [ActivityNet-QA](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md) 120 | and [NExT-QA-OE](https://github.com/doc-doc/NExT-QA). 121 | You can download 122 | [MSVD-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EUNEXqg8pctPq3WZPHb4Fd8BYIxHO5qPCnU6aWsrV1O4JQ?e=guynwu) 123 | and [MSRVTT-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EcEXh1HfTXhLrRnuwHbl15IBJeRop-d50Q90njHmhvLwtA?e=SE24eG) from LLaMA-VID. 124 | 125 | 126 | **Meta Info.** 127 | For meta info of training data, please download the following files and organize them as in [Structure](#structure). 128 | 129 | | Training Stage | Data file name | Size | 130 | | --- | --- | ---: | 131 | | Pretrain | [llava_558k_with_webvid.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 254 MB | 132 | | Finetune | [llava_v1_5_mix665k_with_video_chatgpt.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 860 MB | 133 | 134 | For meta info of evaluation data, please reformat each QA list to a json file named `test_qa.json` under [Structure](#structure) with format like this: 135 | 136 | ```json 137 | [ 138 | { 139 | "video_id": "v_1QIUV7WYKXg", 140 | "question": "is the athlete wearing trousers", 141 | "id": "v_1QIUV7WYKXg_3", 142 | "answer": "no", 143 | "answer_type": 3, 144 | "duration": 9.88 145 | }, 146 | { 147 | "video_id": "v_9eniCub7u60", 148 | "question": "does the girl in black clothes have long hair", 149 | "id": "v_9eniCub7u60_2", 150 | "answer": "yes", 151 | "answer_type": 3, 152 | "duration": 19.43 153 | }, 154 | ] 155 | ``` 156 | 157 | ### Pretrained Weights 158 | We recommend users to download the pretrained weights from the following link 159 | [Vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5), 160 | [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14), 161 | and put them in `ckpt` following [Structure](#structure). 162 | 163 | ### Feature Extraction 164 | 165 | We recommend users to extract ViT features of training and evaluation data, which accelerates training and evaluating a lot. If you do so, just replace `.mp4` with `.safetensors` in video filename and put them in `image_features` and `video_features` folder. If not, ignore the `image_features` and `video_features` folder. 166 | 167 | We load video feature at fps=1 and arrange them in the time order. 168 | 169 | Each `.safetensors` file should contain a dict like this: 170 | 171 | ```python 172 | { 173 | 'feature': torch.Tensor() with shape=[256, 1024] for image and shape=[Length, 256, 1024] for video. 174 | } 175 | ``` 176 | 177 | 178 | ### Structure 179 | The folder structure should be organized as follows before training. 180 | 181 | ``` 182 | Flash-VStream 183 | ├── checkpoints-finetune 184 | ├── checkpoints-pretrain 185 | ├── ckpt 186 | │ ├── clip-vit-large-patch14 187 | │ ├── vicuna-7b-v1.5 188 | ├── data 189 | │ ├── pretrain 190 | │ │ ├── llava_558k_with_webvid.json 191 | │ │ ├── image_features 192 | │ │ ├── images 193 | │ │ ├── video_features 194 | │ │ ├── videos 195 | │ ├── finetune 196 | │ │ ├── llava_v1_5_mix665k_with_video_chatgpt.json 197 | │ │ ├── activitynet 198 | │ │ ├── coco 199 | │ │ ├── gqa 200 | │ │ ├── image_features 201 | │ │ │ ├── coco 202 | │ │ │ ├── gqa 203 | │ │ │ ├── ocr_vqa 204 | │ │ │ ├── textvqa 205 | │ │ │ ├── vg 206 | │ │ ├── ocr_vqa 207 | │ │ ├── textvqa 208 | │ │ ├── vg 209 | │ │ ├── video_features 210 | │ │ │ ├── activitynet 211 | │ ├── eval_video 212 | │ │ ├── ActivityNet-QA 213 | │ │ │ ├── video_features 214 | │ │ │ ├── test_qa.json 215 | │ │ ├── MSRVTT-QA 216 | │ │ │ ├── video_features 217 | │ │ │ ├── test_qa.json 218 | │ │ ├── MSVD-QA 219 | │ │ │ ├── video_features 220 | │ │ │ ├── test_qa.json 221 | │ │ ├── nextoe 222 | │ │ │ ├── video_features 223 | │ │ │ ├── test_qa.json 224 | │ │ ├── vstream 225 | │ │ │ ├── video_features 226 | │ │ │ ├── test_qa.json 227 | │ │ ├── vstream-realtime 228 | │ │ │ ├── video_features 229 | │ │ │ ├── test_qa.json 230 | ├── flash_vstream 231 | ├── scripts 232 | 233 | ``` 234 | 235 | ## Train 236 | Flash-VStream is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`. If your GPUs have less than 80GB memory, you may try ZeRO-2 and ZeRO-3 stages. 237 | 238 | Please make sure you download and organize the data following [Preparation](#preparation) before training. 239 | 240 | Like LLaVA, Flash-VStream has two training stages: pretrain and finetune. Their checkpoints will be saved in `checkpoints-pretrain` and `checkpoints-finetune` folder. These two stages will take about 15 hours on 8 A100 GPUs in total. 241 | 242 | If you want to train Flash-VStream from pretrained LLM and evaluate it, please run the following command: 243 | 244 | ```bash 245 | bash scripts/train_and_eval.sh 246 | ``` 247 | 248 | ## Evaluation 249 | Please make sure you download and organize the data following [Preparation](#preparation) before evaluation. 250 | 251 | If you want to evaluate a Flash-VStream model, please run the following command: 252 | 253 | ```bash 254 | bash scripts/eval.sh 255 | ``` 256 | 257 | ## Real-time CLI Inference 258 | We provide a real-time CLI inference script, which simulates video stream input by reading frames of a video file at a fixed frame speed. You can ask any question and get the answer at any timestamp of the video stream. Run the following command and have a try: 259 | 260 | ```bash 261 | bash scripts/realtime_cli.sh 262 | ``` 263 | 264 | ## VStream-QA Benchmark 265 | Please download VStream-QA Benchmark following [this](https://huggingface.co/datasets/IVGSZ/VStream-QA) repo. 266 | 267 | ## Citation 268 | If you find this project useful in your research, please consider citing: 269 | 270 | ``` 271 | @article{flashvstream, 272 | title={Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams}, 273 | author={Haoji Zhang and Yiqin Wang and Yansong Tang and Yong Liu and Jiashi Feng and Jifeng Dai and Xiaojie Jin}, 274 | year={2024}, 275 | eprint={2406.08085}, 276 | archivePrefix={arXiv}, 277 | primaryClass={cs.CV} 278 | } 279 | ``` 280 | 281 | ## Acknowledgement 282 | We would like to thank the following repos for their great work: 283 | 284 | - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA). 285 | - This work utilizes LLMs from [Vicuna](https://github.com/lm-sys/FastChat). 286 | - Some code is borrowed from [LLaMA-VID](https://github.com/dvlab-research/LLaMA-VID). 287 | - We perform video-based evaluation from [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT). 288 | 289 | ## License 290 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-yellow.svg)](LICENSE) 291 | 292 | This project is licensed under the [Apache-2.0 License](LICENSE). 293 | 294 | -------------------------------------------------------------------------------- /assets/example.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVGSZ/Flash-VStream/be61d780acb39f29a7193935cf36c210cfc16695/assets/example.mp4 -------------------------------------------------------------------------------- /assets/radar_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVGSZ/Flash-VStream/be61d780acb39f29a7193935cf36c210cfc16695/assets/radar_comparison.png -------------------------------------------------------------------------------- /assets/teaser2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVGSZ/Flash-VStream/be61d780acb39f29a7193935cf36c210cfc16695/assets/teaser2.png -------------------------------------------------------------------------------- /flash_vstream/__init__.py: -------------------------------------------------------------------------------- 1 | from flash_vstream.model import VStreamLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /flash_vstream/constants.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 4 | WORKER_HEART_BEAT_INTERVAL = 15 5 | 6 | LOGDIR = "." 7 | 8 | # Model Constants 9 | IGNORE_INDEX = -100 10 | IMAGE_TOKEN_INDEX = -200 11 | DEFAULT_IMAGE_TOKEN = "" 12 | DEFAULT_IMAGE_PATCH_TOKEN = "" 13 | DEFAULT_IM_START_TOKEN = "" 14 | DEFAULT_IM_END_TOKEN = "" 15 | IMAGE_PLACEHOLDER = "" 16 | -------------------------------------------------------------------------------- /flash_vstream/conversation.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import dataclasses 4 | from enum import auto, Enum 5 | from typing import List, Tuple 6 | 7 | 8 | class SeparatorStyle(Enum): 9 | """Different separator style.""" 10 | SINGLE = auto() 11 | TWO = auto() 12 | MPT = auto() 13 | PLAIN = auto() 14 | LLAMA_2 = auto() 15 | 16 | 17 | @dataclasses.dataclass 18 | class Conversation: 19 | """A class that keeps all conversation history.""" 20 | system: str 21 | roles: List[str] 22 | messages: List[List[str]] 23 | offset: int 24 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 25 | sep: str = "###" 26 | sep2: str = None 27 | version: str = "Unknown" 28 | 29 | skip_next: bool = False 30 | 31 | def get_prompt(self): 32 | messages = self.messages 33 | if len(messages) > 0 and type(messages[0][1]) is tuple: 34 | messages = self.messages.copy() 35 | init_role, init_msg = messages[0].copy() 36 | init_msg = init_msg[0].replace("", "").strip() 37 | if 'mmtag' in self.version: 38 | messages[0] = (init_role, init_msg) 39 | messages.insert(0, (self.roles[0], "")) 40 | messages.insert(1, (self.roles[1], "Received.")) 41 | else: 42 | messages[0] = (init_role, "\n" + init_msg) 43 | 44 | if self.sep_style == SeparatorStyle.SINGLE: 45 | ret = self.system + self.sep 46 | for role, message in messages: 47 | if message: 48 | if type(message) is tuple: 49 | message, _, _ = message 50 | ret += role + ": " + message + self.sep 51 | else: 52 | ret += role + ":" 53 | elif self.sep_style == SeparatorStyle.TWO: 54 | seps = [self.sep, self.sep2] 55 | ret = self.system + seps[0] 56 | for i, (role, message) in enumerate(messages): 57 | if message: 58 | if type(message) is tuple: 59 | message, _, _ = message 60 | ret += role + ": " + message + seps[i % 2] 61 | else: 62 | ret += role + ":" 63 | elif self.sep_style == SeparatorStyle.MPT: 64 | ret = self.system + self.sep 65 | for role, message in messages: 66 | if message: 67 | if type(message) is tuple: 68 | message, _, _ = message 69 | ret += role + message + self.sep 70 | else: 71 | ret += role 72 | elif self.sep_style == SeparatorStyle.LLAMA_2: 73 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" 74 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 75 | ret = "" 76 | 77 | for i, (role, message) in enumerate(messages): 78 | if i == 0: 79 | assert message, "first message should not be none" 80 | assert role == self.roles[0], "first message should come from user" 81 | if message: 82 | if type(message) is tuple: 83 | message, _, _ = message 84 | if i == 0: message = wrap_sys(self.system) + message 85 | if i % 2 == 0: 86 | message = wrap_inst(message) 87 | ret += self.sep + message 88 | else: 89 | ret += " " + message + " " + self.sep2 90 | else: 91 | ret += "" 92 | ret = ret.lstrip(self.sep) 93 | elif self.sep_style == SeparatorStyle.PLAIN: 94 | seps = [self.sep, self.sep2] 95 | ret = self.system 96 | for i, (role, message) in enumerate(messages): 97 | if message: 98 | if type(message) is tuple: 99 | message, _, _ = message 100 | ret += message + seps[i % 2] 101 | else: 102 | ret += "" 103 | else: 104 | raise ValueError(f"Invalid style: {self.sep_style}") 105 | 106 | return ret 107 | 108 | def append_message(self, role, message): 109 | self.messages.append([role, message]) 110 | 111 | def get_images(self, return_pil=False): 112 | images = [] 113 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 114 | if i % 2 == 0: 115 | if type(msg) is tuple: 116 | import base64 117 | from io import BytesIO 118 | from PIL import Image 119 | msg, image, image_process_mode = msg 120 | if image_process_mode == "Pad": 121 | def expand2square(pil_img, background_color=(122, 116, 104)): 122 | width, height = pil_img.size 123 | if width == height: 124 | return pil_img 125 | elif width > height: 126 | result = Image.new(pil_img.mode, (width, width), background_color) 127 | result.paste(pil_img, (0, (width - height) // 2)) 128 | return result 129 | else: 130 | result = Image.new(pil_img.mode, (height, height), background_color) 131 | result.paste(pil_img, ((height - width) // 2, 0)) 132 | return result 133 | image = expand2square(image) 134 | elif image_process_mode in ["Default", "Crop"]: 135 | pass 136 | elif image_process_mode == "Resize": 137 | image = image.resize((336, 336)) 138 | else: 139 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 140 | max_hw, min_hw = max(image.size), min(image.size) 141 | aspect_ratio = max_hw / min_hw 142 | max_len, min_len = 800, 400 143 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 144 | longest_edge = int(shortest_edge * aspect_ratio) 145 | W, H = image.size 146 | if longest_edge != max(image.size): 147 | if H > W: 148 | H, W = longest_edge, shortest_edge 149 | else: 150 | H, W = shortest_edge, longest_edge 151 | image = image.resize((W, H)) 152 | if return_pil: 153 | images.append(image) 154 | else: 155 | buffered = BytesIO() 156 | image.save(buffered, format="PNG") 157 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 158 | images.append(img_b64_str) 159 | return images 160 | 161 | def to_gradio_chatbot(self): 162 | ret = [] 163 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 164 | if i % 2 == 0: 165 | if type(msg) is tuple: 166 | import base64 167 | from io import BytesIO 168 | msg, image, image_process_mode = msg 169 | max_hw, min_hw = max(image.size), min(image.size) 170 | aspect_ratio = max_hw / min_hw 171 | max_len, min_len = 800, 400 172 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 173 | longest_edge = int(shortest_edge * aspect_ratio) 174 | W, H = image.size 175 | if H > W: 176 | H, W = longest_edge, shortest_edge 177 | else: 178 | H, W = shortest_edge, longest_edge 179 | image = image.resize((W, H)) 180 | buffered = BytesIO() 181 | image.save(buffered, format="JPEG") 182 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 183 | img_str = f'user upload image' 184 | msg = img_str + msg.replace('', '').strip() 185 | ret.append([msg, None]) 186 | else: 187 | ret.append([msg, None]) 188 | else: 189 | ret[-1][-1] = msg 190 | return ret 191 | 192 | def copy(self): 193 | return Conversation( 194 | system=self.system, 195 | roles=self.roles, 196 | messages=[[x, y] for x, y in self.messages], 197 | offset=self.offset, 198 | sep_style=self.sep_style, 199 | sep=self.sep, 200 | sep2=self.sep2, 201 | version=self.version) 202 | 203 | def dict(self): 204 | if len(self.get_images()) > 0: 205 | return { 206 | "system": self.system, 207 | "roles": self.roles, 208 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 209 | "offset": self.offset, 210 | "sep": self.sep, 211 | "sep2": self.sep2, 212 | } 213 | return { 214 | "system": self.system, 215 | "roles": self.roles, 216 | "messages": self.messages, 217 | "offset": self.offset, 218 | "sep": self.sep, 219 | "sep2": self.sep2, 220 | } 221 | 222 | 223 | conv_vicuna_v0 = Conversation( 224 | system="A chat between a curious human and an artificial intelligence assistant. " 225 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 226 | roles=("Human", "Assistant"), 227 | messages=( 228 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 229 | ("Assistant", 230 | "Renewable energy sources are those that can be replenished naturally in a relatively " 231 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 232 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 233 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 234 | "renewable and non-renewable energy sources:\n" 235 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 236 | "energy sources are finite and will eventually run out.\n" 237 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 238 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 239 | "and other negative effects.\n" 240 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 241 | "have lower operational costs than non-renewable sources.\n" 242 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 243 | "locations than non-renewable sources.\n" 244 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 245 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 246 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 247 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 248 | ), 249 | offset=2, 250 | sep_style=SeparatorStyle.SINGLE, 251 | sep="###", 252 | ) 253 | 254 | conv_vicuna_v1 = Conversation( 255 | system="A chat between a curious user and an artificial intelligence assistant. " 256 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 257 | roles=("USER", "ASSISTANT"), 258 | version="v1", 259 | messages=(), 260 | offset=0, 261 | sep_style=SeparatorStyle.TWO, 262 | sep=" ", 263 | sep2="", 264 | ) 265 | 266 | conv_vicuna_v1_mcq = Conversation( 267 | system="A chat between a curious user and an artificial intelligence assistant. " 268 | "The assistant gives helpful, detailed, and polite answers to the user's questions. " 269 | "The assistant should give the number of correct answer.", 270 | roles=("USER", "ASSISTANT"), 271 | version="v1", 272 | messages=(), 273 | offset=0, 274 | sep_style=SeparatorStyle.TWO, 275 | sep=" ", 276 | sep2="", 277 | ) 278 | 279 | conv_tiny = Conversation( 280 | system="""<|system|> 281 | A conversation between a user and an AI assistant. The assistant gives short and honest answers.""", 282 | roles=("<|user|>\n", "<|assistant|>\n"), 283 | version="mpt", 284 | messages=(), 285 | offset=0, 286 | sep_style=SeparatorStyle.MPT, 287 | sep="", 288 | ) 289 | 290 | conv_llama_2 = Conversation( 291 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 292 | 293 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 294 | roles=("USER", "ASSISTANT"), 295 | version="llama_v2", 296 | messages=(), 297 | offset=0, 298 | sep_style=SeparatorStyle.LLAMA_2, 299 | sep="", 300 | sep2="", 301 | ) 302 | 303 | conv_mpt = Conversation( 304 | system="""<|im_start|>system 305 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 306 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 307 | version="mpt", 308 | messages=(), 309 | offset=0, 310 | sep_style=SeparatorStyle.MPT, 311 | sep="<|im_end|>", 312 | ) 313 | 314 | conv_plain = Conversation( 315 | system="", 316 | roles=("", ""), 317 | messages=( 318 | ), 319 | offset=0, 320 | sep_style=SeparatorStyle.PLAIN, 321 | sep="\n", 322 | ) 323 | 324 | 325 | default_conversation = conv_vicuna_v1 326 | conv_templates = { 327 | "default": conv_vicuna_v0, 328 | "v0": conv_vicuna_v0, 329 | "v1": conv_vicuna_v1, 330 | "vicuna_v1": conv_vicuna_v1, 331 | "llama_2": conv_llama_2, 332 | "plain": conv_plain, 333 | } 334 | 335 | 336 | if __name__ == "__main__": 337 | print(default_conversation.get_prompt()) 338 | -------------------------------------------------------------------------------- /flash_vstream/eval_video/eval_activitynet_qa.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import os 4 | import ast 5 | import json 6 | import openai 7 | import argparse 8 | from tqdm import tqdm 9 | from time import sleep 10 | from collections import defaultdict 11 | from multiprocessing.pool import Pool 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 15 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.") 16 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.") 17 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.") 18 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.") 19 | parser.add_argument("--num_chunks", default=1, type=int, help="Result splits") 20 | parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key") 21 | parser.add_argument("--api_type", default=None, type=str, help="OpenAI API type") 22 | parser.add_argument("--api_version", default=None, type=str, help="OpenAI API version") 23 | parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base") 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def annotate(prediction_set, caption_files, output_dir): 29 | """ 30 | Evaluates question and answer pairs using GPT-3 31 | Returns a score for correctness. 32 | """ 33 | for file in tqdm(caption_files): 34 | key = file[:-5] # Strip file extension 35 | qa_set = prediction_set[key] 36 | question = qa_set['q'] 37 | answer = qa_set['a'] 38 | pred = qa_set['pred'] 39 | try: 40 | # Compute the correctness score 41 | completion = openai.ChatCompletion.create( 42 | model="gpt-3.5-turbo", 43 | messages=[ 44 | { 45 | "role": "system", 46 | "content": 47 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " 48 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" 49 | "------" 50 | "##INSTRUCTIONS: " 51 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n" 52 | "- Consider synonyms or paraphrases as valid matches.\n" 53 | "- Evaluate the correctness of the prediction compared to the answer." 54 | }, 55 | { 56 | "role": "user", 57 | "content": 58 | "Please evaluate the following video-based question-answer pair:\n\n" 59 | f"Question: {question}\n" 60 | f"Correct Answer: {answer}\n" 61 | f"Predicted Answer: {pred}\n\n" 62 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " 63 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 64 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 65 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." 66 | } 67 | ], 68 | temperature=0.002 69 | ) 70 | # Convert response to a Python dictionary. 71 | response_message = completion["choices"][0]["message"]["content"] 72 | response_dict = ast.literal_eval(response_message) 73 | result_qa_pair = [response_dict, qa_set] 74 | 75 | # Save the question-answer pairs to a json file. 76 | with open(f"{output_dir}/{key}.json", "w") as f: 77 | json.dump(result_qa_pair, f) 78 | sleep(0.5) 79 | 80 | except Exception as e: 81 | print(f"Error processing file '{key}': {e}") 82 | sleep(1) 83 | 84 | 85 | def main(): 86 | """ 87 | Main function to control the flow of the program. 88 | """ 89 | # Parse arguments. 90 | args = parse_args() 91 | 92 | if args.num_chunks > 1: 93 | pred_contents = [] 94 | for _idx in range(args.num_chunks): 95 | file = os.path.join(args.pred_path, f"{args.num_chunks}_{_idx}.json") 96 | pred_contents += [json.loads(line) for line in open(file)] 97 | 98 | else: 99 | file = os.path.join(args.pred_path, f"pred.json") 100 | pred_contents = [json.loads(line) for line in open(file)] 101 | 102 | # Dictionary to store the count of occurrences for each video_id 103 | video_id_counts = {} 104 | new_pred_contents = [] 105 | 106 | # Iterate through each sample in pred_contents 107 | for sample in pred_contents: 108 | video_id = sample['id'] 109 | if video_id in video_id_counts: 110 | video_id_counts[video_id] += 1 111 | else: 112 | video_id_counts[video_id] = 0 113 | 114 | # Create a new sample with the modified key 115 | new_sample = sample 116 | new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}" 117 | new_pred_contents.append(new_sample) 118 | 119 | # Generating list of id's and corresponding files 120 | id_list = [x['id'] for x in new_pred_contents] 121 | caption_files = [f"{id}.json" for id in id_list] 122 | 123 | output_dir = args.output_dir 124 | # Generate output directory if not exists. 125 | if not os.path.exists(output_dir): 126 | os.makedirs(output_dir) 127 | 128 | # Preparing dictionary of question-answer sets 129 | prediction_set = {} 130 | for sample in new_pred_contents: 131 | id = sample['id'] 132 | question = sample['question'] 133 | answer = sample['answer'] 134 | pred = sample['pred'] 135 | qa_set = {"q": question, "a": answer, "pred": pred, "a_type": sample['answer_type'] if 'answer_type' in sample else None} 136 | prediction_set[id] = qa_set 137 | 138 | # Set the OpenAI API key. 139 | openai.api_key = args.api_key # Your API key here 140 | if args.api_type: 141 | openai.api_type = args.api_type 142 | if args.api_version: 143 | openai.api_version = args.api_version 144 | if args.api_base: 145 | openai.api_base = args.api_base # Your API base here 146 | num_tasks = args.num_tasks 147 | 148 | # While loop to ensure that all captions are processed. 149 | incomplete_lengths = [] 150 | for _ in range(100): 151 | try: 152 | # Files that have not been processed yet. 153 | completed_files = os.listdir(output_dir) 154 | print(f"completed_files: {len(completed_files)}") 155 | 156 | # Files that have not been processed yet. 157 | incomplete_files = [f for f in caption_files if f not in completed_files] 158 | print(f"incomplete_files: {len(incomplete_files)}") 159 | incomplete_lengths.append(len(incomplete_files)) 160 | if len(incomplete_lengths) > 5 and len(set(incomplete_lengths[-5:])) <= 1: 161 | print(f"incomplete_lengths: {incomplete_lengths}") 162 | print(f"incomplete_files: {incomplete_files}") 163 | print(f"completed_files: {completed_files}") 164 | print(f"failed for 5 times, break") 165 | break 166 | 167 | # Break the loop when there are no incomplete files 168 | if len(incomplete_files) == 0: 169 | break 170 | if len(incomplete_files) <= num_tasks: 171 | num_tasks = 1 172 | 173 | # Split tasks into parts. 174 | part_len = len(incomplete_files) // num_tasks 175 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 176 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts] 177 | 178 | # Use a pool of workers to process the files in parallel. 179 | with Pool() as pool: 180 | pool.starmap(annotate, task_args) 181 | 182 | except Exception as e: 183 | print(f"Error: {e}") 184 | 185 | # Combine all the processed files into one 186 | combined_contents = {} 187 | json_path = args.output_json 188 | 189 | # Iterate through json files 190 | for file_name in os.listdir(output_dir): 191 | if file_name.endswith(".json"): 192 | file_path = os.path.join(output_dir, file_name) 193 | with open(file_path, "r") as json_file: 194 | content = json.load(json_file) 195 | assert 'pred' in content[0], f"Error: {file_name} don't has key=pred" 196 | assert 'score' in content[0], f"Error: {file_name} don't has key=score" 197 | combined_contents[file_name[:-5]] = content 198 | 199 | # Write combined content to a json file 200 | with open(json_path, "w") as json_file: 201 | json.dump(combined_contents, json_file) 202 | print("All evaluation completed!") 203 | 204 | class ScoreMeter: 205 | def __init__(self): 206 | self.score_sum = 0 207 | self.count = 0 208 | self.yes_count = 0 209 | self.no_count = 0 210 | self.score_dict = {'yes': defaultdict(int), 'no': defaultdict(int)} 211 | 212 | def add_score(self, score, pred): 213 | self.score_sum += score 214 | self.count += 1 215 | pred_lower = pred.lower() 216 | if 'yes' in pred_lower: 217 | self.yes_count += 1 218 | self.score_dict['yes'][score] += 1 219 | elif 'no' in pred_lower: 220 | self.no_count += 1 221 | self.score_dict['no'][score] += 1 222 | 223 | def get_average_score(self): 224 | res = (self.score_sum / self.count) if self.count else 0 225 | return f"{res:.6f}" 226 | 227 | def get_accuracy(self, response_type): 228 | if response_type == 'yes': 229 | res = (self.yes_count / self.count) if self.count else 0 230 | elif response_type == 'no': 231 | res = (self.no_count / self.count) if self.count else 0 232 | else: 233 | res = 0 234 | return f"{res:.6f}" 235 | 236 | meter_dic = {'total': ScoreMeter()} 237 | for key, result in combined_contents.items(): 238 | # Computing score 239 | score_match = result[0]['score'] 240 | score = int(score_match) 241 | pred = result[0]['pred'] 242 | 243 | meter_dic["total"].add_score(score, pred) 244 | if 'a_type' in result[1] and result[1]['a_type'] is not None: 245 | typ = str(result[1]['a_type']) 246 | if typ not in meter_dic: 247 | meter_dic[typ] = ScoreMeter() 248 | meter_dic[typ].add_score(score, pred) 249 | 250 | if 'next' in args.output_dir: 251 | typ = typ[0] 252 | if typ not in meter_dic: 253 | meter_dic[typ] = ScoreMeter() 254 | meter_dic[typ].add_score(score, pred) 255 | 256 | csv_dic = {'acc': meter_dic["total"].get_accuracy('yes'), 'score': meter_dic["total"].get_average_score()} 257 | 258 | output = "" 259 | output += "Yes count: " + str(meter_dic["total"].yes_count) + "\n" 260 | output += "No count: " + str(meter_dic["total"].no_count) + "\n" 261 | output += "Accuracy: " + str(meter_dic["total"].get_accuracy('yes')) + "\n" 262 | output += "Average score: " + str(meter_dic["total"].get_average_score()) + "\n" 263 | output += "\n" 264 | output += "Total Score Yes/No distribution:\n" 265 | for key, value in meter_dic["total"].score_dict.items(): 266 | output += f"{key}:\n" 267 | for k in range(0, 6): 268 | v = value[k] 269 | output += f"{k}: {v}\n" 270 | output += "\n" 271 | output += "Answer Type Score distribution:\n" 272 | output += 'Type, Accuracy, Avg_score\n' 273 | key_list = sorted([k for k in meter_dic.keys()]) 274 | for key in key_list: 275 | output += f"{key}, {meter_dic[key].get_accuracy('yes')}, {meter_dic[key].get_average_score()}\n" 276 | csv_dic[key] = meter_dic[key].get_accuracy('yes') 277 | 278 | output += "\n" 279 | for k in csv_dic.keys(): 280 | output += f"{k}, " 281 | output = output.rstrip(', ') # Remove the trailing comma and space 282 | output += "\n" 283 | 284 | for k in csv_dic.keys(): 285 | output += str(csv_dic[k]) + ", " 286 | output = output.rstrip(', ') # Remove the trailing comma and space 287 | output += "\n" 288 | 289 | print(output) 290 | args.output_csv = args.output_json.replace(".json", ".csv") 291 | with open(args.output_csv, 'w') as f: 292 | f.write(output) 293 | 294 | if __name__ == "__main__": 295 | main() 296 | 297 | -------------------------------------------------------------------------------- /flash_vstream/eval_video/eval_any_dataset_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Flash-VStream Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | import subprocess 18 | import multiprocessing 19 | 20 | def exec(cmd, sub=False, device=None): 21 | print(f'exec: {cmd}') 22 | if not sub: 23 | if isinstance(cmd, list): 24 | cmd = ' '.join(cmd) 25 | os.system(cmd) 26 | else: 27 | my_env = os.environ.copy() 28 | my_env["CUDA_VISIBLE_DEVICES"] = device 29 | subprocess.run(cmd, env=my_env) 30 | 31 | # multi gpu, feature 32 | def eval_msvd(args): 33 | model_path = args.model_path 34 | num_chunks = args.num_chunks 35 | if not args.only_eval: 36 | processes = [] 37 | for idx in range(0, num_chunks): 38 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 39 | "--model-path", model_path, 40 | "--video_dir", "./data/eval_video/MSVD-QA/video_features", 41 | "--gt_file", "./data/eval_video/MSVD-QA/test_qa.json", 42 | "--output_dir", os.path.join(model_path, "evaluation", "msvd"), 43 | "--output_name", "pred", 44 | "--num-chunks", str(num_chunks), 45 | "--chunk-idx", str(idx), 46 | "--conv-mode", "vicuna_v1"] 47 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 48 | processes.append(p) 49 | p.start() # 启动子进程 50 | for p in processes: 51 | p.join() 52 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 53 | "--pred_path", os.path.join(model_path, "evaluation", "msvd"), 54 | "--output_dir", os.path.join(model_path, "evaluation", "msvd", "results"), 55 | "--output_json", os.path.join(model_path, "evaluation", "msvd", "results.json"), 56 | "--num_chunks", str(num_chunks), 57 | "--num_tasks", "16", 58 | "--api_key", args.api_key, 59 | "--api_base", args.api_base, 60 | "--api_type", args.api_type, 61 | "--api_version", args.api_version, 62 | ] 63 | exec(cmd) 64 | 65 | # multi gpu, feature 66 | def eval_msrvtt(args): 67 | model_path = args.model_path 68 | num_chunks = args.num_chunks 69 | if not args.only_eval: 70 | processes = [] 71 | for idx in range(0, num_chunks): 72 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 73 | "--model-path", model_path, 74 | "--video_dir", "./data/eval_video/MSRVTT-QA/video_features", 75 | "--gt_file", "./data/eval_video/MSRVTT-QA/test_qa.json", 76 | "--output_dir", os.path.join(model_path, "evaluation", "msrvtt"), 77 | "--output_name", "pred", 78 | "--num-chunks", str(num_chunks), 79 | "--chunk-idx", str(idx), 80 | "--conv-mode", "vicuna_v1"] 81 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 82 | processes.append(p) 83 | p.start() # 启动子进程 84 | for p in processes: 85 | p.join() 86 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 87 | "--pred_path", os.path.join(model_path, "evaluation", "msrvtt"), 88 | "--output_dir", os.path.join(model_path, "evaluation", "msrvtt", "results"), 89 | "--output_json", os.path.join(model_path, "evaluation", "msrvtt", "results.json"), 90 | "--num_chunks", str(num_chunks), 91 | "--num_tasks", "16", 92 | "--api_key", args.api_key, 93 | "--api_base", args.api_base, 94 | "--api_type", args.api_type, 95 | "--api_version", args.api_version, 96 | ] 97 | exec(cmd) 98 | 99 | # multi gpu, feature 100 | def eval_actnet(args): 101 | model_path = args.model_path 102 | num_chunks = args.num_chunks 103 | if not args.only_eval: 104 | processes = [] 105 | for idx in range(0, num_chunks): 106 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 107 | "--model-path", model_path, 108 | "--video_dir", "./data/eval_video/ActivityNet-QA/video_features", 109 | "--gt_file", "./data/eval_video/ActivityNet-QA/test_qa.json", 110 | "--output_dir", os.path.join(model_path, "evaluation", "actnet"), 111 | "--output_name", "pred", 112 | "--num-chunks", str(num_chunks), 113 | "--chunk-idx", str(idx), 114 | "--conv-mode", "vicuna_v1", 115 | ] 116 | 117 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 118 | processes.append(p) 119 | p.start() # 启动子进程 120 | for p in processes: 121 | p.join() 122 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 123 | "--pred_path", os.path.join(model_path, "evaluation", "actnet"), 124 | "--output_dir", os.path.join(model_path, "evaluation", "actnet", "results"), 125 | "--output_json", os.path.join(model_path, "evaluation", "actnet", "results.json"), 126 | "--num_chunks", str(num_chunks), 127 | "--num_tasks", "16", 128 | "--api_key", args.api_key, 129 | "--api_base", args.api_base, 130 | "--api_type", args.api_type, 131 | "--api_version", args.api_version, 132 | ] 133 | exec(cmd) 134 | 135 | # multi gpu, feature 136 | def eval_nextoe(args): # follow msvd format, OE follow actnet 137 | model_path = args.model_path 138 | num_chunks = args.num_chunks 139 | if not args.only_eval: 140 | processes = [] 141 | for idx in range(0, num_chunks): 142 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 143 | "--model-path", model_path, 144 | "--video_dir", "./data/eval_video/nextoe/video_features", 145 | "--gt_file", "./data/eval_video/nextoe/test_qa.json", 146 | "--output_dir", os.path.join(model_path, "evaluation", "nextoe"), 147 | "--output_name", "pred", 148 | "--num-chunks", str(num_chunks), 149 | "--chunk-idx", str(idx), 150 | "--conv-mode", "vicuna_v1", 151 | ] 152 | 153 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 154 | processes.append(p) 155 | p.start() # 启动子进程 156 | for p in processes: 157 | p.join() 158 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 159 | "--pred_path", os.path.join(model_path, "evaluation", "nextoe"), 160 | "--output_dir", os.path.join(model_path, "evaluation", "nextoe", "results"), 161 | "--output_json", os.path.join(model_path, "evaluation", "nextoe", "results.json"), 162 | "--num_chunks", str(num_chunks), 163 | "--num_tasks", "16", 164 | "--api_key", args.api_key, 165 | "--api_base", args.api_base, 166 | "--api_type", args.api_type, 167 | "--api_version", args.api_version, 168 | ] 169 | exec(cmd) 170 | 171 | # multi gpu, feature 172 | def eval_vsmovienet(args): # follow msvd format 173 | model_path = args.model_path 174 | num_chunks = args.num_chunks 175 | if not args.only_eval: 176 | processes = [] 177 | for idx in range(0, num_chunks): 178 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 179 | "--model-path", model_path, 180 | "--video_dir", "./data/eval_video/vstream/movienet_video_features", 181 | "--gt_file", "./data/eval_video/vstream/test_qa_movienet.json", 182 | "--output_dir", os.path.join(model_path, "evaluation", "vsmovienet"), 183 | "--output_name", "pred", 184 | "--num-chunks", str(num_chunks), 185 | "--chunk-idx", str(idx), 186 | "--conv-mode", "vicuna_v1", 187 | ] 188 | 189 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 190 | processes.append(p) 191 | p.start() # 启动子进程 192 | for p in processes: 193 | p.join() 194 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 195 | "--pred_path", os.path.join(model_path, "evaluation", "vsmovienet"), 196 | "--output_dir", os.path.join(model_path, "evaluation", "vsmovienet", "results"), 197 | "--output_json", os.path.join(model_path, "evaluation", "vsmovienet", "results.json"), 198 | "--num_chunks", str(num_chunks), 199 | "--num_tasks", "16", 200 | "--api_key", args.api_key, 201 | "--api_base", args.api_base, 202 | "--api_type", args.api_type, 203 | "--api_version", args.api_version, 204 | ] 205 | exec(cmd) 206 | 207 | # multi gpu, feature 208 | def eval_vsego4d(args): # follow msvd format 209 | model_path = args.model_path 210 | num_chunks = args.num_chunks 211 | if not args.only_eval: 212 | processes = [] 213 | for idx in range(0, num_chunks): 214 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 215 | "--model-path", model_path, 216 | "--video_dir", "./data/eval_video/vstream/ego4d_video_features", 217 | "--gt_file", "./data/eval_video/vstream/test_qa_ego4d.json", 218 | "--output_dir", os.path.join(model_path, "evaluation", "vsego4d"), 219 | "--output_name", "pred", 220 | "--num-chunks", str(num_chunks), 221 | "--chunk-idx", str(idx), 222 | "--conv-mode", "vicuna_v1", 223 | ] 224 | 225 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 226 | processes.append(p) 227 | p.start() # 启动子进程 228 | for p in processes: 229 | p.join() 230 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 231 | "--pred_path", os.path.join(model_path, "evaluation", "vsego4d"), 232 | "--output_dir", os.path.join(model_path, "evaluation", "vsego4d", "results"), 233 | "--output_json", os.path.join(model_path, "evaluation", "vsego4d", "results.json"), 234 | "--num_chunks", str(num_chunks), 235 | "--num_tasks", "16", 236 | "--api_key", args.api_key, 237 | "--api_base", args.api_base, 238 | "--api_type", args.api_type, 239 | "--api_version", args.api_version, 240 | ] 241 | exec(cmd) 242 | 243 | # multi gpu, feature 244 | def eval_realtime_vsmovienet(args): # follow msvd format 245 | model_path = args.model_path 246 | num_chunks = args.num_chunks 247 | if not args.only_eval: 248 | processes = [] 249 | for idx in range(0, num_chunks): 250 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 251 | "--model-path", model_path, 252 | "--video_dir", "./data/eval_video/vstream-realtime/movienet_video_features", 253 | "--gt_file", "./data/eval_video/vstream-realtime/test_qa_movienet.json", 254 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet"), 255 | "--output_name", "pred", 256 | "--num-chunks", str(num_chunks), 257 | "--chunk-idx", str(idx), 258 | "--conv-mode", "vicuna_v1", 259 | ] 260 | 261 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 262 | processes.append(p) 263 | p.start() # 启动子进程 264 | for p in processes: 265 | p.join() 266 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 267 | "--pred_path", os.path.join(model_path, "evaluation", "realtime_vsmovienet"), 268 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results"), 269 | "--output_json", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results.json"), 270 | "--num_chunks", str(num_chunks), 271 | "--num_tasks", "16", 272 | "--api_key", args.api_key, 273 | "--api_base", args.api_base, 274 | "--api_type", args.api_type, 275 | "--api_version", args.api_version, 276 | ] 277 | exec(cmd) 278 | 279 | # multi gpu, feature 280 | def eval_realtime_vsego4d(args): # follow msvd format 281 | model_path = args.model_path 282 | num_chunks = args.num_chunks 283 | if not args.only_eval: 284 | processes = [] 285 | for idx in range(0, num_chunks): 286 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py", 287 | "--model-path", model_path, 288 | "--video_dir", "./data/eval_video/vstream-realtime/ego4d_video_features", 289 | "--gt_file", "./data/eval_video/vstream-realtime/test_qa_ego4d.json", 290 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d"), 291 | "--output_name", "pred", 292 | "--num-chunks", str(num_chunks), 293 | "--chunk-idx", str(idx), 294 | "--conv-mode", "vicuna_v1", 295 | ] 296 | 297 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx))) 298 | processes.append(p) 299 | p.start() # 启动子进程 300 | for p in processes: 301 | p.join() 302 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py", 303 | "--pred_path", os.path.join(model_path, "evaluation", "realtime_vsego4d"), 304 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results"), 305 | "--output_json", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results.json"), 306 | "--num_chunks", str(num_chunks), 307 | "--num_tasks", "16", 308 | "--api_key", args.api_key, 309 | "--api_base", args.api_base, 310 | "--api_type", args.api_type, 311 | "--api_version", args.api_version, 312 | ] 313 | exec(cmd) 314 | 315 | 316 | if __name__ == "__main__": 317 | parser = argparse.ArgumentParser() 318 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 319 | parser.add_argument("--dataset", type=str, default=None) 320 | parser.add_argument("--api_key", type=str, default=None) 321 | parser.add_argument("--api_base", type=str, default=None) 322 | parser.add_argument("--api_type", type=str, default=None) 323 | parser.add_argument("--api_version", type=str, default=None) 324 | parser.add_argument("--num_chunks", type=int, default=1) 325 | parser.add_argument("--only_eval", action="store_true") 326 | parser.add_argument("--vizlen", type=int, default=0) 327 | parser.add_argument("--use_speech", action="store_true", default=False) 328 | args = parser.parse_args() 329 | func_dic = {'msvd': eval_msvd, 330 | 'msrvtt': eval_msrvtt, 331 | 'actnet': eval_actnet, 332 | 'nextoe': eval_nextoe, 333 | 'vsmovienet': eval_vsmovienet, 334 | 'vsego4d': eval_vsego4d, 335 | 'realtime_vsmovienet': eval_realtime_vsmovienet, 336 | 'realtime_vsego4d': eval_realtime_vsego4d, 337 | } 338 | if args.dataset in func_dic: 339 | print(f'Execute {args.dataset} evaluation') 340 | func_dic[args.dataset](args) 341 | -------------------------------------------------------------------------------- /flash_vstream/eval_video/model_msvd_qa.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import os 4 | import json 5 | import math 6 | import torch 7 | import argparse 8 | from tqdm import tqdm 9 | from decord import VideoReader, cpu 10 | 11 | from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 12 | from flash_vstream.conversation import conv_templates, SeparatorStyle 13 | from flash_vstream.model.builder import load_pretrained_model 14 | from flash_vstream.utils import disable_torch_init 15 | from flash_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def parse_args(): 30 | """ 31 | Parse command-line arguments. 32 | """ 33 | parser = argparse.ArgumentParser() 34 | 35 | # Define the command-line arguments 36 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True) 37 | parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True) 38 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True) 39 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True) 40 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 41 | parser.add_argument("--model-base", type=str, default=None) 42 | parser.add_argument("--conv-mode", type=str, default=None) 43 | parser.add_argument("--num-chunks", type=int, default=1) 44 | parser.add_argument("--chunk-idx", type=int, default=0) 45 | parser.add_argument("--model-max-length", type=int, default=None) 46 | 47 | return parser.parse_args() 48 | 49 | 50 | def load_video(video_path): 51 | vr = VideoReader(video_path, ctx=cpu(0)) 52 | total_frame_num = len(vr) 53 | fps = round(vr.get_avg_fps()) 54 | frame_idx = [i for i in range(0, len(vr), fps)] 55 | spare_frames = vr.get_batch(frame_idx).asnumpy() 56 | return spare_frames 57 | 58 | 59 | def run_inference(args): 60 | """ 61 | Run inference on ActivityNet QA DataSet using the Video-ChatGPT model. 62 | 63 | Args: 64 | args: Command-line arguments. 65 | """ 66 | # Initialize the model 67 | model_name = get_model_name_from_path(args.model_path) 68 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length) 69 | 70 | # Load both ground truth file containing questions and answers 71 | with open(args.gt_file) as file: 72 | gt_questions = json.load(file) 73 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx) 74 | 75 | # Create the output directory if it doesn't exist 76 | if not os.path.exists(args.output_dir): 77 | try: 78 | os.makedirs(args.output_dir) 79 | except Exception as e: 80 | print(f'mkdir Except: {e}') 81 | 82 | video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 83 | if args.num_chunks > 1: 84 | output_name = f"{args.num_chunks}_{args.chunk_idx}" 85 | else: 86 | output_name = args.output_name 87 | answers_file = os.path.join(args.output_dir, f"{output_name}.json") 88 | ans_file = open(answers_file, "w") 89 | 90 | for sample in tqdm(gt_questions, desc=f"cuda:{args.chunk_idx} "): 91 | video_name = sample['video_id'] 92 | question = sample['question'] 93 | id = sample['id'] 94 | answer = sample['answer'] 95 | 96 | sample_set = {'id': id, 'question': question, 'answer': answer} 97 | 98 | # Load the video file 99 | for fmt in video_formats: # Added this line 100 | temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}") 101 | if os.path.exists(temp_path): 102 | video_path = temp_path 103 | break 104 | 105 | # Check if the video exists 106 | if os.path.exists(video_path): 107 | video = load_video(video_path) 108 | video = image_processor.preprocess(video, return_tensors='pt')['pixel_values'].half().cuda() 109 | video = [video] 110 | 111 | qs = question 112 | if model.config.mm_use_im_start_end: 113 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 114 | else: 115 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 116 | 117 | conv = conv_templates[args.conv_mode].copy() 118 | conv.append_message(conv.roles[0], qs) 119 | conv.append_message(conv.roles[1], None) 120 | prompt = conv.get_prompt() 121 | 122 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 123 | 124 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 125 | keywords = [stop_str] 126 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 127 | 128 | with torch.inference_mode(): 129 | output_ids = model.generate( 130 | input_ids, 131 | images=video, 132 | do_sample=True, 133 | temperature=0.002, 134 | max_new_tokens=1024, 135 | use_cache=True, 136 | stopping_criteria=[stopping_criteria]) 137 | 138 | input_token_len = input_ids.shape[1] 139 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 140 | if n_diff_input_output > 0: 141 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 142 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 143 | outputs = outputs.strip() 144 | if outputs.endswith(stop_str): 145 | outputs = outputs[:-len(stop_str)] 146 | outputs = outputs.strip() 147 | 148 | sample_set['pred'] = outputs 149 | ans_file.write(json.dumps(sample_set) + "\n") 150 | ans_file.flush() 151 | 152 | ans_file.close() 153 | 154 | 155 | if __name__ == "__main__": 156 | args = parse_args() 157 | run_inference(args) 158 | -------------------------------------------------------------------------------- /flash_vstream/eval_video/model_msvd_qa_featuresloader.py: -------------------------------------------------------------------------------- 1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. 2 | # Based on https://github.com/haotian-liu/LLaVA. 3 | 4 | import os 5 | import json 6 | import math 7 | import torch 8 | import random 9 | import argparse 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset, DataLoader 12 | from safetensors.torch import load_file 13 | 14 | from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 15 | from flash_vstream.conversation import conv_templates, SeparatorStyle 16 | from flash_vstream.model.builder import load_pretrained_model 17 | from flash_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 18 | 19 | 20 | def split_list(lst, n): 21 | """Split a list into n (roughly) equal-sized chunks""" 22 | chunk_size = math.ceil(len(lst) / n) # integer division 23 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 24 | 25 | 26 | def get_chunk(lst, n, k): 27 | chunks = split_list(lst, n) 28 | return chunks[k] 29 | 30 | 31 | def parse_args(): 32 | """ 33 | Parse command-line arguments. 34 | """ 35 | parser = argparse.ArgumentParser() 36 | 37 | # Define the command-line arguments 38 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True) 39 | parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True) 40 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True) 41 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True) 42 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 43 | parser.add_argument("--model-base", type=str, default=None) 44 | parser.add_argument("--conv-mode", type=str, default=None) 45 | parser.add_argument("--num-chunks", type=int, default=1) 46 | parser.add_argument("--chunk-idx", type=int, default=0) 47 | parser.add_argument("--model-max-length", type=int, default=None) 48 | return parser.parse_args() 49 | 50 | 51 | class CustomDataset(Dataset): 52 | def __init__(self, questions, video_dir, tokenizer, image_processor, model_config): 53 | self.questions = questions 54 | self.video_dir = video_dir 55 | self.tokenizer = tokenizer 56 | self.image_processor = image_processor 57 | self.model_config = model_config 58 | 59 | def __getitem__(self, index): 60 | sample = self.questions[index] 61 | video_name = sample['video_id'] 62 | try: 63 | video_path = os.path.join(self.video_dir, video_name + '.safetensors') 64 | video_tensor = load_file(video_path)['feature'] 65 | except Exception as e: 66 | print(f'Dataset Exception: {e}, randomly choose one.') 67 | idx = random.randint(0, len(self.questions) - 1) 68 | return self.__getitem__(idx) 69 | qs = sample['question'] 70 | if self.model_config.mm_use_im_start_end: 71 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 72 | else: 73 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 74 | conv = conv_templates[args.conv_mode].copy() 75 | if 'system' in sample: 76 | conv.system = conv.system + ' ' + sample['system'] 77 | conv.append_message(conv.roles[0], qs) 78 | conv.append_message(conv.roles[1], None) 79 | prompt = conv.get_prompt() 80 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') 81 | return input_ids, video_tensor 82 | 83 | def __len__(self): 84 | return len(self.questions) 85 | 86 | 87 | def create_data_loader(questions, video_dir, tokenizer, image_processor, model_config, batch_size=1, num_workers=2): 88 | assert batch_size == 1, "batch_size must be 1" 89 | dataset = CustomDataset(questions, video_dir, tokenizer, image_processor, model_config) 90 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) 91 | return data_loader 92 | 93 | 94 | def run_inference(args): 95 | """ 96 | Run inference on ActivityNet QA DataSet using the Video-ChatGPT model. 97 | 98 | Args: 99 | args: Command-line arguments. 100 | """ 101 | # Initialize the model 102 | model_name = get_model_name_from_path(args.model_path) 103 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length) 104 | 105 | # Load both ground truth file containing questions and answers 106 | with open(args.gt_file) as file: 107 | gt_questions = json.load(file) 108 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx) 109 | 110 | # Create the output directory if it doesn't exist 111 | if not os.path.exists(args.output_dir): 112 | try: 113 | os.makedirs(args.output_dir) 114 | except Exception as e: 115 | print(f'mkdir Except: {e}') 116 | 117 | video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 118 | if args.num_chunks > 1: 119 | output_name = f"{args.num_chunks}_{args.chunk_idx}" 120 | else: 121 | output_name = args.output_name 122 | answers_file = os.path.join(args.output_dir, f"{output_name}.json") 123 | # resume from old exp 124 | exist_id_set = set() 125 | if os.path.exists(answers_file): 126 | with open(answers_file) as f: 127 | exist_pred_contents = [json.loads(line) for line in f] 128 | exist_id_set = set([x['id'] for x in exist_pred_contents]) 129 | 130 | new_gt_questions = [] 131 | for sample in tqdm(gt_questions): 132 | if not sample['id'] in exist_id_set: 133 | new_gt_questions.append(sample) 134 | gt_questions = new_gt_questions 135 | 136 | data_loader = create_data_loader(gt_questions, args.video_dir, tokenizer, image_processor, model.config) 137 | 138 | conv = conv_templates[args.conv_mode].copy() 139 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 140 | keywords = [stop_str] 141 | 142 | with open(answers_file, "a") as ans_file: 143 | for data, sample in tqdm(zip(data_loader, gt_questions), desc=f"cuda:{args.chunk_idx} ", total=len(gt_questions)): 144 | input_ids, video_tensors = data 145 | input_ids = input_ids.to(device='cuda', non_blocking=True) 146 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 147 | with torch.inference_mode(): 148 | output_ids = model.generate( 149 | input_ids, 150 | features=video_tensors.to(dtype=torch.float16, device='cuda', non_blocking=True), 151 | do_sample=True, 152 | temperature=0.002, 153 | max_new_tokens=1024, 154 | use_cache=True, 155 | stopping_criteria=[stopping_criteria], 156 | ) 157 | input_token_len = input_ids.shape[1] 158 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 159 | if n_diff_input_output > 0: 160 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 161 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 162 | outputs = outputs.strip() 163 | if outputs.endswith(stop_str): 164 | outputs = outputs[:-len(stop_str)] 165 | outputs = outputs.strip() 166 | sample_set = { 167 | 'id': sample['id'], 168 | 'question': sample['question'], 169 | 'answer': sample['answer'], 170 | 'answer_type': sample['answer_type'] if 'answer_type' in sample else None, 171 | 'pred': outputs 172 | } 173 | ans_file.write(json.dumps(sample_set) + "\n") 174 | ans_file.flush() 175 | 176 | 177 | if __name__ == "__main__": 178 | args = parse_args() 179 | run_inference(args) 180 | -------------------------------------------------------------------------------- /flash_vstream/mm_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | from PIL import Image 4 | from io import BytesIO 5 | import base64 6 | 7 | import torch 8 | from transformers import StoppingCriteria 9 | from flash_vstream.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def load_image_from_base64(image): 13 | return Image.open(BytesIO(base64.b64decode(image))) 14 | 15 | 16 | def expand2square(pil_img, background_color): 17 | width, height = pil_img.size 18 | if width == height: 19 | return pil_img 20 | elif width > height: 21 | result = Image.new(pil_img.mode, (width, width), background_color) 22 | result.paste(pil_img, (0, (width - height) // 2)) 23 | return result 24 | else: 25 | result = Image.new(pil_img.mode, (height, height), background_color) 26 | result.paste(pil_img, ((height - width) // 2, 0)) 27 | return result 28 | 29 | 30 | def process_images(images, image_processor, model_cfg): 31 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 32 | new_images = [] 33 | if image_aspect_ratio == 'pad': 34 | for image in images: 35 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 36 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 37 | new_images.append(image) 38 | else: 39 | return image_processor(images, return_tensors='pt')['pixel_values'] 40 | if all(x.shape == new_images[0].shape for x in new_images): 41 | new_images = torch.stack(new_images, dim=0) 42 | return new_images 43 | 44 | 45 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 46 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 47 | 48 | def insert_separator(X, sep): 49 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 50 | 51 | input_ids = [] 52 | offset = 0 53 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 54 | offset = 1 55 | input_ids.append(prompt_chunks[0][0]) 56 | 57 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 58 | input_ids.extend(x[offset:]) 59 | 60 | if return_tensors is not None: 61 | if return_tensors == 'pt': 62 | return torch.tensor(input_ids, dtype=torch.long) 63 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 64 | return input_ids 65 | 66 | 67 | def get_model_name_from_path(model_path): 68 | model_path = model_path.strip("/") 69 | model_paths = model_path.split("/") 70 | if model_paths[-1].startswith('checkpoint-'): 71 | return model_paths[-2] + "_" + model_paths[-1] 72 | else: 73 | return model_paths[-1] 74 | 75 | class KeywordsStoppingCriteria(StoppingCriteria): 76 | def __init__(self, keywords, tokenizer, input_ids): 77 | self.keywords = keywords 78 | self.keyword_ids = [] 79 | self.max_keyword_len = 0 80 | for keyword in keywords: 81 | cur_keyword_ids = tokenizer(keyword).input_ids 82 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 83 | cur_keyword_ids = cur_keyword_ids[1:] 84 | if len(cur_keyword_ids) > self.max_keyword_len: 85 | self.max_keyword_len = len(cur_keyword_ids) 86 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 87 | self.tokenizer = tokenizer 88 | self.start_len = input_ids.shape[1] 89 | 90 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 91 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 92 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 93 | for keyword_id in self.keyword_ids: 94 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 95 | return True 96 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 97 | for keyword in self.keywords: 98 | if keyword in outputs: 99 | return True 100 | return False 101 | 102 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 103 | outputs = [] 104 | for i in range(output_ids.shape[0]): 105 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 106 | return all(outputs) 107 | -------------------------------------------------------------------------------- /flash_vstream/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.vstream_llama import VStreamLlamaForCausalLM, VStreamConfig 2 | -------------------------------------------------------------------------------- /flash_vstream/model/builder.py: -------------------------------------------------------------------------------- 1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. 2 | # ------------------------------------------------------------------------ 3 | # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright: 4 | # Copyright 2023 Haotian Liu 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import os 20 | import warnings 21 | import shutil 22 | 23 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 24 | import torch 25 | from flash_vstream.model import VStreamLlamaForCausalLM, VStreamConfig 26 | from flash_vstream.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 27 | 28 | 29 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs): 30 | kwargs = {"device_map": device_map, **kwargs} 31 | 32 | if device != "cuda": 33 | kwargs['device_map'] = {"": device} 34 | 35 | if load_8bit: 36 | kwargs['load_in_8bit'] = True 37 | elif load_4bit: 38 | kwargs['load_in_4bit'] = True 39 | kwargs['quantization_config'] = BitsAndBytesConfig( 40 | load_in_4bit=True, 41 | bnb_4bit_compute_dtype=torch.float16, 42 | bnb_4bit_use_double_quant=True, 43 | bnb_4bit_quant_type='nf4' 44 | ) 45 | else: 46 | kwargs['torch_dtype'] = torch.float16 47 | 48 | if 'vstream' in model_name.lower(): 49 | # Load LLaMA-VStream model 50 | if 'lora' in model_name.lower() and model_base is None: 51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 52 | if 'lora' in model_name.lower() and model_base is not None: 53 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 54 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 55 | print('(LoRA) Loading LLaMA-VStream from base model...') 56 | model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 57 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 58 | if model.lm_head.weight.shape[0] != token_num: 59 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 60 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 61 | 62 | print('(LoRA) Loading additional LLaMA-VStream weights...') 63 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 64 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 65 | else: 66 | # this is probably from HF Hub 67 | from huggingface_hub import hf_hub_download 68 | def load_from_hf(repo_id, filename, subfolder=None): 69 | cache_file = hf_hub_download( 70 | repo_id=repo_id, 71 | filename=filename, 72 | subfolder=subfolder) 73 | return torch.load(cache_file, map_location='cpu') 74 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 75 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 76 | if any(k.startswith('model.model.') for k in non_lora_trainables): 77 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 78 | model.load_state_dict(non_lora_trainables, strict=False) 79 | 80 | from peft import PeftModel 81 | print('Loading LoRA weights...') 82 | model = PeftModel.from_pretrained(model, model_path) 83 | print('Merging LoRA weights...') 84 | model = model.merge_and_unload() 85 | print('Model is loaded...') 86 | elif model_base is not None: 87 | # this may be mm projector only 88 | print('Loading LLaMA-VStream from base model...') 89 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 90 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 91 | model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 92 | 93 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 94 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 95 | model.load_state_dict(mm_projector_weights, strict=False) 96 | else: 97 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 98 | model = VStreamLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 99 | else: 100 | # Load language model 101 | if model_base is not None: 102 | # PEFT model 103 | from peft import PeftModel 104 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 105 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 106 | print(f"Loading LoRA weights from {model_path}") 107 | model = PeftModel.from_pretrained(model, model_path) 108 | print(f"Merging weights") 109 | model = model.merge_and_unload() 110 | print('Convert to FP16...') 111 | model.to(torch.float16) 112 | else: 113 | use_fast = False 114 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 115 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 116 | 117 | image_processor = None 118 | 119 | if 'vstream' in model_name.lower(): 120 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 121 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 122 | if mm_use_im_patch_token: 123 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 124 | if mm_use_im_start_end: 125 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 126 | model.resize_token_embeddings(len(tokenizer)) 127 | 128 | vision_tower = model.get_vision_tower() 129 | if not vision_tower.is_loaded: 130 | vision_tower.load_model() 131 | vision_tower.to(device=device, dtype=torch.float16) 132 | image_processor = vision_tower.image_processor 133 | 134 | if hasattr(model.config, "max_sequence_length"): 135 | context_len = model.config.max_sequence_length 136 | else: 137 | context_len = 2048 138 | 139 | return tokenizer, model, image_processor, context_len 140 | -------------------------------------------------------------------------------- /flash_vstream/model/compress_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Flash-VStream Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | def drop_feature(img_feature, video_max_frames, img_similarity=None): 21 | T, P, D = img_feature.shape 22 | indices = [[i] for i in range(T)] 23 | T0 = video_max_frames 24 | if T <= T0: 25 | return img_feature, img_similarity, [indices] 26 | cur_feature = img_feature[:T0] # [T0, P, D] 27 | if img_similarity is not None: 28 | cur_sim = img_similarity[:T0 - 1] 29 | else: 30 | cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1] 31 | cur_indices = indices[:T0] 32 | step_indices = [cur_indices] 33 | for i in range(T0, T): 34 | new_feature = img_feature[i] 35 | new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0) 36 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) 37 | all_indices = cur_indices + [[i]] 38 | all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0) 39 | idx = torch.argmax(all_sim) 40 | if random.randint(0, 1) > 0: 41 | idx = idx + 1 42 | cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) 43 | if idx + 1 == T0 + 1: 44 | cur_sim = all_sim[:T0 - 1] 45 | cur_indices = all_indices[:-1] 46 | elif idx == 0: 47 | cur_sim = all_sim[1:] 48 | cur_indices = all_indices[1:] 49 | else: 50 | cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]]) 51 | cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0) 52 | cur_indices = all_indices[:idx] + all_indices[idx + 1:] 53 | step_indices.append(cur_indices) 54 | # print(f'Note: perform drop feature {img_feature.shape} to {cur_feature.shape}') 55 | return cur_feature, cur_sim, step_indices 56 | 57 | 58 | def merge_feature(img_feature, video_max_frames, img_similarity=None): 59 | T, P, D = img_feature.shape 60 | indices = [[i] for i in range(T)] 61 | T0 = video_max_frames 62 | if T <= T0: 63 | return img_feature, img_similarity, [indices] 64 | cur_feature = img_feature[:T0] # [T0, P, D] 65 | cur_indices = indices[:T0] 66 | step_indices = [cur_indices] 67 | if img_similarity is not None: 68 | cur_sim = img_similarity[:T0 - 1] 69 | else: 70 | cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1] 71 | for i in range(T0, T): 72 | new_feature = img_feature[i] 73 | new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0) 74 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) 75 | all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0) 76 | all_indices = cur_indices + [[i]] 77 | idx = torch.argmax(all_sim) 78 | all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0 79 | all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1] 80 | cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) 81 | cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]]) 82 | cur_indices = all_indices[:idx] + all_indices[idx + 1:] 83 | if idx > 0: 84 | cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0) 85 | if idx + 1 < T0: 86 | cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0) 87 | step_indices.append(cur_indices) 88 | # print(f'Note: perform merge feature {img_feature.shape} to {cur_feature.shape}') 89 | return cur_feature, cur_sim, step_indices 90 | 91 | 92 | def kmeans_feature(img_feature, video_max_frames, img_similarity=None): 93 | def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10): 94 | indices = torch.randperm(X.size(0))[:num_clusters] 95 | centroids = X[indices] 96 | for i in range(max_iter): 97 | if distance == 'euclidean': 98 | dists = torch.cdist(X, centroids, p=2) 99 | else: 100 | raise NotImplementedError("Only Euclidean distance is supported yet") 101 | labels = torch.argmin(dists, dim=1) 102 | new_centroids = [] 103 | for j in range(num_clusters): 104 | cluster_points = X[labels == j] 105 | if len(cluster_points) > 0: 106 | new_centroid = cluster_points.mean(0) 107 | else: # fix nan centroids 108 | new_centroid = X[random.randint(0, X.size(0) - 1)] 109 | new_centroids.append(new_centroid) 110 | new_centroids = torch.stack(new_centroids) 111 | diff = torch.norm(centroids - new_centroids, dim=1).sum() 112 | if diff < tol: 113 | break 114 | centroids = new_centroids 115 | return centroids, labels, i 116 | T, P, D = img_feature.shape 117 | T0 = video_max_frames 118 | if T <= T0: 119 | return img_feature, img_similarity, [[[i] for i in range(T)]] 120 | X = img_feature.view(T, -1) # [T, P, D] 121 | centroids, labels, exit_step = kmeans_torch(X, T0) 122 | reduced_feature = centroids.view(T0, P, D) 123 | # print(f'Note: perform kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0 124 | step_indices = [[] for _ in range(T0)] 125 | for i in range(T0): 126 | step_indices[i] = [j for j in range(T) if labels[j] == i] 127 | return reduced_feature, img_similarity, [step_indices] 128 | 129 | 130 | def weighted_kmeans_feature(img_feature, video_max_frames, weights=None): 131 | if weights is None: 132 | weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device) 133 | def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10): 134 | indices = torch.randperm(X.size(0), device=X.device)[:num_clusters] 135 | centroids = X[indices] 136 | for i in range(max_iter): 137 | if distance == 'euclidean': 138 | dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt() 139 | else: 140 | raise NotImplementedError("Only Euclidean distance is supported yet") 141 | labels = torch.argmin(dists, dim=1) 142 | weighted_sum = torch.zeros_like(centroids) 143 | weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device) 144 | for j in range(num_clusters): 145 | cluster_mask = labels == j 146 | weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0) 147 | weights_sum[j] = torch.sum(weights[cluster_mask]) 148 | mask = weights_sum > 0 149 | new_centroids = torch.zeros_like(weighted_sum) 150 | new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None] 151 | if mask.sum() < num_clusters: # fix nan centroids 152 | new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())]) 153 | diff = torch.norm(centroids - new_centroids, dim=1).sum() 154 | if diff < tol: 155 | break 156 | centroids = new_centroids 157 | return centroids, labels, weights_sum, i 158 | T, P, D = img_feature.shape 159 | T0 = video_max_frames 160 | if T <= T0: 161 | return img_feature, weights, [[[i] for i in range(T)]] 162 | X = img_feature.view(T, -1) # [T, P, D] 163 | centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights) 164 | reduced_feature = centroids.view(T0, P, D) 165 | # print(f'Note: perform weighted kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0 166 | step_indices = [[] for _ in range(T0)] 167 | for i in range(T0): 168 | step_indices[i] = [j for j in range(T) if labels[j] == i] 169 | return reduced_feature, weights, [step_indices] 170 | 171 | 172 | def k_drop_feature(img_feature, video_max_frames, img_similarity=None): 173 | T, P, D = img_feature.shape 174 | indices = [[i] for i in range(T)] 175 | T0 = video_max_frames 176 | if T <= T0: 177 | return img_feature, img_similarity, [indices] 178 | cur_feature = img_feature[:T0] # [T0, P, D] 179 | normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1) 180 | cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0] 181 | cur_sim.fill_diagonal_(-100.0) 182 | cur_indices = indices[:T0] 183 | step_indices = [cur_indices] 184 | for i in range(T0, T): 185 | # get new feature 186 | new_feature = img_feature[i] 187 | normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1) 188 | new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1] 189 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) 190 | normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0) 191 | all_indices = cur_indices + [[i]] 192 | # get new similarity 193 | all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1] 194 | all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1] 195 | all_sim[-1, :-1] = new_sim.T 196 | # choose compression position 197 | idx = torch.argmax(all_sim) 198 | left, right = idx // (T0 + 1), idx % (T0 + 1) 199 | if random.randint(0, 1) > 0: 200 | idx = left 201 | else: 202 | idx = right 203 | assert all_sim[left, right] == torch.max(all_sim) 204 | # get compressed feature and similarity 205 | cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) 206 | normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]]) 207 | cur_indices = all_indices[:idx] + all_indices[idx + 1:] 208 | cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) # [T0, T0 + 1] 209 | cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) # [T0, T0] 210 | step_indices.append(cur_indices) 211 | # print(f'Note: perform k-drop feature {img_feature.shape} to {cur_feature.shape}') 212 | return cur_feature, None, step_indices 213 | 214 | 215 | def k_merge_feature(img_feature, video_max_frames, img_similarity=None): 216 | T, P, D = img_feature.shape 217 | indices = [[i] for i in range(T)] 218 | T0 = video_max_frames 219 | if T <= T0: 220 | return img_feature, img_similarity, [indices] 221 | cur_feature = img_feature[:T0] # [T0, P, D] 222 | normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1) 223 | cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0] 224 | cur_sim.fill_diagonal_(-100.0) 225 | cur_indices = indices[:T0] 226 | step_indices = [cur_indices] 227 | for i in range(T0, T): 228 | # get new feature 229 | new_feature = img_feature[i] 230 | normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1) 231 | new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1] 232 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) 233 | normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0) 234 | all_indices = cur_indices + [[i]] 235 | # get new similarity 236 | all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1] 237 | all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1] 238 | all_sim[-1, :-1] = new_sim.T 239 | # choose compression position 240 | idx = torch.argmax(all_sim) 241 | left, right = idx // (T0 + 1), idx % (T0 + 1) 242 | assert all_sim[left, right] == torch.max(all_sim) 243 | # update feature 244 | all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0 245 | normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1) 246 | all_indices[right] = all_indices[left] + all_indices[right] 247 | # update similarity 248 | new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) # [T0 + 1, 1] 249 | all_sim[right, :] = new_sim.T 250 | all_sim[:, right:right+1] = new_sim 251 | all_sim[right, right] = -100.0 252 | # get compressed feature and similarity 253 | cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]]) 254 | normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]]) 255 | cur_indices = all_indices[:left] + all_indices[left + 1:] 256 | cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) # [T0, T0 + 1] 257 | cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) # [T0, T0] 258 | step_indices.append(cur_indices) 259 | # print(f'Note: perform k-merge feature {img_feature.shape} to {cur_feature.shape}') 260 | return cur_feature, cur_sim, step_indices 261 | 262 | 263 | def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2): 264 | T, P, D = img_feature.shape 265 | T0 = video_max_frames 266 | if T <= T0: 267 | return img_feature, None 268 | cur_feature = img_feature[:T0] # [T0, P, D] 269 | turing_memory = cur_feature.reshape(T0*P, D) # [T0*P, D] 270 | for i in range(T0, T, T0): 271 | j = min(i + T0, T) 272 | new_feature = img_feature[i:j] # [P, D] 273 | new_feature = new_feature.reshape(-1, D) # [n*P, D] 274 | turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) # [T0*P, n*P] 275 | cur_feature = turing_memory.reshape(T0, P, D) 276 | # print(f'Note: perform {attention_fn.__name__} feature {img_feature.shape} to {cur_feature.shape}') 277 | return cur_feature, None 278 | -------------------------------------------------------------------------------- /flash_vstream/model/language_model/vstream_llama.py: -------------------------------------------------------------------------------- 1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. 2 | # ------------------------------------------------------------------------ 3 | # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright: 4 | # Copyright 2023 Haotian Liu 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from typing import List, Optional, Tuple, Union 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from flash_vstream.model.vstream_arch import VStreamMetaModel, VStreamMetaForCausalLM 26 | 27 | 28 | class VStreamConfig(LlamaConfig): 29 | model_type = "vstream" 30 | 31 | 32 | class VStreamLlamaModel(VStreamMetaModel, LlamaModel): 33 | config_class = VStreamConfig 34 | 35 | def __init__(self, config: LlamaConfig): 36 | super(VStreamLlamaModel, self).__init__(config) 37 | 38 | 39 | class VStreamLlamaForCausalLM(VStreamMetaForCausalLM, LlamaForCausalLM): 40 | config_class = VStreamConfig 41 | 42 | def __init__(self, config): 43 | super(VStreamLlamaForCausalLM, self).__init__(config) 44 | self.model = VStreamLlamaModel(config) 45 | self.pretraining_tp = config.pretraining_tp 46 | self.vocab_size = config.vocab_size 47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 48 | 49 | # Initialize weights and apply final processing 50 | self.post_init() 51 | 52 | def get_model(self): 53 | return self.model 54 | 55 | def forward( 56 | self, 57 | input_ids: torch.LongTensor = None, 58 | attention_mask: Optional[torch.Tensor] = None, 59 | position_ids: Optional[torch.LongTensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = True, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | features: Optional[torch.FloatTensor] = None, 68 | return_dict: Optional[bool] = None, 69 | ) -> Union[Tuple, CausalLMOutputWithPast]: 70 | if inputs_embeds is None: 71 | if self.use_video_streaming_mode: 72 | ( 73 | input_ids, 74 | position_ids, 75 | attention_mask, 76 | past_key_values, 77 | inputs_embeds, 78 | labels 79 | ) = self.prepare_inputs_labels_for_multimodal_streaming( 80 | input_ids, 81 | position_ids, 82 | attention_mask, 83 | past_key_values, 84 | labels, 85 | ) 86 | else: 87 | ( 88 | input_ids, 89 | position_ids, 90 | attention_mask, 91 | past_key_values, 92 | inputs_embeds, 93 | labels 94 | ) = self.prepare_inputs_labels_for_multimodal( 95 | input_ids, 96 | position_ids, 97 | attention_mask, 98 | past_key_values, 99 | labels, 100 | images, 101 | features, 102 | ) 103 | return super().forward( 104 | input_ids=input_ids, 105 | attention_mask=attention_mask, 106 | position_ids=position_ids, 107 | past_key_values=past_key_values, 108 | inputs_embeds=inputs_embeds, 109 | labels=labels, 110 | use_cache=use_cache, 111 | output_attentions=output_attentions, 112 | output_hidden_states=output_hidden_states, 113 | return_dict=return_dict 114 | ) 115 | 116 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 117 | images = kwargs.pop("images", None) 118 | features = kwargs.pop("features", None) 119 | _inputs = super().prepare_inputs_for_generation( 120 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 121 | ) 122 | if images is not None: 123 | _inputs['images'] = images 124 | if features is not None: 125 | _inputs['features'] = features 126 | return _inputs 127 | 128 | AutoConfig.register("vstream", VStreamConfig) 129 | AutoModelForCausalLM.register(VStreamConfig, VStreamLlamaForCausalLM) 130 | -------------------------------------------------------------------------------- /flash_vstream/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import os 4 | from .clip_encoder import CLIPVisionTower 5 | 6 | 7 | def build_vision_tower(vision_tower_cfg, **kwargs): 8 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 9 | is_absolute_path_exists = os.path.exists(vision_tower) 10 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 11 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 12 | 13 | raise ValueError(f'Unknown vision tower: {vision_tower}') 14 | -------------------------------------------------------------------------------- /flash_vstream/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 7 | 8 | 9 | class CLIPVisionTower(nn.Module): 10 | def __init__(self, vision_tower, args, delay_load=False): 11 | super().__init__() 12 | 13 | self.is_loaded = False 14 | 15 | self.vision_tower_name = vision_tower 16 | self.select_layer = args.mm_vision_select_layer 17 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 18 | 19 | if not delay_load: 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self): 25 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 26 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 27 | self.vision_tower.requires_grad_(False) 28 | 29 | self.is_loaded = True 30 | 31 | def feature_select(self, image_forward_outs): 32 | image_features = image_forward_outs.hidden_states[self.select_layer] 33 | if self.select_feature == 'patch': 34 | image_features = image_features[:, 1:] 35 | elif self.select_feature == 'cls_patch': 36 | image_features = image_features 37 | else: 38 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 39 | return image_features 40 | 41 | @torch.no_grad() 42 | def forward(self, images): 43 | if type(images) is list: 44 | image_features = [] 45 | for image in images: 46 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 47 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 48 | image_features.append(image_feature) 49 | else: 50 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 51 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 52 | 53 | return image_features 54 | 55 | @property 56 | def dummy_feature(self): 57 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 58 | 59 | @property 60 | def dtype(self): 61 | return self.vision_tower.dtype 62 | 63 | @property 64 | def device(self): 65 | return self.vision_tower.device 66 | 67 | @property 68 | def config(self): 69 | if self.is_loaded: 70 | return self.vision_tower.config 71 | else: 72 | return self.cfg_only 73 | 74 | @property 75 | def hidden_size(self): 76 | return self.config.hidden_size 77 | 78 | @property 79 | def num_patches(self): 80 | return (self.config.image_size // self.config.patch_size) ** 2 81 | -------------------------------------------------------------------------------- /flash_vstream/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import re 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": 'identity'} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential( 26 | nn.Linear(channels, channels), 27 | nn.GELU(), 28 | nn.Linear(channels, channels) 29 | ) 30 | def forward(self, x): 31 | x = self.pre_norm(x) 32 | return x + self.proj(x) 33 | 34 | 35 | def build_vision_projector(config, input_dim, delay_load=False, **kwargs): 36 | projector_type = getattr(config, 'mm_projector_type', 'linear') 37 | 38 | if projector_type == 'linear': 39 | return nn.Linear(input_dim, config.hidden_size) 40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 41 | if mlp_gelu_match: 42 | mlp_depth = int(mlp_gelu_match.group(1)) 43 | modules = [nn.Linear(input_dim, config.hidden_size)] 44 | for _ in range(1, mlp_depth): 45 | modules.append(nn.GELU()) 46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 47 | return nn.Sequential(*modules) 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /flash_vstream/serve/cli_video_stream.py: -------------------------------------------------------------------------------- 1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. 2 | # Based on https://github.com/haotian-liu/LLaVA. 3 | """ 4 | This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module. 5 | main process: CLI server I/O, LLM inference 6 | process-1: logger listener 7 | process-2: frame generator, 8 | process-3: frame memory manager 9 | Author: Haoji Zhang, Haotian Liu 10 | (This code is based on https://github.com/haotian-liu/LLaVA) 11 | """ 12 | import argparse 13 | import requests 14 | import logging 15 | import torch 16 | import numpy as np 17 | import time 18 | import os 19 | 20 | from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 21 | from flash_vstream.conversation import conv_templates, SeparatorStyle 22 | from flash_vstream.model.builder import load_pretrained_model 23 | from flash_vstream.utils import disable_torch_init 24 | from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 25 | 26 | from torch.multiprocessing import Process, Queue, Manager 27 | from transformers import TextStreamer 28 | from decord import VideoReader 29 | from datetime import datetime 30 | from PIL import Image 31 | from io import BytesIO 32 | 33 | class _Metric: 34 | def __init__(self): 35 | self._latest_value = None 36 | self._sum = 0.0 37 | self._max = 0.0 38 | self._count = 0 39 | 40 | @property 41 | def val(self): 42 | return self._latest_value 43 | 44 | @property 45 | def max(self): 46 | return self._max 47 | 48 | @property 49 | def avg(self): 50 | if self._count == 0: 51 | return float('nan') 52 | return self._sum / self._count 53 | 54 | def add(self, value): 55 | self._latest_value = value 56 | self._sum += value 57 | self._count += 1 58 | if value > self._max: 59 | self._max = value 60 | 61 | def __str__(self): 62 | latest_formatted = f"{self.val:.6f}" if self.val is not None else "None" 63 | average_formatted = f"{self.avg:.6f}" 64 | max_formatted = f"{self.max:.6f}" 65 | return f"{latest_formatted} ({average_formatted}, {max_formatted})" 66 | 67 | 68 | class MetricMeter: 69 | def __init__(self): 70 | self._metrics = {} 71 | 72 | def add(self, key, value): 73 | if key not in self._metrics: 74 | self._metrics[key] = _Metric() 75 | self._metrics[key].add(value) 76 | 77 | def val(self, key): 78 | metric = self._metrics.get(key) 79 | if metric is None or metric.val is None: 80 | raise ValueError(f"No values have been added for key '{key}'.") 81 | return metric.val 82 | 83 | def avg(self, key): 84 | metric = self._metrics.get(key) 85 | if metric is None: 86 | raise ValueError(f"No values have been added for key '{key}'.") 87 | return metric.avg 88 | 89 | def max(self, key): 90 | metric = self._metrics.get(key) 91 | if metric is None: 92 | raise ValueError(f"No values have been added for key '{key}'.") 93 | return metric.max 94 | 95 | def __getitem__(self, key): 96 | metric = self._metrics.get(key) 97 | if metric is None: 98 | raise KeyError(f"The key '{key}' does not exist.") 99 | return str(metric) 100 | 101 | def load_image(image_file): 102 | if image_file.startswith('http://') or image_file.startswith('https://'): 103 | response = requests.get(image_file) 104 | image = Image.open(BytesIO(response.content)).convert('RGB') 105 | else: 106 | image = Image.open(image_file).convert('RGB') 107 | return image 108 | 109 | def listener(queue, filename): 110 | ############## Start sub process-1: Listener ############# 111 | import sys, traceback 112 | root = logging.getLogger() 113 | root.setLevel(logging.DEBUG) 114 | # h = logging.StreamHandler(sys.stdout) 115 | h = logging.FileHandler(filename) 116 | f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s') 117 | h.setFormatter(f) 118 | root.addHandler(h) 119 | while True: 120 | try: 121 | record = queue.get() 122 | if record is None: # None is a signal to finish 123 | break 124 | logger = logging.getLogger(record.name) 125 | logger.handle(record) # No level or filter logic applied - just do it! 126 | except Exception: 127 | import sys, traceback 128 | print('Whoops! Problem:', file=sys.stderr) 129 | traceback.print_exc(file=sys.stderr) 130 | 131 | def worker_configurer(queue): 132 | h = logging.handlers.QueueHandler(queue) # Just the one handler needed 133 | root = logging.getLogger() 134 | root.addHandler(h) 135 | root.setLevel(logging.DEBUG) 136 | 137 | def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0): 138 | ############## Start sub process-2: Simulator ############# 139 | worker_configurer(log_queue) 140 | logger = logging.getLogger(__name__) 141 | logger.setLevel(logging.DEBUG) 142 | 143 | vr = VideoReader(video_file) 144 | sample_fps = round(vr.get_avg_fps() / video_fps) 145 | frame_idx = [i for i in range(0, len(vr), sample_fps)] 146 | video = vr.get_batch(frame_idx).asnumpy() 147 | video = np.repeat(video, 6, axis=0) 148 | length = video.shape[0] 149 | sleep_time = 1 / video_fps / play_speed 150 | time_meter = MetricMeter() 151 | logger.info(f'Simulator Process: start, length = {length}') 152 | try: 153 | for start in range(0, length): 154 | start_time = time.perf_counter() 155 | end = min(start + 1, length) 156 | video_clip = video[start:end] 157 | frame_queue.put(video_clip) 158 | if start > 0: 159 | time_meter.add('real_sleep', start_time - last_start) 160 | logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}') 161 | if end < length: 162 | time.sleep(sleep_time) 163 | last_start = start_time 164 | frame_queue.put(None) 165 | except Exception as e: 166 | print(f'Simulator Exception: {e}') 167 | time.sleep(0.1) 168 | logger.info(f'Simulator Process: end') 169 | 170 | def frame_memory_manager(model, image_processor, frame_queue, log_queue): 171 | ############## Start sub process-3: Memory Manager ############# 172 | worker_configurer(log_queue) 173 | logger = logging.getLogger(__name__) 174 | logger.setLevel(logging.DEBUG) 175 | 176 | time_meter = MetricMeter() 177 | logger.info(f'MemManager Process: start') 178 | frame_cnt = 0 179 | while True: 180 | try: 181 | video_clip = frame_queue.get() 182 | start_time = time.perf_counter() 183 | if video_clip is None: 184 | logger.info(f'MemManager: Ooops, get None') 185 | break 186 | logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue') 187 | image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values'] 188 | image = image.unsqueeze(0) 189 | image_tensor = image.to(model.device, dtype=torch.float16) 190 | # time_2 = time.perf_counter() 191 | logger.info(f'MemManager: Start embedding') 192 | with torch.inference_mode(): 193 | model.embed_video_streaming(image_tensor) 194 | logger.info(f'MemManager: End embedding') 195 | end_time = time.perf_counter() 196 | if frame_cnt > 0: 197 | time_meter.add('memory_latency', end_time - start_time) 198 | logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}') 199 | else: 200 | logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged') 201 | frame_cnt += video_clip.shape[0] 202 | except Exception as e: 203 | print(f'MemManager Exception: {e}') 204 | time.sleep(0.1) 205 | logger.info(f'MemManager Process: end') 206 | 207 | def main(args): 208 | # torch.multiprocessing.log_to_stderr(logging.DEBUG) 209 | torch.multiprocessing.set_start_method('spawn', force=True) 210 | disable_torch_init() 211 | 212 | log_queue = Queue() 213 | frame_queue = Queue(maxsize=10) 214 | processes = [] 215 | 216 | ############## Start listener process ############# 217 | p1 = Process(target=listener, args=(log_queue, args.log_file)) 218 | processes.append(p1) 219 | p1.start() 220 | 221 | ############## Start main process ############# 222 | worker_configurer(log_queue) 223 | logger = logging.getLogger(__name__) 224 | 225 | model_name = get_model_name_from_path(args.model_path) 226 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 227 | 228 | logger.info(f'Using conv_mode={args.conv_mode}') 229 | 230 | conv = conv_templates[args.conv_mode].copy() 231 | if "mpt" in model_name.lower(): 232 | roles = ('user', 'assistant') 233 | else: 234 | roles = conv.roles 235 | 236 | with Manager() as manager: 237 | image_tensor = None 238 | model.use_video_streaming_mode = True 239 | model.video_embedding_memory = manager.list() 240 | if args.video_max_frames is not None: 241 | model.config.video_max_frames = args.video_max_frames 242 | logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}') 243 | 244 | logger.info(f'Important: set video_fps = {args.video_fps}') 245 | logger.info(f'Important: set play_speed = {args.play_speed}') 246 | 247 | ############## Start simulator process ############# 248 | p2 = Process(target=video_stream_similator, 249 | args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed)) 250 | processes.append(p2) 251 | p2.start() 252 | 253 | ############## Start memory manager process ############# 254 | p3 = Process(target=frame_memory_manager, 255 | args=(model, image_processor, frame_queue, log_queue)) 256 | processes.append(p3) 257 | p3.start() 258 | 259 | # start QA server 260 | start_time = datetime.now() 261 | time_meter = MetricMeter() 262 | conv_cnt = 0 263 | while True: 264 | time.sleep(5) 265 | try: 266 | # inp = input(f"{roles[0]}: ") 267 | inp = "what is in the video?" 268 | except EOFError: 269 | inp = "" 270 | if not inp: 271 | print("exit...") 272 | break 273 | 274 | # 获取当前时间 275 | now = datetime.now() 276 | conv_start_time = time.perf_counter() 277 | # 将当前时间格式化为字符串 278 | current_time = now.strftime("%H:%M:%S") 279 | duration = now.timestamp() - start_time.timestamp() 280 | 281 | # 打印当前时间 282 | print("\nCurrent Time:", current_time, "Run for:", duration) 283 | print(f"{roles[0]}: {inp}", end="\n") 284 | print(f"{roles[1]}: ", end="") 285 | # every conversation is a new conversation 286 | conv = conv_templates[args.conv_mode].copy() 287 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 288 | conv.append_message(conv.roles[0], inp) 289 | 290 | conv.append_message(conv.roles[1], None) 291 | prompt = conv.get_prompt() 292 | 293 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 294 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 295 | keywords = [stop_str] 296 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 297 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 298 | 299 | llm_start_time = time.perf_counter() 300 | with torch.inference_mode(): 301 | output_ids = model.generate( 302 | input_ids, 303 | images=image_tensor, 304 | do_sample=True if args.temperature > 0 else False, 305 | temperature=args.temperature, 306 | max_new_tokens=args.max_new_tokens, 307 | streamer=streamer, 308 | use_cache=True, 309 | stopping_criteria=[stopping_criteria] 310 | ) 311 | llm_end_time = time.perf_counter() 312 | 313 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 314 | conv.messages[-1][-1] = outputs 315 | conv_end_time = time.perf_counter() 316 | if conv_cnt > 0: 317 | time_meter.add('conv_latency', conv_end_time - conv_start_time) 318 | time_meter.add('llm_latency', llm_end_time - llm_start_time) 319 | time_meter.add('real_sleep', conv_start_time - last_conv_start_time) 320 | logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}') 321 | else: 322 | logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}') 323 | conv_cnt += 1 324 | last_conv_start_time = conv_start_time 325 | 326 | for p in processes: 327 | p.terminate() 328 | print("All processes finished.") 329 | 330 | 331 | if __name__ == "__main__": 332 | parser = argparse.ArgumentParser() 333 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 334 | parser.add_argument("--model-base", type=str, default=None) 335 | parser.add_argument("--image-file", type=str, default=None) 336 | parser.add_argument("--video-file", type=str, default=None) 337 | parser.add_argument("--device", type=str, default="cuda") 338 | parser.add_argument("--conv-mode", type=str, default="vicuna_v1") 339 | parser.add_argument("--temperature", type=float, default=0.2) 340 | parser.add_argument("--max-new-tokens", type=int, default=512) 341 | parser.add_argument("--load-8bit", action="store_true") 342 | parser.add_argument("--load-4bit", action="store_true") 343 | parser.add_argument("--debug", action="store_true") 344 | 345 | parser.add_argument("--log-file", type=str, default="tmp_cli.log") 346 | parser.add_argument("--use_1process", action="store_true") 347 | parser.add_argument("--video_max_frames", type=int, default=None) 348 | parser.add_argument("--video_fps", type=float, default=1.0) 349 | parser.add_argument("--play_speed", type=float, default=1.0) 350 | args = parser.parse_args() 351 | main(args) 352 | -------------------------------------------------------------------------------- /flash_vstream/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | from typing import Optional, Tuple 4 | import warnings 5 | 6 | import torch 7 | 8 | import transformers 9 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 10 | 11 | try: 12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 13 | except ImportError: 14 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 15 | from flash_attn.bert_padding import unpad_input, pad_input 16 | 17 | 18 | def forward( 19 | self, 20 | hidden_states: torch.Tensor, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | position_ids: Optional[torch.Tensor] = None, 23 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 24 | output_attentions: bool = False, 25 | use_cache: bool = False, 26 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 27 | if output_attentions: 28 | warnings.warn( 29 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 30 | ) 31 | 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states) 41 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states) 46 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) # shape: (b, num_heads, s, head_dim) 49 | 50 | kv_seq_len = key_states.shape[-2] 51 | if past_key_value is not None: 52 | kv_seq_len += past_key_value[0].shape[-2] 53 | 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | query_states, key_states = apply_rotary_pos_emb( 56 | query_states, key_states, cos, sin, position_ids 57 | ) 58 | 59 | if past_key_value is not None: 60 | # reuse k, v 61 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 62 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 63 | 64 | past_key_value = (key_states, value_states) if use_cache else None 65 | 66 | # repeat k/v heads if n_kv_heads < n_heads 67 | key_states = repeat_kv(key_states, self.num_key_value_groups) 68 | value_states = repeat_kv(value_states, self.num_key_value_groups) 69 | 70 | # Transform the data into the format required by flash attention 71 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 72 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 77 | cu_q_lens = torch.arange( 78 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 79 | ) 80 | max_s = q_len 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = output.view(bsz, q_len, -1) 85 | else: 86 | qkv = qkv.reshape(bsz, q_len, -1) 87 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 88 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 89 | output_unpad = flash_attn_unpadded_qkvpacked_func( 90 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 91 | ) 92 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 93 | output = pad_input(output_unpad, indices, bsz, q_len) 94 | 95 | return self.o_proj(output), None, past_key_value 96 | 97 | 98 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 99 | # requires the attention mask to be the same as the key_padding_mask 100 | def _prepare_decoder_attention_mask( 101 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 102 | ): 103 | # [bsz, seq_len] 104 | return attention_mask 105 | 106 | 107 | def replace_llama_attn_with_flash_attn(): 108 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 109 | if cuda_major < 8: 110 | warnings.warn( 111 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 112 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 113 | ) 114 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 115 | _prepare_decoder_attention_mask 116 | ) 117 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 118 | -------------------------------------------------------------------------------- /flash_vstream/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | """ 4 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 5 | """ 6 | 7 | import logging 8 | import math 9 | from typing import Optional, Tuple 10 | 11 | import torch 12 | import transformers.models.llama.modeling_llama 13 | from torch import nn 14 | 15 | try: 16 | import xformers.ops 17 | except ImportError: 18 | logging.error("xformers not found! Please install it before trying to use it.") 19 | 20 | 21 | def replace_llama_attn_with_xformers_attn(): 22 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 23 | 24 | 25 | def xformers_forward( 26 | self, 27 | hidden_states: torch.Tensor, 28 | attention_mask: Optional[torch.Tensor] = None, 29 | position_ids: Optional[torch.LongTensor] = None, 30 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 31 | output_attentions: bool = False, 32 | use_cache: bool = False, 33 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 34 | # pylint: disable=duplicate-code 35 | bsz, q_len, _ = hidden_states.size() 36 | 37 | query_states = ( 38 | self.q_proj(hidden_states) 39 | .view(bsz, q_len, self.num_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | key_states = ( 43 | self.k_proj(hidden_states) 44 | .view(bsz, q_len, self.num_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) 47 | value_states = ( 48 | self.v_proj(hidden_states) 49 | .view(bsz, q_len, self.num_heads, self.head_dim) 50 | .transpose(1, 2) 51 | ) 52 | 53 | kv_seq_len = key_states.shape[-2] 54 | if past_key_value is not None: 55 | kv_seq_len += past_key_value[0].shape[-2] 56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 57 | ( 58 | query_states, 59 | key_states, 60 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 61 | query_states, key_states, cos, sin, position_ids 62 | ) 63 | # [bsz, nh, t, hd] 64 | 65 | if past_key_value is not None: 66 | # reuse k, v, self_attention 67 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 68 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 69 | 70 | past_key_value = (key_states, value_states) if use_cache else None 71 | 72 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 73 | if not output_attentions: 74 | query_states = query_states.transpose(1, 2) 75 | key_states = key_states.transpose(1, 2) 76 | value_states = value_states.transpose(1, 2) 77 | 78 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 79 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 80 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 81 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 82 | attn_output = xformers.ops.memory_efficient_attention( 83 | query_states, key_states, value_states, attn_bias=None 84 | ) 85 | else: 86 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 87 | attn_output = xformers.ops.memory_efficient_attention( 88 | query_states, 89 | key_states, 90 | value_states, 91 | attn_bias=xformers.ops.LowerTriangularMask(), 92 | ) 93 | attn_weights = None 94 | else: 95 | attn_weights = torch.matmul( 96 | query_states, key_states.transpose(2, 3) 97 | ) / math.sqrt(self.head_dim) 98 | 99 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 100 | raise ValueError( 101 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 102 | f" {attn_weights.size()}" 103 | ) 104 | 105 | if attention_mask is not None: 106 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 107 | raise ValueError( 108 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 109 | ) 110 | attn_weights = attn_weights + attention_mask 111 | attn_weights = torch.max( 112 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 113 | ) 114 | 115 | # upcast attention to fp32 116 | attn_weights = nn.functional.softmax( 117 | attn_weights, dim=-1, dtype=torch.float32 118 | ).to(query_states.dtype) 119 | attn_output = torch.matmul(attn_weights, value_states) 120 | 121 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 122 | raise ValueError( 123 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 124 | f" {attn_output.size()}" 125 | ) 126 | 127 | attn_output = attn_output.transpose(1, 2) 128 | 129 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 130 | attn_output = self.o_proj(attn_output) 131 | return attn_output, attn_weights, past_key_value 132 | -------------------------------------------------------------------------------- /flash_vstream/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. 2 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 3 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 5 | 6 | # Need to call this before importing transformers. 7 | from flash_vstream.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 8 | 9 | replace_llama_attn_with_flash_attn() 10 | 11 | from flash_vstream.train.train import train 12 | 13 | if __name__ == "__main__": 14 | train() 15 | -------------------------------------------------------------------------------- /flash_vstream/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 4 | 5 | # Need to call this before importing transformers. 6 | from flash_vstream.train.llama_xformers_attn_monkey_patch import ( 7 | replace_llama_attn_with_xformers_attn, 8 | ) 9 | 10 | replace_llama_attn_with_xformers_attn() 11 | 12 | from flash_vstream.train.train import train 13 | 14 | if __name__ == "__main__": 15 | train() 16 | -------------------------------------------------------------------------------- /flash_vstream/train/vstream_trainer.py: -------------------------------------------------------------------------------- 1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. 2 | # ------------------------------------------------------------------------ 3 | # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright: 4 | # Copyright 2023 Haotian Liu 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | import torch 20 | import torch.nn as nn 21 | 22 | from torch.utils.data import Sampler 23 | 24 | from transformers import Trainer 25 | from transformers.trainer import ( 26 | is_sagemaker_mp_enabled, 27 | get_parameter_names, 28 | has_length, 29 | ALL_LAYERNORM_LAYERS, 30 | ShardedDDPOption, 31 | logger, 32 | ) 33 | from typing import List, Optional 34 | 35 | 36 | def maybe_zero_3(param, ignore_status=False, name=None): 37 | from deepspeed import zero 38 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 39 | if hasattr(param, "ds_id"): 40 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 41 | if not ignore_status: 42 | print(name, 'no ignore status') 43 | with zero.GatheredParameters([param]): 44 | param = param.data.detach().cpu().clone() 45 | else: 46 | param = param.detach().cpu().clone() 47 | return param 48 | 49 | 50 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 51 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 52 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 53 | return to_return 54 | 55 | 56 | def split_to_even_chunks(indices, lengths, num_chunks): 57 | """ 58 | Split a list of indices into `chunks` chunks of roughly equal lengths. 59 | """ 60 | 61 | if len(indices) % num_chunks != 0: 62 | return [indices[i::num_chunks] for i in range(num_chunks)] 63 | 64 | num_indices_per_chunk = len(indices) // num_chunks 65 | 66 | chunks = [[] for _ in range(num_chunks)] 67 | chunks_lengths = [0 for _ in range(num_chunks)] 68 | for index in indices: 69 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 70 | chunks[shortest_chunk].append(index) 71 | chunks_lengths[shortest_chunk] += lengths[index] 72 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 73 | chunks_lengths[shortest_chunk] = float("inf") 74 | 75 | return chunks 76 | 77 | 78 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 79 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 80 | assert all(l != 0 for l in lengths), "Should not have zero length." 81 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 82 | # all samples are in the same modality 83 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 84 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 85 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 86 | 87 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 88 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 89 | megabatch_size = world_size * batch_size 90 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 91 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 92 | 93 | last_mm = mm_megabatches[-1] 94 | last_lang = lang_megabatches[-1] 95 | additional_batch = last_mm + last_lang 96 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 97 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 98 | megabatches = [megabatches[i] for i in megabatch_indices] 99 | 100 | if len(additional_batch) > 0: 101 | megabatches.append(sorted(additional_batch)) 102 | 103 | return [i for megabatch in megabatches for i in megabatch] 104 | 105 | 106 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 107 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 108 | indices = torch.randperm(len(lengths), generator=generator) 109 | megabatch_size = world_size * batch_size 110 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 111 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 112 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 113 | 114 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 115 | 116 | 117 | class LengthGroupedSampler(Sampler): 118 | r""" 119 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 120 | keeping a bit of randomness. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | batch_size: int, 126 | world_size: int, 127 | lengths: Optional[List[int]] = None, 128 | generator=None, 129 | group_by_modality: bool = False, 130 | ): 131 | if lengths is None: 132 | raise ValueError("Lengths must be provided.") 133 | 134 | self.batch_size = batch_size 135 | self.world_size = world_size 136 | self.lengths = lengths 137 | self.generator = generator 138 | self.group_by_modality = group_by_modality 139 | 140 | def __len__(self): 141 | return len(self.lengths) 142 | 143 | def __iter__(self): 144 | if self.group_by_modality: 145 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 146 | else: 147 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 148 | return iter(indices) 149 | 150 | 151 | class VStreamTrainer(Trainer): 152 | 153 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 154 | if self.train_dataset is None or not has_length(self.train_dataset): 155 | return None 156 | 157 | if self.args.group_by_modality_length: 158 | lengths = self.train_dataset.modality_lengths 159 | return LengthGroupedSampler( 160 | self.args.train_batch_size, 161 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 162 | lengths=lengths, 163 | group_by_modality=True, 164 | ) 165 | else: 166 | return super()._get_train_sampler() 167 | 168 | def create_optimizer(self): 169 | """ 170 | Setup the optimizer. 171 | 172 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 173 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 174 | """ 175 | if is_sagemaker_mp_enabled(): 176 | return super().create_optimizer() 177 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 178 | return super().create_optimizer() 179 | 180 | opt_model = self.model 181 | 182 | if self.optimizer is None: 183 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 184 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 185 | if self.args.mm_projector_lr is not None: 186 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 187 | optimizer_grouped_parameters = [ 188 | { 189 | "params": [ 190 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 191 | ], 192 | "weight_decay": self.args.weight_decay, 193 | }, 194 | { 195 | "params": [ 196 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 197 | ], 198 | "weight_decay": 0.0, 199 | }, 200 | { 201 | "params": [ 202 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 203 | ], 204 | "weight_decay": self.args.weight_decay, 205 | "lr": self.args.mm_projector_lr, 206 | }, 207 | { 208 | "params": [ 209 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 210 | ], 211 | "weight_decay": 0.0, 212 | "lr": self.args.mm_projector_lr, 213 | }, 214 | ] 215 | else: 216 | optimizer_grouped_parameters = [ 217 | { 218 | "params": [ 219 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 220 | ], 221 | "weight_decay": self.args.weight_decay, 222 | }, 223 | { 224 | "params": [ 225 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 226 | ], 227 | "weight_decay": 0.0, 228 | }, 229 | ] 230 | 231 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 232 | 233 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 234 | if optimizer_cls.__name__ == "Adam8bit": 235 | import bitsandbytes 236 | 237 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 238 | 239 | skipped = 0 240 | for module in opt_model.modules(): 241 | if isinstance(module, nn.Embedding): 242 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 243 | logger.info(f"skipped {module}: {skipped/2**20}M params") 244 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 245 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 246 | logger.info(f"skipped: {skipped/2**20}M params") 247 | 248 | return self.optimizer 249 | -------------------------------------------------------------------------------- /flash_vstream/utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import datetime 4 | import logging 5 | import logging.handlers 6 | import os 7 | import sys 8 | 9 | import requests 10 | 11 | from flash_vstream.constants import LOGDIR 12 | 13 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 14 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 15 | 16 | handler = None 17 | 18 | 19 | def build_logger(logger_name, logger_filename): 20 | global handler 21 | 22 | formatter = logging.Formatter( 23 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 24 | datefmt="%Y-%m-%d %H:%M:%S", 25 | ) 26 | 27 | # Set the format of root handlers 28 | if not logging.getLogger().handlers: 29 | logging.basicConfig(level=logging.INFO) 30 | logging.getLogger().handlers[0].setFormatter(formatter) 31 | 32 | # Redirect stdout and stderr to loggers 33 | stdout_logger = logging.getLogger("stdout") 34 | stdout_logger.setLevel(logging.INFO) 35 | sl = StreamToLogger(stdout_logger, logging.INFO) 36 | sys.stdout = sl 37 | 38 | stderr_logger = logging.getLogger("stderr") 39 | stderr_logger.setLevel(logging.ERROR) 40 | sl = StreamToLogger(stderr_logger, logging.ERROR) 41 | sys.stderr = sl 42 | 43 | # Get logger 44 | logger = logging.getLogger(logger_name) 45 | logger.setLevel(logging.INFO) 46 | 47 | # Add a file handler for all loggers 48 | if handler is None: 49 | os.makedirs(LOGDIR, exist_ok=True) 50 | filename = os.path.join(LOGDIR, logger_filename) 51 | handler = logging.handlers.TimedRotatingFileHandler( 52 | filename, when='D', utc=True, encoding='UTF-8') 53 | handler.setFormatter(formatter) 54 | 55 | for name, item in logging.root.manager.loggerDict.items(): 56 | if isinstance(item, logging.Logger): 57 | item.addHandler(handler) 58 | 59 | return logger 60 | 61 | 62 | class StreamToLogger(object): 63 | """ 64 | Fake file-like stream object that redirects writes to a logger instance. 65 | """ 66 | def __init__(self, logger, log_level=logging.INFO): 67 | self.terminal = sys.stdout 68 | self.logger = logger 69 | self.log_level = log_level 70 | self.linebuf = '' 71 | 72 | def __getattr__(self, attr): 73 | return getattr(self.terminal, attr) 74 | 75 | def write(self, buf): 76 | temp_linebuf = self.linebuf + buf 77 | self.linebuf = '' 78 | for line in temp_linebuf.splitlines(True): 79 | # From the io.TextIOWrapper docs: 80 | # On output, if newline is None, any '\n' characters written 81 | # are translated to the system default line separator. 82 | # By default sys.stdout.write() expects '\n' newlines and then 83 | # translates them so this is still cross platform. 84 | if line[-1] == '\n': 85 | self.logger.log(self.log_level, line.rstrip()) 86 | else: 87 | self.linebuf += line 88 | 89 | def flush(self): 90 | if self.linebuf != '': 91 | self.logger.log(self.log_level, self.linebuf.rstrip()) 92 | self.linebuf = '' 93 | 94 | 95 | def disable_torch_init(): 96 | """ 97 | Disable the redundant torch default initialization to accelerate model creation. 98 | """ 99 | import torch 100 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 101 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 102 | 103 | 104 | def violates_moderation(text): 105 | """ 106 | Check whether the text violates OpenAI moderation API. 107 | """ 108 | url = "https://api.openai.com/v1/moderations" 109 | headers = {"Content-Type": "application/json", 110 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 111 | text = text.replace("\n", "") 112 | data = "{" + '"input": ' + f'"{text}"' + "}" 113 | data = data.encode("utf-8") 114 | try: 115 | ret = requests.post(url, headers=headers, data=data, timeout=5) 116 | flagged = ret.json()["results"][0]["flagged"] 117 | except requests.exceptions.RequestException as e: 118 | flagged = False 119 | except KeyError as e: 120 | flagged = False 121 | 122 | return flagged 123 | 124 | 125 | def pretty_print_semaphore(semaphore): 126 | if semaphore is None: 127 | return "None" 128 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 129 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vstream" 7 | version = "1.0" 8 | description = "Flash-VStream" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "numpy", 17 | "tokenizers>=0.12.1", 18 | "torch==2.0.1", 19 | "torchvision==0.15.2", 20 | "wandb", 21 | "tensorboard", 22 | "tensorboardX", 23 | "httpx==0.23.0", 24 | "deepspeed==0.9.5", 25 | "peft==0.4.0", 26 | "transformers==4.31.0", 27 | "accelerate==0.21.0", 28 | "bitsandbytes==0.41.0", 29 | "scikit-learn==1.2.2", 30 | "sentencepiece==0.1.99", 31 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 32 | "decord", 33 | "openai==0.28.0", 34 | ] 35 | 36 | [project.urls] 37 | "Homepage" = "https://github.com/zhang9302002/Flash-VStream" 38 | "Bug Tracker" = "https://github.com/zhang9302002/Flash-VStream/issues" 39 | 40 | [tool.setuptools.packages.find] 41 | exclude = ["checkpoints*", "data*", "docs", "scripts*"] 42 | 43 | [tool.wheel] 44 | exclude = ["checkpoints*", "data*", "docs", "scripts*"] -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set up python environment 4 | conda activate vstream 5 | 6 | # set important configurations 7 | ngpus=8 8 | gputype=A100 9 | 10 | # auto calculate configurations 11 | gpus_list=$(seq -s, 0 $((ngpus - 1))) 12 | date_device="$(date +%m%d)_${ngpus}${gputype}" 13 | 14 | echo start eval 15 | # define your openai info here 16 | 17 | for dataset in actnet nextoe msvd msrvtt vsmovienet vsego4d realtime_vsmovienet realtime_vsego4d 18 | do 19 | echo start eval ${dataset} 20 | python -m flash_vstream.eval_video.eval_any_dataset_features \ 21 | --model-path your_model_checkpoint_path \ 22 | --dataset ${dataset} \ 23 | --num_chunks $ngpus \ 24 | --api_key $OPENAIKEY \ 25 | --api_base $OPENAIBASE \ 26 | --api_type $OPENAITYPE \ 27 | --api_version $OPENAIVERSION \ 28 | --test \ 29 | >> ${date_device}_vstream-7b-eval-${dataset}.log 2>&1 30 | done 31 | 32 | -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/haotian-liu/LLaVA. 2 | 3 | import argparse 4 | from llama_vstream.model.builder import load_pretrained_model 5 | from llama_vstream.mm_utils import get_model_name_from_path 6 | 7 | 8 | def merge_lora(args): 9 | model_name = get_model_name_from_path(args.model_path) 10 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 11 | 12 | model.save_pretrained(args.save_model_path) 13 | tokenizer.save_pretrained(args.save_model_path) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model-path", type=str, required=True) 19 | parser.add_argument("--model-base", type=str, required=True) 20 | parser.add_argument("--save-model-path", type=str, required=True) 21 | 22 | args = parser.parse_args() 23 | 24 | merge_lora(args) 25 | -------------------------------------------------------------------------------- /scripts/realtime_cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m flash_vstream.serve.cli_video_stream \ 4 | --model-path your_model_checkpoint_path \ 5 | --video-file assets/example.mp4 \ 6 | --conv-mode vicuna_v1 --temperature 0.0 \ 7 | --video_max_frames 1200 \ 8 | --video_fps 1.0 --play_speed 1.0 \ 9 | --log-file realtime_cli.log 10 | -------------------------------------------------------------------------------- /scripts/train_and_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set up python environment 4 | conda activate vstream 5 | 6 | # set important configurations 7 | type=weighted_kmeans 8 | suffix=STAR 9 | cur_length=1 10 | cur_size=8 11 | long_length=25 12 | long_size=4 13 | Turing_length=25 14 | Turing_size=1 15 | ngpus=8 16 | gputype=A100 17 | 18 | # auto calculate configurations 19 | gpus_list=$(seq -s, 0 $((ngpus - 1))) 20 | date_device="$(date +%m%d)_${ngpus}${gputype}" 21 | 22 | echo start pretrain 23 | deepspeed --master_addr 127.0.0.1 --master_port 12345 --include localhost:${gpus_list} flash_vstream/train/train_mem.py \ 24 | --deepspeed ./scripts/zero0.json \ 25 | --model_name_or_path ./ckpt/vicuna-7b-v1.5 \ 26 | --version plain \ 27 | --data_path ./data/pretrain/llava_558k_with_webvid.json \ 28 | --image_folder ./data/pretrain/image_features \ 29 | --video_folder ./data/pretrain/video_features \ 30 | --vision_tower ./ckpt/clip-vit-large-patch14 \ 31 | --mm_projector_type mlp2x_gelu \ 32 | --tune_mm_mlp_adapter True \ 33 | --mm_vision_select_layer -2 \ 34 | --mm_use_im_start_end False \ 35 | --mm_use_im_patch_token False \ 36 | --video_fps 1 \ 37 | --compress_type mean \ 38 | --compress_size ${cur_size} \ 39 | --compress_long_memory_size ${long_size} \ 40 | --compress_Turing_memory_size ${Turing_size} \ 41 | --video_max_frames $((cur_length + long_length)) \ 42 | --video_current_memory_length ${cur_length} \ 43 | --video_long_memory_length ${long_length} \ 44 | --video_Turing_memory_length ${Turing_length} \ 45 | --video_sample_type ${type} \ 46 | --group_by_modality_length \ 47 | --bf16 \ 48 | --output_dir ./checkpoints-pretrain/vstream-7b-pretrain-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix} \ 49 | --num_train_epochs 1 \ 50 | --per_device_train_batch_size 32 \ 51 | --per_device_eval_batch_size 4 \ 52 | --gradient_accumulation_steps 1 \ 53 | --evaluation_strategy "no" \ 54 | --save_strategy "steps" \ 55 | --save_steps 100 \ 56 | --save_total_limit 10 \ 57 | --learning_rate 1e-3 \ 58 | --weight_decay 0. \ 59 | --warmup_ratio 0.03 \ 60 | --lr_scheduler_type "cosine" \ 61 | --logging_steps 1 \ 62 | --model_max_length 2048 \ 63 | --gradient_checkpointing True \ 64 | --dataloader_num_workers 4 \ 65 | --lazy_preprocess True \ 66 | --report_to tensorboard \ 67 | >> ${date_device}_vstream-7b-pretrain-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}.log 2>&1 68 | 69 | echo start finetune 70 | deepspeed --master_addr 127.0.2.1 --master_port 12345 --include localhost:${gpus_list} flash_vstream/train/train_mem.py \ 71 | --deepspeed ./scripts/zero1.json \ 72 | --model_name_or_path ./checkpoints-pretrain/vstream-7b-pretrain-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}/checkpoint-3000 \ 73 | --version v1 \ 74 | --data_path ./data/finetune/llava_v1_5_mix665k_with_video_chatgpt.json \ 75 | --image_folder ./data/finetune/image_features \ 76 | --video_folder ./data/finetune/video_features \ 77 | --vision_tower ./ckpt/clip-vit-large-patch14 \ 78 | --mm_projector_type mlp2x_gelu \ 79 | --mm_vision_select_layer -2 \ 80 | --mm_use_im_start_end False \ 81 | --mm_use_im_patch_token False \ 82 | --image_aspect_ratio pad \ 83 | --video_fps 1 \ 84 | --compress_type mean \ 85 | --compress_size ${cur_size} \ 86 | --compress_long_memory_size ${long_size} \ 87 | --compress_Turing_memory_size ${Turing_size} \ 88 | --video_max_frames $((cur_length + long_length)) \ 89 | --video_current_memory_length ${cur_length} \ 90 | --video_long_memory_length ${long_length} \ 91 | --video_Turing_memory_length ${Turing_length} \ 92 | --video_sample_type ${type} \ 93 | --group_by_modality_length \ 94 | --bf16 \ 95 | --output_dir ./checkpoints-finetune/vstream-7b-finetune-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix} \ 96 | --num_train_epochs 1 \ 97 | --per_device_train_batch_size 16 \ 98 | --per_device_eval_batch_size 4 \ 99 | --gradient_accumulation_steps 1 \ 100 | --evaluation_strategy "no" \ 101 | --save_strategy "steps" \ 102 | --save_steps 100 \ 103 | --save_total_limit 10 \ 104 | --learning_rate 2e-5 \ 105 | --weight_decay 0. \ 106 | --warmup_ratio 0.03 \ 107 | --lr_scheduler_type "cosine" \ 108 | --logging_steps 1 \ 109 | --model_max_length 2048 \ 110 | --gradient_checkpointing True \ 111 | --dataloader_num_workers 4 \ 112 | --lazy_preprocess True \ 113 | --report_to tensorboard \ 114 | >> ${date_device}_vstream-7b-finetune-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}.log 2>&1 115 | 116 | 117 | echo start eval 118 | # define your openai info here 119 | 120 | for dataset in actnet nextoe msvd msrvtt vsmovienet vsego4d realtime_vsmovienet realtime_vsego4d 121 | do 122 | echo start eval ${dataset} 123 | python -m flash_vstream.eval_video.eval_any_dataset_features \ 124 | --model-path checkpoints-finetune/vstream-7b-finetune-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}/checkpoint-5900 \ 125 | --dataset ${dataset} \ 126 | --num_chunks $ngpus \ 127 | --api_key $OPENAIKEY \ 128 | --api_base $OPENAIBASE \ 129 | --api_type $OPENAITYPE \ 130 | --api_version $OPENAIVERSION \ 131 | >> ${date_device}_vstream-7b-eval-${dataset}-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}.log 2>&1 132 | done 133 | -------------------------------------------------------------------------------- /scripts/zero0.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 0 18 | } 19 | } -------------------------------------------------------------------------------- /scripts/zero1.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 1 18 | } 19 | } -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /vstream.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: vstream 3 | Version: 1.0 4 | Summary: Flash-VStream 5 | Project-URL: Homepage, https://github.com/zhang9302002/Flash-VStream 6 | Project-URL: Bug Tracker, https://github.com/zhang9302002/Flash-VStream/issues 7 | Classifier: Programming Language :: Python :: 3 8 | Classifier: License :: OSI Approved :: Apache Software License 9 | Requires-Python: >=3.10 10 | Description-Content-Type: text/markdown 11 | License-File: LICENSE 12 | Requires-Dist: numpy 13 | Requires-Dist: tokenizers>=0.12.1 14 | Requires-Dist: torch==2.0.1 15 | Requires-Dist: torchvision==0.15.2 16 | Requires-Dist: wandb 17 | Requires-Dist: tensorboard 18 | Requires-Dist: tensorboardX 19 | Requires-Dist: httpx==0.24.1 20 | Requires-Dist: deepspeed==0.9.5 21 | Requires-Dist: peft==0.4.0 22 | Requires-Dist: transformers==4.31.0 23 | Requires-Dist: accelerate==0.21.0 24 | Requires-Dist: bitsandbytes==0.41.0 25 | Requires-Dist: scikit-learn==1.2.2 26 | Requires-Dist: sentencepiece==0.1.99 27 | Requires-Dist: einops==0.6.1 28 | Requires-Dist: einops-exts==0.0.4 29 | Requires-Dist: timm==0.6.13 30 | Requires-Dist: decord 31 | Requires-Dist: openai==0.28.0 32 | 33 | # Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams 34 | 35 | 36 | Haoji Zhang\*, 37 | Yiqin Wang\*, 38 | Yansong Tang †, 39 | Yong Liu, 40 | Jiashi Feng, 41 | Jifeng Dai, 42 | Xiaojie Jin†‡ 43 | 44 | \* Equally contributing first authors, †Correspondence, ‡Project Lead 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=flash-vstream-memory-based-real-time) 55 | 56 | 57 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=flash-vstream-memory-based-real-time) 58 | 59 | 60 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/question-answering-on-next-qa-open-ended)](https://paperswithcode.com/sota/question-answering-on-next-qa-open-ended?p=flash-vstream-memory-based-real-time) 61 | 62 | 63 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flash-vstream-memory-based-real-time/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=flash-vstream-memory-based-real-time) 64 | 65 | 66 | 67 | 68 | 69 | We presented **Flash-VStream**, a noval LMM able to process extremely long video streams in real-time and respond to user queries simultaneously. 70 | 71 | We also proposed **VStream-QA**, a novel question answering benchmark specifically designed for online video streaming understanding. 72 | 73 |

74 | 75 |

76 |

77 | 78 |

79 | 80 | ## News 81 | - [2024/6/15] 🏅 Our team won the 1st Place at [Long-Term Video Question Answering Challenge](https://sites.google.com/view/loveucvpr24/track1) of [LOVEU Workshop@CVPR'24](https://sites.google.com/view/loveucvpr24/home). We used a Hierarchical Memory model based on Flash-VStream-7b. 82 | 83 | - [2024/06/12] 🔥 Flash-VStream is coming! We release the 84 | [homepage](https://invinciblewyq.github.io/vstream-page), 85 | [paper](https://arxiv.org/abs/2406.08085v1), 86 | [code](https://github.com/IVGSZ/Flash-VStream) 87 | and [model](https://huggingface.co/IVGSZ/Flash-VStream-7b) 88 | for Flash-VStream. 89 | We release the [dataset](https://huggingface.co/datasets/IVGSZ/VStream-QA) for VStream-QA benchmark. 90 | 91 | ## Contents 92 | - [Install](#install) 93 | - [Model](#model) 94 | - [Preparation](#preparation) 95 | - [Train](#train) 96 | - [Evaluation](#evaluation) 97 | - [Real-time CLI Inference](#Real-time-CLI-Inference) 98 | - [VStream-QA Benchmark](#VStream-QA-Benchmark) 99 | - [Citation](#citation) 100 | - [Acknowledgement](#acknowledgement) 101 | - [License](#license) 102 | 103 | ## Install 104 | Please follow the instructions below to install the required packages. 105 | 1. Clone this repository 106 | 107 | 2. Install Package 108 | ```bash 109 | conda create -n vstream python=3.10 -y 110 | conda activate vstream 111 | cd Flash-VStream 112 | pip install --upgrade pip 113 | pip install -e . 114 | ``` 115 | 116 | 3. Install additional packages for training cases 117 | ```bash 118 | pip install ninja 119 | pip install flash-attn --no-build-isolation 120 | ``` 121 | 122 | ## Model 123 | 124 | We provide our Flash-VStream models after Stage 1 and 2 finetuning: 125 | 126 | | Model | Weight | Initialized from LLM | Initialized from ViT | 127 | | --- | --- | --- | --- | 128 | | Flash-VStream-7b | [Flash-VStream-7b](https://huggingface.co/IVGSZ/Flash-VStream-7b) | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) | 129 | 130 | 131 | ## Preparation 132 | 133 | ### Dataset 134 | 135 | **Image VQA Dataset.** 136 | Please organize the training Image VQA training data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md) and evaluation data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md). 137 | Please put the pretraining data, finetuning data, and evaluation data in `pretrain`, `finetune`, and `eval_video` folder following [Structure](#structure). 138 | 139 | **Video VQA Dataset.** 140 | please download the 2.5M subset from [WebVid](https://maxbain.com/webvid-dataset/) and ActivityNet dataset from [official website](http://activity-net.org/download.html) or [video-chatgpt](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/docs/train_video_chatgpt.md). 141 | 142 | If you want to perform evaluation, please also download corresponding files of 143 | [ActivityNet-QA](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md) 144 | and [NExT-QA-OE](https://github.com/doc-doc/NExT-QA). 145 | You can download 146 | [MSVD-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EUNEXqg8pctPq3WZPHb4Fd8BYIxHO5qPCnU6aWsrV1O4JQ?e=guynwu) 147 | and [MSRVTT-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EcEXh1HfTXhLrRnuwHbl15IBJeRop-d50Q90njHmhvLwtA?e=SE24eG) from LLaMA-VID. 148 | 149 | 150 | **Meta Info.** 151 | For meta info of training data, please download the following files and organize them as in [Structure](#structure). 152 | 153 | | Training Stage | Data file name | Size | 154 | | --- | --- | ---: | 155 | | Pretrain | [llava_558k_with_webvid.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 254 MB | 156 | | Finetune | [llava_v1_5_mix665k_with_video_chatgpt.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 860 MB | 157 | 158 | For meta info of evaluation data, please reformat each QA list to a json file named `test_qa.json` under [Structure](#structure) with format like this: 159 | 160 | ```json 161 | [ 162 | { 163 | "video_id": "v_1QIUV7WYKXg", 164 | "question": "is the athlete wearing trousers", 165 | "id": "v_1QIUV7WYKXg_3", 166 | "answer": "no", 167 | "answer_type": 3, 168 | "duration": 9.88 169 | }, 170 | { 171 | "video_id": "v_9eniCub7u60", 172 | "question": "does the girl in black clothes have long hair", 173 | "id": "v_9eniCub7u60_2", 174 | "answer": "yes", 175 | "answer_type": 3, 176 | "duration": 19.43 177 | }, 178 | ] 179 | ``` 180 | 181 | ### Pretrained Weights 182 | We recommend users to download the pretrained weights from the following link 183 | [Vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5), 184 | [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14), 185 | and put them in `ckpt` following [Structure](#structure). 186 | 187 | ### Feature Extraction 188 | 189 | We recommend users to extract ViT features of training and evaluation data, which accelerates training and evaluating a lot. If you do so, just replace `.mp4` with `.safetensors` in video filename and put them in `image_features` and `video_features` folder. If not, ignore the `image_features` and `video_features` folder. 190 | 191 | We load video feature at fps=1 and arrange them in the time order. 192 | 193 | Each `.safetensors` file should contain a dict like this: 194 | 195 | ```python 196 | { 197 | 'feature': torch.Tensor() with shape=[256, 1024] for image and shape=[Length, 256, 1024] for video. 198 | } 199 | ``` 200 | 201 | 202 | ### Structure 203 | The folder structure should be organized as follows before training. 204 | 205 | ``` 206 | Flash-VStream 207 | ├── checkpoints-finetune 208 | ├── checkpoints-pretrain 209 | ├── ckpt 210 | │ ├── clip-vit-large-patch14 211 | │ ├── vicuna-7b-v1.5 212 | ├── data 213 | │ ├── pretrain 214 | │ │ ├── llava_558k_with_webvid.json 215 | │ │ ├── image_features 216 | │ │ ├── images 217 | │ │ ├── video_features 218 | │ │ ├── videos 219 | │ ├── finetune 220 | │ │ ├── llava_v1_5_mix665k_with_video_chatgpt.json 221 | │ │ ├── activitynet 222 | │ │ ├── coco 223 | │ │ ├── gqa 224 | │ │ ├── image_features 225 | │ │ │ ├── coco 226 | │ │ │ ├── gqa 227 | │ │ │ ├── ocr_vqa 228 | │ │ │ ├── textvqa 229 | │ │ │ ├── vg 230 | │ │ ├── ocr_vqa 231 | │ │ ├── textvqa 232 | │ │ ├── vg 233 | │ │ ├── video_features 234 | │ │ │ ├── activitynet 235 | │ ├── eval_video 236 | │ │ ├── ActivityNet-QA 237 | │ │ │ ├── video_features 238 | │ │ │ ├── test_qa.json 239 | │ │ ├── MSRVTT-QA 240 | │ │ │ ├── video_features 241 | │ │ │ ├── test_qa.json 242 | │ │ ├── MSVD-QA 243 | │ │ │ ├── video_features 244 | │ │ │ ├── test_qa.json 245 | │ │ ├── nextqa 246 | │ │ │ ├── video_features 247 | │ │ │ ├── test_qa.json 248 | │ │ ├── vstream 249 | │ │ │ ├── video_features 250 | │ │ │ ├── test_qa.json 251 | │ │ ├── vstream-realtime 252 | │ │ │ ├── video_features 253 | │ │ │ ├── test_qa.json 254 | ├── flash_vstream 255 | ├── scripts 256 | 257 | ``` 258 | 259 | ## Train 260 | Flash-VStream is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`. If your GPUs have less than 80GB memory, you may try ZeRO-2 and ZeRO-3 stages. 261 | 262 | Please make sure you download and organize the data following [Preparation](#preparation) before training. 263 | 264 | Like LLaVA, Flash-VStream has two training stages: pretrain and finetune. Their checkpoints will be saved in `checkpoints-pretrain` and `checkpoints-finetune` folder. These two stages will take about 15 hours on 8 A100 GPUs in total. 265 | 266 | If you want to train Flash-VStream from pretrained LLM and evaluate it, please run the following command: 267 | 268 | ```bash 269 | bash scripts/train_and_eval.sh 270 | ``` 271 | 272 | ## Evaluation 273 | Please make sure you download and organize the data following [Preparation](#preparation) before evaluation. 274 | 275 | If you want to evaluate a Flash-VStream model, please run the following command: 276 | 277 | ```bash 278 | bash scripts/eval.sh 279 | ``` 280 | 281 | ## Real-time CLI Inference 282 | We provide a real-time CLI inference script, which simulates video stream input by reading frames of a video file at a fixed frame speed. You can ask any question and get the answer at any timestamp of the video stream. Run the following command and have a try: 283 | 284 | ```bash 285 | bash scripts/realtime_cli.sh 286 | ``` 287 | 288 | ## VStream-QA Benchmark 289 | Please download VStream-QA Benchmark following [this](https://huggingface.co/datasets/IVGSZ/VStream-QA) repo. 290 | 291 | ## Citation 292 | If you find this project useful in your research, please consider citing: 293 | 294 | ``` 295 | @article{flashvstream, 296 | title={Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams}, 297 | author={Haoji Zhang and Yiqin Wang and Yansong Tang and Yong Liu and Jiashi Feng and Jifeng Dai and Xiaojie Jin}, 298 | year={2024}, 299 | eprint={2406.08085}, 300 | archivePrefix={arXiv}, 301 | primaryClass={cs.CV} 302 | } 303 | ``` 304 | 305 | ## Acknowledgement 306 | We would like to thank the following repos for their great work: 307 | 308 | - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA). 309 | - This work utilizes LLMs from [Vicuna](https://github.com/lm-sys/FastChat). 310 | - Some code is borrowed from [LLaMA-VID](https://github.com/dvlab-research/LLaMA-VID). 311 | - We perform video-based evaluation from [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT). 312 | 313 | ## License 314 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-yellow.svg)](LICENSE) 315 | 316 | This project is licensed under the [Apache-2.0 License](LICENSE). 317 | -------------------------------------------------------------------------------- /vstream.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | pyproject.toml 4 | flash_vstream/__init__.py 5 | flash_vstream/constants.py 6 | flash_vstream/conversation.py 7 | flash_vstream/mm_utils.py 8 | flash_vstream/utils.py 9 | flash_vstream/eval_video/eval_activitynet_qa.py 10 | flash_vstream/eval_video/eval_any_dataset_features.py 11 | flash_vstream/eval_video/model_msvd_qa.py 12 | flash_vstream/eval_video/model_msvd_qa_featuresloader.py 13 | flash_vstream/model/__init__.py 14 | flash_vstream/model/builder.py 15 | flash_vstream/model/compress_functions.py 16 | flash_vstream/model/vstream_arch.py 17 | flash_vstream/model/language_model/vstream_llama.py 18 | flash_vstream/model/multimodal_encoder/builder.py 19 | flash_vstream/model/multimodal_encoder/clip_encoder.py 20 | flash_vstream/model/multimodal_projector/builder.py 21 | flash_vstream/serve/cli_video_stream.py 22 | flash_vstream/train/llama_flash_attn_monkey_patch.py 23 | flash_vstream/train/llama_xformers_attn_monkey_patch.py 24 | flash_vstream/train/train.py 25 | flash_vstream/train/train_mem.py 26 | flash_vstream/train/train_xformers.py 27 | flash_vstream/train/vstream_trainer.py 28 | vstream.egg-info/PKG-INFO 29 | vstream.egg-info/SOURCES.txt 30 | vstream.egg-info/dependency_links.txt 31 | vstream.egg-info/requires.txt 32 | vstream.egg-info/top_level.txt -------------------------------------------------------------------------------- /vstream.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /vstream.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tokenizers>=0.12.1 3 | torch==2.0.1 4 | torchvision==0.15.2 5 | wandb 6 | tensorboard 7 | tensorboardX 8 | httpx==0.24.1 9 | deepspeed==0.9.5 10 | peft==0.4.0 11 | transformers==4.31.0 12 | accelerate==0.21.0 13 | bitsandbytes==0.41.0 14 | scikit-learn==1.2.2 15 | sentencepiece==0.1.99 16 | einops==0.6.1 17 | einops-exts==0.0.4 18 | timm==0.6.13 19 | decord 20 | openai==0.28.0 21 | -------------------------------------------------------------------------------- /vstream.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | assets 2 | flash_vstream 3 | --------------------------------------------------------------------------------