├── .gitignore ├── LICENSE ├── README.md ├── ckpt └── .place_image_mllm_ckpt ├── cog.yaml ├── figs ├── ben.png ├── pipeline.png ├── sota.png ├── table1.png ├── table2.png ├── table3.png └── vis.png ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── model_utils.py │ ├── run_inference_benchmark_consistency.py │ ├── run_inference_benchmark_general.py │ ├── run_inference_qa.py │ └── single_video_inference.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ └── llava_mpt.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py └── utils.py └── scripts ├── gpt_eval ├── eval_qa_activitynet.sh ├── eval_qa_benchmark.sh ├── eval_qa_msrvtt.sh ├── eval_qa_msvd.sh ├── eval_video_qa.py ├── evaluate_benchmark_1_correctness.py ├── evaluate_benchmark_2_detailed_orientation.py ├── evaluate_benchmark_3_context.py ├── evaluate_benchmark_4_temporal.py └── evaluate_benchmark_5_consistency.py └── infer_video ├── run_benchmark_consistency_qa.sh ├── run_benchmark_generic_qa.sh ├── run_benchmark_temporal_qa.sh ├── run_one_video.sh ├── run_qa_anet_13B.sh ├── run_qa_anet_7B.sh ├── run_qa_msrvtt_13B.sh ├── run_qa_msrvtt_7B.sh ├── run_qa_msvd_13B.sh └── run_qa_msvd_7B.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

FreeVA: Offline MLLM as Training-Free Video Assistant

