├── .github └── stale.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_ZH.md ├── examples ├── VCR │ ├── README.md │ └── vcr_sample_images │ │ ├── .gitkeep │ │ └── lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban │ │ ├── .gitkeep │ │ ├── 1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg │ │ └── 1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json └── VQA │ ├── README.md │ ├── convert_checkpoint_after_ft.py │ ├── run_vqav2_ft.py │ ├── vqa_train_config.json │ ├── vqav2_datamodule.py │ ├── vqav2_train_module.py │ ├── write_vqa.py │ └── zero_to_fp32.py ├── models └── VLE │ ├── __init__.py │ ├── configuration_vle.py │ ├── modeling_vle.py │ ├── pipeline_vle.py │ └── processing_vle.py └── pics ├── VQALLM_workflow.png ├── banner.png ├── birds.jpg ├── demo-banner.png ├── dogs.png ├── door.png ├── fishing.png ├── model.png ├── pink_tongues.png ├── qrcode.jpg └── truck.png /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 4 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 4 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: > 18 | Closing the issue, since no updates observed. 19 | Feel free to re-open if you need any further assistance. 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 HY_cog9 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [**中文**](README_ZH.md) | [**English**](https://github.com/iflytek/VLE) 2 | 3 |

4 |
5 | 6 |
7 |

8 |

9 | 10 | GitHub 11 | GitHub repo size 12 | GitHub top language 13 | GitHub last commit 14 | 15 |

16 | 17 | 18 | 19 | 20 | # VLE: Vision-Language Encoder 21 | 22 | Multimodal pre-trained models are trained on massive multimodal data, and they can utilize information from different modalities and perform various cross-modal tasks. 23 | 24 | In this repository, we introduce **VLE** (**V**ision-**L**anguage **E**ncoder), an image-text multimodal understanding model built on the pre-trained text and image encoders. It can be used for multimodal discriminative tasks such as visual question answering and image-text retrieval. Especially on the visual commonsense reasoning (VCR) task, which requires high-level language understanding and reasoning skills, VLE achieves the best performance among the public methods. 25 | 26 | Recently, LLMs (Large Language Models) have achieved great success and have been used for a wide range of text tasks, including translation, question answering, text summarization, etc. While LLMs are unimodal, their abilities can be leveraged for multimodal understanding tasks. We propose a VQA+LLM pipeline that integrates multimodal models with LLMs for the visual question answering task. It helps the VQA model generate more accurate and fluent answers. 27 | 28 | We open-source VLE-related resources for promoting academic research and better facilitating our community. 29 | 30 | **Try our VLE-based [VQA Demo](https://huggingface.co/spaces/hfl/VQA_VLE_LLM) at 🤗Space 👇👇👇** 31 | 32 |
VLE-based VQA Demo
33 | 34 | ---- 35 | 36 | [Chinese LERT](https://github.com/ymcui/LERT) | [Chinese and English PERT](https://github.com/ymcui/PERT) | [Chinese MacBERT](https://github.com/ymcui/MacBERT) | [ChineseMiniRBT](https://github.com/iflytek/MiniRBT) | [Chinese ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [Chinese XLNet](https://github.com/ymcui/Chinese-XLNet) | [Chinese BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [ Knowledge distillation tool TextBrewer](https://github.com/airaria/TextBrewer) | [Model pruning tool TextPruner](https://github.com/airaria/TextPruner) 37 | 38 | More resources released by HFL: https://github.com/iflytek/HFL-Anthology 39 | 40 | ## Table of Contents 41 | 42 | | Section | Description | 43 | | ----------------------------- | ----------------------------------- | 44 | | [Introduction](#introduction) | Introduction to VLE | 45 | | [Downloads](#downloads) | Download links for VLE | 46 | | [Comparison](#comparison) | Comparison of VLE with other models | 47 | | [VQA with LLM](#vqa-with-llm) | Visual question answering with LLM | 48 | | [Usage](#usage) | How to load VLE for different tasks | 49 | 50 | ## Introduction 51 | 52 | ### Structure 53 | 54 | The structure of VLE is similar to [METER](https://arxiv.org/abs/2111.02387), which consists of two unimodal encoders for text and image separately, followed by a crossmodal fusion module. However, there are several structural differences between VLE and METER: 55 | 56 | * VLE uses DeBERTa-v3 as the text encoder, which is stronger than RoBERTa-base used in METER. 57 | * In the large version of VLE (VLE-large), the hidden size of the crossmodal co-attention fusion module is scaled up to 1024 to increase capacities. 58 | * During fine-tuning, VLE introduces additional token_type_embeddings. 59 | 60 | ### Pre-training 61 | 62 | VLE is pre-trained with image-caption pairs. There are four objectives applied during the pre-training stage: 63 | * **MLM** (Masked Language Modeling): Given an image-caption pair, we randomly mask some input text tokens, and the model is trained to reconstruct the original tokens. 64 | * **ITM** (Image-Text Matching): Given a batch of matched or mismatched image-caption pairs, the model needs to identify which images and captions correspond to each other. 65 | * **MPC** (Masked Patch-box Classification): Given an image-caption pair with some patches masked, the model needs to predict the classes of the objects in the masked patches. 66 | * **PBC** (Patch-box Classification): Given an image-caption pair, the models need to identify which patches are related to the caption. 67 | 68 | VLE models are pre-trained on 14M public English image-caption pairs for 25k steps with a batch size of 2048. 69 | 70 | The following figure illustrates the VLE structure and the pre-training objectives (for simplicity, we omit the PBC objective in the figure). 71 | 72 |
VLE structure and pre-training tasks
73 | 74 | ### Adaptation for downstream tasks 75 | 76 | #### Visual Question Answering (VQA) 77 | 78 | * We follow the standard practice to train the models on VQA with both training and validation data, and test the models on the test-dev set. The pooler output from the last layer of the fusion module is used for classification. 79 | 80 | #### Visual Commonsense Reasoning (VCR) 81 | 82 | * We format VCR as a multiple-choice task which is similar to RACE. For each object in the image in each example, we append the average of patches that cover the object to the image feature embeddings before the fusion module. We also assign token_type_ids to the objects in the image and text to improve alignment between different modalities. 83 | 84 | 85 | ## Downloads 86 | 87 | The model weights are in PyTorch format and can be downloaded through the 🤗 transformers model hub. You can either download the weights and configurations manually or initialize a VLE model with `from_pretrained(model_name)` method in your code. See [Usage](#usage) for details. 88 | 89 | ### Pre-trained Checkpoints 90 | 91 | | Model | Text Encoder | Image Encoder | # Params* | MODEL_NAME | Link | 92 | | --------- | ---------------- | ---------------------- | -------------------- | ------------- | -------------------------------------------- | 93 | | VLE-base | DeBERTa-v3-base | CLIP-ViT-base-patch16 | 378M | hfl/vle-base | [link](https://huggingface.co/hfl/vle-base) | 94 | | VLE-large | DeBERTa-v3-large | CLIP-ViT-large-patch14 | 930M | hfl/vle-large | [link](https://huggingface.co/hfl/vle-large) | 95 | 96 | * : We exclude task heads when counting the number of parameters. 97 | 98 | ### Fine-tuned Checkpoints 99 | 100 | | Model | Text Encoder | Image Encoder | MODEL_NAME | Link | 101 | | ---------------------- | ---------------- | ---------------------- | -------------------------- | --------------------------------------------------------- | 102 | | VLE-base-for-VQA | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vqa | [link](https://huggingface.co/hfl/vle-base-for-vqa) | 103 | | VLE-large-for-VQA | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vqa | [link](https://huggingface.co/hfl/vle-large-for-vqa) | 104 | | VLE-base-for-VCR-q2a | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-base-for-vcr-q2a) | 105 | | VLE-large-for-VCR-q2a | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-large-for-vcr-q2a) | 106 | | VLE-base-for-VCR-qa2r | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-base-for-vcr-qa2r) | 107 | | VLE-large-for-VCR-qa2r | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-large-for-vcr-qa2r) | 108 | 109 | ## Comparison 110 | 111 | In the following table, we compare the performance of VLE with METER and other multimodal models. The VQA results are on the test-dev set, and the VCR results are on the dev set. 112 | 113 | | Model | VQA | VCR (QA2R) | VCR (Q2A) | #Params | #PT data* | 114 | | ------------------- | ---------------- | -------------- | ------------- | ------------ | ------- | 115 | | CoCa | 82.3 | - | - | 2.1 B | unknown | 116 | | BeiT-3 | 84.2 | - | - | 1.9 B | 21M(I-T) + 14M(I) + 160G(T) | 117 | | OFA | 82.0 | - | - | 930M | 20M(I-T) + 39M(I) + 140G(T) | 118 | | BLIP | 78.3 | - | - | 385M | ~130M(I-T) | 119 | | METER-base | 77.7 (76.8†‡) | 79.8§ | 77.6§ | 345M | 9M(I-T) | 120 | | METER-Huge | 80.3 | - | - | 878M | 20M(I-T) | 121 | | VLE-base | 77.6 | 83.7§ | 79.9§ | 378M | 15M(I-T) | 122 | | VLE-large | 79.3 | 87.5§ | 84.3§ | 930M | 15M(I-T) | 123 | 124 | : Result from our reimplementation. 125 | 126 | : Fine-tuning hyperparameters: lr=7e-6, batch_size={256, 512}, num_epochs=10 127 | 128 | § : Fine-tuning hyperparameters: lr=1e-5, batch_size=128, num_epochs=5 129 | 130 | * : Pre-training data. I-T: Image-caption pairs. I: Images. T: Text. 131 | 132 | From the above results, we can see that: 133 | 134 | * **VLE is pre-training efficient**. Compared to models with similar model sizes, VLE achieves comparable or even better performance on VQA with much less pre-training data. 135 | 136 | * **VLE shows higher reasoning ability**. Especially it significantly outperforms METER on Visual Commonsense Reasoning (VCR), which requires higher level language and reasoning skills than VQA. 137 | 138 | ## VQA with LLM 139 | 140 | ### Generating Accurate and Fluent VQA Answers 141 | 142 | LLMs have achieved great success on a wide range of text tasks, while the abilities of LLMs can also be leveraged for multimodal understanding tasks. Specifically, we present a VQA+LLM pipeline that integrates multimodal models with LLMs for the visual question answering task, which helps the VQA model to generate more accurate and fluent answers. 143 | 144 | The workflows are shown in the figure below. 145 | 146 |
Workflows
147 | 148 | (a) VQA: This is the standard way to perform the VQA task with a discriminative model. The question and the image are fed into the multimodal model, and the model is trained to predict the correct answer labels. 149 | 150 | (b) VQA + LLM: The captioning model generates a caption of the image. The caption, question, and answer candidates predicted by the VQA model are concatenated and fed to the LLM. The LLM is asked to give the most reasonable answer. 151 | 152 | We find that VQA+LLM can not only make more accurate predictions, but also generate more fluent and readable predictions. We list some examples: 153 | 154 |
men and truck
155 | 156 | 157 | 158 |
hatch
159 | 160 | The demo is available at : https://huggingface.co/spaces/hfl/VQA_VLE_LLM 161 | 162 | 163 | ## Usage 164 | 165 | **Requirements** 166 | 167 | * PIL 168 | * Transformers >= 4.25 169 | * PyTorch Lightning (only required for running fine-tuning scripts) 170 | 171 | The model classes and utilities are defined in the `*.py` files in [models/VLE](models/VLE). To import VLE into your code, just copy [models](models) directory into your project. 172 | 173 | To run the following demo code, `git clone` the repository and `cd` into it, ensuring you are in the repository's root directory. 174 | 175 | ### Load the VLEModel 176 | 177 | ```python 178 | from models.VLE import VLEModel, VLEProcessor 179 | from PIL import Image 180 | import torch 181 | 182 | model_name="hfl/vle-large" 183 | images = [Image.open('pics/dogs.png')] 184 | text = ["There are dogs on the grass."] 185 | 186 | model = VLEModel.from_pretrained(model_name) 187 | vle_processor = VLEProcessor.from_pretrained(model_name) 188 | multimodal_inputs = vle_processor(text=text,images=images, return_tensors='pt',padding=True) 189 | 190 | #forward 191 | vle_output = model(**multimodal_inputs) 192 | ``` 193 | 194 | ### Inference 195 | 196 | #### Visual Question Answering (VQA) 197 | 198 | ```python 199 | from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline 200 | from PIL import Image 201 | 202 | model_name="hfl/vle-base-for-vqa" 203 | text= "What is the color of the floor?" 204 | image = Image.open("pics/door.png") 205 | 206 | model = VLEForVQA.from_pretrained(model_name) 207 | vle_processor = VLEProcessor.from_pretrained(model_name) 208 | vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor) 209 | 210 | vqa_answers = vqa_pipeline(image=image, question=text, top_k=5) 211 | print(f"Question: {text}. Answers: {vqa_answers}") 212 | ``` 213 | 214 | #### Image-Text Matching 215 | 216 | ```python 217 | from models.VLE import VLEForITM, VLEProcessor, VLEForITMPipeline 218 | from PIL import Image 219 | 220 | model_dir = 'hfl/vle-base' 221 | itm_text = ["a photo of a cat.", "a photo of dogs."] 222 | itm_images = Image.open("pics/dogs.png") 223 | 224 | print("Init ITM model") 225 | model = VLEForITM.from_pretrained(model_dir) 226 | vle_processor = VLEProcessor.from_pretrained(model_dir) 227 | 228 | print("init ITM pipeline") 229 | itm_pipeline = VLEForITMPipeline(model=model, device='cpu', vle_processor=vle_processor) 230 | itm_pred = itm_pipeline([{"image": itm_images, "text": itm_text[0]}, 231 | {"image": itm_images, "text": itm_text[1]}]) 232 | 233 | for t, pred in zip(itm_text,itm_pred): 234 | print(t,pred) 235 | ``` 236 | 237 | #### Patch Box Classification 238 | 239 | ```python 240 | from models.VLE import VLEForPBC, VLEProcessor, VLEForPBCPipeline 241 | from PIL import Image 242 | 243 | model_dir = 'hfl/vle-base' 244 | pbc_text = "pink tongues" 245 | pbc_image = Image.open("pics/dogs.png") 246 | 247 | print("Init PBC model") 248 | model = VLEForPBC.from_pretrained(model_dir) 249 | vle_processor = VLEProcessor.from_pretrained(model_dir) 250 | 251 | print("init PBC pipeline") 252 | pbc_pipeline = VLEForPBCPipeline(model=model, device='cpu', vle_processor=vle_processor) 253 | pbc_pred = pbc_pipeline(image=pbc_image,text=pbc_text) 254 | print(pbc_text) 255 | pbc_pred['image'].save('pics/pink_tongues.png') 256 | ``` 257 | 258 | 259 | #### Visual Commonsense Reasoning (VCR) 260 | 261 | Please follow the instructions in [examples/VCR/README.md](examples/VCR/README.md) 262 | 263 | ### Fine-tuning 264 | 265 | #### Fine-tuning on VQA 266 | 267 | Please follow the instructions in [examples/VQA/README.md](examples/VQA/README.md) 268 | 269 | ## Follow us 270 | 271 | Welcome to follow the official WeChat account of HFL to keep up with the latest technical developments. 272 | 273 | ![qrcode.png](pics/qrcode.jpg) 274 | 275 | ## Disclaimer 276 | This repository's resources are solely intended for academic purposes, and we assume no responsibility for any unforeseen damages or losses that may result from their use. 277 | 278 | This is not an official product by iFLYTEK Co., Ltd. 279 | 280 | ## Issues 281 | 282 | If you have questions, please submit them in a GitHub Issue. 283 | 284 | - Before submitting an issue, please check whether the FAQ can solve the problem, and it is recommended to check whether the previous issue can solve your problem. 285 | - Duplicate and unrelated issues will be handled by [stable-bot](stale · GitHub Marketplace). 286 | - We will try our best to answer your questions, but there is no guarantee that your questions will be answered. 287 | - Politely ask questions and build a harmonious discussion community. 288 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 | [**中文**](README_ZH.md) | [**English**](https://github.com/iflytek/VLE) 2 | 3 |

4 |
5 | 6 |
7 |

8 |

9 | 10 | GitHub 11 | GitHub repo size 12 | GitHub top language 13 | GitHub last commit 14 | 15 |

16 | 17 | 18 | # VLE:视觉-语言多模态预训练模型 19 | 20 | 多模态预训练模型通过在多种模态的大规模数据上的预训练,可以综合利用来自不同模态的信息,执行各种跨模态任务。在本项目中,我们推出了**VLE** (**V**ision-**L**anguage **E**ncoder),一种基于预训练文本和图像编码器的图像-文本多模态理解模型,可应用于如视觉问答、图像-文本检索等多模态判别任务。特别地,在对语言理解和推理能力有更强要求的视觉常识推理(VCR)任务中,VLE取得了公开模型中的最佳效果。 21 | 22 | 最近,大型语言模型(LLM)取得了巨大成功,并被用于翻译、问答、摘要等文本任务。虽然LLM是单模态模型,但它们的能力也可用于辅助多模态理解任务。借助LLM的zero-shot能力,我们设计了一种VQA+LLM方案,将大型语言模型集成到视觉问答任务中,实现了帮助视觉问答模型生成更准确和流畅的答案。 23 | 24 | 我们开源VLE相关资源以供学术研究参考。 25 | 26 | 在线演示地址:https://huggingface.co/spaces/hfl/VQA_VLE_LLM 27 | 28 | ---- 29 | 30 | [中文LERT](https://github.com/ymcui/LERT) | [中英文PERT](https://github.com/ymcui/PERT) | [中文MacBERT](https://github.com/ymcui/MacBERT) | [中文MiniRBT](https://github.com/iflytek/MiniRBT) | [中文ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [中文XLNet](https://github.com/ymcui/Chinese-XLNet) | [中文BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [知识蒸馏工具TextBrewer](https://github.com/airaria/TextBrewer) | [模型裁剪工具TextPruner](https://github.com/airaria/TextPruner) 31 | 32 | 查看更多哈工大讯飞联合实验室(HFL)发布的资源:https://github.com/iflytek/HFL-Anthology 33 | 34 | ## 内容导引 35 | 36 | | 章节 | 描述 | 37 | | ----------------------------- | ----------------------------------------------------------- | 38 | | [简介](#简介) | VLE的结构和训练技术 | 39 | | [模型下载](#模型下载) | VLE预训练模型下载地址 | 40 | | [模型对比](#模型对比) | 多模态任务上VLE与其他模型的比较 | 41 | | [结合大模型的视觉问答](#结合大模型的视觉问答) | 结合大模型的视觉问答策略 | 42 | | [模型使用](#模型使用) | 模型的加载与使用方法 | 43 | 44 | ## 简介 45 | 46 | ### 模型结构 47 | 48 | VLE模型采用双流结构,与METER模型结构类似,由两个单模态编码器(图像编码器和文本编码器)和一个跨模态融合模块构成。VLE与METER的结构上的差异在于: 49 | 50 | * VLE使用DeBERTa-v3作为文本编码器,其性能优于METER中使用的RoBERTa-base。 51 | * 在VLE-large中,跨模态融合模块的隐层维度增加至1024,以增加模型的容量。 52 | * 在精调阶段,VLE引入了额外的token类型向量表示。 53 | 54 | ### 预训练 55 | 56 | VLE使用图文对数据进行预训练。在预训练阶段,VLE采用了四个预训练任务: 57 | 58 | * **MLM** (Masked Language Modeling):掩码预测任务。给定图文对,随机遮掩文本中的部分单词,训练模型还原遮掩的文本。 59 | * **ITM** (Image-Text Matching):图文匹配预测任务。给定图文对,训练模型判断图像和文本是否匹配。 60 | * **MPC** (Masked Patch-box Classification):遮掩Patch分类任务,给定图文对,并遮掩掉图片中包含具体对象的patch,训练模型预测被遮掩的对象种类。 61 | * **PBC** (Patch-box classification):Patch分类任务。给定图文对,预测图片中的哪些patch与文本描述相关。 62 | 63 | VLE在14M的英文图文对数据上进行了25000步的预训练,batch大小为2048。下图展示了VLE的模型结构和部分预训练任务(MLM、ITM和MPC)。 64 | 65 | 66 |
VLE structure and pre-training tasks
67 | 68 | ### 下游任务适配 69 | 70 | #### 视觉问答 (VQA) 71 | 72 | * 我们遵循标准做法,使用VQA的训练集(training set)和验证集(validation set)训练模型,在test-dev集上进行验证。我们采用模型的融合层的pooler的输出进行分类任务的训练。 73 | 74 | #### 视觉常识推理 (VCR) 75 | 76 | * 我们将VCR格式化为一个类似于RACE的选择题任务,并对于每张图像中的对象,将覆盖该对象的patch的表示的平均池化值添加到融合模块之前的图像特征序列中。我们还为图像和文本中的对象添加额外的token_type_ids,以注入不同模态之间的对齐信息,提升模型的对齐性能。 77 | 78 | 79 | ## 模型下载 80 | 81 | 本次发布了VLE-base和VLE-large两个版本的预训练模型,模型权重为PyTorch格式,可以选择手动从🤗 transformers模型库下载权重和配置文件,或者在代码中使用 `from_pretrained(model_name)` 以自动加载模型。详细方法参加[模型使用](#模型使用)。 82 | 83 | ### 预训练权重 84 | 85 | | 模型 | 文本编码器 | 图像编码器 | 参数量* | MODEL_NAME | 链接 | 86 | | --------- | ---------------- | ---------------------- | ------------------ | ------------- | -------------------------------------------- | 87 | | VLE-base | DeBERTa-v3-base | CLIP-ViT-base-patch16 | 378M | hfl/vle-base | [link](https://huggingface.co/hfl/vle-base) | 88 | | VLE-large | DeBERTa-v3-large | CLIP-ViT-large-patch14 | 930M | hfl/vle-large | [link](https://huggingface.co/hfl/vle-large) | 89 | 90 | * : 仅计算encoder和emebddings的参数。特定任务的预测层的参数量未计入。 91 | 92 | ### 精调权重 93 | 94 | | 模型 | 文本编码器 | 图像编码器 | MODEL_NAME | 链接 | 95 | | ---------------------- | ---------------- | ---------------------- | -------------------------- | --------------------------------------------------------- | 96 | | VLE-base-for-VQA | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vqa | [link](https://huggingface.co/hfl/vle-base-for-vqa) | 97 | | VLE-large-for-VQA | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vqa | [link](https://huggingface.co/hfl/vle-large-for-vqa) | 98 | | VLE-base-for-VCR-q2a | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-base-for-vcr-q2a) | 99 | | VLE-large-for-VCR-q2a | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-large-for-vcr-q2a) | 100 | | VLE-base-for-VCR-qa2r | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-base-for-vcr-qa2r) | 101 | | VLE-large-for-VCR-qa2r | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-large-for-vcr-qa2r) | 102 | 103 | ## 模型对比 104 | 105 | 在下表中,我们比较了VLE、METER以及其他多模态模型的参数量、预训练数据和下游任务效果。其中VQA展示的的是test-dev集上的效果;VCR展示的是dev集上的效果。 106 | 107 | | 模型 | VQA | VCR (QA2R) | VCR (Q2A) | 参数量 | 预训练数据量* | 108 | | ------------------- | ---------------- | -------------- | ------------- | ------------ | ------- | 109 | | CoCa | 82.3 | - | - | 2.1 B | 未知 | 110 | | BeiT-3 | 84.2 | - | - | 1.9 B | 21M(I-T) + 14M(I) + 160G(T) | 111 | | OFA | 82.0 | - | - | 930M | 20M(I-T) + 39M(I) + 140G(T) | 112 | | BLIP | 78.3 | - | - | 385M | ~130M(I-T) | 113 | | METER-base | 77.7 (76.8†‡) | 79.8§ | 77.6§ | 345M | 9M(I-T) | 114 | | METER-Huge | 80.3 | - | - | 878M | 20M(I-T) | 115 | | VLE-base | 77.6 | 83.7§ | 79.9§ | 378M | 15M(I-T) | 116 | | VLE-large | 79.3 | 87.5§ | 84.3§ | 930M | 15M(I-T) | 117 | 118 | : 复现效果 119 | 120 | : 精调参数: lr=7e-6, batch_size={256, 512}, num_epochs=10 121 | 122 | § : 精调参数: lr=1e-5, batch_size=128, num_epochs=5 123 | 124 | * : I-T: 图文对. I: 图像. T: 文本. 125 | 126 | 观察上表可以发现: 127 | 128 | * **VLE的预训练更高效**:与大小相近的模型相比,VLE使用了更少的预训练数据,并在视觉问答上取得了相当甚至更好的效果。 129 | * **VLE有更强的推理能力**: 特别地,在对推理能力要求更高的视觉常识推理(VCR)任务上,VLE显著地超过了具有相似结构的METER。 130 | 131 | ## 结合大模型的视觉问答 132 | 133 | 最近,随着指令微调、RLHF等技术的发展,LLM在多种文本任务中取得了巨大的成功。尽管LLM是单模态模型,但它们的能力也可用于辅助多模态理解任务。具体而言,我们提出一种VQA + LLM方案,将多模态模型与LLM集成到视觉问答任务中,从而帮助VQA模型生成更准确和流畅的答案。下图展示了系统流程。 134 | 135 |
Workflows
136 | 137 | (a) VQA: 这是使用判别模型执行VQA任务的标准方式。输入问题和图像到多模态模型中,训练模型预测正确的答案标签。 138 | 139 | (b) VQA + LLM: 首先利用captioning模型生成图片的描述;将图片描述、问题以及VQA模型的详细预测结果拼接,组合成合适的prompt的形式送入LLM,最后要求LLM模型回复最合理的答案。 140 | 141 | VQA+LLM生成的答案更准确,也有更高的可读性。下面是一些例子: 142 | 143 |
men and truck
144 | 145 | 146 | 147 |
hatch
148 | 149 | Demo地址(仅供学术研究):https://huggingface.co/spaces/hfl/VQA_VLE_LLM 150 | 151 | 152 | 153 | 154 | ## 模型使用 155 | 156 | **环境要求** 157 | 158 | * PIL 159 | * Transformers >= 4.25 160 | * PyTorch Lightning (仅用于运行精调脚本) 161 | 162 | 模型相关代码位于[models/VLE](models/VLE)目录下的`*py`文件中。因此,要使用VLE模型,仅需把[models](models)目录复制到你的项目代码目录即可。 163 | 164 | 要运行以下演示代码,请使用`git clone`命令下载本仓库至本地,并进入仓库的根目录。 165 | 166 | ### 加载VLEModel 167 | 168 | ```python 169 | from models.VLE import VLEModel, VLEProcessor 170 | from PIL import Image 171 | import torch 172 | 173 | model_name="hfl/vle-large" 174 | images = [Image.open('pics/dogs.png')] 175 | text = ["There are dogs on the grass."] 176 | 177 | model = VLEModel.from_pretrained(model_name) 178 | vle_processor = VLEProcessor.from_pretrained(model_name) 179 | multimodal_inputs = vle_processor(text=text,images=images, return_tensors='pt',padding=True) 180 | 181 | #forward 182 | vle_output = model(**multimodal_inputs) 183 | ``` 184 | 185 | ### 推理 186 | 187 | #### 视觉问答 (VQA) 188 | 189 | ```python 190 | from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline 191 | from PIL import Image 192 | 193 | model_name="hfl/vle-base-for-vqa" 194 | text= "What is the color of the floor?" 195 | image = Image.open("pics/door.png") 196 | 197 | model = VLEForVQA.from_pretrained(model_name) 198 | vle_processor = VLEProcessor.from_pretrained(model_name) 199 | vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor) 200 | 201 | vqa_answers = vqa_pipeline(image=image, question=text, top_k=5) 202 | print(f"Question: {text}. Answers: {vqa_answers}") 203 | ``` 204 | 205 | #### 图文匹配(ITM) 206 | 207 | ```python 208 | from models.VLE import VLEForITM, VLEProcessor, VLEForITMPipeline 209 | from PIL import Image 210 | 211 | model_dir = 'hfl/vle-base' 212 | itm_text = ["a photo of a cat.", "a photo of dogs."] 213 | itm_images = Image.open("pics/dogs.png") 214 | 215 | print("Init ITM model") 216 | model = VLEForITM.from_pretrained(model_dir) 217 | vle_processor = VLEProcessor.from_pretrained(model_dir) 218 | 219 | print("init ITM pipeline") 220 | itm_pipeline = VLEForITMPipeline(model=model, device='cpu', vle_processor=vle_processor) 221 | itm_pred = itm_pipeline([{"image": itm_images, "text": itm_text[0]}, 222 | {"image": itm_images, "text": itm_text[1]}]) 223 | 224 | for t, pred in zip(itm_text,itm_pred): 225 | print(t,pred) 226 | ``` 227 | 228 | #### Patch分类(PBC) 229 | 230 | ```python 231 | from models.VLE import VLEForPBC, VLEProcessor, VLEForPBCPipeline 232 | from PIL import Image 233 | 234 | model_dir = 'hfl/vle-base' 235 | pbc_text = "pink tongues" 236 | pbc_image = Image.open("pics/dogs.png") 237 | 238 | print("Init PBC model") 239 | model = VLEForPBC.from_pretrained(model_dir) 240 | vle_processor = VLEProcessor.from_pretrained(model_dir) 241 | 242 | print("init PBC pipeline") 243 | pbc_pipeline = VLEForPBCPipeline(model=model, device='cpu', vle_processor=vle_processor) 244 | pbc_pred = pbc_pipeline(image=pbc_image,text=pbc_text) 245 | print(pbc_text) 246 | pbc_pred['image'].save('pics/pink_tongues.png') 247 | ``` 248 | 249 | 250 | #### 视觉常识推理(VCR) 251 | 252 | 详细步骤参见 [examples/VCR/README.md](examples/VCR/README.md) 253 | 254 | ### 精调 255 | 256 | #### VQA任务精调 257 | 258 | 详细步骤参见 [examples/VQA/README.md](examples/VQA/README.md) 259 | 260 | ## 关注我们 261 | 欢迎关注哈工大讯飞联合实验室官方微信公众号,了解最新的技术动态。 262 | 263 | ![qrcode.png](pics/qrcode.jpg) 264 | 265 | ## 免责声明 266 | 本项目相关资源应仅用于学术研究。我们不对使用相关资源产生的任何损失负责。 267 | 268 | 本目录相关内容不属于科大讯飞官方产品。 269 | 270 | ## 问题反馈 271 | 如有问题,请在GitHub Issue中提交。 272 | 273 | - 在提交问题之前,请先查看FAQ能否解决问题,同时建议查阅以往的issue是否能解决你的问题。 274 | - 重复以及与本项目无关的issue会被[stable-bot](stale · GitHub Marketplace)处理,敬请谅解。 275 | - 我们会尽可能的解答你的问题,但无法保证你的问题一定会被解答。 276 | - 礼貌地提出问题,构建和谐的讨论社区。 277 | -------------------------------------------------------------------------------- /examples/VCR/README.md: -------------------------------------------------------------------------------- 1 | # Inference on Visual Commonsense Reasoning (VCR) 2 | 3 | ## Dataset Preparation for VCR 4 | 5 | Download the VCR dataset from [VCR official site](https://visualcommonsense.com/download/), including [Annotations](https://s3.us-west-2.amazonaws.com/ai2-rowanz/vcr1annots.zip) and [Images](https://s3.us-west-2.amazonaws.com/ai2-rowanz/vcr1images.zip). 6 | Unzip the downloaded file. 7 | 8 | ## Inference with VCR pipeline 9 | 10 | Here are two examples of using fine-tuned VLEForVCR Q2A and QA2R models to infer on a VCR sample. 11 | The sample's image is placed in `vcr_sample_images/` as follows, and the annotation `meta_data` is taken from VCR validation set. 12 | 13 | vcr_sample_images 14 | └──lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban 15 | ├──1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168_0.jpg 16 | └──1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168_0.json 17 | 18 | ### VCR Q2A 19 | ```python 20 | from models.VLE import VLEForVCRQ2A, VLEProcessor, VLEForVCRQ2APipeline 21 | 22 | model_name = 'hfl/vle-large-for-vcr-q2a' 23 | model = VLEForVCRQ2A.from_pretrained(model_name) 24 | vle_processor = VLEProcessor.from_pretrained(model_name) 25 | vcr_q2a_pipeline = VLEForVCRQ2APipeline(model=model, device='cpu', vle_processor=vle_processor) 26 | 27 | vcr_image_root = 'examples/VCR/vcr_sample_images' 28 | meta_data = {"movie": "1054_Harry_Potter_and_the_prisoner_of_azkaban", "objects": ["person", "person", "person", "car", "cellphone", "clock"], "interesting_scores": [-1, 0], "answer_likelihood": "possible", "img_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg", "metadata_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json", "answer_orig": "No, 1 is a visitor.", "question_orig": "Does 1 live in this house?", "rationale_orig": "1 is wearing outerwear, holding an umbrella, and there is a car outside.", "question": ["Does", [0], "live", "in", "this", "house", "?"], "answer_match_iter": [2, 3, 0, 1], "answer_sources": [10104, 5332, 1, 16646], "answer_choices": [["No", ",", [0], "lives", "nowhere", "close", "."], ["Yes", ",", [0], "works", "there", "."], ["No", ",", [0], "is", "a", "visitor", "."], ["No", [1], "does", "not", "belong", "here", "."]], "answer_label": 2, "rationale_choices": [[[0], "is", "nicely", "dressed", "with", "a", "tie", ".", "people", "dress", "up", "when", "they", "visit", "someone", "else", "."], [[2], "sits", "comfortably", "in", "a", "chair", ",", "reading", "papers", ",", "while", "it", "seems", [0], "has", "just", "arrived", "and", "is", "settling", "in", "."], [[1], "is", "wearing", "a", "coat", "and", "muff", "and", "is", "sitting", "as", "if", "a", "visitor", "."], [[0], "is", "wearing", "outerwear", ",", "holding", "an", "umbrella", ",", "and", "there", "is", "a", "car", "outside", "."]], "rationale_sources": [26162, 12999, 6661, 1], "rationale_match_iter": [1, 3, 2, 0], "rationale_label": 3, "img_id": "val-0", "question_number": 1, "annot_id": "val-1", "match_fold": "val-0", "match_index": 1} 29 | 30 | vcr_outputs = vcr_q2a_pipeline(vcr_image_root=vcr_image_root, meta_inputs=meta_data) 31 | pred = vcr_outputs[0]["pred"] 32 | print(f'Q: {meta_data["question"]}') 33 | print(f'A1: {meta_data["answer_choices"][0]}') 34 | print(f'A2: {meta_data["answer_choices"][1]}') 35 | print(f'A3: {meta_data["answer_choices"][2]}') 36 | print(f'A4: {meta_data["answer_choices"][3]}') 37 | print(f'Label: {meta_data["answer_label"] + 1}') 38 | print(f'predict: {pred[0] + 1}') 39 | ``` 40 | 41 | ### VCR QA2R 42 | ```python 43 | from models.VLE import VLEForVCRQA2R, VLEProcessor, VLEForVCRQA2RPipeline 44 | 45 | model_name = 'hfl/vle-large-for-vcr-qa2r' 46 | model = VLEForVCRQA2R.from_pretrained(model_name) 47 | vle_processor = VLEProcessor.from_pretrained(model_name) 48 | vcr_qa2r_pipeline = VLEForVCRQA2RPipeline(model=model, device='cpu', vle_processor=vle_processor) 49 | 50 | vcr_image_root = 'examples/VCR/vcr_sample_images' 51 | meta_data = {"movie": "1054_Harry_Potter_and_the_prisoner_of_azkaban", "objects": ["person", "person", "person", "car", "cellphone", "clock"], "interesting_scores": [-1, 0], "answer_likelihood": "possible", "img_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg", "metadata_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json", "answer_orig": "No, 1 is a visitor.", "question_orig": "Does 1 live in this house?", "rationale_orig": "1 is wearing outerwear, holding an umbrella, and there is a car outside.", "question": ["Does", [0], "live", "in", "this", "house", "?"], "answer_match_iter": [2, 3, 0, 1], "answer_sources": [10104, 5332, 1, 16646], "answer_choices": [["No", ",", [0], "lives", "nowhere", "close", "."], ["Yes", ",", [0], "works", "there", "."], ["No", ",", [0], "is", "a", "visitor", "."], ["No", [1], "does", "not", "belong", "here", "."]], "answer_label": 2, "rationale_choices": [[[0], "is", "nicely", "dressed", "with", "a", "tie", ".", "people", "dress", "up", "when", "they", "visit", "someone", "else", "."], [[2], "sits", "comfortably", "in", "a", "chair", ",", "reading", "papers", ",", "while", "it", "seems", [0], "has", "just", "arrived", "and", "is", "settling", "in", "."], [[1], "is", "wearing", "a", "coat", "and", "muff", "and", "is", "sitting", "as", "if", "a", "visitor", "."], [[0], "is", "wearing", "outerwear", ",", "holding", "an", "umbrella", ",", "and", "there", "is", "a", "car", "outside", "."]], "rationale_sources": [26162, 12999, 6661, 1], "rationale_match_iter": [1, 3, 2, 0], "rationale_label": 3, "img_id": "val-0", "question_number": 1, "annot_id": "val-1", "match_fold": "val-0", "match_index": 1} 52 | 53 | vcr_outputs = vcr_qa2r_pipeline(vcr_image_root=vcr_image_root, meta_inputs=meta_data) 54 | pred = vcr_outputs[0]["pred"] 55 | print(f'Q: {meta_data["question"]}') 56 | print(f'A: {meta_data["answer_choices"][meta_data["answer_label"]]}') 57 | print(f'R1: {meta_data["rationale_choices"][0]}') 58 | print(f'R2: {meta_data["rationale_choices"][1]}') 59 | print(f'R3: {meta_data["rationale_choices"][2]}') 60 | print(f'R4: {meta_data["rationale_choices"][3]}') 61 | print(f'Label: {meta_data["rationale_label"] + 1}') 62 | print(f'predict: {pred[0] + 1}') 63 | ``` 64 | -------------------------------------------------------------------------------- /examples/VCR/vcr_sample_images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/examples/VCR/vcr_sample_images/.gitkeep -------------------------------------------------------------------------------- /examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/.gitkeep -------------------------------------------------------------------------------- /examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg -------------------------------------------------------------------------------- /examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json: -------------------------------------------------------------------------------- 1 | {"boxes": [[955.7418212890625, 52.329559326171875, 1551.5677490234375, 789.0325927734375, 0.9993261098861694], [20.42578125, 32.5933837890625, 916.5200805664062, 787.1964721679688, 0.9991201758384705], [902.0066528320312, 111.65625, 1035.2879638671875, 701.32861328125, 0.9876241683959961], [1403.113037109375, 282.20465087890625, 1542.410888671875, 557.3927612304688, 0.7517433762550354], [785.5349731445312, 527.7738647460938, 841.3390502929688, 657.4290161132812, 0.8897306323051453], [366.7726745605469, 0.0, 487.1645812988281, 79.29119873046875, 0.9390438199043274]], "segms": [[[[1133, 64], [1126, 66], [1115, 69], [1112, 70], [1106, 73], [1100, 75], [1081, 79], [1074, 81], [1072, 82], [1070, 83], [1065, 88], [1061, 95], [1050, 105], [1044, 111], [1042, 114], [1040, 117], [1039, 120], [1038, 127], [1035, 136], [1031, 144], [1030, 152], [1029, 161], [1028, 169], [1028, 171], [1029, 172], [1031, 181], [1033, 192], [1034, 199], [1035, 207], [1035, 215], [1036, 224], [1039, 236], [1042, 247], [1045, 253], [1048, 256], [1048, 257], [1050, 259], [1063, 271], [1066, 275], [1066, 281], [1065, 290], [1064, 299], [1063, 302], [1056, 309], [1053, 311], [1043, 316], [1040, 318], [1028, 330], [1025, 334], [1022, 338], [1016, 347], [1013, 352], [1010, 359], [1004, 367], [1002, 371], [999, 377], [996, 386], [994, 395], [990, 407], [985, 416], [983, 421], [979, 433], [978, 438], [973, 467], [970, 479], [967, 486], [965, 495], [964, 500], [963, 506], [962, 512], [961, 521], [960, 532], [959, 548], [959, 624], [960, 643], [961, 654], [962, 664], [963, 672], [964, 678], [965, 683], [966, 687], [968, 692], [971, 697], [973, 703], [978, 720], [979, 722], [980, 724], [982, 727], [988, 733], [992, 738], [995, 742], [1000, 750], [1005, 755], [1011, 760], [1023, 769], [1026, 771], [1029, 773], [1031, 774], [1034, 775], [1037, 776], [1041, 777], [1046, 778], [1051, 779], [1066, 779], [1067, 780], [1081, 780], [1120, 779], [1153, 777], [1159, 777], [1171, 779], [1193, 778], [1203, 778], [1218, 780], [1227, 780], [1260, 780], [1268, 780], [1325, 778], [1394, 778], [1429, 775], [1458, 775], [1476, 778], [1497, 778], [1514, 780], [1522, 780], [1523, 779], [1532, 779], [1540, 778], [1542, 776], [1543, 774], [1544, 772], [1547, 763], [1548, 759], [1549, 754], [1550, 749], [1550, 587], [1549, 575], [1548, 565], [1546, 555], [1545, 550], [1543, 543], [1542, 540], [1541, 538], [1540, 536], [1536, 531], [1530, 519], [1527, 509], [1521, 497], [1517, 492], [1514, 486], [1512, 480], [1506, 462], [1504, 451], [1501, 442], [1498, 435], [1493, 426], [1490, 417], [1488, 410], [1484, 390], [1481, 381], [1480, 379], [1479, 377], [1473, 370], [1471, 366], [1467, 358], [1466, 353], [1463, 347], [1461, 344], [1459, 341], [1451, 333], [1449, 330], [1447, 325], [1445, 322], [1441, 318], [1441, 317], [1438, 314], [1427, 304], [1423, 297], [1417, 291], [1413, 289], [1409, 286], [1405, 282], [1402, 278], [1401, 275], [1399, 272], [1392, 265], [1388, 263], [1378, 253], [1376, 249], [1368, 241], [1360, 236], [1356, 233], [1351, 229], [1339, 217], [1336, 215], [1328, 211], [1323, 210], [1310, 208], [1301, 206], [1297, 205], [1293, 204], [1290, 203], [1287, 202], [1285, 201], [1283, 200], [1268, 186], [1260, 178], [1258, 175], [1257, 173], [1256, 171], [1255, 166], [1254, 158], [1254, 149], [1253, 141], [1252, 135], [1251, 130], [1250, 125], [1249, 121], [1248, 117], [1246, 114], [1233, 101], [1231, 98], [1228, 92], [1225, 88], [1221, 84], [1218, 83], [1203, 78], [1199, 77], [1191, 76], [1183, 74], [1176, 71], [1169, 70], [1159, 66], [1155, 65], [1145, 64]]], [[[253, 38], [244, 39], [236, 40], [230, 41], [227, 42], [224, 43], [211, 50], [206, 52], [203, 53], [176, 61], [171, 63], [166, 65], [164, 66], [159, 70], [157, 72], [157, 73], [154, 77], [149, 83], [134, 97], [128, 103], [125, 107], [123, 110], [121, 114], [118, 120], [117, 123], [116, 126], [115, 133], [112, 143], [110, 148], [104, 163], [103, 165], [102, 167], [100, 169], [99, 171], [98, 173], [96, 178], [95, 181], [94, 186], [92, 199], [91, 208], [92, 218], [93, 228], [94, 236], [95, 243], [97, 254], [98, 259], [99, 262], [100, 265], [105, 275], [108, 284], [109, 288], [110, 292], [111, 297], [114, 315], [115, 323], [116, 341], [119, 362], [120, 368], [123, 374], [125, 377], [135, 386], [138, 389], [138, 390], [142, 394], [143, 396], [144, 398], [147, 406], [152, 423], [154, 427], [157, 433], [159, 436], [161, 439], [169, 447], [171, 450], [171, 452], [170, 467], [169, 477], [168, 479], [167, 481], [163, 486], [159, 494], [157, 499], [154, 508], [153, 512], [152, 516], [150, 526], [148, 541], [146, 550], [145, 554], [144, 558], [143, 560], [141, 563], [124, 580], [122, 583], [121, 585], [119, 591], [113, 611], [112, 613], [103, 622], [96, 626], [89, 633], [87, 636], [86, 638], [85, 640], [84, 645], [81, 652], [80, 654], [79, 656], [76, 661], [74, 664], [58, 680], [55, 684], [52, 689], [50, 694], [48, 697], [46, 700], [30, 716], [28, 719], [27, 721], [26, 728], [25, 736], [24, 745], [24, 748], [26, 755], [29, 764], [32, 771], [33, 773], [36, 776], [38, 777], [45, 779], [57, 782], [62, 783], [68, 784], [76, 784], [106, 783], [121, 782], [134, 781], [142, 780], [149, 779], [163, 776], [174, 776], [200, 777], [218, 779], [238, 780], [263, 781], [357, 782], [500, 782], [554, 782], [584, 781], [593, 780], [601, 779], [609, 778], [614, 777], [623, 774], [631, 770], [637, 768], [643, 767], [646, 766], [651, 764], [658, 761], [678, 751], [685, 745], [688, 743], [691, 741], [697, 738], [702, 736], [712, 733], [720, 729], [723, 727], [726, 725], [734, 717], [738, 714], [744, 711], [755, 707], [761, 704], [770, 699], [773, 697], [776, 695], [784, 687], [789, 681], [792, 677], [795, 673], [797, 670], [799, 667], [800, 665], [803, 658], [809, 645], [810, 642], [811, 639], [812, 633], [813, 626], [814, 618], [815, 606], [815, 581], [814, 574], [813, 568], [812, 562], [811, 558], [805, 552], [803, 551], [798, 550], [791, 549], [783, 548], [773, 547], [769, 549], [766, 551], [752, 565], [748, 568], [743, 571], [736, 574], [733, 576], [729, 579], [719, 589], [716, 591], [708, 595], [704, 596], [700, 597], [686, 600], [676, 601], [671, 602], [662, 604], [655, 606], [645, 610], [643, 610], [637, 609], [631, 607], [625, 606], [619, 606], [613, 605], [608, 603], [600, 599], [584, 591], [578, 587], [574, 583], [565, 573], [558, 567], [548, 560], [539, 551], [533, 546], [517, 535], [511, 529], [509, 525], [505, 517], [501, 505], [501, 488], [502, 474], [503, 470], [504, 466], [506, 459], [507, 456], [508, 453], [511, 447], [513, 444], [520, 436], [522, 433], [524, 429], [527, 423], [528, 419], [530, 409], [531, 400], [532, 395], [536, 383], [538, 378], [542, 370], [550, 363], [554, 359], [558, 354], [561, 349], [564, 343], [565, 338], [566, 326], [567, 320], [568, 314], [571, 302], [576, 289], [577, 285], [577, 279], [581, 262], [582, 255], [583, 246], [584, 237], [584, 234], [583, 231], [581, 226], [577, 219], [575, 214], [574, 211], [574, 208], [569, 196], [566, 180], [565, 177], [563, 172], [561, 167], [560, 165], [559, 163], [556, 158], [554, 155], [548, 149], [541, 144], [536, 139], [531, 133], [527, 126], [522, 121], [518, 118], [514, 116], [507, 113], [502, 110], [498, 107], [493, 103], [485, 95], [479, 91], [466, 83], [459, 78], [455, 74], [452, 72], [449, 70], [445, 68], [439, 65], [434, 63], [425, 60], [421, 59], [411, 58], [388, 57], [378, 56], [370, 55], [362, 54], [355, 53], [348, 52], [342, 51], [336, 50], [331, 49], [326, 48], [312, 44], [297, 41], [290, 40], [281, 39], [270, 38]]], [[[941, 111], [926, 112], [922, 113], [920, 114], [915, 119], [912, 124], [909, 129], [908, 131], [907, 134], [906, 137], [905, 141], [904, 150], [903, 171], [903, 201], [904, 221], [906, 236], [907, 251], [907, 367], [908, 390], [910, 406], [911, 423], [912, 446], [912, 494], [913, 515], [914, 528], [917, 550], [918, 562], [919, 577], [920, 596], [920, 617], [922, 638], [922, 676], [924, 692], [926, 693], [931, 694], [939, 694], [944, 693], [947, 692], [949, 691], [949, 678], [948, 668], [947, 664], [945, 660], [944, 656], [943, 652], [943, 636], [944, 624], [947, 612], [947, 583], [948, 576], [949, 569], [950, 564], [954, 554], [955, 550], [955, 546], [956, 542], [960, 531], [966, 507], [970, 481], [974, 468], [975, 459], [981, 436], [983, 429], [986, 423], [992, 402], [994, 394], [997, 382], [998, 377], [999, 366], [1002, 355], [1002, 332], [999, 316], [999, 309], [1002, 302], [1003, 295], [1004, 293], [1005, 291], [1009, 287], [1010, 284], [1012, 278], [1013, 274], [1013, 269], [1014, 267], [1016, 265], [1017, 263], [1019, 258], [1020, 253], [1021, 247], [1022, 238], [1023, 229], [1024, 217], [1025, 205], [1025, 174], [1024, 159], [1023, 154], [1022, 149], [1021, 145], [1020, 141], [1019, 139], [1018, 137], [1016, 134], [1009, 127], [1006, 126], [1004, 125], [1003, 124], [1003, 123], [1001, 120], [999, 118], [990, 114], [986, 113], [980, 112], [966, 111]]], [[[1437, 282], [1425, 283], [1422, 284], [1420, 285], [1418, 286], [1417, 287], [1416, 289], [1415, 292], [1415, 297], [1422, 304], [1423, 307], [1425, 309], [1428, 311], [1436, 317], [1439, 319], [1442, 320], [1445, 322], [1448, 325], [1448, 326], [1451, 329], [1454, 330], [1458, 334], [1460, 338], [1461, 339], [1466, 343], [1467, 345], [1469, 350], [1471, 355], [1472, 358], [1474, 366], [1477, 377], [1478, 382], [1479, 391], [1480, 395], [1482, 398], [1484, 405], [1484, 411], [1485, 416], [1488, 423], [1489, 430], [1493, 438], [1499, 458], [1502, 463], [1503, 467], [1503, 470], [1504, 474], [1509, 483], [1511, 488], [1512, 491], [1514, 501], [1515, 504], [1517, 507], [1517, 510], [1518, 513], [1523, 522], [1524, 528], [1525, 531], [1526, 533], [1534, 542], [1535, 545], [1540, 543], [1540, 539], [1541, 538], [1542, 525], [1542, 478], [1541, 417], [1539, 358], [1539, 350], [1540, 306], [1539, 299], [1538, 295], [1537, 293], [1535, 290], [1533, 290], [1532, 289], [1528, 289], [1527, 288], [1524, 288], [1520, 290], [1495, 290], [1478, 287], [1473, 285], [1466, 284], [1460, 283], [1450, 282]]], [[[807, 529], [805, 530], [803, 532], [800, 536], [799, 539], [798, 540], [797, 540], [795, 542], [794, 546], [793, 550], [793, 566], [792, 568], [791, 570], [786, 575], [786, 601], [787, 604], [791, 608], [793, 611], [794, 613], [795, 617], [796, 633], [796, 651], [797, 653], [799, 656], [801, 656], [802, 657], [818, 656], [830, 657], [833, 657], [834, 656], [839, 655], [840, 651], [840, 573], [839, 544], [838, 541], [836, 539], [835, 534], [833, 531], [831, 530], [826, 529]]], [[[387, 2], [379, 3], [375, 4], [371, 5], [370, 6], [368, 6], [368, 8], [367, 9], [367, 35], [368, 36], [368, 44], [372, 47], [377, 49], [382, 53], [391, 57], [396, 59], [399, 60], [402, 60], [411, 64], [429, 69], [436, 73], [448, 76], [452, 76], [455, 75], [459, 72], [461, 72], [465, 70], [466, 67], [470, 66], [472, 66], [474, 64], [474, 62], [477, 59], [481, 60], [483, 57], [484, 55], [485, 53], [486, 45], [486, 21], [485, 10], [482, 6], [481, 5], [479, 4], [474, 3], [464, 2]]]], "names": ["person", "person", "person", "car", "cellphone", "clock"], "width": 1920, "height": 797} -------------------------------------------------------------------------------- /examples/VQA/README.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning on VQA 2 | 3 | ## Requirements 4 | 5 | We use `Pytorch-Lightning` to fine-tuning the pre-trained VLEModel on VQA. To speedup training, we use `DeepSpeed`. The main packages are as follows: 6 | 7 | ```bash 8 | pytorch_lightning==1.5.10 9 | transformers==4.26.0 10 | deepspeed==0.7.7 11 | Pillow==8.1.0 12 | tqdm==4.64.1 13 | ipdb==0.13.4 14 | numpy==1.21.6 15 | einops==0.3.0 16 | pyarrow==2.0.0 17 | sacred==0.8.2 18 | pandas==1.1.5 19 | timm==0.4.12 20 | ftfy 21 | torchvision~=0.8.2 22 | torch~=1.7.1 23 | ``` 24 | 25 | ## Dataset Preparation for VQAv2 26 | 27 | Download the VQAv2 dataset from [VQA official site](https://visualqa.org/download.html), including COCO [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip), [2015 test images](http://images.cocodataset.org/zips/test2015.zip), annotations ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip)), and questions ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip), [test](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip)). 28 | 29 | Please unzip and organize the dataset as follows: 30 | 31 | root 32 | ├── train2014 33 | │ ├── COCO_train2014_000000000009.jpg 34 | | └── ... 35 | ├── val2014 36 | | ├── COCO_val2014_000000000042.jpg 37 | | └── ... 38 | ├── test2015 39 | | ├── COCO_test2015_000000000001.jpg 40 | | └── ... 41 | ├── v2_OpenEnded_mscoco_train2014_questions.json 42 | ├── v2_OpenEnded_mscoco_val2014_questions.json 43 | ├── v2_OpenEnded_mscoco_test2015_questions.json 44 | ├── v2_OpenEnded_mscoco_test-dev2015_questions.json 45 | ├── v2_mscoco_train2014_annotations.json 46 | └── v2_mscoco_val2014_annotations.json 47 | 48 | We use `pyarrow` to serialize the datasets, the conversion script is `write_vqa.py`. 49 | Please replace the value of `id2label` in VLE Model's config with the generated mapping file `label2answer.json` (example: `config.json` of `hfl/vle-base-for-vqa`). 50 | 51 | ## Fine-tuning VLE on VQAv2 52 | 53 | Hyperparameters for training are set in `vqa_train_config.json`. 54 | 55 | Move the training related files to the same level of the directory as `models`, as follows: 56 | 57 | root 58 | ├── models 59 | │ └── VLE 60 | | └── ... 61 | ├── run_vqav2_ft.py 62 | ├── vqav2_datamodule.py 63 | └── vqav2_train_module.py 64 | 65 | Specify the config file through `--train_config_file` and run the train script `run_vqav2_ft.py`. Here is an example: 66 | 67 | ```bash 68 | export MASTER_ADDR=$DIST_0_IP 69 | export MASTER_PORT=$DIST_0_PORT 70 | export NODE_RANK=$DIST_RANK 71 | python run_vqav2_ft.py --train_config_file=vqa_train_config.json 72 | ``` 73 | 74 | ## Postprocess the checkpoint 75 | 76 | After training, we convert the saved checkpoint, so that it can be loaded by `VLEModel`. 77 | 78 | We first convert the deepspeed saved checkpoint to a pytorch checkpoint. The convert script is `zero_to_fp32.py`. If you didn't use `DeepSpeed` when training the model, this step could be skipped. 79 | 80 | ```bash 81 | python zero_to_fp32.py 82 | # for example: 83 | python zero_to_fp32.py ./logs/VQAv2_seed0_from_vle-base-ft-vqa/version_0/checkpoints/epoch\=0-step\=0.ckpt step\=0.ckpt global_step0 84 | ``` 85 | 86 | Then, we convert the parameters' names to the same format as `VLEModel`. The convert script is `convert_checkpoint_after_ft.py`. 87 | -------------------------------------------------------------------------------- /examples/VQA/convert_checkpoint_after_ft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pl_ckpt = "./step=0.ckpt" 4 | output_dir = './' 5 | hf_ckpt = {} 6 | 7 | loaded = torch.load(pl_ckpt,map_location='cpu') 8 | if 'state_dict' in loaded: 9 | sd = loaded['state_dict'] 10 | else: 11 | sd = loaded 12 | for k,v in sd.items(): 13 | if k.startswith('model.'): 14 | new_key = k.replace('model.', '') 15 | hf_ckpt[new_key] = v 16 | elif k.startswith('module.model.'): 17 | new_key = k.replace('module.model.', '') 18 | hf_ckpt[new_key] = v 19 | else: 20 | print("unhandled keys:",k) 21 | 22 | torch.save(hf_ckpt, output_dir + 'pytorch_model.bin') -------------------------------------------------------------------------------- /examples/VQA/run_vqav2_ft.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.plugins import DeepSpeedPlugin 3 | import torch 4 | import json 5 | import copy 6 | import os 7 | os.environ["NCCL_DEBUG"] = "INFO" 8 | import argparse 9 | 10 | from vqav2_datamodule import VQAv2DataModule 11 | from vqav2_train_module import VLEForVQA_PL 12 | from models.VLE.processing_vle import VLEProcessor 13 | 14 | 15 | def main(_config): 16 | _config = copy.deepcopy(_config) 17 | pl.seed_everything(_config["seed"], workers=True) 18 | vle_processor = VLEProcessor.from_pretrained(_config["model_dir"]) 19 | if vle_processor.image_processor.size["shortest_edge"] != _config["image_size"]: 20 | vle_processor.image_processor.crop_size["height"] = _config["image_size"] 21 | vle_processor.image_processor.crop_size["width"] = _config["image_size"] 22 | vle_processor.image_processor.size["shortest_edge"] = _config["image_size"] 23 | dm = VQAv2DataModule(vle_processor, _config) 24 | 25 | model = VLEForVQA_PL(_config) 26 | exp_name = 'VQAv2' 27 | 28 | os.makedirs(_config["log_dir"], exist_ok=True) 29 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 30 | save_top_k=_config["save_top_k"], 31 | verbose=True, 32 | monitor="val/the_metric", 33 | mode="max", 34 | save_last=True, 35 | save_weights_only=_config["save_weights_only"] 36 | ) 37 | logger = pl.loggers.TensorBoardLogger( 38 | _config["log_dir"], 39 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["model_dir"].split("/")[-1]}', 40 | ) 41 | 42 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 43 | callbacks = [checkpoint_callback, lr_callback] 44 | 45 | num_gpus = ( 46 | _config["num_gpus"] 47 | if isinstance(_config["num_gpus"], int) 48 | else len(_config["num_gpus"]) 49 | ) 50 | 51 | grad_steps = max(_config["batch_size"] // ( 52 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 53 | ), 1) 54 | 55 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else -1 56 | 57 | if _config["use_deepspeed"]: 58 | deepspeed_config = _config["deepspeed_config"] 59 | if _config["precision"] != "16" and _config["precision"] != 16: 60 | deepspeed_config["fp16"]["enabled"] = False 61 | if _config["precision"] == "bf16": 62 | deepspeed_config["bf16"] = {"enabled": True} 63 | ds_plugin = DeepSpeedPlugin(config=deepspeed_config) 64 | strategy = ds_plugin 65 | else: 66 | strategy = "ddp" 67 | 68 | trainer = pl.Trainer( 69 | gpus=_config["num_gpus"], 70 | num_nodes=_config["num_nodes"], 71 | precision=_config["precision"], 72 | accelerator="gpu", 73 | strategy=strategy, 74 | benchmark=True, 75 | deterministic=False, 76 | max_epochs=_config["max_epoch"] if max_steps is -1 else 1000, 77 | max_steps=max_steps, 78 | callbacks=callbacks, 79 | logger=logger, 80 | prepare_data_per_node=True, 81 | replace_sampler_ddp=False, 82 | accumulate_grad_batches=grad_steps, 83 | log_every_n_steps=10, 84 | flush_logs_every_n_steps=10, 85 | resume_from_checkpoint=_config["resume_from"], 86 | weights_summary="top", 87 | fast_dev_run=_config["fast_dev_run"], 88 | val_check_interval=_config["val_check_interval"], 89 | num_sanity_val_steps=0, 90 | ) 91 | 92 | if not _config["test_only"]: 93 | trainer.fit(model, datamodule=dm) 94 | torch.cuda.empty_cache() 95 | trainer.test(ckpt_path="best", datamodule=dm) 96 | else: 97 | trainer.test(model, datamodule=dm) 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser(description="Args for finetuning VLE on VQAv2.") 102 | parser.add_argument("--train_config_file", type=str, default="vqa_train_config.json", help="Config file for training.") 103 | args = parser.parse_args() 104 | train_config_file = args.train_config_file 105 | train_config = json.load(open(train_config_file, 'r')) 106 | 107 | main(train_config) -------------------------------------------------------------------------------- /examples/VQA/vqa_train_config.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "seed": 0, 4 | "model_dir": "hfl/vle-base", 5 | "data_root": "./data/vqav2/vqav2_arrow", 6 | "num_workers": 4, 7 | "batch_size": 16, 8 | "per_gpu_batchsize": 4, 9 | "image_size": 384, 10 | "max_text_len": 50, 11 | "draw_false_image": 0, 12 | "draw_false_text": 0, 13 | "image_only": false, 14 | "log_dir": "logs", 15 | "num_gpus": 1, 16 | "num_nodes": 1, 17 | "max_epoch": 10, 18 | "max_steps": -1, 19 | "precision": 16, 20 | "resume_from": "", 21 | "fast_dev_run": false, 22 | "val_check_interval": 1.0, 23 | "save_top_k": 1, 24 | "save_weights_only": true, 25 | "test_only": false, 26 | "learning_rate":1e-5, 27 | "weight_decay": 0.01, 28 | "lr_mult_head": 50, 29 | "lr_mult_cross_modal": 5, 30 | "end_lr": 0, 31 | "decay_power": 1, 32 | "optim_type": "adamw", 33 | "warmup_steps": 0.1, 34 | "use_deepspeed": true, 35 | "deepspeed_config":{ 36 | "fp16": { 37 | "enabled": true, 38 | "initial_scale_power": 12, 39 | "min_loss_scale": 2e-10, 40 | "loss_scale_window": 128 41 | }, 42 | "zero_optimization": { 43 | "stage": 2, 44 | "reduce_bucket_size": 5e7, 45 | "allgather_bucket_size": 1.25e9, 46 | "overlap_comm": true, 47 | "contiguous_gradients": true 48 | }, 49 | "zero_allow_untested_optimizer": true 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /examples/VQA/vqav2_datamodule.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pyarrow as pa 3 | import os 4 | from copy import deepcopy 5 | from PIL import Image 6 | Image.MAX_IMAGE_PIXELS = 1000000000 7 | 8 | import torch 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | 14 | class BaseDataModule(LightningDataModule): 15 | def __init__(self, feature_processor, _config): 16 | super().__init__() 17 | 18 | self.data_dir = _config["data_root"] 19 | 20 | self.num_workers = _config["num_workers"] 21 | self.batch_size = _config["per_gpu_batchsize"] 22 | self.eval_batch_size = self.batch_size 23 | 24 | self.image_size = _config["image_size"] 25 | self.max_text_len = _config["max_text_len"] 26 | self.draw_false_image = _config["draw_false_image"] 27 | self.draw_false_text = _config["draw_false_text"] 28 | self.image_only = _config["image_only"] 29 | 30 | self.feature_processor = feature_processor 31 | self.vocab_size = self.feature_processor.tokenizer.vocab_size 32 | self.setup_flag = False 33 | self.config = _config 34 | 35 | @property 36 | def dataset_cls(self): 37 | raise NotImplementedError("return tuple of dataset class") 38 | 39 | @property 40 | def dataset_name(self): 41 | raise NotImplementedError("return name of dataset") 42 | 43 | def set_train_dataset(self): 44 | self.train_dataset = self.dataset_cls( 45 | self.data_dir, 46 | split="train", 47 | image_size=self.image_size, 48 | max_text_len=self.max_text_len, 49 | draw_false_image=self.draw_false_image, 50 | draw_false_text=self.draw_false_text, 51 | image_only=self.image_only, 52 | ) 53 | 54 | def set_val_dataset(self): 55 | self.val_dataset = self.dataset_cls( 56 | self.data_dir, 57 | split="val", 58 | image_size=self.image_size, 59 | max_text_len=self.max_text_len, 60 | draw_false_image=self.draw_false_image, 61 | draw_false_text=self.draw_false_text, 62 | image_only=self.image_only, 63 | ) 64 | 65 | 66 | def set_test_dataset(self): 67 | self.test_dataset = self.dataset_cls( 68 | self.data_dir, 69 | split="test", 70 | image_size=self.image_size, 71 | max_text_len=self.max_text_len, 72 | draw_false_image=self.draw_false_image, 73 | draw_false_text=self.draw_false_text, 74 | image_only=self.image_only, 75 | ) 76 | 77 | def setup(self, stage): 78 | if not self.setup_flag: 79 | self.set_train_dataset() 80 | self.set_val_dataset() 81 | self.set_test_dataset() 82 | 83 | self.train_dataset.feature_processor = self.feature_processor 84 | self.val_dataset.feature_processor = self.feature_processor 85 | self.test_dataset.feature_processor = self.feature_processor 86 | 87 | self.setup_flag = True 88 | 89 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 90 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False) 91 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 92 | 93 | def train_dataloader(self): 94 | loader = DataLoader( 95 | self.train_dataset, 96 | batch_size=self.batch_size, 97 | sampler=self.train_sampler, 98 | num_workers=self.num_workers, 99 | collate_fn=self.train_dataset.collate, 100 | ) 101 | return loader 102 | 103 | def val_dataloader(self): 104 | loader = DataLoader( 105 | self.val_dataset, 106 | batch_size=self.batch_size, 107 | sampler=self.val_sampler, 108 | num_workers=self.num_workers, 109 | collate_fn=self.val_dataset.collate, 110 | ) 111 | return loader 112 | 113 | def test_dataloader(self): 114 | loader = DataLoader( 115 | self.test_dataset, 116 | batch_size=self.batch_size, 117 | sampler=self.test_sampler, 118 | num_workers=self.num_workers, 119 | collate_fn=self.test_dataset.collate, 120 | ) 121 | return loader 122 | 123 | 124 | class VQAv2DataModule(BaseDataModule): 125 | def __init__(self, *args, **kwargs): 126 | super().__init__(*args, **kwargs) 127 | 128 | @property 129 | def dataset_cls(self): 130 | return VQAv2Dataset 131 | 132 | @property 133 | def dataset_name(self): 134 | return "vqa" 135 | 136 | def setup(self, stage): 137 | super().setup(stage) 138 | 139 | 140 | class BaseDataset(torch.utils.data.Dataset): 141 | def __init__( 142 | self, 143 | data_dir: str, 144 | image_size: int, 145 | names: list, 146 | text_column_name: str = "", 147 | remove_duplicate=True, 148 | max_text_len=40, 149 | draw_false_image=0, 150 | draw_false_text=0, 151 | image_only=False, 152 | tokenizer=None, 153 | ): 154 | """ 155 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data 156 | text_column_name : pyarrow table column name that has list of strings as elements 157 | """ 158 | super().__init__() 159 | 160 | self.text_column_name = text_column_name 161 | self.names = names 162 | self.max_text_len = max_text_len 163 | self.draw_false_image = draw_false_image 164 | self.draw_false_text = draw_false_text 165 | self.image_only = image_only 166 | self.data_dir = data_dir 167 | print(names) 168 | if len(names) != 0: 169 | tables = [ 170 | pa.ipc.RecordBatchFileReader( 171 | pa.memory_map(f"{data_dir}/{name}.arrow", "r") 172 | ).read_all() 173 | if os.path.isfile(f"{data_dir}/{name}.arrow") 174 | else print(f"{data_dir}/{name}.arrow" + " not found.") 175 | for name in names 176 | ] 177 | 178 | self.table_names = list() 179 | for i, name in enumerate(names): 180 | self.table_names += [name] * len(tables[i]) 181 | 182 | self.table = pa.concat_tables(tables, promote=True) 183 | if text_column_name != "": 184 | self.text_column_name = text_column_name 185 | self.all_texts = self.table[text_column_name].to_pandas().tolist() 186 | self.all_texts = ( 187 | [list(set(texts)) for texts in self.all_texts] 188 | if remove_duplicate 189 | else self.all_texts 190 | ) 191 | else: 192 | self.all_texts = list() 193 | else: 194 | self.all_texts = list() 195 | 196 | self.index_mapper = dict() 197 | 198 | if text_column_name != "" and not self.image_only: 199 | j = 0 200 | for i, texts in enumerate(self.all_texts): 201 | for _j in range(len(texts)): 202 | self.index_mapper[j] = (i, _j) 203 | j += 1 204 | else: 205 | for i in range(len(self.table)): 206 | self.index_mapper[i] = (i, None) 207 | 208 | @property 209 | def corpus(self): 210 | return [text for texts in self.all_texts for text in texts] 211 | 212 | def __len__(self): 213 | return len(self.index_mapper) 214 | 215 | def get_raw_image(self, index, image_key="image"): 216 | index, caption_index = self.index_mapper[index] 217 | image_bytes = io.BytesIO(deepcopy(self.table[image_key][index]).as_py()) 218 | image_bytes.seek(0) 219 | return Image.open(image_bytes).convert("RGBA") 220 | 221 | def collate(self, batch): 222 | batch_size = len(batch) 223 | keys = set([key for b in batch for key in b.keys()]) 224 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 225 | inputs_batch = {k:torch.concat([dict_batch["inputs"][i][k] for i in range(batch_size)]) for k in dict_batch["inputs"][0].keys()} 226 | dict_batch["inputs"] = inputs_batch 227 | 228 | return dict_batch 229 | 230 | 231 | class VQAv2Dataset(BaseDataset): 232 | def __init__(self, *args, split="", **kwargs): 233 | assert split in ["train", "val", "test"] 234 | self.split = split 235 | 236 | if split == "train": 237 | names = ["vqav2_train", "vqav2_trainable_val"] 238 | # names = ["vqav2_rest_val"] 239 | elif split == "val": 240 | names = ["vqav2_rest_val"] 241 | elif split == "test": 242 | names = ["vqav2_rest_val"] 243 | 244 | super().__init__( 245 | *args, 246 | **kwargs, 247 | names=names, 248 | text_column_name="questions", 249 | remove_duplicate=False, 250 | ) 251 | 252 | def __getitem__(self, index): 253 | image = self.get_raw_image(index) 254 | image_index, question_index = self.index_mapper[index] 255 | text = self.all_texts[image_index][question_index] 256 | model_inputs = self.feature_processor(text=text, images=image, return_tensors="pt",padding="max_length", max_length=self.max_text_len) 257 | 258 | qid = self.table["question_id"][image_index][question_index].as_py() 259 | 260 | if self.split != "test": 261 | answers = self.table["answers"][image_index][question_index].as_py() 262 | labels = self.table["answer_labels"][image_index][question_index].as_py() 263 | scores = self.table["answer_scores"][image_index][question_index].as_py() 264 | else: 265 | answers = list() 266 | labels = list() 267 | scores = list() 268 | 269 | return { 270 | "inputs": model_inputs, 271 | "text": text, 272 | "vqa_answer": answers, 273 | "vqa_labels": labels, 274 | "vqa_scores": scores, 275 | "qid": qid, 276 | } -------------------------------------------------------------------------------- /examples/VQA/vqav2_train_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | import json 5 | import os 6 | import glob 7 | 8 | from torchmetrics.metric import Metric 9 | from transformers.optimization import AdamW 10 | from transformers import ( 11 | get_polynomial_decay_schedule_with_warmup, 12 | get_cosine_schedule_with_warmup, 13 | ) 14 | 15 | from models.VLE import VLEForVQA 16 | from models.VLE.modeling_vle import extend_position_embedding 17 | 18 | 19 | class Scalar(Metric): 20 | def __init__(self, dist_sync_on_step=False): 21 | super().__init__(dist_sync_on_step=dist_sync_on_step) 22 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 23 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 24 | 25 | def update(self, scalar): 26 | if isinstance(scalar, torch.Tensor): 27 | scalar = scalar.detach().to(self.scalar.device) 28 | else: 29 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 30 | self.scalar += scalar.item() 31 | self.total += 1 32 | 33 | def compute(self): 34 | if self.total.item() == 0: 35 | return 0 36 | return self.scalar.item() / self.total.item() 37 | 38 | class VQAScore(Metric): 39 | def __init__(self, dist_sync_on_step=False): 40 | super().__init__(dist_sync_on_step=dist_sync_on_step) 41 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 42 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 43 | 44 | def update(self, logits, target): 45 | logits, target = ( 46 | logits.detach().float().to(self.score.device), 47 | target.detach().float().to(self.score.device), 48 | ) 49 | logits = torch.max(logits, 1)[1] 50 | one_hots = torch.zeros(*target.size()).to(target) 51 | one_hots.scatter_(1, logits.view(-1, 1), 1) 52 | scores = one_hots * target 53 | 54 | self.score += scores.sum().item() 55 | self.total += len(logits) 56 | 57 | def compute(self): 58 | return self.score / self.total 59 | 60 | 61 | class VLEForVQA_PL(pl.LightningModule): 62 | def __init__(self, config): 63 | super().__init__() 64 | self.save_hyperparameters() 65 | 66 | self.model = VLEForVQA.from_pretrained(config["model_dir"]) 67 | 68 | if config["image_size"] != self.model.config.vision_config.image_size: 69 | patch_size = self.model.config.vision_config.patch_size 70 | position_length_after = (config["image_size"]//self.model.config.vision_config.patch_size)**2 + 1 71 | position_embed_dim = self.model.vle.vision_model.vision_model.embeddings.position_embedding.embedding_dim 72 | 73 | new_state_dict = extend_position_embedding(self.model.state_dict(), patch_size, config["image_size"]) 74 | self.model.vle.vision_model.vision_model.embeddings.position_embedding = nn.Embedding(position_length_after, position_embed_dim) 75 | self.model.vle.vision_model.vision_model.embeddings.register_buffer("position_ids", torch.arange(position_length_after).expand((1, -1))) 76 | self.model.load_state_dict(new_state_dict) 77 | 78 | for split in ["train", "val"]: 79 | setattr(self, f"{split}_vqa_score", VQAScore()) 80 | setattr(self, f"{split}_vqa_loss", Scalar()) 81 | 82 | def forward(self, batch): 83 | ret = dict() 84 | model_inputs = batch["inputs"] 85 | model_outputs = self.model(**model_inputs,vqa_labels=batch["vqa_labels"], vqa_scores=batch["vqa_scores"], return_loss=True) 86 | vqa_logits = model_outputs["logits"] 87 | vqa_loss = model_outputs["loss"] 88 | 89 | vqa_targets = torch.zeros(vqa_logits.size()[:2]).to(vqa_logits.device) 90 | vqa_labels = batch["vqa_labels"] 91 | vqa_scores = batch["vqa_scores"] 92 | for i, (_label, _score) in enumerate(zip(vqa_labels, vqa_scores)): 93 | for l, s in zip(_label, _score): 94 | vqa_targets[i, l] = s 95 | 96 | ret = { 97 | "vqa_loss": vqa_loss, 98 | "vqa_logits": vqa_logits, 99 | "vqa_targets": vqa_targets, 100 | "vqa_labels": vqa_labels, 101 | "vqa_scores": vqa_scores, 102 | } 103 | 104 | phase = "train" if self.training else "val" 105 | loss = getattr(self, f"{phase}_vqa_loss")(ret["vqa_loss"]) 106 | score = getattr(self, f"{phase}_vqa_score")( 107 | ret["vqa_logits"], ret["vqa_targets"] 108 | ) 109 | self.log(f"vqa/{phase}/loss", loss, batch_size=self.hparams.config["per_gpu_batchsize"]) 110 | self.log(f"vqa/{phase}/score", score, batch_size=self.hparams.config["per_gpu_batchsize"]) 111 | 112 | return ret 113 | 114 | def training_step(self, batch, batch_idx): 115 | output = self(batch) 116 | total_loss = sum([v for k, v in output.items() if "loss" in k]) 117 | 118 | return total_loss 119 | 120 | def training_epoch_end(self, outs): 121 | self.epoch_wrapup() 122 | 123 | def validation_step(self, batch, batch_idx): 124 | output = self(batch) 125 | 126 | def validation_epoch_end(self, outs): 127 | self.epoch_wrapup() 128 | 129 | def test_step(self, batch, batch_idx): 130 | output = self(batch) 131 | ret = dict() 132 | # update vqa answer 133 | id2label = self.model.config.id2label 134 | vqa_logits = output["vqa_logits"] 135 | vqa_preds = vqa_logits.argmax(dim=-1) 136 | vqa_preds = [id2label[pred.item()] for pred in vqa_preds] 137 | questions = batch["text"] 138 | qids = batch["qid"] 139 | ret.update({"qids": qids, "questions": questions, "preds": vqa_preds}) 140 | 141 | return ret 142 | 143 | def test_epoch_end(self, outs): 144 | model_name = self.hparams.config["model_dir"].split("/")[-1] 145 | save_dir = self.trainer.logger.log_dir 146 | rank = torch.distributed.get_rank() 147 | if rank != 0: 148 | save_dir_id = int(save_dir.split('_')[-1]) - 1 149 | save_dir = '_'.join(save_dir.split('_')[:-1]+[str(save_dir_id)]) 150 | qids, preds = list(), list() 151 | for out in outs: 152 | qids += out["qids"] 153 | preds += out["preds"] 154 | 155 | rets = list() 156 | for qid, pred in zip(qids, preds): 157 | rets.append({"question_id": qid, "answer": pred}) 158 | with open(os.path.join(save_dir, f"vqa_submit_{rank}.json"), "w") as fp: 159 | json.dump(rets, fp, indent=4) 160 | 161 | torch.distributed.barrier() 162 | 163 | if rank == 0: 164 | jsons = list() 165 | paths = list(glob.glob(os.path.join(save_dir,"vqa_submit_*.json"))) 166 | for path in paths: 167 | with open(path, "r") as fp: 168 | jsons += json.load(fp) 169 | os.makedirs(os.path.join(save_dir,"result"), exist_ok=True) 170 | with open(os.path.join(save_dir, f"result/vqa_submit_{model_name}.json"), "w") as fp: 171 | json.dump(jsons, fp, indent=4) 172 | 173 | torch.distributed.barrier() 174 | os.remove(os.path.join(save_dir, f"vqa_submit_{rank}.json")) 175 | 176 | self.epoch_wrapup(test_mode=True) 177 | 178 | def configure_optimizers(self): 179 | lr = self.hparams.config["learning_rate"] 180 | wd = self.hparams.config["weight_decay"] 181 | 182 | no_decay = [ 183 | "bias", 184 | "LayerNorm.bias", 185 | "LayerNorm.weight", 186 | "norm.bias", 187 | "norm.weight", 188 | "norm1.bias", 189 | "norm1.weight", 190 | "norm2.bias", 191 | "norm2.weight", 192 | ] 193 | head_names = ["vqa_classifier"] 194 | cross_modal_names = ['cross_modal'] 195 | lr_mult_head = self.hparams.config["lr_mult_head"] 196 | lr_mult_cross_modal = self.hparams.config["lr_mult_cross_modal"] 197 | end_lr = self.hparams.config["end_lr"] 198 | decay_power = self.hparams.config["decay_power"] 199 | optim_type = self.hparams.config["optim_type"] 200 | all_grad_parameters = [(n,p) for n,p in self.named_parameters()] 201 | optimizer_grouped_parameters = [ 202 | { 203 | "params": [ 204 | p 205 | for n, p in all_grad_parameters 206 | if not any(nd in n for nd in no_decay) 207 | and not any(bb in n for bb in head_names) 208 | and not any(ht in n for ht in cross_modal_names) 209 | ], 210 | "weight_decay": wd, 211 | "lr": lr, 212 | }, 213 | { 214 | "params": [ 215 | p 216 | for n, p in all_grad_parameters 217 | if any(nd in n for nd in no_decay) 218 | and not any(bb in n for bb in head_names) 219 | and not any(ht in n for ht in cross_modal_names) 220 | ], 221 | "weight_decay": 0.0, 222 | "lr": lr, 223 | }, 224 | { 225 | "params": [ 226 | p 227 | for n, p in all_grad_parameters 228 | if not any(nd in n for nd in no_decay) 229 | and any(bb in n for bb in head_names) 230 | and not any(ht in n for ht in cross_modal_names) 231 | ], 232 | "weight_decay": wd, 233 | "lr": lr * lr_mult_head, 234 | }, 235 | { 236 | "params": [ 237 | p 238 | for n, p in all_grad_parameters 239 | if any(nd in n for nd in no_decay) 240 | and any(bb in n for bb in head_names) 241 | and not any(ht in n for ht in cross_modal_names) 242 | ], 243 | "weight_decay": 0.0, 244 | "lr": lr * lr_mult_head, 245 | }, 246 | { 247 | "params": [ 248 | p 249 | for n, p in all_grad_parameters 250 | if not any(nd in n for nd in no_decay) 251 | and not any(bb in n for bb in head_names) 252 | and any(ht in n for ht in cross_modal_names) 253 | ], 254 | "weight_decay": wd, 255 | "lr": lr * lr_mult_cross_modal, 256 | }, 257 | { 258 | "params": [ 259 | p 260 | for n, p in all_grad_parameters 261 | if any(nd in n for nd in no_decay) 262 | and not any(bb in n for bb in head_names) 263 | and any(ht in n for ht in cross_modal_names) 264 | ], 265 | "weight_decay": 0.0, 266 | "lr": lr * lr_mult_cross_modal, 267 | }, 268 | ] 269 | 270 | if optim_type == "adamw": 271 | optimizer = AdamW( 272 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98) 273 | ) 274 | elif optim_type == "adam": 275 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 276 | elif optim_type == "sgd": 277 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9) 278 | 279 | if self.trainer.max_steps is -1: 280 | max_steps = ( 281 | len(self.trainer.datamodule.train_dataloader()) 282 | * self.trainer.max_epochs 283 | // self.trainer.accumulate_grad_batches 284 | ) 285 | else: 286 | max_steps = self.trainer.max_steps 287 | 288 | warmup_steps = self.hparams.config["warmup_steps"] 289 | if isinstance(self.hparams.config["warmup_steps"], float): 290 | warmup_steps = int(max_steps * warmup_steps) 291 | 292 | if decay_power == "cosine": 293 | scheduler = get_cosine_schedule_with_warmup( 294 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, 295 | ) 296 | else: 297 | scheduler = get_polynomial_decay_schedule_with_warmup( 298 | optimizer, 299 | num_warmup_steps=warmup_steps, 300 | num_training_steps=max_steps, 301 | lr_end=end_lr, 302 | power=decay_power, 303 | ) 304 | 305 | sched = {"scheduler": scheduler, "interval": "step"} 306 | 307 | return ( 308 | [optimizer], 309 | [sched], 310 | ) 311 | 312 | def epoch_wrapup(self, test_mode=False): 313 | phase = "train" if self.training else "val" 314 | loss_name = 'vqa' 315 | value = getattr(self, f"{phase}_{loss_name}_score").compute() 316 | self.log(f"{loss_name}/{phase}/score_epoch", value) 317 | getattr(self, f"{phase}_{loss_name}_score").reset() 318 | self.log( 319 | f"{loss_name}/{phase}/loss_epoch", 320 | getattr(self, f"{phase}_{loss_name}_loss").compute(), 321 | ) 322 | getattr(self, f"{phase}_{loss_name}_loss").reset() 323 | 324 | self.log(f"{phase}/the_metric", value) -------------------------------------------------------------------------------- /examples/VQA/write_vqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | 11 | import re 12 | 13 | contractions = { 14 | "aint": "ain't", 15 | "arent": "aren't", 16 | "cant": "can't", 17 | "couldve": "could've", 18 | "couldnt": "couldn't", 19 | "couldn'tve": "couldn't've", 20 | "couldnt've": "couldn't've", 21 | "didnt": "didn't", 22 | "doesnt": "doesn't", 23 | "dont": "don't", 24 | "hadnt": "hadn't", 25 | "hadnt've": "hadn't've", 26 | "hadn'tve": "hadn't've", 27 | "hasnt": "hasn't", 28 | "havent": "haven't", 29 | "hed": "he'd", 30 | "hed've": "he'd've", 31 | "he'dve": "he'd've", 32 | "hes": "he's", 33 | "howd": "how'd", 34 | "howll": "how'll", 35 | "hows": "how's", 36 | "Id've": "I'd've", 37 | "I'dve": "I'd've", 38 | "Im": "I'm", 39 | "Ive": "I've", 40 | "isnt": "isn't", 41 | "itd": "it'd", 42 | "itd've": "it'd've", 43 | "it'dve": "it'd've", 44 | "itll": "it'll", 45 | "let's": "let's", 46 | "maam": "ma'am", 47 | "mightnt": "mightn't", 48 | "mightnt've": "mightn't've", 49 | "mightn'tve": "mightn't've", 50 | "mightve": "might've", 51 | "mustnt": "mustn't", 52 | "mustve": "must've", 53 | "neednt": "needn't", 54 | "notve": "not've", 55 | "oclock": "o'clock", 56 | "oughtnt": "oughtn't", 57 | "ow's'at": "'ow's'at", 58 | "'ows'at": "'ow's'at", 59 | "'ow'sat": "'ow's'at", 60 | "shant": "shan't", 61 | "shed've": "she'd've", 62 | "she'dve": "she'd've", 63 | "she's": "she's", 64 | "shouldve": "should've", 65 | "shouldnt": "shouldn't", 66 | "shouldnt've": "shouldn't've", 67 | "shouldn'tve": "shouldn't've", 68 | "somebody'd": "somebodyd", 69 | "somebodyd've": "somebody'd've", 70 | "somebody'dve": "somebody'd've", 71 | "somebodyll": "somebody'll", 72 | "somebodys": "somebody's", 73 | "someoned": "someone'd", 74 | "someoned've": "someone'd've", 75 | "someone'dve": "someone'd've", 76 | "someonell": "someone'll", 77 | "someones": "someone's", 78 | "somethingd": "something'd", 79 | "somethingd've": "something'd've", 80 | "something'dve": "something'd've", 81 | "somethingll": "something'll", 82 | "thats": "that's", 83 | "thered": "there'd", 84 | "thered've": "there'd've", 85 | "there'dve": "there'd've", 86 | "therere": "there're", 87 | "theres": "there's", 88 | "theyd": "they'd", 89 | "theyd've": "they'd've", 90 | "they'dve": "they'd've", 91 | "theyll": "they'll", 92 | "theyre": "they're", 93 | "theyve": "they've", 94 | "twas": "'twas", 95 | "wasnt": "wasn't", 96 | "wed've": "we'd've", 97 | "we'dve": "we'd've", 98 | "weve": "we've", 99 | "werent": "weren't", 100 | "whatll": "what'll", 101 | "whatre": "what're", 102 | "whats": "what's", 103 | "whatve": "what've", 104 | "whens": "when's", 105 | "whered": "where'd", 106 | "wheres": "where's", 107 | "whereve": "where've", 108 | "whod": "who'd", 109 | "whod've": "who'd've", 110 | "who'dve": "who'd've", 111 | "wholl": "who'll", 112 | "whos": "who's", 113 | "whove": "who've", 114 | "whyll": "why'll", 115 | "whyre": "why're", 116 | "whys": "why's", 117 | "wont": "won't", 118 | "wouldve": "would've", 119 | "wouldnt": "wouldn't", 120 | "wouldnt've": "wouldn't've", 121 | "wouldn'tve": "wouldn't've", 122 | "yall": "y'all", 123 | "yall'll": "y'all'll", 124 | "y'allll": "y'all'll", 125 | "yall'd've": "y'all'd've", 126 | "y'alld've": "y'all'd've", 127 | "y'all'dve": "y'all'd've", 128 | "youd": "you'd", 129 | "youd've": "you'd've", 130 | "you'dve": "you'd've", 131 | "youll": "you'll", 132 | "youre": "you're", 133 | "youve": "you've", 134 | } 135 | 136 | manual_map = { 137 | "none": "0", 138 | "zero": "0", 139 | "one": "1", 140 | "two": "2", 141 | "three": "3", 142 | "four": "4", 143 | "five": "5", 144 | "six": "6", 145 | "seven": "7", 146 | "eight": "8", 147 | "nine": "9", 148 | "ten": "10", 149 | } 150 | articles = ["a", "an", "the"] 151 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 152 | comma_strip = re.compile("(\d)(\,)(\d)") 153 | punct = [ 154 | ";", 155 | r"/", 156 | "[", 157 | "]", 158 | '"', 159 | "{", 160 | "}", 161 | "(", 162 | ")", 163 | "=", 164 | "+", 165 | "\\", 166 | "_", 167 | "-", 168 | ">", 169 | "<", 170 | "@", 171 | "`", 172 | ",", 173 | "?", 174 | "!", 175 | ] 176 | 177 | 178 | def normalize_word(token): 179 | _token = token 180 | for p in punct: 181 | if (p + " " in token or " " + p in token) or ( 182 | re.search(comma_strip, token) != None 183 | ): 184 | _token = _token.replace(p, "") 185 | else: 186 | _token = _token.replace(p, " ") 187 | token = period_strip.sub("", _token, re.UNICODE) 188 | 189 | _token = [] 190 | temp = token.lower().split() 191 | for word in temp: 192 | word = manual_map.setdefault(word, word) 193 | if word not in articles: 194 | _token.append(word) 195 | for i, word in enumerate(_token): 196 | if word in contractions: 197 | _token[i] = contractions[word] 198 | token = " ".join(_token) 199 | token = token.replace(",", "") 200 | return token 201 | 202 | 203 | def get_score(occurences): 204 | if occurences == 0: 205 | return 0.0 206 | elif occurences == 1: 207 | return 0.3 208 | elif occurences == 2: 209 | return 0.6 210 | elif occurences == 3: 211 | return 0.9 212 | else: 213 | return 1.0 214 | 215 | 216 | def path2rest(path, split, annotations, label2ans): 217 | iid = int(path.split("/")[-1].split("_")[-1][:-4]) 218 | 219 | with open(path, "rb") as fp: 220 | binary = fp.read() 221 | 222 | _annot = annotations[split][iid] 223 | _annot = list(_annot.items()) 224 | qids, qas = [a[0] for a in _annot], [a[1] for a in _annot] 225 | questions = [qa[0] for qa in qas] 226 | answers = [qa[1] for qa in qas] if "test" not in split else list(list()) 227 | answer_labels = ( 228 | [a["labels"] for a in answers] if "test" not in split else list(list()) 229 | ) 230 | answer_scores = ( 231 | [a["scores"] for a in answers] if "test" not in split else list(list()) 232 | ) 233 | answers = ( 234 | [[label2ans[l] for l in al] for al in answer_labels] 235 | if "test" not in split 236 | else list(list()) 237 | ) 238 | 239 | return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split] 240 | 241 | 242 | def make_arrow(root, dataset_root): 243 | with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp: 244 | questions_train2014 = json.load(fp)["questions"] 245 | with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp: 246 | questions_val2014 = json.load(fp)["questions"] 247 | with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp: 248 | questions_test2015 = json.load(fp)["questions"] 249 | with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp: 250 | questions_test_dev2015 = json.load(fp)["questions"] 251 | 252 | with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp: 253 | annotations_train2014 = json.load(fp)["annotations"] 254 | with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp: 255 | annotations_val2014 = json.load(fp)["annotations"] 256 | 257 | annotations = dict() 258 | 259 | for split, questions in zip( 260 | ["train", "val", "test", "test-dev"], 261 | [ 262 | questions_train2014, 263 | questions_val2014, 264 | questions_test2015, 265 | questions_test_dev2015, 266 | ], 267 | ): 268 | _annot = defaultdict(dict) 269 | for q in tqdm(questions): 270 | _annot[q["image_id"]][q["question_id"]] = [q["question"]] 271 | 272 | annotations[split] = _annot 273 | 274 | all_major_answers = list() 275 | 276 | for split, annots in zip( 277 | ["train", "val"], [annotations_train2014, annotations_val2014], 278 | ): 279 | _annot = annotations[split] 280 | for q in tqdm(annots): 281 | all_major_answers.append(q["multiple_choice_answer"]) 282 | 283 | all_major_answers = [normalize_word(word) for word in tqdm(all_major_answers)] 284 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} 285 | ans2label = {k: i for i, k in enumerate(counter.keys())} 286 | label2ans = list(counter.keys()) 287 | 288 | with open(os.path.join(dataset_root, "answer2label.json"), 'w', encoding='utf8') as f: 289 | json.dump(ans2label, f) 290 | with open(os.path.join(dataset_root, "label2answer.json"), 'w', encoding='utf8') as f: 291 | json.dump(label2ans, f) 292 | 293 | for split, annots in zip( 294 | ["train", "val"], [annotations_train2014, annotations_val2014], 295 | ): 296 | _annot = annotations[split] 297 | for q in tqdm(annots): 298 | answers = q["answers"] 299 | answer_count = {} 300 | for answer in answers: 301 | answer_ = answer["answer"] 302 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 303 | 304 | labels = [] 305 | scores = [] 306 | for answer in answer_count: 307 | if answer not in ans2label: 308 | continue 309 | labels.append(ans2label[answer]) 310 | score = get_score(answer_count[answer]) 311 | scores.append(score) 312 | 313 | _annot[q["image_id"]][q["question_id"]].append( 314 | {"labels": labels, "scores": scores,} 315 | ) 316 | 317 | for split in ["train", "val"]: 318 | filtered_annot = dict() 319 | for ik, iv in annotations[split].items(): 320 | new_q = dict() 321 | for qk, qv in iv.items(): 322 | if len(qv[1]["labels"]) != 0: 323 | new_q[qk] = qv 324 | if len(new_q) != 0: 325 | filtered_annot[ik] = new_q 326 | annotations[split] = filtered_annot 327 | 328 | for split in [ 329 | "train", 330 | "val", 331 | "test", 332 | "test-dev", 333 | ]: 334 | annot = annotations[split] 335 | split_name = { 336 | "train": "train2014", 337 | "val": "val2014", 338 | "test": "test2015", 339 | "test-dev": "test2015", 340 | }[split] 341 | paths = list(glob(f"{root}/{split_name}/*.jpg")) 342 | random.shuffle(paths) 343 | annot_paths = [ 344 | path 345 | for path in paths 346 | if int(path.split("/")[-1].split("_")[-1][:-4]) in annot 347 | ] 348 | 349 | if len(paths) == len(annot_paths): 350 | print("all images have caption annotations") 351 | else: 352 | print("not all images have caption annotations") 353 | print( 354 | len(paths), len(annot_paths), len(annot), 355 | ) 356 | 357 | bs = [ 358 | path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths) 359 | ] 360 | 361 | dataframe = pd.DataFrame( 362 | bs, 363 | columns=[ 364 | "image", 365 | "questions", 366 | "answers", 367 | "answer_labels", 368 | "answer_scores", 369 | "image_id", 370 | "question_id", 371 | "split", 372 | ], 373 | ) 374 | 375 | table = pa.Table.from_pandas(dataframe) 376 | 377 | os.makedirs(dataset_root, exist_ok=True) 378 | with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink: 379 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 380 | writer.write_table(table) 381 | 382 | table = pa.ipc.RecordBatchFileReader( 383 | pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r") 384 | ).read_all() 385 | 386 | pdtable = table.to_pandas() 387 | 388 | df1 = pdtable[:-1000] 389 | df2 = pdtable[-1000:] 390 | 391 | df1 = pa.Table.from_pandas(df1) 392 | df2 = pa.Table.from_pandas(df2) 393 | 394 | with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink: 395 | with pa.RecordBatchFileWriter(sink, df1.schema) as writer: 396 | writer.write_table(df1) 397 | 398 | with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink: 399 | with pa.RecordBatchFileWriter(sink, df2.schema) as writer: 400 | writer.write_table(df2) 401 | 402 | 403 | if __name__ == "__main__": 404 | root = "./data/vqav2" # directory where store the raw dataset 405 | dataset_root = "./data/vqav2/vqav2_arrow" # directory where output the pyarrow format dataset 406 | make_arrow(root, dataset_root) -------------------------------------------------------------------------------- /examples/VQA/zero_to_fp32.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets 4 | # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in 5 | # the future. Once extracted, the weights don't require DeepSpeed and can be used in any 6 | # application. 7 | # 8 | # example: python zero_to_fp32.py . pytorch_model.bin 9 | 10 | import argparse 11 | import torch 12 | import glob 13 | import math 14 | import os 15 | import re 16 | from collections import OrderedDict 17 | 18 | # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with 19 | # DeepSpeed data structures it has to be available in the current python environment. 20 | from deepspeed.utils import logger 21 | from deepspeed.checkpoint.constants import (DS_VERSION, 22 | OPTIMIZER_STATE_DICT, 23 | SINGLE_PARTITION_OF_FP32_GROUPS, 24 | FP32_FLAT_GROUPS, 25 | ZERO_STAGE, 26 | PARTITION_COUNT, 27 | PARAM_SHAPES, 28 | BUFFER_NAMES) 29 | 30 | debug = 0 31 | 32 | # load to cpu 33 | device = torch.device('cpu') 34 | 35 | 36 | def atoi(text): 37 | return int(text) if text.isdigit() else text 38 | 39 | 40 | def natural_keys(text): 41 | ''' 42 | alist.sort(key=natural_keys) sorts in human order 43 | http://nedbatchelder.com/blog/200712/human_sorting.html 44 | (See Toothy's implementation in the comments) 45 | ''' 46 | return [atoi(c) for c in re.split(r'(\d+)', text)] 47 | 48 | 49 | def get_model_state_file(checkpoint_dir, zero_stage): 50 | if not os.path.isdir(checkpoint_dir): 51 | raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") 52 | 53 | # there should be only one file 54 | if zero_stage == 2: 55 | file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") 56 | elif zero_stage == 3: 57 | file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") 58 | 59 | if not os.path.exists(file): 60 | raise FileNotFoundError(f"can't find model states file at '{file}'") 61 | 62 | return file 63 | 64 | 65 | def get_optim_files(checkpoint_dir): 66 | # XXX: need to test that this simple glob rule works for multi-node setup too 67 | optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, 68 | "*_optim_states.pt")), 69 | key=natural_keys) 70 | 71 | if len(optim_files) == 0: 72 | raise FileNotFoundError( 73 | f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") 74 | 75 | return optim_files 76 | 77 | 78 | def parse_model_state(file): 79 | state_dict = torch.load(file, map_location=device) 80 | 81 | if BUFFER_NAMES not in state_dict: 82 | raise ValueError(f"{file} is not a model state checkpoint") 83 | buffer_names = state_dict[BUFFER_NAMES] 84 | if debug: 85 | print("Found buffers:", buffer_names) 86 | 87 | # recover just the buffers while restoring them to fp32 if they were saved in fp16 88 | buffers = { 89 | k: v.float() 90 | for k, 91 | v in state_dict["module"].items() if k in buffer_names 92 | } 93 | param_shapes = state_dict[PARAM_SHAPES] 94 | 95 | ds_version = state_dict.get(DS_VERSION, None) 96 | 97 | return buffers, param_shapes, ds_version 98 | 99 | 100 | def parse_optim_states(files, ds_checkpoint_dir): 101 | 102 | total_files = len(files) 103 | state_dicts = [] 104 | for f in files: 105 | state_dicts.append(torch.load(f, map_location=device)) 106 | 107 | if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: 108 | raise ValueError(f"{files[0]} is not a zero checkpoint") 109 | zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] 110 | world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] 111 | 112 | # For ZeRO-2 each param group can have different partition_count as data parallelism for expert 113 | # parameters can be different from data parallelism for non-expert parameters. So we can just 114 | # use the max of the partition_count to get the dp world_size. 115 | 116 | if type(world_size) is list: 117 | world_size = max(world_size) 118 | 119 | if world_size != total_files: 120 | raise ValueError( 121 | f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " 122 | "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." 123 | ) 124 | 125 | # the groups are named differently in each stage 126 | if zero_stage == 2: 127 | fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS 128 | elif zero_stage == 3: 129 | fp32_groups_key = FP32_FLAT_GROUPS 130 | else: 131 | raise ValueError(f"unknown zero stage {zero_stage}") 132 | 133 | if zero_stage == 2: 134 | fp32_flat_groups = [ 135 | state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] 136 | for i in range(len(state_dicts)) 137 | ] 138 | elif zero_stage == 3: 139 | # if there is more than one param group, there will be multiple flattened tensors - one 140 | # flattened tensor per group - for simplicity merge them into a single tensor 141 | # 142 | # XXX: could make the script more memory efficient for when there are multiple groups - it 143 | # will require matching the sub-lists of param_shapes for each param group flattened tensor 144 | 145 | fp32_flat_groups = [ 146 | torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 147 | 0) for i in range(len(state_dicts)) 148 | ] 149 | 150 | return zero_stage, world_size, fp32_flat_groups 151 | 152 | 153 | def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): 154 | """ 155 | Returns fp32 state_dict reconstructed from ds checkpoint 156 | 157 | Args: 158 | - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) 159 | 160 | """ 161 | print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") 162 | 163 | optim_files = get_optim_files(ds_checkpoint_dir) 164 | zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) 165 | print( 166 | f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") 167 | 168 | model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) 169 | buffers, param_shapes, ds_version = parse_model_state(model_file) 170 | print(f'Parsing checkpoint created by deepspeed=={ds_version}') 171 | 172 | if zero_stage == 2: 173 | return _get_fp32_state_dict_from_zero2_checkpoint(world_size, 174 | param_shapes, 175 | fp32_flat_groups, 176 | buffers) 177 | elif zero_stage == 3: 178 | return _get_fp32_state_dict_from_zero3_checkpoint(world_size, 179 | param_shapes, 180 | fp32_flat_groups, 181 | buffers) 182 | 183 | 184 | def _get_fp32_state_dict_from_zero2_checkpoint(world_size, 185 | param_shapes, 186 | fp32_flat_groups, 187 | buffers): 188 | 189 | # Reconstruction protocol: 190 | # 191 | # XXX: document this 192 | 193 | if debug: 194 | for i in range(world_size): 195 | for j in range(len(fp32_flat_groups[0])): 196 | print( 197 | f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") 198 | 199 | # XXX: memory usage doubles here (zero2) 200 | num_param_groups = len(fp32_flat_groups[0]) 201 | merged_single_partition_of_fp32_groups = [] 202 | for i in range(num_param_groups): 203 | merged_partitions = [sd[i] for sd in fp32_flat_groups] 204 | full_single_fp32_vector = torch.cat(merged_partitions, 0) 205 | merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) 206 | avail_numel = sum([ 207 | full_single_fp32_vector.numel() 208 | for full_single_fp32_vector in merged_single_partition_of_fp32_groups 209 | ]) 210 | 211 | if debug: 212 | wanted_params = sum([len(shapes) for shapes in param_shapes]) 213 | wanted_numel = sum( 214 | [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) 215 | # not asserting if there is a mismatch due to possible padding 216 | print(f"Have {avail_numel} numels to process.") 217 | print(f"Need {wanted_numel} numels in {wanted_params} params.") 218 | 219 | state_dict = OrderedDict() 220 | 221 | # buffers 222 | state_dict.update(buffers) 223 | if debug: 224 | print(f"added {len(buffers)} buffers") 225 | 226 | # params 227 | # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support 228 | # out-of-core computing solution 229 | total_numel = 0 230 | total_params = 0 231 | for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): 232 | offset = 0 233 | avail_numel = full_single_fp32_vector.numel() 234 | for name, shape in shapes.items(): 235 | 236 | unpartitioned_numel = shape.numel() 237 | total_numel += unpartitioned_numel 238 | total_params += 1 239 | 240 | if debug: 241 | print( 242 | f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " 243 | ) 244 | state_dict[name] = full_single_fp32_vector.narrow( 245 | 0, 246 | offset, 247 | unpartitioned_numel).view(shape) 248 | offset += unpartitioned_numel 249 | 250 | # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and 251 | # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex 252 | # paddings performed in the code it's almost impossible to predict the exact numbers w/o the 253 | # live optimizer object, so we are checking that the numbers are within the right range 254 | align_to = 2 * world_size 255 | 256 | def zero2_align(x): 257 | return align_to * math.ceil(x / align_to) 258 | 259 | if debug: 260 | print(f"original offset={offset}, avail_numel={avail_numel}") 261 | 262 | offset = zero2_align(offset) 263 | avail_numel = zero2_align(avail_numel) 264 | 265 | if debug: 266 | print(f"aligned offset={offset}, avail_numel={avail_numel}") 267 | 268 | # Sanity check 269 | if offset != avail_numel: 270 | raise ValueError( 271 | f"consumed {offset} numels out of {avail_numel} - something is wrong") 272 | 273 | print( 274 | f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" 275 | ) 276 | 277 | return state_dict 278 | 279 | 280 | def zero3_partitioned_param_info(unpartitioned_numel, world_size): 281 | remainder = unpartitioned_numel % world_size 282 | padding_numel = (world_size - remainder) if remainder else 0 283 | partitioned_numel = math.ceil(unpartitioned_numel / world_size) 284 | return partitioned_numel, padding_numel 285 | 286 | 287 | def _get_fp32_state_dict_from_zero3_checkpoint(world_size, 288 | param_shapes, 289 | fp32_flat_groups, 290 | buffers): 291 | 292 | # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each 293 | # param, re-consolidating each param, while dealing with padding if any 294 | 295 | avail_numel = fp32_flat_groups[0].numel() * world_size 296 | # merge list of dicts, preserving order 297 | param_shapes = {k: v for d in param_shapes for k, v in d.items()} 298 | 299 | if debug: 300 | for i in range(world_size): 301 | print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") 302 | 303 | wanted_params = len(param_shapes) 304 | wanted_numel = sum(shape.numel() for shape in param_shapes.values()) 305 | # not asserting if there is a mismatch due to possible padding 306 | print(f"Have {avail_numel} numels to process.") 307 | print(f"Need {wanted_numel} numels in {wanted_params} params.") 308 | 309 | state_dict = OrderedDict() 310 | 311 | # buffers 312 | state_dict.update(buffers) 313 | if debug: 314 | print(f"added {len(buffers)} buffers") 315 | 316 | # params 317 | # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support 318 | # out-of-core computing solution 319 | offset = 0 320 | total_numel = 0 321 | total_params = 0 322 | for name, shape in param_shapes.items(): 323 | 324 | unpartitioned_numel = shape.numel() 325 | total_numel += unpartitioned_numel 326 | total_params += 1 327 | 328 | partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) 329 | 330 | if debug: 331 | print( 332 | f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" 333 | ) 334 | 335 | # XXX: memory usage doubles here 336 | state_dict[name] = torch.cat( 337 | tuple(fp32_flat_groups[i].narrow(0, 338 | offset, 339 | partitioned_numel) 340 | for i in range(world_size)), 341 | 0).narrow(0, 342 | 0, 343 | unpartitioned_numel).view(shape) 344 | offset += partitioned_numel 345 | 346 | offset *= world_size 347 | 348 | # Sanity check 349 | if offset != avail_numel: 350 | raise ValueError( 351 | f"consumed {offset} numels out of {avail_numel} - something is wrong") 352 | 353 | print( 354 | f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" 355 | ) 356 | 357 | return state_dict 358 | 359 | 360 | def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): 361 | """ 362 | Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with 363 | ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example 364 | via a model hub. 365 | 366 | Args: 367 | - ``checkpoint_dir``: path to the desired checkpoint folder 368 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` 369 | 370 | Returns: 371 | - pytorch ``state_dict`` 372 | 373 | Note: this approach may not work if your application doesn't have sufficient free CPU memory and 374 | you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with 375 | the checkpoint. 376 | 377 | A typical usage might be :: 378 | 379 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint 380 | # do the training and checkpoint saving 381 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu 382 | model = model.cpu() # move to cpu 383 | model.load_state_dict(state_dict) 384 | # submit to model hub or save the model to share with others 385 | 386 | In this example the ``model`` will no longer be usable in the deepspeed context of the same 387 | application. i.e. you will need to re-initialize the deepspeed engine, since 388 | ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. 389 | 390 | If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. 391 | 392 | """ 393 | if tag is None: 394 | latest_path = os.path.join(checkpoint_dir, 'latest') 395 | if os.path.isfile(latest_path): 396 | with open(latest_path, 'r') as fd: 397 | tag = fd.read().strip() 398 | else: 399 | raise ValueError(f"Unable to find 'latest' file at {latest_path}") 400 | 401 | ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) 402 | 403 | if not os.path.isdir(ds_checkpoint_dir): 404 | raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") 405 | 406 | return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) 407 | 408 | 409 | def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): 410 | """ 411 | Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be 412 | loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. 413 | 414 | Args: 415 | - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) 416 | - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) 417 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` 418 | """ 419 | 420 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) 421 | print(f"Saving fp32 state dict to {output_file}") 422 | torch.save(state_dict, output_file) 423 | 424 | 425 | def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): 426 | """ 427 | 1. Put the provided model to cpu 428 | 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` 429 | 3. Load it into the provided model 430 | 431 | Args: 432 | - ``model``: the model object to update 433 | - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) 434 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` 435 | 436 | Returns: 437 | - ``model`: modified model 438 | 439 | Make sure you have plenty of CPU memory available before you call this function. If you don't 440 | have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it 441 | conveniently placed for you in the checkpoint folder. 442 | 443 | A typical usage might be :: 444 | 445 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint 446 | model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) 447 | # submit to model hub or save the model to share with others 448 | 449 | Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context 450 | of the same application. i.e. you will need to re-initialize the deepspeed engine, since 451 | ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. 452 | 453 | """ 454 | logger.info(f"Extracting fp32 weights") 455 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) 456 | 457 | logger.info(f"Overwriting model with fp32 weights") 458 | model = model.cpu() 459 | model.load_state_dict(state_dict, strict=False) 460 | 461 | return model 462 | 463 | 464 | if __name__ == "__main__": 465 | 466 | parser = argparse.ArgumentParser() 467 | parser.add_argument( 468 | "checkpoint_dir", 469 | type=str, 470 | help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 471 | parser.add_argument( 472 | "output_file", 473 | type=str, 474 | help= 475 | "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" 476 | ) 477 | parser.add_argument( 478 | "tag", 479 | type=str, 480 | help= 481 | "checkpoint tag used as a unique identifier for checkpoint (e.g. global_step14)" 482 | ) 483 | parser.add_argument("-d", "--debug", action='store_true', help="enable debug") 484 | args = parser.parse_args() 485 | 486 | debug = args.debug 487 | 488 | print(args.checkpoint_dir) 489 | print(args.output_file) 490 | print(args.tag) 491 | 492 | convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, args.tag) 493 | -------------------------------------------------------------------------------- /models/VLE/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_vle import ( 2 | VLEModel, 3 | VLEForVQA, 4 | VLEForITM, 5 | VLEForMLM, 6 | VLEForPBC, 7 | VLEForVCRQ2A, 8 | VLEForVCRQA2R 9 | ) 10 | 11 | from .configuration_vle import VLEConfig 12 | from .processing_vle import VLEProcessor 13 | from .pipeline_vle import VLEForVQAPipeline, VLEForITMPipeline, VLEForPBCPipeline, VLEForVCRQA2RPipeline, VLEForVCRQ2APipeline 14 | -------------------------------------------------------------------------------- /models/VLE/configuration_vle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ VLE model configuration""" 16 | 17 | import copy 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | from transformers.models.auto.configuration_auto import AutoConfig 22 | from transformers.models.clip.configuration_clip import CLIPVisionConfig 23 | from typing import Union, Dict 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class VLEConfig(PretrainedConfig): 29 | 30 | model_type = "vle" 31 | is_composition = True 32 | 33 | def __init__( 34 | self, 35 | text_config: Union[PretrainedConfig, Dict], 36 | vision_config: Union[PretrainedConfig, Dict], 37 | num_token_types=2, 38 | hidden_size=768, 39 | num_hidden_layers=6, 40 | num_attention_heads=12, 41 | intermediate_size=3072, 42 | hidden_act="gelu", 43 | hidden_dropout_prob=0.1, 44 | attention_probs_dropout_prob=0.1, 45 | initializer_range=0.02, 46 | layer_norm_eps=1e-12, 47 | classifier_dropout=None, 48 | **kwargs): 49 | super().__init__(**kwargs) 50 | 51 | if not isinstance(text_config,PretrainedConfig): 52 | text_model_type = text_config.pop('model_type') 53 | text_config = AutoConfig.for_model(text_model_type, **text_config) 54 | self.text_config = text_config 55 | 56 | if not isinstance(vision_config, PretrainedConfig): 57 | vision_model_type = vision_config.pop('model_type') 58 | if vision_model_type == "clip": 59 | vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config 60 | elif vision_model_type == "clip_vision_model": 61 | vision_config = CLIPVisionConfig(**vision_config) 62 | else: 63 | vision_config = AutoConfig.for_model(vision_model_type, **vision_config) 64 | self.vision_config = vision_config 65 | else: 66 | vision_model_type = vision_config.model_type 67 | if vision_model_type== "clip": 68 | vision_config = vision_config.vision_config 69 | self.vision_config = vision_config 70 | 71 | 72 | 73 | # co-attention 74 | self.num_token_types=num_token_types 75 | self.hidden_size=hidden_size 76 | self.num_hidden_layers=num_hidden_layers 77 | self.num_attention_heads=num_attention_heads 78 | self.intermediate_size=intermediate_size 79 | self.hidden_act=hidden_act 80 | self.hidden_dropout_prob=hidden_dropout_prob 81 | self.attention_probs_dropout_prob=attention_probs_dropout_prob 82 | self.initializer_range=initializer_range 83 | self.layer_norm_eps=layer_norm_eps 84 | self.classifier_dropout=classifier_dropout 85 | 86 | 87 | def to_dict(self): 88 | """ 89 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. 90 | 91 | Returns: 92 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 93 | """ 94 | output = copy.deepcopy(self.__dict__) 95 | output["vision_config"] = self.vision_config.to_dict() 96 | output["text_config"] = self.text_config.to_dict() 97 | output["model_type"] = self.__class__.model_type 98 | return output 99 | -------------------------------------------------------------------------------- /models/VLE/modeling_vle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch VLE model.""" 16 | 17 | 18 | from typing import Optional, Tuple, Union 19 | 20 | import torch 21 | from torch import nn 22 | import torch.nn.functional as F 23 | from torch.nn.utils.rnn import pad_sequence 24 | 25 | from transformers.modeling_utils import PreTrainedModel 26 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput 27 | from transformers.models.auto.configuration_auto import AutoConfig 28 | from transformers.models.auto.modeling_auto import AutoModel 29 | 30 | from transformers.models.bert.modeling_bert import BertAttention, BertIntermediate, BertOutput, apply_chunking_to_forward 31 | from transformers.models.clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel 32 | from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2OnlyMLMHead 33 | from .configuration_vle import VLEConfig 34 | from dataclasses import dataclass 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | _CONFIG_FOR_DOC = "VLEConfig" 39 | 40 | 41 | @dataclass 42 | class VLEModelOutput(ModelOutput): 43 | 44 | pooler_output: torch.FloatTensor = None 45 | text_embeds: torch.FloatTensor = None 46 | image_embeds: torch.FloatTensor = None 47 | 48 | 49 | @dataclass 50 | class VLEForITMOutput(ModelOutput): 51 | 52 | loss: torch.FloatTensor = None 53 | logits: torch.FloatTensor = None 54 | 55 | @dataclass 56 | class VLEForPBCOutput(ModelOutput): 57 | 58 | loss: torch.FloatTensor = None 59 | logits: torch.FloatTensor = None 60 | 61 | @dataclass 62 | class VLEForMLMOutput(ModelOutput): 63 | 64 | loss: torch.FloatTensor = None 65 | logits: torch.FloatTensor = None 66 | 67 | @dataclass 68 | class VLEForVQAOutput(ModelOutput): 69 | 70 | loss : torch.FloatTensor = None 71 | logits: torch.FloatTensor = None 72 | 73 | @dataclass 74 | class VLEForVCRQ2AOutput(ModelOutput): 75 | 76 | loss : torch.FloatTensor = None 77 | logits: torch.FloatTensor = None 78 | 79 | @dataclass 80 | class VLEForVCRQA2ROutput(ModelOutput): 81 | 82 | loss : torch.FloatTensor = None 83 | logits: torch.FloatTensor = None 84 | 85 | class ITMHead(nn.Module): 86 | def __init__(self, hidden_size): 87 | super().__init__() 88 | self.fc = nn.Linear(hidden_size, 2) 89 | 90 | def forward(self, x): 91 | x = self.fc(x) 92 | return x 93 | 94 | 95 | def extend_position_embedding(state_dict, patch_size, after): 96 | """ 97 | modify state_dict in-place for longer position embeddings 98 | """ 99 | keys = {} 100 | for k,v in state_dict.items(): 101 | if k.endswith('vision_model.embeddings.position_embedding.weight'): 102 | assert k not in keys 103 | keys['pe'] = (k,v) 104 | if k.endswith('vision_model.embeddings.position_ids'): 105 | assert k not in keys 106 | keys['pi'] = (k,v) 107 | 108 | pe_weight = keys['pe'][1] 109 | position_length_before = pe_weight.shape[0] 110 | embed_dim = pe_weight.shape[1] 111 | grid_before = int((position_length_before - 1)**(1/2)) 112 | position_length_after = (after // patch_size) ** 2 + 1 113 | grid_after = int((position_length_after - 1)**(1/2)) 114 | 115 | new_pe_weight = pe_weight[1:].reshape((grid_before,grid_before,-1)) 116 | new_pe_weight = torch.nn.functional.interpolate( 117 | new_pe_weight.permute(2,0,1).unsqueeze(0), 118 | size = (grid_after,grid_after), mode = 'bicubic') 119 | new_pe_weight = new_pe_weight.squeeze(0).permute(1,2,0).reshape(grid_after*grid_after, -1) 120 | new_pe_weight = torch.cat((pe_weight[0:1],new_pe_weight), dim=0) 121 | assert new_pe_weight.shape == (grid_after*grid_after + 1, embed_dim) 122 | 123 | state_dict[keys['pe'][0]] = new_pe_weight 124 | state_dict[keys['pi'][0]] = torch.arange(grid_after*grid_after + 1).unsqueeze(0) 125 | return state_dict 126 | 127 | 128 | class Pooler(nn.Module): 129 | def __init__(self, hidden_size): 130 | super().__init__() 131 | self.dense = nn.Linear(hidden_size, hidden_size) 132 | self.activation = nn.Tanh() 133 | 134 | def forward(self, hidden_states): 135 | first_token_tensor = hidden_states[:, 0] 136 | pooled_output = self.dense(first_token_tensor) 137 | pooled_output = self.activation(pooled_output) 138 | return pooled_output 139 | 140 | 141 | class BertCrossLayer(nn.Module): 142 | def __init__(self, config): 143 | super().__init__() 144 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 145 | self.seq_len_dim = 1 146 | self.attention = BertAttention(config) 147 | self.is_decoder = config.is_decoder 148 | self.add_cross_attention = config.add_cross_attention 149 | self.crossattention = BertAttention(config) 150 | self.intermediate = BertIntermediate(config) 151 | self.output = BertOutput(config) 152 | 153 | def forward( 154 | self, 155 | hidden_states, 156 | encoder_hidden_states, 157 | attention_mask=None, 158 | encoder_attention_mask=None, 159 | output_attentions=False, 160 | ): 161 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 162 | self_attn_past_key_value = None #past_key_value[:2] if past_key_value is not None else None 163 | self_attention_outputs = self.attention( 164 | hidden_states, 165 | attention_mask, 166 | head_mask=None, 167 | output_attentions=output_attentions, 168 | past_key_value=None, 169 | ) 170 | attention_output = self_attention_outputs[0] 171 | 172 | # if decoder, the last output is tuple of self-attn cache 173 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 174 | 175 | cross_attn_present_key_value = None 176 | cross_attention_outputs = self.crossattention( 177 | attention_output, 178 | attention_mask, 179 | None, 180 | encoder_hidden_states, 181 | encoder_attention_mask, 182 | None, 183 | output_attentions, 184 | ) 185 | attention_output = cross_attention_outputs[0] 186 | outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights 187 | 188 | layer_output = apply_chunking_to_forward( 189 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 190 | ) 191 | outputs = (layer_output,) + outputs 192 | 193 | return outputs 194 | 195 | def feed_forward_chunk(self, attention_output): 196 | intermediate_output = self.intermediate(attention_output) 197 | layer_output = self.output(intermediate_output, attention_output) 198 | return layer_output 199 | 200 | 201 | class VLEPreTrainedModel(PreTrainedModel): 202 | """ 203 | An abstract class to handle weights initialization. 204 | """ 205 | 206 | config_class = VLEConfig 207 | base_model_prefix = "vle" 208 | supports_gradient_checkpointing = False 209 | _keys_to_ignore_on_load_missing = [r"position_ids"] 210 | 211 | def _init_weights(self, module): 212 | """Initialize the weights""" 213 | if isinstance(module, nn.Linear): 214 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 215 | if module.bias is not None: 216 | module.bias.data.zero_() 217 | elif isinstance(module, nn.Embedding): 218 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 219 | if module.padding_idx is not None: 220 | module.weight.data[module.padding_idx].zero_() 221 | elif isinstance(module, nn.LayerNorm): 222 | module.bias.data.zero_() 223 | module.weight.data.fill_(1.0) 224 | # no supported 225 | # def _set_gradient_checkpointing(self, module, value=False): 226 | # if isinstance(module, BertEncoder): 227 | # module.gradient_checkpointing = value 228 | 229 | class VLEModel(VLEPreTrainedModel): 230 | def __init__( 231 | self, 232 | config: Optional[VLEConfig] = None, 233 | vision_model: Optional[PreTrainedModel] = None, 234 | text_model: Optional[PreTrainedModel] = None, 235 | ): 236 | 237 | if config is None and (vision_model is None or text_model is None): 238 | raise ValueError("Either a configuration or an vision and a text model has to be provided") 239 | 240 | if config is None: 241 | config = VLEConfig(text_config=text_model.config, vision_config=vision_model.config) 242 | else: 243 | if not isinstance(config, self.config_class): 244 | raise ValueError(f"config: {config} has to be of type {self.config_class}") 245 | 246 | # initialize with config 247 | super().__init__(config) 248 | 249 | if vision_model is None: 250 | if isinstance(config.vision_config, CLIPVisionConfig): 251 | vision_model = CLIPVisionModel(config.vision_config) 252 | else: 253 | vision_model = AutoModel.from_config(config.vision_config) 254 | 255 | if text_model is None: 256 | text_model = AutoModel.from_config(config.text_config) 257 | 258 | self.vision_model = vision_model 259 | self.text_model = text_model 260 | 261 | # make sure that the individual model's config refers to the shared config 262 | # so that the updates to the config will be synced 263 | self.vision_model.config = self.config.vision_config 264 | self.text_model.config = self.config.text_config 265 | 266 | self.vision_embed_dim = config.vision_config.hidden_size 267 | self.text_embed_dim = config.text_config.hidden_size 268 | self.coattention_dim = config.hidden_size 269 | 270 | # add projection layers 271 | self.text_projection_layer = nn.Linear(self.text_embed_dim, self.coattention_dim) 272 | self.image_projection_layer = nn.Linear(self.vision_embed_dim, self.coattention_dim) 273 | 274 | #self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) 275 | self.token_type_embeddings = nn.Embedding(config.num_token_types, config.hidden_size) 276 | 277 | self.cross_modal_image_layers = nn.ModuleList([BertCrossLayer(config) for _ in range(config.num_hidden_layers)]) 278 | self.cross_modal_text_layers = nn.ModuleList([BertCrossLayer(config) for _ in range(config.num_hidden_layers)]) 279 | self.cross_modal_image_pooler = Pooler(config.hidden_size) 280 | self.cross_modal_text_pooler = Pooler(config.hidden_size) 281 | 282 | # Initialize weights and apply final processing 283 | self.token_type_embeddings.apply(self._init_weights) 284 | self.cross_modal_image_layers.apply(self._init_weights) 285 | self.cross_modal_text_layers.apply(self._init_weights) 286 | self.cross_modal_image_pooler.apply(self._init_weights) 287 | self.cross_modal_text_pooler.apply(self._init_weights) 288 | if hasattr(self,"text_projection_layer"): 289 | self.text_projection_layer.apply(self._init_weights) 290 | if hasattr(self,"image_projection_layer"): 291 | self.image_projection_layer.apply(self._init_weights) 292 | 293 | 294 | def forward( 295 | self, 296 | input_ids: Optional[torch.LongTensor] = None, 297 | pixel_values: Optional[torch.FloatTensor] = None, 298 | attention_mask: Optional[torch.Tensor] = None, 299 | position_ids: Optional[torch.LongTensor] = None, 300 | token_type_ids: Optional[torch.LongTensor] = None, 301 | patch_ids = None, 302 | extend_token_type_ids = None, 303 | return_loss: Optional[bool] = None, 304 | return_dict: Optional[bool] = None, 305 | ) -> Union[Tuple[torch.Tensor], VLEModelOutput]: 306 | 307 | return_dict = return_dict if return_dict is not None else self.config.return_dict 308 | 309 | vision_outputs = self.vision_model( 310 | pixel_values=pixel_values, 311 | return_dict=return_dict, 312 | ) 313 | 314 | text_outputs = self.text_model( 315 | input_ids=input_ids, 316 | attention_mask=attention_mask, 317 | token_type_ids=token_type_ids, 318 | position_ids=position_ids, 319 | return_dict=return_dict, 320 | ) 321 | 322 | image_embeds = self.vision_model.vision_model.post_layernorm(vision_outputs[0]) # last_hidden_state 323 | image_embeds = self.image_projection_layer(image_embeds) 324 | 325 | text_embeds = text_outputs[0] # last_hidden_state 326 | text_embeds = self.text_projection_layer(text_embeds) 327 | 328 | if patch_ids is not None: 329 | # add box image embeddings (mean) 330 | # image_embeds : batch_size * (num_patch+1) * dims 331 | image_embeds_size_1 = image_embeds.size(1) 332 | new_image_embeds = [] 333 | for item_image_embeds, item_patch_ids in zip(image_embeds, patch_ids): 334 | add_item_image_embeds = [] 335 | for i_, box_patch_ids in enumerate(item_patch_ids): 336 | # skip cls embedding 337 | box_image_embeds = item_image_embeds[torch.as_tensor(box_patch_ids) + 1] 338 | box_image_embeds = torch.mean(box_image_embeds, dim=0, keepdim=True) 339 | add_item_image_embeds.append(box_image_embeds) 340 | new_image_embeds.append(torch.cat([item_image_embeds] + add_item_image_embeds)) 341 | image_embeds = pad_sequence(new_image_embeds, batch_first=True) 342 | 343 | len_of_ones = torch.as_tensor([len(box_p_ids) for box_p_ids in patch_ids], dtype=torch.int64, device=image_embeds.device) + image_embeds_size_1 344 | image_mask_ones = [torch.ones(l, dtype=torch.long, device=image_embeds.device) for l in len_of_ones] 345 | image_mask_zeros = [torch.zeros(image_embeds.size(1) - l, dtype=torch.long, device=image_embeds.device) for l in len_of_ones] 346 | image_masks = torch.cat([torch.cat([one, zero]).unsqueeze(0) for one, zero in zip(image_mask_ones, image_mask_zeros)], dim=0) 347 | 348 | text_token_type_ids = pad_sequence(list(map(lambda x: torch.as_tensor(x, device=image_embeds.device, dtype=torch.long),extend_token_type_ids[1])),batch_first=True) 349 | text_token_type_ids = torch.cat([text_token_type_ids, torch.zeros(text_embeds.size(0), text_embeds.size(1) - text_token_type_ids.size(1), device=image_embeds.device, dtype=torch.long)], dim=1) 350 | image_added_token_type_ids = pad_sequence(list(map(lambda x: torch.as_tensor(x, device=image_embeds.device, dtype=torch.long),extend_token_type_ids[0])),batch_first=True) 351 | image_token_type_ids = torch.cat([torch.ones(image_embeds.size(0), image_embeds_size_1, dtype=torch.long, device=image_embeds.device), image_added_token_type_ids], dim=1) 352 | else: 353 | image_masks = torch.ones((image_embeds.size(0), image_embeds.size(1)), dtype=torch.long, device=image_embeds.device) 354 | extend_image_masks = self.text_model.get_extended_attention_mask(image_masks, image_masks.size()) 355 | extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, attention_mask.size()) 356 | 357 | if patch_ids is not None and extend_token_type_ids is not None: 358 | text_embeds = text_embeds + self.token_type_embeddings(text_token_type_ids) 359 | image_embeds = image_embeds + self.token_type_embeddings(image_token_type_ids) 360 | else: 361 | image_embeds = image_embeds + self.token_type_embeddings(torch.full_like(image_masks, 1)) 362 | text_embeds = text_embeds + self.token_type_embeddings(torch.zeros_like(attention_mask)) 363 | 364 | x, y = text_embeds, image_embeds 365 | for text_layer, image_layer in zip(self.cross_modal_text_layers, self.cross_modal_image_layers): 366 | x1 = text_layer(x, y, extend_text_masks, extend_image_masks) 367 | y1 = image_layer(y, x, extend_image_masks, extend_text_masks) 368 | x, y = x1[0], y1[0] 369 | 370 | text_embeds, image_embeds = x, y 371 | text_pooler_output = self.cross_modal_text_pooler(x) 372 | image_pooler_output = self.cross_modal_image_pooler(y) 373 | pooler_output = torch.cat([text_pooler_output, image_pooler_output], dim=-1) 374 | 375 | if not return_dict: 376 | output = (pooler_output, text_embeds, image_embeds) 377 | return output 378 | return VLEModelOutput( 379 | pooler_output = pooler_output, 380 | text_embeds = text_embeds, 381 | image_embeds = image_embeds 382 | ) 383 | 384 | 385 | @classmethod 386 | def from_pretrained(cls, *args, **kwargs): 387 | # At the moment fast initialization is not supported 388 | # for composite models 389 | kwargs["_fast_init"] = False 390 | return super().from_pretrained(*args, **kwargs) 391 | 392 | @classmethod 393 | def from_vision_text_pretrained( 394 | cls, 395 | vision_model_name_or_path: str = None, 396 | text_model_name_or_path: str = None, 397 | *model_args, 398 | **kwargs, 399 | ) -> PreTrainedModel: 400 | 401 | kwargs_vision = { 402 | argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") 403 | } 404 | 405 | kwargs_text = { 406 | argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") 407 | } 408 | 409 | # remove vision, text kwargs from kwargs 410 | for key in kwargs_vision.keys(): 411 | del kwargs["vision_" + key] 412 | for key in kwargs_text.keys(): 413 | del kwargs["text_" + key] 414 | 415 | # Load and initialize the vision and text model 416 | vision_model = kwargs_vision.pop("model", None) 417 | if vision_model is None: 418 | if vision_model_name_or_path is None: 419 | raise ValueError( 420 | "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" 421 | ) 422 | 423 | if "config" not in kwargs_vision: 424 | vision_config = AutoConfig.from_pretrained(vision_model_name_or_path) 425 | 426 | if vision_config.model_type == "clip": 427 | kwargs_vision["config"] = vision_config.vision_config 428 | vision_model = CLIPVisionModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) 429 | else: 430 | kwargs_vision["config"] = vision_config 431 | vision_model = AutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) 432 | 433 | text_model = kwargs_text.pop("model", None) 434 | if text_model is None: 435 | if text_model_name_or_path is None: 436 | raise ValueError( 437 | "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined" 438 | ) 439 | 440 | if "config" not in kwargs_text: 441 | text_config = AutoConfig.from_pretrained(text_model_name_or_path) 442 | kwargs_text["config"] = text_config 443 | 444 | text_model = AutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) 445 | 446 | # instantiate config with corresponding kwargs 447 | config = VLEConfig(text_config=text_model.config, vision_config=vision_model.config, **kwargs) 448 | 449 | # init model 450 | model = cls(config=config, vision_model=vision_model, text_model=text_model) 451 | 452 | # the projection layers are always newly initialized when loading the model 453 | # using pre-trained vision and text model. 454 | logger.warning( 455 | "The coattention layers and projection layers are newly initialized. You should probably TRAIN this model on a down-stream task to be" 456 | " able to use it for predictions and inference." 457 | ) 458 | return model 459 | 460 | 461 | def get_text_features( 462 | self, 463 | input_ids=None, 464 | attention_mask=None, 465 | position_ids=None, 466 | token_type_ids=None, 467 | output_attentions=None, 468 | output_hidden_states=None, 469 | return_dict=None, 470 | ): 471 | text_outputs = self.text_model( 472 | input_ids=input_ids, 473 | attention_mask=attention_mask, 474 | position_ids=position_ids, 475 | token_type_ids=token_type_ids, 476 | #output_attentions=output_attentions, 477 | #output_hidden_states=output_hidden_states, 478 | return_dict=return_dict, 479 | ) 480 | return text_outputs[0] # last_hidden_state 481 | 482 | def get_image_features( 483 | self, 484 | pixel_values=None, 485 | output_attentions=None, 486 | output_hidden_states=None, 487 | return_dict=None, 488 | ): 489 | r""" 490 | Returns: 491 | image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by 492 | applying the projection layer to the pooled output of [`CLIPVisionModel`]. 493 | 494 | Examples: 495 | 496 | ```python 497 | >>> from PIL import Image 498 | >>> import requests 499 | >>> from transformers import VLEModel, AutoImageProcessor 500 | 501 | >>> model = VLEModel.from_pretrained("clip-italian/clip-italian") 502 | >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") 503 | 504 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 505 | >>> image = Image.open(requests.get(url, stream=True).raw) 506 | 507 | >>> inputs = image_processor(images=image, return_tensors="pt") 508 | 509 | >>> image_features = model.get_image_features(**inputs) 510 | ```""" 511 | vision_outputs = self.vision_model( 512 | pixel_values=pixel_values, 513 | #output_attentions=output_attentions, 514 | #output_hidden_states=output_hidden_states, 515 | return_dict=return_dict, 516 | ) 517 | last_hidden_state = self.vision_model.vision_model.post_layernorm(vision_outputs[0]) 518 | return last_hidden_state 519 | def get_input_embeddings(self): 520 | return self.text_model.embeddings.word_embeddings 521 | 522 | def set_input_embeddings(self, new_embeddings): 523 | self.text_model.embeddings.word_embeddings = new_embeddings 524 | 525 | class VLEForVQA(VLEPreTrainedModel): 526 | def __init__( 527 | self, 528 | config: Optional[VLEConfig] = None, 529 | vision_model: Optional[PreTrainedModel] = None, 530 | text_model: Optional[PreTrainedModel] = None, 531 | ): 532 | super().__init__(config) 533 | self.vle = VLEModel(config, vision_model, text_model) 534 | 535 | hidden_size = config.hidden_size 536 | self.num_vqa_labels = len(self.config.id2label) 537 | self.vqa_classifier = nn.Sequential( 538 | nn.Linear(hidden_size * 2, hidden_size * 2), 539 | nn.LayerNorm(hidden_size * 2), 540 | nn.GELU(), 541 | nn.Linear(hidden_size * 2, self.num_vqa_labels), 542 | ) 543 | self.vqa_classifier.apply(self._init_weights) 544 | 545 | def forward(self, 546 | input_ids: Optional[torch.LongTensor], 547 | pixel_values: Optional[torch.FloatTensor], 548 | attention_mask: Optional[torch.Tensor] = None, 549 | position_ids: Optional[torch.LongTensor] = None, 550 | token_type_ids: Optional[torch.LongTensor] = None, 551 | patch_ids = None, 552 | vqa_labels = None, 553 | vqa_scores = None, 554 | return_loss: Optional[bool] = None, 555 | return_dict: Optional[bool] = None, 556 | ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]: 557 | 558 | return_dict = return_dict if return_dict is not None else self.config.return_dict 559 | 560 | vle_output = self.vle( 561 | input_ids = input_ids, 562 | pixel_values = pixel_values, 563 | attention_mask = attention_mask, 564 | position_ids = position_ids, 565 | token_type_ids = token_type_ids, 566 | patch_ids = patch_ids,) 567 | pooler_output = vle_output[0] 568 | vqa_logits = self.vqa_classifier(pooler_output) 569 | 570 | 571 | vqa_loss = None 572 | if return_loss and vqa_labels is not None and vqa_scores is not None: 573 | vqa_targets = torch.zeros(len(vqa_logits), self.num_vqa_labels,device=vqa_logits.device) 574 | for i, (_label, _score) in enumerate(zip(vqa_labels, vqa_scores)): 575 | for l, s in zip(_label, _score): 576 | vqa_targets[i, l] = s 577 | vqa_loss = F.binary_cross_entropy_with_logits(vqa_logits, vqa_targets) * vqa_targets.shape[1] 578 | # https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19 579 | 580 | if not return_dict: 581 | output = (vqa_logits,) 582 | return ((vqa_loss,) + output) if vqa_loss is not None else output 583 | return VLEForVQAOutput( 584 | loss = vqa_loss, 585 | logits = vqa_logits 586 | ) 587 | 588 | 589 | class VLEForITM(VLEPreTrainedModel): 590 | def __init__( 591 | self, 592 | config: Optional[VLEConfig] = None, 593 | vision_model: Optional[PreTrainedModel] = None, 594 | text_model: Optional[PreTrainedModel] = None, 595 | ): 596 | super().__init__(config) 597 | self.vle = VLEModel(config, vision_model, text_model) 598 | 599 | hidden_size = config.hidden_size 600 | self.itm_score = ITMHead(hidden_size*2) 601 | self.itm_score.apply(self._init_weights) 602 | 603 | def forward(self, 604 | input_ids: Optional[torch.LongTensor], 605 | pixel_values: Optional[torch.FloatTensor], 606 | attention_mask: Optional[torch.Tensor] = None, 607 | position_ids: Optional[torch.LongTensor] = None, 608 | token_type_ids: Optional[torch.LongTensor] = None, 609 | patch_ids = None, 610 | itm_labels = None, 611 | return_loss: Optional[bool] = None, 612 | return_dict: Optional[bool] = None, 613 | ) -> Union[Tuple[torch.Tensor], VLEForITMOutput]: 614 | 615 | return_dict = return_dict if return_dict is not None else self.config.return_dict 616 | 617 | vle_output = self.vle( 618 | input_ids = input_ids, 619 | pixel_values = pixel_values, 620 | attention_mask = attention_mask, 621 | position_ids = position_ids, 622 | token_type_ids = token_type_ids, 623 | patch_ids = patch_ids,) 624 | pooler_output = vle_output[0] 625 | 626 | itm_logits = self.itm_score(pooler_output) 627 | itm_loss = None 628 | if return_loss and itm_labels is not None: 629 | itm_loss = nn.functional.cross_entropy(itm_logits, torch.tensor(itm_labels).long().to(itm_logits.device)) 630 | if not return_dict: 631 | output = (itm_logits,) 632 | return ((itm_loss,) + output) if itm_loss is not None else output 633 | return VLEForITMOutput(loss = itm_loss, logits = itm_logits) 634 | 635 | 636 | class VLEForPBC(VLEPreTrainedModel): 637 | def __init__( 638 | self, 639 | config: Optional[VLEConfig] = None, 640 | vision_model: Optional[PreTrainedModel] = None, 641 | text_model: Optional[PreTrainedModel] = None, 642 | ): 643 | super().__init__(config) 644 | self.vle = VLEModel(config, vision_model, text_model) 645 | 646 | hidden_size = config.hidden_size 647 | self.pbc_classifier = nn.Sequential( 648 | nn.Linear(hidden_size, hidden_size), 649 | nn.LayerNorm(hidden_size), 650 | nn.GELU(), 651 | nn.Linear(hidden_size, 2), 652 | ) 653 | self.pbc_classifier.apply(self._init_weights) 654 | 655 | def forward(self, 656 | input_ids: Optional[torch.LongTensor], 657 | pixel_values: Optional[torch.FloatTensor], 658 | attention_mask: Optional[torch.Tensor] = None, 659 | position_ids: Optional[torch.LongTensor] = None, 660 | token_type_ids: Optional[torch.LongTensor] = None, 661 | patch_ids = None, 662 | pbc_labels = None, 663 | return_loss: Optional[bool] = None, 664 | return_dict: Optional[bool] = None, 665 | ) -> Union[Tuple[torch.Tensor], VLEForPBCOutput]: 666 | 667 | return_dict = return_dict if return_dict is not None else self.config.return_dict 668 | 669 | vle_output = self.vle( 670 | input_ids = input_ids, 671 | pixel_values = pixel_values, 672 | attention_mask = attention_mask, 673 | position_ids = position_ids, 674 | token_type_ids = token_type_ids, 675 | patch_ids = patch_ids,) 676 | image_embeds = vle_output['image_embeds'] 677 | pbc_logits = self.pbc_classifier(image_embeds[:,1:,:]) 678 | 679 | pbc_loss = None 680 | if return_loss and pbc_labels is not None: 681 | pbc_loss = F.cross_entropy(pbc_logits, torch.tensor(pbc_labels).long().to(pbc_logits.device)) 682 | 683 | if not return_dict: 684 | output = (pbc_logits,) 685 | return ((pbc_loss,) + output) if pbc_loss is not None else output 686 | return VLEForPBCOutput(loss = pbc_loss, logits = pbc_logits) 687 | 688 | 689 | class VLEForMLM(VLEPreTrainedModel): 690 | _keys_to_ignore_on_load_missing = [r"mlm_score.1.predictions.decoder.weight",r"mlm_score.1.predictions.decoder.bias"] 691 | def __init__( 692 | self, 693 | config: Optional[VLEConfig] = None, 694 | vision_model: Optional[PreTrainedModel] = None, 695 | text_model: Optional[PreTrainedModel] = None, 696 | ): 697 | super().__init__(config) 698 | self.vle = VLEModel(config, vision_model, text_model) 699 | 700 | hidden_size = config.hidden_size 701 | mlm_head = DebertaV2OnlyMLMHead(self.config.text_config) 702 | mlm_transform = nn.Linear(hidden_size, self.config.text_config.hidden_size) 703 | self.mlm_score = nn.Sequential( 704 | mlm_transform, 705 | mlm_head, 706 | ) 707 | 708 | def forward(self, 709 | input_ids: Optional[torch.LongTensor], 710 | pixel_values: Optional[torch.FloatTensor], 711 | attention_mask: Optional[torch.Tensor] = None, 712 | position_ids: Optional[torch.LongTensor] = None, 713 | token_type_ids: Optional[torch.LongTensor] = None, 714 | patch_ids = None, 715 | mlm_labels = None, 716 | return_loss: Optional[bool] = None, 717 | return_dict: Optional[bool] = None, 718 | ) -> Union[Tuple[torch.Tensor], VLEForMLMOutput]: 719 | 720 | return_dict = return_dict if return_dict is not None else self.config.return_dict 721 | 722 | vle_output = self.vle( 723 | input_ids = input_ids, 724 | pixel_values = pixel_values, 725 | attention_mask = attention_mask, 726 | position_ids = position_ids, 727 | token_type_ids = token_type_ids, 728 | patch_ids = patch_ids,) 729 | text_feats = vle_output.text_embeds 730 | 731 | mlm_logits = self.mlm_score(text_feats) 732 | mlm_loss = None 733 | if return_loss and mlm_labels is not None: 734 | mlm_loss = F.cross_entropy( 735 | mlm_logits.view(-1, self.config.text_config.vocab_size), 736 | mlm_labels.view(-1), 737 | ignore_index=-100, 738 | ) 739 | if not return_dict: 740 | output = (mlm_logits,) 741 | return ((mlm_loss,) + output) if mlm_loss is not None else output 742 | return VLEForMLMOutput(loss = mlm_loss, logits = mlm_logits) 743 | 744 | 745 | def get_output_embeddings(self): 746 | return self.mlm_score[1].predictions.decoder 747 | 748 | def set_output_embeddings(self, new_embeddings): 749 | self.mlm_score[1].predictions.decoder = new_embeddings 750 | 751 | 752 | class VLEForVCRQ2A(VLEPreTrainedModel): 753 | def __init__( 754 | self, 755 | config: Optional[VLEConfig] = None, 756 | vision_model: Optional[PreTrainedModel] = None, 757 | text_model: Optional[PreTrainedModel] = None, 758 | ): 759 | super().__init__(config) 760 | self.vle = VLEModel(config, vision_model, text_model) 761 | 762 | hidden_size = config.hidden_size 763 | self.vcr_q2a_logit = nn.Sequential( 764 | nn.Linear(hidden_size * 2, hidden_size * 2), 765 | nn.LayerNorm(hidden_size * 2), 766 | nn.GELU(), 767 | nn.Linear(hidden_size * 2, 1), 768 | ) 769 | self.vcr_q2a_logit.apply(self._init_weights) 770 | 771 | def forward(self, 772 | input_ids: Optional[torch.LongTensor], 773 | pixel_values: Optional[torch.FloatTensor], 774 | attention_mask: Optional[torch.Tensor] = None, 775 | position_ids: Optional[torch.LongTensor] = None, 776 | token_type_ids: Optional[torch.LongTensor] = None, 777 | patch_ids = None, 778 | vcr_labels = None, 779 | extend_token_type_ids = None, 780 | return_loss: Optional[bool] = None, 781 | return_dict: Optional[bool] = None, 782 | ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]: 783 | 784 | return_dict = return_dict if return_dict is not None else self.config.return_dict 785 | 786 | infers = [] 787 | for i in range(4): 788 | vle_output = self.vle( 789 | input_ids = input_ids[i], 790 | pixel_values = pixel_values, 791 | attention_mask = attention_mask[i], 792 | position_ids = position_ids, 793 | token_type_ids = token_type_ids[i], 794 | patch_ids = patch_ids[i], 795 | extend_token_type_ids = extend_token_type_ids[i]) 796 | pooler_output = vle_output[0] 797 | logits = self.vcr_q2a_logit(pooler_output) 798 | infers.append(logits) 799 | 800 | vcr_logits = torch.cat(infers, dim=-1) 801 | vcr_loss = None 802 | if return_loss and vcr_labels is not None: 803 | vcr_targets = torch.zeros(len(vcr_logits), dtype=torch.long).to(self.device) 804 | for i, _label in enumerate(vcr_labels): 805 | vcr_targets[i] = _label 806 | vcr_loss = F.cross_entropy(vcr_logits, vcr_targets.view(-1)) 807 | 808 | if not return_dict: 809 | output = (vcr_logits,) 810 | return ((vcr_loss,) + output) if vcr_loss is not None else output 811 | return VLEForVCRQ2AOutput( 812 | loss = vcr_loss, 813 | logits = vcr_logits 814 | ) 815 | 816 | 817 | class VLEForVCRQA2R(VLEPreTrainedModel): 818 | def __init__( 819 | self, 820 | config: Optional[VLEConfig] = None, 821 | vision_model: Optional[PreTrainedModel] = None, 822 | text_model: Optional[PreTrainedModel] = None, 823 | ): 824 | super().__init__(config) 825 | self.vle = VLEModel(config, vision_model, text_model) 826 | 827 | hidden_size = config.hidden_size 828 | self.vcr_qa2r_logit = nn.Sequential( 829 | nn.Linear(hidden_size * 2, hidden_size * 2), 830 | nn.LayerNorm(hidden_size * 2), 831 | nn.GELU(), 832 | nn.Linear(hidden_size * 2, 1), 833 | ) 834 | self.vcr_qa2r_logit.apply(self._init_weights) 835 | 836 | def forward(self, 837 | input_ids: Optional[torch.LongTensor], 838 | pixel_values: Optional[torch.FloatTensor], 839 | attention_mask: Optional[torch.Tensor] = None, 840 | position_ids: Optional[torch.LongTensor] = None, 841 | token_type_ids: Optional[torch.LongTensor] = None, 842 | patch_ids = None, 843 | vcr_labels = None, 844 | extend_token_type_ids = None, 845 | return_loss: Optional[bool] = None, 846 | return_dict: Optional[bool] = None, 847 | ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]: 848 | 849 | return_dict = return_dict if return_dict is not None else self.config.return_dict 850 | 851 | infers = [] 852 | for i in range(4): 853 | vle_output = self.vle( 854 | input_ids = input_ids[i], 855 | pixel_values = pixel_values, 856 | attention_mask = attention_mask[i], 857 | position_ids = position_ids, 858 | token_type_ids = token_type_ids[i], 859 | patch_ids = patch_ids[i], 860 | extend_token_type_ids = extend_token_type_ids[i]) 861 | pooler_output = vle_output[0] 862 | logits = self.vcr_qa2r_logit(pooler_output) 863 | infers.append(logits) 864 | 865 | vcr_logits = torch.cat(infers, dim=-1) 866 | vcr_loss = None 867 | if return_loss and vcr_labels is not None: 868 | vcr_targets = torch.zeros(len(vcr_logits), dtype=torch.long).to(self.device) 869 | for i, _label in enumerate(vcr_labels): 870 | vcr_targets[i] = _label 871 | vcr_loss = F.cross_entropy(vcr_logits, vcr_targets.view(-1)) 872 | 873 | if not return_dict: 874 | output = (vcr_logits,) 875 | return ((vcr_loss,) + output) if vcr_loss is not None else output 876 | return VLEForVCRQA2ROutput( 877 | loss = vcr_loss, 878 | logits = vcr_logits 879 | ) 880 | -------------------------------------------------------------------------------- /models/VLE/pipeline_vle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Pipeline, BatchEncoding 3 | from PIL import Image 4 | from typing import Union 5 | from copy import deepcopy 6 | import matplotlib.pyplot as plt 7 | import io 8 | import json 9 | import unicodedata 10 | import os 11 | 12 | class VLEForVQAPipeline(Pipeline): 13 | 14 | def __init__(self, vle_processor, *args, **kwargs): 15 | self.vle_processor = vle_processor 16 | super().__init__(*args, **kwargs) 17 | 18 | def _sanitize_parameters(self, top_k=None, **kwargs): 19 | preprocess_params, forward_params, postprocess_params = {}, {}, {} 20 | if top_k is not None: 21 | postprocess_params["top_k"] = top_k 22 | return preprocess_params, forward_params, postprocess_params 23 | 24 | def __call__(self, image: Union["Image.Image", str], question: str = None, **kwargs): 25 | 26 | if isinstance(image, (Image.Image, str)) and isinstance(question, str): 27 | inputs = {"image": image, "question": question} 28 | else: 29 | """ 30 | Supports the following format 31 | - {"image": image, "question": question} 32 | - [{"image": image, "question": question}] 33 | - Generator and datasets 34 | """ 35 | inputs = image 36 | results = super().__call__(inputs, **kwargs) 37 | return results 38 | 39 | def preprocess(self, inputs): 40 | model_inputs = self.vle_processor(text=inputs['question'], images=inputs['image'], return_tensors="pt",padding=True) 41 | return model_inputs 42 | 43 | def _forward(self, model_inputs): 44 | model_outputs = self.model(**model_inputs) 45 | return model_outputs 46 | 47 | def postprocess(self, model_outputs, top_k=1): 48 | if top_k > self.model.num_vqa_labels: 49 | top_k = self.model.num_vqa_labels 50 | probs = torch.softmax(model_outputs['logits'], dim=-1) 51 | probs, preds = torch.sort(probs, descending=True) 52 | probs = probs[:,:top_k].tolist()[0] 53 | preds = preds[:,:top_k].tolist()[0] 54 | 55 | return [{"score": score, "answer": self.model.config.id2label[pred]} for score, pred in zip(probs, preds)] 56 | 57 | 58 | 59 | class VLEForPBCPipeline(Pipeline): 60 | def __init__(self, vle_processor, *args, **kwargs): 61 | self.vle_processor = vle_processor 62 | self.id2label = {0:"False",1:"True"} 63 | super().__init__(*args, **kwargs) 64 | 65 | def _sanitize_parameters(self, **kwargs): 66 | preprocess_params, forward_params, postprocess_params = {}, {}, {} 67 | return preprocess_params, forward_params, postprocess_params 68 | 69 | def __call__(self, image: Union["Image.Image", str], text: str = None, **kwargs): 70 | if isinstance(image, (Image.Image, str)) and isinstance(text, str): 71 | inputs = {"image": image, "text": text} 72 | else: 73 | """ 74 | Supports the following format 75 | - {"image": image, "text": text} 76 | - [{"image": image, "text": text}] 77 | - Generator and datasets 78 | """ 79 | inputs = image 80 | results = super().__call__(inputs, **kwargs) 81 | return results 82 | 83 | def preprocess(self, inputs): 84 | model_inputs = self.vle_processor(text=inputs['text'], images=inputs['image'], return_tensors="pt",padding=True) 85 | return model_inputs, inputs['image'] 86 | 87 | def _forward(self, model_inputs): 88 | model_outputs = self.model(**model_inputs[0]) 89 | return model_outputs, model_inputs[1] 90 | 91 | def postprocess(self, model_outputs): 92 | probs = torch.softmax(model_outputs[0]['logits'], dim=-1) 93 | probs = probs.tolist()[0] 94 | new_image = self.paint_in_image(model_outputs[0]['logits'], model_outputs[1]) 95 | return {"score": probs, "image": new_image} 96 | 97 | def paint_in_image(self, logits, raw_image): 98 | image_back = deepcopy(raw_image) 99 | raw_image_size = image_back.size 100 | resized_image_size = self.model.config.vision_config.image_size 101 | patch_size = self.model.config.vision_config.patch_size 102 | probs = torch.softmax(logits.detach()[0,:,1].to('cpu'),dim=-1).numpy().reshape(-1, resized_image_size//patch_size) 103 | 104 | plt.close('all') 105 | plt.axis('off') 106 | plt.imshow(probs, cmap='gray', interpolation='None', vmin=(probs.max()-probs.min())*2/5+probs.min(),alpha=0.7) 107 | plt.xticks([]) 108 | plt.yticks([]) 109 | buf = io.BytesIO() 110 | plt.savefig(buf, dpi=100, transparent=True, bbox_inches='tight', pad_inches=0) 111 | image_front = Image.open(buf) 112 | 113 | def filter_image_front(img: Image.Image): 114 | width, height = img.width, img.height 115 | for x in range(width): 116 | for y in range(height): 117 | r,g,b,a = img.getpixel((x,y)) 118 | a = int (a * (1-r/255)) 119 | img.putpixel((x,y), (r,g,b,a)) 120 | return img 121 | 122 | image_front = filter_image_front(image_front).resize(raw_image_size) 123 | image_back.paste(image_front, (0,0), image_front) 124 | mixed_image = image_back.resize(raw_image_size) 125 | buf.close() 126 | 127 | return mixed_image 128 | 129 | 130 | 131 | class VLEForITMPipeline(Pipeline): 132 | def __init__(self, vle_processor, *args, **kwargs): 133 | self.vle_processor = vle_processor 134 | self.id2label = {0:"False",1:"True"} 135 | super().__init__(*args, **kwargs) 136 | 137 | def _sanitize_parameters(self, **kwargs): 138 | preprocess_params, forward_params, postprocess_params = {}, {}, {} 139 | return preprocess_params, forward_params, postprocess_params 140 | 141 | def __call__(self, image: Union["Image.Image", str], text: str = None, **kwargs): 142 | if isinstance(image, (Image.Image, str)) and isinstance(text, str): 143 | inputs = {"image": image, "text": text} 144 | else: 145 | """ 146 | Supports the following format 147 | - {"image": image, "text": text} 148 | - [{"image": image, "text": text}] 149 | - Generator and datasets 150 | """ 151 | inputs = image 152 | results = super().__call__(inputs, **kwargs) 153 | return results 154 | 155 | def preprocess(self, inputs): 156 | model_inputs = self.vle_processor(text=inputs['text'], images=inputs['image'], return_tensors="pt",padding=True) 157 | return model_inputs 158 | 159 | def _forward(self, model_inputs): 160 | model_outputs = self.model(**model_inputs) 161 | return model_outputs 162 | 163 | def postprocess(self, model_outputs): 164 | probs = torch.softmax(model_outputs['logits'], dim=-1) 165 | preds = torch.argmax(probs, dim=-1) 166 | probs = probs.tolist()[0] 167 | preds = self.id2label[preds.tolist()[0]] 168 | 169 | return {"score": probs, "match": preds} 170 | 171 | 172 | class VLEForVCRQ2APipeline(Pipeline): 173 | 174 | def __init__(self, vle_processor, *args, **kwargs): 175 | self.vle_processor = vle_processor 176 | self.vle_tokenizer = self.vle_processor.tokenizer 177 | self.GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall', 178 | 'Payton', 'Skyler', 'Frankie', 'Pat', 'Quinn'] 179 | self.person_name_id = 0 180 | self.max_text_len = 80 181 | super().__init__(*args, **kwargs) 182 | 183 | def _sanitize_parameters(self, **kwargs): 184 | preprocess_params, forward_params, postprocess_params = {}, {}, {} 185 | return preprocess_params, forward_params, postprocess_params 186 | 187 | def __call__(self, vcr_image_root: str, meta_inputs: dict, **kwargs): 188 | 189 | inputs = {"vcr_image_root": vcr_image_root, "meta_inputs": meta_inputs} 190 | results = super().__call__(inputs, **kwargs) 191 | return results 192 | 193 | def preprocess(self, inputs): 194 | model_inputs = self.vcr_q2a_preprocess(inputs["vcr_image_root"], inputs["meta_inputs"]) 195 | return model_inputs 196 | 197 | def _forward(self, model_inputs): 198 | model_outputs = self.model(**model_inputs) 199 | return model_outputs 200 | 201 | def postprocess(self, model_outputs, top_k=1): 202 | logits = model_outputs["logits"] 203 | loss = model_outputs["loss"] 204 | preds = torch.argmax(logits, dim=-1, keepdim=True) 205 | 206 | return [{"score": score, "pred": pred} for score, pred in zip(logits, preds)] 207 | 208 | def vcr_q2a_preprocess(self, vcr_image_root, data): 209 | image_fn = data["img_fn"] 210 | objects = data["objects"] 211 | metadata_fn = data["metadata_fn"] 212 | question = data["question"] 213 | answer_choices = data["answer_choices"] 214 | rationale_choices = data["rationale_choices"] 215 | answer_label = data["answer_label"] 216 | rationale_label = data["rationale_label"] 217 | 218 | question_text = question 219 | answer_text = answer_choices 220 | text_tokens, text_ids, obj_tags, text_raw = self.build_text(question_text, answer_text, objects, self.vle_tokenizer) 221 | encoding = [self.vle_tokenizer( 222 | ''.join(self.vle_tokenizer.convert_ids_to_tokens(text_ids_)), 223 | padding="max_length", 224 | truncation=True, 225 | max_length=self.max_text_len, 226 | return_special_tokens_mask=True, 227 | ) for text_ids_ in text_ids] 228 | 229 | obj_tags = [[-1] + tags + [-1] for tags in obj_tags] 230 | 231 | image = Image.open(os.path.join(vcr_image_root, image_fn)) 232 | image_feature = self.vle_processor.image_processor(image) 233 | width = image.size[0] 234 | height = image.size[1] 235 | with open(os.path.join(vcr_image_root, metadata_fn), 'r') as f: 236 | vcr_image_metadata = json.load(f) 237 | boxes = vcr_image_metadata['boxes'] 238 | patch_boxes = self.get_patch_box(boxes, width, height, self.model.config.vision_config.patch_size, self.model.config.vision_config.image_size) 239 | related_box_ids, image_added_token_type_ids, text_token_type_ids = list(zip(*[self.get_related_box_ids_and_token_type_ids(tags) for tags in obj_tags])) 240 | related_patch_boxes = [[patch_boxes[i] for i in related_box_ids[j]] for j in range(4)] 241 | 242 | processed_data = { 243 | "image": image_feature, 244 | "text": (text_raw, encoding), 245 | "obj_tags": obj_tags, 246 | "label": answer_label, 247 | "patch_ids": [[patch_box[1] for patch_box in related_patch_boxes[i]] for i in range(4)], 248 | "extend_token_type_ids": [(image_added_token_type_ids[i], text_token_type_ids[i]) for i in range(4)], 249 | } 250 | 251 | model_inputs = { 252 | "input_ids": [[processed_data["text"][1][i]["input_ids"]] for i in range(4)], 253 | "attention_mask": [[processed_data["text"][1][i]["attention_mask"]] for i in range(4)], 254 | "token_type_ids": [[processed_data["text"][1][i]["token_type_ids"]] for i in range(4)], 255 | "pixel_values": torch.Tensor([processed_data["image"]["pixel_values"][0]]), 256 | } 257 | model_inputs = BatchEncoding(model_inputs, tensor_type='pt') 258 | model_inputs.update({ 259 | "patch_ids": [[processed_data["patch_ids"][i]] for i in range(4)], 260 | "vcr_labels": [processed_data["label"]], 261 | "extend_token_type_ids": [list(zip(processed_data["extend_token_type_ids"][i])) for i in range(4)], 262 | "return_loss": True, 263 | }) 264 | return model_inputs 265 | 266 | def retokenize_and_convert_to_ids_with_tag(self, raw_tokens, objects_replace_name, tokenizer, non_obj_tag=-1, add_space_b4_first_token=False): 267 | parsed_tokens = [] 268 | tags = [] 269 | align_ids = [] 270 | raw = [] 271 | align_id = 0 272 | for idx, mixed_token in enumerate(raw_tokens): 273 | if isinstance(mixed_token, list): 274 | tokens = [" " + objects_replace_name[o] for o in mixed_token] 275 | if idx == 0 and not add_space_b4_first_token: 276 | tokens[0] = tokens[0].lstrip() 277 | retokenized_tokens = tokenizer.tokenize(tokens[0]) 278 | raw.append(tokens[0]) 279 | tags.extend([mixed_token[0] + non_obj_tag + 1 for _ in retokenized_tokens]) 280 | align_ids.extend([align_id for _ in retokenized_tokens]) 281 | align_id += 1 282 | for token, o in zip(tokens[1:], mixed_token[1:]): 283 | retokenized_tokens.append(tokenizer.tokenize(' and')[0]) 284 | tags.append(non_obj_tag) 285 | align_ids.append(align_id) 286 | align_id += 1 287 | re_tokens = tokenizer.tokenize(token) 288 | retokenized_tokens.extend(re_tokens) 289 | tags.extend([o + non_obj_tag + 1 for _ in re_tokens]) 290 | align_ids.extend([align_id for _ in re_tokens]) 291 | align_id += 1 292 | raw.extend([' and', token]) 293 | parsed_tokens.extend(retokenized_tokens) 294 | else: 295 | # fully align to original tokens 296 | if True in [unicodedata.category(str_) == 'Co' for str_ in mixed_token]: 297 | continue 298 | if idx != 0 or add_space_b4_first_token: 299 | mixed_token = " " + mixed_token 300 | raw.append(mixed_token) 301 | retokenized_tokens = tokenizer.tokenize(mixed_token) 302 | parsed_tokens.extend(retokenized_tokens) 303 | align_ids.extend([align_id for _ in retokenized_tokens]) 304 | tags.extend([non_obj_tag for _ in retokenized_tokens]) 305 | align_id += 1 306 | ids = tokenizer.convert_tokens_to_ids(parsed_tokens) 307 | ids_with_tag = list(zip(parsed_tokens, ids, tags, align_ids)) 308 | 309 | return ids_with_tag, raw 310 | 311 | def build_text(self, question_text, answer_text, objects, tokenizer): 312 | objects_replace_name = [] 313 | for o in objects: 314 | if o == 'person': 315 | objects_replace_name.append(self.GENDER_NEUTRAL_NAMES[self.person_name_id]) 316 | self.person_name_id = (self.person_name_id + 1) % len(self.GENDER_NEUTRAL_NAMES) 317 | else: 318 | objects_replace_name.append(o) 319 | 320 | non_obj_tag = -1 321 | question_text, question_text_raw = self.retokenize_and_convert_to_ids_with_tag(question_text, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) 322 | answer_text = [self.retokenize_and_convert_to_ids_with_tag(a_t, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) for a_t in answer_text] 323 | answer_text, answer_text_raw = list(zip(*answer_text)) 324 | for a_t, a_t_raw in zip(answer_text, answer_text_raw): 325 | while len(question_text) + len(a_t) > self.max_text_len - 3: 326 | if len(question_text) > len(a_t): 327 | question_text.pop() 328 | else: 329 | a_t.pop() 330 | 331 | text_tokens = [[q_t[0] for q_t in question_text] + [self.vle_tokenizer.sep_token] + [a_t_t[0] for a_t_t in a_t] for a_t in answer_text] 332 | text_ids = [[q_t[1] for q_t in question_text] + [self.vle_tokenizer.sep_token_id] + [a_t_t[1] for a_t_t in a_t] for a_t in answer_text] 333 | obj_tags = [[q_t[2] for q_t in question_text] + [-1] + [a_t_t[2] for a_t_t in a_t] for a_t in answer_text] 334 | text_raw = [question_text_raw + answer_text_raw_ for answer_text_raw_ in answer_text_raw] 335 | 336 | return text_tokens, text_ids, obj_tags, text_raw 337 | 338 | def get_patch_box(self, boxes, width, height, patch_size, image_size): 339 | patch_count_w = image_size // patch_size 340 | patch_count_h = image_size // patch_size 341 | patch_width = width / patch_count_w 342 | patch_height = height / patch_count_h 343 | 344 | patch_boxes = [] 345 | for box in boxes: 346 | box = box[:4] 347 | patch_x1 = int(box[0] // patch_width) 348 | patch_y1 = int(box[1] // patch_height) 349 | patch_x2 = int(box[2] // patch_width) 350 | patch_y2 = int(box[3] // patch_height) 351 | 352 | patch_x1 = patch_x1 if patch_x1 >= 0 else 0 353 | patch_y1 = patch_y1 if patch_y1 >= 0 else 0 354 | patch_x2 = patch_x2 + 1 if patch_x2 < patch_count_w else patch_count_w 355 | patch_y2 = patch_y2 + 1 if patch_y2 < patch_count_h else patch_count_h 356 | 357 | patch_box = [ 358 | patch_x1 * patch_width, 359 | patch_y1 * patch_height, 360 | patch_x2 * patch_width, 361 | patch_y2 * patch_height 362 | ] 363 | 364 | patch_ids = [patch_count_w * y + x for y in range(patch_y1, patch_y2) for x in range(patch_x1, patch_x2)] 365 | patch_boxes.append([patch_box, patch_ids]) 366 | 367 | return patch_boxes 368 | 369 | def get_related_box_ids_and_token_type_ids(self, obj_tags): 370 | no_obj_tag = -1 371 | obj_tags_set = set() 372 | for tag in obj_tags: 373 | if tag != no_obj_tag: 374 | obj_tags_set.add(tag) 375 | obj_tag_remap = {t: i + 2 for i, t in enumerate(obj_tags_set)} 376 | text_token_type_ids = [obj_tag_remap[tag] if tag != no_obj_tag else 0 for tag in obj_tags] 377 | related_box_ids = list(obj_tag_remap.keys()) 378 | image_added_token_type_ids = list(obj_tag_remap.values()) 379 | 380 | return related_box_ids, image_added_token_type_ids, text_token_type_ids 381 | 382 | 383 | class VLEForVCRQA2RPipeline(Pipeline): 384 | 385 | def __init__(self, vle_processor, *args, **kwargs): 386 | self.vle_processor = vle_processor 387 | self.vle_tokenizer = self.vle_processor.tokenizer 388 | self.GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall', 389 | 'Payton', 'Skyler', 'Frankie', 'Pat', 'Quinn'] 390 | self.person_name_id = 0 391 | self.max_text_len = 80 392 | super().__init__(*args, **kwargs) 393 | 394 | def _sanitize_parameters(self, **kwargs): 395 | preprocess_params, forward_params, postprocess_params = {}, {}, {} 396 | return preprocess_params, forward_params, postprocess_params 397 | 398 | def __call__(self, vcr_image_root: str, meta_inputs: dict, **kwargs): 399 | 400 | inputs = {"vcr_image_root": vcr_image_root, "meta_inputs": meta_inputs} 401 | results = super().__call__(inputs, **kwargs) 402 | return results 403 | 404 | def preprocess(self, inputs): 405 | model_inputs = self.vcr_qa2r_preprocess(inputs["vcr_image_root"], inputs["meta_inputs"]) 406 | return model_inputs 407 | 408 | def _forward(self, model_inputs): 409 | model_outputs = self.model(**model_inputs) 410 | return model_outputs 411 | 412 | def postprocess(self, model_outputs, top_k=1): 413 | logits = model_outputs["logits"] 414 | loss = model_outputs["loss"] 415 | preds = torch.argmax(logits, dim=-1, keepdim=True) 416 | 417 | return [{"score": score, "pred": pred} for score, pred in zip(logits, preds)] 418 | 419 | def vcr_qa2r_preprocess(self, vcr_image_root, data): 420 | image_fn = data["img_fn"] 421 | objects = data["objects"] 422 | metadata_fn = data["metadata_fn"] 423 | question = data["question"] 424 | answer_choices = data["answer_choices"] 425 | rationale_choices = data["rationale_choices"] 426 | answer_label = data["answer_label"] 427 | rationale_label = data["rationale_label"] 428 | 429 | question_text = question 430 | answer_text = answer_choices[answer_label] 431 | rationale_text = rationale_choices 432 | text_tokens, text_ids, obj_tags, text_raw = self.build_text(question_text, answer_text, rationale_text, objects, self.vle_tokenizer) 433 | encoding = [self.vle_tokenizer( 434 | ''.join(self.vle_tokenizer.convert_ids_to_tokens(text_ids_)), 435 | padding="max_length", 436 | truncation=True, 437 | max_length=self.max_text_len, 438 | return_special_tokens_mask=True, 439 | ) for text_ids_ in text_ids] 440 | 441 | obj_tags = [[-1] + tags + [-1] for tags in obj_tags] 442 | 443 | image = Image.open(os.path.join(vcr_image_root, image_fn)) 444 | image_feature = self.vle_processor.image_processor(image) 445 | width = image.size[0] 446 | height = image.size[1] 447 | with open(os.path.join(vcr_image_root, metadata_fn), 'r') as f: 448 | vcr_image_metadata = json.load(f) 449 | boxes = vcr_image_metadata['boxes'] 450 | patch_boxes = self.get_patch_box(boxes, width, height, self.model.config.vision_config.patch_size, self.model.config.vision_config.image_size) 451 | related_box_ids, image_added_token_type_ids, text_token_type_ids = list(zip(*[self.get_related_box_ids_and_token_type_ids(tags) for tags in obj_tags])) 452 | related_patch_boxes = [[patch_boxes[i] for i in related_box_ids[j]] for j in range(4)] 453 | 454 | processed_data = { 455 | "image": image_feature, 456 | "text": (text_raw, encoding), 457 | "obj_tags": obj_tags, 458 | "label": rationale_label, 459 | "patch_ids": [[patch_box[1] for patch_box in related_patch_boxes[i]] for i in range(4)], 460 | "extend_token_type_ids": [(image_added_token_type_ids[i], text_token_type_ids[i]) for i in range(4)], 461 | } 462 | 463 | model_inputs = { 464 | "input_ids": [[processed_data["text"][1][i]["input_ids"]] for i in range(4)], 465 | "attention_mask": [[processed_data["text"][1][i]["attention_mask"]] for i in range(4)], 466 | "token_type_ids": [[processed_data["text"][1][i]["token_type_ids"]] for i in range(4)], 467 | "pixel_values": torch.Tensor([processed_data["image"]["pixel_values"][0]]), 468 | } 469 | model_inputs = BatchEncoding(model_inputs, tensor_type='pt') 470 | model_inputs.update({ 471 | "patch_ids": [[processed_data["patch_ids"][i]] for i in range(4)], 472 | "vcr_labels": [processed_data["label"]], 473 | "extend_token_type_ids": [list(zip(processed_data["extend_token_type_ids"][i])) for i in range(4)], 474 | "return_loss": True, 475 | }) 476 | return model_inputs 477 | 478 | def retokenize_and_convert_to_ids_with_tag(self, raw_tokens, objects_replace_name, tokenizer, non_obj_tag=-1, add_space_b4_first_token=False): 479 | parsed_tokens = [] 480 | tags = [] 481 | align_ids = [] 482 | raw = [] 483 | align_id = 0 484 | for idx, mixed_token in enumerate(raw_tokens): 485 | if isinstance(mixed_token, list): 486 | tokens = [" " + objects_replace_name[o] for o in mixed_token] 487 | if idx == 0 and not add_space_b4_first_token: 488 | tokens[0] = tokens[0].lstrip() 489 | retokenized_tokens = tokenizer.tokenize(tokens[0]) 490 | raw.append(tokens[0]) 491 | tags.extend([mixed_token[0] + non_obj_tag + 1 for _ in retokenized_tokens]) 492 | align_ids.extend([align_id for _ in retokenized_tokens]) 493 | align_id += 1 494 | for token, o in zip(tokens[1:], mixed_token[1:]): 495 | retokenized_tokens.append(tokenizer.tokenize(' and')[0]) 496 | tags.append(non_obj_tag) 497 | align_ids.append(align_id) 498 | align_id += 1 499 | re_tokens = tokenizer.tokenize(token) 500 | retokenized_tokens.extend(re_tokens) 501 | tags.extend([o + non_obj_tag + 1 for _ in re_tokens]) 502 | align_ids.extend([align_id for _ in re_tokens]) 503 | align_id += 1 504 | raw.extend([' and', token]) 505 | parsed_tokens.extend(retokenized_tokens) 506 | else: 507 | # fully align to original tokens 508 | if True in [unicodedata.category(str_) == 'Co' for str_ in mixed_token]: 509 | continue 510 | if idx != 0 or add_space_b4_first_token: 511 | mixed_token = " " + mixed_token 512 | raw.append(mixed_token) 513 | retokenized_tokens = tokenizer.tokenize(mixed_token) 514 | parsed_tokens.extend(retokenized_tokens) 515 | align_ids.extend([align_id for _ in retokenized_tokens]) 516 | tags.extend([non_obj_tag for _ in retokenized_tokens]) 517 | align_id += 1 518 | ids = tokenizer.convert_tokens_to_ids(parsed_tokens) 519 | ids_with_tag = list(zip(parsed_tokens, ids, tags, align_ids)) 520 | 521 | return ids_with_tag, raw 522 | 523 | def build_text(self, question_text, answer_text, rationale_text, objects, tokenizer): 524 | objects_replace_name = [] 525 | for o in objects: 526 | if o == 'person': 527 | objects_replace_name.append(self.GENDER_NEUTRAL_NAMES[self.person_name_id]) 528 | self.person_name_id = (self.person_name_id + 1) % len(self.GENDER_NEUTRAL_NAMES) 529 | else: 530 | objects_replace_name.append(o) 531 | 532 | non_obj_tag = -1 533 | question_text, question_text_raw = self.retokenize_and_convert_to_ids_with_tag(question_text, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) 534 | answer_text, answer_text_raw = self.retokenize_and_convert_to_ids_with_tag(answer_text, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) 535 | rationale_text = [self.retokenize_and_convert_to_ids_with_tag(r_t, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) for r_t in rationale_text] 536 | rationale_text, rationale_text_raw = list(zip(*rationale_text)) 537 | for r_t, r_t_raw in zip(rationale_text, rationale_text_raw): 538 | while len(question_text) + len(answer_text) + len(r_t) > self.max_text_len - 4: 539 | if len(r_t) > len(question_text) + len(answer_text): 540 | r_t.pop() 541 | elif len(question_text) > 1: 542 | question_text.pop() 543 | else: 544 | answer_text.pop() 545 | 546 | text_tokens = [[q_t[0] for q_t in question_text] + [tokenizer.sep_token] + [a_t[0] for a_t in answer_text] + [tokenizer.sep_token] + [r_t_t[0] for r_t_t in r_t] for r_t in rationale_text] 547 | text_ids = [[q_t[1] for q_t in question_text] + [tokenizer.sep_token_id] + [a_t[1] for a_t in answer_text] + [tokenizer.sep_token_id] + [r_t_t[1] for r_t_t in r_t] for r_t in rationale_text] 548 | obj_tags = [[q_t[2] for q_t in question_text] + [-1] + [a_t[2] for a_t in answer_text] + [-1] + [r_t_t[2] for r_t_t in r_t] for r_t in rationale_text] 549 | text_raw = [question_text_raw + answer_text_raw + rationale_text_raw_ for rationale_text_raw_ in rationale_text_raw] 550 | 551 | return text_tokens, text_ids, obj_tags, text_raw 552 | 553 | def get_patch_box(self, boxes, width, height, patch_size, image_size): 554 | patch_count_w = image_size // patch_size 555 | patch_count_h = image_size // patch_size 556 | patch_width = width / patch_count_w 557 | patch_height = height / patch_count_h 558 | 559 | patch_boxes = [] 560 | for box in boxes: 561 | box = box[:4] 562 | patch_x1 = int(box[0] // patch_width) 563 | patch_y1 = int(box[1] // patch_height) 564 | patch_x2 = int(box[2] // patch_width) 565 | patch_y2 = int(box[3] // patch_height) 566 | 567 | patch_x1 = patch_x1 if patch_x1 >= 0 else 0 568 | patch_y1 = patch_y1 if patch_y1 >= 0 else 0 569 | patch_x2 = patch_x2 + 1 if patch_x2 < patch_count_w else patch_count_w 570 | patch_y2 = patch_y2 + 1 if patch_y2 < patch_count_h else patch_count_h 571 | 572 | patch_box = [ 573 | patch_x1 * patch_width, 574 | patch_y1 * patch_height, 575 | patch_x2 * patch_width, 576 | patch_y2 * patch_height 577 | ] 578 | 579 | patch_ids = [patch_count_w * y + x for y in range(patch_y1, patch_y2) for x in range(patch_x1, patch_x2)] 580 | patch_boxes.append([patch_box, patch_ids]) 581 | 582 | return patch_boxes 583 | 584 | def get_related_box_ids_and_token_type_ids(self, obj_tags): 585 | no_obj_tag = -1 586 | obj_tags_set = set() 587 | for tag in obj_tags: 588 | if tag != no_obj_tag: 589 | obj_tags_set.add(tag) 590 | obj_tag_remap = {t: i + 2 for i, t in enumerate(obj_tags_set)} 591 | text_token_type_ids = [obj_tag_remap[tag] if tag != no_obj_tag else 0 for tag in obj_tags] 592 | related_box_ids = list(obj_tag_remap.keys()) 593 | image_added_token_type_ids = list(obj_tag_remap.values()) 594 | 595 | return related_box_ids, image_added_token_type_ids, text_token_type_ids -------------------------------------------------------------------------------- /models/VLE/processing_vle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Processor class for VLE 17 | """ 18 | 19 | import warnings 20 | 21 | from transformers.processing_utils import ProcessorMixin 22 | from transformers.tokenization_utils_base import BatchEncoding 23 | 24 | 25 | class VLEProcessor(ProcessorMixin): 26 | r""" 27 | Constructs a VLE processor which wraps an image processor and a tokenizer into a single 28 | processor. 29 | 30 | [`VLEProcessor`] offers all the functionalities of [`AutoImageProcessor`] and [`AutoTokenizer`]. 31 | See the [`~VLEProcessor.__call__`] and [`~VLEProcessor.decode`] for more 32 | information. 33 | 34 | Args: 35 | image_processor ([`AutoImageProcessor`]): 36 | The image processor is a required input. 37 | tokenizer ([`PreTrainedTokenizer`]): 38 | The tokenizer is a required input. 39 | """ 40 | attributes = ["image_processor", "tokenizer"] 41 | image_processor_class = "CLIPImageProcessor" 42 | tokenizer_class = "DebertaV2Tokenizer" 43 | 44 | def __init__(self, image_processor=None, tokenizer=None, **kwargs): 45 | if "feature_extractor" in kwargs: 46 | warnings.warn( 47 | "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" 48 | " instead.", 49 | FutureWarning, 50 | ) 51 | feature_extractor = kwargs.pop("feature_extractor") 52 | 53 | image_processor = image_processor if image_processor is not None else feature_extractor 54 | if image_processor is None: 55 | raise ValueError("You need to specify an `image_processor`.") 56 | if tokenizer is None: 57 | raise ValueError("You need to specify a `tokenizer`.") 58 | 59 | super().__init__(image_processor, tokenizer) 60 | self.current_processor = self.image_processor 61 | 62 | def __call__(self, text=None, images=None, return_tensors=None, **kwargs): 63 | """ 64 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 65 | and `kwargs` arguments to VLETokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not 66 | `None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 67 | AutoImageProcessor's [`~AutoImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 68 | of the above two methods for more information. 69 | 70 | Args: 71 | text (`str`, `List[str]`, `List[List[str]]`): 72 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 73 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 74 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 75 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 76 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 77 | tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a 78 | number of channels, H and W are image height and width. 79 | 80 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 81 | If set, will return tensors of a particular framework. Acceptable values are: 82 | 83 | - `'tf'`: Return TensorFlow `tf.constant` objects. 84 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 85 | - `'np'`: Return NumPy `np.ndarray` objects. 86 | - `'jax'`: Return JAX `jnp.ndarray` objects. 87 | 88 | Returns: 89 | [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: 90 | 91 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 92 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 93 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 94 | `None`). 95 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 96 | """ 97 | 98 | if text is None and images is None: 99 | raise ValueError("You have to specify either text or images. Both cannot be none.") 100 | 101 | if text is not None: 102 | encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) 103 | 104 | if images is not None: 105 | image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) 106 | 107 | if text is not None and images is not None: 108 | encoding["pixel_values"] = image_features.pixel_values 109 | return encoding 110 | elif text is not None: 111 | return encoding 112 | else: 113 | return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) 114 | 115 | def batch_decode(self, *args, **kwargs): 116 | """ 117 | This method forwards all its arguments to VLETokenizer's 118 | [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. 119 | """ 120 | return self.tokenizer.batch_decode(*args, **kwargs) 121 | 122 | def decode(self, *args, **kwargs): 123 | """ 124 | This method forwards all its arguments to VLETokenizer's [`~PreTrainedTokenizer.decode`]. 125 | Please refer to the docstring of this method for more information. 126 | """ 127 | return self.tokenizer.decode(*args, **kwargs) 128 | 129 | @property 130 | def model_input_names(self): 131 | tokenizer_input_names = self.tokenizer.model_input_names 132 | image_processor_input_names = self.image_processor.model_input_names 133 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 134 | 135 | @property 136 | def feature_extractor_class(self): 137 | warnings.warn( 138 | "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", 139 | FutureWarning, 140 | ) 141 | return self.image_processor_class 142 | 143 | @property 144 | def feature_extractor(self): 145 | warnings.warn( 146 | "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", 147 | FutureWarning, 148 | ) 149 | return self.image_processor 150 | -------------------------------------------------------------------------------- /pics/VQALLM_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/VQALLM_workflow.png -------------------------------------------------------------------------------- /pics/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/banner.png -------------------------------------------------------------------------------- /pics/birds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/birds.jpg -------------------------------------------------------------------------------- /pics/demo-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/demo-banner.png -------------------------------------------------------------------------------- /pics/dogs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/dogs.png -------------------------------------------------------------------------------- /pics/door.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/door.png -------------------------------------------------------------------------------- /pics/fishing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/fishing.png -------------------------------------------------------------------------------- /pics/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/model.png -------------------------------------------------------------------------------- /pics/pink_tongues.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/pink_tongues.png -------------------------------------------------------------------------------- /pics/qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/qrcode.jpg -------------------------------------------------------------------------------- /pics/truck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/truck.png --------------------------------------------------------------------------------