├── .github
└── stale.yml
├── .gitignore
├── LICENSE
├── README.md
├── README_ZH.md
├── examples
├── VCR
│ ├── README.md
│ └── vcr_sample_images
│ │ ├── .gitkeep
│ │ └── lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban
│ │ ├── .gitkeep
│ │ ├── 1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg
│ │ └── 1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json
└── VQA
│ ├── README.md
│ ├── convert_checkpoint_after_ft.py
│ ├── run_vqav2_ft.py
│ ├── vqa_train_config.json
│ ├── vqav2_datamodule.py
│ ├── vqav2_train_module.py
│ ├── write_vqa.py
│ └── zero_to_fp32.py
├── models
└── VLE
│ ├── __init__.py
│ ├── configuration_vle.py
│ ├── modeling_vle.py
│ ├── pipeline_vle.py
│ └── processing_vle.py
└── pics
├── VQALLM_workflow.png
├── banner.png
├── birds.jpg
├── demo-banner.png
├── dogs.png
├── door.png
├── fishing.png
├── model.png
├── pink_tongues.png
├── qrcode.jpg
└── truck.png
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 4
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 4
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: stale
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: >
18 | Closing the issue, since no updates observed.
19 | Feel free to re-open if you need any further assistance.
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | */.DS_Store
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 HY_cog9
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [**中文**](README_ZH.md) | [**English**](https://github.com/iflytek/VLE)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | # VLE: Vision-Language Encoder
21 |
22 | Multimodal pre-trained models are trained on massive multimodal data, and they can utilize information from different modalities and perform various cross-modal tasks.
23 |
24 | In this repository, we introduce **VLE** (**V**ision-**L**anguage **E**ncoder), an image-text multimodal understanding model built on the pre-trained text and image encoders. It can be used for multimodal discriminative tasks such as visual question answering and image-text retrieval. Especially on the visual commonsense reasoning (VCR) task, which requires high-level language understanding and reasoning skills, VLE achieves the best performance among the public methods.
25 |
26 | Recently, LLMs (Large Language Models) have achieved great success and have been used for a wide range of text tasks, including translation, question answering, text summarization, etc. While LLMs are unimodal, their abilities can be leveraged for multimodal understanding tasks. We propose a VQA+LLM pipeline that integrates multimodal models with LLMs for the visual question answering task. It helps the VQA model generate more accurate and fluent answers.
27 |
28 | We open-source VLE-related resources for promoting academic research and better facilitating our community.
29 |
30 | **Try our VLE-based [VQA Demo](https://huggingface.co/spaces/hfl/VQA_VLE_LLM) at 🤗Space 👇👇👇**
31 |
32 |
33 |
34 | ----
35 |
36 | [Chinese LERT](https://github.com/ymcui/LERT) | [Chinese and English PERT](https://github.com/ymcui/PERT) | [Chinese MacBERT](https://github.com/ymcui/MacBERT) | [ChineseMiniRBT](https://github.com/iflytek/MiniRBT) | [Chinese ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [Chinese XLNet](https://github.com/ymcui/Chinese-XLNet) | [Chinese BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [ Knowledge distillation tool TextBrewer](https://github.com/airaria/TextBrewer) | [Model pruning tool TextPruner](https://github.com/airaria/TextPruner)
37 |
38 | More resources released by HFL: https://github.com/iflytek/HFL-Anthology
39 |
40 | ## Table of Contents
41 |
42 | | Section | Description |
43 | | ----------------------------- | ----------------------------------- |
44 | | [Introduction](#introduction) | Introduction to VLE |
45 | | [Downloads](#downloads) | Download links for VLE |
46 | | [Comparison](#comparison) | Comparison of VLE with other models |
47 | | [VQA with LLM](#vqa-with-llm) | Visual question answering with LLM |
48 | | [Usage](#usage) | How to load VLE for different tasks |
49 |
50 | ## Introduction
51 |
52 | ### Structure
53 |
54 | The structure of VLE is similar to [METER](https://arxiv.org/abs/2111.02387), which consists of two unimodal encoders for text and image separately, followed by a crossmodal fusion module. However, there are several structural differences between VLE and METER:
55 |
56 | * VLE uses DeBERTa-v3 as the text encoder, which is stronger than RoBERTa-base used in METER.
57 | * In the large version of VLE (VLE-large), the hidden size of the crossmodal co-attention fusion module is scaled up to 1024 to increase capacities.
58 | * During fine-tuning, VLE introduces additional token_type_embeddings.
59 |
60 | ### Pre-training
61 |
62 | VLE is pre-trained with image-caption pairs. There are four objectives applied during the pre-training stage:
63 | * **MLM** (Masked Language Modeling): Given an image-caption pair, we randomly mask some input text tokens, and the model is trained to reconstruct the original tokens.
64 | * **ITM** (Image-Text Matching): Given a batch of matched or mismatched image-caption pairs, the model needs to identify which images and captions correspond to each other.
65 | * **MPC** (Masked Patch-box Classification): Given an image-caption pair with some patches masked, the model needs to predict the classes of the objects in the masked patches.
66 | * **PBC** (Patch-box Classification): Given an image-caption pair, the models need to identify which patches are related to the caption.
67 |
68 | VLE models are pre-trained on 14M public English image-caption pairs for 25k steps with a batch size of 2048.
69 |
70 | The following figure illustrates the VLE structure and the pre-training objectives (for simplicity, we omit the PBC objective in the figure).
71 |
72 |
73 |
74 | ### Adaptation for downstream tasks
75 |
76 | #### Visual Question Answering (VQA)
77 |
78 | * We follow the standard practice to train the models on VQA with both training and validation data, and test the models on the test-dev set. The pooler output from the last layer of the fusion module is used for classification.
79 |
80 | #### Visual Commonsense Reasoning (VCR)
81 |
82 | * We format VCR as a multiple-choice task which is similar to RACE. For each object in the image in each example, we append the average of patches that cover the object to the image feature embeddings before the fusion module. We also assign token_type_ids to the objects in the image and text to improve alignment between different modalities.
83 |
84 |
85 | ## Downloads
86 |
87 | The model weights are in PyTorch format and can be downloaded through the 🤗 transformers model hub. You can either download the weights and configurations manually or initialize a VLE model with `from_pretrained(model_name)` method in your code. See [Usage](#usage) for details.
88 |
89 | ### Pre-trained Checkpoints
90 |
91 | | Model | Text Encoder | Image Encoder | # Params* | MODEL_NAME | Link |
92 | | --------- | ---------------- | ---------------------- | -------------------- | ------------- | -------------------------------------------- |
93 | | VLE-base | DeBERTa-v3-base | CLIP-ViT-base-patch16 | 378M | hfl/vle-base | [link](https://huggingface.co/hfl/vle-base) |
94 | | VLE-large | DeBERTa-v3-large | CLIP-ViT-large-patch14 | 930M | hfl/vle-large | [link](https://huggingface.co/hfl/vle-large) |
95 |
96 | * : We exclude task heads when counting the number of parameters.
97 |
98 | ### Fine-tuned Checkpoints
99 |
100 | | Model | Text Encoder | Image Encoder | MODEL_NAME | Link |
101 | | ---------------------- | ---------------- | ---------------------- | -------------------------- | --------------------------------------------------------- |
102 | | VLE-base-for-VQA | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vqa | [link](https://huggingface.co/hfl/vle-base-for-vqa) |
103 | | VLE-large-for-VQA | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vqa | [link](https://huggingface.co/hfl/vle-large-for-vqa) |
104 | | VLE-base-for-VCR-q2a | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-base-for-vcr-q2a) |
105 | | VLE-large-for-VCR-q2a | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-large-for-vcr-q2a) |
106 | | VLE-base-for-VCR-qa2r | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-base-for-vcr-qa2r) |
107 | | VLE-large-for-VCR-qa2r | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-large-for-vcr-qa2r) |
108 |
109 | ## Comparison
110 |
111 | In the following table, we compare the performance of VLE with METER and other multimodal models. The VQA results are on the test-dev set, and the VCR results are on the dev set.
112 |
113 | | Model | VQA | VCR (QA2R) | VCR (Q2A) | #Params | #PT data* |
114 | | ------------------- | ---------------- | -------------- | ------------- | ------------ | ------- |
115 | | CoCa | 82.3 | - | - | 2.1 B | unknown |
116 | | BeiT-3 | 84.2 | - | - | 1.9 B | 21M(I-T) + 14M(I) + 160G(T) |
117 | | OFA | 82.0 | - | - | 930M | 20M(I-T) + 39M(I) + 140G(T) |
118 | | BLIP | 78.3 | - | - | 385M | ~130M(I-T) |
119 | | METER-base | 77.7 (76.8†‡) | 79.8§ | 77.6§ | 345M | 9M(I-T) |
120 | | METER-Huge | 80.3 | - | - | 878M | 20M(I-T) |
121 | | VLE-base | 77.6‡ | 83.7§ | 79.9§ | 378M | 15M(I-T) |
122 | | VLE-large | 79.3‡ | 87.5§ | 84.3§ | 930M | 15M(I-T) |
123 |
124 | † : Result from our reimplementation.
125 |
126 | ‡ : Fine-tuning hyperparameters: lr=7e-6, batch_size={256, 512}, num_epochs=10
127 |
128 | § : Fine-tuning hyperparameters: lr=1e-5, batch_size=128, num_epochs=5
129 |
130 | * : Pre-training data. I-T: Image-caption pairs. I: Images. T: Text.
131 |
132 | From the above results, we can see that:
133 |
134 | * **VLE is pre-training efficient**. Compared to models with similar model sizes, VLE achieves comparable or even better performance on VQA with much less pre-training data.
135 |
136 | * **VLE shows higher reasoning ability**. Especially it significantly outperforms METER on Visual Commonsense Reasoning (VCR), which requires higher level language and reasoning skills than VQA.
137 |
138 | ## VQA with LLM
139 |
140 | ### Generating Accurate and Fluent VQA Answers
141 |
142 | LLMs have achieved great success on a wide range of text tasks, while the abilities of LLMs can also be leveraged for multimodal understanding tasks. Specifically, we present a VQA+LLM pipeline that integrates multimodal models with LLMs for the visual question answering task, which helps the VQA model to generate more accurate and fluent answers.
143 |
144 | The workflows are shown in the figure below.
145 |
146 |
147 |
148 | (a) VQA: This is the standard way to perform the VQA task with a discriminative model. The question and the image are fed into the multimodal model, and the model is trained to predict the correct answer labels.
149 |
150 | (b) VQA + LLM: The captioning model generates a caption of the image. The caption, question, and answer candidates predicted by the VQA model are concatenated and fed to the LLM. The LLM is asked to give the most reasonable answer.
151 |
152 | We find that VQA+LLM can not only make more accurate predictions, but also generate more fluent and readable predictions. We list some examples:
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 | The demo is available at : https://huggingface.co/spaces/hfl/VQA_VLE_LLM
161 |
162 |
163 | ## Usage
164 |
165 | **Requirements**
166 |
167 | * PIL
168 | * Transformers >= 4.25
169 | * PyTorch Lightning (only required for running fine-tuning scripts)
170 |
171 | The model classes and utilities are defined in the `*.py` files in [models/VLE](models/VLE). To import VLE into your code, just copy [models](models) directory into your project.
172 |
173 | To run the following demo code, `git clone` the repository and `cd` into it, ensuring you are in the repository's root directory.
174 |
175 | ### Load the VLEModel
176 |
177 | ```python
178 | from models.VLE import VLEModel, VLEProcessor
179 | from PIL import Image
180 | import torch
181 |
182 | model_name="hfl/vle-large"
183 | images = [Image.open('pics/dogs.png')]
184 | text = ["There are dogs on the grass."]
185 |
186 | model = VLEModel.from_pretrained(model_name)
187 | vle_processor = VLEProcessor.from_pretrained(model_name)
188 | multimodal_inputs = vle_processor(text=text,images=images, return_tensors='pt',padding=True)
189 |
190 | #forward
191 | vle_output = model(**multimodal_inputs)
192 | ```
193 |
194 | ### Inference
195 |
196 | #### Visual Question Answering (VQA)
197 |
198 | ```python
199 | from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline
200 | from PIL import Image
201 |
202 | model_name="hfl/vle-base-for-vqa"
203 | text= "What is the color of the floor?"
204 | image = Image.open("pics/door.png")
205 |
206 | model = VLEForVQA.from_pretrained(model_name)
207 | vle_processor = VLEProcessor.from_pretrained(model_name)
208 | vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor)
209 |
210 | vqa_answers = vqa_pipeline(image=image, question=text, top_k=5)
211 | print(f"Question: {text}. Answers: {vqa_answers}")
212 | ```
213 |
214 | #### Image-Text Matching
215 |
216 | ```python
217 | from models.VLE import VLEForITM, VLEProcessor, VLEForITMPipeline
218 | from PIL import Image
219 |
220 | model_dir = 'hfl/vle-base'
221 | itm_text = ["a photo of a cat.", "a photo of dogs."]
222 | itm_images = Image.open("pics/dogs.png")
223 |
224 | print("Init ITM model")
225 | model = VLEForITM.from_pretrained(model_dir)
226 | vle_processor = VLEProcessor.from_pretrained(model_dir)
227 |
228 | print("init ITM pipeline")
229 | itm_pipeline = VLEForITMPipeline(model=model, device='cpu', vle_processor=vle_processor)
230 | itm_pred = itm_pipeline([{"image": itm_images, "text": itm_text[0]},
231 | {"image": itm_images, "text": itm_text[1]}])
232 |
233 | for t, pred in zip(itm_text,itm_pred):
234 | print(t,pred)
235 | ```
236 |
237 | #### Patch Box Classification
238 |
239 | ```python
240 | from models.VLE import VLEForPBC, VLEProcessor, VLEForPBCPipeline
241 | from PIL import Image
242 |
243 | model_dir = 'hfl/vle-base'
244 | pbc_text = "pink tongues"
245 | pbc_image = Image.open("pics/dogs.png")
246 |
247 | print("Init PBC model")
248 | model = VLEForPBC.from_pretrained(model_dir)
249 | vle_processor = VLEProcessor.from_pretrained(model_dir)
250 |
251 | print("init PBC pipeline")
252 | pbc_pipeline = VLEForPBCPipeline(model=model, device='cpu', vle_processor=vle_processor)
253 | pbc_pred = pbc_pipeline(image=pbc_image,text=pbc_text)
254 | print(pbc_text)
255 | pbc_pred['image'].save('pics/pink_tongues.png')
256 | ```
257 |
258 |
259 | #### Visual Commonsense Reasoning (VCR)
260 |
261 | Please follow the instructions in [examples/VCR/README.md](examples/VCR/README.md)
262 |
263 | ### Fine-tuning
264 |
265 | #### Fine-tuning on VQA
266 |
267 | Please follow the instructions in [examples/VQA/README.md](examples/VQA/README.md)
268 |
269 | ## Follow us
270 |
271 | Welcome to follow the official WeChat account of HFL to keep up with the latest technical developments.
272 |
273 | 
274 |
275 | ## Disclaimer
276 | This repository's resources are solely intended for academic purposes, and we assume no responsibility for any unforeseen damages or losses that may result from their use.
277 |
278 | This is not an official product by iFLYTEK Co., Ltd.
279 |
280 | ## Issues
281 |
282 | If you have questions, please submit them in a GitHub Issue.
283 |
284 | - Before submitting an issue, please check whether the FAQ can solve the problem, and it is recommended to check whether the previous issue can solve your problem.
285 | - Duplicate and unrelated issues will be handled by [stable-bot](stale · GitHub Marketplace).
286 | - We will try our best to answer your questions, but there is no guarantee that your questions will be answered.
287 | - Politely ask questions and build a harmonious discussion community.
288 |
--------------------------------------------------------------------------------
/README_ZH.md:
--------------------------------------------------------------------------------
1 | [**中文**](README_ZH.md) | [**English**](https://github.com/iflytek/VLE)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | # VLE:视觉-语言多模态预训练模型
19 |
20 | 多模态预训练模型通过在多种模态的大规模数据上的预训练,可以综合利用来自不同模态的信息,执行各种跨模态任务。在本项目中,我们推出了**VLE** (**V**ision-**L**anguage **E**ncoder),一种基于预训练文本和图像编码器的图像-文本多模态理解模型,可应用于如视觉问答、图像-文本检索等多模态判别任务。特别地,在对语言理解和推理能力有更强要求的视觉常识推理(VCR)任务中,VLE取得了公开模型中的最佳效果。
21 |
22 | 最近,大型语言模型(LLM)取得了巨大成功,并被用于翻译、问答、摘要等文本任务。虽然LLM是单模态模型,但它们的能力也可用于辅助多模态理解任务。借助LLM的zero-shot能力,我们设计了一种VQA+LLM方案,将大型语言模型集成到视觉问答任务中,实现了帮助视觉问答模型生成更准确和流畅的答案。
23 |
24 | 我们开源VLE相关资源以供学术研究参考。
25 |
26 | 在线演示地址:https://huggingface.co/spaces/hfl/VQA_VLE_LLM
27 |
28 | ----
29 |
30 | [中文LERT](https://github.com/ymcui/LERT) | [中英文PERT](https://github.com/ymcui/PERT) | [中文MacBERT](https://github.com/ymcui/MacBERT) | [中文MiniRBT](https://github.com/iflytek/MiniRBT) | [中文ELECTRA](https://github.com/ymcui/Chinese-ELECTRA) | [中文XLNet](https://github.com/ymcui/Chinese-XLNet) | [中文BERT](https://github.com/ymcui/Chinese-BERT-wwm) | [知识蒸馏工具TextBrewer](https://github.com/airaria/TextBrewer) | [模型裁剪工具TextPruner](https://github.com/airaria/TextPruner)
31 |
32 | 查看更多哈工大讯飞联合实验室(HFL)发布的资源:https://github.com/iflytek/HFL-Anthology
33 |
34 | ## 内容导引
35 |
36 | | 章节 | 描述 |
37 | | ----------------------------- | ----------------------------------------------------------- |
38 | | [简介](#简介) | VLE的结构和训练技术 |
39 | | [模型下载](#模型下载) | VLE预训练模型下载地址 |
40 | | [模型对比](#模型对比) | 多模态任务上VLE与其他模型的比较 |
41 | | [结合大模型的视觉问答](#结合大模型的视觉问答) | 结合大模型的视觉问答策略 |
42 | | [模型使用](#模型使用) | 模型的加载与使用方法 |
43 |
44 | ## 简介
45 |
46 | ### 模型结构
47 |
48 | VLE模型采用双流结构,与METER模型结构类似,由两个单模态编码器(图像编码器和文本编码器)和一个跨模态融合模块构成。VLE与METER的结构上的差异在于:
49 |
50 | * VLE使用DeBERTa-v3作为文本编码器,其性能优于METER中使用的RoBERTa-base。
51 | * 在VLE-large中,跨模态融合模块的隐层维度增加至1024,以增加模型的容量。
52 | * 在精调阶段,VLE引入了额外的token类型向量表示。
53 |
54 | ### 预训练
55 |
56 | VLE使用图文对数据进行预训练。在预训练阶段,VLE采用了四个预训练任务:
57 |
58 | * **MLM** (Masked Language Modeling):掩码预测任务。给定图文对,随机遮掩文本中的部分单词,训练模型还原遮掩的文本。
59 | * **ITM** (Image-Text Matching):图文匹配预测任务。给定图文对,训练模型判断图像和文本是否匹配。
60 | * **MPC** (Masked Patch-box Classification):遮掩Patch分类任务,给定图文对,并遮掩掉图片中包含具体对象的patch,训练模型预测被遮掩的对象种类。
61 | * **PBC** (Patch-box classification):Patch分类任务。给定图文对,预测图片中的哪些patch与文本描述相关。
62 |
63 | VLE在14M的英文图文对数据上进行了25000步的预训练,batch大小为2048。下图展示了VLE的模型结构和部分预训练任务(MLM、ITM和MPC)。
64 |
65 |
66 |
67 |
68 | ### 下游任务适配
69 |
70 | #### 视觉问答 (VQA)
71 |
72 | * 我们遵循标准做法,使用VQA的训练集(training set)和验证集(validation set)训练模型,在test-dev集上进行验证。我们采用模型的融合层的pooler的输出进行分类任务的训练。
73 |
74 | #### 视觉常识推理 (VCR)
75 |
76 | * 我们将VCR格式化为一个类似于RACE的选择题任务,并对于每张图像中的对象,将覆盖该对象的patch的表示的平均池化值添加到融合模块之前的图像特征序列中。我们还为图像和文本中的对象添加额外的token_type_ids,以注入不同模态之间的对齐信息,提升模型的对齐性能。
77 |
78 |
79 | ## 模型下载
80 |
81 | 本次发布了VLE-base和VLE-large两个版本的预训练模型,模型权重为PyTorch格式,可以选择手动从🤗 transformers模型库下载权重和配置文件,或者在代码中使用 `from_pretrained(model_name)` 以自动加载模型。详细方法参加[模型使用](#模型使用)。
82 |
83 | ### 预训练权重
84 |
85 | | 模型 | 文本编码器 | 图像编码器 | 参数量* | MODEL_NAME | 链接 |
86 | | --------- | ---------------- | ---------------------- | ------------------ | ------------- | -------------------------------------------- |
87 | | VLE-base | DeBERTa-v3-base | CLIP-ViT-base-patch16 | 378M | hfl/vle-base | [link](https://huggingface.co/hfl/vle-base) |
88 | | VLE-large | DeBERTa-v3-large | CLIP-ViT-large-patch14 | 930M | hfl/vle-large | [link](https://huggingface.co/hfl/vle-large) |
89 |
90 | * : 仅计算encoder和emebddings的参数。特定任务的预测层的参数量未计入。
91 |
92 | ### 精调权重
93 |
94 | | 模型 | 文本编码器 | 图像编码器 | MODEL_NAME | 链接 |
95 | | ---------------------- | ---------------- | ---------------------- | -------------------------- | --------------------------------------------------------- |
96 | | VLE-base-for-VQA | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vqa | [link](https://huggingface.co/hfl/vle-base-for-vqa) |
97 | | VLE-large-for-VQA | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vqa | [link](https://huggingface.co/hfl/vle-large-for-vqa) |
98 | | VLE-base-for-VCR-q2a | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-base-for-vcr-q2a) |
99 | | VLE-large-for-VCR-q2a | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-q2a | [link](https://huggingface.co/hfl/vle-large-for-vcr-q2a) |
100 | | VLE-base-for-VCR-qa2r | DeBERTa-v3-base | CLIP-ViT-base-patch16 | hfl/vle-base-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-base-for-vcr-qa2r) |
101 | | VLE-large-for-VCR-qa2r | DeBERTa-v3-large | CLIP-ViT-large-patch14 | hfl/vle-large-for-vcr-qa2r | [link](https://huggingface.co/hfl/vle-large-for-vcr-qa2r) |
102 |
103 | ## 模型对比
104 |
105 | 在下表中,我们比较了VLE、METER以及其他多模态模型的参数量、预训练数据和下游任务效果。其中VQA展示的的是test-dev集上的效果;VCR展示的是dev集上的效果。
106 |
107 | | 模型 | VQA | VCR (QA2R) | VCR (Q2A) | 参数量 | 预训练数据量* |
108 | | ------------------- | ---------------- | -------------- | ------------- | ------------ | ------- |
109 | | CoCa | 82.3 | - | - | 2.1 B | 未知 |
110 | | BeiT-3 | 84.2 | - | - | 1.9 B | 21M(I-T) + 14M(I) + 160G(T) |
111 | | OFA | 82.0 | - | - | 930M | 20M(I-T) + 39M(I) + 140G(T) |
112 | | BLIP | 78.3 | - | - | 385M | ~130M(I-T) |
113 | | METER-base | 77.7 (76.8†‡) | 79.8§ | 77.6§ | 345M | 9M(I-T) |
114 | | METER-Huge | 80.3 | - | - | 878M | 20M(I-T) |
115 | | VLE-base | 77.6‡ | 83.7§ | 79.9§ | 378M | 15M(I-T) |
116 | | VLE-large | 79.3‡ | 87.5§ | 84.3§ | 930M | 15M(I-T) |
117 |
118 | † : 复现效果
119 |
120 | ‡ : 精调参数: lr=7e-6, batch_size={256, 512}, num_epochs=10
121 |
122 | § : 精调参数: lr=1e-5, batch_size=128, num_epochs=5
123 |
124 | * : I-T: 图文对. I: 图像. T: 文本.
125 |
126 | 观察上表可以发现:
127 |
128 | * **VLE的预训练更高效**:与大小相近的模型相比,VLE使用了更少的预训练数据,并在视觉问答上取得了相当甚至更好的效果。
129 | * **VLE有更强的推理能力**: 特别地,在对推理能力要求更高的视觉常识推理(VCR)任务上,VLE显著地超过了具有相似结构的METER。
130 |
131 | ## 结合大模型的视觉问答
132 |
133 | 最近,随着指令微调、RLHF等技术的发展,LLM在多种文本任务中取得了巨大的成功。尽管LLM是单模态模型,但它们的能力也可用于辅助多模态理解任务。具体而言,我们提出一种VQA + LLM方案,将多模态模型与LLM集成到视觉问答任务中,从而帮助VQA模型生成更准确和流畅的答案。下图展示了系统流程。
134 |
135 | 
136 |
137 | (a) VQA: 这是使用判别模型执行VQA任务的标准方式。输入问题和图像到多模态模型中,训练模型预测正确的答案标签。
138 |
139 | (b) VQA + LLM: 首先利用captioning模型生成图片的描述;将图片描述、问题以及VQA模型的详细预测结果拼接,组合成合适的prompt的形式送入LLM,最后要求LLM模型回复最合理的答案。
140 |
141 | VQA+LLM生成的答案更准确,也有更高的可读性。下面是一些例子:
142 |
143 | 
144 |
145 |
146 |
147 | 
148 |
149 | Demo地址(仅供学术研究):https://huggingface.co/spaces/hfl/VQA_VLE_LLM
150 |
151 |
152 |
153 |
154 | ## 模型使用
155 |
156 | **环境要求**
157 |
158 | * PIL
159 | * Transformers >= 4.25
160 | * PyTorch Lightning (仅用于运行精调脚本)
161 |
162 | 模型相关代码位于[models/VLE](models/VLE)目录下的`*py`文件中。因此,要使用VLE模型,仅需把[models](models)目录复制到你的项目代码目录即可。
163 |
164 | 要运行以下演示代码,请使用`git clone`命令下载本仓库至本地,并进入仓库的根目录。
165 |
166 | ### 加载VLEModel
167 |
168 | ```python
169 | from models.VLE import VLEModel, VLEProcessor
170 | from PIL import Image
171 | import torch
172 |
173 | model_name="hfl/vle-large"
174 | images = [Image.open('pics/dogs.png')]
175 | text = ["There are dogs on the grass."]
176 |
177 | model = VLEModel.from_pretrained(model_name)
178 | vle_processor = VLEProcessor.from_pretrained(model_name)
179 | multimodal_inputs = vle_processor(text=text,images=images, return_tensors='pt',padding=True)
180 |
181 | #forward
182 | vle_output = model(**multimodal_inputs)
183 | ```
184 |
185 | ### 推理
186 |
187 | #### 视觉问答 (VQA)
188 |
189 | ```python
190 | from models.VLE import VLEForVQA, VLEProcessor, VLEForVQAPipeline
191 | from PIL import Image
192 |
193 | model_name="hfl/vle-base-for-vqa"
194 | text= "What is the color of the floor?"
195 | image = Image.open("pics/door.png")
196 |
197 | model = VLEForVQA.from_pretrained(model_name)
198 | vle_processor = VLEProcessor.from_pretrained(model_name)
199 | vqa_pipeline = VLEForVQAPipeline(model=model, device='cpu', vle_processor=vle_processor)
200 |
201 | vqa_answers = vqa_pipeline(image=image, question=text, top_k=5)
202 | print(f"Question: {text}. Answers: {vqa_answers}")
203 | ```
204 |
205 | #### 图文匹配(ITM)
206 |
207 | ```python
208 | from models.VLE import VLEForITM, VLEProcessor, VLEForITMPipeline
209 | from PIL import Image
210 |
211 | model_dir = 'hfl/vle-base'
212 | itm_text = ["a photo of a cat.", "a photo of dogs."]
213 | itm_images = Image.open("pics/dogs.png")
214 |
215 | print("Init ITM model")
216 | model = VLEForITM.from_pretrained(model_dir)
217 | vle_processor = VLEProcessor.from_pretrained(model_dir)
218 |
219 | print("init ITM pipeline")
220 | itm_pipeline = VLEForITMPipeline(model=model, device='cpu', vle_processor=vle_processor)
221 | itm_pred = itm_pipeline([{"image": itm_images, "text": itm_text[0]},
222 | {"image": itm_images, "text": itm_text[1]}])
223 |
224 | for t, pred in zip(itm_text,itm_pred):
225 | print(t,pred)
226 | ```
227 |
228 | #### Patch分类(PBC)
229 |
230 | ```python
231 | from models.VLE import VLEForPBC, VLEProcessor, VLEForPBCPipeline
232 | from PIL import Image
233 |
234 | model_dir = 'hfl/vle-base'
235 | pbc_text = "pink tongues"
236 | pbc_image = Image.open("pics/dogs.png")
237 |
238 | print("Init PBC model")
239 | model = VLEForPBC.from_pretrained(model_dir)
240 | vle_processor = VLEProcessor.from_pretrained(model_dir)
241 |
242 | print("init PBC pipeline")
243 | pbc_pipeline = VLEForPBCPipeline(model=model, device='cpu', vle_processor=vle_processor)
244 | pbc_pred = pbc_pipeline(image=pbc_image,text=pbc_text)
245 | print(pbc_text)
246 | pbc_pred['image'].save('pics/pink_tongues.png')
247 | ```
248 |
249 |
250 | #### 视觉常识推理(VCR)
251 |
252 | 详细步骤参见 [examples/VCR/README.md](examples/VCR/README.md)
253 |
254 | ### 精调
255 |
256 | #### VQA任务精调
257 |
258 | 详细步骤参见 [examples/VQA/README.md](examples/VQA/README.md)
259 |
260 | ## 关注我们
261 | 欢迎关注哈工大讯飞联合实验室官方微信公众号,了解最新的技术动态。
262 |
263 | 
264 |
265 | ## 免责声明
266 | 本项目相关资源应仅用于学术研究。我们不对使用相关资源产生的任何损失负责。
267 |
268 | 本目录相关内容不属于科大讯飞官方产品。
269 |
270 | ## 问题反馈
271 | 如有问题,请在GitHub Issue中提交。
272 |
273 | - 在提交问题之前,请先查看FAQ能否解决问题,同时建议查阅以往的issue是否能解决你的问题。
274 | - 重复以及与本项目无关的issue会被[stable-bot](stale · GitHub Marketplace)处理,敬请谅解。
275 | - 我们会尽可能的解答你的问题,但无法保证你的问题一定会被解答。
276 | - 礼貌地提出问题,构建和谐的讨论社区。
277 |
--------------------------------------------------------------------------------
/examples/VCR/README.md:
--------------------------------------------------------------------------------
1 | # Inference on Visual Commonsense Reasoning (VCR)
2 |
3 | ## Dataset Preparation for VCR
4 |
5 | Download the VCR dataset from [VCR official site](https://visualcommonsense.com/download/), including [Annotations](https://s3.us-west-2.amazonaws.com/ai2-rowanz/vcr1annots.zip) and [Images](https://s3.us-west-2.amazonaws.com/ai2-rowanz/vcr1images.zip).
6 | Unzip the downloaded file.
7 |
8 | ## Inference with VCR pipeline
9 |
10 | Here are two examples of using fine-tuned VLEForVCR Q2A and QA2R models to infer on a VCR sample.
11 | The sample's image is placed in `vcr_sample_images/` as follows, and the annotation `meta_data` is taken from VCR validation set.
12 |
13 | vcr_sample_images
14 | └──lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban
15 | ├──1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168_0.jpg
16 | └──1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168_0.json
17 |
18 | ### VCR Q2A
19 | ```python
20 | from models.VLE import VLEForVCRQ2A, VLEProcessor, VLEForVCRQ2APipeline
21 |
22 | model_name = 'hfl/vle-large-for-vcr-q2a'
23 | model = VLEForVCRQ2A.from_pretrained(model_name)
24 | vle_processor = VLEProcessor.from_pretrained(model_name)
25 | vcr_q2a_pipeline = VLEForVCRQ2APipeline(model=model, device='cpu', vle_processor=vle_processor)
26 |
27 | vcr_image_root = 'examples/VCR/vcr_sample_images'
28 | meta_data = {"movie": "1054_Harry_Potter_and_the_prisoner_of_azkaban", "objects": ["person", "person", "person", "car", "cellphone", "clock"], "interesting_scores": [-1, 0], "answer_likelihood": "possible", "img_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg", "metadata_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json", "answer_orig": "No, 1 is a visitor.", "question_orig": "Does 1 live in this house?", "rationale_orig": "1 is wearing outerwear, holding an umbrella, and there is a car outside.", "question": ["Does", [0], "live", "in", "this", "house", "?"], "answer_match_iter": [2, 3, 0, 1], "answer_sources": [10104, 5332, 1, 16646], "answer_choices": [["No", ",", [0], "lives", "nowhere", "close", "."], ["Yes", ",", [0], "works", "there", "."], ["No", ",", [0], "is", "a", "visitor", "."], ["No", [1], "does", "not", "belong", "here", "."]], "answer_label": 2, "rationale_choices": [[[0], "is", "nicely", "dressed", "with", "a", "tie", ".", "people", "dress", "up", "when", "they", "visit", "someone", "else", "."], [[2], "sits", "comfortably", "in", "a", "chair", ",", "reading", "papers", ",", "while", "it", "seems", [0], "has", "just", "arrived", "and", "is", "settling", "in", "."], [[1], "is", "wearing", "a", "coat", "and", "muff", "and", "is", "sitting", "as", "if", "a", "visitor", "."], [[0], "is", "wearing", "outerwear", ",", "holding", "an", "umbrella", ",", "and", "there", "is", "a", "car", "outside", "."]], "rationale_sources": [26162, 12999, 6661, 1], "rationale_match_iter": [1, 3, 2, 0], "rationale_label": 3, "img_id": "val-0", "question_number": 1, "annot_id": "val-1", "match_fold": "val-0", "match_index": 1}
29 |
30 | vcr_outputs = vcr_q2a_pipeline(vcr_image_root=vcr_image_root, meta_inputs=meta_data)
31 | pred = vcr_outputs[0]["pred"]
32 | print(f'Q: {meta_data["question"]}')
33 | print(f'A1: {meta_data["answer_choices"][0]}')
34 | print(f'A2: {meta_data["answer_choices"][1]}')
35 | print(f'A3: {meta_data["answer_choices"][2]}')
36 | print(f'A4: {meta_data["answer_choices"][3]}')
37 | print(f'Label: {meta_data["answer_label"] + 1}')
38 | print(f'predict: {pred[0] + 1}')
39 | ```
40 |
41 | ### VCR QA2R
42 | ```python
43 | from models.VLE import VLEForVCRQA2R, VLEProcessor, VLEForVCRQA2RPipeline
44 |
45 | model_name = 'hfl/vle-large-for-vcr-qa2r'
46 | model = VLEForVCRQA2R.from_pretrained(model_name)
47 | vle_processor = VLEProcessor.from_pretrained(model_name)
48 | vcr_qa2r_pipeline = VLEForVCRQA2RPipeline(model=model, device='cpu', vle_processor=vle_processor)
49 |
50 | vcr_image_root = 'examples/VCR/vcr_sample_images'
51 | meta_data = {"movie": "1054_Harry_Potter_and_the_prisoner_of_azkaban", "objects": ["person", "person", "person", "car", "cellphone", "clock"], "interesting_scores": [-1, 0], "answer_likelihood": "possible", "img_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg", "metadata_fn": "lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json", "answer_orig": "No, 1 is a visitor.", "question_orig": "Does 1 live in this house?", "rationale_orig": "1 is wearing outerwear, holding an umbrella, and there is a car outside.", "question": ["Does", [0], "live", "in", "this", "house", "?"], "answer_match_iter": [2, 3, 0, 1], "answer_sources": [10104, 5332, 1, 16646], "answer_choices": [["No", ",", [0], "lives", "nowhere", "close", "."], ["Yes", ",", [0], "works", "there", "."], ["No", ",", [0], "is", "a", "visitor", "."], ["No", [1], "does", "not", "belong", "here", "."]], "answer_label": 2, "rationale_choices": [[[0], "is", "nicely", "dressed", "with", "a", "tie", ".", "people", "dress", "up", "when", "they", "visit", "someone", "else", "."], [[2], "sits", "comfortably", "in", "a", "chair", ",", "reading", "papers", ",", "while", "it", "seems", [0], "has", "just", "arrived", "and", "is", "settling", "in", "."], [[1], "is", "wearing", "a", "coat", "and", "muff", "and", "is", "sitting", "as", "if", "a", "visitor", "."], [[0], "is", "wearing", "outerwear", ",", "holding", "an", "umbrella", ",", "and", "there", "is", "a", "car", "outside", "."]], "rationale_sources": [26162, 12999, 6661, 1], "rationale_match_iter": [1, 3, 2, 0], "rationale_label": 3, "img_id": "val-0", "question_number": 1, "annot_id": "val-1", "match_fold": "val-0", "match_index": 1}
52 |
53 | vcr_outputs = vcr_qa2r_pipeline(vcr_image_root=vcr_image_root, meta_inputs=meta_data)
54 | pred = vcr_outputs[0]["pred"]
55 | print(f'Q: {meta_data["question"]}')
56 | print(f'A: {meta_data["answer_choices"][meta_data["answer_label"]]}')
57 | print(f'R1: {meta_data["rationale_choices"][0]}')
58 | print(f'R2: {meta_data["rationale_choices"][1]}')
59 | print(f'R3: {meta_data["rationale_choices"][2]}')
60 | print(f'R4: {meta_data["rationale_choices"][3]}')
61 | print(f'Label: {meta_data["rationale_label"] + 1}')
62 | print(f'predict: {pred[0] + 1}')
63 | ```
64 |
--------------------------------------------------------------------------------
/examples/VCR/vcr_sample_images/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/examples/VCR/vcr_sample_images/.gitkeep
--------------------------------------------------------------------------------
/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/.gitkeep
--------------------------------------------------------------------------------
/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.jpg
--------------------------------------------------------------------------------
/examples/VCR/vcr_sample_images/lsmdc_1054_Harry_Potter_and_the_prisoner_of_azkaban/1054_Harry_Potter_and_the_prisoner_of_azkaban_00.01.46.736-00.01.50.168@0.json:
--------------------------------------------------------------------------------
1 | {"boxes": [[955.7418212890625, 52.329559326171875, 1551.5677490234375, 789.0325927734375, 0.9993261098861694], [20.42578125, 32.5933837890625, 916.5200805664062, 787.1964721679688, 0.9991201758384705], [902.0066528320312, 111.65625, 1035.2879638671875, 701.32861328125, 0.9876241683959961], [1403.113037109375, 282.20465087890625, 1542.410888671875, 557.3927612304688, 0.7517433762550354], [785.5349731445312, 527.7738647460938, 841.3390502929688, 657.4290161132812, 0.8897306323051453], [366.7726745605469, 0.0, 487.1645812988281, 79.29119873046875, 0.9390438199043274]], "segms": [[[[1133, 64], [1126, 66], [1115, 69], [1112, 70], [1106, 73], [1100, 75], [1081, 79], [1074, 81], [1072, 82], [1070, 83], [1065, 88], [1061, 95], [1050, 105], [1044, 111], [1042, 114], [1040, 117], [1039, 120], [1038, 127], [1035, 136], [1031, 144], [1030, 152], [1029, 161], [1028, 169], [1028, 171], [1029, 172], [1031, 181], [1033, 192], [1034, 199], [1035, 207], [1035, 215], [1036, 224], [1039, 236], [1042, 247], [1045, 253], [1048, 256], [1048, 257], [1050, 259], [1063, 271], [1066, 275], [1066, 281], [1065, 290], [1064, 299], [1063, 302], [1056, 309], [1053, 311], [1043, 316], [1040, 318], [1028, 330], [1025, 334], [1022, 338], [1016, 347], [1013, 352], [1010, 359], [1004, 367], [1002, 371], [999, 377], [996, 386], [994, 395], [990, 407], [985, 416], [983, 421], [979, 433], [978, 438], [973, 467], [970, 479], [967, 486], [965, 495], [964, 500], [963, 506], [962, 512], [961, 521], [960, 532], [959, 548], [959, 624], [960, 643], [961, 654], [962, 664], [963, 672], [964, 678], [965, 683], [966, 687], [968, 692], [971, 697], [973, 703], [978, 720], [979, 722], [980, 724], [982, 727], [988, 733], [992, 738], [995, 742], [1000, 750], [1005, 755], [1011, 760], [1023, 769], [1026, 771], [1029, 773], [1031, 774], [1034, 775], [1037, 776], [1041, 777], [1046, 778], [1051, 779], [1066, 779], [1067, 780], [1081, 780], [1120, 779], [1153, 777], [1159, 777], [1171, 779], [1193, 778], [1203, 778], [1218, 780], [1227, 780], [1260, 780], [1268, 780], [1325, 778], [1394, 778], [1429, 775], [1458, 775], [1476, 778], [1497, 778], [1514, 780], [1522, 780], [1523, 779], [1532, 779], [1540, 778], [1542, 776], [1543, 774], [1544, 772], [1547, 763], [1548, 759], [1549, 754], [1550, 749], [1550, 587], [1549, 575], [1548, 565], [1546, 555], [1545, 550], [1543, 543], [1542, 540], [1541, 538], [1540, 536], [1536, 531], [1530, 519], [1527, 509], [1521, 497], [1517, 492], [1514, 486], [1512, 480], [1506, 462], [1504, 451], [1501, 442], [1498, 435], [1493, 426], [1490, 417], [1488, 410], [1484, 390], [1481, 381], [1480, 379], [1479, 377], [1473, 370], [1471, 366], [1467, 358], [1466, 353], [1463, 347], [1461, 344], [1459, 341], [1451, 333], [1449, 330], [1447, 325], [1445, 322], [1441, 318], [1441, 317], [1438, 314], [1427, 304], [1423, 297], [1417, 291], [1413, 289], [1409, 286], [1405, 282], [1402, 278], [1401, 275], [1399, 272], [1392, 265], [1388, 263], [1378, 253], [1376, 249], [1368, 241], [1360, 236], [1356, 233], [1351, 229], [1339, 217], [1336, 215], [1328, 211], [1323, 210], [1310, 208], [1301, 206], [1297, 205], [1293, 204], [1290, 203], [1287, 202], [1285, 201], [1283, 200], [1268, 186], [1260, 178], [1258, 175], [1257, 173], [1256, 171], [1255, 166], [1254, 158], [1254, 149], [1253, 141], [1252, 135], [1251, 130], [1250, 125], [1249, 121], [1248, 117], [1246, 114], [1233, 101], [1231, 98], [1228, 92], [1225, 88], [1221, 84], [1218, 83], [1203, 78], [1199, 77], [1191, 76], [1183, 74], [1176, 71], [1169, 70], [1159, 66], [1155, 65], [1145, 64]]], [[[253, 38], [244, 39], [236, 40], [230, 41], [227, 42], [224, 43], [211, 50], [206, 52], [203, 53], [176, 61], [171, 63], [166, 65], [164, 66], [159, 70], [157, 72], [157, 73], [154, 77], [149, 83], [134, 97], [128, 103], [125, 107], [123, 110], [121, 114], [118, 120], [117, 123], [116, 126], [115, 133], [112, 143], [110, 148], [104, 163], [103, 165], [102, 167], [100, 169], [99, 171], [98, 173], [96, 178], [95, 181], [94, 186], [92, 199], [91, 208], [92, 218], [93, 228], [94, 236], [95, 243], [97, 254], [98, 259], [99, 262], [100, 265], [105, 275], [108, 284], [109, 288], [110, 292], [111, 297], [114, 315], [115, 323], [116, 341], [119, 362], [120, 368], [123, 374], [125, 377], [135, 386], [138, 389], [138, 390], [142, 394], [143, 396], [144, 398], [147, 406], [152, 423], [154, 427], [157, 433], [159, 436], [161, 439], [169, 447], [171, 450], [171, 452], [170, 467], [169, 477], [168, 479], [167, 481], [163, 486], [159, 494], [157, 499], [154, 508], [153, 512], [152, 516], [150, 526], [148, 541], [146, 550], [145, 554], [144, 558], [143, 560], [141, 563], [124, 580], [122, 583], [121, 585], [119, 591], [113, 611], [112, 613], [103, 622], [96, 626], [89, 633], [87, 636], [86, 638], [85, 640], [84, 645], [81, 652], [80, 654], [79, 656], [76, 661], [74, 664], [58, 680], [55, 684], [52, 689], [50, 694], [48, 697], [46, 700], [30, 716], [28, 719], [27, 721], [26, 728], [25, 736], [24, 745], [24, 748], [26, 755], [29, 764], [32, 771], [33, 773], [36, 776], [38, 777], [45, 779], [57, 782], [62, 783], [68, 784], [76, 784], [106, 783], [121, 782], [134, 781], [142, 780], [149, 779], [163, 776], [174, 776], [200, 777], [218, 779], [238, 780], [263, 781], [357, 782], [500, 782], [554, 782], [584, 781], [593, 780], [601, 779], [609, 778], [614, 777], [623, 774], [631, 770], [637, 768], [643, 767], [646, 766], [651, 764], [658, 761], [678, 751], [685, 745], [688, 743], [691, 741], [697, 738], [702, 736], [712, 733], [720, 729], [723, 727], [726, 725], [734, 717], [738, 714], [744, 711], [755, 707], [761, 704], [770, 699], [773, 697], [776, 695], [784, 687], [789, 681], [792, 677], [795, 673], [797, 670], [799, 667], [800, 665], [803, 658], [809, 645], [810, 642], [811, 639], [812, 633], [813, 626], [814, 618], [815, 606], [815, 581], [814, 574], [813, 568], [812, 562], [811, 558], [805, 552], [803, 551], [798, 550], [791, 549], [783, 548], [773, 547], [769, 549], [766, 551], [752, 565], [748, 568], [743, 571], [736, 574], [733, 576], [729, 579], [719, 589], [716, 591], [708, 595], [704, 596], [700, 597], [686, 600], [676, 601], [671, 602], [662, 604], [655, 606], [645, 610], [643, 610], [637, 609], [631, 607], [625, 606], [619, 606], [613, 605], [608, 603], [600, 599], [584, 591], [578, 587], [574, 583], [565, 573], [558, 567], [548, 560], [539, 551], [533, 546], [517, 535], [511, 529], [509, 525], [505, 517], [501, 505], [501, 488], [502, 474], [503, 470], [504, 466], [506, 459], [507, 456], [508, 453], [511, 447], [513, 444], [520, 436], [522, 433], [524, 429], [527, 423], [528, 419], [530, 409], [531, 400], [532, 395], [536, 383], [538, 378], [542, 370], [550, 363], [554, 359], [558, 354], [561, 349], [564, 343], [565, 338], [566, 326], [567, 320], [568, 314], [571, 302], [576, 289], [577, 285], [577, 279], [581, 262], [582, 255], [583, 246], [584, 237], [584, 234], [583, 231], [581, 226], [577, 219], [575, 214], [574, 211], [574, 208], [569, 196], [566, 180], [565, 177], [563, 172], [561, 167], [560, 165], [559, 163], [556, 158], [554, 155], [548, 149], [541, 144], [536, 139], [531, 133], [527, 126], [522, 121], [518, 118], [514, 116], [507, 113], [502, 110], [498, 107], [493, 103], [485, 95], [479, 91], [466, 83], [459, 78], [455, 74], [452, 72], [449, 70], [445, 68], [439, 65], [434, 63], [425, 60], [421, 59], [411, 58], [388, 57], [378, 56], [370, 55], [362, 54], [355, 53], [348, 52], [342, 51], [336, 50], [331, 49], [326, 48], [312, 44], [297, 41], [290, 40], [281, 39], [270, 38]]], [[[941, 111], [926, 112], [922, 113], [920, 114], [915, 119], [912, 124], [909, 129], [908, 131], [907, 134], [906, 137], [905, 141], [904, 150], [903, 171], [903, 201], [904, 221], [906, 236], [907, 251], [907, 367], [908, 390], [910, 406], [911, 423], [912, 446], [912, 494], [913, 515], [914, 528], [917, 550], [918, 562], [919, 577], [920, 596], [920, 617], [922, 638], [922, 676], [924, 692], [926, 693], [931, 694], [939, 694], [944, 693], [947, 692], [949, 691], [949, 678], [948, 668], [947, 664], [945, 660], [944, 656], [943, 652], [943, 636], [944, 624], [947, 612], [947, 583], [948, 576], [949, 569], [950, 564], [954, 554], [955, 550], [955, 546], [956, 542], [960, 531], [966, 507], [970, 481], [974, 468], [975, 459], [981, 436], [983, 429], [986, 423], [992, 402], [994, 394], [997, 382], [998, 377], [999, 366], [1002, 355], [1002, 332], [999, 316], [999, 309], [1002, 302], [1003, 295], [1004, 293], [1005, 291], [1009, 287], [1010, 284], [1012, 278], [1013, 274], [1013, 269], [1014, 267], [1016, 265], [1017, 263], [1019, 258], [1020, 253], [1021, 247], [1022, 238], [1023, 229], [1024, 217], [1025, 205], [1025, 174], [1024, 159], [1023, 154], [1022, 149], [1021, 145], [1020, 141], [1019, 139], [1018, 137], [1016, 134], [1009, 127], [1006, 126], [1004, 125], [1003, 124], [1003, 123], [1001, 120], [999, 118], [990, 114], [986, 113], [980, 112], [966, 111]]], [[[1437, 282], [1425, 283], [1422, 284], [1420, 285], [1418, 286], [1417, 287], [1416, 289], [1415, 292], [1415, 297], [1422, 304], [1423, 307], [1425, 309], [1428, 311], [1436, 317], [1439, 319], [1442, 320], [1445, 322], [1448, 325], [1448, 326], [1451, 329], [1454, 330], [1458, 334], [1460, 338], [1461, 339], [1466, 343], [1467, 345], [1469, 350], [1471, 355], [1472, 358], [1474, 366], [1477, 377], [1478, 382], [1479, 391], [1480, 395], [1482, 398], [1484, 405], [1484, 411], [1485, 416], [1488, 423], [1489, 430], [1493, 438], [1499, 458], [1502, 463], [1503, 467], [1503, 470], [1504, 474], [1509, 483], [1511, 488], [1512, 491], [1514, 501], [1515, 504], [1517, 507], [1517, 510], [1518, 513], [1523, 522], [1524, 528], [1525, 531], [1526, 533], [1534, 542], [1535, 545], [1540, 543], [1540, 539], [1541, 538], [1542, 525], [1542, 478], [1541, 417], [1539, 358], [1539, 350], [1540, 306], [1539, 299], [1538, 295], [1537, 293], [1535, 290], [1533, 290], [1532, 289], [1528, 289], [1527, 288], [1524, 288], [1520, 290], [1495, 290], [1478, 287], [1473, 285], [1466, 284], [1460, 283], [1450, 282]]], [[[807, 529], [805, 530], [803, 532], [800, 536], [799, 539], [798, 540], [797, 540], [795, 542], [794, 546], [793, 550], [793, 566], [792, 568], [791, 570], [786, 575], [786, 601], [787, 604], [791, 608], [793, 611], [794, 613], [795, 617], [796, 633], [796, 651], [797, 653], [799, 656], [801, 656], [802, 657], [818, 656], [830, 657], [833, 657], [834, 656], [839, 655], [840, 651], [840, 573], [839, 544], [838, 541], [836, 539], [835, 534], [833, 531], [831, 530], [826, 529]]], [[[387, 2], [379, 3], [375, 4], [371, 5], [370, 6], [368, 6], [368, 8], [367, 9], [367, 35], [368, 36], [368, 44], [372, 47], [377, 49], [382, 53], [391, 57], [396, 59], [399, 60], [402, 60], [411, 64], [429, 69], [436, 73], [448, 76], [452, 76], [455, 75], [459, 72], [461, 72], [465, 70], [466, 67], [470, 66], [472, 66], [474, 64], [474, 62], [477, 59], [481, 60], [483, 57], [484, 55], [485, 53], [486, 45], [486, 21], [485, 10], [482, 6], [481, 5], [479, 4], [474, 3], [464, 2]]]], "names": ["person", "person", "person", "car", "cellphone", "clock"], "width": 1920, "height": 797}
--------------------------------------------------------------------------------
/examples/VQA/README.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning on VQA
2 |
3 | ## Requirements
4 |
5 | We use `Pytorch-Lightning` to fine-tuning the pre-trained VLEModel on VQA. To speedup training, we use `DeepSpeed`. The main packages are as follows:
6 |
7 | ```bash
8 | pytorch_lightning==1.5.10
9 | transformers==4.26.0
10 | deepspeed==0.7.7
11 | Pillow==8.1.0
12 | tqdm==4.64.1
13 | ipdb==0.13.4
14 | numpy==1.21.6
15 | einops==0.3.0
16 | pyarrow==2.0.0
17 | sacred==0.8.2
18 | pandas==1.1.5
19 | timm==0.4.12
20 | ftfy
21 | torchvision~=0.8.2
22 | torch~=1.7.1
23 | ```
24 |
25 | ## Dataset Preparation for VQAv2
26 |
27 | Download the VQAv2 dataset from [VQA official site](https://visualqa.org/download.html), including COCO [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip), [2015 test images](http://images.cocodataset.org/zips/test2015.zip), annotations ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip)), and questions ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip), [test](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip)).
28 |
29 | Please unzip and organize the dataset as follows:
30 |
31 | root
32 | ├── train2014
33 | │ ├── COCO_train2014_000000000009.jpg
34 | | └── ...
35 | ├── val2014
36 | | ├── COCO_val2014_000000000042.jpg
37 | | └── ...
38 | ├── test2015
39 | | ├── COCO_test2015_000000000001.jpg
40 | | └── ...
41 | ├── v2_OpenEnded_mscoco_train2014_questions.json
42 | ├── v2_OpenEnded_mscoco_val2014_questions.json
43 | ├── v2_OpenEnded_mscoco_test2015_questions.json
44 | ├── v2_OpenEnded_mscoco_test-dev2015_questions.json
45 | ├── v2_mscoco_train2014_annotations.json
46 | └── v2_mscoco_val2014_annotations.json
47 |
48 | We use `pyarrow` to serialize the datasets, the conversion script is `write_vqa.py`.
49 | Please replace the value of `id2label` in VLE Model's config with the generated mapping file `label2answer.json` (example: `config.json` of `hfl/vle-base-for-vqa`).
50 |
51 | ## Fine-tuning VLE on VQAv2
52 |
53 | Hyperparameters for training are set in `vqa_train_config.json`.
54 |
55 | Move the training related files to the same level of the directory as `models`, as follows:
56 |
57 | root
58 | ├── models
59 | │ └── VLE
60 | | └── ...
61 | ├── run_vqav2_ft.py
62 | ├── vqav2_datamodule.py
63 | └── vqav2_train_module.py
64 |
65 | Specify the config file through `--train_config_file` and run the train script `run_vqav2_ft.py`. Here is an example:
66 |
67 | ```bash
68 | export MASTER_ADDR=$DIST_0_IP
69 | export MASTER_PORT=$DIST_0_PORT
70 | export NODE_RANK=$DIST_RANK
71 | python run_vqav2_ft.py --train_config_file=vqa_train_config.json
72 | ```
73 |
74 | ## Postprocess the checkpoint
75 |
76 | After training, we convert the saved checkpoint, so that it can be loaded by `VLEModel`.
77 |
78 | We first convert the deepspeed saved checkpoint to a pytorch checkpoint. The convert script is `zero_to_fp32.py`. If you didn't use `DeepSpeed` when training the model, this step could be skipped.
79 |
80 | ```bash
81 | python zero_to_fp32.py
82 | # for example:
83 | python zero_to_fp32.py ./logs/VQAv2_seed0_from_vle-base-ft-vqa/version_0/checkpoints/epoch\=0-step\=0.ckpt step\=0.ckpt global_step0
84 | ```
85 |
86 | Then, we convert the parameters' names to the same format as `VLEModel`. The convert script is `convert_checkpoint_after_ft.py`.
87 |
--------------------------------------------------------------------------------
/examples/VQA/convert_checkpoint_after_ft.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | pl_ckpt = "./step=0.ckpt"
4 | output_dir = './'
5 | hf_ckpt = {}
6 |
7 | loaded = torch.load(pl_ckpt,map_location='cpu')
8 | if 'state_dict' in loaded:
9 | sd = loaded['state_dict']
10 | else:
11 | sd = loaded
12 | for k,v in sd.items():
13 | if k.startswith('model.'):
14 | new_key = k.replace('model.', '')
15 | hf_ckpt[new_key] = v
16 | elif k.startswith('module.model.'):
17 | new_key = k.replace('module.model.', '')
18 | hf_ckpt[new_key] = v
19 | else:
20 | print("unhandled keys:",k)
21 |
22 | torch.save(hf_ckpt, output_dir + 'pytorch_model.bin')
--------------------------------------------------------------------------------
/examples/VQA/run_vqav2_ft.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from pytorch_lightning.plugins import DeepSpeedPlugin
3 | import torch
4 | import json
5 | import copy
6 | import os
7 | os.environ["NCCL_DEBUG"] = "INFO"
8 | import argparse
9 |
10 | from vqav2_datamodule import VQAv2DataModule
11 | from vqav2_train_module import VLEForVQA_PL
12 | from models.VLE.processing_vle import VLEProcessor
13 |
14 |
15 | def main(_config):
16 | _config = copy.deepcopy(_config)
17 | pl.seed_everything(_config["seed"], workers=True)
18 | vle_processor = VLEProcessor.from_pretrained(_config["model_dir"])
19 | if vle_processor.image_processor.size["shortest_edge"] != _config["image_size"]:
20 | vle_processor.image_processor.crop_size["height"] = _config["image_size"]
21 | vle_processor.image_processor.crop_size["width"] = _config["image_size"]
22 | vle_processor.image_processor.size["shortest_edge"] = _config["image_size"]
23 | dm = VQAv2DataModule(vle_processor, _config)
24 |
25 | model = VLEForVQA_PL(_config)
26 | exp_name = 'VQAv2'
27 |
28 | os.makedirs(_config["log_dir"], exist_ok=True)
29 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
30 | save_top_k=_config["save_top_k"],
31 | verbose=True,
32 | monitor="val/the_metric",
33 | mode="max",
34 | save_last=True,
35 | save_weights_only=_config["save_weights_only"]
36 | )
37 | logger = pl.loggers.TensorBoardLogger(
38 | _config["log_dir"],
39 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["model_dir"].split("/")[-1]}',
40 | )
41 |
42 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
43 | callbacks = [checkpoint_callback, lr_callback]
44 |
45 | num_gpus = (
46 | _config["num_gpus"]
47 | if isinstance(_config["num_gpus"], int)
48 | else len(_config["num_gpus"])
49 | )
50 |
51 | grad_steps = max(_config["batch_size"] // (
52 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
53 | ), 1)
54 |
55 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else -1
56 |
57 | if _config["use_deepspeed"]:
58 | deepspeed_config = _config["deepspeed_config"]
59 | if _config["precision"] != "16" and _config["precision"] != 16:
60 | deepspeed_config["fp16"]["enabled"] = False
61 | if _config["precision"] == "bf16":
62 | deepspeed_config["bf16"] = {"enabled": True}
63 | ds_plugin = DeepSpeedPlugin(config=deepspeed_config)
64 | strategy = ds_plugin
65 | else:
66 | strategy = "ddp"
67 |
68 | trainer = pl.Trainer(
69 | gpus=_config["num_gpus"],
70 | num_nodes=_config["num_nodes"],
71 | precision=_config["precision"],
72 | accelerator="gpu",
73 | strategy=strategy,
74 | benchmark=True,
75 | deterministic=False,
76 | max_epochs=_config["max_epoch"] if max_steps is -1 else 1000,
77 | max_steps=max_steps,
78 | callbacks=callbacks,
79 | logger=logger,
80 | prepare_data_per_node=True,
81 | replace_sampler_ddp=False,
82 | accumulate_grad_batches=grad_steps,
83 | log_every_n_steps=10,
84 | flush_logs_every_n_steps=10,
85 | resume_from_checkpoint=_config["resume_from"],
86 | weights_summary="top",
87 | fast_dev_run=_config["fast_dev_run"],
88 | val_check_interval=_config["val_check_interval"],
89 | num_sanity_val_steps=0,
90 | )
91 |
92 | if not _config["test_only"]:
93 | trainer.fit(model, datamodule=dm)
94 | torch.cuda.empty_cache()
95 | trainer.test(ckpt_path="best", datamodule=dm)
96 | else:
97 | trainer.test(model, datamodule=dm)
98 |
99 |
100 | if __name__ == '__main__':
101 | parser = argparse.ArgumentParser(description="Args for finetuning VLE on VQAv2.")
102 | parser.add_argument("--train_config_file", type=str, default="vqa_train_config.json", help="Config file for training.")
103 | args = parser.parse_args()
104 | train_config_file = args.train_config_file
105 | train_config = json.load(open(train_config_file, 'r'))
106 |
107 | main(train_config)
--------------------------------------------------------------------------------
/examples/VQA/vqa_train_config.json:
--------------------------------------------------------------------------------
1 | {
2 |
3 | "seed": 0,
4 | "model_dir": "hfl/vle-base",
5 | "data_root": "./data/vqav2/vqav2_arrow",
6 | "num_workers": 4,
7 | "batch_size": 16,
8 | "per_gpu_batchsize": 4,
9 | "image_size": 384,
10 | "max_text_len": 50,
11 | "draw_false_image": 0,
12 | "draw_false_text": 0,
13 | "image_only": false,
14 | "log_dir": "logs",
15 | "num_gpus": 1,
16 | "num_nodes": 1,
17 | "max_epoch": 10,
18 | "max_steps": -1,
19 | "precision": 16,
20 | "resume_from": "",
21 | "fast_dev_run": false,
22 | "val_check_interval": 1.0,
23 | "save_top_k": 1,
24 | "save_weights_only": true,
25 | "test_only": false,
26 | "learning_rate":1e-5,
27 | "weight_decay": 0.01,
28 | "lr_mult_head": 50,
29 | "lr_mult_cross_modal": 5,
30 | "end_lr": 0,
31 | "decay_power": 1,
32 | "optim_type": "adamw",
33 | "warmup_steps": 0.1,
34 | "use_deepspeed": true,
35 | "deepspeed_config":{
36 | "fp16": {
37 | "enabled": true,
38 | "initial_scale_power": 12,
39 | "min_loss_scale": 2e-10,
40 | "loss_scale_window": 128
41 | },
42 | "zero_optimization": {
43 | "stage": 2,
44 | "reduce_bucket_size": 5e7,
45 | "allgather_bucket_size": 1.25e9,
46 | "overlap_comm": true,
47 | "contiguous_gradients": true
48 | },
49 | "zero_allow_untested_optimizer": true
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/examples/VQA/vqav2_datamodule.py:
--------------------------------------------------------------------------------
1 | import io
2 | import pyarrow as pa
3 | import os
4 | from copy import deepcopy
5 | from PIL import Image
6 | Image.MAX_IMAGE_PIXELS = 1000000000
7 |
8 | import torch
9 | from pytorch_lightning import LightningDataModule
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data.distributed import DistributedSampler
12 |
13 |
14 | class BaseDataModule(LightningDataModule):
15 | def __init__(self, feature_processor, _config):
16 | super().__init__()
17 |
18 | self.data_dir = _config["data_root"]
19 |
20 | self.num_workers = _config["num_workers"]
21 | self.batch_size = _config["per_gpu_batchsize"]
22 | self.eval_batch_size = self.batch_size
23 |
24 | self.image_size = _config["image_size"]
25 | self.max_text_len = _config["max_text_len"]
26 | self.draw_false_image = _config["draw_false_image"]
27 | self.draw_false_text = _config["draw_false_text"]
28 | self.image_only = _config["image_only"]
29 |
30 | self.feature_processor = feature_processor
31 | self.vocab_size = self.feature_processor.tokenizer.vocab_size
32 | self.setup_flag = False
33 | self.config = _config
34 |
35 | @property
36 | def dataset_cls(self):
37 | raise NotImplementedError("return tuple of dataset class")
38 |
39 | @property
40 | def dataset_name(self):
41 | raise NotImplementedError("return name of dataset")
42 |
43 | def set_train_dataset(self):
44 | self.train_dataset = self.dataset_cls(
45 | self.data_dir,
46 | split="train",
47 | image_size=self.image_size,
48 | max_text_len=self.max_text_len,
49 | draw_false_image=self.draw_false_image,
50 | draw_false_text=self.draw_false_text,
51 | image_only=self.image_only,
52 | )
53 |
54 | def set_val_dataset(self):
55 | self.val_dataset = self.dataset_cls(
56 | self.data_dir,
57 | split="val",
58 | image_size=self.image_size,
59 | max_text_len=self.max_text_len,
60 | draw_false_image=self.draw_false_image,
61 | draw_false_text=self.draw_false_text,
62 | image_only=self.image_only,
63 | )
64 |
65 |
66 | def set_test_dataset(self):
67 | self.test_dataset = self.dataset_cls(
68 | self.data_dir,
69 | split="test",
70 | image_size=self.image_size,
71 | max_text_len=self.max_text_len,
72 | draw_false_image=self.draw_false_image,
73 | draw_false_text=self.draw_false_text,
74 | image_only=self.image_only,
75 | )
76 |
77 | def setup(self, stage):
78 | if not self.setup_flag:
79 | self.set_train_dataset()
80 | self.set_val_dataset()
81 | self.set_test_dataset()
82 |
83 | self.train_dataset.feature_processor = self.feature_processor
84 | self.val_dataset.feature_processor = self.feature_processor
85 | self.test_dataset.feature_processor = self.feature_processor
86 |
87 | self.setup_flag = True
88 |
89 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True)
90 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False)
91 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False)
92 |
93 | def train_dataloader(self):
94 | loader = DataLoader(
95 | self.train_dataset,
96 | batch_size=self.batch_size,
97 | sampler=self.train_sampler,
98 | num_workers=self.num_workers,
99 | collate_fn=self.train_dataset.collate,
100 | )
101 | return loader
102 |
103 | def val_dataloader(self):
104 | loader = DataLoader(
105 | self.val_dataset,
106 | batch_size=self.batch_size,
107 | sampler=self.val_sampler,
108 | num_workers=self.num_workers,
109 | collate_fn=self.val_dataset.collate,
110 | )
111 | return loader
112 |
113 | def test_dataloader(self):
114 | loader = DataLoader(
115 | self.test_dataset,
116 | batch_size=self.batch_size,
117 | sampler=self.test_sampler,
118 | num_workers=self.num_workers,
119 | collate_fn=self.test_dataset.collate,
120 | )
121 | return loader
122 |
123 |
124 | class VQAv2DataModule(BaseDataModule):
125 | def __init__(self, *args, **kwargs):
126 | super().__init__(*args, **kwargs)
127 |
128 | @property
129 | def dataset_cls(self):
130 | return VQAv2Dataset
131 |
132 | @property
133 | def dataset_name(self):
134 | return "vqa"
135 |
136 | def setup(self, stage):
137 | super().setup(stage)
138 |
139 |
140 | class BaseDataset(torch.utils.data.Dataset):
141 | def __init__(
142 | self,
143 | data_dir: str,
144 | image_size: int,
145 | names: list,
146 | text_column_name: str = "",
147 | remove_duplicate=True,
148 | max_text_len=40,
149 | draw_false_image=0,
150 | draw_false_text=0,
151 | image_only=False,
152 | tokenizer=None,
153 | ):
154 | """
155 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data
156 | text_column_name : pyarrow table column name that has list of strings as elements
157 | """
158 | super().__init__()
159 |
160 | self.text_column_name = text_column_name
161 | self.names = names
162 | self.max_text_len = max_text_len
163 | self.draw_false_image = draw_false_image
164 | self.draw_false_text = draw_false_text
165 | self.image_only = image_only
166 | self.data_dir = data_dir
167 | print(names)
168 | if len(names) != 0:
169 | tables = [
170 | pa.ipc.RecordBatchFileReader(
171 | pa.memory_map(f"{data_dir}/{name}.arrow", "r")
172 | ).read_all()
173 | if os.path.isfile(f"{data_dir}/{name}.arrow")
174 | else print(f"{data_dir}/{name}.arrow" + " not found.")
175 | for name in names
176 | ]
177 |
178 | self.table_names = list()
179 | for i, name in enumerate(names):
180 | self.table_names += [name] * len(tables[i])
181 |
182 | self.table = pa.concat_tables(tables, promote=True)
183 | if text_column_name != "":
184 | self.text_column_name = text_column_name
185 | self.all_texts = self.table[text_column_name].to_pandas().tolist()
186 | self.all_texts = (
187 | [list(set(texts)) for texts in self.all_texts]
188 | if remove_duplicate
189 | else self.all_texts
190 | )
191 | else:
192 | self.all_texts = list()
193 | else:
194 | self.all_texts = list()
195 |
196 | self.index_mapper = dict()
197 |
198 | if text_column_name != "" and not self.image_only:
199 | j = 0
200 | for i, texts in enumerate(self.all_texts):
201 | for _j in range(len(texts)):
202 | self.index_mapper[j] = (i, _j)
203 | j += 1
204 | else:
205 | for i in range(len(self.table)):
206 | self.index_mapper[i] = (i, None)
207 |
208 | @property
209 | def corpus(self):
210 | return [text for texts in self.all_texts for text in texts]
211 |
212 | def __len__(self):
213 | return len(self.index_mapper)
214 |
215 | def get_raw_image(self, index, image_key="image"):
216 | index, caption_index = self.index_mapper[index]
217 | image_bytes = io.BytesIO(deepcopy(self.table[image_key][index]).as_py())
218 | image_bytes.seek(0)
219 | return Image.open(image_bytes).convert("RGBA")
220 |
221 | def collate(self, batch):
222 | batch_size = len(batch)
223 | keys = set([key for b in batch for key in b.keys()])
224 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
225 | inputs_batch = {k:torch.concat([dict_batch["inputs"][i][k] for i in range(batch_size)]) for k in dict_batch["inputs"][0].keys()}
226 | dict_batch["inputs"] = inputs_batch
227 |
228 | return dict_batch
229 |
230 |
231 | class VQAv2Dataset(BaseDataset):
232 | def __init__(self, *args, split="", **kwargs):
233 | assert split in ["train", "val", "test"]
234 | self.split = split
235 |
236 | if split == "train":
237 | names = ["vqav2_train", "vqav2_trainable_val"]
238 | # names = ["vqav2_rest_val"]
239 | elif split == "val":
240 | names = ["vqav2_rest_val"]
241 | elif split == "test":
242 | names = ["vqav2_rest_val"]
243 |
244 | super().__init__(
245 | *args,
246 | **kwargs,
247 | names=names,
248 | text_column_name="questions",
249 | remove_duplicate=False,
250 | )
251 |
252 | def __getitem__(self, index):
253 | image = self.get_raw_image(index)
254 | image_index, question_index = self.index_mapper[index]
255 | text = self.all_texts[image_index][question_index]
256 | model_inputs = self.feature_processor(text=text, images=image, return_tensors="pt",padding="max_length", max_length=self.max_text_len)
257 |
258 | qid = self.table["question_id"][image_index][question_index].as_py()
259 |
260 | if self.split != "test":
261 | answers = self.table["answers"][image_index][question_index].as_py()
262 | labels = self.table["answer_labels"][image_index][question_index].as_py()
263 | scores = self.table["answer_scores"][image_index][question_index].as_py()
264 | else:
265 | answers = list()
266 | labels = list()
267 | scores = list()
268 |
269 | return {
270 | "inputs": model_inputs,
271 | "text": text,
272 | "vqa_answer": answers,
273 | "vqa_labels": labels,
274 | "vqa_scores": scores,
275 | "qid": qid,
276 | }
--------------------------------------------------------------------------------
/examples/VQA/vqav2_train_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 | import json
5 | import os
6 | import glob
7 |
8 | from torchmetrics.metric import Metric
9 | from transformers.optimization import AdamW
10 | from transformers import (
11 | get_polynomial_decay_schedule_with_warmup,
12 | get_cosine_schedule_with_warmup,
13 | )
14 |
15 | from models.VLE import VLEForVQA
16 | from models.VLE.modeling_vle import extend_position_embedding
17 |
18 |
19 | class Scalar(Metric):
20 | def __init__(self, dist_sync_on_step=False):
21 | super().__init__(dist_sync_on_step=dist_sync_on_step)
22 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum")
23 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
24 |
25 | def update(self, scalar):
26 | if isinstance(scalar, torch.Tensor):
27 | scalar = scalar.detach().to(self.scalar.device)
28 | else:
29 | scalar = torch.tensor(scalar).float().to(self.scalar.device)
30 | self.scalar += scalar.item()
31 | self.total += 1
32 |
33 | def compute(self):
34 | if self.total.item() == 0:
35 | return 0
36 | return self.scalar.item() / self.total.item()
37 |
38 | class VQAScore(Metric):
39 | def __init__(self, dist_sync_on_step=False):
40 | super().__init__(dist_sync_on_step=dist_sync_on_step)
41 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
42 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
43 |
44 | def update(self, logits, target):
45 | logits, target = (
46 | logits.detach().float().to(self.score.device),
47 | target.detach().float().to(self.score.device),
48 | )
49 | logits = torch.max(logits, 1)[1]
50 | one_hots = torch.zeros(*target.size()).to(target)
51 | one_hots.scatter_(1, logits.view(-1, 1), 1)
52 | scores = one_hots * target
53 |
54 | self.score += scores.sum().item()
55 | self.total += len(logits)
56 |
57 | def compute(self):
58 | return self.score / self.total
59 |
60 |
61 | class VLEForVQA_PL(pl.LightningModule):
62 | def __init__(self, config):
63 | super().__init__()
64 | self.save_hyperparameters()
65 |
66 | self.model = VLEForVQA.from_pretrained(config["model_dir"])
67 |
68 | if config["image_size"] != self.model.config.vision_config.image_size:
69 | patch_size = self.model.config.vision_config.patch_size
70 | position_length_after = (config["image_size"]//self.model.config.vision_config.patch_size)**2 + 1
71 | position_embed_dim = self.model.vle.vision_model.vision_model.embeddings.position_embedding.embedding_dim
72 |
73 | new_state_dict = extend_position_embedding(self.model.state_dict(), patch_size, config["image_size"])
74 | self.model.vle.vision_model.vision_model.embeddings.position_embedding = nn.Embedding(position_length_after, position_embed_dim)
75 | self.model.vle.vision_model.vision_model.embeddings.register_buffer("position_ids", torch.arange(position_length_after).expand((1, -1)))
76 | self.model.load_state_dict(new_state_dict)
77 |
78 | for split in ["train", "val"]:
79 | setattr(self, f"{split}_vqa_score", VQAScore())
80 | setattr(self, f"{split}_vqa_loss", Scalar())
81 |
82 | def forward(self, batch):
83 | ret = dict()
84 | model_inputs = batch["inputs"]
85 | model_outputs = self.model(**model_inputs,vqa_labels=batch["vqa_labels"], vqa_scores=batch["vqa_scores"], return_loss=True)
86 | vqa_logits = model_outputs["logits"]
87 | vqa_loss = model_outputs["loss"]
88 |
89 | vqa_targets = torch.zeros(vqa_logits.size()[:2]).to(vqa_logits.device)
90 | vqa_labels = batch["vqa_labels"]
91 | vqa_scores = batch["vqa_scores"]
92 | for i, (_label, _score) in enumerate(zip(vqa_labels, vqa_scores)):
93 | for l, s in zip(_label, _score):
94 | vqa_targets[i, l] = s
95 |
96 | ret = {
97 | "vqa_loss": vqa_loss,
98 | "vqa_logits": vqa_logits,
99 | "vqa_targets": vqa_targets,
100 | "vqa_labels": vqa_labels,
101 | "vqa_scores": vqa_scores,
102 | }
103 |
104 | phase = "train" if self.training else "val"
105 | loss = getattr(self, f"{phase}_vqa_loss")(ret["vqa_loss"])
106 | score = getattr(self, f"{phase}_vqa_score")(
107 | ret["vqa_logits"], ret["vqa_targets"]
108 | )
109 | self.log(f"vqa/{phase}/loss", loss, batch_size=self.hparams.config["per_gpu_batchsize"])
110 | self.log(f"vqa/{phase}/score", score, batch_size=self.hparams.config["per_gpu_batchsize"])
111 |
112 | return ret
113 |
114 | def training_step(self, batch, batch_idx):
115 | output = self(batch)
116 | total_loss = sum([v for k, v in output.items() if "loss" in k])
117 |
118 | return total_loss
119 |
120 | def training_epoch_end(self, outs):
121 | self.epoch_wrapup()
122 |
123 | def validation_step(self, batch, batch_idx):
124 | output = self(batch)
125 |
126 | def validation_epoch_end(self, outs):
127 | self.epoch_wrapup()
128 |
129 | def test_step(self, batch, batch_idx):
130 | output = self(batch)
131 | ret = dict()
132 | # update vqa answer
133 | id2label = self.model.config.id2label
134 | vqa_logits = output["vqa_logits"]
135 | vqa_preds = vqa_logits.argmax(dim=-1)
136 | vqa_preds = [id2label[pred.item()] for pred in vqa_preds]
137 | questions = batch["text"]
138 | qids = batch["qid"]
139 | ret.update({"qids": qids, "questions": questions, "preds": vqa_preds})
140 |
141 | return ret
142 |
143 | def test_epoch_end(self, outs):
144 | model_name = self.hparams.config["model_dir"].split("/")[-1]
145 | save_dir = self.trainer.logger.log_dir
146 | rank = torch.distributed.get_rank()
147 | if rank != 0:
148 | save_dir_id = int(save_dir.split('_')[-1]) - 1
149 | save_dir = '_'.join(save_dir.split('_')[:-1]+[str(save_dir_id)])
150 | qids, preds = list(), list()
151 | for out in outs:
152 | qids += out["qids"]
153 | preds += out["preds"]
154 |
155 | rets = list()
156 | for qid, pred in zip(qids, preds):
157 | rets.append({"question_id": qid, "answer": pred})
158 | with open(os.path.join(save_dir, f"vqa_submit_{rank}.json"), "w") as fp:
159 | json.dump(rets, fp, indent=4)
160 |
161 | torch.distributed.barrier()
162 |
163 | if rank == 0:
164 | jsons = list()
165 | paths = list(glob.glob(os.path.join(save_dir,"vqa_submit_*.json")))
166 | for path in paths:
167 | with open(path, "r") as fp:
168 | jsons += json.load(fp)
169 | os.makedirs(os.path.join(save_dir,"result"), exist_ok=True)
170 | with open(os.path.join(save_dir, f"result/vqa_submit_{model_name}.json"), "w") as fp:
171 | json.dump(jsons, fp, indent=4)
172 |
173 | torch.distributed.barrier()
174 | os.remove(os.path.join(save_dir, f"vqa_submit_{rank}.json"))
175 |
176 | self.epoch_wrapup(test_mode=True)
177 |
178 | def configure_optimizers(self):
179 | lr = self.hparams.config["learning_rate"]
180 | wd = self.hparams.config["weight_decay"]
181 |
182 | no_decay = [
183 | "bias",
184 | "LayerNorm.bias",
185 | "LayerNorm.weight",
186 | "norm.bias",
187 | "norm.weight",
188 | "norm1.bias",
189 | "norm1.weight",
190 | "norm2.bias",
191 | "norm2.weight",
192 | ]
193 | head_names = ["vqa_classifier"]
194 | cross_modal_names = ['cross_modal']
195 | lr_mult_head = self.hparams.config["lr_mult_head"]
196 | lr_mult_cross_modal = self.hparams.config["lr_mult_cross_modal"]
197 | end_lr = self.hparams.config["end_lr"]
198 | decay_power = self.hparams.config["decay_power"]
199 | optim_type = self.hparams.config["optim_type"]
200 | all_grad_parameters = [(n,p) for n,p in self.named_parameters()]
201 | optimizer_grouped_parameters = [
202 | {
203 | "params": [
204 | p
205 | for n, p in all_grad_parameters
206 | if not any(nd in n for nd in no_decay)
207 | and not any(bb in n for bb in head_names)
208 | and not any(ht in n for ht in cross_modal_names)
209 | ],
210 | "weight_decay": wd,
211 | "lr": lr,
212 | },
213 | {
214 | "params": [
215 | p
216 | for n, p in all_grad_parameters
217 | if any(nd in n for nd in no_decay)
218 | and not any(bb in n for bb in head_names)
219 | and not any(ht in n for ht in cross_modal_names)
220 | ],
221 | "weight_decay": 0.0,
222 | "lr": lr,
223 | },
224 | {
225 | "params": [
226 | p
227 | for n, p in all_grad_parameters
228 | if not any(nd in n for nd in no_decay)
229 | and any(bb in n for bb in head_names)
230 | and not any(ht in n for ht in cross_modal_names)
231 | ],
232 | "weight_decay": wd,
233 | "lr": lr * lr_mult_head,
234 | },
235 | {
236 | "params": [
237 | p
238 | for n, p in all_grad_parameters
239 | if any(nd in n for nd in no_decay)
240 | and any(bb in n for bb in head_names)
241 | and not any(ht in n for ht in cross_modal_names)
242 | ],
243 | "weight_decay": 0.0,
244 | "lr": lr * lr_mult_head,
245 | },
246 | {
247 | "params": [
248 | p
249 | for n, p in all_grad_parameters
250 | if not any(nd in n for nd in no_decay)
251 | and not any(bb in n for bb in head_names)
252 | and any(ht in n for ht in cross_modal_names)
253 | ],
254 | "weight_decay": wd,
255 | "lr": lr * lr_mult_cross_modal,
256 | },
257 | {
258 | "params": [
259 | p
260 | for n, p in all_grad_parameters
261 | if any(nd in n for nd in no_decay)
262 | and not any(bb in n for bb in head_names)
263 | and any(ht in n for ht in cross_modal_names)
264 | ],
265 | "weight_decay": 0.0,
266 | "lr": lr * lr_mult_cross_modal,
267 | },
268 | ]
269 |
270 | if optim_type == "adamw":
271 | optimizer = AdamW(
272 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)
273 | )
274 | elif optim_type == "adam":
275 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
276 | elif optim_type == "sgd":
277 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9)
278 |
279 | if self.trainer.max_steps is -1:
280 | max_steps = (
281 | len(self.trainer.datamodule.train_dataloader())
282 | * self.trainer.max_epochs
283 | // self.trainer.accumulate_grad_batches
284 | )
285 | else:
286 | max_steps = self.trainer.max_steps
287 |
288 | warmup_steps = self.hparams.config["warmup_steps"]
289 | if isinstance(self.hparams.config["warmup_steps"], float):
290 | warmup_steps = int(max_steps * warmup_steps)
291 |
292 | if decay_power == "cosine":
293 | scheduler = get_cosine_schedule_with_warmup(
294 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
295 | )
296 | else:
297 | scheduler = get_polynomial_decay_schedule_with_warmup(
298 | optimizer,
299 | num_warmup_steps=warmup_steps,
300 | num_training_steps=max_steps,
301 | lr_end=end_lr,
302 | power=decay_power,
303 | )
304 |
305 | sched = {"scheduler": scheduler, "interval": "step"}
306 |
307 | return (
308 | [optimizer],
309 | [sched],
310 | )
311 |
312 | def epoch_wrapup(self, test_mode=False):
313 | phase = "train" if self.training else "val"
314 | loss_name = 'vqa'
315 | value = getattr(self, f"{phase}_{loss_name}_score").compute()
316 | self.log(f"{loss_name}/{phase}/score_epoch", value)
317 | getattr(self, f"{phase}_{loss_name}_score").reset()
318 | self.log(
319 | f"{loss_name}/{phase}/loss_epoch",
320 | getattr(self, f"{phase}_{loss_name}_loss").compute(),
321 | )
322 | getattr(self, f"{phase}_{loss_name}_loss").reset()
323 |
324 | self.log(f"{phase}/the_metric", value)
--------------------------------------------------------------------------------
/examples/VQA/write_vqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import pyarrow as pa
4 | import random
5 | import os
6 |
7 | from tqdm import tqdm
8 | from glob import glob
9 | from collections import defaultdict, Counter
10 |
11 | import re
12 |
13 | contractions = {
14 | "aint": "ain't",
15 | "arent": "aren't",
16 | "cant": "can't",
17 | "couldve": "could've",
18 | "couldnt": "couldn't",
19 | "couldn'tve": "couldn't've",
20 | "couldnt've": "couldn't've",
21 | "didnt": "didn't",
22 | "doesnt": "doesn't",
23 | "dont": "don't",
24 | "hadnt": "hadn't",
25 | "hadnt've": "hadn't've",
26 | "hadn'tve": "hadn't've",
27 | "hasnt": "hasn't",
28 | "havent": "haven't",
29 | "hed": "he'd",
30 | "hed've": "he'd've",
31 | "he'dve": "he'd've",
32 | "hes": "he's",
33 | "howd": "how'd",
34 | "howll": "how'll",
35 | "hows": "how's",
36 | "Id've": "I'd've",
37 | "I'dve": "I'd've",
38 | "Im": "I'm",
39 | "Ive": "I've",
40 | "isnt": "isn't",
41 | "itd": "it'd",
42 | "itd've": "it'd've",
43 | "it'dve": "it'd've",
44 | "itll": "it'll",
45 | "let's": "let's",
46 | "maam": "ma'am",
47 | "mightnt": "mightn't",
48 | "mightnt've": "mightn't've",
49 | "mightn'tve": "mightn't've",
50 | "mightve": "might've",
51 | "mustnt": "mustn't",
52 | "mustve": "must've",
53 | "neednt": "needn't",
54 | "notve": "not've",
55 | "oclock": "o'clock",
56 | "oughtnt": "oughtn't",
57 | "ow's'at": "'ow's'at",
58 | "'ows'at": "'ow's'at",
59 | "'ow'sat": "'ow's'at",
60 | "shant": "shan't",
61 | "shed've": "she'd've",
62 | "she'dve": "she'd've",
63 | "she's": "she's",
64 | "shouldve": "should've",
65 | "shouldnt": "shouldn't",
66 | "shouldnt've": "shouldn't've",
67 | "shouldn'tve": "shouldn't've",
68 | "somebody'd": "somebodyd",
69 | "somebodyd've": "somebody'd've",
70 | "somebody'dve": "somebody'd've",
71 | "somebodyll": "somebody'll",
72 | "somebodys": "somebody's",
73 | "someoned": "someone'd",
74 | "someoned've": "someone'd've",
75 | "someone'dve": "someone'd've",
76 | "someonell": "someone'll",
77 | "someones": "someone's",
78 | "somethingd": "something'd",
79 | "somethingd've": "something'd've",
80 | "something'dve": "something'd've",
81 | "somethingll": "something'll",
82 | "thats": "that's",
83 | "thered": "there'd",
84 | "thered've": "there'd've",
85 | "there'dve": "there'd've",
86 | "therere": "there're",
87 | "theres": "there's",
88 | "theyd": "they'd",
89 | "theyd've": "they'd've",
90 | "they'dve": "they'd've",
91 | "theyll": "they'll",
92 | "theyre": "they're",
93 | "theyve": "they've",
94 | "twas": "'twas",
95 | "wasnt": "wasn't",
96 | "wed've": "we'd've",
97 | "we'dve": "we'd've",
98 | "weve": "we've",
99 | "werent": "weren't",
100 | "whatll": "what'll",
101 | "whatre": "what're",
102 | "whats": "what's",
103 | "whatve": "what've",
104 | "whens": "when's",
105 | "whered": "where'd",
106 | "wheres": "where's",
107 | "whereve": "where've",
108 | "whod": "who'd",
109 | "whod've": "who'd've",
110 | "who'dve": "who'd've",
111 | "wholl": "who'll",
112 | "whos": "who's",
113 | "whove": "who've",
114 | "whyll": "why'll",
115 | "whyre": "why're",
116 | "whys": "why's",
117 | "wont": "won't",
118 | "wouldve": "would've",
119 | "wouldnt": "wouldn't",
120 | "wouldnt've": "wouldn't've",
121 | "wouldn'tve": "wouldn't've",
122 | "yall": "y'all",
123 | "yall'll": "y'all'll",
124 | "y'allll": "y'all'll",
125 | "yall'd've": "y'all'd've",
126 | "y'alld've": "y'all'd've",
127 | "y'all'dve": "y'all'd've",
128 | "youd": "you'd",
129 | "youd've": "you'd've",
130 | "you'dve": "you'd've",
131 | "youll": "you'll",
132 | "youre": "you're",
133 | "youve": "you've",
134 | }
135 |
136 | manual_map = {
137 | "none": "0",
138 | "zero": "0",
139 | "one": "1",
140 | "two": "2",
141 | "three": "3",
142 | "four": "4",
143 | "five": "5",
144 | "six": "6",
145 | "seven": "7",
146 | "eight": "8",
147 | "nine": "9",
148 | "ten": "10",
149 | }
150 | articles = ["a", "an", "the"]
151 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
152 | comma_strip = re.compile("(\d)(\,)(\d)")
153 | punct = [
154 | ";",
155 | r"/",
156 | "[",
157 | "]",
158 | '"',
159 | "{",
160 | "}",
161 | "(",
162 | ")",
163 | "=",
164 | "+",
165 | "\\",
166 | "_",
167 | "-",
168 | ">",
169 | "<",
170 | "@",
171 | "`",
172 | ",",
173 | "?",
174 | "!",
175 | ]
176 |
177 |
178 | def normalize_word(token):
179 | _token = token
180 | for p in punct:
181 | if (p + " " in token or " " + p in token) or (
182 | re.search(comma_strip, token) != None
183 | ):
184 | _token = _token.replace(p, "")
185 | else:
186 | _token = _token.replace(p, " ")
187 | token = period_strip.sub("", _token, re.UNICODE)
188 |
189 | _token = []
190 | temp = token.lower().split()
191 | for word in temp:
192 | word = manual_map.setdefault(word, word)
193 | if word not in articles:
194 | _token.append(word)
195 | for i, word in enumerate(_token):
196 | if word in contractions:
197 | _token[i] = contractions[word]
198 | token = " ".join(_token)
199 | token = token.replace(",", "")
200 | return token
201 |
202 |
203 | def get_score(occurences):
204 | if occurences == 0:
205 | return 0.0
206 | elif occurences == 1:
207 | return 0.3
208 | elif occurences == 2:
209 | return 0.6
210 | elif occurences == 3:
211 | return 0.9
212 | else:
213 | return 1.0
214 |
215 |
216 | def path2rest(path, split, annotations, label2ans):
217 | iid = int(path.split("/")[-1].split("_")[-1][:-4])
218 |
219 | with open(path, "rb") as fp:
220 | binary = fp.read()
221 |
222 | _annot = annotations[split][iid]
223 | _annot = list(_annot.items())
224 | qids, qas = [a[0] for a in _annot], [a[1] for a in _annot]
225 | questions = [qa[0] for qa in qas]
226 | answers = [qa[1] for qa in qas] if "test" not in split else list(list())
227 | answer_labels = (
228 | [a["labels"] for a in answers] if "test" not in split else list(list())
229 | )
230 | answer_scores = (
231 | [a["scores"] for a in answers] if "test" not in split else list(list())
232 | )
233 | answers = (
234 | [[label2ans[l] for l in al] for al in answer_labels]
235 | if "test" not in split
236 | else list(list())
237 | )
238 |
239 | return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split]
240 |
241 |
242 | def make_arrow(root, dataset_root):
243 | with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp:
244 | questions_train2014 = json.load(fp)["questions"]
245 | with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp:
246 | questions_val2014 = json.load(fp)["questions"]
247 | with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp:
248 | questions_test2015 = json.load(fp)["questions"]
249 | with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp:
250 | questions_test_dev2015 = json.load(fp)["questions"]
251 |
252 | with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp:
253 | annotations_train2014 = json.load(fp)["annotations"]
254 | with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp:
255 | annotations_val2014 = json.load(fp)["annotations"]
256 |
257 | annotations = dict()
258 |
259 | for split, questions in zip(
260 | ["train", "val", "test", "test-dev"],
261 | [
262 | questions_train2014,
263 | questions_val2014,
264 | questions_test2015,
265 | questions_test_dev2015,
266 | ],
267 | ):
268 | _annot = defaultdict(dict)
269 | for q in tqdm(questions):
270 | _annot[q["image_id"]][q["question_id"]] = [q["question"]]
271 |
272 | annotations[split] = _annot
273 |
274 | all_major_answers = list()
275 |
276 | for split, annots in zip(
277 | ["train", "val"], [annotations_train2014, annotations_val2014],
278 | ):
279 | _annot = annotations[split]
280 | for q in tqdm(annots):
281 | all_major_answers.append(q["multiple_choice_answer"])
282 |
283 | all_major_answers = [normalize_word(word) for word in tqdm(all_major_answers)]
284 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9}
285 | ans2label = {k: i for i, k in enumerate(counter.keys())}
286 | label2ans = list(counter.keys())
287 |
288 | with open(os.path.join(dataset_root, "answer2label.json"), 'w', encoding='utf8') as f:
289 | json.dump(ans2label, f)
290 | with open(os.path.join(dataset_root, "label2answer.json"), 'w', encoding='utf8') as f:
291 | json.dump(label2ans, f)
292 |
293 | for split, annots in zip(
294 | ["train", "val"], [annotations_train2014, annotations_val2014],
295 | ):
296 | _annot = annotations[split]
297 | for q in tqdm(annots):
298 | answers = q["answers"]
299 | answer_count = {}
300 | for answer in answers:
301 | answer_ = answer["answer"]
302 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
303 |
304 | labels = []
305 | scores = []
306 | for answer in answer_count:
307 | if answer not in ans2label:
308 | continue
309 | labels.append(ans2label[answer])
310 | score = get_score(answer_count[answer])
311 | scores.append(score)
312 |
313 | _annot[q["image_id"]][q["question_id"]].append(
314 | {"labels": labels, "scores": scores,}
315 | )
316 |
317 | for split in ["train", "val"]:
318 | filtered_annot = dict()
319 | for ik, iv in annotations[split].items():
320 | new_q = dict()
321 | for qk, qv in iv.items():
322 | if len(qv[1]["labels"]) != 0:
323 | new_q[qk] = qv
324 | if len(new_q) != 0:
325 | filtered_annot[ik] = new_q
326 | annotations[split] = filtered_annot
327 |
328 | for split in [
329 | "train",
330 | "val",
331 | "test",
332 | "test-dev",
333 | ]:
334 | annot = annotations[split]
335 | split_name = {
336 | "train": "train2014",
337 | "val": "val2014",
338 | "test": "test2015",
339 | "test-dev": "test2015",
340 | }[split]
341 | paths = list(glob(f"{root}/{split_name}/*.jpg"))
342 | random.shuffle(paths)
343 | annot_paths = [
344 | path
345 | for path in paths
346 | if int(path.split("/")[-1].split("_")[-1][:-4]) in annot
347 | ]
348 |
349 | if len(paths) == len(annot_paths):
350 | print("all images have caption annotations")
351 | else:
352 | print("not all images have caption annotations")
353 | print(
354 | len(paths), len(annot_paths), len(annot),
355 | )
356 |
357 | bs = [
358 | path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths)
359 | ]
360 |
361 | dataframe = pd.DataFrame(
362 | bs,
363 | columns=[
364 | "image",
365 | "questions",
366 | "answers",
367 | "answer_labels",
368 | "answer_scores",
369 | "image_id",
370 | "question_id",
371 | "split",
372 | ],
373 | )
374 |
375 | table = pa.Table.from_pandas(dataframe)
376 |
377 | os.makedirs(dataset_root, exist_ok=True)
378 | with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink:
379 | with pa.RecordBatchFileWriter(sink, table.schema) as writer:
380 | writer.write_table(table)
381 |
382 | table = pa.ipc.RecordBatchFileReader(
383 | pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r")
384 | ).read_all()
385 |
386 | pdtable = table.to_pandas()
387 |
388 | df1 = pdtable[:-1000]
389 | df2 = pdtable[-1000:]
390 |
391 | df1 = pa.Table.from_pandas(df1)
392 | df2 = pa.Table.from_pandas(df2)
393 |
394 | with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink:
395 | with pa.RecordBatchFileWriter(sink, df1.schema) as writer:
396 | writer.write_table(df1)
397 |
398 | with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink:
399 | with pa.RecordBatchFileWriter(sink, df2.schema) as writer:
400 | writer.write_table(df2)
401 |
402 |
403 | if __name__ == "__main__":
404 | root = "./data/vqav2" # directory where store the raw dataset
405 | dataset_root = "./data/vqav2/vqav2_arrow" # directory where output the pyarrow format dataset
406 | make_arrow(root, dataset_root)
--------------------------------------------------------------------------------
/examples/VQA/zero_to_fp32.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
4 | # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
5 | # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
6 | # application.
7 | #
8 | # example: python zero_to_fp32.py . pytorch_model.bin
9 |
10 | import argparse
11 | import torch
12 | import glob
13 | import math
14 | import os
15 | import re
16 | from collections import OrderedDict
17 |
18 | # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
19 | # DeepSpeed data structures it has to be available in the current python environment.
20 | from deepspeed.utils import logger
21 | from deepspeed.checkpoint.constants import (DS_VERSION,
22 | OPTIMIZER_STATE_DICT,
23 | SINGLE_PARTITION_OF_FP32_GROUPS,
24 | FP32_FLAT_GROUPS,
25 | ZERO_STAGE,
26 | PARTITION_COUNT,
27 | PARAM_SHAPES,
28 | BUFFER_NAMES)
29 |
30 | debug = 0
31 |
32 | # load to cpu
33 | device = torch.device('cpu')
34 |
35 |
36 | def atoi(text):
37 | return int(text) if text.isdigit() else text
38 |
39 |
40 | def natural_keys(text):
41 | '''
42 | alist.sort(key=natural_keys) sorts in human order
43 | http://nedbatchelder.com/blog/200712/human_sorting.html
44 | (See Toothy's implementation in the comments)
45 | '''
46 | return [atoi(c) for c in re.split(r'(\d+)', text)]
47 |
48 |
49 | def get_model_state_file(checkpoint_dir, zero_stage):
50 | if not os.path.isdir(checkpoint_dir):
51 | raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
52 |
53 | # there should be only one file
54 | if zero_stage == 2:
55 | file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
56 | elif zero_stage == 3:
57 | file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
58 |
59 | if not os.path.exists(file):
60 | raise FileNotFoundError(f"can't find model states file at '{file}'")
61 |
62 | return file
63 |
64 |
65 | def get_optim_files(checkpoint_dir):
66 | # XXX: need to test that this simple glob rule works for multi-node setup too
67 | optim_files = sorted(glob.glob(os.path.join(checkpoint_dir,
68 | "*_optim_states.pt")),
69 | key=natural_keys)
70 |
71 | if len(optim_files) == 0:
72 | raise FileNotFoundError(
73 | f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
74 |
75 | return optim_files
76 |
77 |
78 | def parse_model_state(file):
79 | state_dict = torch.load(file, map_location=device)
80 |
81 | if BUFFER_NAMES not in state_dict:
82 | raise ValueError(f"{file} is not a model state checkpoint")
83 | buffer_names = state_dict[BUFFER_NAMES]
84 | if debug:
85 | print("Found buffers:", buffer_names)
86 |
87 | # recover just the buffers while restoring them to fp32 if they were saved in fp16
88 | buffers = {
89 | k: v.float()
90 | for k,
91 | v in state_dict["module"].items() if k in buffer_names
92 | }
93 | param_shapes = state_dict[PARAM_SHAPES]
94 |
95 | ds_version = state_dict.get(DS_VERSION, None)
96 |
97 | return buffers, param_shapes, ds_version
98 |
99 |
100 | def parse_optim_states(files, ds_checkpoint_dir):
101 |
102 | total_files = len(files)
103 | state_dicts = []
104 | for f in files:
105 | state_dicts.append(torch.load(f, map_location=device))
106 |
107 | if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
108 | raise ValueError(f"{files[0]} is not a zero checkpoint")
109 | zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
110 | world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
111 |
112 | # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
113 | # parameters can be different from data parallelism for non-expert parameters. So we can just
114 | # use the max of the partition_count to get the dp world_size.
115 |
116 | if type(world_size) is list:
117 | world_size = max(world_size)
118 |
119 | if world_size != total_files:
120 | raise ValueError(
121 | f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
122 | "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
123 | )
124 |
125 | # the groups are named differently in each stage
126 | if zero_stage == 2:
127 | fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
128 | elif zero_stage == 3:
129 | fp32_groups_key = FP32_FLAT_GROUPS
130 | else:
131 | raise ValueError(f"unknown zero stage {zero_stage}")
132 |
133 | if zero_stage == 2:
134 | fp32_flat_groups = [
135 | state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key]
136 | for i in range(len(state_dicts))
137 | ]
138 | elif zero_stage == 3:
139 | # if there is more than one param group, there will be multiple flattened tensors - one
140 | # flattened tensor per group - for simplicity merge them into a single tensor
141 | #
142 | # XXX: could make the script more memory efficient for when there are multiple groups - it
143 | # will require matching the sub-lists of param_shapes for each param group flattened tensor
144 |
145 | fp32_flat_groups = [
146 | torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key],
147 | 0) for i in range(len(state_dicts))
148 | ]
149 |
150 | return zero_stage, world_size, fp32_flat_groups
151 |
152 |
153 | def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
154 | """
155 | Returns fp32 state_dict reconstructed from ds checkpoint
156 |
157 | Args:
158 | - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
159 |
160 | """
161 | print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
162 |
163 | optim_files = get_optim_files(ds_checkpoint_dir)
164 | zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
165 | print(
166 | f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
167 |
168 | model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
169 | buffers, param_shapes, ds_version = parse_model_state(model_file)
170 | print(f'Parsing checkpoint created by deepspeed=={ds_version}')
171 |
172 | if zero_stage == 2:
173 | return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
174 | param_shapes,
175 | fp32_flat_groups,
176 | buffers)
177 | elif zero_stage == 3:
178 | return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
179 | param_shapes,
180 | fp32_flat_groups,
181 | buffers)
182 |
183 |
184 | def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
185 | param_shapes,
186 | fp32_flat_groups,
187 | buffers):
188 |
189 | # Reconstruction protocol:
190 | #
191 | # XXX: document this
192 |
193 | if debug:
194 | for i in range(world_size):
195 | for j in range(len(fp32_flat_groups[0])):
196 | print(
197 | f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
198 |
199 | # XXX: memory usage doubles here (zero2)
200 | num_param_groups = len(fp32_flat_groups[0])
201 | merged_single_partition_of_fp32_groups = []
202 | for i in range(num_param_groups):
203 | merged_partitions = [sd[i] for sd in fp32_flat_groups]
204 | full_single_fp32_vector = torch.cat(merged_partitions, 0)
205 | merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
206 | avail_numel = sum([
207 | full_single_fp32_vector.numel()
208 | for full_single_fp32_vector in merged_single_partition_of_fp32_groups
209 | ])
210 |
211 | if debug:
212 | wanted_params = sum([len(shapes) for shapes in param_shapes])
213 | wanted_numel = sum(
214 | [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
215 | # not asserting if there is a mismatch due to possible padding
216 | print(f"Have {avail_numel} numels to process.")
217 | print(f"Need {wanted_numel} numels in {wanted_params} params.")
218 |
219 | state_dict = OrderedDict()
220 |
221 | # buffers
222 | state_dict.update(buffers)
223 | if debug:
224 | print(f"added {len(buffers)} buffers")
225 |
226 | # params
227 | # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
228 | # out-of-core computing solution
229 | total_numel = 0
230 | total_params = 0
231 | for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
232 | offset = 0
233 | avail_numel = full_single_fp32_vector.numel()
234 | for name, shape in shapes.items():
235 |
236 | unpartitioned_numel = shape.numel()
237 | total_numel += unpartitioned_numel
238 | total_params += 1
239 |
240 | if debug:
241 | print(
242 | f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
243 | )
244 | state_dict[name] = full_single_fp32_vector.narrow(
245 | 0,
246 | offset,
247 | unpartitioned_numel).view(shape)
248 | offset += unpartitioned_numel
249 |
250 | # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
251 | # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
252 | # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
253 | # live optimizer object, so we are checking that the numbers are within the right range
254 | align_to = 2 * world_size
255 |
256 | def zero2_align(x):
257 | return align_to * math.ceil(x / align_to)
258 |
259 | if debug:
260 | print(f"original offset={offset}, avail_numel={avail_numel}")
261 |
262 | offset = zero2_align(offset)
263 | avail_numel = zero2_align(avail_numel)
264 |
265 | if debug:
266 | print(f"aligned offset={offset}, avail_numel={avail_numel}")
267 |
268 | # Sanity check
269 | if offset != avail_numel:
270 | raise ValueError(
271 | f"consumed {offset} numels out of {avail_numel} - something is wrong")
272 |
273 | print(
274 | f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
275 | )
276 |
277 | return state_dict
278 |
279 |
280 | def zero3_partitioned_param_info(unpartitioned_numel, world_size):
281 | remainder = unpartitioned_numel % world_size
282 | padding_numel = (world_size - remainder) if remainder else 0
283 | partitioned_numel = math.ceil(unpartitioned_numel / world_size)
284 | return partitioned_numel, padding_numel
285 |
286 |
287 | def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
288 | param_shapes,
289 | fp32_flat_groups,
290 | buffers):
291 |
292 | # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
293 | # param, re-consolidating each param, while dealing with padding if any
294 |
295 | avail_numel = fp32_flat_groups[0].numel() * world_size
296 | # merge list of dicts, preserving order
297 | param_shapes = {k: v for d in param_shapes for k, v in d.items()}
298 |
299 | if debug:
300 | for i in range(world_size):
301 | print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
302 |
303 | wanted_params = len(param_shapes)
304 | wanted_numel = sum(shape.numel() for shape in param_shapes.values())
305 | # not asserting if there is a mismatch due to possible padding
306 | print(f"Have {avail_numel} numels to process.")
307 | print(f"Need {wanted_numel} numels in {wanted_params} params.")
308 |
309 | state_dict = OrderedDict()
310 |
311 | # buffers
312 | state_dict.update(buffers)
313 | if debug:
314 | print(f"added {len(buffers)} buffers")
315 |
316 | # params
317 | # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
318 | # out-of-core computing solution
319 | offset = 0
320 | total_numel = 0
321 | total_params = 0
322 | for name, shape in param_shapes.items():
323 |
324 | unpartitioned_numel = shape.numel()
325 | total_numel += unpartitioned_numel
326 | total_params += 1
327 |
328 | partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
329 |
330 | if debug:
331 | print(
332 | f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
333 | )
334 |
335 | # XXX: memory usage doubles here
336 | state_dict[name] = torch.cat(
337 | tuple(fp32_flat_groups[i].narrow(0,
338 | offset,
339 | partitioned_numel)
340 | for i in range(world_size)),
341 | 0).narrow(0,
342 | 0,
343 | unpartitioned_numel).view(shape)
344 | offset += partitioned_numel
345 |
346 | offset *= world_size
347 |
348 | # Sanity check
349 | if offset != avail_numel:
350 | raise ValueError(
351 | f"consumed {offset} numels out of {avail_numel} - something is wrong")
352 |
353 | print(
354 | f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
355 | )
356 |
357 | return state_dict
358 |
359 |
360 | def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
361 | """
362 | Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
363 | ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
364 | via a model hub.
365 |
366 | Args:
367 | - ``checkpoint_dir``: path to the desired checkpoint folder
368 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
369 |
370 | Returns:
371 | - pytorch ``state_dict``
372 |
373 | Note: this approach may not work if your application doesn't have sufficient free CPU memory and
374 | you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
375 | the checkpoint.
376 |
377 | A typical usage might be ::
378 |
379 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
380 | # do the training and checkpoint saving
381 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
382 | model = model.cpu() # move to cpu
383 | model.load_state_dict(state_dict)
384 | # submit to model hub or save the model to share with others
385 |
386 | In this example the ``model`` will no longer be usable in the deepspeed context of the same
387 | application. i.e. you will need to re-initialize the deepspeed engine, since
388 | ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
389 |
390 | If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
391 |
392 | """
393 | if tag is None:
394 | latest_path = os.path.join(checkpoint_dir, 'latest')
395 | if os.path.isfile(latest_path):
396 | with open(latest_path, 'r') as fd:
397 | tag = fd.read().strip()
398 | else:
399 | raise ValueError(f"Unable to find 'latest' file at {latest_path}")
400 |
401 | ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
402 |
403 | if not os.path.isdir(ds_checkpoint_dir):
404 | raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
405 |
406 | return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
407 |
408 |
409 | def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
410 | """
411 | Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
412 | loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
413 |
414 | Args:
415 | - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
416 | - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
417 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
418 | """
419 |
420 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
421 | print(f"Saving fp32 state dict to {output_file}")
422 | torch.save(state_dict, output_file)
423 |
424 |
425 | def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
426 | """
427 | 1. Put the provided model to cpu
428 | 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
429 | 3. Load it into the provided model
430 |
431 | Args:
432 | - ``model``: the model object to update
433 | - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
434 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
435 |
436 | Returns:
437 | - ``model`: modified model
438 |
439 | Make sure you have plenty of CPU memory available before you call this function. If you don't
440 | have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
441 | conveniently placed for you in the checkpoint folder.
442 |
443 | A typical usage might be ::
444 |
445 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
446 | model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
447 | # submit to model hub or save the model to share with others
448 |
449 | Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
450 | of the same application. i.e. you will need to re-initialize the deepspeed engine, since
451 | ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
452 |
453 | """
454 | logger.info(f"Extracting fp32 weights")
455 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
456 |
457 | logger.info(f"Overwriting model with fp32 weights")
458 | model = model.cpu()
459 | model.load_state_dict(state_dict, strict=False)
460 |
461 | return model
462 |
463 |
464 | if __name__ == "__main__":
465 |
466 | parser = argparse.ArgumentParser()
467 | parser.add_argument(
468 | "checkpoint_dir",
469 | type=str,
470 | help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
471 | parser.add_argument(
472 | "output_file",
473 | type=str,
474 | help=
475 | "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
476 | )
477 | parser.add_argument(
478 | "tag",
479 | type=str,
480 | help=
481 | "checkpoint tag used as a unique identifier for checkpoint (e.g. global_step14)"
482 | )
483 | parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
484 | args = parser.parse_args()
485 |
486 | debug = args.debug
487 |
488 | print(args.checkpoint_dir)
489 | print(args.output_file)
490 | print(args.tag)
491 |
492 | convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, args.tag)
493 |
--------------------------------------------------------------------------------
/models/VLE/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_vle import (
2 | VLEModel,
3 | VLEForVQA,
4 | VLEForITM,
5 | VLEForMLM,
6 | VLEForPBC,
7 | VLEForVCRQ2A,
8 | VLEForVCRQA2R
9 | )
10 |
11 | from .configuration_vle import VLEConfig
12 | from .processing_vle import VLEProcessor
13 | from .pipeline_vle import VLEForVQAPipeline, VLEForITMPipeline, VLEForPBCPipeline, VLEForVCRQA2RPipeline, VLEForVCRQ2APipeline
14 |
--------------------------------------------------------------------------------
/models/VLE/configuration_vle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ VLE model configuration"""
16 |
17 | import copy
18 |
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 | from transformers.models.auto.configuration_auto import AutoConfig
22 | from transformers.models.clip.configuration_clip import CLIPVisionConfig
23 | from typing import Union, Dict
24 |
25 | logger = logging.get_logger(__name__)
26 |
27 |
28 | class VLEConfig(PretrainedConfig):
29 |
30 | model_type = "vle"
31 | is_composition = True
32 |
33 | def __init__(
34 | self,
35 | text_config: Union[PretrainedConfig, Dict],
36 | vision_config: Union[PretrainedConfig, Dict],
37 | num_token_types=2,
38 | hidden_size=768,
39 | num_hidden_layers=6,
40 | num_attention_heads=12,
41 | intermediate_size=3072,
42 | hidden_act="gelu",
43 | hidden_dropout_prob=0.1,
44 | attention_probs_dropout_prob=0.1,
45 | initializer_range=0.02,
46 | layer_norm_eps=1e-12,
47 | classifier_dropout=None,
48 | **kwargs):
49 | super().__init__(**kwargs)
50 |
51 | if not isinstance(text_config,PretrainedConfig):
52 | text_model_type = text_config.pop('model_type')
53 | text_config = AutoConfig.for_model(text_model_type, **text_config)
54 | self.text_config = text_config
55 |
56 | if not isinstance(vision_config, PretrainedConfig):
57 | vision_model_type = vision_config.pop('model_type')
58 | if vision_model_type == "clip":
59 | vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
60 | elif vision_model_type == "clip_vision_model":
61 | vision_config = CLIPVisionConfig(**vision_config)
62 | else:
63 | vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
64 | self.vision_config = vision_config
65 | else:
66 | vision_model_type = vision_config.model_type
67 | if vision_model_type== "clip":
68 | vision_config = vision_config.vision_config
69 | self.vision_config = vision_config
70 |
71 |
72 |
73 | # co-attention
74 | self.num_token_types=num_token_types
75 | self.hidden_size=hidden_size
76 | self.num_hidden_layers=num_hidden_layers
77 | self.num_attention_heads=num_attention_heads
78 | self.intermediate_size=intermediate_size
79 | self.hidden_act=hidden_act
80 | self.hidden_dropout_prob=hidden_dropout_prob
81 | self.attention_probs_dropout_prob=attention_probs_dropout_prob
82 | self.initializer_range=initializer_range
83 | self.layer_norm_eps=layer_norm_eps
84 | self.classifier_dropout=classifier_dropout
85 |
86 |
87 | def to_dict(self):
88 | """
89 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
90 |
91 | Returns:
92 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
93 | """
94 | output = copy.deepcopy(self.__dict__)
95 | output["vision_config"] = self.vision_config.to_dict()
96 | output["text_config"] = self.text_config.to_dict()
97 | output["model_type"] = self.__class__.model_type
98 | return output
99 |
--------------------------------------------------------------------------------
/models/VLE/modeling_vle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch VLE model."""
16 |
17 |
18 | from typing import Optional, Tuple, Union
19 |
20 | import torch
21 | from torch import nn
22 | import torch.nn.functional as F
23 | from torch.nn.utils.rnn import pad_sequence
24 |
25 | from transformers.modeling_utils import PreTrainedModel
26 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
27 | from transformers.models.auto.configuration_auto import AutoConfig
28 | from transformers.models.auto.modeling_auto import AutoModel
29 |
30 | from transformers.models.bert.modeling_bert import BertAttention, BertIntermediate, BertOutput, apply_chunking_to_forward
31 | from transformers.models.clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel
32 | from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2OnlyMLMHead
33 | from .configuration_vle import VLEConfig
34 | from dataclasses import dataclass
35 |
36 | logger = logging.get_logger(__name__)
37 |
38 | _CONFIG_FOR_DOC = "VLEConfig"
39 |
40 |
41 | @dataclass
42 | class VLEModelOutput(ModelOutput):
43 |
44 | pooler_output: torch.FloatTensor = None
45 | text_embeds: torch.FloatTensor = None
46 | image_embeds: torch.FloatTensor = None
47 |
48 |
49 | @dataclass
50 | class VLEForITMOutput(ModelOutput):
51 |
52 | loss: torch.FloatTensor = None
53 | logits: torch.FloatTensor = None
54 |
55 | @dataclass
56 | class VLEForPBCOutput(ModelOutput):
57 |
58 | loss: torch.FloatTensor = None
59 | logits: torch.FloatTensor = None
60 |
61 | @dataclass
62 | class VLEForMLMOutput(ModelOutput):
63 |
64 | loss: torch.FloatTensor = None
65 | logits: torch.FloatTensor = None
66 |
67 | @dataclass
68 | class VLEForVQAOutput(ModelOutput):
69 |
70 | loss : torch.FloatTensor = None
71 | logits: torch.FloatTensor = None
72 |
73 | @dataclass
74 | class VLEForVCRQ2AOutput(ModelOutput):
75 |
76 | loss : torch.FloatTensor = None
77 | logits: torch.FloatTensor = None
78 |
79 | @dataclass
80 | class VLEForVCRQA2ROutput(ModelOutput):
81 |
82 | loss : torch.FloatTensor = None
83 | logits: torch.FloatTensor = None
84 |
85 | class ITMHead(nn.Module):
86 | def __init__(self, hidden_size):
87 | super().__init__()
88 | self.fc = nn.Linear(hidden_size, 2)
89 |
90 | def forward(self, x):
91 | x = self.fc(x)
92 | return x
93 |
94 |
95 | def extend_position_embedding(state_dict, patch_size, after):
96 | """
97 | modify state_dict in-place for longer position embeddings
98 | """
99 | keys = {}
100 | for k,v in state_dict.items():
101 | if k.endswith('vision_model.embeddings.position_embedding.weight'):
102 | assert k not in keys
103 | keys['pe'] = (k,v)
104 | if k.endswith('vision_model.embeddings.position_ids'):
105 | assert k not in keys
106 | keys['pi'] = (k,v)
107 |
108 | pe_weight = keys['pe'][1]
109 | position_length_before = pe_weight.shape[0]
110 | embed_dim = pe_weight.shape[1]
111 | grid_before = int((position_length_before - 1)**(1/2))
112 | position_length_after = (after // patch_size) ** 2 + 1
113 | grid_after = int((position_length_after - 1)**(1/2))
114 |
115 | new_pe_weight = pe_weight[1:].reshape((grid_before,grid_before,-1))
116 | new_pe_weight = torch.nn.functional.interpolate(
117 | new_pe_weight.permute(2,0,1).unsqueeze(0),
118 | size = (grid_after,grid_after), mode = 'bicubic')
119 | new_pe_weight = new_pe_weight.squeeze(0).permute(1,2,0).reshape(grid_after*grid_after, -1)
120 | new_pe_weight = torch.cat((pe_weight[0:1],new_pe_weight), dim=0)
121 | assert new_pe_weight.shape == (grid_after*grid_after + 1, embed_dim)
122 |
123 | state_dict[keys['pe'][0]] = new_pe_weight
124 | state_dict[keys['pi'][0]] = torch.arange(grid_after*grid_after + 1).unsqueeze(0)
125 | return state_dict
126 |
127 |
128 | class Pooler(nn.Module):
129 | def __init__(self, hidden_size):
130 | super().__init__()
131 | self.dense = nn.Linear(hidden_size, hidden_size)
132 | self.activation = nn.Tanh()
133 |
134 | def forward(self, hidden_states):
135 | first_token_tensor = hidden_states[:, 0]
136 | pooled_output = self.dense(first_token_tensor)
137 | pooled_output = self.activation(pooled_output)
138 | return pooled_output
139 |
140 |
141 | class BertCrossLayer(nn.Module):
142 | def __init__(self, config):
143 | super().__init__()
144 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
145 | self.seq_len_dim = 1
146 | self.attention = BertAttention(config)
147 | self.is_decoder = config.is_decoder
148 | self.add_cross_attention = config.add_cross_attention
149 | self.crossattention = BertAttention(config)
150 | self.intermediate = BertIntermediate(config)
151 | self.output = BertOutput(config)
152 |
153 | def forward(
154 | self,
155 | hidden_states,
156 | encoder_hidden_states,
157 | attention_mask=None,
158 | encoder_attention_mask=None,
159 | output_attentions=False,
160 | ):
161 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
162 | self_attn_past_key_value = None #past_key_value[:2] if past_key_value is not None else None
163 | self_attention_outputs = self.attention(
164 | hidden_states,
165 | attention_mask,
166 | head_mask=None,
167 | output_attentions=output_attentions,
168 | past_key_value=None,
169 | )
170 | attention_output = self_attention_outputs[0]
171 |
172 | # if decoder, the last output is tuple of self-attn cache
173 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
174 |
175 | cross_attn_present_key_value = None
176 | cross_attention_outputs = self.crossattention(
177 | attention_output,
178 | attention_mask,
179 | None,
180 | encoder_hidden_states,
181 | encoder_attention_mask,
182 | None,
183 | output_attentions,
184 | )
185 | attention_output = cross_attention_outputs[0]
186 | outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
187 |
188 | layer_output = apply_chunking_to_forward(
189 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
190 | )
191 | outputs = (layer_output,) + outputs
192 |
193 | return outputs
194 |
195 | def feed_forward_chunk(self, attention_output):
196 | intermediate_output = self.intermediate(attention_output)
197 | layer_output = self.output(intermediate_output, attention_output)
198 | return layer_output
199 |
200 |
201 | class VLEPreTrainedModel(PreTrainedModel):
202 | """
203 | An abstract class to handle weights initialization.
204 | """
205 |
206 | config_class = VLEConfig
207 | base_model_prefix = "vle"
208 | supports_gradient_checkpointing = False
209 | _keys_to_ignore_on_load_missing = [r"position_ids"]
210 |
211 | def _init_weights(self, module):
212 | """Initialize the weights"""
213 | if isinstance(module, nn.Linear):
214 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
215 | if module.bias is not None:
216 | module.bias.data.zero_()
217 | elif isinstance(module, nn.Embedding):
218 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
219 | if module.padding_idx is not None:
220 | module.weight.data[module.padding_idx].zero_()
221 | elif isinstance(module, nn.LayerNorm):
222 | module.bias.data.zero_()
223 | module.weight.data.fill_(1.0)
224 | # no supported
225 | # def _set_gradient_checkpointing(self, module, value=False):
226 | # if isinstance(module, BertEncoder):
227 | # module.gradient_checkpointing = value
228 |
229 | class VLEModel(VLEPreTrainedModel):
230 | def __init__(
231 | self,
232 | config: Optional[VLEConfig] = None,
233 | vision_model: Optional[PreTrainedModel] = None,
234 | text_model: Optional[PreTrainedModel] = None,
235 | ):
236 |
237 | if config is None and (vision_model is None or text_model is None):
238 | raise ValueError("Either a configuration or an vision and a text model has to be provided")
239 |
240 | if config is None:
241 | config = VLEConfig(text_config=text_model.config, vision_config=vision_model.config)
242 | else:
243 | if not isinstance(config, self.config_class):
244 | raise ValueError(f"config: {config} has to be of type {self.config_class}")
245 |
246 | # initialize with config
247 | super().__init__(config)
248 |
249 | if vision_model is None:
250 | if isinstance(config.vision_config, CLIPVisionConfig):
251 | vision_model = CLIPVisionModel(config.vision_config)
252 | else:
253 | vision_model = AutoModel.from_config(config.vision_config)
254 |
255 | if text_model is None:
256 | text_model = AutoModel.from_config(config.text_config)
257 |
258 | self.vision_model = vision_model
259 | self.text_model = text_model
260 |
261 | # make sure that the individual model's config refers to the shared config
262 | # so that the updates to the config will be synced
263 | self.vision_model.config = self.config.vision_config
264 | self.text_model.config = self.config.text_config
265 |
266 | self.vision_embed_dim = config.vision_config.hidden_size
267 | self.text_embed_dim = config.text_config.hidden_size
268 | self.coattention_dim = config.hidden_size
269 |
270 | # add projection layers
271 | self.text_projection_layer = nn.Linear(self.text_embed_dim, self.coattention_dim)
272 | self.image_projection_layer = nn.Linear(self.vision_embed_dim, self.coattention_dim)
273 |
274 | #self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
275 | self.token_type_embeddings = nn.Embedding(config.num_token_types, config.hidden_size)
276 |
277 | self.cross_modal_image_layers = nn.ModuleList([BertCrossLayer(config) for _ in range(config.num_hidden_layers)])
278 | self.cross_modal_text_layers = nn.ModuleList([BertCrossLayer(config) for _ in range(config.num_hidden_layers)])
279 | self.cross_modal_image_pooler = Pooler(config.hidden_size)
280 | self.cross_modal_text_pooler = Pooler(config.hidden_size)
281 |
282 | # Initialize weights and apply final processing
283 | self.token_type_embeddings.apply(self._init_weights)
284 | self.cross_modal_image_layers.apply(self._init_weights)
285 | self.cross_modal_text_layers.apply(self._init_weights)
286 | self.cross_modal_image_pooler.apply(self._init_weights)
287 | self.cross_modal_text_pooler.apply(self._init_weights)
288 | if hasattr(self,"text_projection_layer"):
289 | self.text_projection_layer.apply(self._init_weights)
290 | if hasattr(self,"image_projection_layer"):
291 | self.image_projection_layer.apply(self._init_weights)
292 |
293 |
294 | def forward(
295 | self,
296 | input_ids: Optional[torch.LongTensor] = None,
297 | pixel_values: Optional[torch.FloatTensor] = None,
298 | attention_mask: Optional[torch.Tensor] = None,
299 | position_ids: Optional[torch.LongTensor] = None,
300 | token_type_ids: Optional[torch.LongTensor] = None,
301 | patch_ids = None,
302 | extend_token_type_ids = None,
303 | return_loss: Optional[bool] = None,
304 | return_dict: Optional[bool] = None,
305 | ) -> Union[Tuple[torch.Tensor], VLEModelOutput]:
306 |
307 | return_dict = return_dict if return_dict is not None else self.config.return_dict
308 |
309 | vision_outputs = self.vision_model(
310 | pixel_values=pixel_values,
311 | return_dict=return_dict,
312 | )
313 |
314 | text_outputs = self.text_model(
315 | input_ids=input_ids,
316 | attention_mask=attention_mask,
317 | token_type_ids=token_type_ids,
318 | position_ids=position_ids,
319 | return_dict=return_dict,
320 | )
321 |
322 | image_embeds = self.vision_model.vision_model.post_layernorm(vision_outputs[0]) # last_hidden_state
323 | image_embeds = self.image_projection_layer(image_embeds)
324 |
325 | text_embeds = text_outputs[0] # last_hidden_state
326 | text_embeds = self.text_projection_layer(text_embeds)
327 |
328 | if patch_ids is not None:
329 | # add box image embeddings (mean)
330 | # image_embeds : batch_size * (num_patch+1) * dims
331 | image_embeds_size_1 = image_embeds.size(1)
332 | new_image_embeds = []
333 | for item_image_embeds, item_patch_ids in zip(image_embeds, patch_ids):
334 | add_item_image_embeds = []
335 | for i_, box_patch_ids in enumerate(item_patch_ids):
336 | # skip cls embedding
337 | box_image_embeds = item_image_embeds[torch.as_tensor(box_patch_ids) + 1]
338 | box_image_embeds = torch.mean(box_image_embeds, dim=0, keepdim=True)
339 | add_item_image_embeds.append(box_image_embeds)
340 | new_image_embeds.append(torch.cat([item_image_embeds] + add_item_image_embeds))
341 | image_embeds = pad_sequence(new_image_embeds, batch_first=True)
342 |
343 | len_of_ones = torch.as_tensor([len(box_p_ids) for box_p_ids in patch_ids], dtype=torch.int64, device=image_embeds.device) + image_embeds_size_1
344 | image_mask_ones = [torch.ones(l, dtype=torch.long, device=image_embeds.device) for l in len_of_ones]
345 | image_mask_zeros = [torch.zeros(image_embeds.size(1) - l, dtype=torch.long, device=image_embeds.device) for l in len_of_ones]
346 | image_masks = torch.cat([torch.cat([one, zero]).unsqueeze(0) for one, zero in zip(image_mask_ones, image_mask_zeros)], dim=0)
347 |
348 | text_token_type_ids = pad_sequence(list(map(lambda x: torch.as_tensor(x, device=image_embeds.device, dtype=torch.long),extend_token_type_ids[1])),batch_first=True)
349 | text_token_type_ids = torch.cat([text_token_type_ids, torch.zeros(text_embeds.size(0), text_embeds.size(1) - text_token_type_ids.size(1), device=image_embeds.device, dtype=torch.long)], dim=1)
350 | image_added_token_type_ids = pad_sequence(list(map(lambda x: torch.as_tensor(x, device=image_embeds.device, dtype=torch.long),extend_token_type_ids[0])),batch_first=True)
351 | image_token_type_ids = torch.cat([torch.ones(image_embeds.size(0), image_embeds_size_1, dtype=torch.long, device=image_embeds.device), image_added_token_type_ids], dim=1)
352 | else:
353 | image_masks = torch.ones((image_embeds.size(0), image_embeds.size(1)), dtype=torch.long, device=image_embeds.device)
354 | extend_image_masks = self.text_model.get_extended_attention_mask(image_masks, image_masks.size())
355 | extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, attention_mask.size())
356 |
357 | if patch_ids is not None and extend_token_type_ids is not None:
358 | text_embeds = text_embeds + self.token_type_embeddings(text_token_type_ids)
359 | image_embeds = image_embeds + self.token_type_embeddings(image_token_type_ids)
360 | else:
361 | image_embeds = image_embeds + self.token_type_embeddings(torch.full_like(image_masks, 1))
362 | text_embeds = text_embeds + self.token_type_embeddings(torch.zeros_like(attention_mask))
363 |
364 | x, y = text_embeds, image_embeds
365 | for text_layer, image_layer in zip(self.cross_modal_text_layers, self.cross_modal_image_layers):
366 | x1 = text_layer(x, y, extend_text_masks, extend_image_masks)
367 | y1 = image_layer(y, x, extend_image_masks, extend_text_masks)
368 | x, y = x1[0], y1[0]
369 |
370 | text_embeds, image_embeds = x, y
371 | text_pooler_output = self.cross_modal_text_pooler(x)
372 | image_pooler_output = self.cross_modal_image_pooler(y)
373 | pooler_output = torch.cat([text_pooler_output, image_pooler_output], dim=-1)
374 |
375 | if not return_dict:
376 | output = (pooler_output, text_embeds, image_embeds)
377 | return output
378 | return VLEModelOutput(
379 | pooler_output = pooler_output,
380 | text_embeds = text_embeds,
381 | image_embeds = image_embeds
382 | )
383 |
384 |
385 | @classmethod
386 | def from_pretrained(cls, *args, **kwargs):
387 | # At the moment fast initialization is not supported
388 | # for composite models
389 | kwargs["_fast_init"] = False
390 | return super().from_pretrained(*args, **kwargs)
391 |
392 | @classmethod
393 | def from_vision_text_pretrained(
394 | cls,
395 | vision_model_name_or_path: str = None,
396 | text_model_name_or_path: str = None,
397 | *model_args,
398 | **kwargs,
399 | ) -> PreTrainedModel:
400 |
401 | kwargs_vision = {
402 | argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
403 | }
404 |
405 | kwargs_text = {
406 | argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
407 | }
408 |
409 | # remove vision, text kwargs from kwargs
410 | for key in kwargs_vision.keys():
411 | del kwargs["vision_" + key]
412 | for key in kwargs_text.keys():
413 | del kwargs["text_" + key]
414 |
415 | # Load and initialize the vision and text model
416 | vision_model = kwargs_vision.pop("model", None)
417 | if vision_model is None:
418 | if vision_model_name_or_path is None:
419 | raise ValueError(
420 | "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
421 | )
422 |
423 | if "config" not in kwargs_vision:
424 | vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
425 |
426 | if vision_config.model_type == "clip":
427 | kwargs_vision["config"] = vision_config.vision_config
428 | vision_model = CLIPVisionModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
429 | else:
430 | kwargs_vision["config"] = vision_config
431 | vision_model = AutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
432 |
433 | text_model = kwargs_text.pop("model", None)
434 | if text_model is None:
435 | if text_model_name_or_path is None:
436 | raise ValueError(
437 | "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
438 | )
439 |
440 | if "config" not in kwargs_text:
441 | text_config = AutoConfig.from_pretrained(text_model_name_or_path)
442 | kwargs_text["config"] = text_config
443 |
444 | text_model = AutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
445 |
446 | # instantiate config with corresponding kwargs
447 | config = VLEConfig(text_config=text_model.config, vision_config=vision_model.config, **kwargs)
448 |
449 | # init model
450 | model = cls(config=config, vision_model=vision_model, text_model=text_model)
451 |
452 | # the projection layers are always newly initialized when loading the model
453 | # using pre-trained vision and text model.
454 | logger.warning(
455 | "The coattention layers and projection layers are newly initialized. You should probably TRAIN this model on a down-stream task to be"
456 | " able to use it for predictions and inference."
457 | )
458 | return model
459 |
460 |
461 | def get_text_features(
462 | self,
463 | input_ids=None,
464 | attention_mask=None,
465 | position_ids=None,
466 | token_type_ids=None,
467 | output_attentions=None,
468 | output_hidden_states=None,
469 | return_dict=None,
470 | ):
471 | text_outputs = self.text_model(
472 | input_ids=input_ids,
473 | attention_mask=attention_mask,
474 | position_ids=position_ids,
475 | token_type_ids=token_type_ids,
476 | #output_attentions=output_attentions,
477 | #output_hidden_states=output_hidden_states,
478 | return_dict=return_dict,
479 | )
480 | return text_outputs[0] # last_hidden_state
481 |
482 | def get_image_features(
483 | self,
484 | pixel_values=None,
485 | output_attentions=None,
486 | output_hidden_states=None,
487 | return_dict=None,
488 | ):
489 | r"""
490 | Returns:
491 | image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
492 | applying the projection layer to the pooled output of [`CLIPVisionModel`].
493 |
494 | Examples:
495 |
496 | ```python
497 | >>> from PIL import Image
498 | >>> import requests
499 | >>> from transformers import VLEModel, AutoImageProcessor
500 |
501 | >>> model = VLEModel.from_pretrained("clip-italian/clip-italian")
502 | >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
503 |
504 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
505 | >>> image = Image.open(requests.get(url, stream=True).raw)
506 |
507 | >>> inputs = image_processor(images=image, return_tensors="pt")
508 |
509 | >>> image_features = model.get_image_features(**inputs)
510 | ```"""
511 | vision_outputs = self.vision_model(
512 | pixel_values=pixel_values,
513 | #output_attentions=output_attentions,
514 | #output_hidden_states=output_hidden_states,
515 | return_dict=return_dict,
516 | )
517 | last_hidden_state = self.vision_model.vision_model.post_layernorm(vision_outputs[0])
518 | return last_hidden_state
519 | def get_input_embeddings(self):
520 | return self.text_model.embeddings.word_embeddings
521 |
522 | def set_input_embeddings(self, new_embeddings):
523 | self.text_model.embeddings.word_embeddings = new_embeddings
524 |
525 | class VLEForVQA(VLEPreTrainedModel):
526 | def __init__(
527 | self,
528 | config: Optional[VLEConfig] = None,
529 | vision_model: Optional[PreTrainedModel] = None,
530 | text_model: Optional[PreTrainedModel] = None,
531 | ):
532 | super().__init__(config)
533 | self.vle = VLEModel(config, vision_model, text_model)
534 |
535 | hidden_size = config.hidden_size
536 | self.num_vqa_labels = len(self.config.id2label)
537 | self.vqa_classifier = nn.Sequential(
538 | nn.Linear(hidden_size * 2, hidden_size * 2),
539 | nn.LayerNorm(hidden_size * 2),
540 | nn.GELU(),
541 | nn.Linear(hidden_size * 2, self.num_vqa_labels),
542 | )
543 | self.vqa_classifier.apply(self._init_weights)
544 |
545 | def forward(self,
546 | input_ids: Optional[torch.LongTensor],
547 | pixel_values: Optional[torch.FloatTensor],
548 | attention_mask: Optional[torch.Tensor] = None,
549 | position_ids: Optional[torch.LongTensor] = None,
550 | token_type_ids: Optional[torch.LongTensor] = None,
551 | patch_ids = None,
552 | vqa_labels = None,
553 | vqa_scores = None,
554 | return_loss: Optional[bool] = None,
555 | return_dict: Optional[bool] = None,
556 | ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]:
557 |
558 | return_dict = return_dict if return_dict is not None else self.config.return_dict
559 |
560 | vle_output = self.vle(
561 | input_ids = input_ids,
562 | pixel_values = pixel_values,
563 | attention_mask = attention_mask,
564 | position_ids = position_ids,
565 | token_type_ids = token_type_ids,
566 | patch_ids = patch_ids,)
567 | pooler_output = vle_output[0]
568 | vqa_logits = self.vqa_classifier(pooler_output)
569 |
570 |
571 | vqa_loss = None
572 | if return_loss and vqa_labels is not None and vqa_scores is not None:
573 | vqa_targets = torch.zeros(len(vqa_logits), self.num_vqa_labels,device=vqa_logits.device)
574 | for i, (_label, _score) in enumerate(zip(vqa_labels, vqa_scores)):
575 | for l, s in zip(_label, _score):
576 | vqa_targets[i, l] = s
577 | vqa_loss = F.binary_cross_entropy_with_logits(vqa_logits, vqa_targets) * vqa_targets.shape[1]
578 | # https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19
579 |
580 | if not return_dict:
581 | output = (vqa_logits,)
582 | return ((vqa_loss,) + output) if vqa_loss is not None else output
583 | return VLEForVQAOutput(
584 | loss = vqa_loss,
585 | logits = vqa_logits
586 | )
587 |
588 |
589 | class VLEForITM(VLEPreTrainedModel):
590 | def __init__(
591 | self,
592 | config: Optional[VLEConfig] = None,
593 | vision_model: Optional[PreTrainedModel] = None,
594 | text_model: Optional[PreTrainedModel] = None,
595 | ):
596 | super().__init__(config)
597 | self.vle = VLEModel(config, vision_model, text_model)
598 |
599 | hidden_size = config.hidden_size
600 | self.itm_score = ITMHead(hidden_size*2)
601 | self.itm_score.apply(self._init_weights)
602 |
603 | def forward(self,
604 | input_ids: Optional[torch.LongTensor],
605 | pixel_values: Optional[torch.FloatTensor],
606 | attention_mask: Optional[torch.Tensor] = None,
607 | position_ids: Optional[torch.LongTensor] = None,
608 | token_type_ids: Optional[torch.LongTensor] = None,
609 | patch_ids = None,
610 | itm_labels = None,
611 | return_loss: Optional[bool] = None,
612 | return_dict: Optional[bool] = None,
613 | ) -> Union[Tuple[torch.Tensor], VLEForITMOutput]:
614 |
615 | return_dict = return_dict if return_dict is not None else self.config.return_dict
616 |
617 | vle_output = self.vle(
618 | input_ids = input_ids,
619 | pixel_values = pixel_values,
620 | attention_mask = attention_mask,
621 | position_ids = position_ids,
622 | token_type_ids = token_type_ids,
623 | patch_ids = patch_ids,)
624 | pooler_output = vle_output[0]
625 |
626 | itm_logits = self.itm_score(pooler_output)
627 | itm_loss = None
628 | if return_loss and itm_labels is not None:
629 | itm_loss = nn.functional.cross_entropy(itm_logits, torch.tensor(itm_labels).long().to(itm_logits.device))
630 | if not return_dict:
631 | output = (itm_logits,)
632 | return ((itm_loss,) + output) if itm_loss is not None else output
633 | return VLEForITMOutput(loss = itm_loss, logits = itm_logits)
634 |
635 |
636 | class VLEForPBC(VLEPreTrainedModel):
637 | def __init__(
638 | self,
639 | config: Optional[VLEConfig] = None,
640 | vision_model: Optional[PreTrainedModel] = None,
641 | text_model: Optional[PreTrainedModel] = None,
642 | ):
643 | super().__init__(config)
644 | self.vle = VLEModel(config, vision_model, text_model)
645 |
646 | hidden_size = config.hidden_size
647 | self.pbc_classifier = nn.Sequential(
648 | nn.Linear(hidden_size, hidden_size),
649 | nn.LayerNorm(hidden_size),
650 | nn.GELU(),
651 | nn.Linear(hidden_size, 2),
652 | )
653 | self.pbc_classifier.apply(self._init_weights)
654 |
655 | def forward(self,
656 | input_ids: Optional[torch.LongTensor],
657 | pixel_values: Optional[torch.FloatTensor],
658 | attention_mask: Optional[torch.Tensor] = None,
659 | position_ids: Optional[torch.LongTensor] = None,
660 | token_type_ids: Optional[torch.LongTensor] = None,
661 | patch_ids = None,
662 | pbc_labels = None,
663 | return_loss: Optional[bool] = None,
664 | return_dict: Optional[bool] = None,
665 | ) -> Union[Tuple[torch.Tensor], VLEForPBCOutput]:
666 |
667 | return_dict = return_dict if return_dict is not None else self.config.return_dict
668 |
669 | vle_output = self.vle(
670 | input_ids = input_ids,
671 | pixel_values = pixel_values,
672 | attention_mask = attention_mask,
673 | position_ids = position_ids,
674 | token_type_ids = token_type_ids,
675 | patch_ids = patch_ids,)
676 | image_embeds = vle_output['image_embeds']
677 | pbc_logits = self.pbc_classifier(image_embeds[:,1:,:])
678 |
679 | pbc_loss = None
680 | if return_loss and pbc_labels is not None:
681 | pbc_loss = F.cross_entropy(pbc_logits, torch.tensor(pbc_labels).long().to(pbc_logits.device))
682 |
683 | if not return_dict:
684 | output = (pbc_logits,)
685 | return ((pbc_loss,) + output) if pbc_loss is not None else output
686 | return VLEForPBCOutput(loss = pbc_loss, logits = pbc_logits)
687 |
688 |
689 | class VLEForMLM(VLEPreTrainedModel):
690 | _keys_to_ignore_on_load_missing = [r"mlm_score.1.predictions.decoder.weight",r"mlm_score.1.predictions.decoder.bias"]
691 | def __init__(
692 | self,
693 | config: Optional[VLEConfig] = None,
694 | vision_model: Optional[PreTrainedModel] = None,
695 | text_model: Optional[PreTrainedModel] = None,
696 | ):
697 | super().__init__(config)
698 | self.vle = VLEModel(config, vision_model, text_model)
699 |
700 | hidden_size = config.hidden_size
701 | mlm_head = DebertaV2OnlyMLMHead(self.config.text_config)
702 | mlm_transform = nn.Linear(hidden_size, self.config.text_config.hidden_size)
703 | self.mlm_score = nn.Sequential(
704 | mlm_transform,
705 | mlm_head,
706 | )
707 |
708 | def forward(self,
709 | input_ids: Optional[torch.LongTensor],
710 | pixel_values: Optional[torch.FloatTensor],
711 | attention_mask: Optional[torch.Tensor] = None,
712 | position_ids: Optional[torch.LongTensor] = None,
713 | token_type_ids: Optional[torch.LongTensor] = None,
714 | patch_ids = None,
715 | mlm_labels = None,
716 | return_loss: Optional[bool] = None,
717 | return_dict: Optional[bool] = None,
718 | ) -> Union[Tuple[torch.Tensor], VLEForMLMOutput]:
719 |
720 | return_dict = return_dict if return_dict is not None else self.config.return_dict
721 |
722 | vle_output = self.vle(
723 | input_ids = input_ids,
724 | pixel_values = pixel_values,
725 | attention_mask = attention_mask,
726 | position_ids = position_ids,
727 | token_type_ids = token_type_ids,
728 | patch_ids = patch_ids,)
729 | text_feats = vle_output.text_embeds
730 |
731 | mlm_logits = self.mlm_score(text_feats)
732 | mlm_loss = None
733 | if return_loss and mlm_labels is not None:
734 | mlm_loss = F.cross_entropy(
735 | mlm_logits.view(-1, self.config.text_config.vocab_size),
736 | mlm_labels.view(-1),
737 | ignore_index=-100,
738 | )
739 | if not return_dict:
740 | output = (mlm_logits,)
741 | return ((mlm_loss,) + output) if mlm_loss is not None else output
742 | return VLEForMLMOutput(loss = mlm_loss, logits = mlm_logits)
743 |
744 |
745 | def get_output_embeddings(self):
746 | return self.mlm_score[1].predictions.decoder
747 |
748 | def set_output_embeddings(self, new_embeddings):
749 | self.mlm_score[1].predictions.decoder = new_embeddings
750 |
751 |
752 | class VLEForVCRQ2A(VLEPreTrainedModel):
753 | def __init__(
754 | self,
755 | config: Optional[VLEConfig] = None,
756 | vision_model: Optional[PreTrainedModel] = None,
757 | text_model: Optional[PreTrainedModel] = None,
758 | ):
759 | super().__init__(config)
760 | self.vle = VLEModel(config, vision_model, text_model)
761 |
762 | hidden_size = config.hidden_size
763 | self.vcr_q2a_logit = nn.Sequential(
764 | nn.Linear(hidden_size * 2, hidden_size * 2),
765 | nn.LayerNorm(hidden_size * 2),
766 | nn.GELU(),
767 | nn.Linear(hidden_size * 2, 1),
768 | )
769 | self.vcr_q2a_logit.apply(self._init_weights)
770 |
771 | def forward(self,
772 | input_ids: Optional[torch.LongTensor],
773 | pixel_values: Optional[torch.FloatTensor],
774 | attention_mask: Optional[torch.Tensor] = None,
775 | position_ids: Optional[torch.LongTensor] = None,
776 | token_type_ids: Optional[torch.LongTensor] = None,
777 | patch_ids = None,
778 | vcr_labels = None,
779 | extend_token_type_ids = None,
780 | return_loss: Optional[bool] = None,
781 | return_dict: Optional[bool] = None,
782 | ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]:
783 |
784 | return_dict = return_dict if return_dict is not None else self.config.return_dict
785 |
786 | infers = []
787 | for i in range(4):
788 | vle_output = self.vle(
789 | input_ids = input_ids[i],
790 | pixel_values = pixel_values,
791 | attention_mask = attention_mask[i],
792 | position_ids = position_ids,
793 | token_type_ids = token_type_ids[i],
794 | patch_ids = patch_ids[i],
795 | extend_token_type_ids = extend_token_type_ids[i])
796 | pooler_output = vle_output[0]
797 | logits = self.vcr_q2a_logit(pooler_output)
798 | infers.append(logits)
799 |
800 | vcr_logits = torch.cat(infers, dim=-1)
801 | vcr_loss = None
802 | if return_loss and vcr_labels is not None:
803 | vcr_targets = torch.zeros(len(vcr_logits), dtype=torch.long).to(self.device)
804 | for i, _label in enumerate(vcr_labels):
805 | vcr_targets[i] = _label
806 | vcr_loss = F.cross_entropy(vcr_logits, vcr_targets.view(-1))
807 |
808 | if not return_dict:
809 | output = (vcr_logits,)
810 | return ((vcr_loss,) + output) if vcr_loss is not None else output
811 | return VLEForVCRQ2AOutput(
812 | loss = vcr_loss,
813 | logits = vcr_logits
814 | )
815 |
816 |
817 | class VLEForVCRQA2R(VLEPreTrainedModel):
818 | def __init__(
819 | self,
820 | config: Optional[VLEConfig] = None,
821 | vision_model: Optional[PreTrainedModel] = None,
822 | text_model: Optional[PreTrainedModel] = None,
823 | ):
824 | super().__init__(config)
825 | self.vle = VLEModel(config, vision_model, text_model)
826 |
827 | hidden_size = config.hidden_size
828 | self.vcr_qa2r_logit = nn.Sequential(
829 | nn.Linear(hidden_size * 2, hidden_size * 2),
830 | nn.LayerNorm(hidden_size * 2),
831 | nn.GELU(),
832 | nn.Linear(hidden_size * 2, 1),
833 | )
834 | self.vcr_qa2r_logit.apply(self._init_weights)
835 |
836 | def forward(self,
837 | input_ids: Optional[torch.LongTensor],
838 | pixel_values: Optional[torch.FloatTensor],
839 | attention_mask: Optional[torch.Tensor] = None,
840 | position_ids: Optional[torch.LongTensor] = None,
841 | token_type_ids: Optional[torch.LongTensor] = None,
842 | patch_ids = None,
843 | vcr_labels = None,
844 | extend_token_type_ids = None,
845 | return_loss: Optional[bool] = None,
846 | return_dict: Optional[bool] = None,
847 | ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]:
848 |
849 | return_dict = return_dict if return_dict is not None else self.config.return_dict
850 |
851 | infers = []
852 | for i in range(4):
853 | vle_output = self.vle(
854 | input_ids = input_ids[i],
855 | pixel_values = pixel_values,
856 | attention_mask = attention_mask[i],
857 | position_ids = position_ids,
858 | token_type_ids = token_type_ids[i],
859 | patch_ids = patch_ids[i],
860 | extend_token_type_ids = extend_token_type_ids[i])
861 | pooler_output = vle_output[0]
862 | logits = self.vcr_qa2r_logit(pooler_output)
863 | infers.append(logits)
864 |
865 | vcr_logits = torch.cat(infers, dim=-1)
866 | vcr_loss = None
867 | if return_loss and vcr_labels is not None:
868 | vcr_targets = torch.zeros(len(vcr_logits), dtype=torch.long).to(self.device)
869 | for i, _label in enumerate(vcr_labels):
870 | vcr_targets[i] = _label
871 | vcr_loss = F.cross_entropy(vcr_logits, vcr_targets.view(-1))
872 |
873 | if not return_dict:
874 | output = (vcr_logits,)
875 | return ((vcr_loss,) + output) if vcr_loss is not None else output
876 | return VLEForVCRQA2ROutput(
877 | loss = vcr_loss,
878 | logits = vcr_logits
879 | )
880 |
--------------------------------------------------------------------------------
/models/VLE/pipeline_vle.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import Pipeline, BatchEncoding
3 | from PIL import Image
4 | from typing import Union
5 | from copy import deepcopy
6 | import matplotlib.pyplot as plt
7 | import io
8 | import json
9 | import unicodedata
10 | import os
11 |
12 | class VLEForVQAPipeline(Pipeline):
13 |
14 | def __init__(self, vle_processor, *args, **kwargs):
15 | self.vle_processor = vle_processor
16 | super().__init__(*args, **kwargs)
17 |
18 | def _sanitize_parameters(self, top_k=None, **kwargs):
19 | preprocess_params, forward_params, postprocess_params = {}, {}, {}
20 | if top_k is not None:
21 | postprocess_params["top_k"] = top_k
22 | return preprocess_params, forward_params, postprocess_params
23 |
24 | def __call__(self, image: Union["Image.Image", str], question: str = None, **kwargs):
25 |
26 | if isinstance(image, (Image.Image, str)) and isinstance(question, str):
27 | inputs = {"image": image, "question": question}
28 | else:
29 | """
30 | Supports the following format
31 | - {"image": image, "question": question}
32 | - [{"image": image, "question": question}]
33 | - Generator and datasets
34 | """
35 | inputs = image
36 | results = super().__call__(inputs, **kwargs)
37 | return results
38 |
39 | def preprocess(self, inputs):
40 | model_inputs = self.vle_processor(text=inputs['question'], images=inputs['image'], return_tensors="pt",padding=True)
41 | return model_inputs
42 |
43 | def _forward(self, model_inputs):
44 | model_outputs = self.model(**model_inputs)
45 | return model_outputs
46 |
47 | def postprocess(self, model_outputs, top_k=1):
48 | if top_k > self.model.num_vqa_labels:
49 | top_k = self.model.num_vqa_labels
50 | probs = torch.softmax(model_outputs['logits'], dim=-1)
51 | probs, preds = torch.sort(probs, descending=True)
52 | probs = probs[:,:top_k].tolist()[0]
53 | preds = preds[:,:top_k].tolist()[0]
54 |
55 | return [{"score": score, "answer": self.model.config.id2label[pred]} for score, pred in zip(probs, preds)]
56 |
57 |
58 |
59 | class VLEForPBCPipeline(Pipeline):
60 | def __init__(self, vle_processor, *args, **kwargs):
61 | self.vle_processor = vle_processor
62 | self.id2label = {0:"False",1:"True"}
63 | super().__init__(*args, **kwargs)
64 |
65 | def _sanitize_parameters(self, **kwargs):
66 | preprocess_params, forward_params, postprocess_params = {}, {}, {}
67 | return preprocess_params, forward_params, postprocess_params
68 |
69 | def __call__(self, image: Union["Image.Image", str], text: str = None, **kwargs):
70 | if isinstance(image, (Image.Image, str)) and isinstance(text, str):
71 | inputs = {"image": image, "text": text}
72 | else:
73 | """
74 | Supports the following format
75 | - {"image": image, "text": text}
76 | - [{"image": image, "text": text}]
77 | - Generator and datasets
78 | """
79 | inputs = image
80 | results = super().__call__(inputs, **kwargs)
81 | return results
82 |
83 | def preprocess(self, inputs):
84 | model_inputs = self.vle_processor(text=inputs['text'], images=inputs['image'], return_tensors="pt",padding=True)
85 | return model_inputs, inputs['image']
86 |
87 | def _forward(self, model_inputs):
88 | model_outputs = self.model(**model_inputs[0])
89 | return model_outputs, model_inputs[1]
90 |
91 | def postprocess(self, model_outputs):
92 | probs = torch.softmax(model_outputs[0]['logits'], dim=-1)
93 | probs = probs.tolist()[0]
94 | new_image = self.paint_in_image(model_outputs[0]['logits'], model_outputs[1])
95 | return {"score": probs, "image": new_image}
96 |
97 | def paint_in_image(self, logits, raw_image):
98 | image_back = deepcopy(raw_image)
99 | raw_image_size = image_back.size
100 | resized_image_size = self.model.config.vision_config.image_size
101 | patch_size = self.model.config.vision_config.patch_size
102 | probs = torch.softmax(logits.detach()[0,:,1].to('cpu'),dim=-1).numpy().reshape(-1, resized_image_size//patch_size)
103 |
104 | plt.close('all')
105 | plt.axis('off')
106 | plt.imshow(probs, cmap='gray', interpolation='None', vmin=(probs.max()-probs.min())*2/5+probs.min(),alpha=0.7)
107 | plt.xticks([])
108 | plt.yticks([])
109 | buf = io.BytesIO()
110 | plt.savefig(buf, dpi=100, transparent=True, bbox_inches='tight', pad_inches=0)
111 | image_front = Image.open(buf)
112 |
113 | def filter_image_front(img: Image.Image):
114 | width, height = img.width, img.height
115 | for x in range(width):
116 | for y in range(height):
117 | r,g,b,a = img.getpixel((x,y))
118 | a = int (a * (1-r/255))
119 | img.putpixel((x,y), (r,g,b,a))
120 | return img
121 |
122 | image_front = filter_image_front(image_front).resize(raw_image_size)
123 | image_back.paste(image_front, (0,0), image_front)
124 | mixed_image = image_back.resize(raw_image_size)
125 | buf.close()
126 |
127 | return mixed_image
128 |
129 |
130 |
131 | class VLEForITMPipeline(Pipeline):
132 | def __init__(self, vle_processor, *args, **kwargs):
133 | self.vle_processor = vle_processor
134 | self.id2label = {0:"False",1:"True"}
135 | super().__init__(*args, **kwargs)
136 |
137 | def _sanitize_parameters(self, **kwargs):
138 | preprocess_params, forward_params, postprocess_params = {}, {}, {}
139 | return preprocess_params, forward_params, postprocess_params
140 |
141 | def __call__(self, image: Union["Image.Image", str], text: str = None, **kwargs):
142 | if isinstance(image, (Image.Image, str)) and isinstance(text, str):
143 | inputs = {"image": image, "text": text}
144 | else:
145 | """
146 | Supports the following format
147 | - {"image": image, "text": text}
148 | - [{"image": image, "text": text}]
149 | - Generator and datasets
150 | """
151 | inputs = image
152 | results = super().__call__(inputs, **kwargs)
153 | return results
154 |
155 | def preprocess(self, inputs):
156 | model_inputs = self.vle_processor(text=inputs['text'], images=inputs['image'], return_tensors="pt",padding=True)
157 | return model_inputs
158 |
159 | def _forward(self, model_inputs):
160 | model_outputs = self.model(**model_inputs)
161 | return model_outputs
162 |
163 | def postprocess(self, model_outputs):
164 | probs = torch.softmax(model_outputs['logits'], dim=-1)
165 | preds = torch.argmax(probs, dim=-1)
166 | probs = probs.tolist()[0]
167 | preds = self.id2label[preds.tolist()[0]]
168 |
169 | return {"score": probs, "match": preds}
170 |
171 |
172 | class VLEForVCRQ2APipeline(Pipeline):
173 |
174 | def __init__(self, vle_processor, *args, **kwargs):
175 | self.vle_processor = vle_processor
176 | self.vle_tokenizer = self.vle_processor.tokenizer
177 | self.GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall',
178 | 'Payton', 'Skyler', 'Frankie', 'Pat', 'Quinn']
179 | self.person_name_id = 0
180 | self.max_text_len = 80
181 | super().__init__(*args, **kwargs)
182 |
183 | def _sanitize_parameters(self, **kwargs):
184 | preprocess_params, forward_params, postprocess_params = {}, {}, {}
185 | return preprocess_params, forward_params, postprocess_params
186 |
187 | def __call__(self, vcr_image_root: str, meta_inputs: dict, **kwargs):
188 |
189 | inputs = {"vcr_image_root": vcr_image_root, "meta_inputs": meta_inputs}
190 | results = super().__call__(inputs, **kwargs)
191 | return results
192 |
193 | def preprocess(self, inputs):
194 | model_inputs = self.vcr_q2a_preprocess(inputs["vcr_image_root"], inputs["meta_inputs"])
195 | return model_inputs
196 |
197 | def _forward(self, model_inputs):
198 | model_outputs = self.model(**model_inputs)
199 | return model_outputs
200 |
201 | def postprocess(self, model_outputs, top_k=1):
202 | logits = model_outputs["logits"]
203 | loss = model_outputs["loss"]
204 | preds = torch.argmax(logits, dim=-1, keepdim=True)
205 |
206 | return [{"score": score, "pred": pred} for score, pred in zip(logits, preds)]
207 |
208 | def vcr_q2a_preprocess(self, vcr_image_root, data):
209 | image_fn = data["img_fn"]
210 | objects = data["objects"]
211 | metadata_fn = data["metadata_fn"]
212 | question = data["question"]
213 | answer_choices = data["answer_choices"]
214 | rationale_choices = data["rationale_choices"]
215 | answer_label = data["answer_label"]
216 | rationale_label = data["rationale_label"]
217 |
218 | question_text = question
219 | answer_text = answer_choices
220 | text_tokens, text_ids, obj_tags, text_raw = self.build_text(question_text, answer_text, objects, self.vle_tokenizer)
221 | encoding = [self.vle_tokenizer(
222 | ''.join(self.vle_tokenizer.convert_ids_to_tokens(text_ids_)),
223 | padding="max_length",
224 | truncation=True,
225 | max_length=self.max_text_len,
226 | return_special_tokens_mask=True,
227 | ) for text_ids_ in text_ids]
228 |
229 | obj_tags = [[-1] + tags + [-1] for tags in obj_tags]
230 |
231 | image = Image.open(os.path.join(vcr_image_root, image_fn))
232 | image_feature = self.vle_processor.image_processor(image)
233 | width = image.size[0]
234 | height = image.size[1]
235 | with open(os.path.join(vcr_image_root, metadata_fn), 'r') as f:
236 | vcr_image_metadata = json.load(f)
237 | boxes = vcr_image_metadata['boxes']
238 | patch_boxes = self.get_patch_box(boxes, width, height, self.model.config.vision_config.patch_size, self.model.config.vision_config.image_size)
239 | related_box_ids, image_added_token_type_ids, text_token_type_ids = list(zip(*[self.get_related_box_ids_and_token_type_ids(tags) for tags in obj_tags]))
240 | related_patch_boxes = [[patch_boxes[i] for i in related_box_ids[j]] for j in range(4)]
241 |
242 | processed_data = {
243 | "image": image_feature,
244 | "text": (text_raw, encoding),
245 | "obj_tags": obj_tags,
246 | "label": answer_label,
247 | "patch_ids": [[patch_box[1] for patch_box in related_patch_boxes[i]] for i in range(4)],
248 | "extend_token_type_ids": [(image_added_token_type_ids[i], text_token_type_ids[i]) for i in range(4)],
249 | }
250 |
251 | model_inputs = {
252 | "input_ids": [[processed_data["text"][1][i]["input_ids"]] for i in range(4)],
253 | "attention_mask": [[processed_data["text"][1][i]["attention_mask"]] for i in range(4)],
254 | "token_type_ids": [[processed_data["text"][1][i]["token_type_ids"]] for i in range(4)],
255 | "pixel_values": torch.Tensor([processed_data["image"]["pixel_values"][0]]),
256 | }
257 | model_inputs = BatchEncoding(model_inputs, tensor_type='pt')
258 | model_inputs.update({
259 | "patch_ids": [[processed_data["patch_ids"][i]] for i in range(4)],
260 | "vcr_labels": [processed_data["label"]],
261 | "extend_token_type_ids": [list(zip(processed_data["extend_token_type_ids"][i])) for i in range(4)],
262 | "return_loss": True,
263 | })
264 | return model_inputs
265 |
266 | def retokenize_and_convert_to_ids_with_tag(self, raw_tokens, objects_replace_name, tokenizer, non_obj_tag=-1, add_space_b4_first_token=False):
267 | parsed_tokens = []
268 | tags = []
269 | align_ids = []
270 | raw = []
271 | align_id = 0
272 | for idx, mixed_token in enumerate(raw_tokens):
273 | if isinstance(mixed_token, list):
274 | tokens = [" " + objects_replace_name[o] for o in mixed_token]
275 | if idx == 0 and not add_space_b4_first_token:
276 | tokens[0] = tokens[0].lstrip()
277 | retokenized_tokens = tokenizer.tokenize(tokens[0])
278 | raw.append(tokens[0])
279 | tags.extend([mixed_token[0] + non_obj_tag + 1 for _ in retokenized_tokens])
280 | align_ids.extend([align_id for _ in retokenized_tokens])
281 | align_id += 1
282 | for token, o in zip(tokens[1:], mixed_token[1:]):
283 | retokenized_tokens.append(tokenizer.tokenize(' and')[0])
284 | tags.append(non_obj_tag)
285 | align_ids.append(align_id)
286 | align_id += 1
287 | re_tokens = tokenizer.tokenize(token)
288 | retokenized_tokens.extend(re_tokens)
289 | tags.extend([o + non_obj_tag + 1 for _ in re_tokens])
290 | align_ids.extend([align_id for _ in re_tokens])
291 | align_id += 1
292 | raw.extend([' and', token])
293 | parsed_tokens.extend(retokenized_tokens)
294 | else:
295 | # fully align to original tokens
296 | if True in [unicodedata.category(str_) == 'Co' for str_ in mixed_token]:
297 | continue
298 | if idx != 0 or add_space_b4_first_token:
299 | mixed_token = " " + mixed_token
300 | raw.append(mixed_token)
301 | retokenized_tokens = tokenizer.tokenize(mixed_token)
302 | parsed_tokens.extend(retokenized_tokens)
303 | align_ids.extend([align_id for _ in retokenized_tokens])
304 | tags.extend([non_obj_tag for _ in retokenized_tokens])
305 | align_id += 1
306 | ids = tokenizer.convert_tokens_to_ids(parsed_tokens)
307 | ids_with_tag = list(zip(parsed_tokens, ids, tags, align_ids))
308 |
309 | return ids_with_tag, raw
310 |
311 | def build_text(self, question_text, answer_text, objects, tokenizer):
312 | objects_replace_name = []
313 | for o in objects:
314 | if o == 'person':
315 | objects_replace_name.append(self.GENDER_NEUTRAL_NAMES[self.person_name_id])
316 | self.person_name_id = (self.person_name_id + 1) % len(self.GENDER_NEUTRAL_NAMES)
317 | else:
318 | objects_replace_name.append(o)
319 |
320 | non_obj_tag = -1
321 | question_text, question_text_raw = self.retokenize_and_convert_to_ids_with_tag(question_text, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag)
322 | answer_text = [self.retokenize_and_convert_to_ids_with_tag(a_t, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) for a_t in answer_text]
323 | answer_text, answer_text_raw = list(zip(*answer_text))
324 | for a_t, a_t_raw in zip(answer_text, answer_text_raw):
325 | while len(question_text) + len(a_t) > self.max_text_len - 3:
326 | if len(question_text) > len(a_t):
327 | question_text.pop()
328 | else:
329 | a_t.pop()
330 |
331 | text_tokens = [[q_t[0] for q_t in question_text] + [self.vle_tokenizer.sep_token] + [a_t_t[0] for a_t_t in a_t] for a_t in answer_text]
332 | text_ids = [[q_t[1] for q_t in question_text] + [self.vle_tokenizer.sep_token_id] + [a_t_t[1] for a_t_t in a_t] for a_t in answer_text]
333 | obj_tags = [[q_t[2] for q_t in question_text] + [-1] + [a_t_t[2] for a_t_t in a_t] for a_t in answer_text]
334 | text_raw = [question_text_raw + answer_text_raw_ for answer_text_raw_ in answer_text_raw]
335 |
336 | return text_tokens, text_ids, obj_tags, text_raw
337 |
338 | def get_patch_box(self, boxes, width, height, patch_size, image_size):
339 | patch_count_w = image_size // patch_size
340 | patch_count_h = image_size // patch_size
341 | patch_width = width / patch_count_w
342 | patch_height = height / patch_count_h
343 |
344 | patch_boxes = []
345 | for box in boxes:
346 | box = box[:4]
347 | patch_x1 = int(box[0] // patch_width)
348 | patch_y1 = int(box[1] // patch_height)
349 | patch_x2 = int(box[2] // patch_width)
350 | patch_y2 = int(box[3] // patch_height)
351 |
352 | patch_x1 = patch_x1 if patch_x1 >= 0 else 0
353 | patch_y1 = patch_y1 if patch_y1 >= 0 else 0
354 | patch_x2 = patch_x2 + 1 if patch_x2 < patch_count_w else patch_count_w
355 | patch_y2 = patch_y2 + 1 if patch_y2 < patch_count_h else patch_count_h
356 |
357 | patch_box = [
358 | patch_x1 * patch_width,
359 | patch_y1 * patch_height,
360 | patch_x2 * patch_width,
361 | patch_y2 * patch_height
362 | ]
363 |
364 | patch_ids = [patch_count_w * y + x for y in range(patch_y1, patch_y2) for x in range(patch_x1, patch_x2)]
365 | patch_boxes.append([patch_box, patch_ids])
366 |
367 | return patch_boxes
368 |
369 | def get_related_box_ids_and_token_type_ids(self, obj_tags):
370 | no_obj_tag = -1
371 | obj_tags_set = set()
372 | for tag in obj_tags:
373 | if tag != no_obj_tag:
374 | obj_tags_set.add(tag)
375 | obj_tag_remap = {t: i + 2 for i, t in enumerate(obj_tags_set)}
376 | text_token_type_ids = [obj_tag_remap[tag] if tag != no_obj_tag else 0 for tag in obj_tags]
377 | related_box_ids = list(obj_tag_remap.keys())
378 | image_added_token_type_ids = list(obj_tag_remap.values())
379 |
380 | return related_box_ids, image_added_token_type_ids, text_token_type_ids
381 |
382 |
383 | class VLEForVCRQA2RPipeline(Pipeline):
384 |
385 | def __init__(self, vle_processor, *args, **kwargs):
386 | self.vle_processor = vle_processor
387 | self.vle_tokenizer = self.vle_processor.tokenizer
388 | self.GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall',
389 | 'Payton', 'Skyler', 'Frankie', 'Pat', 'Quinn']
390 | self.person_name_id = 0
391 | self.max_text_len = 80
392 | super().__init__(*args, **kwargs)
393 |
394 | def _sanitize_parameters(self, **kwargs):
395 | preprocess_params, forward_params, postprocess_params = {}, {}, {}
396 | return preprocess_params, forward_params, postprocess_params
397 |
398 | def __call__(self, vcr_image_root: str, meta_inputs: dict, **kwargs):
399 |
400 | inputs = {"vcr_image_root": vcr_image_root, "meta_inputs": meta_inputs}
401 | results = super().__call__(inputs, **kwargs)
402 | return results
403 |
404 | def preprocess(self, inputs):
405 | model_inputs = self.vcr_qa2r_preprocess(inputs["vcr_image_root"], inputs["meta_inputs"])
406 | return model_inputs
407 |
408 | def _forward(self, model_inputs):
409 | model_outputs = self.model(**model_inputs)
410 | return model_outputs
411 |
412 | def postprocess(self, model_outputs, top_k=1):
413 | logits = model_outputs["logits"]
414 | loss = model_outputs["loss"]
415 | preds = torch.argmax(logits, dim=-1, keepdim=True)
416 |
417 | return [{"score": score, "pred": pred} for score, pred in zip(logits, preds)]
418 |
419 | def vcr_qa2r_preprocess(self, vcr_image_root, data):
420 | image_fn = data["img_fn"]
421 | objects = data["objects"]
422 | metadata_fn = data["metadata_fn"]
423 | question = data["question"]
424 | answer_choices = data["answer_choices"]
425 | rationale_choices = data["rationale_choices"]
426 | answer_label = data["answer_label"]
427 | rationale_label = data["rationale_label"]
428 |
429 | question_text = question
430 | answer_text = answer_choices[answer_label]
431 | rationale_text = rationale_choices
432 | text_tokens, text_ids, obj_tags, text_raw = self.build_text(question_text, answer_text, rationale_text, objects, self.vle_tokenizer)
433 | encoding = [self.vle_tokenizer(
434 | ''.join(self.vle_tokenizer.convert_ids_to_tokens(text_ids_)),
435 | padding="max_length",
436 | truncation=True,
437 | max_length=self.max_text_len,
438 | return_special_tokens_mask=True,
439 | ) for text_ids_ in text_ids]
440 |
441 | obj_tags = [[-1] + tags + [-1] for tags in obj_tags]
442 |
443 | image = Image.open(os.path.join(vcr_image_root, image_fn))
444 | image_feature = self.vle_processor.image_processor(image)
445 | width = image.size[0]
446 | height = image.size[1]
447 | with open(os.path.join(vcr_image_root, metadata_fn), 'r') as f:
448 | vcr_image_metadata = json.load(f)
449 | boxes = vcr_image_metadata['boxes']
450 | patch_boxes = self.get_patch_box(boxes, width, height, self.model.config.vision_config.patch_size, self.model.config.vision_config.image_size)
451 | related_box_ids, image_added_token_type_ids, text_token_type_ids = list(zip(*[self.get_related_box_ids_and_token_type_ids(tags) for tags in obj_tags]))
452 | related_patch_boxes = [[patch_boxes[i] for i in related_box_ids[j]] for j in range(4)]
453 |
454 | processed_data = {
455 | "image": image_feature,
456 | "text": (text_raw, encoding),
457 | "obj_tags": obj_tags,
458 | "label": rationale_label,
459 | "patch_ids": [[patch_box[1] for patch_box in related_patch_boxes[i]] for i in range(4)],
460 | "extend_token_type_ids": [(image_added_token_type_ids[i], text_token_type_ids[i]) for i in range(4)],
461 | }
462 |
463 | model_inputs = {
464 | "input_ids": [[processed_data["text"][1][i]["input_ids"]] for i in range(4)],
465 | "attention_mask": [[processed_data["text"][1][i]["attention_mask"]] for i in range(4)],
466 | "token_type_ids": [[processed_data["text"][1][i]["token_type_ids"]] for i in range(4)],
467 | "pixel_values": torch.Tensor([processed_data["image"]["pixel_values"][0]]),
468 | }
469 | model_inputs = BatchEncoding(model_inputs, tensor_type='pt')
470 | model_inputs.update({
471 | "patch_ids": [[processed_data["patch_ids"][i]] for i in range(4)],
472 | "vcr_labels": [processed_data["label"]],
473 | "extend_token_type_ids": [list(zip(processed_data["extend_token_type_ids"][i])) for i in range(4)],
474 | "return_loss": True,
475 | })
476 | return model_inputs
477 |
478 | def retokenize_and_convert_to_ids_with_tag(self, raw_tokens, objects_replace_name, tokenizer, non_obj_tag=-1, add_space_b4_first_token=False):
479 | parsed_tokens = []
480 | tags = []
481 | align_ids = []
482 | raw = []
483 | align_id = 0
484 | for idx, mixed_token in enumerate(raw_tokens):
485 | if isinstance(mixed_token, list):
486 | tokens = [" " + objects_replace_name[o] for o in mixed_token]
487 | if idx == 0 and not add_space_b4_first_token:
488 | tokens[0] = tokens[0].lstrip()
489 | retokenized_tokens = tokenizer.tokenize(tokens[0])
490 | raw.append(tokens[0])
491 | tags.extend([mixed_token[0] + non_obj_tag + 1 for _ in retokenized_tokens])
492 | align_ids.extend([align_id for _ in retokenized_tokens])
493 | align_id += 1
494 | for token, o in zip(tokens[1:], mixed_token[1:]):
495 | retokenized_tokens.append(tokenizer.tokenize(' and')[0])
496 | tags.append(non_obj_tag)
497 | align_ids.append(align_id)
498 | align_id += 1
499 | re_tokens = tokenizer.tokenize(token)
500 | retokenized_tokens.extend(re_tokens)
501 | tags.extend([o + non_obj_tag + 1 for _ in re_tokens])
502 | align_ids.extend([align_id for _ in re_tokens])
503 | align_id += 1
504 | raw.extend([' and', token])
505 | parsed_tokens.extend(retokenized_tokens)
506 | else:
507 | # fully align to original tokens
508 | if True in [unicodedata.category(str_) == 'Co' for str_ in mixed_token]:
509 | continue
510 | if idx != 0 or add_space_b4_first_token:
511 | mixed_token = " " + mixed_token
512 | raw.append(mixed_token)
513 | retokenized_tokens = tokenizer.tokenize(mixed_token)
514 | parsed_tokens.extend(retokenized_tokens)
515 | align_ids.extend([align_id for _ in retokenized_tokens])
516 | tags.extend([non_obj_tag for _ in retokenized_tokens])
517 | align_id += 1
518 | ids = tokenizer.convert_tokens_to_ids(parsed_tokens)
519 | ids_with_tag = list(zip(parsed_tokens, ids, tags, align_ids))
520 |
521 | return ids_with_tag, raw
522 |
523 | def build_text(self, question_text, answer_text, rationale_text, objects, tokenizer):
524 | objects_replace_name = []
525 | for o in objects:
526 | if o == 'person':
527 | objects_replace_name.append(self.GENDER_NEUTRAL_NAMES[self.person_name_id])
528 | self.person_name_id = (self.person_name_id + 1) % len(self.GENDER_NEUTRAL_NAMES)
529 | else:
530 | objects_replace_name.append(o)
531 |
532 | non_obj_tag = -1
533 | question_text, question_text_raw = self.retokenize_and_convert_to_ids_with_tag(question_text, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag)
534 | answer_text, answer_text_raw = self.retokenize_and_convert_to_ids_with_tag(answer_text, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag)
535 | rationale_text = [self.retokenize_and_convert_to_ids_with_tag(r_t, objects_replace_name, tokenizer, non_obj_tag=non_obj_tag) for r_t in rationale_text]
536 | rationale_text, rationale_text_raw = list(zip(*rationale_text))
537 | for r_t, r_t_raw in zip(rationale_text, rationale_text_raw):
538 | while len(question_text) + len(answer_text) + len(r_t) > self.max_text_len - 4:
539 | if len(r_t) > len(question_text) + len(answer_text):
540 | r_t.pop()
541 | elif len(question_text) > 1:
542 | question_text.pop()
543 | else:
544 | answer_text.pop()
545 |
546 | text_tokens = [[q_t[0] for q_t in question_text] + [tokenizer.sep_token] + [a_t[0] for a_t in answer_text] + [tokenizer.sep_token] + [r_t_t[0] for r_t_t in r_t] for r_t in rationale_text]
547 | text_ids = [[q_t[1] for q_t in question_text] + [tokenizer.sep_token_id] + [a_t[1] for a_t in answer_text] + [tokenizer.sep_token_id] + [r_t_t[1] for r_t_t in r_t] for r_t in rationale_text]
548 | obj_tags = [[q_t[2] for q_t in question_text] + [-1] + [a_t[2] for a_t in answer_text] + [-1] + [r_t_t[2] for r_t_t in r_t] for r_t in rationale_text]
549 | text_raw = [question_text_raw + answer_text_raw + rationale_text_raw_ for rationale_text_raw_ in rationale_text_raw]
550 |
551 | return text_tokens, text_ids, obj_tags, text_raw
552 |
553 | def get_patch_box(self, boxes, width, height, patch_size, image_size):
554 | patch_count_w = image_size // patch_size
555 | patch_count_h = image_size // patch_size
556 | patch_width = width / patch_count_w
557 | patch_height = height / patch_count_h
558 |
559 | patch_boxes = []
560 | for box in boxes:
561 | box = box[:4]
562 | patch_x1 = int(box[0] // patch_width)
563 | patch_y1 = int(box[1] // patch_height)
564 | patch_x2 = int(box[2] // patch_width)
565 | patch_y2 = int(box[3] // patch_height)
566 |
567 | patch_x1 = patch_x1 if patch_x1 >= 0 else 0
568 | patch_y1 = patch_y1 if patch_y1 >= 0 else 0
569 | patch_x2 = patch_x2 + 1 if patch_x2 < patch_count_w else patch_count_w
570 | patch_y2 = patch_y2 + 1 if patch_y2 < patch_count_h else patch_count_h
571 |
572 | patch_box = [
573 | patch_x1 * patch_width,
574 | patch_y1 * patch_height,
575 | patch_x2 * patch_width,
576 | patch_y2 * patch_height
577 | ]
578 |
579 | patch_ids = [patch_count_w * y + x for y in range(patch_y1, patch_y2) for x in range(patch_x1, patch_x2)]
580 | patch_boxes.append([patch_box, patch_ids])
581 |
582 | return patch_boxes
583 |
584 | def get_related_box_ids_and_token_type_ids(self, obj_tags):
585 | no_obj_tag = -1
586 | obj_tags_set = set()
587 | for tag in obj_tags:
588 | if tag != no_obj_tag:
589 | obj_tags_set.add(tag)
590 | obj_tag_remap = {t: i + 2 for i, t in enumerate(obj_tags_set)}
591 | text_token_type_ids = [obj_tag_remap[tag] if tag != no_obj_tag else 0 for tag in obj_tags]
592 | related_box_ids = list(obj_tag_remap.keys())
593 | image_added_token_type_ids = list(obj_tag_remap.values())
594 |
595 | return related_box_ids, image_added_token_type_ids, text_token_type_ids
--------------------------------------------------------------------------------
/models/VLE/processing_vle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Processor class for VLE
17 | """
18 |
19 | import warnings
20 |
21 | from transformers.processing_utils import ProcessorMixin
22 | from transformers.tokenization_utils_base import BatchEncoding
23 |
24 |
25 | class VLEProcessor(ProcessorMixin):
26 | r"""
27 | Constructs a VLE processor which wraps an image processor and a tokenizer into a single
28 | processor.
29 |
30 | [`VLEProcessor`] offers all the functionalities of [`AutoImageProcessor`] and [`AutoTokenizer`].
31 | See the [`~VLEProcessor.__call__`] and [`~VLEProcessor.decode`] for more
32 | information.
33 |
34 | Args:
35 | image_processor ([`AutoImageProcessor`]):
36 | The image processor is a required input.
37 | tokenizer ([`PreTrainedTokenizer`]):
38 | The tokenizer is a required input.
39 | """
40 | attributes = ["image_processor", "tokenizer"]
41 | image_processor_class = "CLIPImageProcessor"
42 | tokenizer_class = "DebertaV2Tokenizer"
43 |
44 | def __init__(self, image_processor=None, tokenizer=None, **kwargs):
45 | if "feature_extractor" in kwargs:
46 | warnings.warn(
47 | "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
48 | " instead.",
49 | FutureWarning,
50 | )
51 | feature_extractor = kwargs.pop("feature_extractor")
52 |
53 | image_processor = image_processor if image_processor is not None else feature_extractor
54 | if image_processor is None:
55 | raise ValueError("You need to specify an `image_processor`.")
56 | if tokenizer is None:
57 | raise ValueError("You need to specify a `tokenizer`.")
58 |
59 | super().__init__(image_processor, tokenizer)
60 | self.current_processor = self.image_processor
61 |
62 | def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
63 | """
64 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
65 | and `kwargs` arguments to VLETokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not
66 | `None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
67 | AutoImageProcessor's [`~AutoImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
68 | of the above two methods for more information.
69 |
70 | Args:
71 | text (`str`, `List[str]`, `List[List[str]]`):
72 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
73 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
74 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
75 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
76 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
77 | tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
78 | number of channels, H and W are image height and width.
79 |
80 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
81 | If set, will return tensors of a particular framework. Acceptable values are:
82 |
83 | - `'tf'`: Return TensorFlow `tf.constant` objects.
84 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
85 | - `'np'`: Return NumPy `np.ndarray` objects.
86 | - `'jax'`: Return JAX `jnp.ndarray` objects.
87 |
88 | Returns:
89 | [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
90 |
91 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
92 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
93 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
94 | `None`).
95 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
96 | """
97 |
98 | if text is None and images is None:
99 | raise ValueError("You have to specify either text or images. Both cannot be none.")
100 |
101 | if text is not None:
102 | encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
103 |
104 | if images is not None:
105 | image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
106 |
107 | if text is not None and images is not None:
108 | encoding["pixel_values"] = image_features.pixel_values
109 | return encoding
110 | elif text is not None:
111 | return encoding
112 | else:
113 | return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
114 |
115 | def batch_decode(self, *args, **kwargs):
116 | """
117 | This method forwards all its arguments to VLETokenizer's
118 | [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
119 | """
120 | return self.tokenizer.batch_decode(*args, **kwargs)
121 |
122 | def decode(self, *args, **kwargs):
123 | """
124 | This method forwards all its arguments to VLETokenizer's [`~PreTrainedTokenizer.decode`].
125 | Please refer to the docstring of this method for more information.
126 | """
127 | return self.tokenizer.decode(*args, **kwargs)
128 |
129 | @property
130 | def model_input_names(self):
131 | tokenizer_input_names = self.tokenizer.model_input_names
132 | image_processor_input_names = self.image_processor.model_input_names
133 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
134 |
135 | @property
136 | def feature_extractor_class(self):
137 | warnings.warn(
138 | "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
139 | FutureWarning,
140 | )
141 | return self.image_processor_class
142 |
143 | @property
144 | def feature_extractor(self):
145 | warnings.warn(
146 | "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
147 | FutureWarning,
148 | )
149 | return self.image_processor
150 |
--------------------------------------------------------------------------------
/pics/VQALLM_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/VQALLM_workflow.png
--------------------------------------------------------------------------------
/pics/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/banner.png
--------------------------------------------------------------------------------
/pics/birds.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/birds.jpg
--------------------------------------------------------------------------------
/pics/demo-banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/demo-banner.png
--------------------------------------------------------------------------------
/pics/dogs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/dogs.png
--------------------------------------------------------------------------------
/pics/door.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/door.png
--------------------------------------------------------------------------------
/pics/fishing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/fishing.png
--------------------------------------------------------------------------------
/pics/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/model.png
--------------------------------------------------------------------------------
/pics/pink_tongues.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/pink_tongues.png
--------------------------------------------------------------------------------
/pics/qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/qrcode.jpg
--------------------------------------------------------------------------------
/pics/truck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iflytek/VLE/e322bb97b4ecd2c1eed11959416aed640a4bf76a/pics/truck.png
--------------------------------------------------------------------------------