├── LICENSE
├── README.md
├── assets
├── example.mp4
├── radar_comparison.png
└── teaser2.png
├── flash_vstream
├── __init__.py
├── constants.py
├── conversation.py
├── eval_video
│ ├── eval_activitynet_qa.py
│ ├── eval_any_dataset_features.py
│ ├── model_msvd_qa.py
│ └── model_msvd_qa_featuresloader.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── builder.py
│ ├── compress_functions.py
│ ├── language_model
│ │ └── vstream_llama.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── vstream_arch.py
├── serve
│ └── cli_video_stream.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_xformers_attn_monkey_patch.py
│ ├── train.py
│ ├── train_mem.py
│ ├── train_xformers.py
│ └── vstream_trainer.py
└── utils.py
├── pyproject.toml
├── scripts
├── eval.sh
├── merge_lora_weights.py
├── realtime_cli.sh
├── train_and_eval.sh
├── zero0.json
├── zero1.json
├── zero2.json
├── zero3.json
└── zero3_offload.json
└── vstream.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
├── requires.txt
└── top_level.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams
2 |
3 |
4 | Haoji Zhang\*,
5 | Yiqin Wang\*,
6 | Yansong Tang †,
7 | Yong Liu,
8 | Jiashi Feng,
9 | Jifeng Dai,
10 | Xiaojie Jin†‡
11 |
12 | \* Equally contributing first authors, †Correspondence, ‡Project Lead
13 |
14 | **Work done when interning at Bytedance.**
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | [](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=flash-vstream-memory-based-real-time)
26 |
27 |
28 | [](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=flash-vstream-memory-based-real-time)
29 |
30 |
31 | [](https://paperswithcode.com/sota/question-answering-on-next-qa-open-ended?p=flash-vstream-memory-based-real-time)
32 |
33 |
34 | [](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=flash-vstream-memory-based-real-time)
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | We presented **Flash-VStream**, a noval LMM able to process extremely long video streams in real-time and respond to user queries simultaneously.
45 |
46 | We also proposed **VStream-QA**, a novel question answering benchmark specifically designed for online video streaming understanding.
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | ## News
56 | - [2024/6/15] 🏅 Our team won the 1st Place at [Long-Term Video Question Answering Challenge](https://sites.google.com/view/loveucvpr24/track1) of [LOVEU Workshop@CVPR'24](https://sites.google.com/view/loveucvpr24/home). Here is our [certification](https://github.com/bytedance/Flash-VStream/assets/37479394/e1496dec-52c8-4707-aabe-fd1970c8f874).
57 | We used a Hierarchical Memory model based on Flash-VStream-7b.
58 |
59 | - [2024/06/12] 🔥 Flash-VStream is coming! We release the
60 | [homepage](https://invinciblewyq.github.io/vstream-page),
61 | [paper](https://arxiv.org/abs/2406.08085v1),
62 | [code](https://github.com/IVGSZ/Flash-VStream)
63 | and [model](https://huggingface.co/IVGSZ/Flash-VStream-7b)
64 | for Flash-VStream.
65 | We release the [dataset](https://huggingface.co/datasets/IVGSZ/VStream-QA) for VStream-QA benchmark.
66 |
67 | ## Contents
68 | - [Install](#install)
69 | - [Model](#model)
70 | - [Preparation](#preparation)
71 | - [Train](#train)
72 | - [Evaluation](#evaluation)
73 | - [Real-time CLI Inference](#Real-time-CLI-Inference)
74 | - [VStream-QA Benchmark](#VStream-QA-Benchmark)
75 | - [Citation](#citation)
76 | - [Acknowledgement](#acknowledgement)
77 | - [License](#license)
78 |
79 | ## Install
80 | Please follow the instructions below to install the required packages.
81 | 1. Clone this repository
82 |
83 | 2. Install Package
84 | ```bash
85 | conda create -n vstream python=3.10 -y
86 | conda activate vstream
87 | cd Flash-VStream
88 | pip install --upgrade pip
89 | pip install -e .
90 | ```
91 |
92 | 3. Install additional packages for training cases
93 | ```bash
94 | pip install ninja
95 | pip install flash-attn --no-build-isolation
96 | ```
97 |
98 | ## Model
99 |
100 | We provide our Flash-VStream models after Stage 1 and 2 finetuning:
101 |
102 | | Model | Weight | Initialized from LLM | Initialized from ViT |
103 | | --- | --- | --- | --- |
104 | | Flash-VStream-7b | [Flash-VStream-7b](https://huggingface.co/IVGSZ/Flash-VStream-7b) | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) |
105 |
106 |
107 | ## Preparation
108 |
109 | ### Dataset
110 |
111 | **Image VQA Dataset.**
112 | Please organize the training Image VQA training data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md) and evaluation data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).
113 | Please put the pretraining data, finetuning data, and evaluation data in `pretrain`, `finetune`, and `eval_video` folder following [Structure](#structure).
114 |
115 | **Video VQA Dataset.**
116 | please download the 2.5M subset from [WebVid](https://maxbain.com/webvid-dataset/) and ActivityNet dataset from [official website](http://activity-net.org/download.html) or [video-chatgpt](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/docs/train_video_chatgpt.md).
117 |
118 | If you want to perform evaluation, please also download corresponding files of
119 | [ActivityNet-QA](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md)
120 | and [NExT-QA-OE](https://github.com/doc-doc/NExT-QA).
121 | You can download
122 | [MSVD-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EUNEXqg8pctPq3WZPHb4Fd8BYIxHO5qPCnU6aWsrV1O4JQ?e=guynwu)
123 | and [MSRVTT-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EcEXh1HfTXhLrRnuwHbl15IBJeRop-d50Q90njHmhvLwtA?e=SE24eG) from LLaMA-VID.
124 |
125 |
126 | **Meta Info.**
127 | For meta info of training data, please download the following files and organize them as in [Structure](#structure).
128 |
129 | | Training Stage | Data file name | Size |
130 | | --- | --- | ---: |
131 | | Pretrain | [llava_558k_with_webvid.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 254 MB |
132 | | Finetune | [llava_v1_5_mix665k_with_video_chatgpt.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 860 MB |
133 |
134 | For meta info of evaluation data, please reformat each QA list to a json file named `test_qa.json` under [Structure](#structure) with format like this:
135 |
136 | ```json
137 | [
138 | {
139 | "video_id": "v_1QIUV7WYKXg",
140 | "question": "is the athlete wearing trousers",
141 | "id": "v_1QIUV7WYKXg_3",
142 | "answer": "no",
143 | "answer_type": 3,
144 | "duration": 9.88
145 | },
146 | {
147 | "video_id": "v_9eniCub7u60",
148 | "question": "does the girl in black clothes have long hair",
149 | "id": "v_9eniCub7u60_2",
150 | "answer": "yes",
151 | "answer_type": 3,
152 | "duration": 19.43
153 | },
154 | ]
155 | ```
156 |
157 | ### Pretrained Weights
158 | We recommend users to download the pretrained weights from the following link
159 | [Vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5),
160 | [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14),
161 | and put them in `ckpt` following [Structure](#structure).
162 |
163 | ### Feature Extraction
164 |
165 | We recommend users to extract ViT features of training and evaluation data, which accelerates training and evaluating a lot. If you do so, just replace `.mp4` with `.safetensors` in video filename and put them in `image_features` and `video_features` folder. If not, ignore the `image_features` and `video_features` folder.
166 |
167 | We load video feature at fps=1 and arrange them in the time order.
168 |
169 | Each `.safetensors` file should contain a dict like this:
170 |
171 | ```python
172 | {
173 | 'feature': torch.Tensor() with shape=[256, 1024] for image and shape=[Length, 256, 1024] for video.
174 | }
175 | ```
176 |
177 |
178 | ### Structure
179 | The folder structure should be organized as follows before training.
180 |
181 | ```
182 | Flash-VStream
183 | ├── checkpoints-finetune
184 | ├── checkpoints-pretrain
185 | ├── ckpt
186 | │ ├── clip-vit-large-patch14
187 | │ ├── vicuna-7b-v1.5
188 | ├── data
189 | │ ├── pretrain
190 | │ │ ├── llava_558k_with_webvid.json
191 | │ │ ├── image_features
192 | │ │ ├── images
193 | │ │ ├── video_features
194 | │ │ ├── videos
195 | │ ├── finetune
196 | │ │ ├── llava_v1_5_mix665k_with_video_chatgpt.json
197 | │ │ ├── activitynet
198 | │ │ ├── coco
199 | │ │ ├── gqa
200 | │ │ ├── image_features
201 | │ │ │ ├── coco
202 | │ │ │ ├── gqa
203 | │ │ │ ├── ocr_vqa
204 | │ │ │ ├── textvqa
205 | │ │ │ ├── vg
206 | │ │ ├── ocr_vqa
207 | │ │ ├── textvqa
208 | │ │ ├── vg
209 | │ │ ├── video_features
210 | │ │ │ ├── activitynet
211 | │ ├── eval_video
212 | │ │ ├── ActivityNet-QA
213 | │ │ │ ├── video_features
214 | │ │ │ ├── test_qa.json
215 | │ │ ├── MSRVTT-QA
216 | │ │ │ ├── video_features
217 | │ │ │ ├── test_qa.json
218 | │ │ ├── MSVD-QA
219 | │ │ │ ├── video_features
220 | │ │ │ ├── test_qa.json
221 | │ │ ├── nextoe
222 | │ │ │ ├── video_features
223 | │ │ │ ├── test_qa.json
224 | │ │ ├── vstream
225 | │ │ │ ├── video_features
226 | │ │ │ ├── test_qa.json
227 | │ │ ├── vstream-realtime
228 | │ │ │ ├── video_features
229 | │ │ │ ├── test_qa.json
230 | ├── flash_vstream
231 | ├── scripts
232 |
233 | ```
234 |
235 | ## Train
236 | Flash-VStream is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`. If your GPUs have less than 80GB memory, you may try ZeRO-2 and ZeRO-3 stages.
237 |
238 | Please make sure you download and organize the data following [Preparation](#preparation) before training.
239 |
240 | Like LLaVA, Flash-VStream has two training stages: pretrain and finetune. Their checkpoints will be saved in `checkpoints-pretrain` and `checkpoints-finetune` folder. These two stages will take about 15 hours on 8 A100 GPUs in total.
241 |
242 | If you want to train Flash-VStream from pretrained LLM and evaluate it, please run the following command:
243 |
244 | ```bash
245 | bash scripts/train_and_eval.sh
246 | ```
247 |
248 | ## Evaluation
249 | Please make sure you download and organize the data following [Preparation](#preparation) before evaluation.
250 |
251 | If you want to evaluate a Flash-VStream model, please run the following command:
252 |
253 | ```bash
254 | bash scripts/eval.sh
255 | ```
256 |
257 | ## Real-time CLI Inference
258 | We provide a real-time CLI inference script, which simulates video stream input by reading frames of a video file at a fixed frame speed. You can ask any question and get the answer at any timestamp of the video stream. Run the following command and have a try:
259 |
260 | ```bash
261 | bash scripts/realtime_cli.sh
262 | ```
263 |
264 | ## VStream-QA Benchmark
265 | Please download VStream-QA Benchmark following [this](https://huggingface.co/datasets/IVGSZ/VStream-QA) repo.
266 |
267 | ## Citation
268 | If you find this project useful in your research, please consider citing:
269 |
270 | ```
271 | @article{flashvstream,
272 | title={Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams},
273 | author={Haoji Zhang and Yiqin Wang and Yansong Tang and Yong Liu and Jiashi Feng and Jifeng Dai and Xiaojie Jin},
274 | year={2024},
275 | eprint={2406.08085},
276 | archivePrefix={arXiv},
277 | primaryClass={cs.CV}
278 | }
279 | ```
280 |
281 | ## Acknowledgement
282 | We would like to thank the following repos for their great work:
283 |
284 | - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA).
285 | - This work utilizes LLMs from [Vicuna](https://github.com/lm-sys/FastChat).
286 | - Some code is borrowed from [LLaMA-VID](https://github.com/dvlab-research/LLaMA-VID).
287 | - We perform video-based evaluation from [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT).
288 |
289 | ## License
290 | [](LICENSE)
291 |
292 | This project is licensed under the [Apache-2.0 License](LICENSE).
293 |
294 |
--------------------------------------------------------------------------------
/assets/example.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVGSZ/Flash-VStream/be61d780acb39f29a7193935cf36c210cfc16695/assets/example.mp4
--------------------------------------------------------------------------------
/assets/radar_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVGSZ/Flash-VStream/be61d780acb39f29a7193935cf36c210cfc16695/assets/radar_comparison.png
--------------------------------------------------------------------------------
/assets/teaser2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVGSZ/Flash-VStream/be61d780acb39f29a7193935cf36c210cfc16695/assets/teaser2.png
--------------------------------------------------------------------------------
/flash_vstream/__init__.py:
--------------------------------------------------------------------------------
1 | from flash_vstream.model import VStreamLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/flash_vstream/constants.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
4 | WORKER_HEART_BEAT_INTERVAL = 15
5 |
6 | LOGDIR = "."
7 |
8 | # Model Constants
9 | IGNORE_INDEX = -100
10 | IMAGE_TOKEN_INDEX = -200
11 | DEFAULT_IMAGE_TOKEN = ""
12 | DEFAULT_IMAGE_PATCH_TOKEN = ""
13 | DEFAULT_IM_START_TOKEN = ""
14 | DEFAULT_IM_END_TOKEN = ""
15 | IMAGE_PLACEHOLDER = ""
16 |
--------------------------------------------------------------------------------
/flash_vstream/conversation.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import dataclasses
4 | from enum import auto, Enum
5 | from typing import List, Tuple
6 |
7 |
8 | class SeparatorStyle(Enum):
9 | """Different separator style."""
10 | SINGLE = auto()
11 | TWO = auto()
12 | MPT = auto()
13 | PLAIN = auto()
14 | LLAMA_2 = auto()
15 |
16 |
17 | @dataclasses.dataclass
18 | class Conversation:
19 | """A class that keeps all conversation history."""
20 | system: str
21 | roles: List[str]
22 | messages: List[List[str]]
23 | offset: int
24 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25 | sep: str = "###"
26 | sep2: str = None
27 | version: str = "Unknown"
28 |
29 | skip_next: bool = False
30 |
31 | def get_prompt(self):
32 | messages = self.messages
33 | if len(messages) > 0 and type(messages[0][1]) is tuple:
34 | messages = self.messages.copy()
35 | init_role, init_msg = messages[0].copy()
36 | init_msg = init_msg[0].replace("", "").strip()
37 | if 'mmtag' in self.version:
38 | messages[0] = (init_role, init_msg)
39 | messages.insert(0, (self.roles[0], ""))
40 | messages.insert(1, (self.roles[1], "Received."))
41 | else:
42 | messages[0] = (init_role, "\n" + init_msg)
43 |
44 | if self.sep_style == SeparatorStyle.SINGLE:
45 | ret = self.system + self.sep
46 | for role, message in messages:
47 | if message:
48 | if type(message) is tuple:
49 | message, _, _ = message
50 | ret += role + ": " + message + self.sep
51 | else:
52 | ret += role + ":"
53 | elif self.sep_style == SeparatorStyle.TWO:
54 | seps = [self.sep, self.sep2]
55 | ret = self.system + seps[0]
56 | for i, (role, message) in enumerate(messages):
57 | if message:
58 | if type(message) is tuple:
59 | message, _, _ = message
60 | ret += role + ": " + message + seps[i % 2]
61 | else:
62 | ret += role + ":"
63 | elif self.sep_style == SeparatorStyle.MPT:
64 | ret = self.system + self.sep
65 | for role, message in messages:
66 | if message:
67 | if type(message) is tuple:
68 | message, _, _ = message
69 | ret += role + message + self.sep
70 | else:
71 | ret += role
72 | elif self.sep_style == SeparatorStyle.LLAMA_2:
73 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n"
74 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75 | ret = ""
76 |
77 | for i, (role, message) in enumerate(messages):
78 | if i == 0:
79 | assert message, "first message should not be none"
80 | assert role == self.roles[0], "first message should come from user"
81 | if message:
82 | if type(message) is tuple:
83 | message, _, _ = message
84 | if i == 0: message = wrap_sys(self.system) + message
85 | if i % 2 == 0:
86 | message = wrap_inst(message)
87 | ret += self.sep + message
88 | else:
89 | ret += " " + message + " " + self.sep2
90 | else:
91 | ret += ""
92 | ret = ret.lstrip(self.sep)
93 | elif self.sep_style == SeparatorStyle.PLAIN:
94 | seps = [self.sep, self.sep2]
95 | ret = self.system
96 | for i, (role, message) in enumerate(messages):
97 | if message:
98 | if type(message) is tuple:
99 | message, _, _ = message
100 | ret += message + seps[i % 2]
101 | else:
102 | ret += ""
103 | else:
104 | raise ValueError(f"Invalid style: {self.sep_style}")
105 |
106 | return ret
107 |
108 | def append_message(self, role, message):
109 | self.messages.append([role, message])
110 |
111 | def get_images(self, return_pil=False):
112 | images = []
113 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
114 | if i % 2 == 0:
115 | if type(msg) is tuple:
116 | import base64
117 | from io import BytesIO
118 | from PIL import Image
119 | msg, image, image_process_mode = msg
120 | if image_process_mode == "Pad":
121 | def expand2square(pil_img, background_color=(122, 116, 104)):
122 | width, height = pil_img.size
123 | if width == height:
124 | return pil_img
125 | elif width > height:
126 | result = Image.new(pil_img.mode, (width, width), background_color)
127 | result.paste(pil_img, (0, (width - height) // 2))
128 | return result
129 | else:
130 | result = Image.new(pil_img.mode, (height, height), background_color)
131 | result.paste(pil_img, ((height - width) // 2, 0))
132 | return result
133 | image = expand2square(image)
134 | elif image_process_mode in ["Default", "Crop"]:
135 | pass
136 | elif image_process_mode == "Resize":
137 | image = image.resize((336, 336))
138 | else:
139 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
140 | max_hw, min_hw = max(image.size), min(image.size)
141 | aspect_ratio = max_hw / min_hw
142 | max_len, min_len = 800, 400
143 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
144 | longest_edge = int(shortest_edge * aspect_ratio)
145 | W, H = image.size
146 | if longest_edge != max(image.size):
147 | if H > W:
148 | H, W = longest_edge, shortest_edge
149 | else:
150 | H, W = shortest_edge, longest_edge
151 | image = image.resize((W, H))
152 | if return_pil:
153 | images.append(image)
154 | else:
155 | buffered = BytesIO()
156 | image.save(buffered, format="PNG")
157 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
158 | images.append(img_b64_str)
159 | return images
160 |
161 | def to_gradio_chatbot(self):
162 | ret = []
163 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
164 | if i % 2 == 0:
165 | if type(msg) is tuple:
166 | import base64
167 | from io import BytesIO
168 | msg, image, image_process_mode = msg
169 | max_hw, min_hw = max(image.size), min(image.size)
170 | aspect_ratio = max_hw / min_hw
171 | max_len, min_len = 800, 400
172 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
173 | longest_edge = int(shortest_edge * aspect_ratio)
174 | W, H = image.size
175 | if H > W:
176 | H, W = longest_edge, shortest_edge
177 | else:
178 | H, W = shortest_edge, longest_edge
179 | image = image.resize((W, H))
180 | buffered = BytesIO()
181 | image.save(buffered, format="JPEG")
182 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
183 | img_str = f'
'
184 | msg = img_str + msg.replace('', '').strip()
185 | ret.append([msg, None])
186 | else:
187 | ret.append([msg, None])
188 | else:
189 | ret[-1][-1] = msg
190 | return ret
191 |
192 | def copy(self):
193 | return Conversation(
194 | system=self.system,
195 | roles=self.roles,
196 | messages=[[x, y] for x, y in self.messages],
197 | offset=self.offset,
198 | sep_style=self.sep_style,
199 | sep=self.sep,
200 | sep2=self.sep2,
201 | version=self.version)
202 |
203 | def dict(self):
204 | if len(self.get_images()) > 0:
205 | return {
206 | "system": self.system,
207 | "roles": self.roles,
208 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
209 | "offset": self.offset,
210 | "sep": self.sep,
211 | "sep2": self.sep2,
212 | }
213 | return {
214 | "system": self.system,
215 | "roles": self.roles,
216 | "messages": self.messages,
217 | "offset": self.offset,
218 | "sep": self.sep,
219 | "sep2": self.sep2,
220 | }
221 |
222 |
223 | conv_vicuna_v0 = Conversation(
224 | system="A chat between a curious human and an artificial intelligence assistant. "
225 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
226 | roles=("Human", "Assistant"),
227 | messages=(
228 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
229 | ("Assistant",
230 | "Renewable energy sources are those that can be replenished naturally in a relatively "
231 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
232 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
233 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
234 | "renewable and non-renewable energy sources:\n"
235 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
236 | "energy sources are finite and will eventually run out.\n"
237 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
238 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
239 | "and other negative effects.\n"
240 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
241 | "have lower operational costs than non-renewable sources.\n"
242 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
243 | "locations than non-renewable sources.\n"
244 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
245 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
246 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
247 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
248 | ),
249 | offset=2,
250 | sep_style=SeparatorStyle.SINGLE,
251 | sep="###",
252 | )
253 |
254 | conv_vicuna_v1 = Conversation(
255 | system="A chat between a curious user and an artificial intelligence assistant. "
256 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
257 | roles=("USER", "ASSISTANT"),
258 | version="v1",
259 | messages=(),
260 | offset=0,
261 | sep_style=SeparatorStyle.TWO,
262 | sep=" ",
263 | sep2="",
264 | )
265 |
266 | conv_vicuna_v1_mcq = Conversation(
267 | system="A chat between a curious user and an artificial intelligence assistant. "
268 | "The assistant gives helpful, detailed, and polite answers to the user's questions. "
269 | "The assistant should give the number of correct answer.",
270 | roles=("USER", "ASSISTANT"),
271 | version="v1",
272 | messages=(),
273 | offset=0,
274 | sep_style=SeparatorStyle.TWO,
275 | sep=" ",
276 | sep2="",
277 | )
278 |
279 | conv_tiny = Conversation(
280 | system="""<|system|>
281 | A conversation between a user and an AI assistant. The assistant gives short and honest answers.""",
282 | roles=("<|user|>\n", "<|assistant|>\n"),
283 | version="mpt",
284 | messages=(),
285 | offset=0,
286 | sep_style=SeparatorStyle.MPT,
287 | sep="",
288 | )
289 |
290 | conv_llama_2 = Conversation(
291 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
292 |
293 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
294 | roles=("USER", "ASSISTANT"),
295 | version="llama_v2",
296 | messages=(),
297 | offset=0,
298 | sep_style=SeparatorStyle.LLAMA_2,
299 | sep="",
300 | sep2="",
301 | )
302 |
303 | conv_mpt = Conversation(
304 | system="""<|im_start|>system
305 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
306 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
307 | version="mpt",
308 | messages=(),
309 | offset=0,
310 | sep_style=SeparatorStyle.MPT,
311 | sep="<|im_end|>",
312 | )
313 |
314 | conv_plain = Conversation(
315 | system="",
316 | roles=("", ""),
317 | messages=(
318 | ),
319 | offset=0,
320 | sep_style=SeparatorStyle.PLAIN,
321 | sep="\n",
322 | )
323 |
324 |
325 | default_conversation = conv_vicuna_v1
326 | conv_templates = {
327 | "default": conv_vicuna_v0,
328 | "v0": conv_vicuna_v0,
329 | "v1": conv_vicuna_v1,
330 | "vicuna_v1": conv_vicuna_v1,
331 | "llama_2": conv_llama_2,
332 | "plain": conv_plain,
333 | }
334 |
335 |
336 | if __name__ == "__main__":
337 | print(default_conversation.get_prompt())
338 |
--------------------------------------------------------------------------------
/flash_vstream/eval_video/eval_activitynet_qa.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import os
4 | import ast
5 | import json
6 | import openai
7 | import argparse
8 | from tqdm import tqdm
9 | from time import sleep
10 | from collections import defaultdict
11 | from multiprocessing.pool import Pool
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
15 | parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
16 | parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
17 | parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
18 | parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
19 | parser.add_argument("--num_chunks", default=1, type=int, help="Result splits")
20 | parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key")
21 | parser.add_argument("--api_type", default=None, type=str, help="OpenAI API type")
22 | parser.add_argument("--api_version", default=None, type=str, help="OpenAI API version")
23 | parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base")
24 | args = parser.parse_args()
25 | return args
26 |
27 |
28 | def annotate(prediction_set, caption_files, output_dir):
29 | """
30 | Evaluates question and answer pairs using GPT-3
31 | Returns a score for correctness.
32 | """
33 | for file in tqdm(caption_files):
34 | key = file[:-5] # Strip file extension
35 | qa_set = prediction_set[key]
36 | question = qa_set['q']
37 | answer = qa_set['a']
38 | pred = qa_set['pred']
39 | try:
40 | # Compute the correctness score
41 | completion = openai.ChatCompletion.create(
42 | model="gpt-3.5-turbo",
43 | messages=[
44 | {
45 | "role": "system",
46 | "content":
47 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
48 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
49 | "------"
50 | "##INSTRUCTIONS: "
51 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
52 | "- Consider synonyms or paraphrases as valid matches.\n"
53 | "- Evaluate the correctness of the prediction compared to the answer."
54 | },
55 | {
56 | "role": "user",
57 | "content":
58 | "Please evaluate the following video-based question-answer pair:\n\n"
59 | f"Question: {question}\n"
60 | f"Correct Answer: {answer}\n"
61 | f"Predicted Answer: {pred}\n\n"
62 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
63 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
64 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
65 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
66 | }
67 | ],
68 | temperature=0.002
69 | )
70 | # Convert response to a Python dictionary.
71 | response_message = completion["choices"][0]["message"]["content"]
72 | response_dict = ast.literal_eval(response_message)
73 | result_qa_pair = [response_dict, qa_set]
74 |
75 | # Save the question-answer pairs to a json file.
76 | with open(f"{output_dir}/{key}.json", "w") as f:
77 | json.dump(result_qa_pair, f)
78 | sleep(0.5)
79 |
80 | except Exception as e:
81 | print(f"Error processing file '{key}': {e}")
82 | sleep(1)
83 |
84 |
85 | def main():
86 | """
87 | Main function to control the flow of the program.
88 | """
89 | # Parse arguments.
90 | args = parse_args()
91 |
92 | if args.num_chunks > 1:
93 | pred_contents = []
94 | for _idx in range(args.num_chunks):
95 | file = os.path.join(args.pred_path, f"{args.num_chunks}_{_idx}.json")
96 | pred_contents += [json.loads(line) for line in open(file)]
97 |
98 | else:
99 | file = os.path.join(args.pred_path, f"pred.json")
100 | pred_contents = [json.loads(line) for line in open(file)]
101 |
102 | # Dictionary to store the count of occurrences for each video_id
103 | video_id_counts = {}
104 | new_pred_contents = []
105 |
106 | # Iterate through each sample in pred_contents
107 | for sample in pred_contents:
108 | video_id = sample['id']
109 | if video_id in video_id_counts:
110 | video_id_counts[video_id] += 1
111 | else:
112 | video_id_counts[video_id] = 0
113 |
114 | # Create a new sample with the modified key
115 | new_sample = sample
116 | new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}"
117 | new_pred_contents.append(new_sample)
118 |
119 | # Generating list of id's and corresponding files
120 | id_list = [x['id'] for x in new_pred_contents]
121 | caption_files = [f"{id}.json" for id in id_list]
122 |
123 | output_dir = args.output_dir
124 | # Generate output directory if not exists.
125 | if not os.path.exists(output_dir):
126 | os.makedirs(output_dir)
127 |
128 | # Preparing dictionary of question-answer sets
129 | prediction_set = {}
130 | for sample in new_pred_contents:
131 | id = sample['id']
132 | question = sample['question']
133 | answer = sample['answer']
134 | pred = sample['pred']
135 | qa_set = {"q": question, "a": answer, "pred": pred, "a_type": sample['answer_type'] if 'answer_type' in sample else None}
136 | prediction_set[id] = qa_set
137 |
138 | # Set the OpenAI API key.
139 | openai.api_key = args.api_key # Your API key here
140 | if args.api_type:
141 | openai.api_type = args.api_type
142 | if args.api_version:
143 | openai.api_version = args.api_version
144 | if args.api_base:
145 | openai.api_base = args.api_base # Your API base here
146 | num_tasks = args.num_tasks
147 |
148 | # While loop to ensure that all captions are processed.
149 | incomplete_lengths = []
150 | for _ in range(100):
151 | try:
152 | # Files that have not been processed yet.
153 | completed_files = os.listdir(output_dir)
154 | print(f"completed_files: {len(completed_files)}")
155 |
156 | # Files that have not been processed yet.
157 | incomplete_files = [f for f in caption_files if f not in completed_files]
158 | print(f"incomplete_files: {len(incomplete_files)}")
159 | incomplete_lengths.append(len(incomplete_files))
160 | if len(incomplete_lengths) > 5 and len(set(incomplete_lengths[-5:])) <= 1:
161 | print(f"incomplete_lengths: {incomplete_lengths}")
162 | print(f"incomplete_files: {incomplete_files}")
163 | print(f"completed_files: {completed_files}")
164 | print(f"failed for 5 times, break")
165 | break
166 |
167 | # Break the loop when there are no incomplete files
168 | if len(incomplete_files) == 0:
169 | break
170 | if len(incomplete_files) <= num_tasks:
171 | num_tasks = 1
172 |
173 | # Split tasks into parts.
174 | part_len = len(incomplete_files) // num_tasks
175 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
176 | task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
177 |
178 | # Use a pool of workers to process the files in parallel.
179 | with Pool() as pool:
180 | pool.starmap(annotate, task_args)
181 |
182 | except Exception as e:
183 | print(f"Error: {e}")
184 |
185 | # Combine all the processed files into one
186 | combined_contents = {}
187 | json_path = args.output_json
188 |
189 | # Iterate through json files
190 | for file_name in os.listdir(output_dir):
191 | if file_name.endswith(".json"):
192 | file_path = os.path.join(output_dir, file_name)
193 | with open(file_path, "r") as json_file:
194 | content = json.load(json_file)
195 | assert 'pred' in content[0], f"Error: {file_name} don't has key=pred"
196 | assert 'score' in content[0], f"Error: {file_name} don't has key=score"
197 | combined_contents[file_name[:-5]] = content
198 |
199 | # Write combined content to a json file
200 | with open(json_path, "w") as json_file:
201 | json.dump(combined_contents, json_file)
202 | print("All evaluation completed!")
203 |
204 | class ScoreMeter:
205 | def __init__(self):
206 | self.score_sum = 0
207 | self.count = 0
208 | self.yes_count = 0
209 | self.no_count = 0
210 | self.score_dict = {'yes': defaultdict(int), 'no': defaultdict(int)}
211 |
212 | def add_score(self, score, pred):
213 | self.score_sum += score
214 | self.count += 1
215 | pred_lower = pred.lower()
216 | if 'yes' in pred_lower:
217 | self.yes_count += 1
218 | self.score_dict['yes'][score] += 1
219 | elif 'no' in pred_lower:
220 | self.no_count += 1
221 | self.score_dict['no'][score] += 1
222 |
223 | def get_average_score(self):
224 | res = (self.score_sum / self.count) if self.count else 0
225 | return f"{res:.6f}"
226 |
227 | def get_accuracy(self, response_type):
228 | if response_type == 'yes':
229 | res = (self.yes_count / self.count) if self.count else 0
230 | elif response_type == 'no':
231 | res = (self.no_count / self.count) if self.count else 0
232 | else:
233 | res = 0
234 | return f"{res:.6f}"
235 |
236 | meter_dic = {'total': ScoreMeter()}
237 | for key, result in combined_contents.items():
238 | # Computing score
239 | score_match = result[0]['score']
240 | score = int(score_match)
241 | pred = result[0]['pred']
242 |
243 | meter_dic["total"].add_score(score, pred)
244 | if 'a_type' in result[1] and result[1]['a_type'] is not None:
245 | typ = str(result[1]['a_type'])
246 | if typ not in meter_dic:
247 | meter_dic[typ] = ScoreMeter()
248 | meter_dic[typ].add_score(score, pred)
249 |
250 | if 'next' in args.output_dir:
251 | typ = typ[0]
252 | if typ not in meter_dic:
253 | meter_dic[typ] = ScoreMeter()
254 | meter_dic[typ].add_score(score, pred)
255 |
256 | csv_dic = {'acc': meter_dic["total"].get_accuracy('yes'), 'score': meter_dic["total"].get_average_score()}
257 |
258 | output = ""
259 | output += "Yes count: " + str(meter_dic["total"].yes_count) + "\n"
260 | output += "No count: " + str(meter_dic["total"].no_count) + "\n"
261 | output += "Accuracy: " + str(meter_dic["total"].get_accuracy('yes')) + "\n"
262 | output += "Average score: " + str(meter_dic["total"].get_average_score()) + "\n"
263 | output += "\n"
264 | output += "Total Score Yes/No distribution:\n"
265 | for key, value in meter_dic["total"].score_dict.items():
266 | output += f"{key}:\n"
267 | for k in range(0, 6):
268 | v = value[k]
269 | output += f"{k}: {v}\n"
270 | output += "\n"
271 | output += "Answer Type Score distribution:\n"
272 | output += 'Type, Accuracy, Avg_score\n'
273 | key_list = sorted([k for k in meter_dic.keys()])
274 | for key in key_list:
275 | output += f"{key}, {meter_dic[key].get_accuracy('yes')}, {meter_dic[key].get_average_score()}\n"
276 | csv_dic[key] = meter_dic[key].get_accuracy('yes')
277 |
278 | output += "\n"
279 | for k in csv_dic.keys():
280 | output += f"{k}, "
281 | output = output.rstrip(', ') # Remove the trailing comma and space
282 | output += "\n"
283 |
284 | for k in csv_dic.keys():
285 | output += str(csv_dic[k]) + ", "
286 | output = output.rstrip(', ') # Remove the trailing comma and space
287 | output += "\n"
288 |
289 | print(output)
290 | args.output_csv = args.output_json.replace(".json", ".csv")
291 | with open(args.output_csv, 'w') as f:
292 | f.write(output)
293 |
294 | if __name__ == "__main__":
295 | main()
296 |
297 |
--------------------------------------------------------------------------------
/flash_vstream/eval_video/eval_any_dataset_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Flash-VStream Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import argparse
17 | import subprocess
18 | import multiprocessing
19 |
20 | def exec(cmd, sub=False, device=None):
21 | print(f'exec: {cmd}')
22 | if not sub:
23 | if isinstance(cmd, list):
24 | cmd = ' '.join(cmd)
25 | os.system(cmd)
26 | else:
27 | my_env = os.environ.copy()
28 | my_env["CUDA_VISIBLE_DEVICES"] = device
29 | subprocess.run(cmd, env=my_env)
30 |
31 | # multi gpu, feature
32 | def eval_msvd(args):
33 | model_path = args.model_path
34 | num_chunks = args.num_chunks
35 | if not args.only_eval:
36 | processes = []
37 | for idx in range(0, num_chunks):
38 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
39 | "--model-path", model_path,
40 | "--video_dir", "./data/eval_video/MSVD-QA/video_features",
41 | "--gt_file", "./data/eval_video/MSVD-QA/test_qa.json",
42 | "--output_dir", os.path.join(model_path, "evaluation", "msvd"),
43 | "--output_name", "pred",
44 | "--num-chunks", str(num_chunks),
45 | "--chunk-idx", str(idx),
46 | "--conv-mode", "vicuna_v1"]
47 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
48 | processes.append(p)
49 | p.start() # 启动子进程
50 | for p in processes:
51 | p.join()
52 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
53 | "--pred_path", os.path.join(model_path, "evaluation", "msvd"),
54 | "--output_dir", os.path.join(model_path, "evaluation", "msvd", "results"),
55 | "--output_json", os.path.join(model_path, "evaluation", "msvd", "results.json"),
56 | "--num_chunks", str(num_chunks),
57 | "--num_tasks", "16",
58 | "--api_key", args.api_key,
59 | "--api_base", args.api_base,
60 | "--api_type", args.api_type,
61 | "--api_version", args.api_version,
62 | ]
63 | exec(cmd)
64 |
65 | # multi gpu, feature
66 | def eval_msrvtt(args):
67 | model_path = args.model_path
68 | num_chunks = args.num_chunks
69 | if not args.only_eval:
70 | processes = []
71 | for idx in range(0, num_chunks):
72 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
73 | "--model-path", model_path,
74 | "--video_dir", "./data/eval_video/MSRVTT-QA/video_features",
75 | "--gt_file", "./data/eval_video/MSRVTT-QA/test_qa.json",
76 | "--output_dir", os.path.join(model_path, "evaluation", "msrvtt"),
77 | "--output_name", "pred",
78 | "--num-chunks", str(num_chunks),
79 | "--chunk-idx", str(idx),
80 | "--conv-mode", "vicuna_v1"]
81 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
82 | processes.append(p)
83 | p.start() # 启动子进程
84 | for p in processes:
85 | p.join()
86 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
87 | "--pred_path", os.path.join(model_path, "evaluation", "msrvtt"),
88 | "--output_dir", os.path.join(model_path, "evaluation", "msrvtt", "results"),
89 | "--output_json", os.path.join(model_path, "evaluation", "msrvtt", "results.json"),
90 | "--num_chunks", str(num_chunks),
91 | "--num_tasks", "16",
92 | "--api_key", args.api_key,
93 | "--api_base", args.api_base,
94 | "--api_type", args.api_type,
95 | "--api_version", args.api_version,
96 | ]
97 | exec(cmd)
98 |
99 | # multi gpu, feature
100 | def eval_actnet(args):
101 | model_path = args.model_path
102 | num_chunks = args.num_chunks
103 | if not args.only_eval:
104 | processes = []
105 | for idx in range(0, num_chunks):
106 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
107 | "--model-path", model_path,
108 | "--video_dir", "./data/eval_video/ActivityNet-QA/video_features",
109 | "--gt_file", "./data/eval_video/ActivityNet-QA/test_qa.json",
110 | "--output_dir", os.path.join(model_path, "evaluation", "actnet"),
111 | "--output_name", "pred",
112 | "--num-chunks", str(num_chunks),
113 | "--chunk-idx", str(idx),
114 | "--conv-mode", "vicuna_v1",
115 | ]
116 |
117 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
118 | processes.append(p)
119 | p.start() # 启动子进程
120 | for p in processes:
121 | p.join()
122 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
123 | "--pred_path", os.path.join(model_path, "evaluation", "actnet"),
124 | "--output_dir", os.path.join(model_path, "evaluation", "actnet", "results"),
125 | "--output_json", os.path.join(model_path, "evaluation", "actnet", "results.json"),
126 | "--num_chunks", str(num_chunks),
127 | "--num_tasks", "16",
128 | "--api_key", args.api_key,
129 | "--api_base", args.api_base,
130 | "--api_type", args.api_type,
131 | "--api_version", args.api_version,
132 | ]
133 | exec(cmd)
134 |
135 | # multi gpu, feature
136 | def eval_nextoe(args): # follow msvd format, OE follow actnet
137 | model_path = args.model_path
138 | num_chunks = args.num_chunks
139 | if not args.only_eval:
140 | processes = []
141 | for idx in range(0, num_chunks):
142 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
143 | "--model-path", model_path,
144 | "--video_dir", "./data/eval_video/nextoe/video_features",
145 | "--gt_file", "./data/eval_video/nextoe/test_qa.json",
146 | "--output_dir", os.path.join(model_path, "evaluation", "nextoe"),
147 | "--output_name", "pred",
148 | "--num-chunks", str(num_chunks),
149 | "--chunk-idx", str(idx),
150 | "--conv-mode", "vicuna_v1",
151 | ]
152 |
153 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
154 | processes.append(p)
155 | p.start() # 启动子进程
156 | for p in processes:
157 | p.join()
158 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
159 | "--pred_path", os.path.join(model_path, "evaluation", "nextoe"),
160 | "--output_dir", os.path.join(model_path, "evaluation", "nextoe", "results"),
161 | "--output_json", os.path.join(model_path, "evaluation", "nextoe", "results.json"),
162 | "--num_chunks", str(num_chunks),
163 | "--num_tasks", "16",
164 | "--api_key", args.api_key,
165 | "--api_base", args.api_base,
166 | "--api_type", args.api_type,
167 | "--api_version", args.api_version,
168 | ]
169 | exec(cmd)
170 |
171 | # multi gpu, feature
172 | def eval_vsmovienet(args): # follow msvd format
173 | model_path = args.model_path
174 | num_chunks = args.num_chunks
175 | if not args.only_eval:
176 | processes = []
177 | for idx in range(0, num_chunks):
178 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
179 | "--model-path", model_path,
180 | "--video_dir", "./data/eval_video/vstream/movienet_video_features",
181 | "--gt_file", "./data/eval_video/vstream/test_qa_movienet.json",
182 | "--output_dir", os.path.join(model_path, "evaluation", "vsmovienet"),
183 | "--output_name", "pred",
184 | "--num-chunks", str(num_chunks),
185 | "--chunk-idx", str(idx),
186 | "--conv-mode", "vicuna_v1",
187 | ]
188 |
189 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
190 | processes.append(p)
191 | p.start() # 启动子进程
192 | for p in processes:
193 | p.join()
194 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
195 | "--pred_path", os.path.join(model_path, "evaluation", "vsmovienet"),
196 | "--output_dir", os.path.join(model_path, "evaluation", "vsmovienet", "results"),
197 | "--output_json", os.path.join(model_path, "evaluation", "vsmovienet", "results.json"),
198 | "--num_chunks", str(num_chunks),
199 | "--num_tasks", "16",
200 | "--api_key", args.api_key,
201 | "--api_base", args.api_base,
202 | "--api_type", args.api_type,
203 | "--api_version", args.api_version,
204 | ]
205 | exec(cmd)
206 |
207 | # multi gpu, feature
208 | def eval_vsego4d(args): # follow msvd format
209 | model_path = args.model_path
210 | num_chunks = args.num_chunks
211 | if not args.only_eval:
212 | processes = []
213 | for idx in range(0, num_chunks):
214 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
215 | "--model-path", model_path,
216 | "--video_dir", "./data/eval_video/vstream/ego4d_video_features",
217 | "--gt_file", "./data/eval_video/vstream/test_qa_ego4d.json",
218 | "--output_dir", os.path.join(model_path, "evaluation", "vsego4d"),
219 | "--output_name", "pred",
220 | "--num-chunks", str(num_chunks),
221 | "--chunk-idx", str(idx),
222 | "--conv-mode", "vicuna_v1",
223 | ]
224 |
225 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
226 | processes.append(p)
227 | p.start() # 启动子进程
228 | for p in processes:
229 | p.join()
230 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
231 | "--pred_path", os.path.join(model_path, "evaluation", "vsego4d"),
232 | "--output_dir", os.path.join(model_path, "evaluation", "vsego4d", "results"),
233 | "--output_json", os.path.join(model_path, "evaluation", "vsego4d", "results.json"),
234 | "--num_chunks", str(num_chunks),
235 | "--num_tasks", "16",
236 | "--api_key", args.api_key,
237 | "--api_base", args.api_base,
238 | "--api_type", args.api_type,
239 | "--api_version", args.api_version,
240 | ]
241 | exec(cmd)
242 |
243 | # multi gpu, feature
244 | def eval_realtime_vsmovienet(args): # follow msvd format
245 | model_path = args.model_path
246 | num_chunks = args.num_chunks
247 | if not args.only_eval:
248 | processes = []
249 | for idx in range(0, num_chunks):
250 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
251 | "--model-path", model_path,
252 | "--video_dir", "./data/eval_video/vstream-realtime/movienet_video_features",
253 | "--gt_file", "./data/eval_video/vstream-realtime/test_qa_movienet.json",
254 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet"),
255 | "--output_name", "pred",
256 | "--num-chunks", str(num_chunks),
257 | "--chunk-idx", str(idx),
258 | "--conv-mode", "vicuna_v1",
259 | ]
260 |
261 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
262 | processes.append(p)
263 | p.start() # 启动子进程
264 | for p in processes:
265 | p.join()
266 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
267 | "--pred_path", os.path.join(model_path, "evaluation", "realtime_vsmovienet"),
268 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results"),
269 | "--output_json", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results.json"),
270 | "--num_chunks", str(num_chunks),
271 | "--num_tasks", "16",
272 | "--api_key", args.api_key,
273 | "--api_base", args.api_base,
274 | "--api_type", args.api_type,
275 | "--api_version", args.api_version,
276 | ]
277 | exec(cmd)
278 |
279 | # multi gpu, feature
280 | def eval_realtime_vsego4d(args): # follow msvd format
281 | model_path = args.model_path
282 | num_chunks = args.num_chunks
283 | if not args.only_eval:
284 | processes = []
285 | for idx in range(0, num_chunks):
286 | cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
287 | "--model-path", model_path,
288 | "--video_dir", "./data/eval_video/vstream-realtime/ego4d_video_features",
289 | "--gt_file", "./data/eval_video/vstream-realtime/test_qa_ego4d.json",
290 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d"),
291 | "--output_name", "pred",
292 | "--num-chunks", str(num_chunks),
293 | "--chunk-idx", str(idx),
294 | "--conv-mode", "vicuna_v1",
295 | ]
296 |
297 | p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
298 | processes.append(p)
299 | p.start() # 启动子进程
300 | for p in processes:
301 | p.join()
302 | cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
303 | "--pred_path", os.path.join(model_path, "evaluation", "realtime_vsego4d"),
304 | "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results"),
305 | "--output_json", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results.json"),
306 | "--num_chunks", str(num_chunks),
307 | "--num_tasks", "16",
308 | "--api_key", args.api_key,
309 | "--api_base", args.api_base,
310 | "--api_type", args.api_type,
311 | "--api_version", args.api_version,
312 | ]
313 | exec(cmd)
314 |
315 |
316 | if __name__ == "__main__":
317 | parser = argparse.ArgumentParser()
318 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
319 | parser.add_argument("--dataset", type=str, default=None)
320 | parser.add_argument("--api_key", type=str, default=None)
321 | parser.add_argument("--api_base", type=str, default=None)
322 | parser.add_argument("--api_type", type=str, default=None)
323 | parser.add_argument("--api_version", type=str, default=None)
324 | parser.add_argument("--num_chunks", type=int, default=1)
325 | parser.add_argument("--only_eval", action="store_true")
326 | parser.add_argument("--vizlen", type=int, default=0)
327 | parser.add_argument("--use_speech", action="store_true", default=False)
328 | args = parser.parse_args()
329 | func_dic = {'msvd': eval_msvd,
330 | 'msrvtt': eval_msrvtt,
331 | 'actnet': eval_actnet,
332 | 'nextoe': eval_nextoe,
333 | 'vsmovienet': eval_vsmovienet,
334 | 'vsego4d': eval_vsego4d,
335 | 'realtime_vsmovienet': eval_realtime_vsmovienet,
336 | 'realtime_vsego4d': eval_realtime_vsego4d,
337 | }
338 | if args.dataset in func_dic:
339 | print(f'Execute {args.dataset} evaluation')
340 | func_dic[args.dataset](args)
341 |
--------------------------------------------------------------------------------
/flash_vstream/eval_video/model_msvd_qa.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import os
4 | import json
5 | import math
6 | import torch
7 | import argparse
8 | from tqdm import tqdm
9 | from decord import VideoReader, cpu
10 |
11 | from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
12 | from flash_vstream.conversation import conv_templates, SeparatorStyle
13 | from flash_vstream.model.builder import load_pretrained_model
14 | from flash_vstream.utils import disable_torch_init
15 | from flash_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def parse_args():
30 | """
31 | Parse command-line arguments.
32 | """
33 | parser = argparse.ArgumentParser()
34 |
35 | # Define the command-line arguments
36 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
37 | parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)
38 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
39 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
40 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
41 | parser.add_argument("--model-base", type=str, default=None)
42 | parser.add_argument("--conv-mode", type=str, default=None)
43 | parser.add_argument("--num-chunks", type=int, default=1)
44 | parser.add_argument("--chunk-idx", type=int, default=0)
45 | parser.add_argument("--model-max-length", type=int, default=None)
46 |
47 | return parser.parse_args()
48 |
49 |
50 | def load_video(video_path):
51 | vr = VideoReader(video_path, ctx=cpu(0))
52 | total_frame_num = len(vr)
53 | fps = round(vr.get_avg_fps())
54 | frame_idx = [i for i in range(0, len(vr), fps)]
55 | spare_frames = vr.get_batch(frame_idx).asnumpy()
56 | return spare_frames
57 |
58 |
59 | def run_inference(args):
60 | """
61 | Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
62 |
63 | Args:
64 | args: Command-line arguments.
65 | """
66 | # Initialize the model
67 | model_name = get_model_name_from_path(args.model_path)
68 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length)
69 |
70 | # Load both ground truth file containing questions and answers
71 | with open(args.gt_file) as file:
72 | gt_questions = json.load(file)
73 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
74 |
75 | # Create the output directory if it doesn't exist
76 | if not os.path.exists(args.output_dir):
77 | try:
78 | os.makedirs(args.output_dir)
79 | except Exception as e:
80 | print(f'mkdir Except: {e}')
81 |
82 | video_formats = ['.mp4', '.avi', '.mov', '.mkv']
83 | if args.num_chunks > 1:
84 | output_name = f"{args.num_chunks}_{args.chunk_idx}"
85 | else:
86 | output_name = args.output_name
87 | answers_file = os.path.join(args.output_dir, f"{output_name}.json")
88 | ans_file = open(answers_file, "w")
89 |
90 | for sample in tqdm(gt_questions, desc=f"cuda:{args.chunk_idx} "):
91 | video_name = sample['video_id']
92 | question = sample['question']
93 | id = sample['id']
94 | answer = sample['answer']
95 |
96 | sample_set = {'id': id, 'question': question, 'answer': answer}
97 |
98 | # Load the video file
99 | for fmt in video_formats: # Added this line
100 | temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}")
101 | if os.path.exists(temp_path):
102 | video_path = temp_path
103 | break
104 |
105 | # Check if the video exists
106 | if os.path.exists(video_path):
107 | video = load_video(video_path)
108 | video = image_processor.preprocess(video, return_tensors='pt')['pixel_values'].half().cuda()
109 | video = [video]
110 |
111 | qs = question
112 | if model.config.mm_use_im_start_end:
113 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
114 | else:
115 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
116 |
117 | conv = conv_templates[args.conv_mode].copy()
118 | conv.append_message(conv.roles[0], qs)
119 | conv.append_message(conv.roles[1], None)
120 | prompt = conv.get_prompt()
121 |
122 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
123 |
124 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
125 | keywords = [stop_str]
126 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
127 |
128 | with torch.inference_mode():
129 | output_ids = model.generate(
130 | input_ids,
131 | images=video,
132 | do_sample=True,
133 | temperature=0.002,
134 | max_new_tokens=1024,
135 | use_cache=True,
136 | stopping_criteria=[stopping_criteria])
137 |
138 | input_token_len = input_ids.shape[1]
139 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
140 | if n_diff_input_output > 0:
141 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
142 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
143 | outputs = outputs.strip()
144 | if outputs.endswith(stop_str):
145 | outputs = outputs[:-len(stop_str)]
146 | outputs = outputs.strip()
147 |
148 | sample_set['pred'] = outputs
149 | ans_file.write(json.dumps(sample_set) + "\n")
150 | ans_file.flush()
151 |
152 | ans_file.close()
153 |
154 |
155 | if __name__ == "__main__":
156 | args = parse_args()
157 | run_inference(args)
158 |
--------------------------------------------------------------------------------
/flash_vstream/eval_video/model_msvd_qa_featuresloader.py:
--------------------------------------------------------------------------------
1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2 | # Based on https://github.com/haotian-liu/LLaVA.
3 |
4 | import os
5 | import json
6 | import math
7 | import torch
8 | import random
9 | import argparse
10 | from tqdm import tqdm
11 | from torch.utils.data import Dataset, DataLoader
12 | from safetensors.torch import load_file
13 |
14 | from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
15 | from flash_vstream.conversation import conv_templates, SeparatorStyle
16 | from flash_vstream.model.builder import load_pretrained_model
17 | from flash_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
18 |
19 |
20 | def split_list(lst, n):
21 | """Split a list into n (roughly) equal-sized chunks"""
22 | chunk_size = math.ceil(len(lst) / n) # integer division
23 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
24 |
25 |
26 | def get_chunk(lst, n, k):
27 | chunks = split_list(lst, n)
28 | return chunks[k]
29 |
30 |
31 | def parse_args():
32 | """
33 | Parse command-line arguments.
34 | """
35 | parser = argparse.ArgumentParser()
36 |
37 | # Define the command-line arguments
38 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
39 | parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)
40 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
41 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
42 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
43 | parser.add_argument("--model-base", type=str, default=None)
44 | parser.add_argument("--conv-mode", type=str, default=None)
45 | parser.add_argument("--num-chunks", type=int, default=1)
46 | parser.add_argument("--chunk-idx", type=int, default=0)
47 | parser.add_argument("--model-max-length", type=int, default=None)
48 | return parser.parse_args()
49 |
50 |
51 | class CustomDataset(Dataset):
52 | def __init__(self, questions, video_dir, tokenizer, image_processor, model_config):
53 | self.questions = questions
54 | self.video_dir = video_dir
55 | self.tokenizer = tokenizer
56 | self.image_processor = image_processor
57 | self.model_config = model_config
58 |
59 | def __getitem__(self, index):
60 | sample = self.questions[index]
61 | video_name = sample['video_id']
62 | try:
63 | video_path = os.path.join(self.video_dir, video_name + '.safetensors')
64 | video_tensor = load_file(video_path)['feature']
65 | except Exception as e:
66 | print(f'Dataset Exception: {e}, randomly choose one.')
67 | idx = random.randint(0, len(self.questions) - 1)
68 | return self.__getitem__(idx)
69 | qs = sample['question']
70 | if self.model_config.mm_use_im_start_end:
71 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
72 | else:
73 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
74 | conv = conv_templates[args.conv_mode].copy()
75 | if 'system' in sample:
76 | conv.system = conv.system + ' ' + sample['system']
77 | conv.append_message(conv.roles[0], qs)
78 | conv.append_message(conv.roles[1], None)
79 | prompt = conv.get_prompt()
80 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
81 | return input_ids, video_tensor
82 |
83 | def __len__(self):
84 | return len(self.questions)
85 |
86 |
87 | def create_data_loader(questions, video_dir, tokenizer, image_processor, model_config, batch_size=1, num_workers=2):
88 | assert batch_size == 1, "batch_size must be 1"
89 | dataset = CustomDataset(questions, video_dir, tokenizer, image_processor, model_config)
90 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
91 | return data_loader
92 |
93 |
94 | def run_inference(args):
95 | """
96 | Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
97 |
98 | Args:
99 | args: Command-line arguments.
100 | """
101 | # Initialize the model
102 | model_name = get_model_name_from_path(args.model_path)
103 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length)
104 |
105 | # Load both ground truth file containing questions and answers
106 | with open(args.gt_file) as file:
107 | gt_questions = json.load(file)
108 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
109 |
110 | # Create the output directory if it doesn't exist
111 | if not os.path.exists(args.output_dir):
112 | try:
113 | os.makedirs(args.output_dir)
114 | except Exception as e:
115 | print(f'mkdir Except: {e}')
116 |
117 | video_formats = ['.mp4', '.avi', '.mov', '.mkv']
118 | if args.num_chunks > 1:
119 | output_name = f"{args.num_chunks}_{args.chunk_idx}"
120 | else:
121 | output_name = args.output_name
122 | answers_file = os.path.join(args.output_dir, f"{output_name}.json")
123 | # resume from old exp
124 | exist_id_set = set()
125 | if os.path.exists(answers_file):
126 | with open(answers_file) as f:
127 | exist_pred_contents = [json.loads(line) for line in f]
128 | exist_id_set = set([x['id'] for x in exist_pred_contents])
129 |
130 | new_gt_questions = []
131 | for sample in tqdm(gt_questions):
132 | if not sample['id'] in exist_id_set:
133 | new_gt_questions.append(sample)
134 | gt_questions = new_gt_questions
135 |
136 | data_loader = create_data_loader(gt_questions, args.video_dir, tokenizer, image_processor, model.config)
137 |
138 | conv = conv_templates[args.conv_mode].copy()
139 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
140 | keywords = [stop_str]
141 |
142 | with open(answers_file, "a") as ans_file:
143 | for data, sample in tqdm(zip(data_loader, gt_questions), desc=f"cuda:{args.chunk_idx} ", total=len(gt_questions)):
144 | input_ids, video_tensors = data
145 | input_ids = input_ids.to(device='cuda', non_blocking=True)
146 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
147 | with torch.inference_mode():
148 | output_ids = model.generate(
149 | input_ids,
150 | features=video_tensors.to(dtype=torch.float16, device='cuda', non_blocking=True),
151 | do_sample=True,
152 | temperature=0.002,
153 | max_new_tokens=1024,
154 | use_cache=True,
155 | stopping_criteria=[stopping_criteria],
156 | )
157 | input_token_len = input_ids.shape[1]
158 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
159 | if n_diff_input_output > 0:
160 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
161 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
162 | outputs = outputs.strip()
163 | if outputs.endswith(stop_str):
164 | outputs = outputs[:-len(stop_str)]
165 | outputs = outputs.strip()
166 | sample_set = {
167 | 'id': sample['id'],
168 | 'question': sample['question'],
169 | 'answer': sample['answer'],
170 | 'answer_type': sample['answer_type'] if 'answer_type' in sample else None,
171 | 'pred': outputs
172 | }
173 | ans_file.write(json.dumps(sample_set) + "\n")
174 | ans_file.flush()
175 |
176 |
177 | if __name__ == "__main__":
178 | args = parse_args()
179 | run_inference(args)
180 |
--------------------------------------------------------------------------------
/flash_vstream/mm_utils.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | from PIL import Image
4 | from io import BytesIO
5 | import base64
6 |
7 | import torch
8 | from transformers import StoppingCriteria
9 | from flash_vstream.constants import IMAGE_TOKEN_INDEX
10 |
11 |
12 | def load_image_from_base64(image):
13 | return Image.open(BytesIO(base64.b64decode(image)))
14 |
15 |
16 | def expand2square(pil_img, background_color):
17 | width, height = pil_img.size
18 | if width == height:
19 | return pil_img
20 | elif width > height:
21 | result = Image.new(pil_img.mode, (width, width), background_color)
22 | result.paste(pil_img, (0, (width - height) // 2))
23 | return result
24 | else:
25 | result = Image.new(pil_img.mode, (height, height), background_color)
26 | result.paste(pil_img, ((height - width) // 2, 0))
27 | return result
28 |
29 |
30 | def process_images(images, image_processor, model_cfg):
31 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32 | new_images = []
33 | if image_aspect_ratio == 'pad':
34 | for image in images:
35 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
36 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
37 | new_images.append(image)
38 | else:
39 | return image_processor(images, return_tensors='pt')['pixel_values']
40 | if all(x.shape == new_images[0].shape for x in new_images):
41 | new_images = torch.stack(new_images, dim=0)
42 | return new_images
43 |
44 |
45 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
46 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
47 |
48 | def insert_separator(X, sep):
49 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
50 |
51 | input_ids = []
52 | offset = 0
53 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
54 | offset = 1
55 | input_ids.append(prompt_chunks[0][0])
56 |
57 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
58 | input_ids.extend(x[offset:])
59 |
60 | if return_tensors is not None:
61 | if return_tensors == 'pt':
62 | return torch.tensor(input_ids, dtype=torch.long)
63 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
64 | return input_ids
65 |
66 |
67 | def get_model_name_from_path(model_path):
68 | model_path = model_path.strip("/")
69 | model_paths = model_path.split("/")
70 | if model_paths[-1].startswith('checkpoint-'):
71 | return model_paths[-2] + "_" + model_paths[-1]
72 | else:
73 | return model_paths[-1]
74 |
75 | class KeywordsStoppingCriteria(StoppingCriteria):
76 | def __init__(self, keywords, tokenizer, input_ids):
77 | self.keywords = keywords
78 | self.keyword_ids = []
79 | self.max_keyword_len = 0
80 | for keyword in keywords:
81 | cur_keyword_ids = tokenizer(keyword).input_ids
82 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
83 | cur_keyword_ids = cur_keyword_ids[1:]
84 | if len(cur_keyword_ids) > self.max_keyword_len:
85 | self.max_keyword_len = len(cur_keyword_ids)
86 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
87 | self.tokenizer = tokenizer
88 | self.start_len = input_ids.shape[1]
89 |
90 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
91 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
92 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
93 | for keyword_id in self.keyword_ids:
94 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
95 | return True
96 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
97 | for keyword in self.keywords:
98 | if keyword in outputs:
99 | return True
100 | return False
101 |
102 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
103 | outputs = []
104 | for i in range(output_ids.shape[0]):
105 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
106 | return all(outputs)
107 |
--------------------------------------------------------------------------------
/flash_vstream/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.vstream_llama import VStreamLlamaForCausalLM, VStreamConfig
2 |
--------------------------------------------------------------------------------
/flash_vstream/model/builder.py:
--------------------------------------------------------------------------------
1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2 | # ------------------------------------------------------------------------
3 | # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4 | # Copyright 2023 Haotian Liu
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | import os
20 | import warnings
21 | import shutil
22 |
23 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
24 | import torch
25 | from flash_vstream.model import VStreamLlamaForCausalLM, VStreamConfig
26 | from flash_vstream.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27 |
28 |
29 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
30 | kwargs = {"device_map": device_map, **kwargs}
31 |
32 | if device != "cuda":
33 | kwargs['device_map'] = {"": device}
34 |
35 | if load_8bit:
36 | kwargs['load_in_8bit'] = True
37 | elif load_4bit:
38 | kwargs['load_in_4bit'] = True
39 | kwargs['quantization_config'] = BitsAndBytesConfig(
40 | load_in_4bit=True,
41 | bnb_4bit_compute_dtype=torch.float16,
42 | bnb_4bit_use_double_quant=True,
43 | bnb_4bit_quant_type='nf4'
44 | )
45 | else:
46 | kwargs['torch_dtype'] = torch.float16
47 |
48 | if 'vstream' in model_name.lower():
49 | # Load LLaMA-VStream model
50 | if 'lora' in model_name.lower() and model_base is None:
51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52 | if 'lora' in model_name.lower() and model_base is not None:
53 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
54 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
55 | print('(LoRA) Loading LLaMA-VStream from base model...')
56 | model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
57 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
58 | if model.lm_head.weight.shape[0] != token_num:
59 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
60 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61 |
62 | print('(LoRA) Loading additional LLaMA-VStream weights...')
63 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
64 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
65 | else:
66 | # this is probably from HF Hub
67 | from huggingface_hub import hf_hub_download
68 | def load_from_hf(repo_id, filename, subfolder=None):
69 | cache_file = hf_hub_download(
70 | repo_id=repo_id,
71 | filename=filename,
72 | subfolder=subfolder)
73 | return torch.load(cache_file, map_location='cpu')
74 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
75 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
76 | if any(k.startswith('model.model.') for k in non_lora_trainables):
77 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
78 | model.load_state_dict(non_lora_trainables, strict=False)
79 |
80 | from peft import PeftModel
81 | print('Loading LoRA weights...')
82 | model = PeftModel.from_pretrained(model, model_path)
83 | print('Merging LoRA weights...')
84 | model = model.merge_and_unload()
85 | print('Model is loaded...')
86 | elif model_base is not None:
87 | # this may be mm projector only
88 | print('Loading LLaMA-VStream from base model...')
89 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
90 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
91 | model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
92 |
93 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
94 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
95 | model.load_state_dict(mm_projector_weights, strict=False)
96 | else:
97 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
98 | model = VStreamLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
99 | else:
100 | # Load language model
101 | if model_base is not None:
102 | # PEFT model
103 | from peft import PeftModel
104 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
105 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
106 | print(f"Loading LoRA weights from {model_path}")
107 | model = PeftModel.from_pretrained(model, model_path)
108 | print(f"Merging weights")
109 | model = model.merge_and_unload()
110 | print('Convert to FP16...')
111 | model.to(torch.float16)
112 | else:
113 | use_fast = False
114 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
115 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
116 |
117 | image_processor = None
118 |
119 | if 'vstream' in model_name.lower():
120 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
121 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
122 | if mm_use_im_patch_token:
123 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
124 | if mm_use_im_start_end:
125 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
126 | model.resize_token_embeddings(len(tokenizer))
127 |
128 | vision_tower = model.get_vision_tower()
129 | if not vision_tower.is_loaded:
130 | vision_tower.load_model()
131 | vision_tower.to(device=device, dtype=torch.float16)
132 | image_processor = vision_tower.image_processor
133 |
134 | if hasattr(model.config, "max_sequence_length"):
135 | context_len = model.config.max_sequence_length
136 | else:
137 | context_len = 2048
138 |
139 | return tokenizer, model, image_processor, context_len
140 |
--------------------------------------------------------------------------------
/flash_vstream/model/compress_functions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Flash-VStream Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import random
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | def drop_feature(img_feature, video_max_frames, img_similarity=None):
21 | T, P, D = img_feature.shape
22 | indices = [[i] for i in range(T)]
23 | T0 = video_max_frames
24 | if T <= T0:
25 | return img_feature, img_similarity, [indices]
26 | cur_feature = img_feature[:T0] # [T0, P, D]
27 | if img_similarity is not None:
28 | cur_sim = img_similarity[:T0 - 1]
29 | else:
30 | cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
31 | cur_indices = indices[:T0]
32 | step_indices = [cur_indices]
33 | for i in range(T0, T):
34 | new_feature = img_feature[i]
35 | new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
36 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
37 | all_indices = cur_indices + [[i]]
38 | all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
39 | idx = torch.argmax(all_sim)
40 | if random.randint(0, 1) > 0:
41 | idx = idx + 1
42 | cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
43 | if idx + 1 == T0 + 1:
44 | cur_sim = all_sim[:T0 - 1]
45 | cur_indices = all_indices[:-1]
46 | elif idx == 0:
47 | cur_sim = all_sim[1:]
48 | cur_indices = all_indices[1:]
49 | else:
50 | cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
51 | cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
52 | cur_indices = all_indices[:idx] + all_indices[idx + 1:]
53 | step_indices.append(cur_indices)
54 | # print(f'Note: perform drop feature {img_feature.shape} to {cur_feature.shape}')
55 | return cur_feature, cur_sim, step_indices
56 |
57 |
58 | def merge_feature(img_feature, video_max_frames, img_similarity=None):
59 | T, P, D = img_feature.shape
60 | indices = [[i] for i in range(T)]
61 | T0 = video_max_frames
62 | if T <= T0:
63 | return img_feature, img_similarity, [indices]
64 | cur_feature = img_feature[:T0] # [T0, P, D]
65 | cur_indices = indices[:T0]
66 | step_indices = [cur_indices]
67 | if img_similarity is not None:
68 | cur_sim = img_similarity[:T0 - 1]
69 | else:
70 | cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
71 | for i in range(T0, T):
72 | new_feature = img_feature[i]
73 | new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
74 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
75 | all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
76 | all_indices = cur_indices + [[i]]
77 | idx = torch.argmax(all_sim)
78 | all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0
79 | all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1]
80 | cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
81 | cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
82 | cur_indices = all_indices[:idx] + all_indices[idx + 1:]
83 | if idx > 0:
84 | cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
85 | if idx + 1 < T0:
86 | cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0)
87 | step_indices.append(cur_indices)
88 | # print(f'Note: perform merge feature {img_feature.shape} to {cur_feature.shape}')
89 | return cur_feature, cur_sim, step_indices
90 |
91 |
92 | def kmeans_feature(img_feature, video_max_frames, img_similarity=None):
93 | def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10):
94 | indices = torch.randperm(X.size(0))[:num_clusters]
95 | centroids = X[indices]
96 | for i in range(max_iter):
97 | if distance == 'euclidean':
98 | dists = torch.cdist(X, centroids, p=2)
99 | else:
100 | raise NotImplementedError("Only Euclidean distance is supported yet")
101 | labels = torch.argmin(dists, dim=1)
102 | new_centroids = []
103 | for j in range(num_clusters):
104 | cluster_points = X[labels == j]
105 | if len(cluster_points) > 0:
106 | new_centroid = cluster_points.mean(0)
107 | else: # fix nan centroids
108 | new_centroid = X[random.randint(0, X.size(0) - 1)]
109 | new_centroids.append(new_centroid)
110 | new_centroids = torch.stack(new_centroids)
111 | diff = torch.norm(centroids - new_centroids, dim=1).sum()
112 | if diff < tol:
113 | break
114 | centroids = new_centroids
115 | return centroids, labels, i
116 | T, P, D = img_feature.shape
117 | T0 = video_max_frames
118 | if T <= T0:
119 | return img_feature, img_similarity, [[[i] for i in range(T)]]
120 | X = img_feature.view(T, -1) # [T, P, D]
121 | centroids, labels, exit_step = kmeans_torch(X, T0)
122 | reduced_feature = centroids.view(T0, P, D)
123 | # print(f'Note: perform kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
124 | step_indices = [[] for _ in range(T0)]
125 | for i in range(T0):
126 | step_indices[i] = [j for j in range(T) if labels[j] == i]
127 | return reduced_feature, img_similarity, [step_indices]
128 |
129 |
130 | def weighted_kmeans_feature(img_feature, video_max_frames, weights=None):
131 | if weights is None:
132 | weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device)
133 | def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10):
134 | indices = torch.randperm(X.size(0), device=X.device)[:num_clusters]
135 | centroids = X[indices]
136 | for i in range(max_iter):
137 | if distance == 'euclidean':
138 | dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt()
139 | else:
140 | raise NotImplementedError("Only Euclidean distance is supported yet")
141 | labels = torch.argmin(dists, dim=1)
142 | weighted_sum = torch.zeros_like(centroids)
143 | weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device)
144 | for j in range(num_clusters):
145 | cluster_mask = labels == j
146 | weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0)
147 | weights_sum[j] = torch.sum(weights[cluster_mask])
148 | mask = weights_sum > 0
149 | new_centroids = torch.zeros_like(weighted_sum)
150 | new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None]
151 | if mask.sum() < num_clusters: # fix nan centroids
152 | new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())])
153 | diff = torch.norm(centroids - new_centroids, dim=1).sum()
154 | if diff < tol:
155 | break
156 | centroids = new_centroids
157 | return centroids, labels, weights_sum, i
158 | T, P, D = img_feature.shape
159 | T0 = video_max_frames
160 | if T <= T0:
161 | return img_feature, weights, [[[i] for i in range(T)]]
162 | X = img_feature.view(T, -1) # [T, P, D]
163 | centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights)
164 | reduced_feature = centroids.view(T0, P, D)
165 | # print(f'Note: perform weighted kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
166 | step_indices = [[] for _ in range(T0)]
167 | for i in range(T0):
168 | step_indices[i] = [j for j in range(T) if labels[j] == i]
169 | return reduced_feature, weights, [step_indices]
170 |
171 |
172 | def k_drop_feature(img_feature, video_max_frames, img_similarity=None):
173 | T, P, D = img_feature.shape
174 | indices = [[i] for i in range(T)]
175 | T0 = video_max_frames
176 | if T <= T0:
177 | return img_feature, img_similarity, [indices]
178 | cur_feature = img_feature[:T0] # [T0, P, D]
179 | normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
180 | cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
181 | cur_sim.fill_diagonal_(-100.0)
182 | cur_indices = indices[:T0]
183 | step_indices = [cur_indices]
184 | for i in range(T0, T):
185 | # get new feature
186 | new_feature = img_feature[i]
187 | normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
188 | new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
189 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
190 | normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
191 | all_indices = cur_indices + [[i]]
192 | # get new similarity
193 | all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
194 | all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
195 | all_sim[-1, :-1] = new_sim.T
196 | # choose compression position
197 | idx = torch.argmax(all_sim)
198 | left, right = idx // (T0 + 1), idx % (T0 + 1)
199 | if random.randint(0, 1) > 0:
200 | idx = left
201 | else:
202 | idx = right
203 | assert all_sim[left, right] == torch.max(all_sim)
204 | # get compressed feature and similarity
205 | cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
206 | normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]])
207 | cur_indices = all_indices[:idx] + all_indices[idx + 1:]
208 | cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) # [T0, T0 + 1]
209 | cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) # [T0, T0]
210 | step_indices.append(cur_indices)
211 | # print(f'Note: perform k-drop feature {img_feature.shape} to {cur_feature.shape}')
212 | return cur_feature, None, step_indices
213 |
214 |
215 | def k_merge_feature(img_feature, video_max_frames, img_similarity=None):
216 | T, P, D = img_feature.shape
217 | indices = [[i] for i in range(T)]
218 | T0 = video_max_frames
219 | if T <= T0:
220 | return img_feature, img_similarity, [indices]
221 | cur_feature = img_feature[:T0] # [T0, P, D]
222 | normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
223 | cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
224 | cur_sim.fill_diagonal_(-100.0)
225 | cur_indices = indices[:T0]
226 | step_indices = [cur_indices]
227 | for i in range(T0, T):
228 | # get new feature
229 | new_feature = img_feature[i]
230 | normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
231 | new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
232 | all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
233 | normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
234 | all_indices = cur_indices + [[i]]
235 | # get new similarity
236 | all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
237 | all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
238 | all_sim[-1, :-1] = new_sim.T
239 | # choose compression position
240 | idx = torch.argmax(all_sim)
241 | left, right = idx // (T0 + 1), idx % (T0 + 1)
242 | assert all_sim[left, right] == torch.max(all_sim)
243 | # update feature
244 | all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0
245 | normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1)
246 | all_indices[right] = all_indices[left] + all_indices[right]
247 | # update similarity
248 | new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) # [T0 + 1, 1]
249 | all_sim[right, :] = new_sim.T
250 | all_sim[:, right:right+1] = new_sim
251 | all_sim[right, right] = -100.0
252 | # get compressed feature and similarity
253 | cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]])
254 | normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]])
255 | cur_indices = all_indices[:left] + all_indices[left + 1:]
256 | cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) # [T0, T0 + 1]
257 | cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) # [T0, T0]
258 | step_indices.append(cur_indices)
259 | # print(f'Note: perform k-merge feature {img_feature.shape} to {cur_feature.shape}')
260 | return cur_feature, cur_sim, step_indices
261 |
262 |
263 | def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2):
264 | T, P, D = img_feature.shape
265 | T0 = video_max_frames
266 | if T <= T0:
267 | return img_feature, None
268 | cur_feature = img_feature[:T0] # [T0, P, D]
269 | turing_memory = cur_feature.reshape(T0*P, D) # [T0*P, D]
270 | for i in range(T0, T, T0):
271 | j = min(i + T0, T)
272 | new_feature = img_feature[i:j] # [P, D]
273 | new_feature = new_feature.reshape(-1, D) # [n*P, D]
274 | turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) # [T0*P, n*P]
275 | cur_feature = turing_memory.reshape(T0, P, D)
276 | # print(f'Note: perform {attention_fn.__name__} feature {img_feature.shape} to {cur_feature.shape}')
277 | return cur_feature, None
278 |
--------------------------------------------------------------------------------
/flash_vstream/model/language_model/vstream_llama.py:
--------------------------------------------------------------------------------
1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2 | # ------------------------------------------------------------------------
3 | # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4 | # Copyright 2023 Haotian Liu
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from typing import List, Optional, Tuple, Union
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | LlamaConfig, LlamaModel, LlamaForCausalLM
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from flash_vstream.model.vstream_arch import VStreamMetaModel, VStreamMetaForCausalLM
26 |
27 |
28 | class VStreamConfig(LlamaConfig):
29 | model_type = "vstream"
30 |
31 |
32 | class VStreamLlamaModel(VStreamMetaModel, LlamaModel):
33 | config_class = VStreamConfig
34 |
35 | def __init__(self, config: LlamaConfig):
36 | super(VStreamLlamaModel, self).__init__(config)
37 |
38 |
39 | class VStreamLlamaForCausalLM(VStreamMetaForCausalLM, LlamaForCausalLM):
40 | config_class = VStreamConfig
41 |
42 | def __init__(self, config):
43 | super(VStreamLlamaForCausalLM, self).__init__(config)
44 | self.model = VStreamLlamaModel(config)
45 | self.pretraining_tp = config.pretraining_tp
46 | self.vocab_size = config.vocab_size
47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
48 |
49 | # Initialize weights and apply final processing
50 | self.post_init()
51 |
52 | def get_model(self):
53 | return self.model
54 |
55 | def forward(
56 | self,
57 | input_ids: torch.LongTensor = None,
58 | attention_mask: Optional[torch.Tensor] = None,
59 | position_ids: Optional[torch.LongTensor] = None,
60 | past_key_values: Optional[List[torch.FloatTensor]] = None,
61 | inputs_embeds: Optional[torch.FloatTensor] = None,
62 | labels: Optional[torch.LongTensor] = None,
63 | use_cache: Optional[bool] = True,
64 | output_attentions: Optional[bool] = None,
65 | output_hidden_states: Optional[bool] = None,
66 | images: Optional[torch.FloatTensor] = None,
67 | features: Optional[torch.FloatTensor] = None,
68 | return_dict: Optional[bool] = None,
69 | ) -> Union[Tuple, CausalLMOutputWithPast]:
70 | if inputs_embeds is None:
71 | if self.use_video_streaming_mode:
72 | (
73 | input_ids,
74 | position_ids,
75 | attention_mask,
76 | past_key_values,
77 | inputs_embeds,
78 | labels
79 | ) = self.prepare_inputs_labels_for_multimodal_streaming(
80 | input_ids,
81 | position_ids,
82 | attention_mask,
83 | past_key_values,
84 | labels,
85 | )
86 | else:
87 | (
88 | input_ids,
89 | position_ids,
90 | attention_mask,
91 | past_key_values,
92 | inputs_embeds,
93 | labels
94 | ) = self.prepare_inputs_labels_for_multimodal(
95 | input_ids,
96 | position_ids,
97 | attention_mask,
98 | past_key_values,
99 | labels,
100 | images,
101 | features,
102 | )
103 | return super().forward(
104 | input_ids=input_ids,
105 | attention_mask=attention_mask,
106 | position_ids=position_ids,
107 | past_key_values=past_key_values,
108 | inputs_embeds=inputs_embeds,
109 | labels=labels,
110 | use_cache=use_cache,
111 | output_attentions=output_attentions,
112 | output_hidden_states=output_hidden_states,
113 | return_dict=return_dict
114 | )
115 |
116 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
117 | images = kwargs.pop("images", None)
118 | features = kwargs.pop("features", None)
119 | _inputs = super().prepare_inputs_for_generation(
120 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
121 | )
122 | if images is not None:
123 | _inputs['images'] = images
124 | if features is not None:
125 | _inputs['features'] = features
126 | return _inputs
127 |
128 | AutoConfig.register("vstream", VStreamConfig)
129 | AutoModelForCausalLM.register(VStreamConfig, VStreamLlamaForCausalLM)
130 |
--------------------------------------------------------------------------------
/flash_vstream/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import os
4 | from .clip_encoder import CLIPVisionTower
5 |
6 |
7 | def build_vision_tower(vision_tower_cfg, **kwargs):
8 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
9 | is_absolute_path_exists = os.path.exists(vision_tower)
10 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
11 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
12 |
13 | raise ValueError(f'Unknown vision tower: {vision_tower}')
14 |
--------------------------------------------------------------------------------
/flash_vstream/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
7 |
8 |
9 | class CLIPVisionTower(nn.Module):
10 | def __init__(self, vision_tower, args, delay_load=False):
11 | super().__init__()
12 |
13 | self.is_loaded = False
14 |
15 | self.vision_tower_name = vision_tower
16 | self.select_layer = args.mm_vision_select_layer
17 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
18 |
19 | if not delay_load:
20 | self.load_model()
21 | else:
22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23 |
24 | def load_model(self):
25 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
26 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
27 | self.vision_tower.requires_grad_(False)
28 |
29 | self.is_loaded = True
30 |
31 | def feature_select(self, image_forward_outs):
32 | image_features = image_forward_outs.hidden_states[self.select_layer]
33 | if self.select_feature == 'patch':
34 | image_features = image_features[:, 1:]
35 | elif self.select_feature == 'cls_patch':
36 | image_features = image_features
37 | else:
38 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
39 | return image_features
40 |
41 | @torch.no_grad()
42 | def forward(self, images):
43 | if type(images) is list:
44 | image_features = []
45 | for image in images:
46 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
47 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
48 | image_features.append(image_feature)
49 | else:
50 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
51 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
52 |
53 | return image_features
54 |
55 | @property
56 | def dummy_feature(self):
57 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
58 |
59 | @property
60 | def dtype(self):
61 | return self.vision_tower.dtype
62 |
63 | @property
64 | def device(self):
65 | return self.vision_tower.device
66 |
67 | @property
68 | def config(self):
69 | if self.is_loaded:
70 | return self.vision_tower.config
71 | else:
72 | return self.cfg_only
73 |
74 | @property
75 | def hidden_size(self):
76 | return self.config.hidden_size
77 |
78 | @property
79 | def num_patches(self):
80 | return (self.config.image_size // self.config.patch_size) ** 2
81 |
--------------------------------------------------------------------------------
/flash_vstream/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import torch
4 | import torch.nn as nn
5 | import re
6 |
7 |
8 | class IdentityMap(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, *args, **kwargs):
13 | return x
14 |
15 | @property
16 | def config(self):
17 | return {"mm_projector_type": 'identity'}
18 |
19 |
20 | class SimpleResBlock(nn.Module):
21 | def __init__(self, channels):
22 | super().__init__()
23 | self.pre_norm = nn.LayerNorm(channels)
24 |
25 | self.proj = nn.Sequential(
26 | nn.Linear(channels, channels),
27 | nn.GELU(),
28 | nn.Linear(channels, channels)
29 | )
30 | def forward(self, x):
31 | x = self.pre_norm(x)
32 | return x + self.proj(x)
33 |
34 |
35 | def build_vision_projector(config, input_dim, delay_load=False, **kwargs):
36 | projector_type = getattr(config, 'mm_projector_type', 'linear')
37 |
38 | if projector_type == 'linear':
39 | return nn.Linear(input_dim, config.hidden_size)
40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41 | if mlp_gelu_match:
42 | mlp_depth = int(mlp_gelu_match.group(1))
43 | modules = [nn.Linear(input_dim, config.hidden_size)]
44 | for _ in range(1, mlp_depth):
45 | modules.append(nn.GELU())
46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47 | return nn.Sequential(*modules)
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/flash_vstream/serve/cli_video_stream.py:
--------------------------------------------------------------------------------
1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2 | # Based on https://github.com/haotian-liu/LLaVA.
3 | """
4 | This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module.
5 | main process: CLI server I/O, LLM inference
6 | process-1: logger listener
7 | process-2: frame generator,
8 | process-3: frame memory manager
9 | Author: Haoji Zhang, Haotian Liu
10 | (This code is based on https://github.com/haotian-liu/LLaVA)
11 | """
12 | import argparse
13 | import requests
14 | import logging
15 | import torch
16 | import numpy as np
17 | import time
18 | import os
19 |
20 | from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
21 | from flash_vstream.conversation import conv_templates, SeparatorStyle
22 | from flash_vstream.model.builder import load_pretrained_model
23 | from flash_vstream.utils import disable_torch_init
24 | from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
25 |
26 | from torch.multiprocessing import Process, Queue, Manager
27 | from transformers import TextStreamer
28 | from decord import VideoReader
29 | from datetime import datetime
30 | from PIL import Image
31 | from io import BytesIO
32 |
33 | class _Metric:
34 | def __init__(self):
35 | self._latest_value = None
36 | self._sum = 0.0
37 | self._max = 0.0
38 | self._count = 0
39 |
40 | @property
41 | def val(self):
42 | return self._latest_value
43 |
44 | @property
45 | def max(self):
46 | return self._max
47 |
48 | @property
49 | def avg(self):
50 | if self._count == 0:
51 | return float('nan')
52 | return self._sum / self._count
53 |
54 | def add(self, value):
55 | self._latest_value = value
56 | self._sum += value
57 | self._count += 1
58 | if value > self._max:
59 | self._max = value
60 |
61 | def __str__(self):
62 | latest_formatted = f"{self.val:.6f}" if self.val is not None else "None"
63 | average_formatted = f"{self.avg:.6f}"
64 | max_formatted = f"{self.max:.6f}"
65 | return f"{latest_formatted} ({average_formatted}, {max_formatted})"
66 |
67 |
68 | class MetricMeter:
69 | def __init__(self):
70 | self._metrics = {}
71 |
72 | def add(self, key, value):
73 | if key not in self._metrics:
74 | self._metrics[key] = _Metric()
75 | self._metrics[key].add(value)
76 |
77 | def val(self, key):
78 | metric = self._metrics.get(key)
79 | if metric is None or metric.val is None:
80 | raise ValueError(f"No values have been added for key '{key}'.")
81 | return metric.val
82 |
83 | def avg(self, key):
84 | metric = self._metrics.get(key)
85 | if metric is None:
86 | raise ValueError(f"No values have been added for key '{key}'.")
87 | return metric.avg
88 |
89 | def max(self, key):
90 | metric = self._metrics.get(key)
91 | if metric is None:
92 | raise ValueError(f"No values have been added for key '{key}'.")
93 | return metric.max
94 |
95 | def __getitem__(self, key):
96 | metric = self._metrics.get(key)
97 | if metric is None:
98 | raise KeyError(f"The key '{key}' does not exist.")
99 | return str(metric)
100 |
101 | def load_image(image_file):
102 | if image_file.startswith('http://') or image_file.startswith('https://'):
103 | response = requests.get(image_file)
104 | image = Image.open(BytesIO(response.content)).convert('RGB')
105 | else:
106 | image = Image.open(image_file).convert('RGB')
107 | return image
108 |
109 | def listener(queue, filename):
110 | ############## Start sub process-1: Listener #############
111 | import sys, traceback
112 | root = logging.getLogger()
113 | root.setLevel(logging.DEBUG)
114 | # h = logging.StreamHandler(sys.stdout)
115 | h = logging.FileHandler(filename)
116 | f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s')
117 | h.setFormatter(f)
118 | root.addHandler(h)
119 | while True:
120 | try:
121 | record = queue.get()
122 | if record is None: # None is a signal to finish
123 | break
124 | logger = logging.getLogger(record.name)
125 | logger.handle(record) # No level or filter logic applied - just do it!
126 | except Exception:
127 | import sys, traceback
128 | print('Whoops! Problem:', file=sys.stderr)
129 | traceback.print_exc(file=sys.stderr)
130 |
131 | def worker_configurer(queue):
132 | h = logging.handlers.QueueHandler(queue) # Just the one handler needed
133 | root = logging.getLogger()
134 | root.addHandler(h)
135 | root.setLevel(logging.DEBUG)
136 |
137 | def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0):
138 | ############## Start sub process-2: Simulator #############
139 | worker_configurer(log_queue)
140 | logger = logging.getLogger(__name__)
141 | logger.setLevel(logging.DEBUG)
142 |
143 | vr = VideoReader(video_file)
144 | sample_fps = round(vr.get_avg_fps() / video_fps)
145 | frame_idx = [i for i in range(0, len(vr), sample_fps)]
146 | video = vr.get_batch(frame_idx).asnumpy()
147 | video = np.repeat(video, 6, axis=0)
148 | length = video.shape[0]
149 | sleep_time = 1 / video_fps / play_speed
150 | time_meter = MetricMeter()
151 | logger.info(f'Simulator Process: start, length = {length}')
152 | try:
153 | for start in range(0, length):
154 | start_time = time.perf_counter()
155 | end = min(start + 1, length)
156 | video_clip = video[start:end]
157 | frame_queue.put(video_clip)
158 | if start > 0:
159 | time_meter.add('real_sleep', start_time - last_start)
160 | logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}')
161 | if end < length:
162 | time.sleep(sleep_time)
163 | last_start = start_time
164 | frame_queue.put(None)
165 | except Exception as e:
166 | print(f'Simulator Exception: {e}')
167 | time.sleep(0.1)
168 | logger.info(f'Simulator Process: end')
169 |
170 | def frame_memory_manager(model, image_processor, frame_queue, log_queue):
171 | ############## Start sub process-3: Memory Manager #############
172 | worker_configurer(log_queue)
173 | logger = logging.getLogger(__name__)
174 | logger.setLevel(logging.DEBUG)
175 |
176 | time_meter = MetricMeter()
177 | logger.info(f'MemManager Process: start')
178 | frame_cnt = 0
179 | while True:
180 | try:
181 | video_clip = frame_queue.get()
182 | start_time = time.perf_counter()
183 | if video_clip is None:
184 | logger.info(f'MemManager: Ooops, get None')
185 | break
186 | logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue')
187 | image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values']
188 | image = image.unsqueeze(0)
189 | image_tensor = image.to(model.device, dtype=torch.float16)
190 | # time_2 = time.perf_counter()
191 | logger.info(f'MemManager: Start embedding')
192 | with torch.inference_mode():
193 | model.embed_video_streaming(image_tensor)
194 | logger.info(f'MemManager: End embedding')
195 | end_time = time.perf_counter()
196 | if frame_cnt > 0:
197 | time_meter.add('memory_latency', end_time - start_time)
198 | logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}')
199 | else:
200 | logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged')
201 | frame_cnt += video_clip.shape[0]
202 | except Exception as e:
203 | print(f'MemManager Exception: {e}')
204 | time.sleep(0.1)
205 | logger.info(f'MemManager Process: end')
206 |
207 | def main(args):
208 | # torch.multiprocessing.log_to_stderr(logging.DEBUG)
209 | torch.multiprocessing.set_start_method('spawn', force=True)
210 | disable_torch_init()
211 |
212 | log_queue = Queue()
213 | frame_queue = Queue(maxsize=10)
214 | processes = []
215 |
216 | ############## Start listener process #############
217 | p1 = Process(target=listener, args=(log_queue, args.log_file))
218 | processes.append(p1)
219 | p1.start()
220 |
221 | ############## Start main process #############
222 | worker_configurer(log_queue)
223 | logger = logging.getLogger(__name__)
224 |
225 | model_name = get_model_name_from_path(args.model_path)
226 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
227 |
228 | logger.info(f'Using conv_mode={args.conv_mode}')
229 |
230 | conv = conv_templates[args.conv_mode].copy()
231 | if "mpt" in model_name.lower():
232 | roles = ('user', 'assistant')
233 | else:
234 | roles = conv.roles
235 |
236 | with Manager() as manager:
237 | image_tensor = None
238 | model.use_video_streaming_mode = True
239 | model.video_embedding_memory = manager.list()
240 | if args.video_max_frames is not None:
241 | model.config.video_max_frames = args.video_max_frames
242 | logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}')
243 |
244 | logger.info(f'Important: set video_fps = {args.video_fps}')
245 | logger.info(f'Important: set play_speed = {args.play_speed}')
246 |
247 | ############## Start simulator process #############
248 | p2 = Process(target=video_stream_similator,
249 | args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed))
250 | processes.append(p2)
251 | p2.start()
252 |
253 | ############## Start memory manager process #############
254 | p3 = Process(target=frame_memory_manager,
255 | args=(model, image_processor, frame_queue, log_queue))
256 | processes.append(p3)
257 | p3.start()
258 |
259 | # start QA server
260 | start_time = datetime.now()
261 | time_meter = MetricMeter()
262 | conv_cnt = 0
263 | while True:
264 | time.sleep(5)
265 | try:
266 | # inp = input(f"{roles[0]}: ")
267 | inp = "what is in the video?"
268 | except EOFError:
269 | inp = ""
270 | if not inp:
271 | print("exit...")
272 | break
273 |
274 | # 获取当前时间
275 | now = datetime.now()
276 | conv_start_time = time.perf_counter()
277 | # 将当前时间格式化为字符串
278 | current_time = now.strftime("%H:%M:%S")
279 | duration = now.timestamp() - start_time.timestamp()
280 |
281 | # 打印当前时间
282 | print("\nCurrent Time:", current_time, "Run for:", duration)
283 | print(f"{roles[0]}: {inp}", end="\n")
284 | print(f"{roles[1]}: ", end="")
285 | # every conversation is a new conversation
286 | conv = conv_templates[args.conv_mode].copy()
287 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
288 | conv.append_message(conv.roles[0], inp)
289 |
290 | conv.append_message(conv.roles[1], None)
291 | prompt = conv.get_prompt()
292 |
293 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
294 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
295 | keywords = [stop_str]
296 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
297 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
298 |
299 | llm_start_time = time.perf_counter()
300 | with torch.inference_mode():
301 | output_ids = model.generate(
302 | input_ids,
303 | images=image_tensor,
304 | do_sample=True if args.temperature > 0 else False,
305 | temperature=args.temperature,
306 | max_new_tokens=args.max_new_tokens,
307 | streamer=streamer,
308 | use_cache=True,
309 | stopping_criteria=[stopping_criteria]
310 | )
311 | llm_end_time = time.perf_counter()
312 |
313 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
314 | conv.messages[-1][-1] = outputs
315 | conv_end_time = time.perf_counter()
316 | if conv_cnt > 0:
317 | time_meter.add('conv_latency', conv_end_time - conv_start_time)
318 | time_meter.add('llm_latency', llm_end_time - llm_start_time)
319 | time_meter.add('real_sleep', conv_start_time - last_conv_start_time)
320 | logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}')
321 | else:
322 | logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}')
323 | conv_cnt += 1
324 | last_conv_start_time = conv_start_time
325 |
326 | for p in processes:
327 | p.terminate()
328 | print("All processes finished.")
329 |
330 |
331 | if __name__ == "__main__":
332 | parser = argparse.ArgumentParser()
333 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
334 | parser.add_argument("--model-base", type=str, default=None)
335 | parser.add_argument("--image-file", type=str, default=None)
336 | parser.add_argument("--video-file", type=str, default=None)
337 | parser.add_argument("--device", type=str, default="cuda")
338 | parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
339 | parser.add_argument("--temperature", type=float, default=0.2)
340 | parser.add_argument("--max-new-tokens", type=int, default=512)
341 | parser.add_argument("--load-8bit", action="store_true")
342 | parser.add_argument("--load-4bit", action="store_true")
343 | parser.add_argument("--debug", action="store_true")
344 |
345 | parser.add_argument("--log-file", type=str, default="tmp_cli.log")
346 | parser.add_argument("--use_1process", action="store_true")
347 | parser.add_argument("--video_max_frames", type=int, default=None)
348 | parser.add_argument("--video_fps", type=float, default=1.0)
349 | parser.add_argument("--play_speed", type=float, default=1.0)
350 | args = parser.parse_args()
351 | main(args)
352 |
--------------------------------------------------------------------------------
/flash_vstream/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | from typing import Optional, Tuple
4 | import warnings
5 |
6 | import torch
7 |
8 | import transformers
9 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
10 |
11 | try:
12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
13 | except ImportError:
14 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
15 | from flash_attn.bert_padding import unpad_input, pad_input
16 |
17 |
18 | def forward(
19 | self,
20 | hidden_states: torch.Tensor,
21 | attention_mask: Optional[torch.Tensor] = None,
22 | position_ids: Optional[torch.Tensor] = None,
23 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
24 | output_attentions: bool = False,
25 | use_cache: bool = False,
26 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
27 | if output_attentions:
28 | warnings.warn(
29 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
30 | )
31 |
32 | bsz, q_len, _ = hidden_states.size()
33 |
34 | query_states = (
35 | self.q_proj(hidden_states)
36 | .view(bsz, q_len, self.num_heads, self.head_dim)
37 | .transpose(1, 2)
38 | )
39 | key_states = (
40 | self.k_proj(hidden_states)
41 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
42 | .transpose(1, 2)
43 | )
44 | value_states = (
45 | self.v_proj(hidden_states)
46 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
47 | .transpose(1, 2)
48 | ) # shape: (b, num_heads, s, head_dim)
49 |
50 | kv_seq_len = key_states.shape[-2]
51 | if past_key_value is not None:
52 | kv_seq_len += past_key_value[0].shape[-2]
53 |
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | query_states, key_states = apply_rotary_pos_emb(
56 | query_states, key_states, cos, sin, position_ids
57 | )
58 |
59 | if past_key_value is not None:
60 | # reuse k, v
61 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
62 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
63 |
64 | past_key_value = (key_states, value_states) if use_cache else None
65 |
66 | # repeat k/v heads if n_kv_heads < n_heads
67 | key_states = repeat_kv(key_states, self.num_key_value_groups)
68 | value_states = repeat_kv(value_states, self.num_key_value_groups)
69 |
70 | # Transform the data into the format required by flash attention
71 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
72 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
73 | key_padding_mask = attention_mask
74 |
75 | if key_padding_mask is None:
76 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
77 | cu_q_lens = torch.arange(
78 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
79 | )
80 | max_s = q_len
81 | output = flash_attn_unpadded_qkvpacked_func(
82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
83 | )
84 | output = output.view(bsz, q_len, -1)
85 | else:
86 | qkv = qkv.reshape(bsz, q_len, -1)
87 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
88 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
89 | output_unpad = flash_attn_unpadded_qkvpacked_func(
90 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
91 | )
92 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
93 | output = pad_input(output_unpad, indices, bsz, q_len)
94 |
95 | return self.o_proj(output), None, past_key_value
96 |
97 |
98 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
99 | # requires the attention mask to be the same as the key_padding_mask
100 | def _prepare_decoder_attention_mask(
101 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
102 | ):
103 | # [bsz, seq_len]
104 | return attention_mask
105 |
106 |
107 | def replace_llama_attn_with_flash_attn():
108 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
109 | if cuda_major < 8:
110 | warnings.warn(
111 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
112 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
113 | )
114 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
115 | _prepare_decoder_attention_mask
116 | )
117 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
118 |
--------------------------------------------------------------------------------
/flash_vstream/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | """
4 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
5 | """
6 |
7 | import logging
8 | import math
9 | from typing import Optional, Tuple
10 |
11 | import torch
12 | import transformers.models.llama.modeling_llama
13 | from torch import nn
14 |
15 | try:
16 | import xformers.ops
17 | except ImportError:
18 | logging.error("xformers not found! Please install it before trying to use it.")
19 |
20 |
21 | def replace_llama_attn_with_xformers_attn():
22 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
23 |
24 |
25 | def xformers_forward(
26 | self,
27 | hidden_states: torch.Tensor,
28 | attention_mask: Optional[torch.Tensor] = None,
29 | position_ids: Optional[torch.LongTensor] = None,
30 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
31 | output_attentions: bool = False,
32 | use_cache: bool = False,
33 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
34 | # pylint: disable=duplicate-code
35 | bsz, q_len, _ = hidden_states.size()
36 |
37 | query_states = (
38 | self.q_proj(hidden_states)
39 | .view(bsz, q_len, self.num_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | key_states = (
43 | self.k_proj(hidden_states)
44 | .view(bsz, q_len, self.num_heads, self.head_dim)
45 | .transpose(1, 2)
46 | )
47 | value_states = (
48 | self.v_proj(hidden_states)
49 | .view(bsz, q_len, self.num_heads, self.head_dim)
50 | .transpose(1, 2)
51 | )
52 |
53 | kv_seq_len = key_states.shape[-2]
54 | if past_key_value is not None:
55 | kv_seq_len += past_key_value[0].shape[-2]
56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
57 | (
58 | query_states,
59 | key_states,
60 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
61 | query_states, key_states, cos, sin, position_ids
62 | )
63 | # [bsz, nh, t, hd]
64 |
65 | if past_key_value is not None:
66 | # reuse k, v, self_attention
67 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
68 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
69 |
70 | past_key_value = (key_states, value_states) if use_cache else None
71 |
72 | # We only apply xformers optimizations if we don't need to output the whole attention matrix
73 | if not output_attentions:
74 | query_states = query_states.transpose(1, 2)
75 | key_states = key_states.transpose(1, 2)
76 | value_states = value_states.transpose(1, 2)
77 |
78 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
79 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
80 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
81 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
82 | attn_output = xformers.ops.memory_efficient_attention(
83 | query_states, key_states, value_states, attn_bias=None
84 | )
85 | else:
86 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
87 | attn_output = xformers.ops.memory_efficient_attention(
88 | query_states,
89 | key_states,
90 | value_states,
91 | attn_bias=xformers.ops.LowerTriangularMask(),
92 | )
93 | attn_weights = None
94 | else:
95 | attn_weights = torch.matmul(
96 | query_states, key_states.transpose(2, 3)
97 | ) / math.sqrt(self.head_dim)
98 |
99 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
100 | raise ValueError(
101 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
102 | f" {attn_weights.size()}"
103 | )
104 |
105 | if attention_mask is not None:
106 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
107 | raise ValueError(
108 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
109 | )
110 | attn_weights = attn_weights + attention_mask
111 | attn_weights = torch.max(
112 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
113 | )
114 |
115 | # upcast attention to fp32
116 | attn_weights = nn.functional.softmax(
117 | attn_weights, dim=-1, dtype=torch.float32
118 | ).to(query_states.dtype)
119 | attn_output = torch.matmul(attn_weights, value_states)
120 |
121 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
122 | raise ValueError(
123 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
124 | f" {attn_output.size()}"
125 | )
126 |
127 | attn_output = attn_output.transpose(1, 2)
128 |
129 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
130 | attn_output = self.o_proj(attn_output)
131 | return attn_output, attn_weights, past_key_value
132 |
--------------------------------------------------------------------------------
/flash_vstream/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/haotian-liu/LLaVA.
2 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
3 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
4 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
5 |
6 | # Need to call this before importing transformers.
7 | from flash_vstream.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
8 |
9 | replace_llama_attn_with_flash_attn()
10 |
11 | from flash_vstream.train.train import train
12 |
13 | if __name__ == "__main__":
14 | train()
15 |
--------------------------------------------------------------------------------
/flash_vstream/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
4 |
5 | # Need to call this before importing transformers.
6 | from flash_vstream.train.llama_xformers_attn_monkey_patch import (
7 | replace_llama_attn_with_xformers_attn,
8 | )
9 |
10 | replace_llama_attn_with_xformers_attn()
11 |
12 | from flash_vstream.train.train import train
13 |
14 | if __name__ == "__main__":
15 | train()
16 |
--------------------------------------------------------------------------------
/flash_vstream/train/vstream_trainer.py:
--------------------------------------------------------------------------------
1 | # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2 | # ------------------------------------------------------------------------
3 | # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4 | # Copyright 2023 Haotian Liu
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import os
19 | import torch
20 | import torch.nn as nn
21 |
22 | from torch.utils.data import Sampler
23 |
24 | from transformers import Trainer
25 | from transformers.trainer import (
26 | is_sagemaker_mp_enabled,
27 | get_parameter_names,
28 | has_length,
29 | ALL_LAYERNORM_LAYERS,
30 | ShardedDDPOption,
31 | logger,
32 | )
33 | from typing import List, Optional
34 |
35 |
36 | def maybe_zero_3(param, ignore_status=False, name=None):
37 | from deepspeed import zero
38 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
39 | if hasattr(param, "ds_id"):
40 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
41 | if not ignore_status:
42 | print(name, 'no ignore status')
43 | with zero.GatheredParameters([param]):
44 | param = param.data.detach().cpu().clone()
45 | else:
46 | param = param.detach().cpu().clone()
47 | return param
48 |
49 |
50 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
51 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
52 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
53 | return to_return
54 |
55 |
56 | def split_to_even_chunks(indices, lengths, num_chunks):
57 | """
58 | Split a list of indices into `chunks` chunks of roughly equal lengths.
59 | """
60 |
61 | if len(indices) % num_chunks != 0:
62 | return [indices[i::num_chunks] for i in range(num_chunks)]
63 |
64 | num_indices_per_chunk = len(indices) // num_chunks
65 |
66 | chunks = [[] for _ in range(num_chunks)]
67 | chunks_lengths = [0 for _ in range(num_chunks)]
68 | for index in indices:
69 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
70 | chunks[shortest_chunk].append(index)
71 | chunks_lengths[shortest_chunk] += lengths[index]
72 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
73 | chunks_lengths[shortest_chunk] = float("inf")
74 |
75 | return chunks
76 |
77 |
78 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
79 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
80 | assert all(l != 0 for l in lengths), "Should not have zero length."
81 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
82 | # all samples are in the same modality
83 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
84 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
85 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
86 |
87 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
88 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
89 | megabatch_size = world_size * batch_size
90 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
91 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
92 |
93 | last_mm = mm_megabatches[-1]
94 | last_lang = lang_megabatches[-1]
95 | additional_batch = last_mm + last_lang
96 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
97 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
98 | megabatches = [megabatches[i] for i in megabatch_indices]
99 |
100 | if len(additional_batch) > 0:
101 | megabatches.append(sorted(additional_batch))
102 |
103 | return [i for megabatch in megabatches for i in megabatch]
104 |
105 |
106 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
107 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
108 | indices = torch.randperm(len(lengths), generator=generator)
109 | megabatch_size = world_size * batch_size
110 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
111 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
112 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
113 |
114 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
115 |
116 |
117 | class LengthGroupedSampler(Sampler):
118 | r"""
119 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
120 | keeping a bit of randomness.
121 | """
122 |
123 | def __init__(
124 | self,
125 | batch_size: int,
126 | world_size: int,
127 | lengths: Optional[List[int]] = None,
128 | generator=None,
129 | group_by_modality: bool = False,
130 | ):
131 | if lengths is None:
132 | raise ValueError("Lengths must be provided.")
133 |
134 | self.batch_size = batch_size
135 | self.world_size = world_size
136 | self.lengths = lengths
137 | self.generator = generator
138 | self.group_by_modality = group_by_modality
139 |
140 | def __len__(self):
141 | return len(self.lengths)
142 |
143 | def __iter__(self):
144 | if self.group_by_modality:
145 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
146 | else:
147 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
148 | return iter(indices)
149 |
150 |
151 | class VStreamTrainer(Trainer):
152 |
153 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
154 | if self.train_dataset is None or not has_length(self.train_dataset):
155 | return None
156 |
157 | if self.args.group_by_modality_length:
158 | lengths = self.train_dataset.modality_lengths
159 | return LengthGroupedSampler(
160 | self.args.train_batch_size,
161 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
162 | lengths=lengths,
163 | group_by_modality=True,
164 | )
165 | else:
166 | return super()._get_train_sampler()
167 |
168 | def create_optimizer(self):
169 | """
170 | Setup the optimizer.
171 |
172 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
173 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
174 | """
175 | if is_sagemaker_mp_enabled():
176 | return super().create_optimizer()
177 | if self.sharded_ddp == ShardedDDPOption.SIMPLE:
178 | return super().create_optimizer()
179 |
180 | opt_model = self.model
181 |
182 | if self.optimizer is None:
183 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
184 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
185 | if self.args.mm_projector_lr is not None:
186 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
187 | optimizer_grouped_parameters = [
188 | {
189 | "params": [
190 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
191 | ],
192 | "weight_decay": self.args.weight_decay,
193 | },
194 | {
195 | "params": [
196 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
197 | ],
198 | "weight_decay": 0.0,
199 | },
200 | {
201 | "params": [
202 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
203 | ],
204 | "weight_decay": self.args.weight_decay,
205 | "lr": self.args.mm_projector_lr,
206 | },
207 | {
208 | "params": [
209 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
210 | ],
211 | "weight_decay": 0.0,
212 | "lr": self.args.mm_projector_lr,
213 | },
214 | ]
215 | else:
216 | optimizer_grouped_parameters = [
217 | {
218 | "params": [
219 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
220 | ],
221 | "weight_decay": self.args.weight_decay,
222 | },
223 | {
224 | "params": [
225 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
226 | ],
227 | "weight_decay": 0.0,
228 | },
229 | ]
230 |
231 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
232 |
233 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
234 | if optimizer_cls.__name__ == "Adam8bit":
235 | import bitsandbytes
236 |
237 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
238 |
239 | skipped = 0
240 | for module in opt_model.modules():
241 | if isinstance(module, nn.Embedding):
242 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
243 | logger.info(f"skipped {module}: {skipped/2**20}M params")
244 | manager.register_module_override(module, "weight", {"optim_bits": 32})
245 | logger.debug(f"bitsandbytes: will optimize {module} in fp32")
246 | logger.info(f"skipped: {skipped/2**20}M params")
247 |
248 | return self.optimizer
249 |
--------------------------------------------------------------------------------
/flash_vstream/utils.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import datetime
4 | import logging
5 | import logging.handlers
6 | import os
7 | import sys
8 |
9 | import requests
10 |
11 | from flash_vstream.constants import LOGDIR
12 |
13 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
14 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15 |
16 | handler = None
17 |
18 |
19 | def build_logger(logger_name, logger_filename):
20 | global handler
21 |
22 | formatter = logging.Formatter(
23 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24 | datefmt="%Y-%m-%d %H:%M:%S",
25 | )
26 |
27 | # Set the format of root handlers
28 | if not logging.getLogger().handlers:
29 | logging.basicConfig(level=logging.INFO)
30 | logging.getLogger().handlers[0].setFormatter(formatter)
31 |
32 | # Redirect stdout and stderr to loggers
33 | stdout_logger = logging.getLogger("stdout")
34 | stdout_logger.setLevel(logging.INFO)
35 | sl = StreamToLogger(stdout_logger, logging.INFO)
36 | sys.stdout = sl
37 |
38 | stderr_logger = logging.getLogger("stderr")
39 | stderr_logger.setLevel(logging.ERROR)
40 | sl = StreamToLogger(stderr_logger, logging.ERROR)
41 | sys.stderr = sl
42 |
43 | # Get logger
44 | logger = logging.getLogger(logger_name)
45 | logger.setLevel(logging.INFO)
46 |
47 | # Add a file handler for all loggers
48 | if handler is None:
49 | os.makedirs(LOGDIR, exist_ok=True)
50 | filename = os.path.join(LOGDIR, logger_filename)
51 | handler = logging.handlers.TimedRotatingFileHandler(
52 | filename, when='D', utc=True, encoding='UTF-8')
53 | handler.setFormatter(formatter)
54 |
55 | for name, item in logging.root.manager.loggerDict.items():
56 | if isinstance(item, logging.Logger):
57 | item.addHandler(handler)
58 |
59 | return logger
60 |
61 |
62 | class StreamToLogger(object):
63 | """
64 | Fake file-like stream object that redirects writes to a logger instance.
65 | """
66 | def __init__(self, logger, log_level=logging.INFO):
67 | self.terminal = sys.stdout
68 | self.logger = logger
69 | self.log_level = log_level
70 | self.linebuf = ''
71 |
72 | def __getattr__(self, attr):
73 | return getattr(self.terminal, attr)
74 |
75 | def write(self, buf):
76 | temp_linebuf = self.linebuf + buf
77 | self.linebuf = ''
78 | for line in temp_linebuf.splitlines(True):
79 | # From the io.TextIOWrapper docs:
80 | # On output, if newline is None, any '\n' characters written
81 | # are translated to the system default line separator.
82 | # By default sys.stdout.write() expects '\n' newlines and then
83 | # translates them so this is still cross platform.
84 | if line[-1] == '\n':
85 | self.logger.log(self.log_level, line.rstrip())
86 | else:
87 | self.linebuf += line
88 |
89 | def flush(self):
90 | if self.linebuf != '':
91 | self.logger.log(self.log_level, self.linebuf.rstrip())
92 | self.linebuf = ''
93 |
94 |
95 | def disable_torch_init():
96 | """
97 | Disable the redundant torch default initialization to accelerate model creation.
98 | """
99 | import torch
100 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
101 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
102 |
103 |
104 | def violates_moderation(text):
105 | """
106 | Check whether the text violates OpenAI moderation API.
107 | """
108 | url = "https://api.openai.com/v1/moderations"
109 | headers = {"Content-Type": "application/json",
110 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
111 | text = text.replace("\n", "")
112 | data = "{" + '"input": ' + f'"{text}"' + "}"
113 | data = data.encode("utf-8")
114 | try:
115 | ret = requests.post(url, headers=headers, data=data, timeout=5)
116 | flagged = ret.json()["results"][0]["flagged"]
117 | except requests.exceptions.RequestException as e:
118 | flagged = False
119 | except KeyError as e:
120 | flagged = False
121 |
122 | return flagged
123 |
124 |
125 | def pretty_print_semaphore(semaphore):
126 | if semaphore is None:
127 | return "None"
128 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
129 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "vstream"
7 | version = "1.0"
8 | description = "Flash-VStream"
9 | readme = "README.md"
10 | requires-python = ">=3.10"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "numpy",
17 | "tokenizers>=0.12.1",
18 | "torch==2.0.1",
19 | "torchvision==0.15.2",
20 | "wandb",
21 | "tensorboard",
22 | "tensorboardX",
23 | "httpx==0.23.0",
24 | "deepspeed==0.9.5",
25 | "peft==0.4.0",
26 | "transformers==4.31.0",
27 | "accelerate==0.21.0",
28 | "bitsandbytes==0.41.0",
29 | "scikit-learn==1.2.2",
30 | "sentencepiece==0.1.99",
31 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
32 | "decord",
33 | "openai==0.28.0",
34 | ]
35 |
36 | [project.urls]
37 | "Homepage" = "https://github.com/zhang9302002/Flash-VStream"
38 | "Bug Tracker" = "https://github.com/zhang9302002/Flash-VStream/issues"
39 |
40 | [tool.setuptools.packages.find]
41 | exclude = ["checkpoints*", "data*", "docs", "scripts*"]
42 |
43 | [tool.wheel]
44 | exclude = ["checkpoints*", "data*", "docs", "scripts*"]
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # set up python environment
4 | conda activate vstream
5 |
6 | # set important configurations
7 | ngpus=8
8 | gputype=A100
9 |
10 | # auto calculate configurations
11 | gpus_list=$(seq -s, 0 $((ngpus - 1)))
12 | date_device="$(date +%m%d)_${ngpus}${gputype}"
13 |
14 | echo start eval
15 | # define your openai info here
16 |
17 | for dataset in actnet nextoe msvd msrvtt vsmovienet vsego4d realtime_vsmovienet realtime_vsego4d
18 | do
19 | echo start eval ${dataset}
20 | python -m flash_vstream.eval_video.eval_any_dataset_features \
21 | --model-path your_model_checkpoint_path \
22 | --dataset ${dataset} \
23 | --num_chunks $ngpus \
24 | --api_key $OPENAIKEY \
25 | --api_base $OPENAIBASE \
26 | --api_type $OPENAITYPE \
27 | --api_version $OPENAIVERSION \
28 | --test \
29 | >> ${date_device}_vstream-7b-eval-${dataset}.log 2>&1
30 | done
31 |
32 |
--------------------------------------------------------------------------------
/scripts/merge_lora_weights.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/haotian-liu/LLaVA.
2 |
3 | import argparse
4 | from llama_vstream.model.builder import load_pretrained_model
5 | from llama_vstream.mm_utils import get_model_name_from_path
6 |
7 |
8 | def merge_lora(args):
9 | model_name = get_model_name_from_path(args.model_path)
10 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
11 |
12 | model.save_pretrained(args.save_model_path)
13 | tokenizer.save_pretrained(args.save_model_path)
14 |
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--model-path", type=str, required=True)
19 | parser.add_argument("--model-base", type=str, required=True)
20 | parser.add_argument("--save-model-path", type=str, required=True)
21 |
22 | args = parser.parse_args()
23 |
24 | merge_lora(args)
25 |
--------------------------------------------------------------------------------
/scripts/realtime_cli.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m flash_vstream.serve.cli_video_stream \
4 | --model-path your_model_checkpoint_path \
5 | --video-file assets/example.mp4 \
6 | --conv-mode vicuna_v1 --temperature 0.0 \
7 | --video_max_frames 1200 \
8 | --video_fps 1.0 --play_speed 1.0 \
9 | --log-file realtime_cli.log
10 |
--------------------------------------------------------------------------------
/scripts/train_and_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # set up python environment
4 | conda activate vstream
5 |
6 | # set important configurations
7 | type=weighted_kmeans
8 | suffix=STAR
9 | cur_length=1
10 | cur_size=8
11 | long_length=25
12 | long_size=4
13 | Turing_length=25
14 | Turing_size=1
15 | ngpus=8
16 | gputype=A100
17 |
18 | # auto calculate configurations
19 | gpus_list=$(seq -s, 0 $((ngpus - 1)))
20 | date_device="$(date +%m%d)_${ngpus}${gputype}"
21 |
22 | echo start pretrain
23 | deepspeed --master_addr 127.0.0.1 --master_port 12345 --include localhost:${gpus_list} flash_vstream/train/train_mem.py \
24 | --deepspeed ./scripts/zero0.json \
25 | --model_name_or_path ./ckpt/vicuna-7b-v1.5 \
26 | --version plain \
27 | --data_path ./data/pretrain/llava_558k_with_webvid.json \
28 | --image_folder ./data/pretrain/image_features \
29 | --video_folder ./data/pretrain/video_features \
30 | --vision_tower ./ckpt/clip-vit-large-patch14 \
31 | --mm_projector_type mlp2x_gelu \
32 | --tune_mm_mlp_adapter True \
33 | --mm_vision_select_layer -2 \
34 | --mm_use_im_start_end False \
35 | --mm_use_im_patch_token False \
36 | --video_fps 1 \
37 | --compress_type mean \
38 | --compress_size ${cur_size} \
39 | --compress_long_memory_size ${long_size} \
40 | --compress_Turing_memory_size ${Turing_size} \
41 | --video_max_frames $((cur_length + long_length)) \
42 | --video_current_memory_length ${cur_length} \
43 | --video_long_memory_length ${long_length} \
44 | --video_Turing_memory_length ${Turing_length} \
45 | --video_sample_type ${type} \
46 | --group_by_modality_length \
47 | --bf16 \
48 | --output_dir ./checkpoints-pretrain/vstream-7b-pretrain-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix} \
49 | --num_train_epochs 1 \
50 | --per_device_train_batch_size 32 \
51 | --per_device_eval_batch_size 4 \
52 | --gradient_accumulation_steps 1 \
53 | --evaluation_strategy "no" \
54 | --save_strategy "steps" \
55 | --save_steps 100 \
56 | --save_total_limit 10 \
57 | --learning_rate 1e-3 \
58 | --weight_decay 0. \
59 | --warmup_ratio 0.03 \
60 | --lr_scheduler_type "cosine" \
61 | --logging_steps 1 \
62 | --model_max_length 2048 \
63 | --gradient_checkpointing True \
64 | --dataloader_num_workers 4 \
65 | --lazy_preprocess True \
66 | --report_to tensorboard \
67 | >> ${date_device}_vstream-7b-pretrain-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}.log 2>&1
68 |
69 | echo start finetune
70 | deepspeed --master_addr 127.0.2.1 --master_port 12345 --include localhost:${gpus_list} flash_vstream/train/train_mem.py \
71 | --deepspeed ./scripts/zero1.json \
72 | --model_name_or_path ./checkpoints-pretrain/vstream-7b-pretrain-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}/checkpoint-3000 \
73 | --version v1 \
74 | --data_path ./data/finetune/llava_v1_5_mix665k_with_video_chatgpt.json \
75 | --image_folder ./data/finetune/image_features \
76 | --video_folder ./data/finetune/video_features \
77 | --vision_tower ./ckpt/clip-vit-large-patch14 \
78 | --mm_projector_type mlp2x_gelu \
79 | --mm_vision_select_layer -2 \
80 | --mm_use_im_start_end False \
81 | --mm_use_im_patch_token False \
82 | --image_aspect_ratio pad \
83 | --video_fps 1 \
84 | --compress_type mean \
85 | --compress_size ${cur_size} \
86 | --compress_long_memory_size ${long_size} \
87 | --compress_Turing_memory_size ${Turing_size} \
88 | --video_max_frames $((cur_length + long_length)) \
89 | --video_current_memory_length ${cur_length} \
90 | --video_long_memory_length ${long_length} \
91 | --video_Turing_memory_length ${Turing_length} \
92 | --video_sample_type ${type} \
93 | --group_by_modality_length \
94 | --bf16 \
95 | --output_dir ./checkpoints-finetune/vstream-7b-finetune-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix} \
96 | --num_train_epochs 1 \
97 | --per_device_train_batch_size 16 \
98 | --per_device_eval_batch_size 4 \
99 | --gradient_accumulation_steps 1 \
100 | --evaluation_strategy "no" \
101 | --save_strategy "steps" \
102 | --save_steps 100 \
103 | --save_total_limit 10 \
104 | --learning_rate 2e-5 \
105 | --weight_decay 0. \
106 | --warmup_ratio 0.03 \
107 | --lr_scheduler_type "cosine" \
108 | --logging_steps 1 \
109 | --model_max_length 2048 \
110 | --gradient_checkpointing True \
111 | --dataloader_num_workers 4 \
112 | --lazy_preprocess True \
113 | --report_to tensorboard \
114 | >> ${date_device}_vstream-7b-finetune-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}.log 2>&1
115 |
116 |
117 | echo start eval
118 | # define your openai info here
119 |
120 | for dataset in actnet nextoe msvd msrvtt vsmovienet vsego4d realtime_vsmovienet realtime_vsego4d
121 | do
122 | echo start eval ${dataset}
123 | python -m flash_vstream.eval_video.eval_any_dataset_features \
124 | --model-path checkpoints-finetune/vstream-7b-finetune-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}/checkpoint-5900 \
125 | --dataset ${dataset} \
126 | --num_chunks $ngpus \
127 | --api_key $OPENAIKEY \
128 | --api_base $OPENAIBASE \
129 | --api_type $OPENAITYPE \
130 | --api_version $OPENAIVERSION \
131 | >> ${date_device}_vstream-7b-eval-${dataset}-${type}${cur_length}*${cur_size}-${long_length}*${long_size}-${Turing_length}*${Turing_size}-${suffix}.log 2>&1
132 | done
133 |
--------------------------------------------------------------------------------
/scripts/zero0.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 0
18 | }
19 | }
--------------------------------------------------------------------------------
/scripts/zero1.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 1
18 | }
19 | }
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/vstream.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: vstream
3 | Version: 1.0
4 | Summary: Flash-VStream
5 | Project-URL: Homepage, https://github.com/zhang9302002/Flash-VStream
6 | Project-URL: Bug Tracker, https://github.com/zhang9302002/Flash-VStream/issues
7 | Classifier: Programming Language :: Python :: 3
8 | Classifier: License :: OSI Approved :: Apache Software License
9 | Requires-Python: >=3.10
10 | Description-Content-Type: text/markdown
11 | License-File: LICENSE
12 | Requires-Dist: numpy
13 | Requires-Dist: tokenizers>=0.12.1
14 | Requires-Dist: torch==2.0.1
15 | Requires-Dist: torchvision==0.15.2
16 | Requires-Dist: wandb
17 | Requires-Dist: tensorboard
18 | Requires-Dist: tensorboardX
19 | Requires-Dist: httpx==0.24.1
20 | Requires-Dist: deepspeed==0.9.5
21 | Requires-Dist: peft==0.4.0
22 | Requires-Dist: transformers==4.31.0
23 | Requires-Dist: accelerate==0.21.0
24 | Requires-Dist: bitsandbytes==0.41.0
25 | Requires-Dist: scikit-learn==1.2.2
26 | Requires-Dist: sentencepiece==0.1.99
27 | Requires-Dist: einops==0.6.1
28 | Requires-Dist: einops-exts==0.0.4
29 | Requires-Dist: timm==0.6.13
30 | Requires-Dist: decord
31 | Requires-Dist: openai==0.28.0
32 |
33 | # Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams
34 |
35 |
36 | Haoji Zhang\*,
37 | Yiqin Wang\*,
38 | Yansong Tang †,
39 | Yong Liu,
40 | Jiashi Feng,
41 | Jifeng Dai,
42 | Xiaojie Jin†‡
43 |
44 | \* Equally contributing first authors, †Correspondence, ‡Project Lead
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | [](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=flash-vstream-memory-based-real-time)
55 |
56 |
57 | [](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=flash-vstream-memory-based-real-time)
58 |
59 |
60 | [](https://paperswithcode.com/sota/question-answering-on-next-qa-open-ended?p=flash-vstream-memory-based-real-time)
61 |
62 |
63 | [](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=flash-vstream-memory-based-real-time)
64 |
65 |
66 |
67 |
68 |
69 | We presented **Flash-VStream**, a noval LMM able to process extremely long video streams in real-time and respond to user queries simultaneously.
70 |
71 | We also proposed **VStream-QA**, a novel question answering benchmark specifically designed for online video streaming understanding.
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | ## News
81 | - [2024/6/15] 🏅 Our team won the 1st Place at [Long-Term Video Question Answering Challenge](https://sites.google.com/view/loveucvpr24/track1) of [LOVEU Workshop@CVPR'24](https://sites.google.com/view/loveucvpr24/home). We used a Hierarchical Memory model based on Flash-VStream-7b.
82 |
83 | - [2024/06/12] 🔥 Flash-VStream is coming! We release the
84 | [homepage](https://invinciblewyq.github.io/vstream-page),
85 | [paper](https://arxiv.org/abs/2406.08085v1),
86 | [code](https://github.com/IVGSZ/Flash-VStream)
87 | and [model](https://huggingface.co/IVGSZ/Flash-VStream-7b)
88 | for Flash-VStream.
89 | We release the [dataset](https://huggingface.co/datasets/IVGSZ/VStream-QA) for VStream-QA benchmark.
90 |
91 | ## Contents
92 | - [Install](#install)
93 | - [Model](#model)
94 | - [Preparation](#preparation)
95 | - [Train](#train)
96 | - [Evaluation](#evaluation)
97 | - [Real-time CLI Inference](#Real-time-CLI-Inference)
98 | - [VStream-QA Benchmark](#VStream-QA-Benchmark)
99 | - [Citation](#citation)
100 | - [Acknowledgement](#acknowledgement)
101 | - [License](#license)
102 |
103 | ## Install
104 | Please follow the instructions below to install the required packages.
105 | 1. Clone this repository
106 |
107 | 2. Install Package
108 | ```bash
109 | conda create -n vstream python=3.10 -y
110 | conda activate vstream
111 | cd Flash-VStream
112 | pip install --upgrade pip
113 | pip install -e .
114 | ```
115 |
116 | 3. Install additional packages for training cases
117 | ```bash
118 | pip install ninja
119 | pip install flash-attn --no-build-isolation
120 | ```
121 |
122 | ## Model
123 |
124 | We provide our Flash-VStream models after Stage 1 and 2 finetuning:
125 |
126 | | Model | Weight | Initialized from LLM | Initialized from ViT |
127 | | --- | --- | --- | --- |
128 | | Flash-VStream-7b | [Flash-VStream-7b](https://huggingface.co/IVGSZ/Flash-VStream-7b) | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) |
129 |
130 |
131 | ## Preparation
132 |
133 | ### Dataset
134 |
135 | **Image VQA Dataset.**
136 | Please organize the training Image VQA training data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md) and evaluation data following [this](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).
137 | Please put the pretraining data, finetuning data, and evaluation data in `pretrain`, `finetune`, and `eval_video` folder following [Structure](#structure).
138 |
139 | **Video VQA Dataset.**
140 | please download the 2.5M subset from [WebVid](https://maxbain.com/webvid-dataset/) and ActivityNet dataset from [official website](http://activity-net.org/download.html) or [video-chatgpt](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/docs/train_video_chatgpt.md).
141 |
142 | If you want to perform evaluation, please also download corresponding files of
143 | [ActivityNet-QA](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md)
144 | and [NExT-QA-OE](https://github.com/doc-doc/NExT-QA).
145 | You can download
146 | [MSVD-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EUNEXqg8pctPq3WZPHb4Fd8BYIxHO5qPCnU6aWsrV1O4JQ?e=guynwu)
147 | and [MSRVTT-QA](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155186668_link_cuhk_edu_hk/EcEXh1HfTXhLrRnuwHbl15IBJeRop-d50Q90njHmhvLwtA?e=SE24eG) from LLaMA-VID.
148 |
149 |
150 | **Meta Info.**
151 | For meta info of training data, please download the following files and organize them as in [Structure](#structure).
152 |
153 | | Training Stage | Data file name | Size |
154 | | --- | --- | ---: |
155 | | Pretrain | [llava_558k_with_webvid.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 254 MB |
156 | | Finetune | [llava_v1_5_mix665k_with_video_chatgpt.json](https://huggingface.co/datasets/YanweiLi/LLaMA-VID-Data) | 860 MB |
157 |
158 | For meta info of evaluation data, please reformat each QA list to a json file named `test_qa.json` under [Structure](#structure) with format like this:
159 |
160 | ```json
161 | [
162 | {
163 | "video_id": "v_1QIUV7WYKXg",
164 | "question": "is the athlete wearing trousers",
165 | "id": "v_1QIUV7WYKXg_3",
166 | "answer": "no",
167 | "answer_type": 3,
168 | "duration": 9.88
169 | },
170 | {
171 | "video_id": "v_9eniCub7u60",
172 | "question": "does the girl in black clothes have long hair",
173 | "id": "v_9eniCub7u60_2",
174 | "answer": "yes",
175 | "answer_type": 3,
176 | "duration": 19.43
177 | },
178 | ]
179 | ```
180 |
181 | ### Pretrained Weights
182 | We recommend users to download the pretrained weights from the following link
183 | [Vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5),
184 | [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14),
185 | and put them in `ckpt` following [Structure](#structure).
186 |
187 | ### Feature Extraction
188 |
189 | We recommend users to extract ViT features of training and evaluation data, which accelerates training and evaluating a lot. If you do so, just replace `.mp4` with `.safetensors` in video filename and put them in `image_features` and `video_features` folder. If not, ignore the `image_features` and `video_features` folder.
190 |
191 | We load video feature at fps=1 and arrange them in the time order.
192 |
193 | Each `.safetensors` file should contain a dict like this:
194 |
195 | ```python
196 | {
197 | 'feature': torch.Tensor() with shape=[256, 1024] for image and shape=[Length, 256, 1024] for video.
198 | }
199 | ```
200 |
201 |
202 | ### Structure
203 | The folder structure should be organized as follows before training.
204 |
205 | ```
206 | Flash-VStream
207 | ├── checkpoints-finetune
208 | ├── checkpoints-pretrain
209 | ├── ckpt
210 | │ ├── clip-vit-large-patch14
211 | │ ├── vicuna-7b-v1.5
212 | ├── data
213 | │ ├── pretrain
214 | │ │ ├── llava_558k_with_webvid.json
215 | │ │ ├── image_features
216 | │ │ ├── images
217 | │ │ ├── video_features
218 | │ │ ├── videos
219 | │ ├── finetune
220 | │ │ ├── llava_v1_5_mix665k_with_video_chatgpt.json
221 | │ │ ├── activitynet
222 | │ │ ├── coco
223 | │ │ ├── gqa
224 | │ │ ├── image_features
225 | │ │ │ ├── coco
226 | │ │ │ ├── gqa
227 | │ │ │ ├── ocr_vqa
228 | │ │ │ ├── textvqa
229 | │ │ │ ├── vg
230 | │ │ ├── ocr_vqa
231 | │ │ ├── textvqa
232 | │ │ ├── vg
233 | │ │ ├── video_features
234 | │ │ │ ├── activitynet
235 | │ ├── eval_video
236 | │ │ ├── ActivityNet-QA
237 | │ │ │ ├── video_features
238 | │ │ │ ├── test_qa.json
239 | │ │ ├── MSRVTT-QA
240 | │ │ │ ├── video_features
241 | │ │ │ ├── test_qa.json
242 | │ │ ├── MSVD-QA
243 | │ │ │ ├── video_features
244 | │ │ │ ├── test_qa.json
245 | │ │ ├── nextqa
246 | │ │ │ ├── video_features
247 | │ │ │ ├── test_qa.json
248 | │ │ ├── vstream
249 | │ │ │ ├── video_features
250 | │ │ │ ├── test_qa.json
251 | │ │ ├── vstream-realtime
252 | │ │ │ ├── video_features
253 | │ │ │ ├── test_qa.json
254 | ├── flash_vstream
255 | ├── scripts
256 |
257 | ```
258 |
259 | ## Train
260 | Flash-VStream is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`. If your GPUs have less than 80GB memory, you may try ZeRO-2 and ZeRO-3 stages.
261 |
262 | Please make sure you download and organize the data following [Preparation](#preparation) before training.
263 |
264 | Like LLaVA, Flash-VStream has two training stages: pretrain and finetune. Their checkpoints will be saved in `checkpoints-pretrain` and `checkpoints-finetune` folder. These two stages will take about 15 hours on 8 A100 GPUs in total.
265 |
266 | If you want to train Flash-VStream from pretrained LLM and evaluate it, please run the following command:
267 |
268 | ```bash
269 | bash scripts/train_and_eval.sh
270 | ```
271 |
272 | ## Evaluation
273 | Please make sure you download and organize the data following [Preparation](#preparation) before evaluation.
274 |
275 | If you want to evaluate a Flash-VStream model, please run the following command:
276 |
277 | ```bash
278 | bash scripts/eval.sh
279 | ```
280 |
281 | ## Real-time CLI Inference
282 | We provide a real-time CLI inference script, which simulates video stream input by reading frames of a video file at a fixed frame speed. You can ask any question and get the answer at any timestamp of the video stream. Run the following command and have a try:
283 |
284 | ```bash
285 | bash scripts/realtime_cli.sh
286 | ```
287 |
288 | ## VStream-QA Benchmark
289 | Please download VStream-QA Benchmark following [this](https://huggingface.co/datasets/IVGSZ/VStream-QA) repo.
290 |
291 | ## Citation
292 | If you find this project useful in your research, please consider citing:
293 |
294 | ```
295 | @article{flashvstream,
296 | title={Flash-VStream: Memory-Based Real-Time Understanding for Long Video Streams},
297 | author={Haoji Zhang and Yiqin Wang and Yansong Tang and Yong Liu and Jiashi Feng and Jifeng Dai and Xiaojie Jin},
298 | year={2024},
299 | eprint={2406.08085},
300 | archivePrefix={arXiv},
301 | primaryClass={cs.CV}
302 | }
303 | ```
304 |
305 | ## Acknowledgement
306 | We would like to thank the following repos for their great work:
307 |
308 | - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA).
309 | - This work utilizes LLMs from [Vicuna](https://github.com/lm-sys/FastChat).
310 | - Some code is borrowed from [LLaMA-VID](https://github.com/dvlab-research/LLaMA-VID).
311 | - We perform video-based evaluation from [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT).
312 |
313 | ## License
314 | [](LICENSE)
315 |
316 | This project is licensed under the [Apache-2.0 License](LICENSE).
317 |
--------------------------------------------------------------------------------
/vstream.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | pyproject.toml
4 | flash_vstream/__init__.py
5 | flash_vstream/constants.py
6 | flash_vstream/conversation.py
7 | flash_vstream/mm_utils.py
8 | flash_vstream/utils.py
9 | flash_vstream/eval_video/eval_activitynet_qa.py
10 | flash_vstream/eval_video/eval_any_dataset_features.py
11 | flash_vstream/eval_video/model_msvd_qa.py
12 | flash_vstream/eval_video/model_msvd_qa_featuresloader.py
13 | flash_vstream/model/__init__.py
14 | flash_vstream/model/builder.py
15 | flash_vstream/model/compress_functions.py
16 | flash_vstream/model/vstream_arch.py
17 | flash_vstream/model/language_model/vstream_llama.py
18 | flash_vstream/model/multimodal_encoder/builder.py
19 | flash_vstream/model/multimodal_encoder/clip_encoder.py
20 | flash_vstream/model/multimodal_projector/builder.py
21 | flash_vstream/serve/cli_video_stream.py
22 | flash_vstream/train/llama_flash_attn_monkey_patch.py
23 | flash_vstream/train/llama_xformers_attn_monkey_patch.py
24 | flash_vstream/train/train.py
25 | flash_vstream/train/train_mem.py
26 | flash_vstream/train/train_xformers.py
27 | flash_vstream/train/vstream_trainer.py
28 | vstream.egg-info/PKG-INFO
29 | vstream.egg-info/SOURCES.txt
30 | vstream.egg-info/dependency_links.txt
31 | vstream.egg-info/requires.txt
32 | vstream.egg-info/top_level.txt
--------------------------------------------------------------------------------
/vstream.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/vstream.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tokenizers>=0.12.1
3 | torch==2.0.1
4 | torchvision==0.15.2
5 | wandb
6 | tensorboard
7 | tensorboardX
8 | httpx==0.24.1
9 | deepspeed==0.9.5
10 | peft==0.4.0
11 | transformers==4.31.0
12 | accelerate==0.21.0
13 | bitsandbytes==0.41.0
14 | scikit-learn==1.2.2
15 | sentencepiece==0.1.99
16 | einops==0.6.1
17 | einops-exts==0.0.4
18 | timm==0.6.13
19 | decord
20 | openai==0.28.0
21 |
--------------------------------------------------------------------------------
/vstream.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | assets
2 | flash_vstream
3 |
--------------------------------------------------------------------------------