├── evaluation ├── bleu │ ├── __init__.py │ ├── __pycache__ │ │ ├── bleu.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── bleu_scorer.cpython-39.pyc │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── __pycache__ │ │ ├── cider.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── cider_scorer.cpython-39.pyc │ ├── cider.py │ └── cider_scorer.py ├── rouge │ ├── __init__.py │ ├── __pycache__ │ │ ├── rouge.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ └── rouge.py ├── meteor │ ├── __init__.py │ ├── __pycache__ │ │ ├── meteor.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ └── meteor.py ├── spice │ ├── __init__.py │ ├── __pycache__ │ │ ├── spice.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ └── spice.py ├── stanford-corenlp-3.4.1.jar ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── tokenizer.cpython-39.pyc ├── __init__.py └── tokenizer.py ├── models ├── beam_search │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── beam_search.cpython-39.pyc │ └── beam_search.py ├── __init__.py ├── transformer │ ├── __init__.py │ ├── __pycache__ │ │ ├── utils.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── attention.cpython-39.pyc │ │ ├── captioner.cpython-39.pyc │ │ ├── decoders.cpython-39.pyc │ │ └── encoders.cpython-39.pyc │ ├── utils.py │ ├── transformer.py │ ├── encoders.py │ ├── attention.py │ └── decoders.py ├── __pycache__ │ ├── deecap.cpython-39.pyc │ ├── model.cpython-39.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── configure.cpython-39.pyc │ ├── containers.cpython-39.pyc │ ├── autoregressive.cpython-39.pyc │ └── non_autoregressive.cpython-39.pyc ├── utils.py ├── configure.py ├── model.py ├── containers.py └── deecap.py ├── data ├── train.pkl ├── __init__.py ├── example.py ├── utils.py ├── dataset.py ├── field.py └── vocab.py ├── __pycache__ └── dataset.cpython-39.pyc ├── utils ├── __pycache__ │ ├── logger.cpython-39.pyc │ ├── typing.cpython-39.pyc │ ├── utils.cpython-39.pyc │ └── __init__.cpython-39.pyc ├── typing.py ├── __init__.py ├── utils.py └── logger.py ├── README.md ├── test.py ├── inference.py ├── dataset.py ├── train_deecap.py └── train_tic.py /evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /evaluation/spice/__init__.py: -------------------------------------------------------------------------------- 1 | from .spice import Spice 2 | -------------------------------------------------------------------------------- /models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | -------------------------------------------------------------------------------- /data/train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/data/train.pkl -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .configure import * 3 | from .deecap import * -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .encoders import * 3 | from .decoders import * 4 | -------------------------------------------------------------------------------- /evaluation/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /models/__pycache__/deecap.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/deecap.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/typing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/utils/__pycache__/typing.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/configure.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/configure.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/containers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/containers.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/__pycache__/tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/bleu/__pycache__/bleu.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/bleu/__pycache__/bleu.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/cider/__pycache__/cider.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/cider/__pycache__/cider.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/rouge/__pycache__/rouge.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/rouge/__pycache__/rouge.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/spice/__pycache__/spice.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/spice/__pycache__/spice.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/autoregressive.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/autoregressive.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/bleu/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/bleu/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/cider/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/cider/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/meteor/__pycache__/meteor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/meteor/__pycache__/meteor.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/rouge/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/rouge/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/spice/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/spice/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/non_autoregressive.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/__pycache__/non_autoregressive.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/transformer/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/bleu/__pycache__/bleu_scorer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/bleu/__pycache__/bleu_scorer.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/meteor/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/meteor/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/beam_search/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/beam_search/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/transformer/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/transformer/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/captioner.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/transformer/__pycache__/captioner.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/decoders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/transformer/__pycache__/decoders.cpython-39.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/encoders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/transformer/__pycache__/encoders.cpython-39.pyc -------------------------------------------------------------------------------- /evaluation/cider/__pycache__/cider_scorer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/evaluation/cider/__pycache__/cider_scorer.cpython-39.pyc -------------------------------------------------------------------------------- /models/beam_search/__pycache__/beam_search.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/DeeCap/HEAD/models/beam_search/__pycache__/beam_search.cpython-39.pyc -------------------------------------------------------------------------------- /utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence 2 | 3 | import torch 4 | 5 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 6 | TensorOrNone = Union[torch.Tensor, None] 7 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import RawField, Merge, ImageDetectionsField, TextField 2 | from .dataset import COCO 3 | from torch.utils.data import DataLoader as TorchDataLoader 4 | 5 | class DataLoader(TorchDataLoader): 6 | def __init__(self, dataset, *args, **kwargs): 7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) 8 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def one_hot_to_index(one_hot: Tensor) -> Tensor: 6 | """ 7 | Converts a one-hot tensor into a tensor with corresponding indexes 8 | """ 9 | device, dtype = one_hot.device, one_hot.dtype 10 | vocab_size = one_hot.shape[-1] 11 | oh2idx = torch.tensor(range(vocab_size), dtype=dtype, device=device) 12 | return (one_hot @ oh2idx.unsqueeze(dim=1)).long().squeeze(dim=-1) 13 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .typing import * 3 | from .utils import * 4 | 5 | 6 | def get_batch_size(x: TensorOrSequence) -> int: 7 | if isinstance(x, torch.Tensor): 8 | b_s = x.size(0) 9 | else: 10 | b_s = x[0].size(0) 11 | return b_s 12 | 13 | 14 | def get_device(x: TensorOrSequence) -> int: 15 | if isinstance(x, torch.Tensor): 16 | b_s = x.device 17 | else: 18 | b_s = x[0].device 19 | return b_s 20 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu 2 | from .meteor import Meteor 3 | from .rouge import Rouge 4 | from .cider import Cider 5 | from .spice import Spice 6 | from .tokenizer import PTBTokenizer 7 | 8 | 9 | def compute_all_scores(gts, gen): 10 | # metrics = (Bleu(), Meteor(), Rouge(), Cider(), Spice()) 11 | metrics = (Bleu(), Cider()) 12 | all_score = {} 13 | all_scores = {} 14 | for metric in metrics: 15 | score, scores = metric.compute_score(gts, gen) 16 | all_score[str(metric)] = score 17 | all_scores[str(metric)] = scores 18 | 19 | return all_score, all_scores 20 | -------------------------------------------------------------------------------- /data/example.py: -------------------------------------------------------------------------------- 1 | 2 | class Example(object): 3 | """Defines a single training or test example. 4 | Stores each column of the example as an attribute. 5 | """ 6 | @classmethod 7 | def fromdict(cls, data): 8 | ex = cls(data) 9 | return ex 10 | 11 | def __init__(self, data): 12 | for key, val in data.items(): 13 | super(Example, self).__setattr__(key, val) 14 | 15 | def __setattr__(self, key, value): 16 | raise AttributeError 17 | 18 | def __hash__(self): 19 | return hash(tuple(x for x in self.__dict__.values())) 20 | 21 | def __eq__(self, other): 22 | this = tuple(x for x in self.__dict__.values()) 23 | other = tuple(x for x in other.__dict__.values()) 24 | return this == other 25 | 26 | def __ne__(self, other): 27 | return not self.__eq__(other) 28 | -------------------------------------------------------------------------------- /evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import shutil 4 | import warnings 5 | import requests 6 | import pidfile 7 | from contextlib import contextmanager 8 | from time import sleep 9 | 10 | @contextmanager 11 | def exclusive(pidname): 12 | done = False 13 | while not done: 14 | try: 15 | with pidfile.PIDFile(pidname): 16 | yield 17 | done = True 18 | except pidfile.AlreadyRunningError: 19 | sleep(5) 20 | 21 | 22 | def download_from_url(url, path): 23 | """Download file, with logic (from tensor2tensor) for Google Drive""" 24 | if 'drive.google.com' not in url: 25 | print('Downloading %s; may take a few minutes' % url) 26 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 27 | with open(path, "wb") as file: 28 | file.write(r.content) 29 | return 30 | print('Downloading from Google Drive; may take a few minutes') 31 | confirm_token = None 32 | session = requests.Session() 33 | response = session.get(url, stream=True) 34 | for k, v in response.cookies.items(): 35 | if k.startswith("download_warning"): 36 | confirm_token = v 37 | 38 | if confirm_token: 39 | url = url + "&confirm=" + confirm_token 40 | response = session.get(url, stream=True) 41 | 42 | chunk_size = 16 * 1024 43 | with open(path, "wb") as f: 44 | for chunk in response.iter_content(chunk_size): 45 | if chunk: 46 | f.write(chunk) 47 | 48 | 49 | class DummyFile(object): 50 | def write(self, x): pass 51 | 52 | 53 | @contextmanager 54 | def nostdout(): 55 | save_stdout = sys.stdout 56 | sys.stdout = DummyFile() 57 | yield 58 | sys.stdout = save_stdout 59 | -------------------------------------------------------------------------------- /evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | 12 | class Cider: 13 | """ 14 | Main Class to compute the CIDEr metric 15 | 16 | """ 17 | def __init__(self, gts=None, n=4, sigma=6.0): 18 | # set cider to sum over 1 to 4-grams 19 | self._n = n 20 | # set the standard deviation parameter for gaussian penalty 21 | self._sigma = sigma 22 | self.doc_frequency = None 23 | self.ref_len = None 24 | if gts is not None: 25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 26 | self.doc_frequency = tmp_cider.doc_frequency 27 | self.ref_len = tmp_cider.ref_len 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | Main function to compute CIDEr score 32 | :param gts (dict) : dictionary with key and value 33 | res (dict) : dictionary with key and value 34 | :return: cider (float) : computed CIDEr score for the corpus 35 | """ 36 | assert(gts.keys() == res.keys()) 37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 38 | ref_len=self.ref_len) 39 | return cider_scorer.compute_score() 40 | 41 | def __str__(self): 42 | return 'CIDEr' 43 | -------------------------------------------------------------------------------- /models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=input.dtype, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None, dtype=torch.float32): 19 | pos = torch.arange(max_len, dtype=dtype) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | """ 29 | Position-wise feed forward layer 30 | """ 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(input + out) 50 | return out 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeeCap 2 | 3 | This repository includes the reference code for paper: 4 | 5 | > **Dynamic Early Exit for Efficient Image Captioning** 6 | 7 | 8 | ## Data 9 | 10 | To run the code, annotations and images for the COCO dataset are needed. 11 | Please download the zip files including the images ([train2014.zip](http://images.cocodataset.org/zips/train2014.zip), [val2014.zip](http://images.cocodataset.org/zips/val2014.zip)), 12 | the zip file containing the annotations ([annotations_trainval2014.zip](http://images.cocodataset.org/annotations/annotations_trainval2014.zip)) and extract them. These paths will be set as arguments later. 13 | Our code supports the image features extracted from conventional [Faster-RCNN](https://drive.google.com/file/d/1MV6dSnqViQfyvgyHrmAT_lLpFbkzp3mx/view) or [CLIP](https://github.com/openai/CLIP) model. 14 | 15 | 16 | ## Training Procedure 17 | 18 | Run `python train_deecap.py` using the following arguments: 19 | 20 | | Argument | Possible values | 21 | |------|------| 22 | | `--exp_name` | Experiment name (default: deecap)| 23 | | `--train_data_path` | Path to the training dataset | 24 | | `--features_path` | Path to detection features file (optional) | 25 | | `--annotation_folder` | Path to folder with annotations (optional) | 26 | | `--tokenizer_path` | Path to the tokenizer | 27 | | `--out_dir` | Path to the saved checkpoint | 28 | | `--batch_size` | Batch size (default: 10) | 29 | | `--lr` | Learning rate (default: 1e-4) | 30 | 31 | 32 | 33 | 34 | 35 | ## Evaluation 36 | 37 | To reproduce the results reported in our paper, download the checkpoint model file and place it in the ckpt folder. 38 | 39 | Run `python test.py` using the following arguments: 40 | 41 | | Argument | Possible values | 42 | |------|------| 43 | | `--batch_size` | Batch size (default: 10) | 44 | | `--features_path` | Path to detection features file | 45 | | `--annotation_folder` | Path to folder with COCO annotations | 46 | 47 | 48 | ## Acknowledgment 49 | This repository refers to [Transformer Image Captioning](https://github.com/aimagelab/meshed-memory-transformer) and [huggingface DeeBERT](https://github.com/huggingface/transformers/tree/master/examples/research_projects/deebert). 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /models/configure.py: -------------------------------------------------------------------------------- 1 | # parameters configuration for transformer image captioning model 2 | 3 | class TransformerConfig: 4 | def __init__( 5 | self, 6 | clip_dim=512, 7 | clip_length=16, 8 | vocab_size=50257, 9 | n_positions=1024, 10 | n_embd=512, 11 | n_layer=6, 12 | n_head=8, 13 | n_inner=None, 14 | activation_function="gelu_new", 15 | resid_pdrop=0.1, 16 | embd_pdrop=0.1, 17 | attn_pdrop=0.1, 18 | layer_norm_epsilon=1e-5, 19 | initializer_range=0.02, 20 | patience=0.5, 21 | summary_type="cls_index", 22 | summary_use_proj=True, 23 | summary_activation=None, 24 | summary_proj_to_labels=True, 25 | summary_first_dropout=0.1, 26 | scale_attn_weights=True, 27 | use_cache=True, 28 | bos_token_id=50256, 29 | eos_token_id=50256, 30 | scale_attn_by_inverse_layer_idx=False, 31 | reorder_and_upcast_attn=False, 32 | ): 33 | self.clip_dim=clip_dim 34 | self.clip_length=clip_length 35 | self.vocab_size = vocab_size 36 | self.n_positions = n_positions 37 | self.n_embd = n_embd 38 | self.n_layer = n_layer 39 | self.n_head = n_head 40 | self.n_inner = n_inner 41 | self.activation_function = activation_function 42 | self.resid_pdrop = resid_pdrop 43 | self.embd_pdrop = embd_pdrop 44 | self.attn_pdrop = attn_pdrop 45 | self.layer_norm_epsilon = layer_norm_epsilon 46 | self.initializer_range = initializer_range 47 | self.summary_type = summary_type 48 | self.summary_use_proj = summary_use_proj 49 | self.summary_activation = summary_activation 50 | self.summary_first_dropout = summary_first_dropout 51 | self.summary_proj_to_labels = summary_proj_to_labels 52 | self.scale_attn_weights = scale_attn_weights 53 | self.use_cache = use_cache 54 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 55 | self.reorder_and_upcast_attn = reorder_and_upcast_attn 56 | self.patience = patience 57 | 58 | self.bos_token_id = bos_token_id 59 | self.eos_token_id = eos_token_id 60 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import Tensor 4 | from .transformer import Encoder, Decoder, ScaledDotProductAttentionMemory, MeshedDecoder 5 | from .containers import Module 6 | from .beam_search import * 7 | 8 | 9 | # original transformer image captioning model 10 | class TICModel(Module): 11 | def __init__(self, config): 12 | super(TICModel, self).__init__() 13 | self.model_d = config.n_embd 14 | self.clip_dim = config.clip_dim 15 | self.clip_length = config.clip_length 16 | self.feature_project = nn.Linear(config.clip_dim, config.clip_length*config.n_embd) 17 | self.visual_encoder = Encoder(config.n_layer, config.clip_length, config.n_embd) 18 | self.language_decoder = Decoder(config.vocab_size) 19 | 20 | self.bos_idx = config.bos_token_id 21 | self.eos_idx = config.eos_token_id 22 | self.vocab_size = config.vocab_size 23 | self.max_generation_length = self.language_decoder.max_len 24 | 25 | self.register_state('enc_output', None) 26 | self.register_state('mask_enc', None) 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | for p in self.visual_encoder.parameters(): 31 | if p.dim() > 1: 32 | nn.init.xavier_uniform_(p) 33 | for p in self.language_decoder.parameters(): 34 | if p.dim() > 1: 35 | nn.init.xavier_uniform_(p) 36 | 37 | def forward(self, images, seq): 38 | images = self.feature_project(images).view(-1, self.clip_length, self.clip_dim) 39 | enc_output, mask_enc = self.visual_encoder(images) 40 | dec_output = self.language_decoder(seq, enc_output, mask_enc) 41 | return dec_output 42 | 43 | def step(self, t: int, prev_output: Tensor, visual: Tensor) -> Tensor: 44 | if t == 0: 45 | visual = self.feature_project(visual).view(-1, self.clip_length, self.clip_dim) 46 | self.enc_output, self.mask_enc = self.visual_encoder(visual) 47 | input = visual.data.new_full((visual.shape[0], 1), self.bos_idx, dtype=torch.long) 48 | else: 49 | input = prev_output 50 | logits = self.language_decoder(input, self.enc_output, self.mask_enc) 51 | return logits 52 | 53 | def beam_search(self, visual, beam_size: int, out_size=1, 54 | return_logits=False, **kwargs): 55 | bs = BeamSearch(self, self.max_generation_length, self.eos_idx, beam_size) 56 | return bs.apply(visual, out_size, return_logits, **kwargs) 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, seq, *args): 28 | enc_output, mask_enc = self.encoder(images) 29 | dec_output = self.decoder(seq, enc_output, mask_enc) 30 | return dec_output 31 | 32 | def init_state(self, b_s, device): 33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 34 | None, None] 35 | 36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 37 | it = None 38 | if mode == 'teacher_forcing': 39 | raise NotImplementedError 40 | elif mode == 'feedback': 41 | if t == 0: 42 | self.enc_output, self.mask_enc = self.encoder(visual) 43 | if isinstance(visual, torch.Tensor): 44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 45 | else: 46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 47 | else: 48 | it = prev_output 49 | 50 | return self.decoder(it, self.enc_output, self.mask_enc) 51 | 52 | 53 | class TransformerEnsemble(CaptioningModel): 54 | def __init__(self, model: Transformer, weight_files): 55 | super(TransformerEnsemble, self).__init__() 56 | self.n = len(weight_files) 57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 58 | for i in range(self.n): 59 | state_dict_i = torch.load(weight_files[i])['state_dict'] 60 | self.models[i].load_state_dict(state_dict_i) 61 | 62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 63 | out_ensemble = [] 64 | for i in range(self.n): 65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 66 | out_ensemble.append(out_i.unsqueeze(0)) 67 | 68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 69 | -------------------------------------------------------------------------------- /models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from utils.typing import * 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | if self._buffers[name] is not None: 31 | self._buffers[name] = fn(self._buffers[name]) 32 | for m in self.children(): 33 | if isinstance(m, Module): 34 | m.apply_to_states(fn) 35 | 36 | def _init_states(self, batch_size: int): 37 | for name in self._state_names: 38 | if self._state_defaults[name] is None: 39 | self._buffers[name] = None 40 | else: 41 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 42 | self._buffers[name] = self._buffers[name].unsqueeze(0) 43 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 44 | self._buffers[name] = self._buffers[name].contiguous() 45 | 46 | def _reset_states(self): 47 | for name in self._state_names: 48 | if self._state_defaults[name] is None: 49 | self._buffers[name] = None 50 | else: 51 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 52 | 53 | def enable_statefulness(self, batch_size: int): 54 | for m in self.children(): 55 | if isinstance(m, Module): 56 | m.enable_statefulness(batch_size) 57 | self._init_states(batch_size) 58 | self._is_stateful = True 59 | 60 | def disable_statefulness(self): 61 | for m in self.children(): 62 | if isinstance(m, Module): 63 | m.disable_statefulness() 64 | self._reset_states() 65 | self._is_stateful = False 66 | 67 | @contextmanager 68 | def statefulness(self, batch_size: int): 69 | self.enable_statefulness(batch_size) 70 | try: 71 | yield 72 | finally: 73 | self.disable_statefulness() 74 | 75 | 76 | class ModuleList(nn.ModuleList, Module): 77 | pass 78 | 79 | 80 | class ModuleDict(nn.ModuleDict, Module): 81 | pass 82 | -------------------------------------------------------------------------------- /evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | from utils import download_from_url 9 | 10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 11 | METEOR_JAR = 'meteor-1.5.jar' 12 | 13 | class Meteor: 14 | def __init__(self): 15 | base_path = os.path.dirname(os.path.abspath(__file__)) 16 | jar_path = os.path.join(base_path, METEOR_JAR) 17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 18 | if not os.path.isfile(jar_path): 19 | if not os.path.isfile(gz_path): 20 | download_from_url(METEOR_GZ_URL, gz_path) 21 | tar = tarfile.open(gz_path, "r") 22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 23 | tar.close() 24 | os.remove(gz_path) 25 | 26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 27 | '-', '-', '-stdio', '-l', 'en', '-norm'] 28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 29 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 30 | stdin=subprocess.PIPE, \ 31 | stdout=subprocess.PIPE, \ 32 | stderr=subprocess.PIPE) 33 | # Used to guarantee thread safety 34 | self.lock = threading.Lock() 35 | 36 | def compute_score(self, gts, res): 37 | assert(gts.keys() == res.keys()) 38 | imgIds = gts.keys() 39 | scores = [] 40 | 41 | eval_line = 'EVAL' 42 | self.lock.acquire() 43 | for i in imgIds: 44 | assert(len(res[i]) == 1) 45 | stat = self._stat(res[i][0], gts[i]) 46 | eval_line += ' ||| {}'.format(stat) 47 | 48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 49 | self.meteor_p.stdin.flush() 50 | for i in range(0,len(imgIds)): 51 | scores.append(float(self.meteor_p.stdout.readline().strip())) 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | self.lock.release() 54 | 55 | return score, scores 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 62 | self.meteor_p.stdin.flush() 63 | raw = self.meteor_p.stdout.readline().decode().strip() 64 | numbers = [str(int(float(n))) for n in raw.split()] 65 | return ' '.join(numbers) 66 | 67 | def __del__(self): 68 | self.lock.acquire() 69 | self.meteor_p.stdin.close() 70 | self.meteor_p.kill() 71 | self.meteor_p.wait() 72 | self.lock.release() 73 | 74 | def __str__(self): 75 | return 'METEOR' 76 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib, sys 2 | 3 | class DummyFile(object): 4 | def write(self, x): pass 5 | 6 | @contextlib.contextmanager 7 | def nostdout(): 8 | save_stdout = sys.stdout 9 | sys.stdout = DummyFile() 10 | yield 11 | sys.stdout = save_stdout 12 | 13 | def reporthook(t): 14 | """https://github.com/tqdm/tqdm""" 15 | last_b = [0] 16 | 17 | def inner(b=1, bsize=1, tsize=None): 18 | """ 19 | b: int, optionala 20 | Number of blocks just transferred [default: 1]. 21 | bsize: int, optional 22 | Size of each block (in tqdm units) [default: 1]. 23 | tsize: int, optional 24 | Total size (in tqdm units). If [default: None] remains unchanged. 25 | """ 26 | if tsize is not None: 27 | t.total = tsize 28 | t.update((b - last_b[0]) * bsize) 29 | last_b[0] = b 30 | return inner 31 | 32 | def get_tokenizer(tokenizer): 33 | if callable(tokenizer): 34 | return tokenizer 35 | if tokenizer == "spacy": 36 | try: 37 | import spacy 38 | spacy_en = spacy.load('en') 39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 40 | except ImportError: 41 | print("Please install SpaCy and the SpaCy English tokenizer. " 42 | "See the docs at https://spacy.io for more information.") 43 | raise 44 | except AttributeError: 45 | print("Please install SpaCy and the SpaCy English tokenizer. " 46 | "See the docs at https://spacy.io for more information.") 47 | raise 48 | elif tokenizer == "moses": 49 | try: 50 | from nltk.tokenize.moses import MosesTokenizer 51 | moses_tokenizer = MosesTokenizer() 52 | return moses_tokenizer.tokenize 53 | except ImportError: 54 | print("Please install NLTK. " 55 | "See the docs at http://nltk.org for more information.") 56 | raise 57 | except LookupError: 58 | print("Please install the necessary NLTK corpora. " 59 | "See the docs at http://nltk.org for more information.") 60 | raise 61 | elif tokenizer == 'revtok': 62 | try: 63 | import revtok 64 | return revtok.tokenize 65 | except ImportError: 66 | print("Please install revtok.") 67 | raise 68 | elif tokenizer == 'subword': 69 | try: 70 | import revtok 71 | return lambda x: revtok.tokenize(x, decap=True) 72 | except ImportError: 73 | print("Please install revtok.") 74 | raise 75 | raise ValueError("Requested tokenizer {}, valid choices are a " 76 | "callable that takes a single string as input, " 77 | "\"revtok\" for the revtok reversible tokenizer, " 78 | "\"subword\" for the revtok caps-aware tokenizer, " 79 | "\"spacy\" for the SpaCy English tokenizer, or " 80 | "\"moses\" for the NLTK port of the Moses tokenization " 81 | "script.".format(tokenizer)) 82 | -------------------------------------------------------------------------------- /models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.transformer.attention import MultiHeadAttention 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 10 | attention_module=None, attention_module_kwargs=None): 11 | super(EncoderLayer, self).__init__() 12 | self.identity_map_reordering = identity_map_reordering 13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 14 | attention_module=attention_module, 15 | attention_module_kwargs=attention_module_kwargs) 16 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 17 | 18 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 19 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 20 | ff = self.pwff(att) 21 | return ff 22 | 23 | 24 | class Encoder(nn.Module): 25 | def __init__(self, N=6, max_len=16, d_in=512, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 26 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None, 27 | with_pe=False, with_mesh=False): 28 | super(Encoder, self).__init__() 29 | self.d_in = d_in 30 | self.d_model = d_model 31 | self.dropout = dropout 32 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 33 | identity_map_reordering=identity_map_reordering, 34 | attention_module=attention_module, 35 | attention_module_kwargs=attention_module_kwargs) 36 | for _ in range(N)]) 37 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, self.d_in, 0), freeze=True) 38 | self.fc = nn.Linear(d_in, self.d_model) 39 | self.dropout = nn.Dropout(p=self.dropout) 40 | self.layer_norm = nn.LayerNorm(self.d_model) 41 | self.with_pe = with_pe 42 | self.with_mesh = with_mesh 43 | 44 | def forward(self, input): 45 | # input (b_s, seq_len, d_in) 46 | b_s, seq_len = input.shape[:2] 47 | seq = torch.arange(1, seq_len + 1, device=input.device).view(1, -1).expand(b_s, -1) # (b_s, seq_len) 48 | 49 | out = input 50 | if self.with_pe: 51 | out = out + self.pos_emb(seq) 52 | out = F.relu(self.fc(out)) 53 | out = self.dropout(out) 54 | out = self.layer_norm(out) 55 | outs = list() 56 | for l in self.layers: 57 | out = l(out, out, out) 58 | if self.with_mesh: 59 | outs.append(out.unsqueeze(1)) 60 | 61 | if self.with_mesh: 62 | outs = torch.cat(outs, 1) 63 | return outs, None 64 | return out, None 65 | 66 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from transformers import GPT2Tokenizer 7 | from torch.utils.data import Dataset, DataLoader 8 | from tqdm import tqdm 9 | 10 | 11 | from models.transformer import Encoder, Decoder 12 | from models import entropy, DeeCapModel, TransformerConfig 13 | from dataset import ClipCocoDataset 14 | import evaluation 15 | 16 | 17 | use_device = torch.cuda.is_available() 18 | device = torch.device('cuda:0' if use_device else 'cpu') 19 | torch.backends.cudnn.benchmark = True 20 | 21 | random.seed(1234) 22 | torch.manual_seed(1234) 23 | np.random.seed(1234) 24 | 25 | SPECIAL_TOKENS = ["", ""] 26 | SPECIAL_TOKENS_DICT = {'bos_token': "", 'eos_token': "", } 27 | max_length = 20 28 | 29 | 30 | def greedy_decode(img_features, model, tokenizer): 31 | special_token_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) 32 | 33 | gen_i = [special_token_ids[0]] 34 | 35 | for i in range(max_length): 36 | tokens = torch.tensor(gen_i).long().unsqueeze(0) 37 | logits = model.step(img_features, tokens) 38 | logits = logits[0].cpu().numpy() 39 | next_word = np.argsort(logits)[-1] 40 | if next_word == special_token_ids[1]: 41 | break 42 | gen_i.append(next_word) 43 | return gen_i 44 | 45 | 46 | 47 | def predict_captions(model, test_dataloader, tokenizer): 48 | import itertools 49 | model.eval() 50 | gen = {} 51 | gts = {} 52 | progress = tqdm(total=len(test_dataloader), desc='DeeCapModel') 53 | with torch.no_grad(): 54 | for idx, (tokens, _, img_features) in enumerate(test_dataloader): 55 | tokens, img_features = tokens.to(device), img_features.to(device, dtype=torch.float32) 56 | gen_i = greedy_decode(img_features, model, tokenizer) 57 | caps_gen = tokenizer.batch_decode([gen_i]) 58 | caps_gt = tokenizer.batch_decode(tokens) 59 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 60 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 61 | gen['%d_%d' % (idx, i)] = [gen_i.strip(), ] 62 | gts['%d_%d' % (idx, i)] = gts_i 63 | progress.update() 64 | 65 | gts = evaluation.PTBTokenizer.tokenize(gts) 66 | gen = evaluation.PTBTokenizer.tokenize(gen) 67 | scores, _ = evaluation.compute_all_scores(gts, gen) 68 | return scores 69 | 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = argparse.ArgumentParser(description='deecap') 74 | parser.add_argument('--batch_size', type=int, default=1) 75 | parser.add_argument('--test_data_path', default='./data/test.pkl') 76 | parser.add_argument('--tokenizer_path', default='./ckpt/gpt2') 77 | args = parser.parse_args() 78 | 79 | tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_path) 80 | tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) 81 | test_dataset = ClipCocoDataset(args.test_data_path, tokenizer) 82 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False) 83 | 84 | config = TransformerConfig(vocab_size=len(tokenizer)) 85 | model = DeeCapModel(config).to(device) 86 | 87 | scores = predict_captions(model, test_dataloader, tokenizer) 88 | print(scores) 89 | 90 | 91 | ''' 92 | input_f = torch.randn((5,512)) 93 | input_l = torch.ones((5,20)).long() 94 | 95 | 96 | configuration = TransformerConfig() 97 | model = DeeCapModel(configuration) 98 | model.step(input_f, input_l) 99 | ''' 100 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import * 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | 7 | from model import DeeCapModel 8 | from dataset import build_input_from_segments, DeeCapDataset 9 | 10 | SPECIAL_TOKENS = ['[BOS]', '[EOS]', '[SEP]', '[IMG]', '[TXT]', '[PAD]'] 11 | SPECIAL_TOKENS_DICT = {'bos_token':'[BOS]', 'eos_token':'[EOS]', 'additional_special_tokens':['[IMG]', '[TXT]', '[SEP]'], 'pad_token':'[PAD]'} 12 | 13 | 14 | def beam_search(img_features, model, tokenizer, max_length=25, beam_size=5): 15 | special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) 16 | current_output = '' 17 | hyplist = [([], 0., current_output)] 18 | comp_hyplist = [] 19 | img_emb = model.img_ff(img_features) 20 | for i in range(max_length): 21 | new_hyplist = [] 22 | argmin = 0 23 | for out, lp, st in hyplist: 24 | instance = build_input_from_segments(img_features, st, tokenizer, label_flag=False) 25 | input_ids = torch.tensor(instance['input_ids']).long() 26 | token_type_ids = torch.tensor(instance['token_type_ids']).long() 27 | input_emb = model.transformer.wte(input_ids) 28 | input_embs = torch.cat([img_emb, input_emb], dim=0) 29 | # print(input_embs.size(), toekn_type_ids.size()) 30 | 31 | logits = model(input_embs=input_embs, token_type_ids=token_type_ids)[0] 32 | logp = F.log_softmax(logits, dim=-1)[-1, :] 33 | lp_vec = logp.cpu().data.numpy() + lp 34 | 35 | for o in np.argsort(lp_vec)[::-1]: 36 | if o == tokenizer.unk_token_id or o == tokenizer.eos_token_id: 37 | continue 38 | new_lp = lp_vec[o] 39 | if len(new_hyplist) == beam_size: 40 | if new_hyplist[argmin][1] < new_lp: 41 | new_st = copy.deepcopy(st) 42 | new_st += ' ' 43 | new_st += tokenizer.convert_ids_to_tokens([o])[0] 44 | new_hyplist[argmin] = (out+[o], new_lp, new_st) 45 | argmin = min(enumerate(new_hyplist), key=lambda h: h[1])[0] 46 | else: 47 | break 48 | else: 49 | new_st = copy.deepcopy(st) 50 | new_st += ' ' 51 | new_st += tokenizer.convert_ids_to_tokens([o])[0] 52 | new_hyplist.append((out+[o], new_lp, new_st)) 53 | if len(new_hyplist) == beam_size: 54 | argmin = min(enumerate(new_hyplist), key=lambda h: h[1])[0] 55 | hyplist = new_hyplist 56 | maxhyps = sorted(hyplist, key=lambda h: -h[1])[:1] 57 | print(maxhyps) 58 | 59 | 60 | 61 | def generate_caption(model, tokenizer, data): 62 | model.eval() 63 | with torch.no_grad(): 64 | for instance in data: 65 | img_features = instance[0] 66 | hypstr = beam_search(img_features, model, tokenizer) 67 | break 68 | 69 | 70 | if __name__ == '__main__': 71 | ckpt_model_path = 'model/ckpt' 72 | tokenizer = GPT2Tokenizer.from_pretrained(ckpt_model_path, do_lower_case=True) 73 | model_config = GPT2Config.from_pretrained(ckpt_model_path) 74 | model = DeeCapModel(model_config) 75 | 76 | ckpt = torch.load('model/ckpt/epoch_1', map_location='cpu') 77 | model.load_state_dict(ckpt['model']) 78 | tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) 79 | 80 | use_cuda = torch.cuda.is_available() 81 | device = torch.device('cuda' if use_cuda else 'cpu') 82 | model = model.to(device) 83 | model.eval() 84 | 85 | test_dataset_path = 'data' 86 | test_dataset = DeeCapDataset(test_dataset_path, tokenizer) 87 | 88 | generate_caption(model, tokenizer, test_dataset) 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /evaluation/spice/spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import subprocess 4 | import json 5 | import numpy as np 6 | import tempfile 7 | import tarfile 8 | from utils import download_from_url 9 | 10 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 11 | SPICE_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/spice.tgz' 12 | SPICE_JAR = 'spice-1.0.jar' 13 | TEMP_DIR = 'tmp' 14 | CACHE_DIR = 'cache' 15 | 16 | 17 | class Spice: 18 | """ 19 | Main Class to compute the SPICE metric 20 | """ 21 | 22 | def __init__(self): 23 | base_path = os.path.dirname(os.path.abspath(__file__)) 24 | jar_path = os.path.join(base_path, SPICE_JAR) 25 | gz_path = os.path.join(base_path, os.path.basename(SPICE_GZ_URL)) 26 | if not os.path.isfile(jar_path): 27 | if not os.path.isfile(gz_path): 28 | download_from_url(SPICE_GZ_URL, gz_path) 29 | tar = tarfile.open(gz_path, "r") 30 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 31 | tar.close() 32 | os.remove(gz_path) 33 | 34 | 35 | def float_convert(self, obj): 36 | try: 37 | return float(obj) 38 | except: 39 | return np.nan 40 | 41 | def compute_score(self, gts, res): 42 | assert (sorted(gts.keys()) == sorted(res.keys())) 43 | imgIds = sorted(gts.keys()) 44 | 45 | # Prepare temp input file for the SPICE scorer 46 | input_data = [] 47 | for id in imgIds: 48 | hypo = res[id] 49 | ref = gts[id] 50 | 51 | # Sanity check. 52 | assert (type(hypo) is list) 53 | assert (len(hypo) == 1) 54 | assert (type(ref) is list) 55 | assert (len(ref) >= 1) 56 | 57 | input_data.append({ 58 | "image_id": id, 59 | "test": hypo[0], 60 | "refs": ref 61 | }) 62 | 63 | cwd = os.path.dirname(os.path.abspath(__file__)) 64 | temp_dir = os.path.join(cwd, TEMP_DIR) 65 | if not os.path.exists(temp_dir): 66 | os.makedirs(temp_dir) 67 | in_file = tempfile.NamedTemporaryFile('w+', delete=False, dir=temp_dir, encoding='utf8') 68 | json.dump(input_data, in_file, indent=2) 69 | in_file.close() 70 | 71 | # Start job 72 | out_file = tempfile.NamedTemporaryFile('w+', delete=False, dir=temp_dir, encoding='utf8') 73 | out_file.close() 74 | cache_dir = os.path.join(cwd, CACHE_DIR) 75 | if not os.path.exists(cache_dir): 76 | os.makedirs(cache_dir) 77 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 78 | '-cache', cache_dir, 79 | '-out', out_file.name, 80 | '-subset', 81 | '-silent' 82 | ] 83 | 84 | try: 85 | from subprocess import DEVNULL # Python 3. 86 | except ImportError: 87 | DEVNULL = open(os.devnull, 'wb') 88 | subprocess.check_call(spice_cmd, 89 | cwd=os.path.dirname(os.path.abspath(__file__)), 90 | stdout=DEVNULL, stderr=DEVNULL) 91 | 92 | # Read and process results 93 | with open(out_file.name) as data_file: 94 | results = json.load(data_file) 95 | os.remove(in_file.name) 96 | os.remove(out_file.name) 97 | 98 | imgId_to_scores = {} 99 | spice_scores = [] 100 | for item in results: 101 | imgId_to_scores[item['image_id']] = item['scores'] 102 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 103 | average_score = np.mean(np.array(spice_scores)) 104 | scores = [] 105 | for image_id in imgIds: 106 | # Convert none to NaN before saving scores over subcategories 107 | score_set = {} 108 | for category, score_tuple in imgId_to_scores[image_id].items(): 109 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} 110 | scores.append(score_set) 111 | return average_score, scores 112 | 113 | def __str__(self): 114 | return 'SPICE' 115 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import skimage.io as io 3 | import clip 4 | from PIL import Image 5 | import pickle 6 | import json 7 | import os 8 | from tqdm import tqdm 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import GPT2Tokenizer 11 | import sys 12 | 13 | 14 | def dataset_split(dataset_path, output_path): 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | clip_model, preprocess = clip.load("ViT-B/32", device=device) 17 | 18 | annotation_path = os.path.join(dataset_path, 'annotations/captions_train2014.json') 19 | with open(annotation_path, 'r') as f: 20 | data = json.load(f)['annotations'] 21 | print("%0d captions loaded from json." %len(data)) 22 | 23 | all_embeddings = [] 24 | all_captions = [] 25 | for i in tqdm(range(len(data))): 26 | d = data[i] 27 | img_id = d['image_id'] 28 | file_name = os.path.join(dataset_path, f"train2014/COCO_train2014_{int(img_id):012d}.jpg") 29 | image = io.imread(file_name) 30 | image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device) 31 | with torch.no_grad(): 32 | img_features = clip_model.encode_image(image).cpu() 33 | d['clip_embedding'] = i 34 | all_embeddings.append(img_features) 35 | all_captions.append(d) 36 | if i == 20: 37 | break 38 | 39 | with open(output_path, 'wb') as f: 40 | pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f) 41 | return 0 42 | 43 | 44 | 45 | class ClipCocoDataset(Dataset): 46 | 47 | def __len__(self) -> int: 48 | return len(self.captions_tokens) 49 | 50 | def pad_tokens(self, item: int): 51 | tokens = self.captions_tokens[item] 52 | if self.padding == False: 53 | padding = 0 54 | else: 55 | padding = self.max_seq_len - tokens.shape[0] 56 | if padding > 0: 57 | tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1)) 58 | self.captions_tokens[item] = tokens 59 | elif padding < 0: 60 | tokens = tokens[:self.max_seq_len] 61 | self.captions_tokens[item] = tokens 62 | mask = tokens.ge(0) # mask is zero where we out of sequence 63 | tokens[~mask] = 0 64 | mask = mask.float() 65 | return tokens, mask 66 | 67 | def __getitem__(self, item: int): 68 | tokens, mask = self.pad_tokens(item) 69 | features = self.features[self.caption2embedding[item]] 70 | if self.normalize_prefix: 71 | features = features.float() 72 | features = features / features.norm(2, -1) 73 | return tokens, mask, features 74 | 75 | def __init__(self, data_path: str, tokenizer, padding=True, normalize_features=False): 76 | self.tokenizer = tokenizer 77 | self.normalize_prefix = normalize_features 78 | with open(data_path, 'rb') as f: 79 | all_data = pickle.load(f) 80 | print("Data size is %0d" % len(all_data["clip_embedding"])) 81 | sys.stdout.flush() 82 | self.features = all_data["clip_embedding"] 83 | captions_raw = all_data["captions"] 84 | self.image_ids = [caption["image_id"] for caption in captions_raw] 85 | self.captions = [caption['caption'] for caption in captions_raw] 86 | self.padding=padding 87 | 88 | self.captions_tokens = [] 89 | self.caption2embedding = [] 90 | max_seq_len = 0 91 | for caption in captions_raw: 92 | self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64)) 93 | self.caption2embedding.append(caption["clip_embedding"]) 94 | max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0]) 95 | # self.max_seq_len = max_seq_len 96 | all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float() 97 | self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max())) 98 | 99 | 100 | 101 | if __name__ == '__main__': 102 | dataset_path='/Users/feizhengcong/Desktop/COCO' 103 | output_path = './data/train.pkl' 104 | # dataset_split(dataset_path, output_path) 105 | tokenizer = GPT2Tokenizer.from_pretrained('ckpt/gpt2') 106 | dataset = ClipCocoDataset('data/train.pkl', tokenizer) 107 | tokens, mask, features = dataset[0] 108 | print(tokens, mask) 109 | print(features.size()) 110 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | import datetime 3 | import time 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20, fmt=None): 14 | if fmt is None: 15 | fmt = "{median:.4f} ({global_avg:.4f})" 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 0 19 | self.fmt = fmt 20 | 21 | def update(self, value, n=1): 22 | self.deque.append(value) 23 | self.count += n 24 | self.total += value * n 25 | 26 | def synchronize_between_processes(self): 27 | """ 28 | Warning: does not synchronize the deque! 29 | """ 30 | if not is_dist_avail_and_initialized(): 31 | return 32 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 33 | dist.barrier() 34 | dist.all_reduce(t) 35 | t = t.tolist() 36 | self.count = int(t[0]) 37 | self.total = t[1] 38 | 39 | @property 40 | def median(self): 41 | d = torch.tensor(list(self.deque)) 42 | return d.median().item() 43 | 44 | @property 45 | def avg(self): 46 | d = torch.tensor(list(self.deque), dtype=torch.float32) 47 | return d.mean().item() 48 | 49 | @property 50 | def global_avg(self): 51 | return self.total / self.count 52 | 53 | @property 54 | def max(self): 55 | return max(self.deque) 56 | 57 | @property 58 | def value(self): 59 | return self.deque[-1] 60 | 61 | def __str__(self): 62 | return self.fmt.format( 63 | median=self.median, 64 | avg=self.avg, 65 | global_avg=self.global_avg, 66 | max=self.max, 67 | value=self.value) 68 | 69 | 70 | class MetricLogger(object): 71 | def __init__(self, delimiter="\t"): 72 | self.meters = defaultdict(SmoothedValue) 73 | self.delimiter = delimiter 74 | 75 | def update(self, **kwargs): 76 | for k, v in kwargs.items(): 77 | if isinstance(v, torch.Tensor): 78 | v = v.item() 79 | assert isinstance(v, (float, int)) 80 | self.meters[k].update(v) 81 | 82 | def __getattr__(self, attr): 83 | if attr in self.meters: 84 | return self.meters[attr] 85 | if attr in self.__dict__: 86 | return self.__dict__[attr] 87 | raise AttributeError("'{}' object has no attribute '{}'".format( 88 | type(self).__name__, attr)) 89 | 90 | def __str__(self): 91 | loss_str = [] 92 | for name, meter in self.meters.items(): 93 | loss_str.append( 94 | "{}: {}".format(name, str(meter)) 95 | ) 96 | return self.delimiter.join(loss_str) 97 | 98 | def synchronize_between_processes(self): 99 | for meter in self.meters.values(): 100 | meter.synchronize_between_processes() 101 | 102 | def add_meter(self, name, meter): 103 | self.meters[name] = meter 104 | 105 | def log_every(self, iterable, print_freq, header=None): 106 | i = 0 107 | if not header: 108 | header = '' 109 | start_time = time.time() 110 | end = time.time() 111 | iter_time = SmoothedValue(fmt='{avg:.4f}') 112 | data_time = SmoothedValue(fmt='{avg:.4f}') 113 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 114 | if torch.cuda.is_available(): 115 | log_msg = self.delimiter.join([ 116 | header, 117 | '[{0' + space_fmt + '}/{1}]', 118 | 'eta: {eta}', 119 | '{meters}', 120 | 'time: {time}', 121 | 'data: {data}', 122 | 'max mem: {memory:.0f}' 123 | ]) 124 | else: 125 | log_msg = self.delimiter.join([ 126 | header, 127 | '[{0' + space_fmt + '}/{1}]', 128 | 'eta: {eta}', 129 | '{meters}', 130 | 'time: {time}', 131 | 'data: {data}' 132 | ]) 133 | MB = 1024.0 * 1024.0 134 | for obj in iterable: 135 | data_time.update(time.time() - end) 136 | yield obj 137 | iter_time.update(time.time() - end) 138 | if i % print_freq == 0: 139 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 140 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 141 | if torch.cuda.is_available(): 142 | print(log_msg.format( 143 | i, len(iterable), eta=eta_string, 144 | meters=str(self), 145 | time=str(iter_time), data=str(data_time), 146 | memory=torch.cuda.max_memory_allocated() / MB)) 147 | else: 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time))) 152 | i += 1 153 | end = time.time() 154 | total_time = time.time() - start_time 155 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 156 | print('{} Total time: {}'.format(header, total_time_str)) 157 | 158 | 159 | def is_dist_avail_and_initialized(): 160 | if not dist.is_available(): 161 | return False 162 | if not dist.is_initialized(): 163 | return False 164 | return True 165 | -------------------------------------------------------------------------------- /train_deecap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from transformers import GPT2Tokenizer, AdamW, get_linear_schedule_with_warmup 7 | from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss 8 | 9 | 10 | from models import DeeCapModel, TransformerConfig 11 | from dataset import ClipCocoDataset 12 | from torch.utils.data import Dataset, DataLoader 13 | from tqdm import tqdm 14 | 15 | 16 | 17 | SPECIAL_TOKENS = ["", ""] 18 | SPECIAL_TOKENS_DICT = {'bos_token': "", 'eos_token': "", } 19 | 20 | 21 | 22 | use_device = torch.cuda.is_available() 23 | device = torch.device('cuda:0' if use_device else 'cpu') 24 | torch.backends.cudnn.benchmark = True 25 | 26 | random.seed(1234) 27 | torch.manual_seed(1234) 28 | np.random.seed(1234) 29 | 30 | 31 | def LayerWeightLoss(model, outputs, labels, args): 32 | total_loss = None 33 | total_weights = 0 34 | loss_fct = CrossEntropyLoss() 35 | 36 | # weight cross-entropy for different language layers 37 | for idx, logits_item in enumerate(outputs): 38 | loss = loss_fct(logits_item.view(-1, model.config.vocab_size), labels.view(-1)) 39 | if total_loss is None: 40 | total_loss = loss 41 | else: 42 | total_loss += loss * (idx + 1) 43 | total_weights += idx + 1 44 | 45 | # cos similarity loss for hidden representation prediction 46 | cos_loss_total = 0 47 | if args.predictor_training: 48 | hidden_states_list = model.hidden_states_list 49 | hidden_states_list = torch.stack(hidden_states_list, dim=2) # (bsz, seq_len, num_layer, model_d) 50 | hidden_states_proj_list = model.hidden_states_proj_list # (bsz, seq_len, num_layer, model_d) 51 | 52 | cos_loss_fct = CosineEmbeddingLoss() 53 | 54 | for i in range(len(hidden_states_proj_list) - 1): 55 | hidden_states = hidden_states_list[:, :, i+1:, :].reshape(-1, model.model_d) 56 | hidden_states_proj = hidden_states_proj_list[i][:, :, i+1:, :].reshape(-1, model.model_d) 57 | target = torch.ones(hidden_states.shape[0], device=hidden_states.device) 58 | cos_loss_total += cos_loss_fct(hidden_states, hidden_states_proj, target) 59 | 60 | cos_loss_total /= len(hidden_states_proj_list) - 1 61 | 62 | return total_loss / total_weights + cos_loss_total 63 | 64 | 65 | 66 | def train(model, train_dataloader, args, optimizer, scheduler, epoch): 67 | model.train() 68 | running_loss = .0 69 | print('Num Training Epochs = ', epoch) 70 | progress = tqdm(total=len(train_dataloader), desc='DeeCapModel') 71 | for idx, (tokens, _, img_features) in enumerate(train_dataloader): 72 | model.zero_grad() 73 | tokens, img_features = tokens.to(device), img_features.to(device, dtype=torch.float32) 74 | outputs = model(img_features, tokens) 75 | loss = LayerWeightLoss(model, outputs, tokens, args) 76 | loss.backward() 77 | optimizer.step() 78 | scheduler.step() 79 | optimizer.zero_grad() 80 | running_loss += loss.item() 81 | progress.set_postfix({"loss": running_loss / (idx + 1)}) 82 | progress.update() 83 | 84 | progress.close() 85 | return running_loss / len(train_dataloader) 86 | 87 | 88 | 89 | def evaluate_loss(model, test_dataloader, args): 90 | model.eval() 91 | running_loss = .0 92 | progress = tqdm(total=len(test_dataloader), desc='DeeCapModel') 93 | with torch.no_grad(): 94 | for idx, (tokens, _, img_features) in enumerate(test_dataloader): 95 | tokens, img_features = tokens.to(device), img_features.to(device, dtype=torch.float32) 96 | outputs = model(img_features, tokens) 97 | loss = LayerWeightLoss(model, outputs, tokens, args) 98 | 99 | running_loss += loss.item() 100 | progress.set_postfix({"loss": running_loss / (idx + 1)}) 101 | progress.update() 102 | val_loss = running_loss / len(test_dataloader) 103 | return val_loss 104 | 105 | 106 | 107 | def main(): 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument('--train_data_path', default='./data/train.pkl') 110 | parser.add_argument('--test_data_path', default='./data/test.pkl') 111 | parser.add_argument('--tokenizer_path', default='./ckpt/gpt2') 112 | parser.add_argument('--batch_size', default=5) 113 | parser.add_argument('--lr', default=1e-4) 114 | parser.add_argument('--epochs', default=8) 115 | parser.add_argument('--warmup_steps', default=5000) 116 | parser.add_argument('--out_dir', default='./ckpt') 117 | parser.add_argument('--model_type', default='deecap') 118 | parser.add_argument('--predictor_training', default=True) 119 | args = parser.parse_args() 120 | 121 | tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_path) 122 | tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) 123 | 124 | train_dataset = ClipCocoDataset(args.train_data_path, tokenizer) 125 | test_dataset = ClipCocoDataset(args.test_data_path, tokenizer) 126 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 127 | test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=False) 128 | 129 | config = TransformerConfig(vocab_size=len(tokenizer)) 130 | model = DeeCapModel(config).to(device) 131 | 132 | optimizer = AdamW(model.parameters(), lr=args.lr) 133 | scheduler = get_linear_schedule_with_warmup( 134 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(train_dataloader) 135 | ) 136 | 137 | for epoch in range(args.epochs): 138 | train(model, train_dataloader, args, optimizer, scheduler, epoch) 139 | val_loss = evaluate_loss(model, test_dataloader, args) 140 | 141 | torch.save( 142 | model.state_dict(), 143 | os.path.join(args.out_dir, '%s_last.pth' %args.model_type) 144 | ) 145 | break 146 | 147 | 148 | 149 | if __name__ == "__main__": 150 | main() -------------------------------------------------------------------------------- /evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_logits = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | shape = [int(sh) for sh in s.shape] 23 | beam = selected_beam 24 | for _ in shape[1:]: 25 | beam = beam.unsqueeze(-1) 26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 28 | s = s.view(*([-1, ] + shape[1:])) 29 | return s 30 | 31 | return fn 32 | 33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 34 | if isinstance(visual, torch.Tensor): 35 | visual_shape = visual.shape 36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 40 | visual_exp = visual.view(visual_exp_shape) 41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 43 | else: 44 | new_visual = [] 45 | for im in visual: 46 | visual_shape = im.shape 47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 51 | visual_exp = im.view(visual_exp_shape) 52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 54 | new_visual.append(new_im) 55 | visual = tuple(new_visual) 56 | return visual 57 | 58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_logits=False, **kwargs): 59 | self.b_s = utils.get_batch_size(visual) 60 | self.device = utils.get_device(visual) 61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 63 | self.log_probs = [] 64 | self.selected_words = None 65 | if return_logits: 66 | self.all_logits = [] 67 | 68 | outputs = [] 69 | with self.model.statefulness(self.b_s): 70 | for t in range(self.max_len): 71 | visual, outputs = self.iter(t, visual, outputs, return_logits, **kwargs) 72 | 73 | # Sort result 74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 75 | outputs = torch.cat(outputs, -1) 76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 77 | log_probs = torch.cat(self.log_probs, -1) 78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 79 | outputs = outputs.contiguous()[:, :out_size] 80 | log_probs = log_probs.contiguous()[:, :out_size] 81 | 82 | if return_logits: 83 | all_logits = torch.cat(self.all_logits, 2) 84 | all_logits = torch.gather(all_logits, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 85 | self.max_len, 86 | all_logits.shape[-1])) 87 | all_logits = all_logits.contiguous()[:, :out_size] 88 | 89 | if out_size == 1: 90 | outputs = outputs.squeeze(1) 91 | log_probs = log_probs.squeeze(1) 92 | if return_logits: 93 | all_logits = all_logits.squeeze(1) 94 | 95 | if return_logits: 96 | return outputs, log_probs, all_logits 97 | else: 98 | return outputs, log_probs 99 | 100 | def select(self, t, candidate_logprob, **kwargs): 101 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 102 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 103 | return selected_idx, selected_logprob 104 | 105 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_logits, **kwargs): 106 | cur_beam_size = 1 if t == 0 else self.beam_size 107 | 108 | word_logits = self.model.step(t, self.selected_words, visual, **kwargs) 109 | word_logits = word_logits.view(self.b_s, cur_beam_size, -1) 110 | word_logprob = torch.log_softmax(word_logits, dim=-1) 111 | candidate_logprob = self.seq_logprob + word_logprob 112 | 113 | # Mask sequence if it reaches EOS 114 | if t > 0: 115 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).type(visual.dtype).unsqueeze(-1) 116 | self.seq_mask = self.seq_mask * mask 117 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 118 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 119 | old_seq_logprob[:, :, 1:] = -999 120 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 121 | 122 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 123 | selected_beam = torch.floor_divide(selected_idx, candidate_logprob.shape[-1]) 124 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 125 | 126 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 127 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 128 | 129 | self.seq_logprob = selected_logprob.unsqueeze(-1) 130 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 131 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 132 | outputs.append(selected_words.unsqueeze(-1)) 133 | 134 | if return_logits: 135 | if t == 0: 136 | self.all_logits.append(word_logits.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 137 | else: 138 | self.all_logits.append(word_logits.unsqueeze(2)) 139 | 140 | this_word_logprob = torch.gather(word_logprob, 1, 141 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 142 | word_logprob.shape[-1])) 143 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 144 | self.log_probs = list( 145 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 146 | self.log_probs.append(this_word_logprob) 147 | self.selected_words = selected_words.view(-1, 1) 148 | 149 | return visual, outputs 150 | -------------------------------------------------------------------------------- /train_tic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import GPT2Tokenizer, AdamW, get_linear_schedule_with_warmup 3 | import torch 4 | from tqdm import tqdm 5 | from torch.nn import functional as nnf 6 | import os 7 | import multiprocessing 8 | import itertools 9 | import numpy as np 10 | import random 11 | from torch.optim import Adam 12 | 13 | from models import TICModel, TransformerConfig 14 | from dataset import ClipCocoDataset 15 | from torch.utils.data import Dataset, DataLoader 16 | import evaluation 17 | from evaluation import PTBTokenizer, Cider 18 | 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | use_device = torch.cuda.is_available() 24 | device = torch.device('cuda:0' if use_device else 'cpu') 25 | torch.backends.cudnn.benchmark = True 26 | 27 | random.seed(1234) 28 | torch.manual_seed(1234) 29 | np.random.seed(1234) 30 | 31 | 32 | 33 | def evaluate_metrics(model, test_dataloader, tokenizer, epoch): 34 | model.eval() 35 | gen = {} 36 | gts = {} 37 | with tqdm(desc='Epoch %d - evaluation' % epoch, unit='it', total=len(test_dataloader)) as pbar: 38 | for idx, (tokens, _, img_features) in enumerate(test_dataloader): 39 | img_features = img_features.to(device) 40 | with torch.no_grad(): 41 | text, _ = model.beam_search(img_features, beam_size=5, out_size=1) 42 | 43 | caps_gt = tokenizer.batch_decode(tokens) 44 | caps_gen = tokenizer.batch_decode(text) 45 | 46 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 47 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 48 | gen['%d_%d' % (idx, i)] = [gen_i, ] 49 | gts['%d_%d' % (idx, i)] = gts_i 50 | pbar.update() 51 | break 52 | gts = evaluation.PTBTokenizer.tokenize(gts) 53 | gen = evaluation.PTBTokenizer.tokenize(gen) 54 | scores, _ = evaluation.compute_all_scores(gts, gen) 55 | print(scores) 56 | return scores 57 | 58 | 59 | 60 | 61 | def train_xe(model, train_dataloader, args, optimizer, scheduler, epoch): 62 | model.train() 63 | running_loss = .0 64 | progress = tqdm(total=len(train_dataloader), desc='TICModel') 65 | for idx, (tokens, _, img_features) in enumerate(train_dataloader): 66 | model.zero_grad() 67 | tokens, img_features = tokens.to(device), img_features.to(device, dtype=torch.float32) 68 | outputs = model(img_features, tokens) 69 | loss = nnf.cross_entropy(outputs.reshape(-1, outputs.shape[-1]), tokens.flatten(), ignore_index=0) 70 | loss.backward() 71 | optimizer.step() 72 | scheduler.step() 73 | optimizer.zero_grad() 74 | running_loss += loss.item() 75 | progress.set_postfix({"loss": running_loss / (idx + 1)}) 76 | progress.update() 77 | break 78 | progress.close() 79 | return running_loss / len(train_dataloader) 80 | 81 | 82 | 83 | 84 | def train_scst(model, train_dataloader, cider_train, args, optimizer, scheduler, epoch, tokenizer): 85 | tokenizer_pool = multiprocessing.Pool() 86 | running_reward = .0 87 | running_reward_baseline = .0 88 | model.train() 89 | seq_len = model.language_decoder.max_len 90 | running_loss = .0 91 | beam_size = 5 92 | with tqdm(desc='Epoch %d - train' % epoch, unit='it', total=len(train_dataloader)) as pbar: 93 | for it, (caps_gt, _, img_features) in enumerate(train_dataloader): 94 | img_features = img_features.to(device) 95 | outs, log_probs, logits = model.beam_search(img_features, beam_size=beam_size, out_size=beam_size, return_logits=True) 96 | optimizer.zero_grad() 97 | 98 | # Rewards 99 | caps_gen = tokenizer.batch_decode(outs.view(-1, seq_len)) 100 | caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt))) 101 | caps_gt = tokenizer.batch_decode(caps_gt) 102 | 103 | caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt]) 104 | reward = cider_train.compute_score(caps_gt, caps_gen)[1].astype(np.float32) 105 | reward = torch.from_numpy(reward).to(device).view(img_features.shape[0], beam_size) 106 | reward_baseline = torch.mean(reward, -1, keepdim=True) 107 | loss = -torch.mean(log_probs, -1) * (reward - reward_baseline) 108 | 109 | loss = loss.mean() 110 | loss.backward() 111 | optimizer.step() 112 | scheduler.step() 113 | 114 | running_loss += loss.item() 115 | running_reward += reward.mean().item() 116 | running_reward_baseline += reward_baseline.mean().item() 117 | pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1), 118 | reward_baseline=running_reward_baseline / (it + 1)) 119 | pbar.update() 120 | break 121 | 122 | loss = running_loss / len(train_dataloader) 123 | reward = running_reward / len(train_dataloader) 124 | reward_baseline = running_reward_baseline / len(train_dataloader) 125 | return loss, reward, reward_baseline 126 | 127 | 128 | 129 | 130 | def main(): 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--train_data_path', default='./data/train.pkl') 133 | parser.add_argument('--test_data_path', default='./data/test.pkl') 134 | parser.add_argument('--tokenizer_path', default='./ckpt/gpt2') 135 | parser.add_argument('--batch_size', default=5) 136 | parser.add_argument('--lr', default=1e-2) 137 | parser.add_argument('--epochs', default=10) 138 | parser.add_argument('--warmup_steps', default=5000) 139 | parser.add_argument('--out_dir', default='./ckpt') 140 | parser.add_argument('--model_type', default='tic') 141 | parser.add_argument('--phase', type=str, default='xs', choices=('xe', 'scst')) 142 | args = parser.parse_args() 143 | 144 | tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_path) 145 | train_dataset = ClipCocoDataset(args.train_data_path, tokenizer) 146 | test_dataset = ClipCocoDataset(args.test_data_path, tokenizer) 147 | 148 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 149 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False) 150 | ref_caps_train = list(tokenizer.decode(text) for text in test_dataset.captions_tokens) 151 | cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train)) 152 | 153 | config = TransformerConfig() 154 | model = TICModel(config).to(device) 155 | 156 | optimizer = AdamW(model.parameters(), lr=args.lr) 157 | scheduler = get_linear_schedule_with_warmup( 158 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(train_dataloader) 159 | ) 160 | 161 | use_rl = False 162 | best_cider = .0 163 | patience = 0 164 | 165 | 166 | for epoch in range(args.epochs): 167 | if not use_rl: 168 | train_loss = train_xe(model, train_dataloader, args, optimizer, scheduler, epoch) 169 | else: 170 | train_loss, reward, reward_baseline = train_scst(model, train_dataloader, cider_train, args, optimizer, scheduler, epoch, tokenizer) 171 | 172 | scores = evaluate_metrics(model, test_dataloader, tokenizer, epoch) 173 | val_cider = scores['CIDEr'] 174 | 175 | best = False 176 | if val_cider >= best_cider: 177 | best_cider = val_cider 178 | patience = 0 179 | best = True 180 | else: 181 | patience += 1 182 | 183 | switch_to_rl = False 184 | exit_train = False 185 | if patience == 5: 186 | if not use_rl: 187 | use_rl = True 188 | switch_to_rl = True 189 | patience = 0 190 | optim = Adam(model.parameters(), lr=5e-6) 191 | print("Switching to RL") 192 | else: 193 | print('patience reached.') 194 | exit_train = True 195 | 196 | 197 | torch.save( 198 | model.state_dict(), 199 | os.path.join(args.out_dir, f"{args.model_type}-{epoch:02d}.pt") 200 | ) 201 | break 202 | 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.containers import Module 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | """ 9 | Scaled dot-product attention 10 | """ 11 | 12 | def __init__(self, d_model, d_k, d_v, h): 13 | ''' 14 | :param d_model: Output dimensionality of the model 15 | :param d_k: Dimensionality of queries and keys 16 | :param d_v: Dimensionality of values 17 | :param h: Number of heads 18 | ''' 19 | super(ScaledDotProductAttention, self).__init__() 20 | self.fc_q = nn.Linear(d_model, h * d_k) 21 | self.fc_k = nn.Linear(d_model, h * d_k) 22 | self.fc_v = nn.Linear(d_model, h * d_v) 23 | self.fc_o = nn.Linear(h * d_v, d_model) 24 | 25 | self.d_model = d_model 26 | self.d_k = d_k 27 | self.d_v = d_v 28 | self.h = h 29 | 30 | self.init_weights() 31 | 32 | def init_weights(self): 33 | nn.init.xavier_uniform_(self.fc_q.weight) 34 | nn.init.xavier_uniform_(self.fc_k.weight) 35 | nn.init.xavier_uniform_(self.fc_v.weight) 36 | nn.init.xavier_uniform_(self.fc_o.weight) 37 | nn.init.constant_(self.fc_q.bias, 0) 38 | nn.init.constant_(self.fc_k.bias, 0) 39 | nn.init.constant_(self.fc_v.bias, 0) 40 | nn.init.constant_(self.fc_o.bias, 0) 41 | 42 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 43 | """ 44 | Computes 45 | :param queries: Queries (b_s, nq, d_model) 46 | :param keys: Keys (b_s, nk, d_model) 47 | :param values: Values (b_s, nk, d_model) 48 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 49 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 50 | :return: 51 | """ 52 | b_s, nq = queries.shape[:2] 53 | nk = keys.shape[1] 54 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 55 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 56 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 57 | 58 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 59 | if attention_weights is not None: 60 | att = att * attention_weights 61 | if attention_mask is not None: 62 | att = att.masked_fill(attention_mask, -np.inf) 63 | att = torch.softmax(att, -1) 64 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 65 | out = self.fc_o(out) # (b_s, nq, d_model) 66 | return out 67 | 68 | 69 | class ScaledDotProductAttentionMemory(nn.Module): 70 | """ 71 | Scaled dot-product attention with memory 72 | """ 73 | 74 | def __init__(self, d_model, d_k, d_v, h, m): 75 | """ 76 | :param d_model: Output dimensionality of the model 77 | :param d_k: Dimensionality of queries and keys 78 | :param d_v: Dimensionality of values 79 | :param h: Number of heads 80 | :param m: Number of memory slots 81 | """ 82 | super(ScaledDotProductAttentionMemory, self).__init__() 83 | self.fc_q = nn.Linear(d_model, h * d_k) 84 | self.fc_k = nn.Linear(d_model, h * d_k) 85 | self.fc_v = nn.Linear(d_model, h * d_v) 86 | self.fc_o = nn.Linear(h * d_v, d_model) 87 | self.d_model = d_model 88 | self.d_k = d_k 89 | self.d_v = d_v 90 | self.h = h 91 | self.m = m 92 | 93 | if self.m > 0: 94 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) 95 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) 96 | 97 | self.init_weights() 98 | 99 | def init_weights(self): 100 | nn.init.xavier_uniform_(self.fc_q.weight) 101 | nn.init.xavier_uniform_(self.fc_k.weight) 102 | nn.init.xavier_uniform_(self.fc_v.weight) 103 | nn.init.xavier_uniform_(self.fc_o.weight) 104 | nn.init.constant_(self.fc_q.bias, 0) 105 | nn.init.constant_(self.fc_k.bias, 0) 106 | nn.init.constant_(self.fc_v.bias, 0) 107 | nn.init.constant_(self.fc_o.bias, 0) 108 | 109 | if self.m > 0: 110 | nn.init.normal_(self.m_k, 0, 1 / self.d_k) 111 | nn.init.normal_(self.m_v, 0, 1 / self.m) 112 | 113 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 114 | """ 115 | Computes 116 | :param queries: Queries (b_s, nq, d_model) 117 | :param keys: Keys (b_s, nk, d_model) 118 | :param values: Values (b_s, nk, d_model) 119 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 120 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 121 | :return: 122 | """ 123 | b_s, nq = queries.shape[:2] 124 | nk = keys.shape[1] 125 | 126 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 127 | 128 | if self.m > 0: 129 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) 130 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) 131 | k = torch.cat([self.fc_k(keys), m_k], 1) 132 | v = torch.cat([self.fc_v(values), m_v], 1) 133 | else: 134 | k = self.fc_k(keys) 135 | v = self.fc_v(values) 136 | 137 | k = k.view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 138 | v = v.view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 139 | 140 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 141 | if attention_weights is not None: 142 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) 143 | if attention_mask is not None: 144 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) 145 | att = torch.softmax(att, -1) 146 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 147 | out = self.fc_o(out) # (b_s, nq, d_model) 148 | return out 149 | 150 | 151 | class MultiHeadAttention(Module): 152 | """ 153 | Multi-head attention layer with Dropout and Layer Normalization. 154 | """ 155 | 156 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 157 | attention_module=None, attention_module_kwargs=None): 158 | super(MultiHeadAttention, self).__init__() 159 | self.identity_map_reordering = identity_map_reordering 160 | if attention_module is not None: 161 | if attention_module_kwargs is not None: 162 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs) 163 | else: 164 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 165 | else: 166 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 167 | self.dropout = nn.Dropout(p=dropout) 168 | self.layer_norm = nn.LayerNorm(d_model) 169 | 170 | self.can_be_stateful = can_be_stateful 171 | if self.can_be_stateful: 172 | self.register_state('running_keys', None) 173 | self.register_state('running_values', None) 174 | 175 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 176 | if self.can_be_stateful and self._is_stateful: 177 | if self.running_keys is None: 178 | self.running_keys = keys 179 | self.running_values = values 180 | else: 181 | self.running_keys = torch.cat([self.running_keys, keys], 1) 182 | self.running_values = torch.cat([self.running_values, values], 1) 183 | keys = self.running_keys 184 | values = self.running_values 185 | 186 | if self.identity_map_reordering: 187 | q_norm = self.layer_norm(queries) 188 | k_norm = self.layer_norm(keys) 189 | v_norm = self.layer_norm(values) 190 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) 191 | out = queries + self.dropout(torch.relu(out)) 192 | else: 193 | out = self.attention(queries, keys, values, attention_mask, attention_weights) 194 | out = self.dropout(out) 195 | out = self.layer_norm(queries + out) 196 | return out 197 | -------------------------------------------------------------------------------- /evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | from models.transformer.attention import MultiHeadAttention 6 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 7 | from models.containers import Module, ModuleList 8 | from models.utils import one_hot_to_index 9 | 10 | 11 | class MeshedDecoderLayer(Module): 12 | def __init__(self, N_enc, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(MeshedDecoderLayer, self).__init__() 15 | self.N_enc = N_enc 16 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 17 | attention_module=self_att_module, 18 | attention_module_kwargs=self_att_module_kwargs) 19 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 20 | attention_module=enc_att_module, 21 | attention_module_kwargs=enc_att_module_kwargs) 22 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 23 | 24 | self.fc_alpha = ModuleList([nn.Linear(d_model + d_model, d_model) for _ in range(N_enc)]) 25 | 26 | self.init_weights() 27 | 28 | def init_weights(self): 29 | for fc_alpha in self.fc_alpha: 30 | nn.init.xavier_uniform_(fc_alpha.weight) 31 | nn.init.constant_(fc_alpha.bias, 0) 32 | 33 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 34 | self_att = self.self_att(input, input, input, mask_self_att) 35 | self_att = self_att * mask_pad 36 | 37 | enc_att = None 38 | for i in range(self.N_enc): 39 | enc_att_i = self.enc_att(self_att, enc_output[:, i], enc_output[:, i], mask_enc_att) * mask_pad 40 | alpha_i = torch.sigmoid(self.fc_alpha[i](torch.cat([self_att, enc_att_i], -1))) 41 | if enc_att is None: 42 | enc_att = enc_att_i * alpha_i 43 | else: 44 | enc_att += enc_att_i * alpha_i 45 | 46 | enc_att /= np.sqrt(self.N_enc) 47 | enc_att *= mask_pad 48 | 49 | ff = self.pwff(enc_att) 50 | ff = ff * mask_pad 51 | return ff 52 | 53 | 54 | class DecoderLayer(Module): 55 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 56 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 57 | super(DecoderLayer, self).__init__() 58 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 59 | attention_module=self_att_module, 60 | attention_module_kwargs=self_att_module_kwargs) 61 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 62 | attention_module=enc_att_module, 63 | attention_module_kwargs=enc_att_module_kwargs) 64 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 65 | 66 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 67 | self_att = self.self_att(input, input, input, mask_self_att) 68 | enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att) 69 | ff = self.pwff(enc_att) 70 | ff = ff * mask_pad 71 | 72 | return ff 73 | 74 | 75 | class MeshedDecoder(Module): 76 | def __init__(self, vocab_size, max_len, N_dec, N_enc, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 77 | dropout=.1, 78 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 79 | super(MeshedDecoder, self).__init__() 80 | self.d_model = d_model 81 | self.vocab_size = vocab_size 82 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 83 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 84 | self.layers = ModuleList( 85 | [MeshedDecoderLayer(N_enc, d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 86 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 87 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 88 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 89 | self.max_len = max_len 90 | self.padding_idx = padding_idx 91 | self.N = N_dec 92 | 93 | self.register_state('running_mask_self_attention', None) 94 | self.register_state('running_seq', torch.zeros((1,)).long()) 95 | 96 | def forward(self, input, encoder_output_list, mask_encoder): 97 | # input (b_s, seq_len) 98 | input = input[:, :self.max_len] 99 | b_s, seq_len = input.shape[:2] 100 | 101 | if input.dtype in [torch.long, torch.int]: 102 | input_index = input 103 | else: 104 | input_index = one_hot_to_index(input) 105 | 106 | mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) # (b_s, seq_len, 1) 107 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device), 108 | diagonal=1) 109 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 110 | mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool() 111 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 112 | if self._is_stateful: 113 | if self.running_mask_self_attention is None: 114 | self.running_mask_self_attention = mask_self_attention 115 | else: 116 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], 117 | -1) 118 | mask_self_attention = self.running_mask_self_attention 119 | 120 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 121 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 122 | if self._is_stateful: 123 | self.running_seq.add_(1) 124 | seq = self.running_seq 125 | 126 | if input.dtype in [torch.long, torch.int]: 127 | out = self.word_emb(input) 128 | else: 129 | out = input @ self.word_emb.weight 130 | 131 | out = out + self.pos_emb(seq) 132 | for i, l in enumerate(self.layers): 133 | out = l(out, encoder_output_list, mask_queries, mask_self_attention, mask_encoder) 134 | 135 | out = self.fc(out) 136 | return out 137 | 138 | 139 | class Decoder(Module): 140 | def __init__(self, vocab_size=50257, max_len=40, N_dec=6, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 141 | dropout=.1, self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, 142 | enc_att_module_kwargs=None): 143 | super(Decoder, self).__init__() 144 | self.d_model = d_model 145 | self.vocab_size = vocab_size 146 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 147 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 148 | self.layers = ModuleList( 149 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 150 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 151 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 152 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 153 | self.max_len = max_len 154 | self.padding_idx = padding_idx 155 | self.N = N_dec 156 | 157 | self.register_state('running_mask_self_attention', None) 158 | self.register_state('running_seq', torch.zeros((1,)).long()) 159 | 160 | def forward(self, input, encoder_output, mask_encoder): 161 | # input (b_s, seq_len) 162 | input = input[:, :self.max_len] 163 | b_s, seq_len = input.shape[:2] 164 | 165 | if input.dtype in [torch.long, torch.int]: 166 | input_index = input 167 | else: 168 | input_index = one_hot_to_index(input) 169 | 170 | mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) 171 | # (b_s, seq_len, 1) 172 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device), 173 | diagonal=1) 174 | # print(mask_self_attention) (seq_len, seq_len) 175 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 176 | mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool() 177 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 178 | if self._is_stateful: 179 | if self.running_mask_self_attention is None: 180 | self.running_mask_self_attention = mask_self_attention 181 | else: 182 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], 183 | -1) 184 | mask_self_attention = self.running_mask_self_attention 185 | 186 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 187 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 188 | if self._is_stateful: 189 | self.running_seq.add_(1) 190 | seq = self.running_seq 191 | 192 | if input.dtype in [torch.long, torch.int]: 193 | out = self.word_emb(input) 194 | else: 195 | out = input @ self.word_emb.weight 196 | 197 | out = out + self.pos_emb(seq) 198 | for i, l in enumerate(self.layers): 199 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 200 | 201 | out = self.fc(out) 202 | return out 203 | 204 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import collections 5 | import torch 6 | from .example import Example 7 | from .utils import nostdout 8 | from pycocotools.coco import COCO as pyCOCO 9 | 10 | 11 | class Dataset(object): 12 | def __init__(self, examples, fields): 13 | self.examples = examples 14 | self.fields = dict(fields) 15 | 16 | def collate_fn(self): 17 | def collate(batch): 18 | if len(self.fields) == 1: 19 | batch = [batch, ] 20 | else: 21 | batch = list(zip(*batch)) 22 | 23 | tensors = [] 24 | for field, data in zip(self.fields.values(), batch): 25 | tensor = field.process(data) 26 | if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): 27 | tensors.extend(tensor) 28 | else: 29 | tensors.append(tensor) 30 | 31 | if len(tensors) > 1: 32 | return tensors 33 | else: 34 | return tensors[0] 35 | 36 | return collate 37 | 38 | def __getitem__(self, i): 39 | example = self.examples[i] 40 | data = [] 41 | for field_name, field in self.fields.items(): 42 | data.append(field.preprocess(getattr(example, field_name))) 43 | 44 | if len(data) == 1: 45 | data = data[0] 46 | return data 47 | 48 | def __len__(self): 49 | return len(self.examples) 50 | 51 | def __getattr__(self, attr): 52 | if attr in self.fields: 53 | for x in self.examples: 54 | yield getattr(x, attr) 55 | 56 | 57 | class ValueDataset(Dataset): 58 | def __init__(self, examples, fields, dictionary): 59 | self.dictionary = dictionary 60 | super(ValueDataset, self).__init__(examples, fields) 61 | 62 | def collate_fn(self): 63 | def collate(batch): 64 | value_batch_flattened = list(itertools.chain(*batch)) 65 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) 66 | 67 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) 68 | if isinstance(value_tensors_flattened, collections.Sequence) \ 69 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): 70 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened] 71 | else: 72 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] 73 | 74 | return value_tensors 75 | return collate 76 | 77 | def __getitem__(self, i): 78 | if i not in self.dictionary: 79 | raise IndexError 80 | 81 | values_data = [] 82 | for idx in self.dictionary[i]: 83 | value_data = super(ValueDataset, self).__getitem__(idx) 84 | values_data.append(value_data) 85 | return values_data 86 | 87 | def __len__(self): 88 | return len(self.dictionary) 89 | 90 | 91 | class DictionaryDataset(Dataset): 92 | def __init__(self, examples, fields, key_fields): 93 | if not isinstance(key_fields, (tuple, list)): 94 | key_fields = (key_fields,) 95 | for field in key_fields: 96 | assert (field in fields) 97 | 98 | dictionary = collections.defaultdict(list) 99 | key_fields = {k: fields[k] for k in key_fields} 100 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} 101 | key_examples = [] 102 | key_dict = dict() 103 | value_examples = [] 104 | 105 | for i, e in enumerate(examples): 106 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) 107 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) 108 | if key_example not in key_dict: 109 | key_dict[key_example] = len(key_examples) 110 | key_examples.append(key_example) 111 | 112 | value_examples.append(value_example) 113 | dictionary[key_dict[key_example]].append(i) 114 | 115 | self.key_dataset = Dataset(key_examples, key_fields) 116 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) 117 | super(DictionaryDataset, self).__init__(examples, fields) 118 | 119 | def collate_fn(self): 120 | def collate(batch): 121 | key_batch, value_batch = list(zip(*batch)) 122 | key_tensors = self.key_dataset.collate_fn()(key_batch) 123 | value_tensors = self.value_dataset.collate_fn()(value_batch) 124 | return key_tensors, value_tensors 125 | return collate 126 | 127 | def __getitem__(self, i): 128 | return self.key_dataset[i], self.value_dataset[i] 129 | 130 | def __len__(self): 131 | return len(self.key_dataset) 132 | 133 | 134 | def unique(sequence): 135 | seen = set() 136 | if isinstance(sequence[0], list): 137 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] 138 | else: 139 | return [x for x in sequence if not (x in seen or seen.add(x))] 140 | 141 | 142 | class PairedDataset(Dataset): 143 | def __init__(self, examples, fields): 144 | assert ('image' in fields) 145 | assert ('text' in fields) 146 | super(PairedDataset, self).__init__(examples, fields) 147 | self.image_field = self.fields['image'] 148 | self.text_field = self.fields['text'] 149 | 150 | def image_set(self): 151 | img_list = [e.image for e in self.examples] 152 | image_set = unique(img_list) 153 | examples = [Example.fromdict({'image': i}) for i in image_set] 154 | dataset = Dataset(examples, {'image': self.image_field}) 155 | return dataset 156 | 157 | def text_set(self): 158 | text_list = [e.text for e in self.examples] 159 | text_list = unique(text_list) 160 | examples = [Example.fromdict({'text': t}) for t in text_list] 161 | dataset = Dataset(examples, {'text': self.text_field}) 162 | return dataset 163 | 164 | def image_dictionary(self, fields=None): 165 | if not fields: 166 | fields = self.fields 167 | dataset = DictionaryDataset(self.examples, fields, key_fields='image') 168 | return dataset 169 | 170 | def text_dictionary(self, fields=None): 171 | if not fields: 172 | fields = self.fields 173 | dataset = DictionaryDataset(self.examples, fields, key_fields='text') 174 | return dataset 175 | 176 | @property 177 | def splits(self): 178 | raise NotImplementedError 179 | 180 | 181 | class COCO(PairedDataset): 182 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 183 | cut_validation=False): 184 | roots = {} 185 | roots['train'] = { 186 | 'img': os.path.join(img_root, 'train2014'), 187 | 'cap': os.path.join(ann_root, 'captions_train2014.json') 188 | } 189 | roots['val'] = { 190 | 'img': os.path.join(img_root, 'val2014'), 191 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 192 | } 193 | roots['test'] = { 194 | 'img': os.path.join(img_root, 'val2014'), 195 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 196 | } 197 | roots['trainrestval'] = { 198 | 'img': (roots['train']['img'], roots['val']['img']), 199 | 'cap': (roots['train']['cap'], roots['val']['cap']) 200 | } 201 | 202 | if id_root is not None: 203 | ids = {} 204 | ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy')) 205 | ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy')) 206 | if cut_validation: 207 | ids['val'] = ids['val'][:5000] 208 | ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy')) 209 | ids['trainrestval'] = ( 210 | ids['train'], 211 | np.load(os.path.join(id_root, 'coco_restval_ids.npy'))) 212 | 213 | if use_restval: 214 | roots['train'] = roots['trainrestval'] 215 | ids['train'] = ids['trainrestval'] 216 | else: 217 | ids = None 218 | 219 | with nostdout(): 220 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids) 221 | examples = self.train_examples + self.val_examples + self.test_examples 222 | super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field}) 223 | 224 | @property 225 | def splits(self): 226 | train_split = PairedDataset(self.train_examples, self.fields) 227 | val_split = PairedDataset(self.val_examples, self.fields) 228 | test_split = PairedDataset(self.test_examples, self.fields) 229 | return train_split, val_split, test_split 230 | 231 | @classmethod 232 | def get_samples(cls, roots, ids_dataset=None): 233 | train_samples = [] 234 | val_samples = [] 235 | test_samples = [] 236 | 237 | for split in ['train', 'val', 'test']: 238 | if isinstance(roots[split]['cap'], tuple): 239 | coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1])) 240 | root = roots[split]['img'] 241 | else: 242 | coco_dataset = (pyCOCO(roots[split]['cap']),) 243 | root = (roots[split]['img'],) 244 | 245 | if ids_dataset is None: 246 | ids = list(coco_dataset.anns.keys()) 247 | else: 248 | ids = ids_dataset[split] 249 | 250 | if isinstance(ids, tuple): 251 | bp = len(ids[0]) 252 | ids = list(ids[0]) + list(ids[1]) 253 | else: 254 | bp = len(ids) 255 | 256 | for index in range(len(ids)): 257 | if index < bp: 258 | coco = coco_dataset[0] 259 | img_root = root[0] 260 | else: 261 | coco = coco_dataset[1] 262 | img_root = root[1] 263 | 264 | ann_id = ids[index] 265 | caption = coco.anns[ann_id]['caption'] 266 | img_id = coco.anns[ann_id]['image_id'] 267 | filename = coco.loadImgs(img_id)[0]['file_name'] 268 | 269 | example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption}) 270 | 271 | if split == 'train': 272 | train_samples.append(example) 273 | elif split == 'val': 274 | val_samples.append(example) 275 | elif split == 'test': 276 | test_samples.append(example) 277 | 278 | return train_samples, val_samples, test_samples 279 | 280 | -------------------------------------------------------------------------------- /data/field.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | from collections import Counter, OrderedDict 3 | from torch.utils.data.dataloader import default_collate 4 | from itertools import chain 5 | import six 6 | import torch 7 | import numpy as np 8 | import h5py 9 | import os 10 | import warnings 11 | import shutil 12 | 13 | from .dataset import Dataset 14 | from .vocab import Vocab 15 | from .utils import get_tokenizer 16 | 17 | 18 | class RawField(object): 19 | """ Defines a general datatype. 20 | 21 | Every dataset consists of one or more types of data. For instance, 22 | a machine translation dataset contains paired examples of text, while 23 | an image captioning dataset contains images and texts. 24 | Each of these types of data is represented by a RawField object. 25 | An RawField object does not assume any property of the data type and 26 | it holds parameters relating to how a datatype should be processed. 27 | 28 | Attributes: 29 | preprocessing: The Pipeline that will be applied to examples 30 | using this field before creating an example. 31 | Default: None. 32 | postprocessing: A Pipeline that will be applied to a list of examples 33 | using this field before assigning to a batch. 34 | Function signature: (batch(list)) -> object 35 | Default: None. 36 | """ 37 | 38 | def __init__(self, preprocessing=None, postprocessing=None): 39 | self.preprocessing = preprocessing 40 | self.postprocessing = postprocessing 41 | 42 | def preprocess(self, x): 43 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """ 44 | if self.preprocessing is not None: 45 | return self.preprocessing(x) 46 | else: 47 | return x 48 | 49 | def process(self, batch, *args, **kwargs): 50 | """ Process a list of examples to create a batch. 51 | 52 | Postprocess the batch with user-provided Pipeline. 53 | 54 | Args: 55 | batch (list(object)): A list of object from a batch of examples. 56 | Returns: 57 | object: Processed object given the input and custom 58 | postprocessing Pipeline. 59 | """ 60 | if self.postprocessing is not None: 61 | batch = self.postprocessing(batch) 62 | return default_collate(batch) 63 | 64 | 65 | class Merge(RawField): 66 | def __init__(self, *fields): 67 | super(Merge, self).__init__() 68 | self.fields = fields 69 | 70 | def preprocess(self, x): 71 | return tuple(f.preprocess(x) for f in self.fields) 72 | 73 | def process(self, batch, *args, **kwargs): 74 | if len(self.fields) == 1: 75 | batch = [batch, ] 76 | else: 77 | batch = list(zip(*batch)) 78 | 79 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch)) 80 | return out 81 | 82 | 83 | class ImageDetectionsField(RawField): 84 | def __init__(self, preprocessing=None, postprocessing=None, detections_path=None, max_detections=100, 85 | sort_by_prob=False, load_in_tmp=True): 86 | self.max_detections = max_detections 87 | self.detections_path = detections_path 88 | self.sort_by_prob = sort_by_prob 89 | 90 | tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path)) 91 | 92 | if load_in_tmp: 93 | if not os.path.isfile(tmp_detections_path): 94 | if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path): 95 | warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path) 96 | else: 97 | warnings.warn("Copying detection file to /tmp") 98 | shutil.copyfile(detections_path, tmp_detections_path) 99 | warnings.warn("Done.") 100 | self.detections_path = tmp_detections_path 101 | else: 102 | self.detections_path = tmp_detections_path 103 | 104 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing) 105 | 106 | def preprocess(self, x, avoid_precomp=False): 107 | image_id = int(x.split('_')[-1].split('.')[0]) 108 | try: 109 | f = h5py.File(self.detections_path, 'r') 110 | precomp_data = f['%d_features' % image_id][()] 111 | if self.sort_by_prob: 112 | precomp_data = precomp_data[np.argsort(np.max(f['%d_cls_prob' % image_id][()], -1))[::-1]] 113 | except KeyError: 114 | warnings.warn('Could not find detections for %d' % image_id) 115 | precomp_data = np.random.rand(10,2048) 116 | 117 | delta = self.max_detections - precomp_data.shape[0] 118 | if delta > 0: 119 | precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0) 120 | elif delta < 0: 121 | precomp_data = precomp_data[:self.max_detections] 122 | 123 | return precomp_data.astype(np.float32) 124 | 125 | 126 | class TextField(RawField): 127 | vocab_cls = Vocab 128 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python 129 | # numeric type. 130 | dtypes = { 131 | torch.float32: float, 132 | torch.float: float, 133 | torch.float64: float, 134 | torch.double: float, 135 | torch.float16: float, 136 | torch.half: float, 137 | 138 | torch.uint8: int, 139 | torch.int8: int, 140 | torch.int16: int, 141 | torch.short: int, 142 | torch.int32: int, 143 | torch.int: int, 144 | torch.int64: int, 145 | torch.long: int, 146 | } 147 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 148 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 149 | 150 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long, 151 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()), 152 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="", 153 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True): 154 | self.use_vocab = use_vocab 155 | self.init_token = init_token 156 | self.eos_token = eos_token 157 | self.fix_length = fix_length 158 | self.dtype = dtype 159 | self.lower = lower 160 | self.tokenize = get_tokenizer(tokenize) 161 | self.remove_punctuation = remove_punctuation 162 | self.include_lengths = include_lengths 163 | self.batch_first = batch_first 164 | self.pad_token = pad_token 165 | self.unk_token = unk_token 166 | self.pad_first = pad_first 167 | self.truncate_first = truncate_first 168 | self.vocab = None 169 | self.vectors = vectors 170 | if nopoints: 171 | self.punctuations.append("..") 172 | 173 | super(TextField, self).__init__(preprocessing, postprocessing) 174 | 175 | def preprocess(self, x): 176 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type): 177 | x = six.text_type(x, encoding='utf-8') 178 | if self.lower: 179 | x = six.text_type.lower(x) 180 | x = self.tokenize(x.rstrip('\n')) 181 | if self.remove_punctuation: 182 | x = [w for w in x if w not in self.punctuations] 183 | if self.preprocessing is not None: 184 | return self.preprocessing(x) 185 | else: 186 | return x 187 | 188 | def process(self, batch, device=None): 189 | padded = self.pad(batch) 190 | tensor = self.numericalize(padded, device=device) 191 | return tensor 192 | 193 | def build_vocab(self, *args, **kwargs): 194 | counter = Counter() 195 | sources = [] 196 | for arg in args: 197 | if isinstance(arg, Dataset): 198 | sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] 199 | else: 200 | sources.append(arg) 201 | 202 | for data in sources: 203 | for x in data: 204 | x = self.preprocess(x) 205 | try: 206 | counter.update(x) 207 | except TypeError: 208 | counter.update(chain.from_iterable(x)) 209 | 210 | specials = list(OrderedDict.fromkeys([ 211 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 212 | self.eos_token] 213 | if tok is not None])) 214 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 215 | 216 | def pad(self, minibatch): 217 | """Pad a batch of examples using this field. 218 | Pads to self.fix_length if provided, otherwise pads to the length of 219 | the longest example in the batch. Prepends self.init_token and appends 220 | self.eos_token if those attributes are not None. Returns a tuple of the 221 | padded list and a list containing lengths of each example if 222 | `self.include_lengths` is `True`, else just 223 | returns the padded list. 224 | """ 225 | minibatch = list(minibatch) 226 | if self.fix_length is None: 227 | max_len = max(len(x) for x in minibatch) 228 | else: 229 | max_len = self.fix_length + ( 230 | self.init_token, self.eos_token).count(None) - 2 231 | padded, lengths = [], [] 232 | for x in minibatch: 233 | if self.pad_first: 234 | padded.append( 235 | [self.pad_token] * max(0, max_len - len(x)) + 236 | ([] if self.init_token is None else [self.init_token]) + 237 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 238 | ([] if self.eos_token is None else [self.eos_token])) 239 | else: 240 | padded.append( 241 | ([] if self.init_token is None else [self.init_token]) + 242 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 243 | ([] if self.eos_token is None else [self.eos_token]) + 244 | [self.pad_token] * max(0, max_len - len(x))) 245 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 246 | if self.include_lengths: 247 | return padded, lengths 248 | return padded 249 | 250 | def numericalize(self, arr, device=None): 251 | """Turn a batch of examples that use this field into a list of Variables. 252 | If the field has include_lengths=True, a tensor of lengths will be 253 | included in the return value. 254 | Arguments: 255 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 256 | List of tokenized and padded examples, or tuple of List of 257 | tokenized and padded examples and List of lengths of each 258 | example if self.include_lengths is True. 259 | device (str or torch.device): A string or instance of `torch.device` 260 | specifying which device the Variables are going to be created on. 261 | If left as default, the tensors will be created on cpu. Default: None. 262 | """ 263 | if self.include_lengths and not isinstance(arr, tuple): 264 | raise ValueError("Field has include_lengths set to True, but " 265 | "input data is not a tuple of " 266 | "(data batch, batch lengths).") 267 | if isinstance(arr, tuple): 268 | arr, lengths = arr 269 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device) 270 | 271 | if self.use_vocab: 272 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 273 | 274 | if self.postprocessing is not None: 275 | arr = self.postprocessing(arr, self.vocab) 276 | 277 | var = torch.tensor(arr, dtype=self.dtype, device=device) 278 | else: 279 | if self.vectors: 280 | arr = [[self.vectors[x] for x in ex] for ex in arr] 281 | if self.dtype not in self.dtypes: 282 | raise ValueError( 283 | "Specified Field dtype {} can not be used with " 284 | "use_vocab=False because we do not know how to numericalize it. " 285 | "Please raise an issue at " 286 | "https://github.com/pytorch/text/issues".format(self.dtype)) 287 | numericalization_func = self.dtypes[self.dtype] 288 | # It doesn't make sense to explictly coerce to a numeric type if 289 | # the data is sequential, since it's unclear how to coerce padding tokens 290 | # to a numeric type. 291 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 292 | else x for x in arr] 293 | 294 | if self.postprocessing is not None: 295 | arr = self.postprocessing(arr, None) 296 | 297 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr]) 298 | 299 | # var = torch.tensor(arr, dtype=self.dtype, device=device) 300 | if not self.batch_first: 301 | var.t_() 302 | var = var.contiguous() 303 | 304 | if self.include_lengths: 305 | return var, lengths 306 | return var 307 | 308 | def decode(self, word_idxs, join_words=True): 309 | if isinstance(word_idxs, list) and len(word_idxs) == 0: 310 | return self.decode([word_idxs, ], join_words)[0] 311 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int): 312 | return self.decode([word_idxs, ], join_words)[0] 313 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1: 314 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0] 315 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1: 316 | return self.decode(word_idxs.unsqueeze(0), join_words)[0] 317 | 318 | captions = [] 319 | for wis in word_idxs: 320 | caption = [] 321 | for wi in wis: 322 | word = self.vocab.itos[int(wi)] 323 | if word == self.eos_token: 324 | break 325 | caption.append(word) 326 | if join_words: 327 | caption = ' '.join(caption) 328 | captions.append(caption) 329 | return captions 330 | -------------------------------------------------------------------------------- /models/deecap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .containers import Module, ModuleList 4 | from .transformer import DecoderLayer, sinusoid_encoding_table, Encoder 5 | from .utils import one_hot_to_index 6 | 7 | 8 | 9 | def entropy(x): 10 | prob = torch.nn.functional.softmax(x, dim=1) 11 | return -torch.sum(prob * torch.log(prob), dim=1) 12 | 13 | 14 | 15 | 16 | class DeeCapPooler(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | self.dense = nn.Linear(config.n_embd, config.n_embd) 20 | self.activation = nn.Tanh() 21 | 22 | def forward(self, hidden_states): 23 | # We "pool" the model by simply taking the last hidden state 24 | # last_token_tensor = hidden_states[:, -1] 25 | last_token_tensor = hidden_states 26 | pooled_output = self.dense(last_token_tensor) 27 | pooled_output = self.activation(pooled_output) 28 | return pooled_output 29 | 30 | 31 | 32 | class ImitationNet(nn.Module): 33 | def __init__(self, config): 34 | super().__init__() 35 | self.weight = nn.Parameter(torch.zeros(config.n_layer, config.n_embd, config.n_embd). 36 | normal_(mean=0.0, std=config.initializer_range)) 37 | self.bias = nn.Parameter(torch.zeros(config.n_layer, config.n_embd)) 38 | self.act = nn.Tanh() 39 | 40 | def forward(self, hidden_representation): 41 | approximate_representation = self.act(hidden_representation.matmul(self.weight).permute(1, 0, 2) + self.bias) 42 | return approximate_representation 43 | 44 | 45 | 46 | 47 | class InternelClassifierWithGate(nn.Module): 48 | def __init__(self, config): 49 | super().__init__() 50 | self.config = config 51 | 52 | self.proj_act = nn.Tanh() 53 | self.gate_proj = nn.Linear(config.n_embd, config.n_embd) 54 | self.gate_act = nn.Sigmoid() 55 | 56 | def forward(self, hidden_representation, current_layer): # hidden_representation (bsz, num_layers, hidden_size) 57 | 58 | prev_logits = torch.sum(hidden_representation[:, :, :current_layer+1, :], dim=2) 59 | future_logits = torch.sum(hidden_representation[:, :, current_layer+1:, :], dim=2) 60 | 61 | prev_gate = self.proj_act(torch.sum(hidden_representation[:, :, :current_layer+1, :], dim=2) / (current_layer + 1)) 62 | prev_lamb = self.gate_act(self.gate_proj(prev_gate)) 63 | 64 | _logits = 2 * prev_lamb * prev_logits + (2 - 2 * prev_lamb) * future_logits 65 | 66 | return _logits 67 | 68 | 69 | 70 | 71 | class DeeCapDecoder(Module): 72 | def __init__(self, vocab_size=50257, max_len=40, N_dec=3, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 73 | dropout=.1, self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, 74 | enc_att_module_kwargs=None): 75 | super(DeeCapDecoder, self).__init__() 76 | self.d_model = d_model 77 | self.vocab_size = vocab_size 78 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 79 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 80 | self.layers = ModuleList( 81 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 82 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 83 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 84 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 85 | self.max_len = max_len 86 | self.padding_idx = padding_idx 87 | self.N = N_dec 88 | 89 | self.register_state('running_mask_self_attention', None) 90 | self.register_state('running_seq', torch.zeros((1,)).long()) 91 | 92 | def forward(self, input, encoder_output, mask_encoder): 93 | # input (b_s, seq_len) 94 | input = input[:, :self.max_len] 95 | b_s, seq_len = input.shape[:2] 96 | 97 | if input.dtype in [torch.long, torch.int]: 98 | input_index = input 99 | else: 100 | input_index = one_hot_to_index(input) 101 | 102 | mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) 103 | # (b_s, seq_len, 1) 104 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device), 105 | diagonal=1) 106 | # print(mask_self_attention) (seq_len, seq_len) 107 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 108 | mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool() 109 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 110 | if self._is_stateful: 111 | if self.running_mask_self_attention is None: 112 | self.running_mask_self_attention = mask_self_attention 113 | else: 114 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], 115 | -1) 116 | mask_self_attention = self.running_mask_self_attention 117 | 118 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 119 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 120 | if self._is_stateful: 121 | self.running_seq.add_(1) 122 | seq = self.running_seq 123 | 124 | if input.dtype in [torch.long, torch.int]: 125 | out = self.word_emb(input) 126 | else: 127 | out = input @ self.word_emb.weight 128 | 129 | out = out + self.pos_emb(seq) 130 | for i, l in enumerate(self.layers): 131 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 132 | 133 | out = self.fc(out) 134 | return out 135 | 136 | # only forward propagation with transformer block 137 | def adaptive_forward(self, hidden_states, current_layer, encoder_output, mask_queries, mask_self_attention, mask_encoder): 138 | layer_outputs = self.layers[current_layer](hidden_states, encoder_output, mask_queries, mask_self_attention, mask_encoder) 139 | return layer_outputs 140 | 141 | def word_forward(self, input): 142 | # input (b_s, seq_len) 143 | input = input[:, :self.max_len] 144 | b_s, seq_len = input.shape[:2] 145 | 146 | if input.dtype in [torch.long, torch.int]: 147 | input_index = input 148 | else: 149 | input_index = one_hot_to_index(input) 150 | 151 | mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) 152 | # (b_s, seq_len, 1) 153 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device), 154 | diagonal=1) 155 | # print(mask_self_attention) (seq_len, seq_len) 156 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 157 | mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool() 158 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 159 | if self._is_stateful: 160 | if self.running_mask_self_attention is None: 161 | self.running_mask_self_attention = mask_self_attention 162 | else: 163 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], 164 | -1) 165 | mask_self_attention = self.running_mask_self_attention 166 | 167 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 168 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 169 | if self._is_stateful: 170 | self.running_seq.add_(1) 171 | seq = self.running_seq 172 | 173 | if input.dtype in [torch.long, torch.int]: 174 | out = self.word_emb(input) 175 | else: 176 | out = input @ self.word_emb.weight 177 | 178 | out = out + self.pos_emb(seq) 179 | return out, mask_queries, mask_self_attention 180 | 181 | 182 | 183 | class DeeCapModel(Module): 184 | def __init__(self, config): 185 | super(DeeCapModel, self).__init__() 186 | self.config = config 187 | self.model_d = config.n_embd 188 | self.clip_dim = config.clip_dim 189 | self.clip_length = config.clip_length 190 | self.feature_project = nn.Linear(config.clip_dim, config.clip_length*config.n_embd) 191 | self.visual_encoder = Encoder(config.n_layer, config.clip_length, config.n_embd) 192 | self.language_decoder = DeeCapDecoder(config.vocab_size, N_dec=config.n_layer) 193 | self.imitation_net = ImitationNet(config) 194 | self.poolers = nn.ModuleList([DeeCapPooler(config) for _ in range(config.n_layer)]) 195 | self.fusion_net = InternelClassifierWithGate(config) 196 | 197 | self.bos_idx = config.bos_token_id 198 | self.eos_idx = config.eos_token_id 199 | self.vocab_size = config.vocab_size 200 | self.max_generation_length = self.language_decoder.max_len 201 | self.freezed_lower_layer = 3 202 | 203 | self.hidden_states_list = [] 204 | self.hidden_states_proj_list = [] 205 | 206 | # compute the acceleration ration 207 | self.patience = config.patience 208 | self.inference_words_num = 0 209 | self.inference_layers_num = 0 210 | 211 | self.register_state('enc_output', None) 212 | self.register_state('mask_enc', None) 213 | self.init_weights() 214 | 215 | 216 | def init_weights(self): 217 | for p in self.visual_encoder.parameters(): 218 | if p.dim() > 1: 219 | nn.init.xavier_uniform_(p) 220 | for p in self.language_decoder.parameters(): 221 | if p.dim() > 1: 222 | nn.init.xavier_uniform_(p) 223 | 224 | 225 | def forward(self, images, seq): 226 | bsz, seq_len = seq.size()[:2] 227 | 228 | images = self.feature_project(images).view(-1, self.clip_length, self.clip_dim) 229 | enc_output, mask_enc = self.visual_encoder(images) 230 | 231 | 232 | self.hidden_states_list.clear() 233 | self.hidden_states_proj_list.clear() 234 | 235 | hidden_states, mask_queries, mask_self_attention = self.language_decoder.word_forward(seq) 236 | # (bsz, seq_len, model_d) 237 | 238 | res = [] 239 | all_pool = [] 240 | for i in range(self.config.n_layer): 241 | 242 | hidden_states = self.language_decoder.adaptive_forward(hidden_states, i, enc_output, mask_queries, mask_self_attention, mask_enc) 243 | # (bsz, seq_len, model_d) 244 | if i < self.freezed_lower_layer: 245 | hidden_states = hidden_states.detach() 246 | pooled_output = self.poolers[i](hidden_states) 247 | confidence_token = pooled_output 248 | # (bsz, seq_len, model_d) 249 | 250 | self.hidden_states_list.append(confidence_token.detach()) 251 | # approximate high-level hidden representation 252 | self.hidden_states_proj_list.append(self.imitation_net(confidence_token.view(-1, self.config.n_embd)).contiguous().view(bsz, seq_len, self.config.n_layer, -1)) 253 | 254 | if all_pool: 255 | all_pool[-1] = all_pool[-1].detach() 256 | all_pool.append(pooled_output) 257 | pooled_output = torch.stack(all_pool, dim=2) 258 | # last year does not incorporate hidden representation prediction 259 | if i < self.config.n_layer - 1: 260 | pred_hidden_representation = self.hidden_states_proj_list[-1][:, :, i+1, :] 261 | pred_hidden_representation = pred_hidden_representation.unsqueeze(2) 262 | pooled_output = torch.cat([pooled_output, pred_hidden_representation], dim=2) #reshape(pooled_output.shape[0], -1) 263 | # (bsz, seq_len, layer, d_model) 264 | logits = self.language_decoder.fc(self.fusion_net(pooled_output, i)) 265 | # (bsz, seq_len, d_model) 266 | res.append(logits) 267 | 268 | return res 269 | 270 | 271 | def step(self, images, prev_seq): 272 | bsz, seq_len = prev_seq.size()[:2] 273 | 274 | images = self.feature_project(images).view(-1, self.clip_length, self.clip_dim) 275 | enc_output, mask_enc = self.visual_encoder(images) 276 | 277 | self.hidden_states_list.clear() 278 | self.hidden_states_proj_list.clear() 279 | all_pool = [] 280 | 281 | hidden_states, mask_queries, mask_self_attention = self.language_decoder.word_forward(prev_seq) 282 | calculated_layer_num = 0 283 | patient_result = None 284 | 285 | for i in range(self.config.n_layer): 286 | calculated_layer_num += 1 287 | hidden_states = self.language_decoder.adaptive_forward(hidden_states, i, enc_output, mask_queries, mask_self_attention, mask_enc) 288 | # (bsz, seq_len, model_d) 289 | pooled_output = self.poolers[i](hidden_states) 290 | confidence_token = pooled_output 291 | 292 | self.hidden_states_proj_list.append(self.imitation_net(confidence_token.view(-1, self.config.n_embd)).contiguous().view(bsz, seq_len, self.config.n_layer, -1)) 293 | 294 | all_pool.append(pooled_output) 295 | pooled_output = torch.stack(all_pool, dim=2) 296 | 297 | if i < self.config.n_layer - 1: 298 | pred_hidden_representation = self.hidden_states_proj_list[-1][:, :, i+1, :] 299 | pred_hidden_representation = pred_hidden_representation.unsqueeze(2) 300 | pooled_output = torch.cat([pooled_output, pred_hidden_representation], dim=2) #reshape(pooled_output.shape[0], -1) 301 | # (bsz, seq_len, layer, d_model) 302 | logits = self.language_decoder.fc(self.fusion_net(pooled_output, i))[:, -1, :] 303 | prob = torch.nn.functional.softmax(logits, dim=1) 304 | entropy = -torch.sum(prob * torch.log(prob), dim=1) 305 | patient_result = logits 306 | 307 | # bsz = 1 for testing image-text pair 308 | if torch.all((entropy) < self.patience): 309 | break 310 | self.inference_layers_num += calculated_layer_num 311 | self.inference_words_num += 1 312 | return logits 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | -------------------------------------------------------------------------------- /data/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import array 3 | from collections import defaultdict 4 | from functools import partial 5 | import io 6 | import logging 7 | import os 8 | import zipfile 9 | 10 | import six 11 | from six.moves.urllib.request import urlretrieve 12 | import torch 13 | from tqdm import tqdm 14 | import tarfile 15 | 16 | from .utils import reporthook 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Vocab(object): 22 | """Defines a vocabulary object that will be used to numericalize a field. 23 | 24 | Attributes: 25 | freqs: A collections.Counter object holding the frequencies of tokens 26 | in the data used to build the Vocab. 27 | stoi: A collections.defaultdict instance mapping token strings to 28 | numerical identifiers. 29 | itos: A list of token strings indexed by their numerical identifiers. 30 | """ 31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''], 32 | vectors=None, unk_init=None, vectors_cache=None): 33 | """Create a Vocab object from a collections.Counter. 34 | 35 | Arguments: 36 | counter: collections.Counter object holding the frequencies of 37 | each value found in the data. 38 | max_size: The maximum size of the vocabulary, or None for no 39 | maximum. Default: None. 40 | min_freq: The minimum frequency needed to include a token in the 41 | vocabulary. Values less than 1 will be set to 1. Default: 1. 42 | specials: The list of special tokens (e.g., padding or eos) that 43 | will be prepended to the vocabulary in addition to an 44 | token. Default: [''] 45 | vectors: One of either the available pretrained vectors 46 | or custom pretrained vectors (see Vocab.load_vectors); 47 | or a list of aforementioned vectors 48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 49 | to zero vectors; can be any function that takes in a Tensor and 50 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 51 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 52 | """ 53 | self.freqs = counter 54 | counter = counter.copy() 55 | min_freq = max(min_freq, 1) 56 | 57 | self.itos = list(specials) 58 | # frequencies of special tokens are not counted when building vocabulary 59 | # in frequency order 60 | for tok in specials: 61 | del counter[tok] 62 | 63 | max_size = None if max_size is None else max_size + len(self.itos) 64 | 65 | # sort by frequency, then alphabetically 66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 68 | 69 | for word, freq in words_and_frequencies: 70 | if freq < min_freq or len(self.itos) == max_size: 71 | break 72 | self.itos.append(word) 73 | 74 | self.stoi = defaultdict(_default_unk_index) 75 | # stoi is simply a reverse dict for itos 76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) 77 | 78 | self.vectors = None 79 | if vectors is not None: 80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 81 | else: 82 | assert unk_init is None and vectors_cache is None 83 | 84 | def __eq__(self, other): 85 | if self.freqs != other.freqs: 86 | return False 87 | if self.stoi != other.stoi: 88 | return False 89 | if self.itos != other.itos: 90 | return False 91 | if self.vectors != other.vectors: 92 | return False 93 | return True 94 | 95 | def __len__(self): 96 | return len(self.itos) 97 | 98 | def extend(self, v, sort=False): 99 | words = sorted(v.itos) if sort else v.itos 100 | for w in words: 101 | if w not in self.stoi: 102 | self.itos.append(w) 103 | self.stoi[w] = len(self.itos) - 1 104 | 105 | def load_vectors(self, vectors, **kwargs): 106 | """ 107 | Arguments: 108 | vectors: one of or a list containing instantiations of the 109 | GloVe, CharNGram, or Vectors classes. Alternatively, one 110 | of or a list of available pretrained vectors: 111 | charngram.100d 112 | fasttext.en.300d 113 | fasttext.simple.300d 114 | glove.42B.300d 115 | glove.840B.300d 116 | glove.twitter.27B.25d 117 | glove.twitter.27B.50d 118 | glove.twitter.27B.100d 119 | glove.twitter.27B.200d 120 | glove.6B.50d 121 | glove.6B.100d 122 | glove.6B.200d 123 | glove.6B.300d 124 | Remaining keyword arguments: Passed to the constructor of Vectors classes. 125 | """ 126 | if not isinstance(vectors, list): 127 | vectors = [vectors] 128 | for idx, vector in enumerate(vectors): 129 | if six.PY2 and isinstance(vector, str): 130 | vector = six.text_type(vector) 131 | if isinstance(vector, six.string_types): 132 | # Convert the string pretrained vector identifier 133 | # to a Vectors object 134 | if vector not in pretrained_aliases: 135 | raise ValueError( 136 | "Got string input vector {}, but allowed pretrained " 137 | "vectors are {}".format( 138 | vector, list(pretrained_aliases.keys()))) 139 | vectors[idx] = pretrained_aliases[vector](**kwargs) 140 | elif not isinstance(vector, Vectors): 141 | raise ValueError( 142 | "Got input vectors of type {}, expected str or " 143 | "Vectors object".format(type(vector))) 144 | 145 | tot_dim = sum(v.dim for v in vectors) 146 | self.vectors = torch.Tensor(len(self), tot_dim) 147 | for i, token in enumerate(self.itos): 148 | start_dim = 0 149 | for v in vectors: 150 | end_dim = start_dim + v.dim 151 | self.vectors[i][start_dim:end_dim] = v[token.strip()] 152 | start_dim = end_dim 153 | assert(start_dim == tot_dim) 154 | 155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): 156 | """ 157 | Set the vectors for the Vocab instance from a collection of Tensors. 158 | 159 | Arguments: 160 | stoi: A dictionary of string to the index of the associated vector 161 | in the `vectors` input argument. 162 | vectors: An indexed iterable (or other structure supporting __getitem__) that 163 | given an input index, returns a FloatTensor representing the vector 164 | for the token associated with the index. For example, 165 | vector[stoi["string"]] should return the vector for "string". 166 | dim: The dimensionality of the vectors. 167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 168 | to zero vectors; can be any function that takes in a Tensor and 169 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 170 | """ 171 | self.vectors = torch.Tensor(len(self), dim) 172 | for i, token in enumerate(self.itos): 173 | wv_index = stoi.get(token, None) 174 | if wv_index is not None: 175 | self.vectors[i] = vectors[wv_index] 176 | else: 177 | self.vectors[i] = unk_init(self.vectors[i]) 178 | 179 | 180 | class Vectors(object): 181 | 182 | def __init__(self, name, cache=None, 183 | url=None, unk_init=None): 184 | """ 185 | Arguments: 186 | name: name of the file that contains the vectors 187 | cache: directory for cached vectors 188 | url: url for download if vectors not found in cache 189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors 190 | to zero vectors; can be any function that takes in a Tensor and 191 | returns a Tensor of the same size 192 | """ 193 | cache = '.vector_cache' if cache is None else cache 194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init 195 | self.cache(name, cache, url=url) 196 | 197 | def __getitem__(self, token): 198 | if token in self.stoi: 199 | return self.vectors[self.stoi[token]] 200 | else: 201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim)) 202 | 203 | def cache(self, name, cache, url=None): 204 | if os.path.isfile(name): 205 | path = name 206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' 207 | else: 208 | path = os.path.join(cache, name) 209 | path_pt = path + '.pt' 210 | 211 | if not os.path.isfile(path_pt): 212 | if not os.path.isfile(path) and url: 213 | logger.info('Downloading vectors from {}'.format(url)) 214 | if not os.path.exists(cache): 215 | os.makedirs(cache) 216 | dest = os.path.join(cache, os.path.basename(url)) 217 | if not os.path.isfile(dest): 218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 219 | try: 220 | urlretrieve(url, dest, reporthook=reporthook(t)) 221 | except KeyboardInterrupt as e: # remove the partial zip file 222 | os.remove(dest) 223 | raise e 224 | logger.info('Extracting vectors into {}'.format(cache)) 225 | ext = os.path.splitext(dest)[1][1:] 226 | if ext == 'zip': 227 | with zipfile.ZipFile(dest, "r") as zf: 228 | zf.extractall(cache) 229 | elif ext == 'gz': 230 | with tarfile.open(dest, 'r:gz') as tar: 231 | tar.extractall(path=cache) 232 | if not os.path.isfile(path): 233 | raise RuntimeError('no vectors found at {}'.format(path)) 234 | 235 | # str call is necessary for Python 2/3 compatibility, since 236 | # argument must be Python 2 str (Python 3 bytes) or 237 | # Python 3 str (Python 2 unicode) 238 | itos, vectors, dim = [], array.array(str('d')), None 239 | 240 | # Try to read the whole file with utf-8 encoding. 241 | binary_lines = False 242 | try: 243 | with io.open(path, encoding="utf8") as f: 244 | lines = [line for line in f] 245 | # If there are malformed lines, read in binary mode 246 | # and manually decode each word from utf-8 247 | except: 248 | logger.warning("Could not read {} as UTF8 file, " 249 | "reading file as bytes and skipping " 250 | "words with malformed UTF8.".format(path)) 251 | with open(path, 'rb') as f: 252 | lines = [line for line in f] 253 | binary_lines = True 254 | 255 | logger.info("Loading vectors from {}".format(path)) 256 | for line in tqdm(lines, total=len(lines)): 257 | # Explicitly splitting on " " is important, so we don't 258 | # get rid of Unicode non-breaking spaces in the vectors. 259 | entries = line.rstrip().split(b" " if binary_lines else " ") 260 | 261 | word, entries = entries[0], entries[1:] 262 | if dim is None and len(entries) > 1: 263 | dim = len(entries) 264 | elif len(entries) == 1: 265 | logger.warning("Skipping token {} with 1-dimensional " 266 | "vector {}; likely a header".format(word, entries)) 267 | continue 268 | elif dim != len(entries): 269 | raise RuntimeError( 270 | "Vector for token {} has {} dimensions, but previously " 271 | "read vectors have {} dimensions. All vectors must have " 272 | "the same number of dimensions.".format(word, len(entries), dim)) 273 | 274 | if binary_lines: 275 | try: 276 | if isinstance(word, six.binary_type): 277 | word = word.decode('utf-8') 278 | except: 279 | logger.info("Skipping non-UTF8 token {}".format(repr(word))) 280 | continue 281 | vectors.extend(float(x) for x in entries) 282 | itos.append(word) 283 | 284 | self.itos = itos 285 | self.stoi = {word: i for i, word in enumerate(itos)} 286 | self.vectors = torch.Tensor(vectors).view(-1, dim) 287 | self.dim = dim 288 | logger.info('Saving vectors to {}'.format(path_pt)) 289 | if not os.path.exists(cache): 290 | os.makedirs(cache) 291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 292 | else: 293 | logger.info('Loading vectors from {}'.format(path_pt)) 294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 295 | 296 | 297 | class GloVe(Vectors): 298 | url = { 299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 303 | } 304 | 305 | def __init__(self, name='840B', dim=300, **kwargs): 306 | url = self.url[name] 307 | name = 'glove.{}.{}d.txt'.format(name, str(dim)) 308 | super(GloVe, self).__init__(name, url=url, **kwargs) 309 | 310 | 311 | class FastText(Vectors): 312 | 313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' 314 | 315 | def __init__(self, language="en", **kwargs): 316 | url = self.url_base.format(language) 317 | name = os.path.basename(url) 318 | super(FastText, self).__init__(name, url=url, **kwargs) 319 | 320 | 321 | class CharNGram(Vectors): 322 | 323 | name = 'charNgram.txt' 324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 325 | 'jmt_pre-trained_embeddings.tar.gz') 326 | 327 | def __init__(self, **kwargs): 328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) 329 | 330 | def __getitem__(self, token): 331 | vector = torch.Tensor(1, self.dim).zero_() 332 | if token == "": 333 | return self.unk_init(vector) 334 | # These literals need to be coerced to unicode for Python 2 compatibility 335 | # when we try to join them with read ngrams from the files. 336 | chars = ['#BEGIN#'] + list(token) + ['#END#'] 337 | num_vectors = 0 338 | for n in [2, 3, 4]: 339 | end = len(chars) - n + 1 340 | grams = [chars[i:(i + n)] for i in range(end)] 341 | for gram in grams: 342 | gram_key = '{}gram-{}'.format(n, ''.join(gram)) 343 | if gram_key in self.stoi: 344 | vector += self.vectors[self.stoi[gram_key]] 345 | num_vectors += 1 346 | if num_vectors > 0: 347 | vector /= num_vectors 348 | else: 349 | vector = self.unk_init(vector) 350 | return vector 351 | 352 | 353 | def _default_unk_index(): 354 | return 0 355 | 356 | 357 | pretrained_aliases = { 358 | "charngram.100d": partial(CharNGram), 359 | "fasttext.en.300d": partial(FastText, language="en"), 360 | "fasttext.simple.300d": partial(FastText, language="simple"), 361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"), 362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"), 363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), 364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), 365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), 366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), 367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"), 368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"), 369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"), 370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300") 371 | } 372 | """Mapping from string name to factory function""" 373 | --------------------------------------------------------------------------------