├── .gitignore
├── README.md
├── assets
├── poster-v2.pdf
└── spec_overview.png
├── docs
└── evaluate_custom_model.md
├── notebooks
├── evaluate_example_colab.ipynb
├── evaluate_example_local.ipynb
├── explore_spec_colab.ipynb
└── explore_spec_local.ipynb
├── setup.py
└── spec
├── __init__.py
├── dataset.py
├── eval.py
├── models
├── __init__.py
├── base_wrapper.py
├── blip_utils
│ ├── .ipynb_checkpoints
│ │ └── README-checkpoint.md
│ ├── README.md
│ ├── __pycache__
│ │ ├── blip.cpython-38.pyc
│ │ ├── blip_retrieval.cpython-38.pyc
│ │ ├── med.cpython-38.pyc
│ │ └── vit.cpython-38.pyc
│ ├── blip.py
│ ├── blip_itm.py
│ ├── blip_pretrain.py
│ ├── blip_retrieval.py
│ ├── med.py
│ ├── utils.py
│ └── vit.py
├── blip_wrapper.py
├── clip_wrapper.py
└── flava_wrapper.py
└── run_eval.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
SPEC: Synthesize, Diagnose, and Optimize: Towards Fine-Grained Vision-Language Understanding
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
23 |
24 |
25 | † Corresponding author
26 |
27 |
28 |
29 |
30 | ## :fire: News
31 | * `Apr. 14, 2024` We have released a [preview](https://wjpoom.github.io/preview/) of a more advanced dataset version, the full version will come soon.
32 | * `Apr. 13, 2024` We released the SPEC dataset and the code for evaluation, sorry for the delay :relaxed:.
33 | * `Feb. 28, 2024` Our work has been accepted by [CVPR 2024](https://cvpr.thecvf.com/) :tada:.
34 |
35 | ## :rocket: A more advanced version is coming!
36 | We are building a new version with a larger data scale, more object categories, and higher-quality images and text, and more.
37 | You can preview it at [this website](https://wjpoom.github.io/preview/), and the full version will come soon.
38 |
39 | ## :mag: SPEC Benchmark
40 | To evaluate the understanding capability of visual-language models on fine-grained concepts, we propose a new benchmark, SPEC,
41 | which consists of six distinct subsets, distributed across the dimensions of **S**ize, **P**osition, **E**xistence, and **C**ount.
42 | Each test case consists of an image candidate set, which differs only in certain visual concepts, and a text candidate set,
43 | which differs only in the corresponding language concept.
44 |
45 |
46 |
47 |
48 |
49 | ## :wrench: Usage
50 | ### install
51 | ``` shell
52 | git clone https://github.com/wjpoom/SPEC.git
53 | cd SPEC/
54 | pip install -e .
55 | ```
56 | ### prepare data
57 | * run the following code in Python shell, replace `/path/to/save/data` with a specified dir to store the data.
58 | ```python
59 | import zipfile
60 | import os
61 | from huggingface_hub import hf_hub_download
62 |
63 | data_root = '/path/to/save/data'
64 | hf_hub_download(repo_id='wjpoom/SPEC', repo_type='dataset', filename='data.zip', local_dir=data_root)
65 |
66 | with zipfile.ZipFile(os.path.join(data_root, 'data.zip'), 'r') as zip_ref:
67 | zip_ref.extractall(os.path.join(data_root))
68 |
69 | os.remove(os.path.join(data_root, 'data.zip'))
70 | ```
71 | ### explore the dataset
72 | * We provide a 📓notebook that enables you to visually explore the test samples in the SPEC dataset.
73 | * Run this notebook either [locally](https://github.com/wjpoom/SPEC/blob/main/notebooks/explore_spec_local.ipynb) or online using [Colab](https://colab.research.google.com/github/wjpoom/SPEC/blob/main/notebooks/explore_spec_colab.ipynb).
74 |
75 | ### reproduce the results
76 | * In our paper, we evaluated four popular VLMs using our SPEC dataset, namely: CLIP, BLIP, FLAVA and CoCa.
77 | * To reproduce the results with these VLMs, you can run [this script](https://github.com/wjpoom/SPEC/blob/main/spec/run_eval.sh).
78 | * You can also reproduce with this [local notebook](https://github.com/wjpoom/SPEC/blob/main/notebooks/evaluate_example_local.ipynb) or the online [Colab notebook](https://colab.research.google.com/github/wjpoom/SPEC/blob/main/notebooks/evaluate_example_colab.ipynb).
79 |
80 | ### evaluate custom VLMs
81 | * If you want to evaluate your custom model on SPEC, you can follow the instructions in [this document](https://github.com/wjpoom/SPEC/blob/main/docs/evaluate_custom_model.md).
82 |
83 | ## :memo: TODO
84 | - [ ] Release the newly built version of the dataset
85 | - [ ] Release the code of our data synthesize pipeline
86 | - [x] Release the testing set of SPEC benchmark
87 | - [x] Release the evaluation code of SPEC
88 |
89 | ## :clap: Acknowledgement
90 | Part of this repository is built upon [ARO](https://github.com/mertyg/vision-language-models-are-bows), thanks for the well-organized codebase.
91 |
92 | ## Contact Us
93 | Feel free to contact us if you have any questions or suggestions
94 |
95 | Email (Wujian Peng): wjpeng24@m.fudan.edu.cn
96 |
97 | ## :black_nib: Citation
98 | If you use our code or data in this repo or find our work helpful, please consider giving a citation:
99 |
100 | ``` bibtex
101 | @inproceedings{peng2024synthesize,
102 | title={Synthesize diagnose and optimize: Towards fine-grained vision-language understanding},
103 | author={Peng, Wujian and Xie, Sicheng and You, Zuyao and Lan, Shiyi and Wu, Zuxuan},
104 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
105 | pages={13279--13288},
106 | year={2024}
107 | }
108 | ```
109 |
--------------------------------------------------------------------------------
/assets/poster-v2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/assets/poster-v2.pdf
--------------------------------------------------------------------------------
/assets/spec_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/assets/spec_overview.png
--------------------------------------------------------------------------------
/docs/evaluate_custom_model.md:
--------------------------------------------------------------------------------
1 | # Evaluate Custom Vision Language Model on SPEC
2 | We have implemented the testing code for four VLMs, namely [CLIP](https://arxiv.org/abs/2103.00020), [BLIP](https://arxiv.org/abs/2201.12086), [FLAVA](https://arxiv.org/abs/2112.04482), and [CoCa](https://arxiv.org/abs/2205.01917).
3 | If you want to test other custom models, you need to complete the following steps.
4 |
5 | ## Step 1. Implement custom `model wrapper`
6 | Firstly, create a file named `custom_wrapper.py` under [`spec/models/`](https://github.com/wjpoom/SPEC/tree/main/spec/models). Next, define your own `CustomWrapper` class within this file.
7 | `CustomWrapper` needs to inherit from the base class [`BaseWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/base_wrapper.py#L6)
8 | and implement the `i2t_evaluate` and `t2i_evaluate` methods, you can also add any other methods you need. Your `CustomWrapper` should look like the following:
9 | ```python
10 | class CustomWrapper(BaseWrapper):
11 | def __init__(self):
12 | pass
13 | @torch.no_grad()
14 | def i2t_evaluate(self, subset_name, dataloader):
15 | pass
16 | @torch.no_grad()
17 | def t2i_evaluate(self, subset_name, dataloader):
18 | pass
19 | ```
20 | **Note**: take care of the return format of `i2t_evaluate` and `t2i_evaluate`. Please refer to instances in [`CLIPWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/clip_wrapper.py#L9C7-L9C18), [`BLIPWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/blip_wrapper.py#L30) or
21 | [`FLAVAWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/flava_wrapper.py#L8) when implementing your code.
22 |
23 | ## Step 2. Add your model in `get_model()`
24 | We defined a method named `get_model()` in [`spec/models/__init__.py`](https://github.com/wjpoom/SPEC/blob/main/spec/models/__init__.py),
25 | which handles model loading. You need to add the code to load your custom model within this function,
26 | simply add the following code block at the end:
27 | ```python
28 | elif model_name == CUSTOM_MODEL_NAME:
29 | from .custom_wrapper import CUSTOMWrapper
30 | model = CUSTOMWrapper(...)
31 | image_preprocess = model.image_preprocess
32 | return model, image_preprocess
33 | ```
34 | where `CUSTOM_MODEL_NAME` is a string that distinguishes your custom model, `custom_wrapper` and `CUSTOMWrapper` are the filename and wrapper class name you defined in the first step.
35 |
36 | **Note**: You need to return `image_preprocess`, which will be used in the dataset construction to process the input image (e.g., cropping, converting to tensor, etc.). If you don't need this operation, please return None.
37 | ## Step 3. Evaluate your custom model
38 | Run the following script to evaluate your custom model on SPEC !
39 | ```shell
40 | model=CUSTOM_MODEL_NAME
41 | model_dir='/path/to/cache/models'
42 | data_dir='/path/to/data'
43 | out_dir='/path/to/save/results'
44 |
45 | python eval.py \
46 | --model-name $model \
47 | --model-cache-dir $model_dir \
48 | --subset-names absolute_size relative_size absolute_spatial relative_spatial existence count \
49 | --data-root $data_dir \
50 | --out-path $out_dir \
51 | --batch-size 64 \
52 | --num-workers 8 \
53 | --seed 1
54 | ```
55 |
--------------------------------------------------------------------------------
/notebooks/evaluate_example_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "048e9a63-4c0e-4c4e-8d90-34f2af8049f9",
6 | "metadata": {},
7 | "source": [
8 | "# Evaluate Popular Vision Language Models on SPEC\n",
9 | "In [our paper](https://arxiv.org/abs/2312.00081), we evaluated four popular VLMs using our SPEC dataset, namely: [CLIP](https://arxiv.org/abs/2103.00020), [BLIP](https://arxiv.org/abs/2201.12086), [FLAVA](https://arxiv.org/abs/2112.04482), and [CoCa](https://arxiv.org/abs/2205.01917). \\\n",
10 | "This notebook will guide readers to reproduce these results step by step, let's go!"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "c3c22764-d4af-4ab9-8460-d6e4dfc79301",
16 | "metadata": {},
17 | "source": [
18 | "## 1. Prepare the environment"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "outputs": [],
25 | "source": [
26 | "!git clone https://github.com/wjpoom/SPEC.git\n",
27 | "%cd SPEC\n",
28 | "!pip install -e . --quiet"
29 | ],
30 | "metadata": {
31 | "collapsed": false
32 | }
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "id": "fb66cde2-cca5-45c1-ac3d-706b3ff4662d",
37 | "metadata": {},
38 | "source": [
39 | "## 2. Import Packages"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 1,
45 | "id": "c5a66c29-7726-4126-bdeb-0a20a1eeeac4",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import zipfile\n",
50 | "import os\n",
51 | "import torch\n",
52 | "import warnings\n",
53 | "warnings.filterwarnings('ignore')\n",
54 | "from spec import get_data, get_model\n",
55 | "from huggingface_hub import hf_hub_download"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "f8d7636e-5a5e-462e-b99a-45feccf1e3ba",
61 | "metadata": {},
62 | "source": [
63 | "## 3. Prepare the testing dataset\n",
64 | "We store the data on HuggingFace. Before starting, you need to download and decompress the data as following:"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 2,
70 | "id": "856550c0-da71-4c09-9533-2df5bd75ba19",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# specify the path to save the downloaded and extracted the data\n",
75 | "data_root = 'data'\n",
76 | "# download *.zip files\n",
77 | "hf_hub_download(repo_id='wjpoom/SPEC', repo_type='dataset', filename='data.zip', local_dir=data_root)\n",
78 | "# extract *.zip files\n",
79 | "with zipfile.ZipFile(os.path.join(data_root, 'data.zip'), 'r') as zip_ref:\n",
80 | " zip_ref.extractall(os.path.join(data_root))\n",
81 | "# remove the *.zip files\n",
82 | "os.remove(os.path.join(data_root, 'data.zip'))\n",
83 | "print(f'The SPEC dataset is prepared at: {data_root}.')"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "id": "860f7834-e471-4f9a-9d4b-b4a152ffc133",
89 | "metadata": {},
90 | "source": [
91 | "## 4. Let's Evaluate VLMs on SPEC dataset!"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "id": "b06a98b7-2b4d-47d3-a13b-794dca37a403",
97 | "metadata": {},
98 | "source": [
99 | "### 4.1 Evaluate CLIP\n",
100 | "We use the `ViT/B-32` variant of [CLIP](https://arxiv.org/abs/2103.00020) with weights resumed from the checkpoint release by OpenAI."
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 3,
106 | "id": "8e687eb6-08a7-47d8-9927-30b174ea7a32",
107 | "metadata": {
108 | "scrolled": true
109 | },
110 | "outputs": [
111 | {
112 | "name": "stderr",
113 | "output_type": "stream",
114 | "text": [
115 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:07<00:00, 2.27it/s]\n"
116 | ]
117 | },
118 | {
119 | "name": "stdout",
120 | "output_type": "stream",
121 | "text": [
122 | "existence subset: Image2Text Accuracy: 57.00 %\n"
123 | ]
124 | },
125 | {
126 | "name": "stderr",
127 | "output_type": "stream",
128 | "text": [
129 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:10<00:00, 1.46it/s]\n"
130 | ]
131 | },
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "existence subset: Text2Image Accuracy: 52.00 %\n"
137 | ]
138 | },
139 | {
140 | "name": "stderr",
141 | "output_type": "stream",
142 | "text": [
143 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:12<00:00, 2.66it/s]\n"
144 | ]
145 | },
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "relative_spatial subset: Image2Text Accuracy: 27.10 %\n"
151 | ]
152 | },
153 | {
154 | "name": "stderr",
155 | "output_type": "stream",
156 | "text": [
157 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:36<00:00, 1.15s/it]\n"
158 | ]
159 | },
160 | {
161 | "name": "stdout",
162 | "output_type": "stream",
163 | "text": [
164 | "relative_spatial subset: Text2Image Accuracy: 26.75 %\n"
165 | ]
166 | },
167 | {
168 | "name": "stderr",
169 | "output_type": "stream",
170 | "text": [
171 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.69it/s]\n"
172 | ]
173 | },
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "absolute_size subset: Image2Text Accuracy: 44.27 %\n"
179 | ]
180 | },
181 | {
182 | "name": "stderr",
183 | "output_type": "stream",
184 | "text": [
185 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n"
186 | ]
187 | },
188 | {
189 | "name": "stdout",
190 | "output_type": "stream",
191 | "text": [
192 | "absolute_size subset: Text2Image Accuracy: 36.27 %\n"
193 | ]
194 | },
195 | {
196 | "name": "stderr",
197 | "output_type": "stream",
198 | "text": [
199 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.67it/s]\n"
200 | ]
201 | },
202 | {
203 | "name": "stdout",
204 | "output_type": "stream",
205 | "text": [
206 | "relative_size subset: Image2Text Accuracy: 34.07 %\n"
207 | ]
208 | },
209 | {
210 | "name": "stderr",
211 | "output_type": "stream",
212 | "text": [
213 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n"
214 | ]
215 | },
216 | {
217 | "name": "stdout",
218 | "output_type": "stream",
219 | "text": [
220 | "relative_size subset: Text2Image Accuracy: 32.47 %\n"
221 | ]
222 | },
223 | {
224 | "name": "stderr",
225 | "output_type": "stream",
226 | "text": [
227 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:43<00:00, 1.62it/s]\n"
228 | ]
229 | },
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "count subset: Image2Text Accuracy: 25.27 %\n"
235 | ]
236 | },
237 | {
238 | "name": "stderr",
239 | "output_type": "stream",
240 | "text": [
241 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:59<00:00, 2.52s/it]\n"
242 | ]
243 | },
244 | {
245 | "name": "stdout",
246 | "output_type": "stream",
247 | "text": [
248 | "count subset: Text2Image Accuracy: 23.62 %\n"
249 | ]
250 | },
251 | {
252 | "name": "stderr",
253 | "output_type": "stream",
254 | "text": [
255 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:45<00:00, 1.57it/s]\n"
256 | ]
257 | },
258 | {
259 | "name": "stdout",
260 | "output_type": "stream",
261 | "text": [
262 | "absolute_spatial subset: Image2Text Accuracy: 12.64 %\n"
263 | ]
264 | },
265 | {
266 | "name": "stderr",
267 | "output_type": "stream",
268 | "text": [
269 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:55<00:00, 2.48s/it]"
270 | ]
271 | },
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "absolute_spatial subset: Text2Image Accuracy: 12.20 %\n",
277 | "\n",
278 | "############# finished the evaluation on all selected subsets ###############\n",
279 | "average of all subset: Image2Text Accuracy: 33.39 %\n",
280 | "average of all subset: Text2Image Accuracy: 30.55 %\n",
281 | "result saved to clip_openai_evaluate_result.pth.\n"
282 | ]
283 | },
284 | {
285 | "name": "stderr",
286 | "output_type": "stream",
287 | "text": [
288 | "\n"
289 | ]
290 | }
291 | ],
292 | "source": [
293 | "# load model\n",
294 | "model_cache_dir = 'models/clip' # specify the path to save the downloaded model checkpoint\n",
295 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
296 | "model, image_preprocess = get_model(model_name='clip', cache_dir=model_cache_dir, device=device)\n",
297 | "# load datasets\n",
298 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
299 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
300 | "# evaluate\n",
301 | "result = {}\n",
302 | "i2t_acc = 0.\n",
303 | "t2i_acc = 0.\n",
304 | "subset_num = 0\n",
305 | "for subset_name, dataloaders in subsets.items():\n",
306 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
307 | " result[subset_name] = subset_result\n",
308 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
309 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
310 | " subset_num += 1\n",
311 | "# print and save results\n",
312 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
313 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
314 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
315 | "out_path = 'results' # specify the path to save the evaluation results\n",
316 | "os.makedirs(out_path, exist_ok=True)\n",
317 | "out_fn = f\"clip_result.pth\" # specify the filename according to the model you used\n",
318 | "torch.save(result, os.path.join(out_path, out_fn))\n",
319 | "print(f'result saved to {out_fn}.')"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "id": "0614e38d-646b-4d7e-a6e7-0a6b05307da9",
325 | "metadata": {},
326 | "source": [
327 | "### 4.2 Evaluate BLIP\n",
328 | "We use the `ViT-B` variant of [BLIP](https://arxiv.org/abs/2201.12086) with weights resumed from the checkpoint released in this [link](https://github.com/salesforce/BLIP), which is finetuned on COCO for image-text retrieval."
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": 4,
334 | "id": "ec215787-f3ea-4890-af47-7e83abc3a9a6",
335 | "metadata": {},
336 | "outputs": [
337 | {
338 | "name": "stdout",
339 | "output_type": "stream",
340 | "text": [
341 | "load checkpoint from ~/.cache/blip/blip-coco-base.pth\n",
342 | "missing keys:\n",
343 | "[]\n"
344 | ]
345 | },
346 | {
347 | "name": "stderr",
348 | "output_type": "stream",
349 | "text": [
350 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:36<00:00, 2.29s/it]\n"
351 | ]
352 | },
353 | {
354 | "name": "stdout",
355 | "output_type": "stream",
356 | "text": [
357 | "existence subset: Image2Text Accuracy: 55.50 %\n"
358 | ]
359 | },
360 | {
361 | "name": "stderr",
362 | "output_type": "stream",
363 | "text": [
364 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:38<00:00, 2.39s/it]\n"
365 | ]
366 | },
367 | {
368 | "name": "stdout",
369 | "output_type": "stream",
370 | "text": [
371 | "existence subset: Text2Image Accuracy: 50.10 %\n"
372 | ]
373 | },
374 | {
375 | "name": "stderr",
376 | "output_type": "stream",
377 | "text": [
378 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:11<00:00, 2.23s/it]\n"
379 | ]
380 | },
381 | {
382 | "name": "stdout",
383 | "output_type": "stream",
384 | "text": [
385 | "relative_spatial subset: Image2Text Accuracy: 30.65 %\n"
386 | ]
387 | },
388 | {
389 | "name": "stderr",
390 | "output_type": "stream",
391 | "text": [
392 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:17<00:00, 4.31s/it]\n"
393 | ]
394 | },
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "relative_spatial subset: Text2Image Accuracy: 29.60 %\n"
400 | ]
401 | },
402 | {
403 | "name": "stderr",
404 | "output_type": "stream",
405 | "text": [
406 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:55<00:00, 2.30s/it]\n"
407 | ]
408 | },
409 | {
410 | "name": "stdout",
411 | "output_type": "stream",
412 | "text": [
413 | "absolute_size subset: Image2Text Accuracy: 43.20 %\n"
414 | ]
415 | },
416 | {
417 | "name": "stderr",
418 | "output_type": "stream",
419 | "text": [
420 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.36s/it]\n"
421 | ]
422 | },
423 | {
424 | "name": "stdout",
425 | "output_type": "stream",
426 | "text": [
427 | "absolute_size subset: Text2Image Accuracy: 43.07 %\n"
428 | ]
429 | },
430 | {
431 | "name": "stderr",
432 | "output_type": "stream",
433 | "text": [
434 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:54<00:00, 2.26s/it]\n"
435 | ]
436 | },
437 | {
438 | "name": "stdout",
439 | "output_type": "stream",
440 | "text": [
441 | "relative_size subset: Image2Text Accuracy: 34.33 %\n"
442 | ]
443 | },
444 | {
445 | "name": "stderr",
446 | "output_type": "stream",
447 | "text": [
448 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.37s/it]\n"
449 | ]
450 | },
451 | {
452 | "name": "stdout",
453 | "output_type": "stream",
454 | "text": [
455 | "relative_size subset: Text2Image Accuracy: 33.27 %\n"
456 | ]
457 | },
458 | {
459 | "name": "stderr",
460 | "output_type": "stream",
461 | "text": [
462 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:44<00:00, 2.32s/it]\n"
463 | ]
464 | },
465 | {
466 | "name": "stdout",
467 | "output_type": "stream",
468 | "text": [
469 | "count subset: Image2Text Accuracy: 36.87 %\n"
470 | ]
471 | },
472 | {
473 | "name": "stderr",
474 | "output_type": "stream",
475 | "text": [
476 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n"
477 | ]
478 | },
479 | {
480 | "name": "stdout",
481 | "output_type": "stream",
482 | "text": [
483 | "count subset: Text2Image Accuracy: 37.40 %\n"
484 | ]
485 | },
486 | {
487 | "name": "stderr",
488 | "output_type": "stream",
489 | "text": [
490 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:53<00:00, 2.44s/it]\n"
491 | ]
492 | },
493 | {
494 | "name": "stdout",
495 | "output_type": "stream",
496 | "text": [
497 | "absolute_spatial subset: Image2Text Accuracy: 12.07 %\n"
498 | ]
499 | },
500 | {
501 | "name": "stderr",
502 | "output_type": "stream",
503 | "text": [
504 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n"
505 | ]
506 | },
507 | {
508 | "name": "stdout",
509 | "output_type": "stream",
510 | "text": [
511 | "absolute_spatial subset: Text2Image Accuracy: 11.58 %\n",
512 | "\n",
513 | "############# finished the evaluation on all selected subsets ###############\n",
514 | "average of all subset: Image2Text Accuracy: 35.44 %\n",
515 | "average of all subset: Text2Image Accuracy: 34.17 %\n",
516 | "result saved to blip_evaluate_result.pth.\n"
517 | ]
518 | }
519 | ],
520 | "source": [
521 | "# load model\n",
522 | "model_cache_dir = 'models/blip' # specify the path to save the downloaded model checkpoint\n",
523 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
524 | "model, image_preprocess = get_model(model_name='blip', cache_dir=model_cache_dir, device=device)\n",
525 | "# load datasets\n",
526 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
527 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
528 | "# evaluate\n",
529 | "result = {}\n",
530 | "i2t_acc = 0.\n",
531 | "t2i_acc = 0.\n",
532 | "subset_num = 0\n",
533 | "for subset_name, dataloaders in subsets.items():\n",
534 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
535 | " result[subset_name] = subset_result\n",
536 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
537 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
538 | " subset_num += 1\n",
539 | "# print and save results\n",
540 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
541 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
542 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
543 | "out_path = 'results' # specify the path to save the evaluation results\n",
544 | "os.makedirs(out_path, exist_ok=True)\n",
545 | "out_fn = f\"blip_result.pth\" # specify the filename according to the model you used\n",
546 | "torch.save(result, os.path.join(out_path, out_fn))\n",
547 | "print(f'result saved to {out_fn}.')"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "id": "33b537d7-3343-4aa5-a632-ef07ed8cfa6d",
553 | "metadata": {},
554 | "source": [
555 | "### 4.3 Evaluate FLAVA\n",
556 | "We use the `full` version of [FLAVA](https://arxiv.org/abs/2112.04482) with weights resumed from this [checkpoint](https://huggingface.co/facebook/flava-full)."
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": null,
562 | "id": "cc335f88-09f7-445b-b548-23bac460e052",
563 | "metadata": {},
564 | "outputs": [
565 | {
566 | "name": "stderr",
567 | "output_type": "stream",
568 | "text": [
569 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:14<00:00, 4.65s/it]\n"
570 | ]
571 | },
572 | {
573 | "name": "stdout",
574 | "output_type": "stream",
575 | "text": [
576 | "existence subset: Image2Text Accuracy: 57.90 %\n"
577 | ]
578 | },
579 | {
580 | "name": "stderr",
581 | "output_type": "stream",
582 | "text": [
583 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [02:07<00:00, 7.99s/it]\n"
584 | ]
585 | },
586 | {
587 | "name": "stdout",
588 | "output_type": "stream",
589 | "text": [
590 | "existence subset: Text2Image Accuracy: 51.80 %\n"
591 | ]
592 | },
593 | {
594 | "name": "stderr",
595 | "output_type": "stream",
596 | "text": [
597 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:33<00:00, 4.78s/it]\n"
598 | ]
599 | },
600 | {
601 | "name": "stdout",
602 | "output_type": "stream",
603 | "text": [
604 | "relative_spatial subset: Image2Text Accuracy: 25.80 %\n"
605 | ]
606 | },
607 | {
608 | "name": "stderr",
609 | "output_type": "stream",
610 | "text": [
611 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [07:53<00:00, 14.80s/it]\n"
612 | ]
613 | },
614 | {
615 | "name": "stdout",
616 | "output_type": "stream",
617 | "text": [
618 | "relative_spatial subset: Text2Image Accuracy: 25.85 %\n"
619 | ]
620 | },
621 | {
622 | "name": "stderr",
623 | "output_type": "stream",
624 | "text": [
625 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.42s/it]\n"
626 | ]
627 | },
628 | {
629 | "name": "stdout",
630 | "output_type": "stream",
631 | "text": [
632 | "absolute_size subset: Image2Text Accuracy: 37.07 %\n"
633 | ]
634 | },
635 | {
636 | "name": "stderr",
637 | "output_type": "stream",
638 | "text": [
639 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:31<00:00, 11.29s/it]\n"
640 | ]
641 | },
642 | {
643 | "name": "stdout",
644 | "output_type": "stream",
645 | "text": [
646 | "absolute_size subset: Text2Image Accuracy: 36.67 %\n"
647 | ]
648 | },
649 | {
650 | "name": "stderr",
651 | "output_type": "stream",
652 | "text": [
653 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.45s/it]\n"
654 | ]
655 | },
656 | {
657 | "name": "stdout",
658 | "output_type": "stream",
659 | "text": [
660 | "relative_size subset: Image2Text Accuracy: 33.53 %\n"
661 | ]
662 | },
663 | {
664 | "name": "stderr",
665 | "output_type": "stream",
666 | "text": [
667 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:44<00:00, 11.86s/it]\n"
668 | ]
669 | },
670 | {
671 | "name": "stdout",
672 | "output_type": "stream",
673 | "text": [
674 | "relative_size subset: Text2Image Accuracy: 33.07 %\n"
675 | ]
676 | },
677 | {
678 | "name": "stderr",
679 | "output_type": "stream",
680 | "text": [
681 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [05:46<00:00, 4.88s/it]\n"
682 | ]
683 | },
684 | {
685 | "name": "stdout",
686 | "output_type": "stream",
687 | "text": [
688 | "count subset: Image2Text Accuracy: 14.00 %\n"
689 | ]
690 | },
691 | {
692 | "name": "stderr",
693 | "output_type": "stream",
694 | "text": [
695 | "Text to Image retrieval on : 1%|█▍ | 1/71 [01:43<2:00:24, 103.21s/it]"
696 | ]
697 | }
698 | ],
699 | "source": [
700 | "# load model\n",
701 | "model_cache_dir = 'models/flava' # specify the path to save the downloaded model checkpoint\n",
702 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
703 | "model, image_preprocess = get_model(model_name='flava', cache_dir=model_cache_dir, device=device)\n",
704 | "# load datasets\n",
705 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
706 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
707 | "# evaluate\n",
708 | "result = {}\n",
709 | "i2t_acc = 0.\n",
710 | "t2i_acc = 0.\n",
711 | "subset_num = 0\n",
712 | "for subset_name, dataloaders in subsets.items():\n",
713 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
714 | " result[subset_name] = subset_result\n",
715 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
716 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
717 | " subset_num += 1\n",
718 | "# print and save results\n",
719 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
720 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
721 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
722 | "out_path = 'results' # specify the path to save the evluation results\n",
723 | "os.makedirs(out_path, exist_ok=True)\n",
724 | "out_fn = f\"flava_result.pth\" # specify the filename according to the model you used\n",
725 | "torch.save(result, os.path.join(out_path, out_fn))\n",
726 | "print(f'result saved to {out_fn}.')"
727 | ]
728 | },
729 | {
730 | "cell_type": "markdown",
731 | "id": "e3528ed5-6ab2-456f-b712-0f7250d9d557",
732 | "metadata": {},
733 | "source": [
734 | "### 4.4 Evaluate CoCa\n",
735 | "We used the `ViT/B-32` variant of [CoCa](https://arxiv.org/abs/2205.01917) model with weights resumed from the [checkpoint](https://github.com/mlfoundations/open_clip) that pretrained on LAION-2B dataset."
736 | ]
737 | },
738 | {
739 | "cell_type": "code",
740 | "execution_count": null,
741 | "id": "2fde7225-1399-45a5-b59d-a71921f140be",
742 | "metadata": {},
743 | "outputs": [],
744 | "source": [
745 | "# load model\n",
746 | "model_cache_dir = 'models/coca' # specify the path to save the downloaded model checkpoint\n",
747 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
748 | "model, image_preprocess = get_model(model_name='coca', cache_dir=model_cache_dir, device=device)\n",
749 | "# load datasets\n",
750 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
751 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
752 | "# evaluate\n",
753 | "result = {}\n",
754 | "i2t_acc = 0.\n",
755 | "t2i_acc = 0.\n",
756 | "subset_num = 0\n",
757 | "for subset_name, dataloaders in subsets.items():\n",
758 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
759 | " result[subset_name] = subset_result\n",
760 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
761 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
762 | " subset_num += 1\n",
763 | "# print and save results\n",
764 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
765 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
766 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
767 | "out_path = 'results' # specify the path to save the evluation results\n",
768 | "os.makedirs(out_path, exist_ok=True)\n",
769 | "out_fn = f\"coca_result.pth\" # specify the filename according to the model you used\n",
770 | "torch.save(result, os.path.join(out_path, out_fn))\n",
771 | "print(f'result saved to {out_fn}.')"
772 | ]
773 | },
774 | {
775 | "cell_type": "markdown",
776 | "id": "b46eaa47-3b2a-4e14-b366-8a567f7e1e54",
777 | "metadata": {},
778 | "source": [
779 | "## What's Next?\n",
780 | "Want to test your own visual language model on SPEC? We have provided a [tutorial](https://github.com/wjpoom/SPEC/blob/main/docs/evaluate_custom_model.md) to help evaluate custom models, feel free to have a try."
781 | ]
782 | }
783 | ],
784 | "metadata": {
785 | "kernelspec": {
786 | "display_name": "spec",
787 | "language": "python",
788 | "name": "spec"
789 | },
790 | "language_info": {
791 | "codemirror_mode": {
792 | "name": "ipython",
793 | "version": 3
794 | },
795 | "file_extension": ".py",
796 | "mimetype": "text/x-python",
797 | "name": "python",
798 | "nbconvert_exporter": "python",
799 | "pygments_lexer": "ipython3",
800 | "version": "3.8.18"
801 | }
802 | },
803 | "nbformat": 4,
804 | "nbformat_minor": 5
805 | }
806 |
--------------------------------------------------------------------------------
/notebooks/evaluate_example_local.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "048e9a63-4c0e-4c4e-8d90-34f2af8049f9",
6 | "metadata": {},
7 | "source": [
8 | "# Evaluate Popular Vision Language Models on SPEC\n",
9 | "In [our paper](https://arxiv.org/abs/2312.00081), we evaluated four popular VLMs using our SPEC dataset, namely: [CLIP](https://arxiv.org/abs/2103.00020), [BLIP](https://arxiv.org/abs/2201.12086), [FLAVA](https://arxiv.org/abs/2112.04482), and [CoCa](https://arxiv.org/abs/2205.01917). \\\n",
10 | "This notebook will guide readers to reproduce these results step by step, let's go!"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "c3c22764-d4af-4ab9-8460-d6e4dfc79301",
16 | "metadata": {},
17 | "source": [
18 | "## 1. How to use this notebook?\n",
19 | "You can run this notebook locally, before running, make sure that you have prepared the environment. \\\n",
20 | "You can also directly run this online notebook: [](https://colab.research.google.com/github/wjpoom/SPEC/blob/main/notebooks/evaluate_example_colab.ipynb)."
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "id": "fb66cde2-cca5-45c1-ac3d-706b3ff4662d",
26 | "metadata": {},
27 | "source": [
28 | "## 2. Import Packages"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 1,
34 | "id": "c5a66c29-7726-4126-bdeb-0a20a1eeeac4",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "import zipfile\n",
39 | "import os\n",
40 | "import torch\n",
41 | "import warnings\n",
42 | "warnings.filterwarnings('ignore')\n",
43 | "from spec import get_data, get_model\n",
44 | "from huggingface_hub import hf_hub_download"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "id": "f8d7636e-5a5e-462e-b99a-45feccf1e3ba",
50 | "metadata": {},
51 | "source": [
52 | "## 3. Prepare the testing dataset\n",
53 | "We store the data on HuggingFace. Before starting, you need to download and decompress the data as following:"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 2,
59 | "id": "856550c0-da71-4c09-9533-2df5bd75ba19",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "# specify the path to save the downloaded and extracted the data\n",
64 | "data_root = '/path/to/save/data'\n",
65 | "# download *.zip files\n",
66 | "hf_hub_download(repo_id='wjpoom/SPEC', repo_type='dataset', filename='data.zip', local_dir=data_root)\n",
67 | "# extract *.zip files\n",
68 | "with zipfile.ZipFile(os.path.join(data_root, 'data.zip'), 'r') as zip_ref:\n",
69 | " zip_ref.extractall(os.path.join(data_root))\n",
70 | "# remove the *.zip files\n",
71 | "os.remove(os.path.join(data_root, 'data.zip'))\n",
72 | "print(f'The SPEC dataset is prepared at: {data_root}.')"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "id": "860f7834-e471-4f9a-9d4b-b4a152ffc133",
78 | "metadata": {},
79 | "source": [
80 | "## 4. Let's Evaluate VLMs on SPEC dataset!"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "b06a98b7-2b4d-47d3-a13b-794dca37a403",
86 | "metadata": {},
87 | "source": [
88 | "### 4.1 Evaluate CLIP\n",
89 | "We use the `ViT/B-32` variant of [CLIP](https://arxiv.org/abs/2103.00020) with weights resumed from the checkpoint release by OpenAI."
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 3,
95 | "id": "8e687eb6-08a7-47d8-9927-30b174ea7a32",
96 | "metadata": {
97 | "scrolled": true
98 | },
99 | "outputs": [
100 | {
101 | "name": "stderr",
102 | "output_type": "stream",
103 | "text": [
104 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:07<00:00, 2.27it/s]\n"
105 | ]
106 | },
107 | {
108 | "name": "stdout",
109 | "output_type": "stream",
110 | "text": [
111 | "existence subset: Image2Text Accuracy: 57.00 %\n"
112 | ]
113 | },
114 | {
115 | "name": "stderr",
116 | "output_type": "stream",
117 | "text": [
118 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:10<00:00, 1.46it/s]\n"
119 | ]
120 | },
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "existence subset: Text2Image Accuracy: 52.00 %\n"
126 | ]
127 | },
128 | {
129 | "name": "stderr",
130 | "output_type": "stream",
131 | "text": [
132 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:12<00:00, 2.66it/s]\n"
133 | ]
134 | },
135 | {
136 | "name": "stdout",
137 | "output_type": "stream",
138 | "text": [
139 | "relative_spatial subset: Image2Text Accuracy: 27.10 %\n"
140 | ]
141 | },
142 | {
143 | "name": "stderr",
144 | "output_type": "stream",
145 | "text": [
146 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:36<00:00, 1.15s/it]\n"
147 | ]
148 | },
149 | {
150 | "name": "stdout",
151 | "output_type": "stream",
152 | "text": [
153 | "relative_spatial subset: Text2Image Accuracy: 26.75 %\n"
154 | ]
155 | },
156 | {
157 | "name": "stderr",
158 | "output_type": "stream",
159 | "text": [
160 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.69it/s]\n"
161 | ]
162 | },
163 | {
164 | "name": "stdout",
165 | "output_type": "stream",
166 | "text": [
167 | "absolute_size subset: Image2Text Accuracy: 44.27 %\n"
168 | ]
169 | },
170 | {
171 | "name": "stderr",
172 | "output_type": "stream",
173 | "text": [
174 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n"
175 | ]
176 | },
177 | {
178 | "name": "stdout",
179 | "output_type": "stream",
180 | "text": [
181 | "absolute_size subset: Text2Image Accuracy: 36.27 %\n"
182 | ]
183 | },
184 | {
185 | "name": "stderr",
186 | "output_type": "stream",
187 | "text": [
188 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.67it/s]\n"
189 | ]
190 | },
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "relative_size subset: Image2Text Accuracy: 34.07 %\n"
196 | ]
197 | },
198 | {
199 | "name": "stderr",
200 | "output_type": "stream",
201 | "text": [
202 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n"
203 | ]
204 | },
205 | {
206 | "name": "stdout",
207 | "output_type": "stream",
208 | "text": [
209 | "relative_size subset: Text2Image Accuracy: 32.47 %\n"
210 | ]
211 | },
212 | {
213 | "name": "stderr",
214 | "output_type": "stream",
215 | "text": [
216 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:43<00:00, 1.62it/s]\n"
217 | ]
218 | },
219 | {
220 | "name": "stdout",
221 | "output_type": "stream",
222 | "text": [
223 | "count subset: Image2Text Accuracy: 25.27 %\n"
224 | ]
225 | },
226 | {
227 | "name": "stderr",
228 | "output_type": "stream",
229 | "text": [
230 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:59<00:00, 2.52s/it]\n"
231 | ]
232 | },
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "count subset: Text2Image Accuracy: 23.62 %\n"
238 | ]
239 | },
240 | {
241 | "name": "stderr",
242 | "output_type": "stream",
243 | "text": [
244 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:45<00:00, 1.57it/s]\n"
245 | ]
246 | },
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "absolute_spatial subset: Image2Text Accuracy: 12.64 %\n"
252 | ]
253 | },
254 | {
255 | "name": "stderr",
256 | "output_type": "stream",
257 | "text": [
258 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:55<00:00, 2.48s/it]"
259 | ]
260 | },
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "absolute_spatial subset: Text2Image Accuracy: 12.20 %\n",
266 | "\n",
267 | "############# finished the evaluation on all selected subsets ###############\n",
268 | "average of all subset: Image2Text Accuracy: 33.39 %\n",
269 | "average of all subset: Text2Image Accuracy: 30.55 %\n",
270 | "result saved to clip_openai_evaluate_result.pth.\n"
271 | ]
272 | },
273 | {
274 | "name": "stderr",
275 | "output_type": "stream",
276 | "text": [
277 | "\n"
278 | ]
279 | }
280 | ],
281 | "source": [
282 | "# load model\n",
283 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n",
284 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
285 | "model, image_preprocess = get_model(model_name='clip', cache_dir=model_cache_dir, device=device)\n",
286 | "# load datasets\n",
287 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
288 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
289 | "# evaluate\n",
290 | "result = {}\n",
291 | "i2t_acc = 0.\n",
292 | "t2i_acc = 0.\n",
293 | "subset_num = 0\n",
294 | "for subset_name, dataloaders in subsets.items():\n",
295 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
296 | " result[subset_name] = subset_result\n",
297 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
298 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
299 | " subset_num += 1\n",
300 | "# print and save results\n",
301 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
302 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
303 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
304 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n",
305 | "os.makedirs(out_path, exist_ok=True)\n",
306 | "out_fn = f\"clip_result.pth\" # specify the filename according to the model you used\n",
307 | "torch.save(result, os.path.join(out_path, out_fn))\n",
308 | "print(f'result saved to {out_fn}.')"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "id": "0614e38d-646b-4d7e-a6e7-0a6b05307da9",
314 | "metadata": {},
315 | "source": [
316 | "### 4.2 Evaluate BLIP\n",
317 | "We use the `ViT-B` variant of [BLIP](https://arxiv.org/abs/2201.12086) with weights resumed from the checkpoint released in this [link](https://github.com/salesforce/BLIP), which is finetuned on COCO for image-text retrieval."
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 4,
323 | "id": "ec215787-f3ea-4890-af47-7e83abc3a9a6",
324 | "metadata": {},
325 | "outputs": [
326 | {
327 | "name": "stdout",
328 | "output_type": "stream",
329 | "text": [
330 | "load checkpoint from ~/.cache/blip/blip-coco-base.pth\n",
331 | "missing keys:\n",
332 | "[]\n"
333 | ]
334 | },
335 | {
336 | "name": "stderr",
337 | "output_type": "stream",
338 | "text": [
339 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:36<00:00, 2.29s/it]\n"
340 | ]
341 | },
342 | {
343 | "name": "stdout",
344 | "output_type": "stream",
345 | "text": [
346 | "existence subset: Image2Text Accuracy: 55.50 %\n"
347 | ]
348 | },
349 | {
350 | "name": "stderr",
351 | "output_type": "stream",
352 | "text": [
353 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:38<00:00, 2.39s/it]\n"
354 | ]
355 | },
356 | {
357 | "name": "stdout",
358 | "output_type": "stream",
359 | "text": [
360 | "existence subset: Text2Image Accuracy: 50.10 %\n"
361 | ]
362 | },
363 | {
364 | "name": "stderr",
365 | "output_type": "stream",
366 | "text": [
367 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:11<00:00, 2.23s/it]\n"
368 | ]
369 | },
370 | {
371 | "name": "stdout",
372 | "output_type": "stream",
373 | "text": [
374 | "relative_spatial subset: Image2Text Accuracy: 30.65 %\n"
375 | ]
376 | },
377 | {
378 | "name": "stderr",
379 | "output_type": "stream",
380 | "text": [
381 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:17<00:00, 4.31s/it]\n"
382 | ]
383 | },
384 | {
385 | "name": "stdout",
386 | "output_type": "stream",
387 | "text": [
388 | "relative_spatial subset: Text2Image Accuracy: 29.60 %\n"
389 | ]
390 | },
391 | {
392 | "name": "stderr",
393 | "output_type": "stream",
394 | "text": [
395 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:55<00:00, 2.30s/it]\n"
396 | ]
397 | },
398 | {
399 | "name": "stdout",
400 | "output_type": "stream",
401 | "text": [
402 | "absolute_size subset: Image2Text Accuracy: 43.20 %\n"
403 | ]
404 | },
405 | {
406 | "name": "stderr",
407 | "output_type": "stream",
408 | "text": [
409 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.36s/it]\n"
410 | ]
411 | },
412 | {
413 | "name": "stdout",
414 | "output_type": "stream",
415 | "text": [
416 | "absolute_size subset: Text2Image Accuracy: 43.07 %\n"
417 | ]
418 | },
419 | {
420 | "name": "stderr",
421 | "output_type": "stream",
422 | "text": [
423 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:54<00:00, 2.26s/it]\n"
424 | ]
425 | },
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "relative_size subset: Image2Text Accuracy: 34.33 %\n"
431 | ]
432 | },
433 | {
434 | "name": "stderr",
435 | "output_type": "stream",
436 | "text": [
437 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.37s/it]\n"
438 | ]
439 | },
440 | {
441 | "name": "stdout",
442 | "output_type": "stream",
443 | "text": [
444 | "relative_size subset: Text2Image Accuracy: 33.27 %\n"
445 | ]
446 | },
447 | {
448 | "name": "stderr",
449 | "output_type": "stream",
450 | "text": [
451 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:44<00:00, 2.32s/it]\n"
452 | ]
453 | },
454 | {
455 | "name": "stdout",
456 | "output_type": "stream",
457 | "text": [
458 | "count subset: Image2Text Accuracy: 36.87 %\n"
459 | ]
460 | },
461 | {
462 | "name": "stderr",
463 | "output_type": "stream",
464 | "text": [
465 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n"
466 | ]
467 | },
468 | {
469 | "name": "stdout",
470 | "output_type": "stream",
471 | "text": [
472 | "count subset: Text2Image Accuracy: 37.40 %\n"
473 | ]
474 | },
475 | {
476 | "name": "stderr",
477 | "output_type": "stream",
478 | "text": [
479 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:53<00:00, 2.44s/it]\n"
480 | ]
481 | },
482 | {
483 | "name": "stdout",
484 | "output_type": "stream",
485 | "text": [
486 | "absolute_spatial subset: Image2Text Accuracy: 12.07 %\n"
487 | ]
488 | },
489 | {
490 | "name": "stderr",
491 | "output_type": "stream",
492 | "text": [
493 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n"
494 | ]
495 | },
496 | {
497 | "name": "stdout",
498 | "output_type": "stream",
499 | "text": [
500 | "absolute_spatial subset: Text2Image Accuracy: 11.58 %\n",
501 | "\n",
502 | "############# finished the evaluation on all selected subsets ###############\n",
503 | "average of all subset: Image2Text Accuracy: 35.44 %\n",
504 | "average of all subset: Text2Image Accuracy: 34.17 %\n",
505 | "result saved to blip_evaluate_result.pth.\n"
506 | ]
507 | }
508 | ],
509 | "source": [
510 | "# load model\n",
511 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n",
512 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
513 | "model, image_preprocess = get_model(model_name='blip', cache_dir=model_cache_dir, device=device)\n",
514 | "# load datasets\n",
515 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
516 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
517 | "# evaluate\n",
518 | "result = {}\n",
519 | "i2t_acc = 0.\n",
520 | "t2i_acc = 0.\n",
521 | "subset_num = 0\n",
522 | "for subset_name, dataloaders in subsets.items():\n",
523 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
524 | " result[subset_name] = subset_result\n",
525 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
526 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
527 | " subset_num += 1\n",
528 | "# print and save results\n",
529 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
530 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
531 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
532 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n",
533 | "os.makedirs(out_path, exist_ok=True)\n",
534 | "out_fn = f\"blip_result.pth\" # specify the filename according to the model you used\n",
535 | "torch.save(result, os.path.join(out_path, out_fn))\n",
536 | "print(f'result saved to {out_fn}.')"
537 | ]
538 | },
539 | {
540 | "cell_type": "markdown",
541 | "id": "33b537d7-3343-4aa5-a632-ef07ed8cfa6d",
542 | "metadata": {},
543 | "source": [
544 | "### 4.3 Evaluate FLAVA\n",
545 | "We use the `full` version of [FLAVA](https://arxiv.org/abs/2112.04482) with weights resumed from this [checkpoint](https://huggingface.co/facebook/flava-full)."
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": null,
551 | "id": "cc335f88-09f7-445b-b548-23bac460e052",
552 | "metadata": {},
553 | "outputs": [
554 | {
555 | "name": "stderr",
556 | "output_type": "stream",
557 | "text": [
558 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:14<00:00, 4.65s/it]\n"
559 | ]
560 | },
561 | {
562 | "name": "stdout",
563 | "output_type": "stream",
564 | "text": [
565 | "existence subset: Image2Text Accuracy: 57.90 %\n"
566 | ]
567 | },
568 | {
569 | "name": "stderr",
570 | "output_type": "stream",
571 | "text": [
572 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [02:07<00:00, 7.99s/it]\n"
573 | ]
574 | },
575 | {
576 | "name": "stdout",
577 | "output_type": "stream",
578 | "text": [
579 | "existence subset: Text2Image Accuracy: 51.80 %\n"
580 | ]
581 | },
582 | {
583 | "name": "stderr",
584 | "output_type": "stream",
585 | "text": [
586 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:33<00:00, 4.78s/it]\n"
587 | ]
588 | },
589 | {
590 | "name": "stdout",
591 | "output_type": "stream",
592 | "text": [
593 | "relative_spatial subset: Image2Text Accuracy: 25.80 %\n"
594 | ]
595 | },
596 | {
597 | "name": "stderr",
598 | "output_type": "stream",
599 | "text": [
600 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [07:53<00:00, 14.80s/it]\n"
601 | ]
602 | },
603 | {
604 | "name": "stdout",
605 | "output_type": "stream",
606 | "text": [
607 | "relative_spatial subset: Text2Image Accuracy: 25.85 %\n"
608 | ]
609 | },
610 | {
611 | "name": "stderr",
612 | "output_type": "stream",
613 | "text": [
614 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.42s/it]\n"
615 | ]
616 | },
617 | {
618 | "name": "stdout",
619 | "output_type": "stream",
620 | "text": [
621 | "absolute_size subset: Image2Text Accuracy: 37.07 %\n"
622 | ]
623 | },
624 | {
625 | "name": "stderr",
626 | "output_type": "stream",
627 | "text": [
628 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:31<00:00, 11.29s/it]\n"
629 | ]
630 | },
631 | {
632 | "name": "stdout",
633 | "output_type": "stream",
634 | "text": [
635 | "absolute_size subset: Text2Image Accuracy: 36.67 %\n"
636 | ]
637 | },
638 | {
639 | "name": "stderr",
640 | "output_type": "stream",
641 | "text": [
642 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.45s/it]\n"
643 | ]
644 | },
645 | {
646 | "name": "stdout",
647 | "output_type": "stream",
648 | "text": [
649 | "relative_size subset: Image2Text Accuracy: 33.53 %\n"
650 | ]
651 | },
652 | {
653 | "name": "stderr",
654 | "output_type": "stream",
655 | "text": [
656 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:44<00:00, 11.86s/it]\n"
657 | ]
658 | },
659 | {
660 | "name": "stdout",
661 | "output_type": "stream",
662 | "text": [
663 | "relative_size subset: Text2Image Accuracy: 33.07 %\n"
664 | ]
665 | },
666 | {
667 | "name": "stderr",
668 | "output_type": "stream",
669 | "text": [
670 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [05:46<00:00, 4.88s/it]\n"
671 | ]
672 | },
673 | {
674 | "name": "stdout",
675 | "output_type": "stream",
676 | "text": [
677 | "count subset: Image2Text Accuracy: 14.00 %\n"
678 | ]
679 | },
680 | {
681 | "name": "stderr",
682 | "output_type": "stream",
683 | "text": [
684 | "Text to Image retrieval on : 1%|█▍ | 1/71 [01:43<2:00:24, 103.21s/it]"
685 | ]
686 | }
687 | ],
688 | "source": [
689 | "# load model\n",
690 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n",
691 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
692 | "model, image_preprocess = get_model(model_name='flava', cache_dir=model_cache_dir, device=device)\n",
693 | "# load datasets\n",
694 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
695 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
696 | "# evaluate\n",
697 | "result = {}\n",
698 | "i2t_acc = 0.\n",
699 | "t2i_acc = 0.\n",
700 | "subset_num = 0\n",
701 | "for subset_name, dataloaders in subsets.items():\n",
702 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
703 | " result[subset_name] = subset_result\n",
704 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
705 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
706 | " subset_num += 1\n",
707 | "# print and save results\n",
708 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
709 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
710 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
711 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n",
712 | "os.makedirs(out_path, exist_ok=True)\n",
713 | "out_fn = f\"flava_result.pth\" # specify the filename according to the model you used\n",
714 | "torch.save(result, os.path.join(out_path, out_fn))\n",
715 | "print(f'result saved to {out_fn}.')"
716 | ]
717 | },
718 | {
719 | "cell_type": "markdown",
720 | "id": "e3528ed5-6ab2-456f-b712-0f7250d9d557",
721 | "metadata": {},
722 | "source": [
723 | "### 4.4 Evaluate CoCa\n",
724 | "We used the `ViT/B-32` variant of [CoCa](https://arxiv.org/abs/2205.01917) model with weights resumed from the [checkpoint](https://github.com/mlfoundations/open_clip) that pretrained on LAION-2B dataset."
725 | ]
726 | },
727 | {
728 | "cell_type": "code",
729 | "execution_count": null,
730 | "id": "2fde7225-1399-45a5-b59d-a71921f140be",
731 | "metadata": {},
732 | "outputs": [],
733 | "source": [
734 | "# load model\n",
735 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n",
736 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
737 | "model, image_preprocess = get_model(model_name='coca', cache_dir=model_cache_dir, device=device)\n",
738 | "# load datasets\n",
739 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n",
740 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n",
741 | "# evaluate\n",
742 | "result = {}\n",
743 | "i2t_acc = 0.\n",
744 | "t2i_acc = 0.\n",
745 | "subset_num = 0\n",
746 | "for subset_name, dataloaders in subsets.items():\n",
747 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n",
748 | " result[subset_name] = subset_result\n",
749 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n",
750 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n",
751 | " subset_num += 1\n",
752 | "# print and save results\n",
753 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n",
754 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n",
755 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n",
756 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n",
757 | "os.makedirs(out_path, exist_ok=True)\n",
758 | "out_fn = f\"coca_result.pth\" # specify the filename according to the model you used\n",
759 | "torch.save(result, os.path.join(out_path, out_fn))\n",
760 | "print(f'result saved to {out_fn}.')"
761 | ]
762 | },
763 | {
764 | "cell_type": "markdown",
765 | "id": "b46eaa47-3b2a-4e14-b366-8a567f7e1e54",
766 | "metadata": {},
767 | "source": [
768 | "## What's Next?\n",
769 | "Want to test your own visual language model on SPEC? We have provided a [tutorial](https://github.com/wjpoom/SPEC/blob/main/docs/evaluate_custom_model.md) to help evaluate custom models, feel free to have a try."
770 | ]
771 | }
772 | ],
773 | "metadata": {
774 | "kernelspec": {
775 | "display_name": "spec",
776 | "language": "python",
777 | "name": "spec"
778 | },
779 | "language_info": {
780 | "codemirror_mode": {
781 | "name": "ipython",
782 | "version": 3
783 | },
784 | "file_extension": ".py",
785 | "mimetype": "text/x-python",
786 | "name": "python",
787 | "nbconvert_exporter": "python",
788 | "pygments_lexer": "ipython3",
789 | "version": "3.8.18"
790 | }
791 | },
792 | "nbformat": 4,
793 | "nbformat_minor": 5
794 | }
795 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # setup.py
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(
6 | name='spec',
7 | version='1.0.0',
8 | packages=find_packages(),
9 | install_requires=[
10 | 'einops>=0.7.0',
11 | 'open_clip_torch==2.24.0',
12 | 'PyYAML==6.0.1',
13 | 'setuptools==69.2.0',
14 | 'timm==0.9.16',
15 | 'torch==2.2.1',
16 | 'torchvision==0.17.1',
17 | 'tqdm==4.66.2',
18 | 'transformers==4.38.2',
19 | 'huggingface-hub==0.21.4',
20 | 'jedi>=0.16',
21 | 'fairscale==0.4.13',
22 | ]
23 | )
--------------------------------------------------------------------------------
/spec/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import get_data
2 | from .models import get_model
--------------------------------------------------------------------------------
/spec/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 |
5 | from torch.utils.data import Dataset, DataLoader
6 | from PIL import Image
7 |
8 |
9 | class Image2TextDataset(Dataset):
10 | def __init__(self, subset_root, image_preprocess=None):
11 | """
12 | Args:
13 | subset_root: the path to the root dir of a subset, (e.g. `absolute_size`)
14 | """
15 | self.subset_root = subset_root
16 | self.image_preprocess = image_preprocess
17 |
18 | ann = os.path.join(subset_root, 'image2text.json')
19 | with open(ann, 'r') as f:
20 | self.sample_list = json.load(f)
21 | f.close()
22 |
23 | def __len__(self):
24 | return len(self.sample_list)
25 |
26 | def __getitem__(self, idx):
27 | sample_info = self.sample_list[idx]
28 |
29 | # query image
30 | image_path = os.path.join(self.subset_root, sample_info['query'])
31 | query_image = Image.open(image_path).convert('RGB')
32 | if self.image_preprocess is not None:
33 | query_image = self.image_preprocess(query_image)
34 |
35 | # candidate texts
36 | candidate_texts = sample_info['keys']
37 |
38 | # label
39 | label = sample_info['label']
40 |
41 | sample = {
42 | "query_image": query_image,
43 | "candidate_texts": candidate_texts,
44 | "label": label
45 | }
46 |
47 | return sample
48 |
49 | def collate_fn(self, batch):
50 | query_image = []
51 | candidate_texts = []
52 | label = []
53 | for sample in batch:
54 | query_image.append(sample['query_image'])
55 | candidate_texts.append(sample['candidate_texts'])
56 | label.append(sample['label'])
57 | if self.image_preprocess is not None:
58 | query_image = torch.stack(query_image, dim=0)
59 | batch = {
60 | 'query_image': query_image,
61 | 'candidate_texts': candidate_texts,
62 | 'label': torch.tensor(label)
63 | }
64 | return batch
65 |
66 |
67 | class Text2ImageDataset(Dataset):
68 | def __init__(self, subset_root, image_preprocess=None):
69 | """
70 | Args:
71 | subset_root: the path to the root dir of a subset, (e.g. `absolute_size`)
72 | """
73 | self.subset_root = subset_root
74 | self.image_preprocess = image_preprocess
75 |
76 | ann = os.path.join(subset_root, 'text2image.json')
77 | with open(ann, 'r') as f:
78 | self.sample_list = json.load(f)
79 | f.close()
80 |
81 | def __len__(self):
82 | return len(self.sample_list)
83 |
84 | def __getitem__(self, idx):
85 | sample_info = self.sample_list[idx]
86 |
87 | # query text
88 | query_text = sample_info['query']
89 |
90 | # candidate images
91 | candidate_images = []
92 | for img in sample_info['keys']:
93 | img = Image.open(os.path.join(self.subset_root, img)).convert('RGB')
94 | if self.image_preprocess is not None:
95 | img = self.image_preprocess(img)
96 | candidate_images.append(img)
97 | if self.image_preprocess is not None:
98 | candidate_images = torch.stack(candidate_images, dim=0)
99 |
100 | # label
101 | label = sample_info['label']
102 |
103 | sample = {
104 | "query_text": query_text,
105 | "candidate_images": candidate_images,
106 | "label": label
107 | }
108 |
109 | return sample
110 |
111 | def collate_fn(self, batch):
112 | query_text = []
113 | candidate_images = []
114 | label = []
115 | for sample in batch:
116 | query_text.append(sample['query_text'])
117 | candidate_images.append(sample['candidate_images'])
118 | label.append(sample['label'])
119 | if self.image_preprocess is not None:
120 | candidate_images = torch.stack(candidate_images, dim=0)
121 |
122 | batch = {
123 | 'query_text': query_text,
124 | 'candidate_images': candidate_images,
125 | 'label': torch.tensor(label)
126 | }
127 | return batch
128 |
129 |
130 | def get_data(data_root, subset_names, image_preprocess, batch_size, num_workers):
131 | """
132 | Create SPEC datasets
133 | Args:
134 | data_root: the path to the dir contains all the subsets' data
135 | subset_names: selected subsets names that you wish to evaluate your model on
136 | image_preprocess: image_preprocess
137 | batch_size: batch_size for dataloader
138 | num_workers: num_workers for dataloader
139 | Return:
140 | A list contains selected sub-datasets
141 | """
142 |
143 | data = {}
144 | for subset_nm in subset_names:
145 |
146 | subset_root_path = os.path.join(data_root, subset_nm)
147 | i2t_dataset = Image2TextDataset(subset_root=subset_root_path,
148 | image_preprocess=image_preprocess)
149 | t2i_dataset = Text2ImageDataset(subset_root=subset_root_path,
150 | image_preprocess=image_preprocess)
151 |
152 | i2t_loader = DataLoader(dataset=i2t_dataset,
153 | batch_size=batch_size,
154 | num_workers=num_workers,
155 | collate_fn=i2t_dataset.collate_fn,
156 | shuffle=False,
157 | drop_last=False)
158 |
159 | t2i_loader = DataLoader(dataset=t2i_dataset,
160 | batch_size=batch_size,
161 | num_workers=num_workers,
162 | collate_fn=t2i_dataset.collate_fn,
163 | shuffle=False,
164 | drop_last=False)
165 |
166 | data[subset_nm] = {
167 | 'i2t_dataloader': i2t_loader,
168 | 't2i_dataloader': t2i_loader
169 | }
170 |
171 | return data
172 |
--------------------------------------------------------------------------------
/spec/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import random
5 | import numpy as np
6 |
7 | from models import get_model
8 | from dataset import get_data
9 |
10 |
11 | def seed_all(seed):
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 |
17 |
18 | @torch.no_grad()
19 | def main():
20 | os.makedirs(args.out_path, exist_ok=True)
21 |
22 | # set random seeds
23 | seed_all(args.seed)
24 |
25 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
26 | # load model
27 | model, image_preprocess = get_model(model_name=args.model_name,
28 | cache_dir=args.model_cache_dir,
29 | device=device)
30 |
31 | # load data
32 | data = get_data(data_root=args.data_root,
33 | subset_names=args.subset_names,
34 | image_preprocess=image_preprocess,
35 | batch_size=args.batch_size,
36 | num_workers=args.num_workers)
37 |
38 | # evaluate on each subset
39 | print(f'\nBegin the evaluation of {args.model_name} on all selected subsets.')
40 | result = {}
41 | i2t_acc = 0.
42 | t2i_acc = 0.
43 | subset_num = 0
44 | for subset_name, dataloaders in data.items():
45 | subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)
46 | result[subset_name] = subset_result
47 | i2t_acc += subset_result['accuracy']['i2t_accuracy']
48 | t2i_acc += subset_result['accuracy']['t2i_accuracy']
49 | subset_num += 1
50 | print(f'\nFinished the evaluation of {args.model_name} on all selected subsets.')
51 | print(f'average all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')
52 | print(f'average all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')
53 |
54 | # save results
55 | out_fn = f"{args.model_name}_result.pth"
56 | out_fn = os.path.join(args.out_path, out_fn)
57 | torch.save(result, out_fn)
58 | print(f'result saved to {out_fn}.')
59 |
60 |
61 | if __name__ == '__main__':
62 | parser = argparse.ArgumentParser('Vision Language Models Evaluation Pipeline')
63 | parser.add_argument('--model-name',
64 | type=str,
65 | default='clip')
66 | parser.add_argument('--pretrained',
67 | type=str,
68 | help="the pretrained model checkpoint")
69 | parser.add_argument('--model-cache-dir',
70 | type=str,
71 | default='~/.cache',
72 | help='the path to cache the downloaded model checkpoints')
73 | parser.add_argument('--subset-names',
74 | type=str,
75 | nargs='+',
76 | choices=['count', 'relative_size', 'absolute_size', 'relative_spatial', 'absolute_spatial',
77 | 'existence'],
78 | help='type of generated dataset type for enhanced ability')
79 | parser.add_argument('--data-root',
80 | type=str,
81 | help='the path the the root dir of data')
82 | parser.add_argument('--out-path',
83 | type=str,
84 | default='out',
85 | help="path to save evaluation-real results")
86 | parser.add_argument('--batch-size',
87 | type=int,
88 | default=64)
89 | parser.add_argument('--num-workers',
90 | type=int,
91 | default=8)
92 | parser.add_argument('--seed',
93 | type=int,
94 | default=1,
95 | help="random seed for reproducibility")
96 |
97 | args = parser.parse_args()
98 |
99 | main()
100 |
--------------------------------------------------------------------------------
/spec/models/__init__.py:
--------------------------------------------------------------------------------
1 | """code credict: https://github.com/mertyg/vision-language-models-are-bows/tree/main/model_zoo"""
2 | import os
3 |
4 | def get_model(model_name, cache_dir='~/.cache', device='cuda'):
5 | """
6 | Helper function that returns a model and an image preprocessing function and text tokenizer.
7 | Args:
8 | model_name: the model that you want to create
9 | cache_dir: the path to cache the downloader model checkpoints
10 | device
11 | Returns:
12 | pretrained_model, image_preprocess
13 | """
14 | os.makedirs(cache_dir, exist_ok=True)
15 |
16 | if model_name == 'clip':
17 | from .clip_wrapper import CLIPWrapper
18 | clip_model = CLIPWrapper(device=device, variant='ViT-B-32', pretrained='openai')
19 | image_preprocess = clip_model.image_preprocess
20 | return clip_model, image_preprocess
21 |
22 | elif model_name == 'blip':
23 | from .blip_wrapper import BLIPModelWrapper
24 | blip_model = BLIPModelWrapper(cache_dir=cache_dir, device=device, variant="blip-coco-base")
25 | image_preprocess = blip_model.image_preprocess
26 | return blip_model, image_preprocess
27 |
28 | elif model_name == "flava":
29 | from .flava_wrapper import FlavaWrapper
30 | flava_model = FlavaWrapper(cache_dir=cache_dir, device=device)
31 | image_preprocess = None
32 | return flava_model, image_preprocess
33 |
34 | elif model_name == "coca":
35 | from .clip_wrapper import CLIPWrapper
36 | coca_model = CLIPWrapper(device=device, variant="coca_ViT-B-32", pretrained="laion2B-s13B-b90k")
37 | image_preprocess = coca_model.image_preprocess
38 | return coca_model, image_preprocess
39 |
40 | else:
41 | raise ValueError(f"Unknown model {model_name}")
42 |
--------------------------------------------------------------------------------
/spec/models/base_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from abc import ABCMeta, abstractmethod
4 |
5 |
6 | class BaseWrapper(metaclass=ABCMeta):
7 | """
8 | This is the base model wrapper, if you want to evluate
9 | """
10 |
11 | @abstractmethod
12 | @torch.no_grad()
13 | def i2t_evaluate(self, subset_name, dataloader):
14 | """
15 | Performing `Image-to-Text` retrieval evaluation.
16 | Args:
17 | subset_name: the name of subset that you are running on
18 | dataloader: the dataloader that contains all the image to text samples
19 |
20 | Returns:
21 | i2t_scores:
22 | i2t_acc:
23 | """
24 | pass
25 |
26 | @abstractmethod
27 | @torch.no_grad()
28 | def t2i_evaluate(self, subset_name, dataloader):
29 | """
30 | Performing `Text-to-Image` retrieval evaluation.
31 | Args:
32 | subset_name: the name of subset that you are running on
33 | dataloader: the dataloader that contains all the image to text samples
34 |
35 | Returns:
36 | t2i_scores:
37 | t2i_acc:
38 | """
39 | pass
40 |
41 | @torch.no_grad()
42 | def evaluate(self, subset_name, dataloaders):
43 | """Computes the image-text matching scores and the image-to-text and text-to-image accuracy on a given subset
44 | Args:
45 | subset_name: the name of the subset
46 | dataloaders (Dict): include an "i2t_dataloader" and a "t2i_dataloader"
47 | Returns:
48 | scores(Dict of Tensor): `i2t_scores`, `t2i_scores`
49 | accuracy(Dict of Scalar): `i2t_accuracy`, t2i_accuracy`
50 | """
51 | # image to text retrieval
52 | i2t_scores, i2t_acc = self.i2t_evaluate(subset_name, dataloaders['i2t_dataloader'])
53 | print(f'{subset_name} subset: Image2Text Accuracy: {i2t_acc:.2f} %')
54 |
55 | # text to image retrieval
56 | t2i_scores, t2i_acc = self.t2i_evaluate(subset_name, dataloaders['t2i_dataloader'])
57 | print(f'{subset_name} subset: Text2Image Accuracy: {t2i_acc:.2f} %')
58 |
59 | """
60 | `i2t_scores`: tensor of shape NxL, N is the number of testing samples, L is the number of candidate texts per sample
61 | `t2i_scores`: tensor of shape NxK, N is the number of testing samples, K is the number of candidate images per sample
62 | """
63 | scores = {
64 | 'i2t_scores': i2t_scores,
65 | 't2i_scores': t2i_scores
66 | }
67 |
68 | accuracy = {
69 | 'i2t_accuracy': i2t_acc,
70 | 't2i_accuracy': t2i_acc,
71 | }
72 |
73 | return {
74 | "accuracy": accuracy,
75 | "scores": scores
76 | }
77 |
--------------------------------------------------------------------------------
/spec/models/blip_utils/.ipynb_checkpoints/README-checkpoint.md:
--------------------------------------------------------------------------------
1 | Most of the utilities here are obtained from https://github.com/salesforce/BLIP/ .
--------------------------------------------------------------------------------
/spec/models/blip_utils/README.md:
--------------------------------------------------------------------------------
1 | Most of the utilities here are obtained from https://github.com/salesforce/BLIP/ .
--------------------------------------------------------------------------------
/spec/models/blip_utils/__pycache__/blip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/blip.cpython-38.pyc
--------------------------------------------------------------------------------
/spec/models/blip_utils/__pycache__/blip_retrieval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/blip_retrieval.cpython-38.pyc
--------------------------------------------------------------------------------
/spec/models/blip_utils/__pycache__/med.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/med.cpython-38.pyc
--------------------------------------------------------------------------------
/spec/models/blip_utils/__pycache__/vit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/vit.cpython-38.pyc
--------------------------------------------------------------------------------
/spec/models/blip_utils/blip.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | import warnings
9 | warnings.filterwarnings("ignore")
10 |
11 | from .vit import VisionTransformer, interpolate_pos_embed
12 | from .med import BertConfig, BertModel, BertLMHeadModel
13 | from transformers import BertTokenizer
14 |
15 | import torch
16 | from torch import nn
17 | import torch.nn.functional as F
18 |
19 | import os
20 | from urllib.parse import urlparse
21 | from timm.models.hub import download_cached_file
22 |
23 | class BLIP_Base(nn.Module):
24 | def __init__(self,
25 | med_config = 'configs/med_config.json',
26 | image_size = 224,
27 | vit = 'base',
28 | vit_grad_ckpt = False,
29 | vit_ckpt_layer = 0,
30 | ):
31 | """
32 | Args:
33 | med_config (str): path for the mixture of encoder-decoder model's configuration file
34 | image_size (int): input image size
35 | vit (str): model size of vision transformer
36 | """
37 | super().__init__()
38 |
39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
40 | self.tokenizer = init_tokenizer()
41 | med_config = BertConfig.from_json_file(med_config)
42 | med_config.encoder_width = vision_width
43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
44 |
45 |
46 | def forward(self, image, caption, mode):
47 |
48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device)
50 |
51 | if mode=='image':
52 | # return image features
53 | image_embeds = self.visual_encoder(image)
54 | return image_embeds
55 |
56 | elif mode=='text':
57 | # return text features
58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
59 | return_dict = True, mode = 'text')
60 | return text_output.last_hidden_state
61 |
62 | elif mode=='multimodal':
63 | # return multimodel features
64 | image_embeds = self.visual_encoder(image)
65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
66 |
67 | text.input_ids[:,0] = self.tokenizer.enc_token_id
68 | output = self.text_encoder(text.input_ids,
69 | attention_mask = text.attention_mask,
70 | encoder_hidden_states = image_embeds,
71 | encoder_attention_mask = image_atts,
72 | return_dict = True,
73 | )
74 | return output.last_hidden_state
75 |
76 |
77 |
78 | class BLIP_Decoder(nn.Module):
79 | def __init__(self,
80 | med_config = 'configs/med_config.json',
81 | image_size = 384,
82 | vit = 'base',
83 | vit_grad_ckpt = False,
84 | vit_ckpt_layer = 0,
85 | prompt = 'a picture of ',
86 | ):
87 | """
88 | Args:
89 | med_config (str): path for the mixture of encoder-decoder model's configuration file
90 | image_size (int): input image size
91 | vit (str): model size of vision transformer
92 | """
93 | super().__init__()
94 |
95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
96 | self.tokenizer = init_tokenizer()
97 | med_config = BertConfig.from_json_file(med_config)
98 | med_config.encoder_width = vision_width
99 | self.text_decoder = BertLMHeadModel(config=med_config)
100 |
101 | self.prompt = prompt
102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
103 |
104 |
105 | def forward(self, image, caption):
106 |
107 | image_embeds = self.visual_encoder(image)
108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
109 |
110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
111 |
112 | text.input_ids[:,0] = self.tokenizer.bos_token_id
113 |
114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
115 | decoder_targets[:,:self.prompt_length] = -100
116 |
117 | decoder_output = self.text_decoder(text.input_ids,
118 | attention_mask = text.attention_mask,
119 | encoder_hidden_states = image_embeds,
120 | encoder_attention_mask = image_atts,
121 | labels = decoder_targets,
122 | return_dict = True,
123 | )
124 | loss_lm = decoder_output.loss
125 |
126 | return loss_lm
127 |
128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
129 | image_embeds = self.visual_encoder(image)
130 |
131 | if not sample:
132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
133 |
134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
136 |
137 | prompt = [self.prompt] * image.size(0)
138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
139 | input_ids[:,0] = self.tokenizer.bos_token_id
140 | input_ids = input_ids[:, :-1]
141 |
142 | if sample:
143 | #nucleus sampling
144 | outputs = self.text_decoder.generate(input_ids=input_ids,
145 | max_length=max_length,
146 | min_length=min_length,
147 | do_sample=True,
148 | top_p=top_p,
149 | num_return_sequences=1,
150 | eos_token_id=self.tokenizer.sep_token_id,
151 | pad_token_id=self.tokenizer.pad_token_id,
152 | repetition_penalty=1.1,
153 | **model_kwargs)
154 | else:
155 | #beam search
156 | outputs = self.text_decoder.generate(input_ids=input_ids,
157 | max_length=max_length,
158 | min_length=min_length,
159 | num_beams=num_beams,
160 | eos_token_id=self.tokenizer.sep_token_id,
161 | pad_token_id=self.tokenizer.pad_token_id,
162 | repetition_penalty=repetition_penalty,
163 | **model_kwargs)
164 |
165 | captions = []
166 | for output in outputs:
167 | caption = self.tokenizer.decode(output, skip_special_tokens=True)
168 | captions.append(caption[len(self.prompt):])
169 | return captions
170 |
171 |
172 | def blip_decoder(pretrained='',**kwargs):
173 | model = BLIP_Decoder(**kwargs)
174 | if pretrained:
175 | model,msg = load_checkpoint(model,pretrained)
176 | assert(len(msg.missing_keys)==0)
177 | return model
178 |
179 | def blip_feature_extractor(pretrained='',**kwargs):
180 | model = BLIP_Base(**kwargs)
181 | if pretrained:
182 | model,msg = load_checkpoint(model,pretrained)
183 | assert(len(msg.missing_keys)==0)
184 | return model
185 |
186 | def init_tokenizer():
187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'})
189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
191 | return tokenizer
192 |
193 |
194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
195 |
196 | assert vit in ['base', 'large'], "vit parameter must be base or large"
197 | if vit=='base':
198 | vision_width = 768
199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
201 | drop_path_rate=0 or drop_path_rate
202 | )
203 | elif vit=='large':
204 | vision_width = 1024
205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
207 | drop_path_rate=0.1 or drop_path_rate
208 | )
209 | return visual_encoder, vision_width
210 |
211 | def is_url(url_or_filename):
212 | parsed = urlparse(url_or_filename)
213 | return parsed.scheme in ("http", "https")
214 |
215 | def load_checkpoint(model,url_or_filename):
216 | if is_url(url_or_filename):
217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
218 | checkpoint = torch.load(cached_file, map_location='cpu')
219 | elif os.path.isfile(url_or_filename):
220 | checkpoint = torch.load(url_or_filename, map_location='cpu')
221 | else:
222 | raise RuntimeError('checkpoint url or path is invalid')
223 |
224 | state_dict = checkpoint['model']
225 |
226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
229 | model.visual_encoder_m)
230 | for key in model.state_dict().keys():
231 | if key in state_dict.keys():
232 | if state_dict[key].shape!=model.state_dict()[key].shape:
233 | del state_dict[key]
234 |
235 | msg = model.load_state_dict(state_dict,strict=False)
236 | print('load checkpoint from %s'%url_or_filename)
237 | return model,msg
238 |
239 |
--------------------------------------------------------------------------------
/spec/models/blip_utils/blip_itm.py:
--------------------------------------------------------------------------------
1 | from .med import BertConfig, BertModel
2 | from transformers import BertTokenizer
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 | from .blip import create_vit, init_tokenizer, load_checkpoint
9 |
10 | class BLIP_ITM(nn.Module):
11 | def __init__(self,
12 | med_config = 'configs/med_config.json',
13 | image_size = 384,
14 | vit = 'base',
15 | vit_grad_ckpt = False,
16 | vit_ckpt_layer = 0,
17 | embed_dim = 256,
18 | ):
19 | """
20 | Args:
21 | med_config (str): path for the mixture of encoder-decoder model's configuration file
22 | image_size (int): input image size
23 | vit (str): model size of vision transformer
24 | """
25 | super().__init__()
26 |
27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
28 | self.tokenizer = init_tokenizer()
29 | med_config = BertConfig.from_json_file(med_config)
30 | med_config.encoder_width = vision_width
31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
32 |
33 | text_width = self.text_encoder.config.hidden_size
34 |
35 | self.vision_proj = nn.Linear(vision_width, embed_dim)
36 | self.text_proj = nn.Linear(text_width, embed_dim)
37 |
38 | self.itm_head = nn.Linear(text_width, 2)
39 |
40 |
41 | def forward(self, image, caption, match_head='itm'):
42 |
43 | image_embeds = self.visual_encoder(image)
44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45 |
46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
47 | return_tensors="pt").to(image.device)
48 |
49 |
50 | if match_head=='itm':
51 | output = self.text_encoder(text.input_ids,
52 | attention_mask = text.attention_mask,
53 | encoder_hidden_states = image_embeds,
54 | encoder_attention_mask = image_atts,
55 | return_dict = True,
56 | )
57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:])
58 | return itm_output
59 |
60 | elif match_head=='itc':
61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
62 | return_dict = True, mode = 'text')
63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
65 |
66 | sim = image_feat @ text_feat.t()
67 | return sim
68 |
69 |
70 | def blip_itm(pretrained='',**kwargs):
71 | model = BLIP_ITM(**kwargs)
72 | if pretrained:
73 | model,msg = load_checkpoint(model,pretrained)
74 | assert(len(msg.missing_keys)==0)
75 | return model
76 |
--------------------------------------------------------------------------------
/spec/models/blip_utils/blip_pretrain.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | from .med import BertConfig, BertModel, BertLMHeadModel
9 | from transformers import BertTokenizer
10 | import transformers
11 | transformers.logging.set_verbosity_error()
12 |
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F
16 |
17 | from .blip import create_vit, init_tokenizer, load_checkpoint
18 |
19 | class BLIP_Pretrain(nn.Module):
20 | def __init__(self,
21 | med_config = 'configs/bert_config.json',
22 | image_size = 224,
23 | vit = 'base',
24 | vit_grad_ckpt = False,
25 | vit_ckpt_layer = 0,
26 | embed_dim = 256,
27 | queue_size = 57600,
28 | momentum = 0.995,
29 | ):
30 | """
31 | Args:
32 | med_config (str): path for the mixture of encoder-decoder model's configuration file
33 | image_size (int): input image size
34 | vit (str): model size of vision transformer
35 | """
36 | super().__init__()
37 |
38 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
39 |
40 | if vit=='base':
41 | checkpoint = torch.hub.load_state_dict_from_url(
42 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
43 | map_location="cpu", check_hash=True)
44 | state_dict = checkpoint["model"]
45 | msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
46 | elif vit=='large':
47 | from timm.models.helpers import load_custom_pretrained
48 | from timm.models.vision_transformer import default_cfgs
49 | load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
50 |
51 | self.tokenizer = init_tokenizer()
52 | encoder_config = BertConfig.from_json_file(med_config)
53 | encoder_config.encoder_width = vision_width
54 | self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
55 | self.text_encoder.resize_token_embeddings(len(self.tokenizer))
56 |
57 | text_width = self.text_encoder.config.hidden_size
58 |
59 | self.vision_proj = nn.Linear(vision_width, embed_dim)
60 | self.text_proj = nn.Linear(text_width, embed_dim)
61 |
62 | self.itm_head = nn.Linear(text_width, 2)
63 |
64 | # create momentum encoders
65 | self.visual_encoder_m, vision_width = create_vit(vit,image_size)
66 | self.vision_proj_m = nn.Linear(vision_width, embed_dim)
67 | self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
68 | self.text_proj_m = nn.Linear(text_width, embed_dim)
69 |
70 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
71 | [self.vision_proj,self.vision_proj_m],
72 | [self.text_encoder,self.text_encoder_m],
73 | [self.text_proj,self.text_proj_m],
74 | ]
75 | self.copy_params()
76 |
77 | # create the queue
78 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
79 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
80 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
81 |
82 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
83 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
84 |
85 | self.queue_size = queue_size
86 | self.momentum = momentum
87 | self.temp = nn.Parameter(0.07*torch.ones([]))
88 |
89 | # create the decoder
90 | decoder_config = BertConfig.from_json_file(med_config)
91 | decoder_config.encoder_width = vision_width
92 | self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
93 | self.text_decoder.resize_token_embeddings(len(self.tokenizer))
94 | tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
95 |
96 |
97 | def forward(self, image, caption, alpha):
98 | with torch.no_grad():
99 | self.temp.clamp_(0.001,0.5)
100 |
101 | image_embeds = self.visual_encoder(image)
102 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
103 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
104 |
105 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
106 | return_tensors="pt").to(image.device)
107 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
108 | return_dict = True, mode = 'text')
109 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
110 |
111 | # get momentum features
112 | with torch.no_grad():
113 | self._momentum_update()
114 | image_embeds_m = self.visual_encoder_m(image)
115 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
116 | image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
117 |
118 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
119 | return_dict = True, mode = 'text')
120 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
121 | text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
122 |
123 | sim_i2t_m = image_feat_m @ text_feat_all / self.temp
124 | sim_t2i_m = text_feat_m @ image_feat_all / self.temp
125 |
126 | sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
127 | sim_targets.fill_diagonal_(1)
128 |
129 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
130 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
131 |
132 | sim_i2t = image_feat @ text_feat_all / self.temp
133 | sim_t2i = text_feat @ image_feat_all / self.temp
134 |
135 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
136 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
137 |
138 | loss_ita = (loss_i2t+loss_t2i)/2
139 |
140 | self._dequeue_and_enqueue(image_feat_m, text_feat_m)
141 |
142 | ###============== Image-text Matching ===================###
143 | encoder_input_ids = text.input_ids.clone()
144 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id
145 |
146 | # forward the positve image-text pair
147 | bs = image.size(0)
148 | output_pos = self.text_encoder(encoder_input_ids,
149 | attention_mask = text.attention_mask,
150 | encoder_hidden_states = image_embeds,
151 | encoder_attention_mask = image_atts,
152 | return_dict = True,
153 | )
154 | with torch.no_grad():
155 | weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
156 | weights_t2i.fill_diagonal_(0)
157 | weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
158 | weights_i2t.fill_diagonal_(0)
159 |
160 | # select a negative image for each text
161 | image_embeds_neg = []
162 | for b in range(bs):
163 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
164 | image_embeds_neg.append(image_embeds[neg_idx])
165 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
166 |
167 | # select a negative text for each image
168 | text_ids_neg = []
169 | text_atts_neg = []
170 | for b in range(bs):
171 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
172 | text_ids_neg.append(encoder_input_ids[neg_idx])
173 | text_atts_neg.append(text.attention_mask[neg_idx])
174 |
175 | text_ids_neg = torch.stack(text_ids_neg,dim=0)
176 | text_atts_neg = torch.stack(text_atts_neg,dim=0)
177 |
178 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
179 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
180 |
181 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
182 | image_atts_all = torch.cat([image_atts,image_atts],dim=0)
183 |
184 | output_neg = self.text_encoder(text_ids_all,
185 | attention_mask = text_atts_all,
186 | encoder_hidden_states = image_embeds_all,
187 | encoder_attention_mask = image_atts_all,
188 | return_dict = True,
189 | )
190 |
191 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
192 | vl_output = self.itm_head(vl_embeddings)
193 |
194 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
195 | dim=0).to(image.device)
196 | loss_itm = F.cross_entropy(vl_output, itm_labels)
197 |
198 | ##================= LM ========================##
199 | decoder_input_ids = text.input_ids.clone()
200 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id
201 | decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
202 |
203 | decoder_output = self.text_decoder(decoder_input_ids,
204 | attention_mask = text.attention_mask,
205 | encoder_hidden_states = image_embeds,
206 | encoder_attention_mask = image_atts,
207 | labels = decoder_targets,
208 | return_dict = True,
209 | )
210 |
211 | loss_lm = decoder_output.loss
212 | return loss_ita, loss_itm, loss_lm
213 |
214 |
215 |
216 | @torch.no_grad()
217 | def copy_params(self):
218 | for model_pair in self.model_pairs:
219 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
220 | param_m.data.copy_(param.data) # initialize
221 | param_m.requires_grad = False # not update by gradient
222 |
223 |
224 | @torch.no_grad()
225 | def _momentum_update(self):
226 | for model_pair in self.model_pairs:
227 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
228 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
229 |
230 |
231 | @torch.no_grad()
232 | def _dequeue_and_enqueue(self, image_feat, text_feat):
233 | # gather keys before updating queue
234 | image_feats = concat_all_gather(image_feat)
235 | text_feats = concat_all_gather(text_feat)
236 |
237 | batch_size = image_feats.shape[0]
238 |
239 | ptr = int(self.queue_ptr)
240 | assert self.queue_size % batch_size == 0 # for simplicity
241 |
242 | # replace the keys at ptr (dequeue and enqueue)
243 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
244 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
245 | ptr = (ptr + batch_size) % self.queue_size # move pointer
246 |
247 | self.queue_ptr[0] = ptr
248 |
249 |
250 | def blip_pretrain(**kwargs):
251 | model = BLIP_Pretrain(**kwargs)
252 | return model
253 |
254 |
255 | @torch.no_grad()
256 | def concat_all_gather(tensor):
257 | """
258 | Performs all_gather operation on the provided tensors.
259 | *** Warning ***: torch.distributed.all_gather has no gradient.
260 | """
261 | tensors_gather = [torch.ones_like(tensor)
262 | for _ in range(torch.distributed.get_world_size())]
263 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
264 |
265 | output = torch.cat(tensors_gather, dim=0)
266 | return output
267 |
268 |
269 | from typing import List
270 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
271 | uninitialized_encoder_weights: List[str] = []
272 | if decoder.__class__ != encoder.__class__:
273 | logger.info(
274 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
275 | )
276 |
277 | def tie_encoder_to_decoder_recursively(
278 | decoder_pointer: nn.Module,
279 | encoder_pointer: nn.Module,
280 | module_name: str,
281 | uninitialized_encoder_weights: List[str],
282 | skip_key: str,
283 | depth=0,
284 | ):
285 | assert isinstance(decoder_pointer, nn.Module) and isinstance(
286 | encoder_pointer, nn.Module
287 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
288 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
289 | assert hasattr(encoder_pointer, "weight")
290 | encoder_pointer.weight = decoder_pointer.weight
291 | if hasattr(decoder_pointer, "bias"):
292 | assert hasattr(encoder_pointer, "bias")
293 | encoder_pointer.bias = decoder_pointer.bias
294 | print(module_name+' is tied')
295 | return
296 |
297 | encoder_modules = encoder_pointer._modules
298 | decoder_modules = decoder_pointer._modules
299 | if len(decoder_modules) > 0:
300 | assert (
301 | len(encoder_modules) > 0
302 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
303 |
304 | all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
305 | encoder_layer_pos = 0
306 | for name, module in decoder_modules.items():
307 | if name.isdigit():
308 | encoder_name = str(int(name) + encoder_layer_pos)
309 | decoder_name = name
310 | if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
311 | encoder_modules
312 | ) != len(decoder_modules):
313 | # this can happen if the name corresponds to the position in a list module list of layers
314 | # in this case the decoder has added a cross-attention that the encoder does not have
315 | # thus skip this step and subtract one layer pos from encoder
316 | encoder_layer_pos -= 1
317 | continue
318 | elif name not in encoder_modules:
319 | continue
320 | elif depth > 500:
321 | raise ValueError(
322 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
323 | )
324 | else:
325 | decoder_name = encoder_name = name
326 | tie_encoder_to_decoder_recursively(
327 | decoder_modules[decoder_name],
328 | encoder_modules[encoder_name],
329 | module_name + "/" + name,
330 | uninitialized_encoder_weights,
331 | skip_key,
332 | depth=depth + 1,
333 | )
334 | all_encoder_weights.remove(module_name + "/" + encoder_name)
335 |
336 | uninitialized_encoder_weights += list(all_encoder_weights)
337 |
338 | # tie weights recursively
339 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
--------------------------------------------------------------------------------
/spec/models/blip_utils/blip_retrieval.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | '''
8 | from .med import BertConfig, BertModel
9 | from transformers import BertTokenizer
10 |
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 |
15 | from .blip import create_vit, init_tokenizer, load_checkpoint
16 |
17 | class BLIP_Retrieval(nn.Module):
18 | def __init__(self,
19 | med_config = 'configs/med_config.json',
20 | image_size = 384,
21 | vit = 'base',
22 | vit_grad_ckpt = False,
23 | vit_ckpt_layer = 0,
24 | embed_dim = 256,
25 | queue_size = 57600,
26 | momentum = 0.995,
27 | negative_all_rank = False,
28 | ):
29 | """
30 | Args:
31 | med_config (str): path for the mixture of encoder-decoder model's configuration file
32 | image_size (int): input image size
33 | vit (str): model size of vision transformer
34 | """
35 | super().__init__()
36 |
37 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
38 | self.tokenizer = init_tokenizer()
39 | med_config = BertConfig.from_json_file(med_config)
40 | med_config.encoder_width = vision_width
41 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
42 |
43 | text_width = self.text_encoder.config.hidden_size
44 |
45 | self.vision_proj = nn.Linear(vision_width, embed_dim)
46 | self.text_proj = nn.Linear(text_width, embed_dim)
47 |
48 | self.itm_head = nn.Linear(text_width, 2)
49 |
50 | # create momentum encoders
51 | self.visual_encoder_m, vision_width = create_vit(vit,image_size)
52 | self.vision_proj_m = nn.Linear(vision_width, embed_dim)
53 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
54 | self.text_proj_m = nn.Linear(text_width, embed_dim)
55 |
56 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
57 | [self.vision_proj,self.vision_proj_m],
58 | [self.text_encoder,self.text_encoder_m],
59 | [self.text_proj,self.text_proj_m],
60 | ]
61 | self.copy_params()
62 |
63 | # create the queue
64 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
65 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
66 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
67 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
68 |
69 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
70 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
71 |
72 | self.queue_size = queue_size
73 | self.momentum = momentum
74 | self.temp = nn.Parameter(0.07*torch.ones([]))
75 |
76 | self.negative_all_rank = negative_all_rank
77 |
78 |
79 | def forward(self, image, caption, alpha, idx):
80 | with torch.no_grad():
81 | self.temp.clamp_(0.001,0.5)
82 |
83 | image_embeds = self.visual_encoder(image)
84 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
85 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
86 |
87 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
88 | return_tensors="pt").to(image.device)
89 |
90 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
91 | return_dict = True, mode = 'text')
92 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
93 |
94 | ###============== Image-text Contrastive Learning ===================###
95 | idx = idx.view(-1,1)
96 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
97 | pos_idx = torch.eq(idx, idx_all).float()
98 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
99 |
100 | # get momentum features
101 | with torch.no_grad():
102 | self._momentum_update()
103 | image_embeds_m = self.visual_encoder_m(image)
104 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
105 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
106 |
107 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
108 | return_dict = True, mode = 'text')
109 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
110 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
111 |
112 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
113 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
114 |
115 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
116 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
117 |
118 | sim_i2t = image_feat @ text_feat_m_all / self.temp
119 | sim_t2i = text_feat @ image_feat_m_all / self.temp
120 |
121 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
122 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
123 |
124 | loss_ita = (loss_i2t+loss_t2i)/2
125 |
126 | idxs = concat_all_gather(idx)
127 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
128 |
129 | ###============== Image-text Matching ===================###
130 | encoder_input_ids = text.input_ids.clone()
131 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id
132 |
133 | # forward the positve image-text pair
134 | bs = image.size(0)
135 | output_pos = self.text_encoder(encoder_input_ids,
136 | attention_mask = text.attention_mask,
137 | encoder_hidden_states = image_embeds,
138 | encoder_attention_mask = image_atts,
139 | return_dict = True,
140 | )
141 |
142 |
143 | if self.negative_all_rank:
144 | # compute sample similarity
145 | with torch.no_grad():
146 | mask = torch.eq(idx, idxs.t())
147 |
148 | image_feat_world = concat_all_gather(image_feat)
149 | text_feat_world = concat_all_gather(text_feat)
150 |
151 | sim_i2t = image_feat @ text_feat_world.t() / self.temp
152 | sim_t2i = text_feat @ image_feat_world.t() / self.temp
153 |
154 | weights_i2t = F.softmax(sim_i2t,dim=1)
155 | weights_i2t.masked_fill_(mask, 0)
156 |
157 | weights_t2i = F.softmax(sim_t2i,dim=1)
158 | weights_t2i.masked_fill_(mask, 0)
159 |
160 | image_embeds_world = all_gather_with_grad(image_embeds)
161 |
162 | # select a negative image (from all ranks) for each text
163 | image_embeds_neg = []
164 | for b in range(bs):
165 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
166 | image_embeds_neg.append(image_embeds_world[neg_idx])
167 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
168 |
169 | # select a negative text (from all ranks) for each image
170 | input_ids_world = concat_all_gather(encoder_input_ids)
171 | att_mask_world = concat_all_gather(text.attention_mask)
172 |
173 | text_ids_neg = []
174 | text_atts_neg = []
175 | for b in range(bs):
176 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
177 | text_ids_neg.append(input_ids_world[neg_idx])
178 | text_atts_neg.append(att_mask_world[neg_idx])
179 |
180 | else:
181 | with torch.no_grad():
182 | mask = torch.eq(idx, idx.t())
183 |
184 | sim_i2t = image_feat @ text_feat.t() / self.temp
185 | sim_t2i = text_feat @ image_feat.t() / self.temp
186 |
187 | weights_i2t = F.softmax(sim_i2t,dim=1)
188 | weights_i2t.masked_fill_(mask, 0)
189 |
190 | weights_t2i = F.softmax(sim_t2i,dim=1)
191 | weights_t2i.masked_fill_(mask, 0)
192 |
193 | # select a negative image (from same rank) for each text
194 | image_embeds_neg = []
195 | for b in range(bs):
196 | neg_idx = torch.multinomial(weights_t2i[b], 1).item()
197 | image_embeds_neg.append(image_embeds[neg_idx])
198 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
199 |
200 | # select a negative text (from same rank) for each image
201 | text_ids_neg = []
202 | text_atts_neg = []
203 | for b in range(bs):
204 | neg_idx = torch.multinomial(weights_i2t[b], 1).item()
205 | text_ids_neg.append(encoder_input_ids[neg_idx])
206 | text_atts_neg.append(text.attention_mask[neg_idx])
207 |
208 | text_ids_neg = torch.stack(text_ids_neg,dim=0)
209 | text_atts_neg = torch.stack(text_atts_neg,dim=0)
210 |
211 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
212 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
213 |
214 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
215 | image_atts_all = torch.cat([image_atts,image_atts],dim=0)
216 |
217 | output_neg = self.text_encoder(text_ids_all,
218 | attention_mask = text_atts_all,
219 | encoder_hidden_states = image_embeds_all,
220 | encoder_attention_mask = image_atts_all,
221 | return_dict = True,
222 | )
223 |
224 |
225 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
226 | vl_output = self.itm_head(vl_embeddings)
227 |
228 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
229 | dim=0).to(image.device)
230 | loss_itm = F.cross_entropy(vl_output, itm_labels)
231 |
232 | return loss_ita, loss_itm
233 |
234 |
235 | @torch.no_grad()
236 | def copy_params(self):
237 | for model_pair in self.model_pairs:
238 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
239 | param_m.data.copy_(param.data) # initialize
240 | param_m.requires_grad = False # not update by gradient
241 |
242 |
243 | @torch.no_grad()
244 | def _momentum_update(self):
245 | for model_pair in self.model_pairs:
246 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
247 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
248 |
249 |
250 | @torch.no_grad()
251 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
252 | # gather keys before updating queue
253 | image_feats = concat_all_gather(image_feat)
254 | text_feats = concat_all_gather(text_feat)
255 |
256 |
257 | batch_size = image_feats.shape[0]
258 |
259 | ptr = int(self.ptr_queue)
260 | assert self.queue_size % batch_size == 0 # for simplicity
261 |
262 | # replace the keys at ptr (dequeue and enqueue)
263 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
264 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
265 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
266 | ptr = (ptr + batch_size) % self.queue_size # move pointer
267 |
268 | self.ptr_queue[0] = ptr
269 |
270 |
271 | def blip_retrieval(pretrained='',**kwargs):
272 | model = BLIP_Retrieval(**kwargs)
273 | if pretrained:
274 | model,msg = load_checkpoint(model,pretrained)
275 | print("missing keys:")
276 | print(msg.missing_keys)
277 | return model
278 |
279 |
280 | @torch.no_grad()
281 | def concat_all_gather(tensor):
282 | """
283 | Performs all_gather operation on the provided tensors.
284 | *** Warning ***: torch.distributed.all_gather has no gradient.
285 | """
286 | tensors_gather = [torch.ones_like(tensor)
287 | for _ in range(torch.distributed.get_world_size())]
288 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
289 |
290 | output = torch.cat(tensors_gather, dim=0)
291 | return output
292 |
293 |
294 | class GatherLayer(torch.autograd.Function):
295 | """
296 | Gather tensors from all workers with support for backward propagation:
297 | This implementation does not cut the gradients as torch.distributed.all_gather does.
298 | """
299 |
300 | @staticmethod
301 | def forward(ctx, x):
302 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
303 | torch.distributed.all_gather(output, x)
304 | return tuple(output)
305 |
306 | @staticmethod
307 | def backward(ctx, *grads):
308 | all_gradients = torch.stack(grads)
309 | torch.distributed.all_reduce(all_gradients)
310 | return all_gradients[torch.distributed.get_rank()]
311 |
312 |
313 | def all_gather_with_grad(tensors):
314 | """
315 | Performs all_gather operation on the provided tensors.
316 | Graph remains connected for backward grad computation.
317 | """
318 | # Queue the gathered tensors
319 | world_size = torch.distributed.get_world_size()
320 | # There is no need for reduction in the single-proc case
321 | if world_size == 1:
322 | return tensors
323 |
324 | tensor_all = GatherLayer.apply(tensors)
325 |
326 | return torch.cat(tensor_all, dim=0)
327 |
--------------------------------------------------------------------------------
/spec/models/blip_utils/med.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | * Based on huggingface code base
8 | * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9 | '''
10 |
11 | import math
12 | import os
13 | import warnings
14 | from dataclasses import dataclass
15 | from typing import Optional, Tuple
16 |
17 | import torch
18 | from torch import Tensor, device, dtype, nn
19 | import torch.utils.checkpoint
20 | from torch import nn
21 | from torch.nn import CrossEntropyLoss
22 | import torch.nn.functional as F
23 |
24 | from transformers.activations import ACT2FN
25 | from transformers.file_utils import (
26 | ModelOutput,
27 | )
28 | from transformers.modeling_outputs import (
29 | BaseModelOutputWithPastAndCrossAttentions,
30 | BaseModelOutputWithPoolingAndCrossAttentions,
31 | CausalLMOutputWithCrossAttentions,
32 | MaskedLMOutput,
33 | MultipleChoiceModelOutput,
34 | NextSentencePredictorOutput,
35 | QuestionAnsweringModelOutput,
36 | SequenceClassifierOutput,
37 | TokenClassifierOutput,
38 | )
39 | from transformers.modeling_utils import (
40 | PreTrainedModel,
41 | apply_chunking_to_forward,
42 | find_pruneable_heads_and_indices,
43 | prune_linear_layer,
44 | )
45 | from transformers.utils import logging
46 | from transformers.models.bert.configuration_bert import BertConfig
47 |
48 |
49 | logger = logging.get_logger(__name__)
50 |
51 |
52 | class BertEmbeddings(nn.Module):
53 | """Construct the embeddings from word and position embeddings."""
54 |
55 | def __init__(self, config):
56 | super().__init__()
57 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59 |
60 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61 | # any TensorFlow checkpoint file
62 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
64 |
65 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68 |
69 | self.config = config
70 |
71 | def forward(
72 | self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73 | ):
74 | if input_ids is not None:
75 | input_shape = input_ids.size()
76 | else:
77 | input_shape = inputs_embeds.size()[:-1]
78 |
79 | seq_length = input_shape[1]
80 |
81 | if position_ids is None:
82 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83 |
84 | if inputs_embeds is None:
85 | inputs_embeds = self.word_embeddings(input_ids)
86 |
87 | embeddings = inputs_embeds
88 |
89 | if self.position_embedding_type == "absolute":
90 | position_embeddings = self.position_embeddings(position_ids)
91 | embeddings += position_embeddings
92 | embeddings = self.LayerNorm(embeddings)
93 | embeddings = self.dropout(embeddings)
94 | return embeddings
95 |
96 |
97 | class BertSelfAttention(nn.Module):
98 | def __init__(self, config, is_cross_attention):
99 | super().__init__()
100 | self.config = config
101 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102 | raise ValueError(
103 | "The hidden size (%d) is not a multiple of the number of attention "
104 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105 | )
106 |
107 | self.num_attention_heads = config.num_attention_heads
108 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109 | self.all_head_size = self.num_attention_heads * self.attention_head_size
110 |
111 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
112 | if is_cross_attention:
113 | self.key = nn.Linear(config.encoder_width, self.all_head_size)
114 | self.value = nn.Linear(config.encoder_width, self.all_head_size)
115 | else:
116 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
117 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
118 |
119 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122 | self.max_position_embeddings = config.max_position_embeddings
123 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124 | self.save_attention = False
125 |
126 | def save_attn_gradients(self, attn_gradients):
127 | self.attn_gradients = attn_gradients
128 |
129 | def get_attn_gradients(self):
130 | return self.attn_gradients
131 |
132 | def save_attention_map(self, attention_map):
133 | self.attention_map = attention_map
134 |
135 | def get_attention_map(self):
136 | return self.attention_map
137 |
138 | def transpose_for_scores(self, x):
139 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140 | x = x.view(*new_x_shape)
141 | return x.permute(0, 2, 1, 3)
142 |
143 | def forward(
144 | self,
145 | hidden_states,
146 | attention_mask=None,
147 | head_mask=None,
148 | encoder_hidden_states=None,
149 | encoder_attention_mask=None,
150 | past_key_value=None,
151 | output_attentions=False,
152 | ):
153 | mixed_query_layer = self.query(hidden_states)
154 |
155 | # If this is instantiated as a cross-attention module, the keys
156 | # and values come from an encoder; the attention mask needs to be
157 | # such that the encoder's padding tokens are not attended to.
158 | is_cross_attention = encoder_hidden_states is not None
159 |
160 | if is_cross_attention:
161 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163 | attention_mask = encoder_attention_mask
164 | elif past_key_value is not None:
165 | key_layer = self.transpose_for_scores(self.key(hidden_states))
166 | value_layer = self.transpose_for_scores(self.value(hidden_states))
167 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169 | else:
170 | key_layer = self.transpose_for_scores(self.key(hidden_states))
171 | value_layer = self.transpose_for_scores(self.value(hidden_states))
172 |
173 | query_layer = self.transpose_for_scores(mixed_query_layer)
174 |
175 | past_key_value = (key_layer, value_layer)
176 |
177 | # Take the dot product between "query" and "key" to get the raw attention scores.
178 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179 |
180 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181 | seq_length = hidden_states.size()[1]
182 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184 | distance = position_ids_l - position_ids_r
185 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187 |
188 | if self.position_embedding_type == "relative_key":
189 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190 | attention_scores = attention_scores + relative_position_scores
191 | elif self.position_embedding_type == "relative_key_query":
192 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195 |
196 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197 | if attention_mask is not None:
198 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199 | attention_scores = attention_scores + attention_mask
200 |
201 | # Normalize the attention scores to probabilities.
202 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
203 |
204 | if is_cross_attention and self.save_attention:
205 | self.save_attention_map(attention_probs)
206 | attention_probs.register_hook(self.save_attn_gradients)
207 |
208 | # This is actually dropping out entire tokens to attend to, which might
209 | # seem a bit unusual, but is taken from the original Transformer paper.
210 | attention_probs_dropped = self.dropout(attention_probs)
211 |
212 | # Mask heads if we want to
213 | if head_mask is not None:
214 | attention_probs_dropped = attention_probs_dropped * head_mask
215 |
216 | context_layer = torch.matmul(attention_probs_dropped, value_layer)
217 |
218 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220 | context_layer = context_layer.view(*new_context_layer_shape)
221 |
222 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223 |
224 | outputs = outputs + (past_key_value,)
225 | return outputs
226 |
227 |
228 | class BertSelfOutput(nn.Module):
229 | def __init__(self, config):
230 | super().__init__()
231 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
234 |
235 | def forward(self, hidden_states, input_tensor):
236 | hidden_states = self.dense(hidden_states)
237 | hidden_states = self.dropout(hidden_states)
238 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
239 | return hidden_states
240 |
241 |
242 | class BertAttention(nn.Module):
243 | def __init__(self, config, is_cross_attention=False):
244 | super().__init__()
245 | self.self = BertSelfAttention(config, is_cross_attention)
246 | self.output = BertSelfOutput(config)
247 | self.pruned_heads = set()
248 |
249 | def prune_heads(self, heads):
250 | if len(heads) == 0:
251 | return
252 | heads, index = find_pruneable_heads_and_indices(
253 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254 | )
255 |
256 | # Prune linear layers
257 | self.self.query = prune_linear_layer(self.self.query, index)
258 | self.self.key = prune_linear_layer(self.self.key, index)
259 | self.self.value = prune_linear_layer(self.self.value, index)
260 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261 |
262 | # Update hyper params and store pruned heads
263 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265 | self.pruned_heads = self.pruned_heads.union(heads)
266 |
267 | def forward(
268 | self,
269 | hidden_states,
270 | attention_mask=None,
271 | head_mask=None,
272 | encoder_hidden_states=None,
273 | encoder_attention_mask=None,
274 | past_key_value=None,
275 | output_attentions=False,
276 | ):
277 | self_outputs = self.self(
278 | hidden_states,
279 | attention_mask,
280 | head_mask,
281 | encoder_hidden_states,
282 | encoder_attention_mask,
283 | past_key_value,
284 | output_attentions,
285 | )
286 | attention_output = self.output(self_outputs[0], hidden_states)
287 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288 | return outputs
289 |
290 |
291 | class BertIntermediate(nn.Module):
292 | def __init__(self, config):
293 | super().__init__()
294 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295 | if isinstance(config.hidden_act, str):
296 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
297 | else:
298 | self.intermediate_act_fn = config.hidden_act
299 |
300 | def forward(self, hidden_states):
301 | hidden_states = self.dense(hidden_states)
302 | hidden_states = self.intermediate_act_fn(hidden_states)
303 | return hidden_states
304 |
305 |
306 | class BertOutput(nn.Module):
307 | def __init__(self, config):
308 | super().__init__()
309 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
312 |
313 | def forward(self, hidden_states, input_tensor):
314 | hidden_states = self.dense(hidden_states)
315 | hidden_states = self.dropout(hidden_states)
316 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
317 | return hidden_states
318 |
319 |
320 | class BertLayer(nn.Module):
321 | def __init__(self, config, layer_num):
322 | super().__init__()
323 | self.config = config
324 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
325 | self.seq_len_dim = 1
326 | self.attention = BertAttention(config)
327 | self.layer_num = layer_num
328 | if self.config.add_cross_attention:
329 | self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330 | self.intermediate = BertIntermediate(config)
331 | self.output = BertOutput(config)
332 |
333 | def forward(
334 | self,
335 | hidden_states,
336 | attention_mask=None,
337 | head_mask=None,
338 | encoder_hidden_states=None,
339 | encoder_attention_mask=None,
340 | past_key_value=None,
341 | output_attentions=False,
342 | mode=None,
343 | ):
344 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346 | self_attention_outputs = self.attention(
347 | hidden_states,
348 | attention_mask,
349 | head_mask,
350 | output_attentions=output_attentions,
351 | past_key_value=self_attn_past_key_value,
352 | )
353 | attention_output = self_attention_outputs[0]
354 |
355 | outputs = self_attention_outputs[1:-1]
356 | present_key_value = self_attention_outputs[-1]
357 |
358 | if mode=='multimodal':
359 | assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360 |
361 | cross_attention_outputs = self.crossattention(
362 | attention_output,
363 | attention_mask,
364 | head_mask,
365 | encoder_hidden_states,
366 | encoder_attention_mask,
367 | output_attentions=output_attentions,
368 | )
369 | attention_output = cross_attention_outputs[0]
370 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371 | layer_output = apply_chunking_to_forward(
372 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373 | )
374 | outputs = (layer_output,) + outputs
375 |
376 | outputs = outputs + (present_key_value,)
377 |
378 | return outputs
379 |
380 | def feed_forward_chunk(self, attention_output):
381 | intermediate_output = self.intermediate(attention_output)
382 | layer_output = self.output(intermediate_output, attention_output)
383 | return layer_output
384 |
385 |
386 | class BertEncoder(nn.Module):
387 | def __init__(self, config):
388 | super().__init__()
389 | self.config = config
390 | self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391 | self.gradient_checkpointing = False
392 |
393 | def forward(
394 | self,
395 | hidden_states,
396 | attention_mask=None,
397 | head_mask=None,
398 | encoder_hidden_states=None,
399 | encoder_attention_mask=None,
400 | past_key_values=None,
401 | use_cache=None,
402 | output_attentions=False,
403 | output_hidden_states=False,
404 | return_dict=True,
405 | mode='multimodal',
406 | ):
407 | all_hidden_states = () if output_hidden_states else None
408 | all_self_attentions = () if output_attentions else None
409 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410 |
411 | next_decoder_cache = () if use_cache else None
412 |
413 | for i in range(self.config.num_hidden_layers):
414 | layer_module = self.layer[i]
415 | if output_hidden_states:
416 | all_hidden_states = all_hidden_states + (hidden_states,)
417 |
418 | layer_head_mask = head_mask[i] if head_mask is not None else None
419 | past_key_value = past_key_values[i] if past_key_values is not None else None
420 |
421 | if self.gradient_checkpointing and self.training:
422 |
423 | if use_cache:
424 | logger.warn(
425 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426 | )
427 | use_cache = False
428 |
429 | def create_custom_forward(module):
430 | def custom_forward(*inputs):
431 | return module(*inputs, past_key_value, output_attentions)
432 |
433 | return custom_forward
434 |
435 | layer_outputs = torch.utils.checkpoint.checkpoint(
436 | create_custom_forward(layer_module),
437 | hidden_states,
438 | attention_mask,
439 | layer_head_mask,
440 | encoder_hidden_states,
441 | encoder_attention_mask,
442 | mode=mode,
443 | )
444 | else:
445 | layer_outputs = layer_module(
446 | hidden_states,
447 | attention_mask,
448 | layer_head_mask,
449 | encoder_hidden_states,
450 | encoder_attention_mask,
451 | past_key_value,
452 | output_attentions,
453 | mode=mode,
454 | )
455 |
456 | hidden_states = layer_outputs[0]
457 | if use_cache:
458 | next_decoder_cache += (layer_outputs[-1],)
459 | if output_attentions:
460 | all_self_attentions = all_self_attentions + (layer_outputs[1],)
461 |
462 | if output_hidden_states:
463 | all_hidden_states = all_hidden_states + (hidden_states,)
464 |
465 | if not return_dict:
466 | return tuple(
467 | v
468 | for v in [
469 | hidden_states,
470 | next_decoder_cache,
471 | all_hidden_states,
472 | all_self_attentions,
473 | all_cross_attentions,
474 | ]
475 | if v is not None
476 | )
477 | return BaseModelOutputWithPastAndCrossAttentions(
478 | last_hidden_state=hidden_states,
479 | past_key_values=next_decoder_cache,
480 | hidden_states=all_hidden_states,
481 | attentions=all_self_attentions,
482 | cross_attentions=all_cross_attentions,
483 | )
484 |
485 |
486 | class BertPooler(nn.Module):
487 | def __init__(self, config):
488 | super().__init__()
489 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490 | self.activation = nn.Tanh()
491 |
492 | def forward(self, hidden_states):
493 | # We "pool" the model by simply taking the hidden state corresponding
494 | # to the first token.
495 | first_token_tensor = hidden_states[:, 0]
496 | pooled_output = self.dense(first_token_tensor)
497 | pooled_output = self.activation(pooled_output)
498 | return pooled_output
499 |
500 |
501 | class BertPredictionHeadTransform(nn.Module):
502 | def __init__(self, config):
503 | super().__init__()
504 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505 | if isinstance(config.hidden_act, str):
506 | self.transform_act_fn = ACT2FN[config.hidden_act]
507 | else:
508 | self.transform_act_fn = config.hidden_act
509 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510 |
511 | def forward(self, hidden_states):
512 | hidden_states = self.dense(hidden_states)
513 | hidden_states = self.transform_act_fn(hidden_states)
514 | hidden_states = self.LayerNorm(hidden_states)
515 | return hidden_states
516 |
517 |
518 | class BertLMPredictionHead(nn.Module):
519 | def __init__(self, config):
520 | super().__init__()
521 | self.transform = BertPredictionHeadTransform(config)
522 |
523 | # The output weights are the same as the input embeddings, but there is
524 | # an output-only bias for each token.
525 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526 |
527 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528 |
529 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530 | self.decoder.bias = self.bias
531 |
532 | def forward(self, hidden_states):
533 | hidden_states = self.transform(hidden_states)
534 | hidden_states = self.decoder(hidden_states)
535 | return hidden_states
536 |
537 |
538 | class BertOnlyMLMHead(nn.Module):
539 | def __init__(self, config):
540 | super().__init__()
541 | self.predictions = BertLMPredictionHead(config)
542 |
543 | def forward(self, sequence_output):
544 | prediction_scores = self.predictions(sequence_output)
545 | return prediction_scores
546 |
547 |
548 | class BertPreTrainedModel(PreTrainedModel):
549 | """
550 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551 | models.
552 | """
553 |
554 | config_class = BertConfig
555 | base_model_prefix = "bert"
556 | _keys_to_ignore_on_load_missing = [r"position_ids"]
557 |
558 | def _init_weights(self, module):
559 | """ Initialize the weights """
560 | if isinstance(module, (nn.Linear, nn.Embedding)):
561 | # Slightly different from the TF version which uses truncated_normal for initialization
562 | # cf https://github.com/pytorch/pytorch/pull/5617
563 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564 | elif isinstance(module, nn.LayerNorm):
565 | module.bias.data.zero_()
566 | module.weight.data.fill_(1.0)
567 | if isinstance(module, nn.Linear) and module.bias is not None:
568 | module.bias.data.zero_()
569 |
570 |
571 | class BertModel(BertPreTrainedModel):
572 | """
573 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578 | input to the forward pass.
579 | """
580 |
581 | def __init__(self, config, add_pooling_layer=True):
582 | super().__init__(config)
583 | self.config = config
584 |
585 | self.embeddings = BertEmbeddings(config)
586 |
587 | self.encoder = BertEncoder(config)
588 |
589 | self.pooler = BertPooler(config) if add_pooling_layer else None
590 |
591 | self.init_weights()
592 |
593 |
594 | def get_input_embeddings(self):
595 | return self.embeddings.word_embeddings
596 |
597 | def set_input_embeddings(self, value):
598 | self.embeddings.word_embeddings = value
599 |
600 | def _prune_heads(self, heads_to_prune):
601 | """
602 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603 | class PreTrainedModel
604 | """
605 | for layer, heads in heads_to_prune.items():
606 | self.encoder.layer[layer].attention.prune_heads(heads)
607 |
608 |
609 | def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610 | """
611 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612 |
613 | Arguments:
614 | attention_mask (:obj:`torch.Tensor`):
615 | Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616 | input_shape (:obj:`Tuple[int]`):
617 | The shape of the input to the model.
618 | device: (:obj:`torch.device`):
619 | The device of the input to the model.
620 |
621 | Returns:
622 | :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623 | """
624 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625 | # ourselves in which case we just need to make it broadcastable to all heads.
626 | if attention_mask.dim() == 3:
627 | extended_attention_mask = attention_mask[:, None, :, :]
628 | elif attention_mask.dim() == 2:
629 | # Provided a padding mask of dimensions [batch_size, seq_length]
630 | # - if the model is a decoder, apply a causal mask in addition to the padding mask
631 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632 | if is_decoder:
633 | batch_size, seq_length = input_shape
634 |
635 | seq_ids = torch.arange(seq_length, device=device)
636 | causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637 | # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638 | # causal and attention masks must have same type with pytorch version < 1.3
639 | causal_mask = causal_mask.to(attention_mask.dtype)
640 |
641 | if causal_mask.shape[1] < attention_mask.shape[1]:
642 | prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643 | causal_mask = torch.cat(
644 | [
645 | torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646 | causal_mask,
647 | ],
648 | axis=-1,
649 | )
650 |
651 | extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652 | else:
653 | extended_attention_mask = attention_mask[:, None, None, :]
654 | else:
655 | raise ValueError(
656 | "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657 | input_shape, attention_mask.shape
658 | )
659 | )
660 |
661 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662 | # masked positions, this operation will create a tensor which is 0.0 for
663 | # positions we want to attend and -10000.0 for masked positions.
664 | # Since we are adding it to the raw scores before the softmax, this is
665 | # effectively the same as removing these entirely.
666 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668 | return extended_attention_mask
669 |
670 | def forward(
671 | self,
672 | input_ids=None,
673 | attention_mask=None,
674 | position_ids=None,
675 | head_mask=None,
676 | inputs_embeds=None,
677 | encoder_embeds=None,
678 | encoder_hidden_states=None,
679 | encoder_attention_mask=None,
680 | past_key_values=None,
681 | use_cache=None,
682 | output_attentions=None,
683 | output_hidden_states=None,
684 | return_dict=None,
685 | is_decoder=False,
686 | mode='multimodal',
687 | ):
688 | r"""
689 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691 | the model is configured as a decoder.
692 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695 | - 1 for tokens that are **not masked**,
696 | - 0 for tokens that are **masked**.
697 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702 | use_cache (:obj:`bool`, `optional`):
703 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704 | decoding (see :obj:`past_key_values`).
705 | """
706 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707 | output_hidden_states = (
708 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709 | )
710 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711 |
712 | if is_decoder:
713 | use_cache = use_cache if use_cache is not None else self.config.use_cache
714 | else:
715 | use_cache = False
716 |
717 | if input_ids is not None and inputs_embeds is not None:
718 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719 | elif input_ids is not None:
720 | input_shape = input_ids.size()
721 | batch_size, seq_length = input_shape
722 | device = input_ids.device
723 | elif inputs_embeds is not None:
724 | input_shape = inputs_embeds.size()[:-1]
725 | batch_size, seq_length = input_shape
726 | device = inputs_embeds.device
727 | elif encoder_embeds is not None:
728 | input_shape = encoder_embeds.size()[:-1]
729 | batch_size, seq_length = input_shape
730 | device = encoder_embeds.device
731 | else:
732 | raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733 |
734 | # past_key_values_length
735 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736 |
737 | if attention_mask is None:
738 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739 |
740 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741 | # ourselves in which case we just need to make it broadcastable to all heads.
742 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743 | device, is_decoder)
744 |
745 | # If a 2D or 3D attention mask is provided for the cross-attention
746 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747 | if encoder_hidden_states is not None:
748 | if type(encoder_hidden_states) == list:
749 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750 | else:
751 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753 |
754 | if type(encoder_attention_mask) == list:
755 | encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756 | elif encoder_attention_mask is None:
757 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759 | else:
760 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761 | else:
762 | encoder_extended_attention_mask = None
763 |
764 | # Prepare head mask if needed
765 | # 1.0 in head_mask indicate we keep the head
766 | # attention_probs has shape bsz x n_heads x N x N
767 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770 |
771 | if encoder_embeds is None:
772 | embedding_output = self.embeddings(
773 | input_ids=input_ids,
774 | position_ids=position_ids,
775 | inputs_embeds=inputs_embeds,
776 | past_key_values_length=past_key_values_length,
777 | )
778 | else:
779 | embedding_output = encoder_embeds
780 |
781 | encoder_outputs = self.encoder(
782 | embedding_output,
783 | attention_mask=extended_attention_mask,
784 | head_mask=head_mask,
785 | encoder_hidden_states=encoder_hidden_states,
786 | encoder_attention_mask=encoder_extended_attention_mask,
787 | past_key_values=past_key_values,
788 | use_cache=use_cache,
789 | output_attentions=output_attentions,
790 | output_hidden_states=output_hidden_states,
791 | return_dict=return_dict,
792 | mode=mode,
793 | )
794 | sequence_output = encoder_outputs[0]
795 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796 |
797 | if not return_dict:
798 | return (sequence_output, pooled_output) + encoder_outputs[1:]
799 |
800 | return BaseModelOutputWithPoolingAndCrossAttentions(
801 | last_hidden_state=sequence_output,
802 | pooler_output=pooled_output,
803 | past_key_values=encoder_outputs.past_key_values,
804 | hidden_states=encoder_outputs.hidden_states,
805 | attentions=encoder_outputs.attentions,
806 | cross_attentions=encoder_outputs.cross_attentions,
807 | )
808 |
809 |
810 |
811 | class BertLMHeadModel(BertPreTrainedModel):
812 |
813 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
814 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815 |
816 | def __init__(self, config):
817 | super().__init__(config)
818 |
819 | self.bert = BertModel(config, add_pooling_layer=False)
820 | self.cls = BertOnlyMLMHead(config)
821 |
822 | self.init_weights()
823 |
824 | def get_output_embeddings(self):
825 | return self.cls.predictions.decoder
826 |
827 | def set_output_embeddings(self, new_embeddings):
828 | self.cls.predictions.decoder = new_embeddings
829 |
830 | def forward(
831 | self,
832 | input_ids=None,
833 | attention_mask=None,
834 | position_ids=None,
835 | head_mask=None,
836 | inputs_embeds=None,
837 | encoder_hidden_states=None,
838 | encoder_attention_mask=None,
839 | labels=None,
840 | past_key_values=None,
841 | use_cache=None,
842 | output_attentions=None,
843 | output_hidden_states=None,
844 | return_dict=None,
845 | return_logits=False,
846 | is_decoder=True,
847 | reduction='mean',
848 | mode='multimodal',
849 | ):
850 | r"""
851 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853 | the model is configured as a decoder.
854 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857 | - 1 for tokens that are **not masked**,
858 | - 0 for tokens that are **masked**.
859 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868 | use_cache (:obj:`bool`, `optional`):
869 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870 | decoding (see :obj:`past_key_values`).
871 | Returns:
872 | Example::
873 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874 | >>> import torch
875 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876 | >>> config = BertConfig.from_pretrained("bert-base-cased")
877 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879 | >>> outputs = model(**inputs)
880 | >>> prediction_logits = outputs.logits
881 | """
882 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883 | if labels is not None:
884 | use_cache = False
885 |
886 | outputs = self.bert(
887 | input_ids,
888 | attention_mask=attention_mask,
889 | position_ids=position_ids,
890 | head_mask=head_mask,
891 | inputs_embeds=inputs_embeds,
892 | encoder_hidden_states=encoder_hidden_states,
893 | encoder_attention_mask=encoder_attention_mask,
894 | past_key_values=past_key_values,
895 | use_cache=use_cache,
896 | output_attentions=output_attentions,
897 | output_hidden_states=output_hidden_states,
898 | return_dict=return_dict,
899 | is_decoder=is_decoder,
900 | mode=mode,
901 | )
902 |
903 | sequence_output = outputs[0]
904 | prediction_scores = self.cls(sequence_output)
905 |
906 | if return_logits:
907 | return prediction_scores[:, :-1, :].contiguous()
908 |
909 | lm_loss = None
910 | if labels is not None:
911 | # we are doing next-token prediction; shift prediction scores and input ids by one
912 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913 | labels = labels[:, 1:].contiguous()
914 | loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916 | if reduction=='none':
917 | lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918 |
919 | if not return_dict:
920 | output = (prediction_scores,) + outputs[2:]
921 | return ((lm_loss,) + output) if lm_loss is not None else output
922 |
923 | return CausalLMOutputWithCrossAttentions(
924 | loss=lm_loss,
925 | logits=prediction_scores,
926 | past_key_values=outputs.past_key_values,
927 | hidden_states=outputs.hidden_states,
928 | attentions=outputs.attentions,
929 | cross_attentions=outputs.cross_attentions,
930 | )
931 |
932 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933 | input_shape = input_ids.shape
934 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935 | if attention_mask is None:
936 | attention_mask = input_ids.new_ones(input_shape)
937 |
938 | # cut decoder_input_ids if past is used
939 | if past is not None:
940 | input_ids = input_ids[:, -1:]
941 |
942 | return {
943 | "input_ids": input_ids,
944 | "attention_mask": attention_mask,
945 | "past_key_values": past,
946 | "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947 | "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948 | "is_decoder": True,
949 | }
950 |
951 | def _reorder_cache(self, past, beam_idx):
952 | reordered_past = ()
953 | for layer_past in past:
954 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955 | return reordered_past
956 |
--------------------------------------------------------------------------------
/spec/models/blip_utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3 | """Decay the learning rate"""
4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5 | for param_group in optimizer.param_groups:
6 | param_group['lr'] = lr
7 |
8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9 | """Warmup the learning rate"""
10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11 | for param_group in optimizer.param_groups:
12 | param_group['lr'] = lr
13 |
14 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15 | """Decay the learning rate"""
16 | lr = max(min_lr, init_lr * (decay_rate**epoch))
17 | for param_group in optimizer.param_groups:
18 | param_group['lr'] = lr
19 |
20 | import numpy as np
21 | import io
22 | import os
23 | import time
24 | from collections import defaultdict, deque
25 | import datetime
26 |
27 | import torch
28 | import torch.distributed as dist
29 |
30 | class SmoothedValue(object):
31 | """Track a series of values and provide access to smoothed values over a
32 | window or the global series average.
33 | """
34 |
35 | def __init__(self, window_size=20, fmt=None):
36 | if fmt is None:
37 | fmt = "{median:.4f} ({global_avg:.4f})"
38 | self.deque = deque(maxlen=window_size)
39 | self.total = 0.0
40 | self.count = 0
41 | self.fmt = fmt
42 |
43 | def update(self, value, n=1):
44 | self.deque.append(value)
45 | self.count += n
46 | self.total += value * n
47 |
48 | def synchronize_between_processes(self):
49 | """
50 | Warning: does not synchronize the deque!
51 | """
52 | if not is_dist_avail_and_initialized():
53 | return
54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55 | dist.barrier()
56 | dist.all_reduce(t)
57 | t = t.tolist()
58 | self.count = int(t[0])
59 | self.total = t[1]
60 |
61 | @property
62 | def median(self):
63 | d = torch.tensor(list(self.deque))
64 | return d.median().item()
65 |
66 | @property
67 | def avg(self):
68 | d = torch.tensor(list(self.deque), dtype=torch.float32)
69 | return d.mean().item()
70 |
71 | @property
72 | def global_avg(self):
73 | return self.total / self.count
74 |
75 | @property
76 | def max(self):
77 | return max(self.deque)
78 |
79 | @property
80 | def value(self):
81 | return self.deque[-1]
82 |
83 | def __str__(self):
84 | return self.fmt.format(
85 | median=self.median,
86 | avg=self.avg,
87 | global_avg=self.global_avg,
88 | max=self.max,
89 | value=self.value)
90 |
91 |
92 | class MetricLogger(object):
93 | def __init__(self, delimiter="\t"):
94 | self.meters = defaultdict(SmoothedValue)
95 | self.delimiter = delimiter
96 |
97 | def update(self, **kwargs):
98 | for k, v in kwargs.items():
99 | if isinstance(v, torch.Tensor):
100 | v = v.item()
101 | assert isinstance(v, (float, int))
102 | self.meters[k].update(v)
103 |
104 | def __getattr__(self, attr):
105 | if attr in self.meters:
106 | return self.meters[attr]
107 | if attr in self.__dict__:
108 | return self.__dict__[attr]
109 | raise AttributeError("'{}' object has no attribute '{}'".format(
110 | type(self).__name__, attr))
111 |
112 | def __str__(self):
113 | loss_str = []
114 | for name, meter in self.meters.items():
115 | loss_str.append(
116 | "{}: {}".format(name, str(meter))
117 | )
118 | return self.delimiter.join(loss_str)
119 |
120 | def global_avg(self):
121 | loss_str = []
122 | for name, meter in self.meters.items():
123 | loss_str.append(
124 | "{}: {:.4f}".format(name, meter.global_avg)
125 | )
126 | return self.delimiter.join(loss_str)
127 |
128 | def synchronize_between_processes(self):
129 | for meter in self.meters.values():
130 | meter.synchronize_between_processes()
131 |
132 | def add_meter(self, name, meter):
133 | self.meters[name] = meter
134 |
135 | def log_every(self, iterable, print_freq, header=None):
136 | i = 0
137 | if not header:
138 | header = ''
139 | start_time = time.time()
140 | end = time.time()
141 | iter_time = SmoothedValue(fmt='{avg:.4f}')
142 | data_time = SmoothedValue(fmt='{avg:.4f}')
143 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144 | log_msg = [
145 | header,
146 | '[{0' + space_fmt + '}/{1}]',
147 | 'eta: {eta}',
148 | '{meters}',
149 | 'time: {time}',
150 | 'data: {data}'
151 | ]
152 | if torch.cuda.is_available():
153 | log_msg.append('max mem: {memory:.0f}')
154 | log_msg = self.delimiter.join(log_msg)
155 | MB = 1024.0 * 1024.0
156 | for obj in iterable:
157 | data_time.update(time.time() - end)
158 | yield obj
159 | iter_time.update(time.time() - end)
160 | if i % print_freq == 0 or i == len(iterable) - 1:
161 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
162 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163 | if torch.cuda.is_available():
164 | print(log_msg.format(
165 | i, len(iterable), eta=eta_string,
166 | meters=str(self),
167 | time=str(iter_time), data=str(data_time),
168 | memory=torch.cuda.max_memory_allocated() / MB))
169 | else:
170 | print(log_msg.format(
171 | i, len(iterable), eta=eta_string,
172 | meters=str(self),
173 | time=str(iter_time), data=str(data_time)))
174 | i += 1
175 | end = time.time()
176 | total_time = time.time() - start_time
177 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178 | print('{} Total time: {} ({:.4f} s / it)'.format(
179 | header, total_time_str, total_time / len(iterable)))
180 |
181 |
182 | class AttrDict(dict):
183 | def __init__(self, *args, **kwargs):
184 | super(AttrDict, self).__init__(*args, **kwargs)
185 | self.__dict__ = self
186 |
187 |
188 | def compute_acc(logits, label, reduction='mean'):
189 | ret = (torch.argmax(logits, dim=1) == label).float()
190 | if reduction == 'none':
191 | return ret.detach()
192 | elif reduction == 'mean':
193 | return ret.mean().item()
194 |
195 | def compute_n_params(model, return_str=True):
196 | tot = 0
197 | for p in model.parameters():
198 | w = 1
199 | for x in p.shape:
200 | w *= x
201 | tot += w
202 | if return_str:
203 | if tot >= 1e6:
204 | return '{:.1f}M'.format(tot / 1e6)
205 | else:
206 | return '{:.1f}K'.format(tot / 1e3)
207 | else:
208 | return tot
209 |
210 | def setup_for_distributed(is_master):
211 | """
212 | This function disables printing when not in master process
213 | """
214 | import builtins as __builtin__
215 | builtin_print = __builtin__.print
216 |
217 | def print(*args, **kwargs):
218 | force = kwargs.pop('force', False)
219 | if is_master or force:
220 | builtin_print(*args, **kwargs)
221 |
222 | __builtin__.print = print
223 |
224 |
225 | def is_dist_avail_and_initialized():
226 | if not dist.is_available():
227 | return False
228 | if not dist.is_initialized():
229 | return False
230 | return True
231 |
232 |
233 | def get_world_size():
234 | if not is_dist_avail_and_initialized():
235 | return 1
236 | return dist.get_world_size()
237 |
238 |
239 | def get_rank():
240 | if not is_dist_avail_and_initialized():
241 | return 0
242 | return dist.get_rank()
243 |
244 |
245 | def is_main_process():
246 | return get_rank() == 0
247 |
248 |
249 | def save_on_master(*args, **kwargs):
250 | if is_main_process():
251 | torch.save(*args, **kwargs)
252 |
253 |
254 | def init_distributed_mode(args):
255 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256 | args.rank = int(os.environ["RANK"])
257 | args.world_size = int(os.environ['WORLD_SIZE'])
258 | args.gpu = int(os.environ['LOCAL_RANK'])
259 | elif 'SLURM_PROCID' in os.environ:
260 | args.rank = int(os.environ['SLURM_PROCID'])
261 | args.gpu = args.rank % torch.cuda.device_count()
262 | else:
263 | print('Not using distributed mode')
264 | args.distributed = False
265 | return
266 |
267 | args.distributed = True
268 |
269 | torch.cuda.set_device(args.gpu)
270 | args.dist_backend = 'nccl'
271 | print('| distributed init (rank {}, word {}): {}'.format(
272 | args.rank, args.world_size, args.dist_url), flush=True)
273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274 | world_size=args.world_size, rank=args.rank)
275 | torch.distributed.barrier()
276 | setup_for_distributed(args.rank == 0)
277 |
278 |
--------------------------------------------------------------------------------
/spec/models/blip_utils/vit.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2022, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Junnan Li
7 | * Based on timm code base
8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from functools import partial
15 |
16 | from timm.models.vision_transformer import _cfg, PatchEmbed
17 | from timm.models.registry import register_model
18 | from timm.models.layers import trunc_normal_, DropPath
19 | from timm.models.helpers import named_apply, adapt_input_conv
20 |
21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22 |
23 | class Mlp(nn.Module):
24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25 | """
26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x):
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46 | super().__init__()
47 | self.num_heads = num_heads
48 | head_dim = dim // num_heads
49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50 | self.scale = qk_scale or head_dim ** -0.5
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.attn_drop = nn.Dropout(attn_drop)
53 | self.proj = nn.Linear(dim, dim)
54 | self.proj_drop = nn.Dropout(proj_drop)
55 | self.attn_gradients = None
56 | self.attention_map = None
57 |
58 | def save_attn_gradients(self, attn_gradients):
59 | self.attn_gradients = attn_gradients
60 |
61 | def get_attn_gradients(self):
62 | return self.attn_gradients
63 |
64 | def save_attention_map(self, attention_map):
65 | self.attention_map = attention_map
66 |
67 | def get_attention_map(self):
68 | return self.attention_map
69 |
70 | def forward(self, x, register_hook=False):
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74 |
75 | attn = (q @ k.transpose(-2, -1)) * self.scale
76 | attn = attn.softmax(dim=-1)
77 | attn = self.attn_drop(attn)
78 |
79 | if register_hook:
80 | self.save_attention_map(attn)
81 | attn.register_hook(self.save_attn_gradients)
82 |
83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84 | x = self.proj(x)
85 | x = self.proj_drop(x)
86 | return x
87 |
88 |
89 | class Block(nn.Module):
90 |
91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93 | super().__init__()
94 | self.norm1 = norm_layer(dim)
95 | self.attn = Attention(
96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99 | self.norm2 = norm_layer(dim)
100 | mlp_hidden_dim = int(dim * mlp_ratio)
101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102 |
103 | if use_grad_checkpointing:
104 | self.attn = checkpoint_wrapper(self.attn)
105 | self.mlp = checkpoint_wrapper(self.mlp)
106 |
107 | def forward(self, x, register_hook=False):
108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109 | x = x + self.drop_path(self.mlp(self.norm2(x)))
110 | return x
111 |
112 |
113 | class VisionTransformer(nn.Module):
114 | """ Vision Transformer
115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116 | https://arxiv.org/abs/2010.11929
117 | """
118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121 | use_grad_checkpointing=False, ckpt_layer=0):
122 | """
123 | Args:
124 | img_size (int, tuple): input image size
125 | patch_size (int, tuple): patch size
126 | in_chans (int): number of input channels
127 | num_classes (int): number of classes for classification head
128 | embed_dim (int): embedding dimension
129 | depth (int): depth of transformer
130 | num_heads (int): number of attention heads
131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132 | qkv_bias (bool): enable bias for qkv if True
133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135 | drop_rate (float): dropout rate
136 | attn_drop_rate (float): attention dropout rate
137 | drop_path_rate (float): stochastic depth rate
138 | norm_layer: (nn.Module): normalization layer
139 | """
140 | super().__init__()
141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143 |
144 | self.patch_embed = PatchEmbed(
145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146 |
147 | num_patches = self.patch_embed.num_patches
148 |
149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151 | self.pos_drop = nn.Dropout(p=drop_rate)
152 |
153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154 | self.blocks = nn.ModuleList([
155 | Block(
156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159 | )
160 | for i in range(depth)])
161 | self.norm = norm_layer(embed_dim)
162 |
163 | trunc_normal_(self.pos_embed, std=.02)
164 | trunc_normal_(self.cls_token, std=.02)
165 | self.apply(self._init_weights)
166 |
167 | def _init_weights(self, m):
168 | if isinstance(m, nn.Linear):
169 | trunc_normal_(m.weight, std=.02)
170 | if isinstance(m, nn.Linear) and m.bias is not None:
171 | nn.init.constant_(m.bias, 0)
172 | elif isinstance(m, nn.LayerNorm):
173 | nn.init.constant_(m.bias, 0)
174 | nn.init.constant_(m.weight, 1.0)
175 |
176 | @torch.jit.ignore
177 | def no_weight_decay(self):
178 | return {'pos_embed', 'cls_token'}
179 |
180 | def forward(self, x, register_blk=-1):
181 | B = x.shape[0]
182 | x = self.patch_embed(x)
183 |
184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185 | x = torch.cat((cls_tokens, x), dim=1)
186 |
187 | x = x + self.pos_embed[:,:x.size(1),:]
188 | x = self.pos_drop(x)
189 |
190 | for i,blk in enumerate(self.blocks):
191 | x = blk(x, register_blk==i)
192 | x = self.norm(x)
193 |
194 | return x
195 |
196 | @torch.jit.ignore()
197 | def load_pretrained(self, checkpoint_path, prefix=''):
198 | _load_weights(self, checkpoint_path, prefix)
199 |
200 |
201 | @torch.no_grad()
202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204 | """
205 | import numpy as np
206 |
207 | def _n2p(w, t=True):
208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209 | w = w.flatten()
210 | if t:
211 | if w.ndim == 4:
212 | w = w.transpose([3, 2, 0, 1])
213 | elif w.ndim == 3:
214 | w = w.transpose([2, 0, 1])
215 | elif w.ndim == 2:
216 | w = w.transpose([1, 0])
217 | return torch.from_numpy(w)
218 |
219 | w = np.load(checkpoint_path)
220 | if not prefix and 'opt/target/embedding/kernel' in w:
221 | prefix = 'opt/target/'
222 |
223 | if hasattr(model.patch_embed, 'backbone'):
224 | # hybrid
225 | backbone = model.patch_embed.backbone
226 | stem_only = not hasattr(backbone, 'stem')
227 | stem = backbone if stem_only else backbone.stem
228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231 | if not stem_only:
232 | for i, stage in enumerate(backbone.stages):
233 | for j, block in enumerate(stage.blocks):
234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235 | for r in range(3):
236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239 | if block.downsample is not None:
240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244 | else:
245 | embed_conv_w = adapt_input_conv(
246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247 | model.patch_embed.proj.weight.copy_(embed_conv_w)
248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251 | if pos_embed_w.shape != model.pos_embed.shape:
252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254 | model.pos_embed.copy_(pos_embed_w)
255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263 | for i, block in enumerate(model.blocks.children()):
264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268 | block.attn.qkv.weight.copy_(torch.cat([
269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270 | block.attn.qkv.bias.copy_(torch.cat([
271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274 | for r in range(2):
275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279 |
280 |
281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282 | # interpolate position embedding
283 | embedding_size = pos_embed_checkpoint.shape[-1]
284 | num_patches = visual_encoder.patch_embed.num_patches
285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286 | # height (== width) for the checkpoint position embedding
287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288 | # height (== width) for the new position embedding
289 | new_size = int(num_patches ** 0.5)
290 |
291 | if orig_size!=new_size:
292 | # class_token and dist_token are kept unchanged
293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294 | # only the position tokens are interpolated
295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297 | pos_tokens = torch.nn.functional.interpolate(
298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302 |
303 | return new_pos_embed
304 | else:
305 | return pos_embed_checkpoint
--------------------------------------------------------------------------------
/spec/models/blip_wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import yaml
4 | import subprocess
5 | import torch.nn.functional as F
6 |
7 | from tqdm import tqdm
8 | from einops import rearrange
9 | from .blip_utils.blip_retrieval import blip_retrieval
10 | from .base_wrapper import BaseWrapper
11 | from torchvision import transforms
12 |
13 |
14 | # All the below URLs are taken from, and most of the implementation are heavily inspired from the wonderful https://github.com/salesforce/BLIP repo.
15 | download_urls = {
16 | "blip-flickr-base": {
17 | "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth",
18 | "config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_flickr.yaml",
19 | "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
20 | },
21 |
22 | "blip-coco-base": {
23 | "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth",
24 | "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml",
25 | "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json"
26 | },
27 | }
28 |
29 |
30 | class BLIPModelWrapper(BaseWrapper):
31 | def __init__(self, cache_dir, device, variant="blip-coco-base"):
32 | self.cache_dir = cache_dir
33 | self.device = device
34 | self.variant = variant
35 | self.image_preprocess = transforms.Compose([
36 | transforms.Resize((384, 384), interpolation=transforms.functional.InterpolationMode.BICUBIC),
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
39 | # download and load model
40 | self.config_path = os.path.join(cache_dir, f"{variant}-config")
41 | self.model_path = os.path.join(cache_dir, f"{variant}.pth")
42 | self.bert_config_path = os.path.join(cache_dir, "configs", f"{variant}_med_config.json")
43 | if not (os.path.exists(self.config_path) and os.path.exists(self.model_path) and os.path.exists(
44 | self.bert_config_path)):
45 | self.download()
46 | config = yaml.load(open(self.config_path, 'r'), Loader=yaml.Loader)
47 | self.config = config
48 | self.config['k_test'] = 128
49 | config['med_config'] = self.bert_config_path
50 | model = blip_retrieval(pretrained=self.model_path, image_size=config['image_size'], vit=config['vit'],
51 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
52 | queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'],
53 | med_config=config['med_config']).to(device)
54 | self.model = model.eval()
55 |
56 |
57 | def download(self):
58 | print(f"Downloading BLIP model to {self.cache_dir}...")
59 | model_url = download_urls[self.variant]["model_url"]
60 | config_url = download_urls[self.variant]["config_url"]
61 | bert_config_url = download_urls[self.variant]["bert_config_url"]
62 | os.makedirs(os.path.join(self.cache_dir, "configs"), exist_ok=True)
63 | subprocess.call(["wget", "-cq", model_url, "-O", self.model_path])
64 | subprocess.call(["wget", "-cq", config_url, "-O", self.config_path])
65 | subprocess.call(["wget", "-cq", bert_config_url, "-O", self.bert_config_path])
66 |
67 | @torch.no_grad()
68 | def i2t_evaluate(self, subset_name, dataloader):
69 | tqdm_i2t_loader = tqdm(dataloader)
70 | tqdm_i2t_loader.set_description(f"Image to Text retrieval on <{subset_name}>")
71 | i2t_scores = []
72 | i2t_correct_num = 0
73 | total_num = 0
74 | for batch in tqdm_i2t_loader:
75 | bs = len(batch['label'])
76 | # get query images
77 | query_images = batch['query_image'].to(self.device) # B,C,H,W (B:batch size)
78 |
79 | # compute normalized image embeddings
80 | image_feats = self.model.visual_encoder(query_images)
81 | image_embeddings = self.model.vision_proj(image_feats[:, 0, :])
82 | image_embeddings = F.normalize(image_embeddings, dim=-1) # B, D (D:feature dim)
83 |
84 | # get candidate texts
85 | candidate_texts = batch['candidate_texts']
86 |
87 | # compute normalized text embeddings
88 | text_embeddings = []
89 | for texts in candidate_texts:
90 | text_input = self.model.tokenizer(texts, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
91 | text_feat = self.model.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask, mode='text')
92 | text_embed = F.normalize(self.model.text_proj(text_feat.last_hidden_state[:, 0, :]))
93 | text_embeddings.append(text_embed)
94 | text_embeddings = torch.stack(text_embeddings, dim=0) # B, L, D
95 |
96 | # calculate matching result
97 | batch_i2t_scores = torch.einsum('BD,BLD->BL', [image_embeddings, text_embeddings]).cpu()
98 | i2t_scores.append(batch_i2t_scores)
99 | gt_labels = batch['label']
100 | pred_labels = batch_i2t_scores.argmax(dim=-1)
101 | correct_num = (gt_labels == pred_labels).sum()
102 | i2t_correct_num += correct_num.item()
103 | total_num += bs
104 |
105 | i2t_scores = torch.cat(i2t_scores, dim=0)
106 | i2t_acc = 100 * i2t_correct_num / total_num
107 |
108 | return i2t_scores, i2t_acc
109 |
110 | @torch.no_grad()
111 | def t2i_evaluate(self, subset_name, dataloader):
112 | tqdm_t2i_loader = tqdm(dataloader)
113 | tqdm_t2i_loader.set_description(f"Text to Image retrieval on <{subset_name}>")
114 | t2i_scores = []
115 | t2i_correct_num = 0
116 | total_num = 0
117 | for batch in tqdm_t2i_loader:
118 | bs = len(batch['label'])
119 | # get query texts
120 | query_texts = batch['query_text'] # B (B:batch size, list)
121 | query_texts = self.model.tokenizer(query_texts, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
122 |
123 | # compute normalized text embeddings
124 | text_feats = self.model.text_encoder(query_texts.input_ids, attention_mask=query_texts.attention_mask, mode='text')
125 | text_embeddings = F.normalize(self.model.text_proj(text_feats.last_hidden_state[:, 0, :])) # B,D (D:feature dim)
126 |
127 | # get candidate images
128 | candidate_images = batch['candidate_images'].to(self.device) # B,K,C,H,W (K:num of candidate images per case, S:sentence length)
129 | candidate_images = rearrange(candidate_images, 'B K C H W -> (B K) C H W')
130 |
131 | # compute normalized image embeddings
132 | image_feats = self.model.visual_encoder(candidate_images)
133 | image_embeddings = self.model.vision_proj(image_feats[:, 0, :])
134 | image_embeddings = F.normalize(image_embeddings, dim=-1) # B, D (D:feature dim)
135 | image_embeddings = rearrange(image_embeddings, '(B K) D -> B K D', B=bs)
136 |
137 | # calculate matching result
138 | batch_t2i_scores = torch.einsum('BD,BKD->BK', [text_embeddings, image_embeddings]).cpu()
139 | t2i_scores.append(batch_t2i_scores)
140 | gt_labels = batch['label']
141 | pred_labels = batch_t2i_scores.argmax(dim=-1)
142 | correct_num = (gt_labels == pred_labels).sum()
143 | t2i_correct_num += correct_num.item()
144 | total_num += bs
145 |
146 | t2i_scores = torch.cat(t2i_scores, dim=0)
147 | t2i_acc = 100 * t2i_correct_num / total_num
148 |
149 | return t2i_scores, t2i_acc
150 |
--------------------------------------------------------------------------------
/spec/models/clip_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from einops import rearrange
4 | from tqdm import tqdm
5 | from .base_wrapper import BaseWrapper
6 | from open_clip import create_model_and_transforms, get_tokenizer
7 |
8 |
9 | class CLIPWrapper(BaseWrapper):
10 | def __init__(self, device, variant='ViT-B-32', pretrained='openai'):
11 | self.device = device
12 | model, _, image_preprocess = create_model_and_transforms(variant,
13 | device=self.device,
14 | pretrained=pretrained)
15 | self.model = model.eval()
16 | self.tokenizer = get_tokenizer(variant)
17 | self.image_preprocess = image_preprocess
18 |
19 | @torch.no_grad()
20 | def i2t_evaluate(self, subset_name, dataloader):
21 | tqdm_i2t_loader = tqdm(dataloader)
22 | tqdm_i2t_loader.set_description(f"Image to Text retrieval on <{subset_name}>")
23 | i2t_scores = []
24 | i2t_correct_num = 0
25 | total_num = 0
26 | for batch in tqdm_i2t_loader:
27 | bs = len(batch['label'])
28 | # get query images
29 | query_images = batch['query_image'].to(self.device) # B,C,H,W (B:batch size)
30 |
31 | # compute normalized image embeddings
32 | image_embeddings = self.model.encode_image(query_images) # B,D (D:feature dim)
33 | image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
34 |
35 | # get candidate texts
36 | candidate_texts = batch['candidate_texts']
37 | candidate_texts = [self.tokenizer(texts) for texts in candidate_texts]
38 | candidate_texts = torch.stack(candidate_texts, dim=0).to(self.device) # B,L,S (L:num of candidate texts, S:sentence length)
39 |
40 | # compute normalized text embeddings
41 | candidate_texts = rearrange(candidate_texts, 'B L S -> (B L) S')
42 | text_embeddings = self.model.encode_text(candidate_texts) # BL,D (D:feature dim)
43 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
44 | text_embeddings = rearrange(text_embeddings, '(B L) D -> B L D', B=bs) # B, L, D
45 |
46 | # calculate matching result
47 | batch_i2t_scores = torch.einsum('BD,BLD->BL', [image_embeddings, text_embeddings]).cpu()
48 | i2t_scores.append(batch_i2t_scores)
49 | gt_labels = batch['label']
50 | pred_labels = batch_i2t_scores.argmax(dim=-1)
51 | correct_num = (gt_labels == pred_labels).sum()
52 | i2t_correct_num += correct_num.item()
53 | total_num += bs
54 |
55 | i2t_scores = torch.cat(i2t_scores, dim=0)
56 | i2t_acc = 100 * i2t_correct_num / total_num
57 |
58 | return i2t_scores, i2t_acc
59 |
60 | @torch.no_grad()
61 | def t2i_evaluate(self, subset_name, dataloader):
62 | tqdm_t2i_loader = tqdm(dataloader)
63 | tqdm_t2i_loader.set_description(f"Text to Image retrieval on <{subset_name}>")
64 | t2i_scores = []
65 | t2i_correct_num = 0
66 | total_num = 0
67 | for batch in tqdm_t2i_loader:
68 | bs = len(batch['label'])
69 | # get query texts
70 | query_texts = batch['query_text'] # B,1,S (B:batch size, S:sentence length)
71 | query_texts = self.tokenizer(query_texts).to(self.device) # B, S (B:batch size, S:sentence length)
72 |
73 | # compute normalized text embeddings
74 | text_embeddings = self.model.encode_text(query_texts) # B,D (D:feature dim)
75 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
76 |
77 | # get candidate images
78 | candidate_images = batch['candidate_images'].to(self.device) # B,K,C,H,W (K:num of candidate images per case, S:sentence length)
79 | candidate_images = rearrange(candidate_images, 'B K C H W -> (B K) C H W')
80 |
81 | # compute normalized image embeddings
82 | image_embeddings = self.model.encode_image(
83 | candidate_images) # BK,D (K:num of candidate images per case, D:feature dim)
84 | image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
85 | image_embeddings = rearrange(image_embeddings, '(B K) D -> B K D', B=bs)
86 |
87 | # calculate matching result
88 | batch_t2i_scores = torch.einsum('BD,BKD->BK', [text_embeddings, image_embeddings]).cpu()
89 | t2i_scores.append(batch_t2i_scores)
90 | gt_labels = batch['label']
91 | pred_labels = batch_t2i_scores.argmax(dim=-1)
92 | correct_num = (gt_labels == pred_labels).sum()
93 | t2i_correct_num += correct_num.item()
94 | total_num += bs
95 |
96 | t2i_scores = torch.cat(t2i_scores, dim=0)
97 | t2i_acc = 100 * t2i_correct_num / total_num
98 |
99 | return t2i_scores, t2i_acc
100 |
--------------------------------------------------------------------------------
/spec/models/flava_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 | from transformers import FlavaProcessor, FlavaForPreTraining, BertTokenizer, FlavaFeatureExtractor
5 | from .base_wrapper import BaseWrapper
6 |
7 |
8 | class FlavaWrapper(BaseWrapper):
9 | def __init__(self, cache_dir, device):
10 | # load model, tokenizer, processor
11 | self.model = FlavaForPreTraining.from_pretrained("facebook/flava-full", cache_dir=cache_dir).eval()
12 | self.model = self.model.to(device)
13 | self.feature_extractor = FlavaFeatureExtractor.from_pretrained("facebook/flava-full", cache_dir=cache_dir)
14 | self.tokenizer = BertTokenizer.from_pretrained("facebook/flava-full", cache_dir=cache_dir)
15 | self.processor = FlavaProcessor.from_pretrained("facebook/flava-full", cache_dir=cache_dir)
16 | self.device = device
17 |
18 | @torch.no_grad()
19 | def i2t_evaluate(self, subset_name, dataloader):
20 | tqdm_i2t_loader = tqdm(dataloader)
21 | tqdm_i2t_loader.set_description(f"Image to Text retrieval on <{subset_name}>")
22 | i2t_scores = []
23 | i2t_correct_num = 0
24 | total_num = 0
25 | for batch in tqdm_i2t_loader:
26 | bs = len(batch['label'])
27 | # get query images
28 | query_images = batch['query_image'] # [B x PIL.Image] (B:batch size)
29 |
30 | # compute normalized image embeddings
31 | inputs = self.feature_extractor(images=query_images, return_tensors="pt").to(self.device)
32 | image_embeddings = self.model.flava.get_image_features(**inputs).cpu().numpy()[:, 0, :]
33 | image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
34 | image_embeddings = torch.tensor(image_embeddings) # (B, D)
35 |
36 | # get candidate texts
37 | candidate_texts = batch['candidate_texts'] # B, L
38 |
39 | # compute normalized text embeddings
40 | text_embeddings = []
41 | for texts in candidate_texts:
42 | text_input = self.tokenizer(text=texts, return_tensors="pt", padding="max_length", max_length=77).to(self.device)
43 | text_feats = self.model.flava.get_text_features(**text_input).cpu().numpy()[:, 0, :]
44 | text_feats = text_feats / np.linalg.norm(text_feats, axis=1, keepdims=True)
45 | text_feats = torch.tensor(text_feats)
46 | text_embeddings.append(text_feats)
47 | text_embeddings = torch.stack(text_embeddings, dim=0) # (B, L, D)
48 |
49 | # calculate matching result
50 | batch_i2t_scores = torch.einsum('BD,BLD->BL', [image_embeddings, text_embeddings]).cpu()
51 | i2t_scores.append(batch_i2t_scores)
52 | gt_labels = batch['label']
53 | pred_labels = batch_i2t_scores.argmax(dim=-1)
54 | correct_num = (gt_labels == pred_labels).sum()
55 | i2t_correct_num += correct_num.item()
56 | total_num += bs
57 |
58 | i2t_scores = torch.cat(i2t_scores, dim=0)
59 | i2t_acc = 100 * i2t_correct_num / total_num
60 |
61 | return i2t_scores, i2t_acc
62 |
63 | @torch.no_grad()
64 | def t2i_evaluate(self, subset_name, dataloader):
65 | tqdm_t2i_loader = tqdm(dataloader)
66 | tqdm_t2i_loader.set_description(f"Text to Image retrieval on <{subset_name}>")
67 | t2i_scores = []
68 | t2i_correct_num = 0
69 | total_num = 0
70 | for batch in tqdm_t2i_loader:
71 | bs = len(batch['label'])
72 | # get query texts
73 | query_texts = batch['query_text'] # [B x STR]
74 |
75 | # compute normalized text embeddings
76 | text_input = self.tokenizer(text=query_texts, return_tensors="pt", padding="max_length", max_length=77).to(self.device)
77 | text_embeddings = self.model.flava.get_text_features(**text_input).cpu().numpy()[:, 0, :]
78 | text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
79 | text_embeddings = torch.tensor(text_embeddings) # B,D (D:feature dim)
80 |
81 | # get candidate images
82 | candidate_images = batch['candidate_images'] # [BxL, PIL.Image]
83 |
84 | # compute normalized image embeddings
85 | image_embeddings = []
86 | for images in candidate_images:
87 | image_inputs = self.feature_extractor(images=images, return_tensors="pt").to(self.device)
88 | image_embed = self.model.flava.get_image_features(**image_inputs).cpu().numpy()[:, 0, :]
89 | image_embed = image_embed / np.linalg.norm(image_embed, axis=1, keepdims=True)
90 | image_embed = torch.tensor(image_embed) # (L, D)
91 | image_embeddings.append(image_embed)
92 | image_embeddings = torch.stack(image_embeddings, dim=0) # (B, L, D)
93 |
94 | # calculate matching result
95 | batch_t2i_scores = torch.einsum('BD,BKD->BK', [text_embeddings, image_embeddings]).cpu()
96 | t2i_scores.append(batch_t2i_scores)
97 | gt_labels = batch['label']
98 | pred_labels = batch_t2i_scores.argmax(dim=-1)
99 | correct_num = (gt_labels == pred_labels).sum()
100 | t2i_correct_num += correct_num.item()
101 | total_num += bs
102 |
103 | t2i_scores = torch.cat(t2i_scores, dim=0)
104 | t2i_acc = 100 * t2i_correct_num / total_num
105 |
106 | return t2i_scores, t2i_acc
107 |
--------------------------------------------------------------------------------
/spec/run_eval.sh:
--------------------------------------------------------------------------------
1 | model_dir='/path/to/cache/models'
2 | data_dir='/path/to/data'
3 | out_dir='/path/to/save/results'
4 |
5 | for model in clip blip flava coca
6 | do
7 | python eval.py \
8 | --model-name $model \
9 | --model-cache-dir $model_dir \
10 | --subset-names absolute_size relative_size absolute_spatial relative_spatial existence count \
11 | --data-root $data_dir \
12 | --out-path $out_dir \
13 | --batch-size 64 \
14 | --num-workers 8 \
15 | --seed 1
16 | done
--------------------------------------------------------------------------------