├── assets ├── .gitkeep ├── catdog.png └── DeViL_teaser.png ├── .gitignore ├── environment.yml ├── src ├── language_model.py ├── milan_keys.py ├── viz_notebook.ipynb ├── vision_models.py ├── trainer.py ├── dataset.py ├── main.py ├── evaluation.py └── model.py ├── README.md └── LICENSE /assets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/catdog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/DeViL/HEAD/assets/catdog.png -------------------------------------------------------------------------------- /assets/DeViL_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/DeViL/HEAD/assets/DeViL_teaser.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.pt 4 | *.json 5 | log/ 6 | data/ 7 | venv/ 8 | *.jpg 9 | *.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: devil 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - huggingface 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - ftfy=6.1.1 10 | - matplotlib=3.7.1 11 | - openjdk=11.0.15 12 | - python=3.10.9 13 | - pytorch=2.0.0 14 | - pytorch-cuda=11.8 15 | - scipy=1.10.1 16 | - sentencepiece=0.1.97 17 | - torchvision=0.15.0 18 | - transformers=4.26.1 19 | - wandb=0.14.0 20 | - webdataset=0.2.48 21 | - pip: 22 | - bert_score==0.3.13 23 | - easydict==1.10 24 | - evaluate==0.4.0 25 | - pycocoevalcap==1.2 26 | - pycocotools==2.0.6 27 | - timm==0.8.13.dev0 28 | - git+https://github.com/openai/CLIP.git 29 | -------------------------------------------------------------------------------- /src/language_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | 3 | 4 | def get_language_model(nlp_model_name): 5 | # get tokenizer 6 | tokenizer = get_tokenizer(nlp_model_name) 7 | # get model 8 | language_model = AutoModelForCausalLM.from_pretrained(nlp_model_name) 9 | 10 | if "opt" in nlp_model_name: 11 | embedding_weights = language_model.model.decoder.embed_tokens 12 | embed_dim = language_model.config.word_embed_proj_dim 13 | elif "gpt2" in nlp_model_name: 14 | embedding_weights = language_model.transformer.wte 15 | embed_dim = language_model.config.n_embd 16 | elif "BERT" in nlp_model_name: 17 | language_model = AutoModelForCausalLM.from_pretrained( 18 | nlp_model_name, is_decoder=True 19 | ) 20 | embedding_weights = language_model.bert.embeddings.word_embeddings 21 | embed_dim = language_model.config.hidden_size 22 | else: 23 | raise NotImplementedError 24 | 25 | return language_model, tokenizer, embedding_weights, embed_dim 26 | 27 | 28 | def get_tokenizer(language_model_name): 29 | tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_fast=False) 30 | if "BERT" not in language_model_name: 31 | tokenizer.pad_token = tokenizer.eos_token 32 | tokenizer.add_bos_token = True 33 | return tokenizer 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeViL: Decoding Vision features into Language 2 | 3 | [[Paper]](https://arxiv.org/abs/2309.01617) 4 | 5 | This is the official repository for our **GCPR 2023 Oral** paper on Decoding Vision features into Language (DeViL). 6 | 7 | ![DeViL Teaser](./assets/DeViL_teaser.png) 8 | 9 | ## Getting started 10 | To ensure you have the right environment to work on please use the file environment.yml using this command 11 | 12 | ``` 13 | conda env create -n --file environment.yml 14 | ``` 15 | 16 | ## Training DeViL on image-text pairs 17 | 18 | Datasets are implemented in the `src/datasets.py` file which contains the code for data loading and data collation. 19 | 20 | A paired image-text dataset for training returns a dictionay with items: `{'id': image id, 'image': image, 'text': captions}`. 21 | 22 | In the `src/datasets.py` file you can find implementations for the CC3M and MILANNOTATIONS datasets. 23 | 24 | Once the dataset is prepared, you can run training with the `src/main.py` file. Don't forget to set the `data_root` and `logdir` arguments. 25 | 26 | For a detailed explanation of all arguments, see: `src/main.py`. 27 | 28 | 29 | Example command: 30 | ``` 31 | python src/main.py --data_root ./data --dataset cc3m --logdir ./results --language_model facebook/opt-125m \ 32 | --vision_backbone timm_resnet50 --token_dropout 0.5 --feature_dropout 0.5 --vision_feat_layer -1 -2 -3 -4 33 | ``` 34 | 35 | ## Evaluation (NLP metrics) 36 | To evaluate a trained model run: 37 | 38 | ``` 39 | python src/main.py 40 | --do_eval_nlp \ 41 | --model_ckpt \ 42 | --data_root 43 | ``` 44 | 45 | If you wish to evaluate descriptions of a specific layer, set the `layer` argument. 46 | 47 | ## Generation of textual descriptions 48 | 49 | To generate textual descriptions for different layers and feature locations run: 50 | 51 | ``` 52 | python src/main.py 53 | --do_eval_qualitative \ 54 | --model_ckpt \ 55 | --data_root \ 56 | --loc_ids "-1: [[-1, -1]]" 57 | ``` 58 | 59 | For more details, see the `loc_ids`, `pool_locs`, `kernel_size` arguments in `src/main.py`. 60 | 61 | ## Generation of open-vocabulary saliency maps 62 | 63 | To generate open-vocabulary saliency maps see [viz_notebook.ipynb](./src/viz_notebook.ipynb). 64 | 65 | ## CC3M Dataset 66 | 67 | To obtain CC3M in webdataset format, you can use [img2dataset](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md). 68 | 69 | ## MILAN 70 | 71 | To train on the [MILANNOTATIONS](https://github.com/evandez/neuron-descriptions) dataset, follow the instructions to download the dataset and change the `dataset` argument to one of the MILANNOTATION [keys](./src/milan_keys.py) (e.g. "imagenet"). 72 | 73 | You can also activate the `by_unit` argument so that this dataset is processed by neuron (aka unit) instead of by image. 74 | -------------------------------------------------------------------------------- /src/milan_keys.py: -------------------------------------------------------------------------------- 1 | import easydict 2 | 3 | KEYS = easydict.EasyDict() 4 | KEYS.ALEXNET = "alexnet" 5 | KEYS.BIGGAN = "biggan" 6 | KEYS.DINO_VITS8 = "dino_vits8" 7 | KEYS.RESNET152 = "resnet152" 8 | 9 | KEYS.IMAGENET = "imagenet" 10 | KEYS.PLACES365 = "places365" 11 | 12 | KEYS.ALEXNET_IMAGENET = f"{KEYS.ALEXNET}/{KEYS.IMAGENET}" 13 | KEYS.BIGGAN_IMAGENET = f"{KEYS.BIGGAN}/{KEYS.IMAGENET}" 14 | KEYS.DINO_VITS8_IMAGENET = f"{KEYS.DINO_VITS8}/{KEYS.IMAGENET}" 15 | KEYS.RESNET152_IMAGENET = f"{KEYS.RESNET152}/{KEYS.IMAGENET}" 16 | 17 | KEYS.ALEXNET_PLACES365 = f"{KEYS.ALEXNET}/{KEYS.PLACES365}" 18 | KEYS.RESNET152_PLACES365 = f"{KEYS.RESNET152}/{KEYS.PLACES365}" 19 | KEYS.BIGGAN_PLACES365 = f"{KEYS.BIGGAN}/{KEYS.PLACES365}" 20 | 21 | KEYS.GENERATORS = "gen" 22 | KEYS.CLASSIFIERS = "cls" 23 | KEYS.BASE = "base" 24 | KEYS.NOT_ALEXNET_IMAGENET = f"not-{KEYS.ALEXNET}-{KEYS.IMAGENET}" 25 | KEYS.NOT_ALEXNET_PLACES365 = f"not-{KEYS.ALEXNET}-{KEYS.PLACES365}" 26 | KEYS.NOT_RESNET152_IMAGENET = f"not-{KEYS.RESNET152}-{KEYS.IMAGENET}" 27 | KEYS.NOT_RESNET152_PLACES365 = f"not-{KEYS.RESNET152}-{KEYS.PLACES365}" 28 | KEYS.NOT_BIGGAN_IMAGENET = f"not-{KEYS.BIGGAN}-{KEYS.IMAGENET}" 29 | KEYS.NOT_BIGGAN_PLACES365 = f"not-{KEYS.BIGGAN}-{KEYS.PLACES365}" 30 | 31 | # Different partitions of MILANNOTATIONS based on the generalization 32 | # experiments from the original paper. 33 | DATASET_GROUPINGS = { 34 | KEYS.BASE: ( 35 | KEYS.ALEXNET_IMAGENET, 36 | KEYS.ALEXNET_PLACES365, 37 | KEYS.RESNET152_IMAGENET, 38 | KEYS.RESNET152_PLACES365, 39 | KEYS.BIGGAN_IMAGENET, 40 | KEYS.BIGGAN_PLACES365, 41 | ), 42 | KEYS.CLASSIFIERS: ( 43 | KEYS.ALEXNET_IMAGENET, 44 | KEYS.ALEXNET_PLACES365, 45 | KEYS.RESNET152_IMAGENET, 46 | KEYS.RESNET152_PLACES365, 47 | ), 48 | KEYS.GENERATORS: ( 49 | KEYS.BIGGAN_IMAGENET, 50 | KEYS.BIGGAN_PLACES365, 51 | ), 52 | KEYS.IMAGENET: ( 53 | KEYS.ALEXNET_IMAGENET, 54 | KEYS.RESNET152_IMAGENET, 55 | KEYS.BIGGAN_IMAGENET, 56 | ), 57 | KEYS.PLACES365: ( 58 | KEYS.ALEXNET_PLACES365, 59 | KEYS.RESNET152_PLACES365, 60 | KEYS.BIGGAN_PLACES365, 61 | ), 62 | KEYS.ALEXNET: ( 63 | KEYS.ALEXNET_IMAGENET, 64 | KEYS.ALEXNET_PLACES365, 65 | ), 66 | KEYS.RESNET152: ( 67 | KEYS.RESNET152_IMAGENET, 68 | KEYS.RESNET152_PLACES365, 69 | ), 70 | KEYS.BIGGAN: ( 71 | KEYS.BIGGAN_IMAGENET, 72 | KEYS.BIGGAN_PLACES365, 73 | ), 74 | KEYS.NOT_ALEXNET_IMAGENET: ( 75 | KEYS.ALEXNET_PLACES365, 76 | KEYS.RESNET152_IMAGENET, 77 | KEYS.RESNET152_PLACES365, 78 | KEYS.BIGGAN_IMAGENET, 79 | KEYS.BIGGAN_PLACES365, 80 | ), 81 | KEYS.NOT_ALEXNET_PLACES365: ( 82 | KEYS.ALEXNET_IMAGENET, 83 | KEYS.RESNET152_IMAGENET, 84 | KEYS.RESNET152_PLACES365, 85 | KEYS.BIGGAN_IMAGENET, 86 | KEYS.BIGGAN_PLACES365, 87 | ), 88 | KEYS.NOT_RESNET152_IMAGENET: ( 89 | KEYS.ALEXNET_IMAGENET, 90 | KEYS.ALEXNET_PLACES365, 91 | KEYS.RESNET152_PLACES365, 92 | KEYS.BIGGAN_IMAGENET, 93 | KEYS.BIGGAN_PLACES365, 94 | ), 95 | KEYS.NOT_RESNET152_PLACES365: ( 96 | KEYS.ALEXNET_IMAGENET, 97 | KEYS.ALEXNET_PLACES365, 98 | KEYS.RESNET152_IMAGENET, 99 | KEYS.BIGGAN_IMAGENET, 100 | KEYS.BIGGAN_PLACES365, 101 | ), 102 | KEYS.NOT_BIGGAN_IMAGENET: ( 103 | KEYS.ALEXNET_IMAGENET, 104 | KEYS.ALEXNET_PLACES365, 105 | KEYS.RESNET152_IMAGENET, 106 | KEYS.RESNET152_PLACES365, 107 | KEYS.BIGGAN_PLACES365, 108 | ), 109 | KEYS.NOT_BIGGAN_PLACES365: ( 110 | KEYS.ALEXNET_IMAGENET, 111 | KEYS.ALEXNET_PLACES365, 112 | KEYS.RESNET152_IMAGENET, 113 | KEYS.RESNET152_PLACES365, 114 | KEYS.BIGGAN_IMAGENET, 115 | ), 116 | } 117 | 118 | TRAIN_TEST_PAIRS = { 119 | KEYS.CLASSIFIERS: KEYS.GENERATORS, 120 | KEYS.GENERATORS: KEYS.CLASSIFIERS, 121 | KEYS.IMAGENET: KEYS.PLACES365, 122 | KEYS.PLACES365: KEYS.IMAGENET, 123 | KEYS.ALEXNET: KEYS.RESNET152, 124 | KEYS.RESNET152: KEYS.ALEXNET, 125 | KEYS.BASE: KEYS.DINO_VITS8_IMAGENET, 126 | KEYS.NOT_ALEXNET_IMAGENET: KEYS.ALEXNET_IMAGENET, 127 | KEYS.NOT_ALEXNET_PLACES365: KEYS.ALEXNET_PLACES365, 128 | KEYS.NOT_RESNET152_IMAGENET: KEYS.RESNET152_IMAGENET, 129 | KEYS.NOT_RESNET152_PLACES365: KEYS.RESNET152_PLACES365, 130 | KEYS.NOT_BIGGAN_IMAGENET: KEYS.BIGGAN_IMAGENET, 131 | KEYS.NOT_BIGGAN_PLACES365: KEYS.BIGGAN_PLACES365, 132 | } 133 | 134 | WITHIN_NETWORK = [ 135 | KEYS.ALEXNET_IMAGENET, 136 | KEYS.ALEXNET_PLACES365, 137 | KEYS.BIGGAN_IMAGENET, 138 | KEYS.BIGGAN_PLACES365, 139 | KEYS.RESNET152_IMAGENET, 140 | KEYS.RESNET152_PLACES365, 141 | ] 142 | -------------------------------------------------------------------------------- /src/viz_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import AutoConfig\n", 10 | "from model import Features2WordsModel, TranslationTransformerConfig\n", 11 | "import os\n", 12 | "from argparse import Namespace # needed for reading saved argparse parameters\n", 13 | "import torch\n", 14 | "from torch.nn import functional as F\n", 15 | "from PIL import Image\n", 16 | "from itertools import product\n", 17 | "from evaluation import get_saliency_map\n", 18 | "import matplotlib.pyplot as plt" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# PARAMETERS TO DEFINE\n", 28 | "MODEL_CKPT = # path to model checkpoint\n", 29 | "IMG = '../assets/catdog.png'\n", 30 | "QUERIES = ['cat', 'dog', 'animal']\n", 31 | "LAYER = -1 # e.g. -1; if left None will compute metrics with features for all layers.\n", 32 | "HRES = False # set to True if you want to generate saliency maps for a given layer but with the resolution of the lowest layer.\n", 33 | "POOL_LOCS = 'keep_dims' # choose between None, reduce_dims, keep_dims. Pools lower layer locations.\n", 34 | "KERNEL_SIZE = -1 # kernel size for pooling lower layer locations. We typically use 3 for layer -2 and 7 for layer -3.\n", 35 | "GPU_ID = 0" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# load model\n", 45 | "translation_config = TranslationTransformerConfig.from_json_file(os.path.join(MODEL_CKPT, 'translation_model_config.json'))\n", 46 | "lm_model_config = AutoConfig.from_pretrained(os.path.join(MODEL_CKPT, 'lm_model_config.json'))\n", 47 | "max_length = lm_model_config.max_length\n", 48 | "\n", 49 | "with open(os.path.join(os.path.join(MODEL_CKPT, '../train_params.txt')), 'r') as f:\n", 50 | " namespace_str = f.read()\n", 51 | "train_args = eval(namespace_str)\n", 52 | "vision_backbone = train_args.vision_backbone\n", 53 | "vision_feat_func = train_args.vision_feat_func\n", 54 | "vision_feat_layers = train_args.vision_feat_layers\n", 55 | "language_model = train_args.language_model\n", 56 | "model = Features2WordsModel(translation_config=translation_config, cnn_name=vision_backbone, vision_feat_func=vision_feat_func, vision_feat_layers=vision_feat_layers, lm_model_name=language_model, max_length=max_length)\n", 57 | "model_checkpoint = torch.load(os.path.join(MODEL_CKPT, 'model.pt'))\n", 58 | "model.translation.load_state_dict(model_checkpoint[\"MODEL_STATE\"])\n", 59 | "model.eval()\n", 60 | "model.to(GPU_ID)\n", 61 | "transform = model.transform" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "def grid(n, dim):\n", 71 | " return sorted(list(set(product(range(n), repeat=dim))))\n", 72 | "\n", 73 | "def prepare_grid(model, layer, hres, pool_locs):\n", 74 | " stride = 1\n", 75 | " # prepare spatial locations\n", 76 | " gs = model.grid_sizes[::-1]\n", 77 | " if layer is None:\n", 78 | " layer_id = None\n", 79 | " layer_name = 'all'\n", 80 | " gs = gs[0] # get the size of the lowest layer the model was trained on\n", 81 | " else:\n", 82 | " layer_name = str(layer)\n", 83 | " layer_id = -layer - 1\n", 84 | "\n", 85 | " if hres:\n", 86 | " gs = gs[0] # get the size of the lowest layer the model was trained on\n", 87 | " layer_name += '_hres'\n", 88 | " else:\n", 89 | " if pool_locs == 'reduce_dims':\n", 90 | " gs = gs[-1]\n", 91 | " stride = int(gs[layer] / gs[-1])\n", 92 | " else:\n", 93 | " gs = gs[layer]\n", 94 | " \n", 95 | " loc_ids = grid(gs, 2)\n", 96 | " return gs, loc_ids, layer_id, layer_name, stride\n", 97 | "\n", 98 | "def get_saliency(model, layer, hres, pool_locs, kernel_size, img, query):\n", 99 | " gs, loc_ids, layer_id, layer_name, stride = prepare_grid(model, layer, hres, pool_locs)\n", 100 | " loss_saliency, _ = get_saliency_map(model, loc_ids, img.unsqueeze(0), query, layer_id, gpu_id=GPU_ID, hres=hres, pool_locs=pool_locs, kernel_size=kernel_size, stride=stride)\n", 101 | " loss_saliency_interp = loss_saliency.view(-1, gs, gs).unsqueeze(dim=1)\n", 102 | " loss_saliency_interp = F.interpolate(loss_saliency_interp, tuple(img.shape[1:]), mode='bilinear')\n", 103 | " smap_loss = loss_saliency_interp[0].permute(1, 2, 0).cpu().numpy()\n", 104 | " smap_loss = -smap_loss\n", 105 | "\n", 106 | " return smap_loss, layer_name" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# preprocess image\n", 116 | "img = Image.open(IMG).convert(\"RGB\")\n", 117 | "img_wo_transform = img.copy()\n", 118 | "img = transform(img)\n", 119 | "img = img.to(GPU_ID)\n", 120 | "\n", 121 | "for i, q in enumerate(QUERIES):\n", 122 | " smap, layer_name = get_saliency(model, LAYER, HRES, POOL_LOCS, KERNEL_SIZE, img, q)\n", 123 | " plt.figure(i)\n", 124 | " plt.imshow(img_wo_transform)\n", 125 | " plt.axis('off')\n", 126 | " plt.title(f'Layer {layer_name}')\n", 127 | " plt.imshow(smap, cmap='jet', alpha=0.5)\n", 128 | " plt.show()" 129 | ] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3 (ipykernel)", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.10.9" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 4 153 | } 154 | -------------------------------------------------------------------------------- /src/vision_models.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import timm 3 | import torch 4 | import torchvision 5 | from timm.data import resolve_data_config 6 | from timm.data.transforms_factory import create_transform 7 | from torch import nn 8 | from torchvision.transforms._presets import ImageClassification 9 | 10 | 11 | class WrapOutputInList(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, x): 16 | return [x] 17 | 18 | 19 | class IndexOutput(nn.Module): 20 | def __init__( 21 | self, vision_feat_level, add_vit_embed_token=False, only_vit_embed_token=False 22 | ): 23 | super().__init__() 24 | self.vision_feat_level = vision_feat_level 25 | self.add_vit_embed_token = add_vit_embed_token 26 | self.only_vit_embed_token = only_vit_embed_token 27 | 28 | def select_indices(self, x): 29 | return [x[vfl] for vfl in self.vision_feat_level] 30 | 31 | def forward(self, x): 32 | if self.add_vit_embed_token and not self.only_vit_embed_token: 33 | # join lists 34 | return self.select_indices(x[0]) + self.select_indices(x[1]) 35 | else: 36 | return self.select_indices(x) 37 | 38 | 39 | def get_vision_model( 40 | vision_model_name, 41 | vision_feat_level=[-1], 42 | add_vit_embed_token=False, 43 | only_vit_embed_token=False, 44 | ): 45 | transform = None 46 | if vision_model_name.startswith("timm"): 47 | assert not add_vit_embed_token, "ViT embed token unsupported for timm models" 48 | timm_model_name = vision_model_name[5:] 49 | vision_model = timm.create_model( 50 | timm_model_name, features_only=True, pretrained=True 51 | ) 52 | 53 | # Get model dims 54 | vision_dim = [] 55 | grid_size = [] 56 | if vision_model.default_cfg["pool_size"] is not None: 57 | assert ( 58 | vision_model.default_cfg["pool_size"][0] 59 | == vision_model.default_cfg["pool_size"][1] 60 | ) 61 | total_pool_size = vision_model.default_cfg["pool_size"][0] 62 | else: 63 | total_pool_size = 224 // vision_model.feature_info[-1]["reduction"] 64 | 65 | for vfl in vision_feat_level: 66 | vision_dim.append(vision_model.feature_info[vfl]["num_chs"]) 67 | if timm_model_name.startswith(("resnet", "efficientformerv2")): 68 | level_factor = ( 69 | vision_model.feature_info[-1]["reduction"] 70 | // vision_model.feature_info[vfl]["reduction"] 71 | ) 72 | elif timm_model_name.startswith("davit"): 73 | assert vfl < 0 74 | level_factor = prod( 75 | [vision_model.feature_info[l]["reduction"] for l in range(vfl, -1)] 76 | ) 77 | else: 78 | raise NotImplementedError 79 | grid_size.append(total_pool_size * level_factor) 80 | 81 | config = resolve_data_config({}, model=vision_model) 82 | transform = create_transform(**config) 83 | 84 | vision_model = nn.Sequential(vision_model, IndexOutput(vision_feat_level)) 85 | 86 | elif vision_model_name.startswith("clip"): 87 | clip_model_name = vision_model_name[5:] 88 | vision_model, transform = clip.load(clip_model_name) 89 | vision_model = vision_model.visual 90 | vision_model.float() 91 | if clip_model_name.startswith("RN"): 92 | assert ( 93 | not add_vit_embed_token 94 | ), "ViT embed token unsupported for clip_RN models" 95 | new_forward = clip_resnet_forward 96 | vision_dim, grid_size = get_resnet_dims(vision_feat_level) 97 | elif clip_model_name.startswith("ViT"): 98 | vision_model.add_vit_embed_token = add_vit_embed_token 99 | vision_model.only_vit_embed_token = only_vit_embed_token 100 | new_forward = clip_vit_forward 101 | vision_dim = [] 102 | grid_size = [] 103 | if not only_vit_embed_token: 104 | vision_dim = [768] * len(vision_feat_level) 105 | grid_size = [7] * len(vision_feat_level) 106 | if add_vit_embed_token: 107 | vision_dim += [768] * len(vision_feat_level) 108 | grid_size += [1] * len(vision_feat_level) 109 | else: 110 | raise ValueError 111 | 112 | # add new_forward function to the model instance as a class method 113 | bound_method = new_forward.__get__(vision_model, vision_model.__class__) 114 | setattr(vision_model, "forward", bound_method) 115 | 116 | vision_model = nn.Sequential( 117 | vision_model, 118 | IndexOutput(vision_feat_level, add_vit_embed_token, only_vit_embed_token), 119 | ) 120 | else: 121 | raise NotImplementedError 122 | 123 | assert transform is not None 124 | unnormalize = adjust_resize_and_get_unnormalize(transform) 125 | 126 | return vision_model, vision_dim, grid_size, transform, unnormalize 127 | 128 | 129 | def get_resnet_dims(vision_feat_level): 130 | vision_dim = [256, 512, 1024, 2048] 131 | grid_size = [56, 28, 14, 7] 132 | vision_dim = [vision_dim[vfl] for vfl in vision_feat_level] 133 | grid_size = [grid_size[vfl] for vfl in vision_feat_level] 134 | return vision_dim, grid_size 135 | 136 | 137 | def clip_resnet_forward(self, x): 138 | def stem(x): 139 | x = self.relu1(self.bn1(self.conv1(x))) 140 | x = self.relu2(self.bn2(self.conv2(x))) 141 | x = self.relu3(self.bn3(self.conv3(x))) 142 | x = self.avgpool(x) 143 | return x 144 | 145 | x = x.type(self.conv1.weight.dtype) 146 | x = stem(x) 147 | x1 = self.layer1(x) 148 | x2 = self.layer2(x1) 149 | x3 = self.layer3(x2) 150 | x4 = self.layer4(x3) 151 | # x = self.attnpool(x) 152 | 153 | return [x1, x2, x3, x4] 154 | 155 | 156 | def clip_vit_forward(self, x): 157 | x = self.conv1(x) # shape = [*, width, grid, grid] 158 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 159 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 160 | x = torch.cat( 161 | [ 162 | self.class_embedding.to(x.dtype) 163 | + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 164 | x, 165 | ], 166 | dim=1, 167 | ) # shape = [*, grid ** 2 + 1, width] 168 | x = x + self.positional_embedding.to(x.dtype) 169 | x = self.ln_pre(x) 170 | 171 | x = x.permute(1, 0, 2) # NLD -> LND 172 | outputs = [] 173 | embed_outputs = [] 174 | for mod in self.transformer.resblocks: 175 | x = mod(x) 176 | if not self.only_vit_embed_token: 177 | outputs.append(x[1:].permute(1, 2, 0).reshape(x.shape[1], x.shape[2], 7, 7)) 178 | if self.add_vit_embed_token: 179 | embed_outputs.append( 180 | x[:1].permute(1, 2, 0).reshape(x.shape[1], x.shape[2], 1, 1) 181 | ) 182 | 183 | """ 184 | x = x.permute(1, 0, 2) # LND -> NLD 185 | 186 | x = self.ln_post(x[:, 0, :]) 187 | 188 | if self.proj is not None: 189 | x = x @ self.proj 190 | """ 191 | if self.only_vit_embed_token: 192 | return embed_outputs 193 | elif self.add_vit_embed_token: 194 | return (outputs, embed_outputs) 195 | else: 196 | return outputs 197 | 198 | 199 | def adjust_resize_and_get_unnormalize(transform): 200 | if isinstance(transform, ImageClassification): 201 | mean = transform.mean 202 | std = transform.std 203 | transform.resize_size = ( 204 | transform.crop_size 205 | ) # directly resize to final crop size (224x224) 206 | elif isinstance(transform, torchvision.transforms.Compose): 207 | has_normalize = False 208 | crop_size = None 209 | for t in transform.transforms: 210 | if isinstance(t, torchvision.transforms.Normalize): 211 | mean = t.mean 212 | std = t.std 213 | has_normalize = True 214 | if isinstance(t, torchvision.transforms.CenterCrop): 215 | crop_size = t.size if isinstance(t.size, int) else t.size[0] 216 | if crop_size is None: 217 | print("Crop size not found, leaving resizing transform unchanged.") 218 | else: 219 | for t in transform.transforms: 220 | if isinstance(t, torchvision.transforms.Resize): 221 | t.size = crop_size 222 | if not has_normalize: 223 | return nn.Identity() 224 | else: 225 | raise NotImplementedError 226 | 227 | if not isinstance(mean, torch.Tensor): 228 | mean = torch.tensor(mean) 229 | std = torch.tensor(std) 230 | unnormalize = torchvision.transforms.Normalize(mean=-mean / std, std=1 / std) 231 | return unnormalize 232 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from distutils.dir_util import copy_tree 4 | 5 | import torch 6 | import wandb 7 | from torch.cuda import amp 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from tqdm import tqdm 12 | 13 | from dataset import get_loader 14 | from evaluation import eval_nlp 15 | 16 | WANDB = True 17 | 18 | 19 | class Trainer: 20 | def __init__( 21 | self, model, train_dtset, eval_dtset, outdir, args, is_distributed=False 22 | ): 23 | super().__init__() 24 | self.is_distributed = is_distributed 25 | self.train_dtset = train_dtset 26 | self.eval_dtset = eval_dtset 27 | self.outdir = outdir 28 | self.args = args 29 | self.max_steps = self.args.max_steps 30 | self.eval_steps = self.args.eval_steps 31 | self.warmup_steps = self.args.warmup_steps 32 | self.bs = self.args.bs 33 | self.lr = self.args.lr 34 | self.min_lr = self.args.min_lr 35 | self.decay = self.args.decay 36 | self.fp16 = self.args.fp16 37 | self.num_workers = self.args.num_workers 38 | 39 | if self.is_distributed: 40 | self.gpu_id = int(os.environ["LOCAL_RANK"]) 41 | else: 42 | self.gpu_id = 0 43 | self.model = model.to(self.gpu_id) 44 | 45 | self.opt = self.get_optimizer() 46 | self.train_loader = get_loader( 47 | self.train_dtset, 48 | self.num_workers, 49 | tokenizer=self.model.tokenizer, 50 | nr_learnable_tokens=self.model.nr_learnable_tokens, 51 | is_distributed=self.is_distributed, 52 | bs=self.bs, 53 | is_train=True, 54 | ) 55 | self.eval_loader = get_loader( 56 | self.eval_dtset, 57 | self.num_workers, 58 | tokenizer=self.model.tokenizer, 59 | nr_learnable_tokens=self.model.nr_learnable_tokens, 60 | is_distributed=self.is_distributed, 61 | bs=self.bs, 62 | is_train=False, 63 | ) 64 | self.scheduler = self.get_scheduler( 65 | self.opt, self.warmup_steps, self.max_steps, self.min_lr 66 | ) 67 | 68 | if self.fp16: 69 | self.scaler = amp.GradScaler(enabled=True) 70 | 71 | if WANDB and self.gpu_id == 0: 72 | wandb.init( 73 | config=self.args, 74 | project="devil", 75 | name=self.outdir.split("/")[-1], 76 | dir=self.outdir, 77 | ) 78 | self.best_score = 0.0 79 | self.best_loss = sys.maxsize 80 | 81 | self.step = 0 82 | if args.resume: 83 | assert ( 84 | args.model_ckpt is not None 85 | ), "Please provide a pretrained model checkpoint to resume training." 86 | loc = f"cuda:{self.gpu_id}" 87 | model_snapshot = torch.load( 88 | os.path.join(args.model_ckpt, "model.pt"), map_location=loc 89 | ) 90 | self.model.translation.load_state_dict(model_snapshot["MODEL_STATE"]) 91 | self.opt.load_state_dict(model_snapshot["OPTIMIZER_STATE"]) 92 | self.scheduler.load_state_dict(model_snapshot["SCHEDULER_STATE"]) 93 | self.step = model_snapshot["STEPS_RUN"] 94 | self.best_score = model_snapshot["SCORE"] 95 | self.best_loss = model_snapshot["LOSS"] 96 | 97 | if self.is_distributed: 98 | self.model = DDP(self.model, device_ids=[self.gpu_id]) 99 | 100 | self.forbidden_keys = ["id", "raw"] 101 | 102 | def get_parameter_names(self, model, forbidden_layer_types): 103 | result = [] 104 | for name, child in model.named_children(): 105 | result += [ 106 | f"{name}.{n}" 107 | for n in self.get_parameter_names(child, forbidden_layer_types) 108 | if not isinstance(child, tuple(forbidden_layer_types)) 109 | ] 110 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 111 | result += list(model._parameters.keys()) 112 | return result 113 | 114 | def get_optimizer(self): 115 | decay_parameters = self.get_parameter_names( 116 | self.model.translation, [torch.nn.LayerNorm] 117 | ) 118 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 119 | optimizer_grouped_parameters = [ 120 | { 121 | "params": [ 122 | p 123 | for n, p in self.model.translation.named_parameters() 124 | if n in decay_parameters 125 | ], 126 | "weight_decay": self.decay, 127 | }, 128 | { 129 | "params": [ 130 | p 131 | for n, p in self.model.translation.named_parameters() 132 | if n not in decay_parameters 133 | ], 134 | "weight_decay": 0.0, 135 | }, 136 | ] 137 | return AdamW(optimizer_grouped_parameters, lr=self.lr) 138 | 139 | def get_scheduler(self, optimizer, num_warmup_steps, num_training_steps, min_lr): 140 | def lr_lambda(current_step: int): 141 | if current_step < num_warmup_steps: 142 | return float(current_step) / float(max(1, num_warmup_steps)) 143 | 144 | # min_lr / self.lr (aka initial lr) because lambda is multiplied by initial lr (can be thought of as a %) 145 | return max( 146 | min_lr / self.lr, 147 | float(num_training_steps - current_step) 148 | / float(max(1, num_training_steps - num_warmup_steps)), 149 | ) 150 | 151 | return LambdaLR(optimizer, lr_lambda, -1) 152 | 153 | def train(self): 154 | self.model.train() 155 | 156 | if self.gpu_id == 0: 157 | pbar = tqdm(self.train_loader, file=sys.stdout) 158 | pbar.set_description("training") 159 | data_iter = iter(pbar) 160 | else: 161 | data_iter = iter(self.train_loader) 162 | 163 | while True: 164 | try: 165 | batch = next(data_iter) 166 | except StopIteration: 167 | if self.gpu_id == 0: 168 | pbar = tqdm(self.train_loader, file=sys.stdout) 169 | pbar.set_description("training") 170 | data_iter = iter(pbar) 171 | else: 172 | data_iter = iter(self.train_loader) 173 | batch = next(data_iter) 174 | batch = { 175 | k: v.to(self.gpu_id) 176 | for k, v in batch.items() 177 | if k not in self.forbidden_keys 178 | } 179 | 180 | with amp.autocast(enabled=self.fp16): 181 | loss = self.model(**batch) 182 | 183 | self.opt.zero_grad() 184 | if self.fp16: 185 | self.scaler.scale(loss).backward() 186 | self.scaler.step(self.opt) 187 | self.scaler.update() 188 | else: 189 | loss.backward() 190 | self.opt.step() 191 | 192 | if self.scheduler: 193 | self.scheduler.step() 194 | 195 | self.step += 1 196 | 197 | if self.gpu_id == 0: 198 | if WANDB: 199 | logdict = {"train/loss": loss.item(), "train_step": self.step} 200 | if self.scheduler: 201 | logdict.update( 202 | { 203 | "lr": self.scheduler.get_last_lr()[0], 204 | "lr_step": self.step, 205 | } 206 | ) 207 | 208 | if self.step % self.eval_steps == 0: 209 | val_loss, metrics, plot_data = eval_nlp( 210 | self.model, 211 | self.eval_loader, 212 | self.outdir, 213 | gpu_id=self.gpu_id, 214 | is_distributed=self.is_distributed, 215 | outname=f"step_{self.step:08d}", 216 | milan=self.args.milan, 217 | ) 218 | assert self.model.training is True 219 | 220 | if WANDB: 221 | logdict.update( 222 | { 223 | "val_step": self.step, 224 | "val/loss": val_loss, 225 | } 226 | ) 227 | for k, v in metrics.items(): 228 | logdict["val/" + k] = v 229 | 230 | for image_title, grid, cpt, _ in plot_data: 231 | logdict.update( 232 | {image_title: wandb.Image(grid, caption=cpt)} 233 | ) 234 | 235 | # save model with best bertscore 236 | bertscore = metrics["BERTScore"] 237 | is_best_score = bertscore > self.best_score 238 | self.best_score = max(self.best_score, bertscore) 239 | 240 | # save model with best val loss 241 | is_best_loss = val_loss < self.best_loss 242 | self.best_loss = min(val_loss, self.best_loss) 243 | 244 | # save latest checkpoint 245 | self.save_checkpoint( 246 | self.step, bertscore, val_loss, is_best_score, is_best_loss 247 | ) 248 | 249 | if WANDB: 250 | wandb.log(logdict) 251 | 252 | if self.step == self.max_steps: 253 | break 254 | 255 | def save_checkpoint( 256 | self, curr_step, score, loss, is_best_score=False, is_best_loss=False 257 | ): 258 | if self.is_distributed: 259 | model = self.model.module 260 | else: 261 | model = self.model 262 | 263 | snapshot = { 264 | "MODEL_STATE": model.translation.state_dict(), 265 | "OPTIMIZER_STATE": self.opt.state_dict(), 266 | "SCHEDULER_STATE": self.scheduler.state_dict(), 267 | "STEPS_RUN": curr_step, 268 | "SCORE": score, 269 | "LOSS": loss, 270 | } 271 | 272 | save_path = os.path.join(self.outdir, "latest_checkpoint") 273 | os.makedirs(save_path, exist_ok=True) 274 | 275 | torch.save(snapshot, os.path.join(save_path, "model.pt")) 276 | 277 | model.lm_model.config.to_json_file( 278 | os.path.join(save_path, "lm_model_config.json") 279 | ) 280 | model.translation.config.to_json_file( 281 | os.path.join(save_path, "translation_model_config.json") 282 | ) 283 | 284 | if hasattr(self, "tokenizer"): 285 | self.model.tokenizer.save_pretrained(save_path) 286 | 287 | if is_best_score: 288 | copy_tree(save_path, os.path.join(self.outdir, "best_checkpoint_score")) 289 | if is_best_loss: 290 | copy_tree(save_path, os.path.join(self.outdir, "best_checkpoint_loss")) 291 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from random import choice 4 | from typing import Any, Dict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import webdataset as wds 10 | from torch.utils.data import DataLoader, Dataset 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torchvision.transforms import Compose, Normalize, ToTensor 13 | 14 | from milan_keys import DATASET_GROUPINGS, KEYS, TRAIN_TEST_PAIRS, WITHIN_NETWORK 15 | 16 | random.seed(0) 17 | 18 | 19 | class CustomDataCollator: 20 | def __init__(self, tokenizer, nr_learnable_tokens) -> None: 21 | self.tokenizer = tokenizer 22 | self.nr_learnable_tokens = nr_learnable_tokens 23 | 24 | def __call__(self, batch) -> Dict[str, Any]: 25 | out_batch = {} 26 | 27 | if "id" in batch[0]: 28 | ids = [i["id"] for i in batch] 29 | elif "__key__" in batch[0]: 30 | ids = [i["__key__"] for i in batch] 31 | else: 32 | raise ValueError 33 | 34 | imgs = [i["image"] for i in batch] 35 | imgs = torch.stack(imgs) 36 | 37 | texts = [i["text"] + self.tokenizer.eos_token for i in batch] 38 | texts = self.tokenizer( 39 | texts, 40 | padding="longest", 41 | truncation=True, 42 | max_length=100, 43 | return_tensors="pt", 44 | ) 45 | 46 | bs = len(ids) 47 | 48 | gt_captions = texts.input_ids.clone() 49 | gt_captions = torch.where( 50 | texts.attention_mask == 0, -100, gt_captions 51 | ) # ignore padding tokens in loss 52 | gt_captions = torch.cat( 53 | ( 54 | gt_captions[:, :1], 55 | torch.ones((bs, self.nr_learnable_tokens), dtype=torch.long) * (-100), 56 | gt_captions[:, 1:], 57 | ), 58 | dim=1, 59 | ) 60 | 61 | attention_mask = texts.attention_mask 62 | attention_mask = torch.cat( 63 | ( 64 | attention_mask[:, :1], 65 | torch.ones((bs, self.nr_learnable_tokens), dtype=torch.long), 66 | attention_mask[:, 1:], 67 | ), 68 | dim=1, 69 | ) 70 | 71 | out_batch = { 72 | "id": ids, 73 | "pixel_values": imgs, 74 | "input_ids": texts["input_ids"], 75 | "attention_mask": attention_mask, 76 | "gt_captions": gt_captions, 77 | } 78 | 79 | if "label" in batch[0]: 80 | out_batch["label"] = torch.tensor([i["label"] for i in batch]) 81 | 82 | # raw text for NLP metrics 83 | if "raw" in batch[0]: 84 | out_batch["raw"] = [x["raw"] for x in batch] 85 | else: 86 | out_batch["raw"] = [ 87 | [x["text"]] for x in batch 88 | ] # single caption (webdatasets) 89 | 90 | if "mask" in batch[0]: 91 | masks = [i["mask"] for i in batch] 92 | masks = torch.stack(masks) 93 | out_batch["mask"] = masks 94 | 95 | return out_batch 96 | 97 | 98 | def filter_no_caption_or_no_image(sample): 99 | has_caption = "txt" in sample 100 | has_image = "png" in sample or "jpg" in sample or "jpeg" in sample 101 | return has_caption and has_image 102 | 103 | 104 | def chunkIt(seq, num): 105 | avg = len(seq) / float(num) 106 | out = [] 107 | last = 0.0 108 | 109 | while last < len(seq): 110 | out.append(seq[int(last) : int(last + avg)]) 111 | last += avg 112 | 113 | return out 114 | 115 | 116 | def get_wds_dataset(data_root, shard, transform, batch_size, collator=None, val=False): 117 | """ 118 | return a dataset that returns an image, and text 119 | """ 120 | if val == False: 121 | if "LOCAL_RANK" in os.environ: 122 | world_size = int(os.environ["WORLD_SIZE"]) 123 | local_rank = int(os.environ["LOCAL_RANK"]) 124 | else: 125 | world_size = 1 126 | local_rank = 0 127 | 128 | shard_split = chunkIt(shard, world_size) 129 | shard = shard_split[local_rank] 130 | 131 | total_shards = len(shard) 132 | shard = "{" + f"{shard[0]}..{shard[-1]}" + "}.tar" 133 | 134 | input_shards = os.path.join(data_root, shard) 135 | 136 | pipeline = [ 137 | wds.SimpleShardList(input_shards), 138 | # at this point we have an iterator over all the shards 139 | ] 140 | 141 | if val == False: 142 | pipeline.extend( 143 | [ 144 | wds.shuffle(bufsize=total_shards, initial=total_shards), 145 | ] 146 | ) 147 | 148 | pipeline.extend( 149 | [ 150 | wds.split_by_worker, 151 | # at this point, we have an iterator over the shards assigned to each worker 152 | wds.tarfile_to_samples(), 153 | wds.select(filter_no_caption_or_no_image), 154 | wds.decode("pilrgb"), 155 | wds.rename(image="jpg;png;jpeg", text="txt"), 156 | wds.map_dict(image=transform, text=lambda text: text), 157 | ] 158 | ) 159 | 160 | if val == False: 161 | pipeline.extend([wds.shuffle(100 * batch_size)]) 162 | 163 | pipeline.extend([wds.batched(batch_size, partial=False, collation_fn=collator)]) 164 | 165 | dataset = wds.DataPipeline(*pipeline) 166 | return dataset 167 | 168 | 169 | def get_loader( 170 | dtset, 171 | num_workers, 172 | tokenizer=None, 173 | nr_learnable_tokens=None, 174 | is_distributed=None, 175 | bs=None, 176 | is_train=None, 177 | ): 178 | if isinstance(dtset, MILANDataset): 179 | collator = CustomDataCollator(tokenizer, nr_learnable_tokens) 180 | 181 | return DataLoader( 182 | dtset, 183 | batch_size=bs, 184 | num_workers=num_workers, 185 | drop_last=is_train, 186 | collate_fn=collator, 187 | pin_memory=True, 188 | shuffle=(is_train) and not is_distributed, 189 | sampler=DistributedSampler(dtset) 190 | if (is_train) and is_distributed 191 | else None, 192 | ) 193 | else: 194 | return wds.WebLoader( 195 | dtset, 196 | batch_size=None, 197 | num_workers=num_workers, 198 | pin_memory=True, 199 | shuffle=False, 200 | ) 201 | 202 | 203 | def get_milan_transform(transform): 204 | tfs = [ToTensor()] 205 | for t in transform.transforms: 206 | if isinstance(t, Normalize): 207 | tfs.append(t) 208 | break 209 | 210 | target_transform = Compose(tfs) 211 | return target_transform 212 | 213 | 214 | class MILANDataset(Dataset): 215 | def __init__(self, dtset, root, split="train", transform=None, by_unit=False): 216 | super().__init__() 217 | self.dtset = KEYS[dtset] 218 | self.root = root 219 | self.split = split 220 | self.transform = transform 221 | self.by_unit = by_unit 222 | self.nr_imgs_per_unit = 15 223 | 224 | self.idx_to_info = {} 225 | self.global_id = 0 226 | if self.dtset in DATASET_GROUPINGS: 227 | if self.split == "train": 228 | dtset_list = DATASET_GROUPINGS[self.dtset] 229 | else: 230 | dtset_list = TRAIN_TEST_PAIRS[self.dtset] 231 | if dtset_list in DATASET_GROUPINGS: 232 | dtset_list = DATASET_GROUPINGS[dtset_list] 233 | else: 234 | dtset_list = [dtset_list] 235 | for name in dtset_list: 236 | self.get_single_dataset(name) 237 | elif self.dtset in WITHIN_NETWORK: 238 | self.get_single_dataset(self.dtset, self.split) 239 | else: 240 | raise NotImplementedError 241 | 242 | def get_single_dataset(self, name, split=None): 243 | anns_csv = os.path.join(os.path.join(self.root, name), "annotations.csv") 244 | anns = pd.read_csv(anns_csv) 245 | anns["summary"] = anns["summary"].astype(str) 246 | anns["layer"] = anns["layer"].astype(str) 247 | layers = sorted(list(anns["layer"].unique())) 248 | 249 | if split is not None: 250 | # process split file 251 | split_units = {} 252 | split_idxs = torch.load( 253 | os.path.join(self.root, name, f'{name.replace("/", "_")}-splits.pth') 254 | )[split] 255 | for item in split_idxs: 256 | if item["layer"] not in split_units: 257 | split_units[item["layer"]] = [] 258 | split_units[item["layer"]].append(item["unit"]) 259 | 260 | for l in layers: 261 | units = anns[anns["layer"] == l]["unit"].unique() 262 | for u in units: 263 | # ignore units that are not in this split 264 | if (split is not None) and (u not in split_units[l]): 265 | continue 266 | 267 | cpts = anns.loc[(anns["layer"] == l) & (anns["unit"] == u)] 268 | cpts = list(cpts["summary"].values) 269 | 270 | if self.by_unit: 271 | dct = {"dataset": name, "layer": l, "unit": u, "captions": cpts} 272 | self.idx_to_info[self.global_id] = dct 273 | self.global_id += 1 274 | else: 275 | # repeat information for all images of unit 276 | for i in range(self.nr_imgs_per_unit): 277 | dct = { 278 | "dataset": name, 279 | "layer": l, 280 | "unit": u, 281 | "imgid": i, 282 | "captions": cpts, 283 | } 284 | self.idx_to_info[self.global_id] = dct 285 | self.global_id += 1 286 | 287 | def __len__(self): 288 | return len(self.idx_to_info) 289 | 290 | def __getitem__(self, index): 291 | info = self.idx_to_info[index] 292 | dataset, layer, unit = info["dataset"], info["layer"], info["unit"] 293 | 294 | unit_imgs = np.load( 295 | os.path.join(self.root, dataset, layer, f"images_{unit}.npy") 296 | ) 297 | unit_masks = np.load( 298 | os.path.join(self.root, dataset, layer, f"masks_{unit}.npy") 299 | ) 300 | 301 | if self.by_unit == False: 302 | imgid = info["imgid"] 303 | unit_imgs = unit_imgs[imgid] 304 | unit_imgs = np.swapaxes(unit_imgs, 0, 1) 305 | unit_imgs = np.swapaxes(unit_imgs, 1, 2) 306 | if self.transform: 307 | unit_imgs = self.transform(unit_imgs) 308 | unit_masks = unit_masks[imgid] 309 | else: 310 | if self.transform: 311 | transformed = [] 312 | for img in unit_imgs: 313 | img = np.swapaxes(img, 0, 1) 314 | img = np.swapaxes(img, 1, 2) 315 | transformed.append(self.transform(img)) 316 | unit_imgs = torch.stack(transformed) 317 | 318 | mask = torch.from_numpy(unit_masks).float() 319 | 320 | captions = info["captions"] 321 | captions = choice(captions) 322 | 323 | if self.split == "train": 324 | return { 325 | "image": unit_imgs, 326 | "mask": mask, 327 | "text": captions, 328 | "id": str(index), 329 | } 330 | return { 331 | "image": unit_imgs, 332 | "mask": mask, 333 | "text": captions, 334 | "raw": info["captions"], 335 | "id": str(index), 336 | } 337 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import random 5 | from argparse import Namespace # needed for reading saved argparse parameters 6 | 7 | import torch 8 | import yaml 9 | from torch.distributed import destroy_process_group, init_process_group 10 | from transformers import AutoConfig 11 | 12 | from dataset import ( 13 | CustomDataCollator, 14 | MILANDataset, 15 | get_loader, 16 | get_milan_transform, 17 | get_wds_dataset, 18 | ) 19 | from evaluation import eval_nlp, eval_qualitative 20 | from milan_keys import KEYS 21 | from model import Features2WordsModel, TranslationTransformerConfig 22 | from trainer import Trainer 23 | 24 | random.seed(0) 25 | torch.manual_seed(0) 26 | 27 | KEYS = list(KEYS.keys()) 28 | 29 | DATASETS = ["cc3m"] 30 | DATASETS.extend(KEYS) 31 | 32 | 33 | def ddp_setup(): 34 | init_process_group(backend="nccl") 35 | 36 | 37 | def _create_folder(args): 38 | folder = args.logdir 39 | 40 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 41 | vision_backbone = args.vision_backbone.replace("/", "_") 42 | language_model = args.language_model.split("/")[-1] 43 | vision_feat_layers = ( 44 | str(args.vision_feat_layers).replace("[", "").replace("]", "").replace(", ", "") 45 | ) 46 | 47 | fname = "" 48 | fname += f"{vision_backbone}_{language_model}_layers_{vision_feat_layers}" 49 | if args.token_dropout > 0: 50 | fname += f"_dropout_{args.token_dropout}" 51 | if args.feature_dropout > 0: 52 | fname += f"_featdropout_{args.feature_dropout}" 53 | fname += f"_tokens_{args.nr_learnable_tokens}_{args.dataset}_{timestamp}" 54 | results_path = os.path.join(folder, fname) 55 | 56 | if not os.path.exists(results_path): 57 | os.makedirs(results_path) 58 | 59 | return results_path 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser(description="Arguments to run the script.") 64 | 65 | parser.add_argument( 66 | "--num_workers", type=int, default=8, help="Number of workers for dataloader." 67 | ) 68 | parser.add_argument( 69 | "--fp16", 70 | dest="fp16", 71 | action="store_true", 72 | help="Use 16-bit floating-point precision.", 73 | ) 74 | parser.add_argument( 75 | "--no-fp16", 76 | dest="fp16", 77 | action="store_false", 78 | help="Do not use 16-bit floating-point precision.", 79 | ) 80 | parser.set_defaults(fp16=False) 81 | 82 | # Directories and paths 83 | parser.add_argument( 84 | "--logdir", 85 | type=str, 86 | default="./results", 87 | help="Directory where logs and models are to be stored.", 88 | ) 89 | 90 | # Actions 91 | parser.add_argument( 92 | "--do_eval_nlp", 93 | action="store_true", 94 | help="Evaluate model given by model_ckpt on nlp metrics.", 95 | ) 96 | parser.add_argument( 97 | "--do_eval_qualitative", 98 | action="store_true", 99 | help="Evaluate model given by model_ckpt qualitatively.", 100 | ) 101 | 102 | # Dataset 103 | parser.add_argument( 104 | "--dataset", 105 | type=str, 106 | default="cc3m", 107 | choices=DATASETS, 108 | help="Dataset.", 109 | ) 110 | parser.add_argument( 111 | "--by_unit", 112 | action="store_true", 113 | help="Use MILAN datasets by neuron unit (instead of by image).", 114 | ) 115 | parser.add_argument( 116 | "--data_root", 117 | type=str, 118 | default="./data", 119 | help="Root directory of the dataset.", 120 | ) 121 | 122 | # Model 123 | parser.add_argument( 124 | "--nr_layers", 125 | type=int, 126 | default=12, 127 | help="Number of hidden layers in transformer translation model.", 128 | ) 129 | parser.add_argument( 130 | "--nr_heads", 131 | type=int, 132 | default=12, 133 | help="Number of heads per attention layer.", 134 | ) 135 | parser.add_argument( 136 | "--intermediate_size", 137 | type=int, 138 | default=3072, 139 | help="Translation transformer intermediate size.", 140 | ) 141 | parser.add_argument( 142 | "--nr_learnable_tokens", 143 | type=int, 144 | default=10, 145 | help="Number of learnable tokens in transformer translation model.", 146 | ) 147 | parser.add_argument( 148 | "--language_model", 149 | type=str, 150 | default="facebook/opt-125m", 151 | choices=["gpt2", "facebook/opt-125m"], 152 | help="LM used to produce words.", 153 | ) 154 | parser.add_argument( 155 | "--vision_backbone", 156 | type=str, 157 | default="timm_resnet50", 158 | help="Vision backbone from which image features are produced.", 159 | ) 160 | parser.add_argument( 161 | "--vision_feat_func", 162 | type=str, 163 | default="avg_pooling", 164 | choices=["none", "avg_pooling"], 165 | help="Function applied to vision features.", 166 | ) 167 | parser.add_argument( 168 | "--vision_feat_layers", 169 | nargs="+", 170 | type=int, 171 | default=[-1], 172 | help=( 173 | "List of feature layers (indexed by order) of the vision model " 174 | "to pass to the transformer translation model." 175 | ), 176 | ) 177 | parser.add_argument( 178 | "--token_dropout", 179 | type=float, 180 | default=0.5, 181 | help="Dropout probability for vision tokens in the translation model.", 182 | ) 183 | parser.add_argument( 184 | "--feature_dropout", 185 | type=float, 186 | default=0.5, 187 | help="Dropout probability for vision features in the backbone model.", 188 | ) 189 | parser.add_argument( 190 | "--add_vit_embed_token", 191 | action="store_true", 192 | help="Extracts ViT embedding token in addition to spatial tokens.", 193 | ) 194 | parser.add_argument( 195 | "--only_vit_embed_token", 196 | action="store_true", 197 | help="Extract only ViT embedding token. (Only supported for Clip ViT atm).", 198 | ) 199 | 200 | # Training 201 | parser.add_argument( 202 | "--model_ckpt", 203 | type=str, 204 | default=None, 205 | help="Pretrained model.", 206 | ) 207 | parser.add_argument( 208 | "--resume", 209 | dest="resume", 210 | action="store_true", 211 | help="Resume model training.", 212 | ) 213 | parser.add_argument( 214 | "--max_steps", 215 | type=int, 216 | default=1000000, 217 | help="Maximum number of training steps.", 218 | ) 219 | parser.add_argument( 220 | "--eval_steps", 221 | type=int, 222 | default=10000, 223 | help="Number of steps between evaluations.", 224 | ) 225 | parser.add_argument( 226 | "--warmup_steps", 227 | type=int, 228 | default=5000, 229 | help="Number of warmup steps for optimizer.", 230 | ) 231 | parser.add_argument( 232 | "--bs", 233 | type=int, 234 | default=64, 235 | help="Batch size.", 236 | ) 237 | parser.add_argument( 238 | "--lr", 239 | type=float, 240 | default=1e-4, 241 | help="Learning rate.", 242 | ) 243 | parser.add_argument( 244 | "--min_lr", 245 | type=float, 246 | default=1e-6, 247 | help="Minimum learning rate.", 248 | ) 249 | parser.add_argument( 250 | "--decay", 251 | type=float, 252 | default=1e-6, 253 | help="Learning rate decay.", 254 | ) 255 | 256 | # NLP eval 257 | parser.add_argument( 258 | "--layer", 259 | type=int, 260 | help="Layer id (eg. -1) for NLP eval." 261 | "If left None will compute metrics with features for all layers." 262 | "Only applicable to models trained with features for more than one layer.", 263 | ) 264 | 265 | # Qualitative eval 266 | parser.add_argument( 267 | "--loc_ids", 268 | type=yaml.safe_load, 269 | help="Location ids per layer for qualitative eval." 270 | "Example: {-1: [[0, 0], [1, 6]], -3: []} (given between str quotes) will generate:" 271 | "- description for full image with features from all layers" 272 | "- description for full image with features from layers -1 and -3" 273 | "- description for image locations (0,0) and (1,6) for layer -1" 274 | "if, for example, -1: [[-1, -1]] is given it will generate descriptions for all locations in layer -1", 275 | ) 276 | parser.add_argument( 277 | "--pool_locs", 278 | type=str, 279 | default="None", 280 | choices=["reduce_dims", "keep_dims"], 281 | help="Pool lower layer locations (available for eval_qualitative).", 282 | ) 283 | parser.add_argument( 284 | "--kernel_size", 285 | type=int, 286 | default=3, 287 | help="Kernel size for pooling lower layer locations.", 288 | ) 289 | 290 | # Generation 291 | parser.add_argument( 292 | "--max_length", 293 | type=int, 294 | default=100, 295 | help="Max length of generated sequence.", 296 | ) 297 | 298 | # Logging 299 | parser.add_argument( 300 | "--wandb_online", 301 | action="store_true", 302 | help="Start WANDB in online sync mode.", 303 | ) 304 | 305 | args = parser.parse_args() 306 | 307 | args.milan = args.dataset in KEYS 308 | 309 | vargs = vars(args) 310 | n_do_evals = sum([vargs[arg] for arg in vargs.keys() if arg.startswith("do_eval")]) 311 | do_train = n_do_evals == 0 312 | 313 | if n_do_evals > 0: 314 | assert ( 315 | args.model_ckpt is not None 316 | ), "Please provide a checkpoint to perform evaluation." 317 | 318 | if args.do_eval_qualitative: 319 | assert args.dataset == "cc3m", "Invalid dataset for do_eval_qualitative." 320 | 321 | if args.pool_locs != "None": 322 | assert ( 323 | args.layer is not None and args.layer != "-1" 324 | ), "pool_locs is only available for layers other than -1." 325 | 326 | if args.only_vit_embed_token: 327 | assert ( 328 | args.vision_backbone == "clip_ViT-B/32" 329 | ), "only_vit_embed_token is only supported for Clip ViT models." 330 | 331 | torch.set_num_threads(max(1, args.num_workers)) 332 | torch.set_num_interop_threads(max(1, args.num_workers)) 333 | if "LOCAL_RANK" in os.environ and do_train: 334 | ddp_setup() 335 | is_distributed = True 336 | else: 337 | is_distributed = False 338 | 339 | if not args.wandb_online: 340 | os.environ["WANDB_MODE"] = "offline" 341 | 342 | # ensure that vision_feat_layers are in decreasing order [-1, -2, -3] 343 | if len(args.vision_feat_layers) > 1: 344 | args.vision_feat_layers = sorted(args.vision_feat_layers)[::-1] 345 | 346 | if args.only_vit_embed_token: 347 | args.add_vit_embed_token = True 348 | 349 | if do_train: 350 | path = _create_folder(args) 351 | 352 | with open(os.path.join(path, "train_params.txt"), "w") as f: 353 | f.write(str(args)) 354 | 355 | translation_config = TranslationTransformerConfig( 356 | intermediate_size=args.intermediate_size, 357 | num_hidden_layers=args.nr_layers, 358 | num_attention_heads=args.nr_heads, 359 | nr_learnable_tokens=args.nr_learnable_tokens, 360 | token_dropout=args.token_dropout, 361 | add_vit_embed_token=args.add_vit_embed_token, 362 | ) 363 | 364 | else: 365 | translation_config = TranslationTransformerConfig.from_json_file( 366 | args.model_ckpt + "translation_model_config.json" 367 | ) 368 | lm_model_config = AutoConfig.from_pretrained( 369 | args.model_ckpt + "lm_model_config.json" 370 | ) 371 | args.max_length = lm_model_config.max_length 372 | 373 | with open(os.path.join(args.model_ckpt, "../train_params.txt"), "r") as f: 374 | namespace_str = f.read() 375 | train_args = eval(namespace_str) 376 | args.vision_backbone = train_args.vision_backbone 377 | args.vision_feat_func = train_args.vision_feat_func 378 | args.vision_feat_layers = train_args.vision_feat_layers 379 | args.language_model = train_args.language_model 380 | if "feature_dropout" in train_args: 381 | args.feature_dropout = train_args.feature_dropout 382 | if "by_unit" in train_args: 383 | args.by_unit = train_args.by_unit 384 | if "only_vit_embed_token" in train_args: 385 | args.only_vit_embed_token = train_args.only_vit_embed_token 386 | 387 | if args.do_eval_nlp and len(args.vision_feat_layers) == 1: 388 | assert ( 389 | args.layer is None 390 | ), "A layer different from None can only be given when the model has been trained on features from more than one layer." 391 | 392 | model = Features2WordsModel( 393 | translation_config=translation_config, 394 | cnn_name=args.vision_backbone, 395 | vision_feat_func=args.vision_feat_func, 396 | vision_feat_layers=args.vision_feat_layers, 397 | lm_model_name=args.language_model, 398 | max_length=args.max_length, 399 | feature_dropout=args.feature_dropout, 400 | only_vit_embed_token=args.only_vit_embed_token, 401 | ) 402 | 403 | test_dtset = None 404 | if args.milan: 405 | transform = get_milan_transform(model.transform) 406 | if do_train: 407 | tr_dtset = MILANDataset( 408 | args.dataset, 409 | os.path.join(args.data_root, "milan", "data"), 410 | split="train", 411 | transform=transform, 412 | by_unit=args.by_unit, 413 | ) 414 | val_dtset = MILANDataset( 415 | args.dataset, 416 | os.path.join(args.data_root, "milan", "data"), 417 | split="test", 418 | transform=transform, 419 | by_unit=args.by_unit, 420 | ) 421 | else: 422 | collator = CustomDataCollator(model.tokenizer, model.nr_learnable_tokens) 423 | 424 | train_shard = [f"{i:05d}" for i in range(332)] 425 | val_shard = [f"{i:05d}" for i in range(2)] 426 | train_path = "cc3m/cc3m/train" 427 | val_path = "cc3m/cc3m/valid" 428 | 429 | if do_train: 430 | tr_dtset = get_wds_dataset( 431 | os.path.join(args.data_root, train_path), 432 | train_shard, 433 | model.transform, 434 | args.bs, 435 | collator=collator, 436 | ) 437 | val_dtset = get_wds_dataset( 438 | os.path.join(args.data_root, val_path), 439 | val_shard, 440 | model.transform, 441 | args.bs, 442 | collator=collator, 443 | val=True, 444 | ) 445 | 446 | if do_train: 447 | # Trainer 448 | trainer = Trainer(model, tr_dtset, val_dtset, path, args, is_distributed) 449 | 450 | trainer.train() 451 | else: 452 | model_checkpoint = torch.load(os.path.join(args.model_ckpt, "model.pt")) 453 | model.translation.load_state_dict(model_checkpoint["MODEL_STATE"]) 454 | model.eval() 455 | model.to(0) 456 | print(f'Loaded model from {os.path.join(args.model_ckpt, "model.pt")}') 457 | 458 | val_loader = get_loader( 459 | val_dtset, 460 | args.num_workers, 461 | tokenizer=model.tokenizer, 462 | nr_learnable_tokens=model.nr_learnable_tokens, 463 | is_distributed=False, 464 | bs=args.bs, 465 | is_train=False, 466 | ) 467 | if test_dtset is not None: 468 | test_loader = get_loader( 469 | test_dtset, 470 | args.num_workers, 471 | tokenizer=model.tokenizer, 472 | nr_learnable_tokens=model.nr_learnable_tokens, 473 | is_distributed=False, 474 | bs=args.bs, 475 | is_train=False, 476 | ) 477 | 478 | if args.pool_locs == "None": 479 | args.pool_locs = None 480 | 481 | if args.do_eval_nlp: 482 | outname = args.dataset 483 | if outname in KEYS: 484 | if args.by_unit: 485 | outname += "_by_unit" 486 | else: 487 | outname += "_by_imgs" 488 | eval_nlp( 489 | model, 490 | val_loader, 491 | args.model_ckpt, 492 | outname=outname, 493 | add_spice=True, 494 | layer=args.layer, 495 | viz=False, 496 | milan=args.milan, 497 | ) 498 | if test_dtset is not None: 499 | outname += "_TEST" 500 | eval_nlp( 501 | model, 502 | test_loader, 503 | args.model_ckpt, 504 | outname=outname, 505 | add_spice=True, 506 | layer=args.layer, 507 | viz=False, 508 | milan=args.milan, 509 | ) 510 | if args.do_eval_qualitative: 511 | if args.loc_ids is not None: 512 | for k, v in args.loc_ids.items(): 513 | args.loc_ids[k] = [tuple(l) for l in v] 514 | foldername = f"{args.dataset}_feat_select_results" 515 | if args.pool_locs is not None: 516 | foldername += f"_{args.pool_locs}_ks{args.kernel_size}" 517 | eval_qualitative( 518 | model, 519 | val_loader, 520 | os.path.join(args.model_ckpt, foldername), 521 | loc_ids=args.loc_ids, 522 | pool_locs=args.pool_locs, 523 | kernel_size=args.kernel_size, 524 | ) 525 | 526 | if is_distributed: 527 | destroy_process_group() 528 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from itertools import product 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from evaluate import load 10 | from pycocoevalcap.bleu.bleu import Bleu 11 | from pycocoevalcap.cider.cider import Cider 12 | from pycocoevalcap.meteor.meteor import Meteor 13 | from pycocoevalcap.rouge.rouge import Rouge 14 | from pycocoevalcap.spice.spice import Spice 15 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 16 | from torchvision.transforms import ToPILImage 17 | from torchvision.utils import make_grid 18 | from tqdm import tqdm 19 | 20 | 21 | class RunningMean: 22 | def __init__(self): 23 | self.value = 0 24 | self.cnt = 0 25 | 26 | def update(self, val, n): 27 | if self.cnt == 0: 28 | self.value = val 29 | self.cnt = n 30 | else: 31 | self.cnt += n 32 | ratio = n / self.cnt 33 | self.value += (val - self.value) * ratio 34 | 35 | def get(self): 36 | return self.value 37 | 38 | 39 | def grid(n, dim): 40 | return sorted(list(set(product(range(n), repeat=dim)))) 41 | 42 | 43 | def compute_bert_score(res, gts, milan=False, return_individual=False): 44 | scorer = load("bertscore") # reset metric 45 | assert len(res) == len(gts) 46 | 47 | hyps = [] 48 | refs = [] 49 | for imgid in res.keys(): 50 | hyps.append(res[imgid].lower()) 51 | # support for multiple references per image 52 | refs.append([r.lower() for r in gts[imgid]]) 53 | 54 | if milan: 55 | results = scorer.compute( 56 | predictions=hyps, 57 | references=refs, 58 | lang="en", 59 | idf=True, 60 | rescale_with_baseline=True, 61 | use_fast_tokenizer=True, 62 | ) 63 | else: 64 | results = scorer.compute(predictions=hyps, references=refs, lang="en") 65 | 66 | if return_individual: 67 | return sum(results["f1"]) / len(results["f1"]), results["f1"] 68 | return sum(results["f1"]) / len(results["f1"]) 69 | 70 | 71 | def compute_coco_metrics(res, gts, add_spice=False, return_individual=False): 72 | def setImgToEvalImgs(imgToEval, scores, imgIds, method): 73 | for imgId, score in zip(imgIds, scores): 74 | if not imgId in imgToEval: 75 | imgToEval[imgId] = {} 76 | imgToEval[imgId]["image_id"] = imgId 77 | imgToEval[imgId][method] = score 78 | 79 | def setEvalImgs(imgToEval): 80 | return [eval for imgId, eval in imgToEval.items()] 81 | 82 | # ================================================= 83 | # Set up scorers 84 | # ================================================= 85 | print("tokenization...") 86 | tokenizer = PTBTokenizer() 87 | gts = tokenizer.tokenize(gts) 88 | res = tokenizer.tokenize(res) 89 | 90 | # ================================================= 91 | # Set up scorers 92 | # ================================================= 93 | print("setting up scorers...") 94 | scorers = [ 95 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 96 | (Meteor(), "METEOR"), 97 | (Rouge(), "ROUGE_L"), 98 | (Cider(), "CIDEr"), 99 | ] 100 | if add_spice: 101 | scorers.append((Spice(), "SPICE")) 102 | 103 | # ================================================= 104 | # Compute scores 105 | # ================================================= 106 | eval_dict = {} 107 | if return_individual: 108 | imgToEval = {} 109 | for scorer, method in scorers: 110 | print("computing %s score..." % (scorer.method())) 111 | score, scores = scorer.compute_score(gts, res) 112 | if type(method) == list: 113 | for sc, scs, m in zip(score, scores, method): 114 | eval_dict[m] = sc 115 | if return_individual: 116 | setImgToEvalImgs(imgToEval, scs, gts.keys(), m) 117 | print("%s: %0.3f" % (m, sc)) 118 | else: 119 | eval_dict[method] = score 120 | if return_individual: 121 | setImgToEvalImgs(imgToEval, scores, gts.keys(), method) 122 | print("%s: %0.3f" % (method, score)) 123 | 124 | if return_individual: 125 | return eval_dict, setEvalImgs(imgToEval) 126 | return eval_dict 127 | 128 | 129 | @torch.no_grad() 130 | def eval_nlp( 131 | model, 132 | eval_loader, 133 | save_dir, 134 | gpu_id=0, 135 | is_distributed=False, 136 | outname=None, 137 | add_spice=False, 138 | layer=None, 139 | viz=True, 140 | milan=False, 141 | ): 142 | os.makedirs(save_dir, exist_ok=True) 143 | model.eval() 144 | 145 | avg_loss = RunningMean() 146 | gt_captions_coco = {} 147 | pred_captions_coco = {} 148 | gt_captions = {} 149 | pred_captions = {} 150 | 151 | pbar = tqdm(eval_loader, file=sys.stdout) 152 | pbar.set_description("validating") 153 | 154 | plot_data = [] 155 | for step, batch in enumerate(pbar): 156 | ids = batch["id"] 157 | gt_strs = batch["raw"] 158 | 159 | forbidden_keys = ["id", "raw"] 160 | batch = {k: v.to(gpu_id) for k, v in batch.items() if k not in forbidden_keys} 161 | 162 | loss = model(**batch) 163 | token_cnt = (batch["gt_captions"][:, 1:] != -100).sum() 164 | 165 | avg_loss.update(loss.item(), token_cnt.item()) 166 | 167 | # for NLG metrics computation 168 | mask = batch["mask"].to(gpu_id) if milan else None 169 | if is_distributed: 170 | gen_sents = model.module.generate( 171 | batch["pixel_values"], mask=mask, layer=layer 172 | ) 173 | pred_strs = model.module.tokenizer.batch_decode( 174 | gen_sents, skip_special_tokens=True 175 | ) 176 | else: 177 | gen_sents = model.generate(batch["pixel_values"], mask=mask, layer=layer) 178 | pred_strs = model.tokenizer.batch_decode( 179 | gen_sents, skip_special_tokens=True 180 | ) 181 | 182 | for out_idx in range(len(ids)): 183 | gen = pred_strs[out_idx] 184 | gt = gt_strs[out_idx] 185 | img_id = ids[out_idx] 186 | 187 | gt_captions_coco[img_id] = [{"image_id": img_id, "caption": g} for g in gt] 188 | gt_captions[img_id] = gt 189 | 190 | pred_captions_coco[img_id] = [{"image_id": img_id, "caption": gen}] 191 | pred_captions[img_id] = gen 192 | 193 | if ( 194 | viz and (step == 0 or step == 5) and out_idx < 10 195 | ): # log first 10 images of first/6th batch 196 | og_img = batch["pixel_values"][out_idx].cpu() 197 | grid = make_grid([og_img], padding=5, pad_value=1, nrow=2) 198 | image_title = f"image_{img_id}" 199 | cpt = "GT: " + gt[0] + "\nGen: " + gen 200 | plot_data.append((image_title, grid, cpt, step)) 201 | 202 | if outname is not None: 203 | save_dir = os.path.join(save_dir, outname + "-") 204 | else: 205 | if not save_dir.endswith("/"): 206 | save_dir += "/" 207 | 208 | if layer is None: 209 | save_dir += "all_features_" 210 | else: 211 | save_dir += f"layer{layer}_features_" 212 | 213 | eval_results = {} 214 | eval_results = compute_coco_metrics( 215 | pred_captions_coco, gt_captions_coco, add_spice=add_spice 216 | ) 217 | 218 | to_save = pred_captions 219 | if milan: 220 | new_pred_captions = {} 221 | for imgid, caption in pred_captions.items(): 222 | imginfo = eval_loader.dataset.idx_to_info[int(imgid)] 223 | dataset, layer, unit = imginfo["dataset"], imginfo["layer"], imginfo["unit"] 224 | new_pred_captions[imgid] = { 225 | "dataset": dataset, 226 | "layer": str(layer), 227 | "neuron": str(unit), 228 | "caption": caption, 229 | } 230 | to_save = new_pred_captions 231 | 232 | with open(save_dir + "val-captions.json", "w") as f: 233 | json.dump(to_save, f) 234 | 235 | bertscore = compute_bert_score(pred_captions, gt_captions, milan=milan) 236 | print(f"BERTScore: {bertscore}") 237 | 238 | eval_results["BERTScore"] = bertscore 239 | json.dump(eval_results, open(save_dir + "val-metrics-overall.json", "w")) 240 | model.train() 241 | 242 | return avg_loss.get(), eval_results, plot_data 243 | 244 | 245 | @torch.no_grad() 246 | def eval_qualitative( 247 | model, eval_loader, save_dir, gpu_id=0, loc_ids=None, pool_locs=None, kernel_size=3 248 | ): 249 | model.eval() 250 | os.makedirs(save_dir, exist_ok=True) 251 | convert_to_pil = ToPILImage() 252 | grid_sizes = model.grid_sizes[::-1] 253 | 254 | stride = 1 255 | if loc_ids is None: 256 | layers = model.vision_feat_layers 257 | loc_ids = {} 258 | for i in range(len(layers)): 259 | l = layers[i] 260 | if i > 0 and pool_locs == "reduce_dims": 261 | loc_ids[l] = loc_ids[layers[0]].copy() 262 | stride = int(gs[l] / gs[-1]) 263 | else: 264 | gs = grid_sizes[l] 265 | loc_ids[l] = grid(gs, 2) 266 | else: 267 | layers = list(loc_ids.keys()) 268 | for i in range(len(layers)): 269 | l = layers[i] 270 | if l not in model.vision_feat_layers: 271 | raise ValueError( 272 | f"Layer {l} not in layers that model trained with ({model.vision_feat_layers})" 273 | ) 274 | 275 | if len(loc_ids[l]) > 0: 276 | if loc_ids[l][0] == (-1, -1): 277 | if i > 0 and pool_locs == "reduce_dims": 278 | loc_ids[l] = loc_ids[layers[0]].copy() 279 | stride = int(gs[l] / gs[-1]) 280 | else: 281 | gs = grid_sizes[l] 282 | loc_ids[l] = grid(gs, 2) 283 | 284 | for l in layers: 285 | if len(loc_ids[l]) > 0: 286 | gs = grid_sizes[l] 287 | max_dim0 = max(loc_ids[l], key=lambda item: item[0]) 288 | max_dim1 = max(loc_ids[l], key=lambda item: item[1]) 289 | if max(max_dim0) >= gs: 290 | raise ValueError( 291 | f"Location {max_dim0} is out of bounds for grid size {gs}" 292 | ) 293 | if max(max_dim1) >= gs: 294 | raise ValueError( 295 | f"Location {max_dim1} is out of bounds for grid size {gs}" 296 | ) 297 | 298 | for step, batch in tqdm(enumerate(eval_loader)): 299 | pixel_values = batch["pixel_values"].to(gpu_id) 300 | bs = pixel_values.shape[0] 301 | 302 | # TODO: adapt for MILAN 303 | 304 | # get full image caption with features from all layers 305 | print("Generating full image description with features from all layers") 306 | gen = model.generate(pixel_values=pixel_values) 307 | all_layers_full_gen = model.tokenizer.batch_decode( 308 | gen, skip_special_tokens=True 309 | ) 310 | 311 | # model was trained on features from more than one layer 312 | # get full image caption with layer-wise features 313 | layer_full_gens = {} 314 | for layer in layers: 315 | print(f"Generating full image description with features from layer {layer}") 316 | gen = model.generate(pixel_values=pixel_values, layer=layer) 317 | layer_full_gens[layer] = model.tokenizer.batch_decode( 318 | gen, skip_special_tokens=True 319 | ) 320 | 321 | # spatial captions 322 | layer_partial_gens = {} 323 | for l in layers: 324 | partial_gens = {} 325 | print( 326 | f"Generating {len(loc_ids[l])} partial image descriptions with features from layer {l}" 327 | ) 328 | for t in loc_ids[l]: 329 | gen = model.generate( 330 | pixel_values=pixel_values, 331 | layer=l, 332 | feat_index=t, 333 | pool_locs=pool_locs, 334 | kernel_size=kernel_size, 335 | stride=stride, 336 | ) 337 | gen_str = model.tokenizer.batch_decode(gen, skip_special_tokens=True) 338 | partial_gens[t] = gen_str 339 | layer_partial_gens[l] = partial_gens 340 | 341 | # save results for each image in batch 342 | for bid in range(bs): 343 | img_id = batch["id"][bid] 344 | single_pixel_values = pixel_values[bid] 345 | fi_cpt_all_feat = all_layers_full_gen[bid].replace("\n", " ") 346 | if "raw" in batch: 347 | captions = batch["raw"][bid] 348 | 349 | # just resize and crop (do not normalize) to save image 350 | img_to_save = model.unnormalize(single_pixel_values) 351 | img_to_save = convert_to_pil(img_to_save) 352 | img_to_save.save(os.path.join(save_dir, str(img_id) + ".png")) 353 | 354 | text_file = open(os.path.join(save_dir, str(img_id) + ".txt"), "w") 355 | 356 | if "raw" in batch: 357 | print("GT Captions:\n", file=text_file) 358 | for c in captions: 359 | print(f"\t{c}\n", file=text_file) 360 | print("\n", file=text_file) 361 | 362 | # save full image caption with all features 363 | print( 364 | f"Full gen caption (with all features): {fi_cpt_all_feat}\n", 365 | file=text_file, 366 | ) 367 | 368 | # save full image caption with layer-wise features 369 | for k, v in layer_full_gens.items(): 370 | v_bid = v[bid].replace("\n", "") 371 | print(f"Full gen caption (layer {k}): {v_bid}\n", file=text_file) 372 | 373 | # save partial captions per layer 374 | print("Partial captions:\n", file=text_file) 375 | for k, v in layer_partial_gens.items(): 376 | if len(loc_ids[k]) > 0: 377 | print(f"\tLayer {k}:\n", file=text_file) 378 | for t in loc_ids[k]: 379 | pt_cpt = v[t][bid].replace("\n", " ") 380 | print(f"\t\t({t[0]}, {t[1]}): {pt_cpt}\n", file=text_file) 381 | 382 | text_file.close() 383 | 384 | if (step + 1) * bs > 100: 385 | return 386 | 387 | 388 | def prepare_input_ids(texts, tokenizer, add_prompt=False): 389 | tokens = tokenizer( 390 | texts, 391 | padding="longest", 392 | truncation=True, 393 | max_length=100, 394 | return_tensors="pt", 395 | ) 396 | 397 | input_ids = tokens.input_ids 398 | att_masks = tokens.attention_mask 399 | 400 | if add_prompt: 401 | prompt_id = tokenizer("the").input_ids 402 | assert len(prompt_id) == 2 403 | prompt_id = prompt_id[1] 404 | 405 | input_ids = torch.cat([input_ids[:, :1], input_ids], dim=1) 406 | input_ids[:, 1] = prompt_id 407 | att_masks = torch.cat([att_masks[:, :1], att_masks], dim=1) 408 | att_masks[:, 1] = 0 409 | 410 | gts_cap = input_ids.clone() 411 | gts_cap = torch.where( 412 | att_masks == 0, -100, gts_cap 413 | ) # ignore padding tokens in loss 414 | 415 | return input_ids, att_masks, gts_cap 416 | 417 | 418 | @torch.no_grad() 419 | def get_saliency_map( 420 | model, 421 | loc_ids, 422 | pixel_values, 423 | labels, 424 | layer_id, 425 | gpu_id=0, 426 | add_prompt=False, 427 | hres=False, 428 | pool_locs=None, 429 | kernel_size=3, 430 | stride=1, 431 | ): 432 | loss_fct = nn.CrossEntropyLoss(reduction="none") 433 | 434 | input_ids, attn_masks, gt_caps = prepare_input_ids( 435 | labels, model.tokenizer, add_prompt=add_prompt 436 | ) 437 | 438 | pixel_values = pixel_values.to(gpu_id) 439 | bs = pixel_values.shape[0] 440 | 441 | # compute image features 442 | # TODO: adapt for MILAN 443 | img_features_cnn, forced_token_mask = model.forward_vision_model( 444 | pixel_values, just_cnn_features=True 445 | ) 446 | 447 | dim_lowest_layer = img_features_cnn[-1].shape[ 448 | 2: 449 | ] # get shape of lowest layer (usually the third) 450 | if layer_id is None: 451 | # interpolate higher layers to dim of lowest layer 452 | for i in range(len(img_features_cnn[:-1])): 453 | img_features_cnn[i] = F.interpolate( 454 | img_features_cnn[i], tuple(dim_lowest_layer), mode="bilinear" 455 | ) 456 | else: 457 | img_features_layer = img_features_cnn[layer_id] 458 | if pool_locs is not None: 459 | img_features_layer = F.avg_pool2d( 460 | img_features_layer, 461 | kernel_size=kernel_size, 462 | padding=int((kernel_size - 1) / 2), 463 | stride=stride, 464 | count_include_pad=False, 465 | ) 466 | elif hres: 467 | img_features_layer = F.interpolate( 468 | img_features_layer, tuple(dim_lowest_layer), mode="bilinear" 469 | ) 470 | 471 | loss_per_loc = [] 472 | logits_per_loc = [] 473 | for loc in tqdm(loc_ids, desc="Location"): 474 | if layer_id is None: 475 | img_features = [] 476 | for i in range(len(img_features_cnn)): 477 | feats = img_features_cnn[i][:, :, loc[0], loc[1]] 478 | feats = feats[:, :, None, None] 479 | img_features.append(feats.to(gpu_id)) 480 | token_mask = None 481 | else: 482 | feats = img_features_layer[:, :, loc[0], loc[1]] 483 | feats = feats[:, :, None, None] 484 | img_features = [ 485 | torch.zeros((bs, f.shape[1], 1, 1)).to(gpu_id) for f in img_features_cnn 486 | ] 487 | img_features[layer_id] = feats 488 | 489 | token_mask = torch.ones((bs, model.feat_seq_len)).to( 490 | gpu_id 491 | ) # mask all layer features 492 | token_mask[:, layer_id] = torch.zeros((bs)).to( 493 | gpu_id 494 | ) # unmask corresponding layer features 495 | 496 | token_mask = torch.cat( 497 | (token_mask, torch.zeros((bs, model.nr_learnable_tokens)).to(gpu_id)), 498 | dim=1, 499 | ) 500 | token_mask = token_mask.bool() 501 | 502 | # batch-wise feature encoding 503 | encoded_features = model.translation( 504 | img_features, token_mask=token_mask, forced_token_mask=forced_token_mask 505 | ) 506 | encoded_features = encoded_features[:, -model.nr_learnable_tokens :, :] 507 | 508 | input_ids = input_ids.to(gpu_id) 509 | inputs_embeds = model.embedding_weights(input_ids) 510 | inputs_embeds = torch.cat( 511 | (inputs_embeds[:, :1, :], encoded_features, inputs_embeds[:, 1:, :]), dim=1 512 | ) 513 | 514 | attention_mask = torch.cat( 515 | ( 516 | attn_masks[:, :1], 517 | torch.ones( 518 | (encoded_features.shape[0], model.nr_learnable_tokens), 519 | dtype=torch.long, 520 | ), 521 | attn_masks[:, 1:], 522 | ), 523 | dim=1, 524 | ) 525 | 526 | gt_captions = torch.cat( 527 | ( 528 | gt_caps[:, :1], 529 | torch.ones( 530 | (encoded_features.shape[0], model.nr_learnable_tokens), 531 | dtype=torch.long, 532 | ) 533 | * (-100), 534 | gt_caps[:, 1:], 535 | ), 536 | dim=1, 537 | ) 538 | 539 | # loss calculation begins 540 | out = model.lm_model( 541 | inputs_embeds=inputs_embeds, attention_mask=attention_mask.to(gpu_id) 542 | ) 543 | logits = out.logits[:, model.nr_learnable_tokens + 1 :, :] 544 | input_ids_wo_bos = input_ids[:, 1:].unsqueeze(2) 545 | logits = torch.gather(logits, 2, input_ids_wo_bos).squeeze(2) 546 | 547 | # shift so that tokens < n predict n 548 | shift_logits = out.logits[..., :-1, :].permute(0, 2, 1).contiguous() 549 | shift_labels = gt_captions[..., 1:].contiguous().to(gpu_id) 550 | # flatten the tokens 551 | loss = loss_fct(shift_logits, shift_labels) 552 | loss_per_loc.append(loss.sum(1)) 553 | logits_per_loc.append(logits.sum(1)) 554 | 555 | loss_per_loc = torch.stack(loss_per_loc, dim=1) 556 | logits_per_loc = torch.stack(logits_per_loc, dim=1) 557 | return loss_per_loc, logits_per_loc 558 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from transformers import LogitsProcessorList, PretrainedConfig, ViTPreTrainedModel 7 | from transformers.utils import logging 8 | 9 | from language_model import get_language_model 10 | from vision_models import get_vision_model 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | class TranslationTransformerConfig(PretrainedConfig): 16 | model_type = "translationtransformer" 17 | 18 | def __init__( 19 | self, 20 | hidden_size=768, 21 | num_hidden_layers=12, 22 | num_attention_heads=12, 23 | intermediate_size=3072, 24 | hidden_act="gelu", 25 | initializer_range=0.02, 26 | dropout_prob=0.0, 27 | layer_norm_eps=1e-12, 28 | nr_learnable_tokens=10, 29 | feat_seq_len=49, 30 | token_dropout=0.5, 31 | add_vit_embed_token=False, 32 | **kwargs, 33 | ): 34 | super().__init__(**kwargs) 35 | 36 | self.hidden_size = hidden_size 37 | self.num_hidden_layers = num_hidden_layers 38 | self.num_attention_heads = num_attention_heads 39 | self.intermediate_size = intermediate_size 40 | self.hidden_act = hidden_act 41 | self.initializer_range = initializer_range 42 | self.dropout_prob = dropout_prob 43 | self.layer_norm_eps = layer_norm_eps 44 | 45 | self.nr_learnable_tokens = nr_learnable_tokens 46 | self.feat_seq_len = feat_seq_len 47 | self.token_dropout = token_dropout 48 | self.add_vit_embed_token = add_vit_embed_token 49 | 50 | 51 | class TranslationTransformer(ViTPreTrainedModel): 52 | config_class = TranslationTransformerConfig 53 | 54 | def __init__(self, config): 55 | super().__init__(config) 56 | 57 | self.add_vit_embed_token = config.add_vit_embed_token 58 | self.learnable_inputs = nn.Parameter( 59 | torch.randn((1, config.nr_learnable_tokens, config.hidden_size)) * 0.1 60 | ) 61 | self.position_embeddings = nn.Parameter( 62 | torch.randn(1, config.feat_seq_len, config.hidden_size) * 0.1 63 | ) 64 | 65 | bernoulli_prob = torch.cat( 66 | [ 67 | torch.ones((1, config.feat_seq_len)) * config.token_dropout, 68 | torch.zeros(1, config.nr_learnable_tokens), 69 | ], 70 | dim=1, 71 | ) 72 | self.register_buffer("bernoulli_prob", bernoulli_prob) 73 | 74 | encoder_layer = nn.TransformerEncoderLayer( 75 | d_model=config.hidden_size, 76 | nhead=config.num_attention_heads, 77 | dim_feedforward=config.intermediate_size, 78 | dropout=config.dropout_prob, 79 | activation=config.hidden_act, 80 | layer_norm_eps=config.layer_norm_eps, 81 | batch_first=True, 82 | norm_first=True, 83 | ) 84 | 85 | self.encoder = nn.TransformerEncoder( 86 | encoder_layer, num_layers=config.num_hidden_layers 87 | ) 88 | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 89 | 90 | self.projections = nn.ModuleList() 91 | for i in range(len(config.img_embed_dims)): 92 | self.projections.append( 93 | nn.Linear(config.img_embed_dims[i], config.hidden_size) 94 | ) 95 | self.out_proj = nn.Linear(config.hidden_size, config.lm_embed_dim) 96 | 97 | # Initialize weights and apply final processing 98 | self.post_init() 99 | 100 | def _apply_token_mask_mode(self, token_mask, mode): 101 | n_tokens = len(self.projections) 102 | # mode == 0: mask all vit embed tokens, only allow spatial embeds 103 | m = mode == 0 104 | token_mask[m, n_tokens // 2 : n_tokens] = 1.0 105 | # mode == 1: mask all spatial embeds, only allow vit embed tokens 106 | m = mode == 1 107 | token_mask[m, : n_tokens // 2] = 1.0 108 | # mode == 2: mask per layer, keep masking of vit embed token and spatial same for each layer 109 | m = mode == 2 110 | token_mask[m, :n_tokens:2] = token_mask[m, 1:n_tokens:2] 111 | 112 | return token_mask 113 | 114 | def forward( 115 | self, 116 | img_feats, 117 | token_mask=None, 118 | forced_token_mask=None, 119 | ): 120 | if hasattr(self, "projections"): 121 | for i in range(len(img_feats)): 122 | feats = img_feats[i] 123 | dims = feats.shape 124 | feats = torch.reshape(feats, (dims[0], dims[1], dims[2] * dims[3])) 125 | feats = torch.permute(feats, (0, 2, 1)) 126 | img_feats[i] = self.projections[i](feats) 127 | 128 | img_feats = torch.cat(img_feats, dim=1) 129 | 130 | img_feats = img_feats + self.position_embeddings 131 | 132 | bs = img_feats.shape[0] 133 | expanded_learnable_inputs = self.learnable_inputs.expand(bs, -1, -1) 134 | inputs = torch.cat((img_feats, expanded_learnable_inputs), dim=1) 135 | 136 | assert token_mask is None or forced_token_mask is None 137 | if forced_token_mask is not None and ( 138 | (self.training and token_mask is None and self.bernoulli_prob.sum() == 0.0) 139 | or (not self.training) 140 | ): 141 | assert not forced_token_mask.all(dim=1).any() 142 | token_mask = torch.cat( 143 | [ 144 | forced_token_mask, 145 | torch.zeros( 146 | (forced_token_mask.shape[0], self.config.nr_learnable_tokens), 147 | device=img_feats.device, 148 | dtype=torch.bool, 149 | ), 150 | ], 151 | dim=1, 152 | ) 153 | elif self.training and token_mask is None and self.bernoulli_prob.sum() > 0.0: 154 | # sample token mask 155 | bernoulli_prob = self.bernoulli_prob.expand(bs, -1) 156 | token_mask = torch.bernoulli(bernoulli_prob) 157 | if self.add_vit_embed_token: 158 | mask_mode = torch.randint(3, size=(bs,)) 159 | token_mask = self._apply_token_mask_mode(token_mask, mask_mode) 160 | 161 | if forced_token_mask is not None: 162 | token_mask[:, : forced_token_mask.shape[1]] += forced_token_mask.float() 163 | token_mask = torch.clamp(token_mask, max=1.0) 164 | 165 | # make sure at least one token is not masked 166 | all_masked = token_mask.sum(dim=1) == img_feats.shape[1] 167 | while all_masked.sum() > 0: 168 | all_masked = all_masked.unsqueeze(1) 169 | new_token_mask = torch.bernoulli(bernoulli_prob) 170 | if self.add_vit_embed_token: 171 | token_mask = self._apply_token_mask_mode(token_mask, mask_mode) 172 | token_mask = ~all_masked * token_mask + all_masked * new_token_mask 173 | if forced_token_mask is not None: 174 | token_mask[ 175 | :, : forced_token_mask.shape[1] 176 | ] += forced_token_mask.float() 177 | token_mask = torch.clamp(token_mask, max=1.0) 178 | all_masked = token_mask.sum(dim=1) == img_feats.shape[1] 179 | 180 | token_mask = token_mask.bool() 181 | 182 | sequence_output = self.encoder(inputs, src_key_padding_mask=token_mask) 183 | 184 | sequence_output = self.layernorm(sequence_output) 185 | sequence_output = self.out_proj(sequence_output) 186 | 187 | return sequence_output 188 | 189 | 190 | class LMBase(nn.Module): 191 | def __init__(self, **kwargs): 192 | super().__init__(**kwargs) 193 | 194 | def greedy( 195 | self, 196 | input_ids: torch.LongTensor, 197 | visual_features: torch.FloatTensor, 198 | logits_processor: Optional[LogitsProcessorList] = None, 199 | max_length: Optional[int] = None, 200 | pad_token_id: Optional[int] = None, 201 | eos_token_id: Optional[int] = None, 202 | **model_kwargs, 203 | ): 204 | # init values 205 | logits_processor = ( 206 | logits_processor if logits_processor is not None else LogitsProcessorList() 207 | ) 208 | max_length = ( 209 | max_length if max_length is not None else self.lm_model.config.max_length 210 | ) 211 | pad_token_id = ( 212 | pad_token_id 213 | if pad_token_id is not None 214 | else self.lm_model.config.pad_token_id 215 | ) 216 | eos_token_id = ( 217 | eos_token_id 218 | if eos_token_id is not None 219 | else self.lm_model.config.eos_token_id 220 | ) 221 | 222 | # keep track of which sequences are already finished 223 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 224 | 225 | while True: 226 | # prepare model inputs 227 | model_inputs = self.lm_model.prepare_inputs_for_generation( 228 | input_ids, **model_kwargs 229 | ) 230 | if "attention_mask" in model_inputs: 231 | if model_inputs["attention_mask"] is not None: 232 | model_inputs["attention_mask"] = torch.ones( 233 | ( 234 | input_ids.shape[0], 235 | input_ids.shape[1] + self.nr_learnable_tokens, 236 | ), 237 | device=input_ids.device, 238 | dtype=torch.long, 239 | ) 240 | 241 | if visual_features is not None: 242 | inputs_embeds = self.embedding_weights(model_inputs["input_ids"]) 243 | inputs_embeds = torch.cat( 244 | (inputs_embeds[:, :1, :], visual_features, inputs_embeds[:, 1:, :]), 245 | dim=1, 246 | ) 247 | 248 | model_inputs.pop("input_ids") 249 | model_inputs["inputs_embeds"] = inputs_embeds 250 | 251 | # forward pass to get next token 252 | outputs = self.lm_model( 253 | **model_inputs, 254 | return_dict=True, 255 | output_hidden_states=True, 256 | ) 257 | 258 | next_token_logits = outputs.logits[:, -1, :] 259 | 260 | # pre-process distribution 261 | next_tokens_scores = logits_processor(input_ids, next_token_logits) 262 | 263 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 264 | 265 | # finished sentences should have their next token be a padding token 266 | if eos_token_id is not None: 267 | if pad_token_id is None: 268 | raise ValueError( 269 | "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." 270 | ) 271 | 272 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( 273 | 1 - unfinished_sequences 274 | ) 275 | 276 | # update generated ids, model inputs, and length for next step 277 | input_ids = torch.cat([input_ids, next_tokens.detach()[:, None]], dim=1) 278 | if "past_key_values" in outputs: 279 | model_kwargs["past_key_values"] = outputs.past_key_values 280 | visual_features = None # are only needed the first time; afterwards we use past attentions and just the latest input id 281 | 282 | if "token_type_ids" in model_kwargs: 283 | token_type_ids = model_kwargs["token_type_ids"] 284 | model_kwargs["token_type_ids"] = torch.cat( 285 | [ 286 | token_type_ids, 287 | torch.ones( 288 | (token_type_ids.shape[0], 1), 289 | dtype=torch.long, 290 | device=token_type_ids.device, 291 | ), 292 | ], 293 | dim=-1, 294 | ) 295 | 296 | if "attention_mask" in model_kwargs: 297 | attention_mask = model_kwargs["attention_mask"] 298 | model_kwargs["attention_mask"] = torch.cat( 299 | [ 300 | attention_mask, 301 | attention_mask.new_ones((attention_mask.shape[0], 1)), 302 | ], 303 | dim=-1, 304 | ) 305 | 306 | # if eos_token was found in one sentence, set sentence to finished 307 | if eos_token_id is not None: 308 | unfinished_sequences = unfinished_sequences.mul( 309 | (next_tokens != eos_token_id).long() 310 | ) 311 | 312 | # stop when each sentence is finished, or if we exceed the maximum length 313 | if unfinished_sequences.max() == 0 or input_ids.shape[-1] >= max_length: 314 | break 315 | 316 | return input_ids 317 | 318 | 319 | class Features2WordsModel(LMBase): 320 | def __init__( 321 | self, 322 | translation_config, 323 | cnn_name="resnet50", 324 | vision_feat_func="none", 325 | vision_feat_layers=[-1], 326 | lm_model_name="gpt2", 327 | max_length=None, 328 | feature_dropout=0.0, 329 | only_vit_embed_token=False, 330 | **kwargs, 331 | ): 332 | super().__init__(**kwargs) 333 | self.vision_feat_func = vision_feat_func 334 | self.vision_feat_layers = vision_feat_layers 335 | self.only_vit_embed_token = only_vit_embed_token 336 | 337 | self.register_buffer( 338 | "feature_dropout", torch.tensor([feature_dropout]), persistent=False 339 | ) 340 | 341 | # VISION MODEL 342 | ( 343 | self.cnn, 344 | img_embed_dims, 345 | self.grid_sizes, 346 | self.transform, 347 | self.unnormalize, 348 | ) = get_vision_model( 349 | cnn_name, 350 | vision_feat_layers, 351 | add_vit_embed_token=translation_config.add_vit_embed_token, 352 | only_vit_embed_token=self.only_vit_embed_token, 353 | ) 354 | if self.vision_feat_func == "avg_pooling": 355 | self.feat_seq_len = len(self.grid_sizes) 356 | else: 357 | self.feat_seq_len = sum([gs**2 for gs in self.grid_size]) 358 | raise NotImplementedError 359 | 360 | # freeze cnn 361 | for p in self.cnn.parameters(): 362 | p.requires_grad = False 363 | 364 | # LANGUAGE MODEL 365 | ( 366 | self.lm_model, 367 | self.tokenizer, 368 | self.embedding_weights, 369 | lm_embed_dim, 370 | ) = get_language_model(lm_model_name) 371 | 372 | self.lm_model.config.pad_token_id = self.lm_model.config.eos_token_id 373 | if max_length is not None: 374 | self.lm_model.config.max_length = max_length 375 | 376 | # freeze language model 377 | for p in self.lm_model.parameters(): 378 | p.requires_grad = False 379 | 380 | # TRANSLATION MODEL 381 | self.nr_learnable_tokens = translation_config.nr_learnable_tokens 382 | translation_config.feat_seq_len = self.feat_seq_len 383 | translation_config.img_embed_dims = img_embed_dims 384 | translation_config.lm_embed_dim = lm_embed_dim 385 | self.translation = TranslationTransformer(translation_config) 386 | 387 | self.train() 388 | 389 | def train(self, mode=True): 390 | super().train(mode) 391 | if mode: 392 | self.cnn.eval() 393 | 394 | def forward_vision_model( 395 | self, 396 | pixel_values: torch.FloatTensor, 397 | mask: torch.LongTensor = None, 398 | just_cnn_features: bool = False, 399 | ): 400 | dims = None 401 | if pixel_values.dim() > 4: # MILAN dataset by neuron (bs, 15, C, H, W) 402 | dims = pixel_values.shape 403 | pixel_values = pixel_values.view( 404 | dims[0] * dims[1], dims[2], dims[3], dims[4] 405 | ) 406 | 407 | if mask is not None: 408 | mdims = mask.shape 409 | mask = mask.view(mdims[0] * mdims[1], mdims[2], mdims[3], mdims[4]) 410 | 411 | img_features = self.cnn(pixel_values) 412 | 413 | forced_token_mask = [] 414 | if just_cnn_features == False: 415 | if self.vision_feat_func == "avg_pooling": 416 | if ( 417 | self.feature_dropout.item() > 0.0 and self.training 418 | ) or mask is not None: 419 | pooled_img_features = [] 420 | for imgf in img_features: 421 | imgf_b4_flatten = imgf 422 | imgf = imgf.flatten(-2) 423 | if imgf.shape[-1] == 1 and mask is None: 424 | # don't apply feature dropout on vit embed token 425 | pooled_img_features.append(imgf.unsqueeze(-1)) 426 | continue 427 | # 1 means keep, 0 means drop 428 | if mask is not None: 429 | # 1 means keep, 0 means drop 430 | dropout_mask = F.interpolate( 431 | mask, 432 | imgf_b4_flatten.shape[-2:], 433 | mode="bilinear", 434 | align_corners=False, 435 | ) 436 | # Normalize the masks so they look more like attention. If any 437 | # of them are all zeros, we'll end up with divide-by-zero errors. 438 | zeros = torch.zeros_like(dropout_mask) 439 | valid = ( 440 | ~dropout_mask.isclose(zeros) 441 | .all(dim=-1) 442 | .all(dim=-1) 443 | .view(-1) 444 | ) 445 | indices = valid.nonzero().squeeze() 446 | dropout_mask[indices] /= dropout_mask[indices].sum( 447 | dim=(-1, -2), keepdim=True 448 | ) 449 | 450 | pooled_imgf = imgf_b4_flatten.mul(dropout_mask).sum( 451 | dim=(-1, -2), keepdim=True 452 | ) 453 | 454 | if dims is not None: 455 | feat_dims = pooled_imgf.shape 456 | pooled_imgf = pooled_imgf.view( 457 | dims[0], 458 | dims[1], 459 | feat_dims[1], 460 | feat_dims[2], 461 | feat_dims[3], 462 | ) 463 | mdims = dropout_mask.shape 464 | dropout_mask = dropout_mask.view( 465 | dims[0], dims[1], mdims[1], mdims[2], mdims[3] 466 | ) 467 | dropout_mask = dropout_mask.sum(dim=1).sum( 468 | dim=(-1, -2), keepdim=True 469 | ) 470 | indices = dropout_mask.view(dims[0], -1).sum(1) != 0.0 471 | dropout_mask[indices] = 1 / dropout_mask[indices] 472 | pooled_imgf = pooled_imgf.sum(dim=1) * dropout_mask 473 | forced_token_mask.append(~indices) 474 | pooled_img_features.append(pooled_imgf) 475 | else: 476 | dropout_mask = torch.bernoulli( 477 | (1.0 - self.feature_dropout).expand( 478 | imgf.shape[0], imgf.shape[2] 479 | ) 480 | ) 481 | while (dropout_mask.sum(1) == 0.0).any(): 482 | dropout_mask = torch.bernoulli( 483 | (1.0 - self.feature_dropout).expand( 484 | imgf.shape[0], imgf.shape[2] 485 | ) 486 | ) 487 | pooled_imgf = [] 488 | for img, dm in zip(imgf, dropout_mask): 489 | idx = dm.nonzero().squeeze(1) 490 | pimg = img.index_select(1, idx).mean(dim=1) 491 | pooled_imgf.append(pimg[..., None, None]) 492 | pooled_img_features.append(torch.stack(pooled_imgf)) 493 | img_features = pooled_img_features 494 | else: 495 | img_features = [ 496 | F.adaptive_avg_pool2d(feats, (1, 1)) for feats in img_features 497 | ] 498 | else: 499 | if dims is not None: # MILAN by_unit - average over the 15 image features 500 | new_dims = img_features.shape 501 | img_features = img_features.view( 502 | dims[0], dims[1], new_dims[-2], new_dims[-1] 503 | ) 504 | img_features = img_features.mean(dim=1) 505 | 506 | if len(forced_token_mask) > 0: 507 | forced_token_mask = torch.stack(forced_token_mask, dim=1) 508 | else: 509 | forced_token_mask = None 510 | 511 | return img_features, forced_token_mask 512 | 513 | def forward( 514 | self, 515 | pixel_values: torch.FloatTensor, 516 | input_ids: torch.LongTensor, 517 | attention_mask: torch.LongTensor, 518 | gt_captions: torch.LongTensor, 519 | mask: torch.LongTensor = None, 520 | ): 521 | img_features, forced_token_mask = self.forward_vision_model( 522 | pixel_values, mask=mask, just_cnn_features=False 523 | ) 524 | encoded_features = self.translation( 525 | img_features, forced_token_mask=forced_token_mask 526 | ) 527 | encoded_features = encoded_features[:, -self.nr_learnable_tokens :, :] 528 | 529 | inputs_embeds = self.embedding_weights(input_ids) 530 | 531 | inputs_embeds = torch.cat( 532 | (inputs_embeds[:, :1, :], encoded_features, inputs_embeds[:, 1:, :]), dim=1 533 | ) 534 | 535 | out = self.lm_model( 536 | inputs_embeds=inputs_embeds, 537 | attention_mask=attention_mask, 538 | labels=gt_captions, 539 | ) 540 | return out["loss"] 541 | 542 | @torch.no_grad() 543 | def generate( 544 | self, 545 | pixel_values=None, 546 | img_features=None, 547 | mask=None, 548 | forced_token_mask=None, 549 | max_length=None, 550 | layer=None, 551 | feat_index=None, 552 | pool_locs=None, 553 | kernel_size=3, 554 | stride=1, 555 | ): 556 | if pixel_values is None and img_features is None: 557 | raise ValueError("Please provide either pixel_values or img_features.") 558 | if self.vision_feat_func != "avg_pooling" and feat_index is not None: 559 | raise ValueError( 560 | "Index selection is only possible when vision_feat_func is avg_pooling" 561 | ) 562 | if layer is None and feat_index is not None: 563 | raise ValueError( 564 | "Index selection is only possible when a layer is specified" 565 | ) 566 | if mask is not None and feat_index is not None: 567 | raise NotImplementedError 568 | if pool_locs is not None and feat_index is None: 569 | raise ValueError( 570 | "pool_locs is only valid for a specific location, so feat_index cannot be None" 571 | ) 572 | 573 | if img_features is None: 574 | just_cnn_features = False 575 | if feat_index is not None: 576 | just_cnn_features = True 577 | img_features, forced_token_mask = self.forward_vision_model( 578 | pixel_values, just_cnn_features=just_cnn_features, mask=mask 579 | ) 580 | 581 | if isinstance(img_features, list): 582 | bs = img_features[0].shape[0] 583 | else: 584 | bs = img_features.shape[0] 585 | device = img_features[0].device 586 | 587 | if layer is not None: 588 | layer_id = -layer - 1 589 | 590 | if feat_index is not None: 591 | feats = img_features[layer_id] 592 | if pool_locs is not None: 593 | feats = F.avg_pool2d( 594 | feats, 595 | kernel_size=kernel_size, 596 | padding=int((kernel_size - 1) / 2), 597 | stride=stride, 598 | count_include_pad=False, 599 | ) 600 | feats = feats[:, :, feat_index[0], feat_index[1]] 601 | feats = feats[:, :, None, None] 602 | img_features = [ 603 | torch.zeros((bs, f.shape[1], 1, 1), device=device) for f in img_features 604 | ] 605 | img_features[layer_id] = feats 606 | 607 | token_mask = None 608 | if layer is not None: 609 | token_mask = torch.ones( 610 | (bs, self.feat_seq_len), device=device 611 | ) # mask all layer features 612 | token_mask[:, layer_id] = torch.zeros( 613 | (bs), device=device 614 | ) # unmask corresponding layer features 615 | 616 | token_mask = torch.cat( 617 | ( 618 | token_mask, 619 | torch.zeros((bs, self.nr_learnable_tokens), device=device), 620 | ), 621 | dim=1, 622 | ) 623 | token_mask = token_mask.bool() 624 | 625 | encoded_features = self.translation( 626 | img_features, token_mask=token_mask, forced_token_mask=forced_token_mask 627 | ) 628 | encoded_features = encoded_features[:, -self.nr_learnable_tokens :, :] 629 | 630 | input_ids = ( 631 | torch.ones((bs, 1), dtype=torch.long, device=device) 632 | * self.lm_model.config.bos_token_id 633 | ) 634 | 635 | generated = self.greedy( 636 | input_ids, visual_features=encoded_features, max_length=max_length 637 | ) 638 | return generated 639 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | --------------------------------------------------------------------------------