├── .gitignore
├── LICENSE
├── README.md
├── ckpt
└── .place_image_mllm_ckpt
├── cog.yaml
├── figs
├── ben.png
├── pipeline.png
├── sota.png
├── table1.png
├── table2.png
├── table3.png
└── vis.png
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── eval
│ ├── model_utils.py
│ ├── run_inference_benchmark_consistency.py
│ ├── run_inference_benchmark_general.py
│ ├── run_inference_qa.py
│ └── single_video_inference.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_mistral.py
│ │ └── llava_mpt.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
└── utils.py
└── scripts
├── gpt_eval
├── eval_qa_activitynet.sh
├── eval_qa_benchmark.sh
├── eval_qa_msrvtt.sh
├── eval_qa_msvd.sh
├── eval_video_qa.py
├── evaluate_benchmark_1_correctness.py
├── evaluate_benchmark_2_detailed_orientation.py
├── evaluate_benchmark_3_context.py
├── evaluate_benchmark_4_temporal.py
└── evaluate_benchmark_5_consistency.py
└── infer_video
├── run_benchmark_consistency_qa.sh
├── run_benchmark_generic_qa.sh
├── run_benchmark_temporal_qa.sh
├── run_one_video.sh
├── run_qa_anet_13B.sh
├── run_qa_anet_7B.sh
├── run_qa_msrvtt_13B.sh
├── run_qa_msrvtt_7B.sh
├── run_qa_msvd_13B.sh
└── run_qa_msvd_7B.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
FreeVA: Offline MLLM as Training-Free Video Assistant
4 |
5 |
6 | [](https://arxiv.org/abs/2405.07798)
7 |
8 | #### [Wenhao Wu](https://whwu95.github.io/)
9 |
10 | #### [The University of Sydney](https://www.sydney.edu.au/)
11 |
12 | If you like our project, please give us a star ⭐ on GitHub for latest update.
13 |
14 |
15 | ***
16 |
17 |
18 | Welcome to **FreeVA** - a plug-and-play, simple yet effective study exploring the utilization of existing image MLLMs as video conversational models in a training-free manner. ⚡The core code can be just one line!
19 |
20 |
21 | ## Main Take-aways💡
22 | The study provides an essential, yet must-know baseline, and reveals several surprising findings:
23 | 1) 😄FreeVA, leveraging only offline image-based MLLM without additional training, excels in zero-shot video question-answering (e.g., MSVD-QA, ActivityNet-QA, and MSRVTT-QA), even surpassing state-of-the-art methods that involve video instruction tuning.
24 | 2) 🤔While mainstream video-based MLLMs typically initialize with an image-based MLLM (\eg, LLaVA) and then fine-tune using video instruction tuning, the study indicates that utilizing the widely adopted VideoInstruct-100K for video instruction tuning doesn't actually lead to better performance compared to not training at all.
25 | 3) ⚠️The commonly used evaluation metrics in existing works are significantly influenced by changes in the GPT-3.5 API version over time. If ignored, this could affect the fairness and uniformity of comparisons between different methods and impact the analysis and judgment of researchers in the field.
26 |
27 |
28 |
29 | ## 📢News
30 | - [x] **[Jun 7, 2024]** FreeVA results for more MLLMs, such as [InstructBLIP](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip), [InternVL](https://github.com/OpenGVLab/InternVL), and [Dense Connector](https://github.com/HJYao00/DenseConnector), are provided.
31 | - [x] **[May 14, 2024]** [Preprint](https://arxiv.org/pdf/2405.07798) has been released.
32 | - [x] **[May 13, 2024]** Code has been released. Thanks for your star 😝.
33 |
34 |
35 |
36 | ## Overview
37 |
38 |
39 |
40 | An illustration of (a) an overview of the image MLLM inference process and (b) our proposed FreeVA for zero-shot video inference using existing image MLLMs.
41 |

42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 | ## Results
50 |
51 |
52 |
53 | ### Quantitative Results
54 |
55 |
56 |

57 |
58 |

59 |
60 |
61 | ### Qualitative Results
62 |
63 |

64 |
65 |
66 |
67 | ## Empirical Study📊
68 |
69 |
74 |
75 |
76 | ## Running Video QA💬
77 |
78 | FreeVA can be applied to any image-based MLLM, and its core code is straightforward, simply involving a temporal aggregation. Please refer to [temporal_aggregation](./llava/model/llava_arch.py#L148) for implementation details.
79 |
80 | Below, we provide guidance on running the code using LLaVA-1.5 as an example.
81 |
82 | Before running:
83 | 1) Please refer to [cog.yaml](./cog.yaml) for environment configuration regarding LLaVA-1.5.
84 | 2) Please download the LLaVA model in advance and place it in the "ckpt" folder, for example, "llava-v1.5-7b" or "llava-v1.5-13b".
85 | 3) Please refer to [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT) for downloading the evaluation dataset and corresponding annotations.
86 |
87 | ### Zero-shot Video Question-Answering
88 | To enhance evaluation efficiency, we provide a script for single-machine *multi-GPU* evaluation. Taking the ActivityNet-QA dataset as an example, the specific steps to run the script are as follows:
89 |
90 | **Step1: Obtain the prediction file.**
91 | ```sh
92 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_qa_anet_7B.sh
93 | ```
94 | You will get a predtion file *merge.jsonl*.
95 |
96 | **Step2: GPT-assistant evaluation**
97 |
98 | Running the following command will provide you with accuracy and score.
99 | Before running, please fill in your OpenAI API Key, the prediction file address for Step 1, the number of worker processes for multiprocessing (to accelerate inference), and the version number of GPT-3.5.
100 |
101 | ⚠️*Note: The default version of gpt-3.5-turbo has been updated three times in chronological order: gpt-3.5-turbo-0301, gpt-3.5-turbo-0613, gpt-3.5-turbo-0125, with significant performance differences between versions.*
102 | ```sh
103 | bash scripts/gpt_eval/eval_qa_activitynet.sh
104 | ```
105 |
106 | *The evaluation process for other datasets (MSRVTT-QA, MSVD-QA) follows the same procedure. Please refer to the steps outlined above.*
107 |
108 | ### Video-Based Text Generation Performance
109 |
110 | The generative performance benchmark, include five evaluation metrics such as Correctness of Information, Detail Orientation, Contextual Understanding, Temporal Understanding, and Consistency.
111 |
112 | **Step1: Obtain the prediction file.**
113 | ```sh
114 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_benchmark_generic_qa.sh
115 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_benchmark_temporal_qa.sh
116 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_benchmark_consistency_qa.sh
117 | ```
118 | You will get the predtion file *generic.jsonl, temporal.jsonl, consistency.jsonl*, respectively.
119 |
120 |
121 | **Step2: GPT-assistant evaluation**
122 |
123 | Running the following script will generate these five metrics.
124 |
125 | Before running, please fill in your OpenAI API Key, the prediction file address for Step 1, the number of worker processes for multiprocessing (to accelerate inference), and the version number of GPT-3.5.
126 |
127 | ```sh
128 | bash scripts/gpt_eval/eval_qa_benchmark.sh
129 | ```
130 |
131 | ## Acknowledgement🙏
132 | We extend our sincere gratitude to the following awesome projects:
133 | - [LLaVA](https://github.com/haotian-liu/LLaVA): Visual Instruction Tuning
134 | - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): Towards Detailed Video Understanding via Large Vision and Language Models
135 |
136 |
137 | ## BibTeX & Citation
138 |
139 | If you use our code in your research or wish to refer to the results, please star 🌟 this repo and use the following BibTeX 📑 entry.
140 |
141 | ```bibtex
142 | @article{FreeVA,
143 | title={FreeVA: Offline MLLM as Training-Free Video Assistant},
144 | author={Wu, Wenhao},
145 | booktitle={arXiv preprint arXiv:2405.07798},
146 | year={2024}
147 | }
148 |
149 |
--------------------------------------------------------------------------------
/ckpt/.place_image_mllm_ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/ckpt/.place_image_mllm_ckpt
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 |
7 | python_version: "3.11"
8 |
9 | python_packages:
10 | - "torch==2.0.1"
11 | - "accelerate==0.21.0"
12 | - "bitsandbytes==0.41.0"
13 | - "deepspeed==0.9.5"
14 | - "einops-exts==0.0.4"
15 | - "einops==0.6.1"
16 | - "gradio==3.35.2"
17 | - "gradio_client==0.2.9"
18 | - "httpx==0.24.0"
19 | - "markdown2==2.4.10"
20 | - "numpy==1.26.0"
21 | - "peft==0.4.0"
22 | - "scikit-learn==1.2.2"
23 | - "sentencepiece==0.1.99"
24 | - "shortuuid==1.0.11"
25 | - "timm==0.6.13"
26 | - "tokenizers==0.13.3"
27 | - "torch==2.0.1"
28 | - "torchvision==0.15.2"
29 | - "transformers==4.31.0"
30 | - "wandb==0.15.12"
31 | - "wavedrom==2.0.3.post3"
32 | - "Pygments==2.16.1"
33 | run:
34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
35 |
36 | # predict.py defines how predictions are run on your model
37 | predict: "predict.py:Predictor"
38 |
--------------------------------------------------------------------------------
/figs/ben.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/ben.png
--------------------------------------------------------------------------------
/figs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/pipeline.png
--------------------------------------------------------------------------------
/figs/sota.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/sota.png
--------------------------------------------------------------------------------
/figs/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/table1.png
--------------------------------------------------------------------------------
/figs/table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/table2.png
--------------------------------------------------------------------------------
/figs/table3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/table3.png
--------------------------------------------------------------------------------
/figs/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/vis.png
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
--------------------------------------------------------------------------------
/llava/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Tuple
4 | import base64
5 | from io import BytesIO
6 | from PIL import Image
7 |
8 |
9 | class SeparatorStyle(Enum):
10 | """Different separator style."""
11 | SINGLE = auto()
12 | TWO = auto()
13 | MPT = auto()
14 | PLAIN = auto()
15 | LLAMA_2 = auto()
16 |
17 |
18 | @dataclasses.dataclass
19 | class Conversation:
20 | """A class that keeps all conversation history."""
21 | system: str
22 | roles: List[str]
23 | messages: List[List[str]]
24 | offset: int
25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26 | sep: str = "###"
27 | sep2: str = None
28 | version: str = "Unknown"
29 |
30 | skip_next: bool = False
31 |
32 | def get_prompt(self):
33 | messages = self.messages
34 | if len(messages) > 0 and type(messages[0][1]) is tuple:
35 | messages = self.messages.copy()
36 | init_role, init_msg = messages[0].copy()
37 | init_msg = init_msg[0].replace("", "").strip()
38 | if 'mmtag' in self.version:
39 | messages[0] = (init_role, init_msg)
40 | messages.insert(0, (self.roles[0], ""))
41 | messages.insert(1, (self.roles[1], "Received."))
42 | else:
43 | messages[0] = (init_role, "\n" + init_msg)
44 |
45 | if self.sep_style == SeparatorStyle.SINGLE:
46 | ret = self.system + self.sep
47 | for role, message in messages:
48 | if message:
49 | if type(message) is tuple:
50 | message, _, _ = message
51 | ret += role + ": " + message + self.sep
52 | else:
53 | ret += role + ":"
54 | elif self.sep_style == SeparatorStyle.TWO:
55 | seps = [self.sep, self.sep2]
56 | ret = self.system + seps[0]
57 | for i, (role, message) in enumerate(messages):
58 | if message:
59 | if type(message) is tuple:
60 | message, _, _ = message
61 | ret += role + ": " + message + seps[i % 2]
62 | else:
63 | ret += role + ":"
64 | elif self.sep_style == SeparatorStyle.MPT:
65 | ret = self.system + self.sep
66 | for role, message in messages:
67 | if message:
68 | if type(message) is tuple:
69 | message, _, _ = message
70 | ret += role + message + self.sep
71 | else:
72 | ret += role
73 | elif self.sep_style == SeparatorStyle.LLAMA_2:
74 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
75 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76 | ret = ""
77 |
78 | for i, (role, message) in enumerate(messages):
79 | if i == 0:
80 | assert message, "first message should not be none"
81 | assert role == self.roles[0], "first message should come from user"
82 | if message:
83 | if type(message) is tuple:
84 | message, _, _ = message
85 | if i == 0: message = wrap_sys(self.system) + message
86 | if i % 2 == 0:
87 | message = wrap_inst(message)
88 | ret += self.sep + message
89 | else:
90 | ret += " " + message + " " + self.sep2
91 | else:
92 | ret += ""
93 | ret = ret.lstrip(self.sep)
94 | elif self.sep_style == SeparatorStyle.PLAIN:
95 | seps = [self.sep, self.sep2]
96 | ret = self.system
97 | for i, (role, message) in enumerate(messages):
98 | if message:
99 | if type(message) is tuple:
100 | message, _, _ = message
101 | ret += message + seps[i % 2]
102 | else:
103 | ret += ""
104 | else:
105 | raise ValueError(f"Invalid style: {self.sep_style}")
106 |
107 | return ret
108 |
109 | def append_message(self, role, message):
110 | self.messages.append([role, message])
111 |
112 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
113 | if image_process_mode == "Pad":
114 | def expand2square(pil_img, background_color=(122, 116, 104)):
115 | width, height = pil_img.size
116 | if width == height:
117 | return pil_img
118 | elif width > height:
119 | result = Image.new(pil_img.mode, (width, width), background_color)
120 | result.paste(pil_img, (0, (width - height) // 2))
121 | return result
122 | else:
123 | result = Image.new(pil_img.mode, (height, height), background_color)
124 | result.paste(pil_img, ((height - width) // 2, 0))
125 | return result
126 | image = expand2square(image)
127 | elif image_process_mode in ["Default", "Crop"]:
128 | pass
129 | elif image_process_mode == "Resize":
130 | image = image.resize((336, 336))
131 | else:
132 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
133 | if max(image.size) > max_len:
134 | max_hw, min_hw = max(image.size), min(image.size)
135 | aspect_ratio = max_hw / min_hw
136 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
137 | longest_edge = int(shortest_edge * aspect_ratio)
138 | W, H = image.size
139 | if H > W:
140 | H, W = longest_edge, shortest_edge
141 | else:
142 | H, W = shortest_edge, longest_edge
143 | image = image.resize((W, H))
144 | if return_pil:
145 | return image
146 | else:
147 | buffered = BytesIO()
148 | image.save(buffered, format=image_format)
149 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
150 | return img_b64_str
151 |
152 | def get_images(self, return_pil=False):
153 | images = []
154 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
155 | if i % 2 == 0:
156 | if type(msg) is tuple:
157 | msg, image, image_process_mode = msg
158 | image = self.process_image(image, image_process_mode, return_pil=return_pil)
159 | images.append(image)
160 | return images
161 |
162 | def to_gradio_chatbot(self):
163 | ret = []
164 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
165 | if i % 2 == 0:
166 | if type(msg) is tuple:
167 | msg, image, image_process_mode = msg
168 | img_b64_str = self.process_image(
169 | image, "Default", return_pil=False,
170 | image_format='JPEG')
171 | img_str = f'
'
172 | msg = img_str + msg.replace('', '').strip()
173 | ret.append([msg, None])
174 | else:
175 | ret.append([msg, None])
176 | else:
177 | ret[-1][-1] = msg
178 | return ret
179 |
180 | def copy(self):
181 | return Conversation(
182 | system=self.system,
183 | roles=self.roles,
184 | messages=[[x, y] for x, y in self.messages],
185 | offset=self.offset,
186 | sep_style=self.sep_style,
187 | sep=self.sep,
188 | sep2=self.sep2,
189 | version=self.version)
190 |
191 | def dict(self):
192 | if len(self.get_images()) > 0:
193 | return {
194 | "system": self.system,
195 | "roles": self.roles,
196 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
197 | "offset": self.offset,
198 | "sep": self.sep,
199 | "sep2": self.sep2,
200 | }
201 | return {
202 | "system": self.system,
203 | "roles": self.roles,
204 | "messages": self.messages,
205 | "offset": self.offset,
206 | "sep": self.sep,
207 | "sep2": self.sep2,
208 | }
209 |
210 |
211 | conv_vicuna_v0 = Conversation(
212 | system="A chat between a curious human and an artificial intelligence assistant. "
213 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
214 | roles=("Human", "Assistant"),
215 | messages=(
216 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
217 | ("Assistant",
218 | "Renewable energy sources are those that can be replenished naturally in a relatively "
219 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
220 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
221 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
222 | "renewable and non-renewable energy sources:\n"
223 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
224 | "energy sources are finite and will eventually run out.\n"
225 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
226 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
227 | "and other negative effects.\n"
228 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
229 | "have lower operational costs than non-renewable sources.\n"
230 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
231 | "locations than non-renewable sources.\n"
232 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
233 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
234 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
235 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
236 | ),
237 | offset=2,
238 | sep_style=SeparatorStyle.SINGLE,
239 | sep="###",
240 | )
241 |
242 | conv_vicuna_v1 = Conversation(
243 | system="A chat between a curious user and an artificial intelligence assistant. "
244 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
245 | roles=("USER", "ASSISTANT"),
246 | version="v1",
247 | messages=(),
248 | offset=0,
249 | sep_style=SeparatorStyle.TWO,
250 | sep=" ",
251 | sep2="",
252 | )
253 |
254 | conv_llama_2 = Conversation(
255 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
256 |
257 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
258 | roles=("USER", "ASSISTANT"),
259 | version="llama_v2",
260 | messages=(),
261 | offset=0,
262 | sep_style=SeparatorStyle.LLAMA_2,
263 | sep="",
264 | sep2="",
265 | )
266 |
267 | conv_llava_llama_2 = Conversation(
268 | system="You are a helpful language and vision assistant. "
269 | "You are able to understand the visual content that the user provides, "
270 | "and assist the user with a variety of tasks using natural language.",
271 | roles=("USER", "ASSISTANT"),
272 | version="llama_v2",
273 | messages=(),
274 | offset=0,
275 | sep_style=SeparatorStyle.LLAMA_2,
276 | sep="",
277 | sep2="",
278 | )
279 |
280 | conv_mpt = Conversation(
281 | system="""<|im_start|>system
282 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
283 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
284 | version="mpt",
285 | messages=(),
286 | offset=0,
287 | sep_style=SeparatorStyle.MPT,
288 | sep="<|im_end|>",
289 | )
290 |
291 | conv_llava_plain = Conversation(
292 | system="",
293 | roles=("", ""),
294 | messages=(
295 | ),
296 | offset=0,
297 | sep_style=SeparatorStyle.PLAIN,
298 | sep="\n",
299 | )
300 |
301 | conv_llava_v0 = Conversation(
302 | system="A chat between a curious human and an artificial intelligence assistant. "
303 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
304 | roles=("Human", "Assistant"),
305 | messages=(
306 | ),
307 | offset=0,
308 | sep_style=SeparatorStyle.SINGLE,
309 | sep="###",
310 | )
311 |
312 | conv_llava_v0_mmtag = Conversation(
313 | system="A chat between a curious user and an artificial intelligence assistant. "
314 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
315 | "The visual content will be provided with the following format: visual content.",
316 | roles=("Human", "Assistant"),
317 | messages=(
318 | ),
319 | offset=0,
320 | sep_style=SeparatorStyle.SINGLE,
321 | sep="###",
322 | version="v0_mmtag",
323 | )
324 |
325 | conv_llava_v1 = Conversation(
326 | system="A chat between a curious human and an artificial intelligence assistant. "
327 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
328 | roles=("USER", "ASSISTANT"),
329 | version="v1",
330 | messages=(),
331 | offset=0,
332 | sep_style=SeparatorStyle.TWO,
333 | sep=" ",
334 | sep2="",
335 | )
336 |
337 | conv_llava_v1_mmtag = Conversation(
338 | system="A chat between a curious user and an artificial intelligence assistant. "
339 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
340 | "The visual content will be provided with the following format: visual content.",
341 | roles=("USER", "ASSISTANT"),
342 | messages=(),
343 | offset=0,
344 | sep_style=SeparatorStyle.TWO,
345 | sep=" ",
346 | sep2="",
347 | version="v1_mmtag",
348 | )
349 |
350 | conv_mistral_instruct = Conversation(
351 | system="",
352 | roles=("USER", "ASSISTANT"),
353 | version="llama_v2",
354 | messages=(),
355 | offset=0,
356 | sep_style=SeparatorStyle.LLAMA_2,
357 | sep="",
358 | sep2="",
359 | )
360 |
361 | conv_chatml_direct = Conversation(
362 | system="""<|im_start|>system
363 | Answer the questions.""",
364 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
365 | version="mpt",
366 | messages=(),
367 | offset=0,
368 | sep_style=SeparatorStyle.MPT,
369 | sep="<|im_end|>",
370 | )
371 |
372 | default_conversation = conv_vicuna_v1
373 | conv_templates = {
374 | "default": conv_vicuna_v0,
375 | "v0": conv_vicuna_v0,
376 | "v1": conv_vicuna_v1,
377 | "vicuna_v1": conv_vicuna_v1,
378 | "llama_2": conv_llama_2,
379 | "mistral_instruct": conv_mistral_instruct,
380 | "chatml_direct": conv_chatml_direct,
381 | "mistral_direct": conv_chatml_direct,
382 |
383 | "plain": conv_llava_plain,
384 | "v0_plain": conv_llava_plain,
385 | "llava_v0": conv_llava_v0,
386 | "v0_mmtag": conv_llava_v0_mmtag,
387 | "llava_v1": conv_llava_v1,
388 | "v1_mmtag": conv_llava_v1_mmtag,
389 | "llava_llama_2": conv_llava_llama_2,
390 |
391 | "mpt": conv_mpt,
392 | }
393 |
394 |
395 | if __name__ == "__main__":
396 | print(default_conversation.get_prompt())
397 |
--------------------------------------------------------------------------------
/llava/eval/model_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from decord import VideoReader, cpu
4 |
5 |
6 |
7 | def load_video(vis_path, n_clips=1, num_frm=100):
8 | """
9 | Load video frames from a video file.
10 |
11 | Parameters:
12 | vis_path (str): Path to the video file.
13 | n_clips (int): Number of clips to extract from the video. Defaults to 1.
14 | num_frm (int): Number of frames to extract from each clip. Defaults to 100.
15 |
16 | Returns:
17 | list: List of PIL.Image.Image objects representing video frames.
18 | """
19 |
20 | # decord.bridge.set_bridge('torch')
21 | # Load video with VideoReader
22 | vr = VideoReader(vis_path, ctx=cpu(0))
23 | total_frame_num = len(vr)
24 |
25 | # Currently, this function supports only 1 clip
26 | assert n_clips == 1
27 |
28 | # Calculate total number of frames to extract
29 | total_num_frm = min(total_frame_num, num_frm)
30 | # Get indices of frames to extract
31 | frame_idx = get_seq_frames(total_frame_num, total_num_frm)
32 | # Extract frames as numpy array
33 | img_array = vr.get_batch(frame_idx).asnumpy() # T H W C
34 |
35 | original_size = (img_array.shape[-2], img_array.shape[-3]) # (width, height)
36 | original_sizes = (original_size,) * total_num_frm
37 |
38 | clip_imgs = [Image.fromarray(img_array[j]) for j in range(total_num_frm)]
39 |
40 |
41 | return clip_imgs, original_sizes
42 |
43 |
44 |
45 |
46 | def get_seq_frames(total_num_frames, desired_num_frames):
47 | """
48 | Calculate the indices of frames to extract from a video.
49 |
50 | Parameters:
51 | total_num_frames (int): Total number of frames in the video.
52 | desired_num_frames (int): Desired number of frames to extract.
53 |
54 | Returns:
55 | list: List of indices of frames to extract.
56 | """
57 |
58 | # Calculate the size of each segment from which a frame will be extracted
59 | seg_size = float(total_num_frames - 1) / desired_num_frames
60 |
61 | seq = []
62 | for i in range(desired_num_frames):
63 | # Calculate the start and end indices of each segment
64 | start = int(np.round(seg_size * i))
65 | end = int(np.round(seg_size * (i + 1)))
66 |
67 | # Append the middle index of the segment to the list
68 | seq.append((start + end) // 2)
69 |
70 | return seq
71 |
--------------------------------------------------------------------------------
/llava/eval/run_inference_benchmark_consistency.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import argparse
4 | import json
5 |
6 | from tqdm import tqdm
7 | from llava.eval.model_utils import load_video
8 |
9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 | from llava.conversation import conv_templates, SeparatorStyle
11 | from llava.model.builder import load_pretrained_model
12 | from llava.utils import disable_torch_init
13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
14 |
15 | from PIL import Image
16 | import math
17 | import torch
18 | import time
19 |
20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes):
21 | if model.config.mm_use_im_start_end:
22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
23 | else:
24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question
25 |
26 | conv = conv_templates[conv_mode].copy()
27 | conv.append_message(conv.roles[0], qs)
28 | conv.append_message(conv.roles[1], None)
29 | prompt = conv.get_prompt()
30 |
31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
32 | image_tensor = process_images(video_frames, image_processor, model.config)
33 |
34 | with torch.inference_mode():
35 | output_ids = model.generate(
36 | input_ids,
37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
38 | image_sizes=image_sizes,
39 | do_sample=True if args.temperature > 0 else False,
40 | temperature=args.temperature,
41 | top_p=args.top_p,
42 | num_beams=args.num_beams,
43 | max_new_tokens=128,
44 | use_cache=True)
45 |
46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
47 | return outputs
48 |
49 | def split_list(lst, n):
50 | """Split a list into n (roughly) equal-sized chunks"""
51 | chunk_size = math.ceil(len(lst) / n) # integer division
52 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
53 |
54 | def get_chunk(lst, n, k):
55 | chunks = split_list(lst, n)
56 | return chunks[k]
57 |
58 | def parse_args():
59 | """
60 | Parse command-line arguments.
61 | """
62 | parser = argparse.ArgumentParser()
63 |
64 | # Define the command-line arguments
65 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
66 | parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True)
67 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
68 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
69 | parser.add_argument("--model_name", type=str, required=True)
70 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1')
71 | parser.add_argument("--num_chunks", type=int, default=1)
72 | parser.add_argument("--chunk_idx", type=int, default=0)
73 | parser.add_argument("--num_frames", type=int, default=100)
74 | parser.add_argument("--device", type=str, required=False, default='cuda:0')
75 | parser.add_argument("--model-base", type=str, default=None)
76 | parser.add_argument("--num_beams", type=int, default=1)
77 | parser.add_argument("--temperature", type=float, default=0.2)
78 | parser.add_argument("--top_p", type=float, default=None)
79 |
80 | return parser.parse_args()
81 |
82 |
83 | def run_inference(args):
84 | """
85 | Run inference on a set of video files using the provided model.
86 |
87 | Args:
88 | args: Command-line arguments.
89 | """
90 |
91 | disable_torch_init()
92 | model_path = os.path.expanduser(args.model_name)
93 | model_name = get_model_name_from_path(model_path)
94 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
95 |
96 |
97 | gt_contents = json.load(open(args.gt_file, "r"))
98 | gt_contents = get_chunk(gt_contents, args.num_chunks, args.chunk_idx)
99 |
100 | answers_file = os.path.join(args.output_dir, f"{args.output_name}.json")
101 | os.makedirs(args.output_dir, exist_ok=True)
102 | ans_file = open(answers_file, "w")
103 |
104 | # Create the output directory if it doesn't exist
105 | if not os.path.exists(args.output_dir):
106 | os.makedirs(args.output_dir)
107 |
108 | output_list = [] # List to store the output results
109 | conv_mode = args.conv_mode
110 |
111 | video_formats = ['.mp4', '.avi', '.mov', '.mkv']
112 |
113 | # Iterate over each sample in the ground truth file
114 | index = 0
115 | for sample in tqdm(gt_contents):
116 | video_name = sample['video_name']
117 | sample_set = sample
118 | question_1 = sample['Q1']
119 | question_2 = sample['Q2']
120 |
121 | # Load the video file
122 | for fmt in video_formats: # Added this line
123 | temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}")
124 | if os.path.exists(temp_path):
125 | video_path = temp_path
126 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames)
127 | # Run inference on the video for the first question and add the output to the list
128 | output_1 = llava_inference(video_frames, question_1, conv_mode, model,
129 | tokenizer, image_processor, sizes)
130 | sample_set['pred1'] = output_1
131 |
132 | # Run inference on the video for the second question and add the output to the list
133 | output_2 = llava_inference(video_frames, question_2, conv_mode, model,
134 | tokenizer, image_processor, sizes)
135 | sample_set['pred2'] = output_2
136 |
137 | output_list.append(sample_set)
138 | ans_file.write(json.dumps(sample_set) + "\n")
139 | index += 1
140 | break
141 |
142 | ans_file.close()
143 |
144 |
145 | if __name__ == "__main__":
146 | args = parse_args()
147 | run_inference(args)
148 |
--------------------------------------------------------------------------------
/llava/eval/run_inference_benchmark_general.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import argparse
4 | import json
5 |
6 | from tqdm import tqdm
7 | from llava.eval.model_utils import load_video
8 |
9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 | from llava.conversation import conv_templates, SeparatorStyle
11 | from llava.model.builder import load_pretrained_model
12 | from llava.utils import disable_torch_init
13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
14 |
15 | from PIL import Image
16 | import math
17 | import torch
18 | import time
19 |
20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes):
21 | if model.config.mm_use_im_start_end:
22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
23 | else:
24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question
25 |
26 | conv = conv_templates[conv_mode].copy()
27 | conv.append_message(conv.roles[0], qs)
28 | conv.append_message(conv.roles[1], None)
29 | prompt = conv.get_prompt()
30 |
31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
32 | image_tensor = process_images(video_frames, image_processor, model.config)
33 |
34 | with torch.inference_mode():
35 | output_ids = model.generate(
36 | input_ids,
37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
38 | image_sizes=image_sizes,
39 | do_sample=True if args.temperature > 0 else False,
40 | temperature=args.temperature,
41 | top_p=args.top_p,
42 | num_beams=args.num_beams,
43 | max_new_tokens=128,
44 | use_cache=True)
45 |
46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
47 | return outputs
48 |
49 | def split_list(lst, n):
50 | """Split a list into n (roughly) equal-sized chunks"""
51 | chunk_size = math.ceil(len(lst) / n) # integer division
52 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
53 |
54 | def get_chunk(lst, n, k):
55 | chunks = split_list(lst, n)
56 | return chunks[k]
57 |
58 | def parse_args():
59 | """
60 | Parse command-line arguments.
61 | """
62 | parser = argparse.ArgumentParser()
63 |
64 | # Define the command-line arguments
65 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
66 | parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True)
67 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
68 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
69 | parser.add_argument("--model_name", type=str, required=True)
70 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1')
71 | parser.add_argument("--num_chunks", type=int, default=1)
72 | parser.add_argument("--chunk_idx", type=int, default=0)
73 | parser.add_argument("--num_frames", type=int, default=100)
74 | parser.add_argument("--device", type=str, required=False, default='cuda:0')
75 | parser.add_argument("--model-base", type=str, default=None)
76 | parser.add_argument("--num_beams", type=int, default=1)
77 | parser.add_argument("--temperature", type=float, default=0.2)
78 | parser.add_argument("--top_p", type=float, default=None)
79 |
80 | return parser.parse_args()
81 |
82 |
83 | def run_inference(args):
84 | """
85 | Run inference on a set of video files using the provided model.
86 |
87 | Args:
88 | args: Command-line arguments.
89 | """
90 |
91 | disable_torch_init()
92 | model_path = os.path.expanduser(args.model_name)
93 | model_name = get_model_name_from_path(model_path)
94 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
95 |
96 |
97 | gt_contents = json.load(open(args.gt_file, "r"))
98 | gt_contents = get_chunk(gt_contents, args.num_chunks, args.chunk_idx)
99 |
100 | answers_file = os.path.join(args.output_dir, f"{args.output_name}.json")
101 | os.makedirs(args.output_dir, exist_ok=True)
102 | ans_file = open(answers_file, "w")
103 |
104 | # Create the output directory if it doesn't exist
105 | if not os.path.exists(args.output_dir):
106 | os.makedirs(args.output_dir)
107 |
108 | output_list = [] # List to store the output results
109 | conv_mode = args.conv_mode
110 |
111 | video_formats = ['.mp4', '.avi', '.mov', '.mkv']
112 |
113 | # Iterate over each sample in the ground truth file
114 | index = 0
115 | for sample in tqdm(gt_contents):
116 | video_name = sample['video_name']
117 | sample_set = sample
118 | question = sample['Q']
119 |
120 | # Load the video file
121 | for fmt in video_formats: # Added this line
122 | temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}")
123 | if os.path.exists(temp_path):
124 | video_path = temp_path
125 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames)
126 | # Run inference on the video and add the output to the list
127 | output = llava_inference(video_frames, question, conv_mode, model,
128 | tokenizer, image_processor, sizes)
129 | sample_set['pred'] = output
130 |
131 | output_list.append(sample_set)
132 | ans_file.write(json.dumps(sample_set) + "\n")
133 | index += 1
134 | break
135 |
136 | ans_file.close()
137 |
138 |
139 |
140 | if __name__ == "__main__":
141 | args = parse_args()
142 | run_inference(args)
143 |
--------------------------------------------------------------------------------
/llava/eval/run_inference_qa.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import argparse
4 | import json
5 |
6 | from tqdm import tqdm
7 | from llava.eval.model_utils import load_video
8 |
9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 | from llava.conversation import conv_templates, SeparatorStyle
11 | from llava.model.builder import load_pretrained_model
12 | from llava.utils import disable_torch_init
13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
14 |
15 | from PIL import Image
16 | import math
17 | import torch
18 | import time
19 |
20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes):
21 | if model.config.mm_use_im_start_end:
22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
23 | else:
24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question
25 |
26 | conv = conv_templates[conv_mode].copy()
27 | conv.append_message(conv.roles[0], qs)
28 | conv.append_message(conv.roles[1], None)
29 | prompt = conv.get_prompt()
30 |
31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
32 | image_tensor = process_images(video_frames, image_processor, model.config)
33 |
34 | with torch.inference_mode():
35 | output_ids = model.generate(
36 | input_ids,
37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
38 | image_sizes=image_sizes,
39 | do_sample=True if args.temperature > 0 else False,
40 | temperature=args.temperature,
41 | top_p=args.top_p,
42 | num_beams=args.num_beams,
43 | max_new_tokens=128,
44 | use_cache=True)
45 |
46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
47 | return outputs
48 |
49 | def split_list(lst, n):
50 | """Split a list into n (roughly) equal-sized chunks"""
51 | chunk_size = math.ceil(len(lst) / n) # integer division
52 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
53 |
54 | def get_chunk(lst, n, k):
55 | chunks = split_list(lst, n)
56 | return chunks[k]
57 |
58 | def parse_args():
59 | """
60 | Parse command-line arguments.
61 | """
62 | parser = argparse.ArgumentParser()
63 |
64 | # Define the command-line arguments
65 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
66 | parser.add_argument('--gt_file_question', help='Path to the ground truth file containing question.', required=True)
67 | parser.add_argument('--gt_file_answers', help='Path to the ground truth file containing answers.', required=True)
68 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
69 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
70 | parser.add_argument("--model_name", type=str, required=True)
71 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1')
72 | parser.add_argument("--num_chunks", type=int, default=1)
73 | parser.add_argument("--chunk_idx", type=int, default=0)
74 | parser.add_argument("--num_frames", type=int, default=100)
75 | parser.add_argument("--device", type=str, required=False, default='cuda:0')
76 | parser.add_argument("--model-base", type=str, default=None)
77 | parser.add_argument("--num_beams", type=int, default=1)
78 | parser.add_argument("--temperature", type=float, default=0.2)
79 | parser.add_argument("--top_p", type=float, default=None)
80 |
81 | return parser.parse_args()
82 |
83 |
84 | def run_inference(args):
85 | """
86 | Run inference on Video QA DataSetå.
87 |
88 | Args:
89 | args: Command-line arguments.
90 | """
91 |
92 | disable_torch_init()
93 | model_path = os.path.expanduser(args.model_name)
94 | model_name = get_model_name_from_path(model_path)
95 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
96 |
97 |
98 | gt_questions = json.load(open(args.gt_file_question, "r"))
99 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
100 | gt_answers = json.load(open(args.gt_file_answers, "r"))
101 | gt_answers = get_chunk(gt_answers, args.num_chunks, args.chunk_idx)
102 |
103 | answers_file = os.path.join(args.output_dir, f"{args.output_name}.json")
104 | os.makedirs(args.output_dir, exist_ok=True)
105 | ans_file = open(answers_file, "w")
106 |
107 | # Create the output directory if it doesn't exist
108 | if not os.path.exists(args.output_dir):
109 | os.makedirs(args.output_dir)
110 |
111 | output_list = [] # List to store the output results
112 | conv_mode = args.conv_mode
113 |
114 | video_formats = ['.mp4', '.avi', '.mov', '.mkv']
115 |
116 | # Iterate over each sample in the ground truth file
117 | index = 0
118 | for sample in tqdm(gt_questions):
119 | video_name = sample['video_name']
120 | question = sample['question']
121 | id = sample['question_id']
122 | answer = gt_answers[index]['answer']
123 | index += 1
124 |
125 | sample_set = {'id': id, 'question': question, 'answer': answer}
126 |
127 | # Load the video file
128 | for fmt in video_formats: # Added this line
129 | vid_name = f"v_{video_name}" if 'Activitynet' in args.video_dir else video_name
130 | temp_path = os.path.join(args.video_dir, f"{vid_name}{fmt}")
131 |
132 | if os.path.exists(temp_path):
133 | # print(f'processing {idx}/{len(gt_questions)}')
134 | video_path = temp_path
135 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames)
136 | # Run inference on the video and add the output to the list
137 | output = llava_inference(video_frames, question, conv_mode, model,
138 | tokenizer, image_processor, sizes)
139 | print(output)
140 | sample_set['pred'] = output
141 | output_list.append(sample_set)
142 | ans_file.write(json.dumps(sample_set) + "\n")
143 | break
144 |
145 | ans_file.close()
146 |
147 |
148 | if __name__ == "__main__":
149 | args = parse_args()
150 | run_inference(args)
151 |
--------------------------------------------------------------------------------
/llava/eval/single_video_inference.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import argparse
4 | import json
5 |
6 | from tqdm import tqdm
7 | from llava.eval.model_utils import load_video
8 |
9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 | from llava.conversation import conv_templates, SeparatorStyle
11 | from llava.model.builder import load_pretrained_model
12 | from llava.utils import disable_torch_init
13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
14 |
15 | from PIL import Image
16 | import math
17 | import torch
18 | import time
19 |
20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes):
21 | if model.config.mm_use_im_start_end:
22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
23 | else:
24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question
25 |
26 | conv = conv_templates[conv_mode].copy()
27 | conv.append_message(conv.roles[0], qs)
28 | conv.append_message(conv.roles[1], None)
29 | prompt = conv.get_prompt()
30 |
31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
32 | image_tensor = process_images(video_frames, image_processor, model.config)
33 |
34 | with torch.inference_mode():
35 | output_ids = model.generate(
36 | input_ids,
37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
38 | image_sizes=image_sizes,
39 | do_sample=False,
40 | temperature=0,
41 | top_p=None,
42 | num_beams=1,
43 | max_new_tokens=128,
44 | use_cache=True)
45 |
46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
47 | return outputs
48 |
49 | def parse_args():
50 | """
51 | Parse command-line arguments.
52 | """
53 | parser = argparse.ArgumentParser()
54 |
55 | # Define the command-line arguments
56 | parser.add_argument('--video_path', help='input video path', required=True)
57 | parser.add_argument("--model_name", type=str, required=True)
58 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1')
59 | parser.add_argument("--num_frames", type=int, default=4)
60 |
61 | return parser.parse_args()
62 |
63 |
64 | def run_inference(args):
65 | """
66 | Run inference
67 |
68 | Args:
69 | args: Command-line arguments.
70 | """
71 |
72 | disable_torch_init()
73 | model_path = os.path.expanduser(args.model_name)
74 | model_name = get_model_name_from_path(model_path)
75 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
76 |
77 | conv_mode = args.conv_mode
78 |
79 | video_path = args.video_path
80 |
81 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames)
82 |
83 | question = "describe the video in detail"
84 |
85 | try:
86 | # Run inference on the video and add the output to the list
87 | output = llava_inference(video_frames, question, conv_mode, model,
88 | tokenizer, image_processor, sizes)
89 | print("\n\n", output)
90 |
91 | except Exception as e:
92 | print(f"Error processing video file '{video_path}': {e}")
93 |
94 |
95 | if __name__ == "__main__":
96 | args = parse_args()
97 | run_inference(args)
98 |
--------------------------------------------------------------------------------
/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 | import torch
5 | import math
6 | import ast
7 |
8 | from transformers import StoppingCriteria
9 | from llava.constants import IMAGE_TOKEN_INDEX
10 |
11 |
12 | def select_best_resolution(original_size, possible_resolutions):
13 | """
14 | Selects the best resolution from a list of possible resolutions based on the original size.
15 |
16 | Args:
17 | original_size (tuple): The original size of the image in the format (width, height).
18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19 |
20 | Returns:
21 | tuple: The best fit resolution in the format (width, height).
22 | """
23 | original_width, original_height = original_size
24 | best_fit = None
25 | max_effective_resolution = 0
26 | min_wasted_resolution = float('inf')
27 |
28 | for width, height in possible_resolutions:
29 | scale = min(width / original_width, height / original_height)
30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32 | wasted_resolution = (width * height) - effective_resolution
33 |
34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35 | max_effective_resolution = effective_resolution
36 | min_wasted_resolution = wasted_resolution
37 | best_fit = (width, height)
38 |
39 | return best_fit
40 |
41 |
42 | def resize_and_pad_image(image, target_resolution):
43 | """
44 | Resize and pad an image to a target resolution while maintaining aspect ratio.
45 |
46 | Args:
47 | image (PIL.Image.Image): The input image.
48 | target_resolution (tuple): The target resolution (width, height) of the image.
49 |
50 | Returns:
51 | PIL.Image.Image: The resized and padded image.
52 | """
53 | original_width, original_height = image.size
54 | target_width, target_height = target_resolution
55 |
56 | scale_w = target_width / original_width
57 | scale_h = target_height / original_height
58 |
59 | if scale_w < scale_h:
60 | new_width = target_width
61 | new_height = min(math.ceil(original_height * scale_w), target_height)
62 | else:
63 | new_height = target_height
64 | new_width = min(math.ceil(original_width * scale_h), target_width)
65 |
66 | # Resize the image
67 | resized_image = image.resize((new_width, new_height))
68 |
69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70 | paste_x = (target_width - new_width) // 2
71 | paste_y = (target_height - new_height) // 2
72 | new_image.paste(resized_image, (paste_x, paste_y))
73 |
74 | return new_image
75 |
76 |
77 | def divide_to_patches(image, patch_size):
78 | """
79 | Divides an image into patches of a specified size.
80 |
81 | Args:
82 | image (PIL.Image.Image): The input image.
83 | patch_size (int): The size of each patch.
84 |
85 | Returns:
86 | list: A list of PIL.Image.Image objects representing the patches.
87 | """
88 | patches = []
89 | width, height = image.size
90 | for i in range(0, height, patch_size):
91 | for j in range(0, width, patch_size):
92 | box = (j, i, j + patch_size, i + patch_size)
93 | patch = image.crop(box)
94 | patches.append(patch)
95 |
96 | return patches
97 |
98 |
99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100 | """
101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102 |
103 | Args:
104 | image_size (tuple): The size of the input image in the format (width, height).
105 | grid_pinpoints (str): A string representation of a list of possible resolutions.
106 | patch_size (int): The size of each image patch.
107 |
108 | Returns:
109 | tuple: The shape of the image patch grid in the format (width, height).
110 | """
111 | if type(grid_pinpoints) is list:
112 | possible_resolutions = grid_pinpoints
113 | else:
114 | possible_resolutions = ast.literal_eval(grid_pinpoints)
115 | width, height = select_best_resolution(image_size, possible_resolutions)
116 | return width // patch_size, height // patch_size
117 |
118 |
119 | def process_anyres_image(image, processor, grid_pinpoints):
120 | """
121 | Process an image with variable resolutions.
122 |
123 | Args:
124 | image (PIL.Image.Image): The input image to be processed.
125 | processor: The image processor object.
126 | grid_pinpoints (str): A string representation of a list of possible resolutions.
127 |
128 | Returns:
129 | torch.Tensor: A tensor containing the processed image patches.
130 | """
131 | if type(grid_pinpoints) is list:
132 | possible_resolutions = grid_pinpoints
133 | else:
134 | possible_resolutions = ast.literal_eval(grid_pinpoints)
135 | best_resolution = select_best_resolution(image.size, possible_resolutions)
136 | image_padded = resize_and_pad_image(image, best_resolution)
137 |
138 | patches = divide_to_patches(image_padded, processor.crop_size['height'])
139 |
140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141 |
142 | image_patches = [image_original_resize] + patches
143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144 | for image_patch in image_patches]
145 | return torch.stack(image_patches, dim=0)
146 |
147 |
148 | def load_image_from_base64(image):
149 | return Image.open(BytesIO(base64.b64decode(image)))
150 |
151 |
152 | def expand2square(pil_img, background_color):
153 | width, height = pil_img.size
154 | if width == height:
155 | return pil_img
156 | elif width > height:
157 | result = Image.new(pil_img.mode, (width, width), background_color)
158 | result.paste(pil_img, (0, (width - height) // 2))
159 | return result
160 | else:
161 | result = Image.new(pil_img.mode, (height, height), background_color)
162 | result.paste(pil_img, ((height - width) // 2, 0))
163 | return result
164 |
165 |
166 | def process_images(images, image_processor, model_cfg):
167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168 | new_images = []
169 | if image_aspect_ratio == 'pad':
170 | for image in images:
171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173 | new_images.append(image)
174 | elif image_aspect_ratio == "anyres":
175 | for image in images:
176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177 | new_images.append(image)
178 | else:
179 | return image_processor(images, return_tensors='pt')['pixel_values']
180 | if all(x.shape == new_images[0].shape for x in new_images):
181 | new_images = torch.stack(new_images, dim=0)
182 |
183 | return new_images
184 |
185 |
186 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
187 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
188 |
189 | def insert_separator(X, sep):
190 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
191 |
192 | input_ids = []
193 | offset = 0
194 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
195 | offset = 1
196 | input_ids.append(prompt_chunks[0][0])
197 |
198 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
199 | input_ids.extend(x[offset:])
200 |
201 | if return_tensors is not None:
202 | if return_tensors == 'pt':
203 | return torch.tensor(input_ids, dtype=torch.long)
204 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
205 | return input_ids
206 |
207 |
208 | def get_model_name_from_path(model_path):
209 | model_path = model_path.strip("/")
210 | model_paths = model_path.split("/")
211 | if model_paths[-1].startswith('checkpoint-'):
212 | return model_paths[-2] + "_" + model_paths[-1]
213 | else:
214 | return model_paths[-1]
215 |
216 | class KeywordsStoppingCriteria(StoppingCriteria):
217 | def __init__(self, keywords, tokenizer, input_ids):
218 | self.keywords = keywords
219 | self.keyword_ids = []
220 | self.max_keyword_len = 0
221 | for keyword in keywords:
222 | cur_keyword_ids = tokenizer(keyword).input_ids
223 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
224 | cur_keyword_ids = cur_keyword_ids[1:]
225 | if len(cur_keyword_ids) > self.max_keyword_len:
226 | self.max_keyword_len = len(cur_keyword_ids)
227 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
228 | self.tokenizer = tokenizer
229 | self.start_len = input_ids.shape[1]
230 |
231 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
232 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
233 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
234 | for keyword_id in self.keyword_ids:
235 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
236 | if torch.equal(truncated_output_ids, keyword_id):
237 | return True
238 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
239 | for keyword in self.keywords:
240 | if keyword in outputs:
241 | return True
242 | return False
243 |
244 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
245 | outputs = []
246 | for i in range(output_ids.shape[0]):
247 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
248 | return all(outputs)
249 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5 | except:
6 | pass
7 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 |
20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21 | import torch
22 | from llava.model import *
23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 |
25 |
26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27 | kwargs = {"device_map": device_map, **kwargs}
28 |
29 | if device != "cuda":
30 | kwargs['device_map'] = {"": device}
31 |
32 | if load_8bit:
33 | kwargs['load_in_8bit'] = True
34 | elif load_4bit:
35 | kwargs['load_in_4bit'] = True
36 | kwargs['quantization_config'] = BitsAndBytesConfig(
37 | load_in_4bit=True,
38 | bnb_4bit_compute_dtype=torch.float16,
39 | bnb_4bit_use_double_quant=True,
40 | bnb_4bit_quant_type='nf4'
41 | )
42 | else:
43 | kwargs['torch_dtype'] = torch.float16
44 |
45 | if use_flash_attn:
46 | kwargs['attn_implementation'] = 'flash_attention_2'
47 |
48 | if 'llava' in model_name.lower():
49 | # Load LLaVA model
50 | if 'lora' in model_name.lower() and model_base is None:
51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52 | if 'lora' in model_name.lower() and model_base is not None:
53 | from llava.model.language_model.llava_llama import LlavaConfig
54 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56 | print('Loading LLaVA from base model...')
57 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
59 | if model.lm_head.weight.shape[0] != token_num:
60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62 |
63 | print('Loading additional LLaVA weights...')
64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
66 | else:
67 | # this is probably from HF Hub
68 | from huggingface_hub import hf_hub_download
69 | def load_from_hf(repo_id, filename, subfolder=None):
70 | cache_file = hf_hub_download(
71 | repo_id=repo_id,
72 | filename=filename,
73 | subfolder=subfolder)
74 | return torch.load(cache_file, map_location='cpu')
75 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
76 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
77 | if any(k.startswith('model.model.') for k in non_lora_trainables):
78 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
79 | model.load_state_dict(non_lora_trainables, strict=False)
80 |
81 | from peft import PeftModel
82 | print('Loading LoRA weights...')
83 | model = PeftModel.from_pretrained(model, model_path)
84 | print('Merging LoRA weights...')
85 | model = model.merge_and_unload()
86 | print('Model is loaded...')
87 | elif model_base is not None:
88 | # this may be mm projector only
89 | print('Loading LLaVA from base model...')
90 | if 'mpt' in model_name.lower():
91 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
92 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
93 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
94 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
95 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
96 | else:
97 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
98 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
99 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
100 |
101 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
102 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
103 | model.load_state_dict(mm_projector_weights, strict=False)
104 | else:
105 | if 'mpt' in model_name.lower():
106 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
107 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
108 | elif 'mistral' in model_name.lower():
109 | tokenizer = AutoTokenizer.from_pretrained(model_path)
110 | model = LlavaMistralForCausalLM.from_pretrained(
111 | model_path,
112 | low_cpu_mem_usage=True,
113 | **kwargs
114 | )
115 | else:
116 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
117 | model = LlavaLlamaForCausalLM.from_pretrained(
118 | model_path,
119 | low_cpu_mem_usage=True,
120 | **kwargs
121 | )
122 | else:
123 | # Load language model
124 | if model_base is not None:
125 | # PEFT model
126 | from peft import PeftModel
127 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
128 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
129 | print(f"Loading LoRA weights from {model_path}")
130 | model = PeftModel.from_pretrained(model, model_path)
131 | print(f"Merging weights")
132 | model = model.merge_and_unload()
133 | print('Convert to FP16...')
134 | model.to(torch.float16)
135 | else:
136 | use_fast = False
137 | if 'mpt' in model_name.lower():
138 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
139 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
140 | else:
141 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
142 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
143 |
144 | image_processor = None
145 |
146 | if 'llava' in model_name.lower():
147 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
148 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
149 | if mm_use_im_patch_token:
150 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
151 | if mm_use_im_start_end:
152 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
153 | model.resize_token_embeddings(len(tokenizer))
154 |
155 | vision_tower = model.get_vision_tower()
156 | if not vision_tower.is_loaded:
157 | vision_tower.load_model(device_map=device_map)
158 | if device_map != 'auto':
159 | vision_tower.to(device=device_map, dtype=torch.float16)
160 | # vision_tower.to(device=device, dtype=torch.float16)
161 | image_processor = vision_tower.image_processor
162 |
163 | if hasattr(model.config, "max_sequence_length"):
164 | context_len = model.config.max_sequence_length
165 | else:
166 | context_len = 2048
167 |
168 | return tokenizer, model, image_processor, context_len
169 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, \
22 | LlamaConfig, LlamaModel, LlamaForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from transformers.generation.utils import GenerateOutput
26 |
27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaConfig(LlamaConfig):
31 | model_type = "llava_llama"
32 |
33 |
34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35 | config_class = LlavaConfig
36 |
37 | def __init__(self, config: LlamaConfig):
38 | super(LlavaLlamaModel, self).__init__(config)
39 |
40 |
41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42 | config_class = LlavaConfig
43 |
44 | def __init__(self, config):
45 | super(LlamaForCausalLM, self).__init__(config)
46 | self.model = LlavaLlamaModel(config)
47 | self.pretraining_tp = config.pretraining_tp
48 | self.vocab_size = config.vocab_size
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 |
73 | if inputs_embeds is None:
74 | (
75 | input_ids,
76 | position_ids,
77 | attention_mask,
78 | past_key_values,
79 | inputs_embeds,
80 | labels
81 | ) = self.prepare_inputs_labels_for_multimodal(
82 | input_ids,
83 | position_ids,
84 | attention_mask,
85 | past_key_values,
86 | labels,
87 | images,
88 | image_sizes,
89 | )
90 |
91 | return super().forward(
92 | input_ids=input_ids,
93 | attention_mask=attention_mask,
94 | position_ids=position_ids,
95 | past_key_values=past_key_values,
96 | inputs_embeds=inputs_embeds,
97 | labels=labels,
98 | use_cache=use_cache,
99 | output_attentions=output_attentions,
100 | output_hidden_states=output_hidden_states,
101 | return_dict=return_dict
102 | )
103 |
104 | @torch.no_grad()
105 | def generate(
106 | self,
107 | inputs: Optional[torch.Tensor] = None,
108 | images: Optional[torch.Tensor] = None,
109 | image_sizes: Optional[torch.Tensor] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | position_ids = kwargs.pop("position_ids", None)
113 | attention_mask = kwargs.pop("attention_mask", None)
114 | if "inputs_embeds" in kwargs:
115 | raise NotImplementedError("`inputs_embeds` is not supported")
116 |
117 | if images is not None:
118 | (
119 | inputs,
120 | position_ids,
121 | attention_mask,
122 | _,
123 | inputs_embeds,
124 | _
125 | ) = self.prepare_inputs_labels_for_multimodal(
126 | inputs,
127 | position_ids,
128 | attention_mask,
129 | None,
130 | None,
131 | images,
132 | image_sizes=image_sizes,
133 | )
134 | else:
135 | inputs_embeds = self.get_model().embed_tokens(inputs)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 |
156 | return inputs
157 |
158 | AutoConfig.register("llava_llama", LlavaConfig)
159 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
160 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mistral.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn import CrossEntropyLoss
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | MistralConfig, MistralModel, MistralForCausalLM
24 |
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 | from transformers.generation.utils import GenerateOutput
27 |
28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29 |
30 |
31 | class LlavaMistralConfig(MistralConfig):
32 | model_type = "llava_mistral"
33 |
34 |
35 | class LlavaMistralModel(LlavaMetaModel, MistralModel):
36 | config_class = LlavaMistralConfig
37 |
38 | def __init__(self, config: MistralConfig):
39 | super(LlavaMistralModel, self).__init__(config)
40 |
41 |
42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43 | config_class = LlavaMistralConfig
44 |
45 | def __init__(self, config):
46 | super(MistralForCausalLM, self).__init__(config)
47 | self.model = LlavaMistralModel(config)
48 |
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 |
73 | if inputs_embeds is None:
74 | (
75 | input_ids,
76 | position_ids,
77 | attention_mask,
78 | past_key_values,
79 | inputs_embeds,
80 | labels
81 | ) = self.prepare_inputs_labels_for_multimodal(
82 | input_ids,
83 | position_ids,
84 | attention_mask,
85 | past_key_values,
86 | labels,
87 | images,
88 | image_sizes
89 | )
90 |
91 | return super().forward(
92 | input_ids=input_ids,
93 | attention_mask=attention_mask,
94 | position_ids=position_ids,
95 | past_key_values=past_key_values,
96 | inputs_embeds=inputs_embeds,
97 | labels=labels,
98 | use_cache=use_cache,
99 | output_attentions=output_attentions,
100 | output_hidden_states=output_hidden_states,
101 | return_dict=return_dict
102 | )
103 |
104 | @torch.no_grad()
105 | def generate(
106 | self,
107 | inputs: Optional[torch.Tensor] = None,
108 | images: Optional[torch.Tensor] = None,
109 | image_sizes: Optional[torch.Tensor] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | position_ids = kwargs.pop("position_ids", None)
113 | attention_mask = kwargs.pop("attention_mask", None)
114 | if "inputs_embeds" in kwargs:
115 | raise NotImplementedError("`inputs_embeds` is not supported")
116 |
117 | if images is not None:
118 | (
119 | inputs,
120 | position_ids,
121 | attention_mask,
122 | _,
123 | inputs_embeds,
124 | _
125 | ) = self.prepare_inputs_labels_for_multimodal(
126 | inputs,
127 | position_ids,
128 | attention_mask,
129 | None,
130 | None,
131 | images,
132 | image_sizes=image_sizes
133 | )
134 | else:
135 | inputs_embeds = self.get_model().embed_tokens(inputs)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 | return inputs
156 |
157 | AutoConfig.register("llava_mistral", LlavaMistralConfig)
158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
159 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 |
20 | from transformers import AutoConfig, AutoModelForCausalLM, \
21 | MptConfig, MptForCausalLM, MptModel
22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23 |
24 |
25 | class LlavaMptConfig(MptConfig):
26 | model_type = "llava_mpt"
27 |
28 |
29 | class LlavaMptModel(LlavaMetaModel, MptModel):
30 | config_class = LlavaMptConfig
31 |
32 | def __init__(self, config: MptConfig):
33 | config.hidden_size = config.d_model
34 | super(LlavaMptModel, self).__init__(config)
35 |
36 | def embed_tokens(self, x):
37 | return self.wte(x)
38 |
39 |
40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaMptConfig
42 | supports_gradient_checkpointing = True
43 |
44 | def __init__(self, config):
45 | super(MptForCausalLM, self).__init__(config)
46 |
47 | self.transformer = LlavaMptModel(config)
48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.transformer
55 |
56 | def _set_gradient_checkpointing(self, module, value=False):
57 | if isinstance(module, LlavaMptModel):
58 | module.gradient_checkpointing = value
59 |
60 | def forward(
61 | self,
62 | input_ids: Optional[torch.LongTensor] = None,
63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64 | attention_mask: Optional[torch.Tensor] = None,
65 | inputs_embeds: Optional[torch.Tensor] = None,
66 | labels: Optional[torch.Tensor] = None,
67 | use_cache: Optional[bool] = None,
68 | output_attentions: Optional[bool] = None,
69 | output_hidden_states: Optional[bool] = None,
70 | return_dict: Optional[bool] = None,
71 | images=None):
72 |
73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74 |
75 | return super().forward(
76 | input_ids,
77 | past_key_values=past_key_values,
78 | attention_mask=attention_mask,
79 | inputs_embeds=inputs_embeds,
80 | labels=labels,
81 | use_cache=use_cache,
82 | output_attentions=output_attentions,
83 | output_hidden_states=output_hidden_states,
84 | return_dict=return_dict,
85 | )
86 |
87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88 | images = kwargs.pop("images", None)
89 | _inputs = super().prepare_inputs_for_generation(
90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91 | )
92 | _inputs['images'] = images
93 | return _inputs
94 |
95 |
96 | AutoConfig.register("llava_mpt", LlavaMptConfig)
97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
98 |
--------------------------------------------------------------------------------
/llava/model/llava_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from abc import ABC, abstractmethod
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from .multimodal_encoder.builder import build_vision_tower
22 | from .multimodal_projector.builder import build_vision_projector
23 |
24 | from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25 |
26 | from llava.mm_utils import get_anyres_image_grid_shape
27 |
28 | from einops import rearrange
29 |
30 |
31 | class LlavaMetaModel:
32 |
33 | def __init__(self, config):
34 | super(LlavaMetaModel, self).__init__(config)
35 |
36 | if hasattr(config, "mm_vision_tower"):
37 | self.vision_tower = build_vision_tower(config, delay_load=True)
38 | self.mm_projector = build_vision_projector(config)
39 |
40 | if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
41 | self.image_newline = nn.Parameter(
42 | torch.empty(config.hidden_size, dtype=self.dtype)
43 | )
44 |
45 | def get_vision_tower(self):
46 | vision_tower = getattr(self, 'vision_tower', None)
47 | if type(vision_tower) is list:
48 | vision_tower = vision_tower[0]
49 | return vision_tower
50 |
51 | def initialize_vision_modules(self, model_args, fsdp=None):
52 | vision_tower = model_args.vision_tower
53 | mm_vision_select_layer = model_args.mm_vision_select_layer
54 | mm_vision_select_feature = model_args.mm_vision_select_feature
55 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
56 | mm_patch_merge_type = model_args.mm_patch_merge_type
57 |
58 | self.config.mm_vision_tower = vision_tower
59 |
60 | if self.get_vision_tower() is None:
61 | vision_tower = build_vision_tower(model_args)
62 |
63 | if fsdp is not None and len(fsdp) > 0:
64 | self.vision_tower = [vision_tower]
65 | else:
66 | self.vision_tower = vision_tower
67 | else:
68 | if fsdp is not None and len(fsdp) > 0:
69 | vision_tower = self.vision_tower[0]
70 | else:
71 | vision_tower = self.vision_tower
72 | vision_tower.load_model()
73 |
74 | self.config.use_mm_proj = True
75 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
76 | self.config.mm_hidden_size = vision_tower.hidden_size
77 | self.config.mm_vision_select_layer = mm_vision_select_layer
78 | self.config.mm_vision_select_feature = mm_vision_select_feature
79 | self.config.mm_patch_merge_type = mm_patch_merge_type
80 |
81 | if getattr(self, 'mm_projector', None) is None:
82 | self.mm_projector = build_vision_projector(self.config)
83 |
84 | if 'unpad' in mm_patch_merge_type:
85 | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
86 | self.image_newline = nn.Parameter(
87 | torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
88 | )
89 | else:
90 | # In case it is frozen by LoRA
91 | for p in self.mm_projector.parameters():
92 | p.requires_grad = True
93 |
94 | if pretrain_mm_mlp_adapter is not None:
95 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
96 | def get_w(weights, keyword):
97 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
98 |
99 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
100 |
101 |
102 | def unpad_image(tensor, original_size):
103 | """
104 | Unpads a PyTorch tensor of a padded and resized image.
105 |
106 | Args:
107 | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
108 | original_size (tuple): The original size of the image (height, width).
109 |
110 | Returns:
111 | torch.Tensor: The unpadded image tensor.
112 | """
113 | original_width, original_height = original_size
114 | current_height, current_width = tensor.shape[1:]
115 |
116 | original_aspect_ratio = original_width / original_height
117 | current_aspect_ratio = current_width / current_height
118 |
119 | if original_aspect_ratio > current_aspect_ratio:
120 | scale_factor = current_width / original_width
121 | new_height = int(original_height * scale_factor)
122 | padding = (current_height - new_height) // 2
123 | unpadded_tensor = tensor[:, padding:current_height - padding, :]
124 | else:
125 | scale_factor = current_height / original_height
126 | new_width = int(original_width * scale_factor)
127 | padding = (current_width - new_width) // 2
128 | unpadded_tensor = tensor[:, :, padding:current_width - padding]
129 |
130 | return unpadded_tensor
131 |
132 |
133 | class LlavaMetaForCausalLM(ABC):
134 |
135 | @abstractmethod
136 | def get_model(self):
137 | pass
138 |
139 | def get_vision_tower(self):
140 | return self.get_model().get_vision_tower()
141 |
142 | def encode_images(self, images):
143 | image_features = self.get_model().get_vision_tower()(images)
144 | image_features = self.get_model().mm_projector(image_features)
145 | return image_features
146 |
147 |
148 | def temporal_aggregation(self, image_features):
149 | T, N, D = image_features.shape
150 |
151 | ## D1: temporal cat (Just one line!)
152 | image_features = image_features.view(T * N, D) # [T*N D]
153 |
154 | ## D2: spatial pool + temporal cat (Uncomment to use)
155 | # pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
156 | # image_features = rearrange(image_features, 't n d -> t d n')
157 | # image_features = pool2(image_features) # [t d n] -> [t d (n/2)]
158 | # image_features = rearrange(image_features, 't d n -> t n d', t=T)
159 | # image_features = image_features.view(-1, D) # [T*N D]
160 |
161 | ## S1: GAP
162 | # image_features = torch.mean(image_features, dim=0) # [T N D] -> [N D]
163 |
164 | ####### unsqueeze
165 | image_features = image_features.unsqueeze(0) # [1 T*N D]
166 | return image_features
167 |
168 | def prepare_inputs_labels_for_multimodal(
169 | self, input_ids, position_ids, attention_mask, past_key_values, labels,
170 | images, image_sizes=None
171 | ):
172 | vision_tower = self.get_vision_tower()
173 | if vision_tower is None or images is None or input_ids.shape[1] == 1:
174 | return input_ids, position_ids, attention_mask, past_key_values, None, labels
175 |
176 | if type(images) is list or images.ndim == 5:
177 | # images: [T S C H W]
178 | if type(images) is list:
179 | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
180 | concat_images = torch.cat([image for image in images], dim=0) # [TS C H W]
181 | image_features = self.encode_images(concat_images) # [TS N D]
182 | split_sizes = [image.shape[0] for image in images] # T * [S]
183 | image_features = torch.split(image_features, split_sizes, dim=0) # T * [S N D]
184 | mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
185 | image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
186 | if mm_patch_merge_type == 'flat':
187 | image_features = [x.flatten(0, 1) for x in image_features]
188 | elif mm_patch_merge_type.startswith('spatial'):
189 | new_image_features = []
190 | for image_idx, image_feature in enumerate(image_features):
191 | if image_feature.shape[0] > 1:
192 | base_image_feature = image_feature[0] # [N D]
193 | image_feature = image_feature[1:] # [S-1 N D]
194 |
195 | height = width = self.get_vision_tower().num_patches_per_side
196 | assert height * width == base_image_feature.shape[0]
197 | if image_aspect_ratio == 'anyres':
198 | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
199 | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) # [sqrt(S-1) sqrt(S-1) H W D]
200 | else:
201 | raise NotImplementedError
202 | if 'unpad' in mm_patch_merge_type:
203 | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
204 | image_feature = image_feature.flatten(1, 2).flatten(2, 3)
205 | image_feature = unpad_image(image_feature, image_sizes[image_idx])
206 | image_feature = torch.cat((
207 | image_feature,
208 | self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
209 | ), dim=-1)
210 | image_feature = image_feature.flatten(1, 2).transpose(0, 1)
211 | else:
212 | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
213 | image_feature = image_feature.flatten(0, 3)
214 | # image_feature = torch.cat((base_image_feature, image_feature), dim=0)
215 | # Using base feature only for LLaVA-1.6 to avoid extra tokens from high-resolution features.
216 | image_feature = base_image_feature # [576 D]
217 | else:
218 | image_feature = image_feature[0]
219 | if 'unpad' in mm_patch_merge_type:
220 | image_feature = torch.cat((
221 | image_feature,
222 | self.model.image_newline[None].to(image_feature.device)
223 | ), dim=0)
224 | new_image_features.append(image_feature)
225 | image_features = new_image_features
226 | # len=T, [1948 D]
227 | # whwu. concat
228 | image_features = torch.stack(image_features, dim=0) # [T New_N D]
229 | else:
230 | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
231 | else:
232 | image_features = self.encode_images(images).to(self.device) # [T 576 D]
233 |
234 | # whwu: unsqueeze for image_features
235 | image_features = self.temporal_aggregation(image_features)
236 |
237 |
238 | # TODO: image start / end is not implemented here to support pretraining.
239 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
240 | raise NotImplementedError
241 |
242 | # Let's just add dummy tensors if they do not exist,
243 | # it is a headache to deal with None all the time.
244 | # But it is not ideal, and if you have a better idea,
245 | # please open an issue / submit a PR, thanks.
246 | _labels = labels
247 | _position_ids = position_ids
248 | _attention_mask = attention_mask
249 | if attention_mask is None:
250 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
251 | else:
252 | attention_mask = attention_mask.bool()
253 | if position_ids is None:
254 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
255 | if labels is None:
256 | labels = torch.full_like(input_ids, IGNORE_INDEX)
257 |
258 | # remove the padding using attention_mask -- FIXME
259 | _input_ids = input_ids
260 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
261 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
262 |
263 | new_input_embeds = []
264 | new_labels = []
265 | cur_image_idx = 0
266 | for batch_idx, cur_input_ids in enumerate(input_ids):
267 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
268 | # always 1
269 | if num_images == 0:
270 | cur_image_features = image_features[cur_image_idx]
271 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
272 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
273 | new_input_embeds.append(cur_input_embeds)
274 | new_labels.append(labels[batch_idx])
275 | cur_image_idx += 1
276 | continue
277 |
278 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
279 | cur_input_ids_noim = []
280 | cur_labels = labels[batch_idx]
281 | cur_labels_noim = []
282 | for i in range(len(image_token_indices) - 1):
283 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
284 | cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
285 | split_sizes = [x.shape[0] for x in cur_labels_noim]
286 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
287 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
288 | cur_new_input_embeds = []
289 | cur_new_labels = []
290 |
291 | for i in range(num_images + 1):
292 | cur_new_input_embeds.append(cur_input_embeds_no_im[i])
293 | cur_new_labels.append(cur_labels_noim[i])
294 | if i < num_images:
295 | cur_image_features = image_features[cur_image_idx]
296 | cur_image_idx += 1
297 | cur_new_input_embeds.append(cur_image_features)
298 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
299 |
300 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
301 |
302 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) # 627: 576 + 51, 1999: 1948 + 51, 51 is prefix
303 | cur_new_labels = torch.cat(cur_new_labels)
304 |
305 | new_input_embeds.append(cur_new_input_embeds)
306 | new_labels.append(cur_new_labels)
307 |
308 | # Truncate sequences to max length as image embeddings can make the sequence longer
309 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
310 | if tokenizer_model_max_length is not None:
311 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
312 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
313 |
314 | # Combine them
315 | max_len = max(x.shape[0] for x in new_input_embeds)
316 | batch_size = len(new_input_embeds)
317 |
318 | new_input_embeds_padded = []
319 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
320 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
321 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
322 |
323 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
324 | cur_len = cur_new_embed.shape[0]
325 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
326 | new_input_embeds_padded.append(torch.cat((
327 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
328 | cur_new_embed
329 | ), dim=0))
330 | if cur_len > 0:
331 | new_labels_padded[i, -cur_len:] = cur_new_labels
332 | attention_mask[i, -cur_len:] = True
333 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
334 | else:
335 | new_input_embeds_padded.append(torch.cat((
336 | cur_new_embed,
337 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
338 | ), dim=0))
339 | if cur_len > 0:
340 | new_labels_padded[i, :cur_len] = cur_new_labels
341 | attention_mask[i, :cur_len] = True
342 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
343 |
344 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
345 |
346 | if _labels is None:
347 | new_labels = None
348 | else:
349 | new_labels = new_labels_padded
350 |
351 | if _attention_mask is None:
352 | attention_mask = None
353 | else:
354 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
355 |
356 | if _position_ids is None:
357 | position_ids = None
358 |
359 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
360 |
361 |
362 | def initialize_vision_tokenizer(self, model_args, tokenizer):
363 | if model_args.mm_use_im_patch_token:
364 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
365 | self.resize_token_embeddings(len(tokenizer))
366 |
367 | if model_args.mm_use_im_start_end:
368 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
369 | self.resize_token_embeddings(len(tokenizer))
370 |
371 | if num_new_tokens > 0:
372 | input_embeddings = self.get_input_embeddings().weight.data
373 | output_embeddings = self.get_output_embeddings().weight.data
374 |
375 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
376 | dim=0, keepdim=True)
377 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
378 | dim=0, keepdim=True)
379 |
380 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
381 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
382 |
383 | if model_args.tune_mm_mlp_adapter:
384 | for p in self.get_input_embeddings().parameters():
385 | p.requires_grad = True
386 | for p in self.get_output_embeddings().parameters():
387 | p.requires_grad = False
388 |
389 | if model_args.pretrain_mm_mlp_adapter:
390 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
391 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
392 | assert num_new_tokens == 2
393 | if input_embeddings.shape == embed_tokens_weight.shape:
394 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
395 | elif embed_tokens_weight.shape[0] == num_new_tokens:
396 | input_embeddings[-num_new_tokens:] = embed_tokens_weight
397 | else:
398 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
399 | elif model_args.mm_use_im_patch_token:
400 | if model_args.tune_mm_mlp_adapter:
401 | for p in self.get_input_embeddings().parameters():
402 | p.requires_grad = False
403 | for p in self.get_output_embeddings().parameters():
404 | p.requires_grad = False
405 |
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 |
17 | if not delay_load:
18 | self.load_model()
19 | elif getattr(args, 'unfreeze_mm_vision_tower', False):
20 | self.load_model()
21 | else:
22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name, cache_dir='./cache_dir')
23 |
24 | def load_model(self, device_map=None):
25 | if self.is_loaded:
26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27 | return
28 |
29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name, cache_dir='./cache_dir')
30 |
31 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map, cache_dir='./cache_dir')
32 | self.vision_tower.requires_grad_(False)
33 |
34 | self.is_loaded = True
35 |
36 | def feature_select(self, image_forward_outs):
37 | image_features = image_forward_outs.hidden_states[self.select_layer]
38 | if self.select_feature == 'patch':
39 | image_features = image_features[:, 1:]
40 | elif self.select_feature == 'cls_patch':
41 | image_features = image_features
42 | else:
43 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
44 | return image_features
45 |
46 | @torch.no_grad()
47 | def forward(self, images):
48 | if type(images) is list:
49 | image_features = []
50 | for image in images:
51 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
52 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
53 | image_features.append(image_feature)
54 | else:
55 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
56 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
57 |
58 | return image_features
59 |
60 | @property
61 | def dummy_feature(self):
62 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
63 |
64 | @property
65 | def dtype(self):
66 | return self.vision_tower.dtype
67 |
68 | @property
69 | def device(self):
70 | return self.vision_tower.device
71 |
72 | @property
73 | def config(self):
74 | if self.is_loaded:
75 | return self.vision_tower.config
76 | else:
77 | return self.cfg_only
78 |
79 | @property
80 | def hidden_size(self):
81 | return self.config.hidden_size
82 |
83 | @property
84 | def num_patches_per_side(self):
85 | return self.config.image_size // self.config.patch_size
86 |
87 | @property
88 | def num_patches(self):
89 | return (self.config.image_size // self.config.patch_size) ** 2
90 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/eval_qa_activitynet.sh:
--------------------------------------------------------------------------------
1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125"
2 | output_name="llava-v1.5-7b_u4FRS"
3 | pred_path="Activitynet_Zero_Shot_QA/${output_name}/merge.jsonl"
4 | output_dir="Activitynet_Zero_Shot_QA/${output_name}/${gpt_version}"
5 | output_json="Activitynet_Zero_Shot_QA/${output_name}/results_${gpt_version}.json"
6 | api_key="sk-xxx"
7 | num_tasks=25
8 |
9 |
10 |
11 | python3 scripts/gpt_eval/eval_video_qa.py \
12 | --pred_path ${pred_path} \
13 | --output_dir ${output_dir} \
14 | --output_json ${output_json} \
15 | --api_key ${api_key} \
16 | --gpt_version ${gpt_version} \
17 | --num_tasks ${num_tasks}
--------------------------------------------------------------------------------
/scripts/gpt_eval/eval_qa_benchmark.sh:
--------------------------------------------------------------------------------
1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125"
2 | output_name="llava-v1.5-7b_u4FRS"
3 | pred_dir="Video_Benchmark/${output_name}"
4 | output_dir="Video_Benchmark/${output_name}/${gpt_version}"
5 | api_key="sk-xxx"
6 | num_tasks=32
7 |
8 |
9 | # Run the "correctness" evaluation script
10 | python3 scripts/gpt_eval/evaluate_benchmark_1_correctness.py \
11 | --pred_path "${pred_dir}/generic.jsonl" \
12 | --output_dir "${output_dir}/correctness_pred" \
13 | --output_json "${output_dir}/correctness_results.json" \
14 | --api_key ${api_key} \
15 | --gpt_version ${gpt_version} \
16 | --num_tasks ${num_tasks}
17 |
18 |
19 | # Run the "detailed orientation" evaluation script
20 | python3 scripts/gpt_eval/evaluate_benchmark_2_detailed_orientation.py \
21 | --pred_path "${pred_dir}/generic.jsonl" \
22 | --output_dir "${output_dir}/detailed_eval" \
23 | --output_json "${output_dir}/detailed_orientation_results.json" \
24 | --api_key ${api_key} \
25 | --gpt_version ${gpt_version} \
26 | --num_tasks ${num_tasks}
27 |
28 |
29 | # Run the "contextual understanding" evaluation script
30 | python3 scripts/gpt_eval/evaluate_benchmark_3_context.py \
31 | --pred_path "${pred_dir}/generic.jsonl" \
32 | --output_dir "${output_dir}/context_eval" \
33 | --output_json "${output_dir}/contextual_understanding_results.json" \
34 | --api_key ${api_key} \
35 | --gpt_version ${gpt_version} \
36 | --num_tasks ${num_tasks}
37 |
38 |
39 | # Run the "temporal understanding" evaluation script
40 | python3 scripts/gpt_eval/evaluate_benchmark_4_temporal.py \
41 | --pred_path "${pred_dir}/temporal.jsonl" \
42 | --output_dir "${output_dir}/temporal_eval" \
43 | --output_json "${output_dir}/temporal_understanding_results.json" \
44 | --api_key ${api_key} \
45 | --gpt_version ${gpt_version} \
46 | --num_tasks ${num_tasks}
47 |
48 |
49 | # Run the "consistency" evaluation script
50 | python3 scripts/gpt_eval/evaluate_benchmark_5_consistency.py \
51 | --pred_path "${pred_dir}/consistency.jsonl" \
52 | --output_dir "${output_dir}/consistency_eval" \
53 | --output_json "${output_dir}/consistency_results.json" \
54 | --api_key ${api_key} \
55 | --gpt_version ${gpt_version} \
56 | --num_tasks ${num_tasks}
57 |
58 |
59 | echo "All evaluations completed!"
--------------------------------------------------------------------------------
/scripts/gpt_eval/eval_qa_msrvtt.sh:
--------------------------------------------------------------------------------
1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125"
2 | output_name="llava-v1.5-7b_u4FRS"
3 | pred_path="MSRVTT_Zero_Shot_QA/${output_name}/merge.jsonl"
4 | output_dir="MSRVTT_Zero_Shot_QA/${output_name}/${gpt_version}"
5 | output_json="MSRVTT_Zero_Shot_QA/${output_name}/results_${gpt_version}.json"
6 | api_key="sk-xxx"
7 | num_tasks=32
8 |
9 |
10 |
11 | python3 scripts/gpt_eval/eval_video_qa.py \
12 | --pred_path ${pred_path} \
13 | --output_dir ${output_dir} \
14 | --output_json ${output_json} \
15 | --api_key ${api_key} \
16 | --gpt_version ${gpt_version} \
17 | --num_tasks ${num_tasks}
18 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/eval_qa_msvd.sh:
--------------------------------------------------------------------------------
1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125"
2 | output_name="llava-v1.5-7b_u4FRS"
3 | pred_path="MSVD_Zero_Shot_QA/${output_name}/merge.jsonl"
4 | output_dir="MSVD_Zero_Shot_QA/${output_name}/${gpt_version}"
5 | output_json="MSVD_Zero_Shot_QA/${output_name}/results_${gpt_version}.json"
6 | api_key="sk-xxx"
7 | num_tasks=25
8 |
9 |
10 |
11 | python3 scripts/gpt_eval/eval_video_qa.py \
12 | --pred_path ${pred_path} \
13 | --output_dir ${output_dir} \
14 | --output_json ${output_json} \
15 | --api_key ${api_key} \
16 | --gpt_version ${gpt_version} \
17 | --num_tasks ${num_tasks}
--------------------------------------------------------------------------------
/scripts/gpt_eval/eval_video_qa.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import argparse
4 | import json
5 | import ast
6 | from multiprocessing.pool import Pool
7 | from tqdm import tqdm
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.")
13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.")
14 | parser.add_argument("--api_key", default="", help="OpenAI API key.")
15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.")
16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def annotate(prediction_set, caption_files, output_dir, args):
22 | """
23 | Evaluates question and answer pairs using GPT-3
24 | Returns a score for correctness.
25 | """
26 | # Set the OpenAI API key.
27 | openai.api_key = args.api_key
28 | # if args.api_base is not None:
29 | # openai.api_base = args.api_base
30 | for file in caption_files:
31 | key = file[:-5] # Strip file extension
32 | qa_set = prediction_set[key]
33 | question = qa_set['q']
34 | answer = qa_set['a']
35 | pred = qa_set['pred']
36 | try:
37 | # Compute the correctness score
38 | completion = openai.ChatCompletion.create(
39 | model=args.gpt_version,
40 | messages=[
41 | {
42 | "role": "system",
43 | "content":
44 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
45 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
46 | "------"
47 | "##INSTRUCTIONS: "
48 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
49 | "- Consider synonyms or paraphrases as valid matches.\n"
50 | "- Evaluate the correctness of the prediction compared to the answer."
51 | },
52 | {
53 | "role": "user",
54 | "content":
55 | "Please evaluate the following video-based question-answer pair:\n\n"
56 | f"Question: {question}\n"
57 | f"Correct Answer: {answer}\n"
58 | f"Predicted Answer: {pred}\n\n"
59 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
60 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
61 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
62 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
63 | }
64 | ]
65 | )
66 | # Convert response to a Python dictionary.
67 | response_message = completion["choices"][0]["message"]["content"]
68 | response_dict = ast.literal_eval(response_message)
69 | result_qa_pair = [response_dict, qa_set]
70 |
71 | # Save the question-answer pairs to a json file.
72 | with open(f"{output_dir}/{key}.json", "w") as f:
73 | json.dump(result_qa_pair, f)
74 |
75 | except Exception as e:
76 | print(f"Error processing file '{key}': {e}")
77 |
78 |
79 | def main():
80 | """
81 | Main function to control the flow of the program.
82 | """
83 | # Parse arguments.
84 | args = parse_args()
85 |
86 | file = open(args.pred_path)
87 | new_pred_contents = [eval(i.strip()) for i in file.readlines()]
88 |
89 | # pred_contents = [json.loads(line) for line in open(args.pred_path)]
90 | # # Dictionary to store the count of occurrences for each video_id
91 | # video_id_counts = {}
92 | # new_pred_contents = []
93 |
94 | # # Iterate through each sample in pred_contents
95 | # for sample in pred_contents:
96 | # video_id = sample['id']
97 | # if video_id in video_id_counts:
98 | # video_id_counts[video_id] += 1
99 | # else:
100 | # video_id_counts[video_id] = 0
101 |
102 | # # Create a new sample with the modified key
103 | # new_sample = sample
104 | # new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}"
105 | # new_pred_contents.append(new_sample)
106 |
107 | # Generating list of id's and corresponding files
108 | id_list = [x['id'] for x in new_pred_contents]
109 | caption_files = [f"{id}.json" for id in id_list]
110 |
111 | output_dir = args.output_dir
112 | # Generate output directory if not exists.
113 | if not os.path.exists(output_dir):
114 | os.makedirs(output_dir)
115 |
116 | # Preparing dictionary of question-answer sets
117 | prediction_set = {}
118 | for sample in new_pred_contents:
119 | id = sample['id']
120 | question = sample['question']
121 | answer = sample['answer']
122 | pred = sample['pred']
123 | qa_set = {"q": question, "a": answer, "pred": pred}
124 | prediction_set[id] = qa_set
125 |
126 | num_tasks = args.num_tasks
127 |
128 | # While loop to ensure that all captions are processed.
129 | while True:
130 | try:
131 | # Files that have not been processed yet.
132 | completed_files = os.listdir(output_dir)
133 | print(f"completed_files: {len(completed_files)}")
134 |
135 | # Files that have not been processed yet.
136 | incomplete_files = [f for f in caption_files if f not in completed_files]
137 | print(f"incomplete_files: {len(incomplete_files)}")
138 |
139 | # Break the loop when there are no incomplete files
140 | if len(incomplete_files) == 0:
141 | break
142 | if len(incomplete_files) <= num_tasks:
143 | num_tasks = 1
144 |
145 | # Split tasks into parts.
146 | part_len = len(incomplete_files) // num_tasks
147 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
148 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
149 |
150 | # Use a pool of workers to process the files in parallel.
151 | with Pool() as pool:
152 | pool.starmap(annotate, task_args)
153 |
154 | except Exception as e:
155 | print(f"Error: {e}")
156 |
157 | # Combine all the processed files into one
158 | combined_contents = {}
159 | json_path = args.output_json
160 |
161 | # Iterate through json files
162 | for file_name in os.listdir(output_dir):
163 | if file_name.endswith(".json"):
164 | file_path = os.path.join(output_dir, file_name)
165 | with open(file_path, "r") as json_file:
166 | content = json.load(json_file)
167 | combined_contents[file_name[:-5]] = content
168 |
169 | # Write combined content to a json file
170 | with open(json_path, "w") as json_file:
171 | json.dump(combined_contents, json_file)
172 | print("All evaluation completed!")
173 |
174 | # Calculate average score and accuracy
175 | score_sum = 0
176 | count = 0
177 | yes_count = 0
178 | no_count = 0
179 | for key, result in tqdm(combined_contents.items()):
180 | try:
181 | # Computing score
182 | count += 1
183 | score_match = result[0]['score']
184 | score = int(score_match)
185 | score_sum += score
186 |
187 | # Computing accuracy
188 | pred = result[0]['pred']
189 | if "yes" in pred.lower():
190 | yes_count += 1
191 | elif "no" in pred.lower():
192 | no_count += 1
193 | except:
194 | print(result)
195 |
196 | average_score = score_sum / count
197 | accuracy = yes_count / (yes_count + no_count)
198 | print("Yes count:", yes_count)
199 | print("No count:", no_count)
200 | print("Accuracy:", accuracy)
201 | print("Average score:", average_score)
202 |
203 |
204 | if __name__ == "__main__":
205 | main()
206 |
207 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/evaluate_benchmark_1_correctness.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import argparse
4 | import json
5 | import ast
6 | from multiprocessing.pool import Pool
7 | from tqdm import tqdm
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.")
13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.")
14 | parser.add_argument("--api_key", default="", help="OpenAI API key.")
15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.")
16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def annotate(prediction_set, caption_files, output_dir, args):
22 | """
23 | Evaluates question and answer pairs using GPT-3
24 | Returns a score for correctness.
25 | """
26 | # Set the OpenAI API key.
27 | openai.api_key = args.api_key
28 | # if args.api_base is not None:
29 | # openai.api_base = args.api_base
30 | for file in caption_files:
31 | key = file[:-5] # Strip file extension
32 | qa_set = prediction_set[key]
33 | question = qa_set['q']
34 | answer = qa_set['a']
35 | pred = qa_set['pred']
36 | try:
37 | # Compute the correctness score
38 | completion = openai.ChatCompletion.create(
39 | model="gpt-3.5-turbo",
40 | messages=[
41 | {
42 | "role": "system",
43 | "content":
44 | "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. "
45 | "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:"
46 | "------"
47 | "##INSTRUCTIONS: "
48 | "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n"
49 | "- The predicted answer must be factually accurate and align with the video content.\n"
50 | "- Consider synonyms or paraphrases as valid matches.\n"
51 | "- Evaluate the factual accuracy of the prediction compared to the answer."
52 | },
53 | {
54 | "role": "user",
55 | "content":
56 | "Please evaluate the following video-based question-answer pair:\n\n"
57 | f"Question: {question}\n"
58 | f"Correct Answer: {answer}\n"
59 | f"Predicted Answer: {pred}\n\n"
60 | "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. "
61 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING."
62 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
63 | "For example, your response should look like this: {''score': 4.8}."
64 | }
65 | ]
66 | )
67 | # Convert response to a Python dictionary.
68 | response_message = completion["choices"][0]["message"]["content"]
69 | response_dict = ast.literal_eval(response_message)
70 | result_qa_pair = [response_dict, qa_set]
71 |
72 | # Save the question-answer pairs to a json file.
73 | with open(f"{output_dir}/{key}.json", "w") as f:
74 | json.dump(result_qa_pair, f)
75 |
76 | except Exception as e:
77 | print(f"Error processing file '{key}': {e}")
78 |
79 |
80 | def main():
81 | """
82 | Main function to control the flow of the program.
83 | """
84 | # Parse arguments.
85 | args = parse_args()
86 |
87 | pred_contents = [json.loads(line) for line in open(args.pred_path)]
88 |
89 | # Dictionary to store the count of occurrences for each video_id
90 | video_id_counts = {}
91 | new_pred_contents = []
92 |
93 | # Iterate through each sample in pred_contents
94 | for sample in pred_contents:
95 | video_id = sample['video_name']
96 | if video_id in video_id_counts:
97 | video_id_counts[video_id] += 1
98 | else:
99 | video_id_counts[video_id] = 0
100 |
101 | # Create a new sample with the modified key
102 | new_sample = sample
103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
104 | new_pred_contents.append(new_sample)
105 |
106 | # Generating list of id's and corresponding files
107 | id_list = [x['video_name'] for x in new_pred_contents]
108 | caption_files = [f"{id}.json" for id in id_list]
109 |
110 | output_dir = args.output_dir
111 | # Generate output directory if not exists.
112 | if not os.path.exists(output_dir):
113 | os.makedirs(output_dir)
114 |
115 | # Preparing dictionary of question-answer sets
116 | prediction_set = {}
117 | for sample in new_pred_contents:
118 | id = sample['video_name']
119 | question = sample['Q']
120 | answer = sample['A']
121 | pred = sample['pred']
122 | qa_set = {"q": question, "a": answer, "pred": pred}
123 | prediction_set[id] = qa_set
124 |
125 | num_tasks = args.num_tasks
126 |
127 | # While loop to ensure that all captions are processed.
128 | while True:
129 | try:
130 | # Files that have not been processed yet.
131 | completed_files = os.listdir(output_dir)
132 | print(f"completed_files: {len(completed_files)}")
133 |
134 | # Files that have not been processed yet.
135 | incomplete_files = [f for f in caption_files if f not in completed_files]
136 | print(f"incomplete_files: {len(incomplete_files)}")
137 |
138 | # Break the loop when there are no incomplete files
139 | if len(incomplete_files) == 0:
140 | break
141 | if len(incomplete_files) <= num_tasks:
142 | num_tasks = 1
143 |
144 | # Split tasks into parts.
145 | part_len = len(incomplete_files) // num_tasks
146 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
147 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
148 |
149 | # Use a pool of workers to process the files in parallel.
150 | with Pool() as pool:
151 | pool.starmap(annotate, task_args)
152 |
153 | except Exception as e:
154 | print(f"Error: {e}")
155 |
156 | # Combine all the processed files into one
157 | combined_contents = {}
158 | json_path = args.output_json
159 |
160 | # Iterate through json files
161 | for file_name in os.listdir(output_dir):
162 | if file_name.endswith(".json"):
163 | file_path = os.path.join(output_dir, file_name)
164 | with open(file_path, "r") as json_file:
165 | content = json.load(json_file)
166 | combined_contents[file_name[:-5]] = content
167 |
168 | # Write combined content to a json file
169 | with open(json_path, "w") as json_file:
170 | json.dump(combined_contents, json_file)
171 | print("All evaluation completed!")
172 |
173 | # Calculate average score
174 | score_sum = 0
175 | count = 0
176 | for key, result in combined_contents.items():
177 | count += 1
178 | score_match = result[0]['score']
179 | score = int(score_match)
180 | score_sum += score
181 | average_score = score_sum / count
182 |
183 | print("Average score for correctness:", average_score)
184 |
185 |
186 | if __name__ == "__main__":
187 | main()
188 |
189 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/evaluate_benchmark_2_detailed_orientation.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import argparse
4 | import json
5 | import ast
6 | from multiprocessing.pool import Pool
7 | from tqdm import tqdm
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.")
13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.")
14 | parser.add_argument("--api_key", default="", help="OpenAI API key.")
15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.")
16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def annotate(prediction_set, caption_files, output_dir, args):
22 | """
23 | Evaluates question and answer pairs using GPT-3 and
24 | returns a score for detailed orientation.
25 | """
26 | # Set the OpenAI API key.
27 | openai.api_key = args.api_key
28 | # if args.api_base is not None:
29 | # openai.api_base = args.api_base
30 | for file in caption_files:
31 | key = file[:-5] # Strip file extension
32 | qa_set = prediction_set[key]
33 | question = qa_set['q']
34 | answer = qa_set['a']
35 | pred = qa_set['pred']
36 | try:
37 | # Compute the detailed-orientation score
38 | completion = openai.ChatCompletion.create(
39 | model="gpt-3.5-turbo",
40 | messages=[
41 | {
42 | "role": "system",
43 | "content":
44 | "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. "
45 | "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:"
46 | "------"
47 | "##INSTRUCTIONS: "
48 | "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n"
49 | "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n"
50 | "- Consider synonyms or paraphrases as valid matches.\n"
51 | "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity."
52 | },
53 | {
54 | "role": "user",
55 | "content":
56 | "Please evaluate the following video-based question-answer pair:\n\n"
57 | f"Question: {question}\n"
58 | f"Correct Answer: {answer}\n"
59 | f"Predicted Answer: {pred}\n\n"
60 | "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. "
61 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING."
62 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
63 | "For example, your response should look like this: {''score': 4.8}."
64 | }
65 | ]
66 | )
67 | # Convert response to a Python dictionary.
68 | response_message = completion["choices"][0]["message"]["content"]
69 | response_dict = ast.literal_eval(response_message)
70 | result_qa_pair = [response_dict, qa_set]
71 |
72 | # Save the question-answer pairs to a json file.
73 | with open(f"{output_dir}/{key}.json", "w") as f:
74 | json.dump(result_qa_pair, f)
75 |
76 | except Exception as e:
77 | print(f"Error processing file '{key}': {e}")
78 |
79 |
80 | def main():
81 | """
82 | Main function to control the flow of the program.
83 | """
84 | # Parse arguments.
85 | args = parse_args()
86 |
87 | pred_contents = [json.loads(line) for line in open(args.pred_path)]
88 |
89 | # Dictionary to store the count of occurrences for each video_id
90 | video_id_counts = {}
91 | new_pred_contents = []
92 |
93 | # Iterate through each sample in pred_contents
94 | for sample in pred_contents:
95 | video_id = sample['video_name']
96 | if video_id in video_id_counts:
97 | video_id_counts[video_id] += 1
98 | else:
99 | video_id_counts[video_id] = 0
100 |
101 | # Create a new sample with the modified key
102 | new_sample = sample
103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
104 | new_pred_contents.append(new_sample)
105 |
106 | # Generating list of id's and corresponding files
107 | id_list = [x['video_name'] for x in new_pred_contents]
108 | caption_files = [f"{id}.json" for id in id_list]
109 |
110 | output_dir = args.output_dir
111 | # Generate output directory if not exists.
112 | if not os.path.exists(output_dir):
113 | os.makedirs(output_dir)
114 |
115 | # Preparing dictionary of question-answer sets
116 | prediction_set = {}
117 | for sample in new_pred_contents:
118 | id = sample['video_name']
119 | question = sample['Q']
120 | answer = sample['A']
121 | pred = sample['pred']
122 | qa_set = {"q": question, "a": answer, "pred": pred}
123 | prediction_set[id] = qa_set
124 |
125 | num_tasks = args.num_tasks
126 |
127 | # While loop to ensure that all captions are processed.
128 | while True:
129 | try:
130 | # Files that have not been processed yet.
131 | completed_files = os.listdir(output_dir)
132 | print(f"completed_files: {len(completed_files)}")
133 |
134 | # Files that have not been processed yet.
135 | incomplete_files = [f for f in caption_files if f not in completed_files]
136 | print(f"incomplete_files: {len(incomplete_files)}")
137 |
138 | # Break the loop when there are no incomplete files
139 | if len(incomplete_files) == 0:
140 | break
141 | if len(incomplete_files) <= num_tasks:
142 | num_tasks = 1
143 |
144 | # Split tasks into parts.
145 | part_len = len(incomplete_files) // num_tasks
146 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
147 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
148 |
149 | # Use a pool of workers to process the files in parallel.
150 | with Pool() as pool:
151 | pool.starmap(annotate, task_args)
152 |
153 | except Exception as e:
154 | print(f"Error: {e}")
155 |
156 | # Combine all the processed files into one
157 | combined_contents = {}
158 | json_path = args.output_json
159 |
160 | # Iterate through json files
161 | for file_name in os.listdir(output_dir):
162 | if file_name.endswith(".json"):
163 | file_path = os.path.join(output_dir, file_name)
164 | with open(file_path, "r") as json_file:
165 | content = json.load(json_file)
166 | combined_contents[file_name[:-5]] = content
167 |
168 | # Write combined content to a json file
169 | with open(json_path, "w") as json_file:
170 | json.dump(combined_contents, json_file)
171 | print("All evaluation completed!")
172 |
173 | # Calculate average score
174 | score_sum = 0
175 | count = 0
176 | for key, result in combined_contents.items():
177 | count += 1
178 | score_match = result[0]['score']
179 | score = int(score_match)
180 | score_sum += score
181 | average_score = score_sum / count
182 |
183 | print("Average score for detailed orientation:", average_score)
184 |
185 |
186 | if __name__ == "__main__":
187 | main()
188 |
189 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/evaluate_benchmark_3_context.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import argparse
4 | import json
5 | import ast
6 | from multiprocessing.pool import Pool
7 | from tqdm import tqdm
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.")
13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.")
14 | parser.add_argument("--api_key", default="", help="OpenAI API key.")
15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.")
16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def annotate(prediction_set, caption_files, output_dir, args):
22 | """
23 | Evaluates question and answer pairs using GPT-3 and
24 | returns a score for contextual understanding.
25 | """
26 | # Set the OpenAI API key.
27 | openai.api_key = args.api_key
28 | # if args.api_base is not None:
29 | # openai.api_base = args.api_base
30 | for file in caption_files:
31 | key = file[:-5] # Strip file extension
32 | qa_set = prediction_set[key]
33 | question = qa_set['q']
34 | answer = qa_set['a']
35 | pred = qa_set['pred']
36 | try:
37 | # Compute the contextual understanding score
38 | completion = openai.ChatCompletion.create(
39 | model="gpt-3.5-turbo",
40 | messages=[
41 | {
42 | "role": "system",
43 | "content":
44 | "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. "
45 | "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:"
46 | "------"
47 | "##INSTRUCTIONS: "
48 | "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n"
49 | "- The predicted answer must capture the main themes and sentiments of the video.\n"
50 | "- Consider synonyms or paraphrases as valid matches.\n"
51 | "- Provide your evaluation of the contextual understanding of the prediction compared to the answer."
52 | },
53 | {
54 | "role": "user",
55 | "content":
56 | "Please evaluate the following video-based question-answer pair:\n\n"
57 | f"Question: {question}\n"
58 | f"Correct Answer: {answer}\n"
59 | f"Predicted Answer: {pred}\n\n"
60 | "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. "
61 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING."
62 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
63 | "For example, your response should look like this: {''score': 4.8}."
64 | }
65 | ]
66 | )
67 | # Convert response to a Python dictionary.
68 | response_message = completion["choices"][0]["message"]["content"]
69 | response_dict = ast.literal_eval(response_message)
70 | result_qa_pair = [response_dict, qa_set]
71 |
72 | # Save the question-answer pairs to a json file.
73 | with open(f"{output_dir}/{key}.json", "w") as f:
74 | json.dump(result_qa_pair, f)
75 |
76 | except Exception as e:
77 | print(f"Error processing file '{key}': {e}")
78 |
79 |
80 | def main():
81 | """
82 | Main function to control the flow of the program.
83 | """
84 | # Parse arguments.
85 | args = parse_args()
86 |
87 | pred_contents = [json.loads(line) for line in open(args.pred_path)]
88 |
89 | # Dictionary to store the count of occurrences for each video_id
90 | video_id_counts = {}
91 | new_pred_contents = []
92 |
93 | # Iterate through each sample in pred_contents
94 | for sample in pred_contents:
95 | video_id = sample['video_name']
96 | if video_id in video_id_counts:
97 | video_id_counts[video_id] += 1
98 | else:
99 | video_id_counts[video_id] = 0
100 |
101 | # Create a new sample with the modified key
102 | new_sample = sample
103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
104 | new_pred_contents.append(new_sample)
105 |
106 | # Generating list of id's and corresponding files
107 | id_list = [x['video_name'] for x in new_pred_contents]
108 | caption_files = [f"{id}.json" for id in id_list]
109 |
110 | output_dir = args.output_dir
111 | # Generate output directory if not exists.
112 | if not os.path.exists(output_dir):
113 | os.makedirs(output_dir)
114 |
115 | # Preparing dictionary of question-answer sets
116 | prediction_set = {}
117 | for sample in new_pred_contents:
118 | id = sample['video_name']
119 | question = sample['Q']
120 | answer = sample['A']
121 | pred = sample['pred']
122 | qa_set = {"q": question, "a": answer, "pred": pred}
123 | prediction_set[id] = qa_set
124 |
125 | num_tasks = args.num_tasks
126 |
127 | # While loop to ensure that all captions are processed.
128 | while True:
129 | try:
130 | # Files that have not been processed yet.
131 | completed_files = os.listdir(output_dir)
132 | print(f"completed_files: {len(completed_files)}")
133 |
134 | # Files that have not been processed yet.
135 | incomplete_files = [f for f in caption_files if f not in completed_files]
136 | print(f"incomplete_files: {len(incomplete_files)}")
137 |
138 | # Break the loop when there are no incomplete files
139 | if len(incomplete_files) == 0:
140 | break
141 | if len(incomplete_files) <= num_tasks:
142 | num_tasks = 1
143 |
144 | # Split tasks into parts.
145 | part_len = len(incomplete_files) // num_tasks
146 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
147 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
148 |
149 | # Use a pool of workers to process the files in parallel.
150 | with Pool() as pool:
151 | pool.starmap(annotate, task_args)
152 |
153 | except Exception as e:
154 | print(f"Error: {e}")
155 |
156 | # Combine all the processed files into one
157 | combined_contents = {}
158 | json_path = args.output_json
159 |
160 | # Iterate through json files
161 | for file_name in os.listdir(output_dir):
162 | if file_name.endswith(".json"):
163 | file_path = os.path.join(output_dir, file_name)
164 | with open(file_path, "r") as json_file:
165 | content = json.load(json_file)
166 | combined_contents[file_name[:-5]] = content
167 |
168 | # Write combined content to a json file
169 | with open(json_path, "w") as json_file:
170 | json.dump(combined_contents, json_file)
171 | print("All evaluation completed!")
172 |
173 | # Calculate average score
174 | score_sum = 0
175 | count = 0
176 | for key, result in combined_contents.items():
177 | count += 1
178 | score_match = result[0]['score']
179 | score = int(score_match)
180 | score_sum += score
181 | average_score = score_sum / count
182 |
183 | print("Average score for contextual understanding:", average_score)
184 |
185 |
186 | if __name__ == "__main__":
187 | main()
188 |
189 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/evaluate_benchmark_4_temporal.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import argparse
4 | import json
5 | import ast
6 | from multiprocessing.pool import Pool
7 | from tqdm import tqdm
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.")
13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.")
14 | parser.add_argument("--api_key", default="", help="OpenAI API key.")
15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.")
16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def annotate(prediction_set, caption_files, output_dir, args):
22 | """
23 | Evaluates question and answer pairs using GPT-3 and
24 | returns a score for temporal understanding.
25 | """
26 | # Set the OpenAI API key.
27 | openai.api_key = args.api_key
28 | # if args.api_base is not None:
29 | # openai.api_base = args.api_base
30 | for file in caption_files:
31 | key = file[:-5] # Strip file extension
32 | qa_set = prediction_set[key]
33 | question = qa_set['q']
34 | answer = qa_set['a']
35 | pred = qa_set['pred']
36 | try:
37 | # Compute the temporal understanding score
38 | completion = openai.ChatCompletion.create(
39 | model="gpt-3.5-turbo",
40 | messages=[
41 | {
42 | "role": "system",
43 | "content":
44 | "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. "
45 | "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:"
46 | "------"
47 | "##INSTRUCTIONS: "
48 | "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n"
49 | "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n"
50 | "- Evaluate the temporal accuracy of the prediction compared to the answer."
51 | },
52 | {
53 | "role": "user",
54 | "content":
55 | "Please evaluate the following video-based question-answer pair:\n\n"
56 | f"Question: {question}\n"
57 | f"Correct Answer: {answer}\n"
58 | f"Predicted Answer: {pred}\n\n"
59 | "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. "
60 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING."
61 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
62 | "For example, your response should look like this: {''score': 4.8}."
63 | }
64 | ]
65 | )
66 | # Convert response to a Python dictionary.
67 | response_message = completion["choices"][0]["message"]["content"]
68 | response_dict = ast.literal_eval(response_message)
69 | result_qa_pair = [response_dict, qa_set]
70 |
71 | # Save the question-answer pairs to a json file.
72 | with open(f"{output_dir}/{key}.json", "w") as f:
73 | json.dump(result_qa_pair, f)
74 |
75 | except Exception as e:
76 | print(f"Error processing file '{key}': {e}")
77 |
78 |
79 | def main():
80 | """
81 | Main function to control the flow of the program.
82 | """
83 | # Parse arguments.
84 | args = parse_args()
85 |
86 | pred_contents = [json.loads(line) for line in open(args.pred_path)]
87 |
88 | # Dictionary to store the count of occurrences for each video_id
89 | video_id_counts = {}
90 | new_pred_contents = []
91 |
92 | # Iterate through each sample in pred_contents
93 | for sample in pred_contents:
94 | video_id = sample['video_name']
95 | if video_id in video_id_counts:
96 | video_id_counts[video_id] += 1
97 | else:
98 | video_id_counts[video_id] = 0
99 |
100 | # Create a new sample with the modified key
101 | new_sample = sample
102 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
103 | new_pred_contents.append(new_sample)
104 |
105 | # Generating list of id's and corresponding files
106 | id_list = [x['video_name'] for x in new_pred_contents]
107 | caption_files = [f"{id}.json" for id in id_list]
108 |
109 | output_dir = args.output_dir
110 | # Generate output directory if not exists.
111 | if not os.path.exists(output_dir):
112 | os.makedirs(output_dir)
113 |
114 | # Preparing dictionary of question-answer sets
115 | prediction_set = {}
116 | for sample in new_pred_contents:
117 | id = sample['video_name']
118 | question = sample['Q']
119 | answer = sample['A']
120 | pred = sample['pred']
121 | qa_set = {"q": question, "a": answer, "pred": pred}
122 | prediction_set[id] = qa_set
123 |
124 | num_tasks = args.num_tasks
125 |
126 | # While loop to ensure that all captions are processed.
127 | while True:
128 | try:
129 | # Files that have not been processed yet.
130 | completed_files = os.listdir(output_dir)
131 | print(f"completed_files: {len(completed_files)}")
132 |
133 | # Files that have not been processed yet.
134 | incomplete_files = [f for f in caption_files if f not in completed_files]
135 | print(f"incomplete_files: {len(incomplete_files)}")
136 |
137 | # Break the loop when there are no incomplete files
138 | if len(incomplete_files) == 0:
139 | break
140 | if len(incomplete_files) <= num_tasks:
141 | num_tasks = 1
142 |
143 | # Split tasks into parts.
144 | part_len = len(incomplete_files) // num_tasks
145 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
146 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
147 |
148 | # Use a pool of workers to process the files in parallel.
149 | with Pool() as pool:
150 | pool.starmap(annotate, task_args)
151 |
152 | except Exception as e:
153 | print(f"Error: {e}")
154 |
155 | # Combine all the processed files into one
156 | combined_contents = {}
157 | json_path = args.output_json
158 |
159 | # Iterate through json files
160 | for file_name in os.listdir(output_dir):
161 | if file_name.endswith(".json"):
162 | file_path = os.path.join(output_dir, file_name)
163 | with open(file_path, "r") as json_file:
164 | content = json.load(json_file)
165 | combined_contents[file_name[:-5]] = content
166 |
167 | # Write combined content to a json file
168 | with open(json_path, "w") as json_file:
169 | json.dump(combined_contents, json_file)
170 | print("All evaluation completed!")
171 |
172 | # Calculate average score
173 | score_sum = 0
174 | count = 0
175 | for key, result in combined_contents.items():
176 | count += 1
177 | score_match = result[0]['score']
178 | score = int(score_match)
179 | score_sum += score
180 | average_score = score_sum / count
181 |
182 | print("Average score temporal understanding:", average_score)
183 |
184 |
185 | if __name__ == "__main__":
186 | main()
187 |
188 |
--------------------------------------------------------------------------------
/scripts/gpt_eval/evaluate_benchmark_5_consistency.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | import argparse
4 | import json
5 | import ast
6 | from multiprocessing.pool import Pool
7 | from tqdm import tqdm
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.")
12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.")
13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.")
14 | parser.add_argument("--api_key", default="", help="OpenAI API key.")
15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.")
16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def annotate(prediction_set, caption_files, output_dir, args):
22 | """
23 | Evaluates question and answer pairs using GPT-3 and
24 | returns a score for consistency.
25 | """
26 | # Set the OpenAI API key.
27 | openai.api_key = args.api_key
28 | # if args.api_base is not None:
29 | # openai.api_base = args.api_base
30 | for file in caption_files:
31 | key = file[:-5] # Strip file extension
32 | qa_set = prediction_set[key]
33 | question1 = qa_set['q1']
34 | question2 = qa_set['q2']
35 | answer = qa_set['a']
36 | pred1 = qa_set['pred1']
37 | pred2 = qa_set['pred2']
38 | try:
39 | # Compute the consistency score
40 | completion = openai.ChatCompletion.create(
41 | model="gpt-3.5-turbo",
42 | messages=[
43 | {
44 | "role": "system",
45 | "content":
46 | "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. "
47 | "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ."
48 | "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:"
49 | "------"
50 | "##INSTRUCTIONS: "
51 | "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n"
52 | "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n"
53 | "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n"
54 | "- Evaluate the consistency of the two predicted answers compared to the correct answer."
55 | },
56 | {
57 | "role": "user",
58 | "content":
59 | "Please evaluate the following video-based question-answer pair:\n\n"
60 | f"Question 1: {question1}\n"
61 | f"Question 2: {question2}\n"
62 | f"Correct Answer: {answer}\n"
63 | f"Predicted Answer to Question 1: {pred1}\n"
64 | f"Predicted Answer to Question 2: {pred2}\n\n"
65 | "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. "
66 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING."
67 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
68 | "For example, your response should look like this: {''score': 4.8}."
69 | }
70 | ]
71 | )
72 | # Convert response to a Python dictionary.
73 | response_message = completion["choices"][0]["message"]["content"]
74 | response_dict = ast.literal_eval(response_message)
75 | result_qa_pair = [response_dict, qa_set]
76 |
77 | # Save the question-answer pairs to a json file.
78 | with open(f"{output_dir}/{key}.json", "w") as f:
79 | json.dump(result_qa_pair, f)
80 |
81 | except Exception as e:
82 | print(f"Error processing file '{key}': {e}")
83 |
84 |
85 | def main():
86 | """
87 | Main function to control the flow of the program.
88 | """
89 | # Parse arguments.
90 | args = parse_args()
91 |
92 | pred_contents = [json.loads(line) for line in open(args.pred_path)]
93 |
94 | # Dictionary to store the count of occurrences for each video_id
95 | video_id_counts = {}
96 | new_pred_contents = []
97 |
98 | # Iterate through each sample in pred_contents
99 | for sample in pred_contents:
100 | video_id = sample['video_name']
101 | if video_id in video_id_counts:
102 | video_id_counts[video_id] += 1
103 | else:
104 | video_id_counts[video_id] = 0
105 |
106 | # Create a new sample with the modified key
107 | new_sample = sample
108 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}"
109 | new_pred_contents.append(new_sample)
110 |
111 | # Generating list of id's and corresponding files
112 | id_list = [x['video_name'] for x in new_pred_contents]
113 | caption_files = [f"{id}.json" for id in id_list]
114 |
115 | output_dir = args.output_dir
116 | # Generate output directory if not exists.
117 | if not os.path.exists(output_dir):
118 | os.makedirs(output_dir)
119 |
120 | # Preparing dictionary of question-answer sets
121 | prediction_set = {}
122 | for sample in new_pred_contents:
123 | id = sample['video_name']
124 | question1 = sample['Q1']
125 | question2 = sample['Q2']
126 | answer = sample['A']
127 | pred1 = sample['pred1']
128 | pred2 = sample['pred2']
129 | qa_set = {"q1": question1, "q2": question2, "a": answer, "pred1": pred1, "pred2": pred2}
130 | prediction_set[id] = qa_set
131 |
132 | num_tasks = args.num_tasks
133 |
134 | # While loop to ensure that all captions are processed.
135 | while True:
136 | try:
137 | # Files that have not been processed yet.
138 | completed_files = os.listdir(output_dir)
139 | print(f"completed_files: {len(completed_files)}")
140 |
141 | # Files that have not been processed yet.
142 | incomplete_files = [f for f in caption_files if f not in completed_files]
143 | print(f"incomplete_files: {len(incomplete_files)}")
144 |
145 | # Break the loop when there are no incomplete files
146 | if len(incomplete_files) == 0:
147 | break
148 | if len(incomplete_files) <= num_tasks:
149 | num_tasks = 1
150 |
151 | # Split tasks into parts.
152 | part_len = len(incomplete_files) // num_tasks
153 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
154 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
155 |
156 | # Use a pool of workers to process the files in parallel.
157 | with Pool() as pool:
158 | pool.starmap(annotate, task_args)
159 |
160 | except Exception as e:
161 | print(f"Error: {e}")
162 |
163 | # Combine all the processed files into one
164 | combined_contents = {}
165 | json_path = args.output_json
166 |
167 | # Iterate through json files
168 | for file_name in os.listdir(output_dir):
169 | if file_name.endswith(".json"):
170 | file_path = os.path.join(output_dir, file_name)
171 | with open(file_path, "r") as json_file:
172 | content = json.load(json_file)
173 | combined_contents[file_name[:-5]] = content
174 |
175 | # Write combined content to a json file
176 | with open(json_path, "w") as json_file:
177 | json.dump(combined_contents, json_file)
178 | print("All evaluation completed!")
179 |
180 | # Calculate average score
181 | score_sum = 0
182 | count = 0
183 | for key, result in combined_contents.items():
184 | count += 1
185 | score_match = result[0]['score']
186 | score = int(score_match)
187 | score_sum += score
188 | average_score = score_sum / count
189 |
190 | print("Average score for consistency:", average_score)
191 |
192 |
193 | if __name__ == "__main__":
194 | main()
195 |
196 |
--------------------------------------------------------------------------------
/scripts/infer_video/run_benchmark_consistency_qa.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-7b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/VideoChatGPT_Test_Videos"
6 | gt_file="${GPT_Zero_Shot_QA}/consistency_qa.json"
7 | output_dir="output/Benchmark_Consistency_QA/${CKPT_NAME}_u${num_frames}FRS"
8 |
9 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
10 | IFS=',' read -ra GPULIST <<< "$gpu_list"
11 |
12 | CHUNKS=${#GPULIST[@]}
13 |
14 |
15 | for IDX in $(seq 0 $((CHUNKS-1))); do
16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_benchmark_consistency.py \
17 | --video_dir ${video_dir} \
18 | --gt_file ${gt_file} \
19 | --output_dir ${output_dir} \
20 | --output_name ${CHUNKS}_${IDX} \
21 | --model_name ${model_path} \
22 | --num_chunks $CHUNKS \
23 | --num_frames $num_frames \
24 | --conv-mode vicuna_v1 \
25 | --temperature 0 \
26 | --chunk_idx $IDX &
27 | done
28 |
29 | wait
30 |
31 | output_file=${output_dir}/consistency.jsonl
32 |
33 | # Clear out the output file if it exists.
34 | > "$output_file"
35 |
36 | # Loop through the indices and concatenate each file.
37 | for IDX in $(seq 0 $((CHUNKS-1))); do
38 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
39 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_benchmark_generic_qa.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-7b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/VideoChatGPT_Test_Videos"
6 | gt_file="${GPT_Zero_Shot_QA}/generic_qa.json"
7 | output_dir="output/Benchmark_Generic_QA/${CKPT_NAME}_u${num_frames}FRS"
8 |
9 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
10 | IFS=',' read -ra GPULIST <<< "$gpu_list"
11 |
12 | CHUNKS=${#GPULIST[@]}
13 |
14 |
15 | for IDX in $(seq 0 $((CHUNKS-1))); do
16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_benchmark_general.py \
17 | --video_dir ${video_dir} \
18 | --gt_file ${gt_file} \
19 | --output_dir ${output_dir} \
20 | --output_name ${CHUNKS}_${IDX} \
21 | --model_name ${model_path} \
22 | --num_chunks $CHUNKS \
23 | --num_frames $num_frames \
24 | --conv-mode vicuna_v1 \
25 | --temperature 0 \
26 | --chunk_idx $IDX &
27 | done
28 |
29 | wait
30 |
31 | output_file=${output_dir}/generic.jsonl
32 |
33 | # Clear out the output file if it exists.
34 | > "$output_file"
35 |
36 | # Loop through the indices and concatenate each file.
37 | for IDX in $(seq 0 $((CHUNKS-1))); do
38 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
39 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_benchmark_temporal_qa.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-7b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/VideoChatGPT_Test_Videos"
6 | gt_file="${GPT_Zero_Shot_QA}/temporal_qa.json"
7 | output_dir="output/Benchmark_Temporal_QA/${CKPT_NAME}_u${num_frames}FRS"
8 |
9 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
10 | IFS=',' read -ra GPULIST <<< "$gpu_list"
11 |
12 | CHUNKS=${#GPULIST[@]}
13 |
14 |
15 | for IDX in $(seq 0 $((CHUNKS-1))); do
16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_benchmark_general.py \
17 | --video_dir ${video_dir} \
18 | --gt_file ${gt_file} \
19 | --output_dir ${output_dir} \
20 | --output_name ${CHUNKS}_${IDX} \
21 | --model_name ${model_path} \
22 | --num_chunks $CHUNKS \
23 | --num_frames $num_frames \
24 | --conv-mode vicuna_v1 \
25 | --temperature 0 \
26 | --chunk_idx $IDX &
27 | done
28 |
29 | wait
30 |
31 | output_file=${output_dir}/temporal.jsonl
32 |
33 | # Clear out the output file if it exists.
34 | > "$output_file"
35 |
36 | # Loop through the indices and concatenate each file.
37 | for IDX in $(seq 0 $((CHUNKS-1))); do
38 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
39 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_one_video.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-7b"
4 | video_path="video_samples/sample_demo_23.mp4"
5 |
6 | CUDA_VISIBLE_DEVICES=0 python3 llava/eval/single_video_inference.py \
7 | --video_path ${video_path} \
8 | --model_name ${model_path} \
9 | --num_frames $num_frames \
10 | --conv-mode vicuna_v1
11 |
12 |
--------------------------------------------------------------------------------
/scripts/infer_video/run_qa_anet_13B.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-13b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-13b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/all_test"
6 | gt_file_question="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_q.json"
7 | gt_file_answers="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_a.json"
8 | output_dir="output/Activitynet_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS"
9 |
10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
11 | IFS=',' read -ra GPULIST <<< "$gpu_list"
12 |
13 | CHUNKS=${#GPULIST[@]}
14 |
15 |
16 | for IDX in $(seq 0 $((CHUNKS-1))); do
17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \
18 | --video_dir ${video_dir} \
19 | --gt_file_question ${gt_file_question} \
20 | --gt_file_answers ${gt_file_answers} \
21 | --output_dir ${output_dir} \
22 | --output_name ${CHUNKS}_${IDX} \
23 | --model_name ${model_path} \
24 | --num_chunks $CHUNKS \
25 | --num_frames $num_frames \
26 | --conv-mode vicuna_v1 \
27 | --temperature 0 \
28 | --chunk_idx $IDX &
29 | done
30 |
31 | wait
32 |
33 | output_file=${output_dir}/merge.jsonl
34 |
35 | # Clear out the output file if it exists.
36 | > "$output_file"
37 |
38 | # Loop through the indices and concatenate each file.
39 | for IDX in $(seq 0 $((CHUNKS-1))); do
40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
41 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_qa_anet_7B.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-7b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/all_test"
6 | gt_file_question="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_q.json"
7 | gt_file_answers="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_a.json"
8 | output_dir="output/Activitynet_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS"
9 |
10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
11 | IFS=',' read -ra GPULIST <<< "$gpu_list"
12 |
13 | CHUNKS=${#GPULIST[@]}
14 |
15 |
16 | for IDX in $(seq 0 $((CHUNKS-1))); do
17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \
18 | --video_dir ${video_dir} \
19 | --gt_file_question ${gt_file_question} \
20 | --gt_file_answers ${gt_file_answers} \
21 | --output_dir ${output_dir} \
22 | --output_name ${CHUNKS}_${IDX} \
23 | --model_name ${model_path} \
24 | --num_chunks $CHUNKS \
25 | --num_frames $num_frames \
26 | --conv-mode vicuna_v1 \
27 | --temperature 0 \
28 | --chunk_idx $IDX &
29 | done
30 |
31 | wait
32 |
33 | output_file=${output_dir}/merge.jsonl
34 |
35 | # Clear out the output file if it exists.
36 | > "$output_file"
37 |
38 | # Loop through the indices and concatenate each file.
39 | for IDX in $(seq 0 $((CHUNKS-1))); do
40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
41 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_qa_msrvtt_13B.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-13b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-13b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/videos/all"
6 | gt_file_question="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_q.json"
7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_a.json"
8 | output_dir="output/MSRVTT_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS"
9 |
10 |
11 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
12 | IFS=',' read -ra GPULIST <<< "$gpu_list"
13 |
14 | CHUNKS=${#GPULIST[@]}
15 |
16 |
17 | for IDX in $(seq 0 $((CHUNKS-1))); do
18 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \
19 | --video_dir ${video_dir} \
20 | --gt_file_question ${gt_file_question} \
21 | --gt_file_answers ${gt_file_answers} \
22 | --output_dir ${output_dir} \
23 | --output_name ${CHUNKS}_${IDX} \
24 | --model_name ${model_path} \
25 | --num_chunks $CHUNKS \
26 | --num_frames $num_frames \
27 | --conv-mode vicuna_v1 \
28 | --temperature 0 \
29 | --chunk_idx $IDX &
30 | done
31 |
32 | wait
33 |
34 | output_file=${output_dir}/merge.jsonl
35 |
36 | # Clear out the output file if it exists.
37 | > "$output_file"
38 |
39 | # Loop through the indices and concatenate each file.
40 | for IDX in $(seq 0 $((CHUNKS-1))); do
41 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
42 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_qa_msrvtt_7B.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-7b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/videos/all"
6 | gt_file_question="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_q.json"
7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_a.json"
8 | output_dir="output/MSRVTT_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS"
9 |
10 |
11 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
12 | IFS=',' read -ra GPULIST <<< "$gpu_list"
13 |
14 | CHUNKS=${#GPULIST[@]}
15 |
16 |
17 | for IDX in $(seq 0 $((CHUNKS-1))); do
18 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \
19 | --video_dir ${video_dir} \
20 | --gt_file_question ${gt_file_question} \
21 | --gt_file_answers ${gt_file_answers} \
22 | --output_dir ${output_dir} \
23 | --output_name ${CHUNKS}_${IDX} \
24 | --model_name ${model_path} \
25 | --num_chunks $CHUNKS \
26 | --num_frames $num_frames \
27 | --conv-mode vicuna_v1 \
28 | --temperature 0 \
29 | --chunk_idx $IDX &
30 | done
31 |
32 | wait
33 |
34 | output_file=${output_dir}/merge.jsonl
35 |
36 | # Clear out the output file if it exists.
37 | > "$output_file"
38 |
39 | # Loop through the indices and concatenate each file.
40 | for IDX in $(seq 0 $((CHUNKS-1))); do
41 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
42 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_qa_msvd_13B.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-13b"
2 | num_frames=4
3 | model_path="ckpt/llava-v1.5-13b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/videos"
6 | gt_file_question="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_q.json"
7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_a.json"
8 | output_dir="output/MSVD_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS"
9 |
10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
11 | IFS=',' read -ra GPULIST <<< "$gpu_list"
12 |
13 | CHUNKS=${#GPULIST[@]}
14 |
15 |
16 | for IDX in $(seq 0 $((CHUNKS-1))); do
17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \
18 | --video_dir ${video_dir} \
19 | --gt_file_question ${gt_file_question} \
20 | --gt_file_answers ${gt_file_answers} \
21 | --output_dir ${output_dir} \
22 | --output_name ${CHUNKS}_${IDX} \
23 | --model_name ${model_path} \
24 | --num_chunks $CHUNKS \
25 | --num_frames $num_frames \
26 | --conv-mode vicuna_v1 \
27 | --temperature 0 \
28 | --chunk_idx $IDX &
29 | done
30 |
31 | wait
32 |
33 | output_file=${output_dir}/merge.jsonl
34 |
35 | # Clear out the output file if it exists.
36 | > "$output_file"
37 |
38 | # Loop through the indices and concatenate each file.
39 | for IDX in $(seq 0 $((CHUNKS-1))); do
40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
41 | done
--------------------------------------------------------------------------------
/scripts/infer_video/run_qa_msvd_7B.sh:
--------------------------------------------------------------------------------
1 | CKPT_NAME="llava-v1.5-7b"
2 | num_frames=4
3 | model_path="../ckpt/llava-v1.5-7b"
4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA"
5 | video_dir="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/videos"
6 | gt_file_question="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_q.json"
7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_a.json"
8 | output_dir="output/MSVD_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS"
9 |
10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
11 | IFS=',' read -ra GPULIST <<< "$gpu_list"
12 |
13 | CHUNKS=${#GPULIST[@]}
14 |
15 |
16 | for IDX in $(seq 0 $((CHUNKS-1))); do
17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \
18 | --video_dir ${video_dir} \
19 | --gt_file_question ${gt_file_question} \
20 | --gt_file_answers ${gt_file_answers} \
21 | --output_dir ${output_dir} \
22 | --output_name ${CHUNKS}_${IDX} \
23 | --model_name ${model_path} \
24 | --num_chunks $CHUNKS \
25 | --num_frames $num_frames \
26 | --conv-mode vicuna_v1 \
27 | --temperature 0 \
28 | --chunk_idx $IDX &
29 | done
30 |
31 | wait
32 |
33 | output_file=${output_dir}/merge.jsonl
34 |
35 | # Clear out the output file if it exists.
36 | > "$output_file"
37 |
38 | # Loop through the indices and concatenate each file.
39 | for IDX in $(seq 0 $((CHUNKS-1))); do
40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file"
41 | done
--------------------------------------------------------------------------------