├── common ├── models │ ├── backbone │ │ └── __init__.py │ ├── beam_search │ │ ├── __init__.py │ │ └── beam_search.py │ ├── transformer │ │ ├── __init__.py │ │ ├── PolarRPE.py │ │ ├── transformer.py │ │ ├── encoders.py │ │ ├── decoders.py │ │ └── utils.py │ ├── __init__.py │ ├── captioning_model.py │ └── containers.py ├── evaluation │ ├── bleu │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── meteor │ │ ├── __init__.py │ │ └── meteor.py │ ├── rouge │ │ ├── __init__.py │ │ └── rouge.py │ ├── spice │ │ ├── __init__.py │ │ └── spice.py │ ├── tokenizer │ │ ├── __init__.py │ │ ├── stanford-corenlp-3.4.1.jar │ │ └── ptbtokenizer.py │ ├── __init__.py │ └── eval.py ├── utils │ ├── typing.py │ ├── __init__.py │ ├── tsv_file.py │ └── utils.py ├── data │ ├── __init__.py │ ├── example.py │ ├── utils.py │ ├── dataset.py │ ├── vocab.py │ └── field.py ├── visualization.py ├── online_test.py ├── test.py └── train.py ├── cache └── vocab.pkl ├── models ├── __init__.py ├── transformer.py ├── encoder_gce.py └── decoder_cca.py ├── .gitignore ├── tools ├── extract_clip_feature.py └── transform_vinvl_feature.py ├── train.py ├── README.md ├── requirements.txt └── environment.yml /common/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /common/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /common/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /common/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /common/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /common/evaluation/spice/__init__.py: -------------------------------------------------------------------------------- 1 | from .spice import Spice -------------------------------------------------------------------------------- /common/evaluation/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .ptbtokenizer import PTBTokenizer -------------------------------------------------------------------------------- /common/models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | -------------------------------------------------------------------------------- /cache/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimingboya/DFT/HEAD/cache/vocab.pkl -------------------------------------------------------------------------------- /common/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /common/evaluation/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimingboya/DFT/HEAD/common/evaluation/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder_gce import build_encoder 2 | from .decoder_cca import build_decoder 3 | from .transformer import Transformer, TransformerEnsemble -------------------------------------------------------------------------------- /common/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .captioning_model import CaptioningModel 3 | from .containers import Module, ModuleList, ModuleDict -------------------------------------------------------------------------------- /common/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Tuple 2 | import torch 3 | 4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 5 | TensorOrNone = Union[torch.Tensor, None] 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coco/* 2 | cache/cider_cache.pkl 3 | common/evaluation/meteor/data 4 | common/evaluation/spice/cache 5 | common/evaluation/spice/lib 6 | tensorboard_logs 7 | __pycache__ 8 | .DS_Store 9 | tmp* 10 | *.jar 11 | *.tgz 12 | *.pth 13 | !stanford-corenlp-3.4.1.jar 14 | -------------------------------------------------------------------------------- /common/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import * 2 | from .dataset import COCODataset 3 | from torch.utils.data import DataLoader as TorchDataLoader 4 | 5 | class DataLoader(TorchDataLoader): 6 | def __init__(self, dataset, *args, **kwargs): 7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) 8 | -------------------------------------------------------------------------------- /common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .typing import * 3 | 4 | def get_batch_size(x: TensorOrSequence) -> int: 5 | if isinstance(x, Sequence): 6 | b_s = x[0].size(0) 7 | else: 8 | b_s = x.size(0) 9 | return b_s 10 | 11 | 12 | def get_device(x: TensorOrSequence) -> int: 13 | if isinstance(x, Sequence): 14 | device = x[0].device 15 | else: 16 | device = x.device 17 | return device 18 | -------------------------------------------------------------------------------- /common/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu 2 | from .meteor import Meteor 3 | from .rouge import Rouge 4 | from .cider import Cider 5 | from .spice import Spice 6 | from .tokenizer import PTBTokenizer 7 | from .eval import COCOEvalCap 8 | 9 | def compute_scores(gts, gen): 10 | metrics = (Bleu(), Meteor(), Rouge(), Cider(), Spice()) 11 | # metrics = (Bleu(), Rouge(), Cider()) 12 | all_score = {} 13 | all_scores = {} 14 | for metric in metrics: 15 | score, scores = metric.compute_score(gts, gen) 16 | all_score[str(metric)] = score 17 | all_scores[str(metric)] = scores 18 | 19 | return all_score, all_scores 20 | -------------------------------------------------------------------------------- /common/data/example.py: -------------------------------------------------------------------------------- 1 | 2 | class Example(object): 3 | """Defines a single training or test example. 4 | Stores each column of the example as an attribute. 5 | """ 6 | @classmethod 7 | def fromdict(cls, data): 8 | ex = cls(data) 9 | return ex 10 | 11 | def __init__(self, data): 12 | for key, val in data.items(): 13 | super(Example, self).__setattr__(key, val) 14 | 15 | def __setattr__(self, key, value): 16 | raise AttributeError 17 | 18 | def __hash__(self): 19 | return hash(tuple(x['image_id'] if isinstance(x, dict) else x for x in self.__dict__.values())) 20 | 21 | def __eq__(self, other): 22 | this = tuple(x for x in self.__dict__.values()) 23 | other = tuple(x for x in other.__dict__.values()) 24 | return this == other 25 | 26 | def __ne__(self, other): 27 | return not self.__eq__(other) 28 | -------------------------------------------------------------------------------- /common/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /tools/extract_clip_feature.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | from PIL import Image 6 | from pycocotools.coco import COCO 7 | from common.models.backbone import clip 8 | 9 | def extract_feature(clip_variant, device): 10 | image_dir = 'coco/images/%s2014' 11 | base_path = 'coco/annotations/captions_%s2014.json' 12 | 13 | save_path = os.path.join('coco/features', 'COCO2014_%s.hdf5' % clip_variant) 14 | f = h5py.File(save_path, mode='w') 15 | 16 | clip_model, transform = clip.load(clip_variant, device=device ,jit=False) 17 | image_model = clip_model.visual.to(device).eval() 18 | image_model.forward = image_model.intermediate_features 19 | 20 | for split in ['train', 'val']: 21 | ann_path = base_path % split 22 | coco = COCO(ann_path) 23 | 24 | with torch.no_grad(): 25 | for img_id, img in tqdm(coco.imgs.items(), split): 26 | image_path = os.path.join(image_dir % split ,img['file_name']) 27 | 28 | image = Image.open(image_path).convert('RGB') 29 | image = transform(image) 30 | 31 | image = image.to(device).unsqueeze(0) 32 | gird, x = image_model.forward2(image) 33 | 34 | gird = gird.squeeze(0).cpu().numpy() 35 | x = x.squeeze(0).cpu().numpy() 36 | f.create_dataset('%s_features' % img_id, data=gird) 37 | f.create_dataset('%s_global' % img_id, data=x) 38 | 39 | f.close() 40 | 41 | if __name__=='__main__': 42 | clip_variant = 'RN50x4' 43 | device = 'cuda:0' 44 | extract_feature(clip_variant, device) -------------------------------------------------------------------------------- /common/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | import six 11 | from six.moves import cPickle 12 | from .cider_scorer import CiderScorer 13 | 14 | class Cider: 15 | """ 16 | Main Class to compute the CIDEr metric 17 | 18 | """ 19 | def __init__(self, gts=None, n=4, sigma=6.0, df_path=None): 20 | # set cider to sum over 1 to 4-grams 21 | self._n = n 22 | # set the standard deviation parameter for gaussian penalty 23 | self._sigma = sigma 24 | self.doc_frequency = None 25 | self.ref_len = None 26 | 27 | if df_path is not None: 28 | pkl_file = cPickle.load(open(df_path,'rb'), **(dict(encoding='latin1') if six.PY3 else {})) 29 | self.document_frequency = pkl_file['document_frequency'] 30 | self.ref_len = np.log(float(pkl_file['ref_len'])) 31 | elif gts is not None: 32 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 33 | self.doc_frequency = tmp_cider.doc_frequency 34 | self.ref_len = tmp_cider.ref_len 35 | 36 | def compute_score(self, gts, res): 37 | """ 38 | Main function to compute CIDEr score 39 | :param gts (dict) : dictionary with key and value 40 | res (dict) : dictionary with key and value 41 | :return: cider (float) : computed CIDEr score for the corpus 42 | """ 43 | assert(gts.keys() == res.keys()) 44 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 45 | ref_len=self.ref_len) 46 | return cider_scorer.compute_score() 47 | 48 | def __str__(self): 49 | return 'CIDEr' 50 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from common.data.field import DualImageField, TextField 3 | from common.train import train 4 | from common.utils.utils import create_dataset 5 | from models import build_encoder, build_decoder, Transformer 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='Dual Transformer') 10 | parser.add_argument('--output', type=str, default='DualModel') 11 | parser.add_argument('--exp_name', type=str, default='dft') 12 | parser.add_argument('--device', type=str, default='cuda:0') 13 | parser.add_argument('--batch_size', type=int, default=20) 14 | parser.add_argument('--workers', type=int, default=8) 15 | 16 | parser.add_argument('--N_enc', type=int, default=3) 17 | parser.add_argument('--N_dec', type=int, default=3) 18 | 19 | parser.add_argument('--xe_base_lr', type=float, default=1e-4) 20 | parser.add_argument('--rl_base_lr', type=float, default=5e-6) 21 | parser.add_argument('--use_rl', action='store_true') 22 | parser.add_argument('--resume_last', action='store_true') 23 | parser.add_argument('--resume_best', action='store_true') 24 | parser.add_argument('--clip_path', type=str, default='coco/features/COCO2014_RN50x4_GLOBAL.hdf5') 25 | parser.add_argument('--vinvl_path', type=str, default='coco/features/COCO2014_VinVL.hdf5') 26 | parser.add_argument('--image_folder', type=str, default='coco/images') 27 | parser.add_argument('--annotation_folder', type=str, default='coco/annotations') 28 | args = parser.parse_args() 29 | print(args) 30 | 31 | return args 32 | 33 | 34 | def main(args): 35 | print('Dual Transformer Training') 36 | 37 | # Pipeline for image features 38 | image_field = DualImageField(args.clip_path, args.vinvl_path, max_detections=50) 39 | 40 | # Pipeline for text 41 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 42 | remove_punctuation=True, nopoints=False) 43 | 44 | datasets = create_dataset(args, image_field, text_field) 45 | 46 | encoder = build_encoder(args.N_enc, device=args.device) 47 | decoder = build_decoder(len(text_field.vocab), 54, args.N_dec, text_field.vocab.stoi['']) 48 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(args.device) 49 | 50 | train(args, model, datasets, image_field, text_field) 51 | 52 | if __name__ == "__main__": 53 | args = parse_args() 54 | main(args) 55 | -------------------------------------------------------------------------------- /common/evaluation/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /common/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | import os, pickle 4 | from common.data.field import TextField 5 | from models import build_encoder, build_decoder, Transformer, TransformerEnsemble 6 | 7 | def visualize(model, dual_features, text_field, device): 8 | model.eval() 9 | with torch.no_grad(): 10 | dual_features = [x.to(device) for x in dual_features] 11 | out, _ = model.beam_search(dual_features, 20, text_field.vocab.stoi[''], 1, out_size=1) 12 | 13 | caps_gen = text_field.decode(out, join_words=False) 14 | caps_gen = ' '.join(caps_gen[0]).strip() 15 | return caps_gen 16 | 17 | def get_features_by_id(image_id, max_detections=50): 18 | clip_path = 'coco/features/COCO2014_RN50x4_GLOBAL.hdf5' 19 | vinvl_path = 'coco/features/COCO2014_VinVL.hdf5' 20 | clip_file = h5py.File(clip_path, 'r') 21 | vinvl_file = h5py.File(vinvl_path, 'r') 22 | 23 | feature_key = '%d_features' % image_id 24 | boxs_key = '%d_boxes' % image_id 25 | gird_feature = torch.from_numpy(clip_file[feature_key][()]) 26 | region_feature = torch.from_numpy(vinvl_file[feature_key][()]) 27 | boxes = torch.from_numpy(vinvl_file[boxs_key][()]) 28 | 29 | delta = max_detections - region_feature.shape[0] 30 | if delta > 0: 31 | region_feature = torch.cat([region_feature, torch.zeros((delta, region_feature.shape[1]))], 0) 32 | elif delta < 0: 33 | region_feature = region_feature[:max_detections] 34 | 35 | return gird_feature, region_feature, boxes 36 | 37 | def test(): 38 | image_id = 108982 39 | device = 'cuda:0' 40 | # model_path = 'coco/checkpoints/DualModel/dual_add_best.pth' 41 | # model_path = 'coco/checkpoints/DualModel/dual_fuse_global_best.pth' 42 | model_path = 'coco/checkpoints/DualModel/dual_fuse_global5_rl_best.pth' 43 | 44 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 45 | remove_punctuation=True, nopoints=False) 46 | vocab_path = 'cache/vocab.pkl' 47 | text_field.vocab = pickle.load(open(vocab_path, 'rb')) 48 | 49 | encoder = build_encoder(3, device=device) 50 | decoder = build_decoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 51 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 52 | 53 | data = torch.load(model_path, map_location=device) 54 | model.load_state_dict(data['state_dict']) 55 | print(data['best_cider']) 56 | 57 | gird_feature, region_feature, boxes = get_features_by_id(image_id) 58 | dual_features = [gird_feature.unsqueeze(0), region_feature.unsqueeze(0) ] 59 | caps_gen = visualize(model, dual_features, text_field, device) 60 | print(caps_gen) 61 | 62 | -------------------------------------------------------------------------------- /common/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | from .bleu import Bleu 3 | # from .meteor import Meteor 4 | from .rouge import Rouge 5 | from .cider import Cider 6 | from .tokenizer import PTBTokenizer 7 | # from spice.spice import Spice 8 | 9 | class COCOEvalCap: 10 | def __init__(self, coco, cocoRes): 11 | self.evalImgs = [] 12 | self.eval = {} 13 | self.imgToEval = {} 14 | self.coco = coco 15 | self.cocoRes = cocoRes 16 | self.params = {'image_id': coco.getImgIds()} 17 | 18 | def evaluate(self): 19 | imgIds = self.params['image_id'] 20 | # imgIds = self.coco.getImgIds() 21 | gts = {} 22 | res = {} 23 | for imgId in imgIds: 24 | gts[imgId] = [item['caption'] for item in self.coco.imgToAnns[imgId]] 25 | res[imgId] = [item['caption'] for item in self.cocoRes.imgToAnns[imgId]] 26 | 27 | # ================================================= 28 | # Set up scorers 29 | # ================================================= 30 | tokenizer = PTBTokenizer() 31 | gts = tokenizer.tokenize(gts) 32 | res = tokenizer.tokenize(res) 33 | 34 | # ================================================= 35 | # Set up scorers 36 | # ================================================= 37 | scorers = [ 38 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 39 | # (Meteor(),"METEOR"), 40 | (Rouge(), "ROUGE_L"), 41 | (Cider(), "CIDEr"), 42 | # (Spice(), "SPICE") 43 | ] 44 | 45 | # ================================================= 46 | # Compute scores 47 | # ================================================= 48 | for scorer, method in scorers: 49 | score, scores = scorer.compute_score(gts, res) 50 | if type(method) == list: 51 | for sc, scs, m in zip(score, scores, method): 52 | self.setEval(sc, m) 53 | self.setImgToEvalImgs(scs, gts.keys(), m) 54 | else: 55 | self.setEval(score, method) 56 | self.setImgToEvalImgs(scores, gts.keys(), method) 57 | self.setEvalImgs() 58 | 59 | def setEval(self, score, method): 60 | self.eval[method] = score 61 | 62 | def setImgToEvalImgs(self, scores, imgIds, method): 63 | for imgId, score in zip(imgIds, scores): 64 | if not imgId in self.imgToEval: 65 | self.imgToEval[imgId] = {} 66 | self.imgToEval[imgId]["image_id"] = imgId 67 | self.imgToEval[imgId][method] = score 68 | 69 | def setEvalImgs(self): 70 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /common/models/captioning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | import common.utils as utils 4 | from .containers import Module 5 | from .beam_search import * 6 | 7 | 8 | class CaptioningModel(Module): 9 | def __init__(self): 10 | super(CaptioningModel, self).__init__() 11 | 12 | def init_weights(self): 13 | raise NotImplementedError 14 | 15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 16 | raise NotImplementedError 17 | 18 | def forward(self, images, seq, *args): 19 | device = images.device 20 | b_s = images.size(0) 21 | seq_len = seq.size(1) 22 | state = self.init_state(b_s, device) 23 | out = None 24 | 25 | outputs = [] 26 | for t in range(seq_len): 27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') 28 | outputs.append(out) 29 | 30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) 31 | return outputs 32 | 33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 34 | b_s = utils.get_batch_size(visual) 35 | device = utils.get_device(visual) 36 | outputs = [] 37 | log_probs = [] 38 | 39 | mask = torch.ones((b_s,), device=device) 40 | with self.statefulness(b_s): 41 | out = None 42 | for t in range(max_len): 43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) 44 | out = torch.max(log_probs_t, -1)[1] 45 | mask = mask * (out.squeeze(-1) != eos_idx).float() 46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) 47 | outputs.append(out) 48 | 49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 50 | 51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 52 | b_s = utils.get_batch_size(visual) 53 | outputs = [] 54 | log_probs = [] 55 | 56 | with self.statefulness(b_s): 57 | out = None 58 | for t in range(max_len): 59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs) 60 | distr = distributions.Categorical(logits=out[:, 0]) 61 | out = distr.sample().unsqueeze(1) 62 | outputs.append(out) 63 | log_probs.append(distr.log_prob(out).unsqueeze(1)) 64 | 65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 66 | 67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1, 68 | return_probs=False, **kwargs): 69 | bs = BeamSearch(self, max_len, eos_idx, beam_size) 70 | return bs.apply(visual, out_size, return_probs, **kwargs) 71 | -------------------------------------------------------------------------------- /common/models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from ..utils.typing import * 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | self._buffers[name] = fn(self._buffers[name]) 31 | for m in self.children(): 32 | if isinstance(m, Module): 33 | m.apply_to_states(fn) 34 | 35 | def _init_states(self, batch_size: int): 36 | for name in self._state_names: 37 | if self._state_defaults[name] is None: 38 | self._buffers[name] = None 39 | else: 40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 41 | self._buffers[name] = self._buffers[name].unsqueeze(0) 42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 43 | self._buffers[name] = self._buffers[name].contiguous() 44 | 45 | def _reset_states(self): 46 | for name in self._state_names: 47 | if self._state_defaults[name] is None: 48 | self._buffers[name] = None 49 | else: 50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 51 | 52 | def enable_statefulness(self, batch_size: int): 53 | for m in self.children(): 54 | if isinstance(m, Module): 55 | m.enable_statefulness(batch_size) 56 | self._init_states(batch_size) 57 | self._is_stateful = True 58 | 59 | def disable_statefulness(self): 60 | for m in self.children(): 61 | if isinstance(m, Module): 62 | m.disable_statefulness() 63 | self._reset_states() 64 | self._is_stateful = False 65 | 66 | @contextmanager 67 | def statefulness(self, batch_size: int): 68 | self.enable_statefulness(batch_size) 69 | try: 70 | yield 71 | finally: 72 | self.disable_statefulness() 73 | 74 | 75 | class ModuleList(nn.ModuleList, Module): 76 | pass 77 | 78 | 79 | class ModuleDict(nn.ModuleDict, Module): 80 | pass 81 | -------------------------------------------------------------------------------- /common/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | # from common.utils.utils import download_from_url 9 | 10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 11 | METEOR_JAR = 'meteor-1.5.jar' 12 | 13 | class Meteor: 14 | def __init__(self): 15 | base_path = os.path.dirname(os.path.abspath(__file__)) 16 | jar_path = os.path.join(base_path, METEOR_JAR) 17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 18 | # if not os.path.isfile(jar_path): 19 | # if not os.path.isfile(gz_path): 20 | # download_from_url(METEOR_GZ_URL, gz_path) 21 | # tar = tarfile.open(gz_path, "r") 22 | # tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 23 | # tar.close() 24 | # os.remove(gz_path) 25 | 26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 27 | '-', '-', '-stdio', '-l', 'en', '-norm'] 28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 29 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 30 | stdin=subprocess.PIPE, \ 31 | stdout=subprocess.PIPE, \ 32 | stderr=subprocess.PIPE) 33 | # Used to guarantee thread safety 34 | self.lock = threading.Lock() 35 | 36 | def compute_score(self, gts, res): 37 | assert(gts.keys() == res.keys()) 38 | imgIds = gts.keys() 39 | scores = [] 40 | 41 | eval_line = 'EVAL' 42 | self.lock.acquire() 43 | for i in imgIds: 44 | assert(len(res[i]) == 1) 45 | stat = self._stat(res[i][0], gts[i]) 46 | eval_line += ' ||| {}'.format(stat) 47 | 48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 49 | self.meteor_p.stdin.flush() 50 | for i in range(0,len(imgIds)): 51 | scores.append(float(self.meteor_p.stdout.readline().strip())) 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | self.lock.release() 54 | 55 | return score, scores 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 62 | self.meteor_p.stdin.flush() 63 | raw = self.meteor_p.stdout.readline().decode().strip() 64 | numbers = [str(int(float(n))) for n in raw.split()] 65 | return ' '.join(numbers) 66 | 67 | def __del__(self): 68 | self.lock.acquire() 69 | self.meteor_p.stdin.close() 70 | self.meteor_p.kill() 71 | self.meteor_p.wait() 72 | self.lock.release() 73 | 74 | def __str__(self): 75 | return 'METEOR' 76 | -------------------------------------------------------------------------------- /common/models/transformer/PolarRPE.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | # 我提出的基于grid的极坐标相对位置编码 6 | class PolarRPE(nn.Module): 7 | def __init__(self, k=3, h=8, d_k=64, d_r=256, window_size = (9, 9), device='cuda:0'): 8 | super(PolarRPE, self).__init__() 9 | Wh, Ww = window_size 10 | self.h = h 11 | self.d_k = d_k 12 | self.num_seq = Wh * Ww 13 | # num_direction = 4 * k + 1 14 | num_direction = 4 * k 15 | num_distance = math.floor(math.sqrt(Wh*Wh + Ww*Ww)) 16 | 17 | # define a parameter table of relative position 18 | self.relative_table = nn.Embedding(num_direction * num_distance, d_r) 19 | self.projection = nn.Linear(d_r, h * d_k) 20 | 21 | # get pair-wise relative position index for each token inside the window 22 | coords_h, coords_w = torch.arange(Wh), torch.arange(Ww) 23 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]), dim=-1) # Wh, Ww, 2 24 | coords_flatten = coords.view(-1, 2) # Wh*Ww, 2 25 | relative_coords = coords_flatten.unsqueeze(1) - coords_flatten.unsqueeze(0) # Wh*Ww, Wh*Ww, 2 26 | relative_coords = relative_coords.view(-1, 2).float() # N*N, 2 27 | 28 | # relative_distance_pos 29 | norm_relative_distance = torch.norm(relative_coords, dim=-1) 30 | relative_distance_pos = norm_relative_distance.int() # N*N 31 | 32 | # relative_direction_pos 33 | unit_direction_x = torch.cos(torch.arange(num_direction - 1) * math.pi / 2 / k) 34 | unit_direction_y = torch.sin(torch.arange(num_direction - 1) * math.pi / 2 / k) 35 | unit_direction = torch.stack([unit_direction_x, unit_direction_y]) # 2, 4k 36 | 37 | relative_direction = torch.matmul(relative_coords, unit_direction) 38 | relative_direction_pos = torch.argmax(relative_direction, dim=-1) # N*N 39 | # relative_direction_pos = relative_direction_pos.masked_fill(norm_relative_distance == 0, num_direction-1) 40 | 41 | relative_pos = relative_direction_pos * num_distance + relative_distance_pos 42 | # relative_pos = relative_pos.masked_fill(norm_relative_distance == 0, num_direction * num_distance) 43 | 44 | self.relative_pos = relative_pos.to(device) 45 | 46 | self.init_weights() 47 | 48 | def init_weights(self): 49 | nn.init.uniform_(self.relative_table.weight, b=0.2) 50 | nn.init.xavier_uniform_(self.projection.weight) 51 | nn.init.constant_(self.projection.bias, 0) 52 | 53 | def forward(self, bs): 54 | 55 | relative_emb = self.relative_table(self.relative_pos) 56 | relative_emb = self.projection(relative_emb).view(-1, self.h, self.d_k) # (n*n, h, d_k) 57 | 58 | relative_emb = relative_emb.view(self.num_seq, self.num_seq, self.h, self.d_k).permute(2, 0, 1, 3) 59 | relative_emb = relative_emb.unsqueeze(0).expand(bs, self.h, self.num_seq, self.num_seq, self.d_k) # (b_s, h, n, n, d_k) 60 | 61 | return relative_emb 62 | 63 | if __name__ == '__main__': 64 | rpe = PolarRPE(device='cpu') 65 | rpe(2) 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /common/data/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib, sys 2 | 3 | class DummyFile(object): 4 | def write(self, x): pass 5 | 6 | @contextlib.contextmanager 7 | def nostdout(): 8 | save_stdout = sys.stdout 9 | sys.stdout = DummyFile() 10 | yield 11 | sys.stdout = save_stdout 12 | 13 | def reporthook(t): 14 | """https://github.com/tqdm/tqdm""" 15 | last_b = [0] 16 | 17 | def inner(b=1, bsize=1, tsize=None): 18 | """ 19 | b: int, optionala 20 | Number of blocks just transferred [default: 1]. 21 | bsize: int, optional 22 | Size of each block (in tqdm units) [default: 1]. 23 | tsize: int, optional 24 | Total size (in tqdm units). If [default: None] remains unchanged. 25 | """ 26 | if tsize is not None: 27 | t.total = tsize 28 | t.update((b - last_b[0]) * bsize) 29 | last_b[0] = b 30 | return inner 31 | 32 | def get_tokenizer(tokenizer): 33 | if callable(tokenizer): 34 | return tokenizer 35 | if tokenizer == "spacy": 36 | try: 37 | from spacy.lang.en import English 38 | spacy_en = English() 39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 40 | except ImportError: 41 | print("Please install SpaCy and the SpaCy English tokenizer. " 42 | "See the docs at https://spacy.io for more information.") 43 | raise 44 | except AttributeError: 45 | print("Please install SpaCy and the SpaCy English tokenizer. " 46 | "See the docs at https://spacy.io for more information.") 47 | raise 48 | elif tokenizer == "moses": 49 | try: 50 | from nltk.tokenize.moses import MosesTokenizer 51 | moses_tokenizer = MosesTokenizer() 52 | return moses_tokenizer.tokenize 53 | except ImportError: 54 | print("Please install NLTK. " 55 | "See the docs at http://nltk.org for more information.") 56 | raise 57 | except LookupError: 58 | print("Please install the necessary NLTK corpora. " 59 | "See the docs at http://nltk.org for more information.") 60 | raise 61 | elif tokenizer == 'revtok': 62 | try: 63 | import revtok 64 | return revtok.tokenize 65 | except ImportError: 66 | print("Please install revtok.") 67 | raise 68 | elif tokenizer == 'subword': 69 | try: 70 | import revtok 71 | return lambda x: revtok.tokenize(x, decap=True) 72 | except ImportError: 73 | print("Please install revtok.") 74 | raise 75 | raise ValueError("Requested tokenizer {}, valid choices are a " 76 | "callable that takes a single string as input, " 77 | "\"revtok\" for the revtok reversible tokenizer, " 78 | "\"subword\" for the revtok caps-aware tokenizer, " 79 | "\"spacy\" for the SpaCy English tokenizer, or " 80 | "\"moses\" for the NLTK port of the Moses tokenization " 81 | "script.".format(tokenizer)) 82 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | 5 | from common.models.captioning_model import CaptioningModel 6 | from common.models.containers import ModuleList 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('grid_features', None) 15 | self.register_state('object_features', None) 16 | self.register_state('mask_enc', None) 17 | self.init_weights() 18 | 19 | @property 20 | def d_model(self): 21 | return self.decoder.d_model 22 | 23 | def init_weights(self): 24 | for p in self.parameters(): 25 | if p.dim() > 1: 26 | nn.init.xavier_uniform_(p) 27 | 28 | def forward(self, images, seq, *args): 29 | if not isinstance(images, tuple) and not isinstance(images, list): 30 | images = [images] 31 | 32 | enc_output = self.encoder(*images) 33 | 34 | if not isinstance(enc_output, tuple) and not isinstance(enc_output, list): 35 | enc_output = [enc_output] 36 | if not isinstance(seq, tuple) and not isinstance(seq, list): 37 | seq = [seq] 38 | 39 | dec_output = self.decoder(*seq, *enc_output) 40 | return dec_output 41 | 42 | def init_state(self, b_s, device): 43 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 44 | None, None] 45 | 46 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 47 | it = None 48 | if mode == 'teacher_forcing': 49 | raise NotImplementedError 50 | elif mode == 'feedback': 51 | if t == 0: 52 | if not isinstance(visual, tuple) and not isinstance(visual, list): 53 | visual = [visual] 54 | grid_features, region_features = visual 55 | self.grid_features, self.object_features, self.mask_enc = self.encoder(grid_features, region_features) 56 | 57 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() # (b_s,1) 58 | else: 59 | it = prev_output 60 | 61 | return self.decoder(it, self.grid_features, self.object_features, self.mask_enc) 62 | 63 | class TransformerEnsemble(CaptioningModel): 64 | def __init__(self, model: Transformer, weight_files, device): 65 | super(TransformerEnsemble, self).__init__() 66 | self.n = len(weight_files) 67 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 68 | for i in range(self.n): 69 | state_dict_i = torch.load(weight_files[i], map_location=device)['state_dict'] 70 | self.models[i].load_state_dict(state_dict_i) 71 | 72 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 73 | out_ensemble = [] 74 | for i in range(self.n): 75 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 76 | out_ensemble.append(out_i.unsqueeze(0)) 77 | 78 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 79 | -------------------------------------------------------------------------------- /common/models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | 5 | from ..captioning_model import CaptioningModel 6 | from ..containers import ModuleList 7 | 8 | 9 | class Transformer(CaptioningModel): 10 | def __init__(self, bos_idx, encoder, decoder): 11 | super(Transformer, self).__init__() 12 | self.bos_idx = bos_idx 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | self.register_state('enc_output', None) 16 | self.register_state('mask_enc', None) 17 | self.init_weights() 18 | 19 | @property 20 | def d_model(self): 21 | return self.decoder.d_model 22 | 23 | def init_weights(self): 24 | for p in self.parameters(): 25 | if p.dim() > 1: 26 | nn.init.xavier_uniform_(p) 27 | 28 | def forward(self, images, seq, *args): 29 | if not isinstance(images, tuple) and not isinstance(images, list): 30 | images = [images] 31 | 32 | enc_output = self.encoder(*images) 33 | 34 | if not isinstance(enc_output, tuple) and not isinstance(enc_output, list): 35 | enc_output = [enc_output] 36 | 37 | dec_output = self.decoder(seq, *enc_output) 38 | return dec_output 39 | 40 | def init_state(self, b_s, device): 41 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 42 | None, None] 43 | 44 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 45 | it = None 46 | if mode == 'teacher_forcing': 47 | raise NotImplementedError 48 | elif mode == 'feedback': 49 | if t == 0: 50 | if not isinstance(visual, tuple) and not isinstance(visual, list): 51 | visual = [visual] 52 | # self.enc_output, self.mask_enc = self.encoder(*visual) 53 | enc_output = self.encoder(*visual) 54 | if isinstance(enc_output, tuple) or isinstance(enc_output, list): 55 | self.enc_output, self.mask_enc = enc_output[0], enc_output[1] 56 | else: 57 | self.enc_output, self.mask_enc = enc_output, None 58 | 59 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 60 | else: 61 | it = prev_output 62 | 63 | return self.decoder(it, self.enc_output, self.mask_enc) 64 | 65 | 66 | class TransformerEnsemble(CaptioningModel): 67 | def __init__(self, model: Transformer, weight_files): 68 | super(TransformerEnsemble, self).__init__() 69 | self.n = len(weight_files) 70 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 71 | for i in range(self.n): 72 | state_dict_i = torch.load(weight_files[i])['state_dict'] 73 | self.models[i].load_state_dict(state_dict_i) 74 | 75 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 76 | out_ensemble = [] 77 | for i in range(self.n): 78 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 79 | out_ensemble.append(out_i.unsqueeze(0)) 80 | 81 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 82 | -------------------------------------------------------------------------------- /common/models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .attention import MultiHeadAttention 4 | from .utils import PositionWiseFeedForward 5 | 6 | 7 | class EncoderLayer(nn.Module): 8 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 9 | attention_module=None, attention_module_kwargs=None): 10 | super(EncoderLayer, self).__init__() 11 | self.identity_map_reordering = identity_map_reordering 12 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 13 | attention_module=attention_module, 14 | attention_module_kwargs=attention_module_kwargs) 15 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 16 | 17 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, pos=None): 18 | if pos is not None: 19 | queries = queries + pos 20 | keys = keys + pos 21 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights=attention_weights) 22 | ff = self.pwff(att) 23 | return ff 24 | 25 | 26 | class TransformerEncoder(nn.Module): 27 | def __init__(self, N, padding_idx = None, d_in=2048, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, multi_level = False, 28 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 29 | super(TransformerEncoder, self).__init__() 30 | self.d_model = d_model 31 | self.dropout = dropout 32 | self.padding_idx = padding_idx 33 | self.multi_level = multi_level 34 | 35 | self.in_proj_model = nn.Sequential( 36 | nn.Linear(d_in, self.d_model), 37 | nn.ReLU(), 38 | nn.Dropout(p=self.dropout), 39 | nn.LayerNorm(self.d_model) 40 | ) 41 | 42 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 43 | identity_map_reordering=identity_map_reordering, 44 | attention_module=attention_module, 45 | attention_module_kwargs=attention_module_kwargs) 46 | for _ in range(N)]) 47 | 48 | def forward(self, input, attention_weights=None, pos = None): 49 | # input (b_s, seq_len, d_in) 50 | attention_mask = None 51 | if self.padding_idx is not None: 52 | # (b_s, 1, 1, seq_len) 53 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) 54 | 55 | out = self.in_proj_model(input) 56 | 57 | if self.multi_level: 58 | outs = [] 59 | for l in self.layers: 60 | out = l(out, out, out, attention_mask, attention_weights) 61 | outs.append(out.unsqueeze(1)) 62 | 63 | outs = torch.cat(outs, 1) 64 | return outs, attention_mask 65 | 66 | else: 67 | for l in self.layers: 68 | out = l(out, out, out, attention_mask, attention_weights, pos) 69 | 70 | return out, attention_mask -------------------------------------------------------------------------------- /tools/transform_vinvl_feature.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import json 4 | import base64 5 | import numpy as np 6 | from tqdm import tqdm 7 | from common.utils.tsv_file import TSVFile 8 | 9 | def extract_feature(detections_path): 10 | save_path = os.path.join('coco/features', 'COCO2014_VinVL_TEST.hdf5') 11 | f = h5py.File(save_path, mode='w') 12 | 13 | file_feat = TSVFile(os.path.join(detections_path, 'features.tsv')) 14 | file_label = TSVFile(os.path.join(detections_path, 'predictions.tsv')) 15 | 16 | for i in tqdm(range(file_feat.num_rows())): 17 | row_feat = file_feat.seek(i) 18 | row_label = file_label.seek(i) 19 | 20 | assert row_feat[0] == row_label[0] 21 | image_id = int(row_feat[0].split('_')[-1]) 22 | 23 | feature = np.frombuffer(base64.b64decode(row_feat[2]), np.float32) 24 | feature = feature.reshape((int(row_feat[1]), -1))[:,:-6] 25 | 26 | row_label = json.loads(row_label[1]) 27 | objects = row_label['objects'] 28 | size = np.array([row_label['image_h'], row_label['image_w']]) 29 | boxes = np.array([]) 30 | cls_prob = np.array([]) 31 | if len(objects): 32 | boxes = np.stack([np.array(l['rect']) for l in objects]) 33 | cls_prob = np.stack([np.array(l['conf']) for l in objects]) 34 | 35 | f.create_dataset('%s_features' % image_id, data=feature) 36 | f.create_dataset('%s_boxes' % image_id, data=boxes) 37 | f.create_dataset('%s_size' % image_id, data=size) 38 | f.create_dataset('%s_cls_prob' % image_id, data=cls_prob) 39 | 40 | f.close() 41 | 42 | def extract_label(detections_path): 43 | res = dict() 44 | cnt = 0 45 | 46 | bi_words = [] 47 | for split in ('train', 'val', 'test'): 48 | file_label = TSVFile(os.path.join(detections_path, '%s.label.tsv' % split)) 49 | 50 | for i in tqdm(range(file_label.num_rows())): 51 | row_label = file_label.seek(i) 52 | image_id = row_label[0] 53 | 54 | labels = json.loads(row_label[1]) 55 | 56 | arr = [] 57 | s = set() 58 | for l in labels: 59 | x = l['class'] 60 | 61 | words = x.split() 62 | if 'and' in words: 63 | words.remove('and') 64 | if '&' in words: 65 | words.remove('&') 66 | if 'Human' in words: 67 | continue 68 | 69 | for word in words: 70 | if word not in s: 71 | s.add(word) 72 | arr.append(word) 73 | 74 | if len(x.split()) > 1: 75 | bi_words.append(x) 76 | 77 | if len(arr) == 0: 78 | cnt += 1 79 | print('image_id: %s' % image_id) 80 | 81 | line = ' '.join(arr) 82 | res[image_id] = line 83 | 84 | with open('coco/features/COCO2014_VinVL_labels.json','w') as fp: 85 | json.dump(res,fp) 86 | 87 | print('nums of Nill:%s' % cnt) 88 | print(set(bi_words)) 89 | 90 | if __name__=='__main__': 91 | # Please go to https://github.com/pzzhang/VinVL/blob/main/DOWNLOAD.md to download the original VinVL features 92 | 93 | detections_path = 'Folder path of the downloaded VinVL features' 94 | # extract_label(detections_path) 95 | extract_feature(detections_path) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross on Cross Attention: Deep Fusion Transformer for Image Captioning 2 | 3 | ## Environment setup 4 | Clone the repository and create the `dft` conda environment using the `environment.yml` file: 5 | ``` 6 | conda env create -f environment.yml 7 | conda activate dft 8 | ``` 9 | 10 | Then download spacy data by executing the following command: 11 | ``` 12 | python -m spacy download en 13 | ``` 14 | 15 | Note: Python 3.8 is required to run our code. 16 | 17 | 18 | ## Data preparation 19 | To run the code, annotations and visual features for the COCO dataset are needed. Please download the annotations file [annotations.zip](https://pan.baidu.com/s/17ik-2OZGFaQ5-AzCCWkL9w) (Extraction code: ska0) and extract it. 20 | 21 | To reproduce our result, please generate the corresponding feature files (`COCO2014_RN50x4_GLOBAL.hdf5`, `COCO2014_VinVL.hdf5`) using the code in the tools folder, in which features of each image are stored under the `_features` key. `` is the id of each COCO image, without leading zeros (e.g. the `` for `COCO_val2014_000000037209.jpg` is `37209`). VinVL region feature dimension is (N, 2048), N is the number of region features; CLIP grid feature dimension is (M, 2560), M is the number of grid features. 22 | 23 | 24 | ## Evaluation 25 | To reproduce the results reported in our paper, download the pretrained model file [dft.pth](https://pan.baidu.com/s/17ik-2OZGFaQ5-AzCCWkL9w) (Extraction code: ska0) and place it in the code folder. 26 | 27 | 28 | ## Training procedure 29 | Run `python train.py` using the following arguments: 30 | 31 | | Argument | Possible values | 32 | |------|------| 33 | | `--output` | Output path| 34 | | `--exp_name` | Experiment name| 35 | | `--batch_size` | Batch size (default: 20) | 36 | | `--workers` | Number of workers (default: 8) | 37 | | `--warmup` | Warmup value for learning rate scheduling (default: 10000) | 38 | | `--N_enc` | Number of encoder layers| 39 | | `--N_dec` | Number of decoder layers| 40 | | `--resume_last` | If used, the training will be resumed from the last checkpoint. | 41 | | `--resume_best` | If used, the training will be resumed from the best checkpoint. | 42 | | `--use_rl` | Whether to turn on reinforcement learning| 43 | | `--clip_path` | CLIP grid feature path| 44 | | `--vinvl_path` | VinVL region feature path| 45 | | `--features_path` | Path to detection features file | 46 | | `--annotation_folder` | Path to folder with COCO annotations | 47 | 48 | For example, to train our model with the parameters used in our experiments, use 49 | ``` 50 | python train.py --exp_name dft --batch_size 20 --clip_path /path/to/clip_gird_features --vinvl_path /path/to/vinvl_region_features --annotation_folder /path/to/annotations 51 | ``` 52 | 53 | #### References 54 | [1] Cornia M, Stefanini M, Baraldi L, et al. Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 55 | [2] Radford A, Kim J W, Hallacy C, et al. Learning transferable visual models from natural language supervision. In International Conference on Machine Learning. 56 | [3] Zhang P, Li X, Hu X, et al. Vinvl: Revisiting visual representations in vision-language models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 57 | 58 | #### Acknowledgements 59 | Thank Cornia _et.al_ for their open source code ([meshed-memory-transformer 60 | ](https://github.com/aimagelab/meshed-memory-transformer)), on which our implements are based. 61 | Thanks to Zhang et al. for the powerful region features ([VinVL](https://github.com/pzzhang/VinVL)). 62 | -------------------------------------------------------------------------------- /models/encoder_gce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from common.models.transformer.attention import MultiHeadAttention, NormSelfAttention 5 | from common.models.transformer.utils import PolarRPE, PositionWiseFeedForward, RelationalEmbedding 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1): 10 | super(EncoderLayer, self).__init__() 11 | 12 | self.self_grid = MultiHeadAttention(d_model, d_k, d_v, h, dropout) 13 | 14 | self.self_region = MultiHeadAttention(d_model, d_k, d_v, h, dropout) 15 | self.global_grid = MultiHeadAttention(d_model, d_k, d_v, h, dropout, shortcut=False) 16 | self.global_region = MultiHeadAttention(d_model, d_k, d_v, h, dropout, shortcut=False) 17 | 18 | self.cls_grid = nn.Parameter(torch.randn(1, 1, d_model), requires_grad=True) 19 | self.cls_region = nn.Parameter(torch.randn(1, 1, d_model), requires_grad=True) 20 | 21 | self.pwff_grid = PositionWiseFeedForward(d_model, d_ff, dropout) 22 | self.pwff_region = PositionWiseFeedForward(d_model, d_ff, dropout) 23 | 24 | def forward(self, gird_features, region_features, attention_mask): 25 | b_s = region_features.shape[0] 26 | cls_grid = self.cls_grid.expand(b_s, 1, -1) 27 | cls_region = self.cls_region.expand(b_s, 1, -1) 28 | 29 | cls_grid = self.global_grid(cls_grid, gird_features, gird_features) 30 | cls_region = self.global_region(cls_region, region_features, region_features, attention_mask=attention_mask) 31 | 32 | gird_features = torch.cat([cls_region, gird_features], dim=1) 33 | region_features = torch.cat([cls_grid, region_features], dim=1) 34 | 35 | add_mask = torch.zeros(b_s, 1, 1, 1).bool().to(region_features.device) 36 | attention_mask = torch.cat([add_mask, attention_mask], dim=-1) 37 | grid_att = self.self_grid(gird_features, gird_features, gird_features) 38 | region_att = self.self_region(region_features, region_features, region_features, attention_mask=attention_mask) 39 | 40 | gird_ff = self.pwff_grid(grid_att) 41 | region_ff = self.pwff_region(region_att) 42 | 43 | gird_ff = gird_ff[:,1:] 44 | region_ff = region_ff[:,1:] 45 | 46 | return gird_ff, region_ff 47 | 48 | class TransformerEncoder(nn.Module): 49 | def __init__(self, N, device='cuda', d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1): 50 | super(TransformerEncoder, self).__init__() 51 | self.d_model = d_model 52 | self.dropout = dropout 53 | self.device = device 54 | 55 | self.grid_proj = nn.Sequential( 56 | nn.Linear(2560, self.d_model), 57 | nn.ReLU(), 58 | nn.Dropout(p=self.dropout), 59 | nn.LayerNorm(self.d_model) 60 | ) 61 | 62 | self.region_proj = nn.Sequential( 63 | nn.Linear(2048, self.d_model), 64 | nn.ReLU(), 65 | nn.Dropout(p=self.dropout), 66 | nn.LayerNorm(self.d_model) 67 | ) 68 | 69 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout) for _ in range(N)]) 70 | 71 | def forward(self, grid_features, region_features): 72 | # input (b_s, seq_len) 73 | b_s = region_features.shape[0] 74 | attention_mask = (torch.sum(torch.abs(region_features), -1) == 0).unsqueeze(1).unsqueeze(1) 75 | grid_features = self.grid_proj(grid_features) 76 | region_features = self.region_proj(region_features) 77 | 78 | for l in self.layers: 79 | grid_features, region_features = l(grid_features, region_features, attention_mask) 80 | 81 | return grid_features, region_features, attention_mask 82 | 83 | def build_encoder(N, device='cuda', d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1): 84 | Encoder = TransformerEncoder(N, device, d_model, d_k, d_v, h, d_ff, dropout) 85 | 86 | return Encoder -------------------------------------------------------------------------------- /common/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /common/online_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import torch 4 | from tqdm import tqdm 5 | import argparse 6 | import numpy as np 7 | import os, pickle 8 | from common.utils.utils import setup_seed 9 | from common.data import OnlineTestDataset, DataLoader 10 | from common.data.field import RawField, TextField, DualImageField 11 | from models import build_encoder, build_decoder, Transformer, TransformerEnsemble 12 | 13 | setup_seed(1234) 14 | torch.backends.cudnn.benchmark = True 15 | 16 | def online_test(model, dataloader, text_field, device): 17 | import itertools 18 | model.eval() 19 | result = [] 20 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: 21 | for it, (images, image_ids) in enumerate(iter(dataloader)): 22 | with torch.no_grad(): 23 | if isinstance(images, tuple) or isinstance(images, list): 24 | images = [x.to(device) for x in images] 25 | else: 26 | images = images.to(device) 27 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 28 | 29 | caps_gen = text_field.decode(out, join_words=False) 30 | for i, (image_id, gen_i) in enumerate(zip(image_ids, caps_gen)): 31 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 32 | result.append({"image_id": image_id.item(), "caption": gen_i.strip()}) 33 | pbar.update() 34 | return result 35 | 36 | def test(): 37 | parser = argparse.ArgumentParser(description='Dual Transformer') 38 | parser.add_argument('--batch_size', type=int, default=100) 39 | parser.add_argument('--workers', type=int, default=8) 40 | parser.add_argument('--device', type=str, default='cuda:0') 41 | parser.add_argument('--model_path', type=str, default='coco/checkpoints/DualModel/dual_fuse_global_rl_best2.pth') 42 | parser.add_argument('--feature_type', type=str, default='clip') 43 | parser.add_argument('--features_path', type=str, default='coco/features/COCO2014_RN50x4_GLOBAL.hdf5') 44 | parser.add_argument('--image_folder', type=str, default='coco/images') 45 | parser.add_argument('--annotation_folder', type=str, default='coco/annotations') 46 | args = parser.parse_args() 47 | 48 | print('Dual Transformer Evaluation') 49 | 50 | # Pipeline for image regions 51 | # image_field = ImageDetectionsField(feature_type=args.feature_type, detections_path=args.features_path, max_detections=50) 52 | image_field = DualImageField(max_detections=50, global_feature=False, online_test=True) 53 | 54 | # Pipeline for text 55 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 56 | remove_punctuation=True, nopoints=False) 57 | vocab_path = 'cache/vocab.pkl' 58 | text_field.vocab = pickle.load(open(vocab_path, 'rb')) 59 | 60 | ann_path = 'coco/annotations/image_info_test2014.json' 61 | # ann_path = 'coco/annotations/captions_val2014.json' 62 | dataset_test = OnlineTestDataset(ann_path, {'image': image_field, 'image_id': RawField()}) 63 | dict_dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, drop_last=False) 64 | 65 | # Model and dataloaders 66 | encoder = build_encoder(3, device=args.device) 67 | decoder = build_decoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 68 | 69 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder) 70 | weight_files = ['coco/checkpoints/DualModel/dual_fuse_global_rl_best2.pth', 71 | 'coco/checkpoints/DualModel/dual_fuse_global2_rl_best.pth', 72 | 'coco/checkpoints/DualModel/dual_fuse_global3_rl_best.pth', 73 | 'coco/checkpoints/DualModel/dual_fuse_global5_rl_best.pth'] 74 | model = TransformerEnsemble(model, weight_files, args.device).to(args.device) 75 | 76 | result = online_test(model, dict_dataloader_test, text_field, args.device) 77 | 78 | output_path = 'outputs/result/captions_test2014_DFT2_results.json' 79 | # output_path = 'outputs/result/captions_val2014_DFT2_results.json' 80 | with open(output_path, 'w') as fp: 81 | json.dump(result, fp) 82 | 83 | print("Online test result is over") -------------------------------------------------------------------------------- /common/evaluation/spice/spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import subprocess 4 | import json 5 | import numpy as np 6 | import tempfile 7 | import tarfile 8 | # from common.utils.utils import download_from_url 9 | 10 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 11 | SPICE_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/spice.tgz' 12 | SPICE_JAR = 'spice-1.0.jar' 13 | TEMP_DIR = 'tmp' 14 | CACHE_DIR = 'cache' 15 | 16 | 17 | class Spice: 18 | """ 19 | Main Class to compute the SPICE metric 20 | """ 21 | 22 | def __init__(self): 23 | base_path = os.path.dirname(os.path.abspath(__file__)) 24 | jar_path = os.path.join(base_path, SPICE_JAR) 25 | gz_path = os.path.join(base_path, os.path.basename(SPICE_GZ_URL)) 26 | # if not os.path.isfile(jar_path): 27 | # if not os.path.isfile(gz_path): 28 | # download_from_url(SPICE_GZ_URL, gz_path) 29 | # tar = tarfile.open(gz_path, "r") 30 | # tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 31 | # tar.close() 32 | # os.remove(gz_path) 33 | 34 | 35 | def float_convert(self, obj): 36 | try: 37 | return float(obj) 38 | except: 39 | return np.nan 40 | 41 | def compute_score(self, gts, res): 42 | assert (sorted(gts.keys()) == sorted(res.keys())) 43 | imgIds = sorted(gts.keys()) 44 | 45 | # Prepare temp input file for the SPICE scorer 46 | input_data = [] 47 | for id in imgIds: 48 | hypo = res[id] 49 | ref = gts[id] 50 | 51 | # Sanity check. 52 | assert (type(hypo) is list) 53 | assert (len(hypo) == 1) 54 | assert (type(ref) is list) 55 | assert (len(ref) >= 1) 56 | 57 | input_data.append({ 58 | "image_id": id, 59 | "test": hypo[0], 60 | "refs": ref 61 | }) 62 | 63 | cwd = os.path.dirname(os.path.abspath(__file__)) 64 | temp_dir = os.path.join(cwd, TEMP_DIR) 65 | if not os.path.exists(temp_dir): 66 | os.makedirs(temp_dir) 67 | in_file = tempfile.NamedTemporaryFile('w+', delete=False, dir=temp_dir, encoding='utf8') 68 | json.dump(input_data, in_file, indent=2) 69 | in_file.close() 70 | 71 | # Start job 72 | out_file = tempfile.NamedTemporaryFile('w+', delete=False, dir=temp_dir, encoding='utf8') 73 | out_file.close() 74 | cache_dir = os.path.join(cwd, CACHE_DIR) 75 | if not os.path.exists(cache_dir): 76 | os.makedirs(cache_dir) 77 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 78 | '-cache', cache_dir, 79 | '-out', out_file.name, 80 | '-subset', 81 | '-silent' 82 | ] 83 | 84 | try: 85 | from subprocess import DEVNULL # Python 3. 86 | except ImportError: 87 | DEVNULL = open(os.devnull, 'wb') 88 | subprocess.check_call(spice_cmd, 89 | cwd=os.path.dirname(os.path.abspath(__file__)), 90 | stdout=DEVNULL, stderr=DEVNULL) 91 | 92 | # Read and process results 93 | with open(out_file.name) as data_file: 94 | results = json.load(data_file) 95 | os.remove(in_file.name) 96 | os.remove(out_file.name) 97 | 98 | imgId_to_scores = {} 99 | spice_scores = [] 100 | for item in results: 101 | imgId_to_scores[item['image_id']] = item['scores'] 102 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 103 | average_score = np.mean(np.array(spice_scores)) 104 | scores = [] 105 | for image_id in imgIds: 106 | # Convert none to NaN before saving scores over subcategories 107 | score_set = {} 108 | for category, score_tuple in imgId_to_scores[image_id].items(): 109 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} 110 | scores.append(score_set) 111 | return average_score, scores 112 | 113 | def __str__(self): 114 | return 'SPICE' -------------------------------------------------------------------------------- /common/test.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from tqdm import tqdm 4 | import argparse 5 | import numpy as np 6 | import common.evaluation as evaluation 7 | from common.utils.utils import create_dataset, setup_seed 8 | from common.data import DataLoader 9 | from common.data.field import RawField, TextField, DualImageField 10 | from models import build_encoder, build_decoder, Transformer, TransformerEnsemble 11 | 12 | setup_seed(1234) 13 | torch.backends.cudnn.benchmark = True 14 | 15 | def predict_captions(model, dataloader, text_field, device): 16 | import itertools 17 | model.eval() 18 | gen = {} 19 | gts = {} 20 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: 21 | for it, (images, caps_gt) in enumerate(iter(dataloader)): 22 | with torch.no_grad(): 23 | if isinstance(images, tuple) or isinstance(images, list): 24 | images = [x.to(device) for x in images] 25 | else: 26 | images = images.to(device) 27 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 28 | 29 | caps_gen = text_field.decode(out, join_words=False) 30 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 31 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 32 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ] 33 | gts['%d_%d' % (it, i)] = gts_i 34 | pbar.update() 35 | 36 | gts = evaluation.PTBTokenizer.tokenize(gts) 37 | gen = evaluation.PTBTokenizer.tokenize(gen) 38 | scores, _ = evaluation.compute_scores(gts, gen) 39 | return scores 40 | 41 | 42 | def test(): 43 | parser = argparse.ArgumentParser(description='Dual Transformer') 44 | parser.add_argument('--batch_size', type=int, default=100) 45 | parser.add_argument('--workers', type=int, default=8) 46 | parser.add_argument('--device', type=str, default='cuda:0') 47 | parser.add_argument('--model_path', type=str, default='Camel/saved_models/Camel_best.pth') 48 | parser.add_argument('--feature_type', type=str, default='clip') 49 | parser.add_argument('--features_path', type=str, default='coco/features/COCO2014_RN50x4_GLOBAL.hdf5') 50 | parser.add_argument('--image_folder', type=str, default='coco/images') 51 | parser.add_argument('--annotation_folder', type=str, default='coco/annotations') 52 | args = parser.parse_args() 53 | 54 | print('Dual Transformer Evaluation') 55 | 56 | # Pipeline for image regions 57 | image_field = DualImageField(max_detections=50, global_feature=False) 58 | 59 | # Pipeline for text 60 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 61 | remove_punctuation=True, nopoints=False) 62 | 63 | # Create the dataset 64 | datasets = create_dataset(args, image_field, text_field) 65 | _, val_dataset, test_dataset = datasets 66 | 67 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 68 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 69 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5, num_workers=args.workers, pin_memory=True, drop_last=False) 70 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5, num_workers=args.workers, pin_memory=True, drop_last=False) 71 | 72 | # Model and dataloaders 73 | encoder = build_encoder(3, device=args.device) 74 | decoder = build_decoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 75 | 76 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(args.device) 77 | 78 | # weight_files = ['coco/checkpoints/DualModel/dual_fuse_global_rl_best2.pth', 79 | # 'coco/checkpoints/DualModel/dual_fuse_global2_rl_best.pth', 80 | # 'coco/checkpoints/DualModel/dual_fuse_global3_rl_last.pth', 81 | # 'coco/checkpoints/DualModel/dual_fuse_global5_rl_best.pth'] 82 | 83 | weight_files = ['coco/checkpoints/DualModel/dual_fuse_global_best.pth', 84 | 'coco/checkpoints/DualModel/dual_fuse_global2_best.pth', 85 | 'coco/checkpoints/DualModel/dual_fuse_global3_last.pth', 86 | 'coco/checkpoints/DualModel/dual_fuse_global5_best.pth'] 87 | 88 | model = TransformerEnsemble(model, weight_files, args.device).to(args.device) 89 | 90 | # data = torch.load(args.model_path, map_location=args.device) 91 | # model.load_state_dict(data['state_dict_t']) 92 | # print(data['best_cider']) 93 | 94 | scores = predict_captions(model, dict_dataloader_val, text_field, args.device) 95 | print("Validation scores", scores) 96 | scores = predict_captions(model, dict_dataloader_test, text_field, args.device) 97 | print("Test scores", scores) -------------------------------------------------------------------------------- /common/models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from .utils import PositionWiseFeedForward, sinusoid_encoding_table 6 | from .attention import MultiHeadAttention 7 | from ..containers import Module, ModuleList 8 | 9 | 10 | class DecoderLayer(Module): 11 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 12 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 13 | super(DecoderLayer, self).__init__() 14 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 15 | attention_module=self_att_module, 16 | attention_module_kwargs=self_att_module_kwargs) 17 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 18 | attention_module=enc_att_module, 19 | attention_module_kwargs=enc_att_module_kwargs) 20 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 21 | 22 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att, pos=None): 23 | self_att = self.self_att(input, input, input, mask_self_att) 24 | self_att = self_att * mask_pad 25 | 26 | if pos is not None: 27 | enc_output = enc_output + pos 28 | 29 | enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att) 30 | enc_att = enc_att * mask_pad 31 | 32 | ff = self.pwff(enc_att) 33 | ff = ff * mask_pad 34 | return ff 35 | 36 | 37 | class TransformerDecoder(Module): 38 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, return_logits = False, 39 | enc_dim=None, self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 40 | super(TransformerDecoder, self).__init__() 41 | self.d_model = d_model 42 | self.pro_flag = False 43 | if enc_dim is not None and enc_dim != d_model: 44 | self.in_proj_model = nn.Sequential( 45 | nn.Linear(enc_dim, self.d_model), 46 | nn.ReLU(), 47 | nn.Dropout(p=dropout), 48 | nn.LayerNorm(self.d_model) 49 | ) 50 | self.pro_flag = True 51 | 52 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 53 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 54 | self.layers = ModuleList( 55 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 56 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 57 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 58 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 59 | self.max_len = max_len 60 | self.padding_idx = padding_idx 61 | self.N = N_dec 62 | self.return_logits = return_logits 63 | 64 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 65 | self.register_state('running_seq', torch.zeros((1,)).long()) 66 | 67 | def forward(self, input, encoder_output, mask_encoder=None): 68 | if self.pro_flag: 69 | encoder_output = self.in_proj_model(encoder_output) 70 | 71 | # input (b_s, seq_len) 72 | b_s, seq_len = input.shape[:2] 73 | mask_queries = (input != self.padding_idx).unsqueeze(-1) # (b_s, seq_len, 1) 74 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 75 | diagonal=1) 76 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 77 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 78 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 79 | 80 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 81 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 82 | 83 | if self._is_stateful: 84 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 85 | mask_self_attention = self.running_mask_self_attention 86 | self.running_seq.add_(1) 87 | seq = self.running_seq 88 | 89 | out = self.word_emb(input) + self.pos_emb(seq) 90 | for i, l in enumerate(self.layers): 91 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 92 | 93 | out = self.fc(out) 94 | if self.return_logits: 95 | return out 96 | else: 97 | return F.log_softmax(out, dim=-1) -------------------------------------------------------------------------------- /models/decoder_cca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from common.models.containers import Module, ModuleList 5 | from common.models.transformer.attention import MultiHeadAttention, OSAttention 6 | from common.models.transformer.utils import PositionWiseFeedForward, sinusoid_encoding_table 7 | 8 | class DecoderLayer(Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1): 10 | super(DecoderLayer, self).__init__() 11 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True) 12 | self.grid_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout) 13 | self.region_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout) 14 | self.grid_cross_region = MultiHeadAttention(d_model, d_k, d_v, h, dropout) 15 | self.region_cross_grid = MultiHeadAttention(d_model, d_k, d_v, h, dropout) 16 | self.self_cross = MultiHeadAttention(d_model, d_k, d_v, h, dropout, attention_module=OSAttention) 17 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 18 | 19 | def forward(self, input, grid_features, region_features, mask_pad, mask_self_att, mask_enc_att=None): 20 | self_att = self.self_att(input, input, input, mask_self_att) 21 | self_att = self_att * mask_pad 22 | 23 | grid_att = self.grid_att(self_att, grid_features, grid_features) 24 | grid_att = grid_att * mask_pad 25 | 26 | region_att = self.region_att(self_att, region_features, region_features, mask_enc_att) 27 | region_att = region_att * mask_pad 28 | 29 | grid_cross_att = self.grid_cross_region(grid_att, region_features, region_features, mask_enc_att) 30 | grid_cross_att = grid_cross_att * mask_pad 31 | 32 | region_cross_att = self.region_cross_grid(region_att, grid_features, grid_features) 33 | region_cross_att = region_cross_att * mask_pad 34 | 35 | enc_features = torch.stack([grid_att, region_att, grid_cross_att, region_cross_att], dim=-2) 36 | enc_att = self.self_cross(self_att, enc_features, enc_features) 37 | enc_att = enc_att * mask_pad 38 | 39 | ff = self.pwff(enc_att) 40 | ff = ff * mask_pad 41 | return ff 42 | 43 | class TransformerDecoder(Module): 44 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, return_logits = False): 45 | super(TransformerDecoder, self).__init__() 46 | self.d_model = d_model 47 | self.num_queries = 5 48 | 49 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 50 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 51 | self.layers = ModuleList([DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout) for _ in range(N_dec)]) 52 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 53 | self.max_len = max_len 54 | self.padding_idx = padding_idx 55 | self.N = N_dec 56 | self.return_logits = return_logits 57 | 58 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 59 | self.register_state('running_seq', torch.zeros((1,)).long()) 60 | 61 | def forward(self, input, grid_features, region_features, mask_encoder=None): 62 | 63 | # input (b_s, seq_len) 64 | b_s, seq_len = input.shape[:2] 65 | mask_queries = (input != self.padding_idx).unsqueeze(-1) # (b_s, seq_len, 1) 66 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), diagonal=1) 67 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 68 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 69 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 70 | 71 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 72 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 73 | 74 | if self._is_stateful: 75 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 76 | mask_self_attention = self.running_mask_self_attention 77 | self.running_seq.add_(1) 78 | seq = self.running_seq 79 | 80 | out = self.word_emb(input) + self.pos_emb(seq) 81 | for i, l in enumerate(self.layers): 82 | out = l(out, grid_features, region_features, mask_queries, mask_self_attention, mask_encoder) 83 | 84 | out = self.fc(out) 85 | if self.return_logits: 86 | return out 87 | else: 88 | return F.log_softmax(out, dim=-1) 89 | 90 | def build_decoder(vocab_size, max_len, N_dec, padding_idx, 91 | d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, **kwargs): 92 | Decoder = TransformerDecoder(vocab_size, max_len, N_dec, padding_idx, 93 | d_model, d_k, d_v, h, d_ff, dropout) 94 | 95 | return Decoder -------------------------------------------------------------------------------- /common/utils/tsv_file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license. 2 | 3 | import logging 4 | import shutil 5 | import numpy as np 6 | import os 7 | import os.path as op 8 | 9 | 10 | def mkdir(path): 11 | # if it is the current folder, skip. 12 | if path == '': 13 | return 14 | try: 15 | os.makedirs(path) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def generate_lineidx_file(filein, idxout): 22 | idxout_tmp = idxout + '.tmp' 23 | with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout: 24 | fsize = os.fstat(tsvin.fileno()).st_size 25 | fpos = 0 26 | while fpos!=fsize: 27 | tsvout.write(str(fpos)+"\n") 28 | tsvin.readline() 29 | fpos = tsvin.tell() 30 | os.rename(idxout_tmp, idxout) 31 | 32 | 33 | class TSVFile(object): 34 | def __init__(self, tsv_file, generate_lineidx=False): 35 | self.tsv_file = tsv_file 36 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' 37 | self._fp = None 38 | self._lineidx = None 39 | # the process always keeps the process which opens the file. 40 | # If the pid is not equal to the currrent pid, we will re-open the file. 41 | self.pid = None 42 | # generate lineidx if not exist 43 | if not op.isfile(self.lineidx) and generate_lineidx: 44 | generate_lineidx_file(self.tsv_file, self.lineidx) 45 | 46 | def __del__(self): 47 | if self._fp: 48 | self._fp.close() 49 | 50 | def __str__(self): 51 | return "TSVFile(tsv_file='{}')".format(self.tsv_file) 52 | 53 | def __repr__(self): 54 | return str(self) 55 | 56 | def num_rows(self): 57 | self._ensure_lineidx_loaded() 58 | return len(self._lineidx) 59 | 60 | def seek(self, idx): 61 | self._ensure_tsv_opened() 62 | self._ensure_lineidx_loaded() 63 | try: 64 | pos = self._lineidx[idx] 65 | except: 66 | logging.info('{}-{}'.format(self.tsv_file, idx)) 67 | raise 68 | self._fp.seek(pos) 69 | return [s.strip() for s in self._fp.readline().split('\t')] 70 | 71 | def seek_first_column(self, idx): 72 | self._ensure_tsv_opened() 73 | self._ensure_lineidx_loaded() 74 | pos = self._lineidx[idx] 75 | self._fp.seek(pos) 76 | return read_to_character(self._fp, '\t') 77 | 78 | def __getitem__(self, index): 79 | return self.seek(index) 80 | 81 | def __len__(self): 82 | return self.num_rows() 83 | 84 | def _ensure_lineidx_loaded(self): 85 | if self._lineidx is None: 86 | logging.info('loading lineidx: {}'.format(self.lineidx)) 87 | with open(self.lineidx, 'r') as fp: 88 | self._lineidx = [int(i.strip()) for i in fp.readlines()] 89 | 90 | def _ensure_tsv_opened(self): 91 | if self._fp is None: 92 | self._fp = open(self.tsv_file, 'r') 93 | self.pid = os.getpid() 94 | 95 | if self.pid != os.getpid(): 96 | logging.info('re-open {} because the process id changed'.format(self.tsv_file)) 97 | self._fp = open(self.tsv_file, 'r') 98 | self.pid = os.getpid() 99 | 100 | 101 | def tsv_writer(values, tsv_file_name, sep='\t'): 102 | mkdir(os.path.dirname(tsv_file_name)) 103 | tsv_file_name_tmp = tsv_file_name + '.tmp' 104 | with open(tsv_file_name_tmp, 'wb') as fp: 105 | assert values is not None 106 | for value in values: 107 | assert value is not None 108 | v = sep.join(map(lambda v: v.decode() if type(v) == bytes else str(v), value)) + '\n' 109 | v = v.encode() 110 | fp.write(v) 111 | os.rename(tsv_file_name_tmp, tsv_file_name) 112 | 113 | 114 | def concat_files(ins, out): 115 | out_tmp = out + '.tmp' 116 | with open(out_tmp, 'wb') as fp_out: 117 | for i, f in enumerate(ins): 118 | with open(f, 'rb') as fp_in: 119 | shutil.copyfileobj(fp_in, fp_out, 1024*1024*10) 120 | os.rename(out_tmp, out) 121 | 122 | 123 | def concat_tsv_files(tsvs, out_tsv, generate_lineidx=False): 124 | concat_files(tsvs, out_tsv) 125 | if generate_lineidx: 126 | sizes = [os.stat(t).st_size for t in tsvs] 127 | sizes = np.cumsum(sizes) 128 | all_idx = [] 129 | for i, t in enumerate(tsvs): 130 | for idx in load_list_file(op.splitext(t)[0] + '.lineidx'): 131 | if i == 0: 132 | all_idx.append(idx) 133 | else: 134 | all_idx.append(str(int(idx) + sizes[i - 1])) 135 | with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f: 136 | f.write('\n'.join(all_idx)) 137 | 138 | 139 | def load_list_file(fname): 140 | with open(fname, 'r') as fp: 141 | lines = fp.readlines() 142 | result = [line.strip() for line in lines] 143 | if len(result) > 0 and result[-1] == '': 144 | result = result[:-1] 145 | return result 146 | 147 | 148 | def reorder_tsv_keys(in_tsv_file, ordered_keys, out_tsv_file): 149 | tsv = TSVFile(in_tsv_file, generate_lineidx=True) 150 | keys = [tsv.seek(i)[0] for i in range(len(tsv))] 151 | key_to_idx = {key: i for i, key in enumerate(keys)} 152 | def gen_rows(): 153 | for key in ordered_keys: 154 | idx = key_to_idx[key] 155 | yield tsv.seek(idx) 156 | tsv_writer(gen_rows(), out_tsv_file) 157 | 158 | 159 | def delete_tsv_files(tsvs): 160 | for t in tsvs: 161 | if op.isfile(t): 162 | try_delete(t) 163 | line = op.splitext(t)[0] + '.lineidx' 164 | if op.isfile(line): 165 | try_delete(line) 166 | 167 | 168 | def try_once(func): 169 | def func_wrapper(*args, **kwargs): 170 | try: 171 | return func(*args, **kwargs) 172 | except Exception as e: 173 | logging.info('ignore error \n{}'.format(str(e))) 174 | return func_wrapper 175 | 176 | 177 | @try_once 178 | def try_delete(f): 179 | os.remove(f) -------------------------------------------------------------------------------- /common/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /common/models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import common.utils as utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_log_probs = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | if s is not None: 23 | shape = [int(sh) for sh in s.shape] 24 | beam = selected_beam 25 | for _ in shape[1:]: 26 | beam = beam.unsqueeze(-1) 27 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 28 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 29 | s = s.view(*([-1, ] + shape[1:])) 30 | return s 31 | 32 | return fn 33 | 34 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 35 | if isinstance(visual, torch.Tensor): 36 | visual_shape = visual.shape 37 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 38 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 39 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 40 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 41 | visual_exp = visual.view(visual_exp_shape) 42 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 43 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 44 | else: 45 | new_visual = [] 46 | for im in visual: 47 | visual_shape = im.shape 48 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 49 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 50 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 51 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 52 | visual_exp = im.view(visual_exp_shape) 53 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 54 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 55 | new_visual.append(new_im) 56 | visual = tuple(new_visual) 57 | return visual 58 | 59 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): 60 | self.b_s = utils.get_batch_size(visual) 61 | self.device = utils.get_device(visual) 62 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 63 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 64 | self.log_probs = [] 65 | self.selected_words = None 66 | if return_probs: 67 | self.all_log_probs = [] 68 | 69 | outputs = [] 70 | with self.model.statefulness(self.b_s): 71 | for t in range(self.max_len): 72 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) 73 | 74 | # Sort result 75 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 76 | outputs = torch.cat(outputs, -1) 77 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 78 | log_probs = torch.cat(self.log_probs, -1) 79 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 80 | if return_probs: 81 | all_log_probs = torch.cat(self.all_log_probs, 2) 82 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 83 | self.max_len, 84 | all_log_probs.shape[-1])) 85 | 86 | outputs = outputs.contiguous()[:, :out_size] 87 | log_probs = log_probs.contiguous()[:, :out_size] 88 | if out_size == 1: 89 | outputs = outputs.squeeze(1) 90 | log_probs = log_probs.squeeze(1) 91 | 92 | if return_probs: 93 | return outputs, log_probs, all_log_probs 94 | else: 95 | return outputs, log_probs 96 | 97 | def select(self, t, candidate_logprob, **kwargs): 98 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 99 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 100 | return selected_idx, selected_logprob 101 | 102 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs): 103 | cur_beam_size = 1 if t == 0 else self.beam_size 104 | 105 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) 106 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 107 | candidate_logprob = self.seq_logprob + word_logprob 108 | 109 | # Mask sequence if it reaches EOS 110 | if t > 0: 111 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 112 | self.seq_mask = self.seq_mask * mask 113 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 114 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 115 | old_seq_logprob[:, :, 1:] = -999 116 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 117 | 118 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 119 | selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode='floor') 120 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 121 | 122 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 123 | # visual = self._expand_visual(visual, cur_beam_size, selected_beam) 124 | 125 | self.seq_logprob = selected_logprob.unsqueeze(-1) 126 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 127 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 128 | outputs.append(selected_words.unsqueeze(-1)) 129 | 130 | if return_probs: 131 | if t == 0: 132 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 133 | else: 134 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 135 | 136 | this_word_logprob = torch.gather(word_logprob, 1, 137 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 138 | word_logprob.shape[-1])) 139 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 140 | self.log_probs = list( 141 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 142 | self.log_probs.append(this_word_logprob) 143 | self.selected_words = selected_words.view(-1, 1) 144 | 145 | return visual, outputs -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | antlr4-python3-runtime==4.9.3 3 | argon2-cffi==21.3.0 4 | argon2-cffi-bindings==21.2.0 5 | -e git+https://github.com/Vision-CAIR/artemis-v2.git@ea53c802622dcc8e8b9cedfaaff602e7034de20c#egg=artemis&subdirectory=neural_speaker/sat 6 | asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work 7 | attrs==21.4.0 8 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 9 | beautifulsoup4==4.11.1 10 | bleach==5.0.1 11 | blis @ file:///opt/conda/conda-bld/cython-blis_1651219782567/work 12 | boto3==1.24.21 13 | botocore==1.27.21 14 | brotlipy==0.7.0 15 | bs4==0.0.1 16 | bypy==1.8.1 17 | cachetools==5.2.0 18 | catalogue @ file:///opt/conda/conda-bld/catalogue_1651218742349/work 19 | certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi 20 | cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work 21 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 22 | click @ file:///tmp/build/80754af9/click_1646056590078/work 23 | clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620 24 | cloudpickle==2.1.0 25 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 26 | cryptography @ file:///tmp/build/80754af9/cryptography_1652101588893/work 27 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work 28 | cymem @ file:///opt/conda/conda-bld/cymem_1651237256138/work 29 | Cython @ file:///tmp/build/80754af9/cython_1647850345254/work 30 | dask==2022.7.1 31 | debugpy @ file:///tmp/build/80754af9/debugpy_1637091799509/work 32 | decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work 33 | defusedxml==0.7.1 34 | dill==0.3.5.1 35 | editdistpy==0.1.3 36 | einops==0.4.1 37 | en-core-web-sm @ file:///home/xys/code/captioning/en_core_web_sm-3.2.0.tar.gz 38 | entrypoints @ file:///tmp/build/80754af9/entrypoints_1649926439650/work 39 | executing @ file:///opt/conda/conda-bld/executing_1646925071911/work 40 | faiss==1.7.2 41 | fastjsonschema==2.16.1 42 | filelock @ file:///opt/conda/conda-bld/filelock_1647002191454/work 43 | fonttools==4.25.0 44 | fsspec==2022.3.0 45 | ftfy @ file:///home/xys/code/captioning/ftfy-6.1.0-py3-none-any.whl 46 | fvcore==0.1.5.post20220512 47 | gdown==4.5.1 48 | google-auth==2.9.0 49 | google-auth-oauthlib==0.4.6 50 | grpcio==1.47.0 51 | h5py @ file:///tmp/build/80754af9/h5py_1637138488546/work 52 | huggingface-hub @ file:///tmp/build/80754af9/huggingface_hub_1639662742275/work 53 | hydra-core==1.2.0 54 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work 55 | imageio==2.20.0 56 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1648544546694/work 57 | iopath==0.1.9 58 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1647000773790/work/dist/ipykernel-6.9.1-py3-none-any.whl 59 | ipython @ file:///opt/conda/conda-bld/ipython_1651600145335/work 60 | ipython-genutils==0.2.0 61 | ipywidgets==7.7.1 62 | jedi @ file:///tmp/build/80754af9/jedi_1644297102865/work 63 | Jinja2 @ file:///opt/conda/conda-bld/jinja2_1647436528585/work 64 | jmespath==1.0.1 65 | joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work 66 | json-lines==0.5.0 67 | jsonlines==3.0.0 68 | jsonschema==4.8.0 69 | jupyter==1.0.0 70 | jupyter-client @ file:///opt/conda/conda-bld/jupyter_client_1650622202839/work 71 | jupyter-console==6.4.4 72 | jupyter-core @ file:///opt/conda/conda-bld/jupyter_core_1651671229925/work 73 | jupyterlab-pygments==0.2.2 74 | jupyterlab-widgets==1.1.1 75 | kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1638569886207/work 76 | langcodes @ file:///opt/conda/conda-bld/langcodes_1643477751144/work 77 | lmdb==1.3.0 78 | locket==1.0.0 79 | Markdown==3.3.7 80 | MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621523467000/work 81 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1647441664166/work 82 | matplotlib-inline @ file:///tmp/build/80754af9/matplotlib-inline_1628242447089/work 83 | mistune==0.8.4 84 | mkl-fft==1.3.1 85 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work 86 | mkl-service==2.4.0 87 | multiprocess==0.70.13 88 | # Editable install with no version control (MultiScaleDeformableAttention==1.0) 89 | -e /home/xys/code/grit/models/ops 90 | munkres==1.1.4 91 | murmurhash @ file:///opt/conda/conda-bld/murmurhash_1651237169273/work 92 | nbclient==0.6.6 93 | nbconvert==6.5.0 94 | nbformat==5.4.0 95 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1649847906199/work 96 | networkx==2.8.5 97 | nltk==3.7 98 | notebook==6.4.12 99 | numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1651563629415/work 100 | oauthlib==3.2.0 101 | omegaconf==2.2.2 102 | opencv-python==4.6.0.66 103 | packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work 104 | pandas==1.4.3 105 | pandocfilters==1.5.0 106 | panopticapi @ git+https://github.com/cocodataset/panopticapi.git@7bb4655548f98f3fedc07bf37e9040a992b054b0 107 | parso @ file:///opt/conda/conda-bld/parso_1641458642106/work 108 | partd==1.2.0 109 | pathy @ file:///opt/conda/conda-bld/pathy_1651566172310/work 110 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 111 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 112 | Pillow==9.0.1 113 | pkginfo @ file:///home/xys/code/captioning/pkginfo-1.8.2-py2.py3-none-any.whl 114 | plotly==5.9.0 115 | portalocker==2.4.0 116 | preshed @ file:///opt/conda/conda-bld/preshed_1651240927559/work 117 | prometheus-client==0.14.1 118 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1633440160888/work 119 | protobuf==3.19.0 120 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 121 | pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work 122 | pyasn1==0.4.8 123 | pyasn1-modules==0.2.8 124 | pycocoevalcap==1.2 125 | pycocotools @ file:///home/conda/feedstock_root/build_artifacts/pycocotools_1641707022464/work 126 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 127 | pydantic @ file:///opt/conda/conda-bld/pydantic_1636617910351/work 128 | pyDeprecate==0.3.2 129 | Pygments @ file:///opt/conda/conda-bld/pygments_1644249106324/work 130 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work 131 | pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work 132 | pyrsistent==0.18.1 133 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work 134 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 135 | pytorch-lightning==1.6.3 136 | pytorch-transformers==1.2.0 137 | pytz==2022.1 138 | PyWavelets==1.3.0 139 | PyYAML==6.0 140 | pyzmq @ file:///tmp/build/80754af9/pyzmq_1638434985866/work 141 | qtconsole==5.3.1 142 | QtPy==2.1.0 143 | regex @ file:///tmp/build/80754af9/regex_1648447707500/work 144 | requests @ file:///opt/conda/conda-bld/requests_1641824580448/work 145 | requests-oauthlib==1.3.1 146 | requests-toolbelt==0.9.1 147 | rsa==4.8 148 | s3transfer==0.6.0 149 | sacremoses @ file:///tmp/build/80754af9/sacremoses_1633107328213/work 150 | scikit-image==0.19.3 151 | scikit-learn==1.1.1 152 | scipy==1.8.1 153 | seaborn==0.11.2 154 | Send2Trash==1.8.0 155 | sentencepiece==0.1.96 156 | shellingham @ file:///Users/ktietz/demo/mc3/conda-bld/shellingham_1629144685686/work 157 | six @ file:///tmp/build/80754af9/six_1644875935023/work 158 | smart-open @ file:///opt/conda/conda-bld/smart_open_1651563547610/work 159 | soupsieve==2.3.2.post1 160 | spacy @ file:///opt/conda/conda-bld/spacy_1643537927879/work 161 | spacy-legacy @ file:///opt/conda/conda-bld/spacy-legacy_1651648750835/work 162 | spacy-loggers @ file:///opt/conda/conda-bld/spacy-loggers_1643478552797/work 163 | srsly @ file:///opt/conda/conda-bld/srsly_1651584738433/work 164 | stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work 165 | symspellpy==6.7.6 166 | tabulate==0.8.10 167 | tenacity==8.0.1 168 | tensorboard==2.9.0 169 | tensorboard-data-server==0.6.1 170 | tensorboard-plugin-wit==1.8.1 171 | termcolor==1.1.0 172 | terminado==0.15.0 173 | thinc @ file:///opt/conda/conda-bld/thinc_1651591454033/work 174 | threadpoolctl==3.1.0 175 | tifffile==2022.7.28 176 | timm @ file:///home/conda/feedstock_root/build_artifacts/timm_1625085613679/work 177 | tinycss2==1.1.1 178 | tokenizers @ file:///opt/conda/conda-bld/tokenizers_1651822590771/work 179 | toolz==0.12.0 180 | torch==1.8.2 181 | torch-tb-profiler==0.4.0 182 | torchmetrics==0.8.2 183 | torchvision==0.9.2 184 | tornado @ file:///tmp/build/80754af9/tornado_1606942317143/work 185 | tqdm @ file:///tmp/build/80754af9/tqdm_1635330843403/work 186 | traitlets==5.3.0 187 | transformers @ file:///opt/conda/conda-bld/transformers_1651834311006/work 188 | typer @ file:///opt/conda/conda-bld/typer_1651237163820/work 189 | typing_extensions==4.2.0 190 | urllib3 @ file:///opt/conda/conda-bld/urllib3_1650639997961/work 191 | wasabi @ file:///opt/conda/conda-bld/wasabi_1651237317563/work 192 | wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work 193 | webencodings==0.5.1 194 | Werkzeug==2.1.2 195 | widgetsnbextension==3.6.1 196 | yacs==0.1.8 197 | zipp @ file:///opt/conda/conda-bld/zipp_1652341764480/work 198 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: caption 2 | channels: 3 | - pytorch-lts 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_gnu 11 | - asttokens=2.0.5=pyhd3eb1b0_0 12 | - backcall=0.2.0=pyhd3eb1b0_0 13 | - blas=1.0=mkl 14 | - brotli=1.0.9=he6710b0_2 15 | - brotlipy=0.7.0=py39h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - ca-certificates=2022.4.26=h06a4308_0 18 | - catalogue=2.0.7=py39h06a4308_0 19 | - certifi=2022.6.15=py39h06a4308_0 20 | - cffi=1.15.0=py39hd667e15_1 21 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 22 | - click=8.0.4=py39h06a4308_0 23 | - colorama=0.4.4=pyhd3eb1b0_0 24 | - cryptography=37.0.1=py39h9ce1e76_0 25 | - cudatoolkit=11.1.74=h6bb024c_0 26 | - cycler=0.11.0=pyhd3eb1b0_0 27 | - cymem=2.0.6=py39h295c915_0 28 | - cython=0.29.28=py39h295c915_0 29 | - cython-blis=0.7.7=py39hce1f21e_0 30 | - debugpy=1.5.1=py39h295c915_0 31 | - decorator=5.1.1=pyhd3eb1b0_0 32 | - entrypoints=0.4=py39h06a4308_0 33 | - executing=0.8.3=pyhd3eb1b0_0 34 | - faiss=1.7.2=py39cuda111h03baf68_0_cuda 35 | - ffmpeg=4.2.2=h20bf706_0 36 | - filelock=3.6.0=pyhd3eb1b0_0 37 | - fonttools=4.25.0=pyhd3eb1b0_0 38 | - freetype=2.11.0=h70c0345_0 39 | - giflib=5.2.1=h7b6447c_0 40 | - gmp=6.2.1=h2531618_2 41 | - gnutls=3.6.15=he1e5248_0 42 | - h5py=3.6.0=py39ha0f2276_0 43 | - hdf5=1.10.6=hb1b8bf9_0 44 | - huggingface_hub=0.2.1=pyhd3eb1b0_0 45 | - idna=3.3=pyhd3eb1b0_0 46 | - importlib-metadata=4.11.3=py39h06a4308_0 47 | - importlib_metadata=4.11.3=hd3eb1b0_0 48 | - intel-openmp=2021.4.0=h06a4308_3561 49 | - ipykernel=6.9.1=py39h06a4308_0 50 | - ipython=8.3.0=py39h06a4308_0 51 | - jedi=0.18.1=py39h06a4308_1 52 | - jinja2=3.0.3=pyhd3eb1b0_0 53 | - joblib=1.1.0=pyhd3eb1b0_0 54 | - jpeg=9b=0 55 | - jupyter_client=7.2.2=py39h06a4308_0 56 | - jupyter_core=4.10.0=py39h06a4308_0 57 | - kiwisolver=1.3.2=py39h295c915_0 58 | - lame=3.100=h7b6447c_0 59 | - langcodes=3.3.0=pyhd3eb1b0_0 60 | - lcms2=2.12=h3be6417_0 61 | - ld_impl_linux-64=2.38=h1181459_0 62 | - libblas=3.9.0=12_linux64_mkl 63 | - libfaiss=1.7.2=cuda111h7721031_0_cuda 64 | - libfaiss-avx2=1.7.2=cuda111h1234567_0_cuda 65 | - libffi=3.3=he6710b0_2 66 | - libgcc-ng=12.1.0=h8d9b700_16 67 | - libgfortran-ng=7.5.0=ha8ba4b0_17 68 | - libgfortran4=7.5.0=ha8ba4b0_17 69 | - libgomp=12.1.0=h8d9b700_16 70 | - libidn2=2.3.2=h7f8727e_0 71 | - liblapack=3.9.0=12_linux64_mkl 72 | - libopus=1.3.1=h7b6447c_0 73 | - libpng=1.6.37=hbc83047_0 74 | - libsodium=1.0.18=h7b6447c_0 75 | - libstdcxx-ng=12.1.0=ha89aaad_16 76 | - libtasn1=4.16.0=h27cfd23_0 77 | - libtiff=4.2.0=h85742a9_0 78 | - libunistring=0.9.10=h27cfd23_0 79 | - libuv=1.40.0=h7b6447c_0 80 | - libvpx=1.7.0=h439df22_0 81 | - libwebp=1.2.0=h89dd481_0 82 | - libwebp-base=1.2.0=h27cfd23_0 83 | - lz4-c=1.9.3=h295c915_1 84 | - markupsafe=2.0.1=py39h27cfd23_0 85 | - matplotlib-base=3.5.1=py39ha18d171_1 86 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 87 | - mkl=2021.4.0=h06a4308_640 88 | - mkl-service=2.4.0=py39h7f8727e_0 89 | - mkl_fft=1.3.1=py39hd3c417c_0 90 | - mkl_random=1.2.2=py39h51133e4_0 91 | - munkres=1.1.4=py_0 92 | - murmurhash=1.0.7=py39h295c915_0 93 | - ncurses=6.3=h7f8727e_2 94 | - nest-asyncio=1.5.5=py39h06a4308_0 95 | - nettle=3.7.3=hbbd107a_1 96 | - ninja=1.10.2=h06a4308_5 97 | - ninja-base=1.10.2=hd09550d_5 98 | - numpy=1.21.5=py39he7a7128_2 99 | - numpy-base=1.21.5=py39hf524024_2 100 | - openh264=2.1.1=h4ff587b_0 101 | - openssl=1.1.1p=h5eee18b_0 102 | - packaging=21.3=pyhd3eb1b0_0 103 | - parso=0.8.3=pyhd3eb1b0_0 104 | - pathy=0.6.1=py39h06a4308_0 105 | - pexpect=4.8.0=pyhd3eb1b0_3 106 | - pickleshare=0.7.5=pyhd3eb1b0_1003 107 | - pillow=9.0.1=py39h22f2fdc_0 108 | - pip=21.2.4=py39h06a4308_0 109 | - preshed=3.0.6=py39h295c915_0 110 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 111 | - ptyprocess=0.7.0=pyhd3eb1b0_2 112 | - pure_eval=0.2.2=pyhd3eb1b0_0 113 | - pycocotools=2.0.4=py39hce5d2b2_0 114 | - pycparser=2.21=pyhd3eb1b0_0 115 | - pydantic=1.8.2=py39h7f8727e_0 116 | - pygments=2.11.2=pyhd3eb1b0_0 117 | - pyopenssl=22.0.0=pyhd3eb1b0_0 118 | - pyparsing=3.0.4=pyhd3eb1b0_0 119 | - pysocks=1.7.1=py39h06a4308_0 120 | - python=3.9.12=h12debd9_0 121 | - python-dateutil=2.8.2=pyhd3eb1b0_0 122 | - python_abi=3.9=2_cp39 123 | - pytorch=1.8.2=py3.9_cuda11.1_cudnn8.0.5_0 124 | - pyzmq=22.3.0=py39h295c915_2 125 | - readline=8.1.2=h7f8727e_1 126 | - requests=2.27.1=pyhd3eb1b0_0 127 | - sacremoses=0.0.43=pyhd3eb1b0_0 128 | - setuptools=61.2.0=py39h06a4308_0 129 | - shellingham=1.3.1=pyhd3eb1b0_0 130 | - six=1.16.0=pyhd3eb1b0_1 131 | - smart_open=5.2.1=py39h06a4308_0 132 | - spacy=3.2.1=py39hae6d005_0 133 | - spacy-legacy=3.0.9=py39h06a4308_0 134 | - spacy-loggers=1.0.1=pyhd3eb1b0_0 135 | - sqlite=3.38.3=hc218d9a_0 136 | - srsly=2.4.3=py39h295c915_0 137 | - stack_data=0.2.0=pyhd3eb1b0_0 138 | - thinc=8.0.15=py39hae6d005_0 139 | - timm=0.4.12=pyhd8ed1ab_0 140 | - tk=8.6.11=h1ccaba5_1 141 | - tokenizers=0.11.4=py39h3dcd8bd_1 142 | - torchvision=0.9.2=py39_cu111 143 | - tornado=6.1=py39h27cfd23_0 144 | - tqdm=4.62.3=pyhd3eb1b0_1 145 | - transformers=4.18.0=py39h06a4308_0 146 | - typer=0.4.1=py39h06a4308_0 147 | - typing_extensions=4.1.1=pyh06a4308_0 148 | - tzdata=2022a=hda174b7_0 149 | - urllib3=1.26.9=py39h06a4308_0 150 | - wasabi=0.9.1=py39h06a4308_0 151 | - wcwidth=0.2.5=pyhd3eb1b0_0 152 | - wheel=0.37.1=pyhd3eb1b0_0 153 | - x264=1!157.20191217=h7b6447c_0 154 | - xz=5.2.5=h7f8727e_1 155 | - yaml=0.2.5=h7b6447c_0 156 | - zeromq=4.3.4=h2531618_0 157 | - zipp=3.8.0=py39h06a4308_0 158 | - zlib=1.2.12=h7f8727e_2 159 | - zstd=1.4.9=haebb681_0 160 | - pip: 161 | - absl-py==1.0.0 162 | - antlr4-python3-runtime==4.9.3 163 | - argon2-cffi==21.3.0 164 | - argon2-cffi-bindings==21.2.0 165 | - attrs==21.4.0 166 | - beautifulsoup4==4.11.1 167 | - bleach==5.0.1 168 | - boto3==1.24.21 169 | - botocore==1.27.21 170 | - bs4==0.0.1 171 | - bypy==1.8.1 172 | - cachetools==5.2.0 173 | - clip==1.0 174 | - cloudpickle==2.1.0 175 | - dask==2022.7.1 176 | - defusedxml==0.7.1 177 | - dill==0.3.5.1 178 | - editdistpy==0.1.3 179 | - einops==0.4.1 180 | - en-core-web-sm==3.2.0 181 | - fastjsonschema==2.16.1 182 | - fsspec==2022.3.0 183 | - ftfy==6.1.0 184 | - fvcore==0.1.5.post20220512 185 | - gdown==4.5.1 186 | - google-auth==2.9.0 187 | - google-auth-oauthlib==0.4.6 188 | - grpcio==1.47.0 189 | - hydra-core==1.2.0 190 | - imageio==2.20.0 191 | - iopath==0.1.9 192 | - ipython-genutils==0.2.0 193 | - ipywidgets==7.7.1 194 | - jmespath==1.0.1 195 | - json-lines==0.5.0 196 | - jsonlines==3.0.0 197 | - jsonschema==4.8.0 198 | - jupyter==1.0.0 199 | - jupyter-console==6.4.4 200 | - jupyterlab-pygments==0.2.2 201 | - jupyterlab-widgets==1.1.1 202 | - lmdb==1.3.0 203 | - locket==1.0.0 204 | - markdown==3.3.7 205 | - mistune==0.8.4 206 | - multiprocess==0.70.13 207 | - nbclient==0.6.6 208 | - nbconvert==6.5.0 209 | - nbformat==5.4.0 210 | - networkx==2.8.5 211 | - nltk==3.7 212 | - notebook==6.4.12 213 | - oauthlib==3.2.0 214 | - omegaconf==2.2.2 215 | - opencv-python==4.6.0.66 216 | - pandas==1.4.3 217 | - pandocfilters==1.5.0 218 | - panopticapi==0.1 219 | - partd==1.2.0 220 | - pkginfo==1.8.2 221 | - plotly==5.9.0 222 | - portalocker==2.4.0 223 | - prometheus-client==0.14.1 224 | - protobuf==3.19.0 225 | - pyasn1==0.4.8 226 | - pyasn1-modules==0.2.8 227 | - pycocoevalcap==1.2 228 | - pydeprecate==0.3.2 229 | - pyrsistent==0.18.1 230 | - pytorch-lightning==1.6.3 231 | - pytorch-transformers==1.2.0 232 | - pytz==2022.1 233 | - pywavelets==1.3.0 234 | - pyyaml==6.0 235 | - qtconsole==5.3.1 236 | - qtpy==2.1.0 237 | - regex==2022.3.2 238 | - requests-oauthlib==1.3.1 239 | - requests-toolbelt==0.9.1 240 | - rsa==4.8 241 | - s3transfer==0.6.0 242 | - scikit-image==0.19.3 243 | - scikit-learn==1.1.1 244 | - scipy==1.8.1 245 | - seaborn==0.11.2 246 | - send2trash==1.8.0 247 | - sentencepiece==0.1.96 248 | - soupsieve==2.3.2.post1 249 | - symspellpy==6.7.6 250 | - tabulate==0.8.10 251 | - tenacity==8.0.1 252 | - tensorboard==2.9.0 253 | - tensorboard-data-server==0.6.1 254 | - tensorboard-plugin-wit==1.8.1 255 | - termcolor==1.1.0 256 | - terminado==0.15.0 257 | - threadpoolctl==3.1.0 258 | - tifffile==2022.7.28 259 | - tinycss2==1.1.1 260 | - toolz==0.12.0 261 | - torch-tb-profiler==0.4.0 262 | - torchmetrics==0.8.2 263 | - traitlets==5.3.0 264 | - typing-extensions==4.2.0 265 | - webencodings==0.5.1 266 | - werkzeug==2.1.2 267 | - widgetsnbextension==3.6.1 268 | - yacs==0.1.8 269 | prefix: /home/xys/miniconda3/envs/caption 270 | -------------------------------------------------------------------------------- /common/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import time 4 | import numpy as np 5 | from torch.optim import Adam 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torch.nn import NLLLoss 8 | import os 9 | from shutil import copyfile 10 | from torch.utils.tensorboard import SummaryWriter 11 | from common.evaluation import PTBTokenizer, Cider 12 | from common.data import DataLoader 13 | from common.data.field import RawField 14 | from common.utils import evaluate_loss, evaluate_metrics, train_xe, train_scst 15 | from common.utils.utils import setup_seed 16 | 17 | setup_seed(123456) 18 | 19 | def train(args, model, datasets, image_field, text_field, optim=None, scheduler=None, 20 | train_xe_fn = train_xe, evaluate_loss_fn = evaluate_loss): 21 | 22 | device = args.device 23 | output = args.output 24 | use_rl = args.use_rl 25 | 26 | date = time.strftime("%Y-%m-%d", time.localtime()) 27 | writer = SummaryWriter(log_dir=os.path.join(output, 'tensorboard_logs', args.exp_name, date)) 28 | 29 | train_dataset, val_dataset, test_dataset = datasets 30 | 31 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 32 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 33 | 34 | cider_cache_path = 'cache/cider_cache.pkl' 35 | if use_rl: 36 | if os.path.exists(cider_cache_path): 37 | cider_train = torch.load(cider_cache_path) 38 | else: 39 | ref_caps_train = list(train_dataset.text) 40 | cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train)) 41 | torch.save(cider_train, cider_cache_path) 42 | 43 | train_dataset = train_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 44 | 45 | 46 | train_batch_size = args.batch_size // 5 if use_rl else args.batch_size 47 | dataloader_train = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) 48 | dataloader_val = DataLoader(val_dataset, batch_size=train_batch_size, num_workers=args.workers, pin_memory=True, drop_last=False) 49 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=train_batch_size, num_workers=args.workers, pin_memory=True, drop_last=False) 50 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=train_batch_size, num_workers=args.workers, pin_memory=True, drop_last=False) 51 | 52 | # def lambda_lr(s): 53 | # warm_up = args.warmup 54 | # s += 1 55 | # return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5) 56 | 57 | def lambda_lr(s): 58 | if s <= 3: 59 | lr = args.xe_base_lr * s / 4 60 | elif s <= 10: 61 | lr = args.xe_base_lr 62 | elif s <= 12: 63 | lr = args.xe_base_lr * 0.2 64 | else: 65 | lr = args.xe_base_lr * 0.2 * 0.2 66 | return lr 67 | 68 | def lambda_lr_rl(s): 69 | refine_epoch = 8 70 | if s <= refine_epoch: 71 | lr = args.rl_base_lr 72 | elif s <= refine_epoch + 3: 73 | lr = args.rl_base_lr * 0.2 74 | elif s <= refine_epoch + 6: 75 | lr = args.rl_base_lr * 0.2 * 0.2 76 | else: 77 | lr = args.rl_base_lr * 0.2 * 0.2 * 0.2 78 | return lr 79 | 80 | # Initial conditions 81 | if use_rl: 82 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) 83 | scheduler = LambdaLR(optim, lambda_lr_rl) 84 | else: 85 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) if optim is None else optim 86 | scheduler = LambdaLR(optim, lambda_lr) if scheduler is None else scheduler 87 | 88 | 89 | loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['']) 90 | best_cider = .0 91 | patience = 0 92 | start_epoch = 0 93 | 94 | last_saved = os.path.join('coco/checkpoints', output, '%s_last.pth' % args.exp_name) 95 | best_saved = os.path.join('coco/checkpoints', output, '%s_best.pth' % args.exp_name) 96 | 97 | if args.resume_last or args.resume_best: 98 | if use_rl: 99 | last_saved = os.path.join('coco/checkpoints', output, '%s_rl_last.pth' % args.exp_name) 100 | best_saved = os.path.join('coco/checkpoints', output, '%s_rl_best.pth' % args.exp_name) 101 | 102 | if args.resume_last: 103 | fname = last_saved 104 | else: 105 | fname = best_saved 106 | 107 | if os.path.exists(fname): 108 | data = torch.load(fname) 109 | torch.set_rng_state(data['torch_rng_state']) 110 | torch.cuda.set_rng_state(data['cuda_rng_state']) 111 | np.random.set_state(data['numpy_rng_state']) 112 | random.setstate(data['random_rng_state']) 113 | model.load_state_dict(data['state_dict'], strict=False) 114 | optim.load_state_dict(data['optimizer']) 115 | scheduler.load_state_dict(data['scheduler']) 116 | start_epoch = data['epoch'] + 1 117 | best_cider = data['best_cider'] 118 | patience = data['patience'] 119 | print('Resuming from epoch %d, validation loss %f, and best cider %f' % ( 120 | data['epoch'], data['val_loss'], data['best_cider'])) 121 | 122 | elif use_rl: 123 | data = torch.load(best_saved, map_location=device) 124 | model.load_state_dict(data['state_dict'], strict=False) 125 | best_cider = data['best_cider'] 126 | start_epoch = 0 127 | patience = 0 128 | print('Resuming from XE epoch %d, validation loss %f, and best cider %f' % ( 129 | data['epoch'], data['val_loss'], data['best_cider'])) 130 | 131 | last_saved = os.path.join('coco/checkpoints', output, '%s_rl_last.pth' % args.exp_name) 132 | best_saved = os.path.join('coco/checkpoints', output, '%s_rl_best.pth' % args.exp_name) 133 | 134 | print("Training starts") 135 | for e in range(start_epoch, start_epoch + 100): 136 | if not use_rl: 137 | train_loss = train_xe_fn(model, dataloader_train, optim, loss_fn, text_field, e, device, scheduler, args) 138 | writer.add_scalar('data/train_loss', train_loss, e) 139 | else: 140 | train_loss, reward, reward_baseline = train_scst(model, dataloader_train, optim, 141 | cider_train, text_field, e, device, scheduler, args) 142 | writer.add_scalar('data/train_loss', train_loss, e) 143 | writer.add_scalar('data/reward', reward, e) 144 | writer.add_scalar('data/reward_baseline', reward_baseline, e) 145 | 146 | # Validation loss 147 | val_loss = evaluate_loss_fn(model, dataloader_val, loss_fn, text_field, e, device, args) 148 | writer.add_scalar('data/val_loss', val_loss, e) 149 | 150 | # Validation scores 151 | scores = evaluate_metrics(model, dict_dataloader_val, text_field, e, device, args) 152 | print("Validation scores", scores) 153 | val_cider = scores['CIDEr'] 154 | writer.add_scalar('data/val_cider', val_cider, e) 155 | writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e) 156 | writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e) 157 | writer.add_scalar('data/val_meteor', scores['METEOR'], e) 158 | writer.add_scalar('data/val_rouge', scores['ROUGE'], e) 159 | 160 | # Test scores 161 | scores = evaluate_metrics(model, dict_dataloader_test, text_field, e, device, args) 162 | print("Test scores", scores) 163 | writer.add_scalar('data/test_cider', scores['CIDEr'], e) 164 | writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e) 165 | writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e) 166 | writer.add_scalar('data/test_meteor', scores['METEOR'], e) 167 | writer.add_scalar('data/test_rouge', scores['ROUGE'], e) 168 | 169 | # Prepare for next epoch 170 | best = False 171 | if val_cider >= best_cider: 172 | best_cider = val_cider 173 | patience = 0 174 | best = True 175 | else: 176 | patience += 1 177 | 178 | # switch_to_rl = False 179 | exit_train = False 180 | # automatic training strategy 181 | if patience == 5: 182 | if e < 15: 183 | patience = 0 184 | else: 185 | print('patience reached.') 186 | exit_train = True 187 | 188 | saved_dir = os.path.join('coco/checkpoints', output) 189 | if not os.path.exists(saved_dir): 190 | os.makedirs(saved_dir) 191 | 192 | torch.save({ 193 | 'torch_rng_state': torch.get_rng_state(), 194 | 'cuda_rng_state': torch.cuda.get_rng_state(), 195 | 'numpy_rng_state': np.random.get_state(), 196 | 'random_rng_state': random.getstate(), 197 | 'epoch': e, 198 | 'val_loss': val_loss, 199 | 'val_cider': val_cider, 200 | 'state_dict': model.state_dict(), 201 | 'optimizer': optim.state_dict(), 202 | 'scheduler': scheduler.state_dict(), 203 | 'patience': patience, 204 | 'best_cider': best_cider, 205 | 'use_rl': use_rl, 206 | }, last_saved) 207 | 208 | if best: 209 | copyfile(last_saved, best_saved) 210 | 211 | if exit_train: 212 | writer.close() 213 | break -------------------------------------------------------------------------------- /common/evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /common/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import os, pickle 4 | import numpy as np 5 | from tqdm import tqdm 6 | import requests 7 | import itertools 8 | import multiprocessing 9 | from common.data.dataset import COCODataset 10 | import common.evaluation as evaluation 11 | from torch import Tensor 12 | 13 | def setup_seed(seed): 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | def one_hot_to_index(one_hot: Tensor) -> Tensor: 21 | """ 22 | Converts a one-hot tensor into a tensor with corresponding indexes 23 | """ 24 | device, dtype = one_hot.device, one_hot.dtype 25 | vocab_size = one_hot.shape[-1] 26 | oh2idx = torch.tensor(range(vocab_size), dtype=dtype, device=device) 27 | return (one_hot @ oh2idx.unsqueeze(dim=1)).long().squeeze(dim=-1) 28 | 29 | def download_from_url(url, path): 30 | """Download file, with logic (from tensor2tensor) for Google Drive""" 31 | if 'drive.google.com' not in url: 32 | print('Downloading %s; may take a few minutes' % url) 33 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 34 | with open(path, "wb") as file: 35 | file.write(r.content) 36 | return 37 | print('Downloading from Google Drive; may take a few minutes') 38 | confirm_token = None 39 | session = requests.Session() 40 | response = session.get(url, stream=True) 41 | for k, v in response.cookies.items(): 42 | if k.startswith("download_warning"): 43 | confirm_token = v 44 | 45 | if confirm_token: 46 | url = url + "&confirm=" + confirm_token 47 | response = session.get(url, stream=True) 48 | 49 | chunk_size = 16 * 1024 50 | with open(path, "wb") as f: 51 | for chunk in response.iter_content(chunk_size): 52 | if chunk: 53 | f.write(chunk) 54 | 55 | 56 | def create_dataset(args, image_field, text_field): 57 | # Create the dataset 58 | dataset = COCODataset(image_field, text_field, args.image_folder, args.annotation_folder, args.annotation_folder) 59 | train_dataset, val_dataset, test_dataset = dataset.splits 60 | 61 | vocab_path = 'cache/vocab.pkl' 62 | if not os.path.isfile(vocab_path): 63 | print("Building vocabulary") 64 | text_field.build_vocab(train_dataset, val_dataset, min_freq=5) 65 | pickle.dump(text_field.vocab, open(vocab_path, 'wb')) 66 | else: 67 | text_field.vocab = pickle.load(open(vocab_path, 'rb')) 68 | 69 | return (train_dataset, val_dataset, test_dataset) 70 | 71 | 72 | def evaluate_loss(model, dataloader, loss_fn, text_field, epoch, device = 'cuda', args=None): 73 | # Validation loss 74 | model.eval() 75 | running_loss = .0 76 | with tqdm(desc='Epoch %d - validation' % epoch, unit='it', total=len(dataloader)) as pbar: 77 | with torch.no_grad(): 78 | for it, (images, captions) in enumerate(dataloader): 79 | # if it == 10: 80 | # break 81 | captions = captions.to(device) 82 | if isinstance(images, tuple) or isinstance(images, list): 83 | images = [x.to(device) for x in images] 84 | else: 85 | images = images.to(device) 86 | out = model(images, captions) 87 | captions = captions[:, 1:].contiguous() 88 | out = out[:, :-1].contiguous() 89 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1)) 90 | this_loss = loss.item() 91 | running_loss += this_loss 92 | 93 | pbar.set_postfix(loss=running_loss / (it + 1)) 94 | pbar.update() 95 | 96 | val_loss = running_loss / len(dataloader) 97 | return val_loss 98 | 99 | def dict_to_cuda(input_dict, deivce): 100 | for key in input_dict: 101 | if isinstance(input_dict[key], list): 102 | input_dict[key] = [ val.to(deivce) for val in input_dict[key]] 103 | elif isinstance(input_dict[key], dict): 104 | dict_to_cuda(input_dict[key], deivce) 105 | else: 106 | input_dict[key] = input_dict[key].to(deivce) 107 | 108 | def evaluate_metrics(model, dataloader, text_field, epoch, device = 'cuda', args=None): 109 | import itertools 110 | model.eval() 111 | gen = {} 112 | gts = {} 113 | with tqdm(desc='Epoch %d - evaluation' % epoch, unit='it', total=len(dataloader)) as pbar: 114 | for it, (images, caps_gt) in enumerate(iter(dataloader)): 115 | # if it == 10: 116 | # break 117 | with torch.no_grad(): 118 | if isinstance(images, tuple) or isinstance(images, list): 119 | images = [x.to(device) for x in images] 120 | else: 121 | images = images.to(device) 122 | # images[0] = images[0].to(device) 123 | # dict_to_cuda(images[1], device) 124 | # images[0] = images[0].to(device) 125 | # images[1] = images[1].to(device) 126 | # images[2] = { 127 | # k1: { 128 | # k2: v2.to(device) 129 | # for k2, v2 in v1.items() 130 | # } 131 | # for k1, v1 in images[2].items() 132 | # } 133 | 134 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 135 | 136 | caps_gen = text_field.decode(out, join_words=False) 137 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 138 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 139 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ] 140 | gts['%d_%d' % (it, i)] = gts_i 141 | pbar.update() 142 | 143 | gts = evaluation.PTBTokenizer.tokenize(gts) 144 | gen = evaluation.PTBTokenizer.tokenize(gen) 145 | scores, _ = evaluation.compute_scores(gts, gen) 146 | return scores 147 | 148 | 149 | def train_xe(model, dataloader, optim, loss_fn, text_field, epoch, device = 'cuda', scheduler = None, args=None): 150 | # Training with cross-entropy 151 | model.train() 152 | if scheduler is not None: 153 | scheduler.step() 154 | # print('lr0 = ', optim.state_dict()['param_groups'][0]['lr']) 155 | # print('lr1 = ', optim.state_dict()['param_groups'][1]['lr']) 156 | running_loss = .0 157 | with tqdm(desc='Epoch %d - train' % epoch, unit='it', total=len(dataloader)) as pbar: 158 | for it, (images, captions) in enumerate(dataloader): 159 | # if it == 10: 160 | # break 161 | captions = captions.to(device) 162 | if isinstance(images, tuple) or isinstance(images, list): 163 | images = [x.to(device) for x in images] 164 | else: 165 | images = images.to(device) 166 | out = model(images, captions) 167 | optim.zero_grad() 168 | captions_gt = captions[:, 1:].contiguous() 169 | out = out[:, :-1].contiguous() 170 | loss = loss_fn(out.view(-1, out.shape[-1]), captions_gt.view(-1)) 171 | loss.backward() 172 | 173 | optim.step() 174 | this_loss = loss.item() 175 | running_loss += this_loss 176 | 177 | pbar.set_postfix(loss=running_loss / (it + 1)) 178 | pbar.update() 179 | # if scheduler is not None: 180 | # scheduler.step() 181 | 182 | loss = running_loss / len(dataloader) 183 | return loss 184 | 185 | 186 | def train_scst(model, dataloader, optim, cider, text_field, epoch, device = 'cuda', scheduler = None, args=None): 187 | # Training with self-critical 188 | model.train() 189 | if scheduler is not None: 190 | scheduler.step() 191 | lr = optim.state_dict()['param_groups'][0]['lr'] 192 | 193 | tokenizer_pool = multiprocessing.Pool() 194 | running_reward = .0 195 | running_reward_baseline = .0 196 | running_loss = .0 197 | seq_len = 20 198 | beam_size = 5 199 | 200 | with tqdm(desc='Epoch %d - train' % epoch, unit='it', total=len(dataloader)) as pbar: 201 | for it, (images, caps_gt) in enumerate(dataloader): 202 | # if it == 2: 203 | # break 204 | if isinstance(images, tuple) or isinstance(images, list): 205 | images = [x.to(device) for x in images] 206 | bs = images[0].shape[0] 207 | else: 208 | images = images.to(device) 209 | bs = images.shape[0] 210 | outs, log_probs = model.beam_search(images, seq_len, text_field.vocab.stoi[''], 211 | beam_size, out_size=beam_size) 212 | optim.zero_grad() 213 | 214 | # Rewards 215 | caps_gen = text_field.decode(outs.view(-1, seq_len)) 216 | caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt))) 217 | caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt]) 218 | reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32) 219 | reward = torch.from_numpy(reward).to(device).view(bs, beam_size) 220 | reward_baseline = torch.mean(reward, -1, keepdim=True) 221 | loss = -torch.mean(log_probs, -1) * (reward - reward_baseline) 222 | 223 | loss = loss.mean() 224 | loss.backward() 225 | optim.step() 226 | 227 | running_loss += loss.item() 228 | running_reward += reward.mean().item() 229 | running_reward_baseline += reward_baseline.mean().item() 230 | pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1), lr=lr) 231 | pbar.update() 232 | 233 | loss = running_loss / len(dataloader) 234 | reward = running_reward / len(dataloader) 235 | reward_baseline = running_reward_baseline / len(dataloader) 236 | return loss, reward, reward_baseline -------------------------------------------------------------------------------- /common/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import collections 5 | import torch 6 | 7 | from .example import Example 8 | from .utils import nostdout 9 | from pycocotools.coco import COCO as pyCOCO 10 | 11 | class Dataset(object): 12 | def __init__(self, examples, fields): 13 | self.examples = examples 14 | self.fields = dict(fields) 15 | 16 | def collate_fn(self): 17 | def collate(batch): 18 | if len(self.fields) == 1: 19 | batch = [batch, ] 20 | else: 21 | batch = list(zip(*batch)) 22 | 23 | tensors = [] 24 | for field, data in zip(self.fields.values(), batch): 25 | tensor = field.process(data) 26 | # if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): 27 | # tensors.extend(tensor) 28 | # else: 29 | tensors.append(tensor) 30 | 31 | if len(tensors) > 1: 32 | return tensors 33 | else: 34 | return tensors[0] 35 | 36 | return collate 37 | 38 | def __getitem__(self, i): 39 | example = self.examples[i] 40 | data = [] 41 | for field_name, field in self.fields.items(): 42 | data.append(field.preprocess(getattr(example, field_name))) 43 | 44 | if len(data) == 1: 45 | data = data[0] 46 | return data 47 | 48 | def __len__(self): 49 | return len(self.examples) 50 | 51 | def __getattr__(self, attr): 52 | if attr in self.fields: 53 | for x in self.examples: 54 | yield getattr(x, attr) 55 | 56 | 57 | class ValueDataset(Dataset): 58 | def __init__(self, examples, fields, dictionary): 59 | self.dictionary = dictionary 60 | super(ValueDataset, self).__init__(examples, fields) 61 | 62 | def collate_fn(self): 63 | def collate(batch): 64 | value_batch_flattened = list(itertools.chain(*batch)) 65 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) 66 | 67 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) 68 | if isinstance(value_tensors_flattened, collections.Sequence) \ 69 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): 70 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened] 71 | else: 72 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] 73 | 74 | return value_tensors 75 | return collate 76 | 77 | def __getitem__(self, i): 78 | if i not in self.dictionary: 79 | raise IndexError 80 | 81 | values_data = [] 82 | for idx in self.dictionary[i]: 83 | value_data = super(ValueDataset, self).__getitem__(idx) 84 | values_data.append(value_data) 85 | return values_data 86 | 87 | def __len__(self): 88 | return len(self.dictionary) 89 | 90 | 91 | class DictionaryDataset(Dataset): 92 | def __init__(self, examples, fields, key_fields, val_fields=None): 93 | if not isinstance(key_fields, (tuple, list)): 94 | key_fields = (key_fields,) 95 | if (val_fields is not None) and (not isinstance(val_fields, (tuple, list))): 96 | val_fields = (val_fields,) 97 | for field in key_fields: 98 | assert (field in fields) 99 | 100 | dictionary = collections.defaultdict(list) 101 | key_fields = {k: fields[k] for k in key_fields} 102 | if val_fields is not None: 103 | value_fields = {k: fields[k] for k in val_fields} 104 | else: 105 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} 106 | key_examples = [] 107 | key_dict = dict() 108 | value_examples = [] 109 | 110 | for i, e in enumerate(examples): 111 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) 112 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) 113 | if key_example not in key_dict: 114 | key_dict[key_example] = len(key_examples) 115 | key_examples.append(key_example) 116 | 117 | value_examples.append(value_example) 118 | dictionary[key_dict[key_example]].append(i) 119 | 120 | self.key_dataset = Dataset(key_examples, key_fields) 121 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) 122 | super(DictionaryDataset, self).__init__(examples, fields) 123 | 124 | def collate_fn(self): 125 | def collate(batch): 126 | key_batch, value_batch = list(zip(*batch)) 127 | key_tensors = self.key_dataset.collate_fn()(key_batch) 128 | value_tensors = self.value_dataset.collate_fn()(value_batch) 129 | return key_tensors, value_tensors 130 | return collate 131 | 132 | def __getitem__(self, i): 133 | # TODO 134 | captions = [] 135 | for item in self.value_dataset[i]: 136 | captions.append(item['caption']) 137 | return self.key_dataset[i], captions 138 | # return self.key_dataset[i], self.value_dataset[i] 139 | 140 | # arr = [] 141 | # for item in self.value_dataset[i]: 142 | # arr.append([item[0]['caption'], item[1]]) 143 | # return self.key_dataset[i], arr 144 | 145 | def __len__(self): 146 | return len(self.key_dataset) 147 | 148 | 149 | def unique(sequence): 150 | seen = set() 151 | if isinstance(sequence[0], list): 152 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] 153 | else: 154 | return [x for x in sequence if not (x in seen or seen.add(x))] 155 | 156 | 157 | class PairedDataset(Dataset): 158 | def __init__(self, examples, fields): 159 | assert ('image' in fields) 160 | assert ('text' in fields) 161 | super(PairedDataset, self).__init__(examples, fields) 162 | self.image_field = self.fields['image'] 163 | self.text_field = self.fields['text'] 164 | 165 | def image_set(self): 166 | img_list = [e.image for e in self.examples] 167 | image_set = unique(img_list) 168 | examples = [Example.fromdict({'image': i}) for i in image_set] 169 | dataset = Dataset(examples, {'image': self.image_field}) 170 | return dataset 171 | 172 | def text_set(self): 173 | text_list = [e.text for e in self.examples] 174 | text_list = unique(text_list) 175 | examples = [Example.fromdict({'text': t}) for t in text_list] 176 | dataset = Dataset(examples, {'text': self.text_field}) 177 | return dataset 178 | 179 | def image_dictionary(self, fields=None): 180 | if not fields: 181 | fields = self.fields 182 | dataset = DictionaryDataset(self.examples, fields, key_fields='image') 183 | return dataset 184 | 185 | def text_dictionary(self, fields=None): 186 | if not fields: 187 | fields = self.fields 188 | dataset = DictionaryDataset(self.examples, fields, key_fields='text') 189 | return dataset 190 | 191 | @property 192 | def splits(self): 193 | raise NotImplementedError 194 | 195 | 196 | class COCODataset(PairedDataset): 197 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 198 | cut_validation=False): 199 | roots = {} 200 | roots['train'] = { 201 | 'img': os.path.join(img_root, 'train2014'), 202 | 'cap': os.path.join(ann_root, 'captions_train2014.json') 203 | } 204 | roots['val'] = { 205 | 'img': os.path.join(img_root, 'val2014'), 206 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 207 | } 208 | roots['test'] = { 209 | 'img': os.path.join(img_root, 'val2014'), 210 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 211 | } 212 | roots['trainrestval'] = { 213 | 'img': (roots['train']['img'], roots['val']['img']), 214 | 'cap': (roots['train']['cap'], roots['val']['cap']) 215 | } 216 | 217 | if id_root is not None: 218 | ids = {} 219 | ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy')) 220 | ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy')) 221 | if cut_validation: 222 | ids['val'] = ids['val'][:5000] 223 | ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy')) 224 | ids['trainrestval'] = ( 225 | ids['train'], 226 | np.load(os.path.join(id_root, 'coco_restval_ids.npy'))) 227 | 228 | if use_restval: 229 | roots['train'] = roots['trainrestval'] 230 | ids['train'] = ids['trainrestval'] 231 | else: 232 | ids = None 233 | 234 | with nostdout(): 235 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids) 236 | examples = self.train_examples + self.val_examples + self.test_examples 237 | super(COCODataset, self).__init__(examples, {'image': image_field, 'text': text_field}) 238 | 239 | @property 240 | def splits(self): 241 | train_split = PairedDataset(self.train_examples, self.fields) 242 | val_split = PairedDataset(self.val_examples, self.fields) 243 | test_split = PairedDataset(self.test_examples, self.fields) 244 | return train_split, val_split, test_split 245 | 246 | @classmethod 247 | def get_samples(cls, roots, ids_dataset=None): 248 | train_samples = [] 249 | val_samples = [] 250 | test_samples = [] 251 | 252 | for split in ['train', 'val', 'test']: 253 | if isinstance(roots[split]['cap'], tuple): 254 | coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1])) 255 | root = roots[split]['img'] 256 | else: 257 | coco_dataset = (pyCOCO(roots[split]['cap']),) 258 | root = (roots[split]['img'],) 259 | 260 | if ids_dataset is None: 261 | ids = list(coco_dataset.anns.keys()) 262 | else: 263 | ids = ids_dataset[split] 264 | 265 | if isinstance(ids, tuple): 266 | bp = len(ids[0]) 267 | ids = list(ids[0]) + list(ids[1]) 268 | else: 269 | bp = len(ids) 270 | 271 | for index in range(len(ids)): 272 | # for index in range(100): 273 | if index < bp: 274 | coco = coco_dataset[0] 275 | img_root = root[0] 276 | else: 277 | coco = coco_dataset[1] 278 | img_root = root[1] 279 | 280 | ann_id = ids[index] 281 | ann = coco.anns[ann_id] 282 | caption = ann['caption'] 283 | img_id = ann['image_id'] 284 | image = coco.imgs[img_id] 285 | 286 | image_path = os.path.join(img_root, image['file_name']) 287 | orig_size = (image['width'], image['height']) 288 | 289 | image_des = {'image_id': img_id, 'image_path': image_path, 'split': split, 'orig_size': orig_size} 290 | text = {'image_id': img_id, 'ann_id': ann_id, 'caption': caption} 291 | 292 | example = Example.fromdict({'image': image_des, 'text': text}) 293 | 294 | if split == 'train': 295 | train_samples.append(example) 296 | elif split == 'val': 297 | val_samples.append(example) 298 | elif split == 'test': 299 | test_samples.append(example) 300 | 301 | return train_samples, val_samples, test_samples -------------------------------------------------------------------------------- /common/data/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import array 3 | from collections import defaultdict 4 | from functools import partial 5 | import io 6 | import logging 7 | import os 8 | import zipfile 9 | 10 | import six 11 | from six.moves.urllib.request import urlretrieve 12 | import torch 13 | from tqdm import tqdm 14 | import tarfile 15 | 16 | from .utils import reporthook 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Vocab(object): 22 | """Defines a vocabulary object that will be used to numericalize a field. 23 | 24 | Attributes: 25 | freqs: A collections.Counter object holding the frequencies of tokens 26 | in the data used to build the Vocab. 27 | stoi: A collections.defaultdict instance mapping token strings to 28 | numerical identifiers. 29 | itos: A list of token strings indexed by their numerical identifiers. 30 | """ 31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''], 32 | vectors=None, unk_init=None, vectors_cache=None): 33 | """Create a Vocab object from a collections.Counter. 34 | 35 | Arguments: 36 | counter: collections.Counter object holding the frequencies of 37 | each value found in the data. 38 | max_size: The maximum size of the vocabulary, or None for no 39 | maximum. Default: None. 40 | min_freq: The minimum frequency needed to include a token in the 41 | vocabulary. Values less than 1 will be set to 1. Default: 1. 42 | specials: The list of special tokens (e.g., padding or eos) that 43 | will be prepended to the vocabulary in addition to an 44 | token. Default: [''] 45 | vectors: One of either the available pretrained vectors 46 | or custom pretrained vectors (see Vocab.load_vectors); 47 | or a list of aforementioned vectors 48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 49 | to zero vectors; can be any function that takes in a Tensor and 50 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 51 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 52 | """ 53 | self.freqs = counter 54 | counter = counter.copy() 55 | min_freq = max(min_freq, 1) 56 | 57 | self.itos = list(specials) 58 | # frequencies of special tokens are not counted when building vocabulary 59 | # in frequency order 60 | for tok in specials: 61 | del counter[tok] 62 | 63 | max_size = None if max_size is None else max_size + len(self.itos) 64 | 65 | # sort by frequency, then alphabetically 66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 68 | 69 | for word, freq in words_and_frequencies: 70 | if freq < min_freq or len(self.itos) == max_size: 71 | break 72 | self.itos.append(word) 73 | 74 | self.stoi = defaultdict(_default_unk_index) 75 | # stoi is simply a reverse dict for itos 76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) 77 | 78 | self.vectors = None 79 | if vectors is not None: 80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 81 | else: 82 | assert unk_init is None and vectors_cache is None 83 | 84 | def __eq__(self, other): 85 | if self.freqs != other.freqs: 86 | return False 87 | if self.stoi != other.stoi: 88 | return False 89 | if self.itos != other.itos: 90 | return False 91 | if self.vectors != other.vectors: 92 | return False 93 | return True 94 | 95 | def __len__(self): 96 | return len(self.itos) 97 | 98 | def extend(self, v, sort=False): 99 | words = sorted(v.itos) if sort else v.itos 100 | for w in words: 101 | if w not in self.stoi: 102 | self.itos.append(w) 103 | self.stoi[w] = len(self.itos) - 1 104 | 105 | def load_vectors(self, vectors, **kwargs): 106 | """ 107 | Arguments: 108 | vectors: one of or a list containing instantiations of the 109 | GloVe, CharNGram, or Vectors classes. Alternatively, one 110 | of or a list of available pretrained vectors: 111 | charngram.100d 112 | fasttext.en.300d 113 | fasttext.simple.300d 114 | glove.42B.300d 115 | glove.840B.300d 116 | glove.twitter.27B.25d 117 | glove.twitter.27B.50d 118 | glove.twitter.27B.100d 119 | glove.twitter.27B.200d 120 | glove.6B.50d 121 | glove.6B.100d 122 | glove.6B.200d 123 | glove.6B.300d 124 | Remaining keyword arguments: Passed to the constructor of Vectors classes. 125 | """ 126 | if not isinstance(vectors, list): 127 | vectors = [vectors] 128 | for idx, vector in enumerate(vectors): 129 | if six.PY2 and isinstance(vector, str): 130 | vector = six.text_type(vector) 131 | if isinstance(vector, six.string_types): 132 | # Convert the string pretrained vector identifier 133 | # to a Vectors object 134 | if vector not in pretrained_aliases: 135 | raise ValueError( 136 | "Got string input vector {}, but allowed pretrained " 137 | "vectors are {}".format( 138 | vector, list(pretrained_aliases.keys()))) 139 | vectors[idx] = pretrained_aliases[vector](**kwargs) 140 | elif not isinstance(vector, Vectors): 141 | raise ValueError( 142 | "Got input vectors of type {}, expected str or " 143 | "Vectors object".format(type(vector))) 144 | 145 | tot_dim = sum(v.dim for v in vectors) 146 | self.vectors = torch.Tensor(len(self), tot_dim) 147 | for i, token in enumerate(self.itos): 148 | start_dim = 0 149 | for v in vectors: 150 | end_dim = start_dim + v.dim 151 | self.vectors[i][start_dim:end_dim] = v[token.strip()] 152 | start_dim = end_dim 153 | assert(start_dim == tot_dim) 154 | 155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): 156 | """ 157 | Set the vectors for the Vocab instance from a collection of Tensors. 158 | 159 | Arguments: 160 | stoi: A dictionary of string to the index of the associated vector 161 | in the `vectors` input argument. 162 | vectors: An indexed iterable (or other structure supporting __getitem__) that 163 | given an input index, returns a FloatTensor representing the vector 164 | for the token associated with the index. For example, 165 | vector[stoi["string"]] should return the vector for "string". 166 | dim: The dimensionality of the vectors. 167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 168 | to zero vectors; can be any function that takes in a Tensor and 169 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 170 | """ 171 | self.vectors = torch.Tensor(len(self), dim) 172 | for i, token in enumerate(self.itos): 173 | wv_index = stoi.get(token, None) 174 | if wv_index is not None: 175 | self.vectors[i] = vectors[wv_index] 176 | else: 177 | self.vectors[i] = unk_init(self.vectors[i]) 178 | 179 | 180 | class Vectors(object): 181 | 182 | def __init__(self, name, cache=None, 183 | url=None, unk_init=None): 184 | """ 185 | Arguments: 186 | name: name of the file that contains the vectors 187 | cache: directory for cached vectors 188 | url: url for download if vectors not found in cache 189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors 190 | to zero vectors; can be any function that takes in a Tensor and 191 | returns a Tensor of the same size 192 | """ 193 | cache = '.vector_cache' if cache is None else cache 194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init 195 | self.cache(name, cache, url=url) 196 | 197 | def __getitem__(self, token): 198 | if token in self.stoi: 199 | return self.vectors[self.stoi[token]] 200 | else: 201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim)) 202 | 203 | def cache(self, name, cache, url=None): 204 | if os.path.isfile(name): 205 | path = name 206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' 207 | else: 208 | path = os.path.join(cache, name) 209 | path_pt = path + '.pt' 210 | 211 | if not os.path.isfile(path_pt): 212 | if not os.path.isfile(path) and url: 213 | logger.info('Downloading vectors from {}'.format(url)) 214 | if not os.path.exists(cache): 215 | os.makedirs(cache) 216 | dest = os.path.join(cache, os.path.basename(url)) 217 | if not os.path.isfile(dest): 218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 219 | try: 220 | urlretrieve(url, dest, reporthook=reporthook(t)) 221 | except KeyboardInterrupt as e: # remove the partial zip file 222 | os.remove(dest) 223 | raise e 224 | logger.info('Extracting vectors into {}'.format(cache)) 225 | ext = os.path.splitext(dest)[1][1:] 226 | if ext == 'zip': 227 | with zipfile.ZipFile(dest, "r") as zf: 228 | zf.extractall(cache) 229 | elif ext == 'gz': 230 | with tarfile.open(dest, 'r:gz') as tar: 231 | tar.extractall(path=cache) 232 | if not os.path.isfile(path): 233 | raise RuntimeError('no vectors found at {}'.format(path)) 234 | 235 | # str call is necessary for Python 2/3 compatibility, since 236 | # argument must be Python 2 str (Python 3 bytes) or 237 | # Python 3 str (Python 2 unicode) 238 | itos, vectors, dim = [], array.array(str('d')), None 239 | 240 | # Try to read the whole file with utf-8 encoding. 241 | binary_lines = False 242 | try: 243 | with io.open(path, encoding="utf8") as f: 244 | lines = [line for line in f] 245 | # If there are malformed lines, read in binary mode 246 | # and manually decode each word from utf-8 247 | except: 248 | logger.warning("Could not read {} as UTF8 file, " 249 | "reading file as bytes and skipping " 250 | "words with malformed UTF8.".format(path)) 251 | with open(path, 'rb') as f: 252 | lines = [line for line in f] 253 | binary_lines = True 254 | 255 | logger.info("Loading vectors from {}".format(path)) 256 | for line in tqdm(lines, total=len(lines)): 257 | # Explicitly splitting on " " is important, so we don't 258 | # get rid of Unicode non-breaking spaces in the vectors. 259 | entries = line.rstrip().split(b" " if binary_lines else " ") 260 | 261 | word, entries = entries[0], entries[1:] 262 | if dim is None and len(entries) > 1: 263 | dim = len(entries) 264 | elif len(entries) == 1: 265 | logger.warning("Skipping token {} with 1-dimensional " 266 | "vector {}; likely a header".format(word, entries)) 267 | continue 268 | elif dim != len(entries): 269 | raise RuntimeError( 270 | "Vector for token {} has {} dimensions, but previously " 271 | "read vectors have {} dimensions. All vectors must have " 272 | "the same number of dimensions.".format(word, len(entries), dim)) 273 | 274 | if binary_lines: 275 | try: 276 | if isinstance(word, six.binary_type): 277 | word = word.decode('utf-8') 278 | except: 279 | logger.info("Skipping non-UTF8 token {}".format(repr(word))) 280 | continue 281 | vectors.extend(float(x) for x in entries) 282 | itos.append(word) 283 | 284 | self.itos = itos 285 | self.stoi = {word: i for i, word in enumerate(itos)} 286 | self.vectors = torch.Tensor(vectors).view(-1, dim) 287 | self.dim = dim 288 | logger.info('Saving vectors to {}'.format(path_pt)) 289 | if not os.path.exists(cache): 290 | os.makedirs(cache) 291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 292 | else: 293 | logger.info('Loading vectors from {}'.format(path_pt)) 294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 295 | 296 | 297 | class GloVe(Vectors): 298 | url = { 299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 303 | } 304 | 305 | def __init__(self, name='840B', dim=300, **kwargs): 306 | url = self.url[name] 307 | name = 'glove.{}.{}d.txt'.format(name, str(dim)) 308 | super(GloVe, self).__init__(name, url=url, **kwargs) 309 | 310 | 311 | class FastText(Vectors): 312 | 313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' 314 | 315 | def __init__(self, language="en", **kwargs): 316 | url = self.url_base.format(language) 317 | name = os.path.basename(url) 318 | super(FastText, self).__init__(name, url=url, **kwargs) 319 | 320 | 321 | class CharNGram(Vectors): 322 | 323 | name = 'charNgram.txt' 324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 325 | 'jmt_pre-trained_embeddings.tar.gz') 326 | 327 | def __init__(self, **kwargs): 328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) 329 | 330 | def __getitem__(self, token): 331 | vector = torch.Tensor(1, self.dim).zero_() 332 | if token == "": 333 | return self.unk_init(vector) 334 | # These literals need to be coerced to unicode for Python 2 compatibility 335 | # when we try to join them with read ngrams from the files. 336 | chars = ['#BEGIN#'] + list(token) + ['#END#'] 337 | num_vectors = 0 338 | for n in [2, 3, 4]: 339 | end = len(chars) - n + 1 340 | grams = [chars[i:(i + n)] for i in range(end)] 341 | for gram in grams: 342 | gram_key = '{}gram-{}'.format(n, ''.join(gram)) 343 | if gram_key in self.stoi: 344 | vector += self.vectors[self.stoi[gram_key]] 345 | num_vectors += 1 346 | if num_vectors > 0: 347 | vector /= num_vectors 348 | else: 349 | vector = self.unk_init(vector) 350 | return vector 351 | 352 | 353 | def _default_unk_index(): 354 | return 0 355 | 356 | 357 | pretrained_aliases = { 358 | "charngram.100d": partial(CharNGram), 359 | "fasttext.en.300d": partial(FastText, language="en"), 360 | "fasttext.simple.300d": partial(FastText, language="simple"), 361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"), 362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"), 363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), 364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), 365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), 366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), 367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"), 368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"), 369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"), 370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300") 371 | } 372 | """Mapping from string name to factory function""" 373 | -------------------------------------------------------------------------------- /common/models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | # scale the pos to (0 ~ 1) 7 | def get_relative_pos(x, batch_size, norm_len): 8 | x = x.view(1, -1, 1).expand(batch_size, -1, -1) 9 | return x / norm_len 10 | 11 | def get_grids_pos(batch_size, seq_len, grid_size=(7, 7), device='cuda'): 12 | assert seq_len == grid_size[0] * grid_size[1] 13 | 14 | # record the pos of each grid according to the form of region box 15 | x = torch.arange(0, grid_size[0]).float().to(device) 16 | y = torch.arange(0, grid_size[1]).float().to(device) 17 | 18 | px_min = x.view(-1, 1).expand(-1, grid_size[0]).contiguous().view(-1) 19 | py_min = y.view(1, -1).expand(grid_size[1], -1).contiguous().view(-1) 20 | 21 | px_max = px_min + 1 22 | py_max = py_min + 1 23 | 24 | # scale pos 25 | rpx_min = get_relative_pos(px_min, batch_size, grid_size[0]) 26 | rpy_min = get_relative_pos(py_min, batch_size, grid_size[1]) 27 | 28 | rpx_max = get_relative_pos(px_max, batch_size, grid_size[0]) 29 | rpy_max = get_relative_pos(py_max, batch_size, grid_size[1]) 30 | 31 | return rpx_min, rpy_min, rpx_max, rpy_max 32 | 33 | # 适用于bbox或者grid网格的相对位置编码 34 | def RelationalEmbedding(f_g, dim_g=64, wave_len=1000, is_gird=False, trignometric_embedding=True): 35 | """ 36 | Given a tensor with bbox coordinates for detected objects on each batch image, 37 | this function computes a matrix for each image 38 | with entry (i,j) given by a vector representation of the 39 | displacement between the coordinates of bbox_i, and bbox_j 40 | input: np.array of shape=(batch_size, max_nr_bounding_boxes, 4) 41 | output: np.array of shape=(batch_size, max_nr_bounding_boxes, max_nr_bounding_boxes, 64) 42 | """ 43 | # returns a relational embedding for each pair of bboxes, with dimension = dim_g 44 | # follow implementation of https://github.com/heefe92/Relation_Networks-pytorch/blob/master/model.py#L1014-L1055 45 | device = f_g.device 46 | 47 | if is_gird: 48 | batch_size, seq_len = f_g.shape[:2] 49 | gs = int(math.sqrt(seq_len)) 50 | x_min, y_min, x_max, y_max = get_grids_pos(batch_size, seq_len, grid_size=(gs, gs), device=device) 51 | else: 52 | batch_size = f_g.size(0) 53 | x_min, y_min, x_max, y_max = torch.chunk(f_g, 4, dim=-1) 54 | 55 | cx = (x_min + x_max) * 0.5 56 | cy = (y_min + y_max) * 0.5 57 | w = (x_max - x_min) + 1. 58 | h = (y_max - y_min) + 1. 59 | 60 | # cx.view(1,-1) transposes the vector cx, and so dim(delta_x) = (dim(cx), dim(cx)) 61 | delta_x = cx - cx.view(batch_size, 1, -1) 62 | delta_x = torch.clamp(torch.abs(delta_x / w), min=1e-3) 63 | delta_x = torch.log(delta_x) 64 | 65 | delta_y = cy - cy.view(batch_size, 1, -1) 66 | delta_y = torch.clamp(torch.abs(delta_y / h), min=1e-3) 67 | delta_y = torch.log(delta_y) 68 | 69 | delta_w = torch.log(w / w.view(batch_size, 1, -1)) 70 | delta_h = torch.log(h / h.view(batch_size, 1, -1)) 71 | 72 | matrix_size = delta_h.size() 73 | delta_x = delta_x.view(batch_size, matrix_size[1], matrix_size[2], 1) 74 | delta_y = delta_y.view(batch_size, matrix_size[1], matrix_size[2], 1) 75 | delta_w = delta_w.view(batch_size, matrix_size[1], matrix_size[2], 1) 76 | delta_h = delta_h.view(batch_size, matrix_size[1], matrix_size[2], 1) 77 | 78 | position_mat = torch.cat((delta_x, delta_y, delta_w, delta_h), -1) # bs * r * r * 4 79 | 80 | if trignometric_embedding == True: 81 | feat_range = torch.arange(dim_g / 8).to(device) 82 | dim_mat = feat_range / (dim_g / 8) 83 | dim_mat = 1. / (torch.pow(wave_len, dim_mat)) 84 | 85 | dim_mat = dim_mat.view(1, 1, 1, -1) 86 | position_mat = position_mat.view(batch_size, matrix_size[1], matrix_size[2], 4, -1) 87 | position_mat = 100. * position_mat 88 | 89 | mul_mat = position_mat * dim_mat 90 | mul_mat = mul_mat.view(batch_size, matrix_size[1], matrix_size[2], -1) 91 | sin_mat = torch.sin(mul_mat) 92 | cos_mat = torch.cos(mul_mat) 93 | embedding = torch.cat((sin_mat, cos_mat), -1) 94 | else: 95 | embedding = position_mat 96 | return embedding 97 | 98 | 99 | def position_embedding(input, d_model): 100 | device = input.device 101 | input = input.view(-1, 1) 102 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=device).view(1, -1) 103 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 104 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 105 | 106 | out = torch.zeros((input.shape[0], d_model), device=device) 107 | out[:, ::2] = sin 108 | out[:, 1::2] = cos 109 | return out 110 | 111 | # sin cos绝对位置编码 112 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 113 | pos = torch.arange(max_len, dtype=torch.float32) 114 | out = position_embedding(pos, d_model) 115 | 116 | if padding_idx is not None: 117 | out[padding_idx] = 0 118 | return out 119 | 120 | # 基于Grid网格的sin cos绝对位置编码 121 | class GridPESine(nn.Module): 122 | """ 123 | This is a more standard version of the position embedding, very similar to the one 124 | used by the Attention is all you need paper, generalized to work on images. 125 | """ 126 | 127 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 128 | super().__init__() 129 | self.num_pos_feats = num_pos_feats 130 | self.temperature = temperature 131 | self.normalize = normalize 132 | if scale is not None and normalize is False: 133 | raise ValueError("normalize should be True if scale is passed") 134 | if scale is None: 135 | scale = 2 * math.pi 136 | self.scale = scale 137 | 138 | def forward(self, x, mask=None): 139 | device = x.device 140 | if mask is None: 141 | mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=device) 142 | not_mask = (mask == False) 143 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 144 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 145 | if self.normalize: 146 | eps = 1e-6 147 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 148 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 149 | 150 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) 151 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 152 | 153 | pos_x = x_embed[:, :, :, None] / dim_t 154 | pos_y = y_embed[:, :, :, None] / dim_t 155 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 156 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 157 | pos = torch.cat((pos_y, pos_x), dim=3) # .permute(0, 3, 1, 2) 158 | pos = pos.flatten(1, 2) 159 | return pos 160 | 161 | 162 | class PositionWiseFeedForward(nn.Module): 163 | ''' 164 | Position-wise feed forward layer 165 | ''' 166 | 167 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, act_fn='ReLU', identity_map_reordering=False, local=False): 168 | super(PositionWiseFeedForward, self).__init__() 169 | self.local = local 170 | self.identity_map_reordering = identity_map_reordering 171 | if local: 172 | self.dwconv = DWConv(d_ff, gird_size=(9, 9)) 173 | self.fc1 = nn.Linear(d_model, d_ff) 174 | self.fc2 = nn.Linear(d_ff, d_model) 175 | self.dropout = nn.Dropout(p=dropout) 176 | self.dropout_2 = nn.Dropout(p=dropout) 177 | self.layer_norm = nn.LayerNorm(d_model) 178 | self.act = getattr(nn, act_fn)() 179 | 180 | def forward(self, input): 181 | if self.identity_map_reordering: 182 | x = self.layer_norm(input) 183 | x = self.fc1(x) 184 | if self.local: 185 | x = x + self.dwconv(x) 186 | x = self.act(x) 187 | x = self.dropout_2(x) 188 | x = self.fc2(x) 189 | x = input + self.dropout(self.act(x)) 190 | else: 191 | x = self.fc1(input) 192 | if self.local: 193 | x = self.dwconv(x) 194 | x = self.act(x) 195 | x = self.dropout_2(x) 196 | x = self.fc2(x) 197 | x = self.dropout(x) 198 | x = self.layer_norm(input + x) 199 | return x 200 | 201 | 202 | class FFNWithPrivateLN(nn.Module): 203 | ''' 204 | Position-wise feed forward layer 205 | ''' 206 | 207 | def __init__(self, d_model=512, d_ff=2048, dropout=.1): 208 | super(FFNWithPrivateLN, self).__init__() 209 | self.fc1 = nn.Linear(d_model, d_ff) 210 | self.fc2 = nn.Linear(d_ff, d_model) 211 | self.dropout = nn.Dropout(p=dropout) 212 | self.dropout_2 = nn.Dropout(p=dropout) 213 | self.layer_norm = nn.LayerNorm(d_model) 214 | self.layer_norm1 = nn.LayerNorm(d_model) 215 | self.layer_norm2 = nn.LayerNorm(d_model) 216 | 217 | def forward(self, input, m=0): 218 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 219 | out = self.dropout(out) 220 | if m == 0: 221 | out = self.layer_norm(input + out) 222 | elif m == 1: 223 | out = self.layer_norm1(input + out) 224 | else: 225 | out = self.layer_norm2(input + out) 226 | return out 227 | 228 | 229 | class LocalFeedForward(nn.Module): 230 | 231 | def __init__(self, d_model=512, d_ff=2048, dropout=.1): 232 | super(PositionWiseFeedForward, self).__init__() 233 | self.dwconv = DWConv(d_ff, gird_size=(9, 9)) 234 | self.fc1 = nn.Linear(d_model, d_ff) 235 | self.fc2 = nn.Linear(d_ff, d_model) 236 | self.dropout = nn.Dropout(p=dropout) 237 | self.dropout_2 = nn.Dropout(p=dropout) 238 | self.layer_norm = nn.LayerNorm(d_model) 239 | self.act = nn.ReLU() 240 | 241 | def forward(self, input): 242 | x = self.fc1(input) 243 | x = self.dwconv(x) 244 | x = self.act(x) 245 | x = self.dropout_2(x) 246 | x = self.fc2(x) 247 | x = self.dropout(x) 248 | x = self.layer_norm(input + x) 249 | return x 250 | 251 | 252 | class Adapter(nn.Module): 253 | def __init__(self, d_model=512, d_v=64, h=8, mid_dim=40, dropout=.1, act_fn='ReLU'): 254 | super(Adapter, self).__init__() 255 | 256 | self.fc_dalta_o = nn.Linear(h * d_v, d_model) 257 | 258 | self.mh_adapters = nn.ModuleList([ 259 | nn.Sequential( 260 | nn.Linear(d_model, d_v), 261 | # nn.Linear(d_v, mid_dim), 262 | # getattr(nn, act_fn)(), 263 | # nn.Linear(mid_dim, d_v) 264 | DWConv(d_v, gird_size=(9, 9)), 265 | # nn.ReLU(), 266 | # DWConv(d_v, gird_size=(9, 9)) 267 | ) 268 | for _ in range(h)]) 269 | 270 | # self.act = nn.ReLU() 271 | # self.dropout = nn.Dropout(p=dropout) 272 | # self.layer_norm = nn.LayerNorm(d_model) 273 | 274 | self.init_weights() 275 | 276 | def init_weights(self): 277 | for module in self.modules(): 278 | if isinstance(module, nn.Linear): 279 | nn.init.xavier_uniform_(module.weight) 280 | nn.init.constant_(module.bias, 0) 281 | 282 | def forward(self, input): 283 | 284 | delta_hs = [l(input) for l in self.mh_adapters] 285 | delta_h = torch.cat(delta_hs, dim=-1) # (b_s, nq, h*d_v) 286 | delta_h = self.fc_dalta_o(delta_h) 287 | 288 | # delta_h = self.act(delta_h) 289 | 290 | # delta_h = self.dropout(delta_h) 291 | # delta_h = self.layer_norm(input + delta_h) 292 | # delta_h = input + delta_h 293 | 294 | return delta_h 295 | 296 | 297 | class DWConv(nn.Module): 298 | def __init__(self, dim=64, gird_size=(9, 9)): 299 | super(DWConv, self).__init__() 300 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 301 | self.gird_size = gird_size 302 | self.act = nn.ReLU() 303 | 304 | # self.init_weights() 305 | 306 | def init_weights(self): 307 | nn.init.kaiming_normal_(self.dwconv.weight) 308 | 309 | def forward(self, x): 310 | B, N, C = x.shape 311 | H, W = self.gird_size 312 | x = x.transpose(1, 2).view(B, C, H, W) 313 | x = self.dwconv(x) 314 | x = x.flatten(2).transpose(1, 2) 315 | x = self.act(x) 316 | 317 | return x 318 | 319 | class MlpBlock(nn.Module): 320 | def __init__(self, in_dim, mlp_dim): 321 | super(MlpBlock, self).__init__() 322 | self.ff1 = nn.Linear(in_dim, mlp_dim) 323 | self.act = nn.ReLU() 324 | self.ff2 = nn.Linear(mlp_dim, in_dim) 325 | 326 | def __call__(self, x): 327 | x = x.transpose(1, 2) 328 | x = self.ff1(x) 329 | x = self.act(x) 330 | x = self.ff2(x) 331 | x = x.transpose(1, 2) 332 | return x 333 | 334 | # 我提出的基于grid的极坐标相对位置编码 335 | class PolarRPE(nn.Module): 336 | def __init__(self, k=3, h=8, d_k=64, d_r=256, window_size = (9, 9), device='cuda:0'): 337 | super(PolarRPE, self).__init__() 338 | Wh, Ww = window_size 339 | self.h = h 340 | self.d_k = d_k 341 | self.num_seq = Wh * Ww 342 | # num_direction = 4 * k + 1 343 | num_direction = 4 * k 344 | num_distance = math.floor(math.sqrt(Wh*Wh + Ww*Ww)) 345 | 346 | # define a parameter table of relative position 347 | # self.relative_direction_table = nn.Embedding(num_direction, d_r) 348 | # self.relative_distance_table = nn.Embedding(num_distance, d_r) 349 | self.relative_table = nn.Embedding(num_direction * num_distance, d_r) 350 | self.projection = nn.Linear(d_r, h * d_k) 351 | # self.projection = nn.Linear(d_r, h) 352 | # self.act = nn.ReLU() 353 | # self.projection = nn.Linear(d_r * 2, h * d_k) 354 | 355 | # get pair-wise relative position index for each token inside the window 356 | coords_h, coords_w = torch.arange(Wh), torch.arange(Ww) 357 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]), dim=-1) # Wh, Ww, 2 358 | coords_flatten = coords.view(-1, 2) # Wh*Ww, 2 359 | relative_coords = coords_flatten.unsqueeze(1) - coords_flatten.unsqueeze(0) # Wh*Ww, Wh*Ww, 2 360 | relative_coords = relative_coords.view(-1, 2).float() # N*N, 2 361 | 362 | # relative_distance_pos 363 | norm_relative_distance = torch.norm(relative_coords, dim=-1) 364 | relative_distance_pos = norm_relative_distance.int() # N*N 365 | 366 | # relative_direction_pos 367 | unit_direction_x = torch.cos(torch.arange(num_direction - 1) * math.pi / 2 / k) 368 | unit_direction_y = torch.sin(torch.arange(num_direction - 1) * math.pi / 2 / k) 369 | unit_direction = torch.stack([unit_direction_x, unit_direction_y]) # 2, 4k 370 | 371 | relative_direction = torch.matmul(relative_coords, unit_direction) 372 | relative_direction_pos = torch.argmax(relative_direction, dim=-1) # N*N 373 | # relative_direction_pos = relative_direction_pos.masked_fill(norm_relative_distance == 0, num_direction-1) 374 | 375 | relative_pos = relative_direction_pos * num_distance + relative_distance_pos 376 | # relative_pos = relative_pos.masked_fill(norm_relative_distance == 0, num_direction * num_distance) 377 | 378 | # self.relative_direction_pos = relative_direction_pos.to(device) 379 | # self.relative_distance_pos = relative_distance_pos.to(device) 380 | self.relative_pos = relative_pos.to(device) 381 | 382 | self.init_weights() 383 | 384 | def init_weights(self): 385 | # nn.init.uniform_(self.relative_direction_table.weight, b=0.2) 386 | # nn.init.uniform_(self.relative_distance_table.weight, b=0.2) 387 | nn.init.uniform_(self.relative_table.weight, b=0.2) 388 | nn.init.xavier_uniform_(self.projection.weight) 389 | nn.init.constant_(self.projection.bias, 0) 390 | 391 | def forward(self, bs): 392 | 393 | # direction + distance 394 | # relative_direction_emb = self.relative_direction_table(self.relative_direction_pos) 395 | # relative_distance_emb = self.relative_distance_table(self.relative_distance_pos) 396 | # relative_emb = relative_direction_emb + relative_distance_emb # (n*n, d_r) 397 | 398 | # relative_emb = torch.cat([relative_direction_emb, relative_distance_emb], dim=-1) # (n*n, d_r * 2) 399 | relative_emb = self.relative_table(self.relative_pos) 400 | relative_emb = self.projection(relative_emb).view(-1, self.h, self.d_k) # (n*n, h, d_k) 401 | 402 | # relative_emb = self.projection(relative_emb) # (n*n, h) 403 | # relative_emb = self.act(relative_emb) 404 | 405 | # direction 406 | # relative_direction_emb = self.relative_direction_table(self.relative_direction_pos) # (n*n, d_r) 407 | # relative_emb = self.projection(relative_direction_emb).view(-1, self.h, self.d_k) # (n*n, h, d_k) 408 | 409 | # distance 410 | # relative_distance_emb = self.relative_distance_table(self.relative_distance_pos) # (n*n, d_r) 411 | # relative_emb = self.projection(relative_distance_emb).view(-1, self.h, self.d_k) # (n*n, h, d_k) 412 | 413 | relative_emb = relative_emb.view(self.num_seq, self.num_seq, self.h, self.d_k).permute(2, 0, 1, 3) 414 | relative_emb = relative_emb.unsqueeze(0).expand(bs, self.h, self.num_seq, self.num_seq, self.d_k) # (b_s, h, n, n, d_k) 415 | 416 | # relative_emb = relative_emb.view(self.num_seq, self.num_seq, self.h).permute(2, 0, 1) 417 | # relative_emb = relative_emb.unsqueeze(0).expand(bs, self.h, self.num_seq, self.num_seq) # (b_s, h, n, n) 418 | 419 | return relative_emb 420 | 421 | if __name__ == '__main__': 422 | rpe = PolarRPE(device='cpu') 423 | rpe(2) 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | -------------------------------------------------------------------------------- /common/data/field.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | from collections import Counter, OrderedDict 3 | from torch.utils.data.dataloader import default_collate 4 | from torchvision.datasets.folder import default_loader 5 | from itertools import chain 6 | import six 7 | import torch 8 | import numpy as np 9 | import h5py 10 | import os 11 | import warnings 12 | import shutil 13 | 14 | from .vocab import Vocab 15 | from .utils import get_tokenizer 16 | 17 | 18 | class RawField(object): 19 | """ Defines a general datatype. 20 | 21 | Every dataset consists of one or more types of data. For instance, 22 | a machine translation dataset contains paired examples of text, while 23 | an image captioning dataset contains images and texts. 24 | Each of these types of data is represented by a RawField object. 25 | An RawField object does not assume any property of the data type and 26 | it holds parameters relating to how a datatype should be processed. 27 | 28 | Attributes: 29 | preprocessing: The Pipeline that will be applied to examples 30 | using this field before creating an example. 31 | Default: None. 32 | postprocessing: A Pipeline that will be applied to a list of examples 33 | using this field before assigning to a batch. 34 | Function signature: (batch(list)) -> object 35 | Default: None. 36 | """ 37 | 38 | def __init__(self, preprocessing=None, postprocessing=None): 39 | self.preprocessing = preprocessing 40 | self.postprocessing = postprocessing 41 | 42 | def preprocess(self, x): 43 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """ 44 | if self.preprocessing is not None: 45 | return self.preprocessing(x) 46 | else: 47 | return x 48 | 49 | def process(self, batch, *args, **kwargs): 50 | """ Process a list of examples to create a batch. 51 | 52 | Postprocess the batch with user-provided Pipeline. 53 | 54 | Args: 55 | batch (list(object)): A list of object from a batch of examples. 56 | Returns: 57 | object: Processed object given the input and custom 58 | postprocessing Pipeline. 59 | """ 60 | if self.postprocessing is not None: 61 | batch = self.postprocessing(batch) 62 | return default_collate(batch) 63 | 64 | 65 | class Merge(RawField): 66 | def __init__(self, *fields): 67 | super(Merge, self).__init__() 68 | self.fields = fields 69 | 70 | def preprocess(self, x): 71 | return tuple(f.preprocess(x) for f in self.fields) 72 | 73 | def process(self, batch, *args, **kwargs): 74 | if len(self.fields) == 1: 75 | batch = [batch, ] 76 | else: 77 | batch = list(zip(*batch)) 78 | 79 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch)) 80 | return out 81 | 82 | 83 | class ImageDetectionsField(RawField): 84 | def __init__(self, preprocessing=None, postprocessing=None, feature_type='butd', detections_path=None, max_detections=100, 85 | with_pe=False, sort_by_prob=False, load_in_tmp=False, global_feature=False): 86 | self.max_detections = max_detections 87 | self.detections_path = detections_path 88 | self.feature_type = feature_type 89 | self.sort_by_prob = sort_by_prob 90 | self.with_pe = with_pe 91 | self.global_feature = global_feature 92 | 93 | tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path)) 94 | 95 | if load_in_tmp: 96 | if not os.path.isfile(tmp_detections_path): 97 | if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path): 98 | warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path) 99 | else: 100 | warnings.warn("Copying detection file to /tmp") 101 | shutil.copyfile(detections_path, tmp_detections_path) 102 | warnings.warn("Done.") 103 | self.detections_path = tmp_detections_path 104 | else: 105 | self.detections_path = tmp_detections_path 106 | 107 | available_features = ['butd', 'clip', 'vinvl', 'tokens'] 108 | assert self.feature_type in available_features, \ 109 | "region feature not supported, please select ['butd', 'clip', 'vinvl', 'tokens']" 110 | 111 | if self.feature_type in ['butd', 'vinvl', 'clip', 'tokens']: 112 | self.f = h5py.File(self.detections_path, 'r') 113 | 114 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing) 115 | 116 | def preprocess(self, x, avoid_precomp=False): 117 | image_id, split, orig_size = x['image_id'], x['split'], x['orig_size'] 118 | try: 119 | if self.feature_type in ['butd', 'vinvl']: 120 | precomp_data = torch.from_numpy(self.f['%d_features' % image_id][()]) 121 | if self.with_pe: 122 | boxes = torch.from_numpy(self.f['%d_boxes' % image_id][()]) 123 | if len(boxes): 124 | precomp_data = precomp_data[:len(boxes),:] 125 | 126 | if self.sort_by_prob: 127 | idxs = torch.from_numpy(np.argsort(np.max(self.f['%d_cls_prob' % image_id][()], -1))[::-1]) 128 | precomp_data = precomp_data[idxs] 129 | if self.with_pe: 130 | boxes = boxes[idxs] 131 | 132 | elif self.feature_type == 'clip': 133 | precomp_data = torch.from_numpy(self.f['%d_features' % image_id][()]) 134 | if self.global_feature: 135 | global_feature = torch.from_numpy(self.f['%d_global' % image_id][()]) 136 | return precomp_data, global_feature 137 | return precomp_data 138 | 139 | elif self.feature_type == 'tokens': 140 | precomp_data = torch.from_numpy(self.f['%d_tokens' % image_id][()]) 141 | return precomp_data 142 | 143 | if self.with_pe: 144 | size = torch.tensor(orig_size).repeat(len(boxes), 2) 145 | relative_boxes = boxes / size 146 | 147 | except KeyError: 148 | warnings.warn('Could not find detections for %d' % image_id) 149 | precomp_data = torch.rand(10,2048) 150 | relative_boxes = torch.rand((10, 4)) 151 | 152 | delta = self.max_detections - precomp_data.shape[0] 153 | if delta > 0: 154 | precomp_data = torch.cat([precomp_data, torch.zeros((delta, precomp_data.shape[1]))], 0) 155 | elif delta < 0: 156 | precomp_data = precomp_data[:self.max_detections] 157 | 158 | if self.with_pe: 159 | delta_boxes = self.max_detections - len(relative_boxes) 160 | if delta_boxes > 0: 161 | relative_boxes = torch.cat([relative_boxes, torch.zeros((delta_boxes, relative_boxes.shape[1]))], 0) 162 | elif delta_boxes < 0: 163 | relative_boxes = relative_boxes[:self.max_detections] 164 | 165 | return (precomp_data, relative_boxes) 166 | 167 | return precomp_data 168 | 169 | class DualImageField(RawField): 170 | def __init__(self, clip_path, vinvl_path, preprocessing=None, postprocessing=None, max_detections=100, global_feature=False, 171 | with_pe=False, sort_by_prob=False, load_in_tmp=False): 172 | 173 | self.clip_field = ImageDetectionsField(preprocessing, postprocessing, 'clip', clip_path, global_feature=global_feature) 174 | self.vinvl_field = ImageDetectionsField(preprocessing, postprocessing, 'vinvl', vinvl_path, 175 | max_detections, with_pe, sort_by_prob, load_in_tmp) 176 | self.global_feature = global_feature 177 | super().__init__(preprocessing, postprocessing) 178 | 179 | def preprocess(self, x): 180 | region_features = self.vinvl_field.preprocess(x) 181 | if self.global_feature: 182 | grid_features, global_feature = self.clip_field.preprocess(x) 183 | return (grid_features, region_features, global_feature) 184 | else: 185 | grid_features = self.clip_field.preprocess(x) 186 | return (grid_features, region_features) 187 | 188 | class TextField(RawField): 189 | vocab_cls = Vocab 190 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python 191 | # numeric type. 192 | dtypes = { 193 | torch.float32: float, 194 | torch.float: float, 195 | torch.float64: float, 196 | torch.double: float, 197 | torch.float16: float, 198 | torch.half: float, 199 | 200 | torch.uint8: int, 201 | torch.int8: int, 202 | torch.int16: int, 203 | torch.short: int, 204 | torch.int32: int, 205 | torch.int: int, 206 | torch.int64: int, 207 | torch.long: int, 208 | } 209 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 210 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 211 | 212 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long, 213 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()), 214 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="", 215 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True, reverse=False): 216 | self.use_vocab = use_vocab 217 | self.init_token = init_token 218 | self.eos_token = eos_token 219 | self.fix_length = fix_length 220 | self.dtype = dtype 221 | self.lower = lower 222 | self.tokenize = get_tokenizer(tokenize) 223 | self.remove_punctuation = remove_punctuation 224 | self.include_lengths = include_lengths 225 | self.batch_first = batch_first 226 | self.pad_token = pad_token 227 | self.unk_token = unk_token 228 | self.pad_first = pad_first 229 | self.truncate_first = truncate_first 230 | self.vocab = None 231 | self.vectors = vectors 232 | self.reverse = reverse 233 | if nopoints: 234 | self.punctuations.append("..") 235 | 236 | super(TextField, self).__init__(preprocessing, postprocessing) 237 | 238 | def preprocess(self, x): 239 | x = x['caption'] 240 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type): 241 | x = six.text_type(x, encoding='utf-8') 242 | if self.lower: 243 | x = six.text_type.lower(x) 244 | x = self.tokenize(x.rstrip('\n')) 245 | if self.remove_punctuation: 246 | x = [w for w in x if w not in self.punctuations] 247 | if self.preprocessing is not None: 248 | x = self.preprocessing(x) 249 | 250 | if self.reverse: 251 | return x, list(reversed(x)) 252 | else: 253 | return x 254 | 255 | def process(self, batch, device=None): 256 | if self.reverse: 257 | batch = list(zip(*batch)) 258 | padded_1 = self.pad(batch[0]) 259 | padded_2 = self.pad(batch[1], reverse=True) 260 | tensor_1 = self.numericalize(padded_1, device=device) 261 | tensor_2 = self.numericalize(padded_2, device=device) 262 | return tensor_1, tensor_2 263 | # padded = self.pad(batch, reverse=True) 264 | # tensor = self.numericalize(padded, device=device) 265 | # return tensor 266 | else: 267 | padded = self.pad(batch) 268 | tensor = self.numericalize(padded, device=device) 269 | return tensor 270 | 271 | def build_vocab(self, *args, **kwargs): 272 | from .dataset import Dataset 273 | 274 | counter = Counter() 275 | sources = [] 276 | for arg in args: 277 | if isinstance(arg, Dataset): 278 | sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] 279 | else: 280 | sources.append(arg) 281 | 282 | for data in sources: 283 | for x in data: 284 | x = self.preprocess(x) 285 | try: 286 | counter.update(x) 287 | except TypeError: 288 | counter.update(chain.from_iterable(x)) 289 | 290 | specials = list(OrderedDict.fromkeys([ 291 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 292 | self.eos_token] 293 | if tok is not None])) 294 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 295 | 296 | def pad(self, minibatch, reverse=False): 297 | """Pad a batch of examples using this field. 298 | Pads to self.fix_length if provided, otherwise pads to the length of 299 | the longest example in the batch. Prepends self.init_token and appends 300 | self.eos_token if those attributes are not None. Returns a tuple of the 301 | padded list and a list containing lengths of each example if 302 | `self.include_lengths` is `True`, else just 303 | returns the padded list. 304 | """ 305 | minibatch = list(minibatch) 306 | if self.fix_length is None: 307 | max_len = max(len(x) for x in minibatch) 308 | else: 309 | max_len = self.fix_length + ( 310 | self.init_token, self.eos_token).count(None) - 2 311 | padded, lengths = [], [] 312 | for x in minibatch: 313 | if self.pad_first: 314 | padded.append( 315 | [self.pad_token] * max(0, max_len - len(x)) + 316 | ([] if self.init_token is None else [self.init_token]) + 317 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 318 | ([] if self.eos_token is None else [self.eos_token])) 319 | elif reverse: 320 | padded.append( 321 | ([] if self.eos_token is None else [self.eos_token]) + 322 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 323 | ([] if self.init_token is None else [self.init_token]) + 324 | [self.pad_token] * max(0, max_len - len(x))) 325 | else: 326 | padded.append( 327 | ([] if self.init_token is None else [self.init_token]) + 328 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 329 | ([] if self.eos_token is None else [self.eos_token]) + 330 | [self.pad_token] * max(0, max_len - len(x))) 331 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 332 | if self.include_lengths: 333 | return padded, lengths 334 | return padded 335 | 336 | def numericalize(self, arr, device=None): 337 | """Turn a batch of examples that use this field into a list of Variables. 338 | If the field has include_lengths=True, a tensor of lengths will be 339 | included in the return value. 340 | Arguments: 341 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 342 | List of tokenized and padded examples, or tuple of List of 343 | tokenized and padded examples and List of lengths of each 344 | example if self.include_lengths is True. 345 | device (str or torch.device): A string or instance of `torch.device` 346 | specifying which device the Variables are going to be created on. 347 | If left as default, the tensors will be created on cpu. Default: None. 348 | """ 349 | if self.include_lengths and not isinstance(arr, tuple): 350 | raise ValueError("Field has include_lengths set to True, but " 351 | "input data is not a tuple of " 352 | "(data batch, batch lengths).") 353 | if isinstance(arr, tuple): 354 | arr, lengths = arr 355 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device) 356 | 357 | if self.use_vocab: 358 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 359 | 360 | if self.postprocessing is not None: 361 | arr = self.postprocessing(arr, self.vocab) 362 | 363 | var = torch.tensor(arr, dtype=self.dtype, device=device) 364 | else: 365 | if self.vectors: 366 | arr = [[self.vectors[x] for x in ex] for ex in arr] 367 | if self.dtype not in self.dtypes: 368 | raise ValueError( 369 | "Specified Field dtype {} can not be used with " 370 | "use_vocab=False because we do not know how to numericalize it. " 371 | "Please raise an issue at " 372 | "https://github.com/pytorch/text/issues".format(self.dtype)) 373 | numericalization_func = self.dtypes[self.dtype] 374 | # It doesn't make sense to explictly coerce to a numeric type if 375 | # the data is sequential, since it's unclear how to coerce padding tokens 376 | # to a numeric type. 377 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 378 | else x for x in arr] 379 | 380 | if self.postprocessing is not None: 381 | arr = self.postprocessing(arr, None) 382 | 383 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr]) 384 | 385 | # var = torch.tensor(arr, dtype=self.dtype, device=device) 386 | if not self.batch_first: 387 | var.t_() 388 | var = var.contiguous() 389 | 390 | if self.include_lengths: 391 | return var, lengths 392 | return var 393 | 394 | def decode(self, word_idxs, join_words=True): 395 | if isinstance(word_idxs, list) and len(word_idxs) == 0: 396 | return self.decode([word_idxs, ], join_words)[0] 397 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int): 398 | return self.decode([word_idxs, ], join_words)[0] 399 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1: 400 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0] 401 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1: 402 | return self.decode(word_idxs.unsqueeze(0), join_words)[0] 403 | 404 | captions = [] 405 | for wis in word_idxs: 406 | caption = [] 407 | for wi in wis: 408 | word = self.vocab.itos[int(wi)] 409 | if word == self.eos_token: 410 | break 411 | caption.append(word) 412 | if join_words: 413 | caption = ' '.join(caption) 414 | captions.append(caption) 415 | return captions --------------------------------------------------------------------------------