├── .gitignore ├── LICENSE ├── MODEL_LICENSE.pdf ├── README.md ├── README_EN.md ├── demo.py ├── eval_chat.py ├── eval_configs ├── vxverse_hd_benchmark_evaluation.yaml └── vxverse_xverse_hd_eval.yaml ├── eval_vqa.py ├── requirements.txt ├── resources ├── 2_2_Trans.drawio.svg ├── 2_3_Trans.drawio.svg ├── Demo_Trans.svg ├── modelscope.png └── wechat.png └── vxverse ├── __init__.py ├── common ├── __init__.py ├── config.py ├── dist_utils.py ├── eval_utils.py ├── gradcam.py ├── logger.py ├── optims.py ├── registry.py ├── utils.py └── vqa_tools │ ├── VQA │ ├── PythonEvaluationTools │ │ ├── vqaEvalDemo.py │ │ └── vqaEvaluation │ │ │ ├── __init__.py │ │ │ └── vqaEval.py │ ├── PythonHelperTools │ │ ├── vqaDemo.py │ │ └── vqaTools │ │ │ ├── __init__.py │ │ │ └── vqa.py │ ├── QuestionTypes │ │ ├── abstract_v002_question_types.txt │ │ └── mscoco_question_types.txt │ ├── README.md │ └── license.txt │ ├── __init__.py │ ├── vqa.py │ └── vqa_eval.py ├── configs ├── Qformer │ └── bert-base-uncased │ │ ├── config.json │ │ ├── tokenizer.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt ├── datasets │ ├── align │ │ ├── align.yaml │ │ ├── ccs_sub.yaml │ │ └── defaults.yaml │ ├── align_hd │ │ └── align_hd.yaml │ ├── cc_sbu │ │ ├── align.yaml │ │ ├── ccs_sub.yaml │ │ └── defaults.yaml │ └── gqa │ │ └── balanced_val.yaml ├── deepspeed │ └── ds.json ├── default.yaml ├── eva │ └── EVA02-CLIP-bigE-14-336.json └── models │ ├── vxverse_13bchat.yaml │ ├── vxverse_65bchat.yaml │ └── vxverse_7bchat.yaml ├── conversation ├── __init__.py └── conversation.py ├── datasets ├── __init__.py ├── builders │ ├── __init__.py │ ├── base_dataset_builder.py │ └── image_text_pair_builder.py ├── data_utils.py └── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── dataloader_utils.py │ ├── gqa_datasets.py │ └── vqa_datasets.py ├── models ├── Qformer.py ├── __init__.py ├── base_model.py ├── clip_vit.py ├── eva2_vit.py ├── eva_vit.py ├── modeling_llama.py ├── modeling_xverse.py ├── vxverse.py └── vxverse_base.py └── processors ├── __init__.py ├── base_processor.py └── blip_processors.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | dataset/ 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | .idea/ 162 | 163 | wandb/ 164 | jobs/logs/ 165 | *.out 166 | *ipynb 167 | .history/ 168 | #*.json 169 | *.sh 170 | .ipynb_common 171 | logs/ 172 | results/ 173 | prompts/ 174 | output/ 175 | std_output/ 176 | vxverse/output/* 177 | benchmark_eval_output/* 178 | ckpt/ 179 | divide_vqa.py 180 | jobs/ 181 | demo_v2.py 182 | 183 | *.slurm 184 | slurm* 185 | sbatch_generate* 186 | eval_data/ 187 | dataset/Evaluation.md 188 | jupyter_notebook.slurm 189 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [2023] [XVERSE Technology Inc] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MODEL_LICENSE.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/MODEL_LICENSE.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | XVERSE-V-13B 4 |

5 |
6 | 7 |

8 | 🤗 Hugging Face | 9 | ModelScope | 10 | 💬 微信社区 11 |

12 | 13 |

14 |

15 | 中文 | 16 | English 17 |

18 |

