├── .gitignore ├── figures ├── alignment.jpg ├── retrieval.png └── architecture.png ├── utils.py ├── evaluate_utils ├── get_stanford_models.sh ├── rouge.py ├── ptbtokenizer.py ├── spice.py ├── compute_relevance.py └── dcg.py ├── configs ├── teran_f30k_MrSw.yaml ├── teran_f30k_MwSr.yaml ├── teran_f30k_symm.yaml ├── teran_f30k_MrSw_sharedTE.yaml ├── teran_coco_MrSw.yaml ├── teran_coco_MwSr.yaml ├── teran_coco_symm.yaml └── teran_coco_MrSw_sharedTE.yaml ├── test.py ├── README.md ├── environment.yml ├── models ├── text.py ├── utils.py ├── loss.py ├── teran.py └── visual.py ├── LICENSE ├── evaluation.py ├── train.py ├── data.py └── features.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.ipynb_checkpoints 4 | *.json 5 | *.pth.tar 6 | -------------------------------------------------------------------------------- /figures/alignment.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesnico/TERAN/HEAD/figures/alignment.jpg -------------------------------------------------------------------------------- /figures/retrieval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesnico/TERAN/HEAD/figures/retrieval.png -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesnico/TERAN/HEAD/figures/architecture.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from models.teran import TERAN 4 | 5 | 6 | def get_model(config): 7 | model = TERAN(config) 8 | return model 9 | 10 | 11 | def dot_sim(x, y): 12 | return numpy.dot(x, y.T) 13 | 14 | 15 | def cosine_sim(x, y): 16 | x = x / numpy.expand_dims(numpy.linalg.norm(x, axis=1), 1) 17 | y = y / numpy.expand_dims(numpy.linalg.norm(y, axis=1), 1) 18 | return numpy.dot(x, y.T) 19 | -------------------------------------------------------------------------------- /evaluate_utils/get_stanford_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # This script downloads the Stanford CoreNLP models. 3 | 4 | CORENLP=stanford-corenlp-full-2015-12-09 5 | SPICELIB=./lib 6 | JAR=stanford-corenlp-3.6.0 7 | 8 | DIR="$( cd "$(dirname "$0")" ; pwd -P )" 9 | cd $DIR 10 | 11 | if [ -f $SPICELIB/$JAR.jar ]; then 12 | echo "Found Stanford CoreNLP." 13 | else 14 | echo "Downloading..." 15 | wget http://nlp.stanford.edu/software/$CORENLP.zip 16 | echo "Unzipping..." 17 | unzip $CORENLP.zip -d $SPICELIB/ 18 | mv $SPICELIB/$CORENLP/$JAR.jar $SPICELIB/ 19 | mv $SPICELIB/$CORENLP/$JAR-models.jar $SPICELIB/ 20 | rm -f $CORENLP.zip 21 | rm -rf $SPICELIB/$CORENLP/ 22 | echo "Done." 23 | fi -------------------------------------------------------------------------------- /configs/teran_f30k_MrSw.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'f30k' 3 | images-path: 'data/f30k/images' # needed for sizes.pkl 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | image-model: 19 | name: 'bottomup' 20 | pre-extracted-features-root: 'data/f30k/features_36' 21 | transformer-layers: 4 22 | dropout: 0.1 23 | pos-encoding: 'concat-and-process' 24 | crop-size: 224 # not used 25 | fine-tune: False 26 | feat-dim: 2048 27 | norm: True 28 | 29 | model: 30 | name: 'teran' 31 | embed-size: 1024 32 | text-aggregation: 'first' # IMPORTANT 33 | image-aggregation: 'first' 34 | layers: 2 35 | exclude-stopwords: False 36 | shared-transformer: False # IMPorTANT 37 | dropout: 0.1 38 | 39 | training: 40 | lr: 0.00001 # 0.000006 41 | grad-clip: 2.0 42 | max-violation: True # IMPORTANT 43 | loss-type: 'alignment' 44 | alignment-mode: 'MrSw' 45 | measure: 'dot' 46 | margin: 0.2 47 | bs: 30 # IMPORTANT 48 | scheduler: 'steplr' 49 | gamma: 0.1 50 | step-size: 20 51 | warmup: 'linear' 52 | warmup-period: 1000 53 | -------------------------------------------------------------------------------- /configs/teran_f30k_MwSr.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'f30k' 3 | images-path: 'data/f30k/images' # needed for sizes.pkl 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/f30k/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: False 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'MwSr' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 30 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: 'linear' 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /configs/teran_f30k_symm.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'f30k' 3 | images-path: 'data/f30k/images' # needed for sizes.pkl 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/f30k/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: False 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'symm' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 30 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: 'linear' 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /configs/teran_f30k_MrSw_sharedTE.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'f30k' 3 | images-path: 'data/f30k/images' # needed for sizes.pkl 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/f30k/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: True 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'MrSw' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 30 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: 'linear' 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /configs/teran_coco_MrSw.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'coco' 3 | images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/coco/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: False 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'MrSw' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 40 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: null 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /configs/teran_coco_MwSr.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'coco' 3 | images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/coco/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: False 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'MwSr' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 40 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: null 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /configs/teran_coco_symm.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'coco' 3 | images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/coco/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: False 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'symm' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 40 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: null 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /configs/teran_coco_MrSw_sharedTE.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'coco' 3 | images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features 4 | data: 'data' 5 | restval: True 6 | pre-extracted-features: False 7 | 8 | text-model: 9 | name: 'bert' 10 | pretrain: 'bert-base-uncased' 11 | word-dim: 768 12 | extraction-hidden-layer: 6 13 | fine-tune: True 14 | pre-extracted: False 15 | layers: 0 16 | dropout: 0.1 17 | 18 | #text-model: 19 | # name: 'gru' 20 | # word-dim: 300 21 | # fine-tune: True 22 | # pre-extracted: False 23 | # layers: 1 24 | 25 | image-model: 26 | name: 'bottomup' 27 | pre-extracted-features-root: 'data/coco/features_36' 28 | transformer-layers: 4 29 | dropout: 0.1 30 | pos-encoding: 'concat-and-process' 31 | crop-size: 224 # not used 32 | fine-tune: False 33 | feat-dim: 2048 34 | norm: True 35 | 36 | model: 37 | name: 'teran' 38 | embed-size: 1024 39 | text-aggregation: 'first' 40 | image-aggregation: 'first' 41 | layers: 2 42 | exclude-stopwords: False 43 | shared-transformer: True 44 | dropout: 0.1 45 | 46 | training: 47 | lr: 0.00001 # 0.000006 48 | grad-clip: 2.0 49 | max-violation: True 50 | loss-type: 'alignment' 51 | alignment-mode: 'MrSw' 52 | measure: 'dot' 53 | margin: 0.2 54 | bs: 40 55 | scheduler: 'steplr' 56 | gamma: 0.1 57 | step-size: 20 58 | warmup: null 59 | warmup-period: 1000 60 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import evaluation 4 | import yaml 5 | import torch 6 | 7 | def main(opt, current_config): 8 | model_checkpoint = opt.checkpoint 9 | 10 | checkpoint = torch.load(model_checkpoint) 11 | print('Checkpoint loaded from {}'.format(model_checkpoint)) 12 | loaded_config = checkpoint['config'] 13 | 14 | if opt.size == "1k": 15 | fold5 = True 16 | elif opt.size == "5k": 17 | fold5 = False 18 | else: 19 | raise ValueError('Test split size not recognized!') 20 | 21 | # Override some mandatory things in the configuration (paths) 22 | if current_config is not None: 23 | loaded_config['dataset']['images-path'] = current_config['dataset']['images-path'] 24 | loaded_config['dataset']['data'] = current_config['dataset']['data'] 25 | loaded_config['image-model']['pre-extracted-features-root'] = current_config['image-model']['pre-extracted-features-root'] 26 | 27 | evaluation.evalrank(loaded_config, checkpoint, split="test", fold5=fold5) 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('checkpoint', type=str, help="Checkpoint to load") 32 | parser.add_argument('--size', type=str, choices=['1k', '5k'], default='1k') 33 | parser.add_argument('--config', type=str, default=None, help="Which configuration to use for overriding the checkpoint configuration. See into 'config' folder") 34 | 35 | opt = parser.parse_args() 36 | if opt.config is not None: 37 | with open(opt.config, 'r') as ymlfile: 38 | config = yaml.load(ymlfile) 39 | else: 40 | config = None 41 | main(opt, config) -------------------------------------------------------------------------------- /evaluate_utils/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if len(string) < len(sub): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 26 | 27 | for j in range(1, len(sub) + 1): 28 | for i in range(1, len(string) + 1): 29 | if string[i - 1] == sub[j - 1]: 30 | lengths[i][j] = lengths[i - 1][j - 1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | 37 | class Rouge: 38 | ''' 39 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 40 | ''' 41 | 42 | def __init__(self): 43 | # vrama91: updated the value below based on discussion with Hovey 44 | self.beta = 1.2 45 | 46 | def score(self, candidate, refs): 47 | """ 48 | Compute ROUGE-L score given one candidate and references for an image 49 | :param candidate: str : candidate sentence to be evaluated 50 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 51 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 52 | """ 53 | assert (len(candidate) == 1) 54 | assert (len(refs) > 0) 55 | prec = [] 56 | rec = [] 57 | 58 | # split into tokens 59 | token_c = candidate[0].lower().split(" ") 60 | 61 | for reference in refs: 62 | # split into tokens 63 | token_r = reference.lower().split(" ") 64 | # compute the longest common subsequence 65 | lcs = my_lcs(token_r, token_c) 66 | prec.append(lcs / float(len(token_c))) 67 | rec.append(lcs / float(len(token_r))) 68 | 69 | prec_max = max(prec) 70 | rec_max = max(rec) 71 | 72 | if prec_max != 0 and rec_max != 0: 73 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 74 | else: 75 | score = 0.0 76 | return score 77 | -------------------------------------------------------------------------------- /evaluate_utils/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'lib/stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, dir=path_to_jar_dirname) 44 | tmp_file.write(sentences) 45 | tmp_file.close() 46 | 47 | # ====================================================== 48 | # tokenize sentence 49 | # ====================================================== 50 | cmd.append(os.path.basename(tmp_file.name)) 51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 52 | stdout=subprocess.PIPE) 53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 54 | lines = token_lines.decode().split('\n') 55 | # remove temp file 56 | os.remove(tmp_file.name) 57 | 58 | # ====================================================== 59 | # create dictionary for tokenized captions 60 | # ====================================================== 61 | for k, line in zip(image_id, lines): 62 | if not k in final_tokenized_captions_for_image: 63 | final_tokenized_captions_for_image[k] = [] 64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 65 | if w not in PUNCTUATIONS]) 66 | final_tokenized_captions_for_image[k].append(tokenized_caption) 67 | 68 | return final_tokenized_captions_for_image -------------------------------------------------------------------------------- /evaluate_utils/spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import sys 4 | import subprocess 5 | import threading 6 | import json 7 | import numpy as np 8 | import ast 9 | import tempfile 10 | from .ptbtokenizer import PTBTokenizer 11 | 12 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 13 | SPICE_JAR = 'spice-1.0.jar' 14 | TEMP_DIR = 'spice_tmp' 15 | CACHE_DIR = 'spice_cache' 16 | 17 | 18 | class Spice: 19 | """ 20 | Main Class to compute the SPICE metric 21 | """ 22 | 23 | def float_convert(self, obj): 24 | try: 25 | return float(obj) 26 | except: 27 | return np.nan 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | gts: list of (list of 5 strings) 32 | res: string 33 | """ 34 | 35 | # convert to dictionary of image_id -> strings 36 | imgIds = list(range(len(gts))) 37 | gts = {k:v for k, v in zip(imgIds, gts)} 38 | res = {0: res} 39 | 40 | tokenizer = PTBTokenizer() 41 | gts = tokenizer.tokenize(gts) 42 | res = tokenizer.tokenize(res) 43 | 44 | # imgIds = sorted(gts.keys()) 45 | 46 | # Prepare temp input file for the SPICE scorer 47 | input_data = [] 48 | for id in imgIds: 49 | hypo = res[0] 50 | ref = gts[id] 51 | 52 | # Sanity check. 53 | assert (type(hypo) is list) 54 | assert (len(hypo) == 1) 55 | assert (type(ref) is list) 56 | assert (len(ref) >= 1) 57 | 58 | input_data.append({ 59 | "image_id": id, 60 | "test": hypo[0], 61 | "refs": ref 62 | }) 63 | 64 | cwd = os.path.dirname(os.path.abspath(__file__)) 65 | temp_dir = os.path.join(cwd, TEMP_DIR) 66 | if not os.path.exists(temp_dir): 67 | os.makedirs(temp_dir) 68 | in_file = tempfile.NamedTemporaryFile(mode='w', delete=False, dir=temp_dir) 69 | json.dump(input_data, in_file, indent=2) 70 | in_file.close() 71 | 72 | # Start job 73 | out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 74 | out_file.close() 75 | cache_dir = os.path.join(cwd, CACHE_DIR) 76 | if not os.path.exists(cache_dir): 77 | os.makedirs(cache_dir) 78 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 79 | '-cache', cache_dir, 80 | '-out', out_file.name, 81 | '-subset', 82 | '-silent' 83 | ] 84 | subprocess.check_call(spice_cmd, 85 | cwd=os.path.dirname(os.path.abspath(__file__))) 86 | 87 | # Read and process results 88 | with open(out_file.name) as data_file: 89 | results = json.load(data_file) 90 | os.remove(in_file.name) 91 | os.remove(out_file.name) 92 | 93 | imgId_to_scores = {} 94 | spice_scores = [] 95 | for item in results: 96 | imgId_to_scores[item['image_id']] = item['scores'] 97 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 98 | average_score = np.mean(np.array(spice_scores)) 99 | scores = [] 100 | for image_id in imgIds: 101 | # Convert none to NaN before saving scores over subcategories 102 | score_set = {} 103 | for category, score_tuple in imgId_to_scores[image_id].items(): 104 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} 105 | scores.append(score_set) 106 | return average_score, scores 107 | 108 | def method(self): 109 | return "SPICE" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Encoder Reasoning and Alignment Network (TERAN) 2 | 3 | ## Updates 4 | 5 | - :fire: 09/2022: The extension to this work (**ALADIN: Distilling Fine-grained Alignment Scores for Efficient Image-Text Matching and Retrieval**) has been published in proceedings of CBMI 2022. Check out [code](https://github.com/mesnico/ALADIN) and [paper](https://arxiv.org/abs/2207.14757)! 6 | 7 | ## Introduction 8 | 9 | Code for the cross-modal visual-linguistic retrieval method from "Fine-grained Visual Textual Alignment for Cross-modal Retrieval using Transformer Encoders", accepted for publication in ACM Transactions on Multimedia Computing, Communications, and Applications (TOMM) [[Pre-print PDF](https://arxiv.org/abs/2008.05231)]. 10 | 11 | This work is an extension to our previous approach TERN accepted at ICPR 2020. 12 | 13 | This repo is built on top of [VSE++](https://github.com/fartashf/vsepp) and [TERN](https://github.com/mesnico/TERN). 14 | 15 |

16 | Fine-grained Alignment for Precise Matching

17 | 18 |

19 | 20 |

21 | Retrieval

22 | 23 |

