├── .gitignore ├── README.md ├── assets ├── poster-v2.pdf └── spec_overview.png ├── docs └── evaluate_custom_model.md ├── notebooks ├── evaluate_example_colab.ipynb ├── evaluate_example_local.ipynb ├── explore_spec_colab.ipynb └── explore_spec_local.ipynb ├── setup.py └── spec ├── __init__.py ├── dataset.py ├── eval.py ├── models ├── __init__.py ├── base_wrapper.py ├── blip_utils │ ├── .ipynb_checkpoints │ │ └── README-checkpoint.md │ ├── README.md │ ├── __pycache__ │ │ ├── blip.cpython-38.pyc │ │ ├── blip_retrieval.cpython-38.pyc │ │ ├── med.cpython-38.pyc │ │ └── vit.cpython-38.pyc │ ├── blip.py │ ├── blip_itm.py │ ├── blip_pretrain.py │ ├── blip_retrieval.py │ ├── med.py │ ├── utils.py │ └── vit.py ├── blip_wrapper.py ├── clip_wrapper.py └── flava_wrapper.py └── run_eval.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

SPEC: Synthesize, Diagnose, and Optimize: Towards Fine-Grained Vision-Language Understanding

3 | 4 | arXiv 5 | 6 | 7 | HF Dataset: SPEC 8 | 9 | 10 | HF Dataset: Inst-It-Dataset 11 | 12 | 13 | Poster 14 | 15 | 16 |
17 | Wujian Peng, 18 | Sicheng Xie, 19 | Zuyao You, 20 | Shiyi Lan, 21 | Zuxuan Wu, 22 |
23 | 24 |
25 | Corresponding author  26 |
27 | 28 |
29 | 30 | ## :fire: News 31 | * `Apr. 14, 2024` We have released a [preview](https://wjpoom.github.io/preview/) of a more advanced dataset version, the full version will come soon. 32 | * `Apr. 13, 2024` We released the SPEC dataset and the code for evaluation, sorry for the delay :relaxed:. 33 | * `Feb. 28, 2024` Our work has been accepted by [CVPR 2024](https://cvpr.thecvf.com/) :tada:. 34 | 35 | ## :rocket: A more advanced version is coming! 36 | We are building a new version with a larger data scale, more object categories, and higher-quality images and text, and more. 37 | You can preview it at [this website](https://wjpoom.github.io/preview/), and the full version will come soon. 38 | 39 | ## :mag: SPEC Benchmark 40 | To evaluate the understanding capability of visual-language models on fine-grained concepts, we propose a new benchmark, SPEC, 41 | which consists of six distinct subsets, distributed across the dimensions of **S**ize, **P**osition, **E**xistence, and **C**ount. 42 | Each test case consists of an image candidate set, which differs only in certain visual concepts, and a text candidate set, 43 | which differs only in the corresponding language concept. 44 |

45 | 46 | 47 |

