├── evaluation ├── bleu │ ├── __init__.py │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── cider.py │ └── cider_scorer.py ├── rouge │ ├── __init__.py │ └── rouge.py ├── meteor │ ├── __init__.py │ └── meteor.py ├── stanford-corenlp-3.4.1.jar ├── __init__.py └── tokenizer.py ├── models ├── beam_search │ ├── __init__.py │ └── beam_search.py ├── __init__.py ├── transformer │ ├── __init__.py │ ├── utils.py │ ├── transformer.py │ ├── encoders.py │ ├── decoders.py │ └── attention.py ├── captioning_model.py └── containers.py ├── vocab.pkl ├── images ├── m2.png └── results.png ├── utils ├── typing.py ├── __init__.py └── utils.py ├── output_logs └── meshed_memory_transformer_test_o ├── data ├── __init__.py ├── example.py ├── utils.py ├── dataset.py ├── field.py └── vocab.py ├── LICENSE ├── environment.yml ├── test.py ├── README.md └── train.py /evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | -------------------------------------------------------------------------------- /vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/meshed-memory-transformer/HEAD/vocab.pkl -------------------------------------------------------------------------------- /images/m2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/meshed-memory-transformer/HEAD/images/m2.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/meshed-memory-transformer/HEAD/images/results.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .captioning_model import CaptioningModel 3 | -------------------------------------------------------------------------------- /evaluation/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/meshed-memory-transformer/HEAD/evaluation/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Tuple 2 | import torch 3 | 4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 5 | TensorOrNone = Union[torch.Tensor, None] 6 | -------------------------------------------------------------------------------- /output_logs/meshed_memory_transformer_test_o: -------------------------------------------------------------------------------- 1 | Meshed-Memory Transformer Evaluation 2 | {'BLEU': [0.8076084272899184, 0.65337618312199, 0.5093125587687117, 0.3909357911782391], 'METEOR': 0.2918900660095916, 'ROUGE': 0.5863539878042495, 'CIDEr': 1.3119740267338893} 3 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import RawField, Merge, ImageDetectionsField, TextField 2 | from .dataset import COCO 3 | from torch.utils.data import DataLoader as TorchDataLoader 4 | 5 | class DataLoader(TorchDataLoader): 6 | def __init__(self, dataset, *args, **kwargs): 7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) 8 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import download_from_url 2 | from .typing import * 3 | 4 | def get_batch_size(x: TensorOrSequence) -> int: 5 | if isinstance(x, torch.Tensor): 6 | b_s = x.size(0) 7 | else: 8 | b_s = x[0].size(0) 9 | return b_s 10 | 11 | 12 | def get_device(x: TensorOrSequence) -> int: 13 | if isinstance(x, torch.Tensor): 14 | b_s = x.device 15 | else: 16 | b_s = x[0].device 17 | return b_s 18 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu 2 | from .meteor import Meteor 3 | from .rouge import Rouge 4 | from .cider import Cider 5 | from .tokenizer import PTBTokenizer 6 | 7 | def compute_scores(gts, gen): 8 | metrics = (Bleu(), Meteor(), Rouge(), Cider()) 9 | all_score = {} 10 | all_scores = {} 11 | for metric in metrics: 12 | score, scores = metric.compute_score(gts, gen) 13 | all_score[str(metric)] = score 14 | all_scores[str(metric)] = scores 15 | 16 | return all_score, all_scores 17 | -------------------------------------------------------------------------------- /data/example.py: -------------------------------------------------------------------------------- 1 | 2 | class Example(object): 3 | """Defines a single training or test example. 4 | Stores each column of the example as an attribute. 5 | """ 6 | @classmethod 7 | def fromdict(cls, data): 8 | ex = cls(data) 9 | return ex 10 | 11 | def __init__(self, data): 12 | for key, val in data.items(): 13 | super(Example, self).__setattr__(key, val) 14 | 15 | def __setattr__(self, key, value): 16 | raise AttributeError 17 | 18 | def __hash__(self): 19 | return hash(tuple(x for x in self.__dict__.values())) 20 | 21 | def __eq__(self, other): 22 | this = tuple(x for x in self.__dict__.values()) 23 | other = tuple(x for x in other.__dict__.values()) 24 | return this == other 25 | 26 | def __ne__(self, other): 27 | return not self.__eq__(other) 28 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_from_url(url, path): 4 | """Download file, with logic (from tensor2tensor) for Google Drive""" 5 | if 'drive.google.com' not in url: 6 | print('Downloading %s; may take a few minutes' % url) 7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 8 | with open(path, "wb") as file: 9 | file.write(r.content) 10 | return 11 | print('Downloading from Google Drive; may take a few minutes') 12 | confirm_token = None 13 | session = requests.Session() 14 | response = session.get(url, stream=True) 15 | for k, v in response.cookies.items(): 16 | if k.startswith("download_warning"): 17 | confirm_token = v 18 | 19 | if confirm_token: 20 | url = url + "&confirm=" + confirm_token 21 | response = session.get(url, stream=True) 22 | 23 | chunk_size = 16 * 1024 24 | with open(path, "wb") as f: 25 | for chunk in response.iter_content(chunk_size): 26 | if chunk: 27 | f.write(chunk) 28 | -------------------------------------------------------------------------------- /evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, AImageLab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | 12 | class Cider: 13 | """ 14 | Main Class to compute the CIDEr metric 15 | 16 | """ 17 | def __init__(self, gts=None, n=4, sigma=6.0): 18 | # set cider to sum over 1 to 4-grams 19 | self._n = n 20 | # set the standard deviation parameter for gaussian penalty 21 | self._sigma = sigma 22 | self.doc_frequency = None 23 | self.ref_len = None 24 | if gts is not None: 25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 26 | self.doc_frequency = tmp_cider.doc_frequency 27 | self.ref_len = tmp_cider.ref_len 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | Main function to compute CIDEr score 32 | :param gts (dict) : dictionary with key and value 33 | res (dict) : dictionary with key and value 34 | :return: cider (float) : computed CIDEr score for the corpus 35 | """ 36 | assert(gts.keys() == res.keys()) 37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 38 | ref_len=self.ref_len) 39 | return cider_scorer.compute_score() 40 | 41 | def __str__(self): 42 | return 'CIDEr' 43 | -------------------------------------------------------------------------------- /models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(input + out) 50 | return out -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: m2release 2 | channels: 3 | - anaconda 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - asn1crypto=1.2.0=py36_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2019.10.16=0 10 | - certifi=2019.9.11=py36_0 11 | - cffi=1.13.2=py36h2e261b9_0 12 | - chardet=3.0.4=py36_1003 13 | - cryptography=2.8=py36h1ba5d50_0 14 | - cython=0.29.14=py36he6710b0_0 15 | - dill=0.2.9=py36_0 16 | - idna=2.8=py36_0 17 | - intel-openmp=2019.5=281 18 | - libedit=3.1.20181209=hc058e9b_0 19 | - libffi=3.2.1=hd88cf55_4 20 | - libgcc-ng=9.1.0=hdf63c60_0 21 | - libgfortran-ng=7.3.0=hdf63c60_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - mkl=2019.5=281 24 | - mkl-service=2.3.0=py36he904b0f_0 25 | - mkl_fft=1.0.15=py36ha843d7b_0 26 | - mkl_random=1.1.0=py36hd6b4f25_0 27 | - msgpack-numpy=0.4.4.3=py_0 28 | - msgpack-python=0.5.6=py36h6bb024c_1 29 | - ncurses=6.1=he6710b0_1 30 | - openjdk=8.0.152=h46b5887_1 31 | - openssl=1.1.1=h7b6447c_0 32 | - pip=19.3.1=py36_0 33 | - pycparser=2.19=py_0 34 | - pyopenssl=19.1.0=py36_0 35 | - pysocks=1.7.1=py36_0 36 | - python=3.6.9=h265db76_0 37 | - readline=7.0=h7b6447c_5 38 | - requests=2.22.0=py36_0 39 | - setuptools=41.6.0=py36_0 40 | - six=1.13.0=py36_0 41 | - spacy=2.0.11=py36h04863e7_2 42 | - sqlite=3.30.1=h7b6447c_0 43 | - termcolor=1.1.0=py36_1 44 | - thinc=6.11.2=py36hedc7406_1 45 | - tk=8.6.8=hbc83047_0 46 | - toolz=0.10.0=py_0 47 | - urllib3=1.24.2=py36_0 48 | - wheel=0.33.6=py36_0 49 | - xz=5.2.4=h14c3975_4 50 | - zlib=1.2.11=h7b6447c_3 51 | - pip: 52 | - absl-py==0.8.1 53 | - cycler==0.10.0 54 | - cymem==1.31.2 55 | - cytoolz==0.9.0.1 56 | - future==0.17.1 57 | - grpcio==1.25.0 58 | - h5py==2.8.0 59 | - kiwisolver==1.1.0 60 | - markdown==3.1.1 61 | - matplotlib==2.2.3 62 | - msgpack==0.6.2 63 | - multiprocess==0.70.9 64 | - murmurhash==0.28.0 65 | - numpy==1.16.4 66 | - pathlib==1.0.1 67 | - pathos==0.2.3 68 | - pillow==6.2.1 69 | - plac==0.9.6 70 | - pox==0.2.7 71 | - ppft==1.6.6.1 72 | - preshed==1.0.1 73 | - protobuf==3.10.0 74 | - pycocotools==2.0.0 75 | - pyparsing==2.4.5 76 | - python-dateutil==2.8.1 77 | - pytz==2019.3 78 | - regex==2017.4.5 79 | - tensorboard==1.14.0 80 | - torch==1.1.0 81 | - torchvision==0.3.0 82 | - tqdm==4.32.2 83 | - ujson==1.35 84 | - werkzeug==0.16.0 85 | - wrapt==1.10.11 86 | 87 | -------------------------------------------------------------------------------- /evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, seq, *args): 28 | enc_output, mask_enc = self.encoder(images) 29 | dec_output = self.decoder(seq, enc_output, mask_enc) 30 | return dec_output 31 | 32 | def init_state(self, b_s, device): 33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 34 | None, None] 35 | 36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 37 | it = None 38 | if mode == 'teacher_forcing': 39 | raise NotImplementedError 40 | elif mode == 'feedback': 41 | if t == 0: 42 | self.enc_output, self.mask_enc = self.encoder(visual) 43 | if isinstance(visual, torch.Tensor): 44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 45 | else: 46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 47 | else: 48 | it = prev_output 49 | 50 | return self.decoder(it, self.enc_output, self.mask_enc) 51 | 52 | 53 | class TransformerEnsemble(CaptioningModel): 54 | def __init__(self, model: Transformer, weight_files): 55 | super(TransformerEnsemble, self).__init__() 56 | self.n = len(weight_files) 57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 58 | for i in range(self.n): 59 | state_dict_i = torch.load(weight_files[i])['state_dict'] 60 | self.models[i].load_state_dict(state_dict_i) 61 | 62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 63 | out_ensemble = [] 64 | for i in range(self.n): 65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 66 | out_ensemble.append(out_i.unsqueeze(0)) 67 | 68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 69 | -------------------------------------------------------------------------------- /models/captioning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | import utils 4 | from models.containers import Module 5 | from models.beam_search import * 6 | 7 | 8 | class CaptioningModel(Module): 9 | def __init__(self): 10 | super(CaptioningModel, self).__init__() 11 | 12 | def init_weights(self): 13 | raise NotImplementedError 14 | 15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 16 | raise NotImplementedError 17 | 18 | def forward(self, images, seq, *args): 19 | device = images.device 20 | b_s = images.size(0) 21 | seq_len = seq.size(1) 22 | state = self.init_state(b_s, device) 23 | out = None 24 | 25 | outputs = [] 26 | for t in range(seq_len): 27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') 28 | outputs.append(out) 29 | 30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) 31 | return outputs 32 | 33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 34 | b_s = utils.get_batch_size(visual) 35 | device = utils.get_device(visual) 36 | outputs = [] 37 | log_probs = [] 38 | 39 | mask = torch.ones((b_s,), device=device) 40 | with self.statefulness(b_s): 41 | out = None 42 | for t in range(max_len): 43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) 44 | out = torch.max(log_probs_t, -1)[1] 45 | mask = mask * (out.squeeze(-1) != eos_idx).float() 46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) 47 | outputs.append(out) 48 | 49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 50 | 51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 52 | b_s = utils.get_batch_size(visual) 53 | outputs = [] 54 | log_probs = [] 55 | 56 | with self.statefulness(b_s): 57 | out = None 58 | for t in range(max_len): 59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs) 60 | distr = distributions.Categorical(logits=out[:, 0]) 61 | out = distr.sample().unsqueeze(1) 62 | outputs.append(out) 63 | log_probs.append(distr.log_prob(out).unsqueeze(1)) 64 | 65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 66 | 67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1, 68 | return_probs=False, **kwargs): 69 | bs = BeamSearch(self, max_len, eos_idx, beam_size) 70 | return bs.apply(visual, out_size, return_probs, **kwargs) 71 | -------------------------------------------------------------------------------- /models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from utils.typing import * 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | self._buffers[name] = fn(self._buffers[name]) 31 | for m in self.children(): 32 | if isinstance(m, Module): 33 | m.apply_to_states(fn) 34 | 35 | def _init_states(self, batch_size: int): 36 | for name in self._state_names: 37 | if self._state_defaults[name] is None: 38 | self._buffers[name] = None 39 | else: 40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 41 | self._buffers[name] = self._buffers[name].unsqueeze(0) 42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 43 | self._buffers[name] = self._buffers[name].contiguous() 44 | 45 | def _reset_states(self): 46 | for name in self._state_names: 47 | if self._state_defaults[name] is None: 48 | self._buffers[name] = None 49 | else: 50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 51 | 52 | def enable_statefulness(self, batch_size: int): 53 | for m in self.children(): 54 | if isinstance(m, Module): 55 | m.enable_statefulness(batch_size) 56 | self._init_states(batch_size) 57 | self._is_stateful = True 58 | 59 | def disable_statefulness(self): 60 | for m in self.children(): 61 | if isinstance(m, Module): 62 | m.disable_statefulness() 63 | self._reset_states() 64 | self._is_stateful = False 65 | 66 | @contextmanager 67 | def statefulness(self, batch_size: int): 68 | self.enable_statefulness(batch_size) 69 | try: 70 | yield 71 | finally: 72 | self.disable_statefulness() 73 | 74 | 75 | class ModuleList(nn.ModuleList, Module): 76 | pass 77 | 78 | 79 | class ModuleDict(nn.ModuleDict, Module): 80 | pass 81 | -------------------------------------------------------------------------------- /evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | from utils import download_from_url 9 | 10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 11 | METEOR_JAR = 'meteor-1.5.jar' 12 | 13 | class Meteor: 14 | def __init__(self): 15 | base_path = os.path.dirname(os.path.abspath(__file__)) 16 | jar_path = os.path.join(base_path, METEOR_JAR) 17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 18 | if not os.path.isfile(jar_path): 19 | if not os.path.isfile(gz_path): 20 | download_from_url(METEOR_GZ_URL, gz_path) 21 | tar = tarfile.open(gz_path, "r") 22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 23 | tar.close() 24 | os.remove(gz_path) 25 | 26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 27 | '-', '-', '-stdio', '-l', 'en', '-norm'] 28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 29 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 30 | stdin=subprocess.PIPE, \ 31 | stdout=subprocess.PIPE, \ 32 | stderr=subprocess.PIPE) 33 | # Used to guarantee thread safety 34 | self.lock = threading.Lock() 35 | 36 | def compute_score(self, gts, res): 37 | assert(gts.keys() == res.keys()) 38 | imgIds = gts.keys() 39 | scores = [] 40 | 41 | eval_line = 'EVAL' 42 | self.lock.acquire() 43 | for i in imgIds: 44 | assert(len(res[i]) == 1) 45 | stat = self._stat(res[i][0], gts[i]) 46 | eval_line += ' ||| {}'.format(stat) 47 | 48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 49 | self.meteor_p.stdin.flush() 50 | for i in range(0,len(imgIds)): 51 | scores.append(float(self.meteor_p.stdout.readline().strip())) 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | self.lock.release() 54 | 55 | return score, scores 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 62 | self.meteor_p.stdin.flush() 63 | raw = self.meteor_p.stdout.readline().decode().strip() 64 | numbers = [str(int(float(n))) for n in raw.split()] 65 | return ' '.join(numbers) 66 | 67 | def __del__(self): 68 | self.lock.acquire() 69 | self.meteor_p.stdin.close() 70 | self.meteor_p.kill() 71 | self.meteor_p.wait() 72 | self.lock.release() 73 | 74 | def __str__(self): 75 | return 'METEOR' 76 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib, sys 2 | 3 | class DummyFile(object): 4 | def write(self, x): pass 5 | 6 | @contextlib.contextmanager 7 | def nostdout(): 8 | save_stdout = sys.stdout 9 | sys.stdout = DummyFile() 10 | yield 11 | sys.stdout = save_stdout 12 | 13 | def reporthook(t): 14 | """https://github.com/tqdm/tqdm""" 15 | last_b = [0] 16 | 17 | def inner(b=1, bsize=1, tsize=None): 18 | """ 19 | b: int, optionala 20 | Number of blocks just transferred [default: 1]. 21 | bsize: int, optional 22 | Size of each block (in tqdm units) [default: 1]. 23 | tsize: int, optional 24 | Total size (in tqdm units). If [default: None] remains unchanged. 25 | """ 26 | if tsize is not None: 27 | t.total = tsize 28 | t.update((b - last_b[0]) * bsize) 29 | last_b[0] = b 30 | return inner 31 | 32 | def get_tokenizer(tokenizer): 33 | if callable(tokenizer): 34 | return tokenizer 35 | if tokenizer == "spacy": 36 | try: 37 | import spacy 38 | spacy_en = spacy.load('en') 39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 40 | except ImportError: 41 | print("Please install SpaCy and the SpaCy English tokenizer. " 42 | "See the docs at https://spacy.io for more information.") 43 | raise 44 | except AttributeError: 45 | print("Please install SpaCy and the SpaCy English tokenizer. " 46 | "See the docs at https://spacy.io for more information.") 47 | raise 48 | elif tokenizer == "moses": 49 | try: 50 | from nltk.tokenize.moses import MosesTokenizer 51 | moses_tokenizer = MosesTokenizer() 52 | return moses_tokenizer.tokenize 53 | except ImportError: 54 | print("Please install NLTK. " 55 | "See the docs at http://nltk.org for more information.") 56 | raise 57 | except LookupError: 58 | print("Please install the necessary NLTK corpora. " 59 | "See the docs at http://nltk.org for more information.") 60 | raise 61 | elif tokenizer == 'revtok': 62 | try: 63 | import revtok 64 | return revtok.tokenize 65 | except ImportError: 66 | print("Please install revtok.") 67 | raise 68 | elif tokenizer == 'subword': 69 | try: 70 | import revtok 71 | return lambda x: revtok.tokenize(x, decap=True) 72 | except ImportError: 73 | print("Please install revtok.") 74 | raise 75 | raise ValueError("Requested tokenizer {}, valid choices are a " 76 | "callable that takes a single string as input, " 77 | "\"revtok\" for the revtok reversible tokenizer, " 78 | "\"subword\" for the revtok caps-aware tokenizer, " 79 | "\"spacy\" for the SpaCy English tokenizer, or " 80 | "\"moses\" for the NLTK port of the Moses tokenization " 81 | "script.".format(tokenizer)) 82 | -------------------------------------------------------------------------------- /models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.transformer.attention import MultiHeadAttention 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 10 | attention_module=None, attention_module_kwargs=None): 11 | super(EncoderLayer, self).__init__() 12 | self.identity_map_reordering = identity_map_reordering 13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 14 | attention_module=attention_module, 15 | attention_module_kwargs=attention_module_kwargs) 16 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 17 | 18 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 19 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 20 | ff = self.pwff(att) 21 | return ff 22 | 23 | 24 | class MultiLevelEncoder(nn.Module): 25 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 26 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 27 | super(MultiLevelEncoder, self).__init__() 28 | self.d_model = d_model 29 | self.dropout = dropout 30 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 31 | identity_map_reordering=identity_map_reordering, 32 | attention_module=attention_module, 33 | attention_module_kwargs=attention_module_kwargs) 34 | for _ in range(N)]) 35 | self.padding_idx = padding_idx 36 | 37 | def forward(self, input, attention_weights=None): 38 | # input (b_s, seq_len, d_in) 39 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) 40 | 41 | outs = [] 42 | out = input 43 | for l in self.layers: 44 | out = l(out, out, out, attention_mask, attention_weights) 45 | outs.append(out.unsqueeze(1)) 46 | 47 | outs = torch.cat(outs, 1) 48 | return outs, attention_mask 49 | 50 | 51 | class MemoryAugmentedEncoder(MultiLevelEncoder): 52 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 53 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 54 | self.fc = nn.Linear(d_in, self.d_model) 55 | self.dropout = nn.Dropout(p=self.dropout) 56 | self.layer_norm = nn.LayerNorm(self.d_model) 57 | 58 | def forward(self, input, attention_weights=None): 59 | out = F.relu(self.fc(input)) 60 | out = self.dropout(out) 61 | out = self.layer_norm(out) 62 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights) 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import random 2 | from data import ImageDetectionsField, TextField, RawField 3 | from data import COCO, DataLoader 4 | import evaluation 5 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory 6 | import torch 7 | from tqdm import tqdm 8 | import argparse 9 | import pickle 10 | import numpy as np 11 | 12 | random.seed(1234) 13 | torch.manual_seed(1234) 14 | np.random.seed(1234) 15 | 16 | 17 | def predict_captions(model, dataloader, text_field): 18 | import itertools 19 | model.eval() 20 | gen = {} 21 | gts = {} 22 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: 23 | for it, (images, caps_gt) in enumerate(iter(dataloader)): 24 | images = images.to(device) 25 | with torch.no_grad(): 26 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 27 | caps_gen = text_field.decode(out, join_words=False) 28 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 29 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 30 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ] 31 | gts['%d_%d' % (it, i)] = gts_i 32 | pbar.update() 33 | 34 | gts = evaluation.PTBTokenizer.tokenize(gts) 35 | gen = evaluation.PTBTokenizer.tokenize(gen) 36 | scores, _ = evaluation.compute_scores(gts, gen) 37 | 38 | return scores 39 | 40 | 41 | if __name__ == '__main__': 42 | device = torch.device('cuda') 43 | 44 | parser = argparse.ArgumentParser(description='Meshed-Memory Transformer') 45 | parser.add_argument('--batch_size', type=int, default=10) 46 | parser.add_argument('--workers', type=int, default=0) 47 | parser.add_argument('--features_path', type=str) 48 | parser.add_argument('--annotation_folder', type=str) 49 | args = parser.parse_args() 50 | 51 | print('Meshed-Memory Transformer Evaluation') 52 | 53 | # Pipeline for image regions 54 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False) 55 | 56 | # Pipeline for text 57 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 58 | remove_punctuation=True, nopoints=False) 59 | 60 | # Create the dataset 61 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) 62 | _, _, test_dataset = dataset.splits 63 | text_field.vocab = pickle.load(open('vocab.pkl', 'rb')) 64 | 65 | # Model and dataloaders 66 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, 67 | attention_module_kwargs={'m': 40}) 68 | decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 69 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 70 | 71 | data = torch.load('meshed_memory_transformer.pth') 72 | model.load_state_dict(data['state_dict']) 73 | 74 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 75 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) 76 | 77 | scores = predict_captions(model, dict_dataloader_test, text_field) 78 | print(scores) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M²: Meshed-Memory Transformer 2 | This repository contains the reference code for the paper _[Meshed-Memory Transformer for Image Captioning](https://arxiv.org/abs/1912.08226)_ (CVPR 2020). 3 | 4 | Please cite with the following BibTeX: 5 | 6 | ``` 7 | @inproceedings{cornia2020m2, 8 | title={{Meshed-Memory Transformer for Image Captioning}}, 9 | author={Cornia, Marcella and Stefanini, Matteo and Baraldi, Lorenzo and Cucchiara, Rita}, 10 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 11 | year={2020} 12 | } 13 | ``` 14 |