24 | 25 | 26 | ## Setup 27 | 28 | 1. Clone the repo and move into it: 29 | ``` 30 | git clone https://github.com/mesnico/TERAN 31 | cd TERAN 32 | ``` 33 | 34 | 2. Setup python environment using conda: 35 | ``` 36 | conda env create --file environment.yml 37 | conda activate teran 38 | export PYTHONPATH=. 39 | ``` 40 | 41 | ## Get the data 42 | Data and pretrained models be downloaded from this [OneDrive link](https://cnrsc-my.sharepoint.com/:f:/g/personal/nicola_messina_cnr_it/EnsuSFo-rG5Pmf2FhQDPe7EBCHrNtR1ujSIOEcgaj5Xrwg?e=Ger6Sl) (see the steps below to understand which files you need): 43 | 44 | 1. Download and extract the data folder, containing annotations, the splits by Karpathy et al. and ROUGEL - SPICE precomputed relevances for both COCO and Flickr30K datasets. Extract it: 45 | 46 | ``` 47 | tar -xvf data.tgz 48 | ``` 49 | 50 | 2. Download the bottom-up features for both COCO and Flickr30K. We use the code by [Anderson et al.](https://github.com/peteanderson80/bottom-up-attention) for extracting them. 51 | The following command extracts them under `data/coco/` and `data/f30k/`. If you prefer another location, be sure to adjust the configuration file accordingly. 52 | ``` 53 | # for MS-COCO 54 | tar -xvf features_36_coco.tgz -C data/coco 55 | 56 | # for Flickr30k 57 | tar -xvf features_36_f30k.tgz -C data/f30k 58 | ``` 59 | 60 | ## Evaluate 61 | Extract our pre-trained TERAN models: 62 | ``` 63 | tar -xvf TERAN_pretrained_models.tgz 64 | ``` 65 | 66 | Then, issue the following commands for evaluating a given model on the 1k (5fold cross-validation) or 5k test sets. 67 | ``` 68 | python3 test.py pretrained_models/[model].pth --size 1k 69 | python3 test.py pretrained_models/[model].pth --size 5k 70 | ``` 71 | 72 | Please note that if you changed some default paths (e.g. features are in another folder than `data/coco/features_36`), you will need to use the `--config` option and provide the corresponding yaml configuration file containing the right paths. 73 | ## Train 74 | In order to train the model using a given TERAN configuration, issue the following command: 75 | ``` 76 | python3 train.py --config configs/[config].yaml --logger_name runs/teran 77 | ``` 78 | `runs/teran` is where the output files (tensorboard logs, checkpoints) will be stored during this training session. 79 | 80 | ## Visualization 81 | 82 | WIP 83 | 84 | ## Reference 85 | If you found this code useful, please cite the following paper: 86 | 87 | @article{messina2021fine, 88 | title={Fine-grained visual textual alignment for cross-modal retrieval using transformer encoders}, 89 | author={Messina, Nicola and Amato, Giuseppe and Esuli, Andrea and Falchi, Fabrizio and Gennaro, Claudio and Marchand-Maillet, St{\'e}phane}, 90 | journal={ACM Transactions on Multimedia Computing, Communications, and Applications (TOMM)}, 91 | volume={17}, 92 | number={4}, 93 | pages={1--23}, 94 | year={2021}, 95 | publisher={ACM New York, NY} 96 | } 97 | 98 | ## License 99 | 100 | [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0) 101 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: teran 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - absl-py=0.8.1=py36_0 8 | - blas=1.0=mkl 9 | - c-ares=1.15.0=h7b6447c_1001 10 | - ca-certificates=2020.1.1=0 11 | - certifi=2019.11.28=py36_0 12 | - cffi=1.13.2=py36h2e261b9_0 13 | - cudatoolkit=10.0.130=0 14 | - freetype=2.9.1=h8a8886c_1 15 | - grpcio=1.16.1=py36hf8bcb03_1 16 | - intel-openmp=2019.4=243 17 | - jpeg=9b=h024ee3a_2 18 | - libedit=3.1.20181209=hc058e9b_0 19 | - libffi=3.2.1=hd88cf55_4 20 | - libgcc-ng=9.1.0=hdf63c60_0 21 | - libgfortran-ng=7.3.0=hdf63c60_0 22 | - libpng=1.6.37=hbc83047_0 23 | - libprotobuf=3.11.2=hd408876_0 24 | - libstdcxx-ng=9.1.0=hdf63c60_0 25 | - libtiff=4.1.0=h2733197_0 26 | - markdown=3.1.1=py36_0 27 | - mkl=2019.4=243 28 | - mkl-service=2.3.0=py36he904b0f_0 29 | - mkl_fft=1.0.15=py36ha843d7b_0 30 | - mkl_random=1.1.0=py36hd6b4f25_0 31 | - ncurses=6.1=he6710b0_1 32 | - ninja=1.9.0=py36hfd86e86_0 33 | - numpy=1.17.4=py36hc1035e2_0 34 | - numpy-base=1.17.4=py36hde5b4d6_0 35 | - olefile=0.46=py_0 36 | - openssl=1.1.1d=h7b6447c_4 37 | - pillow=6.2.1=py36h34e0f95_0 38 | - pip=19.3.1=py36_0 39 | - protobuf=3.11.2=py36he6710b0_0 40 | - pycparser=2.19=py_0 41 | - python=3.6.9=h265db76_0 42 | - readline=7.0=h7b6447c_5 43 | - setuptools=42.0.2=py36_0 44 | - six=1.13.0=py36_0 45 | - sqlite=3.30.1=h7b6447c_0 46 | - tensorboard=2.0.0=pyhb38c66f_1 47 | - tk=8.6.8=hbc83047_0 48 | - werkzeug=0.16.0=py_0 49 | - wheel=0.33.6=py36_0 50 | - xz=5.2.4=h14c3975_4 51 | - zlib=1.2.11=h7b6447c_3 52 | - zstd=1.3.7=h0b5b093_0 53 | - pytorch=1.3.1=py3.6_cuda10.0.130_cudnn7.6.3_0 54 | - torchvision=0.4.2=py36_cu100 55 | - pip: 56 | - astor==0.8.1 57 | - atomicwrites==1.2.1 58 | - attrs==19.3.0 59 | - awscli==1.16.180 60 | - backcall==0.1.0 61 | - bleach==3.1.0 62 | - boto3==1.10.46 63 | - botocore==1.13.46 64 | - cachetools==3.1.1 65 | - chardet==3.0.4 66 | - click==7.0 67 | - colorama==0.4.1 68 | - cupy-cuda100==6.1.0 69 | - cycler==0.10.0 70 | - cython==0.29.10 71 | - decorator==4.4.1 72 | - defusedxml==0.6.0 73 | - dgl==0.1.3 74 | - dill==0.3.1.1 75 | - docopt==0.6.2 76 | - docutils==0.15.2 77 | - dominate==2.4.0 78 | - easydict==1.9 79 | - entrypoints==0.3 80 | - fastrlock==0.4 81 | - filelock==3.0.10 82 | - fire==0.1.3 83 | - flatbuffers==1.10 84 | - funcsigs==1.0.2 85 | - gast==0.3.2 86 | - google-auth==1.7.1 87 | - google-auth-oauthlib==0.4.1 88 | - google-pasta==0.1.8 89 | - hunspell==0.5.5 90 | - idna==2.8 91 | - imageio==2.6.1 92 | - imdirect==0.5.0 93 | - imgaug==0.3.0 94 | - importlib-metadata==0.23 95 | - ipykernel==5.1.3 96 | - ipython==7.9.0 97 | - ipython-genutils==0.2.0 98 | - isodate==0.6.0 99 | - jedi==0.15.1 100 | - jinja2==2.10.3 101 | - jmespath==0.9.4 102 | - joblib==0.14.1 103 | - json5==0.8.5 104 | - jsonschema==3.1.1 105 | - jupyter-client==5.3.4 106 | - jupyter-core==4.6.1 107 | - jupyterlab==1.2.1 108 | - jupyterlab-server==1.0.6 109 | - keras-applications==1.0.8 110 | - keras-preprocessing==1.1.0 111 | - kiwisolver==1.1.0 112 | - markupsafe==1.1.1 113 | - matplotlib==3.1.2 114 | - mistune==0.8.4 115 | - more-itertools==7.2.0 116 | - nbconvert==5.6.1 117 | - nbformat==4.4.0 118 | - networkx==2.4 119 | - nltk==3.4.5 120 | - notebook==6.0.1 121 | - oauthlib==3.1.0 122 | - opencv-python==4.1.1.26 123 | - pandas==0.24.1 124 | - pandocfilters==1.4.2 125 | - parso==0.5.1 126 | - pexpect==4.7.0 127 | - pickleshare==0.7.5 128 | - piexif==1.1.3 129 | - pipreqs==0.4.10 130 | - pluggy==0.8.0 131 | - progressbar==2.5 132 | - prometheus-client==0.7.1 133 | - prompt-toolkit==2.0.10 134 | - ptyprocess==0.6.0 135 | - pudb==2019.1 136 | - py==1.7.0 137 | - pyasn1==0.4.8 138 | - pyasn1-modules==0.2.7 139 | - pycocotools==2.0.0 140 | - pygments==2.4.2 141 | - pyparsing==2.4.5 142 | - pyrsistent==0.15.5 143 | - pytest==4.1.0 144 | - python-dateutil==2.8.1 145 | - pytorch-nlp==0.3.7.post1 146 | - pytorch-warmup==0.0.4 147 | - pytz==2018.9 148 | - pywavelets==1.1.1 149 | - pyyaml==5.3 150 | - pyzmq==18.1.0 151 | - ray==0.6.1 152 | - rdflib==4.2.2 153 | - redis==3.0.1 154 | - regex==2019.12.20 155 | - requests==2.22.0 156 | - requests-oauthlib==1.3.0 157 | - rouge-score==0.0.3 158 | - rsa==4.0 159 | - s3transfer==0.2.1 160 | - sacremoses==0.0.35 161 | - scikit-image==0.16.2 162 | - scipy==1.4.1 163 | - send2trash==1.5.0 164 | - sentencepiece==0.1.85 165 | - shapely==1.6.4.post2 166 | - sklearn==0.0 167 | - tensorboard-logger==0.1.0 168 | - tensorflow-estimator==1.14.0 169 | - tensorflow-gpu==1.14.0 170 | - termcolor==1.1.0 171 | - terminado==0.8.2 172 | - testpath==0.4.2 173 | - torch==1.3.1 174 | - torch-two-sample==0.1 175 | - torchfile==0.1.0 176 | - torchnet==0.0.4 177 | - tornado==6.0.3 178 | - tqdm==4.42.1 179 | - traitlets==4.3.3 180 | - transformers==2.3.0 181 | - urllib3==1.25.7 182 | - urwid==2.0.1 183 | - virtualenv==16.1.0 184 | - visdom==0.1.8.8 185 | - wcwidth==0.1.7 186 | - webencodings==0.5.1 187 | - websocket-client==0.56.0 188 | - wrapt==1.11.2 189 | - yacs==0.1.6 190 | - yarg==0.1.9 191 | - zipp==0.6.0 192 | prefix: /home/nicola/anaconda3/envs/teran -------------------------------------------------------------------------------- /evaluate_utils/compute_relevance.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.coco import CocoCaptions 2 | import numpy as np 3 | import os 4 | import yaml 5 | import tqdm 6 | import argparse 7 | import multiprocessing 8 | from functools import partial 9 | from torch.utils.data import DataLoader 10 | from rouge_score import rouge_scorer 11 | from evaluate_utils import rouge, spice 12 | import utils 13 | import data 14 | import copy 15 | 16 | 17 | def compute_relevances_wrt_query(query): 18 | i, (_, query_caption, _, _) = query 19 | row_dataloader = DataLoader(compute_relevances_wrt_query.dataset, 20 | num_workers=0, 21 | batch_size=5, 22 | shuffle=False, 23 | collate_fn=my_collate 24 | ) 25 | if compute_relevances_wrt_query.method == 'rougeL': 26 | # scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) 27 | scorer = rouge.Rouge() 28 | 29 | for j, (_, cur_captions, _, _) in enumerate(row_dataloader): 30 | if compute_relevances_wrt_query.npy_file[i, j] < 0: 31 | # fine-grain check on negative values. If not negative, this value has already been computed 32 | relevance = scorer.score(query_caption, cur_captions) 33 | compute_relevances_wrt_query.npy_file[i, j] = relevance 34 | 35 | elif compute_relevances_wrt_query.method == 'spice': 36 | if any(compute_relevances_wrt_query.npy_file[i, :] < 0): 37 | scorer = spice.Spice() 38 | # accumulate all captions 39 | all_captions = [] 40 | for j, (_, cur_captions, _, _) in enumerate(row_dataloader): 41 | all_captions.append(cur_captions) 42 | 43 | _, scores = scorer.compute_score(all_captions, query_caption) 44 | relevances = [s['All']['f'] for s in scores] 45 | relevances = np.array(relevances) 46 | compute_relevances_wrt_query.npy_file[i, :] = relevances 47 | 48 | 49 | def parallel_worker_init(npy_file, dataset, method): 50 | compute_relevances_wrt_query.npy_file = npy_file 51 | compute_relevances_wrt_query.dataset = dataset 52 | compute_relevances_wrt_query.method = method 53 | 54 | 55 | def get_dataset(config, split): 56 | roots, ids = data.get_paths(config) 57 | 58 | data_name = config['dataset']['name'] 59 | if 'coco' in data_name: 60 | # COCO custom dataset 61 | dataset = data.CocoDataset(root=roots[split]['img'], json=roots[split]['cap'], ids=ids[split], get_images=False) 62 | elif 'f8k' in data_name or 'f30k' in data_name: 63 | dataset = data.FlickrDataset(root=roots[split]['img'], split=split, json=roots[split]['cap'], get_images=False) 64 | return dataset 65 | 66 | 67 | def my_collate(batch): 68 | transposed_batch = list(zip(*batch)) 69 | return transposed_batch 70 | 71 | 72 | def main(args, config): 73 | dataset = get_dataset(config, args.split) 74 | queries_dataloader = DataLoader(dataset, num_workers=0, 75 | batch_size=1, 76 | shuffle=False, 77 | collate_fn=my_collate 78 | ) 79 | 80 | relevance_dir = os.path.join(config['dataset']['data'], config['dataset']['name'], 'relevances') 81 | if not os.path.exists(relevance_dir): 82 | os.makedirs(relevance_dir) 83 | relevance_filename = os.path.join(relevance_dir, '{}-{}-{}.npy'.format(config['dataset']['name'], args.split, args.method)) 84 | if os.path.isfile(relevance_filename): 85 | answ = input("Relevances for {} already existing in {}. Continue? (y/n)".format(args.method, relevance_filename)) 86 | if answ != 'y': 87 | quit() 88 | 89 | # filename = os.path.join(cache_dir,'d_{}.npy'.format(query_img_index)) 90 | n_queries = len(queries_dataloader) 91 | n_images = len(queries_dataloader) // 5 92 | if os.path.isfile(relevance_filename): 93 | # print('Graph distances file existing for image {}, cache {}! Loading...'.format(query_img_index, cache_name)) 94 | print('Loading existing file {} with shape {} x {}'.format(relevance_filename, n_queries, n_images)) 95 | npy_file = np.memmap(relevance_filename, dtype=np.float32, shape=(n_queries, n_images), mode='r+') 96 | else: 97 | print('Creating new file {} with shape {} x {}'.format(relevance_filename, n_queries, n_images)) 98 | npy_file = np.memmap(relevance_filename, dtype=np.float32, shape=(n_queries, n_images), mode='w+') 99 | npy_file[:, :] = -1 100 | 101 | # print('Computing {} distances for image {}, cache {}...'.format(n,query_img_index,cache_name)) 102 | 103 | # pbar = ProgressBar(widgets=[Percentage(), Bar(), AdaptiveETA()], maxval=n).start() 104 | print('Starting relevance computation...') 105 | with multiprocessing.Pool(processes=args.ncpus, initializer=parallel_worker_init, 106 | initargs=(npy_file, dataset, args.method)) as pool: 107 | for _ in tqdm.tqdm(pool.imap_unordered(compute_relevances_wrt_query, enumerate(queries_dataloader)), total=n_queries): 108 | pass 109 | 110 | 111 | if __name__ == '__main__': 112 | arg_parser = argparse.ArgumentParser(description='Extract captioning scores for use as relevance') 113 | arg_parser.add_argument('--config', type=str, help="Which configuration to use. See into 'config' folder") 114 | arg_parser.add_argument('--method', type=str, default="rougeL", help="Scoring method") 115 | arg_parser.add_argument('--split', type=str, default="val", help="Dataset split to use") 116 | arg_parser.add_argument('--ncpus', type=int, default=12, help="How many gpus to use") 117 | 118 | args = arg_parser.parse_args() 119 | with open(args.config, 'r') as ymlfile: 120 | config = yaml.load(ymlfile) 121 | main(args, config) 122 | -------------------------------------------------------------------------------- /models/text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | 6 | from models.utils import l2norm 7 | from transformers import BertTokenizer, BertModel, BertConfig 8 | 9 | 10 | def EncoderText(config): 11 | use_abs = config['training']['measure'] == 'order' 12 | num_layers = config['text-model']['layers'] 13 | order_embeddings = config['training']['measure'] == 'order' 14 | if config['text-model']['name'] == 'gru': 15 | print('Using GRU text encoder') 16 | vocab_size = config['text-model']['vocab-size'] 17 | word_dim = config['text-model']['word-dim'] 18 | embed_size = config['model']['embed-size'] 19 | model = EncoderTextGRU(vocab_size, word_dim, embed_size, num_layers, order_embeddings=order_embeddings) 20 | elif config['text-model']['name'] == 'bert': 21 | print('Using BERT text encoder') 22 | model = EncoderTextBERT(config, order_embeddings=order_embeddings, post_transformer_layers=num_layers) 23 | return model 24 | 25 | 26 | # tutorials/08 - Language Model 27 | # RNN Based Language Model 28 | class EncoderTextGRU(nn.Module): 29 | 30 | def __init__(self, vocab_size, word_dim, embed_size, num_layers, 31 | order_embeddings=False): 32 | super(EncoderTextGRU, self).__init__() 33 | self.order_embeddings = order_embeddings 34 | self.embed_size = embed_size 35 | self.vocab_size = vocab_size 36 | 37 | # word embedding 38 | self.word_embeddings = nn.Embedding(vocab_size, word_dim) 39 | 40 | # caption embedding 41 | self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True) 42 | 43 | self.init_weights() 44 | 45 | def init_weights(self): 46 | self.word_embeddings.weight.data.uniform_(-0.1, 0.1) 47 | 48 | def forward(self, x, lengths): 49 | """Handles variable size captions 50 | """ 51 | # Embed word ids to vectors 52 | x = self.word_embeddings(x) 53 | packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 54 | 55 | # Forward propagate RNN 56 | out, _ = self.rnn(packed) 57 | 58 | # Reshape *final* output to (batch_size, hidden_size) 59 | padded = pad_packed_sequence(out, batch_first=True) 60 | I = torch.LongTensor(lengths).view(-1, 1, 1) 61 | I = (I.expand(x.size(0), 1, self.embed_size)-1).to(x.device) 62 | out = torch.gather(padded[0], 1, I).squeeze(1) 63 | 64 | # normalization in the joint embedding space 65 | # out = l2norm(out) 66 | 67 | # take absolute value, used by order embeddings 68 | if self.order_embeddings: 69 | out = torch.abs(out) 70 | 71 | return out, padded[0] 72 | 73 | def get_finetuning_params(self): 74 | return [] 75 | 76 | 77 | class EncoderTextBERT(nn.Module): 78 | def __init__(self, config, order_embeddings=False, mean=True, post_transformer_layers=0): 79 | super().__init__() 80 | self.preextracted = config['text-model']['pre-extracted'] 81 | bert_config = BertConfig.from_pretrained(config['text-model']['pretrain'], 82 | output_hidden_states=True, 83 | num_hidden_layers=config['text-model']['extraction-hidden-layer']) 84 | bert_model = BertModel.from_pretrained(config['text-model']['pretrain'], config=bert_config) 85 | self.order_embeddings = order_embeddings 86 | self.vocab_size = bert_model.config.vocab_size 87 | self.hidden_layer = config['text-model']['extraction-hidden-layer'] 88 | if not self.preextracted: 89 | self.tokenizer = BertTokenizer.from_pretrained(config['text-model']['pretrain']) 90 | self.bert_model = bert_model 91 | self.word_embeddings = self.bert_model.get_input_embeddings() 92 | if post_transformer_layers > 0: 93 | transformer_layer = nn.TransformerEncoderLayer(d_model=config['text-model']['word-dim'], nhead=4, 94 | dim_feedforward=2048, 95 | dropout=config['text-model']['dropout'], activation='relu') 96 | self.transformer_encoder = nn.TransformerEncoder(transformer_layer, 97 | num_layers=post_transformer_layers) 98 | self.post_transformer_layers = post_transformer_layers 99 | self.map = nn.Linear(config['text-model']['word-dim'], config['model']['embed-size']) 100 | self.mean = mean 101 | 102 | def forward(self, x, lengths): 103 | ''' 104 | x: tensor of indexes (LongTensor) obtained with tokenizer.encode() of size B x ? 105 | lengths: tensor of lengths (LongTensor) of size B 106 | ''' 107 | if not self.preextracted or self.post_transformer_layers > 0: 108 | max_len = max(lengths) 109 | attention_mask = torch.ones(x.shape[0], max_len) 110 | for e, l in zip(attention_mask, lengths): 111 | e[l:] = 0 112 | attention_mask = attention_mask.to(x.device) 113 | 114 | if self.preextracted: 115 | outputs = x 116 | else: 117 | outputs = self.bert_model(x, attention_mask=attention_mask) 118 | outputs = outputs[2][-1] 119 | 120 | if self.post_transformer_layers > 0: 121 | outputs = outputs.permute(1, 0, 2) 122 | outputs = self.transformer_encoder(outputs, src_key_padding_mask=(attention_mask - 1).bool()) 123 | outputs = outputs.permute(1, 0, 2) 124 | if self.mean: 125 | x = outputs.mean(dim=1) 126 | else: 127 | x = outputs[:, 0, :] # from the last layer take only the first word 128 | 129 | out = self.map(x) 130 | 131 | # normalization in the joint embedding space 132 | # out = l2norm(out) 133 | 134 | # take absolute value, used by order embeddings 135 | if self.order_embeddings: 136 | out = torch.abs(out) 137 | return out, outputs 138 | 139 | def get_finetuning_params(self): 140 | return list(self.bert_model.parameters()) 141 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn, nn as nn 5 | 6 | 7 | class PositionalEncodingText(nn.Module): 8 | 9 | def __init__(self, d_model, dropout=0.1, max_len=5000): 10 | super(PositionalEncodingText, self).__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0).transpose(0, 1) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | x = x + self.pe[:x.size(0), :] 23 | return self.dropout(x) 24 | 25 | 26 | class PositionalEncodingImageGrid(nn.Module): 27 | def __init__(self, d_model, n_regions=(4, 4)): 28 | super().__init__() 29 | assert n_regions[0] == n_regions[1] 30 | self.map = nn.Linear(2, d_model) 31 | self.n_regions = n_regions 32 | self.coord_tensor = self.build_coord_tensor(n_regions[0]) 33 | 34 | @staticmethod 35 | def build_coord_tensor(d): 36 | coords = torch.linspace(-1., 1., d) 37 | x = coords.unsqueeze(0).repeat(d, 1) 38 | y = coords.unsqueeze(1).repeat(1, d) 39 | ct = torch.stack((x, y), dim=2) 40 | if torch.cuda.is_available(): 41 | ct = ct.cuda() 42 | return ct 43 | 44 | def forward(self, x, start_token=False): # x is seq_len x B x dim 45 | assert not (start_token and self.n_regions[0] == math.sqrt(x.shape[0])) 46 | bs = x.shape[1] 47 | ct = self.coord_tensor.view(self.n_regions[0]**2, -1) # 16 x 2 48 | 49 | ct = self.map(ct).unsqueeze(1) # 16 x d_model 50 | if start_token: 51 | x[1:] = x[1:] + ct.expand(-1, bs, -1) 52 | out_grid_point = torch.FloatTensor([-1. - 2/self.n_regions[0], -1.]).unsqueeze(0) 53 | if torch.cuda.is_available(): 54 | out_grid_point = out_grid_point.cuda() 55 | x[0:1] = x[0:1] + self.map(out_grid_point) 56 | else: 57 | x = x + ct.expand(-1, bs, -1) 58 | return x 59 | 60 | 61 | class PositionalEncodingImageBoxes(nn.Module): 62 | def __init__(self, d_model, mode='project-and-sum'): 63 | super().__init__() 64 | self.mode = mode 65 | if mode == 'project-and-sum': 66 | self.map = nn.Linear(5, d_model) 67 | elif mode == 'concat-and-process': 68 | self.map = nn.Sequential( 69 | nn.Linear(d_model + 5, d_model), 70 | nn.ReLU(), 71 | nn.Linear(d_model, d_model) 72 | ) 73 | 74 | 75 | def forward(self, x, boxes): # x is seq_len x B x dim 76 | bs = x.shape[1] 77 | area = (boxes[:, :, 2] - boxes[:, :, 0]) * (boxes[:, :, 3] - boxes[:, :, 1]) 78 | area = area.unsqueeze(2) 79 | s_infos = torch.cat([boxes, area], dim=2) 80 | if self.mode == 'project-and-sum': 81 | ct = self.map(s_infos).permute(1, 0, 2) # S x B x dim 82 | x = x + ct.expand(-1, bs, -1) 83 | elif self.mode == 'concat-and-process': 84 | x = torch.cat([x, s_infos.permute(1, 0, 2)], dim=2) 85 | x = self.map(x) 86 | return x 87 | 88 | 89 | def l2norm(X): 90 | """L2-normalize columns of X 91 | """ 92 | norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt() 93 | X = torch.div(X, norm) 94 | return X 95 | 96 | 97 | class GatedAggregation(nn.Module): 98 | def __init__(self, feat_dim): 99 | super().__init__() 100 | self.gate_fn = nn.Sequential( 101 | nn.Linear(feat_dim, feat_dim), 102 | nn.ReLU(), 103 | nn.Linear(feat_dim, 1) 104 | ) 105 | self.node_fn = nn.Sequential( 106 | nn.Linear(feat_dim, feat_dim), 107 | nn.ReLU(), 108 | nn.Linear(feat_dim, feat_dim) 109 | ) 110 | 111 | def forward(self, x, mask): 112 | out = x.permute(1, 0, 2) 113 | gate = self.gate_fn(out) 114 | gate = gate.masked_fill_(mask.unsqueeze(2), - float('inf')) 115 | m = torch.sigmoid(gate) # B x S x 1 116 | v = self.node_fn(out) # B x S x dim 117 | out = torch.bmm(m.permute(0, 2, 1), v) # B x 1 x dim 118 | out = out.squeeze(1) # B x dim 119 | return out 120 | 121 | 122 | class Aggregator(nn.Module): 123 | def __init__(self, embed_size, aggregation_type='sum'): 124 | super().__init__() 125 | self.aggregation = aggregation_type 126 | if self.aggregation == 'gated': 127 | self.gated_aggr = GatedAggregation(embed_size) 128 | if self.aggregation == 'gru': 129 | self.gru_aggr = nn.GRU(embed_size, embed_size, batch_first=True) 130 | if self.aggregation == 'sum-and-map': 131 | self.map = nn.Sequential( 132 | nn.Linear(embed_size, embed_size), 133 | nn.ReLU(), 134 | nn.Linear(embed_size, embed_size) 135 | ) 136 | 137 | def forward(self, x, lengths, mask): 138 | if self.aggregation == 'first': 139 | out = x[0, :, :] 140 | elif self.aggregation == 'sum': 141 | x = x.permute(1, 0, 2) 142 | for o, c_len in zip(x, lengths): 143 | o[c_len:] = 0 144 | out = x.sum(dim=1) 145 | elif self.aggregation == 'gated': 146 | out = self.gated_aggr(x, mask) 147 | elif self.aggregation == 'gru': 148 | packed_sequence = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=False, enforce_sorted=False) 149 | _, out = self.gru_aggr(packed_sequence) 150 | out = out.squeeze(0) 151 | elif self.aggregation == 'sum-and-map': 152 | x = x.permute(1, 0, 2) 153 | for o, c_len in zip(x, lengths): 154 | o[c_len:] = 0 155 | out = x.sum(dim=1) 156 | out = self.map(out) 157 | else: 158 | raise ValueError('Final aggregation not defined!') 159 | 160 | return out 161 | 162 | 163 | def generate_square_subsequent_mask(sz): 164 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 165 | Unmasked positions are filled with float(0.0). 166 | """ 167 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 168 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 169 | return mask -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from .utils import l2norm 5 | 6 | 7 | def dot_sim(im, s): 8 | """Cosine similarity between all the image and sentence pairs 9 | """ 10 | return im.mm(s.t()) 11 | 12 | def cosine_sim(im, s): 13 | """Cosine similarity between all the image and sentence pairs 14 | """ 15 | im = l2norm(im) 16 | s = l2norm(s) 17 | return im.mm(s.t()) 18 | 19 | def order_sim(im, s): 20 | """Order embeddings similarity measure $max(0, s-im)$ 21 | """ 22 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1)) 23 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1))) 24 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t() 25 | return score 26 | 27 | 28 | class Contrastive(nn.Module): 29 | def __init__(self, margin=0, measure=False, max_violation=False): 30 | super(Contrastive, self).__init__() 31 | self.margin = margin 32 | if measure == 'order': 33 | self.sim = order_sim 34 | elif measure == 'cosine': 35 | self.sim = cosine_sim 36 | elif measure == 'dot': 37 | self.sim = dot_sim 38 | 39 | self.max_violation = max_violation 40 | 41 | def compute_contrastive_loss(self, scores): 42 | diagonal = scores.diag().view(scores.size(0), 1) 43 | d1 = diagonal.expand_as(scores) 44 | d2 = diagonal.t().expand_as(scores) 45 | 46 | # compare every diagonal score to scores in its column 47 | # caption retrieval 48 | cost_s = (self.margin + scores - d1).clamp(min=0) 49 | # compare every diagonal score to scores in its row 50 | # image retrieval 51 | cost_im = (self.margin + scores - d2).clamp(min=0) 52 | 53 | # clear diagonals 54 | mask = torch.eye(scores.size(0)) > .5 55 | I = mask 56 | if torch.cuda.is_available(): 57 | I = I.cuda() 58 | cost_s = cost_s.masked_fill_(I, 0) 59 | cost_im = cost_im.masked_fill_(I, 0) 60 | 61 | # keep the maximum violating negative for each query 62 | if self.max_violation: 63 | cost_s = cost_s.max(1)[0] 64 | cost_im = cost_im.max(0)[0] 65 | 66 | return cost_s.sum() + cost_im.sum() 67 | 68 | 69 | class AlignmentContrastiveLoss(Contrastive): 70 | """ 71 | Compute contrastive loss 72 | """ 73 | 74 | def __init__(self, margin=0, measure=False, max_violation=False, aggregation='sum-max-sentences', return_similarity_mat=False): 75 | super(AlignmentContrastiveLoss, self).__init__(margin, measure, max_violation) 76 | self.aggregation = aggregation 77 | self.return_similarity_mat = return_similarity_mat 78 | 79 | def forward(self, im_set, s_seq, im_len, s_len): 80 | # im_set = im_set.permute(1, 0, 2) # B x S_im x dim 81 | # s_seq = s_seq.permute(1, 0, 2) # B x S_s x dim 82 | 83 | # do not consider cls and eos tokens 84 | im_set = im_set[:, 1:, :] 85 | s_seq = s_seq[:, 1:-2, :] 86 | im_len = [l - 1 for l in im_len] 87 | s_len = [l - 3 for l in s_len] 88 | 89 | im_set_batch = im_set.size(0) 90 | im_set_len = im_set.size(1) 91 | s_seq_batch = s_seq.size(0) 92 | s_seq_len = s_seq.size(1) 93 | 94 | im_set = im_set.unsqueeze(1).expand(-1, s_seq_batch, -1, -1) # B x B x S_im x dim 95 | s_seq = s_seq.unsqueeze(0).expand(im_set_batch, -1, -1, -1) # B x B x S_s x dim 96 | alignments = torch.matmul(im_set, s_seq.permute(0, 1, 3, 2)) # B x B x S_im x S_s 97 | # alignments = F.relu(alignments) 98 | 99 | # compute mask for the alignments tensor 100 | im_len_mask = torch.zeros(im_set_batch, im_set_len).bool() 101 | im_len_mask = im_len_mask.to(im_set.device) 102 | for im, l in zip(im_len_mask, im_len): 103 | im[l:] = True 104 | im_len_mask = im_len_mask.unsqueeze(2).unsqueeze(1).expand(-1, s_seq_batch, -1, s_seq_len) 105 | 106 | s_len_mask = torch.zeros(s_seq_batch, s_seq_len).bool() 107 | s_len_mask = s_len_mask.to(im_set.device) 108 | for sm, l in zip(s_len_mask, s_len): 109 | sm[l:] = True 110 | s_len_mask = s_len_mask.unsqueeze(1).unsqueeze(0).expand(im_set_batch, -1, im_set_len, -1) 111 | 112 | alignment_mask = im_len_mask | s_len_mask 113 | alignments.masked_fill_(alignment_mask, value=0) 114 | # alignments = F.relu(alignments) 115 | # alignments = F.normalize(alignments,p=2, dim=2) 116 | 117 | if self.aggregation == 'sum': 118 | aggr_similarity = alignments.sum(dim=(2,3)) 119 | elif self.aggregation == 'mean': 120 | aggr_similarity = alignments.mean(dim=(2,3)) 121 | elif self.aggregation == 'MrSw': 122 | aggr_similarity = alignments.max(2)[0].sum(2) 123 | elif self.aggregation == 'MrAVGw': 124 | aggr_similarity = alignments.max(2)[0].sum(2) 125 | expanded_len = torch.FloatTensor(s_len).to(alignments.device).unsqueeze(0).expand(len(im_len), -1) 126 | aggr_similarity /= expanded_len 127 | elif self.aggregation == 'symm': 128 | im = alignments.max(2)[0].sum(2) 129 | s = alignments.max(3)[0].sum(2) 130 | aggr_similarity = im + s 131 | elif self.aggregation == 'MwSr': 132 | aggr_similarity = alignments.max(3)[0].sum(2) 133 | elif self.aggregation == 'scan-sentences': 134 | norm_alignments = F.relu(alignments) 135 | norm_alignments = F.normalize(norm_alignments,p=2, dim=2) 136 | weights = norm_alignments.masked_fill(alignment_mask, value=float('-inf')) 137 | weights = torch.softmax(weights, dim=3) 138 | 139 | weights = weights.unsqueeze(3) # B x B x im x 1 x s 140 | s_seq_ext = s_seq.unsqueeze(2).expand(-1, -1, im_set_len, -1, -1) 141 | att_vector = torch.matmul(weights, s_seq_ext) # B x B x im x 1 x dim 142 | att_vector = att_vector.squeeze(3) 143 | new_alignments = F.cosine_similarity(im_set, att_vector, dim=3) # B x B x im 144 | new_alignments.masked_fill_(im_len_mask[:, :, :, 0], value=0) 145 | 146 | aggr_similarity = new_alignments.sum(2) 147 | 148 | if self.return_similarity_mat: 149 | return aggr_similarity 150 | else: 151 | loss = self.compute_contrastive_loss(aggr_similarity) 152 | return loss 153 | 154 | 155 | class ContrastiveLoss(Contrastive): 156 | """ 157 | Compute contrastive loss 158 | """ 159 | 160 | def __init__(self, margin=0, measure=False, max_violation=False): 161 | super(ContrastiveLoss, self).__init__() 162 | self.margin = margin 163 | if measure == 'order': 164 | self.sim = order_sim 165 | elif measure == 'cosine': 166 | self.sim = cosine_sim 167 | elif measure == 'dot': 168 | self.sim = dot_sim 169 | 170 | self.max_violation = max_violation 171 | 172 | def forward(self, im, s): 173 | # compute image-sentence score matrix 174 | scores = self.sim(im, s) 175 | return self.compute_contrastive_loss(scores) 176 | 177 | 178 | class PermInvMatchingLoss(nn.Module): 179 | def __init__(self): 180 | super().__init__() 181 | 182 | # @staticmethod 183 | # def batched_cosine_sim(im, s): 184 | # """Cosine similarity between all the image and sentence pairs 185 | # """ 186 | # im = F.normalize(im, p=2, dim=2) 187 | # s = F.normalize(s, p=2, dim=2) 188 | # return im.mm(s.permute(0, 2, 1)) 189 | 190 | def forward(self, im, s): 191 | dist_matrix = torch.cdist(im, s, p=2) 192 | row_sum = F.softmin(dist_matrix, dim=2).max(dim=2)[0].sum(dim=1) 193 | col_sum = F.softmin(dist_matrix, dim=1).max(dim=1)[0].sum(dim=1) 194 | loss = 2*torch.Tensor([dist_matrix.shape[1]]).to(im.device) - row_sum - col_sum 195 | loss = loss.mean() 196 | return loss 197 | -------------------------------------------------------------------------------- /evaluate_utils/dcg.py: -------------------------------------------------------------------------------- 1 | # (C) Mathieu Blondel, November 2013 2 | # License: BSD 3 clause 3 | 4 | import numpy as np 5 | import os 6 | 7 | class DCG: 8 | def __init__(self, config, n_queries, split, rank=25, relevance_methods=['rougeL']): 9 | self.rank = rank 10 | self.relevance_methods = relevance_methods 11 | relevance_dir = os.path.join(config['dataset']['data'], config['dataset']['name'], 'relevances') 12 | relevance_filenames = [os.path.join(relevance_dir, '{}-{}-{}.npy'.format(config['dataset']['name'], 13 | split, m)) 14 | for m in relevance_methods] 15 | self.relevances = [np.memmap(f, dtype=np.float32, mode='r') for f in relevance_filenames] 16 | for r in self.relevances: 17 | r.shape = (n_queries, -1) 18 | 19 | def compute_ndcg(self, npts, query_id, sorted_indexes, fold_index=0, retrieval='image'): 20 | sorted_indexes = sorted_indexes[:self.rank] 21 | # npts = self.relevances[0].shape[1] // 5 22 | if retrieval == 'image': 23 | query_base = npts * 5 * fold_index 24 | # sorted_indexes += npts * fold_index 25 | relevances = [r[query_base + query_id, fold_index * npts : (fold_index + 1) * npts] for r in self.relevances] 26 | elif retrieval == 'sentence': 27 | query_base = npts * fold_index 28 | # sorted_indexes += npts * 5 * fold_index 29 | relevances = [r[fold_index * npts * 5 : (fold_index + 1) * npts * 5, query_base + query_id] for r in self.relevances] 30 | 31 | ndcg_scores = [ndcg_from_ranking(r, sorted_indexes) for r in relevances] 32 | out = {k: v for k, v in zip(self.relevance_methods, ndcg_scores)} 33 | return out 34 | 35 | # def compute_dcg(self, query_id, sorted_img_indexes): 36 | # sorted_img_indexes = sorted_img_indexes[:self.rank] 37 | # dcg_score = dcg_from_ranking(self.relevances[query_id], sorted_img_indexes) 38 | # return dcg_score 39 | 40 | 41 | 42 | def ranking_precision_score(y_true, y_score, k=10): 43 | """Precision at rank k 44 | Parameters 45 | ---------- 46 | y_true : array-like, shape = [n_samples] 47 | Ground truth (true relevance labels). 48 | y_score : array-like, shape = [n_samples] 49 | Predicted scores. 50 | k : int 51 | Rank. 52 | Returns 53 | ------- 54 | precision @k : float 55 | """ 56 | unique_y = np.unique(y_true) 57 | 58 | if len(unique_y) > 2: 59 | raise ValueError("Only supported for two relevance levels.") 60 | 61 | pos_label = unique_y[1] 62 | n_pos = np.sum(y_true == pos_label) 63 | 64 | order = np.argsort(y_score)[::-1] 65 | y_true = np.take(y_true, order[:k]) 66 | n_relevant = np.sum(y_true == pos_label) 67 | 68 | # Divide by min(n_pos, k) such that the best achievable score is always 1.0. 69 | return float(n_relevant) / min(n_pos, k) 70 | 71 | 72 | def average_precision_score(y_true, y_score, k=10): 73 | """Average precision at rank k 74 | Parameters 75 | ---------- 76 | y_true : array-like, shape = [n_samples] 77 | Ground truth (true relevance labels). 78 | y_score : array-like, shape = [n_samples] 79 | Predicted scores. 80 | k : int 81 | Rank. 82 | Returns 83 | ------- 84 | average precision @k : float 85 | """ 86 | unique_y = np.unique(y_true) 87 | 88 | if len(unique_y) > 2: 89 | raise ValueError("Only supported for two relevance levels.") 90 | 91 | pos_label = unique_y[1] 92 | n_pos = np.sum(y_true == pos_label) 93 | 94 | order = np.argsort(y_score)[::-1][:min(n_pos, k)] 95 | y_true = np.asarray(y_true)[order] 96 | 97 | score = 0 98 | for i in xrange(len(y_true)): 99 | if y_true[i] == pos_label: 100 | # Compute precision up to document i 101 | # i.e, percentage of relevant documents up to document i. 102 | prec = 0 103 | for j in xrange(0, i + 1): 104 | if y_true[j] == pos_label: 105 | prec += 1.0 106 | prec /= (i + 1.0) 107 | score += prec 108 | 109 | if n_pos == 0: 110 | return 0 111 | 112 | return score / n_pos 113 | 114 | 115 | def dcg_score(y_true, y_score, k=10, gains="exponential"): 116 | """Discounted cumulative gain (DCG) at rank k 117 | Parameters 118 | ---------- 119 | y_true : array-like, shape = [n_samples] 120 | Ground truth (true relevance labels). 121 | y_score : array-like, shape = [n_samples] 122 | Predicted scores. 123 | k : int 124 | Rank. 125 | gains : str 126 | Whether gains should be "exponential" (default) or "linear". 127 | Returns 128 | ------- 129 | DCG @k : float 130 | """ 131 | order = np.argsort(y_score)[::-1] 132 | y_true = np.take(y_true, order[:k]) 133 | 134 | if gains == "exponential": 135 | gains = 2 ** y_true - 1 136 | elif gains == "linear": 137 | gains = y_true 138 | else: 139 | raise ValueError("Invalid gains option.") 140 | 141 | # highest rank is 1 so +2 instead of +1 142 | discounts = np.log2(np.arange(len(y_true)) + 2) 143 | return np.sum(gains / discounts) 144 | 145 | 146 | def ndcg_score(y_true, y_score, k=10, gains="exponential"): 147 | """Normalized discounted cumulative gain (NDCG) at rank k 148 | Parameters 149 | ---------- 150 | y_true : array-like, shape = [n_samples] 151 | Ground truth (true relevance labels). 152 | y_score : array-like, shape = [n_samples] 153 | Predicted scores. 154 | k : int 155 | Rank. 156 | gains : str 157 | Whether gains should be "exponential" (default) or "linear". 158 | Returns 159 | ------- 160 | NDCG @k : float 161 | """ 162 | best = dcg_score(y_true, y_true, k, gains) 163 | actual = dcg_score(y_true, y_score, k, gains) 164 | return actual / best 165 | 166 | 167 | # Alternative API. 168 | 169 | def dcg_from_ranking(y_true, ranking): 170 | """Discounted cumulative gain (DCG) at rank k 171 | Parameters 172 | ---------- 173 | y_true : array-like, shape = [n_samples] 174 | Ground truth (true relevance labels). 175 | ranking : array-like, shape = [k] 176 | Document indices, i.e., 177 | ranking[0] is the index of top-ranked document, 178 | ranking[1] is the index of second-ranked document, 179 | ... 180 | k : int 181 | Rank. 182 | Returns 183 | ------- 184 | DCG @k : float 185 | """ 186 | y_true = np.asarray(y_true) 187 | ranking = np.asarray(ranking) 188 | rel = y_true[ranking] 189 | gains = 2 ** rel - 1 190 | discounts = np.log2(np.arange(len(ranking)) + 2) 191 | return np.sum(gains / discounts) 192 | 193 | 194 | def ndcg_from_ranking(y_true, ranking): 195 | """Normalized discounted cumulative gain (NDCG) at rank k 196 | Parameters 197 | ---------- 198 | y_true : array-like, shape = [n_samples] 199 | Ground truth (true relevance labels). 200 | ranking : array-like, shape = [k] 201 | Document indices, i.e., 202 | ranking[0] is the index of top-ranked document, 203 | ranking[1] is the index of second-ranked document, 204 | ... 205 | k : int 206 | Rank. 207 | Returns 208 | ------- 209 | NDCG @k : float 210 | """ 211 | k = len(ranking) 212 | best_ranking = np.argsort(y_true)[::-1] 213 | best = dcg_from_ranking(y_true, best_ranking[:k]) 214 | if best == 0: 215 | return 0 216 | return dcg_from_ranking(y_true, ranking) / best 217 | 218 | 219 | if __name__ == '__main__': 220 | 221 | # Check that some rankings are better than others 222 | assert dcg_score([5, 3, 2], [2, 1, 0]) > dcg_score([4, 3, 2], [2, 1, 0]) 223 | assert dcg_score([4, 3, 2], [2, 1, 0]) > dcg_score([1, 3, 2], [2, 1, 0]) 224 | 225 | assert dcg_score([5, 3, 2], [2, 1, 0], k=2) > dcg_score([4, 3, 2], [2, 1, 0], k=2) 226 | assert dcg_score([4, 3, 2], [2, 1, 0], k=2) > dcg_score([1, 3, 2], [2, 1, 0], k=2) 227 | 228 | # Perfect rankings 229 | assert ndcg_score([5, 3, 2], [2, 1, 0]) == 1.0 230 | assert ndcg_score([2, 3, 5], [0, 1, 2]) == 1.0 231 | assert ndcg_from_ranking([5, 3, 2], [0, 1, 2]) == 1.0 232 | 233 | assert ndcg_score([5, 3, 2], [2, 1, 0], k=2) == 1.0 234 | assert ndcg_score([2, 3, 5], [0, 1, 2], k=2) == 1.0 235 | assert ndcg_from_ranking([5, 3, 2], [0, 1]) == 1.0 236 | 237 | # Check that sample order is irrelevant 238 | assert dcg_score([5, 3, 2], [2, 1, 0]) == dcg_score([2, 3, 5], [0, 1, 2]) 239 | 240 | assert dcg_score([5, 3, 2], [2, 1, 0], k=2) == dcg_score([2, 3, 5], [0, 1, 2], k=2) 241 | 242 | # Check equivalence between two interfaces. 243 | assert dcg_score([5, 3, 2], [2, 1, 0]) == dcg_from_ranking([5, 3, 2], [0, 1, 2]) 244 | assert dcg_score([1, 3, 2], [2, 1, 0]) == dcg_from_ranking([1, 3, 2], [0, 1, 2]) 245 | assert dcg_score([1, 3, 2], [0, 2, 1]) == dcg_from_ranking([1, 3, 2], [1, 2, 0]) 246 | assert ndcg_score([1, 3, 2], [2, 1, 0]) == ndcg_from_ranking([1, 3, 2], [0, 1, 2]) 247 | 248 | assert dcg_score([5, 3, 2], [2, 1, 0], k=2) == dcg_from_ranking([5, 3, 2], [0, 1]) 249 | assert dcg_score([1, 3, 2], [2, 1, 0], k=2) == dcg_from_ranking([1, 3, 2], [0, 1]) 250 | assert dcg_score([1, 3, 2], [0, 2, 1], k=2) == dcg_from_ranking([1, 3, 2], [1, 2]) 251 | assert ndcg_score([1, 3, 2], [2, 1, 0], k=2) == \ 252 | ndcg_from_ranking([1, 3, 2], [0, 1]) 253 | 254 | # Precision 255 | assert ranking_precision_score([1, 1, 0], [3, 2, 1], k=2) == 1.0 256 | assert ranking_precision_score([1, 1, 0], [1, 0, 0.5], k=2) == 0.5 257 | assert ranking_precision_score([1, 1, 0], [3, 2, 1], k=3) == \ 258 | ranking_precision_score([1, 1, 0], [1, 0, 0.5], k=3) 259 | 260 | # Average precision 261 | from sklearn.metrics import average_precision_score as ap 262 | assert average_precision_score([1, 1, 0], [3, 2, 1]) == ap([1, 1, 0], [3, 2, 1]) 263 | assert average_precision_score([1, 1, 0], [3, 1, 0]) == ap([1, 1, 0], [3, 1, 0]) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/teran.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.init 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | from transformers import BertTokenizer 7 | 8 | from models.loss import ContrastiveLoss, PermInvMatchingLoss, AlignmentContrastiveLoss 9 | from models.text import EncoderTextBERT, EncoderText 10 | from models.visual import TransformerPostProcessing, EncoderImage 11 | 12 | from .utils import l2norm, PositionalEncodingImageBoxes, PositionalEncodingText, Aggregator, generate_square_subsequent_mask 13 | from nltk.corpus import stopwords, words as nltk_words 14 | 15 | 16 | class JointTextImageTransformerEncoder(nn.Module): 17 | """ 18 | This is a bert caption encoder - transformer image encoder (using bottomup features). 19 | If process the encoder outputs through a transformer, like VilBERT and outputs two different graph embeddings 20 | """ 21 | def __init__(self, config): 22 | super().__init__() 23 | self.txt_enc = EncoderText(config) 24 | 25 | visual_feat_dim = config['image-model']['feat-dim'] 26 | caption_feat_dim = config['text-model']['word-dim'] 27 | dropout = config['model']['dropout'] 28 | layers = config['model']['layers'] 29 | embed_size = config['model']['embed-size'] 30 | self.order_embeddings = config['training']['measure'] == 'order' 31 | self.img_enc = EncoderImage(config) 32 | 33 | self.img_proj = nn.Linear(visual_feat_dim, embed_size) 34 | self.cap_proj = nn.Linear(caption_feat_dim, embed_size) 35 | self.embed_size = embed_size 36 | self.shared_transformer = config['model']['shared-transformer'] 37 | 38 | transformer_layer_1 = nn.TransformerEncoderLayer(d_model=embed_size, nhead=4, 39 | dim_feedforward=2048, 40 | dropout=dropout, activation='relu') 41 | self.transformer_encoder_1 = nn.TransformerEncoder(transformer_layer_1, 42 | num_layers=layers) 43 | if not self.shared_transformer: 44 | transformer_layer_2 = nn.TransformerEncoderLayer(d_model=embed_size, nhead=4, 45 | dim_feedforward=2048, 46 | dropout=dropout, activation='relu') 47 | self.transformer_encoder_2 = nn.TransformerEncoder(transformer_layer_2, 48 | num_layers=layers) 49 | self.text_aggregation = Aggregator(embed_size, aggregation_type=config['model']['text-aggregation']) 50 | self.image_aggregation = Aggregator(embed_size, aggregation_type=config['model']['image-aggregation']) 51 | self.text_aggregation_type = config['model']['text-aggregation'] 52 | self.img_aggregation_type = config['model']['image-aggregation'] 53 | 54 | def forward(self, features, captions, feat_len, cap_len, boxes): 55 | # process captions by using bert 56 | full_cap_emb_aggr, c_emb = self.txt_enc(captions, cap_len) # B x S x cap_dim 57 | 58 | # process image regions using a two-layer transformer 59 | full_img_emb_aggr, i_emb = self.img_enc(features, feat_len, boxes) # B x S x vis_dim 60 | # i_emb = i_emb.permute(1, 0, 2) # B x S x vis_dim 61 | 62 | bs = features.shape[0] 63 | 64 | # if False: 65 | # # concatenate the embeddings together 66 | # max_summed_lengths = max([x + y for x, y in zip(feat_len, cap_len)]) 67 | # i_c_emb = torch.zeros(bs, max_summed_lengths, self.embed_size) 68 | # i_c_emb = i_c_emb.to(features.device) 69 | # mask = torch.zeros(bs, max_summed_lengths).bool() 70 | # mask = mask.to(features.device) 71 | # for i_c, m, i, c, i_len, c_len in zip(i_c_emb, mask, i_emb, c_emb, feat_len, cap_len): 72 | # i_c[:c_len] = c[:c_len] 73 | # i_c[c_len:c_len + i_len] = i[:i_len] 74 | # m[c_len + i_len:] = True 75 | # 76 | # i_c_emb = i_c_emb.permute(1, 0, 2) # S_vis + S_txt x B x dim 77 | # out = self.transformer_encoder(i_c_emb, src_key_padding_mask=mask) # S_vis + S_txt x B x dim 78 | # 79 | # full_cap_emb = out[0, :, :] 80 | # I = torch.LongTensor(cap_len).view(1, -1, 1) 81 | # I = I.expand(1, bs, self.embed_size).to(features.device) 82 | # full_img_emb = torch.gather(out, dim=0, index=I).squeeze(0) 83 | # else: 84 | 85 | # forward the captions 86 | if self.text_aggregation_type is not None: 87 | c_emb = self.cap_proj(c_emb) 88 | 89 | mask = torch.zeros(bs, max(cap_len)).bool() 90 | mask = mask.to(features.device) 91 | for m, c_len in zip(mask, cap_len): 92 | m[c_len:] = True 93 | full_cap_emb = self.transformer_encoder_1(c_emb.permute(1, 0, 2), src_key_padding_mask=mask) # S_txt x B x dim 94 | full_cap_emb_aggr = self.text_aggregation(full_cap_emb, cap_len, mask) 95 | # else use the embedding output by the txt model 96 | else: 97 | full_cap_emb = None 98 | 99 | # forward the regions 100 | if self.img_aggregation_type is not None: 101 | i_emb = self.img_proj(i_emb) 102 | 103 | mask = torch.zeros(bs, max(feat_len)).bool() 104 | mask = mask.to(features.device) 105 | for m, v_len in zip(mask, feat_len): 106 | m[v_len:] = True 107 | if self.shared_transformer: 108 | full_img_emb = self.transformer_encoder_1(i_emb.permute(1, 0, 2), src_key_padding_mask=mask) # S_txt x B x dim 109 | else: 110 | full_img_emb = self.transformer_encoder_2(i_emb.permute(1, 0, 2), src_key_padding_mask=mask) # S_txt x B x dim 111 | full_img_emb_aggr = self.image_aggregation(full_img_emb, feat_len, mask) 112 | else: 113 | full_img_emb = None 114 | 115 | full_cap_emb_aggr = l2norm(full_cap_emb_aggr) 116 | full_img_emb_aggr = l2norm(full_img_emb_aggr) 117 | 118 | # normalize even every vector of the set 119 | full_img_emb = F.normalize(full_img_emb, p=2, dim=2) 120 | full_cap_emb = F.normalize(full_cap_emb, p=2, dim=2) 121 | 122 | if self.order_embeddings: 123 | full_cap_emb_aggr = torch.abs(full_cap_emb_aggr) 124 | full_img_emb_aggr = torch.abs(full_img_emb_aggr) 125 | return full_img_emb_aggr, full_cap_emb_aggr, full_img_emb, full_cap_emb 126 | 127 | 128 | class TERAN(torch.nn.Module): 129 | """ 130 | rkiros/uvs model 131 | """ 132 | 133 | def __init__(self, config): 134 | # tutorials/09 - Image Captioning 135 | # Build Models 136 | super().__init__() 137 | self.img_txt_enc = JointTextImageTransformerEncoder(config) 138 | if torch.cuda.is_available(): 139 | self.img_txt_enc.cuda() 140 | cudnn.benchmark = True 141 | 142 | # Loss and Optimizer 143 | 144 | loss_type = config['training']['loss-type'] 145 | if 'alignment' in loss_type: 146 | self.alignment_criterion = AlignmentContrastiveLoss(margin=config['training']['margin'], 147 | measure=config['training']['measure'], 148 | max_violation=config['training']['max-violation'], aggregation=config['training']['alignment-mode']) 149 | if 'matching' in loss_type: 150 | self.matching_criterion = ContrastiveLoss(margin=config['training']['margin'], 151 | measure=config['training']['measure'], 152 | max_violation=config['training']['max-violation']) 153 | 154 | self.Eiters = 0 155 | self.config = config 156 | 157 | if 'exclude-stopwords' in config['model'] and config['model']['exclude-stopwords']: 158 | self.en_stops = set(stopwords.words('english')) 159 | self.tokenizer = BertTokenizer.from_pretrained(config['text-model']['pretrain']) 160 | else: 161 | self.tokenizer = None 162 | 163 | # def state_dict(self): 164 | # state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict()] 165 | # return state_dict 166 | # 167 | # def load_state_dict(self, state_dict): 168 | # self.img_enc.load_state_dict(state_dict[0]) 169 | # self.txt_enc.load_state_dict(state_dict[1]) 170 | # 171 | # def train_start(self): 172 | # """switch to train mode 173 | # """ 174 | # self.img_enc.train() 175 | # self.txt_enc.train() 176 | # 177 | # def val_start(self): 178 | # """switch to evaluate mode 179 | # """ 180 | # self.img_enc.eval() 181 | # self.txt_enc.eval() 182 | 183 | def forward_emb(self, images, captions, img_len, cap_len, boxes): 184 | """Compute the image and caption embeddings 185 | """ 186 | # Set mini-batch dataset 187 | if torch.cuda.is_available(): 188 | images = images.cuda() 189 | captions = captions.cuda() 190 | boxes = boxes.cuda() 191 | 192 | # Forward 193 | img_emb_aggr, cap_emb_aggr, img_feats, cap_feats = self.img_txt_enc(images, captions, img_len, cap_len, boxes) 194 | 195 | if self.tokenizer is not None: 196 | # remove stopwords 197 | # keep only word indexes that are not stopwords 198 | good_word_indexes = [[i for i, (tok, w) in enumerate(zip(self.tokenizer.convert_ids_to_tokens(ids), ids)) if 199 | tok not in self.en_stops or w == 0] for ids in captions] # keeps the padding 200 | cap_len = [len(w) - (cap_feats.shape[0] - orig_len) for w, orig_len in zip(good_word_indexes, cap_len)] 201 | min_cut_len = min([len(w) for w in good_word_indexes]) 202 | good_word_indexes = [words[:min_cut_len] for words in good_word_indexes] 203 | good_word_indexes = torch.LongTensor(good_word_indexes).to(cap_feats.device) # B x S 204 | good_word_indexes = good_word_indexes.t().unsqueeze(2).expand(-1, -1, cap_feats.shape[2]) # S x B x dim 205 | cap_feats = cap_feats.gather(dim=0, index=good_word_indexes) 206 | 207 | return img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, cap_len 208 | 209 | def get_parameters(self): 210 | lr_multiplier = 1.0 if self.config['text-model']['fine-tune'] else 0.0 211 | 212 | ret = [] 213 | params = list(self.img_txt_enc.img_enc.parameters()) 214 | params += list(self.img_txt_enc.img_proj.parameters()) 215 | params += list(self.img_txt_enc.cap_proj.parameters()) 216 | params += list(self.img_txt_enc.transformer_encoder_1.parameters()) 217 | 218 | params += list(self.img_txt_enc.image_aggregation.parameters()) 219 | params += list(self.img_txt_enc.text_aggregation.parameters()) 220 | 221 | if not self.config['model']['shared-transformer']: 222 | params += list(self.img_txt_enc.transformer_encoder_2.parameters()) 223 | 224 | ret.append(params) 225 | 226 | ret.append(list(self.img_txt_enc.txt_enc.parameters())) 227 | 228 | return ret, lr_multiplier 229 | 230 | def forward_loss(self, img_emb, cap_emb, img_emb_set, cap_emb_seq, img_lengths, cap_lengths): 231 | """Compute the loss given pairs of image and caption embeddings 232 | """ 233 | # bs = img_emb.shape[0] 234 | losses = {} 235 | 236 | if 'matching' in self.config['training']['loss-type']: 237 | matching_loss = self.matching_criterion(img_emb, cap_emb) 238 | losses.update({'matching-loss': matching_loss}) 239 | self.logger.update('matching_loss', matching_loss.item(), img_emb.size(0)) 240 | 241 | if 'alignment' in self.config['training']['loss-type']: 242 | img_emb_set = img_emb_set.permute(1, 0, 2) 243 | cap_emb_seq = cap_emb_seq.permute(1, 0, 2) 244 | alignment_loss = self.alignment_criterion(img_emb_set, cap_emb_seq, img_lengths, cap_lengths) 245 | losses.update({'alignment-loss': alignment_loss}) 246 | self.logger.update('alignment_loss', alignment_loss.item(), img_emb_set.size(0)) 247 | 248 | # self.logger.update('Le', matching_loss.item() + alignment_loss.item(), img_emb.size(0) if img_emb is not None else img_emb_set.size(1)) 249 | return losses 250 | 251 | def forward(self, images, targets, img_lengths, cap_lengths, boxes=None, ids=None, *args): 252 | """One training step given images and captions. 253 | """ 254 | # assert self.training() 255 | self.Eiters += 1 256 | self.logger.update('Eit', self.Eiters) 257 | 258 | if type(targets) == tuple or type(targets) == list: 259 | captions, features, wembeddings = targets 260 | # captions = features # Very weird, I know 261 | text = features 262 | else: 263 | text = targets 264 | captions = targets 265 | wembeddings = self.img_txt_enc.txt_enc.word_embeddings(captions.cuda() if torch.cuda.is_available() else captions) 266 | 267 | # compute the embeddings 268 | img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, cap_lengths = self.forward_emb(images, text, img_lengths, cap_lengths, boxes) 269 | # NOTE: img_feats and cap_feats are S x B x dim 270 | 271 | loss_dict = self.forward_loss(img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, img_lengths, cap_lengths) 272 | return loss_dict 273 | -------------------------------------------------------------------------------- /models/visual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | from torchvision import models as models 7 | from models.utils import PositionalEncodingImageBoxes, l2norm 8 | 9 | 10 | def EncoderImage(config): 11 | 12 | # data_name, img_dim, embed_size, finetune=False, 13 | # cnn_type='vgg19', use_abs=False, no_imgnorm=False): 14 | """A wrapper to image encoders. Chooses between an encoder that uses 15 | precomputed image features, `EncoderImagePrecomp`, or an encoder that 16 | computes image features on the fly `EncoderImageFull`. 17 | """ 18 | 19 | embed_size = config['model']['embed-size'] 20 | order_embeddings = config['training']['measure'] == 'order' 21 | no_imgnorm = not config['image-model']['norm'] 22 | if config['dataset']['pre-extracted-features']: 23 | img_dim = config['image-model']['feat-dim'] 24 | img_enc = EncoderImagePrecomp( 25 | img_dim, embed_size, order_embeddings, no_imgnorm) 26 | else: 27 | if config['image-model']['name'] == 'cnn': 28 | finetune = config['image-model']['fine-tune'] 29 | cnn_type = config['image-model']['model'] 30 | use_transformer = config['image-model']['use-transformer'] 31 | img_enc = EncoderImageFull( 32 | embed_size, finetune, cnn_type, order_embeddings, no_imgnorm, use_transformer=use_transformer) 33 | elif config['image-model']['name'] == 'bottomup': 34 | transformer_layers = config['image-model']['transformer-layers'] 35 | pos_encoding = config['image-model']['pos-encoding'] 36 | visual_feat_dim = config['image-model']['feat-dim'] 37 | dropout = config['image-model']['dropout'] 38 | img_enc = TransformerPostProcessing(transformer_layers, visual_feat_dim, embed_size, n_head=4, aggr='mean', pos_encoding=pos_encoding, dropout=dropout, order_embeddings=order_embeddings) 39 | elif config['image-model']['name'] == 'gcn': 40 | img_dim = config['image-model']['feat-dim'] 41 | img_enc = GCNVisualReasoning(img_dim, embed_size, data_name='coco', use_abs = False, no_imgnorm = False) 42 | else: 43 | img_enc = None 44 | 45 | return img_enc 46 | 47 | 48 | class TransformerPostProcessing(nn.Module): 49 | def __init__(self, num_transformer_layers, feat_dim, embed_size, n_head=4, aggr='mean', pos_encoding=None, dropout=0.1, order_embeddings=False): 50 | super().__init__() 51 | transformer_layer = nn.TransformerEncoderLayer(d_model=feat_dim, nhead=n_head, 52 | dim_feedforward=2048, 53 | dropout=dropout, activation='relu') 54 | self.transformer_encoder = nn.TransformerEncoder(transformer_layer, 55 | num_layers=num_transformer_layers) 56 | if pos_encoding is not None: 57 | self.pos_encoding_image = PositionalEncodingImageBoxes(feat_dim, pos_encoding) 58 | self.fc = nn.Linear(feat_dim, embed_size) 59 | self.aggr = aggr 60 | self.order_embeddings = order_embeddings 61 | if aggr == 'gated': 62 | self.gate_fn = nn.Sequential( 63 | nn.Linear(feat_dim, feat_dim), 64 | nn.ReLU(), 65 | nn.Linear(feat_dim, 1) 66 | ) 67 | self.node_fn = nn.Sequential( 68 | nn.Linear(feat_dim, feat_dim), 69 | nn.ReLU(), 70 | nn.Linear(feat_dim, feat_dim) 71 | ) 72 | self.pos_encoding = pos_encoding 73 | 74 | def forward(self, visual_feats, visual_feats_len=None, boxes=None): 75 | """ 76 | Takes an variable len batch of visual features and preprocess them through a transformer. Output a tensor 77 | with the same shape as visual_feats passed in input. 78 | :param visual_feats: 79 | :param visual_feats_len: 80 | :return: a tensor with the same shape as visual_feats passed in input. 81 | """ 82 | # max_len = max(visual_feats_len) 83 | # bs = visual_feats.shape[1] 84 | # attention_mask = torch.zeros(bs, max_len).bool() 85 | # for e, l in zip(attention_mask, visual_feats_len): 86 | # e[l:] = True 87 | # attention_mask = attention_mask.to(visual_feats.device) 88 | 89 | visual_feats = visual_feats.permute(1, 0, 2) 90 | if self.pos_encoding is not None: 91 | visual_feats = self.pos_encoding_image(visual_feats, boxes) 92 | 93 | if visual_feats_len is not None: 94 | bs = visual_feats.shape[1] 95 | # construct the attention mask 96 | max_len = max(visual_feats_len) 97 | mask = torch.zeros(bs, max_len).bool() 98 | for e, l in zip(mask, visual_feats_len): 99 | e[l:] = True 100 | mask = mask.to(visual_feats.device) 101 | else: 102 | mask = None 103 | 104 | visual_feats = self.transformer_encoder(visual_feats, src_key_padding_mask=mask) 105 | # visual_feats = visual_feats.permute(1, 0, 2) 106 | 107 | if self.aggr == 'mean': 108 | out = visual_feats.mean(dim=0) 109 | elif self.aggr == 'gated': 110 | out = visual_feats.permute(1, 0, 2) 111 | m = torch.sigmoid(self.gate_fn(out)) # B x S x 1 112 | v = self.node_fn(out) # B x S x dim 113 | out = torch.bmm(m.permute(0, 2, 1), v) # B x 1 x dim 114 | out = out.squeeze(1) # B x dim 115 | else: 116 | out = visual_feats[0] 117 | 118 | out = self.fc(out) 119 | if self.order_embeddings: 120 | out = torch.abs(out) 121 | 122 | return out, visual_feats.permute(1, 0, 2) 123 | 124 | 125 | def find_nhead(feat_dim, higher=8): 126 | # find the right n_head value (the highest value lower than 'higher') 127 | for i in reversed(range(higher + 1)): 128 | if feat_dim % i == 0: 129 | return i 130 | return 1 131 | 132 | 133 | class GCNVisualReasoning(nn.Module): 134 | 135 | def __init__(self, img_dim, embed_size, data_name, use_abs=False, no_imgnorm=False): 136 | super(GCNVisualReasoning, self).__init__() 137 | self.embed_size = embed_size 138 | self.no_imgnorm = no_imgnorm 139 | self.use_abs = use_abs 140 | self.data_name = data_name 141 | 142 | self.fc = nn.Linear(img_dim, embed_size) 143 | 144 | self.init_weights() 145 | 146 | # GSR 147 | self.img_rnn = nn.GRU(embed_size, embed_size, 1, batch_first=True) 148 | 149 | # GCN reasoning 150 | self.Rs_GCN_1 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 151 | self.Rs_GCN_2 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 152 | self.Rs_GCN_3 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 153 | self.Rs_GCN_4 = Rs_GCN(in_channels=embed_size, inter_channels=embed_size) 154 | 155 | if self.data_name == 'f30k_precomp': 156 | self.bn = nn.BatchNorm1d(embed_size) 157 | 158 | def init_weights(self): 159 | """Xavier initialization for the fully connected layer 160 | """ 161 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 162 | self.fc.out_features) 163 | self.fc.weight.data.uniform_(-r, r) 164 | self.fc.bias.data.fill_(0) 165 | 166 | def forward(self, images, img_len=None, boxes=None): 167 | assert not any(np.array(img_len) - img_len[0]) 168 | """Extract image feature vectors.""" 169 | 170 | fc_img_emd = self.fc(images) 171 | if self.data_name != 'f30k_precomp': 172 | fc_img_emd = l2norm(fc_img_emd) 173 | 174 | # GCN reasoning 175 | # -> B,D,N 176 | GCN_img_emd = fc_img_emd.permute(0, 2, 1) 177 | GCN_img_emd = self.Rs_GCN_1(GCN_img_emd) 178 | GCN_img_emd = self.Rs_GCN_2(GCN_img_emd) 179 | GCN_img_emd = self.Rs_GCN_3(GCN_img_emd) 180 | GCN_img_emd = self.Rs_GCN_4(GCN_img_emd) 181 | # -> B,N,D 182 | GCN_img_emd = GCN_img_emd.permute(0, 2, 1) 183 | 184 | GCN_img_emd = l2norm(GCN_img_emd) 185 | 186 | rnn_img, hidden_state = self.img_rnn(GCN_img_emd) 187 | 188 | # features = torch.mean(rnn_img,dim=1) 189 | features = hidden_state[0] 190 | 191 | if self.data_name == 'f30k_precomp': 192 | features = self.bn(features) 193 | 194 | # normalize in the joint embedding space 195 | if not self.no_imgnorm: 196 | features = l2norm(features) 197 | 198 | # take the absolute value of embedding (used in order embeddings) 199 | if self.use_abs: 200 | features = torch.abs(features) 201 | 202 | return features, GCN_img_emd 203 | 204 | def load_state_dict(self, state_dict): 205 | """Copies parameters. overwritting the default one to 206 | accept state_dict from Full model 207 | """ 208 | own_state = self.state_dict() 209 | new_state = OrderedDict() 210 | for name, param in state_dict.items(): 211 | if name in own_state: 212 | new_state[name] = param 213 | 214 | super().load_state_dict(new_state) 215 | 216 | 217 | class EncoderImageFull(nn.Module): 218 | 219 | def __init__(self, embed_size, finetune=False, cnn_type='vgg19', 220 | use_abs=False, no_imgnorm=False, avgpool_size=(4, 4), use_transformer=False): 221 | """Load pretrained VGG19 and replace top fc layer.""" 222 | super(EncoderImageFull, self).__init__() 223 | self.embed_size = embed_size 224 | self.no_imgnorm = no_imgnorm 225 | self.use_abs = use_abs 226 | 227 | # Load a pre-trained model 228 | self.cnn = self.get_cnn(cnn_type, True) 229 | 230 | # For efficient memory usage. 231 | for param in self.cnn.parameters(): 232 | param.requires_grad = finetune 233 | 234 | # Replace the last fully connected layer of CNN with a new one 235 | if cnn_type.startswith('vgg'): 236 | raise NotImplementedError 237 | self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features, 238 | embed_size) 239 | self.cnn.classifier = nn.Sequential( 240 | *list(self.cnn.classifier.children())[:-1]) 241 | elif cnn_type.startswith('resnet'): 242 | self.spatial_feats_dim = self.cnn.module.fc.in_features 243 | modules = list(self.cnn.module.children())[:-2] 244 | self.cnn = torch.nn.Sequential(*modules) 245 | self.avgpool = nn.AdaptiveAvgPool2d(avgpool_size) 246 | self.glob_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 247 | self.fc = nn.Linear(self.spatial_feats_dim, embed_size) 248 | 249 | # self.cnn.module.fc = nn.Sequential() 250 | 251 | self.use_transformer = use_transformer 252 | if use_transformer: 253 | self.transformer = TransformerPostProcessing(2, self.spatial_feats_dim, embed_size, n_head=4) 254 | self.init_weights() 255 | 256 | def get_cnn(self, arch, pretrained): 257 | """Load a pretrained CNN and parallelize over GPUs 258 | """ 259 | if pretrained: 260 | print("=> using pre-trained model '{}'".format(arch)) 261 | model = models.__dict__[arch](pretrained=True) 262 | else: 263 | print("=> creating model '{}'".format(arch)) 264 | model = models.__dict__[arch]() 265 | 266 | if arch.startswith('alexnet') or arch.startswith('vgg'): 267 | model.features = nn.DataParallel(model.features) 268 | else: 269 | model = nn.DataParallel(model) 270 | 271 | if torch.cuda.is_available(): 272 | model.cuda() 273 | 274 | return model 275 | 276 | def load_state_dict(self, state_dict): 277 | """ 278 | Handle the models saved before commit pytorch/vision@989d52a 279 | """ 280 | if 'cnn.classifier.1.weight' in state_dict: 281 | state_dict['cnn.classifier.0.weight'] = state_dict[ 282 | 'cnn.classifier.1.weight'] 283 | del state_dict['cnn.classifier.1.weight'] 284 | state_dict['cnn.classifier.0.bias'] = state_dict[ 285 | 'cnn.classifier.1.bias'] 286 | del state_dict['cnn.classifier.1.bias'] 287 | state_dict['cnn.classifier.3.weight'] = state_dict[ 288 | 'cnn.classifier.4.weight'] 289 | del state_dict['cnn.classifier.4.weight'] 290 | state_dict['cnn.classifier.3.bias'] = state_dict[ 291 | 'cnn.classifier.4.bias'] 292 | del state_dict['cnn.classifier.4.bias'] 293 | 294 | super(EncoderImageFull, self).load_state_dict(state_dict) 295 | 296 | def init_weights(self): 297 | """Xavier initialization for the fully connected layer 298 | """ 299 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 300 | self.fc.out_features) 301 | self.fc.weight.data.uniform_(-r, r) 302 | self.fc.bias.data.fill_(0) 303 | 304 | def forward(self, images): 305 | """Extract image feature vectors.""" 306 | spatial_features = self.cnn(images) 307 | features = self.glob_avgpool(spatial_features) # compute a single feature 308 | spatial_features = self.avgpool(spatial_features) # fix the size of the spatial grid 309 | 310 | if not self.use_transformer: 311 | features = torch.flatten(features, 1) 312 | # normalization in the image embedding space 313 | features = l2norm(features) 314 | # linear projection to the joint embedding space 315 | features = self.fc(features) 316 | else: 317 | # transformer + fc projection to the joint embedding space 318 | features, _ = self.transformer(spatial_features.view(spatial_features.shape[0], spatial_features.shape[1], -1).permute(2, 0, 1)) 319 | 320 | # normalization in the joint embedding space 321 | if not self.no_imgnorm: 322 | features = l2norm(features) 323 | 324 | # take the absolute value of the embedding (used in order embeddings) 325 | if self.use_abs: 326 | features = torch.abs(features) 327 | 328 | return features, spatial_features 329 | 330 | def get_finetuning_params(self): 331 | return list(self.cnn.parameters()) 332 | 333 | 334 | class EncoderImagePrecomp(nn.Module): 335 | 336 | def __init__(self, img_dim, embed_size, use_abs=False, no_imgnorm=False): 337 | super(EncoderImagePrecomp, self).__init__() 338 | self.embed_size = embed_size 339 | self.no_imgnorm = no_imgnorm 340 | self.use_abs = use_abs 341 | 342 | self.fc = nn.Linear(img_dim, embed_size) 343 | 344 | self.init_weights() 345 | 346 | def init_weights(self): 347 | """Xavier initialization for the fully connected layer 348 | """ 349 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 350 | self.fc.out_features) 351 | self.fc.weight.data.uniform_(-r, r) 352 | self.fc.bias.data.fill_(0) 353 | 354 | def forward(self, images): 355 | """Extract image feature vectors.""" 356 | # assuming that the precomputed features are already l2-normalized 357 | 358 | features = self.fc(images) 359 | 360 | # normalize in the joint embedding space 361 | if not self.no_imgnorm: 362 | features = l2norm(features) 363 | 364 | # take the absolute value of embedding (used in order embeddings) 365 | if self.use_abs: 366 | features = torch.abs(features) 367 | 368 | return features 369 | 370 | def load_state_dict(self, state_dict): 371 | """Copies parameters. overwritting the default one to 372 | accept state_dict from Full model 373 | """ 374 | own_state = self.state_dict() 375 | new_state = OrderedDict() 376 | for name, param in state_dict.items(): 377 | if name in own_state: 378 | new_state[name] = param 379 | 380 | super(EncoderImagePrecomp, self).load_state_dict(new_state) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy 4 | 5 | from data import get_test_loader 6 | import time 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | from collections import OrderedDict 11 | from utils import dot_sim, get_model 12 | from evaluate_utils.dcg import DCG 13 | from models.loss import order_sim, AlignmentContrastiveLoss 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=0): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / (.0001 + self.count) 33 | 34 | def __str__(self): 35 | """String representation for logging 36 | """ 37 | # for values that should be recorded exactly e.g. iteration number 38 | if self.count == 0: 39 | return str(self.val) 40 | # for stats 41 | return '%.4f (%.4f)' % (self.val, self.avg) 42 | 43 | 44 | class LogCollector(object): 45 | """A collection of logging objects that can change from train to val""" 46 | 47 | def __init__(self): 48 | # to keep the order of logged variables deterministic 49 | self.meters = OrderedDict() 50 | 51 | def update(self, k, v, n=0): 52 | # create a new meter if previously not recorded 53 | if k not in self.meters: 54 | self.meters[k] = AverageMeter() 55 | self.meters[k].update(v, n) 56 | 57 | def __str__(self): 58 | """Concatenate the meters in one log line 59 | """ 60 | s = '' 61 | for i, (k, v) in enumerate(self.meters.items()): 62 | if i > 0: 63 | s += ' ' 64 | s += k + ' ' + str(v) 65 | return s 66 | 67 | def tb_log(self, tb_logger, prefix='', step=None): 68 | """Log using tensorboard 69 | """ 70 | for k, v in self.meters.items(): 71 | tb_logger.add_scalar(prefix + k, v.val, global_step=step) 72 | 73 | 74 | def encode_data(model, data_loader, log_step=10, logging=print): 75 | """Encode all images and captions loadable by `data_loader` 76 | """ 77 | batch_time = AverageMeter() 78 | val_logger = LogCollector() 79 | 80 | # switch to evaluate mode 81 | model.eval() 82 | 83 | end = time.time() 84 | 85 | # numpy array to keep all the embeddings 86 | img_embs = None 87 | cap_embs = None 88 | img_lengths = [] 89 | cap_lengths = [] 90 | 91 | # compute maximum lenghts in the whole dataset 92 | max_cap_len = 88 93 | max_img_len = 37 94 | # for _, _, img_length, cap_length, _, _ in data_loader: 95 | # max_cap_len = max(max_cap_len, max(cap_length)) 96 | # max_img_len = max(max_img_len, max(img_length)) 97 | 98 | for i, (images, targets, img_length, cap_length, boxes, ids) in enumerate(data_loader): 99 | # make sure val logger is used 100 | model.logger = val_logger 101 | 102 | if type(targets) == tuple or type(targets) == list: 103 | captions, features, wembeddings = targets 104 | # captions = features # Very weird, I know 105 | text = features 106 | else: 107 | text = targets 108 | captions = targets 109 | wembeddings = model.img_txt_enc.txt_enc.word_embeddings(captions.cuda() if torch.cuda.is_available() else captions) 110 | 111 | # compute the embeddings 112 | with torch.no_grad(): 113 | _, _, img_emb, cap_emb, cap_length = model.forward_emb(images, text, img_length, cap_length, boxes) 114 | 115 | # initialize the numpy arrays given the size of the embeddings 116 | if img_embs is None: 117 | img_embs = torch.zeros((len(data_loader.dataset), max_img_len, img_emb.size(2))) 118 | cap_embs = torch.zeros((len(data_loader.dataset), max_cap_len, cap_emb.size(2))) 119 | 120 | # preserve the embeddings by copying from gpu and converting to numpy 121 | img_embs[ids, :img_emb.size(0), :] = img_emb.cpu().permute(1, 0, 2) 122 | cap_embs[ids, :cap_emb.size(0), :] = cap_emb.cpu().permute(1, 0, 2) 123 | img_lengths.extend(img_length) 124 | cap_lengths.extend(cap_length) 125 | 126 | # measure accuracy and record loss 127 | # model.forward_loss(None, None, img_emb, cap_emb, img_length, cap_length) 128 | 129 | # measure elapsed time 130 | batch_time.update(time.time() - end) 131 | end = time.time() 132 | 133 | if i % log_step == 0: 134 | logging('Test: [{0}/{1}]\t' 135 | '{e_log}\t' 136 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 137 | .format( 138 | i, len(data_loader), batch_time=batch_time, 139 | e_log=str(model.logger))) 140 | del images, captions 141 | 142 | # p = np.random.permutation(len(data_loader.dataset) // 5) * 5 143 | # p = np.transpose(np.tile(p, (5, 1))) 144 | # p = p + np.array([0, 1, 2, 3, 4]) 145 | # p = p.flatten() 146 | # img_embs = img_embs[p] 147 | # cap_embs = cap_embs[p] 148 | 149 | return img_embs, cap_embs, img_lengths, cap_lengths 150 | 151 | 152 | def evalrank(config, checkpoint, split='dev', fold5=False): 153 | """ 154 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold 155 | cross-validation is done (only for MSCOCO). Otherwise, the full data is 156 | used for evaluation. 157 | """ 158 | # load model and options 159 | # checkpoint = torch.load(model_path) 160 | data_path = config['dataset']['data'] 161 | measure = config['training']['measure'] 162 | 163 | # construct model 164 | model = get_model(config) 165 | 166 | # load model state 167 | model.load_state_dict(checkpoint['model'], strict=False) 168 | 169 | print('Loading dataset') 170 | data_loader = get_test_loader(config, workers=4, split_name=split) 171 | 172 | # initialize ndcg scorer 173 | ndcg_val_scorer = DCG(config, len(data_loader.dataset), split, rank=25, relevance_methods=['rougeL', 'spice']) 174 | 175 | # initialize similarity matrix evaluator 176 | sim_matrix_fn = AlignmentContrastiveLoss(aggregation=config['training']['alignment-mode'], return_similarity_mat=True) if config['training']['loss-type'] == 'alignment' else None 177 | 178 | print('Computing results...') 179 | img_embs, cap_embs, img_lenghts, cap_lenghts = encode_data(model, data_loader) 180 | torch.cuda.empty_cache() 181 | 182 | # if checkpoint2 is not None: 183 | # # construct model 184 | # model2 = get_model(config2) 185 | # # load model state 186 | # model2.load_state_dict(checkpoint2['model'], strict=False) 187 | # img_embs2, cap_embs2 = encode_data(model2, data_loader) 188 | # print('Using 2-model ensemble') 189 | # else: 190 | # img_embs2, cap_embs2 = None, None 191 | # print('Using NO ensemble') 192 | 193 | print('Images: %d, Captions: %d' % 194 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 195 | 196 | if not fold5: 197 | # no cross-validation, full evaluation 198 | r, rt = i2t(img_embs, cap_embs, img_lenghts, cap_lenghts, return_ranks=True, ndcg_scorer=ndcg_val_scorer, sim_function=sim_matrix_fn, cap_batches=5) 199 | ri, rti = t2i(img_embs, cap_embs, img_lenghts, cap_lenghts, return_ranks=True, ndcg_scorer=ndcg_val_scorer, sim_function=sim_matrix_fn, im_batches=5) 200 | ar = (r[0] + r[1] + r[2]) / 3 201 | ari = (ri[0] + ri[1] + ri[2]) / 3 202 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 203 | print("rsum: %.1f" % rsum) 204 | print("Average i2t Recall: %.1f" % ar) 205 | print("Image to text: %.1f %.1f %.1f %.1f %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % r) 206 | print("Average t2i Recall: %.1f" % ari) 207 | print("Text to image: %.1f %.1f %.1f %.1f %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) 208 | else: 209 | # 5fold cross-validation, only for MSCOCO 210 | results = [] 211 | for i in range(5): 212 | r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], 213 | img_lenghts[i * 5000:(i + 1) * 5000], cap_lenghts[i * 5000:(i + 1) * 5000], 214 | return_ranks=True, ndcg_scorer=ndcg_val_scorer, fold_index=i, sim_function=sim_matrix_fn, cap_batches=1) 215 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f" % r) 216 | ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], 217 | img_lenghts[i * 5000:(i + 1) * 5000], cap_lenghts[i * 5000:(i + 1) * 5000], 218 | return_ranks=True, ndcg_scorer=ndcg_val_scorer, fold_index=i, sim_function=sim_matrix_fn, im_batches=1) 219 | if i == 0: 220 | rt, rti = rt0, rti0 221 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) 222 | ar = (r[0] + r[1] + r[2]) / 3 223 | ari = (ri[0] + ri[1] + ri[2]) / 3 224 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 225 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 226 | results += [list(r) + list(ri) + [ar, ari, rsum]] 227 | 228 | print("-----------------------------------") 229 | print("Mean metrics: ") 230 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 231 | print("rsum: %.1f" % (mean_metrics[16] * 6)) 232 | print("Average i2t Recall: %.1f" % mean_metrics[14]) 233 | print("Image to text: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % 234 | mean_metrics[:7]) 235 | print("Average t2i Recall: %.1f" % mean_metrics[15]) 236 | print("Text to image: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % 237 | mean_metrics[7:14]) 238 | 239 | torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') 240 | 241 | 242 | def i2t(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=False, ndcg_scorer=None, fold_index=0, measure='dot', sim_function=None, cap_batches=1): 243 | """ 244 | Images->Text (Image Annotation) 245 | Images: (5N, K) matrix of images 246 | Captions: (5N, K) matrix of captions 247 | """ 248 | if npts is None: 249 | npts = images.shape[0] // 5 250 | index_list = [] 251 | 252 | ranks = numpy.zeros(npts) 253 | top1 = numpy.zeros(npts) 254 | rougel_ndcgs = numpy.zeros(npts) 255 | spice_ndcgs = numpy.zeros(npts) 256 | # captions = captions.cuda() 257 | captions_per_batch = captions.shape[0] // cap_batches 258 | 259 | for index in tqdm.trange(npts): 260 | 261 | # Get query image 262 | im = images[5 * index].reshape(1, images.shape[1], images.shape[2]) 263 | im = im.cuda() if sim_function is not None else im 264 | im_len = [img_lenghts[5 * index]] 265 | 266 | d = None 267 | 268 | # Compute scores 269 | if measure == 'order': 270 | bs = 100 271 | if index % bs == 0: 272 | mx = min(images.shape[0], 5 * (index + bs)) 273 | im2 = images[5 * index:mx:5] 274 | d2 = order_sim(torch.Tensor(im2).cuda(), 275 | torch.Tensor(captions).cuda()) 276 | d2 = d2.cpu().numpy() 277 | d = d2[index % bs] 278 | else: 279 | if sim_function is None: 280 | d = torch.mm(im[:, 0, :], captions[:, 0, :].t()) 281 | d = d.cpu().numpy().flatten() 282 | else: 283 | for i in range(cap_batches): 284 | captions_now = captions[i*captions_per_batch:(i+1)*captions_per_batch] 285 | cap_lenghts_now = cap_lenghts[i*captions_per_batch:(i+1)*captions_per_batch] 286 | captions_now = captions_now.cuda() 287 | 288 | d_align = sim_function(im, captions_now, im_len, cap_lenghts_now) 289 | d_align = d_align.cpu().numpy().flatten() 290 | # d_matching = torch.mm(im[:, 0, :], captions[:, 0, :].t()) 291 | # d_matching = d_matching.cpu().numpy().flatten() 292 | if d is None: 293 | d = d_align # + d_matching 294 | else: 295 | d = numpy.concatenate([d, d_align], axis=0) 296 | 297 | inds = numpy.argsort(d)[::-1] 298 | index_list.append(inds[0]) 299 | 300 | # Score 301 | rank = 1e20 302 | for i in range(5 * index, 5 * index + 5, 1): 303 | tmp = numpy.where(inds == i)[0][0] 304 | if tmp < rank: 305 | rank = tmp 306 | ranks[index] = rank 307 | top1[index] = inds[0] 308 | 309 | if ndcg_scorer is not None: 310 | rougel_ndcgs[index], spice_ndcgs[index] = ndcg_scorer.compute_ndcg(npts, index, inds.astype(int), 311 | fold_index=fold_index, 312 | retrieval='sentence').values() 313 | 314 | # Compute metrics 315 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 316 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 317 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 318 | medr = numpy.floor(numpy.median(ranks)) + 1 319 | meanr = ranks.mean() + 1 320 | mean_rougel_ndcg = np.mean(rougel_ndcgs[rougel_ndcgs != 0]) 321 | mean_spice_ndcg = np.mean(spice_ndcgs[spice_ndcgs != 0]) 322 | if return_ranks: 323 | return (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg), (ranks, top1) 324 | else: 325 | return (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg) 326 | 327 | 328 | def t2i(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=False, ndcg_scorer=None, fold_index=0, measure='dot', sim_function=None, im_batches=1): 329 | """ 330 | Text->Images (Image Search) 331 | Images: (5N, K) matrix of images 332 | Captions: (5N, K) matrix of captions 333 | """ 334 | if npts is None: 335 | npts = images.shape[0] // 5 336 | ims = torch.stack([images[i] for i in range(0, len(images), 5)], dim=0) 337 | # ims = ims.cuda() 338 | ims_len = [img_lenghts[i] for i in range(0, len(images), 5)] 339 | 340 | ranks = numpy.zeros(5 * npts) 341 | top50 = numpy.zeros((5 * npts, 50)) 342 | rougel_ndcgs = numpy.zeros(5 * npts) 343 | spice_ndcgs = numpy.zeros(5 * npts) 344 | 345 | images_per_batch = ims.shape[0] // im_batches 346 | 347 | for index in tqdm.trange(npts): 348 | 349 | # Get query captions 350 | queries = captions[5 * index:5 * index + 5] 351 | queries = queries.cuda() if sim_function is not None else queries 352 | queries_len = cap_lenghts[5 * index:5 * index + 5] 353 | 354 | d = None 355 | 356 | # Compute scores 357 | if measure == 'order': 358 | bs = 100 359 | if 5 * index % bs == 0: 360 | mx = min(captions.shape[0], 5 * index + bs) 361 | q2 = captions[5 * index:mx] 362 | d2 = order_sim(torch.Tensor(ims).cuda(), 363 | torch.Tensor(q2).cuda()) 364 | d2 = d2.cpu().numpy() 365 | 366 | d = d2[:, (5 * index) % bs:(5 * index) % bs + 5].T 367 | else: 368 | if sim_function is None: 369 | d = torch.mm(queries[:, 0, :], ims[:, 0, :].t()) 370 | d = d.cpu().numpy() 371 | else: 372 | for i in range(im_batches): 373 | ims_now = ims[i * images_per_batch:(i+1) * images_per_batch] 374 | ims_len_now = ims_len[i * images_per_batch:(i+1) * images_per_batch] 375 | ims_now = ims_now.cuda() 376 | 377 | # d = numpy.dot(queries, ims.T) 378 | d_align = sim_function(ims_now, queries, ims_len_now, queries_len).t() 379 | d_align = d_align.cpu().numpy() 380 | # d_matching = torch.mm(queries[:, 0, :], ims[:, 0, :].t()) 381 | # d_matching = d_matching.cpu().numpy() 382 | if d is None: 383 | d = d_align # + d_matching 384 | else: 385 | d = numpy.concatenate([d, d_align], axis=1) 386 | 387 | inds = numpy.zeros(d.shape) 388 | for i in range(len(inds)): 389 | inds[i] = numpy.argsort(d[i])[::-1] 390 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][ 391 | 0] # in che posizione e' l'immagine (index) che ha questa caption (5*index + i) 392 | top50[5 * index + i] = inds[i][0:50] 393 | # calculate ndcg 394 | if ndcg_scorer is not None: 395 | rougel_ndcgs[5 * index + i], spice_ndcgs[5 * index + i] = \ 396 | ndcg_scorer.compute_ndcg(npts, 5 * index + i, inds[i].astype(int), 397 | fold_index=fold_index, retrieval='image').values() 398 | 399 | # Compute metrics 400 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 401 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 402 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 403 | medr = numpy.floor(numpy.median(ranks)) + 1 404 | meanr = ranks.mean() + 1 405 | mean_rougel_ndcg = np.mean(rougel_ndcgs) 406 | mean_spice_ndcg = np.mean(spice_ndcgs) 407 | 408 | if return_ranks: 409 | return (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg), (ranks, top50) 410 | else: 411 | return (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg) 412 | 413 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import time 4 | import shutil 5 | import yaml 6 | import numpy as np 7 | 8 | import torch 9 | import pytorch_warmup as warmup 10 | 11 | import data 12 | from models.loss import AlignmentContrastiveLoss 13 | from utils import get_model, cosine_sim, dot_sim 14 | from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data 15 | from evaluate_utils.dcg import DCG 16 | 17 | import logging 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | import argparse 21 | 22 | 23 | def main(): 24 | # Hyper Parameters 25 | parser = argparse.ArgumentParser() 26 | # parser.add_argument('--data_path', default='/w/31/faghri/vsepp_data/', 27 | # help='path to datasets') 28 | # parser.add_argument('--data_name', default='precomp', 29 | # help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k') 30 | parser.add_argument('--num_epochs', default=30, type=int, 31 | help='Number of training epochs.') 32 | # parser.add_argument('--crop_size', default=224, type=int, 33 | # help='Size of an image crop as the CNN input.') 34 | parser.add_argument('--lr_update', default=15, type=int, 35 | help='Number of epochs to update the learning rate.') 36 | parser.add_argument('--workers', default=10, type=int, 37 | help='Number of data loader workers.') 38 | parser.add_argument('--log_step', default=10, type=int, 39 | help='Number of steps to print and record the log.') 40 | parser.add_argument('--val_step', default=500, type=int, 41 | help='Number of steps to run validation.') 42 | parser.add_argument('--test_step', default=100000000, type=int, 43 | help='Number of steps to run validation.') 44 | parser.add_argument('--logger_name', default='runs/runX', 45 | help='Path to save the model and Tensorboard log.') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none). Loads model, optimizer, scheduler') 48 | parser.add_argument('--load-model', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none). Loads only the model') 50 | parser.add_argument('--use_restval', action='store_true', 51 | help='Use the restval data for training on MSCOCO.') 52 | parser.add_argument('--reinitialize-scheduler', action='store_true', help='Reinitialize scheduler. To use with --resume') 53 | parser.add_argument('--config', type=str, help="Which configuration to use. See into 'config' folder") 54 | 55 | opt = parser.parse_args() 56 | print(opt) 57 | 58 | # torch.cuda.set_enabled_lms(True) 59 | # if (torch.cuda.get_enabled_lms()): 60 | # torch.cuda.set_limit_lms(11000 * 1024 * 1024) 61 | # print('[LMS=On limit=' + str(torch.cuda.get_limit_lms()) + ']') 62 | 63 | with open(opt.config, 'r') as ymlfile: 64 | config = yaml.load(ymlfile) 65 | 66 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 67 | tb_logger = SummaryWriter(log_dir=opt.logger_name, comment='') 68 | 69 | # Load data loaders 70 | train_loader, val_loader = data.get_loaders( 71 | config, opt.workers) 72 | # test_loader = data.get_test_loader(config, vocab=vocab, workers=4, split_name='test') 73 | 74 | # Construct the model 75 | model = get_model(config) 76 | if torch.cuda.is_available() and not (opt.resume or opt.load_model): 77 | model.cuda() 78 | 79 | assert not ((config['image-model']['fine-tune'] or config['text-model']['fine-tune']) and config['dataset']['pre-extracted-features']) 80 | # Construct the optimizer 81 | 82 | # if config['model']['name'] == 'transformthem': # TODO: handle better 83 | # params = model.parameters() 84 | # for p in model.img_txt_enc.txt_enc.parameters(): 85 | # p.requires_grad = False 86 | # else: 87 | # params = list(model.txt_enc.parameters()) 88 | # params += list(model.img_enc.fc.parameters()) 89 | # 90 | # if not config['dataset']['pre-extracted-features']: 91 | # if config['image-model']['fine-tune']: 92 | # print('Finetuning image encoder') 93 | # params += list(model.img_enc.get_finetuning_params()) 94 | # if config['text-model']['fine-tune']: 95 | # print('Finetuning text encoder') 96 | # params += list(model.txt_enc.get_finetuning_params()) 97 | 98 | params, secondary_lr_multip = model.get_parameters() 99 | # validity check 100 | all_params = params[0] + params[1] 101 | if len(all_params) != len(list(model.parameters())): 102 | raise ValueError('Not all parameters are being returned! Correct get_parameters() method') 103 | 104 | if secondary_lr_multip > 0: 105 | optimizer = torch.optim.Adam([{'params': params[0]}, 106 | {'params': params[1], 'lr': config['training']['lr']*secondary_lr_multip}], 107 | lr=config['training']['lr']) 108 | else: 109 | optimizer = torch.optim.Adam(params[0], lr=config['training']['lr']) 110 | 111 | # LR scheduler 112 | scheduler_name = config['training']['scheduler'] 113 | if scheduler_name == 'steplr': 114 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config['training']['step-size'], gamma=config['training']['gamma']) 115 | elif scheduler_name is None: 116 | scheduler = None 117 | else: 118 | raise ValueError('{} scheduler is not available'.format(scheduler_name)) 119 | 120 | # Warmup scheduler 121 | warmup_scheduler_name = config['training']['warmup'] if not opt.resume else None 122 | if warmup_scheduler_name == 'linear': 123 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=config['training']['warmup-period']) 124 | elif warmup_scheduler_name is None: 125 | warmup_scheduler = None 126 | else: 127 | raise ValueError('{} warmup scheduler is not available'.format(warmup_scheduler_name)) 128 | 129 | # optionally resume from a checkpoint 130 | start_epoch = 0 131 | if opt.resume or opt.load_model: 132 | filename = opt.resume if opt.resume else opt.load_model 133 | if os.path.isfile(filename): 134 | print("=> loading checkpoint '{}'".format(filename)) 135 | checkpoint = torch.load(filename, map_location='cpu') 136 | model.load_state_dict(checkpoint['model'], strict=False) 137 | if torch.cuda.is_available(): 138 | model.cuda() 139 | if opt.resume: 140 | start_epoch = checkpoint['epoch'] 141 | # best_rsum = checkpoint['best_rsum'] 142 | optimizer.load_state_dict(checkpoint['optimizer']) 143 | if checkpoint['scheduler'] is not None and not opt.reinitialize_scheduler: 144 | scheduler.load_state_dict(checkpoint['scheduler']) 145 | # Eiters is used to show logs as the continuation of another 146 | # training 147 | model.Eiters = checkpoint['Eiters'] 148 | print("=> loaded checkpoint '{}' (epoch {})" 149 | .format(opt.resume, start_epoch)) 150 | else: 151 | print("=> loaded only model from checkpoint '{}'" 152 | .format(opt.load_model)) 153 | else: 154 | print("=> no checkpoint found at '{}'".format(opt.resume)) 155 | 156 | if torch.cuda.is_available(): 157 | model.cuda() 158 | model.train() 159 | 160 | # load the ndcg scorer 161 | ndcg_val_scorer = DCG(config, len(val_loader.dataset), 'val', rank=25, relevance_methods=['rougeL', 'spice']) 162 | # ndcg_test_scorer = DCG(config, len(test_loader.dataset), 'test', rank=25, relevance_methods=['rougeL', 'spice']) 163 | 164 | # Train the Model 165 | best_rsum = 0 166 | best_ndcg_sum = 0 167 | alignment_mode = config['training']['alignment-mode'] if config['training']['loss-type'] == 'alignment' else None 168 | 169 | # validate(val_loader, model, tb_logger, measure=config['training']['measure'], log_step=opt.log_step, 170 | # ndcg_scorer=ndcg_val_scorer, alignment_mode=alignment_mode) 171 | 172 | for epoch in range(start_epoch, opt.num_epochs): 173 | # train for one epoch 174 | train(opt, train_loader, model, optimizer, epoch, tb_logger, val_loader, None, 175 | measure=config['training']['measure'], grad_clip=config['training']['grad-clip'], 176 | scheduler=scheduler, warmup_scheduler=warmup_scheduler, ndcg_val_scorer=ndcg_val_scorer, ndcg_test_scorer=None, alignment_mode=alignment_mode) 177 | 178 | # evaluate on validation set 179 | rsum, ndcg_sum = validate(val_loader, model, tb_logger, measure=config['training']['measure'], log_step=opt.log_step, 180 | ndcg_scorer=ndcg_val_scorer, alignment_mode=alignment_mode) 181 | 182 | # remember best R@ sum and save checkpoint 183 | is_best_rsum = rsum > best_rsum 184 | best_rsum = max(rsum, best_rsum) 185 | 186 | is_best_ndcg = ndcg_sum > best_ndcg_sum 187 | best_ndcg_sum = max(ndcg_sum, best_ndcg_sum) 188 | # 189 | # is_best_r1 = r1 > best_r1 190 | # best_r1 = max(r1, best_r1) 191 | 192 | # is_best_val_loss = val_loss < best_val_loss 193 | # best_val_loss = min(val_loss, best_val_loss) 194 | 195 | save_checkpoint({ 196 | 'epoch': epoch + 1, 197 | 'model': model.state_dict(), 198 | 'optimizer': optimizer.state_dict(), 199 | 'scheduler': scheduler.state_dict() if scheduler is not None else None, 200 | 'opt': opt, 201 | 'config': config, 202 | 'Eiters': model.Eiters, 203 | }, is_best_rsum, is_best_ndcg, prefix=opt.logger_name + '/') 204 | 205 | 206 | def train(opt, train_loader, model, optimizer, epoch, tb_logger, val_loader, test_loader, measure='cosine', grad_clip=-1, scheduler=None, warmup_scheduler=None, ndcg_val_scorer=None, ndcg_test_scorer=None, alignment_mode=None): 207 | # average meters to record the training statistics 208 | batch_time = AverageMeter() 209 | data_time = AverageMeter() 210 | train_logger = LogCollector() 211 | 212 | end = time.time() 213 | for i, train_data in enumerate(train_loader): 214 | model.train() 215 | if scheduler is not None: 216 | scheduler.step(epoch) 217 | 218 | if warmup_scheduler is not None: 219 | warmup_scheduler.dampen() 220 | 221 | optimizer.zero_grad() 222 | 223 | # measure data loading time 224 | data_time.update(time.time() - end) 225 | 226 | # make sure train logger is used 227 | model.logger = train_logger 228 | 229 | # Update the model 230 | loss_dict = model(*train_data) 231 | loss = sum(loss for loss in loss_dict.values()) 232 | 233 | # compute gradient and do SGD step 234 | loss.backward() 235 | if grad_clip > 0: 236 | torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), grad_clip) 237 | optimizer.step() 238 | 239 | # measure elapsed time 240 | batch_time.update(time.time() - end) 241 | end = time.time() 242 | 243 | # Print log info 244 | if model.Eiters % opt.log_step == 0: 245 | logging.info( 246 | 'Epoch: [{0}][{1}/{2}]\t' 247 | '{e_log}\t' 248 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 249 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 250 | .format( 251 | epoch, i, len(train_loader), batch_time=batch_time, 252 | data_time=data_time, e_log=str(model.logger))) 253 | 254 | # Record logs in tensorboard 255 | tb_logger.add_scalar('epoch', epoch, model.Eiters) 256 | tb_logger.add_scalar('step', i, model.Eiters) 257 | tb_logger.add_scalar('batch_time', batch_time.val, model.Eiters) 258 | tb_logger.add_scalar('data_time', data_time.val, model.Eiters) 259 | tb_logger.add_scalar('lr', optimizer.param_groups[0]['lr'], model.Eiters) 260 | model.logger.tb_log(tb_logger, step=model.Eiters) 261 | 262 | # validate at every val_step 263 | if model.Eiters % opt.val_step == 0: 264 | validate(val_loader, model, tb_logger, measure=measure, log_step=opt.log_step, ndcg_scorer=ndcg_val_scorer, alignment_mode=alignment_mode) 265 | 266 | # if model.Eiters % opt.test_step == 0: 267 | # test(test_loader, model, tb_logger, measure=measure, log_step=opt.log_step, ndcg_scorer=ndcg_test_scorer) 268 | 269 | 270 | def validate(val_loader, model, tb_logger, measure='cosine', log_step=10, ndcg_scorer=None, alignment_mode=None): 271 | # compute the encoding for all the validation images and captions 272 | img_embs, cap_embs, img_lenghts, cap_lenghts = encode_data( 273 | model, val_loader, log_step, logging.info) 274 | 275 | # initialize similarity matrix evaluator 276 | sim_matrix_fn = AlignmentContrastiveLoss(aggregation=alignment_mode, return_similarity_mat=True) if alignment_mode is not None else None 277 | 278 | if measure == 'cosine': 279 | sim_fn = cosine_sim 280 | elif measure == 'dot': 281 | sim_fn = dot_sim 282 | 283 | # caption retrieval 284 | (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg) = i2t(img_embs, cap_embs, img_lenghts, cap_lenghts, measure=measure, ndcg_scorer=ndcg_scorer, sim_function=sim_matrix_fn) 285 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f" % 286 | (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg)) 287 | # image retrieval 288 | (r1i, r5i, r10i, medri, meanr, mean_rougel_ndcg_i, mean_spice_ndcg_i) = t2i( 289 | img_embs, cap_embs, img_lenghts, cap_lenghts, ndcg_scorer=ndcg_scorer, measure=measure, sim_function=sim_matrix_fn) 290 | 291 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f" % 292 | (r1i, r5i, r10i, medri, meanr, mean_rougel_ndcg_i, mean_spice_ndcg_i)) 293 | # sum of recalls to be used for early stopping 294 | currscore = r1 + r5 + r10 + r1i + r5i + r10i 295 | spice_ndcg_sum = mean_spice_ndcg + mean_spice_ndcg_i 296 | 297 | # record metrics in tensorboard 298 | tb_logger.add_scalar('r1', r1, model.Eiters) 299 | tb_logger.add_scalar('r5', r5, model.Eiters) 300 | tb_logger.add_scalar('r10', r10, model.Eiters) 301 | tb_logger.add_scalars('mean_ndcg', {'rougeL': mean_rougel_ndcg, 'spice': mean_spice_ndcg}, model.Eiters) 302 | tb_logger.add_scalar('medr', medr, model.Eiters) 303 | tb_logger.add_scalar('meanr', meanr, model.Eiters) 304 | tb_logger.add_scalar('r1i', r1i, model.Eiters) 305 | tb_logger.add_scalar('r5i', r5i, model.Eiters) 306 | tb_logger.add_scalar('r10i', r10i, model.Eiters) 307 | tb_logger.add_scalars('mean_ndcg_i', {'rougeL': mean_rougel_ndcg_i, 'spice': mean_spice_ndcg_i}, model.Eiters) 308 | tb_logger.add_scalar('medri', medri, model.Eiters) 309 | tb_logger.add_scalar('meanr', meanr, model.Eiters) 310 | tb_logger.add_scalar('rsum', currscore, model.Eiters) 311 | tb_logger.add_scalar('spice_ndcg_sum', spice_ndcg_sum, model.Eiters) 312 | 313 | return currscore, spice_ndcg_sum 314 | 315 | 316 | def test(test_loader, model, tb_logger, measure='cosine', log_step=10, ndcg_scorer=None): 317 | # compute the encoding for all the validation images and captions 318 | img_embs, cap_embs = encode_data( 319 | model, test_loader, log_step, logging.info) 320 | 321 | if measure == 'cosine': 322 | sim_fn = cosine_sim 323 | elif measure == 'dot': 324 | sim_fn = dot_sim 325 | 326 | results = [] 327 | for i in range(5): 328 | r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], 329 | None, None, 330 | return_ranks=True, ndcg_scorer=ndcg_scorer, fold_index=i) 331 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f" % r) 332 | ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], 333 | None, None, 334 | return_ranks=True, ndcg_scorer=ndcg_scorer, fold_index=i) 335 | if i == 0: 336 | rt, rti = rt0, rti0 337 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) 338 | ar = (r[0] + r[1] + r[2]) / 3 339 | ari = (ri[0] + ri[1] + ri[2]) / 3 340 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 341 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 342 | results += [list(r) + list(ri) + [ar, ari, rsum]] 343 | 344 | print("-----------------------------------") 345 | print("Mean metrics: ") 346 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 347 | print("rsum: %.1f" % (mean_metrics[16] * 6)) 348 | print("Average i2t Recall: %.1f" % mean_metrics[14]) 349 | print("Image to text: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % 350 | mean_metrics[:7]) 351 | print("Average t2i Recall: %.1f" % mean_metrics[15]) 352 | print("Text to image: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % 353 | mean_metrics[7:14]) 354 | 355 | # record metrics in tensorboard 356 | tb_logger.add_scalar('test/r1', mean_metrics[0], model.Eiters) 357 | tb_logger.add_scalar('test/r5', mean_metrics[1], model.Eiters) 358 | tb_logger.add_scalar('test/r10', mean_metrics[2], model.Eiters) 359 | tb_logger.add_scalars('test/mean_ndcg', {'rougeL': mean_metrics[5], 'spice': mean_metrics[6]}, model.Eiters) 360 | tb_logger.add_scalar('test/r1i', mean_metrics[7], model.Eiters) 361 | tb_logger.add_scalar('test/r5i', mean_metrics[8], model.Eiters) 362 | tb_logger.add_scalar('test/r10i', mean_metrics[9], model.Eiters) 363 | tb_logger.add_scalars('test/mean_ndcg_i', {'rougeL': mean_metrics[12], 'spice': mean_metrics[13]}, model.Eiters) 364 | 365 | 366 | def save_checkpoint(state, is_best_rsum, is_best_ndcg, filename='checkpoint.pth.tar', prefix=''): 367 | torch.save(state, prefix + filename) 368 | if is_best_rsum: 369 | shutil.copyfile(prefix + filename, prefix + 'model_best_rsum.pth.tar') 370 | if is_best_ndcg: 371 | shutil.copyfile(prefix + filename, prefix + 'model_best_ndcgspice.pth.tar') 372 | 373 | 374 | # def adjust_learning_rate(opt, optimizer, epoch): 375 | # """Sets the learning rate to the initial LR 376 | # decayed by 10 every 30 epochs""" 377 | # lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) 378 | # for param_group in optimizer.param_groups: 379 | # param_group['lr'] = lr 380 | 381 | 382 | def accuracy(output, target, topk=(1,)): 383 | """Computes the precision@k for the specified values of k""" 384 | maxk = max(topk) 385 | batch_size = target.size(0) 386 | 387 | _, pred = output.topk(maxk, 1, True, True) 388 | pred = pred.t() 389 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 390 | 391 | res = [] 392 | for k in topk: 393 | correct_k = correct[:k].view(-1).float().sum(0) 394 | res.append(correct_k.mul_(100.0 / batch_size)) 395 | return res 396 | 397 | 398 | if __name__ == '__main__': 399 | main() 400 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | import os 5 | import nltk 6 | from PIL import Image 7 | from pycocotools.coco import COCO 8 | import numpy as np 9 | import json as jsonmod 10 | from collections.abc import Sequence 11 | import shelve 12 | from transformers import BertTokenizer 13 | import pickle 14 | import tqdm 15 | 16 | from features import HuggingFaceTransformerExtractor 17 | 18 | 19 | def get_paths(config): 20 | """ 21 | Returns paths to images and annotations for the given datasets. For MSCOCO 22 | indices are also returned to control the data split being used. 23 | The indices are extracted from the Karpathy et al. splits using this 24 | snippet: 25 | 26 | >>> import json 27 | >>> dataset=json.load(open('dataset_coco.json','r')) 28 | >>> A=[] 29 | >>> for i in range(len(D['images'])): 30 | ... if D['images'][i]['split'] == 'val': 31 | ... A+=D['images'][i]['sentids'][:5] 32 | ... 33 | 34 | :param name: Dataset names 35 | :param use_restval: If True, the the `restval` data is included in train. 36 | """ 37 | name = config['dataset']['name'] 38 | annotations_path = os.path.join(config['dataset']['data'], name, 'annotations') 39 | use_restval = config['dataset']['restval'] 40 | 41 | roots = {} 42 | ids = {} 43 | if 'coco' == name: 44 | imgdir = config['dataset']['images-path'] 45 | capdir = annotations_path 46 | roots['train'] = { 47 | 'img': os.path.join(imgdir, 'train2014'), 48 | 'cap': os.path.join(capdir, 'captions_train2014.json') 49 | } 50 | roots['val'] = { 51 | 'img': os.path.join(imgdir, 'val2014'), 52 | 'cap': os.path.join(capdir, 'captions_val2014.json') 53 | } 54 | roots['test'] = { 55 | 'img': os.path.join(imgdir, 'val2014'), 56 | 'cap': os.path.join(capdir, 'captions_val2014.json') 57 | } 58 | roots['trainrestval'] = { 59 | 'img': (roots['train']['img'], roots['val']['img']), 60 | 'cap': (roots['train']['cap'], roots['val']['cap']) 61 | } 62 | ids['train'] = np.load(os.path.join(annotations_path, 'coco_train_ids.npy')) 63 | ids['val'] = np.load(os.path.join(annotations_path, 'coco_dev_ids.npy'))[:5000] 64 | ids['test'] = np.load(os.path.join(annotations_path, 'coco_test_ids.npy')) 65 | ids['trainrestval'] = ( 66 | ids['train'], 67 | np.load(os.path.join(annotations_path, 'coco_restval_ids.npy'))) 68 | if use_restval: 69 | roots['train'] = roots['trainrestval'] 70 | ids['train'] = ids['trainrestval'] 71 | elif 'f30k' == name: 72 | imgdir = config['dataset']['images-path'] 73 | cap = os.path.join(annotations_path, 'dataset_flickr30k.json') 74 | roots['train'] = {'img': imgdir, 'cap': cap} 75 | roots['val'] = {'img': imgdir, 'cap': cap} 76 | roots['test'] = {'img': imgdir, 'cap': cap} 77 | ids = {'train': None, 'val': None, 'test': None} 78 | 79 | return roots, ids 80 | 81 | 82 | class CocoDataset(data.Dataset): 83 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 84 | 85 | def __init__(self, root, json, transform=None, ids=None, get_images=True): 86 | """ 87 | Args: 88 | root: image directory. 89 | json: coco annotation file path. 90 | transform: transformer for image. 91 | """ 92 | self.root = root 93 | self.get_images = get_images 94 | # when using `restval`, two json files are needed 95 | if isinstance(json, tuple): 96 | self.coco = (COCO(json[0]), COCO(json[1])) 97 | else: 98 | self.coco = (COCO(json),) 99 | self.root = (root,) 100 | # if ids provided by get_paths, use split-specific ids 101 | if ids is None: 102 | self.ids = list(self.coco.anns.keys()) 103 | else: 104 | self.ids = ids 105 | 106 | # if `restval` data is to be used, record the break point for ids 107 | if isinstance(self.ids, tuple): 108 | self.bp = len(self.ids[0]) 109 | self.ids = list(self.ids[0]) + list(self.ids[1]) 110 | else: 111 | self.bp = len(self.ids) 112 | self.transform = transform 113 | 114 | def __getitem__(self, index): 115 | """This function returns a tuple that is further passed to collate_fn 116 | """ 117 | root, caption, img_id, path, image, _ = self.get_raw_item(index, self.get_images) 118 | 119 | if self.transform is not None: 120 | image = self.transform(image) 121 | 122 | target = caption 123 | return image, target, index, img_id 124 | 125 | def get_raw_item(self, index, load_image=True): 126 | if index < self.bp: 127 | coco = self.coco[0] 128 | root = self.root[0] 129 | else: 130 | coco = self.coco[1] 131 | root = self.root[1] 132 | ann_id = self.ids[index] 133 | caption = coco.anns[ann_id]['caption'] 134 | img_id = coco.anns[ann_id]['image_id'] 135 | img = coco.imgs[img_id] 136 | img_size = np.array([img['width'], img['height']]) 137 | if load_image: 138 | path = coco.loadImgs(img_id)[0]['file_name'] 139 | image = Image.open(os.path.join(root, path)).convert('RGB') 140 | 141 | return root, caption, img_id, path, image, img_size 142 | else: 143 | return root, caption, img_id, None, None, img_size 144 | 145 | def __len__(self): 146 | return len(self.ids) 147 | 148 | 149 | class BottomUpFeaturesDataset: 150 | def __init__(self, root, json, features_path, split, ids=None, **kwargs): 151 | # which dataset? 152 | r = root[0] if type(root) == tuple else root 153 | r = r.lower() 154 | if 'coco' in r: 155 | self.underlying_dataset = CocoDataset(root, json, ids=ids) 156 | elif 'f30k' in r or 'flickr30k' in r: 157 | self.underlying_dataset = FlickrDataset(root, json, split) 158 | 159 | # data_path = config['image-model']['data-path'] 160 | self.feats_data_path = os.path.join(features_path, 'bu_att') 161 | self.box_data_path = os.path.join(features_path, 'bu_box') 162 | config = kwargs['config'] 163 | self.load_preextracted = config['text-model']['pre-extracted'] 164 | if self.load_preextracted: 165 | # TODO: handle different types of preextracted features, not only BERT 166 | text_extractor = HuggingFaceTransformerExtractor(config, split, finetuned=config['text-model']['fine-tune']) 167 | self.text_features_db = FeatureSequence(text_extractor) 168 | 169 | def __getitem__(self, index): 170 | """This function returns a tuple that is further passed to collate_fn 171 | """ 172 | root, caption, img_id, _, _, img_size = self.underlying_dataset.get_raw_item(index, load_image=False) 173 | img_feat_path = os.path.join(self.feats_data_path, '{}.npz'.format(img_id)) 174 | img_box_path = os.path.join(self.box_data_path, '{}.npy'.format(img_id)) 175 | 176 | img_feat = np.load(img_feat_path)['feat'] 177 | img_boxes = np.load(img_box_path) 178 | 179 | # normalize boxes 180 | img_boxes = img_boxes / np.tile(img_size, 2) 181 | 182 | img_feat = torch.Tensor(img_feat) 183 | img_boxes = torch.Tensor(img_boxes) 184 | 185 | if self.load_preextracted: 186 | record = self.text_features_db[index] 187 | features = record['features'] 188 | captions = record['captions'] 189 | wembeddings = record['wembeddings'] 190 | target = (captions, features, wembeddings) 191 | else: 192 | target = caption 193 | # image = (img_feat, img_boxes) 194 | return img_feat, img_boxes, target, index, img_id 195 | 196 | def __len__(self): 197 | return len(self.underlying_dataset) 198 | 199 | 200 | class FlickrDataset(data.Dataset): 201 | """ 202 | Dataset loader for Flickr30k and Flickr8k full datasets. 203 | """ 204 | 205 | def __init__(self, root, json, split, transform=None, get_images=True): 206 | self.root = root 207 | self.split = split 208 | self.get_images = get_images 209 | self.transform = transform 210 | self.dataset = jsonmod.load(open(json, 'r'))['images'] 211 | self.ids = [] 212 | for i, d in enumerate(self.dataset): 213 | if d['split'] == split: 214 | self.ids += [(i, x) for x in range(len(d['sentences']))] 215 | 216 | # dump flickr images sizes on files for later use 217 | size_file = os.path.join(root, 'sizes.pkl') 218 | if os.path.isfile(size_file): 219 | # load it 220 | with open(size_file, 'rb') as f: 221 | self.sizes = pickle.load(f) 222 | else: 223 | # build it 224 | sizes = [] 225 | for im in tqdm.tqdm(self.dataset): 226 | path = im['filename'] 227 | image = Image.open(os.path.join(root, path)) 228 | sizes.append(image.size) 229 | 230 | with open(size_file, 'wb') as f: 231 | pickle.dump(sizes, f) 232 | self.sizes = sizes 233 | 234 | def __getitem__(self, index): 235 | """This function returns a tuple that is further passed to collate_fn 236 | """ 237 | root, caption, img_id, path, image, _ = self.get_raw_item(index, self.get_images) 238 | if self.transform is not None: 239 | image = self.transform(image) 240 | 241 | # Convert caption (string) to word ids. 242 | target = caption 243 | return image, target, index, img_id 244 | 245 | def get_raw_item(self, index, load_image=True): 246 | root = self.root 247 | ann_id = self.ids[index] 248 | img_id = ann_id[0] 249 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw'] 250 | img_size = self.sizes[img_id] 251 | 252 | if load_image: 253 | path = self.dataset[img_id]['filename'] 254 | image = Image.open(os.path.join(root, path)).convert('RGB') 255 | return root, caption, img_id, path, image, img_size 256 | else: 257 | return root, caption, img_id, None, None, img_size 258 | 259 | 260 | 261 | def __len__(self): 262 | return len(self.ids) 263 | 264 | 265 | class Collate: 266 | def __init__(self, config): 267 | self.vocab_type = config['text-model']['name'] 268 | if self.vocab_type == 'bert': 269 | self.tokenizer = BertTokenizer.from_pretrained(config['text-model']['pretrain']) 270 | 271 | def __call__(self, data): 272 | """Build mini-batch tensors from a list of (image, caption) tuples. 273 | Args: 274 | data: list of (image, caption) tuple. 275 | - image: torch tensor of shape (3, 256, 256) or (? > 3, 2048) 276 | - caption: torch tensor of shape (?); variable length. 277 | 278 | Returns: 279 | images: torch tensor of shape (batch_size, 3, 256, 256). 280 | targets: torch tensor of shape (batch_size, padded_length). 281 | lengths: list; valid length for each padded caption. 282 | """ 283 | # Sort a data list by caption length 284 | # data.sort(key=lambda x: len(x[1]), reverse=True) 285 | if len(data[0]) == 5: # TODO: find a better way to distinguish the two 286 | images, boxes, captions, ids, img_ids = zip(*data) 287 | elif len(data[0]) == 4: 288 | images, captions, ids, img_ids = zip(*data) 289 | 290 | preextracted_captions = type(captions[0]) is tuple 291 | if preextracted_captions: 292 | # they are pre-extracted features 293 | captions, cap_features, wembeddings = zip(*captions) 294 | cap_lengths = [len(cap) for cap in cap_features] 295 | captions = [torch.LongTensor(c) for c in captions] 296 | cap_features = [torch.FloatTensor(f) for f in cap_features] 297 | wembeddings = [torch.FloatTensor(w) for w in wembeddings] 298 | else: 299 | if self.vocab_type == 'bert': 300 | cap_lengths = [len(self.tokenizer.tokenize(c)) + 2 for c in 301 | captions] # + 2 in order to account for begin and end tokens 302 | max_len = max(cap_lengths) 303 | captions_ids = [torch.LongTensor(self.tokenizer.encode(c, max_length=max_len, pad_to_max_length=True)) 304 | for c in captions] 305 | 306 | captions = captions_ids 307 | # Merge images (convert tuple of 3D tensor to 4D tensor) 308 | preextracted_images = not (images[0].shape[0] == 3) 309 | if not preextracted_images: 310 | # they are images 311 | images = torch.stack(images, 0) 312 | else: 313 | # they are image features, variable length 314 | feat_lengths = [f.shape[0] + 1 for f in images] # +1 because the first region feature is reserved as CLS 315 | feat_dim = images[0].shape[1] 316 | img_features = torch.zeros(len(images), max(feat_lengths), feat_dim) 317 | for i, img in enumerate(images): 318 | end = feat_lengths[i] 319 | img_features[i, 1:end] = img 320 | 321 | box_lengths = [b.shape[0] + 1 for b in boxes] # +1 because the first region feature is reserved as CLS 322 | assert box_lengths == feat_lengths 323 | out_boxes = torch.zeros(len(boxes), max(box_lengths), 4) 324 | for i, box in enumerate(boxes): 325 | end = box_lengths[i] 326 | out_boxes[i, 1:end] = box 327 | 328 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 329 | if preextracted_captions: 330 | captions_t = torch.zeros(len(captions), max(cap_lengths)).long() 331 | features_t = torch.zeros(len(cap_features), max(cap_lengths), cap_features[0].shape[1]) 332 | wembeddings_t = torch.zeros(len(wembeddings), max(cap_lengths), wembeddings[0].shape[1]) 333 | for i, (cap, feats, wembs, l) in enumerate(zip(captions, cap_features, wembeddings, cap_lengths)): 334 | captions_t[i, :l] = cap[:l] 335 | features_t[i, :l] = feats[:l] 336 | wembeddings_t[i, :l] = wembs[:l] 337 | targets = (captions_t, features_t, wembeddings_t) 338 | else: 339 | targets = torch.zeros(len(captions), max(cap_lengths)).long() 340 | for i, cap in enumerate(captions): 341 | end = cap_lengths[i] 342 | targets[i, :end] = cap[:end] 343 | 344 | if not preextracted_images: 345 | return images, targets, None, cap_lengths, None, ids 346 | else: 347 | # features = features.permute(0, 2, 1) 348 | return img_features, targets, feat_lengths, cap_lengths, out_boxes, ids 349 | 350 | 351 | def get_loader_single(data_name, split, root, json, transform, preextracted_root=None, 352 | batch_size=100, shuffle=True, 353 | num_workers=2, ids=None, collate_fn=None, **kwargs): 354 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 355 | if 'coco' in data_name: 356 | if preextracted_root is not None: 357 | dataset = BottomUpFeaturesDataset(root=root, 358 | json=json, 359 | features_path=preextracted_root, split=split, 360 | ids=ids, **kwargs) 361 | else: 362 | # COCO custom dataset 363 | dataset = CocoDataset(root=root, 364 | json=json, 365 | transform=transform, ids=ids) 366 | elif 'f8k' in data_name or 'f30k' in data_name: 367 | if preextracted_root is not None: 368 | dataset = BottomUpFeaturesDataset(root=root, 369 | json=json, 370 | features_path=preextracted_root, split=split, 371 | ids=ids, **kwargs) 372 | else: 373 | dataset = FlickrDataset(root=root, 374 | split=split, 375 | json=json, 376 | transform=transform) 377 | 378 | # Data loader 379 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 380 | batch_size=batch_size, 381 | shuffle=shuffle, 382 | pin_memory=True, 383 | num_workers=num_workers, 384 | collate_fn=collate_fn) 385 | return data_loader 386 | 387 | 388 | def get_transform(data_name, split_name, config): 389 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], 390 | std=[0.229, 0.224, 0.225]) 391 | t_list = [] 392 | # if split_name == 'train': 393 | # t_list = [transforms.RandomResizedCrop(config['image-model']['crop-size']), 394 | # transforms.RandomHorizontalFlip()] 395 | # elif split_name == 'val': 396 | # t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 397 | # elif split_name == 'test': 398 | # t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 399 | 400 | t_end = [transforms.ToTensor(), normalizer] 401 | transform = transforms.Compose(t_list + t_end) 402 | return transform 403 | 404 | 405 | def get_loaders(config, workers, batch_size=None): 406 | data_name = config['dataset']['name'] 407 | if batch_size is None: 408 | batch_size = config['training']['bs'] 409 | collate_fn = Collate(config) 410 | roots, ids = get_paths(config) 411 | 412 | transform = get_transform(data_name, 'train', config) 413 | preextracted_root = config['image-model']['pre-extracted-features-root'] \ 414 | if 'pre-extracted-features-root' in config['image-model'] else None 415 | 416 | train_loader = get_loader_single(data_name, 'train', 417 | roots['train']['img'], 418 | roots['train']['cap'], 419 | transform, ids=ids['train'], 420 | preextracted_root=preextracted_root, 421 | batch_size=batch_size, shuffle=True, 422 | num_workers=workers, 423 | collate_fn=collate_fn, config=config) 424 | 425 | transform = get_transform(data_name, 'val', config) 426 | val_loader = get_loader_single(data_name, 'val', 427 | roots['val']['img'], 428 | roots['val']['cap'], 429 | transform, ids=ids['val'], 430 | preextracted_root=preextracted_root, 431 | batch_size=batch_size, shuffle=False, 432 | num_workers=workers, 433 | collate_fn=collate_fn, config=config) 434 | 435 | return train_loader, val_loader 436 | 437 | 438 | def get_test_loader(config, workers, split_name='test', batch_size=None): 439 | data_name = config['dataset']['name'] 440 | if batch_size is None: 441 | batch_size = config['training']['bs'] 442 | collate_fn = Collate(config) 443 | # Build Dataset Loader 444 | roots, ids = get_paths(config) 445 | 446 | preextracted_root = config['image-model']['pre-extracted-features-root'] \ 447 | if 'pre-extracted-features-root' in config['image-model'] else None 448 | 449 | transform = get_transform(data_name, split_name, config) 450 | test_loader = get_loader_single(data_name, split_name, 451 | roots[split_name]['img'], 452 | roots[split_name]['cap'], 453 | transform, ids=ids[split_name], 454 | preextracted_root=preextracted_root, 455 | batch_size=batch_size, shuffle=False, 456 | num_workers=workers, 457 | collate_fn=collate_fn, config=config) 458 | return test_loader 459 | -------------------------------------------------------------------------------- /features.py: -------------------------------------------------------------------------------- 1 | # This script extract features and put them in shelve format 2 | 3 | import os 4 | import torch 5 | import tqdm 6 | import argparse 7 | import yaml 8 | import re 9 | import itertools 10 | import pickle 11 | import numpy as np 12 | from torch.utils.data import DataLoader 13 | # from graphrcnn.extract_features import extract_visual_features 14 | # from torchvision.datasets.coco import CocoCaptions 15 | # from datasets import CocoCaptionsOnly 16 | from torchvision import transforms 17 | from torchvision.models import resnet18, resnet50, resnet101, resnet152, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn 18 | from transformers import BertTokenizer, BertModel 19 | # from datasets import TextCollator 20 | import shelve 21 | import data 22 | from models.text import EncoderTextBERT 23 | 24 | 25 | class TextCollator(object): 26 | """ 27 | From a list of samples from the dataset, 28 | returns the batched captions. 29 | This should be passed to the DataLoader 30 | """ 31 | 32 | def __call__(self, batch): 33 | transposed_batch = list(zip(*batch)) 34 | # images = transposed_batch[0] 35 | captions = transposed_batch[1] 36 | return captions 37 | 38 | 39 | class FeatureExtractor(object): 40 | def __init__(self, config, split, bs=1, collate_fn=torch.utils.data.dataloader.default_collate): 41 | self.config = config 42 | self.split = split 43 | self.output_feat_fld = os.path.join(config['dataset']['data'], '{}_precomp'.format(config['dataset']['name'])) 44 | if not os.path.exists(self.output_feat_fld): 45 | os.makedirs(self.output_feat_fld) 46 | 47 | def extract(self): 48 | """ 49 | Extracts features and dump them on a db file. 50 | For text extractors: each file record is a dictionary with keys: 51 | 'image_id' (int) and 'features' (np.array K x dim) 52 | For image extractors: each file record is a dictionary with keys: 53 | 'boxes' (np.array K x 4), 'scores' (np.array K x 1), 'features' (np.array K x dim) 54 | :return: void 55 | """ 56 | raise NotImplementedError 57 | 58 | def get_db_file(self): 59 | """ 60 | :return: the path to the db file for these features 61 | """ 62 | raise NotImplementedError 63 | 64 | 65 | class HuggingFaceTransformerExtractor(FeatureExtractor): 66 | def __init__(self, config, split, model_name='bert', pretrained='bert-base-uncased', finetuned=None): 67 | super(HuggingFaceTransformerExtractor, self).__init__(config, split, bs=5, collate_fn=TextCollator()) 68 | self.pretrained = pretrained 69 | self.finetuned = finetuned 70 | self.model_name = model_name 71 | self.config = config 72 | 73 | roots, ids = data.get_paths(config) 74 | 75 | data_name = config['dataset']['name'] 76 | transform = data.get_transform(data_name, 'val', config) 77 | collate_fn = data.Collate(config) 78 | self.loader = data.get_loader_single(data_name, split, 79 | roots[split]['img'], 80 | roots[split]['cap'], 81 | transform, ids=ids[split], 82 | batch_size=32, shuffle=False, 83 | num_workers=4, collate_fn=collate_fn) 84 | 85 | def get_db_file(self): 86 | finetuned_str = "" if not self.finetuned else '_finetuned' 87 | feat_dst_filename = os.path.join(self.output_feat_fld, 88 | '{}_{}_{}{}.db'.format(self.split, self.model_name, self.pretrained, finetuned_str)) 89 | print('Hugging Face BERT features filename: {}'.format(feat_dst_filename)) 90 | return feat_dst_filename 91 | 92 | def extract(self, device='cuda'): 93 | # Load pre-trained model tokenizer (vocabulary) and model itself 94 | if self.model_name == 'bert': 95 | self.config['text-model']['layers'] = 0 96 | self.config['text-model']['pre-extracted'] = False 97 | model = EncoderTextBERT(self.config) 98 | else: 99 | raise ValueError('{} model is not known'.format(self.model)) 100 | 101 | if self.finetuned: 102 | if torch.cuda.is_available(): 103 | device = torch.device('cuda') 104 | else: 105 | device = torch.device('cpu') 106 | checkpoint = torch.load(self.finetuned, map_location=device)['model'] 107 | checkpoint = {k[k.find('.txt_enc.'):].replace('.txt_enc.', ''): v for k, v in checkpoint.items() if '.txt_enc.' in k} 108 | model.load_state_dict(checkpoint, strict=False) 109 | print('BERT model extracted from trained model at {}'.format(self.finetuned)) 110 | 111 | model.to(device) 112 | model.eval() 113 | 114 | feat_dst_filename = self.get_db_file() 115 | prog_id = 0 116 | with shelve.open(feat_dst_filename, flag='n') as db: 117 | for images, captions, img_lengths, cap_lengths, boxes, ids in tqdm.tqdm(self.loader): 118 | captions = captions.cuda() 119 | 120 | with torch.no_grad(): 121 | _, feats = model(captions, cap_lengths) 122 | # get the features from the last hidden state 123 | feats = feats.cpu().numpy() 124 | word_embs = model.word_embeddings(captions) 125 | word_embs = word_embs.cpu().numpy() 126 | for c, f, w, l, i in zip(captions.cpu().numpy(), feats, word_embs, cap_lengths, ids): 127 | # dump_feats.append(f[:l]) 128 | dump_dict = {'image_id': i, 'captions': c, 'features': f[:l], 'wembeddings': w[:l]} 129 | db[str(prog_id)] = dump_dict 130 | prog_id += 1 131 | 132 | 133 | class TextWordIndexesExtractor(FeatureExtractor): 134 | def __init__(self, dataset, root, split): 135 | super().__init__(dataset, root, split) 136 | 137 | def get_db_file(self): 138 | feat_dst_filename = os.path.join(self.output_feat_fld, 139 | '{}_{}_word_indexes.db'.format(self.dataset, self.split)) 140 | return feat_dst_filename 141 | 142 | def extract(self, device='cuda'): 143 | if self.dataset == 'coco': 144 | dataset_root = os.path.join(self.root, '{}2014'.format(self.split)) 145 | dataset_json = os.path.join(self.root, 146 | 'stanford_split_annots', 'captions_{}2014.json'.format(self.split)) 147 | dataset = CocoCaptionsOnly(dataset_root, dataset_json, indexing='images') 148 | 149 | dataloader = DataLoader(dataset, 150 | num_workers=4, 151 | batch_size=1, 152 | shuffle=False, 153 | ) 154 | else: 155 | raise ValueError('{} dataset is not implemented!'.format(self.dataset)) 156 | 157 | # Build dictionary 158 | dict_file_path = os.path.join(self.output_feat_fld, 'word_dict_{}.pkl'.format(self.dataset)) 159 | if not os.path.isfile(dict_file_path): 160 | if not self.split == 'train': 161 | raise ValueError('Dictionary should be built on the train set. Rerun with split=train') 162 | else: 163 | print('Building dictionary ...') 164 | wdict = {} 165 | wfreq = {} 166 | counter = 2 # 1 is the unknown label 167 | for i, captions in enumerate(tqdm.tqdm(dataloader)): 168 | captions = [c[0] for c in captions] 169 | tokenized_captions = [re.sub('[!#?,.:";]', '', c).strip().replace("'", ' ').lower().split(' ') for c in captions] 170 | words = itertools.chain.from_iterable(tokenized_captions) 171 | for w in words: 172 | # create dict 173 | if w not in wdict: 174 | wdict[w] = counter 175 | counter += 1 176 | 177 | # handle word frequencies 178 | if w not in wfreq: 179 | wfreq[w] = 1 180 | else: 181 | wfreq[w] += 1 182 | 183 | print('Filtering dictionary ...') 184 | # Filter dictionary based on frequencies 185 | for w, f in wfreq.items(): 186 | if f == 1: 187 | wdict[w] = 1 # 1 is the unknown label 188 | 189 | with open(dict_file_path, 'wb') as f: 190 | pickle.dump(wdict, f) 191 | else: 192 | print('Loading dict from {}'.format(dict_file_path)) 193 | with open(dict_file_path, 'rb') as f: 194 | wdict = pickle.load(f) 195 | 196 | feat_dst_filename = self.get_db_file() 197 | prog_id = 0 198 | with shelve.open(feat_dst_filename, flag='n') as db: 199 | for i, captions in enumerate(tqdm.tqdm(dataloader)): 200 | captions = [c[0] for c in captions] 201 | tokenized_captions = [re.sub('[!#?,.:";]', '', c).strip().replace("'", ' ').lower().split(' ') for c in captions] 202 | lengths = [len(c) for c in tokenized_captions] 203 | max_len = max(lengths) 204 | 205 | for tc, l in zip(tokenized_captions, lengths): 206 | indexes = [wdict[w] if w in wdict else 1 for w in tc] 207 | # dump_feats.append(f[:l]) 208 | dump_dict = {'image_id': i, 'features': np.expand_dims(np.asarray(indexes), axis=1)} 209 | db[str(prog_id)] = dump_dict 210 | prog_id += 1 211 | 212 | 213 | class ResnetFeatureExtractor(FeatureExtractor): 214 | def __init__(self, dataset, root, split, resnet_depth, output_dims=(1, 1)): 215 | super().__init__(dataset, root, split) 216 | self.resnet_depth = resnet_depth 217 | self.avgpool = torch.nn.AdaptiveAvgPool2d(output_dims) 218 | self.output_dims = output_dims 219 | 220 | def extract(self, device='cuda'): 221 | if self.dataset == 'coco': 222 | dataset_root = os.path.join(self.root, '{}2014'.format(self.split)) 223 | dataset_json = os.path.join(self.root, 224 | 'stanford_split_annots', 'captions_{}2014.json'.format(self.split)) 225 | if self.split == 'train': 226 | transform = transforms.Compose( 227 | [transforms.Resize(256), 228 | transforms.FiveCrop(224), 229 | transforms.Lambda(lambda crops: torch.stack([ 230 | transforms.Compose([ 231 | transforms.ToTensor(), 232 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 233 | std=[0.229, 0.224, 0.225]) 234 | ])(crop) for crop in crops]) 235 | )]) 236 | elif self.split == 'test' or self.split == 'val': 237 | transform = transforms.Compose( 238 | [transforms.Resize(256), 239 | transforms.CenterCrop(224), 240 | transforms.ToTensor(), 241 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 242 | std=[0.229, 0.224, 0.225]), 243 | transforms.Lambda(lambda imgt: imgt.unsqueeze(0)) 244 | ] 245 | ) 246 | 247 | dataset = CocoCaptions(dataset_root, dataset_json, 248 | transform=transform) 249 | dataloader = DataLoader(dataset, 250 | num_workers=4, 251 | batch_size=1, 252 | shuffle=False 253 | ) 254 | if self.resnet_depth == 18: 255 | model = resnet18(pretrained=True) 256 | elif self.resnet_depth == 50: 257 | model = resnet50(pretrained=True) 258 | elif self.resnet_depth == 101: 259 | model = resnet101(pretrained=True) 260 | elif self.resnet_depth == 152: 261 | model = resnet152(pretrained=True) 262 | 263 | # delete the classification and the pooling layers 264 | modules = list(model.children())[:-2] 265 | model = torch.nn.Sequential(*modules) 266 | model.to(device) 267 | model.eval() 268 | 269 | feat_dst_filename = self.get_db_file() 270 | with shelve.open(feat_dst_filename, flag='n') as db: 271 | for idx, (img, _) in enumerate(tqdm.tqdm(dataloader)): 272 | with torch.no_grad(): 273 | img = img.to(device) 274 | img = img.squeeze(0) 275 | feats = model(img) 276 | feats = self.avgpool(feats) 277 | feats = feats.view(feats.shape[0], feats.shape[1], -1) 278 | feats = feats.permute(0, 2, 1).squeeze(0) 279 | if idx == 0: 280 | print('Features have shape {}'.format(feats.shape)) 281 | dump_dict = {'scores': None, 'boxes': None, 'features': feats.cpu().numpy()} 282 | db[str(idx)] = dump_dict 283 | 284 | def get_db_file(self): 285 | feat_dst_filename = os.path.join(self.output_feat_fld, 286 | '{}_{}_resnet{}_{}x{}.db'.format(self.dataset, self.split, 287 | self.resnet_depth, self.output_dims[0], 288 | self.output_dims[1])) 289 | return feat_dst_filename 290 | 291 | 292 | class VGGFeatureExtractor(FeatureExtractor): 293 | def __init__(self, dataset, root, split, vgg_depth): # , output_dims=(1, 1)): 294 | super().__init__(dataset, root, split) 295 | self.vgg_depth = vgg_depth 296 | # self.avgpool = torch.nn.AdaptiveAvgPool2d(output_dims) 297 | # self.output_dims = output_dims 298 | 299 | def extract(self, device='cuda'): 300 | if self.dataset == 'coco': 301 | dataset_root = os.path.join(self.root, '{}2014'.format(self.split)) 302 | dataset_json = os.path.join(self.root, 303 | 'stanford_split_annots', 'captions_{}2014.json'.format(self.split)) 304 | if self.split == 'train': 305 | transform = transforms.Compose( 306 | [transforms.Resize(256), 307 | transforms.FiveCrop(224), 308 | transforms.Lambda(lambda crops: torch.stack([ 309 | transforms.Compose([ 310 | transforms.ToTensor(), 311 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 312 | std=[0.229, 0.224, 0.225]) 313 | ])(crop) for crop in crops]) 314 | )]) 315 | elif self.split == 'test' or self.split == 'val': 316 | transform = transforms.Compose( 317 | [transforms.Resize(256), 318 | transforms.CenterCrop(224), 319 | transforms.ToTensor(), 320 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 321 | std=[0.229, 0.224, 0.225]), 322 | transforms.Lambda(lambda imgt: imgt.unsqueeze(0)) 323 | ] 324 | ) 325 | 326 | dataset = CocoCaptions(dataset_root, dataset_json, 327 | transform=transform) 328 | dataloader = DataLoader(dataset, 329 | num_workers=4, 330 | batch_size=1, 331 | shuffle=False 332 | ) 333 | if self.vgg_depth == 11: 334 | model = vgg11_bn(pretrained=True) 335 | elif self.vgg_depth == 13: 336 | model = vgg13_bn(pretrained=True) 337 | elif self.vgg_depth == 16: 338 | model = vgg16_bn(pretrained=True) 339 | elif self.vgg_depth == 19: 340 | model = vgg19_bn(pretrained=True) 341 | 342 | # delete the classification and the pooling layers 343 | modules = list(model.classifier.children())[:-3] 344 | model.classifier = torch.nn.Sequential(*modules) 345 | model.to(device) 346 | model.eval() 347 | 348 | feat_dst_filename = self.get_db_file() 349 | with shelve.open(feat_dst_filename, flag='n') as db: 350 | for idx, (img, _) in enumerate(tqdm.tqdm(dataloader)): 351 | with torch.no_grad(): 352 | img = img.to(device) 353 | img = img.squeeze(0) 354 | feats = model(img) 355 | # feats = self.avgpool(feats) 356 | # feats = feats.view(feats.shape[0], feats.shape[1], -1) 357 | # feats = feats.permute(0, 2, 1).squeeze(0) 358 | if self.split == 'train': 359 | feats = feats.unsqueeze(1) 360 | if idx == 0: 361 | print('Features have shape {}'.format(feats.shape)) 362 | dump_dict = {'scores': None, 'boxes': None, 'features': feats.cpu().numpy()} 363 | db[str(idx)] = dump_dict 364 | 365 | def get_db_file(self): 366 | feat_dst_filename = os.path.join(self.output_feat_fld, 367 | '{}_{}_vgg{}_bn.db'.format(self.dataset, self.split, 368 | self.vgg_depth)) 369 | return feat_dst_filename 370 | 371 | 372 | # class GraphRcnnFeatureExtractor(FeatureExtractor): 373 | # def __init__(self, dataset, root, split, algorithm='sg_imp'): 374 | # super().__init__(dataset, root, split) 375 | # self.algorithm = algorithm 376 | # 377 | # def extract(self): 378 | # # use the graphrcnn package to extract visual relational features 379 | # extract_visual_features(self.dataset, self.root, self.algorithm, self.split) 380 | # 381 | # def get_db_file(self): 382 | # feat_dst_filename = os.path.join(self.output_feat_fld, 383 | # '{}_{}_{}.db'.format(self.dataset, self.split, 384 | # self.algorithm)) 385 | # return feat_dst_filename 386 | 387 | 388 | def get_features_extractor(config, split, method=None, finetuned=None): 389 | if method == 'transformer-bert': 390 | config['text-model']['pre-extracted'] = False 391 | extractor = HuggingFaceTransformerExtractor(config, split, finetuned=finetuned) 392 | 393 | # elif method == 'graphrcnn': 394 | # extractor = GraphRcnnFeatureExtractor(dataset_name, dataset_root, split, 395 | # extractor_config['algorithm']) 396 | # elif method == 'resnet': 397 | # extractor = ResnetFeatureExtractor(dataset_name, dataset_root, split, 398 | # extractor_config['depth'], (extractor_config['output-h'], 399 | # extractor_config['output-w'])) 400 | # elif method == 'vgg': 401 | # extractor = VGGFeatureExtractor(dataset_name, dataset_root, split, 402 | # extractor_config['depth']) 403 | else: 404 | raise ValueError('Extraction method {} not known!'.format(args.method)) 405 | return extractor 406 | 407 | 408 | def main(args, config): 409 | extractor = get_features_extractor(config, args.split, args.method, args.finetuned) 410 | if os.path.isfile(extractor.get_db_file() + '.dat'): 411 | answ = input("Features {} for {} already existing. Overwrite? (y/n)".format(extractor.get_db_file(), extractor)) 412 | if answ == 'y': 413 | print('Using extractor: {}'.format(extractor)) 414 | extractor.extract() 415 | else: 416 | print('Skipping {}'.format(extractor)) 417 | else: 418 | extractor.extract() 419 | 420 | print('DONE') 421 | 422 | 423 | if __name__ == '__main__': 424 | arg_parser = argparse.ArgumentParser(description='Extract captioning scores for use as relevance') 425 | arg_parser.add_argument('--config', type=str, help="Which configuration to use. See into 'config' folder") 426 | arg_parser.add_argument('--split', type=str, default="val", help="Dataset split to use") 427 | arg_parser.add_argument('--finetuned', type=str, default=None, help="Optional finetuning checkpoint") 428 | arg_parser.add_argument('method', type=str, help="Which kind of feature you want to extract") 429 | # arg_parser.add_argument('type', type=str, choices=['image','text'], help="Method type") 430 | 431 | args = arg_parser.parse_args() 432 | 433 | if args.finetuned is not None: 434 | config = torch.load(args.finetuned)['config'] 435 | print('Configuration read from checkpoint') 436 | else: 437 | with open(args.config, 'r') as ymlfile: 438 | config = yaml.load(ymlfile) 439 | main(args, config) 440 | 441 | --------------------------------------------------------------------------------