4 | 5 | 6 | [![arXiv](https://img.shields.io/badge/Arxiv-2405.07798-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2405.07798) 7 | 8 | #### [Wenhao Wu](https://whwu95.github.io/) 9 | 10 | #### [The University of Sydney](https://www.sydney.edu.au/) 11 | 12 |
If you like our project, please give us a star ⭐ on GitHub for latest update.
13 |
14 | 15 | *** 16 | 17 | 18 | Welcome to **FreeVA** - a plug-and-play, simple yet effective study exploring the utilization of existing image MLLMs as video conversational models in a training-free manner. ⚡The core code can be just one line! 19 | 20 | 21 | ## Main Take-aways💡 22 | The study provides an essential, yet must-know baseline, and reveals several surprising findings: 23 | 1) 😄FreeVA, leveraging only offline image-based MLLM without additional training, excels in zero-shot video question-answering (e.g., MSVD-QA, ActivityNet-QA, and MSRVTT-QA), even surpassing state-of-the-art methods that involve video instruction tuning. 24 | 2) 🤔While mainstream video-based MLLMs typically initialize with an image-based MLLM (\eg, LLaVA) and then fine-tune using video instruction tuning, the study indicates that utilizing the widely adopted VideoInstruct-100K for video instruction tuning doesn't actually lead to better performance compared to not training at all. 25 | 3) ⚠️The commonly used evaluation metrics in existing works are significantly influenced by changes in the GPT-3.5 API version over time. If ignored, this could affect the fairness and uniformity of comparisons between different methods and impact the analysis and judgment of researchers in the field. 26 | 27 | 28 | 29 | ## 📢News 30 | - [x] **[Jun 7, 2024]** FreeVA results for more MLLMs, such as [InstructBLIP](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip), [InternVL](https://github.com/OpenGVLab/InternVL), and [Dense Connector](https://github.com/HJYao00/DenseConnector), are provided. 31 | - [x] **[May 14, 2024]** [Preprint](https://arxiv.org/pdf/2405.07798) has been released. 32 | - [x] **[May 13, 2024]** Code has been released. Thanks for your star 😝. 33 | 34 | 35 | 36 | ## Overview 37 | 38 |
39 | 40 | An illustration of (a) an overview of the image MLLM inference process and (b) our proposed FreeVA for zero-shot video inference using existing image MLLMs. 41 | 42 | 43 | 44 |
45 | 46 | 47 | 48 | 49 | ## Results 50 | 51 | 52 | 53 | ### Quantitative Results 54 | 55 |
56 | 57 | 58 | 59 |
60 | 61 | ### Qualitative Results 62 |
63 | 64 |
65 | 66 | 67 | ## Empirical Study📊 68 | 69 |
70 | 71 | 72 | 73 |
74 | 75 | 76 | ## Running Video QA💬 77 | 78 | FreeVA can be applied to any image-based MLLM, and its core code is straightforward, simply involving a temporal aggregation. Please refer to [temporal_aggregation](./llava/model/llava_arch.py#L148) for implementation details. 79 | 80 | Below, we provide guidance on running the code using LLaVA-1.5 as an example. 81 | 82 | Before running: 83 | 1) Please refer to [cog.yaml](./cog.yaml) for environment configuration regarding LLaVA-1.5. 84 | 2) Please download the LLaVA model in advance and place it in the "ckpt" folder, for example, "llava-v1.5-7b" or "llava-v1.5-13b". 85 | 3) Please refer to [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT) for downloading the evaluation dataset and corresponding annotations. 86 | 87 | ### Zero-shot Video Question-Answering 88 | To enhance evaluation efficiency, we provide a script for single-machine *multi-GPU* evaluation. Taking the ActivityNet-QA dataset as an example, the specific steps to run the script are as follows: 89 | 90 | **Step1: Obtain the prediction file.** 91 | ```sh 92 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_qa_anet_7B.sh 93 | ``` 94 | You will get a predtion file *merge.jsonl*. 95 | 96 | **Step2: GPT-assistant evaluation** 97 | 98 | Running the following command will provide you with accuracy and score. 99 | Before running, please fill in your OpenAI API Key, the prediction file address for Step 1, the number of worker processes for multiprocessing (to accelerate inference), and the version number of GPT-3.5. 100 | 101 | ⚠️*Note: The default version of gpt-3.5-turbo has been updated three times in chronological order: gpt-3.5-turbo-0301, gpt-3.5-turbo-0613, gpt-3.5-turbo-0125, with significant performance differences between versions.* 102 | ```sh 103 | bash scripts/gpt_eval/eval_qa_activitynet.sh 104 | ``` 105 | 106 | *The evaluation process for other datasets (MSRVTT-QA, MSVD-QA) follows the same procedure. Please refer to the steps outlined above.* 107 | 108 | ### Video-Based Text Generation Performance 109 | 110 | The generative performance benchmark, include five evaluation metrics such as Correctness of Information, Detail Orientation, Contextual Understanding, Temporal Understanding, and Consistency. 111 | 112 | **Step1: Obtain the prediction file.** 113 | ```sh 114 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_benchmark_generic_qa.sh 115 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_benchmark_temporal_qa.sh 116 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/infer_video/run_benchmark_consistency_qa.sh 117 | ``` 118 | You will get the predtion file *generic.jsonl, temporal.jsonl, consistency.jsonl*, respectively. 119 | 120 | 121 | **Step2: GPT-assistant evaluation** 122 | 123 | Running the following script will generate these five metrics. 124 | 125 | Before running, please fill in your OpenAI API Key, the prediction file address for Step 1, the number of worker processes for multiprocessing (to accelerate inference), and the version number of GPT-3.5. 126 | 127 | ```sh 128 | bash scripts/gpt_eval/eval_qa_benchmark.sh 129 | ``` 130 | 131 | ## Acknowledgement🙏 132 | We extend our sincere gratitude to the following awesome projects: 133 | - [LLaVA](https://github.com/haotian-liu/LLaVA): Visual Instruction Tuning 134 | - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): Towards Detailed Video Understanding via Large Vision and Language Models 135 | 136 | 137 | ## BibTeX & Citation 138 | 139 | If you use our code in your research or wish to refer to the results, please star 🌟 this repo and use the following BibTeX 📑 entry. 140 | 141 | ```bibtex 142 | @article{FreeVA, 143 | title={FreeVA: Offline MLLM as Training-Free Video Assistant}, 144 | author={Wu, Wenhao}, 145 | booktitle={arXiv preprint arXiv:2405.07798}, 146 | year={2024} 147 | } 148 | 149 | -------------------------------------------------------------------------------- /ckpt/.place_image_mllm_ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/ckpt/.place_image_mllm_ckpt -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /figs/ben.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/ben.png -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/pipeline.png -------------------------------------------------------------------------------- /figs/sota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/sota.png -------------------------------------------------------------------------------- /figs/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/table1.png -------------------------------------------------------------------------------- /figs/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/table2.png -------------------------------------------------------------------------------- /figs/table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/table3.png -------------------------------------------------------------------------------- /figs/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/FreeVA/ddc838dbfcc4355630525582e060081198a33124/figs/vis.png -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" -------------------------------------------------------------------------------- /llava/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | SINGLE = auto() 12 | TWO = auto() 13 | MPT = auto() 14 | PLAIN = auto() 15 | LLAMA_2 = auto() 16 | 17 | 18 | @dataclasses.dataclass 19 | class Conversation: 20 | """A class that keeps all conversation history.""" 21 | system: str 22 | roles: List[str] 23 | messages: List[List[str]] 24 | offset: int 25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 26 | sep: str = "###" 27 | sep2: str = None 28 | version: str = "Unknown" 29 | 30 | skip_next: bool = False 31 | 32 | def get_prompt(self): 33 | messages = self.messages 34 | if len(messages) > 0 and type(messages[0][1]) is tuple: 35 | messages = self.messages.copy() 36 | init_role, init_msg = messages[0].copy() 37 | init_msg = init_msg[0].replace("", "").strip() 38 | if 'mmtag' in self.version: 39 | messages[0] = (init_role, init_msg) 40 | messages.insert(0, (self.roles[0], "")) 41 | messages.insert(1, (self.roles[1], "Received.")) 42 | else: 43 | messages[0] = (init_role, "\n" + init_msg) 44 | 45 | if self.sep_style == SeparatorStyle.SINGLE: 46 | ret = self.system + self.sep 47 | for role, message in messages: 48 | if message: 49 | if type(message) is tuple: 50 | message, _, _ = message 51 | ret += role + ": " + message + self.sep 52 | else: 53 | ret += role + ":" 54 | elif self.sep_style == SeparatorStyle.TWO: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system + seps[0] 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | elif self.sep_style == SeparatorStyle.MPT: 65 | ret = self.system + self.sep 66 | for role, message in messages: 67 | if message: 68 | if type(message) is tuple: 69 | message, _, _ = message 70 | ret += role + message + self.sep 71 | else: 72 | ret += role 73 | elif self.sep_style == SeparatorStyle.LLAMA_2: 74 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 75 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 76 | ret = "" 77 | 78 | for i, (role, message) in enumerate(messages): 79 | if i == 0: 80 | assert message, "first message should not be none" 81 | assert role == self.roles[0], "first message should come from user" 82 | if message: 83 | if type(message) is tuple: 84 | message, _, _ = message 85 | if i == 0: message = wrap_sys(self.system) + message 86 | if i % 2 == 0: 87 | message = wrap_inst(message) 88 | ret += self.sep + message 89 | else: 90 | ret += " " + message + " " + self.sep2 91 | else: 92 | ret += "" 93 | ret = ret.lstrip(self.sep) 94 | elif self.sep_style == SeparatorStyle.PLAIN: 95 | seps = [self.sep, self.sep2] 96 | ret = self.system 97 | for i, (role, message) in enumerate(messages): 98 | if message: 99 | if type(message) is tuple: 100 | message, _, _ = message 101 | ret += message + seps[i % 2] 102 | else: 103 | ret += "" 104 | else: 105 | raise ValueError(f"Invalid style: {self.sep_style}") 106 | 107 | return ret 108 | 109 | def append_message(self, role, message): 110 | self.messages.append([role, message]) 111 | 112 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): 113 | if image_process_mode == "Pad": 114 | def expand2square(pil_img, background_color=(122, 116, 104)): 115 | width, height = pil_img.size 116 | if width == height: 117 | return pil_img 118 | elif width > height: 119 | result = Image.new(pil_img.mode, (width, width), background_color) 120 | result.paste(pil_img, (0, (width - height) // 2)) 121 | return result 122 | else: 123 | result = Image.new(pil_img.mode, (height, height), background_color) 124 | result.paste(pil_img, ((height - width) // 2, 0)) 125 | return result 126 | image = expand2square(image) 127 | elif image_process_mode in ["Default", "Crop"]: 128 | pass 129 | elif image_process_mode == "Resize": 130 | image = image.resize((336, 336)) 131 | else: 132 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 133 | if max(image.size) > max_len: 134 | max_hw, min_hw = max(image.size), min(image.size) 135 | aspect_ratio = max_hw / min_hw 136 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 137 | longest_edge = int(shortest_edge * aspect_ratio) 138 | W, H = image.size 139 | if H > W: 140 | H, W = longest_edge, shortest_edge 141 | else: 142 | H, W = shortest_edge, longest_edge 143 | image = image.resize((W, H)) 144 | if return_pil: 145 | return image 146 | else: 147 | buffered = BytesIO() 148 | image.save(buffered, format=image_format) 149 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 150 | return img_b64_str 151 | 152 | def get_images(self, return_pil=False): 153 | images = [] 154 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 155 | if i % 2 == 0: 156 | if type(msg) is tuple: 157 | msg, image, image_process_mode = msg 158 | image = self.process_image(image, image_process_mode, return_pil=return_pil) 159 | images.append(image) 160 | return images 161 | 162 | def to_gradio_chatbot(self): 163 | ret = [] 164 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 165 | if i % 2 == 0: 166 | if type(msg) is tuple: 167 | msg, image, image_process_mode = msg 168 | img_b64_str = self.process_image( 169 | image, "Default", return_pil=False, 170 | image_format='JPEG') 171 | img_str = f'user upload image' 172 | msg = img_str + msg.replace('', '').strip() 173 | ret.append([msg, None]) 174 | else: 175 | ret.append([msg, None]) 176 | else: 177 | ret[-1][-1] = msg 178 | return ret 179 | 180 | def copy(self): 181 | return Conversation( 182 | system=self.system, 183 | roles=self.roles, 184 | messages=[[x, y] for x, y in self.messages], 185 | offset=self.offset, 186 | sep_style=self.sep_style, 187 | sep=self.sep, 188 | sep2=self.sep2, 189 | version=self.version) 190 | 191 | def dict(self): 192 | if len(self.get_images()) > 0: 193 | return { 194 | "system": self.system, 195 | "roles": self.roles, 196 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 197 | "offset": self.offset, 198 | "sep": self.sep, 199 | "sep2": self.sep2, 200 | } 201 | return { 202 | "system": self.system, 203 | "roles": self.roles, 204 | "messages": self.messages, 205 | "offset": self.offset, 206 | "sep": self.sep, 207 | "sep2": self.sep2, 208 | } 209 | 210 | 211 | conv_vicuna_v0 = Conversation( 212 | system="A chat between a curious human and an artificial intelligence assistant. " 213 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 214 | roles=("Human", "Assistant"), 215 | messages=( 216 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 217 | ("Assistant", 218 | "Renewable energy sources are those that can be replenished naturally in a relatively " 219 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 220 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 221 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 222 | "renewable and non-renewable energy sources:\n" 223 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 224 | "energy sources are finite and will eventually run out.\n" 225 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 226 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 227 | "and other negative effects.\n" 228 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 229 | "have lower operational costs than non-renewable sources.\n" 230 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 231 | "locations than non-renewable sources.\n" 232 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 233 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 234 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 235 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 236 | ), 237 | offset=2, 238 | sep_style=SeparatorStyle.SINGLE, 239 | sep="###", 240 | ) 241 | 242 | conv_vicuna_v1 = Conversation( 243 | system="A chat between a curious user and an artificial intelligence assistant. " 244 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 245 | roles=("USER", "ASSISTANT"), 246 | version="v1", 247 | messages=(), 248 | offset=0, 249 | sep_style=SeparatorStyle.TWO, 250 | sep=" ", 251 | sep2="", 252 | ) 253 | 254 | conv_llama_2 = Conversation( 255 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 256 | 257 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 258 | roles=("USER", "ASSISTANT"), 259 | version="llama_v2", 260 | messages=(), 261 | offset=0, 262 | sep_style=SeparatorStyle.LLAMA_2, 263 | sep="", 264 | sep2="", 265 | ) 266 | 267 | conv_llava_llama_2 = Conversation( 268 | system="You are a helpful language and vision assistant. " 269 | "You are able to understand the visual content that the user provides, " 270 | "and assist the user with a variety of tasks using natural language.", 271 | roles=("USER", "ASSISTANT"), 272 | version="llama_v2", 273 | messages=(), 274 | offset=0, 275 | sep_style=SeparatorStyle.LLAMA_2, 276 | sep="", 277 | sep2="", 278 | ) 279 | 280 | conv_mpt = Conversation( 281 | system="""<|im_start|>system 282 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 283 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 284 | version="mpt", 285 | messages=(), 286 | offset=0, 287 | sep_style=SeparatorStyle.MPT, 288 | sep="<|im_end|>", 289 | ) 290 | 291 | conv_llava_plain = Conversation( 292 | system="", 293 | roles=("", ""), 294 | messages=( 295 | ), 296 | offset=0, 297 | sep_style=SeparatorStyle.PLAIN, 298 | sep="\n", 299 | ) 300 | 301 | conv_llava_v0 = Conversation( 302 | system="A chat between a curious human and an artificial intelligence assistant. " 303 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 304 | roles=("Human", "Assistant"), 305 | messages=( 306 | ), 307 | offset=0, 308 | sep_style=SeparatorStyle.SINGLE, 309 | sep="###", 310 | ) 311 | 312 | conv_llava_v0_mmtag = Conversation( 313 | system="A chat between a curious user and an artificial intelligence assistant. " 314 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 315 | "The visual content will be provided with the following format: visual content.", 316 | roles=("Human", "Assistant"), 317 | messages=( 318 | ), 319 | offset=0, 320 | sep_style=SeparatorStyle.SINGLE, 321 | sep="###", 322 | version="v0_mmtag", 323 | ) 324 | 325 | conv_llava_v1 = Conversation( 326 | system="A chat between a curious human and an artificial intelligence assistant. " 327 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 328 | roles=("USER", "ASSISTANT"), 329 | version="v1", 330 | messages=(), 331 | offset=0, 332 | sep_style=SeparatorStyle.TWO, 333 | sep=" ", 334 | sep2="", 335 | ) 336 | 337 | conv_llava_v1_mmtag = Conversation( 338 | system="A chat between a curious user and an artificial intelligence assistant. " 339 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 340 | "The visual content will be provided with the following format: visual content.", 341 | roles=("USER", "ASSISTANT"), 342 | messages=(), 343 | offset=0, 344 | sep_style=SeparatorStyle.TWO, 345 | sep=" ", 346 | sep2="", 347 | version="v1_mmtag", 348 | ) 349 | 350 | conv_mistral_instruct = Conversation( 351 | system="", 352 | roles=("USER", "ASSISTANT"), 353 | version="llama_v2", 354 | messages=(), 355 | offset=0, 356 | sep_style=SeparatorStyle.LLAMA_2, 357 | sep="", 358 | sep2="", 359 | ) 360 | 361 | conv_chatml_direct = Conversation( 362 | system="""<|im_start|>system 363 | Answer the questions.""", 364 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 365 | version="mpt", 366 | messages=(), 367 | offset=0, 368 | sep_style=SeparatorStyle.MPT, 369 | sep="<|im_end|>", 370 | ) 371 | 372 | default_conversation = conv_vicuna_v1 373 | conv_templates = { 374 | "default": conv_vicuna_v0, 375 | "v0": conv_vicuna_v0, 376 | "v1": conv_vicuna_v1, 377 | "vicuna_v1": conv_vicuna_v1, 378 | "llama_2": conv_llama_2, 379 | "mistral_instruct": conv_mistral_instruct, 380 | "chatml_direct": conv_chatml_direct, 381 | "mistral_direct": conv_chatml_direct, 382 | 383 | "plain": conv_llava_plain, 384 | "v0_plain": conv_llava_plain, 385 | "llava_v0": conv_llava_v0, 386 | "v0_mmtag": conv_llava_v0_mmtag, 387 | "llava_v1": conv_llava_v1, 388 | "v1_mmtag": conv_llava_v1_mmtag, 389 | "llava_llama_2": conv_llava_llama_2, 390 | 391 | "mpt": conv_mpt, 392 | } 393 | 394 | 395 | if __name__ == "__main__": 396 | print(default_conversation.get_prompt()) 397 | -------------------------------------------------------------------------------- /llava/eval/model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from decord import VideoReader, cpu 4 | 5 | 6 | 7 | def load_video(vis_path, n_clips=1, num_frm=100): 8 | """ 9 | Load video frames from a video file. 10 | 11 | Parameters: 12 | vis_path (str): Path to the video file. 13 | n_clips (int): Number of clips to extract from the video. Defaults to 1. 14 | num_frm (int): Number of frames to extract from each clip. Defaults to 100. 15 | 16 | Returns: 17 | list: List of PIL.Image.Image objects representing video frames. 18 | """ 19 | 20 | # decord.bridge.set_bridge('torch') 21 | # Load video with VideoReader 22 | vr = VideoReader(vis_path, ctx=cpu(0)) 23 | total_frame_num = len(vr) 24 | 25 | # Currently, this function supports only 1 clip 26 | assert n_clips == 1 27 | 28 | # Calculate total number of frames to extract 29 | total_num_frm = min(total_frame_num, num_frm) 30 | # Get indices of frames to extract 31 | frame_idx = get_seq_frames(total_frame_num, total_num_frm) 32 | # Extract frames as numpy array 33 | img_array = vr.get_batch(frame_idx).asnumpy() # T H W C 34 | 35 | original_size = (img_array.shape[-2], img_array.shape[-3]) # (width, height) 36 | original_sizes = (original_size,) * total_num_frm 37 | 38 | clip_imgs = [Image.fromarray(img_array[j]) for j in range(total_num_frm)] 39 | 40 | 41 | return clip_imgs, original_sizes 42 | 43 | 44 | 45 | 46 | def get_seq_frames(total_num_frames, desired_num_frames): 47 | """ 48 | Calculate the indices of frames to extract from a video. 49 | 50 | Parameters: 51 | total_num_frames (int): Total number of frames in the video. 52 | desired_num_frames (int): Desired number of frames to extract. 53 | 54 | Returns: 55 | list: List of indices of frames to extract. 56 | """ 57 | 58 | # Calculate the size of each segment from which a frame will be extracted 59 | seg_size = float(total_num_frames - 1) / desired_num_frames 60 | 61 | seq = [] 62 | for i in range(desired_num_frames): 63 | # Calculate the start and end indices of each segment 64 | start = int(np.round(seg_size * i)) 65 | end = int(np.round(seg_size * (i + 1))) 66 | 67 | # Append the middle index of the segment to the list 68 | seq.append((start + end) // 2) 69 | 70 | return seq 71 | -------------------------------------------------------------------------------- /llava/eval/run_inference_benchmark_consistency.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import argparse 4 | import json 5 | 6 | from tqdm import tqdm 7 | from llava.eval.model_utils import load_video 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | import torch 18 | import time 19 | 20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes): 21 | if model.config.mm_use_im_start_end: 22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question 23 | else: 24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question 25 | 26 | conv = conv_templates[conv_mode].copy() 27 | conv.append_message(conv.roles[0], qs) 28 | conv.append_message(conv.roles[1], None) 29 | prompt = conv.get_prompt() 30 | 31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 32 | image_tensor = process_images(video_frames, image_processor, model.config) 33 | 34 | with torch.inference_mode(): 35 | output_ids = model.generate( 36 | input_ids, 37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 38 | image_sizes=image_sizes, 39 | do_sample=True if args.temperature > 0 else False, 40 | temperature=args.temperature, 41 | top_p=args.top_p, 42 | num_beams=args.num_beams, 43 | max_new_tokens=128, 44 | use_cache=True) 45 | 46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 47 | return outputs 48 | 49 | def split_list(lst, n): 50 | """Split a list into n (roughly) equal-sized chunks""" 51 | chunk_size = math.ceil(len(lst) / n) # integer division 52 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 53 | 54 | def get_chunk(lst, n, k): 55 | chunks = split_list(lst, n) 56 | return chunks[k] 57 | 58 | def parse_args(): 59 | """ 60 | Parse command-line arguments. 61 | """ 62 | parser = argparse.ArgumentParser() 63 | 64 | # Define the command-line arguments 65 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True) 66 | parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True) 67 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True) 68 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True) 69 | parser.add_argument("--model_name", type=str, required=True) 70 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1') 71 | parser.add_argument("--num_chunks", type=int, default=1) 72 | parser.add_argument("--chunk_idx", type=int, default=0) 73 | parser.add_argument("--num_frames", type=int, default=100) 74 | parser.add_argument("--device", type=str, required=False, default='cuda:0') 75 | parser.add_argument("--model-base", type=str, default=None) 76 | parser.add_argument("--num_beams", type=int, default=1) 77 | parser.add_argument("--temperature", type=float, default=0.2) 78 | parser.add_argument("--top_p", type=float, default=None) 79 | 80 | return parser.parse_args() 81 | 82 | 83 | def run_inference(args): 84 | """ 85 | Run inference on a set of video files using the provided model. 86 | 87 | Args: 88 | args: Command-line arguments. 89 | """ 90 | 91 | disable_torch_init() 92 | model_path = os.path.expanduser(args.model_name) 93 | model_name = get_model_name_from_path(model_path) 94 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 95 | 96 | 97 | gt_contents = json.load(open(args.gt_file, "r")) 98 | gt_contents = get_chunk(gt_contents, args.num_chunks, args.chunk_idx) 99 | 100 | answers_file = os.path.join(args.output_dir, f"{args.output_name}.json") 101 | os.makedirs(args.output_dir, exist_ok=True) 102 | ans_file = open(answers_file, "w") 103 | 104 | # Create the output directory if it doesn't exist 105 | if not os.path.exists(args.output_dir): 106 | os.makedirs(args.output_dir) 107 | 108 | output_list = [] # List to store the output results 109 | conv_mode = args.conv_mode 110 | 111 | video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 112 | 113 | # Iterate over each sample in the ground truth file 114 | index = 0 115 | for sample in tqdm(gt_contents): 116 | video_name = sample['video_name'] 117 | sample_set = sample 118 | question_1 = sample['Q1'] 119 | question_2 = sample['Q2'] 120 | 121 | # Load the video file 122 | for fmt in video_formats: # Added this line 123 | temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}") 124 | if os.path.exists(temp_path): 125 | video_path = temp_path 126 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames) 127 | # Run inference on the video for the first question and add the output to the list 128 | output_1 = llava_inference(video_frames, question_1, conv_mode, model, 129 | tokenizer, image_processor, sizes) 130 | sample_set['pred1'] = output_1 131 | 132 | # Run inference on the video for the second question and add the output to the list 133 | output_2 = llava_inference(video_frames, question_2, conv_mode, model, 134 | tokenizer, image_processor, sizes) 135 | sample_set['pred2'] = output_2 136 | 137 | output_list.append(sample_set) 138 | ans_file.write(json.dumps(sample_set) + "\n") 139 | index += 1 140 | break 141 | 142 | ans_file.close() 143 | 144 | 145 | if __name__ == "__main__": 146 | args = parse_args() 147 | run_inference(args) 148 | -------------------------------------------------------------------------------- /llava/eval/run_inference_benchmark_general.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import argparse 4 | import json 5 | 6 | from tqdm import tqdm 7 | from llava.eval.model_utils import load_video 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | import torch 18 | import time 19 | 20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes): 21 | if model.config.mm_use_im_start_end: 22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question 23 | else: 24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question 25 | 26 | conv = conv_templates[conv_mode].copy() 27 | conv.append_message(conv.roles[0], qs) 28 | conv.append_message(conv.roles[1], None) 29 | prompt = conv.get_prompt() 30 | 31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 32 | image_tensor = process_images(video_frames, image_processor, model.config) 33 | 34 | with torch.inference_mode(): 35 | output_ids = model.generate( 36 | input_ids, 37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 38 | image_sizes=image_sizes, 39 | do_sample=True if args.temperature > 0 else False, 40 | temperature=args.temperature, 41 | top_p=args.top_p, 42 | num_beams=args.num_beams, 43 | max_new_tokens=128, 44 | use_cache=True) 45 | 46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 47 | return outputs 48 | 49 | def split_list(lst, n): 50 | """Split a list into n (roughly) equal-sized chunks""" 51 | chunk_size = math.ceil(len(lst) / n) # integer division 52 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 53 | 54 | def get_chunk(lst, n, k): 55 | chunks = split_list(lst, n) 56 | return chunks[k] 57 | 58 | def parse_args(): 59 | """ 60 | Parse command-line arguments. 61 | """ 62 | parser = argparse.ArgumentParser() 63 | 64 | # Define the command-line arguments 65 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True) 66 | parser.add_argument('--gt_file', help='Path to the ground truth file.', required=True) 67 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True) 68 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True) 69 | parser.add_argument("--model_name", type=str, required=True) 70 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1') 71 | parser.add_argument("--num_chunks", type=int, default=1) 72 | parser.add_argument("--chunk_idx", type=int, default=0) 73 | parser.add_argument("--num_frames", type=int, default=100) 74 | parser.add_argument("--device", type=str, required=False, default='cuda:0') 75 | parser.add_argument("--model-base", type=str, default=None) 76 | parser.add_argument("--num_beams", type=int, default=1) 77 | parser.add_argument("--temperature", type=float, default=0.2) 78 | parser.add_argument("--top_p", type=float, default=None) 79 | 80 | return parser.parse_args() 81 | 82 | 83 | def run_inference(args): 84 | """ 85 | Run inference on a set of video files using the provided model. 86 | 87 | Args: 88 | args: Command-line arguments. 89 | """ 90 | 91 | disable_torch_init() 92 | model_path = os.path.expanduser(args.model_name) 93 | model_name = get_model_name_from_path(model_path) 94 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 95 | 96 | 97 | gt_contents = json.load(open(args.gt_file, "r")) 98 | gt_contents = get_chunk(gt_contents, args.num_chunks, args.chunk_idx) 99 | 100 | answers_file = os.path.join(args.output_dir, f"{args.output_name}.json") 101 | os.makedirs(args.output_dir, exist_ok=True) 102 | ans_file = open(answers_file, "w") 103 | 104 | # Create the output directory if it doesn't exist 105 | if not os.path.exists(args.output_dir): 106 | os.makedirs(args.output_dir) 107 | 108 | output_list = [] # List to store the output results 109 | conv_mode = args.conv_mode 110 | 111 | video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 112 | 113 | # Iterate over each sample in the ground truth file 114 | index = 0 115 | for sample in tqdm(gt_contents): 116 | video_name = sample['video_name'] 117 | sample_set = sample 118 | question = sample['Q'] 119 | 120 | # Load the video file 121 | for fmt in video_formats: # Added this line 122 | temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}") 123 | if os.path.exists(temp_path): 124 | video_path = temp_path 125 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames) 126 | # Run inference on the video and add the output to the list 127 | output = llava_inference(video_frames, question, conv_mode, model, 128 | tokenizer, image_processor, sizes) 129 | sample_set['pred'] = output 130 | 131 | output_list.append(sample_set) 132 | ans_file.write(json.dumps(sample_set) + "\n") 133 | index += 1 134 | break 135 | 136 | ans_file.close() 137 | 138 | 139 | 140 | if __name__ == "__main__": 141 | args = parse_args() 142 | run_inference(args) 143 | -------------------------------------------------------------------------------- /llava/eval/run_inference_qa.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import argparse 4 | import json 5 | 6 | from tqdm import tqdm 7 | from llava.eval.model_utils import load_video 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | import torch 18 | import time 19 | 20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes): 21 | if model.config.mm_use_im_start_end: 22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question 23 | else: 24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question 25 | 26 | conv = conv_templates[conv_mode].copy() 27 | conv.append_message(conv.roles[0], qs) 28 | conv.append_message(conv.roles[1], None) 29 | prompt = conv.get_prompt() 30 | 31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 32 | image_tensor = process_images(video_frames, image_processor, model.config) 33 | 34 | with torch.inference_mode(): 35 | output_ids = model.generate( 36 | input_ids, 37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 38 | image_sizes=image_sizes, 39 | do_sample=True if args.temperature > 0 else False, 40 | temperature=args.temperature, 41 | top_p=args.top_p, 42 | num_beams=args.num_beams, 43 | max_new_tokens=128, 44 | use_cache=True) 45 | 46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 47 | return outputs 48 | 49 | def split_list(lst, n): 50 | """Split a list into n (roughly) equal-sized chunks""" 51 | chunk_size = math.ceil(len(lst) / n) # integer division 52 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 53 | 54 | def get_chunk(lst, n, k): 55 | chunks = split_list(lst, n) 56 | return chunks[k] 57 | 58 | def parse_args(): 59 | """ 60 | Parse command-line arguments. 61 | """ 62 | parser = argparse.ArgumentParser() 63 | 64 | # Define the command-line arguments 65 | parser.add_argument('--video_dir', help='Directory containing video files.', required=True) 66 | parser.add_argument('--gt_file_question', help='Path to the ground truth file containing question.', required=True) 67 | parser.add_argument('--gt_file_answers', help='Path to the ground truth file containing answers.', required=True) 68 | parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True) 69 | parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True) 70 | parser.add_argument("--model_name", type=str, required=True) 71 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1') 72 | parser.add_argument("--num_chunks", type=int, default=1) 73 | parser.add_argument("--chunk_idx", type=int, default=0) 74 | parser.add_argument("--num_frames", type=int, default=100) 75 | parser.add_argument("--device", type=str, required=False, default='cuda:0') 76 | parser.add_argument("--model-base", type=str, default=None) 77 | parser.add_argument("--num_beams", type=int, default=1) 78 | parser.add_argument("--temperature", type=float, default=0.2) 79 | parser.add_argument("--top_p", type=float, default=None) 80 | 81 | return parser.parse_args() 82 | 83 | 84 | def run_inference(args): 85 | """ 86 | Run inference on Video QA DataSetå. 87 | 88 | Args: 89 | args: Command-line arguments. 90 | """ 91 | 92 | disable_torch_init() 93 | model_path = os.path.expanduser(args.model_name) 94 | model_name = get_model_name_from_path(model_path) 95 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 96 | 97 | 98 | gt_questions = json.load(open(args.gt_file_question, "r")) 99 | gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx) 100 | gt_answers = json.load(open(args.gt_file_answers, "r")) 101 | gt_answers = get_chunk(gt_answers, args.num_chunks, args.chunk_idx) 102 | 103 | answers_file = os.path.join(args.output_dir, f"{args.output_name}.json") 104 | os.makedirs(args.output_dir, exist_ok=True) 105 | ans_file = open(answers_file, "w") 106 | 107 | # Create the output directory if it doesn't exist 108 | if not os.path.exists(args.output_dir): 109 | os.makedirs(args.output_dir) 110 | 111 | output_list = [] # List to store the output results 112 | conv_mode = args.conv_mode 113 | 114 | video_formats = ['.mp4', '.avi', '.mov', '.mkv'] 115 | 116 | # Iterate over each sample in the ground truth file 117 | index = 0 118 | for sample in tqdm(gt_questions): 119 | video_name = sample['video_name'] 120 | question = sample['question'] 121 | id = sample['question_id'] 122 | answer = gt_answers[index]['answer'] 123 | index += 1 124 | 125 | sample_set = {'id': id, 'question': question, 'answer': answer} 126 | 127 | # Load the video file 128 | for fmt in video_formats: # Added this line 129 | vid_name = f"v_{video_name}" if 'Activitynet' in args.video_dir else video_name 130 | temp_path = os.path.join(args.video_dir, f"{vid_name}{fmt}") 131 | 132 | if os.path.exists(temp_path): 133 | # print(f'processing {idx}/{len(gt_questions)}') 134 | video_path = temp_path 135 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames) 136 | # Run inference on the video and add the output to the list 137 | output = llava_inference(video_frames, question, conv_mode, model, 138 | tokenizer, image_processor, sizes) 139 | print(output) 140 | sample_set['pred'] = output 141 | output_list.append(sample_set) 142 | ans_file.write(json.dumps(sample_set) + "\n") 143 | break 144 | 145 | ans_file.close() 146 | 147 | 148 | if __name__ == "__main__": 149 | args = parse_args() 150 | run_inference(args) 151 | -------------------------------------------------------------------------------- /llava/eval/single_video_inference.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import argparse 4 | import json 5 | 6 | from tqdm import tqdm 7 | from llava.eval.model_utils import load_video 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | import torch 18 | import time 19 | 20 | def llava_inference(video_frames, question, conv_mode, model, tokenizer, image_processor, image_sizes): 21 | if model.config.mm_use_im_start_end: 22 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question 23 | else: 24 | qs = DEFAULT_IMAGE_TOKEN + '\n' + question 25 | 26 | conv = conv_templates[conv_mode].copy() 27 | conv.append_message(conv.roles[0], qs) 28 | conv.append_message(conv.roles[1], None) 29 | prompt = conv.get_prompt() 30 | 31 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 32 | image_tensor = process_images(video_frames, image_processor, model.config) 33 | 34 | with torch.inference_mode(): 35 | output_ids = model.generate( 36 | input_ids, 37 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 38 | image_sizes=image_sizes, 39 | do_sample=False, 40 | temperature=0, 41 | top_p=None, 42 | num_beams=1, 43 | max_new_tokens=128, 44 | use_cache=True) 45 | 46 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 47 | return outputs 48 | 49 | def parse_args(): 50 | """ 51 | Parse command-line arguments. 52 | """ 53 | parser = argparse.ArgumentParser() 54 | 55 | # Define the command-line arguments 56 | parser.add_argument('--video_path', help='input video path', required=True) 57 | parser.add_argument("--model_name", type=str, required=True) 58 | parser.add_argument("--conv-mode", type=str, required=False, default='video-chatgpt_v1') 59 | parser.add_argument("--num_frames", type=int, default=4) 60 | 61 | return parser.parse_args() 62 | 63 | 64 | def run_inference(args): 65 | """ 66 | Run inference 67 | 68 | Args: 69 | args: Command-line arguments. 70 | """ 71 | 72 | disable_torch_init() 73 | model_path = os.path.expanduser(args.model_name) 74 | model_name = get_model_name_from_path(model_path) 75 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name) 76 | 77 | conv_mode = args.conv_mode 78 | 79 | video_path = args.video_path 80 | 81 | video_frames, sizes = load_video(video_path, num_frm=args.num_frames) 82 | 83 | question = "describe the video in detail" 84 | 85 | try: 86 | # Run inference on the video and add the output to the list 87 | output = llava_inference(video_frames, question, conv_mode, model, 88 | tokenizer, image_processor, sizes) 89 | print("\n\n", output) 90 | 91 | except Exception as e: 92 | print(f"Error processing video file '{video_path}': {e}") 93 | 94 | 95 | if __name__ == "__main__": 96 | args = parse_args() 97 | run_inference(args) 98 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | 183 | return new_images 184 | 185 | 186 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 187 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 188 | 189 | def insert_separator(X, sep): 190 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 191 | 192 | input_ids = [] 193 | offset = 0 194 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 195 | offset = 1 196 | input_ids.append(prompt_chunks[0][0]) 197 | 198 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 199 | input_ids.extend(x[offset:]) 200 | 201 | if return_tensors is not None: 202 | if return_tensors == 'pt': 203 | return torch.tensor(input_ids, dtype=torch.long) 204 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 205 | return input_ids 206 | 207 | 208 | def get_model_name_from_path(model_path): 209 | model_path = model_path.strip("/") 210 | model_paths = model_path.split("/") 211 | if model_paths[-1].startswith('checkpoint-'): 212 | return model_paths[-2] + "_" + model_paths[-1] 213 | else: 214 | return model_paths[-1] 215 | 216 | class KeywordsStoppingCriteria(StoppingCriteria): 217 | def __init__(self, keywords, tokenizer, input_ids): 218 | self.keywords = keywords 219 | self.keyword_ids = [] 220 | self.max_keyword_len = 0 221 | for keyword in keywords: 222 | cur_keyword_ids = tokenizer(keyword).input_ids 223 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 224 | cur_keyword_ids = cur_keyword_ids[1:] 225 | if len(cur_keyword_ids) > self.max_keyword_len: 226 | self.max_keyword_len = len(cur_keyword_ids) 227 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 228 | self.tokenizer = tokenizer 229 | self.start_len = input_ids.shape[1] 230 | 231 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 232 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 233 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 234 | for keyword_id in self.keyword_ids: 235 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 236 | if torch.equal(truncated_output_ids, keyword_id): 237 | return True 238 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 239 | for keyword in self.keywords: 240 | if keyword in outputs: 241 | return True 242 | return False 243 | 244 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 245 | outputs = [] 246 | for i in range(output_ids.shape[0]): 247 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 248 | return all(outputs) 249 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | if 'llava' in model_name.lower(): 49 | # Load LLaVA model 50 | if 'lora' in model_name.lower() and model_base is None: 51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 52 | if 'lora' in model_name.lower() and model_base is not None: 53 | from llava.model.language_model.llava_llama import LlavaConfig 54 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) 55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 56 | print('Loading LLaVA from base model...') 57 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 59 | if model.lm_head.weight.shape[0] != token_num: 60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 62 | 63 | print('Loading additional LLaVA weights...') 64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 66 | else: 67 | # this is probably from HF Hub 68 | from huggingface_hub import hf_hub_download 69 | def load_from_hf(repo_id, filename, subfolder=None): 70 | cache_file = hf_hub_download( 71 | repo_id=repo_id, 72 | filename=filename, 73 | subfolder=subfolder) 74 | return torch.load(cache_file, map_location='cpu') 75 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 76 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 77 | if any(k.startswith('model.model.') for k in non_lora_trainables): 78 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 79 | model.load_state_dict(non_lora_trainables, strict=False) 80 | 81 | from peft import PeftModel 82 | print('Loading LoRA weights...') 83 | model = PeftModel.from_pretrained(model, model_path) 84 | print('Merging LoRA weights...') 85 | model = model.merge_and_unload() 86 | print('Model is loaded...') 87 | elif model_base is not None: 88 | # this may be mm projector only 89 | print('Loading LLaVA from base model...') 90 | if 'mpt' in model_name.lower(): 91 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 92 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 93 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 94 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 95 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 96 | else: 97 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 98 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 99 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 100 | 101 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 102 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 103 | model.load_state_dict(mm_projector_weights, strict=False) 104 | else: 105 | if 'mpt' in model_name.lower(): 106 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 107 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 108 | elif 'mistral' in model_name.lower(): 109 | tokenizer = AutoTokenizer.from_pretrained(model_path) 110 | model = LlavaMistralForCausalLM.from_pretrained( 111 | model_path, 112 | low_cpu_mem_usage=True, 113 | **kwargs 114 | ) 115 | else: 116 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 117 | model = LlavaLlamaForCausalLM.from_pretrained( 118 | model_path, 119 | low_cpu_mem_usage=True, 120 | **kwargs 121 | ) 122 | else: 123 | # Load language model 124 | if model_base is not None: 125 | # PEFT model 126 | from peft import PeftModel 127 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 128 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 129 | print(f"Loading LoRA weights from {model_path}") 130 | model = PeftModel.from_pretrained(model, model_path) 131 | print(f"Merging weights") 132 | model = model.merge_and_unload() 133 | print('Convert to FP16...') 134 | model.to(torch.float16) 135 | else: 136 | use_fast = False 137 | if 'mpt' in model_name.lower(): 138 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 139 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 140 | else: 141 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 142 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 143 | 144 | image_processor = None 145 | 146 | if 'llava' in model_name.lower(): 147 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 148 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 149 | if mm_use_im_patch_token: 150 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 151 | if mm_use_im_start_end: 152 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 153 | model.resize_token_embeddings(len(tokenizer)) 154 | 155 | vision_tower = model.get_vision_tower() 156 | if not vision_tower.is_loaded: 157 | vision_tower.load_model(device_map=device_map) 158 | if device_map != 'auto': 159 | vision_tower.to(device=device_map, dtype=torch.float16) 160 | # vision_tower.to(device=device, dtype=torch.float16) 161 | image_processor = vision_tower.image_processor 162 | 163 | if hasattr(model.config, "max_sequence_length"): 164 | context_len = model.config.max_sequence_length 165 | else: 166 | context_len = 2048 167 | 168 | return tokenizer, model, image_processor, context_len 169 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes, 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes, 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | 156 | return inputs 157 | 158 | AutoConfig.register("llava_llama", LlavaConfig) 159 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 160 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/model/llava_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from .multimodal_encoder.builder import build_vision_tower 22 | from .multimodal_projector.builder import build_vision_projector 23 | 24 | from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | from llava.mm_utils import get_anyres_image_grid_shape 27 | 28 | from einops import rearrange 29 | 30 | 31 | class LlavaMetaModel: 32 | 33 | def __init__(self, config): 34 | super(LlavaMetaModel, self).__init__(config) 35 | 36 | if hasattr(config, "mm_vision_tower"): 37 | self.vision_tower = build_vision_tower(config, delay_load=True) 38 | self.mm_projector = build_vision_projector(config) 39 | 40 | if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): 41 | self.image_newline = nn.Parameter( 42 | torch.empty(config.hidden_size, dtype=self.dtype) 43 | ) 44 | 45 | def get_vision_tower(self): 46 | vision_tower = getattr(self, 'vision_tower', None) 47 | if type(vision_tower) is list: 48 | vision_tower = vision_tower[0] 49 | return vision_tower 50 | 51 | def initialize_vision_modules(self, model_args, fsdp=None): 52 | vision_tower = model_args.vision_tower 53 | mm_vision_select_layer = model_args.mm_vision_select_layer 54 | mm_vision_select_feature = model_args.mm_vision_select_feature 55 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 56 | mm_patch_merge_type = model_args.mm_patch_merge_type 57 | 58 | self.config.mm_vision_tower = vision_tower 59 | 60 | if self.get_vision_tower() is None: 61 | vision_tower = build_vision_tower(model_args) 62 | 63 | if fsdp is not None and len(fsdp) > 0: 64 | self.vision_tower = [vision_tower] 65 | else: 66 | self.vision_tower = vision_tower 67 | else: 68 | if fsdp is not None and len(fsdp) > 0: 69 | vision_tower = self.vision_tower[0] 70 | else: 71 | vision_tower = self.vision_tower 72 | vision_tower.load_model() 73 | 74 | self.config.use_mm_proj = True 75 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 76 | self.config.mm_hidden_size = vision_tower.hidden_size 77 | self.config.mm_vision_select_layer = mm_vision_select_layer 78 | self.config.mm_vision_select_feature = mm_vision_select_feature 79 | self.config.mm_patch_merge_type = mm_patch_merge_type 80 | 81 | if getattr(self, 'mm_projector', None) is None: 82 | self.mm_projector = build_vision_projector(self.config) 83 | 84 | if 'unpad' in mm_patch_merge_type: 85 | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) 86 | self.image_newline = nn.Parameter( 87 | torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std 88 | ) 89 | else: 90 | # In case it is frozen by LoRA 91 | for p in self.mm_projector.parameters(): 92 | p.requires_grad = True 93 | 94 | if pretrain_mm_mlp_adapter is not None: 95 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 96 | def get_w(weights, keyword): 97 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 98 | 99 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 100 | 101 | 102 | def unpad_image(tensor, original_size): 103 | """ 104 | Unpads a PyTorch tensor of a padded and resized image. 105 | 106 | Args: 107 | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. 108 | original_size (tuple): The original size of the image (height, width). 109 | 110 | Returns: 111 | torch.Tensor: The unpadded image tensor. 112 | """ 113 | original_width, original_height = original_size 114 | current_height, current_width = tensor.shape[1:] 115 | 116 | original_aspect_ratio = original_width / original_height 117 | current_aspect_ratio = current_width / current_height 118 | 119 | if original_aspect_ratio > current_aspect_ratio: 120 | scale_factor = current_width / original_width 121 | new_height = int(original_height * scale_factor) 122 | padding = (current_height - new_height) // 2 123 | unpadded_tensor = tensor[:, padding:current_height - padding, :] 124 | else: 125 | scale_factor = current_height / original_height 126 | new_width = int(original_width * scale_factor) 127 | padding = (current_width - new_width) // 2 128 | unpadded_tensor = tensor[:, :, padding:current_width - padding] 129 | 130 | return unpadded_tensor 131 | 132 | 133 | class LlavaMetaForCausalLM(ABC): 134 | 135 | @abstractmethod 136 | def get_model(self): 137 | pass 138 | 139 | def get_vision_tower(self): 140 | return self.get_model().get_vision_tower() 141 | 142 | def encode_images(self, images): 143 | image_features = self.get_model().get_vision_tower()(images) 144 | image_features = self.get_model().mm_projector(image_features) 145 | return image_features 146 | 147 | 148 | def temporal_aggregation(self, image_features): 149 | T, N, D = image_features.shape 150 | 151 | ## D1: temporal cat (Just one line!) 152 | image_features = image_features.view(T * N, D) # [T*N D] 153 | 154 | ## D2: spatial pool + temporal cat (Uncomment to use) 155 | # pool2 = nn.MaxPool1d(kernel_size=2, stride=2) 156 | # image_features = rearrange(image_features, 't n d -> t d n') 157 | # image_features = pool2(image_features) # [t d n] -> [t d (n/2)] 158 | # image_features = rearrange(image_features, 't d n -> t n d', t=T) 159 | # image_features = image_features.view(-1, D) # [T*N D] 160 | 161 | ## S1: GAP 162 | # image_features = torch.mean(image_features, dim=0) # [T N D] -> [N D] 163 | 164 | ####### unsqueeze 165 | image_features = image_features.unsqueeze(0) # [1 T*N D] 166 | return image_features 167 | 168 | def prepare_inputs_labels_for_multimodal( 169 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 170 | images, image_sizes=None 171 | ): 172 | vision_tower = self.get_vision_tower() 173 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 174 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 175 | 176 | if type(images) is list or images.ndim == 5: 177 | # images: [T S C H W] 178 | if type(images) is list: 179 | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] 180 | concat_images = torch.cat([image for image in images], dim=0) # [TS C H W] 181 | image_features = self.encode_images(concat_images) # [TS N D] 182 | split_sizes = [image.shape[0] for image in images] # T * [S] 183 | image_features = torch.split(image_features, split_sizes, dim=0) # T * [S N D] 184 | mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') 185 | image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') 186 | if mm_patch_merge_type == 'flat': 187 | image_features = [x.flatten(0, 1) for x in image_features] 188 | elif mm_patch_merge_type.startswith('spatial'): 189 | new_image_features = [] 190 | for image_idx, image_feature in enumerate(image_features): 191 | if image_feature.shape[0] > 1: 192 | base_image_feature = image_feature[0] # [N D] 193 | image_feature = image_feature[1:] # [S-1 N D] 194 | 195 | height = width = self.get_vision_tower().num_patches_per_side 196 | assert height * width == base_image_feature.shape[0] 197 | if image_aspect_ratio == 'anyres': 198 | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) 199 | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) # [sqrt(S-1) sqrt(S-1) H W D] 200 | else: 201 | raise NotImplementedError 202 | if 'unpad' in mm_patch_merge_type: 203 | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() 204 | image_feature = image_feature.flatten(1, 2).flatten(2, 3) 205 | image_feature = unpad_image(image_feature, image_sizes[image_idx]) 206 | image_feature = torch.cat(( 207 | image_feature, 208 | self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) 209 | ), dim=-1) 210 | image_feature = image_feature.flatten(1, 2).transpose(0, 1) 211 | else: 212 | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() 213 | image_feature = image_feature.flatten(0, 3) 214 | # image_feature = torch.cat((base_image_feature, image_feature), dim=0) 215 | # Using base feature only for LLaVA-1.6 to avoid extra tokens from high-resolution features. 216 | image_feature = base_image_feature # [576 D] 217 | else: 218 | image_feature = image_feature[0] 219 | if 'unpad' in mm_patch_merge_type: 220 | image_feature = torch.cat(( 221 | image_feature, 222 | self.model.image_newline[None].to(image_feature.device) 223 | ), dim=0) 224 | new_image_features.append(image_feature) 225 | image_features = new_image_features 226 | # len=T, [1948 D] 227 | # whwu. concat 228 | image_features = torch.stack(image_features, dim=0) # [T New_N D] 229 | else: 230 | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") 231 | else: 232 | image_features = self.encode_images(images).to(self.device) # [T 576 D] 233 | 234 | # whwu: unsqueeze for image_features 235 | image_features = self.temporal_aggregation(image_features) 236 | 237 | 238 | # TODO: image start / end is not implemented here to support pretraining. 239 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 240 | raise NotImplementedError 241 | 242 | # Let's just add dummy tensors if they do not exist, 243 | # it is a headache to deal with None all the time. 244 | # But it is not ideal, and if you have a better idea, 245 | # please open an issue / submit a PR, thanks. 246 | _labels = labels 247 | _position_ids = position_ids 248 | _attention_mask = attention_mask 249 | if attention_mask is None: 250 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 251 | else: 252 | attention_mask = attention_mask.bool() 253 | if position_ids is None: 254 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 255 | if labels is None: 256 | labels = torch.full_like(input_ids, IGNORE_INDEX) 257 | 258 | # remove the padding using attention_mask -- FIXME 259 | _input_ids = input_ids 260 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 261 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 262 | 263 | new_input_embeds = [] 264 | new_labels = [] 265 | cur_image_idx = 0 266 | for batch_idx, cur_input_ids in enumerate(input_ids): 267 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 268 | # always 1 269 | if num_images == 0: 270 | cur_image_features = image_features[cur_image_idx] 271 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 272 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) 273 | new_input_embeds.append(cur_input_embeds) 274 | new_labels.append(labels[batch_idx]) 275 | cur_image_idx += 1 276 | continue 277 | 278 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 279 | cur_input_ids_noim = [] 280 | cur_labels = labels[batch_idx] 281 | cur_labels_noim = [] 282 | for i in range(len(image_token_indices) - 1): 283 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) 284 | cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) 285 | split_sizes = [x.shape[0] for x in cur_labels_noim] 286 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) 287 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) 288 | cur_new_input_embeds = [] 289 | cur_new_labels = [] 290 | 291 | for i in range(num_images + 1): 292 | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) 293 | cur_new_labels.append(cur_labels_noim[i]) 294 | if i < num_images: 295 | cur_image_features = image_features[cur_image_idx] 296 | cur_image_idx += 1 297 | cur_new_input_embeds.append(cur_image_features) 298 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 299 | 300 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 301 | 302 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) # 627: 576 + 51, 1999: 1948 + 51, 51 is prefix 303 | cur_new_labels = torch.cat(cur_new_labels) 304 | 305 | new_input_embeds.append(cur_new_input_embeds) 306 | new_labels.append(cur_new_labels) 307 | 308 | # Truncate sequences to max length as image embeddings can make the sequence longer 309 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 310 | if tokenizer_model_max_length is not None: 311 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 312 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 313 | 314 | # Combine them 315 | max_len = max(x.shape[0] for x in new_input_embeds) 316 | batch_size = len(new_input_embeds) 317 | 318 | new_input_embeds_padded = [] 319 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 320 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 321 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 322 | 323 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 324 | cur_len = cur_new_embed.shape[0] 325 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 326 | new_input_embeds_padded.append(torch.cat(( 327 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 328 | cur_new_embed 329 | ), dim=0)) 330 | if cur_len > 0: 331 | new_labels_padded[i, -cur_len:] = cur_new_labels 332 | attention_mask[i, -cur_len:] = True 333 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 334 | else: 335 | new_input_embeds_padded.append(torch.cat(( 336 | cur_new_embed, 337 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 338 | ), dim=0)) 339 | if cur_len > 0: 340 | new_labels_padded[i, :cur_len] = cur_new_labels 341 | attention_mask[i, :cur_len] = True 342 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 343 | 344 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 345 | 346 | if _labels is None: 347 | new_labels = None 348 | else: 349 | new_labels = new_labels_padded 350 | 351 | if _attention_mask is None: 352 | attention_mask = None 353 | else: 354 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 355 | 356 | if _position_ids is None: 357 | position_ids = None 358 | 359 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 360 | 361 | 362 | def initialize_vision_tokenizer(self, model_args, tokenizer): 363 | if model_args.mm_use_im_patch_token: 364 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 365 | self.resize_token_embeddings(len(tokenizer)) 366 | 367 | if model_args.mm_use_im_start_end: 368 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 369 | self.resize_token_embeddings(len(tokenizer)) 370 | 371 | if num_new_tokens > 0: 372 | input_embeddings = self.get_input_embeddings().weight.data 373 | output_embeddings = self.get_output_embeddings().weight.data 374 | 375 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 376 | dim=0, keepdim=True) 377 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 378 | dim=0, keepdim=True) 379 | 380 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 381 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 382 | 383 | if model_args.tune_mm_mlp_adapter: 384 | for p in self.get_input_embeddings().parameters(): 385 | p.requires_grad = True 386 | for p in self.get_output_embeddings().parameters(): 387 | p.requires_grad = False 388 | 389 | if model_args.pretrain_mm_mlp_adapter: 390 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 391 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 392 | assert num_new_tokens == 2 393 | if input_embeddings.shape == embed_tokens_weight.shape: 394 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 395 | elif embed_tokens_weight.shape[0] == num_new_tokens: 396 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 397 | else: 398 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 399 | elif model_args.mm_use_im_patch_token: 400 | if model_args.tune_mm_mlp_adapter: 401 | for p in self.get_input_embeddings().parameters(): 402 | p.requires_grad = False 403 | for p in self.get_output_embeddings().parameters(): 404 | p.requires_grad = False 405 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name, cache_dir='./cache_dir') 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | 29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name, cache_dir='./cache_dir') 30 | 31 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map, cache_dir='./cache_dir') 32 | self.vision_tower.requires_grad_(False) 33 | 34 | self.is_loaded = True 35 | 36 | def feature_select(self, image_forward_outs): 37 | image_features = image_forward_outs.hidden_states[self.select_layer] 38 | if self.select_feature == 'patch': 39 | image_features = image_features[:, 1:] 40 | elif self.select_feature == 'cls_patch': 41 | image_features = image_features 42 | else: 43 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 44 | return image_features 45 | 46 | @torch.no_grad() 47 | def forward(self, images): 48 | if type(images) is list: 49 | image_features = [] 50 | for image in images: 51 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 52 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 53 | image_features.append(image_feature) 54 | else: 55 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 56 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 57 | 58 | return image_features 59 | 60 | @property 61 | def dummy_feature(self): 62 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 63 | 64 | @property 65 | def dtype(self): 66 | return self.vision_tower.dtype 67 | 68 | @property 69 | def device(self): 70 | return self.vision_tower.device 71 | 72 | @property 73 | def config(self): 74 | if self.is_loaded: 75 | return self.vision_tower.config 76 | else: 77 | return self.cfg_only 78 | 79 | @property 80 | def hidden_size(self): 81 | return self.config.hidden_size 82 | 83 | @property 84 | def num_patches_per_side(self): 85 | return self.config.image_size // self.config.patch_size 86 | 87 | @property 88 | def num_patches(self): 89 | return (self.config.image_size // self.config.patch_size) ** 2 90 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /scripts/gpt_eval/eval_qa_activitynet.sh: -------------------------------------------------------------------------------- 1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125" 2 | output_name="llava-v1.5-7b_u4FRS" 3 | pred_path="Activitynet_Zero_Shot_QA/${output_name}/merge.jsonl" 4 | output_dir="Activitynet_Zero_Shot_QA/${output_name}/${gpt_version}" 5 | output_json="Activitynet_Zero_Shot_QA/${output_name}/results_${gpt_version}.json" 6 | api_key="sk-xxx" 7 | num_tasks=25 8 | 9 | 10 | 11 | python3 scripts/gpt_eval/eval_video_qa.py \ 12 | --pred_path ${pred_path} \ 13 | --output_dir ${output_dir} \ 14 | --output_json ${output_json} \ 15 | --api_key ${api_key} \ 16 | --gpt_version ${gpt_version} \ 17 | --num_tasks ${num_tasks} -------------------------------------------------------------------------------- /scripts/gpt_eval/eval_qa_benchmark.sh: -------------------------------------------------------------------------------- 1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125" 2 | output_name="llava-v1.5-7b_u4FRS" 3 | pred_dir="Video_Benchmark/${output_name}" 4 | output_dir="Video_Benchmark/${output_name}/${gpt_version}" 5 | api_key="sk-xxx" 6 | num_tasks=32 7 | 8 | 9 | # Run the "correctness" evaluation script 10 | python3 scripts/gpt_eval/evaluate_benchmark_1_correctness.py \ 11 | --pred_path "${pred_dir}/generic.jsonl" \ 12 | --output_dir "${output_dir}/correctness_pred" \ 13 | --output_json "${output_dir}/correctness_results.json" \ 14 | --api_key ${api_key} \ 15 | --gpt_version ${gpt_version} \ 16 | --num_tasks ${num_tasks} 17 | 18 | 19 | # Run the "detailed orientation" evaluation script 20 | python3 scripts/gpt_eval/evaluate_benchmark_2_detailed_orientation.py \ 21 | --pred_path "${pred_dir}/generic.jsonl" \ 22 | --output_dir "${output_dir}/detailed_eval" \ 23 | --output_json "${output_dir}/detailed_orientation_results.json" \ 24 | --api_key ${api_key} \ 25 | --gpt_version ${gpt_version} \ 26 | --num_tasks ${num_tasks} 27 | 28 | 29 | # Run the "contextual understanding" evaluation script 30 | python3 scripts/gpt_eval/evaluate_benchmark_3_context.py \ 31 | --pred_path "${pred_dir}/generic.jsonl" \ 32 | --output_dir "${output_dir}/context_eval" \ 33 | --output_json "${output_dir}/contextual_understanding_results.json" \ 34 | --api_key ${api_key} \ 35 | --gpt_version ${gpt_version} \ 36 | --num_tasks ${num_tasks} 37 | 38 | 39 | # Run the "temporal understanding" evaluation script 40 | python3 scripts/gpt_eval/evaluate_benchmark_4_temporal.py \ 41 | --pred_path "${pred_dir}/temporal.jsonl" \ 42 | --output_dir "${output_dir}/temporal_eval" \ 43 | --output_json "${output_dir}/temporal_understanding_results.json" \ 44 | --api_key ${api_key} \ 45 | --gpt_version ${gpt_version} \ 46 | --num_tasks ${num_tasks} 47 | 48 | 49 | # Run the "consistency" evaluation script 50 | python3 scripts/gpt_eval/evaluate_benchmark_5_consistency.py \ 51 | --pred_path "${pred_dir}/consistency.jsonl" \ 52 | --output_dir "${output_dir}/consistency_eval" \ 53 | --output_json "${output_dir}/consistency_results.json" \ 54 | --api_key ${api_key} \ 55 | --gpt_version ${gpt_version} \ 56 | --num_tasks ${num_tasks} 57 | 58 | 59 | echo "All evaluations completed!" -------------------------------------------------------------------------------- /scripts/gpt_eval/eval_qa_msrvtt.sh: -------------------------------------------------------------------------------- 1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125" 2 | output_name="llava-v1.5-7b_u4FRS" 3 | pred_path="MSRVTT_Zero_Shot_QA/${output_name}/merge.jsonl" 4 | output_dir="MSRVTT_Zero_Shot_QA/${output_name}/${gpt_version}" 5 | output_json="MSRVTT_Zero_Shot_QA/${output_name}/results_${gpt_version}.json" 6 | api_key="sk-xxx" 7 | num_tasks=32 8 | 9 | 10 | 11 | python3 scripts/gpt_eval/eval_video_qa.py \ 12 | --pred_path ${pred_path} \ 13 | --output_dir ${output_dir} \ 14 | --output_json ${output_json} \ 15 | --api_key ${api_key} \ 16 | --gpt_version ${gpt_version} \ 17 | --num_tasks ${num_tasks} 18 | -------------------------------------------------------------------------------- /scripts/gpt_eval/eval_qa_msvd.sh: -------------------------------------------------------------------------------- 1 | gpt_version="gpt-3.5-turbo-0301" #"gpt-3.5-turbo-0613" "gpt-3.5-turbo-0125" 2 | output_name="llava-v1.5-7b_u4FRS" 3 | pred_path="MSVD_Zero_Shot_QA/${output_name}/merge.jsonl" 4 | output_dir="MSVD_Zero_Shot_QA/${output_name}/${gpt_version}" 5 | output_json="MSVD_Zero_Shot_QA/${output_name}/results_${gpt_version}.json" 6 | api_key="sk-xxx" 7 | num_tasks=25 8 | 9 | 10 | 11 | python3 scripts/gpt_eval/eval_video_qa.py \ 12 | --pred_path ${pred_path} \ 13 | --output_dir ${output_dir} \ 14 | --output_json ${output_json} \ 15 | --api_key ${api_key} \ 16 | --gpt_version ${gpt_version} \ 17 | --num_tasks ${num_tasks} -------------------------------------------------------------------------------- /scripts/gpt_eval/eval_video_qa.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 14 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def annotate(prediction_set, caption_files, output_dir, args): 22 | """ 23 | Evaluates question and answer pairs using GPT-3 24 | Returns a score for correctness. 25 | """ 26 | # Set the OpenAI API key. 27 | openai.api_key = args.api_key 28 | # if args.api_base is not None: 29 | # openai.api_base = args.api_base 30 | for file in caption_files: 31 | key = file[:-5] # Strip file extension 32 | qa_set = prediction_set[key] 33 | question = qa_set['q'] 34 | answer = qa_set['a'] 35 | pred = qa_set['pred'] 36 | try: 37 | # Compute the correctness score 38 | completion = openai.ChatCompletion.create( 39 | model=args.gpt_version, 40 | messages=[ 41 | { 42 | "role": "system", 43 | "content": 44 | "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " 45 | "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" 46 | "------" 47 | "##INSTRUCTIONS: " 48 | "- Focus on the meaningful match between the predicted answer and the correct answer.\n" 49 | "- Consider synonyms or paraphrases as valid matches.\n" 50 | "- Evaluate the correctness of the prediction compared to the answer." 51 | }, 52 | { 53 | "role": "user", 54 | "content": 55 | "Please evaluate the following video-based question-answer pair:\n\n" 56 | f"Question: {question}\n" 57 | f"Correct Answer: {answer}\n" 58 | f"Predicted Answer: {pred}\n\n" 59 | "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " 60 | "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." 61 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 62 | "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." 63 | } 64 | ] 65 | ) 66 | # Convert response to a Python dictionary. 67 | response_message = completion["choices"][0]["message"]["content"] 68 | response_dict = ast.literal_eval(response_message) 69 | result_qa_pair = [response_dict, qa_set] 70 | 71 | # Save the question-answer pairs to a json file. 72 | with open(f"{output_dir}/{key}.json", "w") as f: 73 | json.dump(result_qa_pair, f) 74 | 75 | except Exception as e: 76 | print(f"Error processing file '{key}': {e}") 77 | 78 | 79 | def main(): 80 | """ 81 | Main function to control the flow of the program. 82 | """ 83 | # Parse arguments. 84 | args = parse_args() 85 | 86 | file = open(args.pred_path) 87 | new_pred_contents = [eval(i.strip()) for i in file.readlines()] 88 | 89 | # pred_contents = [json.loads(line) for line in open(args.pred_path)] 90 | # # Dictionary to store the count of occurrences for each video_id 91 | # video_id_counts = {} 92 | # new_pred_contents = [] 93 | 94 | # # Iterate through each sample in pred_contents 95 | # for sample in pred_contents: 96 | # video_id = sample['id'] 97 | # if video_id in video_id_counts: 98 | # video_id_counts[video_id] += 1 99 | # else: 100 | # video_id_counts[video_id] = 0 101 | 102 | # # Create a new sample with the modified key 103 | # new_sample = sample 104 | # new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}" 105 | # new_pred_contents.append(new_sample) 106 | 107 | # Generating list of id's and corresponding files 108 | id_list = [x['id'] for x in new_pred_contents] 109 | caption_files = [f"{id}.json" for id in id_list] 110 | 111 | output_dir = args.output_dir 112 | # Generate output directory if not exists. 113 | if not os.path.exists(output_dir): 114 | os.makedirs(output_dir) 115 | 116 | # Preparing dictionary of question-answer sets 117 | prediction_set = {} 118 | for sample in new_pred_contents: 119 | id = sample['id'] 120 | question = sample['question'] 121 | answer = sample['answer'] 122 | pred = sample['pred'] 123 | qa_set = {"q": question, "a": answer, "pred": pred} 124 | prediction_set[id] = qa_set 125 | 126 | num_tasks = args.num_tasks 127 | 128 | # While loop to ensure that all captions are processed. 129 | while True: 130 | try: 131 | # Files that have not been processed yet. 132 | completed_files = os.listdir(output_dir) 133 | print(f"completed_files: {len(completed_files)}") 134 | 135 | # Files that have not been processed yet. 136 | incomplete_files = [f for f in caption_files if f not in completed_files] 137 | print(f"incomplete_files: {len(incomplete_files)}") 138 | 139 | # Break the loop when there are no incomplete files 140 | if len(incomplete_files) == 0: 141 | break 142 | if len(incomplete_files) <= num_tasks: 143 | num_tasks = 1 144 | 145 | # Split tasks into parts. 146 | part_len = len(incomplete_files) // num_tasks 147 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 148 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 149 | 150 | # Use a pool of workers to process the files in parallel. 151 | with Pool() as pool: 152 | pool.starmap(annotate, task_args) 153 | 154 | except Exception as e: 155 | print(f"Error: {e}") 156 | 157 | # Combine all the processed files into one 158 | combined_contents = {} 159 | json_path = args.output_json 160 | 161 | # Iterate through json files 162 | for file_name in os.listdir(output_dir): 163 | if file_name.endswith(".json"): 164 | file_path = os.path.join(output_dir, file_name) 165 | with open(file_path, "r") as json_file: 166 | content = json.load(json_file) 167 | combined_contents[file_name[:-5]] = content 168 | 169 | # Write combined content to a json file 170 | with open(json_path, "w") as json_file: 171 | json.dump(combined_contents, json_file) 172 | print("All evaluation completed!") 173 | 174 | # Calculate average score and accuracy 175 | score_sum = 0 176 | count = 0 177 | yes_count = 0 178 | no_count = 0 179 | for key, result in tqdm(combined_contents.items()): 180 | try: 181 | # Computing score 182 | count += 1 183 | score_match = result[0]['score'] 184 | score = int(score_match) 185 | score_sum += score 186 | 187 | # Computing accuracy 188 | pred = result[0]['pred'] 189 | if "yes" in pred.lower(): 190 | yes_count += 1 191 | elif "no" in pred.lower(): 192 | no_count += 1 193 | except: 194 | print(result) 195 | 196 | average_score = score_sum / count 197 | accuracy = yes_count / (yes_count + no_count) 198 | print("Yes count:", yes_count) 199 | print("No count:", no_count) 200 | print("Accuracy:", accuracy) 201 | print("Average score:", average_score) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | 207 | -------------------------------------------------------------------------------- /scripts/gpt_eval/evaluate_benchmark_1_correctness.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 14 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def annotate(prediction_set, caption_files, output_dir, args): 22 | """ 23 | Evaluates question and answer pairs using GPT-3 24 | Returns a score for correctness. 25 | """ 26 | # Set the OpenAI API key. 27 | openai.api_key = args.api_key 28 | # if args.api_base is not None: 29 | # openai.api_base = args.api_base 30 | for file in caption_files: 31 | key = file[:-5] # Strip file extension 32 | qa_set = prediction_set[key] 33 | question = qa_set['q'] 34 | answer = qa_set['a'] 35 | pred = qa_set['pred'] 36 | try: 37 | # Compute the correctness score 38 | completion = openai.ChatCompletion.create( 39 | model="gpt-3.5-turbo", 40 | messages=[ 41 | { 42 | "role": "system", 43 | "content": 44 | "You are an intelligent chatbot designed for evaluating the factual accuracy of generative outputs for video-based question-answer pairs. " 45 | "Your task is to compare the predicted answer with the correct answer and determine if they are factually consistent. Here's how you can accomplish the task:" 46 | "------" 47 | "##INSTRUCTIONS: " 48 | "- Focus on the factual consistency between the predicted answer and the correct answer. The predicted answer should not contain any misinterpretations or misinformation.\n" 49 | "- The predicted answer must be factually accurate and align with the video content.\n" 50 | "- Consider synonyms or paraphrases as valid matches.\n" 51 | "- Evaluate the factual accuracy of the prediction compared to the answer." 52 | }, 53 | { 54 | "role": "user", 55 | "content": 56 | "Please evaluate the following video-based question-answer pair:\n\n" 57 | f"Question: {question}\n" 58 | f"Correct Answer: {answer}\n" 59 | f"Predicted Answer: {pred}\n\n" 60 | "Provide your evaluation only as a factual accuracy score where the factual accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of factual consistency. " 61 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the factual accuracy score in INTEGER, not STRING." 62 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 63 | "For example, your response should look like this: {''score': 4.8}." 64 | } 65 | ] 66 | ) 67 | # Convert response to a Python dictionary. 68 | response_message = completion["choices"][0]["message"]["content"] 69 | response_dict = ast.literal_eval(response_message) 70 | result_qa_pair = [response_dict, qa_set] 71 | 72 | # Save the question-answer pairs to a json file. 73 | with open(f"{output_dir}/{key}.json", "w") as f: 74 | json.dump(result_qa_pair, f) 75 | 76 | except Exception as e: 77 | print(f"Error processing file '{key}': {e}") 78 | 79 | 80 | def main(): 81 | """ 82 | Main function to control the flow of the program. 83 | """ 84 | # Parse arguments. 85 | args = parse_args() 86 | 87 | pred_contents = [json.loads(line) for line in open(args.pred_path)] 88 | 89 | # Dictionary to store the count of occurrences for each video_id 90 | video_id_counts = {} 91 | new_pred_contents = [] 92 | 93 | # Iterate through each sample in pred_contents 94 | for sample in pred_contents: 95 | video_id = sample['video_name'] 96 | if video_id in video_id_counts: 97 | video_id_counts[video_id] += 1 98 | else: 99 | video_id_counts[video_id] = 0 100 | 101 | # Create a new sample with the modified key 102 | new_sample = sample 103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 104 | new_pred_contents.append(new_sample) 105 | 106 | # Generating list of id's and corresponding files 107 | id_list = [x['video_name'] for x in new_pred_contents] 108 | caption_files = [f"{id}.json" for id in id_list] 109 | 110 | output_dir = args.output_dir 111 | # Generate output directory if not exists. 112 | if not os.path.exists(output_dir): 113 | os.makedirs(output_dir) 114 | 115 | # Preparing dictionary of question-answer sets 116 | prediction_set = {} 117 | for sample in new_pred_contents: 118 | id = sample['video_name'] 119 | question = sample['Q'] 120 | answer = sample['A'] 121 | pred = sample['pred'] 122 | qa_set = {"q": question, "a": answer, "pred": pred} 123 | prediction_set[id] = qa_set 124 | 125 | num_tasks = args.num_tasks 126 | 127 | # While loop to ensure that all captions are processed. 128 | while True: 129 | try: 130 | # Files that have not been processed yet. 131 | completed_files = os.listdir(output_dir) 132 | print(f"completed_files: {len(completed_files)}") 133 | 134 | # Files that have not been processed yet. 135 | incomplete_files = [f for f in caption_files if f not in completed_files] 136 | print(f"incomplete_files: {len(incomplete_files)}") 137 | 138 | # Break the loop when there are no incomplete files 139 | if len(incomplete_files) == 0: 140 | break 141 | if len(incomplete_files) <= num_tasks: 142 | num_tasks = 1 143 | 144 | # Split tasks into parts. 145 | part_len = len(incomplete_files) // num_tasks 146 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 147 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 148 | 149 | # Use a pool of workers to process the files in parallel. 150 | with Pool() as pool: 151 | pool.starmap(annotate, task_args) 152 | 153 | except Exception as e: 154 | print(f"Error: {e}") 155 | 156 | # Combine all the processed files into one 157 | combined_contents = {} 158 | json_path = args.output_json 159 | 160 | # Iterate through json files 161 | for file_name in os.listdir(output_dir): 162 | if file_name.endswith(".json"): 163 | file_path = os.path.join(output_dir, file_name) 164 | with open(file_path, "r") as json_file: 165 | content = json.load(json_file) 166 | combined_contents[file_name[:-5]] = content 167 | 168 | # Write combined content to a json file 169 | with open(json_path, "w") as json_file: 170 | json.dump(combined_contents, json_file) 171 | print("All evaluation completed!") 172 | 173 | # Calculate average score 174 | score_sum = 0 175 | count = 0 176 | for key, result in combined_contents.items(): 177 | count += 1 178 | score_match = result[0]['score'] 179 | score = int(score_match) 180 | score_sum += score 181 | average_score = score_sum / count 182 | 183 | print("Average score for correctness:", average_score) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | 189 | -------------------------------------------------------------------------------- /scripts/gpt_eval/evaluate_benchmark_2_detailed_orientation.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 14 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def annotate(prediction_set, caption_files, output_dir, args): 22 | """ 23 | Evaluates question and answer pairs using GPT-3 and 24 | returns a score for detailed orientation. 25 | """ 26 | # Set the OpenAI API key. 27 | openai.api_key = args.api_key 28 | # if args.api_base is not None: 29 | # openai.api_base = args.api_base 30 | for file in caption_files: 31 | key = file[:-5] # Strip file extension 32 | qa_set = prediction_set[key] 33 | question = qa_set['q'] 34 | answer = qa_set['a'] 35 | pred = qa_set['pred'] 36 | try: 37 | # Compute the detailed-orientation score 38 | completion = openai.ChatCompletion.create( 39 | model="gpt-3.5-turbo", 40 | messages=[ 41 | { 42 | "role": "system", 43 | "content": 44 | "You are an intelligent chatbot designed for evaluating the detail orientation of generative outputs for video-based question-answer pairs. " 45 | "Your task is to compare the predicted answer with the correct answer and determine its level of detail, considering both completeness and specificity. Here's how you can accomplish the task:" 46 | "------" 47 | "##INSTRUCTIONS: " 48 | "- Check if the predicted answer covers all major points from the video. The response should not leave out any key aspects.\n" 49 | "- Evaluate whether the predicted answer includes specific details rather than just generic points. It should provide comprehensive information that is tied to specific elements of the video.\n" 50 | "- Consider synonyms or paraphrases as valid matches.\n" 51 | "- Provide a single evaluation score that reflects the level of detail orientation of the prediction, considering both completeness and specificity." 52 | }, 53 | { 54 | "role": "user", 55 | "content": 56 | "Please evaluate the following video-based question-answer pair:\n\n" 57 | f"Question: {question}\n" 58 | f"Correct Answer: {answer}\n" 59 | f"Predicted Answer: {pred}\n\n" 60 | "Provide your evaluation only as a detail orientation score where the detail orientation score is an integer value between 0 and 5, with 5 indicating the highest level of detail orientation. " 61 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the detail orientation score in INTEGER, not STRING." 62 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 63 | "For example, your response should look like this: {''score': 4.8}." 64 | } 65 | ] 66 | ) 67 | # Convert response to a Python dictionary. 68 | response_message = completion["choices"][0]["message"]["content"] 69 | response_dict = ast.literal_eval(response_message) 70 | result_qa_pair = [response_dict, qa_set] 71 | 72 | # Save the question-answer pairs to a json file. 73 | with open(f"{output_dir}/{key}.json", "w") as f: 74 | json.dump(result_qa_pair, f) 75 | 76 | except Exception as e: 77 | print(f"Error processing file '{key}': {e}") 78 | 79 | 80 | def main(): 81 | """ 82 | Main function to control the flow of the program. 83 | """ 84 | # Parse arguments. 85 | args = parse_args() 86 | 87 | pred_contents = [json.loads(line) for line in open(args.pred_path)] 88 | 89 | # Dictionary to store the count of occurrences for each video_id 90 | video_id_counts = {} 91 | new_pred_contents = [] 92 | 93 | # Iterate through each sample in pred_contents 94 | for sample in pred_contents: 95 | video_id = sample['video_name'] 96 | if video_id in video_id_counts: 97 | video_id_counts[video_id] += 1 98 | else: 99 | video_id_counts[video_id] = 0 100 | 101 | # Create a new sample with the modified key 102 | new_sample = sample 103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 104 | new_pred_contents.append(new_sample) 105 | 106 | # Generating list of id's and corresponding files 107 | id_list = [x['video_name'] for x in new_pred_contents] 108 | caption_files = [f"{id}.json" for id in id_list] 109 | 110 | output_dir = args.output_dir 111 | # Generate output directory if not exists. 112 | if not os.path.exists(output_dir): 113 | os.makedirs(output_dir) 114 | 115 | # Preparing dictionary of question-answer sets 116 | prediction_set = {} 117 | for sample in new_pred_contents: 118 | id = sample['video_name'] 119 | question = sample['Q'] 120 | answer = sample['A'] 121 | pred = sample['pred'] 122 | qa_set = {"q": question, "a": answer, "pred": pred} 123 | prediction_set[id] = qa_set 124 | 125 | num_tasks = args.num_tasks 126 | 127 | # While loop to ensure that all captions are processed. 128 | while True: 129 | try: 130 | # Files that have not been processed yet. 131 | completed_files = os.listdir(output_dir) 132 | print(f"completed_files: {len(completed_files)}") 133 | 134 | # Files that have not been processed yet. 135 | incomplete_files = [f for f in caption_files if f not in completed_files] 136 | print(f"incomplete_files: {len(incomplete_files)}") 137 | 138 | # Break the loop when there are no incomplete files 139 | if len(incomplete_files) == 0: 140 | break 141 | if len(incomplete_files) <= num_tasks: 142 | num_tasks = 1 143 | 144 | # Split tasks into parts. 145 | part_len = len(incomplete_files) // num_tasks 146 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 147 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 148 | 149 | # Use a pool of workers to process the files in parallel. 150 | with Pool() as pool: 151 | pool.starmap(annotate, task_args) 152 | 153 | except Exception as e: 154 | print(f"Error: {e}") 155 | 156 | # Combine all the processed files into one 157 | combined_contents = {} 158 | json_path = args.output_json 159 | 160 | # Iterate through json files 161 | for file_name in os.listdir(output_dir): 162 | if file_name.endswith(".json"): 163 | file_path = os.path.join(output_dir, file_name) 164 | with open(file_path, "r") as json_file: 165 | content = json.load(json_file) 166 | combined_contents[file_name[:-5]] = content 167 | 168 | # Write combined content to a json file 169 | with open(json_path, "w") as json_file: 170 | json.dump(combined_contents, json_file) 171 | print("All evaluation completed!") 172 | 173 | # Calculate average score 174 | score_sum = 0 175 | count = 0 176 | for key, result in combined_contents.items(): 177 | count += 1 178 | score_match = result[0]['score'] 179 | score = int(score_match) 180 | score_sum += score 181 | average_score = score_sum / count 182 | 183 | print("Average score for detailed orientation:", average_score) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | 189 | -------------------------------------------------------------------------------- /scripts/gpt_eval/evaluate_benchmark_3_context.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 14 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def annotate(prediction_set, caption_files, output_dir, args): 22 | """ 23 | Evaluates question and answer pairs using GPT-3 and 24 | returns a score for contextual understanding. 25 | """ 26 | # Set the OpenAI API key. 27 | openai.api_key = args.api_key 28 | # if args.api_base is not None: 29 | # openai.api_base = args.api_base 30 | for file in caption_files: 31 | key = file[:-5] # Strip file extension 32 | qa_set = prediction_set[key] 33 | question = qa_set['q'] 34 | answer = qa_set['a'] 35 | pred = qa_set['pred'] 36 | try: 37 | # Compute the contextual understanding score 38 | completion = openai.ChatCompletion.create( 39 | model="gpt-3.5-turbo", 40 | messages=[ 41 | { 42 | "role": "system", 43 | "content": 44 | "You are an intelligent chatbot designed for evaluating the contextual understanding of generative outputs for video-based question-answer pairs. " 45 | "Your task is to compare the predicted answer with the correct answer and determine if the generated response aligns with the overall context of the video content. Here's how you can accomplish the task:" 46 | "------" 47 | "##INSTRUCTIONS: " 48 | "- Evaluate whether the predicted answer aligns with the overall context of the video content. It should not provide information that is out of context or misaligned.\n" 49 | "- The predicted answer must capture the main themes and sentiments of the video.\n" 50 | "- Consider synonyms or paraphrases as valid matches.\n" 51 | "- Provide your evaluation of the contextual understanding of the prediction compared to the answer." 52 | }, 53 | { 54 | "role": "user", 55 | "content": 56 | "Please evaluate the following video-based question-answer pair:\n\n" 57 | f"Question: {question}\n" 58 | f"Correct Answer: {answer}\n" 59 | f"Predicted Answer: {pred}\n\n" 60 | "Provide your evaluation only as a contextual understanding score where the contextual understanding score is an integer value between 0 and 5, with 5 indicating the highest level of contextual understanding. " 61 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is contextual understanding score in INTEGER, not STRING." 62 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 63 | "For example, your response should look like this: {''score': 4.8}." 64 | } 65 | ] 66 | ) 67 | # Convert response to a Python dictionary. 68 | response_message = completion["choices"][0]["message"]["content"] 69 | response_dict = ast.literal_eval(response_message) 70 | result_qa_pair = [response_dict, qa_set] 71 | 72 | # Save the question-answer pairs to a json file. 73 | with open(f"{output_dir}/{key}.json", "w") as f: 74 | json.dump(result_qa_pair, f) 75 | 76 | except Exception as e: 77 | print(f"Error processing file '{key}': {e}") 78 | 79 | 80 | def main(): 81 | """ 82 | Main function to control the flow of the program. 83 | """ 84 | # Parse arguments. 85 | args = parse_args() 86 | 87 | pred_contents = [json.loads(line) for line in open(args.pred_path)] 88 | 89 | # Dictionary to store the count of occurrences for each video_id 90 | video_id_counts = {} 91 | new_pred_contents = [] 92 | 93 | # Iterate through each sample in pred_contents 94 | for sample in pred_contents: 95 | video_id = sample['video_name'] 96 | if video_id in video_id_counts: 97 | video_id_counts[video_id] += 1 98 | else: 99 | video_id_counts[video_id] = 0 100 | 101 | # Create a new sample with the modified key 102 | new_sample = sample 103 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 104 | new_pred_contents.append(new_sample) 105 | 106 | # Generating list of id's and corresponding files 107 | id_list = [x['video_name'] for x in new_pred_contents] 108 | caption_files = [f"{id}.json" for id in id_list] 109 | 110 | output_dir = args.output_dir 111 | # Generate output directory if not exists. 112 | if not os.path.exists(output_dir): 113 | os.makedirs(output_dir) 114 | 115 | # Preparing dictionary of question-answer sets 116 | prediction_set = {} 117 | for sample in new_pred_contents: 118 | id = sample['video_name'] 119 | question = sample['Q'] 120 | answer = sample['A'] 121 | pred = sample['pred'] 122 | qa_set = {"q": question, "a": answer, "pred": pred} 123 | prediction_set[id] = qa_set 124 | 125 | num_tasks = args.num_tasks 126 | 127 | # While loop to ensure that all captions are processed. 128 | while True: 129 | try: 130 | # Files that have not been processed yet. 131 | completed_files = os.listdir(output_dir) 132 | print(f"completed_files: {len(completed_files)}") 133 | 134 | # Files that have not been processed yet. 135 | incomplete_files = [f for f in caption_files if f not in completed_files] 136 | print(f"incomplete_files: {len(incomplete_files)}") 137 | 138 | # Break the loop when there are no incomplete files 139 | if len(incomplete_files) == 0: 140 | break 141 | if len(incomplete_files) <= num_tasks: 142 | num_tasks = 1 143 | 144 | # Split tasks into parts. 145 | part_len = len(incomplete_files) // num_tasks 146 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 147 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 148 | 149 | # Use a pool of workers to process the files in parallel. 150 | with Pool() as pool: 151 | pool.starmap(annotate, task_args) 152 | 153 | except Exception as e: 154 | print(f"Error: {e}") 155 | 156 | # Combine all the processed files into one 157 | combined_contents = {} 158 | json_path = args.output_json 159 | 160 | # Iterate through json files 161 | for file_name in os.listdir(output_dir): 162 | if file_name.endswith(".json"): 163 | file_path = os.path.join(output_dir, file_name) 164 | with open(file_path, "r") as json_file: 165 | content = json.load(json_file) 166 | combined_contents[file_name[:-5]] = content 167 | 168 | # Write combined content to a json file 169 | with open(json_path, "w") as json_file: 170 | json.dump(combined_contents, json_file) 171 | print("All evaluation completed!") 172 | 173 | # Calculate average score 174 | score_sum = 0 175 | count = 0 176 | for key, result in combined_contents.items(): 177 | count += 1 178 | score_match = result[0]['score'] 179 | score = int(score_match) 180 | score_sum += score 181 | average_score = score_sum / count 182 | 183 | print("Average score for contextual understanding:", average_score) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | 189 | -------------------------------------------------------------------------------- /scripts/gpt_eval/evaluate_benchmark_4_temporal.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 14 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def annotate(prediction_set, caption_files, output_dir, args): 22 | """ 23 | Evaluates question and answer pairs using GPT-3 and 24 | returns a score for temporal understanding. 25 | """ 26 | # Set the OpenAI API key. 27 | openai.api_key = args.api_key 28 | # if args.api_base is not None: 29 | # openai.api_base = args.api_base 30 | for file in caption_files: 31 | key = file[:-5] # Strip file extension 32 | qa_set = prediction_set[key] 33 | question = qa_set['q'] 34 | answer = qa_set['a'] 35 | pred = qa_set['pred'] 36 | try: 37 | # Compute the temporal understanding score 38 | completion = openai.ChatCompletion.create( 39 | model="gpt-3.5-turbo", 40 | messages=[ 41 | { 42 | "role": "system", 43 | "content": 44 | "You are an intelligent chatbot designed for evaluating the temporal understanding of generative outputs for video-based question-answer pairs. " 45 | "Your task is to compare the predicted answer with the correct answer and determine if they correctly reflect the temporal sequence of events in the video content. Here's how you can accomplish the task:" 46 | "------" 47 | "##INSTRUCTIONS: " 48 | "- Focus on the temporal consistency between the predicted answer and the correct answer. The predicted answer should correctly reflect the sequence of events or details as they are presented in the video content.\n" 49 | "- Consider synonyms or paraphrases as valid matches, but only if the temporal order is maintained.\n" 50 | "- Evaluate the temporal accuracy of the prediction compared to the answer." 51 | }, 52 | { 53 | "role": "user", 54 | "content": 55 | "Please evaluate the following video-based question-answer pair:\n\n" 56 | f"Question: {question}\n" 57 | f"Correct Answer: {answer}\n" 58 | f"Predicted Answer: {pred}\n\n" 59 | "Provide your evaluation only as a temporal accuracy score where the temporal accuracy score is an integer value between 0 and 5, with 5 indicating the highest level of temporal consistency. " 60 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the temporal accuracy score in INTEGER, not STRING." 61 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 62 | "For example, your response should look like this: {''score': 4.8}." 63 | } 64 | ] 65 | ) 66 | # Convert response to a Python dictionary. 67 | response_message = completion["choices"][0]["message"]["content"] 68 | response_dict = ast.literal_eval(response_message) 69 | result_qa_pair = [response_dict, qa_set] 70 | 71 | # Save the question-answer pairs to a json file. 72 | with open(f"{output_dir}/{key}.json", "w") as f: 73 | json.dump(result_qa_pair, f) 74 | 75 | except Exception as e: 76 | print(f"Error processing file '{key}': {e}") 77 | 78 | 79 | def main(): 80 | """ 81 | Main function to control the flow of the program. 82 | """ 83 | # Parse arguments. 84 | args = parse_args() 85 | 86 | pred_contents = [json.loads(line) for line in open(args.pred_path)] 87 | 88 | # Dictionary to store the count of occurrences for each video_id 89 | video_id_counts = {} 90 | new_pred_contents = [] 91 | 92 | # Iterate through each sample in pred_contents 93 | for sample in pred_contents: 94 | video_id = sample['video_name'] 95 | if video_id in video_id_counts: 96 | video_id_counts[video_id] += 1 97 | else: 98 | video_id_counts[video_id] = 0 99 | 100 | # Create a new sample with the modified key 101 | new_sample = sample 102 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 103 | new_pred_contents.append(new_sample) 104 | 105 | # Generating list of id's and corresponding files 106 | id_list = [x['video_name'] for x in new_pred_contents] 107 | caption_files = [f"{id}.json" for id in id_list] 108 | 109 | output_dir = args.output_dir 110 | # Generate output directory if not exists. 111 | if not os.path.exists(output_dir): 112 | os.makedirs(output_dir) 113 | 114 | # Preparing dictionary of question-answer sets 115 | prediction_set = {} 116 | for sample in new_pred_contents: 117 | id = sample['video_name'] 118 | question = sample['Q'] 119 | answer = sample['A'] 120 | pred = sample['pred'] 121 | qa_set = {"q": question, "a": answer, "pred": pred} 122 | prediction_set[id] = qa_set 123 | 124 | num_tasks = args.num_tasks 125 | 126 | # While loop to ensure that all captions are processed. 127 | while True: 128 | try: 129 | # Files that have not been processed yet. 130 | completed_files = os.listdir(output_dir) 131 | print(f"completed_files: {len(completed_files)}") 132 | 133 | # Files that have not been processed yet. 134 | incomplete_files = [f for f in caption_files if f not in completed_files] 135 | print(f"incomplete_files: {len(incomplete_files)}") 136 | 137 | # Break the loop when there are no incomplete files 138 | if len(incomplete_files) == 0: 139 | break 140 | if len(incomplete_files) <= num_tasks: 141 | num_tasks = 1 142 | 143 | # Split tasks into parts. 144 | part_len = len(incomplete_files) // num_tasks 145 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 146 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 147 | 148 | # Use a pool of workers to process the files in parallel. 149 | with Pool() as pool: 150 | pool.starmap(annotate, task_args) 151 | 152 | except Exception as e: 153 | print(f"Error: {e}") 154 | 155 | # Combine all the processed files into one 156 | combined_contents = {} 157 | json_path = args.output_json 158 | 159 | # Iterate through json files 160 | for file_name in os.listdir(output_dir): 161 | if file_name.endswith(".json"): 162 | file_path = os.path.join(output_dir, file_name) 163 | with open(file_path, "r") as json_file: 164 | content = json.load(json_file) 165 | combined_contents[file_name[:-5]] = content 166 | 167 | # Write combined content to a json file 168 | with open(json_path, "w") as json_file: 169 | json.dump(combined_contents, json_file) 170 | print("All evaluation completed!") 171 | 172 | # Calculate average score 173 | score_sum = 0 174 | count = 0 175 | for key, result in combined_contents.items(): 176 | count += 1 177 | score_match = result[0]['score'] 178 | score = int(score_match) 179 | score_sum += score 180 | average_score = score_sum / count 181 | 182 | print("Average score temporal understanding:", average_score) 183 | 184 | 185 | if __name__ == "__main__": 186 | main() 187 | 188 | -------------------------------------------------------------------------------- /scripts/gpt_eval/evaluate_benchmark_5_consistency.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import argparse 4 | import json 5 | import ast 6 | from multiprocessing.pool import Pool 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") 11 | parser.add_argument("--pred_path", default=r'', help="The path to file containing prediction.") 12 | parser.add_argument("--output_dir", default=r'', help="The path to save annotation json files.") 13 | parser.add_argument("--output_json", default=r'', help="The path to save annotation final combined json file.") 14 | parser.add_argument("--api_key", default="", help="OpenAI API key.") 15 | parser.add_argument("--gpt_version", default="gpt-3.5-turbo", type=str, help="OpenAI API base.") 16 | parser.add_argument("--num_tasks", default=1, type=int, help="Number of splits.") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def annotate(prediction_set, caption_files, output_dir, args): 22 | """ 23 | Evaluates question and answer pairs using GPT-3 and 24 | returns a score for consistency. 25 | """ 26 | # Set the OpenAI API key. 27 | openai.api_key = args.api_key 28 | # if args.api_base is not None: 29 | # openai.api_base = args.api_base 30 | for file in caption_files: 31 | key = file[:-5] # Strip file extension 32 | qa_set = prediction_set[key] 33 | question1 = qa_set['q1'] 34 | question2 = qa_set['q2'] 35 | answer = qa_set['a'] 36 | pred1 = qa_set['pred1'] 37 | pred2 = qa_set['pred2'] 38 | try: 39 | # Compute the consistency score 40 | completion = openai.ChatCompletion.create( 41 | model="gpt-3.5-turbo", 42 | messages=[ 43 | { 44 | "role": "system", 45 | "content": 46 | "You are an intelligent chatbot designed for evaluating the consistency of generative outputs for similar video-based question-answer pairs. " 47 | "You will be given two very similar questions, a common answer common to both the questions and predicted answers for the two questions ." 48 | "Your task is to compare the predicted answers for two very similar question, with a common correct answer and determine if they are consistent. Here's how you can accomplish the task:" 49 | "------" 50 | "##INSTRUCTIONS: " 51 | "- Focus on the consistency between the two predicted answers and the correct answer. Both predicted answers should correspond to the correct answer and to each other, and should not contain any contradictions or significant differences in the conveyed information.\n" 52 | "- Both predicted answers must be consistent with each other and the correct answer, in terms of the information they provide about the video content.\n" 53 | "- Consider synonyms or paraphrases as valid matches, but only if they maintain the consistency in the conveyed information.\n" 54 | "- Evaluate the consistency of the two predicted answers compared to the correct answer." 55 | }, 56 | { 57 | "role": "user", 58 | "content": 59 | "Please evaluate the following video-based question-answer pair:\n\n" 60 | f"Question 1: {question1}\n" 61 | f"Question 2: {question2}\n" 62 | f"Correct Answer: {answer}\n" 63 | f"Predicted Answer to Question 1: {pred1}\n" 64 | f"Predicted Answer to Question 2: {pred2}\n\n" 65 | "Provide your evaluation only as a consistency score where the consistency score is an integer value between 0 and 5, with 5 indicating the highest level of consistency. " 66 | "Please generate the response in the form of a Python dictionary string with keys 'score', where its value is the consistency score in INTEGER, not STRING." 67 | "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " 68 | "For example, your response should look like this: {''score': 4.8}." 69 | } 70 | ] 71 | ) 72 | # Convert response to a Python dictionary. 73 | response_message = completion["choices"][0]["message"]["content"] 74 | response_dict = ast.literal_eval(response_message) 75 | result_qa_pair = [response_dict, qa_set] 76 | 77 | # Save the question-answer pairs to a json file. 78 | with open(f"{output_dir}/{key}.json", "w") as f: 79 | json.dump(result_qa_pair, f) 80 | 81 | except Exception as e: 82 | print(f"Error processing file '{key}': {e}") 83 | 84 | 85 | def main(): 86 | """ 87 | Main function to control the flow of the program. 88 | """ 89 | # Parse arguments. 90 | args = parse_args() 91 | 92 | pred_contents = [json.loads(line) for line in open(args.pred_path)] 93 | 94 | # Dictionary to store the count of occurrences for each video_id 95 | video_id_counts = {} 96 | new_pred_contents = [] 97 | 98 | # Iterate through each sample in pred_contents 99 | for sample in pred_contents: 100 | video_id = sample['video_name'] 101 | if video_id in video_id_counts: 102 | video_id_counts[video_id] += 1 103 | else: 104 | video_id_counts[video_id] = 0 105 | 106 | # Create a new sample with the modified key 107 | new_sample = sample 108 | new_sample['video_name'] = f"{video_id}_{video_id_counts[video_id]}" 109 | new_pred_contents.append(new_sample) 110 | 111 | # Generating list of id's and corresponding files 112 | id_list = [x['video_name'] for x in new_pred_contents] 113 | caption_files = [f"{id}.json" for id in id_list] 114 | 115 | output_dir = args.output_dir 116 | # Generate output directory if not exists. 117 | if not os.path.exists(output_dir): 118 | os.makedirs(output_dir) 119 | 120 | # Preparing dictionary of question-answer sets 121 | prediction_set = {} 122 | for sample in new_pred_contents: 123 | id = sample['video_name'] 124 | question1 = sample['Q1'] 125 | question2 = sample['Q2'] 126 | answer = sample['A'] 127 | pred1 = sample['pred1'] 128 | pred2 = sample['pred2'] 129 | qa_set = {"q1": question1, "q2": question2, "a": answer, "pred1": pred1, "pred2": pred2} 130 | prediction_set[id] = qa_set 131 | 132 | num_tasks = args.num_tasks 133 | 134 | # While loop to ensure that all captions are processed. 135 | while True: 136 | try: 137 | # Files that have not been processed yet. 138 | completed_files = os.listdir(output_dir) 139 | print(f"completed_files: {len(completed_files)}") 140 | 141 | # Files that have not been processed yet. 142 | incomplete_files = [f for f in caption_files if f not in completed_files] 143 | print(f"incomplete_files: {len(incomplete_files)}") 144 | 145 | # Break the loop when there are no incomplete files 146 | if len(incomplete_files) == 0: 147 | break 148 | if len(incomplete_files) <= num_tasks: 149 | num_tasks = 1 150 | 151 | # Split tasks into parts. 152 | part_len = len(incomplete_files) // num_tasks 153 | all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)] 154 | task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts] 155 | 156 | # Use a pool of workers to process the files in parallel. 157 | with Pool() as pool: 158 | pool.starmap(annotate, task_args) 159 | 160 | except Exception as e: 161 | print(f"Error: {e}") 162 | 163 | # Combine all the processed files into one 164 | combined_contents = {} 165 | json_path = args.output_json 166 | 167 | # Iterate through json files 168 | for file_name in os.listdir(output_dir): 169 | if file_name.endswith(".json"): 170 | file_path = os.path.join(output_dir, file_name) 171 | with open(file_path, "r") as json_file: 172 | content = json.load(json_file) 173 | combined_contents[file_name[:-5]] = content 174 | 175 | # Write combined content to a json file 176 | with open(json_path, "w") as json_file: 177 | json.dump(combined_contents, json_file) 178 | print("All evaluation completed!") 179 | 180 | # Calculate average score 181 | score_sum = 0 182 | count = 0 183 | for key, result in combined_contents.items(): 184 | count += 1 185 | score_match = result[0]['score'] 186 | score = int(score_match) 187 | score_sum += score 188 | average_score = score_sum / count 189 | 190 | print("Average score for consistency:", average_score) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | 196 | -------------------------------------------------------------------------------- /scripts/infer_video/run_benchmark_consistency_qa.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-7b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/VideoChatGPT_Test_Videos" 6 | gt_file="${GPT_Zero_Shot_QA}/consistency_qa.json" 7 | output_dir="output/Benchmark_Consistency_QA/${CKPT_NAME}_u${num_frames}FRS" 8 | 9 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 10 | IFS=',' read -ra GPULIST <<< "$gpu_list" 11 | 12 | CHUNKS=${#GPULIST[@]} 13 | 14 | 15 | for IDX in $(seq 0 $((CHUNKS-1))); do 16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_benchmark_consistency.py \ 17 | --video_dir ${video_dir} \ 18 | --gt_file ${gt_file} \ 19 | --output_dir ${output_dir} \ 20 | --output_name ${CHUNKS}_${IDX} \ 21 | --model_name ${model_path} \ 22 | --num_chunks $CHUNKS \ 23 | --num_frames $num_frames \ 24 | --conv-mode vicuna_v1 \ 25 | --temperature 0 \ 26 | --chunk_idx $IDX & 27 | done 28 | 29 | wait 30 | 31 | output_file=${output_dir}/consistency.jsonl 32 | 33 | # Clear out the output file if it exists. 34 | > "$output_file" 35 | 36 | # Loop through the indices and concatenate each file. 37 | for IDX in $(seq 0 $((CHUNKS-1))); do 38 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 39 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_benchmark_generic_qa.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-7b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/VideoChatGPT_Test_Videos" 6 | gt_file="${GPT_Zero_Shot_QA}/generic_qa.json" 7 | output_dir="output/Benchmark_Generic_QA/${CKPT_NAME}_u${num_frames}FRS" 8 | 9 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 10 | IFS=',' read -ra GPULIST <<< "$gpu_list" 11 | 12 | CHUNKS=${#GPULIST[@]} 13 | 14 | 15 | for IDX in $(seq 0 $((CHUNKS-1))); do 16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_benchmark_general.py \ 17 | --video_dir ${video_dir} \ 18 | --gt_file ${gt_file} \ 19 | --output_dir ${output_dir} \ 20 | --output_name ${CHUNKS}_${IDX} \ 21 | --model_name ${model_path} \ 22 | --num_chunks $CHUNKS \ 23 | --num_frames $num_frames \ 24 | --conv-mode vicuna_v1 \ 25 | --temperature 0 \ 26 | --chunk_idx $IDX & 27 | done 28 | 29 | wait 30 | 31 | output_file=${output_dir}/generic.jsonl 32 | 33 | # Clear out the output file if it exists. 34 | > "$output_file" 35 | 36 | # Loop through the indices and concatenate each file. 37 | for IDX in $(seq 0 $((CHUNKS-1))); do 38 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 39 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_benchmark_temporal_qa.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-7b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/VideoChatGPT_Test_Videos" 6 | gt_file="${GPT_Zero_Shot_QA}/temporal_qa.json" 7 | output_dir="output/Benchmark_Temporal_QA/${CKPT_NAME}_u${num_frames}FRS" 8 | 9 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 10 | IFS=',' read -ra GPULIST <<< "$gpu_list" 11 | 12 | CHUNKS=${#GPULIST[@]} 13 | 14 | 15 | for IDX in $(seq 0 $((CHUNKS-1))); do 16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_benchmark_general.py \ 17 | --video_dir ${video_dir} \ 18 | --gt_file ${gt_file} \ 19 | --output_dir ${output_dir} \ 20 | --output_name ${CHUNKS}_${IDX} \ 21 | --model_name ${model_path} \ 22 | --num_chunks $CHUNKS \ 23 | --num_frames $num_frames \ 24 | --conv-mode vicuna_v1 \ 25 | --temperature 0 \ 26 | --chunk_idx $IDX & 27 | done 28 | 29 | wait 30 | 31 | output_file=${output_dir}/temporal.jsonl 32 | 33 | # Clear out the output file if it exists. 34 | > "$output_file" 35 | 36 | # Loop through the indices and concatenate each file. 37 | for IDX in $(seq 0 $((CHUNKS-1))); do 38 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 39 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_one_video.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-7b" 4 | video_path="video_samples/sample_demo_23.mp4" 5 | 6 | CUDA_VISIBLE_DEVICES=0 python3 llava/eval/single_video_inference.py \ 7 | --video_path ${video_path} \ 8 | --model_name ${model_path} \ 9 | --num_frames $num_frames \ 10 | --conv-mode vicuna_v1 11 | 12 | -------------------------------------------------------------------------------- /scripts/infer_video/run_qa_anet_13B.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-13b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-13b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/all_test" 6 | gt_file_question="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_q.json" 7 | gt_file_answers="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_a.json" 8 | output_dir="output/Activitynet_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS" 9 | 10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 11 | IFS=',' read -ra GPULIST <<< "$gpu_list" 12 | 13 | CHUNKS=${#GPULIST[@]} 14 | 15 | 16 | for IDX in $(seq 0 $((CHUNKS-1))); do 17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \ 18 | --video_dir ${video_dir} \ 19 | --gt_file_question ${gt_file_question} \ 20 | --gt_file_answers ${gt_file_answers} \ 21 | --output_dir ${output_dir} \ 22 | --output_name ${CHUNKS}_${IDX} \ 23 | --model_name ${model_path} \ 24 | --num_chunks $CHUNKS \ 25 | --num_frames $num_frames \ 26 | --conv-mode vicuna_v1 \ 27 | --temperature 0 \ 28 | --chunk_idx $IDX & 29 | done 30 | 31 | wait 32 | 33 | output_file=${output_dir}/merge.jsonl 34 | 35 | # Clear out the output file if it exists. 36 | > "$output_file" 37 | 38 | # Loop through the indices and concatenate each file. 39 | for IDX in $(seq 0 $((CHUNKS-1))); do 40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 41 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_qa_anet_7B.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-7b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/all_test" 6 | gt_file_question="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_q.json" 7 | gt_file_answers="${GPT_Zero_Shot_QA}/Activitynet_Zero_Shot_QA/test_a.json" 8 | output_dir="output/Activitynet_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS" 9 | 10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 11 | IFS=',' read -ra GPULIST <<< "$gpu_list" 12 | 13 | CHUNKS=${#GPULIST[@]} 14 | 15 | 16 | for IDX in $(seq 0 $((CHUNKS-1))); do 17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \ 18 | --video_dir ${video_dir} \ 19 | --gt_file_question ${gt_file_question} \ 20 | --gt_file_answers ${gt_file_answers} \ 21 | --output_dir ${output_dir} \ 22 | --output_name ${CHUNKS}_${IDX} \ 23 | --model_name ${model_path} \ 24 | --num_chunks $CHUNKS \ 25 | --num_frames $num_frames \ 26 | --conv-mode vicuna_v1 \ 27 | --temperature 0 \ 28 | --chunk_idx $IDX & 29 | done 30 | 31 | wait 32 | 33 | output_file=${output_dir}/merge.jsonl 34 | 35 | # Clear out the output file if it exists. 36 | > "$output_file" 37 | 38 | # Loop through the indices and concatenate each file. 39 | for IDX in $(seq 0 $((CHUNKS-1))); do 40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 41 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_qa_msrvtt_13B.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-13b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-13b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/videos/all" 6 | gt_file_question="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_q.json" 7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_a.json" 8 | output_dir="output/MSRVTT_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS" 9 | 10 | 11 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 12 | IFS=',' read -ra GPULIST <<< "$gpu_list" 13 | 14 | CHUNKS=${#GPULIST[@]} 15 | 16 | 17 | for IDX in $(seq 0 $((CHUNKS-1))); do 18 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \ 19 | --video_dir ${video_dir} \ 20 | --gt_file_question ${gt_file_question} \ 21 | --gt_file_answers ${gt_file_answers} \ 22 | --output_dir ${output_dir} \ 23 | --output_name ${CHUNKS}_${IDX} \ 24 | --model_name ${model_path} \ 25 | --num_chunks $CHUNKS \ 26 | --num_frames $num_frames \ 27 | --conv-mode vicuna_v1 \ 28 | --temperature 0 \ 29 | --chunk_idx $IDX & 30 | done 31 | 32 | wait 33 | 34 | output_file=${output_dir}/merge.jsonl 35 | 36 | # Clear out the output file if it exists. 37 | > "$output_file" 38 | 39 | # Loop through the indices and concatenate each file. 40 | for IDX in $(seq 0 $((CHUNKS-1))); do 41 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 42 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_qa_msrvtt_7B.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-7b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/videos/all" 6 | gt_file_question="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_q.json" 7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSRVTT_Zero_Shot_QA/test_a.json" 8 | output_dir="output/MSRVTT_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS" 9 | 10 | 11 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 12 | IFS=',' read -ra GPULIST <<< "$gpu_list" 13 | 14 | CHUNKS=${#GPULIST[@]} 15 | 16 | 17 | for IDX in $(seq 0 $((CHUNKS-1))); do 18 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \ 19 | --video_dir ${video_dir} \ 20 | --gt_file_question ${gt_file_question} \ 21 | --gt_file_answers ${gt_file_answers} \ 22 | --output_dir ${output_dir} \ 23 | --output_name ${CHUNKS}_${IDX} \ 24 | --model_name ${model_path} \ 25 | --num_chunks $CHUNKS \ 26 | --num_frames $num_frames \ 27 | --conv-mode vicuna_v1 \ 28 | --temperature 0 \ 29 | --chunk_idx $IDX & 30 | done 31 | 32 | wait 33 | 34 | output_file=${output_dir}/merge.jsonl 35 | 36 | # Clear out the output file if it exists. 37 | > "$output_file" 38 | 39 | # Loop through the indices and concatenate each file. 40 | for IDX in $(seq 0 $((CHUNKS-1))); do 41 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 42 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_qa_msvd_13B.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-13b" 2 | num_frames=4 3 | model_path="ckpt/llava-v1.5-13b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/videos" 6 | gt_file_question="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_q.json" 7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_a.json" 8 | output_dir="output/MSVD_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS" 9 | 10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 11 | IFS=',' read -ra GPULIST <<< "$gpu_list" 12 | 13 | CHUNKS=${#GPULIST[@]} 14 | 15 | 16 | for IDX in $(seq 0 $((CHUNKS-1))); do 17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \ 18 | --video_dir ${video_dir} \ 19 | --gt_file_question ${gt_file_question} \ 20 | --gt_file_answers ${gt_file_answers} \ 21 | --output_dir ${output_dir} \ 22 | --output_name ${CHUNKS}_${IDX} \ 23 | --model_name ${model_path} \ 24 | --num_chunks $CHUNKS \ 25 | --num_frames $num_frames \ 26 | --conv-mode vicuna_v1 \ 27 | --temperature 0 \ 28 | --chunk_idx $IDX & 29 | done 30 | 31 | wait 32 | 33 | output_file=${output_dir}/merge.jsonl 34 | 35 | # Clear out the output file if it exists. 36 | > "$output_file" 37 | 38 | # Loop through the indices and concatenate each file. 39 | for IDX in $(seq 0 $((CHUNKS-1))); do 40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 41 | done -------------------------------------------------------------------------------- /scripts/infer_video/run_qa_msvd_7B.sh: -------------------------------------------------------------------------------- 1 | CKPT_NAME="llava-v1.5-7b" 2 | num_frames=4 3 | model_path="../ckpt/llava-v1.5-7b" 4 | GPT_Zero_Shot_QA="/bpfs/v2_mnt/VIS/wuwenhao/mllm_data/GPT_Zero_Shot_QA" 5 | video_dir="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/videos" 6 | gt_file_question="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_q.json" 7 | gt_file_answers="${GPT_Zero_Shot_QA}/MSVD_Zero_Shot_QA/test_a.json" 8 | output_dir="output/MSVD_Zero_Shot_QA/${CKPT_NAME}_u${num_frames}FRS" 9 | 10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 11 | IFS=',' read -ra GPULIST <<< "$gpu_list" 12 | 13 | CHUNKS=${#GPULIST[@]} 14 | 15 | 16 | for IDX in $(seq 0 $((CHUNKS-1))); do 17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python3 llava/eval/run_inference_qa.py \ 18 | --video_dir ${video_dir} \ 19 | --gt_file_question ${gt_file_question} \ 20 | --gt_file_answers ${gt_file_answers} \ 21 | --output_dir ${output_dir} \ 22 | --output_name ${CHUNKS}_${IDX} \ 23 | --model_name ${model_path} \ 24 | --num_chunks $CHUNKS \ 25 | --num_frames $num_frames \ 26 | --conv-mode vicuna_v1 \ 27 | --temperature 0 \ 28 | --chunk_idx $IDX & 29 | done 30 | 31 | wait 32 | 33 | output_file=${output_dir}/merge.jsonl 34 | 35 | # Clear out the output file if it exists. 36 | > "$output_file" 37 | 38 | # Loop through the indices and concatenate each file. 39 | for IDX in $(seq 0 $((CHUNKS-1))); do 40 | cat ${output_dir}/${CHUNKS}_${IDX}.json >> "$output_file" 41 | done --------------------------------------------------------------------------------