15 | Meshed-Memory Transformer 16 |

17 | 18 | ## Environment setup 19 | Clone the repository and create the `m2release` conda environment using the `environment.yml` file: 20 | ``` 21 | conda env create -f environment.yml 22 | conda activate m2release 23 | ``` 24 | 25 | Then download spacy data by executing the following command: 26 | ``` 27 | python -m spacy download en 28 | ``` 29 | 30 | Note: Python 3.6 is required to run our code. 31 | 32 | 33 | ## Data preparation 34 | To run the code, annotations and detection features for the COCO dataset are needed. Please download the annotations file [annotations.zip](https://ailb-web.ing.unimore.it/publicfiles/drive/meshed-memory-transformer/annotations.zip) and extract it. 35 | 36 | Detection features are computed with the code provided by [1]. To reproduce our result, please download the COCO features file [coco_detections.hdf5](https://ailb-web.ing.unimore.it/publicfiles/drive/show-control-and-tell/coco_detections.hdf5) (~53.5 GB), in which detections of each image are stored under the `_features` key. `` is the id of each COCO image, without leading zeros (e.g. the `` for `COCO_val2014_000000037209.jpg` is `37209`), and each value should be a `(N, 2048)` tensor, where `N` is the number of detections. 37 | 38 | 39 | ## Evaluation 40 | To reproduce the results reported in our paper, download the pretrained model file [meshed_memory_transformer.pth](https://ailb-web.ing.unimore.it/publicfiles/drive/meshed-memory-transformer/meshed_memory_transformer.pth) and place it in the code folder. 41 | 42 | Run `python test.py` using the following arguments: 43 | 44 | | Argument | Possible values | 45 | |------|------| 46 | | `--batch_size` | Batch size (default: 10) | 47 | | `--workers` | Number of workers (default: 0) | 48 | | `--features_path` | Path to detection features file | 49 | | `--annotation_folder` | Path to folder with COCO annotations | 50 | 51 | #### Expected output 52 | Under `output_logs/`, you may also find the expected output of the evaluation code. 53 | 54 | 55 | ## Training procedure 56 | Run `python train.py` using the following arguments: 57 | 58 | | Argument | Possible values | 59 | |------|------| 60 | | `--exp_name` | Experiment name| 61 | | `--batch_size` | Batch size (default: 10) | 62 | | `--workers` | Number of workers (default: 0) | 63 | | `--m` | Number of memory vectors (default: 40) | 64 | | `--head` | Number of heads (default: 8) | 65 | | `--warmup` | Warmup value for learning rate scheduling (default: 10000) | 66 | | `--resume_last` | If used, the training will be resumed from the last checkpoint. | 67 | | `--resume_best` | If used, the training will be resumed from the best checkpoint. | 68 | | `--features_path` | Path to detection features file | 69 | | `--annotation_folder` | Path to folder with COCO annotations | 70 | | `--logs_folder` | Path folder for tensorboard logs (default: "tensorboard_logs")| 71 | 72 | For example, to train our model with the parameters used in our experiments, use 73 | ``` 74 | python train.py --exp_name m2_transformer --batch_size 50 --m 40 --head 8 --warmup 10000 --features_path /path/to/features --annotation_folder /path/to/annotations 75 | ``` 76 | 77 |