48 | 49 | ## :wrench: Usage 50 | ### install 51 | ``` shell 52 | git clone https://github.com/wjpoom/SPEC.git 53 | cd SPEC/ 54 | pip install -e . 55 | ``` 56 | ### prepare data 57 | * run the following code in Python shell, replace `/path/to/save/data` with a specified dir to store the data. 58 | ```python 59 | import zipfile 60 | import os 61 | from huggingface_hub import hf_hub_download 62 | 63 | data_root = '/path/to/save/data' 64 | hf_hub_download(repo_id='wjpoom/SPEC', repo_type='dataset', filename='data.zip', local_dir=data_root) 65 | 66 | with zipfile.ZipFile(os.path.join(data_root, 'data.zip'), 'r') as zip_ref: 67 | zip_ref.extractall(os.path.join(data_root)) 68 | 69 | os.remove(os.path.join(data_root, 'data.zip')) 70 | ``` 71 | ### explore the dataset 72 | * We provide a 📓notebook that enables you to visually explore the test samples in the SPEC dataset. 73 | * Run this notebook either [locally](https://github.com/wjpoom/SPEC/blob/main/notebooks/explore_spec_local.ipynb) or online using [Colab](https://colab.research.google.com/github/wjpoom/SPEC/blob/main/notebooks/explore_spec_colab.ipynb). 74 | 75 | ### reproduce the results 76 | * In our paper, we evaluated four popular VLMs using our SPEC dataset, namely: CLIP, BLIP, FLAVA and CoCa. 77 | * To reproduce the results with these VLMs, you can run [this script](https://github.com/wjpoom/SPEC/blob/main/spec/run_eval.sh). 78 | * You can also reproduce with this [local notebook](https://github.com/wjpoom/SPEC/blob/main/notebooks/evaluate_example_local.ipynb) or the online [Colab notebook](https://colab.research.google.com/github/wjpoom/SPEC/blob/main/notebooks/evaluate_example_colab.ipynb). 79 | 80 | ### evaluate custom VLMs 81 | * If you want to evaluate your custom model on SPEC, you can follow the instructions in [this document](https://github.com/wjpoom/SPEC/blob/main/docs/evaluate_custom_model.md). 82 | 83 | ## :memo: TODO 84 | - [ ] Release the newly built version of the dataset 85 | - [ ] Release the code of our data synthesize pipeline 86 | - [x] Release the testing set of SPEC benchmark 87 | - [x] Release the evaluation code of SPEC 88 | 89 | ## :clap: Acknowledgement 90 | Part of this repository is built upon [ARO](https://github.com/mertyg/vision-language-models-are-bows), thanks for the well-organized codebase. 91 | 92 | ## Contact Us 93 | Feel free to contact us if you have any questions or suggestions 94 | 95 | Email (Wujian Peng): wjpeng24@m.fudan.edu.cn 96 | 97 | ## :black_nib: Citation 98 | If you use our code or data in this repo or find our work helpful, please consider giving a citation: 99 | 100 | ``` bibtex 101 | @inproceedings{peng2024synthesize, 102 | title={Synthesize diagnose and optimize: Towards fine-grained vision-language understanding}, 103 | author={Peng, Wujian and Xie, Sicheng and You, Zuyao and Lan, Shiyi and Wu, Zuxuan}, 104 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 105 | pages={13279--13288}, 106 | year={2024} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /assets/poster-v2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/assets/poster-v2.pdf -------------------------------------------------------------------------------- /assets/spec_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/assets/spec_overview.png -------------------------------------------------------------------------------- /docs/evaluate_custom_model.md: -------------------------------------------------------------------------------- 1 | # Evaluate Custom Vision Language Model on SPEC 2 | We have implemented the testing code for four VLMs, namely [CLIP](https://arxiv.org/abs/2103.00020), [BLIP](https://arxiv.org/abs/2201.12086), [FLAVA](https://arxiv.org/abs/2112.04482), and [CoCa](https://arxiv.org/abs/2205.01917). 3 | If you want to test other custom models, you need to complete the following steps. 4 | 5 | ## Step 1. Implement custom `model wrapper` 6 | Firstly, create a file named `custom_wrapper.py` under [`spec/models/`](https://github.com/wjpoom/SPEC/tree/main/spec/models). Next, define your own `CustomWrapper` class within this file. 7 | `CustomWrapper` needs to inherit from the base class [`BaseWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/base_wrapper.py#L6) 8 | and implement the `i2t_evaluate` and `t2i_evaluate` methods, you can also add any other methods you need. Your `CustomWrapper` should look like the following: 9 | ```python 10 | class CustomWrapper(BaseWrapper): 11 | def __init__(self): 12 | pass 13 | @torch.no_grad() 14 | def i2t_evaluate(self, subset_name, dataloader): 15 | pass 16 | @torch.no_grad() 17 | def t2i_evaluate(self, subset_name, dataloader): 18 | pass 19 | ``` 20 | **Note**: take care of the return format of `i2t_evaluate` and `t2i_evaluate`. Please refer to instances in [`CLIPWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/clip_wrapper.py#L9C7-L9C18), [`BLIPWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/blip_wrapper.py#L30) or 21 | [`FLAVAWrapper`](https://github.com/wjpoom/SPEC/blob/d1048b57b4f64a813624ce6575ececa86a9178ea/spec/models/flava_wrapper.py#L8) when implementing your code. 22 | 23 | ## Step 2. Add your model in `get_model()` 24 | We defined a method named `get_model()` in [`spec/models/__init__.py`](https://github.com/wjpoom/SPEC/blob/main/spec/models/__init__.py), 25 | which handles model loading. You need to add the code to load your custom model within this function, 26 | simply add the following code block at the end: 27 | ```python 28 | elif model_name == CUSTOM_MODEL_NAME: 29 | from .custom_wrapper import CUSTOMWrapper 30 | model = CUSTOMWrapper(...) 31 | image_preprocess = model.image_preprocess 32 | return model, image_preprocess 33 | ``` 34 | where `CUSTOM_MODEL_NAME` is a string that distinguishes your custom model, `custom_wrapper` and `CUSTOMWrapper` are the filename and wrapper class name you defined in the first step. 35 | 36 | **Note**: You need to return `image_preprocess`, which will be used in the dataset construction to process the input image (e.g., cropping, converting to tensor, etc.). If you don't need this operation, please return None. 37 | ## Step 3. Evaluate your custom model 38 | Run the following script to evaluate your custom model on SPEC ! 39 | ```shell 40 | model=CUSTOM_MODEL_NAME 41 | model_dir='/path/to/cache/models' 42 | data_dir='/path/to/data' 43 | out_dir='/path/to/save/results' 44 | 45 | python eval.py \ 46 | --model-name $model \ 47 | --model-cache-dir $model_dir \ 48 | --subset-names absolute_size relative_size absolute_spatial relative_spatial existence count \ 49 | --data-root $data_dir \ 50 | --out-path $out_dir \ 51 | --batch-size 64 \ 52 | --num-workers 8 \ 53 | --seed 1 54 | ``` 55 | -------------------------------------------------------------------------------- /notebooks/evaluate_example_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "048e9a63-4c0e-4c4e-8d90-34f2af8049f9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Evaluate Popular Vision Language Models on SPEC\n", 9 | "In [our paper](https://arxiv.org/abs/2312.00081), we evaluated four popular VLMs using our SPEC dataset, namely: [CLIP](https://arxiv.org/abs/2103.00020), [BLIP](https://arxiv.org/abs/2201.12086), [FLAVA](https://arxiv.org/abs/2112.04482), and [CoCa](https://arxiv.org/abs/2205.01917). \\\n", 10 | "This notebook will guide readers to reproduce these results step by step, let's go!" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "c3c22764-d4af-4ab9-8460-d6e4dfc79301", 16 | "metadata": {}, 17 | "source": [ 18 | "## 1. Prepare the environment" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "outputs": [], 25 | "source": [ 26 | "!git clone https://github.com/wjpoom/SPEC.git\n", 27 | "%cd SPEC\n", 28 | "!pip install -e . --quiet" 29 | ], 30 | "metadata": { 31 | "collapsed": false 32 | } 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "fb66cde2-cca5-45c1-ac3d-706b3ff4662d", 37 | "metadata": {}, 38 | "source": [ 39 | "## 2. Import Packages" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "id": "c5a66c29-7726-4126-bdeb-0a20a1eeeac4", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import zipfile\n", 50 | "import os\n", 51 | "import torch\n", 52 | "import warnings\n", 53 | "warnings.filterwarnings('ignore')\n", 54 | "from spec import get_data, get_model\n", 55 | "from huggingface_hub import hf_hub_download" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "f8d7636e-5a5e-462e-b99a-45feccf1e3ba", 61 | "metadata": {}, 62 | "source": [ 63 | "## 3. Prepare the testing dataset\n", 64 | "We store the data on HuggingFace. Before starting, you need to download and decompress the data as following:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "id": "856550c0-da71-4c09-9533-2df5bd75ba19", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# specify the path to save the downloaded and extracted the data\n", 75 | "data_root = 'data'\n", 76 | "# download *.zip files\n", 77 | "hf_hub_download(repo_id='wjpoom/SPEC', repo_type='dataset', filename='data.zip', local_dir=data_root)\n", 78 | "# extract *.zip files\n", 79 | "with zipfile.ZipFile(os.path.join(data_root, 'data.zip'), 'r') as zip_ref:\n", 80 | " zip_ref.extractall(os.path.join(data_root))\n", 81 | "# remove the *.zip files\n", 82 | "os.remove(os.path.join(data_root, 'data.zip'))\n", 83 | "print(f'The SPEC dataset is prepared at: {data_root}.')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "860f7834-e471-4f9a-9d4b-b4a152ffc133", 89 | "metadata": {}, 90 | "source": [ 91 | "## 4. Let's Evaluate VLMs on SPEC dataset!" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "b06a98b7-2b4d-47d3-a13b-794dca37a403", 97 | "metadata": {}, 98 | "source": [ 99 | "### 4.1 Evaluate CLIP\n", 100 | "We use the `ViT/B-32` variant of [CLIP](https://arxiv.org/abs/2103.00020) with weights resumed from the checkpoint release by OpenAI." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "id": "8e687eb6-08a7-47d8-9927-30b174ea7a32", 107 | "metadata": { 108 | "scrolled": true 109 | }, 110 | "outputs": [ 111 | { 112 | "name": "stderr", 113 | "output_type": "stream", 114 | "text": [ 115 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:07<00:00, 2.27it/s]\n" 116 | ] 117 | }, 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "existence subset: Image2Text Accuracy: 57.00 %\n" 123 | ] 124 | }, 125 | { 126 | "name": "stderr", 127 | "output_type": "stream", 128 | "text": [ 129 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:10<00:00, 1.46it/s]\n" 130 | ] 131 | }, 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "existence subset: Text2Image Accuracy: 52.00 %\n" 137 | ] 138 | }, 139 | { 140 | "name": "stderr", 141 | "output_type": "stream", 142 | "text": [ 143 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:12<00:00, 2.66it/s]\n" 144 | ] 145 | }, 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "relative_spatial subset: Image2Text Accuracy: 27.10 %\n" 151 | ] 152 | }, 153 | { 154 | "name": "stderr", 155 | "output_type": "stream", 156 | "text": [ 157 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:36<00:00, 1.15s/it]\n" 158 | ] 159 | }, 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "relative_spatial subset: Text2Image Accuracy: 26.75 %\n" 165 | ] 166 | }, 167 | { 168 | "name": "stderr", 169 | "output_type": "stream", 170 | "text": [ 171 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.69it/s]\n" 172 | ] 173 | }, 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "absolute_size subset: Image2Text Accuracy: 44.27 %\n" 179 | ] 180 | }, 181 | { 182 | "name": "stderr", 183 | "output_type": "stream", 184 | "text": [ 185 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n" 186 | ] 187 | }, 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "absolute_size subset: Text2Image Accuracy: 36.27 %\n" 193 | ] 194 | }, 195 | { 196 | "name": "stderr", 197 | "output_type": "stream", 198 | "text": [ 199 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.67it/s]\n" 200 | ] 201 | }, 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "relative_size subset: Image2Text Accuracy: 34.07 %\n" 207 | ] 208 | }, 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n" 214 | ] 215 | }, 216 | { 217 | "name": "stdout", 218 | "output_type": "stream", 219 | "text": [ 220 | "relative_size subset: Text2Image Accuracy: 32.47 %\n" 221 | ] 222 | }, 223 | { 224 | "name": "stderr", 225 | "output_type": "stream", 226 | "text": [ 227 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:43<00:00, 1.62it/s]\n" 228 | ] 229 | }, 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "count subset: Image2Text Accuracy: 25.27 %\n" 235 | ] 236 | }, 237 | { 238 | "name": "stderr", 239 | "output_type": "stream", 240 | "text": [ 241 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:59<00:00, 2.52s/it]\n" 242 | ] 243 | }, 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "count subset: Text2Image Accuracy: 23.62 %\n" 249 | ] 250 | }, 251 | { 252 | "name": "stderr", 253 | "output_type": "stream", 254 | "text": [ 255 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:45<00:00, 1.57it/s]\n" 256 | ] 257 | }, 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "absolute_spatial subset: Image2Text Accuracy: 12.64 %\n" 263 | ] 264 | }, 265 | { 266 | "name": "stderr", 267 | "output_type": "stream", 268 | "text": [ 269 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:55<00:00, 2.48s/it]" 270 | ] 271 | }, 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "absolute_spatial subset: Text2Image Accuracy: 12.20 %\n", 277 | "\n", 278 | "############# finished the evaluation on all selected subsets ###############\n", 279 | "average of all subset: Image2Text Accuracy: 33.39 %\n", 280 | "average of all subset: Text2Image Accuracy: 30.55 %\n", 281 | "result saved to clip_openai_evaluate_result.pth.\n" 282 | ] 283 | }, 284 | { 285 | "name": "stderr", 286 | "output_type": "stream", 287 | "text": [ 288 | "\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "# load model\n", 294 | "model_cache_dir = 'models/clip' # specify the path to save the downloaded model checkpoint\n", 295 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 296 | "model, image_preprocess = get_model(model_name='clip', cache_dir=model_cache_dir, device=device)\n", 297 | "# load datasets\n", 298 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 299 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 300 | "# evaluate\n", 301 | "result = {}\n", 302 | "i2t_acc = 0.\n", 303 | "t2i_acc = 0.\n", 304 | "subset_num = 0\n", 305 | "for subset_name, dataloaders in subsets.items():\n", 306 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 307 | " result[subset_name] = subset_result\n", 308 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 309 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 310 | " subset_num += 1\n", 311 | "# print and save results\n", 312 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 313 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 314 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 315 | "out_path = 'results' # specify the path to save the evaluation results\n", 316 | "os.makedirs(out_path, exist_ok=True)\n", 317 | "out_fn = f\"clip_result.pth\" # specify the filename according to the model you used\n", 318 | "torch.save(result, os.path.join(out_path, out_fn))\n", 319 | "print(f'result saved to {out_fn}.')" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "0614e38d-646b-4d7e-a6e7-0a6b05307da9", 325 | "metadata": {}, 326 | "source": [ 327 | "### 4.2 Evaluate BLIP\n", 328 | "We use the `ViT-B` variant of [BLIP](https://arxiv.org/abs/2201.12086) with weights resumed from the checkpoint released in this [link](https://github.com/salesforce/BLIP), which is finetuned on COCO for image-text retrieval." 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 4, 334 | "id": "ec215787-f3ea-4890-af47-7e83abc3a9a6", 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "load checkpoint from ~/.cache/blip/blip-coco-base.pth\n", 342 | "missing keys:\n", 343 | "[]\n" 344 | ] 345 | }, 346 | { 347 | "name": "stderr", 348 | "output_type": "stream", 349 | "text": [ 350 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:36<00:00, 2.29s/it]\n" 351 | ] 352 | }, 353 | { 354 | "name": "stdout", 355 | "output_type": "stream", 356 | "text": [ 357 | "existence subset: Image2Text Accuracy: 55.50 %\n" 358 | ] 359 | }, 360 | { 361 | "name": "stderr", 362 | "output_type": "stream", 363 | "text": [ 364 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:38<00:00, 2.39s/it]\n" 365 | ] 366 | }, 367 | { 368 | "name": "stdout", 369 | "output_type": "stream", 370 | "text": [ 371 | "existence subset: Text2Image Accuracy: 50.10 %\n" 372 | ] 373 | }, 374 | { 375 | "name": "stderr", 376 | "output_type": "stream", 377 | "text": [ 378 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:11<00:00, 2.23s/it]\n" 379 | ] 380 | }, 381 | { 382 | "name": "stdout", 383 | "output_type": "stream", 384 | "text": [ 385 | "relative_spatial subset: Image2Text Accuracy: 30.65 %\n" 386 | ] 387 | }, 388 | { 389 | "name": "stderr", 390 | "output_type": "stream", 391 | "text": [ 392 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:17<00:00, 4.31s/it]\n" 393 | ] 394 | }, 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "relative_spatial subset: Text2Image Accuracy: 29.60 %\n" 400 | ] 401 | }, 402 | { 403 | "name": "stderr", 404 | "output_type": "stream", 405 | "text": [ 406 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:55<00:00, 2.30s/it]\n" 407 | ] 408 | }, 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "absolute_size subset: Image2Text Accuracy: 43.20 %\n" 414 | ] 415 | }, 416 | { 417 | "name": "stderr", 418 | "output_type": "stream", 419 | "text": [ 420 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.36s/it]\n" 421 | ] 422 | }, 423 | { 424 | "name": "stdout", 425 | "output_type": "stream", 426 | "text": [ 427 | "absolute_size subset: Text2Image Accuracy: 43.07 %\n" 428 | ] 429 | }, 430 | { 431 | "name": "stderr", 432 | "output_type": "stream", 433 | "text": [ 434 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:54<00:00, 2.26s/it]\n" 435 | ] 436 | }, 437 | { 438 | "name": "stdout", 439 | "output_type": "stream", 440 | "text": [ 441 | "relative_size subset: Image2Text Accuracy: 34.33 %\n" 442 | ] 443 | }, 444 | { 445 | "name": "stderr", 446 | "output_type": "stream", 447 | "text": [ 448 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.37s/it]\n" 449 | ] 450 | }, 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "relative_size subset: Text2Image Accuracy: 33.27 %\n" 456 | ] 457 | }, 458 | { 459 | "name": "stderr", 460 | "output_type": "stream", 461 | "text": [ 462 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:44<00:00, 2.32s/it]\n" 463 | ] 464 | }, 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "count subset: Image2Text Accuracy: 36.87 %\n" 470 | ] 471 | }, 472 | { 473 | "name": "stderr", 474 | "output_type": "stream", 475 | "text": [ 476 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n" 477 | ] 478 | }, 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "count subset: Text2Image Accuracy: 37.40 %\n" 484 | ] 485 | }, 486 | { 487 | "name": "stderr", 488 | "output_type": "stream", 489 | "text": [ 490 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:53<00:00, 2.44s/it]\n" 491 | ] 492 | }, 493 | { 494 | "name": "stdout", 495 | "output_type": "stream", 496 | "text": [ 497 | "absolute_spatial subset: Image2Text Accuracy: 12.07 %\n" 498 | ] 499 | }, 500 | { 501 | "name": "stderr", 502 | "output_type": "stream", 503 | "text": [ 504 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n" 505 | ] 506 | }, 507 | { 508 | "name": "stdout", 509 | "output_type": "stream", 510 | "text": [ 511 | "absolute_spatial subset: Text2Image Accuracy: 11.58 %\n", 512 | "\n", 513 | "############# finished the evaluation on all selected subsets ###############\n", 514 | "average of all subset: Image2Text Accuracy: 35.44 %\n", 515 | "average of all subset: Text2Image Accuracy: 34.17 %\n", 516 | "result saved to blip_evaluate_result.pth.\n" 517 | ] 518 | } 519 | ], 520 | "source": [ 521 | "# load model\n", 522 | "model_cache_dir = 'models/blip' # specify the path to save the downloaded model checkpoint\n", 523 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 524 | "model, image_preprocess = get_model(model_name='blip', cache_dir=model_cache_dir, device=device)\n", 525 | "# load datasets\n", 526 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 527 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 528 | "# evaluate\n", 529 | "result = {}\n", 530 | "i2t_acc = 0.\n", 531 | "t2i_acc = 0.\n", 532 | "subset_num = 0\n", 533 | "for subset_name, dataloaders in subsets.items():\n", 534 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 535 | " result[subset_name] = subset_result\n", 536 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 537 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 538 | " subset_num += 1\n", 539 | "# print and save results\n", 540 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 541 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 542 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 543 | "out_path = 'results' # specify the path to save the evaluation results\n", 544 | "os.makedirs(out_path, exist_ok=True)\n", 545 | "out_fn = f\"blip_result.pth\" # specify the filename according to the model you used\n", 546 | "torch.save(result, os.path.join(out_path, out_fn))\n", 547 | "print(f'result saved to {out_fn}.')" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "id": "33b537d7-3343-4aa5-a632-ef07ed8cfa6d", 553 | "metadata": {}, 554 | "source": [ 555 | "### 4.3 Evaluate FLAVA\n", 556 | "We use the `full` version of [FLAVA](https://arxiv.org/abs/2112.04482) with weights resumed from this [checkpoint](https://huggingface.co/facebook/flava-full)." 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "id": "cc335f88-09f7-445b-b548-23bac460e052", 563 | "metadata": {}, 564 | "outputs": [ 565 | { 566 | "name": "stderr", 567 | "output_type": "stream", 568 | "text": [ 569 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:14<00:00, 4.65s/it]\n" 570 | ] 571 | }, 572 | { 573 | "name": "stdout", 574 | "output_type": "stream", 575 | "text": [ 576 | "existence subset: Image2Text Accuracy: 57.90 %\n" 577 | ] 578 | }, 579 | { 580 | "name": "stderr", 581 | "output_type": "stream", 582 | "text": [ 583 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [02:07<00:00, 7.99s/it]\n" 584 | ] 585 | }, 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "existence subset: Text2Image Accuracy: 51.80 %\n" 591 | ] 592 | }, 593 | { 594 | "name": "stderr", 595 | "output_type": "stream", 596 | "text": [ 597 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:33<00:00, 4.78s/it]\n" 598 | ] 599 | }, 600 | { 601 | "name": "stdout", 602 | "output_type": "stream", 603 | "text": [ 604 | "relative_spatial subset: Image2Text Accuracy: 25.80 %\n" 605 | ] 606 | }, 607 | { 608 | "name": "stderr", 609 | "output_type": "stream", 610 | "text": [ 611 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [07:53<00:00, 14.80s/it]\n" 612 | ] 613 | }, 614 | { 615 | "name": "stdout", 616 | "output_type": "stream", 617 | "text": [ 618 | "relative_spatial subset: Text2Image Accuracy: 25.85 %\n" 619 | ] 620 | }, 621 | { 622 | "name": "stderr", 623 | "output_type": "stream", 624 | "text": [ 625 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.42s/it]\n" 626 | ] 627 | }, 628 | { 629 | "name": "stdout", 630 | "output_type": "stream", 631 | "text": [ 632 | "absolute_size subset: Image2Text Accuracy: 37.07 %\n" 633 | ] 634 | }, 635 | { 636 | "name": "stderr", 637 | "output_type": "stream", 638 | "text": [ 639 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:31<00:00, 11.29s/it]\n" 640 | ] 641 | }, 642 | { 643 | "name": "stdout", 644 | "output_type": "stream", 645 | "text": [ 646 | "absolute_size subset: Text2Image Accuracy: 36.67 %\n" 647 | ] 648 | }, 649 | { 650 | "name": "stderr", 651 | "output_type": "stream", 652 | "text": [ 653 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.45s/it]\n" 654 | ] 655 | }, 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "relative_size subset: Image2Text Accuracy: 33.53 %\n" 661 | ] 662 | }, 663 | { 664 | "name": "stderr", 665 | "output_type": "stream", 666 | "text": [ 667 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:44<00:00, 11.86s/it]\n" 668 | ] 669 | }, 670 | { 671 | "name": "stdout", 672 | "output_type": "stream", 673 | "text": [ 674 | "relative_size subset: Text2Image Accuracy: 33.07 %\n" 675 | ] 676 | }, 677 | { 678 | "name": "stderr", 679 | "output_type": "stream", 680 | "text": [ 681 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [05:46<00:00, 4.88s/it]\n" 682 | ] 683 | }, 684 | { 685 | "name": "stdout", 686 | "output_type": "stream", 687 | "text": [ 688 | "count subset: Image2Text Accuracy: 14.00 %\n" 689 | ] 690 | }, 691 | { 692 | "name": "stderr", 693 | "output_type": "stream", 694 | "text": [ 695 | "Text to Image retrieval on : 1%|█▍ | 1/71 [01:43<2:00:24, 103.21s/it]" 696 | ] 697 | } 698 | ], 699 | "source": [ 700 | "# load model\n", 701 | "model_cache_dir = 'models/flava' # specify the path to save the downloaded model checkpoint\n", 702 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 703 | "model, image_preprocess = get_model(model_name='flava', cache_dir=model_cache_dir, device=device)\n", 704 | "# load datasets\n", 705 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 706 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 707 | "# evaluate\n", 708 | "result = {}\n", 709 | "i2t_acc = 0.\n", 710 | "t2i_acc = 0.\n", 711 | "subset_num = 0\n", 712 | "for subset_name, dataloaders in subsets.items():\n", 713 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 714 | " result[subset_name] = subset_result\n", 715 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 716 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 717 | " subset_num += 1\n", 718 | "# print and save results\n", 719 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 720 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 721 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 722 | "out_path = 'results' # specify the path to save the evluation results\n", 723 | "os.makedirs(out_path, exist_ok=True)\n", 724 | "out_fn = f\"flava_result.pth\" # specify the filename according to the model you used\n", 725 | "torch.save(result, os.path.join(out_path, out_fn))\n", 726 | "print(f'result saved to {out_fn}.')" 727 | ] 728 | }, 729 | { 730 | "cell_type": "markdown", 731 | "id": "e3528ed5-6ab2-456f-b712-0f7250d9d557", 732 | "metadata": {}, 733 | "source": [ 734 | "### 4.4 Evaluate CoCa\n", 735 | "We used the `ViT/B-32` variant of [CoCa](https://arxiv.org/abs/2205.01917) model with weights resumed from the [checkpoint](https://github.com/mlfoundations/open_clip) that pretrained on LAION-2B dataset." 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": null, 741 | "id": "2fde7225-1399-45a5-b59d-a71921f140be", 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "# load model\n", 746 | "model_cache_dir = 'models/coca' # specify the path to save the downloaded model checkpoint\n", 747 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 748 | "model, image_preprocess = get_model(model_name='coca', cache_dir=model_cache_dir, device=device)\n", 749 | "# load datasets\n", 750 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 751 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 752 | "# evaluate\n", 753 | "result = {}\n", 754 | "i2t_acc = 0.\n", 755 | "t2i_acc = 0.\n", 756 | "subset_num = 0\n", 757 | "for subset_name, dataloaders in subsets.items():\n", 758 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 759 | " result[subset_name] = subset_result\n", 760 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 761 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 762 | " subset_num += 1\n", 763 | "# print and save results\n", 764 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 765 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 766 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 767 | "out_path = 'results' # specify the path to save the evluation results\n", 768 | "os.makedirs(out_path, exist_ok=True)\n", 769 | "out_fn = f\"coca_result.pth\" # specify the filename according to the model you used\n", 770 | "torch.save(result, os.path.join(out_path, out_fn))\n", 771 | "print(f'result saved to {out_fn}.')" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "id": "b46eaa47-3b2a-4e14-b366-8a567f7e1e54", 777 | "metadata": {}, 778 | "source": [ 779 | "## What's Next?\n", 780 | "Want to test your own visual language model on SPEC? We have provided a [tutorial](https://github.com/wjpoom/SPEC/blob/main/docs/evaluate_custom_model.md) to help evaluate custom models, feel free to have a try." 781 | ] 782 | } 783 | ], 784 | "metadata": { 785 | "kernelspec": { 786 | "display_name": "spec", 787 | "language": "python", 788 | "name": "spec" 789 | }, 790 | "language_info": { 791 | "codemirror_mode": { 792 | "name": "ipython", 793 | "version": 3 794 | }, 795 | "file_extension": ".py", 796 | "mimetype": "text/x-python", 797 | "name": "python", 798 | "nbconvert_exporter": "python", 799 | "pygments_lexer": "ipython3", 800 | "version": "3.8.18" 801 | } 802 | }, 803 | "nbformat": 4, 804 | "nbformat_minor": 5 805 | } 806 | -------------------------------------------------------------------------------- /notebooks/evaluate_example_local.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "048e9a63-4c0e-4c4e-8d90-34f2af8049f9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Evaluate Popular Vision Language Models on SPEC\n", 9 | "In [our paper](https://arxiv.org/abs/2312.00081), we evaluated four popular VLMs using our SPEC dataset, namely: [CLIP](https://arxiv.org/abs/2103.00020), [BLIP](https://arxiv.org/abs/2201.12086), [FLAVA](https://arxiv.org/abs/2112.04482), and [CoCa](https://arxiv.org/abs/2205.01917). \\\n", 10 | "This notebook will guide readers to reproduce these results step by step, let's go!" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "c3c22764-d4af-4ab9-8460-d6e4dfc79301", 16 | "metadata": {}, 17 | "source": [ 18 | "## 1. How to use this notebook?\n", 19 | "You can run this notebook locally, before running, make sure that you have prepared the environment. \\\n", 20 | "You can also directly run this online notebook: [![online notebook](https://img.shields.io/badge/colab-notebook-yellow)](https://colab.research.google.com/github/wjpoom/SPEC/blob/main/notebooks/evaluate_example_colab.ipynb)." 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fb66cde2-cca5-45c1-ac3d-706b3ff4662d", 26 | "metadata": {}, 27 | "source": [ 28 | "## 2. Import Packages" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "id": "c5a66c29-7726-4126-bdeb-0a20a1eeeac4", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import zipfile\n", 39 | "import os\n", 40 | "import torch\n", 41 | "import warnings\n", 42 | "warnings.filterwarnings('ignore')\n", 43 | "from spec import get_data, get_model\n", 44 | "from huggingface_hub import hf_hub_download" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "f8d7636e-5a5e-462e-b99a-45feccf1e3ba", 50 | "metadata": {}, 51 | "source": [ 52 | "## 3. Prepare the testing dataset\n", 53 | "We store the data on HuggingFace. Before starting, you need to download and decompress the data as following:" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "id": "856550c0-da71-4c09-9533-2df5bd75ba19", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# specify the path to save the downloaded and extracted the data\n", 64 | "data_root = '/path/to/save/data'\n", 65 | "# download *.zip files\n", 66 | "hf_hub_download(repo_id='wjpoom/SPEC', repo_type='dataset', filename='data.zip', local_dir=data_root)\n", 67 | "# extract *.zip files\n", 68 | "with zipfile.ZipFile(os.path.join(data_root, 'data.zip'), 'r') as zip_ref:\n", 69 | " zip_ref.extractall(os.path.join(data_root))\n", 70 | "# remove the *.zip files\n", 71 | "os.remove(os.path.join(data_root, 'data.zip'))\n", 72 | "print(f'The SPEC dataset is prepared at: {data_root}.')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "860f7834-e471-4f9a-9d4b-b4a152ffc133", 78 | "metadata": {}, 79 | "source": [ 80 | "## 4. Let's Evaluate VLMs on SPEC dataset!" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "b06a98b7-2b4d-47d3-a13b-794dca37a403", 86 | "metadata": {}, 87 | "source": [ 88 | "### 4.1 Evaluate CLIP\n", 89 | "We use the `ViT/B-32` variant of [CLIP](https://arxiv.org/abs/2103.00020) with weights resumed from the checkpoint release by OpenAI." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "id": "8e687eb6-08a7-47d8-9927-30b174ea7a32", 96 | "metadata": { 97 | "scrolled": true 98 | }, 99 | "outputs": [ 100 | { 101 | "name": "stderr", 102 | "output_type": "stream", 103 | "text": [ 104 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:07<00:00, 2.27it/s]\n" 105 | ] 106 | }, 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "existence subset: Image2Text Accuracy: 57.00 %\n" 112 | ] 113 | }, 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:10<00:00, 1.46it/s]\n" 119 | ] 120 | }, 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "existence subset: Text2Image Accuracy: 52.00 %\n" 126 | ] 127 | }, 128 | { 129 | "name": "stderr", 130 | "output_type": "stream", 131 | "text": [ 132 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:12<00:00, 2.66it/s]\n" 133 | ] 134 | }, 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "relative_spatial subset: Image2Text Accuracy: 27.10 %\n" 140 | ] 141 | }, 142 | { 143 | "name": "stderr", 144 | "output_type": "stream", 145 | "text": [ 146 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:36<00:00, 1.15s/it]\n" 147 | ] 148 | }, 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "relative_spatial subset: Text2Image Accuracy: 26.75 %\n" 154 | ] 155 | }, 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.69it/s]\n" 161 | ] 162 | }, 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "absolute_size subset: Image2Text Accuracy: 44.27 %\n" 168 | ] 169 | }, 170 | { 171 | "name": "stderr", 172 | "output_type": "stream", 173 | "text": [ 174 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n" 175 | ] 176 | }, 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "absolute_size subset: Text2Image Accuracy: 36.27 %\n" 182 | ] 183 | }, 184 | { 185 | "name": "stderr", 186 | "output_type": "stream", 187 | "text": [ 188 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:08<00:00, 2.67it/s]\n" 189 | ] 190 | }, 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "relative_size subset: Image2Text Accuracy: 34.07 %\n" 196 | ] 197 | }, 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00, 1.08it/s]\n" 203 | ] 204 | }, 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "relative_size subset: Text2Image Accuracy: 32.47 %\n" 210 | ] 211 | }, 212 | { 213 | "name": "stderr", 214 | "output_type": "stream", 215 | "text": [ 216 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:43<00:00, 1.62it/s]\n" 217 | ] 218 | }, 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "count subset: Image2Text Accuracy: 25.27 %\n" 224 | ] 225 | }, 226 | { 227 | "name": "stderr", 228 | "output_type": "stream", 229 | "text": [ 230 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:59<00:00, 2.52s/it]\n" 231 | ] 232 | }, 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "count subset: Text2Image Accuracy: 23.62 %\n" 238 | ] 239 | }, 240 | { 241 | "name": "stderr", 242 | "output_type": "stream", 243 | "text": [ 244 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:45<00:00, 1.57it/s]\n" 245 | ] 246 | }, 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "absolute_spatial subset: Image2Text Accuracy: 12.64 %\n" 252 | ] 253 | }, 254 | { 255 | "name": "stderr", 256 | "output_type": "stream", 257 | "text": [ 258 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:55<00:00, 2.48s/it]" 259 | ] 260 | }, 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "absolute_spatial subset: Text2Image Accuracy: 12.20 %\n", 266 | "\n", 267 | "############# finished the evaluation on all selected subsets ###############\n", 268 | "average of all subset: Image2Text Accuracy: 33.39 %\n", 269 | "average of all subset: Text2Image Accuracy: 30.55 %\n", 270 | "result saved to clip_openai_evaluate_result.pth.\n" 271 | ] 272 | }, 273 | { 274 | "name": "stderr", 275 | "output_type": "stream", 276 | "text": [ 277 | "\n" 278 | ] 279 | } 280 | ], 281 | "source": [ 282 | "# load model\n", 283 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n", 284 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 285 | "model, image_preprocess = get_model(model_name='clip', cache_dir=model_cache_dir, device=device)\n", 286 | "# load datasets\n", 287 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 288 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 289 | "# evaluate\n", 290 | "result = {}\n", 291 | "i2t_acc = 0.\n", 292 | "t2i_acc = 0.\n", 293 | "subset_num = 0\n", 294 | "for subset_name, dataloaders in subsets.items():\n", 295 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 296 | " result[subset_name] = subset_result\n", 297 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 298 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 299 | " subset_num += 1\n", 300 | "# print and save results\n", 301 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 302 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 303 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 304 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n", 305 | "os.makedirs(out_path, exist_ok=True)\n", 306 | "out_fn = f\"clip_result.pth\" # specify the filename according to the model you used\n", 307 | "torch.save(result, os.path.join(out_path, out_fn))\n", 308 | "print(f'result saved to {out_fn}.')" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "0614e38d-646b-4d7e-a6e7-0a6b05307da9", 314 | "metadata": {}, 315 | "source": [ 316 | "### 4.2 Evaluate BLIP\n", 317 | "We use the `ViT-B` variant of [BLIP](https://arxiv.org/abs/2201.12086) with weights resumed from the checkpoint released in this [link](https://github.com/salesforce/BLIP), which is finetuned on COCO for image-text retrieval." 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 4, 323 | "id": "ec215787-f3ea-4890-af47-7e83abc3a9a6", 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "load checkpoint from ~/.cache/blip/blip-coco-base.pth\n", 331 | "missing keys:\n", 332 | "[]\n" 333 | ] 334 | }, 335 | { 336 | "name": "stderr", 337 | "output_type": "stream", 338 | "text": [ 339 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:36<00:00, 2.29s/it]\n" 340 | ] 341 | }, 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | "existence subset: Image2Text Accuracy: 55.50 %\n" 347 | ] 348 | }, 349 | { 350 | "name": "stderr", 351 | "output_type": "stream", 352 | "text": [ 353 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:38<00:00, 2.39s/it]\n" 354 | ] 355 | }, 356 | { 357 | "name": "stdout", 358 | "output_type": "stream", 359 | "text": [ 360 | "existence subset: Text2Image Accuracy: 50.10 %\n" 361 | ] 362 | }, 363 | { 364 | "name": "stderr", 365 | "output_type": "stream", 366 | "text": [ 367 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:11<00:00, 2.23s/it]\n" 368 | ] 369 | }, 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "relative_spatial subset: Image2Text Accuracy: 30.65 %\n" 375 | ] 376 | }, 377 | { 378 | "name": "stderr", 379 | "output_type": "stream", 380 | "text": [ 381 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:17<00:00, 4.31s/it]\n" 382 | ] 383 | }, 384 | { 385 | "name": "stdout", 386 | "output_type": "stream", 387 | "text": [ 388 | "relative_spatial subset: Text2Image Accuracy: 29.60 %\n" 389 | ] 390 | }, 391 | { 392 | "name": "stderr", 393 | "output_type": "stream", 394 | "text": [ 395 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:55<00:00, 2.30s/it]\n" 396 | ] 397 | }, 398 | { 399 | "name": "stdout", 400 | "output_type": "stream", 401 | "text": [ 402 | "absolute_size subset: Image2Text Accuracy: 43.20 %\n" 403 | ] 404 | }, 405 | { 406 | "name": "stderr", 407 | "output_type": "stream", 408 | "text": [ 409 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.36s/it]\n" 410 | ] 411 | }, 412 | { 413 | "name": "stdout", 414 | "output_type": "stream", 415 | "text": [ 416 | "absolute_size subset: Text2Image Accuracy: 43.07 %\n" 417 | ] 418 | }, 419 | { 420 | "name": "stderr", 421 | "output_type": "stream", 422 | "text": [ 423 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:54<00:00, 2.26s/it]\n" 424 | ] 425 | }, 426 | { 427 | "name": "stdout", 428 | "output_type": "stream", 429 | "text": [ 430 | "relative_size subset: Image2Text Accuracy: 34.33 %\n" 431 | ] 432 | }, 433 | { 434 | "name": "stderr", 435 | "output_type": "stream", 436 | "text": [ 437 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:20<00:00, 3.37s/it]\n" 438 | ] 439 | }, 440 | { 441 | "name": "stdout", 442 | "output_type": "stream", 443 | "text": [ 444 | "relative_size subset: Text2Image Accuracy: 33.27 %\n" 445 | ] 446 | }, 447 | { 448 | "name": "stderr", 449 | "output_type": "stream", 450 | "text": [ 451 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:44<00:00, 2.32s/it]\n" 452 | ] 453 | }, 454 | { 455 | "name": "stdout", 456 | "output_type": "stream", 457 | "text": [ 458 | "count subset: Image2Text Accuracy: 36.87 %\n" 459 | ] 460 | }, 461 | { 462 | "name": "stderr", 463 | "output_type": "stream", 464 | "text": [ 465 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n" 466 | ] 467 | }, 468 | { 469 | "name": "stdout", 470 | "output_type": "stream", 471 | "text": [ 472 | "count subset: Text2Image Accuracy: 37.40 %\n" 473 | ] 474 | }, 475 | { 476 | "name": "stderr", 477 | "output_type": "stream", 478 | "text": [ 479 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [02:53<00:00, 2.44s/it]\n" 480 | ] 481 | }, 482 | { 483 | "name": "stdout", 484 | "output_type": "stream", 485 | "text": [ 486 | "absolute_spatial subset: Image2Text Accuracy: 12.07 %\n" 487 | ] 488 | }, 489 | { 490 | "name": "stderr", 491 | "output_type": "stream", 492 | "text": [ 493 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [10:56<00:00, 9.25s/it]\n" 494 | ] 495 | }, 496 | { 497 | "name": "stdout", 498 | "output_type": "stream", 499 | "text": [ 500 | "absolute_spatial subset: Text2Image Accuracy: 11.58 %\n", 501 | "\n", 502 | "############# finished the evaluation on all selected subsets ###############\n", 503 | "average of all subset: Image2Text Accuracy: 35.44 %\n", 504 | "average of all subset: Text2Image Accuracy: 34.17 %\n", 505 | "result saved to blip_evaluate_result.pth.\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "# load model\n", 511 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n", 512 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 513 | "model, image_preprocess = get_model(model_name='blip', cache_dir=model_cache_dir, device=device)\n", 514 | "# load datasets\n", 515 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 516 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 517 | "# evaluate\n", 518 | "result = {}\n", 519 | "i2t_acc = 0.\n", 520 | "t2i_acc = 0.\n", 521 | "subset_num = 0\n", 522 | "for subset_name, dataloaders in subsets.items():\n", 523 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 524 | " result[subset_name] = subset_result\n", 525 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 526 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 527 | " subset_num += 1\n", 528 | "# print and save results\n", 529 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 530 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 531 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 532 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n", 533 | "os.makedirs(out_path, exist_ok=True)\n", 534 | "out_fn = f\"blip_result.pth\" # specify the filename according to the model you used\n", 535 | "torch.save(result, os.path.join(out_path, out_fn))\n", 536 | "print(f'result saved to {out_fn}.')" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "id": "33b537d7-3343-4aa5-a632-ef07ed8cfa6d", 542 | "metadata": {}, 543 | "source": [ 544 | "### 4.3 Evaluate FLAVA\n", 545 | "We use the `full` version of [FLAVA](https://arxiv.org/abs/2112.04482) with weights resumed from this [checkpoint](https://huggingface.co/facebook/flava-full)." 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "id": "cc335f88-09f7-445b-b548-23bac460e052", 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "name": "stderr", 556 | "output_type": "stream", 557 | "text": [ 558 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:14<00:00, 4.65s/it]\n" 559 | ] 560 | }, 561 | { 562 | "name": "stdout", 563 | "output_type": "stream", 564 | "text": [ 565 | "existence subset: Image2Text Accuracy: 57.90 %\n" 566 | ] 567 | }, 568 | { 569 | "name": "stderr", 570 | "output_type": "stream", 571 | "text": [ 572 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [02:07<00:00, 7.99s/it]\n" 573 | ] 574 | }, 575 | { 576 | "name": "stdout", 577 | "output_type": "stream", 578 | "text": [ 579 | "existence subset: Text2Image Accuracy: 51.80 %\n" 580 | ] 581 | }, 582 | { 583 | "name": "stderr", 584 | "output_type": "stream", 585 | "text": [ 586 | "Image to Text retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:33<00:00, 4.78s/it]\n" 587 | ] 588 | }, 589 | { 590 | "name": "stdout", 591 | "output_type": "stream", 592 | "text": [ 593 | "relative_spatial subset: Image2Text Accuracy: 25.80 %\n" 594 | ] 595 | }, 596 | { 597 | "name": "stderr", 598 | "output_type": "stream", 599 | "text": [ 600 | "Text to Image retrieval on : 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [07:53<00:00, 14.80s/it]\n" 601 | ] 602 | }, 603 | { 604 | "name": "stdout", 605 | "output_type": "stream", 606 | "text": [ 607 | "relative_spatial subset: Text2Image Accuracy: 25.85 %\n" 608 | ] 609 | }, 610 | { 611 | "name": "stderr", 612 | "output_type": "stream", 613 | "text": [ 614 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.42s/it]\n" 615 | ] 616 | }, 617 | { 618 | "name": "stdout", 619 | "output_type": "stream", 620 | "text": [ 621 | "absolute_size subset: Image2Text Accuracy: 37.07 %\n" 622 | ] 623 | }, 624 | { 625 | "name": "stderr", 626 | "output_type": "stream", 627 | "text": [ 628 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:31<00:00, 11.29s/it]\n" 629 | ] 630 | }, 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "absolute_size subset: Text2Image Accuracy: 36.67 %\n" 636 | ] 637 | }, 638 | { 639 | "name": "stderr", 640 | "output_type": "stream", 641 | "text": [ 642 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:46<00:00, 4.45s/it]\n" 643 | ] 644 | }, 645 | { 646 | "name": "stdout", 647 | "output_type": "stream", 648 | "text": [ 649 | "relative_size subset: Image2Text Accuracy: 33.53 %\n" 650 | ] 651 | }, 652 | { 653 | "name": "stderr", 654 | "output_type": "stream", 655 | "text": [ 656 | "Text to Image retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:44<00:00, 11.86s/it]\n" 657 | ] 658 | }, 659 | { 660 | "name": "stdout", 661 | "output_type": "stream", 662 | "text": [ 663 | "relative_size subset: Text2Image Accuracy: 33.07 %\n" 664 | ] 665 | }, 666 | { 667 | "name": "stderr", 668 | "output_type": "stream", 669 | "text": [ 670 | "Image to Text retrieval on : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [05:46<00:00, 4.88s/it]\n" 671 | ] 672 | }, 673 | { 674 | "name": "stdout", 675 | "output_type": "stream", 676 | "text": [ 677 | "count subset: Image2Text Accuracy: 14.00 %\n" 678 | ] 679 | }, 680 | { 681 | "name": "stderr", 682 | "output_type": "stream", 683 | "text": [ 684 | "Text to Image retrieval on : 1%|█▍ | 1/71 [01:43<2:00:24, 103.21s/it]" 685 | ] 686 | } 687 | ], 688 | "source": [ 689 | "# load model\n", 690 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n", 691 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 692 | "model, image_preprocess = get_model(model_name='flava', cache_dir=model_cache_dir, device=device)\n", 693 | "# load datasets\n", 694 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 695 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 696 | "# evaluate\n", 697 | "result = {}\n", 698 | "i2t_acc = 0.\n", 699 | "t2i_acc = 0.\n", 700 | "subset_num = 0\n", 701 | "for subset_name, dataloaders in subsets.items():\n", 702 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 703 | " result[subset_name] = subset_result\n", 704 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 705 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 706 | " subset_num += 1\n", 707 | "# print and save results\n", 708 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 709 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 710 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 711 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n", 712 | "os.makedirs(out_path, exist_ok=True)\n", 713 | "out_fn = f\"flava_result.pth\" # specify the filename according to the model you used\n", 714 | "torch.save(result, os.path.join(out_path, out_fn))\n", 715 | "print(f'result saved to {out_fn}.')" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "id": "e3528ed5-6ab2-456f-b712-0f7250d9d557", 721 | "metadata": {}, 722 | "source": [ 723 | "### 4.4 Evaluate CoCa\n", 724 | "We used the `ViT/B-32` variant of [CoCa](https://arxiv.org/abs/2205.01917) model with weights resumed from the [checkpoint](https://github.com/mlfoundations/open_clip) that pretrained on LAION-2B dataset." 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": null, 730 | "id": "2fde7225-1399-45a5-b59d-a71921f140be", 731 | "metadata": {}, 732 | "outputs": [], 733 | "source": [ 734 | "# load model\n", 735 | "model_cache_dir = '/path/to/cache/models' # specify the path to save the downloaded model checkpoint\n", 736 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 737 | "model, image_preprocess = get_model(model_name='coca', cache_dir=model_cache_dir, device=device)\n", 738 | "# load datasets\n", 739 | "subset_names = ['absolute_size', 'relative_size', 'absolute_spatial', 'relative_spatial', 'existence', 'count']\n", 740 | "subsets = get_data(data_root=data_root, subset_names=subset_names, image_preprocess=image_preprocess, batch_size=64, num_workers=8)\n", 741 | "# evaluate\n", 742 | "result = {}\n", 743 | "i2t_acc = 0.\n", 744 | "t2i_acc = 0.\n", 745 | "subset_num = 0\n", 746 | "for subset_name, dataloaders in subsets.items():\n", 747 | " subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders)\n", 748 | " result[subset_name] = subset_result\n", 749 | " i2t_acc += subset_result['accuracy']['i2t_accuracy']\n", 750 | " t2i_acc += subset_result['accuracy']['t2i_accuracy']\n", 751 | " subset_num += 1\n", 752 | "# print and save results\n", 753 | "print(f'\\n############# finished the evaluation on all selected subsets ###############')\n", 754 | "print(f'average of all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %')\n", 755 | "print(f'average of all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %')\n", 756 | "out_path = '/path/to/save/results' # specify the path to save the evaluation results\n", 757 | "os.makedirs(out_path, exist_ok=True)\n", 758 | "out_fn = f\"coca_result.pth\" # specify the filename according to the model you used\n", 759 | "torch.save(result, os.path.join(out_path, out_fn))\n", 760 | "print(f'result saved to {out_fn}.')" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "id": "b46eaa47-3b2a-4e14-b366-8a567f7e1e54", 766 | "metadata": {}, 767 | "source": [ 768 | "## What's Next?\n", 769 | "Want to test your own visual language model on SPEC? We have provided a [tutorial](https://github.com/wjpoom/SPEC/blob/main/docs/evaluate_custom_model.md) to help evaluate custom models, feel free to have a try." 770 | ] 771 | } 772 | ], 773 | "metadata": { 774 | "kernelspec": { 775 | "display_name": "spec", 776 | "language": "python", 777 | "name": "spec" 778 | }, 779 | "language_info": { 780 | "codemirror_mode": { 781 | "name": "ipython", 782 | "version": 3 783 | }, 784 | "file_extension": ".py", 785 | "mimetype": "text/x-python", 786 | "name": "python", 787 | "nbconvert_exporter": "python", 788 | "pygments_lexer": "ipython3", 789 | "version": "3.8.18" 790 | } 791 | }, 792 | "nbformat": 4, 793 | "nbformat_minor": 5 794 | } 795 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='spec', 7 | version='1.0.0', 8 | packages=find_packages(), 9 | install_requires=[ 10 | 'einops>=0.7.0', 11 | 'open_clip_torch==2.24.0', 12 | 'PyYAML==6.0.1', 13 | 'setuptools==69.2.0', 14 | 'timm==0.9.16', 15 | 'torch==2.2.1', 16 | 'torchvision==0.17.1', 17 | 'tqdm==4.66.2', 18 | 'transformers==4.38.2', 19 | 'huggingface-hub==0.21.4', 20 | 'jedi>=0.16', 21 | 'fairscale==0.4.13', 22 | ] 23 | ) -------------------------------------------------------------------------------- /spec/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import get_data 2 | from .models import get_model -------------------------------------------------------------------------------- /spec/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from PIL import Image 7 | 8 | 9 | class Image2TextDataset(Dataset): 10 | def __init__(self, subset_root, image_preprocess=None): 11 | """ 12 | Args: 13 | subset_root: the path to the root dir of a subset, (e.g. `absolute_size`) 14 | """ 15 | self.subset_root = subset_root 16 | self.image_preprocess = image_preprocess 17 | 18 | ann = os.path.join(subset_root, 'image2text.json') 19 | with open(ann, 'r') as f: 20 | self.sample_list = json.load(f) 21 | f.close() 22 | 23 | def __len__(self): 24 | return len(self.sample_list) 25 | 26 | def __getitem__(self, idx): 27 | sample_info = self.sample_list[idx] 28 | 29 | # query image 30 | image_path = os.path.join(self.subset_root, sample_info['query']) 31 | query_image = Image.open(image_path).convert('RGB') 32 | if self.image_preprocess is not None: 33 | query_image = self.image_preprocess(query_image) 34 | 35 | # candidate texts 36 | candidate_texts = sample_info['keys'] 37 | 38 | # label 39 | label = sample_info['label'] 40 | 41 | sample = { 42 | "query_image": query_image, 43 | "candidate_texts": candidate_texts, 44 | "label": label 45 | } 46 | 47 | return sample 48 | 49 | def collate_fn(self, batch): 50 | query_image = [] 51 | candidate_texts = [] 52 | label = [] 53 | for sample in batch: 54 | query_image.append(sample['query_image']) 55 | candidate_texts.append(sample['candidate_texts']) 56 | label.append(sample['label']) 57 | if self.image_preprocess is not None: 58 | query_image = torch.stack(query_image, dim=0) 59 | batch = { 60 | 'query_image': query_image, 61 | 'candidate_texts': candidate_texts, 62 | 'label': torch.tensor(label) 63 | } 64 | return batch 65 | 66 | 67 | class Text2ImageDataset(Dataset): 68 | def __init__(self, subset_root, image_preprocess=None): 69 | """ 70 | Args: 71 | subset_root: the path to the root dir of a subset, (e.g. `absolute_size`) 72 | """ 73 | self.subset_root = subset_root 74 | self.image_preprocess = image_preprocess 75 | 76 | ann = os.path.join(subset_root, 'text2image.json') 77 | with open(ann, 'r') as f: 78 | self.sample_list = json.load(f) 79 | f.close() 80 | 81 | def __len__(self): 82 | return len(self.sample_list) 83 | 84 | def __getitem__(self, idx): 85 | sample_info = self.sample_list[idx] 86 | 87 | # query text 88 | query_text = sample_info['query'] 89 | 90 | # candidate images 91 | candidate_images = [] 92 | for img in sample_info['keys']: 93 | img = Image.open(os.path.join(self.subset_root, img)).convert('RGB') 94 | if self.image_preprocess is not None: 95 | img = self.image_preprocess(img) 96 | candidate_images.append(img) 97 | if self.image_preprocess is not None: 98 | candidate_images = torch.stack(candidate_images, dim=0) 99 | 100 | # label 101 | label = sample_info['label'] 102 | 103 | sample = { 104 | "query_text": query_text, 105 | "candidate_images": candidate_images, 106 | "label": label 107 | } 108 | 109 | return sample 110 | 111 | def collate_fn(self, batch): 112 | query_text = [] 113 | candidate_images = [] 114 | label = [] 115 | for sample in batch: 116 | query_text.append(sample['query_text']) 117 | candidate_images.append(sample['candidate_images']) 118 | label.append(sample['label']) 119 | if self.image_preprocess is not None: 120 | candidate_images = torch.stack(candidate_images, dim=0) 121 | 122 | batch = { 123 | 'query_text': query_text, 124 | 'candidate_images': candidate_images, 125 | 'label': torch.tensor(label) 126 | } 127 | return batch 128 | 129 | 130 | def get_data(data_root, subset_names, image_preprocess, batch_size, num_workers): 131 | """ 132 | Create SPEC datasets 133 | Args: 134 | data_root: the path to the dir contains all the subsets' data 135 | subset_names: selected subsets names that you wish to evaluate your model on 136 | image_preprocess: image_preprocess 137 | batch_size: batch_size for dataloader 138 | num_workers: num_workers for dataloader 139 | Return: 140 | A list contains selected sub-datasets 141 | """ 142 | 143 | data = {} 144 | for subset_nm in subset_names: 145 | 146 | subset_root_path = os.path.join(data_root, subset_nm) 147 | i2t_dataset = Image2TextDataset(subset_root=subset_root_path, 148 | image_preprocess=image_preprocess) 149 | t2i_dataset = Text2ImageDataset(subset_root=subset_root_path, 150 | image_preprocess=image_preprocess) 151 | 152 | i2t_loader = DataLoader(dataset=i2t_dataset, 153 | batch_size=batch_size, 154 | num_workers=num_workers, 155 | collate_fn=i2t_dataset.collate_fn, 156 | shuffle=False, 157 | drop_last=False) 158 | 159 | t2i_loader = DataLoader(dataset=t2i_dataset, 160 | batch_size=batch_size, 161 | num_workers=num_workers, 162 | collate_fn=t2i_dataset.collate_fn, 163 | shuffle=False, 164 | drop_last=False) 165 | 166 | data[subset_nm] = { 167 | 'i2t_dataloader': i2t_loader, 168 | 't2i_dataloader': t2i_loader 169 | } 170 | 171 | return data 172 | -------------------------------------------------------------------------------- /spec/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import random 5 | import numpy as np 6 | 7 | from models import get_model 8 | from dataset import get_data 9 | 10 | 11 | def seed_all(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | 18 | @torch.no_grad() 19 | def main(): 20 | os.makedirs(args.out_path, exist_ok=True) 21 | 22 | # set random seeds 23 | seed_all(args.seed) 24 | 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | # load model 27 | model, image_preprocess = get_model(model_name=args.model_name, 28 | cache_dir=args.model_cache_dir, 29 | device=device) 30 | 31 | # load data 32 | data = get_data(data_root=args.data_root, 33 | subset_names=args.subset_names, 34 | image_preprocess=image_preprocess, 35 | batch_size=args.batch_size, 36 | num_workers=args.num_workers) 37 | 38 | # evaluate on each subset 39 | print(f'\nBegin the evaluation of {args.model_name} on all selected subsets.') 40 | result = {} 41 | i2t_acc = 0. 42 | t2i_acc = 0. 43 | subset_num = 0 44 | for subset_name, dataloaders in data.items(): 45 | subset_result = model.evaluate(subset_name=subset_name, dataloaders=dataloaders) 46 | result[subset_name] = subset_result 47 | i2t_acc += subset_result['accuracy']['i2t_accuracy'] 48 | t2i_acc += subset_result['accuracy']['t2i_accuracy'] 49 | subset_num += 1 50 | print(f'\nFinished the evaluation of {args.model_name} on all selected subsets.') 51 | print(f'average all subset: Image2Text Accuracy: {i2t_acc/subset_num:.2f} %') 52 | print(f'average all subset: Text2Image Accuracy: {t2i_acc/subset_num:.2f} %') 53 | 54 | # save results 55 | out_fn = f"{args.model_name}_result.pth" 56 | out_fn = os.path.join(args.out_path, out_fn) 57 | torch.save(result, out_fn) 58 | print(f'result saved to {out_fn}.') 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser('Vision Language Models Evaluation Pipeline') 63 | parser.add_argument('--model-name', 64 | type=str, 65 | default='clip') 66 | parser.add_argument('--pretrained', 67 | type=str, 68 | help="the pretrained model checkpoint") 69 | parser.add_argument('--model-cache-dir', 70 | type=str, 71 | default='~/.cache', 72 | help='the path to cache the downloaded model checkpoints') 73 | parser.add_argument('--subset-names', 74 | type=str, 75 | nargs='+', 76 | choices=['count', 'relative_size', 'absolute_size', 'relative_spatial', 'absolute_spatial', 77 | 'existence'], 78 | help='type of generated dataset type for enhanced ability') 79 | parser.add_argument('--data-root', 80 | type=str, 81 | help='the path the the root dir of data') 82 | parser.add_argument('--out-path', 83 | type=str, 84 | default='out', 85 | help="path to save evaluation-real results") 86 | parser.add_argument('--batch-size', 87 | type=int, 88 | default=64) 89 | parser.add_argument('--num-workers', 90 | type=int, 91 | default=8) 92 | parser.add_argument('--seed', 93 | type=int, 94 | default=1, 95 | help="random seed for reproducibility") 96 | 97 | args = parser.parse_args() 98 | 99 | main() 100 | -------------------------------------------------------------------------------- /spec/models/__init__.py: -------------------------------------------------------------------------------- 1 | """code credict: https://github.com/mertyg/vision-language-models-are-bows/tree/main/model_zoo""" 2 | import os 3 | 4 | def get_model(model_name, cache_dir='~/.cache', device='cuda'): 5 | """ 6 | Helper function that returns a model and an image preprocessing function and text tokenizer. 7 | Args: 8 | model_name: the model that you want to create 9 | cache_dir: the path to cache the downloader model checkpoints 10 | device 11 | Returns: 12 | pretrained_model, image_preprocess 13 | """ 14 | os.makedirs(cache_dir, exist_ok=True) 15 | 16 | if model_name == 'clip': 17 | from .clip_wrapper import CLIPWrapper 18 | clip_model = CLIPWrapper(device=device, variant='ViT-B-32', pretrained='openai') 19 | image_preprocess = clip_model.image_preprocess 20 | return clip_model, image_preprocess 21 | 22 | elif model_name == 'blip': 23 | from .blip_wrapper import BLIPModelWrapper 24 | blip_model = BLIPModelWrapper(cache_dir=cache_dir, device=device, variant="blip-coco-base") 25 | image_preprocess = blip_model.image_preprocess 26 | return blip_model, image_preprocess 27 | 28 | elif model_name == "flava": 29 | from .flava_wrapper import FlavaWrapper 30 | flava_model = FlavaWrapper(cache_dir=cache_dir, device=device) 31 | image_preprocess = None 32 | return flava_model, image_preprocess 33 | 34 | elif model_name == "coca": 35 | from .clip_wrapper import CLIPWrapper 36 | coca_model = CLIPWrapper(device=device, variant="coca_ViT-B-32", pretrained="laion2B-s13B-b90k") 37 | image_preprocess = coca_model.image_preprocess 38 | return coca_model, image_preprocess 39 | 40 | else: 41 | raise ValueError(f"Unknown model {model_name}") 42 | -------------------------------------------------------------------------------- /spec/models/base_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | 6 | class BaseWrapper(metaclass=ABCMeta): 7 | """ 8 | This is the base model wrapper, if you want to evluate 9 | """ 10 | 11 | @abstractmethod 12 | @torch.no_grad() 13 | def i2t_evaluate(self, subset_name, dataloader): 14 | """ 15 | Performing `Image-to-Text` retrieval evaluation. 16 | Args: 17 | subset_name: the name of subset that you are running on 18 | dataloader: the dataloader that contains all the image to text samples 19 | 20 | Returns: 21 | i2t_scores: 22 | i2t_acc: 23 | """ 24 | pass 25 | 26 | @abstractmethod 27 | @torch.no_grad() 28 | def t2i_evaluate(self, subset_name, dataloader): 29 | """ 30 | Performing `Text-to-Image` retrieval evaluation. 31 | Args: 32 | subset_name: the name of subset that you are running on 33 | dataloader: the dataloader that contains all the image to text samples 34 | 35 | Returns: 36 | t2i_scores: 37 | t2i_acc: 38 | """ 39 | pass 40 | 41 | @torch.no_grad() 42 | def evaluate(self, subset_name, dataloaders): 43 | """Computes the image-text matching scores and the image-to-text and text-to-image accuracy on a given subset 44 | Args: 45 | subset_name: the name of the subset 46 | dataloaders (Dict): include an "i2t_dataloader" and a "t2i_dataloader" 47 | Returns: 48 | scores(Dict of Tensor): `i2t_scores`, `t2i_scores` 49 | accuracy(Dict of Scalar): `i2t_accuracy`, t2i_accuracy` 50 | """ 51 | # image to text retrieval 52 | i2t_scores, i2t_acc = self.i2t_evaluate(subset_name, dataloaders['i2t_dataloader']) 53 | print(f'{subset_name} subset: Image2Text Accuracy: {i2t_acc:.2f} %') 54 | 55 | # text to image retrieval 56 | t2i_scores, t2i_acc = self.t2i_evaluate(subset_name, dataloaders['t2i_dataloader']) 57 | print(f'{subset_name} subset: Text2Image Accuracy: {t2i_acc:.2f} %') 58 | 59 | """ 60 | `i2t_scores`: tensor of shape NxL, N is the number of testing samples, L is the number of candidate texts per sample 61 | `t2i_scores`: tensor of shape NxK, N is the number of testing samples, K is the number of candidate images per sample 62 | """ 63 | scores = { 64 | 'i2t_scores': i2t_scores, 65 | 't2i_scores': t2i_scores 66 | } 67 | 68 | accuracy = { 69 | 'i2t_accuracy': i2t_acc, 70 | 't2i_accuracy': t2i_acc, 71 | } 72 | 73 | return { 74 | "accuracy": accuracy, 75 | "scores": scores 76 | } 77 | -------------------------------------------------------------------------------- /spec/models/blip_utils/.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | Most of the utilities here are obtained from https://github.com/salesforce/BLIP/ . -------------------------------------------------------------------------------- /spec/models/blip_utils/README.md: -------------------------------------------------------------------------------- 1 | Most of the utilities here are obtained from https://github.com/salesforce/BLIP/ . -------------------------------------------------------------------------------- /spec/models/blip_utils/__pycache__/blip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/blip.cpython-38.pyc -------------------------------------------------------------------------------- /spec/models/blip_utils/__pycache__/blip_retrieval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/blip_retrieval.cpython-38.pyc -------------------------------------------------------------------------------- /spec/models/blip_utils/__pycache__/med.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/med.cpython-38.pyc -------------------------------------------------------------------------------- /spec/models/blip_utils/__pycache__/vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wjpoom/SPEC/8dd6cdcc0bb4f47ea3551b1c0558ee554656ca7c/spec/models/blip_utils/__pycache__/vit.cpython-38.pyc -------------------------------------------------------------------------------- /spec/models/blip_utils/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | from .vit import VisionTransformer, interpolate_pos_embed 12 | from .med import BertConfig, BertModel, BertLMHeadModel 13 | from transformers import BertTokenizer 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | import os 20 | from urllib.parse import urlparse 21 | from timm.models.hub import download_cached_file 22 | 23 | class BLIP_Base(nn.Module): 24 | def __init__(self, 25 | med_config = 'configs/med_config.json', 26 | image_size = 224, 27 | vit = 'base', 28 | vit_grad_ckpt = False, 29 | vit_ckpt_layer = 0, 30 | ): 31 | """ 32 | Args: 33 | med_config (str): path for the mixture of encoder-decoder model's configuration file 34 | image_size (int): input image size 35 | vit (str): model size of vision transformer 36 | """ 37 | super().__init__() 38 | 39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 40 | self.tokenizer = init_tokenizer() 41 | med_config = BertConfig.from_json_file(med_config) 42 | med_config.encoder_width = vision_width 43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 44 | 45 | 46 | def forward(self, image, caption, mode): 47 | 48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 50 | 51 | if mode=='image': 52 | # return image features 53 | image_embeds = self.visual_encoder(image) 54 | return image_embeds 55 | 56 | elif mode=='text': 57 | # return text features 58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 59 | return_dict = True, mode = 'text') 60 | return text_output.last_hidden_state 61 | 62 | elif mode=='multimodal': 63 | # return multimodel features 64 | image_embeds = self.visual_encoder(image) 65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 66 | 67 | text.input_ids[:,0] = self.tokenizer.enc_token_id 68 | output = self.text_encoder(text.input_ids, 69 | attention_mask = text.attention_mask, 70 | encoder_hidden_states = image_embeds, 71 | encoder_attention_mask = image_atts, 72 | return_dict = True, 73 | ) 74 | return output.last_hidden_state 75 | 76 | 77 | 78 | class BLIP_Decoder(nn.Module): 79 | def __init__(self, 80 | med_config = 'configs/med_config.json', 81 | image_size = 384, 82 | vit = 'base', 83 | vit_grad_ckpt = False, 84 | vit_ckpt_layer = 0, 85 | prompt = 'a picture of ', 86 | ): 87 | """ 88 | Args: 89 | med_config (str): path for the mixture of encoder-decoder model's configuration file 90 | image_size (int): input image size 91 | vit (str): model size of vision transformer 92 | """ 93 | super().__init__() 94 | 95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 96 | self.tokenizer = init_tokenizer() 97 | med_config = BertConfig.from_json_file(med_config) 98 | med_config.encoder_width = vision_width 99 | self.text_decoder = BertLMHeadModel(config=med_config) 100 | 101 | self.prompt = prompt 102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 103 | 104 | 105 | def forward(self, image, caption): 106 | 107 | image_embeds = self.visual_encoder(image) 108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 109 | 110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 111 | 112 | text.input_ids[:,0] = self.tokenizer.bos_token_id 113 | 114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 115 | decoder_targets[:,:self.prompt_length] = -100 116 | 117 | decoder_output = self.text_decoder(text.input_ids, 118 | attention_mask = text.attention_mask, 119 | encoder_hidden_states = image_embeds, 120 | encoder_attention_mask = image_atts, 121 | labels = decoder_targets, 122 | return_dict = True, 123 | ) 124 | loss_lm = decoder_output.loss 125 | 126 | return loss_lm 127 | 128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 129 | image_embeds = self.visual_encoder(image) 130 | 131 | if not sample: 132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 133 | 134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 136 | 137 | prompt = [self.prompt] * image.size(0) 138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 139 | input_ids[:,0] = self.tokenizer.bos_token_id 140 | input_ids = input_ids[:, :-1] 141 | 142 | if sample: 143 | #nucleus sampling 144 | outputs = self.text_decoder.generate(input_ids=input_ids, 145 | max_length=max_length, 146 | min_length=min_length, 147 | do_sample=True, 148 | top_p=top_p, 149 | num_return_sequences=1, 150 | eos_token_id=self.tokenizer.sep_token_id, 151 | pad_token_id=self.tokenizer.pad_token_id, 152 | repetition_penalty=1.1, 153 | **model_kwargs) 154 | else: 155 | #beam search 156 | outputs = self.text_decoder.generate(input_ids=input_ids, 157 | max_length=max_length, 158 | min_length=min_length, 159 | num_beams=num_beams, 160 | eos_token_id=self.tokenizer.sep_token_id, 161 | pad_token_id=self.tokenizer.pad_token_id, 162 | repetition_penalty=repetition_penalty, 163 | **model_kwargs) 164 | 165 | captions = [] 166 | for output in outputs: 167 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 168 | captions.append(caption[len(self.prompt):]) 169 | return captions 170 | 171 | 172 | def blip_decoder(pretrained='',**kwargs): 173 | model = BLIP_Decoder(**kwargs) 174 | if pretrained: 175 | model,msg = load_checkpoint(model,pretrained) 176 | assert(len(msg.missing_keys)==0) 177 | return model 178 | 179 | def blip_feature_extractor(pretrained='',**kwargs): 180 | model = BLIP_Base(**kwargs) 181 | if pretrained: 182 | model,msg = load_checkpoint(model,pretrained) 183 | assert(len(msg.missing_keys)==0) 184 | return model 185 | 186 | def init_tokenizer(): 187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 191 | return tokenizer 192 | 193 | 194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 195 | 196 | assert vit in ['base', 'large'], "vit parameter must be base or large" 197 | if vit=='base': 198 | vision_width = 768 199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 201 | drop_path_rate=0 or drop_path_rate 202 | ) 203 | elif vit=='large': 204 | vision_width = 1024 205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 207 | drop_path_rate=0.1 or drop_path_rate 208 | ) 209 | return visual_encoder, vision_width 210 | 211 | def is_url(url_or_filename): 212 | parsed = urlparse(url_or_filename) 213 | return parsed.scheme in ("http", "https") 214 | 215 | def load_checkpoint(model,url_or_filename): 216 | if is_url(url_or_filename): 217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 218 | checkpoint = torch.load(cached_file, map_location='cpu') 219 | elif os.path.isfile(url_or_filename): 220 | checkpoint = torch.load(url_or_filename, map_location='cpu') 221 | else: 222 | raise RuntimeError('checkpoint url or path is invalid') 223 | 224 | state_dict = checkpoint['model'] 225 | 226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 229 | model.visual_encoder_m) 230 | for key in model.state_dict().keys(): 231 | if key in state_dict.keys(): 232 | if state_dict[key].shape!=model.state_dict()[key].shape: 233 | del state_dict[key] 234 | 235 | msg = model.load_state_dict(state_dict,strict=False) 236 | print('load checkpoint from %s'%url_or_filename) 237 | return model,msg 238 | 239 | -------------------------------------------------------------------------------- /spec/models/blip_utils/blip_itm.py: -------------------------------------------------------------------------------- 1 | from .med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from .blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | ): 19 | """ 20 | Args: 21 | med_config (str): path for the mixture of encoder-decoder model's configuration file 22 | image_size (int): input image size 23 | vit (str): model size of vision transformer 24 | """ 25 | super().__init__() 26 | 27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 28 | self.tokenizer = init_tokenizer() 29 | med_config = BertConfig.from_json_file(med_config) 30 | med_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 32 | 33 | text_width = self.text_encoder.config.hidden_size 34 | 35 | self.vision_proj = nn.Linear(vision_width, embed_dim) 36 | self.text_proj = nn.Linear(text_width, embed_dim) 37 | 38 | self.itm_head = nn.Linear(text_width, 2) 39 | 40 | 41 | def forward(self, image, caption, match_head='itm'): 42 | 43 | image_embeds = self.visual_encoder(image) 44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 45 | 46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 | return_tensors="pt").to(image.device) 48 | 49 | 50 | if match_head=='itm': 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = image_embeds, 54 | encoder_attention_mask = image_atts, 55 | return_dict = True, 56 | ) 57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 58 | return itm_output 59 | 60 | elif match_head=='itc': 61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 62 | return_dict = True, mode = 'text') 63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 65 | 66 | sim = image_feat @ text_feat.t() 67 | return sim 68 | 69 | 70 | def blip_itm(pretrained='',**kwargs): 71 | model = BLIP_ITM(**kwargs) 72 | if pretrained: 73 | model,msg = load_checkpoint(model,pretrained) 74 | assert(len(msg.missing_keys)==0) 75 | return model 76 | -------------------------------------------------------------------------------- /spec/models/blip_utils/blip_pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from .med import BertConfig, BertModel, BertLMHeadModel 9 | from transformers import BertTokenizer 10 | import transformers 11 | transformers.logging.set_verbosity_error() 12 | 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from .blip import create_vit, init_tokenizer, load_checkpoint 18 | 19 | class BLIP_Pretrain(nn.Module): 20 | def __init__(self, 21 | med_config = 'configs/bert_config.json', 22 | image_size = 224, 23 | vit = 'base', 24 | vit_grad_ckpt = False, 25 | vit_ckpt_layer = 0, 26 | embed_dim = 256, 27 | queue_size = 57600, 28 | momentum = 0.995, 29 | ): 30 | """ 31 | Args: 32 | med_config (str): path for the mixture of encoder-decoder model's configuration file 33 | image_size (int): input image size 34 | vit (str): model size of vision transformer 35 | """ 36 | super().__init__() 37 | 38 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) 39 | 40 | if vit=='base': 41 | checkpoint = torch.hub.load_state_dict_from_url( 42 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 43 | map_location="cpu", check_hash=True) 44 | state_dict = checkpoint["model"] 45 | msg = self.visual_encoder.load_state_dict(state_dict,strict=False) 46 | elif vit=='large': 47 | from timm.models.helpers import load_custom_pretrained 48 | from timm.models.vision_transformer import default_cfgs 49 | load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k']) 50 | 51 | self.tokenizer = init_tokenizer() 52 | encoder_config = BertConfig.from_json_file(med_config) 53 | encoder_config.encoder_width = vision_width 54 | self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False) 55 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 56 | 57 | text_width = self.text_encoder.config.hidden_size 58 | 59 | self.vision_proj = nn.Linear(vision_width, embed_dim) 60 | self.text_proj = nn.Linear(text_width, embed_dim) 61 | 62 | self.itm_head = nn.Linear(text_width, 2) 63 | 64 | # create momentum encoders 65 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 66 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 67 | self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False) 68 | self.text_proj_m = nn.Linear(text_width, embed_dim) 69 | 70 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 71 | [self.vision_proj,self.vision_proj_m], 72 | [self.text_encoder,self.text_encoder_m], 73 | [self.text_proj,self.text_proj_m], 74 | ] 75 | self.copy_params() 76 | 77 | # create the queue 78 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 79 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 80 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 81 | 82 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 83 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 84 | 85 | self.queue_size = queue_size 86 | self.momentum = momentum 87 | self.temp = nn.Parameter(0.07*torch.ones([])) 88 | 89 | # create the decoder 90 | decoder_config = BertConfig.from_json_file(med_config) 91 | decoder_config.encoder_width = vision_width 92 | self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config) 93 | self.text_decoder.resize_token_embeddings(len(self.tokenizer)) 94 | tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention') 95 | 96 | 97 | def forward(self, image, caption, alpha): 98 | with torch.no_grad(): 99 | self.temp.clamp_(0.001,0.5) 100 | 101 | image_embeds = self.visual_encoder(image) 102 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 103 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 104 | 105 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30, 106 | return_tensors="pt").to(image.device) 107 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 108 | return_dict = True, mode = 'text') 109 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 110 | 111 | # get momentum features 112 | with torch.no_grad(): 113 | self._momentum_update() 114 | image_embeds_m = self.visual_encoder_m(image) 115 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 116 | image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 117 | 118 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 119 | return_dict = True, mode = 'text') 120 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 121 | text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 122 | 123 | sim_i2t_m = image_feat_m @ text_feat_all / self.temp 124 | sim_t2i_m = text_feat_m @ image_feat_all / self.temp 125 | 126 | sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) 127 | sim_targets.fill_diagonal_(1) 128 | 129 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 130 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 131 | 132 | sim_i2t = image_feat @ text_feat_all / self.temp 133 | sim_t2i = text_feat @ image_feat_all / self.temp 134 | 135 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 136 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 137 | 138 | loss_ita = (loss_i2t+loss_t2i)/2 139 | 140 | self._dequeue_and_enqueue(image_feat_m, text_feat_m) 141 | 142 | ###============== Image-text Matching ===================### 143 | encoder_input_ids = text.input_ids.clone() 144 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 145 | 146 | # forward the positve image-text pair 147 | bs = image.size(0) 148 | output_pos = self.text_encoder(encoder_input_ids, 149 | attention_mask = text.attention_mask, 150 | encoder_hidden_states = image_embeds, 151 | encoder_attention_mask = image_atts, 152 | return_dict = True, 153 | ) 154 | with torch.no_grad(): 155 | weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 156 | weights_t2i.fill_diagonal_(0) 157 | weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4 158 | weights_i2t.fill_diagonal_(0) 159 | 160 | # select a negative image for each text 161 | image_embeds_neg = [] 162 | for b in range(bs): 163 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 164 | image_embeds_neg.append(image_embeds[neg_idx]) 165 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 166 | 167 | # select a negative text for each image 168 | text_ids_neg = [] 169 | text_atts_neg = [] 170 | for b in range(bs): 171 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 172 | text_ids_neg.append(encoder_input_ids[neg_idx]) 173 | text_atts_neg.append(text.attention_mask[neg_idx]) 174 | 175 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 176 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 177 | 178 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 179 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 180 | 181 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 182 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 183 | 184 | output_neg = self.text_encoder(text_ids_all, 185 | attention_mask = text_atts_all, 186 | encoder_hidden_states = image_embeds_all, 187 | encoder_attention_mask = image_atts_all, 188 | return_dict = True, 189 | ) 190 | 191 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 192 | vl_output = self.itm_head(vl_embeddings) 193 | 194 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 195 | dim=0).to(image.device) 196 | loss_itm = F.cross_entropy(vl_output, itm_labels) 197 | 198 | ##================= LM ========================## 199 | decoder_input_ids = text.input_ids.clone() 200 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 201 | decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100) 202 | 203 | decoder_output = self.text_decoder(decoder_input_ids, 204 | attention_mask = text.attention_mask, 205 | encoder_hidden_states = image_embeds, 206 | encoder_attention_mask = image_atts, 207 | labels = decoder_targets, 208 | return_dict = True, 209 | ) 210 | 211 | loss_lm = decoder_output.loss 212 | return loss_ita, loss_itm, loss_lm 213 | 214 | 215 | 216 | @torch.no_grad() 217 | def copy_params(self): 218 | for model_pair in self.model_pairs: 219 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 220 | param_m.data.copy_(param.data) # initialize 221 | param_m.requires_grad = False # not update by gradient 222 | 223 | 224 | @torch.no_grad() 225 | def _momentum_update(self): 226 | for model_pair in self.model_pairs: 227 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 228 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 229 | 230 | 231 | @torch.no_grad() 232 | def _dequeue_and_enqueue(self, image_feat, text_feat): 233 | # gather keys before updating queue 234 | image_feats = concat_all_gather(image_feat) 235 | text_feats = concat_all_gather(text_feat) 236 | 237 | batch_size = image_feats.shape[0] 238 | 239 | ptr = int(self.queue_ptr) 240 | assert self.queue_size % batch_size == 0 # for simplicity 241 | 242 | # replace the keys at ptr (dequeue and enqueue) 243 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 244 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 245 | ptr = (ptr + batch_size) % self.queue_size # move pointer 246 | 247 | self.queue_ptr[0] = ptr 248 | 249 | 250 | def blip_pretrain(**kwargs): 251 | model = BLIP_Pretrain(**kwargs) 252 | return model 253 | 254 | 255 | @torch.no_grad() 256 | def concat_all_gather(tensor): 257 | """ 258 | Performs all_gather operation on the provided tensors. 259 | *** Warning ***: torch.distributed.all_gather has no gradient. 260 | """ 261 | tensors_gather = [torch.ones_like(tensor) 262 | for _ in range(torch.distributed.get_world_size())] 263 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 264 | 265 | output = torch.cat(tensors_gather, dim=0) 266 | return output 267 | 268 | 269 | from typing import List 270 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): 271 | uninitialized_encoder_weights: List[str] = [] 272 | if decoder.__class__ != encoder.__class__: 273 | logger.info( 274 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 275 | ) 276 | 277 | def tie_encoder_to_decoder_recursively( 278 | decoder_pointer: nn.Module, 279 | encoder_pointer: nn.Module, 280 | module_name: str, 281 | uninitialized_encoder_weights: List[str], 282 | skip_key: str, 283 | depth=0, 284 | ): 285 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 286 | encoder_pointer, nn.Module 287 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 288 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 289 | assert hasattr(encoder_pointer, "weight") 290 | encoder_pointer.weight = decoder_pointer.weight 291 | if hasattr(decoder_pointer, "bias"): 292 | assert hasattr(encoder_pointer, "bias") 293 | encoder_pointer.bias = decoder_pointer.bias 294 | print(module_name+' is tied') 295 | return 296 | 297 | encoder_modules = encoder_pointer._modules 298 | decoder_modules = decoder_pointer._modules 299 | if len(decoder_modules) > 0: 300 | assert ( 301 | len(encoder_modules) > 0 302 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 303 | 304 | all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) 305 | encoder_layer_pos = 0 306 | for name, module in decoder_modules.items(): 307 | if name.isdigit(): 308 | encoder_name = str(int(name) + encoder_layer_pos) 309 | decoder_name = name 310 | if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( 311 | encoder_modules 312 | ) != len(decoder_modules): 313 | # this can happen if the name corresponds to the position in a list module list of layers 314 | # in this case the decoder has added a cross-attention that the encoder does not have 315 | # thus skip this step and subtract one layer pos from encoder 316 | encoder_layer_pos -= 1 317 | continue 318 | elif name not in encoder_modules: 319 | continue 320 | elif depth > 500: 321 | raise ValueError( 322 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 323 | ) 324 | else: 325 | decoder_name = encoder_name = name 326 | tie_encoder_to_decoder_recursively( 327 | decoder_modules[decoder_name], 328 | encoder_modules[encoder_name], 329 | module_name + "/" + name, 330 | uninitialized_encoder_weights, 331 | skip_key, 332 | depth=depth + 1, 333 | ) 334 | all_encoder_weights.remove(module_name + "/" + encoder_name) 335 | 336 | uninitialized_encoder_weights += list(all_encoder_weights) 337 | 338 | # tie weights recursively 339 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key) -------------------------------------------------------------------------------- /spec/models/blip_utils/blip_retrieval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from .med import BertConfig, BertModel 9 | from transformers import BertTokenizer 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | 15 | from .blip import create_vit, init_tokenizer, load_checkpoint 16 | 17 | class BLIP_Retrieval(nn.Module): 18 | def __init__(self, 19 | med_config = 'configs/med_config.json', 20 | image_size = 384, 21 | vit = 'base', 22 | vit_grad_ckpt = False, 23 | vit_ckpt_layer = 0, 24 | embed_dim = 256, 25 | queue_size = 57600, 26 | momentum = 0.995, 27 | negative_all_rank = False, 28 | ): 29 | """ 30 | Args: 31 | med_config (str): path for the mixture of encoder-decoder model's configuration file 32 | image_size (int): input image size 33 | vit (str): model size of vision transformer 34 | """ 35 | super().__init__() 36 | 37 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 38 | self.tokenizer = init_tokenizer() 39 | med_config = BertConfig.from_json_file(med_config) 40 | med_config.encoder_width = vision_width 41 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 42 | 43 | text_width = self.text_encoder.config.hidden_size 44 | 45 | self.vision_proj = nn.Linear(vision_width, embed_dim) 46 | self.text_proj = nn.Linear(text_width, embed_dim) 47 | 48 | self.itm_head = nn.Linear(text_width, 2) 49 | 50 | # create momentum encoders 51 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 52 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 53 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) 54 | self.text_proj_m = nn.Linear(text_width, embed_dim) 55 | 56 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 57 | [self.vision_proj,self.vision_proj_m], 58 | [self.text_encoder,self.text_encoder_m], 59 | [self.text_proj,self.text_proj_m], 60 | ] 61 | self.copy_params() 62 | 63 | # create the queue 64 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 65 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 66 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) 67 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) 68 | 69 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 70 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 71 | 72 | self.queue_size = queue_size 73 | self.momentum = momentum 74 | self.temp = nn.Parameter(0.07*torch.ones([])) 75 | 76 | self.negative_all_rank = negative_all_rank 77 | 78 | 79 | def forward(self, image, caption, alpha, idx): 80 | with torch.no_grad(): 81 | self.temp.clamp_(0.001,0.5) 82 | 83 | image_embeds = self.visual_encoder(image) 84 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 85 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 86 | 87 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 88 | return_tensors="pt").to(image.device) 89 | 90 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 91 | return_dict = True, mode = 'text') 92 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 93 | 94 | ###============== Image-text Contrastive Learning ===================### 95 | idx = idx.view(-1,1) 96 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) 97 | pos_idx = torch.eq(idx, idx_all).float() 98 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) 99 | 100 | # get momentum features 101 | with torch.no_grad(): 102 | self._momentum_update() 103 | image_embeds_m = self.visual_encoder_m(image) 104 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 105 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 106 | 107 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 108 | return_dict = True, mode = 'text') 109 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 110 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 111 | 112 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp 113 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp 114 | 115 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 116 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 117 | 118 | sim_i2t = image_feat @ text_feat_m_all / self.temp 119 | sim_t2i = text_feat @ image_feat_m_all / self.temp 120 | 121 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 122 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 123 | 124 | loss_ita = (loss_i2t+loss_t2i)/2 125 | 126 | idxs = concat_all_gather(idx) 127 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) 128 | 129 | ###============== Image-text Matching ===================### 130 | encoder_input_ids = text.input_ids.clone() 131 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 132 | 133 | # forward the positve image-text pair 134 | bs = image.size(0) 135 | output_pos = self.text_encoder(encoder_input_ids, 136 | attention_mask = text.attention_mask, 137 | encoder_hidden_states = image_embeds, 138 | encoder_attention_mask = image_atts, 139 | return_dict = True, 140 | ) 141 | 142 | 143 | if self.negative_all_rank: 144 | # compute sample similarity 145 | with torch.no_grad(): 146 | mask = torch.eq(idx, idxs.t()) 147 | 148 | image_feat_world = concat_all_gather(image_feat) 149 | text_feat_world = concat_all_gather(text_feat) 150 | 151 | sim_i2t = image_feat @ text_feat_world.t() / self.temp 152 | sim_t2i = text_feat @ image_feat_world.t() / self.temp 153 | 154 | weights_i2t = F.softmax(sim_i2t,dim=1) 155 | weights_i2t.masked_fill_(mask, 0) 156 | 157 | weights_t2i = F.softmax(sim_t2i,dim=1) 158 | weights_t2i.masked_fill_(mask, 0) 159 | 160 | image_embeds_world = all_gather_with_grad(image_embeds) 161 | 162 | # select a negative image (from all ranks) for each text 163 | image_embeds_neg = [] 164 | for b in range(bs): 165 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 166 | image_embeds_neg.append(image_embeds_world[neg_idx]) 167 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 168 | 169 | # select a negative text (from all ranks) for each image 170 | input_ids_world = concat_all_gather(encoder_input_ids) 171 | att_mask_world = concat_all_gather(text.attention_mask) 172 | 173 | text_ids_neg = [] 174 | text_atts_neg = [] 175 | for b in range(bs): 176 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 177 | text_ids_neg.append(input_ids_world[neg_idx]) 178 | text_atts_neg.append(att_mask_world[neg_idx]) 179 | 180 | else: 181 | with torch.no_grad(): 182 | mask = torch.eq(idx, idx.t()) 183 | 184 | sim_i2t = image_feat @ text_feat.t() / self.temp 185 | sim_t2i = text_feat @ image_feat.t() / self.temp 186 | 187 | weights_i2t = F.softmax(sim_i2t,dim=1) 188 | weights_i2t.masked_fill_(mask, 0) 189 | 190 | weights_t2i = F.softmax(sim_t2i,dim=1) 191 | weights_t2i.masked_fill_(mask, 0) 192 | 193 | # select a negative image (from same rank) for each text 194 | image_embeds_neg = [] 195 | for b in range(bs): 196 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 197 | image_embeds_neg.append(image_embeds[neg_idx]) 198 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 199 | 200 | # select a negative text (from same rank) for each image 201 | text_ids_neg = [] 202 | text_atts_neg = [] 203 | for b in range(bs): 204 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 205 | text_ids_neg.append(encoder_input_ids[neg_idx]) 206 | text_atts_neg.append(text.attention_mask[neg_idx]) 207 | 208 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 209 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 210 | 211 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 212 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 213 | 214 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 215 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 216 | 217 | output_neg = self.text_encoder(text_ids_all, 218 | attention_mask = text_atts_all, 219 | encoder_hidden_states = image_embeds_all, 220 | encoder_attention_mask = image_atts_all, 221 | return_dict = True, 222 | ) 223 | 224 | 225 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 226 | vl_output = self.itm_head(vl_embeddings) 227 | 228 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 229 | dim=0).to(image.device) 230 | loss_itm = F.cross_entropy(vl_output, itm_labels) 231 | 232 | return loss_ita, loss_itm 233 | 234 | 235 | @torch.no_grad() 236 | def copy_params(self): 237 | for model_pair in self.model_pairs: 238 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 239 | param_m.data.copy_(param.data) # initialize 240 | param_m.requires_grad = False # not update by gradient 241 | 242 | 243 | @torch.no_grad() 244 | def _momentum_update(self): 245 | for model_pair in self.model_pairs: 246 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 247 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 248 | 249 | 250 | @torch.no_grad() 251 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): 252 | # gather keys before updating queue 253 | image_feats = concat_all_gather(image_feat) 254 | text_feats = concat_all_gather(text_feat) 255 | 256 | 257 | batch_size = image_feats.shape[0] 258 | 259 | ptr = int(self.ptr_queue) 260 | assert self.queue_size % batch_size == 0 # for simplicity 261 | 262 | # replace the keys at ptr (dequeue and enqueue) 263 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 264 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 265 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 266 | ptr = (ptr + batch_size) % self.queue_size # move pointer 267 | 268 | self.ptr_queue[0] = ptr 269 | 270 | 271 | def blip_retrieval(pretrained='',**kwargs): 272 | model = BLIP_Retrieval(**kwargs) 273 | if pretrained: 274 | model,msg = load_checkpoint(model,pretrained) 275 | print("missing keys:") 276 | print(msg.missing_keys) 277 | return model 278 | 279 | 280 | @torch.no_grad() 281 | def concat_all_gather(tensor): 282 | """ 283 | Performs all_gather operation on the provided tensors. 284 | *** Warning ***: torch.distributed.all_gather has no gradient. 285 | """ 286 | tensors_gather = [torch.ones_like(tensor) 287 | for _ in range(torch.distributed.get_world_size())] 288 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 289 | 290 | output = torch.cat(tensors_gather, dim=0) 291 | return output 292 | 293 | 294 | class GatherLayer(torch.autograd.Function): 295 | """ 296 | Gather tensors from all workers with support for backward propagation: 297 | This implementation does not cut the gradients as torch.distributed.all_gather does. 298 | """ 299 | 300 | @staticmethod 301 | def forward(ctx, x): 302 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 303 | torch.distributed.all_gather(output, x) 304 | return tuple(output) 305 | 306 | @staticmethod 307 | def backward(ctx, *grads): 308 | all_gradients = torch.stack(grads) 309 | torch.distributed.all_reduce(all_gradients) 310 | return all_gradients[torch.distributed.get_rank()] 311 | 312 | 313 | def all_gather_with_grad(tensors): 314 | """ 315 | Performs all_gather operation on the provided tensors. 316 | Graph remains connected for backward grad computation. 317 | """ 318 | # Queue the gathered tensors 319 | world_size = torch.distributed.get_world_size() 320 | # There is no need for reduction in the single-proc case 321 | if world_size == 1: 322 | return tensors 323 | 324 | tensor_all = GatherLayer.apply(tensors) 325 | 326 | return torch.cat(tensor_all, dim=0) 327 | -------------------------------------------------------------------------------- /spec/models/blip_utils/med.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on huggingface code base 8 | * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert 9 | ''' 10 | 11 | import math 12 | import os 13 | import warnings 14 | from dataclasses import dataclass 15 | from typing import Optional, Tuple 16 | 17 | import torch 18 | from torch import Tensor, device, dtype, nn 19 | import torch.utils.checkpoint 20 | from torch import nn 21 | from torch.nn import CrossEntropyLoss 22 | import torch.nn.functional as F 23 | 24 | from transformers.activations import ACT2FN 25 | from transformers.file_utils import ( 26 | ModelOutput, 27 | ) 28 | from transformers.modeling_outputs import ( 29 | BaseModelOutputWithPastAndCrossAttentions, 30 | BaseModelOutputWithPoolingAndCrossAttentions, 31 | CausalLMOutputWithCrossAttentions, 32 | MaskedLMOutput, 33 | MultipleChoiceModelOutput, 34 | NextSentencePredictorOutput, 35 | QuestionAnsweringModelOutput, 36 | SequenceClassifierOutput, 37 | TokenClassifierOutput, 38 | ) 39 | from transformers.modeling_utils import ( 40 | PreTrainedModel, 41 | apply_chunking_to_forward, 42 | find_pruneable_heads_and_indices, 43 | prune_linear_layer, 44 | ) 45 | from transformers.utils import logging 46 | from transformers.models.bert.configuration_bert import BertConfig 47 | 48 | 49 | logger = logging.get_logger(__name__) 50 | 51 | 52 | class BertEmbeddings(nn.Module): 53 | """Construct the embeddings from word and position embeddings.""" 54 | 55 | def __init__(self, config): 56 | super().__init__() 57 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 58 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 59 | 60 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 61 | # any TensorFlow checkpoint file 62 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 63 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 64 | 65 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 66 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 67 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 68 | 69 | self.config = config 70 | 71 | def forward( 72 | self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 73 | ): 74 | if input_ids is not None: 75 | input_shape = input_ids.size() 76 | else: 77 | input_shape = inputs_embeds.size()[:-1] 78 | 79 | seq_length = input_shape[1] 80 | 81 | if position_ids is None: 82 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 83 | 84 | if inputs_embeds is None: 85 | inputs_embeds = self.word_embeddings(input_ids) 86 | 87 | embeddings = inputs_embeds 88 | 89 | if self.position_embedding_type == "absolute": 90 | position_embeddings = self.position_embeddings(position_ids) 91 | embeddings += position_embeddings 92 | embeddings = self.LayerNorm(embeddings) 93 | embeddings = self.dropout(embeddings) 94 | return embeddings 95 | 96 | 97 | class BertSelfAttention(nn.Module): 98 | def __init__(self, config, is_cross_attention): 99 | super().__init__() 100 | self.config = config 101 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 102 | raise ValueError( 103 | "The hidden size (%d) is not a multiple of the number of attention " 104 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 105 | ) 106 | 107 | self.num_attention_heads = config.num_attention_heads 108 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 109 | self.all_head_size = self.num_attention_heads * self.attention_head_size 110 | 111 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 112 | if is_cross_attention: 113 | self.key = nn.Linear(config.encoder_width, self.all_head_size) 114 | self.value = nn.Linear(config.encoder_width, self.all_head_size) 115 | else: 116 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 117 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 118 | 119 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 120 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 121 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 122 | self.max_position_embeddings = config.max_position_embeddings 123 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 124 | self.save_attention = False 125 | 126 | def save_attn_gradients(self, attn_gradients): 127 | self.attn_gradients = attn_gradients 128 | 129 | def get_attn_gradients(self): 130 | return self.attn_gradients 131 | 132 | def save_attention_map(self, attention_map): 133 | self.attention_map = attention_map 134 | 135 | def get_attention_map(self): 136 | return self.attention_map 137 | 138 | def transpose_for_scores(self, x): 139 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 140 | x = x.view(*new_x_shape) 141 | return x.permute(0, 2, 1, 3) 142 | 143 | def forward( 144 | self, 145 | hidden_states, 146 | attention_mask=None, 147 | head_mask=None, 148 | encoder_hidden_states=None, 149 | encoder_attention_mask=None, 150 | past_key_value=None, 151 | output_attentions=False, 152 | ): 153 | mixed_query_layer = self.query(hidden_states) 154 | 155 | # If this is instantiated as a cross-attention module, the keys 156 | # and values come from an encoder; the attention mask needs to be 157 | # such that the encoder's padding tokens are not attended to. 158 | is_cross_attention = encoder_hidden_states is not None 159 | 160 | if is_cross_attention: 161 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 162 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 163 | attention_mask = encoder_attention_mask 164 | elif past_key_value is not None: 165 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 166 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 167 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 168 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 169 | else: 170 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 171 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 172 | 173 | query_layer = self.transpose_for_scores(mixed_query_layer) 174 | 175 | past_key_value = (key_layer, value_layer) 176 | 177 | # Take the dot product between "query" and "key" to get the raw attention scores. 178 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 179 | 180 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 181 | seq_length = hidden_states.size()[1] 182 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 183 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 184 | distance = position_ids_l - position_ids_r 185 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 186 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 187 | 188 | if self.position_embedding_type == "relative_key": 189 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 190 | attention_scores = attention_scores + relative_position_scores 191 | elif self.position_embedding_type == "relative_key_query": 192 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 193 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 194 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 195 | 196 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 197 | if attention_mask is not None: 198 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 199 | attention_scores = attention_scores + attention_mask 200 | 201 | # Normalize the attention scores to probabilities. 202 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 203 | 204 | if is_cross_attention and self.save_attention: 205 | self.save_attention_map(attention_probs) 206 | attention_probs.register_hook(self.save_attn_gradients) 207 | 208 | # This is actually dropping out entire tokens to attend to, which might 209 | # seem a bit unusual, but is taken from the original Transformer paper. 210 | attention_probs_dropped = self.dropout(attention_probs) 211 | 212 | # Mask heads if we want to 213 | if head_mask is not None: 214 | attention_probs_dropped = attention_probs_dropped * head_mask 215 | 216 | context_layer = torch.matmul(attention_probs_dropped, value_layer) 217 | 218 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 219 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 220 | context_layer = context_layer.view(*new_context_layer_shape) 221 | 222 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 223 | 224 | outputs = outputs + (past_key_value,) 225 | return outputs 226 | 227 | 228 | class BertSelfOutput(nn.Module): 229 | def __init__(self, config): 230 | super().__init__() 231 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 232 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 233 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 234 | 235 | def forward(self, hidden_states, input_tensor): 236 | hidden_states = self.dense(hidden_states) 237 | hidden_states = self.dropout(hidden_states) 238 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 239 | return hidden_states 240 | 241 | 242 | class BertAttention(nn.Module): 243 | def __init__(self, config, is_cross_attention=False): 244 | super().__init__() 245 | self.self = BertSelfAttention(config, is_cross_attention) 246 | self.output = BertSelfOutput(config) 247 | self.pruned_heads = set() 248 | 249 | def prune_heads(self, heads): 250 | if len(heads) == 0: 251 | return 252 | heads, index = find_pruneable_heads_and_indices( 253 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 254 | ) 255 | 256 | # Prune linear layers 257 | self.self.query = prune_linear_layer(self.self.query, index) 258 | self.self.key = prune_linear_layer(self.self.key, index) 259 | self.self.value = prune_linear_layer(self.self.value, index) 260 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 261 | 262 | # Update hyper params and store pruned heads 263 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 264 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 265 | self.pruned_heads = self.pruned_heads.union(heads) 266 | 267 | def forward( 268 | self, 269 | hidden_states, 270 | attention_mask=None, 271 | head_mask=None, 272 | encoder_hidden_states=None, 273 | encoder_attention_mask=None, 274 | past_key_value=None, 275 | output_attentions=False, 276 | ): 277 | self_outputs = self.self( 278 | hidden_states, 279 | attention_mask, 280 | head_mask, 281 | encoder_hidden_states, 282 | encoder_attention_mask, 283 | past_key_value, 284 | output_attentions, 285 | ) 286 | attention_output = self.output(self_outputs[0], hidden_states) 287 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 288 | return outputs 289 | 290 | 291 | class BertIntermediate(nn.Module): 292 | def __init__(self, config): 293 | super().__init__() 294 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 295 | if isinstance(config.hidden_act, str): 296 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 297 | else: 298 | self.intermediate_act_fn = config.hidden_act 299 | 300 | def forward(self, hidden_states): 301 | hidden_states = self.dense(hidden_states) 302 | hidden_states = self.intermediate_act_fn(hidden_states) 303 | return hidden_states 304 | 305 | 306 | class BertOutput(nn.Module): 307 | def __init__(self, config): 308 | super().__init__() 309 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 310 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 311 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 312 | 313 | def forward(self, hidden_states, input_tensor): 314 | hidden_states = self.dense(hidden_states) 315 | hidden_states = self.dropout(hidden_states) 316 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 317 | return hidden_states 318 | 319 | 320 | class BertLayer(nn.Module): 321 | def __init__(self, config, layer_num): 322 | super().__init__() 323 | self.config = config 324 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 325 | self.seq_len_dim = 1 326 | self.attention = BertAttention(config) 327 | self.layer_num = layer_num 328 | if self.config.add_cross_attention: 329 | self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) 330 | self.intermediate = BertIntermediate(config) 331 | self.output = BertOutput(config) 332 | 333 | def forward( 334 | self, 335 | hidden_states, 336 | attention_mask=None, 337 | head_mask=None, 338 | encoder_hidden_states=None, 339 | encoder_attention_mask=None, 340 | past_key_value=None, 341 | output_attentions=False, 342 | mode=None, 343 | ): 344 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 345 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 346 | self_attention_outputs = self.attention( 347 | hidden_states, 348 | attention_mask, 349 | head_mask, 350 | output_attentions=output_attentions, 351 | past_key_value=self_attn_past_key_value, 352 | ) 353 | attention_output = self_attention_outputs[0] 354 | 355 | outputs = self_attention_outputs[1:-1] 356 | present_key_value = self_attention_outputs[-1] 357 | 358 | if mode=='multimodal': 359 | assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" 360 | 361 | cross_attention_outputs = self.crossattention( 362 | attention_output, 363 | attention_mask, 364 | head_mask, 365 | encoder_hidden_states, 366 | encoder_attention_mask, 367 | output_attentions=output_attentions, 368 | ) 369 | attention_output = cross_attention_outputs[0] 370 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 371 | layer_output = apply_chunking_to_forward( 372 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 373 | ) 374 | outputs = (layer_output,) + outputs 375 | 376 | outputs = outputs + (present_key_value,) 377 | 378 | return outputs 379 | 380 | def feed_forward_chunk(self, attention_output): 381 | intermediate_output = self.intermediate(attention_output) 382 | layer_output = self.output(intermediate_output, attention_output) 383 | return layer_output 384 | 385 | 386 | class BertEncoder(nn.Module): 387 | def __init__(self, config): 388 | super().__init__() 389 | self.config = config 390 | self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) 391 | self.gradient_checkpointing = False 392 | 393 | def forward( 394 | self, 395 | hidden_states, 396 | attention_mask=None, 397 | head_mask=None, 398 | encoder_hidden_states=None, 399 | encoder_attention_mask=None, 400 | past_key_values=None, 401 | use_cache=None, 402 | output_attentions=False, 403 | output_hidden_states=False, 404 | return_dict=True, 405 | mode='multimodal', 406 | ): 407 | all_hidden_states = () if output_hidden_states else None 408 | all_self_attentions = () if output_attentions else None 409 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 410 | 411 | next_decoder_cache = () if use_cache else None 412 | 413 | for i in range(self.config.num_hidden_layers): 414 | layer_module = self.layer[i] 415 | if output_hidden_states: 416 | all_hidden_states = all_hidden_states + (hidden_states,) 417 | 418 | layer_head_mask = head_mask[i] if head_mask is not None else None 419 | past_key_value = past_key_values[i] if past_key_values is not None else None 420 | 421 | if self.gradient_checkpointing and self.training: 422 | 423 | if use_cache: 424 | logger.warn( 425 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 426 | ) 427 | use_cache = False 428 | 429 | def create_custom_forward(module): 430 | def custom_forward(*inputs): 431 | return module(*inputs, past_key_value, output_attentions) 432 | 433 | return custom_forward 434 | 435 | layer_outputs = torch.utils.checkpoint.checkpoint( 436 | create_custom_forward(layer_module), 437 | hidden_states, 438 | attention_mask, 439 | layer_head_mask, 440 | encoder_hidden_states, 441 | encoder_attention_mask, 442 | mode=mode, 443 | ) 444 | else: 445 | layer_outputs = layer_module( 446 | hidden_states, 447 | attention_mask, 448 | layer_head_mask, 449 | encoder_hidden_states, 450 | encoder_attention_mask, 451 | past_key_value, 452 | output_attentions, 453 | mode=mode, 454 | ) 455 | 456 | hidden_states = layer_outputs[0] 457 | if use_cache: 458 | next_decoder_cache += (layer_outputs[-1],) 459 | if output_attentions: 460 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 461 | 462 | if output_hidden_states: 463 | all_hidden_states = all_hidden_states + (hidden_states,) 464 | 465 | if not return_dict: 466 | return tuple( 467 | v 468 | for v in [ 469 | hidden_states, 470 | next_decoder_cache, 471 | all_hidden_states, 472 | all_self_attentions, 473 | all_cross_attentions, 474 | ] 475 | if v is not None 476 | ) 477 | return BaseModelOutputWithPastAndCrossAttentions( 478 | last_hidden_state=hidden_states, 479 | past_key_values=next_decoder_cache, 480 | hidden_states=all_hidden_states, 481 | attentions=all_self_attentions, 482 | cross_attentions=all_cross_attentions, 483 | ) 484 | 485 | 486 | class BertPooler(nn.Module): 487 | def __init__(self, config): 488 | super().__init__() 489 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 490 | self.activation = nn.Tanh() 491 | 492 | def forward(self, hidden_states): 493 | # We "pool" the model by simply taking the hidden state corresponding 494 | # to the first token. 495 | first_token_tensor = hidden_states[:, 0] 496 | pooled_output = self.dense(first_token_tensor) 497 | pooled_output = self.activation(pooled_output) 498 | return pooled_output 499 | 500 | 501 | class BertPredictionHeadTransform(nn.Module): 502 | def __init__(self, config): 503 | super().__init__() 504 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 505 | if isinstance(config.hidden_act, str): 506 | self.transform_act_fn = ACT2FN[config.hidden_act] 507 | else: 508 | self.transform_act_fn = config.hidden_act 509 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 510 | 511 | def forward(self, hidden_states): 512 | hidden_states = self.dense(hidden_states) 513 | hidden_states = self.transform_act_fn(hidden_states) 514 | hidden_states = self.LayerNorm(hidden_states) 515 | return hidden_states 516 | 517 | 518 | class BertLMPredictionHead(nn.Module): 519 | def __init__(self, config): 520 | super().__init__() 521 | self.transform = BertPredictionHeadTransform(config) 522 | 523 | # The output weights are the same as the input embeddings, but there is 524 | # an output-only bias for each token. 525 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 526 | 527 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 528 | 529 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 530 | self.decoder.bias = self.bias 531 | 532 | def forward(self, hidden_states): 533 | hidden_states = self.transform(hidden_states) 534 | hidden_states = self.decoder(hidden_states) 535 | return hidden_states 536 | 537 | 538 | class BertOnlyMLMHead(nn.Module): 539 | def __init__(self, config): 540 | super().__init__() 541 | self.predictions = BertLMPredictionHead(config) 542 | 543 | def forward(self, sequence_output): 544 | prediction_scores = self.predictions(sequence_output) 545 | return prediction_scores 546 | 547 | 548 | class BertPreTrainedModel(PreTrainedModel): 549 | """ 550 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 551 | models. 552 | """ 553 | 554 | config_class = BertConfig 555 | base_model_prefix = "bert" 556 | _keys_to_ignore_on_load_missing = [r"position_ids"] 557 | 558 | def _init_weights(self, module): 559 | """ Initialize the weights """ 560 | if isinstance(module, (nn.Linear, nn.Embedding)): 561 | # Slightly different from the TF version which uses truncated_normal for initialization 562 | # cf https://github.com/pytorch/pytorch/pull/5617 563 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 564 | elif isinstance(module, nn.LayerNorm): 565 | module.bias.data.zero_() 566 | module.weight.data.fill_(1.0) 567 | if isinstance(module, nn.Linear) and module.bias is not None: 568 | module.bias.data.zero_() 569 | 570 | 571 | class BertModel(BertPreTrainedModel): 572 | """ 573 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 574 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is 575 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 576 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 577 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an 578 | input to the forward pass. 579 | """ 580 | 581 | def __init__(self, config, add_pooling_layer=True): 582 | super().__init__(config) 583 | self.config = config 584 | 585 | self.embeddings = BertEmbeddings(config) 586 | 587 | self.encoder = BertEncoder(config) 588 | 589 | self.pooler = BertPooler(config) if add_pooling_layer else None 590 | 591 | self.init_weights() 592 | 593 | 594 | def get_input_embeddings(self): 595 | return self.embeddings.word_embeddings 596 | 597 | def set_input_embeddings(self, value): 598 | self.embeddings.word_embeddings = value 599 | 600 | def _prune_heads(self, heads_to_prune): 601 | """ 602 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 603 | class PreTrainedModel 604 | """ 605 | for layer, heads in heads_to_prune.items(): 606 | self.encoder.layer[layer].attention.prune_heads(heads) 607 | 608 | 609 | def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: 610 | """ 611 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. 612 | 613 | Arguments: 614 | attention_mask (:obj:`torch.Tensor`): 615 | Mask with ones indicating tokens to attend to, zeros for tokens to ignore. 616 | input_shape (:obj:`Tuple[int]`): 617 | The shape of the input to the model. 618 | device: (:obj:`torch.device`): 619 | The device of the input to the model. 620 | 621 | Returns: 622 | :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. 623 | """ 624 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 625 | # ourselves in which case we just need to make it broadcastable to all heads. 626 | if attention_mask.dim() == 3: 627 | extended_attention_mask = attention_mask[:, None, :, :] 628 | elif attention_mask.dim() == 2: 629 | # Provided a padding mask of dimensions [batch_size, seq_length] 630 | # - if the model is a decoder, apply a causal mask in addition to the padding mask 631 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] 632 | if is_decoder: 633 | batch_size, seq_length = input_shape 634 | 635 | seq_ids = torch.arange(seq_length, device=device) 636 | causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] 637 | # in case past_key_values are used we need to add a prefix ones mask to the causal mask 638 | # causal and attention masks must have same type with pytorch version < 1.3 639 | causal_mask = causal_mask.to(attention_mask.dtype) 640 | 641 | if causal_mask.shape[1] < attention_mask.shape[1]: 642 | prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] 643 | causal_mask = torch.cat( 644 | [ 645 | torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), 646 | causal_mask, 647 | ], 648 | axis=-1, 649 | ) 650 | 651 | extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] 652 | else: 653 | extended_attention_mask = attention_mask[:, None, None, :] 654 | else: 655 | raise ValueError( 656 | "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( 657 | input_shape, attention_mask.shape 658 | ) 659 | ) 660 | 661 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 662 | # masked positions, this operation will create a tensor which is 0.0 for 663 | # positions we want to attend and -10000.0 for masked positions. 664 | # Since we are adding it to the raw scores before the softmax, this is 665 | # effectively the same as removing these entirely. 666 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 667 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 668 | return extended_attention_mask 669 | 670 | def forward( 671 | self, 672 | input_ids=None, 673 | attention_mask=None, 674 | position_ids=None, 675 | head_mask=None, 676 | inputs_embeds=None, 677 | encoder_embeds=None, 678 | encoder_hidden_states=None, 679 | encoder_attention_mask=None, 680 | past_key_values=None, 681 | use_cache=None, 682 | output_attentions=None, 683 | output_hidden_states=None, 684 | return_dict=None, 685 | is_decoder=False, 686 | mode='multimodal', 687 | ): 688 | r""" 689 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 690 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 691 | the model is configured as a decoder. 692 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 693 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 694 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 695 | - 1 for tokens that are **not masked**, 696 | - 0 for tokens that are **masked**. 697 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 698 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 699 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 700 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 701 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 702 | use_cache (:obj:`bool`, `optional`): 703 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 704 | decoding (see :obj:`past_key_values`). 705 | """ 706 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 707 | output_hidden_states = ( 708 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 709 | ) 710 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 711 | 712 | if is_decoder: 713 | use_cache = use_cache if use_cache is not None else self.config.use_cache 714 | else: 715 | use_cache = False 716 | 717 | if input_ids is not None and inputs_embeds is not None: 718 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 719 | elif input_ids is not None: 720 | input_shape = input_ids.size() 721 | batch_size, seq_length = input_shape 722 | device = input_ids.device 723 | elif inputs_embeds is not None: 724 | input_shape = inputs_embeds.size()[:-1] 725 | batch_size, seq_length = input_shape 726 | device = inputs_embeds.device 727 | elif encoder_embeds is not None: 728 | input_shape = encoder_embeds.size()[:-1] 729 | batch_size, seq_length = input_shape 730 | device = encoder_embeds.device 731 | else: 732 | raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") 733 | 734 | # past_key_values_length 735 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 736 | 737 | if attention_mask is None: 738 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 739 | 740 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 741 | # ourselves in which case we just need to make it broadcastable to all heads. 742 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, 743 | device, is_decoder) 744 | 745 | # If a 2D or 3D attention mask is provided for the cross-attention 746 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 747 | if encoder_hidden_states is not None: 748 | if type(encoder_hidden_states) == list: 749 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() 750 | else: 751 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 752 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 753 | 754 | if type(encoder_attention_mask) == list: 755 | encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] 756 | elif encoder_attention_mask is None: 757 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 758 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 759 | else: 760 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 761 | else: 762 | encoder_extended_attention_mask = None 763 | 764 | # Prepare head mask if needed 765 | # 1.0 in head_mask indicate we keep the head 766 | # attention_probs has shape bsz x n_heads x N x N 767 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 768 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 769 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 770 | 771 | if encoder_embeds is None: 772 | embedding_output = self.embeddings( 773 | input_ids=input_ids, 774 | position_ids=position_ids, 775 | inputs_embeds=inputs_embeds, 776 | past_key_values_length=past_key_values_length, 777 | ) 778 | else: 779 | embedding_output = encoder_embeds 780 | 781 | encoder_outputs = self.encoder( 782 | embedding_output, 783 | attention_mask=extended_attention_mask, 784 | head_mask=head_mask, 785 | encoder_hidden_states=encoder_hidden_states, 786 | encoder_attention_mask=encoder_extended_attention_mask, 787 | past_key_values=past_key_values, 788 | use_cache=use_cache, 789 | output_attentions=output_attentions, 790 | output_hidden_states=output_hidden_states, 791 | return_dict=return_dict, 792 | mode=mode, 793 | ) 794 | sequence_output = encoder_outputs[0] 795 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 796 | 797 | if not return_dict: 798 | return (sequence_output, pooled_output) + encoder_outputs[1:] 799 | 800 | return BaseModelOutputWithPoolingAndCrossAttentions( 801 | last_hidden_state=sequence_output, 802 | pooler_output=pooled_output, 803 | past_key_values=encoder_outputs.past_key_values, 804 | hidden_states=encoder_outputs.hidden_states, 805 | attentions=encoder_outputs.attentions, 806 | cross_attentions=encoder_outputs.cross_attentions, 807 | ) 808 | 809 | 810 | 811 | class BertLMHeadModel(BertPreTrainedModel): 812 | 813 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 814 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 815 | 816 | def __init__(self, config): 817 | super().__init__(config) 818 | 819 | self.bert = BertModel(config, add_pooling_layer=False) 820 | self.cls = BertOnlyMLMHead(config) 821 | 822 | self.init_weights() 823 | 824 | def get_output_embeddings(self): 825 | return self.cls.predictions.decoder 826 | 827 | def set_output_embeddings(self, new_embeddings): 828 | self.cls.predictions.decoder = new_embeddings 829 | 830 | def forward( 831 | self, 832 | input_ids=None, 833 | attention_mask=None, 834 | position_ids=None, 835 | head_mask=None, 836 | inputs_embeds=None, 837 | encoder_hidden_states=None, 838 | encoder_attention_mask=None, 839 | labels=None, 840 | past_key_values=None, 841 | use_cache=None, 842 | output_attentions=None, 843 | output_hidden_states=None, 844 | return_dict=None, 845 | return_logits=False, 846 | is_decoder=True, 847 | reduction='mean', 848 | mode='multimodal', 849 | ): 850 | r""" 851 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 852 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 853 | the model is configured as a decoder. 854 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 855 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 856 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 857 | - 1 for tokens that are **not masked**, 858 | - 0 for tokens that are **masked**. 859 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 860 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 861 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are 862 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` 863 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 864 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 865 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 866 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 867 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 868 | use_cache (:obj:`bool`, `optional`): 869 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 870 | decoding (see :obj:`past_key_values`). 871 | Returns: 872 | Example:: 873 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig 874 | >>> import torch 875 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') 876 | >>> config = BertConfig.from_pretrained("bert-base-cased") 877 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) 878 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 879 | >>> outputs = model(**inputs) 880 | >>> prediction_logits = outputs.logits 881 | """ 882 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 883 | if labels is not None: 884 | use_cache = False 885 | 886 | outputs = self.bert( 887 | input_ids, 888 | attention_mask=attention_mask, 889 | position_ids=position_ids, 890 | head_mask=head_mask, 891 | inputs_embeds=inputs_embeds, 892 | encoder_hidden_states=encoder_hidden_states, 893 | encoder_attention_mask=encoder_attention_mask, 894 | past_key_values=past_key_values, 895 | use_cache=use_cache, 896 | output_attentions=output_attentions, 897 | output_hidden_states=output_hidden_states, 898 | return_dict=return_dict, 899 | is_decoder=is_decoder, 900 | mode=mode, 901 | ) 902 | 903 | sequence_output = outputs[0] 904 | prediction_scores = self.cls(sequence_output) 905 | 906 | if return_logits: 907 | return prediction_scores[:, :-1, :].contiguous() 908 | 909 | lm_loss = None 910 | if labels is not None: 911 | # we are doing next-token prediction; shift prediction scores and input ids by one 912 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 913 | labels = labels[:, 1:].contiguous() 914 | loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) 915 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 916 | if reduction=='none': 917 | lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) 918 | 919 | if not return_dict: 920 | output = (prediction_scores,) + outputs[2:] 921 | return ((lm_loss,) + output) if lm_loss is not None else output 922 | 923 | return CausalLMOutputWithCrossAttentions( 924 | loss=lm_loss, 925 | logits=prediction_scores, 926 | past_key_values=outputs.past_key_values, 927 | hidden_states=outputs.hidden_states, 928 | attentions=outputs.attentions, 929 | cross_attentions=outputs.cross_attentions, 930 | ) 931 | 932 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): 933 | input_shape = input_ids.shape 934 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 935 | if attention_mask is None: 936 | attention_mask = input_ids.new_ones(input_shape) 937 | 938 | # cut decoder_input_ids if past is used 939 | if past is not None: 940 | input_ids = input_ids[:, -1:] 941 | 942 | return { 943 | "input_ids": input_ids, 944 | "attention_mask": attention_mask, 945 | "past_key_values": past, 946 | "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), 947 | "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), 948 | "is_decoder": True, 949 | } 950 | 951 | def _reorder_cache(self, past, beam_idx): 952 | reordered_past = () 953 | for layer_past in past: 954 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 955 | return reordered_past 956 | -------------------------------------------------------------------------------- /spec/models/blip_utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 3 | """Decay the learning rate""" 4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr 7 | 8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 9 | """Warmup the learning rate""" 10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 15 | """Decay the learning rate""" 16 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | 20 | import numpy as np 21 | import io 22 | import os 23 | import time 24 | from collections import defaultdict, deque 25 | import datetime 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | class MetricLogger(object): 93 | def __init__(self, delimiter="\t"): 94 | self.meters = defaultdict(SmoothedValue) 95 | self.delimiter = delimiter 96 | 97 | def update(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if isinstance(v, torch.Tensor): 100 | v = v.item() 101 | assert isinstance(v, (float, int)) 102 | self.meters[k].update(v) 103 | 104 | def __getattr__(self, attr): 105 | if attr in self.meters: 106 | return self.meters[attr] 107 | if attr in self.__dict__: 108 | return self.__dict__[attr] 109 | raise AttributeError("'{}' object has no attribute '{}'".format( 110 | type(self).__name__, attr)) 111 | 112 | def __str__(self): 113 | loss_str = [] 114 | for name, meter in self.meters.items(): 115 | loss_str.append( 116 | "{}: {}".format(name, str(meter)) 117 | ) 118 | return self.delimiter.join(loss_str) 119 | 120 | def global_avg(self): 121 | loss_str = [] 122 | for name, meter in self.meters.items(): 123 | loss_str.append( 124 | "{}: {:.4f}".format(name, meter.global_avg) 125 | ) 126 | return self.delimiter.join(loss_str) 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | def add_meter(self, name, meter): 133 | self.meters[name] = meter 134 | 135 | def log_every(self, iterable, print_freq, header=None): 136 | i = 0 137 | if not header: 138 | header = '' 139 | start_time = time.time() 140 | end = time.time() 141 | iter_time = SmoothedValue(fmt='{avg:.4f}') 142 | data_time = SmoothedValue(fmt='{avg:.4f}') 143 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 144 | log_msg = [ 145 | header, 146 | '[{0' + space_fmt + '}/{1}]', 147 | 'eta: {eta}', 148 | '{meters}', 149 | 'time: {time}', 150 | 'data: {data}' 151 | ] 152 | if torch.cuda.is_available(): 153 | log_msg.append('max mem: {memory:.0f}') 154 | log_msg = self.delimiter.join(log_msg) 155 | MB = 1024.0 * 1024.0 156 | for obj in iterable: 157 | data_time.update(time.time() - end) 158 | yield obj 159 | iter_time.update(time.time() - end) 160 | if i % print_freq == 0 or i == len(iterable) - 1: 161 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 162 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 163 | if torch.cuda.is_available(): 164 | print(log_msg.format( 165 | i, len(iterable), eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB)) 169 | else: 170 | print(log_msg.format( 171 | i, len(iterable), eta=eta_string, 172 | meters=str(self), 173 | time=str(iter_time), data=str(data_time))) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print('{} Total time: {} ({:.4f} s / it)'.format( 179 | header, total_time_str, total_time / len(iterable))) 180 | 181 | 182 | class AttrDict(dict): 183 | def __init__(self, *args, **kwargs): 184 | super(AttrDict, self).__init__(*args, **kwargs) 185 | self.__dict__ = self 186 | 187 | 188 | def compute_acc(logits, label, reduction='mean'): 189 | ret = (torch.argmax(logits, dim=1) == label).float() 190 | if reduction == 'none': 191 | return ret.detach() 192 | elif reduction == 'mean': 193 | return ret.mean().item() 194 | 195 | def compute_n_params(model, return_str=True): 196 | tot = 0 197 | for p in model.parameters(): 198 | w = 1 199 | for x in p.shape: 200 | w *= x 201 | tot += w 202 | if return_str: 203 | if tot >= 1e6: 204 | return '{:.1f}M'.format(tot / 1e6) 205 | else: 206 | return '{:.1f}K'.format(tot / 1e3) 207 | else: 208 | return tot 209 | 210 | def setup_for_distributed(is_master): 211 | """ 212 | This function disables printing when not in master process 213 | """ 214 | import builtins as __builtin__ 215 | builtin_print = __builtin__.print 216 | 217 | def print(*args, **kwargs): 218 | force = kwargs.pop('force', False) 219 | if is_master or force: 220 | builtin_print(*args, **kwargs) 221 | 222 | __builtin__.print = print 223 | 224 | 225 | def is_dist_avail_and_initialized(): 226 | if not dist.is_available(): 227 | return False 228 | if not dist.is_initialized(): 229 | return False 230 | return True 231 | 232 | 233 | def get_world_size(): 234 | if not is_dist_avail_and_initialized(): 235 | return 1 236 | return dist.get_world_size() 237 | 238 | 239 | def get_rank(): 240 | if not is_dist_avail_and_initialized(): 241 | return 0 242 | return dist.get_rank() 243 | 244 | 245 | def is_main_process(): 246 | return get_rank() == 0 247 | 248 | 249 | def save_on_master(*args, **kwargs): 250 | if is_main_process(): 251 | torch.save(*args, **kwargs) 252 | 253 | 254 | def init_distributed_mode(args): 255 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 256 | args.rank = int(os.environ["RANK"]) 257 | args.world_size = int(os.environ['WORLD_SIZE']) 258 | args.gpu = int(os.environ['LOCAL_RANK']) 259 | elif 'SLURM_PROCID' in os.environ: 260 | args.rank = int(os.environ['SLURM_PROCID']) 261 | args.gpu = args.rank % torch.cuda.device_count() 262 | else: 263 | print('Not using distributed mode') 264 | args.distributed = False 265 | return 266 | 267 | args.distributed = True 268 | 269 | torch.cuda.set_device(args.gpu) 270 | args.dist_backend = 'nccl' 271 | print('| distributed init (rank {}, word {}): {}'.format( 272 | args.rank, args.world_size, args.dist_url), flush=True) 273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 274 | world_size=args.world_size, rank=args.rank) 275 | torch.distributed.barrier() 276 | setup_for_distributed(args.rank == 0) 277 | 278 | -------------------------------------------------------------------------------- /spec/models/blip_utils/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /spec/models/blip_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import subprocess 5 | import torch.nn.functional as F 6 | 7 | from tqdm import tqdm 8 | from einops import rearrange 9 | from .blip_utils.blip_retrieval import blip_retrieval 10 | from .base_wrapper import BaseWrapper 11 | from torchvision import transforms 12 | 13 | 14 | # All the below URLs are taken from, and most of the implementation are heavily inspired from the wonderful https://github.com/salesforce/BLIP repo. 15 | download_urls = { 16 | "blip-flickr-base": { 17 | "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth", 18 | "config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_flickr.yaml", 19 | "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json" 20 | }, 21 | 22 | "blip-coco-base": { 23 | "model_url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth", 24 | "config_url": "https://github.com/salesforce/BLIP/raw/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/retrieval_coco.yaml", 25 | "bert_config_url": "https://raw.githubusercontent.com/salesforce/BLIP/0480d94d5725a3d1aac66f21e6bf138ac17d323d/configs/med_config.json" 26 | }, 27 | } 28 | 29 | 30 | class BLIPModelWrapper(BaseWrapper): 31 | def __init__(self, cache_dir, device, variant="blip-coco-base"): 32 | self.cache_dir = cache_dir 33 | self.device = device 34 | self.variant = variant 35 | self.image_preprocess = transforms.Compose([ 36 | transforms.Resize((384, 384), interpolation=transforms.functional.InterpolationMode.BICUBIC), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) 39 | # download and load model 40 | self.config_path = os.path.join(cache_dir, f"{variant}-config") 41 | self.model_path = os.path.join(cache_dir, f"{variant}.pth") 42 | self.bert_config_path = os.path.join(cache_dir, "configs", f"{variant}_med_config.json") 43 | if not (os.path.exists(self.config_path) and os.path.exists(self.model_path) and os.path.exists( 44 | self.bert_config_path)): 45 | self.download() 46 | config = yaml.load(open(self.config_path, 'r'), Loader=yaml.Loader) 47 | self.config = config 48 | self.config['k_test'] = 128 49 | config['med_config'] = self.bert_config_path 50 | model = blip_retrieval(pretrained=self.model_path, image_size=config['image_size'], vit=config['vit'], 51 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 52 | queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'], 53 | med_config=config['med_config']).to(device) 54 | self.model = model.eval() 55 | 56 | 57 | def download(self): 58 | print(f"Downloading BLIP model to {self.cache_dir}...") 59 | model_url = download_urls[self.variant]["model_url"] 60 | config_url = download_urls[self.variant]["config_url"] 61 | bert_config_url = download_urls[self.variant]["bert_config_url"] 62 | os.makedirs(os.path.join(self.cache_dir, "configs"), exist_ok=True) 63 | subprocess.call(["wget", "-cq", model_url, "-O", self.model_path]) 64 | subprocess.call(["wget", "-cq", config_url, "-O", self.config_path]) 65 | subprocess.call(["wget", "-cq", bert_config_url, "-O", self.bert_config_path]) 66 | 67 | @torch.no_grad() 68 | def i2t_evaluate(self, subset_name, dataloader): 69 | tqdm_i2t_loader = tqdm(dataloader) 70 | tqdm_i2t_loader.set_description(f"Image to Text retrieval on <{subset_name}>") 71 | i2t_scores = [] 72 | i2t_correct_num = 0 73 | total_num = 0 74 | for batch in tqdm_i2t_loader: 75 | bs = len(batch['label']) 76 | # get query images 77 | query_images = batch['query_image'].to(self.device) # B,C,H,W (B:batch size) 78 | 79 | # compute normalized image embeddings 80 | image_feats = self.model.visual_encoder(query_images) 81 | image_embeddings = self.model.vision_proj(image_feats[:, 0, :]) 82 | image_embeddings = F.normalize(image_embeddings, dim=-1) # B, D (D:feature dim) 83 | 84 | # get candidate texts 85 | candidate_texts = batch['candidate_texts'] 86 | 87 | # compute normalized text embeddings 88 | text_embeddings = [] 89 | for texts in candidate_texts: 90 | text_input = self.model.tokenizer(texts, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) 91 | text_feat = self.model.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask, mode='text') 92 | text_embed = F.normalize(self.model.text_proj(text_feat.last_hidden_state[:, 0, :])) 93 | text_embeddings.append(text_embed) 94 | text_embeddings = torch.stack(text_embeddings, dim=0) # B, L, D 95 | 96 | # calculate matching result 97 | batch_i2t_scores = torch.einsum('BD,BLD->BL', [image_embeddings, text_embeddings]).cpu() 98 | i2t_scores.append(batch_i2t_scores) 99 | gt_labels = batch['label'] 100 | pred_labels = batch_i2t_scores.argmax(dim=-1) 101 | correct_num = (gt_labels == pred_labels).sum() 102 | i2t_correct_num += correct_num.item() 103 | total_num += bs 104 | 105 | i2t_scores = torch.cat(i2t_scores, dim=0) 106 | i2t_acc = 100 * i2t_correct_num / total_num 107 | 108 | return i2t_scores, i2t_acc 109 | 110 | @torch.no_grad() 111 | def t2i_evaluate(self, subset_name, dataloader): 112 | tqdm_t2i_loader = tqdm(dataloader) 113 | tqdm_t2i_loader.set_description(f"Text to Image retrieval on <{subset_name}>") 114 | t2i_scores = [] 115 | t2i_correct_num = 0 116 | total_num = 0 117 | for batch in tqdm_t2i_loader: 118 | bs = len(batch['label']) 119 | # get query texts 120 | query_texts = batch['query_text'] # B (B:batch size, list) 121 | query_texts = self.model.tokenizer(query_texts, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) 122 | 123 | # compute normalized text embeddings 124 | text_feats = self.model.text_encoder(query_texts.input_ids, attention_mask=query_texts.attention_mask, mode='text') 125 | text_embeddings = F.normalize(self.model.text_proj(text_feats.last_hidden_state[:, 0, :])) # B,D (D:feature dim) 126 | 127 | # get candidate images 128 | candidate_images = batch['candidate_images'].to(self.device) # B,K,C,H,W (K:num of candidate images per case, S:sentence length) 129 | candidate_images = rearrange(candidate_images, 'B K C H W -> (B K) C H W') 130 | 131 | # compute normalized image embeddings 132 | image_feats = self.model.visual_encoder(candidate_images) 133 | image_embeddings = self.model.vision_proj(image_feats[:, 0, :]) 134 | image_embeddings = F.normalize(image_embeddings, dim=-1) # B, D (D:feature dim) 135 | image_embeddings = rearrange(image_embeddings, '(B K) D -> B K D', B=bs) 136 | 137 | # calculate matching result 138 | batch_t2i_scores = torch.einsum('BD,BKD->BK', [text_embeddings, image_embeddings]).cpu() 139 | t2i_scores.append(batch_t2i_scores) 140 | gt_labels = batch['label'] 141 | pred_labels = batch_t2i_scores.argmax(dim=-1) 142 | correct_num = (gt_labels == pred_labels).sum() 143 | t2i_correct_num += correct_num.item() 144 | total_num += bs 145 | 146 | t2i_scores = torch.cat(t2i_scores, dim=0) 147 | t2i_acc = 100 * t2i_correct_num / total_num 148 | 149 | return t2i_scores, t2i_acc 150 | -------------------------------------------------------------------------------- /spec/models/clip_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | from tqdm import tqdm 5 | from .base_wrapper import BaseWrapper 6 | from open_clip import create_model_and_transforms, get_tokenizer 7 | 8 | 9 | class CLIPWrapper(BaseWrapper): 10 | def __init__(self, device, variant='ViT-B-32', pretrained='openai'): 11 | self.device = device 12 | model, _, image_preprocess = create_model_and_transforms(variant, 13 | device=self.device, 14 | pretrained=pretrained) 15 | self.model = model.eval() 16 | self.tokenizer = get_tokenizer(variant) 17 | self.image_preprocess = image_preprocess 18 | 19 | @torch.no_grad() 20 | def i2t_evaluate(self, subset_name, dataloader): 21 | tqdm_i2t_loader = tqdm(dataloader) 22 | tqdm_i2t_loader.set_description(f"Image to Text retrieval on <{subset_name}>") 23 | i2t_scores = [] 24 | i2t_correct_num = 0 25 | total_num = 0 26 | for batch in tqdm_i2t_loader: 27 | bs = len(batch['label']) 28 | # get query images 29 | query_images = batch['query_image'].to(self.device) # B,C,H,W (B:batch size) 30 | 31 | # compute normalized image embeddings 32 | image_embeddings = self.model.encode_image(query_images) # B,D (D:feature dim) 33 | image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) 34 | 35 | # get candidate texts 36 | candidate_texts = batch['candidate_texts'] 37 | candidate_texts = [self.tokenizer(texts) for texts in candidate_texts] 38 | candidate_texts = torch.stack(candidate_texts, dim=0).to(self.device) # B,L,S (L:num of candidate texts, S:sentence length) 39 | 40 | # compute normalized text embeddings 41 | candidate_texts = rearrange(candidate_texts, 'B L S -> (B L) S') 42 | text_embeddings = self.model.encode_text(candidate_texts) # BL,D (D:feature dim) 43 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 44 | text_embeddings = rearrange(text_embeddings, '(B L) D -> B L D', B=bs) # B, L, D 45 | 46 | # calculate matching result 47 | batch_i2t_scores = torch.einsum('BD,BLD->BL', [image_embeddings, text_embeddings]).cpu() 48 | i2t_scores.append(batch_i2t_scores) 49 | gt_labels = batch['label'] 50 | pred_labels = batch_i2t_scores.argmax(dim=-1) 51 | correct_num = (gt_labels == pred_labels).sum() 52 | i2t_correct_num += correct_num.item() 53 | total_num += bs 54 | 55 | i2t_scores = torch.cat(i2t_scores, dim=0) 56 | i2t_acc = 100 * i2t_correct_num / total_num 57 | 58 | return i2t_scores, i2t_acc 59 | 60 | @torch.no_grad() 61 | def t2i_evaluate(self, subset_name, dataloader): 62 | tqdm_t2i_loader = tqdm(dataloader) 63 | tqdm_t2i_loader.set_description(f"Text to Image retrieval on <{subset_name}>") 64 | t2i_scores = [] 65 | t2i_correct_num = 0 66 | total_num = 0 67 | for batch in tqdm_t2i_loader: 68 | bs = len(batch['label']) 69 | # get query texts 70 | query_texts = batch['query_text'] # B,1,S (B:batch size, S:sentence length) 71 | query_texts = self.tokenizer(query_texts).to(self.device) # B, S (B:batch size, S:sentence length) 72 | 73 | # compute normalized text embeddings 74 | text_embeddings = self.model.encode_text(query_texts) # B,D (D:feature dim) 75 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 76 | 77 | # get candidate images 78 | candidate_images = batch['candidate_images'].to(self.device) # B,K,C,H,W (K:num of candidate images per case, S:sentence length) 79 | candidate_images = rearrange(candidate_images, 'B K C H W -> (B K) C H W') 80 | 81 | # compute normalized image embeddings 82 | image_embeddings = self.model.encode_image( 83 | candidate_images) # BK,D (K:num of candidate images per case, D:feature dim) 84 | image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) 85 | image_embeddings = rearrange(image_embeddings, '(B K) D -> B K D', B=bs) 86 | 87 | # calculate matching result 88 | batch_t2i_scores = torch.einsum('BD,BKD->BK', [text_embeddings, image_embeddings]).cpu() 89 | t2i_scores.append(batch_t2i_scores) 90 | gt_labels = batch['label'] 91 | pred_labels = batch_t2i_scores.argmax(dim=-1) 92 | correct_num = (gt_labels == pred_labels).sum() 93 | t2i_correct_num += correct_num.item() 94 | total_num += bs 95 | 96 | t2i_scores = torch.cat(t2i_scores, dim=0) 97 | t2i_acc = 100 * t2i_correct_num / total_num 98 | 99 | return t2i_scores, t2i_acc 100 | -------------------------------------------------------------------------------- /spec/models/flava_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from transformers import FlavaProcessor, FlavaForPreTraining, BertTokenizer, FlavaFeatureExtractor 5 | from .base_wrapper import BaseWrapper 6 | 7 | 8 | class FlavaWrapper(BaseWrapper): 9 | def __init__(self, cache_dir, device): 10 | # load model, tokenizer, processor 11 | self.model = FlavaForPreTraining.from_pretrained("facebook/flava-full", cache_dir=cache_dir).eval() 12 | self.model = self.model.to(device) 13 | self.feature_extractor = FlavaFeatureExtractor.from_pretrained("facebook/flava-full", cache_dir=cache_dir) 14 | self.tokenizer = BertTokenizer.from_pretrained("facebook/flava-full", cache_dir=cache_dir) 15 | self.processor = FlavaProcessor.from_pretrained("facebook/flava-full", cache_dir=cache_dir) 16 | self.device = device 17 | 18 | @torch.no_grad() 19 | def i2t_evaluate(self, subset_name, dataloader): 20 | tqdm_i2t_loader = tqdm(dataloader) 21 | tqdm_i2t_loader.set_description(f"Image to Text retrieval on <{subset_name}>") 22 | i2t_scores = [] 23 | i2t_correct_num = 0 24 | total_num = 0 25 | for batch in tqdm_i2t_loader: 26 | bs = len(batch['label']) 27 | # get query images 28 | query_images = batch['query_image'] # [B x PIL.Image] (B:batch size) 29 | 30 | # compute normalized image embeddings 31 | inputs = self.feature_extractor(images=query_images, return_tensors="pt").to(self.device) 32 | image_embeddings = self.model.flava.get_image_features(**inputs).cpu().numpy()[:, 0, :] 33 | image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True) 34 | image_embeddings = torch.tensor(image_embeddings) # (B, D) 35 | 36 | # get candidate texts 37 | candidate_texts = batch['candidate_texts'] # B, L 38 | 39 | # compute normalized text embeddings 40 | text_embeddings = [] 41 | for texts in candidate_texts: 42 | text_input = self.tokenizer(text=texts, return_tensors="pt", padding="max_length", max_length=77).to(self.device) 43 | text_feats = self.model.flava.get_text_features(**text_input).cpu().numpy()[:, 0, :] 44 | text_feats = text_feats / np.linalg.norm(text_feats, axis=1, keepdims=True) 45 | text_feats = torch.tensor(text_feats) 46 | text_embeddings.append(text_feats) 47 | text_embeddings = torch.stack(text_embeddings, dim=0) # (B, L, D) 48 | 49 | # calculate matching result 50 | batch_i2t_scores = torch.einsum('BD,BLD->BL', [image_embeddings, text_embeddings]).cpu() 51 | i2t_scores.append(batch_i2t_scores) 52 | gt_labels = batch['label'] 53 | pred_labels = batch_i2t_scores.argmax(dim=-1) 54 | correct_num = (gt_labels == pred_labels).sum() 55 | i2t_correct_num += correct_num.item() 56 | total_num += bs 57 | 58 | i2t_scores = torch.cat(i2t_scores, dim=0) 59 | i2t_acc = 100 * i2t_correct_num / total_num 60 | 61 | return i2t_scores, i2t_acc 62 | 63 | @torch.no_grad() 64 | def t2i_evaluate(self, subset_name, dataloader): 65 | tqdm_t2i_loader = tqdm(dataloader) 66 | tqdm_t2i_loader.set_description(f"Text to Image retrieval on <{subset_name}>") 67 | t2i_scores = [] 68 | t2i_correct_num = 0 69 | total_num = 0 70 | for batch in tqdm_t2i_loader: 71 | bs = len(batch['label']) 72 | # get query texts 73 | query_texts = batch['query_text'] # [B x STR] 74 | 75 | # compute normalized text embeddings 76 | text_input = self.tokenizer(text=query_texts, return_tensors="pt", padding="max_length", max_length=77).to(self.device) 77 | text_embeddings = self.model.flava.get_text_features(**text_input).cpu().numpy()[:, 0, :] 78 | text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True) 79 | text_embeddings = torch.tensor(text_embeddings) # B,D (D:feature dim) 80 | 81 | # get candidate images 82 | candidate_images = batch['candidate_images'] # [BxL, PIL.Image] 83 | 84 | # compute normalized image embeddings 85 | image_embeddings = [] 86 | for images in candidate_images: 87 | image_inputs = self.feature_extractor(images=images, return_tensors="pt").to(self.device) 88 | image_embed = self.model.flava.get_image_features(**image_inputs).cpu().numpy()[:, 0, :] 89 | image_embed = image_embed / np.linalg.norm(image_embed, axis=1, keepdims=True) 90 | image_embed = torch.tensor(image_embed) # (L, D) 91 | image_embeddings.append(image_embed) 92 | image_embeddings = torch.stack(image_embeddings, dim=0) # (B, L, D) 93 | 94 | # calculate matching result 95 | batch_t2i_scores = torch.einsum('BD,BKD->BK', [text_embeddings, image_embeddings]).cpu() 96 | t2i_scores.append(batch_t2i_scores) 97 | gt_labels = batch['label'] 98 | pred_labels = batch_t2i_scores.argmax(dim=-1) 99 | correct_num = (gt_labels == pred_labels).sum() 100 | t2i_correct_num += correct_num.item() 101 | total_num += bs 102 | 103 | t2i_scores = torch.cat(t2i_scores, dim=0) 104 | t2i_acc = 100 * t2i_correct_num / total_num 105 | 106 | return t2i_scores, t2i_acc 107 | -------------------------------------------------------------------------------- /spec/run_eval.sh: -------------------------------------------------------------------------------- 1 | model_dir='/path/to/cache/models' 2 | data_dir='/path/to/data' 3 | out_dir='/path/to/save/results' 4 | 5 | for model in clip blip flava coca 6 | do 7 | python eval.py \ 8 | --model-name $model \ 9 | --model-cache-dir $model_dir \ 10 | --subset-names absolute_size relative_size absolute_spatial relative_spatial existence count \ 11 | --data-root $data_dir \ 12 | --out-path $out_dir \ 13 | --batch-size 64 \ 14 | --num-workers 8 \ 15 | --seed 1 16 | done --------------------------------------------------------------------------------