19 | 20 | ## 更新信息 21 | - **[2024/04/28]** 发布 **XVERSE-V-13B** 多模态模型。 22 | 23 | ## 模型介绍 24 | 25 | **XVERSE-V-13B** 是由深圳元象科技自主研发的支持图文问答的多模态大模型(Large Multimodal Model),其主要特点如下: 26 | 27 | - **模型结构**:视觉编码器采用了 **openai/clip-vit-large-patch14-224**,文本模型采用了自研的 **XVERSE-13B-Chat** 模型,图像—文本桥接层采用了高效且简洁的两层 **MLP** 结构。 28 | - **训练数据**:图文数据采用的是完全公开的数据集,其中预训练阶段数据量为 2.1B 图文对,微调阶段采用了 8.2M 的指令数据。训练数据几乎全为英文数据,因此模型的能力主要体现在英文方面。 29 | - **图像分辨率**:不同于其他固定图像分辨率的模型,**XVERSE-V-13B** 将图像切分成多个 **224×224** 的块,分别将他们送到视觉模块进行编码,因此能够处理更高分辨率或者不同宽高比的图像,这为我们模型保留了尽可能多的细节信息。 30 | - **训练方式**: **XVERSE-V-13B** 采用了两阶段训练,分别为规模比较大的图文对预训练和规模比较小的指令数据微调。其中预训练阶段,我们冻结❄️视觉模块和 LLM 模块,只训练🔥桥接层部分; 31 | 指令微调阶段,我们依然冻结❄️视觉模块和LLM模块,但是微调🔥桥接层部分以及LLM的所有线性层的 LoRA 参数;另外,在指令微调阶段,我们对桥接层部分和 LoRA 部分采用了差分学习率。 32 | 33 | ## 图像编码示例 34 | 对于 448*448 的图像,我们通过 Sliding Window 将其切分成4个局部图像块以及 Resize 得到一个包含全局信息的图像,如下图所示 35 | ![avatar](resources/2_2_Trans.drawio.svg) 36 | 37 | 对于更高分辨率的 448*672 的图像,我们通过 Sliding Window 将其切分成6个局部图像块以及 Resize 得到一个包含全局信息的图像,如下图所示 38 | ![avatar](resources/2_3_Trans.drawio.svg) 39 | 40 | > 1:Concate* 表示列向量按行进行拼接 41 | > 42 | > 2:对于其他不同分辨率以及不同宽高比的图像,也是同理进行切块编码 43 | 44 | ## 评测结果 45 | 为了综合评估模型的性能,我们在一系列标准数据集上进行了全面测试,包括 MMBench、MMMU、SEEDBench_IMG、MMStar、LLaVABench、AI2D、ScienceQA、VizWiz、TextVQA、OKVQA 和 GQA 等数据集。这些评估覆盖了模型在多个领域的能力,具体包括 OCR,逻辑推理,关系推理,粗粒度感知和细粒度感知。评估结果如下: 46 | 47 | ### OpenCompass 榜单 48 | [OpenCompass](https://opencompass.org.cn/home) 是面向大模型评测的一站式平台。 其主要特点如下: 开源可复现:提供公平、公开、可复现的大模型评测方案。因此,我们报告模型在此榜单上的相关结果。 49 | 50 | | 数据集 | XVERSE-V-13B | GeminiProVision`*` | Qwen-VL-Plus`*` | Claude-3V Sonnet`*` | LLaVA-Next-Vicuna-13B | Monkey-Chat | OmniLMM-12B | DeepSeek-VL-7B | CogVLM-17B-Chat | TransCore-M | Yi-VL-34B | 51 | |--------------------|:------------:|:------------------:|:---------------:|:-------------------:|:---------------------:|:-----------:|:-----------:|:--------------:|:---------------:|:-----------:|:---------:| 52 | | MMBench | 75.6 | 73.6 | 67.0 | 67.8 | 70.0 | 72.4 | 71.7 | 73.8 | 65.8 | **82.3** | 72.4 | 53 | | MMBench-CN | 74.7 | 74.3 | 70.7 | 64.2 | 68.5 | 67.5 | 62.0 | 71.4 | 55.9 | **80.7** | 70.7 | 54 | | MMStar | **47.8** | 38.6 | 39.7 | 44.2 | 40.4 | 40.7 | 39.6 | 40.5 | 39.9 | 35.6 | 40.5 | 55 | | MMMU-Val | 43.3 | **48.9** | 39.8 | 47.4 | 37.3 | 40.7 | 41.8 | 38.3 | 37.3 | 41.0 | 45.1 | 56 | | MathVistaMini-Test | 44.1 | **46.5** | 37.6 | 45.0 | 34.1 | 35.9 | 34.7 | 36.9 | 35.0 | 32.3 | 31.5 | 57 | | HallusionBench | 31.8 | **45.2** | 40.6 | 41.3 | 31.8 | 39.3 | 35.8 | 34.5 | 35.4 | 27.3 | 35.3 | 58 | | AI2D-Test | 70.4 | 70.2 | 65.7 | 69.9 | **72.2** | 68.5 | 63.3 | 65.3 | 63.3 | 64.1 | 65.9 | 59 | | OCRBench | 489 | 680.0 | **726.0** | 646.0 | 537.0 | 534.0 | 420.0 | 435.0 | 590.0 | 405.0 | 290.0 | 60 | | SEEDBench_IMG | **72.4** | 70.7 | 65.7 | 65.0 | 71.4 | 68.9 | 71.5 | 70.1 | 68.8 | 72.0 | 68.1 | 61 | | LLaVABench | **82.3** | 79.9 | 73.7 | 73.2 | 73.9 | 60.5 | 75.8 | 77.8 | 73.9 | 66.8 | 62.3 | 62 | 63 | > 1:带 `*` 号的模型是闭源模型 64 | 65 | 对于上述所有比较模型,我们优先汇报其官方公布的结果。在缺少官方结果的情况下,我们采用了 [OpenCompass 榜单](https://rank.opencompass.org.cn/leaderboard-multimodal)的报告结果。若 OpenCompass 榜单上仍然缺少相应的数据集评估结果, 66 | 则来自于我们自行执行的评估流程所获得的数据。而评测框架则采用了[VLMEvalKit 评估框架](https://github.com/open-compass/VLMEvalKit/)。 67 | 68 | ### 传统VQA类任务 69 | 传统VQA任务,作为多模态视觉问答领域学术论文常引用的评测任务,具备显著的学术参考价值。因此,我们也将在此类数据集上报告相关的评测结果。 70 | 71 | | 数据集 | XVERSE-V-13B | LLaVA-Next-Vicuna-13B | Monkey-Chat | OmniLMM-12B | DeepSeek-VL-7B | CogVLM-17B-Chat | TransCore-M | Yi-VL-34B | 72 | |--------------------|:------------:|:---------------------:| :-------: | :---------: | :--------: |:---------------:|:-----------:| :--------------: | 73 | | ScienceQA | **86.4** | 73.9 | 82.8 | 80.8 | 81.0 | 70.3 | 74.9 | 75.4 | 74 | | OKVQA | 59.2 | **60.0** | 54.7 | 40.8 | 55.1 | 54.4 | 56.7 | 51.4 | 75 | | GQA | 62.2 | **65.5** | 65.4 | 61.1 | 61.8 | 60.5 | 63.6 | 58.3 | 76 | | VizWiz | **81.9** | 54.6 | 75.6 | 64.0 | 50.1 | 44.0 | 41.4 | 70.8 | 77 | | TextVQA | **74.2** | 64.3 | 53.7 | 62.4 | 63.8 | 69.6 | 63.1 | 54.0 | 78 | 79 | 同理,对于上述所有比较模型,我们优先汇报其官方公布的结果。在缺少官方结果的情况下,则来自于我们自行执行的评估流程所获得的数据。而评测框架则采用了[VLMEvalKit 评估框架](https://github.com/open-compass/VLMEvalKit/)。 80 | 81 | 82 | ## 效果示例 83 | 这里我们展示全景和细节识别、图表分析、百科解答、教育问答、内容创作和代码生成等能力的样例。 84 | 85 | ![avatar](resources/Demo_Trans.svg) 86 | 87 | ## 使用方法 88 | 89 | ### 环境安装 90 | 91 | 1. 下载本仓库: 92 | 93 | ```shell 94 | git clone git@github.com:xverse-ai/XVERSE-V-13B.git 95 | cd XVERSE-V-13B 96 | ``` 97 | 98 | 2. 使用 pip 安装依赖: 99 | 100 | ```shell 101 | pip install -r requirements.txt 102 | ``` 103 | 104 | ### 模型准备与加载 105 | 1. 模型准备: 106 | 我们的模型分为三个部分:视觉编码器 clip-vit-large-patch14-224,大语言模型 XVERSE-13B-Chat 和桥接层 Adapters,这三部分分别可以从下面提供的链接中下载 107 | 108 | | XVERSE-13B-Chat | clip-vit-large-patch14-224 | Adapters | 109 | |---------------------------------------------------------------| :--------------: |:-----------------------------------------------------:| 110 | |
[下载](https://huggingface.co/xverse/XVERSE-13B-Chat) |
[下载](https://huggingface.co/openai/clip-vit-large-patch14) |
[下载](https://huggingface.co/xverse/XVERSE-V-13B)| 111 | 112 | 2. 模型加载: 113 | 完成步骤1之后,只需要将模型权重路径填入到配置文件相应的位置中即可: 114 | 1. 对于 clip-vit-large-patch14-224 和 Adapters,请将路径填分别写到 ./eval_configs/vxverse_*.yaml 文件中的 vit_path 和 ckpt 字段中; 115 | 2. 对于XVERSE-13B-Chat,请将路径填写到 ./vxverse/configs/models/vxverse_13bchat.yaml 文件对应的字段中。 116 | 117 | 118 | ### **OKVQA** 和 **GQA** 数据集的测评 119 | 1. 数据集准备: 120 | 1. 对于OKVQA测试集可以从从此下载 121 | 2. 对于GQA测试集可以从从此下载 122 | 123 | 2. 运行脚本 124 | ```shell 125 | python ./eval_vqa.py --cfg-path ./eval_configs/vxverse_hd_benchmark_evaluation.yaml --dataset gqa 126 | ``` 127 | 128 | ### 网页 Demo 129 | 130 | 可通过以下代码启动一个web server,在浏览器输入访问地址后,可对 XVERSE-V-13B 模型进行体验: 131 | 132 | ```shell 133 | python demo.py --cfg-path ./eval_configs/vxverse_xverse_hd_eval.yaml --gpu-id 0 134 | ``` 135 | 136 | ## 特别说明 137 | 我们的模型是基于修改并适配后的 [Megatron](https://github.com/NVIDIA/Megatron-LM) 框架训练的,而 Pytorch 框架下的模型加载,demo 体验和数据集的评估则是基于[MiniGPT4](https://github.com/Vision-CAIR/MiniGPT-4)代码修改而来的。 138 | 139 | 140 | ## 局限性与免责申明 141 | 142 | XVERSE-V-13B 与其它所有 LMM 一样,在某些情况下可能会产生不准确、有偏见或其他令人反感的内容。因此,请谨慎使用模型生成的内容,请勿将生成的有害内容进行传播,在部署任何 XVERSE-V-13B 的应用之前,开发人员应根据其具体应用对模型进行安全测试和调优。 143 | 144 | 我们强烈警告不要将 XVERSE-V-13B 模型用于制造或传播有害信息,或进行任何可能损害公众、国家、社会安全或违反法规的活动。如果使用 XVERSE-V-13B 模型产生任何问题,无论是数据安全问题、公共舆论风险,还是模型被误解、滥用、传播或不合规使用所引发的任何风险和问题,我们将不承担任何责任。 145 | 146 | ## 模型开源协议 147 | 148 | 使用本仓库的源码需要遵循 [Apache-2.0](LICENSE) 开源协议,使用 XVERSE-V-13B 的模型权重则需要遵循[模型许可协议](MODEL_LICENSE.pdf)。 149 | 150 | XVERSE-V-13B 模型权重对学术研究**完全开放**,并且支持**免费商用**。如需申请商业许可证,请填写【[申请表](https://chat.xverse.cn/home/business.html)】,如有其他问题或合作,请联系 。 151 | 152 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import gradio as gr 9 | 10 | from transformers import StoppingCriteriaList 11 | 12 | from vxverse.common.config import Config 13 | from vxverse.common.registry import registry 14 | from vxverse.conversation.conversation import Chat, CONV_VISION_XVERSE, StoppingCriteriaSub 15 | 16 | # imports modules for registration 17 | 18 | # from vxverse.datasets.builders import * 19 | # from vxverse.models import * 20 | # from vxverse.processors import * 21 | # from vxverse.runners import * 22 | # from vxverse.tasks import * 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="Demo") 27 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 28 | parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") 29 | parser.add_argument("--server_port", type=int, default=20029, help="specify the gpu to load the model.") 30 | parser.add_argument( 31 | "--options", 32 | nargs="+", 33 | help="override some settings in the used config, the key-value pair " 34 | "in xxx=yyy format will be merged into config file (deprecate), " 35 | "change to --cfg-options instead.", 36 | ) 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | # ======================================== 42 | # Model Initialization 43 | # ======================================== 44 | 45 | conv_dict = {'pretrain_xverse13b-chat': CONV_VISION_XVERSE} 46 | 47 | print('Initializing Chat') 48 | args = parse_args() 49 | cfg = Config(args) 50 | 51 | model_config = cfg.model_cfg 52 | model_config.device_8bit = args.gpu_id 53 | model_cls = registry.get_model_class(model_config.arch) 54 | model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) 55 | 56 | CONV_VISION = conv_dict[model_config.model_type] 57 | # CONV_VISION.system = "" 58 | # vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train 59 | vis_processor_cfg = cfg.datasets_cfg.gqa.vis_processor.train 60 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 61 | 62 | if "vicuna" in model_config.model_type: 63 | stop_sign = "###" 64 | stop_words_ids = [[835], [2277, 29937]] 65 | elif "xverse" in model_config.model_type: 66 | stop_words_ids = [[2]] 67 | stop_sign = "<|endoftext|>" 68 | elif "llama" in model_config.model_type: 69 | stop_sign = "" 70 | stop_words_ids = [[2]] 71 | else: 72 | raise ValueError("Not support model type.") 73 | 74 | print("stop_sign", stop_sign) 75 | stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids] 76 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 77 | 78 | chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria, vis_processor_name=vis_processor_cfg.name) 79 | print('Initialization Finished') 80 | 81 | 82 | # ======================================== 83 | # Gradio Setting 84 | # ======================================== 85 | 86 | def escape_markdown(text): 87 | # List of Markdown special characters that need to be escaped 88 | md_chars = ['<', '>'] 89 | # Escape each special character 90 | for char in md_chars: 91 | text = text.replace(char, '\\' + char) 92 | text = text.replace("\u200b\n", "") 93 | return text 94 | 95 | def gradio_reset(chat_state, img_list): 96 | if chat_state is not None: 97 | chat_state.messages = [] 98 | if img_list is not None: 99 | img_list = [] 100 | return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat', interactive=True), chat_state, img_list 101 | 102 | 103 | def upload_img(gr_img, text_input, chat_state): 104 | if gr_img is None: 105 | return None, None, gr.update(interactive=True), chat_state, None 106 | chat_state = CONV_VISION.copy() 107 | img_list = [] 108 | llm_message = chat.upload_img(gr_img, chat_state, img_list) 109 | chat.encode_img(img_list) 110 | return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list 111 | 112 | 113 | 114 | 115 | def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag): 116 | if len(user_message) == 0: 117 | text_box_show = 'Input should not be empty!' 118 | else: 119 | text_box_show = '' 120 | if isinstance(gr_img, dict): 121 | gr_img, mask = gr_img['image'], gr_img['mask'] 122 | else: 123 | mask = None 124 | 125 | if chat_state is None: 126 | chat_state = CONV_VISION.copy() 127 | 128 | if upload_flag: 129 | if replace_flag: 130 | chat_state = CONV_VISION.copy() # new image, reset everything 131 | replace_flag = 0 132 | chatbot = [] 133 | img_list = [] 134 | llm_message = chat.upload_img(gr_img, chat_state, img_list) 135 | upload_flag = 0 136 | chat.ask(user_message, chat_state) 137 | chatbot = chatbot + [[user_message, None]] 138 | 139 | return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag 140 | 141 | 142 | def gradio_answer(chatbot, chat_state, img_list, top_p, temperature): 143 | if len(img_list) > 0: 144 | if not isinstance(img_list[0], torch.Tensor): 145 | chat.encode_img(img_list) 146 | llm_message = chat.answer(conv=chat_state, 147 | stop_sign=stop_sign, 148 | img_list=img_list, 149 | top_p=top_p, 150 | temperature=temperature, 151 | max_new_tokens=500, 152 | max_length=2048)[0] 153 | chatbot[-1][1] = llm_message 154 | print("##############") 155 | print("chat state:{}".format(chat_state)) 156 | return chatbot, chat_state, img_list 157 | 158 | 159 | def gradio_stream_answer(chatbot, chat_state, img_list, top_p, temperature, do_sample=False): 160 | # if img_list != None To support pure text conversation 161 | if img_list != None and len(img_list) > 0: 162 | if not isinstance(img_list[0], torch.Tensor): 163 | chat.encode_img(img_list) 164 | streamer = chat.stream_answer(conv=chat_state, 165 | img_list=img_list, 166 | temperature=temperature, 167 | top_p=top_p, 168 | do_sample=do_sample, 169 | max_new_tokens=2048, 170 | max_length=8192) 171 | 172 | output = '' 173 | for new_output in streamer: 174 | escapped = escape_markdown(new_output) 175 | output += escapped 176 | output = output.split(stop_sign)[0] # remove the stop sign 177 | output = output.split('Assistant:')[-1] 178 | chatbot[-1][1] = output 179 | yield chatbot, chat_state, img_list 180 | chat_state.messages[-1][1] = output 181 | print("##############") 182 | print("chat_state:{}".format(chat_state)) 183 | return chatbot, chat_state, img_list 184 | 185 | 186 | def image_upload_trigger(upload_flag, replace_flag, img_list): 187 | # set the upload flag to true when receive a new image. 188 | # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. 189 | upload_flag = 1 190 | if img_list: 191 | replace_flag = 1 192 | return upload_flag, replace_flag 193 | 194 | def example_trigger(image, text_input, upload_flag, replace_flag, img_list): 195 | # set the upload flag to true when receive a new image. 196 | # if there is an old image (and old conversation), set the replace flag to true to reset the conv later. 197 | upload_flag = 1 198 | if img_list or replace_flag == 1: 199 | replace_flag = 1 200 | 201 | return upload_flag, replace_flag 202 | 203 | title = """

Demo of XVERSE-V

""" 204 | description = """

This is the demo of XVERSE-V. Upload your images and start chatting!

""" 205 | 206 | 207 | #TODO show examples below 208 | 209 | text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False, 210 | scale=8) 211 | with gr.Blocks() as demo: 212 | gr.Markdown(title) 213 | gr.Markdown(description) 214 | 215 | with gr.Row(): 216 | with gr.Column(scale=1): 217 | image = gr.Image(type="pil") 218 | clear = gr.Button("Restart") 219 | 220 | top_p = gr.Slider( 221 | minimum=0.1, 222 | maximum=1, 223 | value=0.8, 224 | step=0.05, 225 | interactive=True, 226 | label="Top P", 227 | ) 228 | 229 | temperature = gr.Slider( 230 | minimum=0.1, 231 | maximum=2.0, 232 | value=1.0, 233 | step=0.1, 234 | interactive=True, 235 | label="Temperature", 236 | ) 237 | do_sample = gr.inputs.Checkbox(label="do_sample") # default False 238 | with gr.Column(scale=2): 239 | chat_state = gr.State() 240 | img_list = gr.State() 241 | chatbot = gr.Chatbot(label='Visual-XVERSE') 242 | 243 | with gr.Row(): 244 | text_input.render() 245 | send = gr.Button("Send", variant='primary', size='sm', scale=1) 246 | 247 | upload_flag = gr.State(value=0) 248 | replace_flag = gr.State(value=0) 249 | image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag]) 250 | 251 | 252 | text_input.submit( 253 | gradio_ask, 254 | [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], 255 | [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False 256 | ).success( 257 | gradio_stream_answer, 258 | [chatbot, chat_state, img_list, top_p, temperature, do_sample], 259 | [chatbot, chat_state] 260 | ) 261 | 262 | send.click( 263 | gradio_ask, 264 | [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag], 265 | [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False 266 | ).success( 267 | gradio_stream_answer, 268 | [chatbot, chat_state, img_list, top_p, temperature, do_sample], 269 | [chatbot, chat_state] 270 | ) 271 | clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False) 272 | 273 | demo.queue(concurrency_count=4) 274 | demo.launch(share=False, server_name="0.0.0.0", server_port=args.server_port, enable_queue=True) 275 | # demo.launch(share=False, server_name="0.0.0.0", server_port=args.server_port, max_threads=4) 276 | -------------------------------------------------------------------------------- /eval_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import argparse 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import torch 11 | from torch.utils.data import DataLoader 12 | # from datasets import load_dataset 13 | from transformers import StoppingCriteriaList 14 | from vxverse.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, CONV_VISION_XVERSE, StoppingCriteriaSub 15 | 16 | from vxverse.common.eval_utils import init_model, eval_parser 17 | from vxverse.conversation.conversation import CONV_VISION_Vicuna0, CONV_VISION_XVERSE, CONV_VISION_LLama2 18 | from vxverse.common.config import Config 19 | from vxverse.common.registry import registry 20 | 21 | 22 | 23 | conv_dict = {'pretrain_xverse13b-chat': CONV_VISION_XVERSE} 24 | 25 | def read_json(file): 26 | res = [] 27 | with open(file, 'r', encoding='utf-8') as f: 28 | for line in f.readlines(): 29 | res.append(json.loads(line.strip())) 30 | return res 31 | 32 | def list_of_str(arg): 33 | return list(map(str, arg.split(','))) 34 | 35 | def prepare_texts(texts, conv_temp): 36 | convs = [conv_temp.copy() for _ in range(len(texts))] 37 | [conv.append_message( 38 | conv.roles[0], ' {}'.format(text)) for conv, text in zip(convs, texts)] 39 | [conv.append_message(conv.roles[1], None) for conv in convs] 40 | texts = [conv.get_prompt() for conv in convs] 41 | return texts 42 | 43 | 44 | parser = eval_parser() 45 | parser.add_argument("--dataset", type=list_of_str, default='gqa', help="dataset to evaluate") 46 | parser.add_argument("--gpu_id", type=int, default=0, help="dataset to evaluate") 47 | args = parser.parse_args() 48 | cfg = Config(args) 49 | 50 | model, vis_processor = init_model(args) 51 | 52 | model_config = cfg.model_cfg 53 | model_cls = registry.get_model_class(model_config.arch) 54 | CONV_VISION = conv_dict[model_config.model_type] 55 | print("model_config.model_type: {}".format(model_config.model_type)) 56 | 57 | 58 | conv_temp = CONV_VISION.copy() 59 | conv_temp.system = "" 60 | model.eval() 61 | save_path = cfg.run_cfg.save_path 62 | 63 | 64 | 65 | eval_file_path = cfg.evaluation_datasets_cfg["chat"]["eval_file_path"] 66 | img_path = None 67 | batch_size = cfg.evaluation_datasets_cfg["chat"]["batch_size"] 68 | max_new_tokens = cfg.evaluation_datasets_cfg["chat"]["max_new_tokens"] 69 | temperature = cfg.evaluation_datasets_cfg["chat"].get("temperature", 0.5) 70 | top_k = cfg.evaluation_datasets_cfg["chat"].get("top_k", 30) 71 | top_p = cfg.evaluation_datasets_cfg["chat"].get("top_p", 0.85) 72 | do_sample = True 73 | repetition_penalty = cfg.evaluation_datasets_cfg["chat"].get("repetition_penalty", 1.1) 74 | 75 | stop_words_ids = [[2]] 76 | stop_sign = "<|endoftext|>" 77 | print("stop_sign", stop_sign) 78 | stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids] 79 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 80 | chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria) 81 | 82 | pure_text_datas = read_json(eval_file_path) 83 | 84 | count=0 85 | total=0 86 | print_prompt_flag = True 87 | img_list = None 88 | for i, sample in enumerate(tqdm(pure_text_datas)): 89 | 90 | conversations = sample["conversations"] 91 | conv = conv_temp.copy() 92 | 93 | 94 | for j in range(len(conversations)//2): 95 | 96 | conv.append_message(conv.roles[0], conversations[j*2]["value"]) 97 | 98 | if print_prompt_flag: 99 | print("########## Prompts ###########") 100 | print(conv) 101 | print_prompt_flag = False 102 | 103 | answer, tokens = chat.answer(conv=conv, 104 | stop_sign=stop_sign, 105 | img_list=img_list, 106 | top_p=top_p, 107 | top_k=top_k, 108 | temperature=temperature, 109 | max_new_tokens=max_new_tokens, 110 | max_length=8192, 111 | repetition_penalty=repetition_penalty, 112 | do_sample=do_sample,) 113 | print(f"Answer:\n{answer}") 114 | sample["conversations"][j+1]["value"] = str(answer) 115 | 116 | 117 | 118 | file_save_path = os.path.join(save_path, "open_test_data_lmm_predicts.json") 119 | with open(file_save_path,'w', encoding='utf-8') as f: 120 | for res in pure_text_datas: 121 | f.write(json.dumps(res, ensure_ascii=False)) 122 | f.write("\n") 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /eval_configs/vxverse_hd_benchmark_evaluation.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: vxverse 3 | model_type: pretrain_xverse13b-chat 4 | max_txt_len: 16 5 | 6 | vit_model: "openai/clip-vit-large-patch14-224" 7 | vit_path: "Your openai/clip-vit-large-patch14-224 path" 8 | 9 | end_sym: "<|endoftext|>" 10 | prompt_template: 'Human: {}\nAssistant: ' 11 | 12 | low_resource: True 13 | 14 | ckpt: "Your ckpt path" 15 | 16 | lora_r: 128 17 | lora_alpha: 256 18 | lora_dropout: 0.05 19 | lora_target_modules: "all_linear" 20 | 21 | has_qformer: False 22 | n_proj_layers: 2 23 | # num_query_token: 64 24 | 25 | freeze_llm: True 26 | freeze_vit: True 27 | 28 | 29 | 30 | datasets: 31 | gqa: 32 | vis_processor: 33 | train: 34 | name: "hd_image_train" 35 | image_size: 224 36 | text_processor: 37 | train: 38 | name: "base_text_process" 39 | 40 | evaluation_datasets: 41 | gqa: 42 | eval_file_path: "Your gqa test jsonl path" 43 | img_path: "Your gqa test images dir" 44 | max_new_tokens: 16 45 | batch_size: 5 46 | okvqa: 47 | eval_file_path: "Your gqa test dir" 48 | img_path: "Your gqa test images dir" 49 | max_new_tokens: 16 50 | batch_size: 5 51 | chat: 52 | eval_file_path: "Your pure-text data path" 53 | max_new_tokens: 2048 54 | batch_size: 1 55 | 56 | run: 57 | task: image_text_pretrain 58 | name: vxverse_evaluation 59 | save_path: 'Your dir to save results' 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /eval_configs/vxverse_xverse_hd_eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: vxverse 3 | model_type: pretrain_xverse13b-chat 4 | max_txt_len: 16 5 | 6 | vit_model: "openai/clip-vit-large-patch14-224" 7 | vit_path: "Your openai/clip-vit-large-patch14-224 path" 8 | 9 | end_sym: "<|endoftext|>" 10 | prompt_template: 'Human: {}\nAssistant: ' 11 | 12 | low_resource: True 13 | 14 | ckpt: "Your ckpt path" 15 | 16 | lora_r: 128 17 | lora_alpha: 256 18 | lora_dropout: 0.05 19 | lora_target_modules: "all_linear" 20 | 21 | has_qformer: False 22 | n_proj_layers: 2 23 | # num_query_token: 64 24 | 25 | freeze_llm: True 26 | freeze_vit: True 27 | 28 | 29 | 30 | datasets: 31 | gqa: 32 | vis_processor: 33 | train: 34 | name: "hd_image_train" 35 | image_size: 224 36 | text_processor: 37 | train: 38 | name: "base_text_process" 39 | 40 | run: 41 | task: image_text_pretrain 42 | 43 | -------------------------------------------------------------------------------- /eval_vqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from vxverse.datasets.datasets.vqa_datasets import OKVQAEvalData, GQAEvalData 7 | from vxverse.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA 8 | from vxverse.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval 9 | 10 | from vxverse.common.eval_utils import prepare_texts, init_model, eval_parser 11 | from vxverse.conversation.conversation import CONV_VISION_XVERSE 12 | from vxverse.common.config import Config 13 | from vxverse.common.registry import registry 14 | 15 | 16 | conv_dict = {'pretrain_xverse13b-chat': CONV_VISION_XVERSE} 17 | 18 | 19 | stop_words_ids = [[2]] 20 | do_sample = False 21 | 22 | 23 | def collater4hd(samples): 24 | # image, question, question_id, img_id 25 | batch_images, batch_questions, batch_question_ids, batch_img_ids = [], [], [], [] 26 | batch_patches_per_image, batch_total_images, batch_labels = [], [], [] 27 | for sample in samples: 28 | if not isinstance(sample["image"], list): 29 | sample["image"] = [sample["image"]] 30 | patches_per_image = [] 31 | for img in sample["image"]: 32 | patches_per_image.append(img.shape[0]) 33 | batch_patches_per_image.append(patches_per_image) 34 | batch_total_images.append(len(sample["image"])) 35 | 36 | for sample in samples: 37 | batch_images.append(torch.cat(sample["image"], dim=0)) 38 | batch_questions.append(sample["question"]) 39 | batch_question_ids.append(sample.get("question_id", 0)) 40 | batch_img_ids.append(sample.get("img_id", 0)) 41 | batch_labels.append(sample.get("label", None)) 42 | return { 43 | "image": batch_images, 44 | "question": batch_questions, 45 | "question_id": batch_question_ids, 46 | "img_id": batch_img_ids, 47 | "label":batch_labels, 48 | "patches_per_image": batch_patches_per_image, 49 | "total_images": batch_total_images 50 | } 51 | 52 | 53 | def list_of_str(arg): 54 | return list(map(str, arg.split(','))) 55 | 56 | 57 | parser = eval_parser() 58 | parser.add_argument("--dataset", type=list_of_str, default='gqa', help="dataset to evaluate") 59 | args = parser.parse_args() 60 | cfg = Config(args) 61 | 62 | model, vis_processor = init_model(args) 63 | 64 | model_config = cfg.model_cfg 65 | model_cls = registry.get_model_class(model_config.arch) 66 | CONV_VISION = conv_dict[model_config.model_type] 67 | proce_type = list(cfg.datasets_cfg.keys())[0] 68 | vis_proce_type = cfg.datasets_cfg.get(proce_type).vis_processor.train.name 69 | print("model_config.model_type: {}".format(model_config.model_type)) 70 | print("vision process type: {}".format(vis_proce_type)) 71 | 72 | conv_temp = CONV_VISION.copy() 73 | conv_temp.system = "" 74 | model.eval() 75 | save_path = cfg.run_cfg.save_path 76 | 77 | 78 | 79 | if 'okvqa' in args.dataset: 80 | 81 | print("################## OKVQA EVAL ###############") 82 | eval_file_path = cfg.evaluation_datasets_cfg["okvqa"]["eval_file_path"] 83 | img_path = cfg.evaluation_datasets_cfg["okvqa"]["img_path"] 84 | batch_size = cfg.evaluation_datasets_cfg["okvqa"]["batch_size"] 85 | max_new_tokens = cfg.evaluation_datasets_cfg["okvqa"]["max_new_tokens"] 86 | 87 | print_res_flag = True 88 | evaluation_annntation_path = os.path.join(eval_file_path, "okvqa_test_split.json") 89 | with open(evaluation_annntation_path) as f: 90 | ok_vqa_test_split = json.load(f) 91 | 92 | 93 | data = OKVQAEvalData(ok_vqa_test_split, vis_processor, img_path) 94 | if vis_proce_type == "hd_image_train": 95 | eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=collater4hd) 96 | else: 97 | eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) 98 | vxverse_predict = [] 99 | 100 | for i, samples in enumerate(tqdm(eval_dataloader)): 101 | images, questions, question_ids, img_ids = samples["image"], samples["question"], samples["question_id"], samples["img_id"] 102 | patches_per_images = samples.get("patches_per_image", None) 103 | total_images = samples.get("total_images", None) 104 | texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template 105 | texts = [text.lstrip() for text in texts] 106 | 107 | answers = model.generate(images, texts, patches_per_images=patches_per_images, max_new_tokens=max_new_tokens, do_sample=False, stop_words_ids=stop_words_ids) 108 | 109 | for answer, question_id, question, img_id, text in zip(answers, question_ids, questions, img_ids, texts): 110 | result = dict() 111 | 112 | answer = answer.lower() 113 | 114 | result['answer'] = answer 115 | result['question_id'] = int(question_id) 116 | result["Prompt"] = text 117 | vxverse_predict.append(result) 118 | if i % 10 == 0: 119 | print(vxverse_predict[i]) 120 | 121 | 122 | file_save_path= os.path.join(save_path,"okvqa.json") 123 | with open(file_save_path,'w', encoding='utf-8') as f: 124 | for res in vxverse_predict: 125 | f.write(json.dumps(res, ensure_ascii=False)) 126 | f.write("\n") 127 | 128 | annFile = os.path.join(eval_file_path,"mscoco_val2014_annotations_clean.json") 129 | quesFile = os.path.join(eval_file_path,"OpenEnded_mscoco_val2014_questions_clean.json" ) 130 | 131 | vqa = VQA(annFile, quesFile) 132 | vqaRes = vqa.loadRes(file_save_path, quesFile) 133 | 134 | vqaEval = VQAEval(vqa, vqaRes, n=2) 135 | vqaEval.evaluate() 136 | print ("Overall OKVQA Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']), flush=True) 137 | 138 | 139 | 140 | if 'gqa' in args.dataset: 141 | 142 | eval_file_path = cfg.evaluation_datasets_cfg["gqa"]["eval_file_path"] 143 | img_path = cfg.evaluation_datasets_cfg["gqa"]["img_path"] 144 | batch_size = cfg.evaluation_datasets_cfg["gqa"]["batch_size"] 145 | max_new_tokens = cfg.evaluation_datasets_cfg["gqa"]["max_new_tokens"] 146 | gqa = json.load(open(eval_file_path)) 147 | data = GQAEvalData(gqa, vis_processor, img_path) 148 | if vis_proce_type == "hd_image_train": 149 | eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False, collate_fn=collater4hd) 150 | else: 151 | eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) 152 | count=0 153 | total=0 154 | print_prompt_flag = True 155 | vxverse_predict = [] 156 | for i, samples in enumerate(tqdm(eval_dataloader)): 157 | images, texts, labels = samples["image"], samples["question"], samples["label"] 158 | patches_per_images = samples.get("patches_per_image", None) 159 | total_images = samples.get("total_images", None) 160 | texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template 161 | texts = [text.lstrip() for text in texts] 162 | 163 | if print_prompt_flag: 164 | print("########## Prompts ###########") 165 | print(texts) 166 | print_prompt_flag = False 167 | answers = model.generate(images, texts, patches_per_images=patches_per_images, max_new_tokens=max_new_tokens, do_sample=do_sample, stop_words_ids=stop_words_ids) 168 | 169 | for answer, label, text in zip(answers, labels, texts): 170 | result = dict() 171 | answer = answer.lower() 172 | result['pred'] = answer 173 | result['gt'] = label 174 | result['Prompt'] = text 175 | vxverse_predict.append(result) 176 | if label in answer.lower(): 177 | count+=1 178 | total+=1 179 | if i % 20 == 0: 180 | print(vxverse_predict[i]) 181 | 182 | print("acc count:", count) 183 | 184 | print('gqa val:', count / total * 100, flush=True) 185 | 186 | file_save_path = os.path.join(save_path, "gqa.json") 187 | with open(file_save_path,'w', encoding='utf-8') as f: 188 | 189 | for res in vxverse_predict: 190 | f.write(json.dumps(res, ensure_ascii=False)) 191 | f.write("\n") 192 | 193 | 194 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchaudio 3 | torchvision 4 | huggingface-hub 5 | psutil==5.9.4 6 | iopath 7 | fifty 8 | pyyaml==6.0 9 | regex==2022.10.31 10 | tokenizers==0.13.3 11 | huggingface-hub==0.18.0 12 | matplotlib==3.7.0 13 | tqdm==4.64.1 14 | transformers==4.30.0 15 | timm==0.6.13 16 | webdataset==0.2.48 17 | omegaconf==2.3.0 18 | peft==0.2.0 19 | opencv-python==4.7.0.72 20 | gradio==3.47.1 21 | decord==0.6.0 22 | sentence-transformers 23 | accelerate==0.20.3 24 | scikit-image 25 | visual-genome 26 | wandb 27 | ninja 28 | xformers==0.0.20 29 | 30 | 31 | -------------------------------------------------------------------------------- /resources/modelscope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/resources/modelscope.png -------------------------------------------------------------------------------- /resources/wechat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/resources/wechat.png -------------------------------------------------------------------------------- /vxverse/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from vxverse.common.registry import registry 14 | 15 | from vxverse.datasets.builders import * 16 | from vxverse.models import * 17 | from vxverse.processors import * 18 | # from vxverse.tasks import * 19 | 20 | 21 | root_dir = os.path.dirname(os.path.abspath(__file__)) 22 | default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) 23 | 24 | registry.register_path("library_root", root_dir) 25 | repo_root = os.path.join(root_dir, "..") 26 | registry.register_path("repo_root", repo_root) 27 | cache_root = os.path.join(repo_root, default_cfg.env.cache_root) 28 | registry.register_path("cache_root", cache_root) 29 | 30 | registry.register("MAX_INT", sys.maxsize) 31 | registry.register("SPLIT_NAMES", ["train", "val", "test"]) 32 | -------------------------------------------------------------------------------- /vxverse/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/vxverse/common/__init__.py -------------------------------------------------------------------------------- /vxverse/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | import logging 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if args.distributed is False: 59 | print("Not using distributed mode") 60 | logging.info("Not using distributed mode") 61 | return 62 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 63 | print("Using distributed mode") 64 | args.rank = int(os.environ["RANK"]) 65 | args.world_size = int(os.environ["WORLD_SIZE"]) 66 | args.gpu = int(os.environ["LOCAL_RANK"]) 67 | elif "SLURM_PROCID" in os.environ: 68 | print("Using SLURM_PROCID distributed mode") 69 | args.rank = int(os.environ["SLURM_PROCID"]) 70 | args.gpu = args.rank % torch.cuda.device_count() 71 | else: 72 | print("Not using distributed mode") 73 | logging.info("Not using distributed mode") 74 | args.distributed = False 75 | return 76 | 77 | args.distributed = True 78 | 79 | torch.cuda.set_device(args.gpu) 80 | args.dist_backend = "nccl" 81 | print( 82 | "| distributed init (rank {}, local_rank {} world_size {}) —— init_method: {}".format( 83 | args.rank, args.gpu, args.world_size, args.dist_url 84 | ), 85 | flush=True, 86 | ) 87 | torch.distributed.init_process_group( 88 | backend=args.dist_backend, 89 | init_method=args.dist_url, 90 | world_size=args.world_size, 91 | rank=args.rank, 92 | timeout=datetime.timedelta( 93 | days=365), # allow auto-downloading and de-compressing 94 | ) 95 | ## May cause [E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed 96 | torch.distributed.barrier() 97 | setup_for_distributed(args.rank == 0) 98 | print("Distributed deployment is initialized...") 99 | 100 | 101 | def get_dist_info(): 102 | if torch.__version__ < "1.0": 103 | initialized = dist._initialized 104 | else: 105 | initialized = dist.is_initialized() 106 | if initialized: 107 | rank = dist.get_rank() 108 | world_size = dist.get_world_size() 109 | else: # non-distributed training 110 | rank = 0 111 | world_size = 1 112 | return rank, world_size 113 | 114 | 115 | def main_process(func): 116 | @functools.wraps(func) 117 | def wrapper(*args, **kwargs): 118 | rank, _ = get_dist_info() 119 | if rank == 0: 120 | return func(*args, **kwargs) 121 | 122 | return wrapper 123 | 124 | 125 | def download_cached_file(url, check_hash=True, progress=False): 126 | """ 127 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 128 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 129 | """ 130 | 131 | def get_cached_file_path(): 132 | # a hack to sync the file path across processes 133 | parts = torch.hub.urlparse(url) 134 | filename = os.path.basename(parts.path) 135 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 136 | 137 | return cached_file 138 | 139 | if is_main_process(): 140 | timm_hub.download_cached_file(url, check_hash, progress) 141 | 142 | # May cause [E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed 143 | if is_dist_avail_and_initialized(): 144 | dist.barrier() 145 | 146 | return get_cached_file_path() 147 | -------------------------------------------------------------------------------- /vxverse/common/eval_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from nltk.translate.bleu_score import sentence_bleu 4 | 5 | from vxverse.common.registry import registry 6 | from vxverse.common.config import Config 7 | 8 | # imports modules for registration 9 | # from vxverse.datasets.builders import * 10 | # from vxverse.models import * 11 | # from vxverse.processors import * 12 | # from vxverse.runners import * 13 | # from vxverse.tasks import * 14 | 15 | 16 | 17 | def eval_parser(): 18 | parser = argparse.ArgumentParser(description="Demo") 19 | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") 20 | parser.add_argument("--name", type=str, default='A2', help="evaluation name") 21 | parser.add_argument("--ckpt", type=str, help="path to configuration file.") 22 | parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.") 23 | parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens") 24 | parser.add_argument("--batch_size", type=int, default=32) 25 | parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") 26 | parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") 27 | parser.add_argument( 28 | "--options", 29 | nargs="+", 30 | help="override some settings in the used config, the key-value pair " 31 | "in xxx=yyy format will be merged into config file (deprecate), " 32 | "change to --cfg-options instead.", 33 | ) 34 | return parser 35 | 36 | 37 | def prepare_texts(texts, conv_temp): 38 | convs = [conv_temp.copy() for _ in range(len(texts))] 39 | [conv.append_message( 40 | conv.roles[0], '\n{}'.format(text)) for conv, text in zip(convs, texts)] 41 | [conv.append_message(conv.roles[1], None) for conv in convs] 42 | texts = [conv.get_prompt() for conv in convs] 43 | return texts 44 | 45 | 46 | def init_model(args): 47 | print('Initialization Model') 48 | cfg = Config(args) 49 | # cfg.model_cfg.ckpt = args.ckpt 50 | # cfg.model_cfg.lora_r = args.lora_r 51 | # cfg.model_cfg.lora_alpha = args.lora_alpha 52 | 53 | 54 | model_config = cfg.model_cfg 55 | model_cls = registry.get_model_class(model_config.arch) 56 | print("############# Model Info #################") 57 | print(model_config) 58 | print(model_cls) 59 | print("#########################################") 60 | model = model_cls.from_config(model_config).to('cuda:0') 61 | 62 | # import pudb; pudb.set_trace() 63 | key = list(cfg.datasets_cfg.keys())[0] 64 | vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train 65 | vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) 66 | print('Initialization Finished') 67 | return model, vis_processor 68 | 69 | def computeIoU(bbox1, bbox2): 70 | x1, y1, x2, y2 = bbox1 71 | x3, y3, x4, y4 = bbox2 72 | intersection_x1 = max(x1, x3) 73 | intersection_y1 = max(y1, y3) 74 | intersection_x2 = min(x2, x4) 75 | intersection_y2 = min(y2, y4) 76 | intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) 77 | bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) 78 | bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) 79 | union_area = bbox1_area + bbox2_area - intersection_area 80 | iou = intersection_area / union_area 81 | return iou 82 | -------------------------------------------------------------------------------- /vxverse/common/gradcam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy.ndimage import filters 4 | from skimage import transform as skimage_transform 5 | 6 | 7 | def getAttMap(img, attMap, blur=True, overlap=True): 8 | attMap -= attMap.min() 9 | if attMap.max() > 0: 10 | attMap /= attMap.max() 11 | attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") 12 | if blur: 13 | attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) 14 | attMap -= attMap.min() 15 | attMap /= attMap.max() 16 | cmap = plt.get_cmap("jet") 17 | attMapV = cmap(attMap) 18 | attMapV = np.delete(attMapV, 3, 2) 19 | if overlap: 20 | attMap = ( 21 | 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img 22 | + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV 23 | ) 24 | return attMap 25 | -------------------------------------------------------------------------------- /vxverse/common/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import logging 10 | import time 11 | from collections import defaultdict, deque 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | from vxverse.common import dist_utils 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not dist_utils.is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | # May cause [E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value, 80 | ) 81 | 82 | 83 | class MetricLogger(object): 84 | def __init__(self, delimiter="\t"): 85 | self.meters = defaultdict(SmoothedValue) 86 | self.delimiter = delimiter 87 | 88 | def update(self, **kwargs): 89 | for k, v in kwargs.items(): 90 | if isinstance(v, torch.Tensor): 91 | v = v.item() 92 | assert isinstance(v, (float, int)) 93 | self.meters[k].update(v) 94 | 95 | def __getattr__(self, attr): 96 | if attr in self.meters: 97 | return self.meters[attr] 98 | if attr in self.__dict__: 99 | return self.__dict__[attr] 100 | raise AttributeError( 101 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 102 | ) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append("{}: {}".format(name, str(meter))) 108 | return self.delimiter.join(loss_str) 109 | 110 | def global_avg(self): 111 | loss_str = [] 112 | for name, meter in self.meters.items(): 113 | loss_str.append("{}: {:.6f}".format(name, meter.global_avg)) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = "" 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt="{avg:.4f}") 130 | data_time = SmoothedValue(fmt="{avg:.4f}") 131 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 132 | log_msg = [ 133 | "[{log_date}]", 134 | "iteration: {0" + space_fmt + "}/{1}", 135 | header, 136 | "eta: {eta}", 137 | "{meters}", 138 | "time: {time}", 139 | "data: {data}", 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append("max mem: {memory:.0f}") 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | now = datetime.datetime.now() 153 | log_date = now.strftime("%Y-%m-%d %H:%M:%S") 154 | if torch.cuda.is_available(): 155 | print( 156 | log_msg.format( 157 | i, 158 | len(iterable), 159 | log_date=log_date, 160 | eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), 163 | data=str(data_time), 164 | memory=torch.cuda.max_memory_allocated() / MB, 165 | ) 166 | ) 167 | else: 168 | print( 169 | log_msg.format( 170 | i, 171 | len(iterable), 172 | log_date=log_date, 173 | eta=eta_string, 174 | meters=str(self), 175 | time=str(iter_time), 176 | data=str(data_time), 177 | ) 178 | ) 179 | i += 1 180 | end = time.time() 181 | total_time = time.time() - start_time 182 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 183 | print( 184 | "{} Total time: {} ({:.4f} s / it)".format( 185 | header, total_time_str, total_time / len(iterable) 186 | ) 187 | ) 188 | 189 | 190 | class AttrDict(dict): 191 | def __init__(self, *args, **kwargs): 192 | super(AttrDict, self).__init__(*args, **kwargs) 193 | self.__dict__ = self 194 | 195 | 196 | def setup_logger(): 197 | logging.basicConfig( 198 | level=logging.INFO if dist_utils.is_main_process() else logging.WARN, 199 | format="%(asctime)s [%(levelname)s] %(message)s", 200 | handlers=[logging.StreamHandler()], 201 | ) 202 | logging.info("Logger is setted!") 203 | -------------------------------------------------------------------------------- /vxverse/common/optims.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | 4 | from vxverse.common.registry import registry 5 | 6 | 7 | @registry.register_lr_scheduler("linear_warmup_step_lr") 8 | class LinearWarmupStepLRScheduler: 9 | def __init__( 10 | self, 11 | optimizer, 12 | max_epoch, 13 | min_lr, 14 | init_lr, 15 | decay_rate=1, 16 | warmup_start_lr=-1, 17 | warmup_steps=0, 18 | **kwargs 19 | ): 20 | self.optimizer = optimizer 21 | 22 | self.max_epoch = max_epoch 23 | self.min_lr = min_lr 24 | 25 | self.decay_rate = decay_rate 26 | 27 | self.init_lr = init_lr 28 | self.warmup_steps = warmup_steps 29 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 30 | 31 | def step(self, cur_epoch, cur_step): 32 | if cur_epoch == 0: 33 | warmup_lr_schedule( 34 | step=cur_step, 35 | optimizer=self.optimizer, 36 | max_step=self.warmup_steps, 37 | init_lr=self.warmup_start_lr, 38 | max_lr=self.init_lr, 39 | ) 40 | else: 41 | step_lr_schedule( 42 | epoch=cur_epoch, 43 | optimizer=self.optimizer, 44 | init_lr=self.init_lr, 45 | min_lr=self.min_lr, 46 | decay_rate=self.decay_rate, 47 | ) 48 | 49 | 50 | @registry.register_lr_scheduler("linear_warmup_cosine_lr") 51 | class LinearWarmupCosineLRScheduler: 52 | def __init__( 53 | self, 54 | optimizer, 55 | max_epoch, 56 | iters_per_epoch, 57 | min_lr, 58 | init_lr, 59 | warmup_steps=0, 60 | warmup_start_lr=-1, 61 | **kwargs 62 | ): 63 | self.optimizer = optimizer 64 | 65 | self.max_epoch = max_epoch 66 | self.iters_per_epoch = iters_per_epoch 67 | self.min_lr = min_lr 68 | 69 | self.init_lr = init_lr 70 | self.warmup_steps = warmup_steps 71 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 72 | 73 | def step(self, cur_epoch, cur_step): 74 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 75 | if total_cur_step < self.warmup_steps: 76 | warmup_lr_schedule( 77 | step=cur_step, 78 | optimizer=self.optimizer, 79 | max_step=self.warmup_steps, 80 | init_lr=self.warmup_start_lr, 81 | max_lr=self.init_lr, 82 | ) 83 | else: 84 | cosine_lr_schedule( 85 | epoch=total_cur_step, 86 | optimizer=self.optimizer, 87 | max_epoch=self.max_epoch * self.iters_per_epoch, 88 | init_lr=self.init_lr, 89 | min_lr=self.min_lr, 90 | ) 91 | 92 | 93 | @registry.register_lr_scheduler("linear_warmup_cosine_diff_lr") 94 | class LinearWarmupCosineDiffLRScheduler: 95 | def __init__( 96 | self, 97 | optimizer, 98 | max_epoch, 99 | iters_per_epoch, 100 | min_lr, 101 | init_lr, 102 | adapter_lr=None, 103 | warmup_steps=0, 104 | warmup_start_lr=-1, 105 | **kwargs 106 | ): 107 | self.optimizer = optimizer 108 | 109 | self.max_epoch = max_epoch 110 | self.iters_per_epoch = iters_per_epoch 111 | self.min_lr = min_lr 112 | 113 | self.init_lr = init_lr 114 | self.warmup_steps = warmup_steps 115 | self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr 116 | if adapter_lr==None: 117 | self.adapter_lr = self.init_lr 118 | else: 119 | self.adapter_lr = adapter_lr 120 | 121 | def step(self, cur_epoch, cur_step): 122 | total_cur_step = cur_epoch * self.iters_per_epoch + cur_step 123 | if total_cur_step < self.warmup_steps: 124 | warmup_diff_lr_schedule( 125 | step=cur_step, 126 | optimizer=self.optimizer, 127 | max_step=self.warmup_steps, 128 | init_lr=self.warmup_start_lr, 129 | max_lr=self.init_lr, 130 | adapter_lr=self.adapter_lr, 131 | ) 132 | else: 133 | cosine_diff_lr_schedule( 134 | epoch=total_cur_step, 135 | optimizer=self.optimizer, 136 | max_epoch=self.max_epoch * self.iters_per_epoch, 137 | init_lr=self.init_lr, 138 | min_lr=self.min_lr, 139 | adapter_lr=self.adapter_lr 140 | ) 141 | 142 | 143 | 144 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 145 | """Decay the learning rate""" 146 | lr = (init_lr - min_lr) * 0.5 * ( 147 | 1.0 + math.cos(math.pi * epoch / max_epoch) 148 | ) + min_lr 149 | for param_group in optimizer.param_groups: 150 | param_group["lr"] = lr 151 | 152 | 153 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 154 | """Warmup the learning rate""" 155 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 156 | for param_group in optimizer.param_groups: 157 | param_group["lr"] = lr 158 | 159 | 160 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 161 | """Decay the learning rate""" 162 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 163 | for param_group in optimizer.param_groups: 164 | param_group["lr"] = lr 165 | 166 | 167 | def cosine_diff_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr, adapter_lr): 168 | """Decay the learning rate""" 169 | lr = (init_lr - min_lr) * 0.5 * ( 170 | 1.0 + math.cos(math.pi * epoch / max_epoch) 171 | ) + min_lr 172 | a_lr = (adapter_lr - min_lr) * 0.5 * ( 173 | 1.0 + math.cos(math.pi * epoch / max_epoch) 174 | ) + min_lr 175 | 176 | for param_group in optimizer.param_groups: 177 | if "adapt" in param_group["name"]: 178 | param_group["lr"] = a_lr 179 | else: 180 | param_group["lr"] = lr 181 | 182 | 183 | def warmup_diff_lr_schedule(optimizer, step, max_step, init_lr, max_lr, adapter_lr): 184 | """Warmup the learning rate""" 185 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) 186 | a_lr = min(adapter_lr, init_lr + (adapter_lr - init_lr) * step / max(max_step, 1)) 187 | for param_group in optimizer.param_groups: 188 | if "adapt" in param_group["name"]: 189 | param_group["lr"] = a_lr 190 | else: 191 | param_group["lr"] = lr 192 | 193 | 194 | def step_diff_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate, adapter_lr): 195 | """Decay the learning rate""" 196 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 197 | a_lr = max(min_lr, adapter_lr * (decay_rate ** epoch)) 198 | for param_group in optimizer.param_groups: 199 | if "adapt" in param_group["name"]: 200 | param_group["lr"] = a_lr 201 | else: 202 | param_group["lr"] = lr -------------------------------------------------------------------------------- /vxverse/common/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | class Registry: 10 | mapping = { 11 | "builder_name_mapping": {}, 12 | "task_name_mapping": {}, 13 | "processor_name_mapping": {}, 14 | "model_name_mapping": {}, 15 | "lr_scheduler_name_mapping": {}, 16 | "runner_name_mapping": {}, 17 | "state": {}, 18 | "paths": {}, 19 | } 20 | 21 | @classmethod 22 | def register_builder(cls, name): 23 | r"""Register a dataset builder to registry with key 'name' 24 | 25 | Args: 26 | name: Key with which the builder will be registered. 27 | 28 | Usage: 29 | 30 | from vxverse.common.registry import registry 31 | from vxverse.datasets.base_dataset_builder import BaseDatasetBuilder 32 | """ 33 | 34 | def wrap(builder_cls): 35 | from vxverse.datasets.builders.base_dataset_builder import BaseDatasetBuilder 36 | 37 | assert issubclass( 38 | builder_cls, BaseDatasetBuilder 39 | ), "All builders must inherit BaseDatasetBuilder class, found {}".format( 40 | builder_cls 41 | ) 42 | if name in cls.mapping["builder_name_mapping"]: 43 | raise KeyError( 44 | "Name '{}' already registered for {}.".format( 45 | name, cls.mapping["builder_name_mapping"][name] 46 | ) 47 | ) 48 | cls.mapping["builder_name_mapping"][name] = builder_cls 49 | return builder_cls 50 | 51 | return wrap 52 | 53 | @classmethod 54 | def register_task(cls, name): 55 | r"""Register a task to registry with key 'name' 56 | 57 | Args: 58 | name: Key with which the task will be registered. 59 | 60 | Usage: 61 | 62 | from vxverse.common.registry import registry 63 | """ 64 | 65 | def wrap(task_cls): 66 | from vxverse.tasks.base_task import BaseTask 67 | 68 | assert issubclass( 69 | task_cls, BaseTask 70 | ), "All tasks must inherit BaseTask class" 71 | if name in cls.mapping["task_name_mapping"]: 72 | raise KeyError( 73 | "Name '{}' already registered for {}.".format( 74 | name, cls.mapping["task_name_mapping"][name] 75 | ) 76 | ) 77 | cls.mapping["task_name_mapping"][name] = task_cls 78 | return task_cls 79 | 80 | return wrap 81 | 82 | @classmethod 83 | def register_model(cls, name): 84 | r"""Register a task to registry with key 'name' 85 | 86 | Args: 87 | name: Key with which the task will be registered. 88 | 89 | Usage: 90 | 91 | from vxverse.common.registry import registry 92 | """ 93 | 94 | def wrap(model_cls): 95 | from vxverse.models import BaseModel 96 | 97 | assert issubclass( 98 | model_cls, BaseModel 99 | ), "All models must inherit BaseModel class" 100 | if name in cls.mapping["model_name_mapping"]: 101 | raise KeyError( 102 | "Name '{}' already registered for {}.".format( 103 | name, cls.mapping["model_name_mapping"][name] 104 | ) 105 | ) 106 | cls.mapping["model_name_mapping"][name] = model_cls 107 | return model_cls 108 | 109 | return wrap 110 | 111 | @classmethod 112 | def register_processor(cls, name): 113 | r"""Register a processor to registry with key 'name' 114 | 115 | Args: 116 | name: Key with which the task will be registered. 117 | 118 | Usage: 119 | 120 | from vxverse.common.registry import registry 121 | """ 122 | 123 | def wrap(processor_cls): 124 | from vxverse.processors import BaseProcessor 125 | 126 | assert issubclass( 127 | processor_cls, BaseProcessor 128 | ), "All processors must inherit BaseProcessor class" 129 | if name in cls.mapping["processor_name_mapping"]: 130 | raise KeyError( 131 | "Name '{}' already registered for {}.".format( 132 | name, cls.mapping["processor_name_mapping"][name] 133 | ) 134 | ) 135 | cls.mapping["processor_name_mapping"][name] = processor_cls 136 | return processor_cls 137 | 138 | return wrap 139 | 140 | @classmethod 141 | def register_lr_scheduler(cls, name): 142 | r"""Register a model to registry with key 'name' 143 | 144 | Args: 145 | name: Key with which the task will be registered. 146 | 147 | Usage: 148 | 149 | from vxverse.common.registry import registry 150 | """ 151 | 152 | def wrap(lr_sched_cls): 153 | if name in cls.mapping["lr_scheduler_name_mapping"]: 154 | raise KeyError( 155 | "Name '{}' already registered for {}.".format( 156 | name, cls.mapping["lr_scheduler_name_mapping"][name] 157 | ) 158 | ) 159 | cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls 160 | return lr_sched_cls 161 | 162 | return wrap 163 | 164 | @classmethod 165 | def register_runner(cls, name): 166 | r"""Register a model to registry with key 'name' 167 | 168 | Args: 169 | name: Key with which the task will be registered. 170 | 171 | Usage: 172 | 173 | from vxverse.common.registry import registry 174 | """ 175 | 176 | def wrap(runner_cls): 177 | if name in cls.mapping["runner_name_mapping"]: 178 | raise KeyError( 179 | "Name '{}' already registered for {}.".format( 180 | name, cls.mapping["runner_name_mapping"][name] 181 | ) 182 | ) 183 | cls.mapping["runner_name_mapping"][name] = runner_cls 184 | return runner_cls 185 | 186 | return wrap 187 | 188 | @classmethod 189 | def register_path(cls, name, path): 190 | r"""Register a path to registry with key 'name' 191 | 192 | Args: 193 | name: Key with which the path will be registered. 194 | 195 | Usage: 196 | 197 | from vxverse.common.registry import registry 198 | """ 199 | assert isinstance(path, str), "All path must be str." 200 | if name in cls.mapping["paths"]: 201 | raise KeyError("Name '{}' already registered.".format(name)) 202 | cls.mapping["paths"][name] = path 203 | 204 | @classmethod 205 | def register(cls, name, obj): 206 | r"""Register an item to registry with key 'name' 207 | 208 | Args: 209 | name: Key with which the item will be registered. 210 | 211 | Usage:: 212 | 213 | from vxverse.common.registry import registry 214 | 215 | registry.register("config", {}) 216 | """ 217 | path = name.split(".") 218 | current = cls.mapping["state"] 219 | 220 | for part in path[:-1]: 221 | if part not in current: 222 | current[part] = {} 223 | current = current[part] 224 | 225 | current[path[-1]] = obj 226 | 227 | # @classmethod 228 | # def get_trainer_class(cls, name): 229 | # return cls.mapping["trainer_name_mapping"].get(name, None) 230 | 231 | @classmethod 232 | def get_builder_class(cls, name): 233 | return cls.mapping["builder_name_mapping"].get(name, None) 234 | 235 | @classmethod 236 | def get_model_class(cls, name): 237 | return cls.mapping["model_name_mapping"].get(name, None) 238 | 239 | @classmethod 240 | def get_task_class(cls, name): 241 | return cls.mapping["task_name_mapping"].get(name, None) 242 | 243 | @classmethod 244 | def get_processor_class(cls, name): 245 | return cls.mapping["processor_name_mapping"].get(name, None) 246 | 247 | @classmethod 248 | def get_lr_scheduler_class(cls, name): 249 | return cls.mapping["lr_scheduler_name_mapping"].get(name, None) 250 | 251 | @classmethod 252 | def get_runner_class(cls, name): 253 | return cls.mapping["runner_name_mapping"].get(name, None) 254 | 255 | @classmethod 256 | def list_runners(cls): 257 | return sorted(cls.mapping["runner_name_mapping"].keys()) 258 | 259 | @classmethod 260 | def list_models(cls): 261 | return sorted(cls.mapping["model_name_mapping"].keys()) 262 | 263 | @classmethod 264 | def list_tasks(cls): 265 | return sorted(cls.mapping["task_name_mapping"].keys()) 266 | 267 | @classmethod 268 | def list_processors(cls): 269 | return sorted(cls.mapping["processor_name_mapping"].keys()) 270 | 271 | @classmethod 272 | def list_lr_schedulers(cls): 273 | return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) 274 | 275 | @classmethod 276 | def list_datasets(cls): 277 | return sorted(cls.mapping["builder_name_mapping"].keys()) 278 | 279 | @classmethod 280 | def get_path(cls, name): 281 | return cls.mapping["paths"].get(name, None) 282 | 283 | @classmethod 284 | def get(cls, name, default=None, no_warning=False): 285 | r"""Get an item from registry with key 'name' 286 | 287 | Args: 288 | name (string): Key whose value needs to be retrieved. 289 | default: If passed and key is not in registry, default value will 290 | be returned with a warning. Default: None 291 | no_warning (bool): If passed as True, warning when key doesn't exist 292 | will not be generated. Useful for MMF's 293 | internal operations. Default: False 294 | """ 295 | original_name = name 296 | name = name.split(".") 297 | value = cls.mapping["state"] 298 | for subname in name: 299 | value = value.get(subname, default) 300 | if value is default: 301 | break 302 | 303 | if ( 304 | "writer" in cls.mapping["state"] 305 | and value == default 306 | and no_warning is False 307 | ): 308 | cls.mapping["state"]["writer"].warning( 309 | "Key {} is not present in registry, returning default value " 310 | "of {}".format(original_name, default) 311 | ) 312 | return value 313 | 314 | @classmethod 315 | def unregister(cls, name): 316 | r"""Remove an item from registry with key 'name' 317 | 318 | Args: 319 | name: Key which needs to be removed. 320 | Usage:: 321 | 322 | from mmf.common.registry import registry 323 | 324 | config = registry.unregister("config") 325 | """ 326 | return cls.mapping["state"].pop(name, None) 327 | 328 | 329 | registry = Registry() 330 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import sys 4 | dataDir = '../../VQA' 5 | sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir)) 6 | from vqa import VQA 7 | from vqaEvaluation.vqaEval import VQAEval 8 | import matplotlib.pyplot as plt 9 | import skimage.io as io 10 | import json 11 | import random 12 | import os 13 | 14 | # # set up file names and paths 15 | # versionType ='v2_' # this should be '' when using VQA v2.0 dataset 16 | # taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 17 | # dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. 18 | # dataSubType ='train2014' 19 | # annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) 20 | # quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) 21 | # imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) 22 | # resultType ='fake' 23 | # fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType'] 24 | # 25 | # # An example result json file has been provided in './Results' folder. 26 | # 27 | # [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \ 28 | # resultType, fileType) for fileType in fileTypes] 29 | # 30 | # # create vqa object and vqaRes object 31 | # vqa = VQA(annFile, quesFile) 32 | # vqaRes = vqa.loadRes(resFile, quesFile) 33 | # 34 | # # create vqaEval object by taking vqa and vqaRes 35 | # vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2 36 | # 37 | # # evaluate results 38 | # """ 39 | # If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function 40 | # By default it uses all the question ids in annotation file 41 | # """ 42 | # vqaEval.evaluate() 43 | # 44 | # # print accuracies 45 | # print "\n" 46 | # print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']) 47 | # print "Per Question Type Accuracy is the following:" 48 | # for quesType in vqaEval.accuracy['perQuestionType']: 49 | # print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType]) 50 | # print "\n" 51 | # print "Per Answer Type Accuracy is the following:" 52 | # for ansType in vqaEval.accuracy['perAnswerType']: 53 | # print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType]) 54 | # print "\n" 55 | # # demo how to use evalQA to retrieve low score result 56 | # evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy 57 | # if len(evals) > 0: 58 | # print 'ground truth answers' 59 | # randomEval = random.choice(evals) 60 | # randomAnn = vqa.loadQA(randomEval) 61 | # vqa.showQA(randomAnn) 62 | # 63 | # print '\n' 64 | # print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval]) 65 | # ann = vqaRes.loadQA(randomEval)[0] 66 | # print "Answer: %s\n" %(ann['answer']) 67 | # 68 | # imgId = randomAnn[0]['image_id'] 69 | # imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 70 | # if os.path.isfile(imgDir + imgFilename): 71 | # I = io.imread(imgDir + imgFilename) 72 | # plt.imshow(I) 73 | # plt.axis('off') 74 | # plt.show() 75 | # 76 | # # plot accuracy for various question types 77 | # plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center') 78 | # plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10) 79 | # plt.title('Per Question Type Accuracy', fontsize=10) 80 | # plt.xlabel('Question Types', fontsize=10) 81 | # plt.ylabel('Accuracy', fontsize=10) 82 | # plt.show() 83 | # 84 | # # save evaluation results to ./Results folder 85 | # json.dump(vqaEval.accuracy, open(accuracyFile, 'w')) 86 | # json.dump(vqaEval.evalQA, open(evalQAFile, 'w')) 87 | # json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w')) 88 | # json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w')) 89 | 90 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py: -------------------------------------------------------------------------------- 1 | author='aagrawal' 2 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | __author__='aagrawal' 4 | 5 | import re 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). 8 | import sys 9 | 10 | 11 | class VQAEval: 12 | def __init__(self, vqa, vqaRes, n=2): 13 | self.n = n 14 | self.accuracy = {} 15 | self.evalQA = {} 16 | self.evalQuesType = {} 17 | self.evalAnsType = {} 18 | self.vqa = vqa 19 | self.vqaRes = vqaRes 20 | self.params = {'question_id': vqa.getQuesIds()} 21 | self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ 22 | "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ 23 | "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ 24 | "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ 25 | "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ 26 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ 27 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ 28 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ 29 | "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ 30 | "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ 31 | "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ 32 | "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ 33 | "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ 34 | "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ 35 | "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ 36 | "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ 37 | "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ 38 | "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ 39 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ 40 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ 41 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ 42 | "youll": "you'll", "youre": "you're", "youve": "you've"} 43 | self.manualMap = { 'none': '0', 44 | 'zero': '0', 45 | 'one': '1', 46 | 'two': '2', 47 | 'three': '3', 48 | 'four': '4', 49 | 'five': '5', 50 | 'six': '6', 51 | 'seven': '7', 52 | 'eight': '8', 53 | 'nine': '9', 54 | 'ten': '10' 55 | } 56 | self.articles = ['a', 57 | 'an', 58 | 'the' 59 | ] 60 | 61 | 62 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 63 | self.commaStrip = re.compile("(\d)(\,)(\d)") 64 | self.punct = [';', r"/", '[', ']', '"', '{', '}', 65 | '(', ')', '=', '+', '\\', '_', '-', 66 | '>', '<', '@', '`', ',', '?', '!'] 67 | 68 | 69 | def evaluate(self, quesIds=None): 70 | if quesIds == None: 71 | quesIds = [quesId for quesId in self.params['question_id']] 72 | gts = {} 73 | res = {} 74 | for quesId in quesIds: 75 | gts[quesId] = self.vqa.qa[quesId] 76 | res[quesId] = self.vqaRes.qa[quesId] 77 | 78 | # ================================================= 79 | # Compute accuracy 80 | # ================================================= 81 | accQA = [] 82 | accQuesType = {} 83 | accAnsType = {} 84 | # print "computing accuracy" 85 | step = 0 86 | for quesId in quesIds: 87 | for ansDic in gts[quesId]['answers']: 88 | ansDic['answer'] = ansDic['answer'].replace('\n', ' ') 89 | ansDic['answer'] = ansDic['answer'].replace('\t', ' ') 90 | ansDic['answer'] = ansDic['answer'].strip() 91 | resAns = res[quesId]['answer'] 92 | resAns = resAns.replace('\n', ' ') 93 | resAns = resAns.replace('\t', ' ') 94 | resAns = resAns.strip() 95 | gtAcc = [] 96 | gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] 97 | 98 | if len(set(gtAnswers)) > 1: 99 | for ansDic in gts[quesId]['answers']: 100 | ansDic['answer'] = self.processPunctuation(ansDic['answer']) 101 | ansDic['answer'] = self.processDigitArticle(ansDic['answer']) 102 | resAns = self.processPunctuation(resAns) 103 | resAns = self.processDigitArticle(resAns) 104 | 105 | for gtAnsDatum in gts[quesId]['answers']: 106 | otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] 107 | matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()] 108 | acc = min(1, float(len(matchingAns))/3) 109 | gtAcc.append(acc) 110 | quesType = gts[quesId]['question_type'] 111 | ansType = gts[quesId]['answer_type'] 112 | avgGTAcc = float(sum(gtAcc))/len(gtAcc) 113 | accQA.append(avgGTAcc) 114 | if quesType not in accQuesType: 115 | accQuesType[quesType] = [] 116 | accQuesType[quesType].append(avgGTAcc) 117 | if ansType not in accAnsType: 118 | accAnsType[ansType] = [] 119 | accAnsType[ansType].append(avgGTAcc) 120 | self.setEvalQA(quesId, avgGTAcc) 121 | self.setEvalQuesType(quesId, quesType, avgGTAcc) 122 | self.setEvalAnsType(quesId, ansType, avgGTAcc) 123 | if step%100 == 0: 124 | self.updateProgress(step/float(len(quesIds))) 125 | step = step + 1 126 | 127 | self.setAccuracy(accQA, accQuesType, accAnsType) 128 | # print "Done computing accuracy" 129 | 130 | def processPunctuation(self, inText): 131 | outText = inText 132 | for p in self.punct: 133 | if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): 134 | outText = outText.replace(p, '') 135 | else: 136 | outText = outText.replace(p, ' ') 137 | outText = self.periodStrip.sub("", 138 | outText, 139 | re.UNICODE) 140 | return outText 141 | 142 | def processDigitArticle(self, inText): 143 | outText = [] 144 | tempText = inText.lower().split() 145 | for word in tempText: 146 | word = self.manualMap.setdefault(word, word) 147 | if word not in self.articles: 148 | outText.append(word) 149 | else: 150 | pass 151 | for wordId, word in enumerate(outText): 152 | if word in self.contractions: 153 | outText[wordId] = self.contractions[word] 154 | outText = ' '.join(outText) 155 | return outText 156 | 157 | def setAccuracy(self, accQA, accQuesType, accAnsType): 158 | self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n) 159 | self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType} 160 | self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType} 161 | 162 | def setEvalQA(self, quesId, acc): 163 | self.evalQA[quesId] = round(100*acc, self.n) 164 | 165 | def setEvalQuesType(self, quesId, quesType, acc): 166 | if quesType not in self.evalQuesType: 167 | self.evalQuesType[quesType] = {} 168 | self.evalQuesType[quesType][quesId] = round(100*acc, self.n) 169 | 170 | def setEvalAnsType(self, quesId, ansType, acc): 171 | if ansType not in self.evalAnsType: 172 | self.evalAnsType[ansType] = {} 173 | self.evalAnsType[ansType][quesId] = round(100*acc, self.n) 174 | 175 | def updateProgress(self, progress): 176 | barLength = 20 177 | status = "" 178 | if isinstance(progress, int): 179 | progress = float(progress) 180 | if not isinstance(progress, float): 181 | progress = 0 182 | status = "error: progress var must be float\r\n" 183 | if progress < 0: 184 | progress = 0 185 | status = "Halt...\r\n" 186 | if progress >= 1: 187 | progress = 1 188 | status = "Done...\r\n" 189 | block = int(round(barLength*progress)) 190 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status) 191 | sys.stdout.write(text) 192 | sys.stdout.flush() 193 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from vqaTools.vqa import VQA 4 | import random 5 | import skimage.io as io 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | dataDir ='../../VQA' 10 | versionType ='v2_' # this should be '' when using VQA v2.0 dataset 11 | taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 12 | dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. 13 | dataSubType ='train2014' 14 | annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType) 15 | quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType) 16 | imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType) 17 | 18 | # initialize VQA api for QA annotations 19 | vqa=VQA(annFile, quesFile) 20 | 21 | # load and display QA annotations for given question types 22 | """ 23 | All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder. 24 | """ 25 | annIds = vqa.getQuesIds(quesTypes='how many'); 26 | anns = vqa.loadQA(annIds) 27 | randomAnn = random.choice(anns) 28 | vqa.showQA([randomAnn]) 29 | imgId = randomAnn['image_id'] 30 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 31 | if os.path.isfile(imgDir + imgFilename): 32 | I = io.imread(imgDir + imgFilename) 33 | plt.imshow(I) 34 | plt.axis('off') 35 | plt.show() 36 | 37 | # load and display QA annotations for given answer types 38 | """ 39 | ansTypes can be one of the following 40 | yes/no 41 | number 42 | other 43 | """ 44 | annIds = vqa.getQuesIds(ansTypes='yes/no'); 45 | anns = vqa.loadQA(annIds) 46 | randomAnn = random.choice(anns) 47 | vqa.showQA([randomAnn]) 48 | imgId = randomAnn['image_id'] 49 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 50 | if os.path.isfile(imgDir + imgFilename): 51 | I = io.imread(imgDir + imgFilename) 52 | plt.imshow(I) 53 | plt.axis('off') 54 | plt.show() 55 | 56 | # load and display QA annotations for given images 57 | """ 58 | Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[]) 59 | Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types. 60 | """ 61 | ids = vqa.getImgIds() 62 | annIds = vqa.getQuesIds(imgIds=random.sample(ids,5)); 63 | anns = vqa.loadQA(annIds) 64 | randomAnn = random.choice(anns) 65 | vqa.showQA([randomAnn]) 66 | imgId = randomAnn['image_id'] 67 | imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg' 68 | if os.path.isfile(imgDir + imgFilename): 69 | I = io.imread(imgDir + imgFilename) 70 | plt.imshow(I) 71 | plt.axis('off') 72 | plt.show() 73 | 74 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | 24 | class VQA: 25 | def __init__(self, annotation_file=None, question_file=None): 26 | """ 27 | Constructor of VQA helper class for reading and visualizing questions and answers. 28 | :param annotation_file (str): location of VQA annotation file 29 | :return: 30 | """ 31 | # load dataset 32 | self.dataset = {} 33 | self.questions = {} 34 | self.qa = {} 35 | self.qqa = {} 36 | self.imgToQA = {} 37 | if not annotation_file == None and not question_file == None: 38 | # print 'loading VQA annotations and questions into memory...' 39 | time_t = datetime.datetime.utcnow() 40 | dataset = json.load(open(annotation_file, 'r')) 41 | questions = json.load(open(question_file, 'r')) 42 | # print datetime.datetime.utcnow() - time_t 43 | self.dataset = dataset 44 | self.questions = questions 45 | self.createIndex() 46 | 47 | def createIndex(self): 48 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 49 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 50 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 51 | for ann in self.dataset['annotations']: 52 | imgToQA[ann['image_id']] += [ann] 53 | qa[ann['question_id']] = ann 54 | for ques in self.questions['questions']: 55 | qqa[ques['question_id']] = ques 56 | # print 'index created!' 57 | 58 | # create class members 59 | self.qa = qa 60 | self.qqa = qqa 61 | self.imgToQA = imgToQA 62 | 63 | def info(self): 64 | """ 65 | Print information about the VQA annotation file. 66 | :return: 67 | """ 68 | 69 | # for key, value in self.datset['info'].items(): 70 | # print '%s: %s'%(key, value) 71 | 72 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 73 | """ 74 | Get question ids that satisfy given filter conditions. default skips that filter 75 | :param imgIds (int array) : get question ids for given imgs 76 | quesTypes (str array) : get question ids for given question types 77 | ansTypes (str array) : get question ids for given answer types 78 | :return: ids (int array) : integer array of question ids 79 | """ 80 | imgIds = imgIds if type(imgIds) == list else [imgIds] 81 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 82 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 83 | 84 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 85 | anns = self.dataset['annotations'] 86 | else: 87 | if not len(imgIds) == 0: 88 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], []) 89 | else: 90 | anns = self.dataset['annotations'] 91 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 92 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 93 | ids = [ann['question_id'] for ann in anns] 94 | return ids 95 | 96 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 97 | """ 98 | Get image ids that satisfy given filter conditions. default skips that filter 99 | :param quesIds (int array) : get image ids for given question ids 100 | quesTypes (str array) : get image ids for given question types 101 | ansTypes (str array) : get image ids for given answer types 102 | :return: ids (int array) : integer array of image ids 103 | """ 104 | quesIds = quesIds if type(quesIds) == list else [quesIds] 105 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 106 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 107 | 108 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 109 | anns = self.dataset['annotations'] 110 | else: 111 | if not len(quesIds) == 0: 112 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], []) 113 | else: 114 | anns = self.dataset['annotations'] 115 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 116 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 117 | ids = [ann['image_id'] for ann in anns] 118 | return ids 119 | 120 | def loadQA(self, ids=[]): 121 | """ 122 | Load questions and answers with the specified question ids. 123 | :param ids (int array) : integer ids specifying question ids 124 | :return: qa (object array) : loaded qa objects 125 | """ 126 | if type(ids) == list: 127 | return [self.qa[id] for id in ids] 128 | elif type(ids) == int: 129 | return [self.qa[ids]] 130 | 131 | def showQA(self, anns): 132 | """ 133 | Display the specified annotations. 134 | :param anns (array of object): annotations to display 135 | :return: None 136 | """ 137 | if len(anns) == 0: 138 | return 0 139 | for ann in anns: 140 | quesId = ann['question_id'] 141 | print("Question: %s" % (self.qqa[quesId]['question'])) 142 | for ans in ann['answers']: 143 | print("Answer %d: %s" % (ans['answer_id'], ans['answer'])) 144 | 145 | def loadRes(self, resFile, quesFile): 146 | """ 147 | Load result file and return a result object. 148 | :param resFile (str) : file name of result file 149 | :return: res (obj) : result api object 150 | """ 151 | res = VQA() 152 | res.questions = json.load(open(quesFile)) 153 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 154 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 155 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 156 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) 157 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 158 | 159 | # print 'Loading and preparing results... ' 160 | time_t = datetime.datetime.utcnow() 161 | anns = [] 162 | with open(resFile, 'r', encoding='utf-8') as f: 163 | for line in f.readlines(): 164 | anns.append(json.loads(line.strip())) 165 | # anns = json.load(open(resFile)) 166 | assert type(anns) == list, 'results is not an array of objects' 167 | annsQuesIds = [ann['question_id'] for ann in anns] 168 | assert set(annsQuesIds) == set(self.getQuesIds()), \ 169 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 170 | for ann in anns: 171 | quesId = ann['question_id'] 172 | if res.dataset['task_type'] == 'Multiple Choice': 173 | assert ann['answer'] in self.qqa[quesId][ 174 | 'multiple_choices'], 'predicted answer is not one of the multiple choices' 175 | qaAnn = self.qa[quesId] 176 | ann['image_id'] = qaAnn['image_id'] 177 | ann['question_type'] = qaAnn['question_type'] 178 | ann['answer_type'] = qaAnn['answer_type'] 179 | # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()) 180 | 181 | res.dataset['annotations'] = anns 182 | res.createIndex() 183 | return res 184 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt: -------------------------------------------------------------------------------- 1 | how many 2 | what color is the 3 | is the 4 | where is the 5 | what 6 | what is 7 | are the 8 | what is the 9 | is there a 10 | does the 11 | is the woman 12 | is the man 13 | what is on the 14 | is it 15 | is the girl 16 | is the boy 17 | is the dog 18 | are they 19 | who is 20 | what kind of 21 | what color are the 22 | what is in the 23 | what is the man 24 | is there 25 | what is the woman 26 | what are the 27 | what is the boy 28 | are there 29 | what is the girl 30 | is this 31 | how 32 | which 33 | how many people are 34 | is the cat 35 | why is the 36 | are 37 | will the 38 | what type of 39 | what is the dog 40 | do 41 | is she 42 | does 43 | do the 44 | is 45 | is the baby 46 | are there any 47 | is the lady 48 | can 49 | what animal is 50 | where are the 51 | is the sun 52 | what are they 53 | did the 54 | what is the cat 55 | what is the lady 56 | how many clouds are 57 | is that 58 | is the little girl 59 | is he 60 | are these 61 | how many trees are 62 | how many pillows 63 | are the people 64 | why 65 | is the young 66 | how many windows are 67 | is this a 68 | what is the little 69 | is the tv 70 | how many animals are 71 | who 72 | how many pictures 73 | how many plants are 74 | how many birds are 75 | what color is 76 | what is the baby 77 | is anyone 78 | what color 79 | how many bushes 80 | is the old man 81 | none of the above 82 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt: -------------------------------------------------------------------------------- 1 | how many 2 | is the 3 | what 4 | what color is the 5 | what is the 6 | is this 7 | is this a 8 | what is 9 | are the 10 | what kind of 11 | is there a 12 | what type of 13 | is it 14 | what are the 15 | where is the 16 | is there 17 | does the 18 | what color are the 19 | are these 20 | are there 21 | which 22 | is 23 | what is the man 24 | is the man 25 | are 26 | how 27 | does this 28 | what is on the 29 | what does the 30 | how many people are 31 | what is in the 32 | what is this 33 | do 34 | what are 35 | are they 36 | what time 37 | what sport is 38 | are there any 39 | is he 40 | what color is 41 | why 42 | where are the 43 | what color 44 | who is 45 | what animal is 46 | is the woman 47 | is this an 48 | do you 49 | how many people are in 50 | what room is 51 | has 52 | is this person 53 | what is the woman 54 | can you 55 | why is the 56 | is the person 57 | what is the color of the 58 | what is the person 59 | could 60 | was 61 | is that a 62 | what number is 63 | what is the name 64 | what brand 65 | none of the above 66 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/README.md: -------------------------------------------------------------------------------- 1 | Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset. 2 | =================== 3 | ## VQA v2.0 release ## 4 | This release consists of 5 | - Real 6 | - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download)) 7 | - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing 8 | - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question) 9 | 10 | There is only one type of task 11 | - Open-ended task 12 | 13 | ## VQA v1.0 release ## 14 | This release consists of 15 | - Real 16 | - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download)) 17 | - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image) 18 | - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question) 19 | - Abstract 20 | - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images 21 | - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image) 22 | - 600,000 answers for training and 300,000 answers for validation (10 per question) 23 | 24 | There are two types of tasks 25 | - Open-ended task 26 | - Multiple-choice task (18 choices per question) 27 | 28 | ## Requirements ## 29 | - python 2.7 30 | - scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation) 31 | - matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation) 32 | 33 | ## Files ## 34 | ./Questions 35 | - For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder. 36 | - For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html). 37 | - Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below 38 | - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip) 39 | - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip) 40 | - Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip). 41 | 42 | ./Annotations 43 | - For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder. 44 | - For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html). 45 | - Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below 46 | - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip) 47 | - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip) 48 | - Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip). 49 | 50 | ./Images 51 | - For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders. 52 | - For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders. 53 | 54 | ./PythonHelperTools 55 | - This directory contains the Python API to read and visualize the VQA dataset 56 | - vqaDemo.py (demo script) 57 | - vqaTools (API to read and visualize data) 58 | 59 | ./PythonEvaluationTools 60 | - This directory contains the Python evaluation code 61 | - vqaEvalDemo.py (evaluation demo script) 62 | - vqaEvaluation (evaluation code) 63 | 64 | ./Results 65 | - OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo) 66 | - Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details. 67 | 68 | ./QuestionTypes 69 | - This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k. 70 | - mscoco_question_types.txt 71 | - abstract_v002_question_types.txt 72 | 73 | ## References ## 74 | - [VQA: Visual Question Answering](http://visualqa.org/) 75 | - [Microsoft COCO](http://mscoco.org/) 76 | 77 | ## Developers ## 78 | - Aishwarya Agrawal (Virginia Tech) 79 | - Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco). 80 | - The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption). 81 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/VQA/license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Aishwarya Agrawal 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 18 | FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | The views and conclusions contained in the software and documentation are 27 | those 28 | of the authors and should not be interpreted as representing official 29 | policies, 30 | either expressed or implied, of the FreeBSD Project. 31 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | __author__ = "aagrawal" 9 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/vqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | __author__ = "aagrawal" 9 | __version__ = "0.9" 10 | 11 | # Interface for accessing the VQA dataset. 12 | 13 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 14 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 15 | 16 | # The following functions are defined: 17 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 18 | # getQuesIds - Get question ids that satisfy given filter conditions. 19 | # getImgIds - Get image ids that satisfy given filter conditions. 20 | # loadQA - Load questions and answers with the specified question ids. 21 | # showQA - Display the specified questions and answers. 22 | # loadRes - Load result file and create result object. 23 | 24 | # Help on each function can be accessed by: "help(COCO.function)" 25 | 26 | import json 27 | import datetime 28 | import copy 29 | 30 | 31 | class VQA: 32 | def __init__(self, annotation_file=None, question_file=None): 33 | """ 34 | Constructor of VQA helper class for reading and visualizing questions and answers. 35 | :param annotation_file (str): location of VQA annotation file 36 | :return: 37 | """ 38 | # load dataset 39 | self.dataset = {} 40 | self.questions = {} 41 | self.qa = {} 42 | self.qqa = {} 43 | self.imgToQA = {} 44 | if not annotation_file == None and not question_file == None: 45 | print("loading VQA annotations and questions into memory...") 46 | time_t = datetime.datetime.utcnow() 47 | dataset = json.load(open(annotation_file, "r")) 48 | questions = json.load(open(question_file, "r")) 49 | self.dataset = dataset 50 | self.questions = questions 51 | self.createIndex() 52 | 53 | def createIndex(self): 54 | # create index 55 | print("creating index...") 56 | imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} 57 | qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} 58 | qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} 59 | for ann in self.dataset["annotations"]: 60 | imgToQA[ann["image_id"]] += [ann] 61 | qa[ann["question_id"]] = ann 62 | for ques in self.questions["questions"]: 63 | qqa[ques["question_id"]] = ques 64 | print("index created!") 65 | 66 | # create class members 67 | self.qa = qa 68 | self.qqa = qqa 69 | self.imgToQA = imgToQA 70 | 71 | def info(self): 72 | """ 73 | Print information about the VQA annotation file. 74 | :return: 75 | """ 76 | for key, value in self.datset["info"].items(): 77 | print("%s: %s" % (key, value)) 78 | 79 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 80 | """ 81 | Get question ids that satisfy given filter conditions. default skips that filter 82 | :param imgIds (int array) : get question ids for given imgs 83 | quesTypes (str array) : get question ids for given question types 84 | ansTypes (str array) : get question ids for given answer types 85 | :return: ids (int array) : integer array of question ids 86 | """ 87 | imgIds = imgIds if type(imgIds) == list else [imgIds] 88 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 89 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 90 | 91 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 92 | anns = self.dataset["annotations"] 93 | else: 94 | if not len(imgIds) == 0: 95 | anns = sum( 96 | [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], 97 | [], 98 | ) 99 | else: 100 | anns = self.dataset["annotations"] 101 | anns = ( 102 | anns 103 | if len(quesTypes) == 0 104 | else [ann for ann in anns if ann["question_type"] in quesTypes] 105 | ) 106 | anns = ( 107 | anns 108 | if len(ansTypes) == 0 109 | else [ann for ann in anns if ann["answer_type"] in ansTypes] 110 | ) 111 | ids = [ann["question_id"] for ann in anns] 112 | return ids 113 | 114 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 115 | """ 116 | Get image ids that satisfy given filter conditions. default skips that filter 117 | :param quesIds (int array) : get image ids for given question ids 118 | quesTypes (str array) : get image ids for given question types 119 | ansTypes (str array) : get image ids for given answer types 120 | :return: ids (int array) : integer array of image ids 121 | """ 122 | quesIds = quesIds if type(quesIds) == list else [quesIds] 123 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 124 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 125 | 126 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 127 | anns = self.dataset["annotations"] 128 | else: 129 | if not len(quesIds) == 0: 130 | anns = sum( 131 | [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] 132 | ) 133 | else: 134 | anns = self.dataset["annotations"] 135 | anns = ( 136 | anns 137 | if len(quesTypes) == 0 138 | else [ann for ann in anns if ann["question_type"] in quesTypes] 139 | ) 140 | anns = ( 141 | anns 142 | if len(ansTypes) == 0 143 | else [ann for ann in anns if ann["answer_type"] in ansTypes] 144 | ) 145 | ids = [ann["image_id"] for ann in anns] 146 | return ids 147 | 148 | def loadQA(self, ids=[]): 149 | """ 150 | Load questions and answers with the specified question ids. 151 | :param ids (int array) : integer ids specifying question ids 152 | :return: qa (object array) : loaded qa objects 153 | """ 154 | if type(ids) == list: 155 | return [self.qa[id] for id in ids] 156 | elif type(ids) == int: 157 | return [self.qa[ids]] 158 | 159 | def showQA(self, anns): 160 | """ 161 | Display the specified annotations. 162 | :param anns (array of object): annotations to display 163 | :return: None 164 | """ 165 | if len(anns) == 0: 166 | return 0 167 | for ann in anns: 168 | quesId = ann["question_id"] 169 | print("Question: %s" % (self.qqa[quesId]["question"])) 170 | for ans in ann["answers"]: 171 | print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) 172 | 173 | def loadRes(self, resFile, quesFile): 174 | """ 175 | Load result file and return a result object. 176 | :param resFile (str) : file name of result file 177 | :return: res (obj) : result api object 178 | """ 179 | res = VQA() 180 | res.questions = json.load(open(quesFile)) 181 | res.dataset["info"] = copy.deepcopy(self.questions["info"]) 182 | res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) 183 | res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) 184 | res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) 185 | res.dataset["license"] = copy.deepcopy(self.questions["license"]) 186 | 187 | print("Loading and preparing results... ") 188 | time_t = datetime.datetime.utcnow() 189 | anns = json.load(open(resFile)) 190 | assert type(anns) == list, "results is not an array of objects" 191 | annsQuesIds = [ann["question_id"] for ann in anns] 192 | assert set(annsQuesIds) == set( 193 | self.getQuesIds() 194 | ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file." 195 | for ann in anns: 196 | quesId = ann["question_id"] 197 | if res.dataset["task_type"] == "Multiple Choice": 198 | assert ( 199 | ann["answer"] in self.qqa[quesId]["multiple_choices"] 200 | ), "predicted answer is not one of the multiple choices" 201 | qaAnn = self.qa[quesId] 202 | ann["image_id"] = qaAnn["image_id"] 203 | ann["question_type"] = qaAnn["question_type"] 204 | ann["answer_type"] = qaAnn["answer_type"] 205 | print( 206 | "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) 207 | ) 208 | 209 | res.dataset["annotations"] = anns 210 | res.createIndex() 211 | return res 212 | -------------------------------------------------------------------------------- /vxverse/common/vqa_tools/vqa_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | # coding=utf-8 9 | 10 | __author__ = "aagrawal" 11 | 12 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 13 | # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). 14 | import sys 15 | import re 16 | 17 | 18 | class VQAEval: 19 | def __init__(self, vqa=None, vqaRes=None, n=2): 20 | self.n = n 21 | self.accuracy = {} 22 | self.evalQA = {} 23 | self.evalQuesType = {} 24 | self.evalAnsType = {} 25 | self.vqa = vqa 26 | self.vqaRes = vqaRes 27 | if vqa is not None: 28 | self.params = {"question_id": vqa.getQuesIds()} 29 | self.contractions = { 30 | "aint": "ain't", 31 | "arent": "aren't", 32 | "cant": "can't", 33 | "couldve": "could've", 34 | "couldnt": "couldn't", 35 | "couldn'tve": "couldn't've", 36 | "couldnt've": "couldn't've", 37 | "didnt": "didn't", 38 | "doesnt": "doesn't", 39 | "dont": "don't", 40 | "hadnt": "hadn't", 41 | "hadnt've": "hadn't've", 42 | "hadn'tve": "hadn't've", 43 | "hasnt": "hasn't", 44 | "havent": "haven't", 45 | "hed": "he'd", 46 | "hed've": "he'd've", 47 | "he'dve": "he'd've", 48 | "hes": "he's", 49 | "howd": "how'd", 50 | "howll": "how'll", 51 | "hows": "how's", 52 | "Id've": "I'd've", 53 | "I'dve": "I'd've", 54 | "Im": "I'm", 55 | "Ive": "I've", 56 | "isnt": "isn't", 57 | "itd": "it'd", 58 | "itd've": "it'd've", 59 | "it'dve": "it'd've", 60 | "itll": "it'll", 61 | "let's": "let's", 62 | "maam": "ma'am", 63 | "mightnt": "mightn't", 64 | "mightnt've": "mightn't've", 65 | "mightn'tve": "mightn't've", 66 | "mightve": "might've", 67 | "mustnt": "mustn't", 68 | "mustve": "must've", 69 | "neednt": "needn't", 70 | "notve": "not've", 71 | "oclock": "o'clock", 72 | "oughtnt": "oughtn't", 73 | "ow's'at": "'ow's'at", 74 | "'ows'at": "'ow's'at", 75 | "'ow'sat": "'ow's'at", 76 | "shant": "shan't", 77 | "shed've": "she'd've", 78 | "she'dve": "she'd've", 79 | "she's": "she's", 80 | "shouldve": "should've", 81 | "shouldnt": "shouldn't", 82 | "shouldnt've": "shouldn't've", 83 | "shouldn'tve": "shouldn't've", 84 | "somebody'd": "somebodyd", 85 | "somebodyd've": "somebody'd've", 86 | "somebody'dve": "somebody'd've", 87 | "somebodyll": "somebody'll", 88 | "somebodys": "somebody's", 89 | "someoned": "someone'd", 90 | "someoned've": "someone'd've", 91 | "someone'dve": "someone'd've", 92 | "someonell": "someone'll", 93 | "someones": "someone's", 94 | "somethingd": "something'd", 95 | "somethingd've": "something'd've", 96 | "something'dve": "something'd've", 97 | "somethingll": "something'll", 98 | "thats": "that's", 99 | "thered": "there'd", 100 | "thered've": "there'd've", 101 | "there'dve": "there'd've", 102 | "therere": "there're", 103 | "theres": "there's", 104 | "theyd": "they'd", 105 | "theyd've": "they'd've", 106 | "they'dve": "they'd've", 107 | "theyll": "they'll", 108 | "theyre": "they're", 109 | "theyve": "they've", 110 | "twas": "'twas", 111 | "wasnt": "wasn't", 112 | "wed've": "we'd've", 113 | "we'dve": "we'd've", 114 | "weve": "we've", 115 | "werent": "weren't", 116 | "whatll": "what'll", 117 | "whatre": "what're", 118 | "whats": "what's", 119 | "whatve": "what've", 120 | "whens": "when's", 121 | "whered": "where'd", 122 | "wheres": "where's", 123 | "whereve": "where've", 124 | "whod": "who'd", 125 | "whod've": "who'd've", 126 | "who'dve": "who'd've", 127 | "wholl": "who'll", 128 | "whos": "who's", 129 | "whove": "who've", 130 | "whyll": "why'll", 131 | "whyre": "why're", 132 | "whys": "why's", 133 | "wont": "won't", 134 | "wouldve": "would've", 135 | "wouldnt": "wouldn't", 136 | "wouldnt've": "wouldn't've", 137 | "wouldn'tve": "wouldn't've", 138 | "yall": "y'all", 139 | "yall'll": "y'all'll", 140 | "y'allll": "y'all'll", 141 | "yall'd've": "y'all'd've", 142 | "y'alld've": "y'all'd've", 143 | "y'all'dve": "y'all'd've", 144 | "youd": "you'd", 145 | "youd've": "you'd've", 146 | "you'dve": "you'd've", 147 | "youll": "you'll", 148 | "youre": "you're", 149 | "youve": "you've", 150 | } 151 | self.manualMap = { 152 | "none": "0", 153 | "zero": "0", 154 | "one": "1", 155 | "two": "2", 156 | "three": "3", 157 | "four": "4", 158 | "five": "5", 159 | "six": "6", 160 | "seven": "7", 161 | "eight": "8", 162 | "nine": "9", 163 | "ten": "10", 164 | } 165 | self.articles = ["a", "an", "the"] 166 | 167 | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") 168 | self.commaStrip = re.compile("(\d)(,)(\d)") 169 | self.punct = [ 170 | ";", 171 | r"/", 172 | "[", 173 | "]", 174 | '"', 175 | "{", 176 | "}", 177 | "(", 178 | ")", 179 | "=", 180 | "+", 181 | "\\", 182 | "_", 183 | "-", 184 | ">", 185 | "<", 186 | "@", 187 | "`", 188 | ",", 189 | "?", 190 | "!", 191 | ] 192 | 193 | def evaluate(self, quesIds=None): 194 | if quesIds == None: 195 | quesIds = [quesId for quesId in self.params["question_id"]] 196 | gts = {} 197 | res = {} 198 | for quesId in quesIds: 199 | gts[quesId] = self.vqa.qa[quesId] 200 | res[quesId] = self.vqaRes.qa[quesId] 201 | 202 | # ================================================= 203 | # Compute accuracy 204 | # ================================================= 205 | accQA = [] 206 | accQuesType = {} 207 | accAnsType = {} 208 | print("computing accuracy") 209 | step = 0 210 | for quesId in quesIds: 211 | resAns = res[quesId]["answer"] 212 | resAns = resAns.replace("\n", " ") 213 | resAns = resAns.replace("\t", " ") 214 | resAns = resAns.strip() 215 | resAns = self.processPunctuation(resAns) 216 | resAns = self.processDigitArticle(resAns) 217 | gtAcc = [] 218 | gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] 219 | if len(set(gtAnswers)) > 1: 220 | for ansDic in gts[quesId]["answers"]: 221 | ansDic["answer"] = self.processPunctuation(ansDic["answer"]) 222 | for gtAnsDatum in gts[quesId]["answers"]: 223 | otherGTAns = [ 224 | item for item in gts[quesId]["answers"] if item != gtAnsDatum 225 | ] 226 | matchingAns = [item for item in otherGTAns if item["answer"] == resAns] 227 | acc = min(1, float(len(matchingAns)) / 3) 228 | gtAcc.append(acc) 229 | quesType = gts[quesId]["question_type"] 230 | ansType = gts[quesId]["answer_type"] 231 | avgGTAcc = float(sum(gtAcc)) / len(gtAcc) 232 | accQA.append(avgGTAcc) 233 | if quesType not in accQuesType: 234 | accQuesType[quesType] = [] 235 | accQuesType[quesType].append(avgGTAcc) 236 | if ansType not in accAnsType: 237 | accAnsType[ansType] = [] 238 | accAnsType[ansType].append(avgGTAcc) 239 | self.setEvalQA(quesId, avgGTAcc) 240 | self.setEvalQuesType(quesId, quesType, avgGTAcc) 241 | self.setEvalAnsType(quesId, ansType, avgGTAcc) 242 | if step % 100 == 0: 243 | self.updateProgress(step / float(len(quesIds))) 244 | step = step + 1 245 | 246 | self.setAccuracy(accQA, accQuesType, accAnsType) 247 | print("Done computing accuracy") 248 | 249 | def processPunctuation(self, inText): 250 | outText = inText 251 | for p in self.punct: 252 | if (p + " " in inText or " " + p in inText) or ( 253 | re.search(self.commaStrip, inText) != None 254 | ): 255 | outText = outText.replace(p, "") 256 | else: 257 | outText = outText.replace(p, " ") 258 | outText = self.periodStrip.sub("", outText, re.UNICODE) 259 | return outText 260 | 261 | def processDigitArticle(self, inText): 262 | outText = [] 263 | tempText = inText.lower().split() 264 | for word in tempText: 265 | word = self.manualMap.setdefault(word, word) 266 | if word not in self.articles: 267 | outText.append(word) 268 | else: 269 | pass 270 | for wordId, word in enumerate(outText): 271 | if word in self.contractions: 272 | outText[wordId] = self.contractions[word] 273 | outText = " ".join(outText) 274 | return outText 275 | 276 | def setAccuracy(self, accQA, accQuesType, accAnsType): 277 | self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) 278 | self.accuracy["perQuestionType"] = { 279 | quesType: round( 280 | 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), 281 | self.n, 282 | ) 283 | for quesType in accQuesType 284 | } 285 | self.accuracy["perAnswerType"] = { 286 | ansType: round( 287 | 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n 288 | ) 289 | for ansType in accAnsType 290 | } 291 | 292 | def setEvalQA(self, quesId, acc): 293 | self.evalQA[quesId] = round(100 * acc, self.n) 294 | 295 | def setEvalQuesType(self, quesId, quesType, acc): 296 | if quesType not in self.evalQuesType: 297 | self.evalQuesType[quesType] = {} 298 | self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) 299 | 300 | def setEvalAnsType(self, quesId, ansType, acc): 301 | if ansType not in self.evalAnsType: 302 | self.evalAnsType[ansType] = {} 303 | self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) 304 | 305 | def updateProgress(self, progress): 306 | barLength = 20 307 | status = "" 308 | if isinstance(progress, int): 309 | progress = float(progress) 310 | if not isinstance(progress, float): 311 | progress = 0 312 | status = "error: progress var must be float\r\n" 313 | if progress < 0: 314 | progress = 0 315 | status = "Halt...\r\n" 316 | if progress >= 1: 317 | progress = 1 318 | status = "Done...\r\n" 319 | block = int(round(barLength * progress)) 320 | text = "\rFinshed Percent: [{0}] {1}% {2}".format( 321 | "#" * block + "-" * (barLength - block), int(progress * 100), status 322 | ) 323 | sys.stdout.write(text) 324 | sys.stdout.flush() 325 | -------------------------------------------------------------------------------- /vxverse/configs/Qformer/bert-base-uncased/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "gradient_checkpointing": false, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "pad_token_id": 0, 18 | "position_embedding_type": "absolute", 19 | "transformers_version": "4.6.0.dev0", 20 | "type_vocab_size": 2, 21 | "use_cache": true, 22 | "vocab_size": 30522 23 | } 24 | -------------------------------------------------------------------------------- /vxverse/configs/Qformer/bert-base-uncased/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": true 3 | } 4 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/align/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/triple/ 6 | end_sym: "<|endoftext|>" 7 | max_context_len: 500 8 | max_txt_len: 500 9 | max_seq_len: 992 10 | datasets: 11 | AOK-VQA: 12 | ChineseFoodTriple: 13 | CLEVR: 14 | CLEVR_CoGenT: 15 | CUB-200-2011: 16 | DAQUAR: 17 | LRV-Instruction: 18 | LRV-Instruction_chart: 19 | OCR-VQA: 20 | OK-VQA: 21 | PathVQA: 22 | PMC-CaseReport: 23 | ScienceQA: 24 | sketch: 25 | Slake: 26 | ST-VQA: 27 | VisDial: 28 | VQA-RAD: 29 | VQAv2: 30 | VQAv2_AS: 31 | VQAv2_BBAS: 32 | IconQA: 33 | TextVQA: 34 | ShareGPT4V_instruct: 35 | ShareGPT4V_mix: 36 | ShareGPT4V_captioner: 37 | miniGPT4_CCS_Align: 38 | LLaVA_V1_5_mix665k: 39 | LLaVA_V1_5_instruct158k: 40 | Geo170K_alignment: 41 | Geo170K_qa_tuning: 42 | GeoQA: 43 | GeoQA+_length: 44 | GeoQA+_angle: 45 | GeoQA+_area: 46 | VisualGenome: 47 | xGQA: 48 | GQA: 49 | DVQA: 50 | ChartQA: 51 | AI2D: 52 | DocVQA: 53 | SynthDoG-EN: 54 | VizWiz_VQA: 55 | CLEVR-math: 56 | LLaVA-Instruct-150K_ZH: 57 | PlotQA: 58 | PISC: 59 | ShapeWorld: 60 | UCF101: 61 | HMDB51: 62 | FER-2013: 63 | ArxivQA: 64 | brainscape-std: 65 | shitishuxe-std: 66 | examcoo-std: 67 | KonIQ-10k: 68 | pure_text: 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/align/ccs_sub.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | laion: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/webdataset/{00000..05421}.tar 6 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/align/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/ccs/webdataset/{00000..05421}.tar 6 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/align_hd/align_hd.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | align_hd: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/triple/ 6 | end_sym: "<|endoftext|>" 7 | max_seq_len: 2048 8 | datasets: 9 | dataset1: 10 | dataset2 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/cc_sbu/align.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu_align: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/dataset 6 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/cc_sbu/ccs_sub.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | laion: 3 | data_type: images 4 | build_info: 5 | with_instruction: True 6 | storage: /path/to/ccs/webdataset/{00000..05421}.tar 7 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/cc_sbu/defaults.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | cc_sbu: 3 | data_type: images 4 | build_info: 5 | storage: /path/to/ccs/webdataset/{00000..05421}.tar 6 | -------------------------------------------------------------------------------- /vxverse/configs/datasets/gqa/balanced_val.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | gqa: 3 | # data_dir: ${env.data_dir}/datasets 4 | data_type: images # [images|videos|features] 5 | 6 | build_info: 7 | # Be careful not to append minus sign (-) before split to avoid itemizing 8 | annotations: 9 | train: 10 | url: 11 | - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json 12 | storage: 13 | - /path/to/gqa/train_balanced_questions.json 14 | 15 | images: 16 | storage: /path/to/gqa/images 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /vxverse/configs/deepspeed/ds.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "zero_optimization": { 11 | "stage": 3, 12 | "offload_optimizer": { 13 | "device": "cpu", 14 | "pin_memory": true 15 | } 16 | }, 17 | "steps_per_print": 50, 18 | "train_micro_batch_size_per_gpu": 1, 19 | "wall_clock_breakdown": false 20 | } -------------------------------------------------------------------------------- /vxverse/configs/default.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | # For default users 3 | # cache_root: "cache" 4 | # For internal use with persistent storage 5 | cache_root: "/export/home/.cache/vxverse" 6 | 7 | -------------------------------------------------------------------------------- /vxverse/configs/eva/EVA02-CLIP-bigE-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /vxverse/configs/models/vxverse_13bchat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: vxverse 3 | 4 | llama_model: "./XVERSE-13B-Chat-latest" 5 | 6 | -------------------------------------------------------------------------------- /vxverse/configs/models/vxverse_65bchat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: vxverse 3 | 4 | llama_model: "./XVERSE-65B-Chat-latest" 5 | 6 | -------------------------------------------------------------------------------- /vxverse/configs/models/vxverse_7bchat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: vxverse 3 | 4 | llama_model: "./XVERSE-7B-Chat" 5 | 6 | -------------------------------------------------------------------------------- /vxverse/conversation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/vxverse/conversation/__init__.py -------------------------------------------------------------------------------- /vxverse/conversation/conversation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from threading import Thread 4 | from PIL import Image 5 | 6 | import torch 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer 8 | from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer 9 | 10 | import dataclasses 11 | from enum import auto, Enum 12 | from typing import List, Tuple, Any 13 | 14 | # from vxverse.common.registry import registry 15 | 16 | 17 | class SeparatorStyle(Enum): 18 | """Different separator style.""" 19 | SINGLE = auto() 20 | TWO = auto() 21 | 22 | 23 | @dataclasses.dataclass 24 | class Conversation: 25 | """A class that keeps all conversation history.""" 26 | system: str 27 | roles: List[str] 28 | messages: List[List[str]] 29 | offset: int 30 | # system_img: List[Image.Image] = [] 31 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 32 | sep: str = "###" 33 | sep2: str = None 34 | 35 | skip_next: bool = False 36 | conv_id: Any = None 37 | 38 | def get_prompt(self): 39 | if self.sep_style == SeparatorStyle.SINGLE: 40 | ret = self.system + self.sep 41 | for role, message in self.messages: 42 | if message: 43 | ret += role + message + self.sep 44 | else: 45 | ret += role 46 | return ret 47 | elif self.sep_style == SeparatorStyle.TWO: 48 | seps = [self.sep, self.sep2] 49 | ret = self.system + seps[0] 50 | for i, (role, message) in enumerate(self.messages): 51 | if message: 52 | ret += role + message + seps[i % 2] 53 | else: 54 | ret += role 55 | return ret 56 | else: 57 | raise ValueError(f"Invalid style: {self.sep_style}") 58 | 59 | def append_message(self, role, message): 60 | self.messages.append([role, message]) 61 | 62 | def to_gradio_chatbot(self): 63 | ret = [] 64 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 65 | if i % 2 == 0: 66 | ret.append([msg, None]) 67 | else: 68 | ret[-1][-1] = msg 69 | return ret 70 | 71 | def copy(self): 72 | return Conversation( 73 | system=self.system, 74 | # system_img=self.system_img, 75 | roles=self.roles, 76 | messages=[[x, y] for x, y in self.messages], 77 | offset=self.offset, 78 | sep_style=self.sep_style, 79 | sep=self.sep, 80 | sep2=self.sep2, 81 | conv_id=self.conv_id) 82 | 83 | def dict(self): 84 | return { 85 | "system": self.system, 86 | # "system_img": self.system_img, 87 | "roles": self.roles, 88 | "messages": self.messages, 89 | "offset": self.offset, 90 | "sep": self.sep, 91 | "sep2": self.sep2, 92 | "conv_id": self.conv_id, 93 | } 94 | 95 | 96 | class StoppingCriteriaSub(StoppingCriteria): 97 | 98 | def __init__(self, stops=[], encounters=1): 99 | super().__init__() 100 | self.stops = stops 101 | 102 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 103 | for stop in self.stops: 104 | if torch.all(input_ids[:, -len(stop):] == stop).item(): 105 | return True 106 | 107 | return False 108 | 109 | 110 | CONV_VISION_Vicuna0 = Conversation( 111 | system="Give the following image: ImageContent. " 112 | "You will be able to see the image once I provide it to you. Please answer my questions.", 113 | roles=("Human: ", "Assistant: "), 114 | messages=[], 115 | offset=2, 116 | sep_style=SeparatorStyle.SINGLE, 117 | sep="###", 118 | ) 119 | 120 | CONV_VISION_Vicuna1 = Conversation( 121 | system="A chat between a curious user and an artificial intelligence assistant. " 122 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 123 | roles=("USER: ", "ASSISTANT: "), 124 | messages=[], 125 | offset=0, 126 | sep_style=SeparatorStyle.TWO, 127 | sep=" ", 128 | sep2="", 129 | ) 130 | 131 | CONV_VISION_XVERSE = Conversation( 132 | system="你是元象开发的具有图文问答能力的XChat,一旦给定你一张图片,你可以依据图片的内容,详细地回答我的问题。", 133 | roles=("Human: ", "Assistant: "), 134 | messages=[], 135 | offset=2, 136 | sep_style=SeparatorStyle.TWO, 137 | sep="\n", 138 | sep2="<|endoftext|>" 139 | ) 140 | 141 | 142 | CONV_VISION_LLama2 = Conversation( 143 | system="Give the following image: ImageContent. " 144 | "You will be able to see the image once I provide it to you. Please answer my questions.", 145 | roles=("[INST] ", " [/INST] "), 146 | messages=[], 147 | offset=2, 148 | sep_style=SeparatorStyle.SINGLE, 149 | sep="", 150 | ) 151 | 152 | CONV_VISION_minigptv2 = Conversation( 153 | system="", 154 | roles=("[INST] ", " [/INST]"), 155 | messages=[], 156 | offset=2, 157 | sep_style=SeparatorStyle.SINGLE, 158 | sep="", 159 | ) 160 | 161 | class Chat: 162 | def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None, vis_processor_name=None): 163 | self.device = device 164 | self.model = model 165 | self.vis_processor = vis_processor 166 | self.vis_processor_name = vis_processor_name 167 | if stopping_criteria is not None: 168 | self.stopping_criteria = stopping_criteria 169 | else: 170 | stop_words_ids = [torch.tensor([2]).to(self.device)] 171 | self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 172 | 173 | def ask(self, text, conv): 174 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ 175 | and "\n" in conv.messages[-1][1]: # last message is image. 176 | conv.messages[-1][1] = conv.messages[-1][1] + text 177 | else: 178 | conv.append_message(conv.roles[0], text) 179 | 180 | def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, 181 | repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000, top_k=30, do_sample=True): 182 | conv.append_message(conv.roles[1], None) 183 | prompt = conv.get_prompt() 184 | print("##############") 185 | print("prompt:{}".format(prompt)) 186 | patches_per_image = [] 187 | if self.vis_processor_name=="hd_image_train": 188 | if img_list!=None: 189 | for img in img_list: 190 | patches_per_image.append(img.shape[0]) 191 | embs = self.model.get_context_emb(prompt, img_list[0], patches_per_image=patches_per_image, 192 | device=self.device) 193 | else: 194 | patches_per_image=None 195 | embs = self.model.get_context_emb(prompt, img_list, patches_per_image=patches_per_image, 196 | device=self.device) 197 | else: 198 | patches_per_image = None 199 | embs = self.model.get_context_emb(prompt, img_list, patches_per_image=patches_per_image, device=self.device) 200 | 201 | current_max_len = embs.shape[1] + max_new_tokens 202 | if current_max_len - max_length > 0: 203 | print('Warning: The number of tokens in current conversation exceeds the max length. ' 204 | 'The model will not see the contexts outside the range.') 205 | begin_idx = max(0, current_max_len - max_length) 206 | embs = embs[:, begin_idx:] 207 | 208 | generation_kwargs = dict( 209 | inputs_embeds=embs, 210 | max_new_tokens=max_new_tokens, 211 | stopping_criteria=self.stopping_criteria, 212 | num_beams=num_beams, 213 | do_sample=do_sample, 214 | min_length=min_length, 215 | top_p=top_p, 216 | top_k=top_k, 217 | repetition_penalty=repetition_penalty, 218 | length_penalty=length_penalty, 219 | temperature=float(temperature), 220 | ) 221 | return generation_kwargs 222 | 223 | def answer(self, conv, img_list, stop_sign, **kargs): 224 | generation_dict = self.answer_prepare(conv, img_list, **kargs) 225 | output_token = self.model_generate(**generation_dict)[0] 226 | output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) 227 | 228 | output_text = output_text.split(stop_sign)[0] # remove the stop sign '###' 229 | output_text = output_text.split('Assistant:')[-1].strip() 230 | 231 | conv.messages[-1][1] = output_text 232 | return output_text, output_token.cpu().numpy() 233 | 234 | def stream_answer(self, conv, img_list, **kargs): 235 | generation_kwargs = self.answer_prepare(conv, img_list, **kargs) 236 | streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) 237 | generation_kwargs['streamer'] = streamer 238 | thread = Thread(target=self.model_generate, kwargs=generation_kwargs) 239 | thread.start() 240 | return streamer 241 | 242 | def model_generate(self, *args, **kwargs): 243 | # for 8 bit and 16 bit compatibility 244 | with self.model.maybe_autocast(self.model.llm_torch_dtype): 245 | output = self.model.llama_model.generate(*args, **kwargs) 246 | return output 247 | 248 | def encode_img(self, img_list): 249 | image = img_list[0] 250 | img_list.pop(0) 251 | 252 | if self.vis_processor_name=="hd_image_train": 253 | if isinstance(image, str): # is a image path 254 | raw_image = Image.open(image).convert('RGB') 255 | image = self.vis_processor(raw_image).to(self.device) 256 | elif isinstance(image, Image.Image): 257 | raw_image = image 258 | image = self.vis_processor(raw_image).to(self.device) 259 | elif isinstance(image, torch.Tensor): 260 | if len(image.shape) == 3: 261 | image = image.unsqueeze(0) 262 | image = image.to(self.device) 263 | patches_per_image = [[image.shape[0]]] 264 | image = [image] 265 | else: 266 | if isinstance(image, str): # is a image path 267 | raw_image = Image.open(image).convert('RGB') 268 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 269 | elif isinstance(image, Image.Image): 270 | raw_image = image 271 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 272 | elif isinstance(image, torch.Tensor): 273 | if len(image.shape) == 3: 274 | image = image.unsqueeze(0) 275 | image = image.to(self.device) 276 | patches_per_image = None 277 | image_emb, _ = self.model.encode_img(image) 278 | if type(image_emb) == list: 279 | img_list.extend(image_emb) 280 | else: 281 | img_list.append(image_emb) 282 | 283 | def upload_img(self, image, conv, img_list): 284 | conv.append_message(conv.roles[0], "\n") 285 | img_list.append(image) 286 | msg = "Received." 287 | 288 | return msg 289 | 290 | -------------------------------------------------------------------------------- /vxverse/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/vxverse/datasets/__init__.py -------------------------------------------------------------------------------- /vxverse/datasets/builders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from vxverse.datasets.builders.base_dataset_builder import load_dataset_config 9 | from vxverse.datasets.builders.image_text_pair_builder import ( 10 | GQABuilder, 11 | # "CCSBUAlignBuilder" 12 | ) 13 | from vxverse.common.registry import registry 14 | 15 | __all__ = [ 16 | "GQABuilder", 17 | # "LaionBuilder", 18 | # "CCSBUAlignBuilder" 19 | ] 20 | 21 | 22 | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): 23 | """ 24 | Example 25 | 26 | >>> dataset = load_dataset("coco_caption", cfg=None) 27 | >>> splits = dataset.keys() 28 | >>> print([len(dataset[split]) for split in splits]) 29 | 30 | """ 31 | if cfg_path is None: 32 | cfg = None 33 | else: 34 | cfg = load_dataset_config(cfg_path) 35 | 36 | try: 37 | builder = registry.get_builder_class(name)(cfg) 38 | except TypeError: 39 | print( 40 | f"Dataset {name} not found. Available datasets:\n" 41 | + ", ".join([str(k) for k in dataset_zoo.get_names()]) 42 | ) 43 | exit(1) 44 | 45 | if vis_path is not None: 46 | if data_type is None: 47 | # use default data type in the config 48 | data_type = builder.config.data_type 49 | 50 | assert ( 51 | data_type in builder.config.build_info 52 | ), f"Invalid data_type {data_type} for {name}." 53 | 54 | builder.config.build_info.get(data_type).storage = vis_path 55 | 56 | dataset = builder.build_datasets() 57 | return dataset 58 | 59 | 60 | class DatasetZoo: 61 | def __init__(self) -> None: 62 | self.dataset_zoo = { 63 | k: list(v.DATASET_CONFIG_DICT.keys()) 64 | for k, v in sorted(registry.mapping["builder_name_mapping"].items()) 65 | } 66 | 67 | def get_names(self): 68 | return list(self.dataset_zoo.keys()) 69 | 70 | 71 | dataset_zoo = DatasetZoo() 72 | -------------------------------------------------------------------------------- /vxverse/datasets/builders/base_dataset_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is from 3 | Copyright (c) 2022, salesforce.com, inc. 4 | All rights reserved. 5 | SPDX-License-Identifier: BSD-3-Clause 6 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | """ 8 | 9 | import logging 10 | import os 11 | import shutil 12 | import warnings 13 | 14 | from omegaconf import OmegaConf 15 | import torch.distributed as dist 16 | from torchvision.datasets.utils import download_url 17 | 18 | import vxverse.common.utils as utils 19 | from vxverse.common.dist_utils import is_dist_avail_and_initialized, is_main_process 20 | from vxverse.common.registry import registry 21 | from vxverse.processors.base_processor import BaseProcessor 22 | 23 | 24 | 25 | class BaseDatasetBuilder: 26 | train_dataset_cls, eval_dataset_cls = None, None 27 | 28 | def __init__(self, cfg=None): 29 | super().__init__() 30 | if cfg is None: 31 | # help to create datasets from default config. 32 | self.config = load_dataset_config(self.default_config_path()) 33 | elif isinstance(cfg, str): 34 | self.config = load_dataset_config(cfg) 35 | else: 36 | # when called from task.build_dataset() 37 | self.config = cfg 38 | 39 | self.data_type = self.config.data_type 40 | self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 41 | self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} 42 | 43 | def build_datasets(self): 44 | # download, split, etc... 45 | # only called on 1 GPU/TPU in distributed 46 | 47 | if is_main_process(): 48 | self._download_data() 49 | 50 | # May cause [E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed 51 | if is_dist_avail_and_initialized(): 52 | dist.barrier() 53 | # at this point, all the annotations and image/videos should be all downloaded to the specified locations. 54 | print("Building datasets...") 55 | logging.info("Building datasets...") 56 | datasets = self.build() # dataset['train'/'val'/'test'] 57 | 58 | return datasets 59 | 60 | def build_processors(self): 61 | vis_proc_cfg = self.config.get("vis_processor") 62 | txt_proc_cfg = self.config.get("text_processor") 63 | 64 | if vis_proc_cfg is not None: 65 | vis_train_cfg = vis_proc_cfg.get("train") 66 | vis_eval_cfg = vis_proc_cfg.get("eval") 67 | self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) 68 | self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) 69 | 70 | if txt_proc_cfg is not None: 71 | txt_train_cfg = txt_proc_cfg.get("train") 72 | txt_eval_cfg = txt_proc_cfg.get("eval") 73 | 74 | self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) 75 | self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) 76 | 77 | @staticmethod 78 | def _build_proc_from_cfg(cfg): 79 | return ( 80 | registry.get_processor_class(cfg.name).from_config(cfg) 81 | if cfg is not None 82 | else None 83 | ) 84 | 85 | @classmethod 86 | def default_config_path(cls, type="default"): 87 | return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) 88 | 89 | def _download_data(self): 90 | self._download_ann() 91 | self._download_vis() 92 | 93 | def _download_ann(self): 94 | """ 95 | Download annotation files if necessary. 96 | All the vision-language datasets should have annotations of unified format. 97 | 98 | storage_path can be: 99 | (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. 100 | (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. 101 | 102 | Local annotation paths should be relative. 103 | """ 104 | anns = self.config.build_info.annotations 105 | 106 | splits = anns.keys() 107 | 108 | cache_root = registry.get_path("cache_root") 109 | 110 | for split in splits: 111 | info = anns[split] 112 | 113 | urls, storage_paths = info.get("url", None), info.storage 114 | 115 | if isinstance(urls, str): 116 | urls = [urls] 117 | if isinstance(storage_paths, str): 118 | storage_paths = [storage_paths] 119 | 120 | assert len(urls) == len(storage_paths) 121 | 122 | for url_or_filename, storage_path in zip(urls, storage_paths): 123 | # if storage_path is relative, make it full by prefixing with cache_root. 124 | if not os.path.isabs(storage_path): 125 | storage_path = os.path.join(cache_root, storage_path) 126 | 127 | dirname = os.path.dirname(storage_path) 128 | if not os.path.exists(dirname): 129 | os.makedirs(dirname) 130 | 131 | if os.path.isfile(url_or_filename): 132 | src, dst = url_or_filename, storage_path 133 | if not os.path.exists(dst): 134 | shutil.copyfile(src=src, dst=dst) 135 | else: 136 | logging.info("Using existing file {}.".format(dst)) 137 | else: 138 | if os.path.isdir(storage_path): 139 | # if only dirname is provided, suffix with basename of URL. 140 | raise ValueError( 141 | "Expecting storage_path to be a file path, got directory {}".format( 142 | storage_path 143 | ) 144 | ) 145 | else: 146 | filename = os.path.basename(storage_path) 147 | 148 | download_url(url=url_or_filename, root=dirname, filename=filename) 149 | 150 | def _download_vis(self): 151 | 152 | storage_path = self.config.build_info.get(self.data_type).storage 153 | storage_path = utils.get_cache_path(storage_path) 154 | 155 | if not os.path.exists(storage_path): 156 | warnings.warn( 157 | f""" 158 | The specified path {storage_path} for visual inputs does not exist. 159 | Please provide a correct path to the visual inputs or 160 | refer to datasets/download_scripts/README.md for downloading instructions. 161 | """ 162 | ) 163 | 164 | def build(self): 165 | """ 166 | Create by split datasets inheriting torch.utils.data.Datasets. 167 | 168 | # build() can be dataset-specific. Overwrite to customize. 169 | """ 170 | self.build_processors() 171 | 172 | build_info = self.config.build_info 173 | 174 | ann_info = build_info.annotations 175 | vis_info = build_info.get(self.data_type) 176 | 177 | datasets = dict() 178 | for split in ann_info.keys(): 179 | if split not in ["train", "val", "test"]: 180 | continue 181 | 182 | is_train = split == "train" 183 | 184 | # processors 185 | vis_processor = ( 186 | self.vis_processors["train"] 187 | if is_train 188 | else self.vis_processors["eval"] 189 | ) 190 | text_processor = ( 191 | self.text_processors["train"] 192 | if is_train 193 | else self.text_processors["eval"] 194 | ) 195 | 196 | # annotation path 197 | ann_paths = ann_info.get(split).storage 198 | if isinstance(ann_paths, str): 199 | ann_paths = [ann_paths] 200 | 201 | abs_ann_paths = [] 202 | for ann_path in ann_paths: 203 | if not os.path.isabs(ann_path): 204 | ann_path = utils.get_cache_path(ann_path) 205 | abs_ann_paths.append(ann_path) 206 | ann_paths = abs_ann_paths 207 | 208 | # visual data storage path 209 | vis_path = os.path.join(vis_info.storage, split) 210 | 211 | if not os.path.isabs(vis_path): 212 | # vis_path = os.path.join(utils.get_cache_path(), vis_path) 213 | vis_path = utils.get_cache_path(vis_path) 214 | 215 | if not os.path.exists(vis_path): 216 | warnings.warn("storage path {} does not exist.".format(vis_path)) 217 | 218 | # create datasets 219 | dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls 220 | datasets[split] = dataset_cls( 221 | vis_processor=vis_processor, 222 | text_processor=text_processor, 223 | ann_paths=ann_paths, 224 | vis_root=vis_path, 225 | ) 226 | 227 | return datasets 228 | 229 | 230 | def load_dataset_config(cfg_path): 231 | cfg = OmegaConf.load(cfg_path).datasets 232 | cfg = cfg[list(cfg.keys())[0]] 233 | 234 | return cfg 235 | -------------------------------------------------------------------------------- /vxverse/datasets/builders/image_text_pair_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | from vxverse.common.registry import registry 6 | from vxverse.datasets.builders.base_dataset_builder import BaseDatasetBuilder 7 | from vxverse.datasets.datasets.gqa_datasets import GQADataset 8 | 9 | 10 | 11 | @registry.register_builder("gqa") 12 | class GQABuilder(BaseDatasetBuilder): 13 | train_dataset_cls = GQADataset 14 | DATASET_CONFIG_DICT = { 15 | "default": "configs/datasets/gqa/balanced_val.yaml", 16 | } 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /vxverse/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import gzip 9 | import logging 10 | import os 11 | import random as rnd 12 | import tarfile 13 | import zipfile 14 | import random 15 | from typing import List 16 | from tqdm import tqdm 17 | 18 | import decord 19 | from decord import VideoReader 20 | import webdataset as wds 21 | import numpy as np 22 | import torch 23 | from torch.utils.data.dataset import IterableDataset 24 | 25 | from vxverse.common.registry import registry 26 | from vxverse.datasets.datasets.base_dataset import ConcatDataset 27 | 28 | 29 | decord.bridge.set_bridge("torch") 30 | MAX_INT = registry.get("MAX_INT") 31 | 32 | 33 | class ChainDataset(wds.DataPipeline): 34 | r"""Dataset for chaining multiple :class:`DataPipeline` s. 35 | 36 | This class is useful to assemble different existing dataset streams. The 37 | chaining operation is done on-the-fly, so concatenating large-scale 38 | datasets with this class will be efficient. 39 | 40 | Args: 41 | datasets (iterable of IterableDataset): datasets to be chained together 42 | """ 43 | def __init__(self, datasets: List[wds.DataPipeline]) -> None: 44 | super().__init__() 45 | self.datasets = datasets 46 | self.prob = [] 47 | self.names = [] 48 | for dataset in self.datasets: 49 | if hasattr(dataset, 'name'): 50 | self.names.append(dataset.name) 51 | else: 52 | self.names.append('Unknown') 53 | if hasattr(dataset, 'sample_ratio'): 54 | self.prob.append(dataset.sample_ratio) 55 | else: 56 | self.prob.append(1) 57 | logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") 58 | 59 | def __iter__(self): 60 | datastreams = [iter(dataset) for dataset in self.datasets] 61 | while True: 62 | select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] 63 | yield next(select_datastream) 64 | 65 | 66 | def apply_to_sample(f, sample): 67 | if len(sample) == 0: 68 | return {} 69 | 70 | def _apply(x): 71 | if torch.is_tensor(x): 72 | return f(x) 73 | elif isinstance(x, dict): 74 | return {key: _apply(value) for key, value in x.items()} 75 | elif isinstance(x, list): 76 | return [_apply(x) for x in x] 77 | else: 78 | return x 79 | 80 | return _apply(sample) 81 | 82 | 83 | def move_to_cuda(sample): 84 | def _move_to_cuda(tensor): 85 | return tensor.cuda() 86 | 87 | return apply_to_sample(_move_to_cuda, sample) 88 | 89 | 90 | def prepare_sample(samples, cuda_enabled=True): 91 | if cuda_enabled: 92 | samples = move_to_cuda(samples) 93 | 94 | # TODO fp16 support 95 | 96 | return samples 97 | 98 | 99 | def reorg_datasets_by_split(datasets, batch_sizes): 100 | """ 101 | Organizes datasets by split. 102 | 103 | Args: 104 | datasets: dict of torch.utils.data.Dataset objects by name. 105 | 106 | Returns: 107 | Dict of datasets by split {split_name: List[Datasets]}. 108 | """ 109 | # if len(datasets) == 1: 110 | # return datasets[list(datasets.keys())[0]] 111 | # else: 112 | reorg_datasets = dict() 113 | reorg_batch_sizes = dict() 114 | 115 | # reorganize by split 116 | for dataset_name, dataset in datasets.items(): 117 | for split_name, dataset_split in dataset.items(): 118 | if split_name not in reorg_datasets: 119 | reorg_datasets[split_name] = [dataset_split] 120 | reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]] 121 | else: 122 | reorg_datasets[split_name].append(dataset_split) 123 | reorg_batch_sizes[split_name].append(batch_sizes[dataset_name]) 124 | 125 | return reorg_datasets, reorg_batch_sizes 126 | 127 | 128 | def concat_datasets(datasets): 129 | """ 130 | Concatenates multiple datasets into a single dataset. 131 | 132 | It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support 133 | generic IterableDataset because it requires creating separate samplers. 134 | 135 | Now only supports conctenating training datasets and assuming validation and testing 136 | have only a single dataset. This is because metrics should not be computed on the concatenated 137 | datasets. 138 | 139 | Args: 140 | datasets: dict of torch.utils.data.Dataset objects by split. 141 | 142 | Returns: 143 | Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, 144 | "val" and "test" remain the same. 145 | 146 | If the input training datasets contain both map-style and DataPipeline datasets, returns 147 | a tuple, where the first element is a concatenated map-style dataset and the second 148 | element is a chained DataPipeline dataset. 149 | 150 | """ 151 | # concatenate datasets in the same split 152 | for split_name in datasets: 153 | if split_name != "train": 154 | assert ( 155 | len(datasets[split_name]) == 1 156 | ), "Do not support multiple {} datasets.".format(split_name) 157 | datasets[split_name] = datasets[split_name][0] 158 | else: 159 | iterable_datasets, map_datasets = [], [] 160 | for dataset in datasets[split_name]: 161 | if isinstance(dataset, wds.DataPipeline): 162 | logging.info( 163 | "Dataset {} is IterableDataset, can't be concatenated.".format( 164 | dataset 165 | ) 166 | ) 167 | iterable_datasets.append(dataset) 168 | elif isinstance(dataset, IterableDataset): 169 | raise NotImplementedError( 170 | "Do not support concatenation of generic IterableDataset." 171 | ) 172 | else: 173 | map_datasets.append(dataset) 174 | 175 | # if len(iterable_datasets) > 0: 176 | # concatenate map-style datasets and iterable-style datasets separately 177 | if len(iterable_datasets) > 1: 178 | chained_datasets = ( 179 | ChainDataset(iterable_datasets) 180 | ) 181 | elif len(iterable_datasets) == 1: 182 | chained_datasets = iterable_datasets[0] 183 | else: 184 | chained_datasets = None 185 | 186 | concat_datasets = ( 187 | ConcatDataset(map_datasets) if len(map_datasets) > 0 else None 188 | ) 189 | 190 | train_datasets = concat_datasets, chained_datasets 191 | train_datasets = tuple([x for x in train_datasets if x is not None]) 192 | train_datasets = ( 193 | train_datasets[0] if len(train_datasets) == 1 else train_datasets 194 | ) 195 | 196 | datasets[split_name] = train_datasets 197 | 198 | return datasets 199 | 200 | -------------------------------------------------------------------------------- /vxverse/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xverse-ai/XVERSE-V-13B/989da861ac64dfb730da0a872f5057372cdad4f5/vxverse/datasets/datasets/__init__.py -------------------------------------------------------------------------------- /vxverse/datasets/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | 25 | self.annotation = [] 26 | # print("ann paths", ann_paths) 27 | for ann_path in ann_paths: 28 | # print("ann_path", ann_path) 29 | ann = json.load(open(ann_path, "r")) 30 | if isinstance(ann, dict): 31 | self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) 32 | # self.annotation.extend(json.load(open(ann_path, "r"))) 33 | else: 34 | self.annotation.extend(json.load(open(ann_path, "r"))) 35 | 36 | self.vis_processor = vis_processor 37 | self.text_processor = text_processor 38 | 39 | self._add_instance_ids() 40 | 41 | def __len__(self): 42 | return len(self.annotation) 43 | 44 | def collater(self, samples): 45 | return default_collate(samples) 46 | 47 | def set_processors(self, vis_processor, text_processor): 48 | self.vis_processor = vis_processor 49 | self.text_processor = text_processor 50 | 51 | def _add_instance_ids(self, key="instance_id"): 52 | for idx, ann in enumerate(self.annotation): 53 | ann[key] = str(idx) 54 | 55 | 56 | class ConcatDataset(ConcatDataset): 57 | def __init__(self, datasets: Iterable[Dataset]) -> None: 58 | super().__init__(datasets) 59 | 60 | def collater(self, samples): 61 | # TODO For now only supports datasets with same underlying collater implementations 62 | 63 | all_keys = set() 64 | for s in samples: 65 | all_keys.update(s) 66 | 67 | shared_keys = all_keys 68 | for s in samples: 69 | shared_keys = shared_keys & set(s.keys()) 70 | 71 | samples_shared_keys = [] 72 | for s in samples: 73 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 74 | 75 | return self.datasets[0].collater(samples_shared_keys) 76 | -------------------------------------------------------------------------------- /vxverse/datasets/datasets/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import time 9 | import random 10 | import torch 11 | from vxverse.datasets.data_utils import move_to_cuda 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class MultiIterLoader: 16 | """ 17 | A simple wrapper for iterating over multiple iterators. 18 | 19 | Args: 20 | loaders (List[Loader]): List of Iterator loaders. 21 | ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. 22 | """ 23 | 24 | def __init__(self, loaders, ratios=None): 25 | # assert all loaders has __next__ method 26 | for loader in loaders: 27 | assert hasattr( 28 | loader, "__next__" 29 | ), "Loader {} has no __next__ method.".format(loader) 30 | 31 | if ratios is None: 32 | ratios = [1.0] * len(loaders) 33 | else: 34 | assert len(ratios) == len(loaders) 35 | ratios = [float(ratio) / sum(ratios) for ratio in ratios] 36 | 37 | self.loaders = loaders 38 | self.ratios = ratios 39 | 40 | def __next__(self): 41 | # random sample from each loader by ratio 42 | loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] 43 | return next(self.loaders[loader_idx]) 44 | 45 | # TODO 46 | # This can be an exception if wds.pipeline, see https://github.com/webdataset/webdataset 47 | # object of type '_InfiniteConstantSampler' has no len() 48 | # If you want to have a length property on your dataset, use the with_length(n) method with whatever length you would like to set. 49 | # def __len__(self): 50 | # return sum([len(x) for x in self.loaders]) 51 | 52 | 53 | class PrefetchLoader(object): 54 | """ 55 | Modified from https://github.com/ChenRocks/UNITER. 56 | 57 | overlap compute and cuda data transfer 58 | (copied and then modified from nvidia apex) 59 | """ 60 | 61 | def __init__(self, loader): 62 | self.loader = loader 63 | self.stream = torch.cuda.Stream() 64 | 65 | def __iter__(self): 66 | loader_it = iter(self.loader) 67 | self.preload(loader_it) 68 | batch = self.next(loader_it) 69 | while batch is not None: 70 | is_tuple = isinstance(batch, tuple) 71 | if is_tuple: 72 | task, batch = batch 73 | 74 | if is_tuple: 75 | yield task, batch 76 | else: 77 | yield batch 78 | batch = self.next(loader_it) 79 | 80 | def __len__(self): 81 | return len(self.loader) 82 | 83 | def preload(self, it): 84 | try: 85 | self.batch = next(it) 86 | except StopIteration: 87 | self.batch = None 88 | return 89 | # if record_stream() doesn't work, another option is to make sure 90 | # device inputs are created on the main stream. 91 | # self.next_input_gpu = torch.empty_like(self.next_input, 92 | # device='cuda') 93 | # self.next_target_gpu = torch.empty_like(self.next_target, 94 | # device='cuda') 95 | # Need to make sure the memory allocated for next_* is not still in use 96 | # by the main stream at the time we start copying to next_*: 97 | # self.stream.wait_stream(torch.cuda.current_stream()) 98 | with torch.cuda.stream(self.stream): 99 | self.batch = move_to_cuda(self.batch) 100 | # more code for the alternative if record_stream() doesn't work: 101 | # copy_ will record the use of the pinned source tensor in this 102 | # side stream. 103 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 104 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 105 | # self.next_input = self.next_input_gpu 106 | # self.next_target = self.next_target_gpu 107 | 108 | def next(self, it): 109 | torch.cuda.current_stream().wait_stream(self.stream) 110 | batch = self.batch 111 | if batch is not None: 112 | record_cuda_stream(batch) 113 | self.preload(it) 114 | return batch 115 | 116 | def __getattr__(self, name): 117 | method = self.loader.__getattribute__(name) 118 | return method 119 | 120 | 121 | def record_cuda_stream(batch): 122 | if isinstance(batch, torch.Tensor): 123 | batch.record_stream(torch.cuda.current_stream()) 124 | elif isinstance(batch, list) or isinstance(batch, tuple): 125 | for t in batch: 126 | record_cuda_stream(t) 127 | elif isinstance(batch, dict): 128 | for t in batch.values(): 129 | record_cuda_stream(t) 130 | else: 131 | pass 132 | 133 | 134 | class IterLoader: 135 | """ 136 | A wrapper to convert DataLoader as an infinite iterator. 137 | 138 | Modified from: 139 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py 140 | """ 141 | 142 | def __init__(self, dataloader: DataLoader, use_distributed: bool = False): 143 | self._dataloader = dataloader 144 | self.iter_loader = iter(self._dataloader) 145 | self._use_distributed = use_distributed 146 | self._epoch = 0 147 | 148 | @property 149 | def epoch(self) -> int: 150 | return self._epoch 151 | 152 | def __next__(self): 153 | try: 154 | data = next(self.iter_loader) 155 | except StopIteration: 156 | self._epoch += 1 157 | if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: 158 | self._dataloader.sampler.set_epoch(self._epoch) 159 | time.sleep(2) # Prevent possible deadlock during epoch transition 160 | self.iter_loader = iter(self._dataloader) 161 | data = next(self.iter_loader) 162 | 163 | return data 164 | 165 | def __iter__(self): 166 | return self 167 | 168 | def __len__(self): 169 | return len(self._dataloader) 170 | -------------------------------------------------------------------------------- /vxverse/datasets/datasets/gqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import os 9 | import json 10 | 11 | from PIL import Image 12 | 13 | from vxverse.datasets.datasets.vqa_datasets import VQADataset 14 | 15 | from collections import OrderedDict 16 | import random 17 | 18 | class __DisplMixin: 19 | def displ_item(self, index): 20 | sample, ann = self.__getitem__(index), self.annotation[index] 21 | 22 | return OrderedDict( 23 | { 24 | "file": ann["image"], 25 | "question": ann["question"], 26 | "question_id": ann["question_id"], 27 | "answers": "; ".join(ann["answer"]), 28 | "image": sample["image"], 29 | } 30 | ) 31 | 32 | 33 | class GQADataset(VQADataset, __DisplMixin): 34 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 35 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 36 | self.instruction_pool =[ 37 | "[vqa] {}", 38 | "[vqa] Based on the image, respond to this question with a short answer: {}" 39 | ] 40 | 41 | def __getitem__(self, index): 42 | ann = self.annotation[index] 43 | 44 | image_path = os.path.join(self.vis_root, ann["image"]) 45 | image = Image.open(image_path).convert("RGB") 46 | 47 | image = self.vis_processor(image) 48 | question = self.text_processor(ann["question"]) 49 | 50 | instruction = random.choice(self.instruction_pool).format(question) 51 | instruction = " {} ".format(instruction) 52 | 53 | answers = self.text_processor(ann["answer"]) 54 | 55 | return { 56 | "image": image, 57 | "instruction_input": instruction, 58 | "answer": answers, 59 | } 60 | 61 | -------------------------------------------------------------------------------- /vxverse/datasets/datasets/vqa_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import torch 9 | from PIL import Image 10 | import os 11 | 12 | from vxverse.datasets.datasets.base_dataset import BaseDataset 13 | 14 | 15 | class VQADataset(BaseDataset): 16 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 17 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 18 | 19 | 20 | class VQAEvalDataset(BaseDataset): 21 | def __init__(self, vis_processor, text_processor, vis_root, ann_paths): 22 | super().__init__(vis_processor, text_processor, vis_root, ann_paths) 23 | 24 | 25 | class OKVQAEvalData(torch.utils.data.Dataset): 26 | def __init__(self, loaded_data, vis_processor, root_path): 27 | self.loaded_data = loaded_data 28 | self.root_path = root_path 29 | self.vis_processor = vis_processor 30 | 31 | def __len__(self): 32 | return len(self.loaded_data) 33 | 34 | def __getitem__(self, idx): 35 | data = self.loaded_data[idx] 36 | img_id = data['image_id'] 37 | question = data['question'] 38 | question_id = data['question_id'] 39 | img_file = "COCO_val2014_" + '{:0>12}.jpg'.format(img_id) 40 | image_path = os.path.join(self.root_path, img_file) 41 | image = Image.open(image_path).convert('RGB') 42 | image = self.vis_processor(image) 43 | 44 | # question = f"Give the following image: ImageContent. Answer the question with a single word or a short phrase simply. Question: {question}" 45 | question = f"Answer the question with a single word or a short phrase simply. Question: {question}" 46 | # return image, question, question_id, img_id 47 | return {"image":image, "question":question, "question_id":question_id, "img_id":img_id} 48 | 49 | class VizWizEvalData(torch.utils.data.Dataset): 50 | def __init__(self, loaded_data, vis_processor, root_path): 51 | self.loaded_data = loaded_data 52 | self.root_path = root_path 53 | self.vis_processor = vis_processor 54 | 55 | def __len__(self): 56 | return len(self.loaded_data) 57 | 58 | def __getitem__(self, idx): 59 | data = self.loaded_data[idx] 60 | img_id = data['image'] 61 | question = data['question'] 62 | answers = data['answers'] 63 | answers = '_'.join([answer['answer'] for answer in answers]) 64 | image_path = os.path.join(self.root_path, img_id) 65 | image = Image.open(image_path).convert('RGB') 66 | image = self.vis_processor(image) 67 | question = f"[vqa] The question is '{question}' Based on the image, answer the question with a single word or phrase. and reply 'unanswerable' when the provided information is insufficient" 68 | return image, question, answers 69 | 70 | class IconQAEvalData(torch.utils.data.Dataset): 71 | def __init__(self, loaded_data, vis_processor, root_path): 72 | self.loaded_data = loaded_data 73 | self.root_path = root_path 74 | self.vis_processor = vis_processor 75 | 76 | def __len__(self): 77 | return len(self.loaded_data) 78 | 79 | def __getitem__(self, idx): 80 | data = self.loaded_data[idx] 81 | image_id = data['image_id'] 82 | question = data['question'] 83 | image_path = os.path.join(self.root_path, image_id, 'image.png') 84 | image = Image.open(image_path).convert('RGB') 85 | image = self.vis_processor(image).half().cuda() 86 | candidates = '_'.join(data['choices']) 87 | answer = data['answer'] 88 | question = f"[vqa] Based on the image, respond to this question with a single word or phrase: {question}" 89 | return image, question, candidates, answer 90 | 91 | class GQAEvalData(torch.utils.data.Dataset): 92 | def __init__(self, loaded_data, vis_processor, root_path): 93 | self.loaded_data = loaded_data 94 | self.root_path = root_path 95 | self.vis_processor = vis_processor 96 | 97 | def __len__(self): 98 | return len(self.loaded_data) 99 | 100 | def __getitem__(self, idx): 101 | ann = self.loaded_data[idx] 102 | image_id = ann["image"] 103 | image_path = os.path.join(self.root_path, f"{image_id}") 104 | image = Image.open(image_path).convert("RGB") 105 | image = self.vis_processor(image) 106 | question = ann["question"] 107 | # question = f"[vqa] Based on the image, respond to this question with a single word or phrase: {question}" 108 | question = f"Based on the image, respond to this question with a single word or phrase: {question}" 109 | labels = ann["answer"] 110 | # return image, question, labels 111 | return {"image": image, "question": question, "label": labels} 112 | 113 | class HMEvalData(torch.utils.data.Dataset): 114 | def __init__(self, loaded_data, vis_processor, root_path): 115 | self.loaded_data = loaded_data 116 | self.root_path = root_path 117 | self.vis_processor = vis_processor 118 | 119 | def __len__(self): 120 | return len(self.loaded_data) 121 | 122 | def __getitem__(self, idx): 123 | ann = self.loaded_data[idx] 124 | image_id = ann["img"] 125 | image_path = os.path.join(self.root_path, f"{image_id}") 126 | image = Image.open(image_path).convert("RGB") 127 | image = self.vis_processor(image) 128 | question = ann["text"] 129 | question = f"This is an image writting '{question}'. Is this image hateful? Answer yes or no. Answer:" 130 | labels = ann["label"] 131 | 132 | return image, question, labels 133 | 134 | class VSREvalData(torch.utils.data.Dataset): 135 | def __init__(self, loaded_data, vis_processor, root_path): 136 | self.loaded_data = loaded_data 137 | self.root_path = root_path 138 | self.vis_processor = vis_processor 139 | 140 | def __len__(self): 141 | return len(self.loaded_data) 142 | 143 | def __getitem__(self, idx): 144 | ann = self.loaded_data[idx] 145 | image_path = os.path.join(self.root_path, ann["image"]) 146 | image = Image.open(image_path).convert("RGB") 147 | image = self.vis_processor(image) 148 | question = ann["caption"] 149 | question = f'[vqa] Based on the image, is this statement true or false? {question}' 150 | labels = 'true' if ann["label"] == 1 else 'false' 151 | 152 | return image, question, labels -------------------------------------------------------------------------------- /vxverse/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import logging 9 | import torch 10 | from omegaconf import OmegaConf 11 | 12 | from vxverse.common.registry import registry 13 | from vxverse.models.base_model import BaseModel 14 | from vxverse.models.vxverse_base import VXVERSEBase 15 | from vxverse.models.vxverse import VXVERSE 16 | from vxverse.processors.base_processor import BaseProcessor 17 | 18 | 19 | __all__ = [ 20 | "load_model", 21 | "BaseModel", 22 | "VXVERSEBase", 23 | "VXVERSE" 24 | ] 25 | 26 | 27 | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): 28 | """ 29 | Load supported models. 30 | 31 | To list all available models and types in registry: 32 | >>> from vxverse.models import model_zoo 33 | >>> print(model_zoo) 34 | 35 | Args: 36 | name (str): name of the model. 37 | model_type (str): type of the model. 38 | is_eval (bool): whether the model is in eval mode. Default: False. 39 | device (str): device to use. Default: "cpu". 40 | checkpoint (str): path or to checkpoint. Default: None. 41 | Note that expecting the checkpoint to have the same keys in state_dict as the model. 42 | 43 | Returns: 44 | model (torch.nn.Module): model. 45 | """ 46 | 47 | model = registry.get_model_class(name).from_pretrained(model_type=model_type) 48 | 49 | if checkpoint is not None: 50 | model.load_checkpoint(checkpoint) 51 | 52 | if is_eval: 53 | model.eval() 54 | 55 | if device == "cpu": 56 | model = model.float() 57 | 58 | return model.to(device) 59 | 60 | 61 | def load_preprocess(config): 62 | """ 63 | Load preprocessor configs and construct preprocessors. 64 | 65 | If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. 66 | 67 | Args: 68 | config (dict): preprocessor configs. 69 | 70 | Returns: 71 | vis_processors (dict): preprocessors for visual inputs. 72 | txt_processors (dict): preprocessors for text inputs. 73 | 74 | Key is "train" or "eval" for processors used in training and evaluation respectively. 75 | """ 76 | 77 | def _build_proc_from_cfg(cfg): 78 | return ( 79 | registry.get_processor_class(cfg.name).from_config(cfg) 80 | if cfg is not None 81 | else BaseProcessor() 82 | ) 83 | 84 | vis_processors = dict() 85 | txt_processors = dict() 86 | 87 | vis_proc_cfg = config.get("vis_processor") 88 | txt_proc_cfg = config.get("text_processor") 89 | 90 | if vis_proc_cfg is not None: 91 | vis_train_cfg = vis_proc_cfg.get("train") 92 | vis_eval_cfg = vis_proc_cfg.get("eval") 93 | else: 94 | vis_train_cfg = None 95 | vis_eval_cfg = None 96 | 97 | vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) 98 | vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) 99 | 100 | if txt_proc_cfg is not None: 101 | txt_train_cfg = txt_proc_cfg.get("train") 102 | txt_eval_cfg = txt_proc_cfg.get("eval") 103 | else: 104 | txt_train_cfg = None 105 | txt_eval_cfg = None 106 | 107 | txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) 108 | txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) 109 | 110 | return vis_processors, txt_processors 111 | 112 | 113 | def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): 114 | """ 115 | Load model and its related preprocessors. 116 | 117 | List all available models and types in registry: 118 | >>> from vxverse.models import model_zoo 119 | >>> print(model_zoo) 120 | 121 | Args: 122 | name (str): name of the model. 123 | model_type (str): type of the model. 124 | is_eval (bool): whether the model is in eval mode. Default: False. 125 | device (str): device to use. Default: "cpu". 126 | 127 | Returns: 128 | model (torch.nn.Module): model. 129 | vis_processors (dict): preprocessors for visual inputs. 130 | txt_processors (dict): preprocessors for text inputs. 131 | """ 132 | model_cls = registry.get_model_class(name) 133 | 134 | # load model 135 | model = model_cls.from_pretrained(model_type=model_type) 136 | 137 | if is_eval: 138 | model.eval() 139 | 140 | # load preprocess 141 | cfg = OmegaConf.load(model_cls.default_config_path(model_type)) 142 | if cfg is not None: 143 | preprocess_cfg = cfg.preprocess 144 | 145 | vis_processors, txt_processors = load_preprocess(preprocess_cfg) 146 | else: 147 | vis_processors, txt_processors = None, None 148 | logging.info( 149 | f"""No default preprocess for model {name} ({model_type}). 150 | This can happen if the model is not finetuned on downstream datasets, 151 | or it is not intended for direct use without finetuning. 152 | """ 153 | ) 154 | 155 | if device == "cpu" or device == torch.device("cpu"): 156 | model = model.float() 157 | 158 | return model.to(device), vis_processors, txt_processors 159 | 160 | 161 | class ModelZoo: 162 | """ 163 | A utility class to create string representation of available model architectures and types. 164 | 165 | >>> from vxverse.models import model_zoo 166 | >>> # list all available models 167 | >>> print(model_zoo) 168 | >>> # show total number of models 169 | >>> print(len(model_zoo)) 170 | """ 171 | 172 | def __init__(self) -> None: 173 | self.model_zoo = { 174 | k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) 175 | for k, v in registry.mapping["model_name_mapping"].items() 176 | } 177 | 178 | def __str__(self) -> str: 179 | return ( 180 | "=" * 50 181 | + "\n" 182 | + f"{'Architectures':<30} {'Types'}\n" 183 | + "=" * 50 184 | + "\n" 185 | + "\n".join( 186 | [ 187 | f"{name:<30} {', '.join(types)}" 188 | for name, types in self.model_zoo.items() 189 | ] 190 | ) 191 | ) 192 | 193 | def __iter__(self): 194 | return iter(self.model_zoo.items()) 195 | 196 | def __len__(self): 197 | return sum([len(v) for v in self.model_zoo.values()]) 198 | 199 | 200 | model_zoo = ModelZoo() 201 | -------------------------------------------------------------------------------- /vxverse/models/clip_vit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from transformers import CLIPVisionConfig, CLIPVisionModel, CLIPImageProcessor 4 | 5 | 6 | class CLIP_VIT_MODEL(nn.Module): 7 | def __init__(self, model_name, vit_path=None, trainable=False, select_layer=-2, select_feature='patch'): 8 | super().__init__() 9 | self.vit_path = vit_path 10 | # self.image_processor = CLIPImageProcessor.from_pretrained(model_name) 11 | self.visual = CLIPVisionModel.from_pretrained(vit_path) 12 | self.num_features = self.visual.vision_model.config.hidden_size 13 | self.model_name = model_name 14 | self.trainable = trainable 15 | if not self.trainable: 16 | self.visual.requires_grad_(False) 17 | self.select_layer = select_layer 18 | self.select_feature = select_feature 19 | 20 | def feature_select(self, image_forward_outs): 21 | image_features = image_forward_outs.hidden_states[self.select_layer] 22 | if self.select_feature == 'patch': 23 | image_features = image_features[:, 1:] 24 | elif self.select_feature == 'cls_patch': 25 | image_features = image_features 26 | else: 27 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 28 | return image_features 29 | 30 | def forward(self, images): 31 | if type(images) is list: 32 | image_features = [] 33 | for image in images: 34 | image_forward_out = self.visual(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 35 | output_hidden_states=True) 36 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 37 | image_features.append(image_feature) 38 | else: 39 | image_forward_outs = self.visual(images.to(device=self.device, dtype=self.dtype), 40 | output_hidden_states=True) 41 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 42 | 43 | return image_features 44 | 45 | 46 | @property 47 | def dummy_feature(self): 48 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 49 | 50 | @property 51 | def dtype(self): 52 | return self.visual.dtype 53 | 54 | @property 55 | def device(self): 56 | return self.visual.device 57 | 58 | @property 59 | def config(self): 60 | if self.is_loaded: 61 | return self.visual.config 62 | else: 63 | return self.cfg_only 64 | 65 | @property 66 | def hidden_size(self): 67 | return self.config.hidden_size 68 | 69 | @property 70 | def num_patches(self): 71 | return (self.config.image_size // self.config.patch_size) ** 2 72 | 73 | 74 | def create_model_and_transforms_for_vit_2(model_name, vit_path): 75 | 76 | return CLIP_VIT_MODEL(model_name, vit_path) 77 | 78 | 79 | -------------------------------------------------------------------------------- /vxverse/models/modeling_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | 8 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings 9 | from transformers.modeling_outputs import CausalLMOutputWithPast 10 | from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC 11 | from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig 12 | 13 | 14 | class LlamaForCausalLM(LlamaForCausalLMOrig): 15 | 16 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 17 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 18 | def forward( 19 | self, 20 | input_ids: torch.LongTensor = None, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | position_ids: Optional[torch.LongTensor] = None, 23 | past_key_values: Optional[List[torch.FloatTensor]] = None, 24 | inputs_embeds: Optional[torch.FloatTensor] = None, 25 | labels: Optional[torch.LongTensor] = None, 26 | use_cache: Optional[bool] = None, 27 | output_attentions: Optional[bool] = None, 28 | output_hidden_states: Optional[bool] = None, 29 | return_dict: Optional[bool] = None, 30 | reduction: Optional[str] = "mean", 31 | ) -> Union[Tuple, CausalLMOutputWithPast]: 32 | r""" 33 | Args: 34 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 35 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 36 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 37 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 38 | 39 | Returns: 40 | 41 | Example: 42 | 43 | ```python 44 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 45 | 46 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 47 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 48 | 49 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 50 | >>> inputs = tokenizer(prompt, return_tensors="pt") 51 | 52 | >>> # Generate 53 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 54 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 55 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 56 | ```""" 57 | 58 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 59 | output_hidden_states = ( 60 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 61 | ) 62 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 63 | 64 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 65 | outputs = self.model( 66 | input_ids=input_ids, 67 | attention_mask=attention_mask, 68 | position_ids=position_ids, 69 | past_key_values=past_key_values, 70 | inputs_embeds=inputs_embeds, 71 | use_cache=use_cache, 72 | output_attentions=output_attentions, 73 | output_hidden_states=output_hidden_states, 74 | return_dict=return_dict, 75 | ) 76 | 77 | hidden_states = outputs[0] 78 | if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1: 79 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 80 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 81 | logits = torch.cat(logits, dim=-1) 82 | else: 83 | logits = self.lm_head(hidden_states) 84 | logits = logits.float() 85 | 86 | loss = None 87 | if labels is not None: 88 | # Shift so that tokens < n predict n 89 | shift_logits = logits[..., :-1, :].contiguous() 90 | shift_labels = labels[..., 1:].contiguous() 91 | # Flatten the tokens 92 | loss_fct = CrossEntropyLoss(reduction=reduction) 93 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 94 | shift_labels = shift_labels.view(-1) 95 | # Enable model parallelism 96 | shift_labels = shift_labels.to(shift_logits.device) 97 | loss = loss_fct(shift_logits, shift_labels) 98 | if reduction == "none": 99 | loss = loss.view(logits.size(0), -1).mean(1) 100 | 101 | if not return_dict: 102 | output = (logits,) + outputs[1:] 103 | return (loss,) + output if loss is not None else output 104 | 105 | return CausalLMOutputWithPast( 106 | loss=loss, 107 | logits=logits, 108 | past_key_values=outputs.past_key_values, 109 | hidden_states=outputs.hidden_states, 110 | attentions=outputs.attentions, 111 | ) 112 | -------------------------------------------------------------------------------- /vxverse/models/modeling_xverse.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | 8 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings 9 | from transformers.modeling_outputs import CausalLMOutputWithPast 10 | # from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC 11 | # from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig 12 | from transformers import AutoModelForCausalLM 13 | 14 | XverseForCausalLM = AutoModelForCausalLM 15 | 16 | def new_xverse_forward( 17 | self, 18 | input_ids: torch.LongTensor = None, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.LongTensor] = None, 21 | past_key_values: Optional[List[torch.FloatTensor]] = None, 22 | inputs_embeds: Optional[torch.FloatTensor] = None, 23 | labels: Optional[torch.LongTensor] = None, 24 | use_cache: Optional[bool] = None, 25 | output_attentions: Optional[bool] = None, 26 | output_hidden_states: Optional[bool] = None, 27 | return_dict: Optional[bool] = None, 28 | reduction: Optional[str] = "mean", 29 | ) -> Union[Tuple, CausalLMOutputWithPast]: 30 | r""" 31 | Args: 32 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 33 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 34 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 35 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 36 | 37 | Returns: 38 | 39 | Example: 40 | 41 | ```python 42 | >>> from transformers import AutoTokenizer, AutoModelForCausalLM 43 | 44 | >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS, trust_remote_code=True) 45 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 46 | 47 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 48 | >>> inputs = tokenizer(prompt, return_tensors="pt") 49 | 50 | >>> # Generate 51 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 52 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 53 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 54 | ```""" 55 | 56 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 57 | output_hidden_states = ( 58 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 59 | ) 60 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 61 | 62 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 63 | outputs = self.model( 64 | input_ids=input_ids, 65 | attention_mask=attention_mask, 66 | position_ids=position_ids, 67 | past_key_values=past_key_values, 68 | inputs_embeds=inputs_embeds, 69 | use_cache=use_cache, 70 | output_attentions=output_attentions, 71 | output_hidden_states=output_hidden_states, 72 | return_dict=return_dict, 73 | ) 74 | # print("》》》》 LLaMA forward 《《《《《") 75 | # print(f"》》》》》 Inference: self.model(): hidden_states dtype: {outputs[0].dtype}") 76 | hidden_states = outputs[0] 77 | if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1: 78 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 79 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 80 | logits = torch.cat(logits, dim=-1) 81 | else: 82 | logits = self.lm_head(hidden_states) 83 | 84 | ## TODO ? 85 | logits = logits.float() 86 | 87 | loss = None 88 | if labels is not None: 89 | # Shift so that tokens < n predict n 90 | shift_logits = logits[..., :-1, :].contiguous() 91 | shift_labels = labels[..., 1:].contiguous() 92 | # Flatten the tokens 93 | loss_fct = CrossEntropyLoss(reduction=reduction) 94 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 95 | shift_labels = shift_labels.view(-1) 96 | # Enable model parallelism 97 | shift_labels = shift_labels.to(shift_logits.device) 98 | loss = loss_fct(shift_logits, shift_labels) 99 | if reduction == "none": 100 | loss = loss.view(logits.size(0), -1).mean(1) 101 | 102 | if not return_dict: 103 | output = (logits,) + outputs[1:] 104 | return (loss,) + output if loss is not None else output 105 | 106 | return CausalLMOutputWithPast( 107 | loss=loss, 108 | logits=logits, 109 | past_key_values=outputs.past_key_values, 110 | hidden_states=outputs.hidden_states, 111 | attentions=outputs.attentions, 112 | ) 113 | 114 | -------------------------------------------------------------------------------- /vxverse/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from vxverse.processors.base_processor import BaseProcessor 9 | from vxverse.processors.blip_processors import ( 10 | Blip2ImageTrainProcessor, 11 | Blip2ImageEvalProcessor, 12 | BlipCaptionProcessor, 13 | HDImageTrainProcessor, 14 | ) 15 | 16 | from vxverse.common.registry import registry 17 | 18 | __all__ = [ 19 | "BaseProcessor", 20 | "Blip2ImageTrainProcessor", 21 | "Blip2ImageEvalProcessor", 22 | "BlipCaptionProcessor", 23 | "HDImageTrainProcessor", 24 | ] 25 | 26 | 27 | def load_processor(name, cfg=None): 28 | """ 29 | Example 30 | 31 | >>> processor = load_processor("alpro_video_train", cfg=None) 32 | """ 33 | processor = registry.get_processor_class(name).from_config(cfg) 34 | 35 | return processor 36 | -------------------------------------------------------------------------------- /vxverse/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | from omegaconf import OmegaConf 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | @classmethod 21 | def from_config(cls, cfg=None): 22 | return cls() 23 | 24 | def build(self, **kwargs): 25 | cfg = OmegaConf.create(kwargs) 26 | 27 | return self.from_config(cfg) 28 | --------------------------------------------------------------------------------