78 | Sample Results 79 |

80 | 81 | #### References 82 | [1] P. Anderson, X. He, C. Buehler, D. Teney, M. Johnson, S. Gould, and L. Zhang. Bottom-up and top-down attention for image captioning and visual question answering. In _Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition_, 2018. 83 | -------------------------------------------------------------------------------- /evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from models.transformer.attention import MultiHeadAttention 7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module, ModuleList 9 | 10 | 11 | class MeshedDecoderLayer(Module): 12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(MeshedDecoderLayer, self).__init__() 15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 16 | attention_module=self_att_module, 17 | attention_module_kwargs=self_att_module_kwargs) 18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 19 | attention_module=enc_att_module, 20 | attention_module_kwargs=enc_att_module_kwargs) 21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 22 | 23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model) 24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model) 25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model) 26 | 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | nn.init.xavier_uniform_(self.fc_alpha1.weight) 31 | nn.init.xavier_uniform_(self.fc_alpha2.weight) 32 | nn.init.xavier_uniform_(self.fc_alpha3.weight) 33 | nn.init.constant_(self.fc_alpha1.bias, 0) 34 | nn.init.constant_(self.fc_alpha2.bias, 0) 35 | nn.init.constant_(self.fc_alpha3.bias, 0) 36 | 37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 38 | self_att = self.self_att(input, input, input, mask_self_att) 39 | self_att = self_att * mask_pad 40 | 41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad 42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad 43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad 44 | 45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) 46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) 47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) 48 | 49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3) 50 | enc_att = enc_att * mask_pad 51 | 52 | ff = self.pwff(enc_att) 53 | ff = ff * mask_pad 54 | return ff 55 | 56 | 57 | class MeshedDecoder(Module): 58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 60 | super(MeshedDecoder, self).__init__() 61 | self.d_model = d_model 62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 64 | self.layers = ModuleList( 65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 68 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 69 | self.max_len = max_len 70 | self.padding_idx = padding_idx 71 | self.N = N_dec 72 | 73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 74 | self.register_state('running_seq', torch.zeros((1,)).long()) 75 | 76 | def forward(self, input, encoder_output, mask_encoder): 77 | # input (b_s, seq_len) 78 | b_s, seq_len = input.shape[:2] 79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 81 | diagonal=1) 82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 85 | if self._is_stateful: 86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 87 | mask_self_attention = self.running_mask_self_attention 88 | 89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 91 | if self._is_stateful: 92 | self.running_seq.add_(1) 93 | seq = self.running_seq 94 | 95 | out = self.word_emb(input) + self.pos_emb(seq) 96 | for i, l in enumerate(self.layers): 97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 98 | 99 | out = self.fc(out) 100 | return F.log_softmax(out, dim=-1) 101 | -------------------------------------------------------------------------------- /evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_log_probs = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | shape = [int(sh) for sh in s.shape] 23 | beam = selected_beam 24 | for _ in shape[1:]: 25 | beam = beam.unsqueeze(-1) 26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 28 | s = s.view(*([-1, ] + shape[1:])) 29 | return s 30 | 31 | return fn 32 | 33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 34 | if isinstance(visual, torch.Tensor): 35 | visual_shape = visual.shape 36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 40 | visual_exp = visual.view(visual_exp_shape) 41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 43 | else: 44 | new_visual = [] 45 | for im in visual: 46 | visual_shape = im.shape 47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 51 | visual_exp = im.view(visual_exp_shape) 52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 54 | new_visual.append(new_im) 55 | visual = tuple(new_visual) 56 | return visual 57 | 58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): 59 | self.b_s = utils.get_batch_size(visual) 60 | self.device = utils.get_device(visual) 61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 63 | self.log_probs = [] 64 | self.selected_words = None 65 | if return_probs: 66 | self.all_log_probs = [] 67 | 68 | outputs = [] 69 | with self.model.statefulness(self.b_s): 70 | for t in range(self.max_len): 71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) 72 | 73 | # Sort result 74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 75 | outputs = torch.cat(outputs, -1) 76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 77 | log_probs = torch.cat(self.log_probs, -1) 78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 79 | if return_probs: 80 | all_log_probs = torch.cat(self.all_log_probs, 2) 81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 82 | self.max_len, 83 | all_log_probs.shape[-1])) 84 | 85 | outputs = outputs.contiguous()[:, :out_size] 86 | log_probs = log_probs.contiguous()[:, :out_size] 87 | if out_size == 1: 88 | outputs = outputs.squeeze(1) 89 | log_probs = log_probs.squeeze(1) 90 | 91 | if return_probs: 92 | return outputs, log_probs, all_log_probs 93 | else: 94 | return outputs, log_probs 95 | 96 | def select(self, t, candidate_logprob, **kwargs): 97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 99 | return selected_idx, selected_logprob 100 | 101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs): 102 | cur_beam_size = 1 if t == 0 else self.beam_size 103 | 104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) 105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 106 | candidate_logprob = self.seq_logprob + word_logprob 107 | 108 | # Mask sequence if it reaches EOS 109 | if t > 0: 110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 111 | self.seq_mask = self.seq_mask * mask 112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 114 | old_seq_logprob[:, :, 1:] = -999 115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 116 | 117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 118 | selected_beam = selected_idx / candidate_logprob.shape[-1] 119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 120 | 121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 123 | 124 | self.seq_logprob = selected_logprob.unsqueeze(-1) 125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 127 | outputs.append(selected_words.unsqueeze(-1)) 128 | 129 | if return_probs: 130 | if t == 0: 131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 132 | else: 133 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 134 | 135 | this_word_logprob = torch.gather(word_logprob, 1, 136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 137 | word_logprob.shape[-1])) 138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 139 | self.log_probs = list( 140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 141 | self.log_probs.append(this_word_logprob) 142 | self.selected_words = selected_words.view(-1, 1) 143 | 144 | return visual, outputs 145 | -------------------------------------------------------------------------------- /models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.containers import Module 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' 9 | Scaled dot-product attention 10 | ''' 11 | 12 | def __init__(self, d_model, d_k, d_v, h): 13 | ''' 14 | :param d_model: Output dimensionality of the model 15 | :param d_k: Dimensionality of queries and keys 16 | :param d_v: Dimensionality of values 17 | :param h: Number of heads 18 | ''' 19 | super(ScaledDotProductAttention, self).__init__() 20 | self.fc_q = nn.Linear(d_model, h * d_k) 21 | self.fc_k = nn.Linear(d_model, h * d_k) 22 | self.fc_v = nn.Linear(d_model, h * d_v) 23 | self.fc_o = nn.Linear(h * d_v, d_model) 24 | 25 | self.d_model = d_model 26 | self.d_k = d_k 27 | self.d_v = d_v 28 | self.h = h 29 | 30 | self.init_weights() 31 | 32 | def init_weights(self): 33 | nn.init.xavier_uniform_(self.fc_q.weight) 34 | nn.init.xavier_uniform_(self.fc_k.weight) 35 | nn.init.xavier_uniform_(self.fc_v.weight) 36 | nn.init.xavier_uniform_(self.fc_o.weight) 37 | nn.init.constant_(self.fc_q.bias, 0) 38 | nn.init.constant_(self.fc_k.bias, 0) 39 | nn.init.constant_(self.fc_v.bias, 0) 40 | nn.init.constant_(self.fc_o.bias, 0) 41 | 42 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 43 | ''' 44 | Computes 45 | :param queries: Queries (b_s, nq, d_model) 46 | :param keys: Keys (b_s, nk, d_model) 47 | :param values: Values (b_s, nk, d_model) 48 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 49 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 50 | :return: 51 | ''' 52 | b_s, nq = queries.shape[:2] 53 | nk = keys.shape[1] 54 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 55 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 56 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 57 | 58 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 59 | if attention_weights is not None: 60 | att = att * attention_weights 61 | if attention_mask is not None: 62 | att = att.masked_fill(attention_mask, -np.inf) 63 | att = torch.softmax(att, -1) 64 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 65 | out = self.fc_o(out) # (b_s, nq, d_model) 66 | return out 67 | 68 | 69 | class ScaledDotProductAttentionMemory(nn.Module): 70 | ''' 71 | Scaled dot-product attention with memory 72 | ''' 73 | 74 | def __init__(self, d_model, d_k, d_v, h, m): 75 | ''' 76 | :param d_model: Output dimensionality of the model 77 | :param d_k: Dimensionality of queries and keys 78 | :param d_v: Dimensionality of values 79 | :param h: Number of heads 80 | :param m: Number of memory slots 81 | ''' 82 | super(ScaledDotProductAttentionMemory, self).__init__() 83 | self.fc_q = nn.Linear(d_model, h * d_k) 84 | self.fc_k = nn.Linear(d_model, h * d_k) 85 | self.fc_v = nn.Linear(d_model, h * d_v) 86 | self.fc_o = nn.Linear(h * d_v, d_model) 87 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) 88 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) 89 | 90 | self.d_model = d_model 91 | self.d_k = d_k 92 | self.d_v = d_v 93 | self.h = h 94 | self.m = m 95 | 96 | self.init_weights() 97 | 98 | def init_weights(self): 99 | nn.init.xavier_uniform_(self.fc_q.weight) 100 | nn.init.xavier_uniform_(self.fc_k.weight) 101 | nn.init.xavier_uniform_(self.fc_v.weight) 102 | nn.init.xavier_uniform_(self.fc_o.weight) 103 | nn.init.normal_(self.m_k, 0, 1 / self.d_k) 104 | nn.init.normal_(self.m_v, 0, 1 / self.m) 105 | nn.init.constant_(self.fc_q.bias, 0) 106 | nn.init.constant_(self.fc_k.bias, 0) 107 | nn.init.constant_(self.fc_v.bias, 0) 108 | nn.init.constant_(self.fc_o.bias, 0) 109 | 110 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 111 | ''' 112 | Computes 113 | :param queries: Queries (b_s, nq, d_model) 114 | :param keys: Keys (b_s, nk, d_model) 115 | :param values: Values (b_s, nk, d_model) 116 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 117 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 118 | :return: 119 | ''' 120 | b_s, nq = queries.shape[:2] 121 | nk = keys.shape[1] 122 | 123 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) 124 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) 125 | 126 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 127 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 128 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 129 | 130 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 131 | if attention_weights is not None: 132 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) 133 | if attention_mask is not None: 134 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) 135 | att = torch.softmax(att, -1) 136 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 137 | out = self.fc_o(out) # (b_s, nq, d_model) 138 | return out 139 | 140 | 141 | class MultiHeadAttention(Module): 142 | ''' 143 | Multi-head attention layer with Dropout and Layer Normalization. 144 | ''' 145 | 146 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 147 | attention_module=None, attention_module_kwargs=None): 148 | super(MultiHeadAttention, self).__init__() 149 | self.identity_map_reordering = identity_map_reordering 150 | if attention_module is not None: 151 | if attention_module_kwargs is not None: 152 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs) 153 | else: 154 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 155 | else: 156 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 157 | self.dropout = nn.Dropout(p=dropout) 158 | self.layer_norm = nn.LayerNorm(d_model) 159 | 160 | self.can_be_stateful = can_be_stateful 161 | if self.can_be_stateful: 162 | self.register_state('running_keys', torch.zeros((0, d_model))) 163 | self.register_state('running_values', torch.zeros((0, d_model))) 164 | 165 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 166 | if self.can_be_stateful and self._is_stateful: 167 | self.running_keys = torch.cat([self.running_keys, keys], 1) 168 | keys = self.running_keys 169 | 170 | self.running_values = torch.cat([self.running_values, values], 1) 171 | values = self.running_values 172 | 173 | if self.identity_map_reordering: 174 | q_norm = self.layer_norm(queries) 175 | k_norm = self.layer_norm(keys) 176 | v_norm = self.layer_norm(values) 177 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) 178 | out = queries + self.dropout(torch.relu(out)) 179 | else: 180 | out = self.attention(queries, keys, values, attention_mask, attention_weights) 181 | out = self.dropout(out) 182 | out = self.layer_norm(queries + out) 183 | return out 184 | -------------------------------------------------------------------------------- /evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import collections 5 | import torch 6 | from .example import Example 7 | from .utils import nostdout 8 | from pycocotools.coco import COCO as pyCOCO 9 | 10 | 11 | class Dataset(object): 12 | def __init__(self, examples, fields): 13 | self.examples = examples 14 | self.fields = dict(fields) 15 | 16 | def collate_fn(self): 17 | def collate(batch): 18 | if len(self.fields) == 1: 19 | batch = [batch, ] 20 | else: 21 | batch = list(zip(*batch)) 22 | 23 | tensors = [] 24 | for field, data in zip(self.fields.values(), batch): 25 | tensor = field.process(data) 26 | if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): 27 | tensors.extend(tensor) 28 | else: 29 | tensors.append(tensor) 30 | 31 | if len(tensors) > 1: 32 | return tensors 33 | else: 34 | return tensors[0] 35 | 36 | return collate 37 | 38 | def __getitem__(self, i): 39 | example = self.examples[i] 40 | data = [] 41 | for field_name, field in self.fields.items(): 42 | data.append(field.preprocess(getattr(example, field_name))) 43 | 44 | if len(data) == 1: 45 | data = data[0] 46 | return data 47 | 48 | def __len__(self): 49 | return len(self.examples) 50 | 51 | def __getattr__(self, attr): 52 | if attr in self.fields: 53 | for x in self.examples: 54 | yield getattr(x, attr) 55 | 56 | 57 | class ValueDataset(Dataset): 58 | def __init__(self, examples, fields, dictionary): 59 | self.dictionary = dictionary 60 | super(ValueDataset, self).__init__(examples, fields) 61 | 62 | def collate_fn(self): 63 | def collate(batch): 64 | value_batch_flattened = list(itertools.chain(*batch)) 65 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) 66 | 67 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) 68 | if isinstance(value_tensors_flattened, collections.Sequence) \ 69 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): 70 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened] 71 | else: 72 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] 73 | 74 | return value_tensors 75 | return collate 76 | 77 | def __getitem__(self, i): 78 | if i not in self.dictionary: 79 | raise IndexError 80 | 81 | values_data = [] 82 | for idx in self.dictionary[i]: 83 | value_data = super(ValueDataset, self).__getitem__(idx) 84 | values_data.append(value_data) 85 | return values_data 86 | 87 | def __len__(self): 88 | return len(self.dictionary) 89 | 90 | 91 | class DictionaryDataset(Dataset): 92 | def __init__(self, examples, fields, key_fields): 93 | if not isinstance(key_fields, (tuple, list)): 94 | key_fields = (key_fields,) 95 | for field in key_fields: 96 | assert (field in fields) 97 | 98 | dictionary = collections.defaultdict(list) 99 | key_fields = {k: fields[k] for k in key_fields} 100 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} 101 | key_examples = [] 102 | key_dict = dict() 103 | value_examples = [] 104 | 105 | for i, e in enumerate(examples): 106 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) 107 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) 108 | if key_example not in key_dict: 109 | key_dict[key_example] = len(key_examples) 110 | key_examples.append(key_example) 111 | 112 | value_examples.append(value_example) 113 | dictionary[key_dict[key_example]].append(i) 114 | 115 | self.key_dataset = Dataset(key_examples, key_fields) 116 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) 117 | super(DictionaryDataset, self).__init__(examples, fields) 118 | 119 | def collate_fn(self): 120 | def collate(batch): 121 | key_batch, value_batch = list(zip(*batch)) 122 | key_tensors = self.key_dataset.collate_fn()(key_batch) 123 | value_tensors = self.value_dataset.collate_fn()(value_batch) 124 | return key_tensors, value_tensors 125 | return collate 126 | 127 | def __getitem__(self, i): 128 | return self.key_dataset[i], self.value_dataset[i] 129 | 130 | def __len__(self): 131 | return len(self.key_dataset) 132 | 133 | 134 | def unique(sequence): 135 | seen = set() 136 | if isinstance(sequence[0], list): 137 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] 138 | else: 139 | return [x for x in sequence if not (x in seen or seen.add(x))] 140 | 141 | 142 | class PairedDataset(Dataset): 143 | def __init__(self, examples, fields): 144 | assert ('image' in fields) 145 | assert ('text' in fields) 146 | super(PairedDataset, self).__init__(examples, fields) 147 | self.image_field = self.fields['image'] 148 | self.text_field = self.fields['text'] 149 | 150 | def image_set(self): 151 | img_list = [e.image for e in self.examples] 152 | image_set = unique(img_list) 153 | examples = [Example.fromdict({'image': i}) for i in image_set] 154 | dataset = Dataset(examples, {'image': self.image_field}) 155 | return dataset 156 | 157 | def text_set(self): 158 | text_list = [e.text for e in self.examples] 159 | text_list = unique(text_list) 160 | examples = [Example.fromdict({'text': t}) for t in text_list] 161 | dataset = Dataset(examples, {'text': self.text_field}) 162 | return dataset 163 | 164 | def image_dictionary(self, fields=None): 165 | if not fields: 166 | fields = self.fields 167 | dataset = DictionaryDataset(self.examples, fields, key_fields='image') 168 | return dataset 169 | 170 | def text_dictionary(self, fields=None): 171 | if not fields: 172 | fields = self.fields 173 | dataset = DictionaryDataset(self.examples, fields, key_fields='text') 174 | return dataset 175 | 176 | @property 177 | def splits(self): 178 | raise NotImplementedError 179 | 180 | 181 | class COCO(PairedDataset): 182 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 183 | cut_validation=False): 184 | roots = {} 185 | roots['train'] = { 186 | 'img': os.path.join(img_root, 'train2014'), 187 | 'cap': os.path.join(ann_root, 'captions_train2014.json') 188 | } 189 | roots['val'] = { 190 | 'img': os.path.join(img_root, 'val2014'), 191 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 192 | } 193 | roots['test'] = { 194 | 'img': os.path.join(img_root, 'val2014'), 195 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 196 | } 197 | roots['trainrestval'] = { 198 | 'img': (roots['train']['img'], roots['val']['img']), 199 | 'cap': (roots['train']['cap'], roots['val']['cap']) 200 | } 201 | 202 | if id_root is not None: 203 | ids = {} 204 | ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy')) 205 | ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy')) 206 | if cut_validation: 207 | ids['val'] = ids['val'][:5000] 208 | ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy')) 209 | ids['trainrestval'] = ( 210 | ids['train'], 211 | np.load(os.path.join(id_root, 'coco_restval_ids.npy'))) 212 | 213 | if use_restval: 214 | roots['train'] = roots['trainrestval'] 215 | ids['train'] = ids['trainrestval'] 216 | else: 217 | ids = None 218 | 219 | with nostdout(): 220 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids) 221 | examples = self.train_examples + self.val_examples + self.test_examples 222 | super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field}) 223 | 224 | @property 225 | def splits(self): 226 | train_split = PairedDataset(self.train_examples, self.fields) 227 | val_split = PairedDataset(self.val_examples, self.fields) 228 | test_split = PairedDataset(self.test_examples, self.fields) 229 | return train_split, val_split, test_split 230 | 231 | @classmethod 232 | def get_samples(cls, roots, ids_dataset=None): 233 | train_samples = [] 234 | val_samples = [] 235 | test_samples = [] 236 | 237 | for split in ['train', 'val', 'test']: 238 | if isinstance(roots[split]['cap'], tuple): 239 | coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1])) 240 | root = roots[split]['img'] 241 | else: 242 | coco_dataset = (pyCOCO(roots[split]['cap']),) 243 | root = (roots[split]['img'],) 244 | 245 | if ids_dataset is None: 246 | ids = list(coco_dataset.anns.keys()) 247 | else: 248 | ids = ids_dataset[split] 249 | 250 | if isinstance(ids, tuple): 251 | bp = len(ids[0]) 252 | ids = list(ids[0]) + list(ids[1]) 253 | else: 254 | bp = len(ids) 255 | 256 | for index in range(len(ids)): 257 | if index < bp: 258 | coco = coco_dataset[0] 259 | img_root = root[0] 260 | else: 261 | coco = coco_dataset[1] 262 | img_root = root[1] 263 | 264 | ann_id = ids[index] 265 | caption = coco.anns[ann_id]['caption'] 266 | img_id = coco.anns[ann_id]['image_id'] 267 | filename = coco.loadImgs(img_id)[0]['file_name'] 268 | 269 | example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption}) 270 | 271 | if split == 'train': 272 | train_samples.append(example) 273 | elif split == 'val': 274 | val_samples.append(example) 275 | elif split == 'test': 276 | test_samples.append(example) 277 | 278 | return train_samples, val_samples, test_samples 279 | 280 | -------------------------------------------------------------------------------- /data/field.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | from collections import Counter, OrderedDict 3 | from torch.utils.data.dataloader import default_collate 4 | from itertools import chain 5 | import six 6 | import torch 7 | import numpy as np 8 | import h5py 9 | import os 10 | import warnings 11 | import shutil 12 | 13 | from .dataset import Dataset 14 | from .vocab import Vocab 15 | from .utils import get_tokenizer 16 | 17 | 18 | class RawField(object): 19 | """ Defines a general datatype. 20 | 21 | Every dataset consists of one or more types of data. For instance, 22 | a machine translation dataset contains paired examples of text, while 23 | an image captioning dataset contains images and texts. 24 | Each of these types of data is represented by a RawField object. 25 | An RawField object does not assume any property of the data type and 26 | it holds parameters relating to how a datatype should be processed. 27 | 28 | Attributes: 29 | preprocessing: The Pipeline that will be applied to examples 30 | using this field before creating an example. 31 | Default: None. 32 | postprocessing: A Pipeline that will be applied to a list of examples 33 | using this field before assigning to a batch. 34 | Function signature: (batch(list)) -> object 35 | Default: None. 36 | """ 37 | 38 | def __init__(self, preprocessing=None, postprocessing=None): 39 | self.preprocessing = preprocessing 40 | self.postprocessing = postprocessing 41 | 42 | def preprocess(self, x): 43 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """ 44 | if self.preprocessing is not None: 45 | return self.preprocessing(x) 46 | else: 47 | return x 48 | 49 | def process(self, batch, *args, **kwargs): 50 | """ Process a list of examples to create a batch. 51 | 52 | Postprocess the batch with user-provided Pipeline. 53 | 54 | Args: 55 | batch (list(object)): A list of object from a batch of examples. 56 | Returns: 57 | object: Processed object given the input and custom 58 | postprocessing Pipeline. 59 | """ 60 | if self.postprocessing is not None: 61 | batch = self.postprocessing(batch) 62 | return default_collate(batch) 63 | 64 | 65 | class Merge(RawField): 66 | def __init__(self, *fields): 67 | super(Merge, self).__init__() 68 | self.fields = fields 69 | 70 | def preprocess(self, x): 71 | return tuple(f.preprocess(x) for f in self.fields) 72 | 73 | def process(self, batch, *args, **kwargs): 74 | if len(self.fields) == 1: 75 | batch = [batch, ] 76 | else: 77 | batch = list(zip(*batch)) 78 | 79 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch)) 80 | return out 81 | 82 | 83 | class ImageDetectionsField(RawField): 84 | def __init__(self, preprocessing=None, postprocessing=None, detections_path=None, max_detections=100, 85 | sort_by_prob=False, load_in_tmp=True): 86 | self.max_detections = max_detections 87 | self.detections_path = detections_path 88 | self.sort_by_prob = sort_by_prob 89 | 90 | tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path)) 91 | 92 | if load_in_tmp: 93 | if not os.path.isfile(tmp_detections_path): 94 | if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path): 95 | warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path) 96 | else: 97 | warnings.warn("Copying detection file to /tmp") 98 | shutil.copyfile(detections_path, tmp_detections_path) 99 | warnings.warn("Done.") 100 | self.detections_path = tmp_detections_path 101 | else: 102 | self.detections_path = tmp_detections_path 103 | 104 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing) 105 | 106 | def preprocess(self, x, avoid_precomp=False): 107 | image_id = int(x.split('_')[-1].split('.')[0]) 108 | try: 109 | f = h5py.File(self.detections_path, 'r') 110 | precomp_data = f['%d_features' % image_id][()] 111 | if self.sort_by_prob: 112 | precomp_data = precomp_data[np.argsort(np.max(f['%d_cls_prob' % image_id][()], -1))[::-1]] 113 | except KeyError: 114 | warnings.warn('Could not find detections for %d' % image_id) 115 | precomp_data = np.random.rand(10,2048) 116 | 117 | delta = self.max_detections - precomp_data.shape[0] 118 | if delta > 0: 119 | precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0) 120 | elif delta < 0: 121 | precomp_data = precomp_data[:self.max_detections] 122 | 123 | return precomp_data.astype(np.float32) 124 | 125 | 126 | class TextField(RawField): 127 | vocab_cls = Vocab 128 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python 129 | # numeric type. 130 | dtypes = { 131 | torch.float32: float, 132 | torch.float: float, 133 | torch.float64: float, 134 | torch.double: float, 135 | torch.float16: float, 136 | torch.half: float, 137 | 138 | torch.uint8: int, 139 | torch.int8: int, 140 | torch.int16: int, 141 | torch.short: int, 142 | torch.int32: int, 143 | torch.int: int, 144 | torch.int64: int, 145 | torch.long: int, 146 | } 147 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 148 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 149 | 150 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long, 151 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()), 152 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="", 153 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True): 154 | self.use_vocab = use_vocab 155 | self.init_token = init_token 156 | self.eos_token = eos_token 157 | self.fix_length = fix_length 158 | self.dtype = dtype 159 | self.lower = lower 160 | self.tokenize = get_tokenizer(tokenize) 161 | self.remove_punctuation = remove_punctuation 162 | self.include_lengths = include_lengths 163 | self.batch_first = batch_first 164 | self.pad_token = pad_token 165 | self.unk_token = unk_token 166 | self.pad_first = pad_first 167 | self.truncate_first = truncate_first 168 | self.vocab = None 169 | self.vectors = vectors 170 | if nopoints: 171 | self.punctuations.append("..") 172 | 173 | super(TextField, self).__init__(preprocessing, postprocessing) 174 | 175 | def preprocess(self, x): 176 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type): 177 | x = six.text_type(x, encoding='utf-8') 178 | if self.lower: 179 | x = six.text_type.lower(x) 180 | x = self.tokenize(x.rstrip('\n')) 181 | if self.remove_punctuation: 182 | x = [w for w in x if w not in self.punctuations] 183 | if self.preprocessing is not None: 184 | return self.preprocessing(x) 185 | else: 186 | return x 187 | 188 | def process(self, batch, device=None): 189 | padded = self.pad(batch) 190 | tensor = self.numericalize(padded, device=device) 191 | return tensor 192 | 193 | def build_vocab(self, *args, **kwargs): 194 | counter = Counter() 195 | sources = [] 196 | for arg in args: 197 | if isinstance(arg, Dataset): 198 | sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] 199 | else: 200 | sources.append(arg) 201 | 202 | for data in sources: 203 | for x in data: 204 | x = self.preprocess(x) 205 | try: 206 | counter.update(x) 207 | except TypeError: 208 | counter.update(chain.from_iterable(x)) 209 | 210 | specials = list(OrderedDict.fromkeys([ 211 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 212 | self.eos_token] 213 | if tok is not None])) 214 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 215 | 216 | def pad(self, minibatch): 217 | """Pad a batch of examples using this field. 218 | Pads to self.fix_length if provided, otherwise pads to the length of 219 | the longest example in the batch. Prepends self.init_token and appends 220 | self.eos_token if those attributes are not None. Returns a tuple of the 221 | padded list and a list containing lengths of each example if 222 | `self.include_lengths` is `True`, else just 223 | returns the padded list. 224 | """ 225 | minibatch = list(minibatch) 226 | if self.fix_length is None: 227 | max_len = max(len(x) for x in minibatch) 228 | else: 229 | max_len = self.fix_length + ( 230 | self.init_token, self.eos_token).count(None) - 2 231 | padded, lengths = [], [] 232 | for x in minibatch: 233 | if self.pad_first: 234 | padded.append( 235 | [self.pad_token] * max(0, max_len - len(x)) + 236 | ([] if self.init_token is None else [self.init_token]) + 237 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 238 | ([] if self.eos_token is None else [self.eos_token])) 239 | else: 240 | padded.append( 241 | ([] if self.init_token is None else [self.init_token]) + 242 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 243 | ([] if self.eos_token is None else [self.eos_token]) + 244 | [self.pad_token] * max(0, max_len - len(x))) 245 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 246 | if self.include_lengths: 247 | return padded, lengths 248 | return padded 249 | 250 | def numericalize(self, arr, device=None): 251 | """Turn a batch of examples that use this field into a list of Variables. 252 | If the field has include_lengths=True, a tensor of lengths will be 253 | included in the return value. 254 | Arguments: 255 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 256 | List of tokenized and padded examples, or tuple of List of 257 | tokenized and padded examples and List of lengths of each 258 | example if self.include_lengths is True. 259 | device (str or torch.device): A string or instance of `torch.device` 260 | specifying which device the Variables are going to be created on. 261 | If left as default, the tensors will be created on cpu. Default: None. 262 | """ 263 | if self.include_lengths and not isinstance(arr, tuple): 264 | raise ValueError("Field has include_lengths set to True, but " 265 | "input data is not a tuple of " 266 | "(data batch, batch lengths).") 267 | if isinstance(arr, tuple): 268 | arr, lengths = arr 269 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device) 270 | 271 | if self.use_vocab: 272 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 273 | 274 | if self.postprocessing is not None: 275 | arr = self.postprocessing(arr, self.vocab) 276 | 277 | var = torch.tensor(arr, dtype=self.dtype, device=device) 278 | else: 279 | if self.vectors: 280 | arr = [[self.vectors[x] for x in ex] for ex in arr] 281 | if self.dtype not in self.dtypes: 282 | raise ValueError( 283 | "Specified Field dtype {} can not be used with " 284 | "use_vocab=False because we do not know how to numericalize it. " 285 | "Please raise an issue at " 286 | "https://github.com/pytorch/text/issues".format(self.dtype)) 287 | numericalization_func = self.dtypes[self.dtype] 288 | # It doesn't make sense to explictly coerce to a numeric type if 289 | # the data is sequential, since it's unclear how to coerce padding tokens 290 | # to a numeric type. 291 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 292 | else x for x in arr] 293 | 294 | if self.postprocessing is not None: 295 | arr = self.postprocessing(arr, None) 296 | 297 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr]) 298 | 299 | # var = torch.tensor(arr, dtype=self.dtype, device=device) 300 | if not self.batch_first: 301 | var.t_() 302 | var = var.contiguous() 303 | 304 | if self.include_lengths: 305 | return var, lengths 306 | return var 307 | 308 | def decode(self, word_idxs, join_words=True): 309 | if isinstance(word_idxs, list) and len(word_idxs) == 0: 310 | return self.decode([word_idxs, ], join_words)[0] 311 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int): 312 | return self.decode([word_idxs, ], join_words)[0] 313 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1: 314 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0] 315 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1: 316 | return self.decode(word_idxs.unsqueeze(0), join_words)[0] 317 | 318 | captions = [] 319 | for wis in word_idxs: 320 | caption = [] 321 | for wi in wis: 322 | word = self.vocab.itos[int(wi)] 323 | if word == self.eos_token: 324 | break 325 | caption.append(word) 326 | if join_words: 327 | caption = ' '.join(caption) 328 | captions.append(caption) 329 | return captions 330 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | from data import ImageDetectionsField, TextField, RawField 3 | from data import COCO, DataLoader 4 | import evaluation 5 | from evaluation import PTBTokenizer, Cider 6 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory 7 | import torch 8 | from torch.optim import Adam 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from torch.nn import NLLLoss 11 | from tqdm import tqdm 12 | from torch.utils.tensorboard import SummaryWriter 13 | import argparse, os, pickle 14 | import numpy as np 15 | import itertools 16 | import multiprocessing 17 | from shutil import copyfile 18 | 19 | random.seed(1234) 20 | torch.manual_seed(1234) 21 | np.random.seed(1234) 22 | 23 | 24 | def evaluate_loss(model, dataloader, loss_fn, text_field): 25 | # Validation loss 26 | model.eval() 27 | running_loss = .0 28 | with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar: 29 | with torch.no_grad(): 30 | for it, (detections, captions) in enumerate(dataloader): 31 | detections, captions = detections.to(device), captions.to(device) 32 | out = model(detections, captions) 33 | captions = captions[:, 1:].contiguous() 34 | out = out[:, :-1].contiguous() 35 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1)) 36 | this_loss = loss.item() 37 | running_loss += this_loss 38 | 39 | pbar.set_postfix(loss=running_loss / (it + 1)) 40 | pbar.update() 41 | 42 | val_loss = running_loss / len(dataloader) 43 | return val_loss 44 | 45 | 46 | def evaluate_metrics(model, dataloader, text_field): 47 | import itertools 48 | model.eval() 49 | gen = {} 50 | gts = {} 51 | with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader)) as pbar: 52 | for it, (images, caps_gt) in enumerate(iter(dataloader)): 53 | images = images.to(device) 54 | with torch.no_grad(): 55 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 56 | caps_gen = text_field.decode(out, join_words=False) 57 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 58 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 59 | gen['%d_%d' % (it, i)] = [gen_i, ] 60 | gts['%d_%d' % (it, i)] = gts_i 61 | pbar.update() 62 | 63 | gts = evaluation.PTBTokenizer.tokenize(gts) 64 | gen = evaluation.PTBTokenizer.tokenize(gen) 65 | scores, _ = evaluation.compute_scores(gts, gen) 66 | return scores 67 | 68 | 69 | def train_xe(model, dataloader, optim, text_field): 70 | # Training with cross-entropy 71 | model.train() 72 | scheduler.step() 73 | running_loss = .0 74 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar: 75 | for it, (detections, captions) in enumerate(dataloader): 76 | detections, captions = detections.to(device), captions.to(device) 77 | out = model(detections, captions) 78 | optim.zero_grad() 79 | captions_gt = captions[:, 1:].contiguous() 80 | out = out[:, :-1].contiguous() 81 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1)) 82 | loss.backward() 83 | 84 | optim.step() 85 | this_loss = loss.item() 86 | running_loss += this_loss 87 | 88 | pbar.set_postfix(loss=running_loss / (it + 1)) 89 | pbar.update() 90 | scheduler.step() 91 | 92 | loss = running_loss / len(dataloader) 93 | return loss 94 | 95 | 96 | def train_scst(model, dataloader, optim, cider, text_field): 97 | # Training with self-critical 98 | tokenizer_pool = multiprocessing.Pool() 99 | running_reward = .0 100 | running_reward_baseline = .0 101 | model.train() 102 | running_loss = .0 103 | seq_len = 20 104 | beam_size = 5 105 | 106 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar: 107 | for it, (detections, caps_gt) in enumerate(dataloader): 108 | detections = detections.to(device) 109 | outs, log_probs = model.beam_search(detections, seq_len, text_field.vocab.stoi[''], 110 | beam_size, out_size=beam_size) 111 | optim.zero_grad() 112 | 113 | # Rewards 114 | caps_gen = text_field.decode(outs.view(-1, seq_len)) 115 | caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt))) 116 | caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt]) 117 | reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32) 118 | reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size) 119 | reward_baseline = torch.mean(reward, -1, keepdim=True) 120 | loss = -torch.mean(log_probs, -1) * (reward - reward_baseline) 121 | 122 | loss = loss.mean() 123 | loss.backward() 124 | optim.step() 125 | 126 | running_loss += loss.item() 127 | running_reward += reward.mean().item() 128 | running_reward_baseline += reward_baseline.mean().item() 129 | pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1), 130 | reward_baseline=running_reward_baseline / (it + 1)) 131 | pbar.update() 132 | 133 | loss = running_loss / len(dataloader) 134 | reward = running_reward / len(dataloader) 135 | reward_baseline = running_reward_baseline / len(dataloader) 136 | return loss, reward, reward_baseline 137 | 138 | 139 | if __name__ == '__main__': 140 | device = torch.device('cuda') 141 | parser = argparse.ArgumentParser(description='Meshed-Memory Transformer') 142 | parser.add_argument('--exp_name', type=str, default='m2_transformer') 143 | parser.add_argument('--batch_size', type=int, default=10) 144 | parser.add_argument('--workers', type=int, default=0) 145 | parser.add_argument('--m', type=int, default=40) 146 | parser.add_argument('--head', type=int, default=8) 147 | parser.add_argument('--warmup', type=int, default=10000) 148 | parser.add_argument('--resume_last', action='store_true') 149 | parser.add_argument('--resume_best', action='store_true') 150 | parser.add_argument('--features_path', type=str) 151 | parser.add_argument('--annotation_folder', type=str) 152 | parser.add_argument('--logs_folder', type=str, default='tensorboard_logs') 153 | args = parser.parse_args() 154 | print(args) 155 | 156 | print('Meshed-Memory Transformer Training') 157 | 158 | writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name)) 159 | 160 | # Pipeline for image regions 161 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False) 162 | 163 | # Pipeline for text 164 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 165 | remove_punctuation=True, nopoints=False) 166 | 167 | # Create the dataset 168 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) 169 | train_dataset, val_dataset, test_dataset = dataset.splits 170 | 171 | if not os.path.isfile('vocab_%s.pkl' % args.exp_name): 172 | print("Building vocabulary") 173 | text_field.build_vocab(train_dataset, val_dataset, min_freq=5) 174 | pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb')) 175 | else: 176 | text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb')) 177 | 178 | # Model and dataloaders 179 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, 180 | attention_module_kwargs={'m': args.m}) 181 | decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 182 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 183 | 184 | dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 185 | ref_caps_train = list(train_dataset.text) 186 | cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train)) 187 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 188 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 189 | 190 | 191 | def lambda_lr(s): 192 | warm_up = args.warmup 193 | s += 1 194 | return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5) 195 | 196 | 197 | # Initial conditions 198 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) 199 | scheduler = LambdaLR(optim, lambda_lr) 200 | loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['']) 201 | use_rl = False 202 | best_cider = .0 203 | patience = 0 204 | start_epoch = 0 205 | 206 | if args.resume_last or args.resume_best: 207 | if args.resume_last: 208 | fname = 'saved_models/%s_last.pth' % args.exp_name 209 | else: 210 | fname = 'saved_models/%s_best.pth' % args.exp_name 211 | 212 | if os.path.exists(fname): 213 | data = torch.load(fname) 214 | torch.set_rng_state(data['torch_rng_state']) 215 | torch.cuda.set_rng_state(data['cuda_rng_state']) 216 | np.random.set_state(data['numpy_rng_state']) 217 | random.setstate(data['random_rng_state']) 218 | model.load_state_dict(data['state_dict'], strict=False) 219 | optim.load_state_dict(data['optimizer']) 220 | scheduler.load_state_dict(data['scheduler']) 221 | start_epoch = data['epoch'] + 1 222 | best_cider = data['best_cider'] 223 | patience = data['patience'] 224 | use_rl = data['use_rl'] 225 | print('Resuming from epoch %d, validation loss %f, and best cider %f' % ( 226 | data['epoch'], data['val_loss'], data['best_cider'])) 227 | 228 | print("Training starts") 229 | for e in range(start_epoch, start_epoch + 100): 230 | dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 231 | drop_last=True) 232 | dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 233 | dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=True, 234 | num_workers=args.workers) 235 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5) 236 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5) 237 | 238 | if not use_rl: 239 | train_loss = train_xe(model, dataloader_train, optim, text_field) 240 | writer.add_scalar('data/train_loss', train_loss, e) 241 | else: 242 | train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim, cider_train, text_field) 243 | writer.add_scalar('data/train_loss', train_loss, e) 244 | writer.add_scalar('data/reward', reward, e) 245 | writer.add_scalar('data/reward_baseline', reward_baseline, e) 246 | 247 | # Validation loss 248 | val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field) 249 | writer.add_scalar('data/val_loss', val_loss, e) 250 | 251 | # Validation scores 252 | scores = evaluate_metrics(model, dict_dataloader_val, text_field) 253 | print("Validation scores", scores) 254 | val_cider = scores['CIDEr'] 255 | writer.add_scalar('data/val_cider', val_cider, e) 256 | writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e) 257 | writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e) 258 | writer.add_scalar('data/val_meteor', scores['METEOR'], e) 259 | writer.add_scalar('data/val_rouge', scores['ROUGE'], e) 260 | 261 | # Test scores 262 | scores = evaluate_metrics(model, dict_dataloader_test, text_field) 263 | print("Test scores", scores) 264 | writer.add_scalar('data/test_cider', scores['CIDEr'], e) 265 | writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e) 266 | writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e) 267 | writer.add_scalar('data/test_meteor', scores['METEOR'], e) 268 | writer.add_scalar('data/test_rouge', scores['ROUGE'], e) 269 | 270 | # Prepare for next epoch 271 | best = False 272 | if val_cider >= best_cider: 273 | best_cider = val_cider 274 | patience = 0 275 | best = True 276 | else: 277 | patience += 1 278 | 279 | switch_to_rl = False 280 | exit_train = False 281 | if patience == 5: 282 | if not use_rl: 283 | use_rl = True 284 | switch_to_rl = True 285 | patience = 0 286 | optim = Adam(model.parameters(), lr=5e-6) 287 | print("Switching to RL") 288 | else: 289 | print('patience reached.') 290 | exit_train = True 291 | 292 | if switch_to_rl and not best: 293 | data = torch.load('saved_models/%s_best.pth' % args.exp_name) 294 | torch.set_rng_state(data['torch_rng_state']) 295 | torch.cuda.set_rng_state(data['cuda_rng_state']) 296 | np.random.set_state(data['numpy_rng_state']) 297 | random.setstate(data['random_rng_state']) 298 | model.load_state_dict(data['state_dict']) 299 | print('Resuming from epoch %d, validation loss %f, and best cider %f' % ( 300 | data['epoch'], data['val_loss'], data['best_cider'])) 301 | 302 | torch.save({ 303 | 'torch_rng_state': torch.get_rng_state(), 304 | 'cuda_rng_state': torch.cuda.get_rng_state(), 305 | 'numpy_rng_state': np.random.get_state(), 306 | 'random_rng_state': random.getstate(), 307 | 'epoch': e, 308 | 'val_loss': val_loss, 309 | 'val_cider': val_cider, 310 | 'state_dict': model.state_dict(), 311 | 'optimizer': optim.state_dict(), 312 | 'scheduler': scheduler.state_dict(), 313 | 'patience': patience, 314 | 'best_cider': best_cider, 315 | 'use_rl': use_rl, 316 | }, 'saved_models/%s_last.pth' % args.exp_name) 317 | 318 | if best: 319 | copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best.pth' % args.exp_name) 320 | 321 | if exit_train: 322 | writer.close() 323 | break 324 | -------------------------------------------------------------------------------- /data/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import array 3 | from collections import defaultdict 4 | from functools import partial 5 | import io 6 | import logging 7 | import os 8 | import zipfile 9 | 10 | import six 11 | from six.moves.urllib.request import urlretrieve 12 | import torch 13 | from tqdm import tqdm 14 | import tarfile 15 | 16 | from .utils import reporthook 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Vocab(object): 22 | """Defines a vocabulary object that will be used to numericalize a field. 23 | 24 | Attributes: 25 | freqs: A collections.Counter object holding the frequencies of tokens 26 | in the data used to build the Vocab. 27 | stoi: A collections.defaultdict instance mapping token strings to 28 | numerical identifiers. 29 | itos: A list of token strings indexed by their numerical identifiers. 30 | """ 31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''], 32 | vectors=None, unk_init=None, vectors_cache=None): 33 | """Create a Vocab object from a collections.Counter. 34 | 35 | Arguments: 36 | counter: collections.Counter object holding the frequencies of 37 | each value found in the data. 38 | max_size: The maximum size of the vocabulary, or None for no 39 | maximum. Default: None. 40 | min_freq: The minimum frequency needed to include a token in the 41 | vocabulary. Values less than 1 will be set to 1. Default: 1. 42 | specials: The list of special tokens (e.g., padding or eos) that 43 | will be prepended to the vocabulary in addition to an 44 | token. Default: [''] 45 | vectors: One of either the available pretrained vectors 46 | or custom pretrained vectors (see Vocab.load_vectors); 47 | or a list of aforementioned vectors 48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 49 | to zero vectors; can be any function that takes in a Tensor and 50 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 51 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 52 | """ 53 | self.freqs = counter 54 | counter = counter.copy() 55 | min_freq = max(min_freq, 1) 56 | 57 | self.itos = list(specials) 58 | # frequencies of special tokens are not counted when building vocabulary 59 | # in frequency order 60 | for tok in specials: 61 | del counter[tok] 62 | 63 | max_size = None if max_size is None else max_size + len(self.itos) 64 | 65 | # sort by frequency, then alphabetically 66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 68 | 69 | for word, freq in words_and_frequencies: 70 | if freq < min_freq or len(self.itos) == max_size: 71 | break 72 | self.itos.append(word) 73 | 74 | self.stoi = defaultdict(_default_unk_index) 75 | # stoi is simply a reverse dict for itos 76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) 77 | 78 | self.vectors = None 79 | if vectors is not None: 80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 81 | else: 82 | assert unk_init is None and vectors_cache is None 83 | 84 | def __eq__(self, other): 85 | if self.freqs != other.freqs: 86 | return False 87 | if self.stoi != other.stoi: 88 | return False 89 | if self.itos != other.itos: 90 | return False 91 | if self.vectors != other.vectors: 92 | return False 93 | return True 94 | 95 | def __len__(self): 96 | return len(self.itos) 97 | 98 | def extend(self, v, sort=False): 99 | words = sorted(v.itos) if sort else v.itos 100 | for w in words: 101 | if w not in self.stoi: 102 | self.itos.append(w) 103 | self.stoi[w] = len(self.itos) - 1 104 | 105 | def load_vectors(self, vectors, **kwargs): 106 | """ 107 | Arguments: 108 | vectors: one of or a list containing instantiations of the 109 | GloVe, CharNGram, or Vectors classes. Alternatively, one 110 | of or a list of available pretrained vectors: 111 | charngram.100d 112 | fasttext.en.300d 113 | fasttext.simple.300d 114 | glove.42B.300d 115 | glove.840B.300d 116 | glove.twitter.27B.25d 117 | glove.twitter.27B.50d 118 | glove.twitter.27B.100d 119 | glove.twitter.27B.200d 120 | glove.6B.50d 121 | glove.6B.100d 122 | glove.6B.200d 123 | glove.6B.300d 124 | Remaining keyword arguments: Passed to the constructor of Vectors classes. 125 | """ 126 | if not isinstance(vectors, list): 127 | vectors = [vectors] 128 | for idx, vector in enumerate(vectors): 129 | if six.PY2 and isinstance(vector, str): 130 | vector = six.text_type(vector) 131 | if isinstance(vector, six.string_types): 132 | # Convert the string pretrained vector identifier 133 | # to a Vectors object 134 | if vector not in pretrained_aliases: 135 | raise ValueError( 136 | "Got string input vector {}, but allowed pretrained " 137 | "vectors are {}".format( 138 | vector, list(pretrained_aliases.keys()))) 139 | vectors[idx] = pretrained_aliases[vector](**kwargs) 140 | elif not isinstance(vector, Vectors): 141 | raise ValueError( 142 | "Got input vectors of type {}, expected str or " 143 | "Vectors object".format(type(vector))) 144 | 145 | tot_dim = sum(v.dim for v in vectors) 146 | self.vectors = torch.Tensor(len(self), tot_dim) 147 | for i, token in enumerate(self.itos): 148 | start_dim = 0 149 | for v in vectors: 150 | end_dim = start_dim + v.dim 151 | self.vectors[i][start_dim:end_dim] = v[token.strip()] 152 | start_dim = end_dim 153 | assert(start_dim == tot_dim) 154 | 155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): 156 | """ 157 | Set the vectors for the Vocab instance from a collection of Tensors. 158 | 159 | Arguments: 160 | stoi: A dictionary of string to the index of the associated vector 161 | in the `vectors` input argument. 162 | vectors: An indexed iterable (or other structure supporting __getitem__) that 163 | given an input index, returns a FloatTensor representing the vector 164 | for the token associated with the index. For example, 165 | vector[stoi["string"]] should return the vector for "string". 166 | dim: The dimensionality of the vectors. 167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 168 | to zero vectors; can be any function that takes in a Tensor and 169 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 170 | """ 171 | self.vectors = torch.Tensor(len(self), dim) 172 | for i, token in enumerate(self.itos): 173 | wv_index = stoi.get(token, None) 174 | if wv_index is not None: 175 | self.vectors[i] = vectors[wv_index] 176 | else: 177 | self.vectors[i] = unk_init(self.vectors[i]) 178 | 179 | 180 | class Vectors(object): 181 | 182 | def __init__(self, name, cache=None, 183 | url=None, unk_init=None): 184 | """ 185 | Arguments: 186 | name: name of the file that contains the vectors 187 | cache: directory for cached vectors 188 | url: url for download if vectors not found in cache 189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors 190 | to zero vectors; can be any function that takes in a Tensor and 191 | returns a Tensor of the same size 192 | """ 193 | cache = '.vector_cache' if cache is None else cache 194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init 195 | self.cache(name, cache, url=url) 196 | 197 | def __getitem__(self, token): 198 | if token in self.stoi: 199 | return self.vectors[self.stoi[token]] 200 | else: 201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim)) 202 | 203 | def cache(self, name, cache, url=None): 204 | if os.path.isfile(name): 205 | path = name 206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' 207 | else: 208 | path = os.path.join(cache, name) 209 | path_pt = path + '.pt' 210 | 211 | if not os.path.isfile(path_pt): 212 | if not os.path.isfile(path) and url: 213 | logger.info('Downloading vectors from {}'.format(url)) 214 | if not os.path.exists(cache): 215 | os.makedirs(cache) 216 | dest = os.path.join(cache, os.path.basename(url)) 217 | if not os.path.isfile(dest): 218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 219 | try: 220 | urlretrieve(url, dest, reporthook=reporthook(t)) 221 | except KeyboardInterrupt as e: # remove the partial zip file 222 | os.remove(dest) 223 | raise e 224 | logger.info('Extracting vectors into {}'.format(cache)) 225 | ext = os.path.splitext(dest)[1][1:] 226 | if ext == 'zip': 227 | with zipfile.ZipFile(dest, "r") as zf: 228 | zf.extractall(cache) 229 | elif ext == 'gz': 230 | with tarfile.open(dest, 'r:gz') as tar: 231 | tar.extractall(path=cache) 232 | if not os.path.isfile(path): 233 | raise RuntimeError('no vectors found at {}'.format(path)) 234 | 235 | # str call is necessary for Python 2/3 compatibility, since 236 | # argument must be Python 2 str (Python 3 bytes) or 237 | # Python 3 str (Python 2 unicode) 238 | itos, vectors, dim = [], array.array(str('d')), None 239 | 240 | # Try to read the whole file with utf-8 encoding. 241 | binary_lines = False 242 | try: 243 | with io.open(path, encoding="utf8") as f: 244 | lines = [line for line in f] 245 | # If there are malformed lines, read in binary mode 246 | # and manually decode each word from utf-8 247 | except: 248 | logger.warning("Could not read {} as UTF8 file, " 249 | "reading file as bytes and skipping " 250 | "words with malformed UTF8.".format(path)) 251 | with open(path, 'rb') as f: 252 | lines = [line for line in f] 253 | binary_lines = True 254 | 255 | logger.info("Loading vectors from {}".format(path)) 256 | for line in tqdm(lines, total=len(lines)): 257 | # Explicitly splitting on " " is important, so we don't 258 | # get rid of Unicode non-breaking spaces in the vectors. 259 | entries = line.rstrip().split(b" " if binary_lines else " ") 260 | 261 | word, entries = entries[0], entries[1:] 262 | if dim is None and len(entries) > 1: 263 | dim = len(entries) 264 | elif len(entries) == 1: 265 | logger.warning("Skipping token {} with 1-dimensional " 266 | "vector {}; likely a header".format(word, entries)) 267 | continue 268 | elif dim != len(entries): 269 | raise RuntimeError( 270 | "Vector for token {} has {} dimensions, but previously " 271 | "read vectors have {} dimensions. All vectors must have " 272 | "the same number of dimensions.".format(word, len(entries), dim)) 273 | 274 | if binary_lines: 275 | try: 276 | if isinstance(word, six.binary_type): 277 | word = word.decode('utf-8') 278 | except: 279 | logger.info("Skipping non-UTF8 token {}".format(repr(word))) 280 | continue 281 | vectors.extend(float(x) for x in entries) 282 | itos.append(word) 283 | 284 | self.itos = itos 285 | self.stoi = {word: i for i, word in enumerate(itos)} 286 | self.vectors = torch.Tensor(vectors).view(-1, dim) 287 | self.dim = dim 288 | logger.info('Saving vectors to {}'.format(path_pt)) 289 | if not os.path.exists(cache): 290 | os.makedirs(cache) 291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 292 | else: 293 | logger.info('Loading vectors from {}'.format(path_pt)) 294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 295 | 296 | 297 | class GloVe(Vectors): 298 | url = { 299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 303 | } 304 | 305 | def __init__(self, name='840B', dim=300, **kwargs): 306 | url = self.url[name] 307 | name = 'glove.{}.{}d.txt'.format(name, str(dim)) 308 | super(GloVe, self).__init__(name, url=url, **kwargs) 309 | 310 | 311 | class FastText(Vectors): 312 | 313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' 314 | 315 | def __init__(self, language="en", **kwargs): 316 | url = self.url_base.format(language) 317 | name = os.path.basename(url) 318 | super(FastText, self).__init__(name, url=url, **kwargs) 319 | 320 | 321 | class CharNGram(Vectors): 322 | 323 | name = 'charNgram.txt' 324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 325 | 'jmt_pre-trained_embeddings.tar.gz') 326 | 327 | def __init__(self, **kwargs): 328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) 329 | 330 | def __getitem__(self, token): 331 | vector = torch.Tensor(1, self.dim).zero_() 332 | if token == "": 333 | return self.unk_init(vector) 334 | # These literals need to be coerced to unicode for Python 2 compatibility 335 | # when we try to join them with read ngrams from the files. 336 | chars = ['#BEGIN#'] + list(token) + ['#END#'] 337 | num_vectors = 0 338 | for n in [2, 3, 4]: 339 | end = len(chars) - n + 1 340 | grams = [chars[i:(i + n)] for i in range(end)] 341 | for gram in grams: 342 | gram_key = '{}gram-{}'.format(n, ''.join(gram)) 343 | if gram_key in self.stoi: 344 | vector += self.vectors[self.stoi[gram_key]] 345 | num_vectors += 1 346 | if num_vectors > 0: 347 | vector /= num_vectors 348 | else: 349 | vector = self.unk_init(vector) 350 | return vector 351 | 352 | 353 | def _default_unk_index(): 354 | return 0 355 | 356 | 357 | pretrained_aliases = { 358 | "charngram.100d": partial(CharNGram), 359 | "fasttext.en.300d": partial(FastText, language="en"), 360 | "fasttext.simple.300d": partial(FastText, language="simple"), 361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"), 362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"), 363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), 364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), 365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), 366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), 367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"), 368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"), 369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"), 370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300") 371 | } 372 | """Mapping from string name to factory function""" 373 | --------------------------------------------------------------------------------