├── Code ├── BuildImgOrderClusterVocab.py ├── Evaluator │ ├── BLEU │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── CIDEr │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── METEOR │ │ ├── data │ │ │ └── paraphrase-en.gz │ │ ├── meteor-1.5.jar │ │ └── meteor.py │ └── ROUGE │ │ └── rouge.py ├── FeatureExtractorPytorch.py ├── Generator.py ├── OrderClusterStream.py ├── Pipeline.py ├── Retriever.py └── Trainer.py ├── Data └── raw │ └── coco │ ├── coco_test_v2.csv │ ├── coco_train_v2.csv │ └── coco_val_v2.csv └── README.md /Code/BuildImgOrderClusterVocab.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import csv 4 | import os 5 | import pickle 6 | from collections import Counter 7 | import nltk 8 | import numpy as np 9 | sys.path.append('..') 10 | from FeatureExtractorPytorch import ResNet152Extractor as extractor 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 12 | 13 | 14 | class BuildVocab(object): 15 | def __init__(self, opts): 16 | self._options = opts 17 | self.model = extractor() 18 | self._GO, self._EOS, self._PAD = 'GO', 'EOS', 'PAD' 19 | self.spec_words = [self._PAD, self._GO, self._EOS] 20 | train_img_idx_sequence, train_filename_sequence, train_captions_sequence, data_split_sequence = \ 21 | self.read_csv(opts.train_data_path) 22 | freq_words = self.get_freq_word(train_captions_sequence) 23 | self.idx2word, self.word2idx, self.word2vec, self.word_freq = self.load_word_info(freq_words) 24 | self.train_data = self.interpreter(opts.train_data_path, data_split='train') 25 | self.val_data = self.interpreter(opts.val_data_path, data_split='val') 26 | self.test_data = self.interpreter(opts.test_data_path, data_split='test') 27 | train_sequence_data = [self.train_data.get('Idx'), self.train_data.get('Caption'), 28 | self.train_data.get('ImgFeature'), self.train_data.get('NumCaption')] 29 | val_sequence_data = [self.val_data.get('Idx'), self.val_data.get('Caption'), 30 | self.val_data.get('ImgFeature'), self.val_data.get('NumCaption')] 31 | test_sequence_data = [self.test_data.get('Idx'), self.test_data.get('Caption'), 32 | self.test_data.get('ImgFeature'), self.test_data.get('NumCaption')] 33 | self.save_pickle_file(train_sequence_data, os.path.join(opts.train_dir, opts.train_data_name)) 34 | self.save_pickle_file(val_sequence_data, os.path.join(opts.train_dir, opts.val_data_name)) 35 | self.save_pickle_file(test_sequence_data, os.path.join(opts.train_dir, opts.test_data_name)) 36 | self.save_pickle_file([self.word_freq, self.idx2word, self.word2idx, self.word2vec], 37 | os.path.join(opts.train_dir, opts.info_name)) 38 | 39 | @staticmethod 40 | def read_csv(path): 41 | img_idx_sequence, filename_sequence, captions_sequence, data_split_sequence = [], [], [], [] 42 | with open(path, 'r') as f: 43 | reader = csv.reader(f) 44 | index = 0 45 | for line in reader: 46 | if index == 0: 47 | index += 1 48 | continue 49 | else: 50 | img_idx, filename, data_split, captions = line 51 | captions = captions.split('---') 52 | new_captions = [] 53 | for caption in captions: 54 | token_caption = nltk.tokenize.word_tokenize(caption) 55 | new_captions.append(token_caption) 56 | captions_sequence.append(new_captions) 57 | filename_sequence.append(filename) 58 | img_idx_sequence.append(img_idx) 59 | data_split_sequence.append(data_split) 60 | f.close() 61 | print('[The number of records in {}: {}]'.format(path, len(img_idx_sequence))) 62 | return img_idx_sequence, filename_sequence, captions_sequence, data_split_sequence 63 | 64 | @staticmethod 65 | def get_freq_word(question_list_list): 66 | raw_word_freq = Counter() 67 | for question_list in question_list_list: 68 | for question in question_list: 69 | raw_word_freq.update(question) 70 | freq_words = sorted(raw_word_freq.items(), key=lambda x: x[1], reverse=True) 71 | print('[FreqWord][The total of word in data set: {}][The number of word: {}][That of removed ones: {}]'. 72 | format(sum(raw_word_freq.values()), len(freq_words), len(raw_word_freq) - len(freq_words))) 73 | return freq_words 74 | 75 | def load_word_info(self, freq_words): 76 | opts = self._options 77 | raw_word2vec = dict() 78 | with open(opts.external_word_vectors_file_path, 'r') as f: 79 | lines = f.readlines() 80 | for line in lines: 81 | v = line.strip().split(" ") 82 | raw_word2vec[v[0]] = np.array([float(nv) for nv in v[1:]]) 83 | f.close() 84 | common_id2word, word2vec, invalid_num, word_freq = [], [], 0, [] 85 | invalid_word, invalid_counter = [], Counter() 86 | freq_words = [(self._PAD, 0), (self._GO, 0), (self._EOS, 0)] + freq_words 87 | for word, freq in freq_words: 88 | if raw_word2vec.get(word) is not None: 89 | common_id2word.append(word) 90 | word2vec.append(raw_word2vec[word]) 91 | word_freq.append((word, freq)) 92 | else: 93 | word_freq.append((word, freq)) 94 | common_id2word.append(word) 95 | random_word2vec = 0.2 * (np.random.random(opts.word_emb_size) - 0.5) 96 | word2vec.append(random_word2vec) 97 | invalid_counter.update([freq]) 98 | invalid_word.append(word) 99 | invalid_num += 1 100 | id2word = common_id2word 101 | id2word = {i: w for i, w in enumerate(id2word)} 102 | word2id = {w: i for i, w in id2word.items()} 103 | word2vec = np.array(word2vec) 104 | print('[Invalid number: {}, invalid_counter: {}]\n' 105 | '[The shape of word2vec: {}, the length of word2idx: {}, that of word_freq: {}]'. 106 | format(invalid_num, invalid_counter, word2vec.shape, len(word2id), len(word_freq))) 107 | return id2word, word2id, word2vec, word_freq 108 | 109 | def interpreter(self, path=None, data_split=None): 110 | img_idx_sequence, filename_sequence, captions_sequence, data_split_sequence = self.read_csv(path) 111 | num_captions_sequence = self.word2id_mapper(captions_sequence) 112 | print('[Interpreter][Word2id mapper has finished!]') 113 | print('[Size of file_sequence: {}][Size of image_idx_sequence: {}]'. 114 | format(len(filename_sequence), len(img_idx_sequence))) 115 | feature_sequence = self.get_image_feature(filename_sequence, data_split) 116 | print('[We have got image features!]') 117 | print('[ImgIdxSequence:{}][FileNameSequence:{}][CaptionSequence:{}][FeatureSequence:{}]]' 118 | '[NumCaptionsSequence:{}]'.format(len(img_idx_sequence), len(filename_sequence), len(captions_sequence), 119 | len(feature_sequence), len(num_captions_sequence))) 120 | data = {'Idx': filename_sequence, 'Caption': captions_sequence, 'ImgFeature': feature_sequence, 121 | 'NumCaption': num_captions_sequence} 122 | return data 123 | 124 | def word2id_mapper(self, captions_sequence): 125 | opts = self._options 126 | word2idx = self.word2idx 127 | min_length = opts.min_sentence_length if opts.min_sentence_length is not None else 0 128 | num_captions_sequence = [] 129 | for captions in captions_sequence: 130 | num_captions = [] 131 | for caption in captions: 132 | num_caption = [] 133 | for word in caption: 134 | if word2idx.get(word) is not None: 135 | num_caption.append(word2idx.get(word)) 136 | if len(num_caption) < min_length: 137 | continue 138 | num_captions.append(num_caption) 139 | num_captions_sequence.append(num_captions) 140 | return num_captions_sequence 141 | 142 | def get_image_feature(self, filename_sequence, data_split=None): 143 | opts = self._options 144 | file_path_sequence = [] 145 | for file_name in filename_sequence: 146 | if data_split == 'train': 147 | sub_folder = 'train2014' 148 | elif data_split == 'val': 149 | sub_folder = 'val2014' 150 | elif data_split == 'test': 151 | sub_folder = 'val2014' 152 | else: 153 | raise Exception('DataSplit {} is illegal'.format(data_split)) 154 | if data_split == 'train': 155 | if 'val' in file_name: 156 | sub_folder = 'val2014' 157 | file_path = os.path.join(opts.image_dir, sub_folder+'/'+file_name) 158 | file_path_sequence.append(file_path) 159 | print('[FeatureExtractStart!][ImgFileSize:{}]'.format(len(file_path_sequence))) 160 | features = self.model.get_feature(file_path_sequence) 161 | return features 162 | 163 | @staticmethod 164 | def save_pickle_file(data, path): 165 | with open(path, 'wb') as f: 166 | pickle.dump(data, f) 167 | f.close() 168 | 169 | 170 | def read_commands(): 171 | parser = argparse.ArgumentParser(usage='Pre-processing data set') 172 | root = os.path.abspath('..') 173 | raw_root = os.path.join(root, 'Data/raw') 174 | parser.add_argument('--data_name', type=str, default='coco') 175 | parser.add_argument('--data_id', type=int, default=100) 176 | parser.add_argument('--train_data_path', type=str, default=os.path.join(raw_root, 'coco/coco_train_v2.csv')) 177 | parser.add_argument('--val_data_path', type=str, default=os.path.join(raw_root, 'coco/coco_val_v2.csv')) 178 | parser.add_argument('--test_data_path', type=str, default=os.path.join(raw_root, 'coco/coco_test_v2.csv')) 179 | parser.add_argument('--image_dir', type=str, default=os.path.join(raw_root, 'img')) 180 | parser.add_argument('--external_word_vectors_file_path', type=str, 181 | default=os.path.join(root, 'Data/word2vec/glove.6B.300d.txt')) 182 | parser.add_argument('--max_sentence_length', type=int, default=16) 183 | parser.add_argument('--word_emb_size', type=int, default=300) 184 | parser.add_argument('--min_sentence_length', type=int, default=0) 185 | parser.add_argument('--min_count', type=int, default=3) 186 | parser.add_argument('--batch_size', type=int, default=32) 187 | parser.add_argument('--word_num', type=int, default=None) 188 | parser.add_argument('--min_freq', type=int, default=5) 189 | args = parser.parse_args() 190 | args.train_dir = os.path.join(root, 'Data/train/{}_v{}'.format(args.data_name, args.data_id)) 191 | args.train_data_name = '{}_train_v{}.pkl'.format(args.data_name, args.data_id) 192 | args.val_data_name = '{}_val_v{}.pkl'.format(args.data_name, args.data_id) 193 | args.test_data_name = '{}_test_v{}.pkl'.format(args.data_name, args.data_id) 194 | args.info_name = '{}_info_v{}.pkl'.format(args.data_name, args.data_id) 195 | return args 196 | 197 | 198 | def main(): 199 | opts = read_commands() 200 | BuildVocab(opts) 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | -------------------------------------------------------------------------------- /Code/Evaluator/BLEU/bleu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import nltk 4 | from Evaluator.BLEU.bleu_scorer import BleuScorer 5 | 6 | 7 | class BLEUEvaluator(object): 8 | def __init__(self, n=4): 9 | # default compute Blue score up to 4 10 | self._n = n 11 | self._hypo_for_image = {} 12 | self.ref_for_image = {} 13 | 14 | def compute_score(self, gts, res, mode='all'): 15 | 16 | assert(gts.keys() == res.keys()) 17 | imgIds = gts.keys() 18 | 19 | bleu_scorer = BleuScorer(n=self._n) 20 | for id in imgIds: 21 | hypo = res[id] 22 | ref = gts[id] 23 | 24 | # Sanity check. 25 | assert(type(hypo) is list) 26 | assert(len(hypo) == 1) 27 | assert(type(ref) is list) 28 | assert(len(ref) >= 1) 29 | 30 | bleu_scorer += (hypo[0], ref) 31 | 32 | # score, scores = bleu_scorer.compute_score(option='shortest') 33 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 34 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 35 | 36 | # return (bleu, bleu_info) 37 | if mode == 'all': 38 | return score 39 | elif mode == 'every': 40 | return scores 41 | else: 42 | return score, scores 43 | 44 | def method(self): 45 | return "Bleu" 46 | 47 | 48 | def main(): 49 | hypo = {'1': ['I like it !'], '2': ['I completely do not know !'], 50 | '3': ['how about you ?'], '4': ['what is this ?'], 5: ['this is amazing !']} 51 | ref = {'1': ['I love you !', 'I love myself !'], '2': ['I do not know !'], '3': ['how are you ?'], 52 | '4': ['what is this animal ?'], 5: ['this is awkward !']} 53 | meteor = BLEUEvaluator(n=4) 54 | score = meteor.compute_score(ref, hypo, 'every') 55 | print(len(score)) 56 | for val in score: 57 | print(val) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /Code/Evaluator/BLEU/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys, math, re 3 | from collections import defaultdict 4 | 5 | def precook(s, n=4, out=False): 6 | """Takes a string as input and returns an object that can be given to 7 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 8 | can take string arguments as well.""" 9 | words = s.split() 10 | counts = defaultdict(int) 11 | for k in range(1,n+1): 12 | for i in range(len(words)-k+1): 13 | ngram = tuple(words[i:i+k]) 14 | counts[ngram] += 1 15 | return (len(words), counts) 16 | 17 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 18 | '''Takes a list of reference sentences for a single segment 19 | and returns an object that encapsulates everything that BLEU 20 | needs to know about them.''' 21 | 22 | reflen = [] 23 | maxcounts = {} 24 | for ref in refs: 25 | rl, counts = precook(ref, n) 26 | reflen.append(rl) 27 | for (ngram,count) in counts.items(): 28 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 29 | 30 | # Calculate effective reference sentence length. 31 | if eff == "shortest": 32 | reflen = min(reflen) 33 | elif eff == "average": 34 | reflen = float(sum(reflen))/len(reflen) 35 | 36 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 37 | 38 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 39 | 40 | return (reflen, maxcounts) 41 | 42 | def cook_test(test, ppp, eff=None, n=4): 43 | '''Takes a test sentence and returns an object that 44 | encapsulates everything that BLEU needs to know about it.''' 45 | reflen, refmaxcounts = ppp 46 | testlen, counts = precook(test, n, True) 47 | 48 | result = {} 49 | 50 | # Calculate effective reference sentence length. 51 | 52 | if eff == "closest": 53 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 54 | else: ## i.e., "average" or "shortest" or None 55 | result["reflen"] = reflen 56 | 57 | result["testlen"] = testlen 58 | 59 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 60 | 61 | result['correct'] = [0]*n 62 | for (ngram, count) in counts.items(): 63 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 64 | 65 | return result 66 | 67 | class BleuScorer(object): 68 | """Bleu scorer. 69 | """ 70 | 71 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 72 | # special_reflen is used in oracle (proportional effective ref len for a node). 73 | 74 | def copy(self): 75 | ''' copy the refs.''' 76 | new = BleuScorer(n=self.n) 77 | new.ctest = copy.copy(self.ctest) 78 | new.crefs = copy.copy(self.crefs) 79 | new._score = None 80 | return new 81 | 82 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 83 | ''' singular instance ''' 84 | 85 | self.n = n 86 | self.crefs = [] 87 | self.ctest = [] 88 | self.cook_append(test, refs) 89 | self.special_reflen = special_reflen 90 | 91 | def cook_append(self, test, refs): 92 | '''called by constructor and __iadd__ to avoid creating new instances.''' 93 | 94 | if refs is not None: 95 | self.crefs.append(cook_refs(refs)) 96 | if test is not None: 97 | cooked_test = cook_test(test, self.crefs[-1]) 98 | self.ctest.append(cooked_test) ## N.B.: -1 99 | else: 100 | self.ctest.append(None) # lens of crefs and ctest have to match 101 | 102 | self._score = None ## need to recompute 103 | 104 | def ratio(self, option=None): 105 | self.compute_score(option=option) 106 | return self._ratio 107 | 108 | def score_ratio(self, option=None): 109 | '''return (bleu, len_ratio) pair''' 110 | return (self.fscore(option=option), self.ratio(option=option)) 111 | 112 | def score_ratio_str(self, option=None): 113 | return "%.4f (%.2f)" % self.score_ratio(option) 114 | 115 | def reflen(self, option=None): 116 | self.compute_score(option=option) 117 | return self._reflen 118 | 119 | def testlen(self, option=None): 120 | self.compute_score(option=option) 121 | return self._testlen 122 | 123 | def retest(self, new_test): 124 | if type(new_test) is str: 125 | new_test = [new_test] 126 | assert len(new_test) == len(self.crefs), new_test 127 | self.ctest = [] 128 | for t, rs in zip(new_test, self.crefs): 129 | self.ctest.append(cook_test(t, rs)) 130 | self._score = None 131 | 132 | return self 133 | 134 | def rescore(self, new_test): 135 | ''' replace test(s) with new test(s), and returns the new score.''' 136 | 137 | return self.retest(new_test).compute_score() 138 | 139 | def size(self): 140 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 141 | return len(self.crefs) 142 | 143 | def __iadd__(self, other): 144 | '''add an instance (e.g., from another sentence).''' 145 | 146 | if type(other) is tuple: 147 | ## avoid creating new BleuScorer instances 148 | self.cook_append(other[0], other[1]) 149 | else: 150 | assert self.compatible(other), "incompatible BLEUs." 151 | self.ctest.extend(other.ctest) 152 | self.crefs.extend(other.crefs) 153 | self._score = None ## need to recompute 154 | 155 | return self 156 | 157 | def compatible(self, other): 158 | return isinstance(other, BleuScorer) and self.n == other.n 159 | 160 | def single_reflen(self, option="average"): 161 | return self._single_reflen(self.crefs[0][0], option) 162 | 163 | def _single_reflen(self, reflens, option=None, testlen=None): 164 | 165 | if option == "shortest": 166 | reflen = min(reflens) 167 | elif option == "average": 168 | reflen = float(sum(reflens))/len(reflens) 169 | elif option == "closest": 170 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 171 | else: 172 | assert False, "unsupported reflen option %s" % option 173 | 174 | return reflen 175 | 176 | def recompute_score(self, option=None, verbose=0): 177 | self._score = None 178 | return self.compute_score(option, verbose) 179 | 180 | def compute_score(self, option=None, verbose=0): 181 | n = self.n 182 | small = 1e-9 183 | tiny = 1e-15 ## so that if guess is 0 still return 0 184 | bleu_list = [[] for _ in range(n)] 185 | 186 | if self._score is not None: 187 | return self._score 188 | 189 | if option is None: 190 | option = "average" if len(self.crefs) == 1 else "closest" 191 | 192 | self._testlen = 0 193 | self._reflen = 0 194 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 195 | 196 | # for each sentence 197 | for comps in self.ctest: 198 | testlen = comps['testlen'] 199 | self._testlen += testlen 200 | 201 | if self.special_reflen is None: ## need computation 202 | reflen = self._single_reflen(comps['reflen'], option, testlen) 203 | else: 204 | reflen = self.special_reflen 205 | 206 | self._reflen += reflen 207 | 208 | for key in ['guess','correct']: 209 | for k in range(n): 210 | totalcomps[key][k] += comps[key][k] 211 | 212 | # append per image bleu score 213 | bleu = 1. 214 | for k in range(n): 215 | bleu *= (float(comps['correct'][k]) + tiny) \ 216 | /(float(comps['guess'][k]) + small) 217 | bleu_list[k].append(bleu ** (1./(k+1))) 218 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 219 | if ratio < 1: 220 | for k in range(n): 221 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 222 | 223 | if verbose > 1: 224 | print(comps, reflen) 225 | 226 | totalcomps['reflen'] = self._reflen 227 | totalcomps['testlen'] = self._testlen 228 | 229 | bleus = [] 230 | bleu = 1. 231 | for k in range(n): 232 | bleu *= float(totalcomps['correct'][k] + tiny) \ 233 | / (totalcomps['guess'][k] + small) 234 | bleus.append(bleu ** (1./(k+1))) 235 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 236 | if ratio < 1: 237 | for k in range(n): 238 | bleus[k] *= math.exp(1 - 1/ratio) 239 | 240 | if verbose > 0: 241 | print(totalcomps) 242 | print("ratio:", ratio) 243 | 244 | self._score = bleus 245 | return self._score, bleu_list -------------------------------------------------------------------------------- /Code/Evaluator/CIDEr/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | import sys 10 | sys.path.append('../..') 11 | from Evaluator.CIDEr.cider_scorer import CiderScorer 12 | 13 | 14 | class CIDErEvaluator: 15 | """ 16 | Main Class to compute the CIDEr metric 17 | 18 | """ 19 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 20 | # set cider to sum over 1 to 4-grams 21 | self._n = n 22 | # set the standard deviation parameter for gaussian penalty 23 | self._sigma = sigma 24 | 25 | def compute_score(self, gts, res, mode='all'): 26 | """ 27 | Main function to compute CIDEr score 28 | :param hypo_for_image (dict) : dictionary with key and value 29 | ref_for_image (dict) : dictionary with key and value 30 | :return: cider (float) : computed CIDEr score for the corpus 31 | """ 32 | 33 | assert(gts.keys() == res.keys()) 34 | imgIds = gts.keys() 35 | 36 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 37 | 38 | for id in imgIds: 39 | hypo = res[id] 40 | ref = gts[id] 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | # assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | # for h in hypo: 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | if mode == 'all': 51 | return [score] 52 | elif mode == 'every': 53 | return scores 54 | else: 55 | return score, scores 56 | 57 | def method(self): 58 | return "CIDEr" 59 | 60 | 61 | def main(): 62 | ref = {'1': ['I like it !', 'I love it !'], '2': ['I completely do not know !'], 63 | '3': ['how about you ?'], '4': ['what is this ?'], 5: ['this is amazing !']} 64 | hypo = {'1': ['I love you !'], '2': ['I do not know !'], '3': ['how are you ?'], 65 | '4': ['what is this animal ?'], 5: ['this is awkward !']} 66 | ref1 = {'1': ['I like it !', 'I love it !'], '2': ['I completely do not know !']} 67 | hypo1 = {'1': ['I love you !'], '2':['I do not know !']} 68 | eval = CIDErEvaluator() 69 | print(eval.compute_score(ref, hypo, 'every')) 70 | print(eval.compute_score(ref1, hypo1, 'every')) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /Code/Evaluator/CIDEr/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | 12 | def precook(s, n=4, out=False): 13 | """ 14 | Takes a string as input and returns an object that can be given to 15 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 16 | can take string arguments as well. 17 | :param s: string : sentence to be converted into ngrams 18 | :param n: int : number of ngrams for which representation is calculated 19 | :return: term frequency vector for occuring ngrams 20 | """ 21 | words = s.split() 22 | counts = defaultdict(int) 23 | for k in range(1,n+1): 24 | for i in range(len(words)-k+1): 25 | ngram = tuple(words[i:i+k]) 26 | counts[ngram] += 1 27 | return counts 28 | 29 | 30 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 31 | '''Takes a list of reference sentences for a single segment 32 | and returns an object that encapsulates everything that BLEU 33 | needs to know about them. 34 | :param refs: list of string : reference sentences for some image 35 | :param n: int : number of ngrams for which (ngram) representation is calculated 36 | :return: result (list of dict) 37 | ''' 38 | return [precook(ref, n) for ref in refs] 39 | 40 | 41 | def cook_test(test, n=4): 42 | '''Takes a test sentence and returns an object that 43 | encapsulates everything that BLEU needs to know about it. 44 | :param test: list of string : hypothesis sentence for some image 45 | :param n: int : number of ngrams for which (ngram) representation is calculated 46 | :return: result (dict) 47 | ''' 48 | return precook(test, n, True) 49 | 50 | 51 | class CiderScorer(object): 52 | """CIDEr scorer. 53 | """ 54 | 55 | def copy(self): 56 | ''' copy the refs.''' 57 | new = CiderScorer(n=self.n) 58 | new.ctest = copy.copy(self.ctest) 59 | new.crefs = copy.copy(self.crefs) 60 | return new 61 | 62 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 63 | ''' singular instance ''' 64 | self.n = n 65 | self.sigma = sigma 66 | self.crefs = [] 67 | self.ctest = [] 68 | self.document_frequency = defaultdict(float) 69 | self.cook_append(test, refs) 70 | self.ref_len = None 71 | 72 | def cook_append(self, test, refs): 73 | '''called by constructor and __iadd__ to avoid creating new instances.''' 74 | 75 | if refs is not None: 76 | self.crefs.append(cook_refs(refs)) 77 | if test is not None: 78 | self.ctest.append(cook_test(test)) ## N.B.: -1 79 | else: 80 | self.ctest.append(None) # lens of crefs and ctest have to match 81 | 82 | def size(self): 83 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 84 | return len(self.crefs) 85 | 86 | def __iadd__(self, other): 87 | '''add an instance (e.g., from another sentence).''' 88 | 89 | if type(other) is tuple: 90 | ## avoid creating new CiderScorer instances 91 | self.cook_append(other[0], other[1]) 92 | else: 93 | self.ctest.extend(other.ctest) 94 | self.crefs.extend(other.crefs) 95 | 96 | return self 97 | 98 | def compute_doc_freq(self): 99 | ''' 100 | Compute term frequency for reference data. 101 | This will be used to compute idf (inverse document frequency later) 102 | The term frequency is stored in the object 103 | :return: None 104 | ''' 105 | for refs in self.crefs: 106 | # refs, k ref captions of one image 107 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 108 | self.document_frequency[ngram] += 1 109 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 110 | 111 | def compute_cider(self): 112 | def counts2vec(cnts): 113 | """ 114 | Function maps counts of ngram to vector of tfidf weights. 115 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 116 | The n-th entry of array denotes length of n-grams. 117 | :param cnts: 118 | :return: vec (array of dict), norm (array of float), length (int) 119 | """ 120 | vec = [defaultdict(float) for _ in range(self.n)] 121 | length = 0 122 | norm = [0.0 for _ in range(self.n)] 123 | for (ngram, term_freq) in cnts.items(): 124 | # give word count 1 if it doesn't appear in reference corpus 125 | df = np.log(max(1.0, self.document_frequency[ngram])) 126 | # ngram index 127 | n = len(ngram)-1 128 | # tf (term_freq) * idf (precomputed idf) for n-grams 129 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 130 | # compute norm for the vector. the norm will be used for computing similarity 131 | norm[n] += pow(vec[n][ngram], 2) 132 | 133 | if n == 1: 134 | length += term_freq 135 | norm = [np.sqrt(n) for n in norm] 136 | return vec, norm, length 137 | 138 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 139 | ''' 140 | Compute the cosine similarity of two vectors. 141 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 142 | :param vec_ref: array of dictionary for vector corresponding to reference 143 | :param norm_hyp: array of float for vector corresponding to hypothesis 144 | :param norm_ref: array of float for vector corresponding to reference 145 | :param length_hyp: int containing length of hypothesis 146 | :param length_ref: int containing length of reference 147 | :return: array of score for each n-grams cosine similarity 148 | ''' 149 | delta = float(length_hyp - length_ref) 150 | # measure consine similarity 151 | val = np.array([0.0 for _ in range(self.n)]) 152 | for n in range(self.n): 153 | # ngram 154 | for (ngram,count) in vec_hyp[n].items(): 155 | # vrama91 : added clipping 156 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 157 | 158 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 159 | val[n] /= (norm_hyp[n]*norm_ref[n]) 160 | 161 | assert(not math.isnan(val[n])) 162 | # vrama91: added a length based gaussian penalty 163 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 164 | return val 165 | 166 | # compute log reference length 167 | self.ref_len = np.log(float(len(self.crefs))) 168 | 169 | scores = [] 170 | for test, refs in zip(self.ctest, self.crefs): 171 | # compute vector for test captions 172 | vec, norm, length = counts2vec(test) 173 | # compute vector for ref captions 174 | score = np.array([0.0 for _ in range(self.n)]) 175 | for ref in refs: 176 | vec_ref, norm_ref, length_ref = counts2vec(ref) 177 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 178 | # change by vrama91 - mean of ngram scores, instead of sum 179 | score_avg = np.mean(score) 180 | # divide by number of references 181 | score_avg /= len(refs) 182 | # multiply score by 10 183 | score_avg *= 10.0 184 | # append score of an image to the score list 185 | scores.append(score_avg) 186 | return scores 187 | 188 | def compute_score(self, option=None, verbose=0): 189 | # compute idf 190 | self.compute_doc_freq() 191 | # assert to check document frequency 192 | assert(len(self.ctest) >= max(self.document_frequency.values())) 193 | # compute cider score 194 | score = self.compute_cider() 195 | # debug 196 | # print score 197 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /Code/Evaluator/METEOR/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LibertFan/ImageCaption/b43b3701dba8db1dad70eab2de438258b7dc7847/Code/Evaluator/METEOR/data/paraphrase-en.gz -------------------------------------------------------------------------------- /Code/Evaluator/METEOR/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LibertFan/ImageCaption/b43b3701dba8db1dad70eab2de438258b7dc7847/Code/Evaluator/METEOR/meteor-1.5.jar -------------------------------------------------------------------------------- /Code/Evaluator/METEOR/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import sys 6 | import subprocess 7 | import threading 8 | 9 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 10 | METEOR_JAR = 'meteor-1.5.jar' 11 | 12 | 13 | # print METEOR_JAR 14 | 15 | class METEOREvaluator: 16 | def __init__(self): 17 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, '-', '-', '-stdio', '-l', 'en', '-norm'] 18 | self.meteor_p = subprocess.Popen(self.meteor_cmd, 19 | cwd=os.path.dirname(os.path.abspath(__file__)), 20 | stdin=subprocess.PIPE, 21 | stdout=subprocess.PIPE, 22 | stderr=subprocess.PIPE, encoding='utf-8', bufsize=0) 23 | # Used to guarantee thread safety 24 | self.lock = threading.Lock() 25 | 26 | def compute_score(self, gts, res, mode='all'): 27 | assert (gts.keys() == res.keys()) 28 | imgIds = gts.keys() 29 | scores = [] 30 | 31 | eval_line = 'EVAL' 32 | self.lock.acquire() 33 | for i in imgIds: 34 | assert (len(res[i]) == 1) 35 | stat, score_line = self._stat(res[i][0], gts[i]) 36 | eval_line += ' ||| {}'.format(stat) 37 | 38 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 39 | for i in range(0, len(imgIds)): 40 | score = self.meteor_p.stdout.readline().strip() 41 | scores.append(float(score)) 42 | score = float(self.meteor_p.stdout.readline().strip()) 43 | self.lock.release() 44 | if mode == 'all': 45 | return [score] 46 | elif mode == 'every': 47 | return scores 48 | else: 49 | return score, scores 50 | 51 | def method(self): 52 | return "METEOR" 53 | 54 | def _stat(self, hypothesis_str, reference_list): 55 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 56 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 57 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 58 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 59 | out = self.meteor_p.stdout.readline().strip() 60 | return out, score_line 61 | 62 | def _score(self, hypothesis_str, reference_list): 63 | self.lock.acquire() 64 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 65 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 66 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 67 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 68 | stats = self.meteor_p.stdout.readline().strip() 69 | eval_line = 'EVAL ||| {}'.format(stats) 70 | # EVAL ||| stats 71 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 72 | score = float(self.meteor_p.stdout.readline().strip()) 73 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 74 | # thanks for Andrej for pointing this out 75 | score = float(self.meteor_p.stdout.readline().strip()) 76 | self.lock.release() 77 | return score 78 | 79 | def __del__(self): 80 | self.lock.acquire() 81 | self.meteor_p.stdin.close() 82 | self.meteor_p.kill() 83 | self.meteor_p.wait() 84 | self.lock.release() 85 | 86 | 87 | def main(): 88 | hypo = {'1': ['I like it.i Love You..next'], '2': ['I completely do not know !']} 89 | ref = {'1': ['I love you !'], '2': ['I do not know !']} 90 | for i in range(30): 91 | hypo[i] = ['The dismantling of the Punggye-ri site, the exact date of which will depend on ' \ 92 | 'weather conditions, will involve the collapsing of all tunnels using explosives ' \ 93 | 'and the removal of all observation facilities, ' \ 94 | 'research buildings and security posts. Journalists from South Korea, China, the US, ' \ 95 | 'the UK and Russia will be asked to attend to witness the event.' \ 96 | 'North Korea said the intention was to allow ' \ 97 | 'not only the local press but also journalists of other countries to conduct on-the-spot coverage ' \ 98 | 'in order to show in a transparent manner the dismantlement of the northern nuclear test ground.'\ 99 | 'The reason officials gave for limiting the number of countries invited to send journalists ' \ 100 | 'was due to the small space of the test ground... located in the uninhabited deep mountain area.'] 101 | ref[i] = ['There is a "sense of optimism" among North Korea\'s leaders, the head of the UN\'s ' \ 102 | 'World Food Programme (WFP) said on Saturday after enjoying what he said was ' \ 103 | 'unprecedented access to the country. David Beasley spent two days in the capital, Pyongyang, ' \ 104 | 'and two outside it, accompanied by government minders. He said the country was working hard ' \ 105 | 'to meet nutritional standards, and hunger was not as high as in the 1990s. Mr Beasley\'s visit, ' \ 106 | 'from 8-11 May, included trips to WFP-funded projects - a children\'s ' \ 107 | 'nursery in South Hwanghae province and a fortified biscuit factory in North North Pyongyan province.'] 108 | meteor = METEOREvaluator() 109 | score = meteor.compute_score(ref, hypo, 'every') 110 | print(score) 111 | pass 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /Code/Evaluator/ROUGE/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if len(string) < len(sub): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub)+1)] for j in range(0, len(string)+1)] 27 | 28 | for j in range(1, len(sub)+1): 29 | for i in range(1, len(string)+1): 30 | if string[i-1] == sub[j-1]: 31 | lengths[i][j] = lengths[i-1][j-1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i-1][j], lengths[i][j-1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class RougeEvaluator(object): 39 | """ Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 40 | """ 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate) == 1) 53 | assert(len(refs) > 0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if prec_max != 0 and rec_max != 0: 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | score = [score] 76 | return score 77 | 78 | def compute_score(self, gts, res, mode='all'): 79 | """ 80 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 81 | Invoked by evaluate_captions.py 82 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 83 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 84 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 85 | """ 86 | assert(gts.keys() == res.keys()) 87 | imgIds = gts.keys() 88 | 89 | score = [] 90 | for id in imgIds: 91 | hypo = res[id] 92 | ref = gts[id] 93 | # Sanity check. 94 | assert(type(hypo) is list) 95 | # assert(len(hypo) == 1) 96 | assert(type(ref) is list) 97 | assert(len(ref) > 0) 98 | for h in hypo: 99 | score.append(self.calc_score([h], ref)) 100 | average_score = np.mean(np.array(score)) 101 | if mode == 'all': 102 | return [float(average_score)] 103 | else: 104 | return average_score, score 105 | 106 | def method(self): 107 | return "Rouge" 108 | 109 | 110 | def main(): 111 | hypo = {'1': ['I like it !'], '2': ['I completely do not know !'], 112 | '3': ['how about you ?'], '4': ['what is this ?'], 5: ['this is amazing !']} 113 | ref = {'1': ['I love you !'], '2': ['I do not know !'], '3': ['how are you ?'], 114 | '4': ['what is this animal ?'], 5: ['this is awkward !']} 115 | meteor = RougeEvaluator() 116 | score = meteor.compute_score(hypo, ref) 117 | print(score) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() -------------------------------------------------------------------------------- /Code/FeatureExtractorPytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import models 4 | import numpy as np 5 | from keras.preprocessing import image 6 | import multiprocessing as mp 7 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 8 | 9 | 10 | def get_img(path): 11 | img = image.load_img(path) 12 | x = image.img_to_array(img, data_format='channels_first') 13 | channel, s1, s2 = x.shape 14 | if s1 < 224 or s2 < 224: 15 | target_size = (max(224, s1), max(224, s2)) 16 | img = image.load_img(path, target_size=target_size) 17 | x = image.img_to_array(img, data_format='channels_first') 18 | if len(x.shape) != 3: 19 | print('[Image`s shape gets wrong. FileName: {}, Shape: {}]'.format(path, x.shape)) 20 | return False, x 21 | else: 22 | return True, x 23 | 24 | 25 | class ResNet152Extractor: 26 | def __init__(self): 27 | base_model = models.resnet152(pretrained=True) 28 | modules = list(base_model.children())[:-1] 29 | self.base_model = torch.nn.Sequential(*modules) 30 | for p in self.base_model.parameters(): 31 | p.requires_grad = False 32 | if torch.cuda.is_available(): 33 | self.base_model = self.base_model.cuda() 34 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 35 | self.base_model.eval() 36 | 37 | def get_feature(self, img_path_sequence): 38 | self.base_model.eval() 39 | sequence_size = len(img_path_sequence) 40 | cpu_num = mp.cpu_count() - 1 41 | chunk_size = 8 42 | index = 0 43 | feature_sequence = [] 44 | while True: 45 | start_idx, end_idx = 512 * index, 512 * (index + 1) 46 | index += 1 47 | if start_idx >= sequence_size: 48 | break 49 | img_paths = img_path_sequence[start_idx: end_idx] 50 | img_sequence = [] 51 | with mp.Pool(processes=cpu_num) as pool: 52 | records = pool.map(get_img, img_paths, chunk_size) 53 | pool.close() 54 | for i, record in enumerate(records): 55 | correct_or, img = record 56 | if correct_or: 57 | img_sequence.append(img) 58 | else: 59 | raise Exception('LoadImgError') 60 | for img in img_sequence: 61 | x = np.expand_dims(img, 0) 62 | x = x / 255. 63 | mean = [0.485, 0.456, 0.406] 64 | std = [0.229, 0.224, 0.225] 65 | x[:, 0, :, :] -= mean[0] 66 | x[:, 1, :, :] -= mean[1] 67 | x[:, 2, :, :] -= mean[2] 68 | x[:, 0, :, :] /= std[0] 69 | x[:, 1, :, :] /= std[1] 70 | x[:, 2, :, :] /= std[2] 71 | x = torch.from_numpy(x).float() 72 | if torch.cuda.is_available(): 73 | x = x.cuda() 74 | feature = self.base_model(x) 75 | feature = feature.mean(-1).mean(-1) 76 | feature = feature.data.cpu().numpy() 77 | feature_sequence.append(feature) 78 | feature_sequence = np.concatenate(feature_sequence, 0) 79 | return feature_sequence 80 | -------------------------------------------------------------------------------- /Code/Generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as func 3 | import numpy as np 4 | 5 | 6 | class Generate(torch.nn.Module): 7 | def __init__(self, opts, streamer): 8 | super(Generate, self).__init__() 9 | self._options = opts 10 | self.word2vec = streamer.word2vec 11 | self.translator = streamer.tokens2sentence 12 | self.measurer = streamer.every_measure_score 13 | self.img_encoder = ImgEncoder(opts) 14 | self.decoder = Decoder(opts, self.word2vec) 15 | self.beam_searcher = BeamSearch(opts) 16 | self.optimizer = torch.optim.Adam(params=self.parameters(), lr=opts.learning_rate) 17 | 18 | def forward(self, img, concept_score, vocab, caption=None, str_caption=None, mode='Generate_MLE'): 19 | self.train() 20 | opts = self._options 21 | caption = caption.view(-1, opts.caption_size) 22 | img_vec = self.img_encoder(img) 23 | if mode == 'Generate_MLE': 24 | if img_vec.size(0) < caption.size(0): 25 | img_vec_size = img_vec.size(-1) 26 | rep_size = caption.size(0) // img_vec.size(0) 27 | img_vec = img_vec.unsqueeze(1).repeat(1, rep_size, 1).view(-1, img_vec_size) 28 | init_vec = [img_vec] 29 | gen_captions, word_prob = self.professor_decode(init_vec, concept_score, caption) 30 | word_loss = self.solve_word_loss(caption, word_prob) 31 | reg_loss = self.solve_reg_loss() 32 | loss = word_loss + reg_loss 33 | self.backward(loss) 34 | if opts.display: 35 | print('[WordLoss:{:.5f}][RegLoss:{:.5f}][Loss:{:.5f}]'.format(word_loss, reg_loss, loss)) 36 | else: 37 | assert str_caption is not None 38 | batch_size = len(str_caption) 39 | self.eval() 40 | greedy_result, greedy_prob = self.greedy_decode([img_vec], concept_score, vocab) 41 | self.train() 42 | random_result, random_prob = self.sample_decode([img_vec], concept_score, vocab, opts.sample_decode_size) 43 | result = torch.cat([greedy_result.view(batch_size, -1, opts.caption_size), 44 | random_result.view(batch_size, -1, opts.caption_size)], 1) 45 | score = self.get_reward(result, str_caption).data 46 | base_score = score[:, 0] 47 | random_score = score[:, 1:].view(-1) 48 | if base_score.size(0) < random_score.size(0): 49 | rep_size = random_score.size(0) // base_score.size(0) 50 | reward = random_score.sub(base_score.repeat(rep_size)) 51 | else: 52 | reward = random_score.sub(base_score) 53 | reinforce_loss = self.solve_reinforce_loss(random_prob, random_result, reward) 54 | reg_loss = self.solve_reg_loss() 55 | loss = reinforce_loss + reg_loss 56 | self.backward(loss) 57 | if opts.display: 58 | print('[ReinforceLoss:{:.5f}][RegLoss:{:.5f}][Loss:{:.5f}]'. 59 | format(reinforce_loss, reg_loss, loss)) 60 | 61 | def professor_decode(self, init_vec, concept_score, captions): 62 | opts = self._options 63 | assert isinstance(init_vec, list) 64 | captions = captions.view(-1, opts.caption_size) 65 | batch_size = captions.size(0) 66 | caption_sequence = captions.unbind(-1) 67 | cell_state = self.decoder.cell_init(batch_size, concept_score) 68 | word_prob_sequence = [] 69 | for t in range(opts.caption_size+len(init_vec)): 70 | if t < len(init_vec): 71 | cell_input = init_vec[t] 72 | repeat_size = batch_size // cell_input.size(0) 73 | cell_input_size = cell_input.size(-1) 74 | cell_input = cell_input.unsqueeze(1).repeat(1, repeat_size, 1).view(-1, cell_input_size) 75 | elif t == len(init_vec): 76 | start_token = torch.ones(batch_size).long() 77 | if torch.cuda.is_available(): 78 | start_token = start_token.cuda() 79 | cell_input = start_token 80 | else: 81 | cell_input = caption_sequence[t-len(init_vec)-1] 82 | word_emb = False if t < len(init_vec) else True 83 | word_project = False if t < len(init_vec) else True 84 | word_prob, cell_state = self.decoder( 85 | cell_input, cell_state, word_emb=word_emb, word_project=word_project) 86 | if t >= len(init_vec): 87 | word_prob_sequence.append(word_prob) 88 | word_prob = torch.stack(word_prob_sequence, -1) 89 | return captions, word_prob 90 | 91 | def greedy_decode(self, init_vec, concept_score, vocab): 92 | opts = self._options 93 | assert isinstance(init_vec, list) 94 | batch_size = init_vec[0].size(0) 95 | cell_state = self.decoder.cell_init(batch_size, concept_score, vocab) 96 | word_score_sequence = [] 97 | caption_sequence = [] 98 | for t in range(opts.caption_size+len(init_vec)): 99 | if t < len(init_vec): 100 | cell_input = init_vec[t] 101 | elif t == len(init_vec): 102 | start_token = torch.ones(batch_size).long() 103 | if torch.cuda.is_available(): 104 | start_token = start_token.cuda() 105 | cell_input = start_token 106 | else: 107 | cell_input = caption_sequence[-1] 108 | word_emb = False if t < len(init_vec) else True 109 | word_project = False if t < len(init_vec) else True 110 | word_score, cell_state = self.decoder( 111 | cell_input, cell_state, word_emb=word_emb, word_project=word_project) 112 | if t >= len(init_vec): 113 | cell_word_score, cell_input = word_score.max(-1) 114 | caption_sequence.append(cell_input) 115 | word_score_sequence.append(cell_word_score) 116 | word_scores = torch.stack(word_score_sequence, -1) 117 | captions = torch.stack(caption_sequence, -1) 118 | masks = captions.le(2.5).float().cumsum(-1).cumsum(-1).le(1.5).long() 119 | captions = captions.mul(masks) 120 | word_scores = word_scores.mul(masks.float()) 121 | return captions, word_scores 122 | 123 | def sample_decode(self, init_vec, concept_score, vocab, gen_size=1): 124 | opts = self._options 125 | assert isinstance(init_vec, list) 126 | batch_size = init_vec[0].size(0) * gen_size 127 | cell_state = self.decoder.cell_init(batch_size, concept_score, vocab) 128 | word_score_sequence = [] 129 | caption_sequence = [] 130 | for t in range(opts.caption_size+len(init_vec)): 131 | if t < len(init_vec): 132 | cell_input = init_vec[t] 133 | cell_input_size = cell_input.size(-1) 134 | cell_input = cell_input.unsqueeze(1).repeat(1, gen_size, 1).view(-1, cell_input_size) 135 | elif t == len(init_vec): 136 | start_token = torch.ones(batch_size).long() 137 | if torch.cuda.is_available(): 138 | start_token = start_token.cuda() 139 | cell_input = start_token 140 | else: 141 | cell_input = caption_sequence[-1] 142 | word_emb = False if t < len(init_vec) else True 143 | word_project = False if t < len(init_vec) else True 144 | word_score, cell_state = self.decoder( 145 | cell_input, cell_state, word_emb=word_emb, word_project=word_project) 146 | if t >= len(init_vec): 147 | word_prob = word_score.data.log().div(opts.temperature).exp() 148 | cell_input = torch.multinomial(word_prob, 1).data 149 | cell_word_score = word_score.gather(1, cell_input).squeeze(-1) 150 | word_score_sequence.append(cell_word_score) 151 | cell_input = cell_input.squeeze(-1) 152 | caption_sequence.append(cell_input) 153 | word_scores = torch.stack(word_score_sequence, -1) 154 | captions = torch.stack(caption_sequence, -1) 155 | masks = captions.le(2.5).float().cumsum(-1).cumsum(-1).le(1.5).long() 156 | captions = captions.mul(masks) 157 | word_scores = word_scores.mul(masks.float()) 158 | return captions, word_scores 159 | 160 | def get_reward(self, captions, str_captions): 161 | opts = self._options 162 | captions = captions.view(-1, opts.caption_size) 163 | hypos, refs = dict(), dict() 164 | hypo_size = captions.size(0) 165 | ref_size = len(str_captions) 166 | group_size = hypo_size // ref_size 167 | for idx, num_caption in enumerate(captions.data.cpu().numpy()): 168 | gen_token = self.translator(num_caption) 169 | hypos[idx] = [gen_token] 170 | raw_caption = str_captions[idx // group_size] 171 | if isinstance(raw_caption, str): 172 | raw_caption = [raw_caption] 173 | elif not isinstance(raw_caption, list): 174 | raw_caption = list(raw_caption) 175 | refs[idx] = raw_caption 176 | scores = self.measurer(hypos, refs) 177 | scores = torch.from_numpy(scores).float() 178 | if torch.cuda.is_available(): 179 | scores = scores.cuda() 180 | if group_size is not None: 181 | scores = scores.view(-1, group_size) 182 | return scores 183 | 184 | def backward(self, loss): 185 | opts = self._options 186 | optimizer = self.optimizer 187 | optimizer.zero_grad() 188 | loss.backward() 189 | if opts.grad_value_clip: 190 | torch.nn.utils.clip_grad_value_(self.parameters(), opts.grad_clip_value) 191 | elif opts.grad_norm_clip: 192 | torch.nn.utils.clip_grad_norm_(self.parameters(), opts.grad_norm_clip_value) 193 | optimizer.step() 194 | 195 | def generate(self, img, concept_score, vocab=None): 196 | self.eval() 197 | opts = self._options 198 | # iterative 199 | img_vec = self.img_encoder(img) 200 | init_vec = [img_vec] 201 | batch_size = img_vec.size(0) 202 | cell_state = self.decoder.cell_init(batch_size, concept_score, vocab) 203 | for vec in init_vec: 204 | if vec.size(0) < batch_size: 205 | group_size = batch_size // vec.size(0) 206 | vec_size = vec.size(-1) 207 | vec = vec.unsqueeze(1).repeat(1, group_size, 1).view(-1, vec_size) 208 | _, cell_state = self.decoder(vec, cell_state, False, False) 209 | start_token = torch.ones(batch_size).long() 210 | if torch.cuda.is_available(): 211 | start_token = start_token.cuda() 212 | gen_tokens = self.beam_searcher.forward( 213 | opts.caption_size, self.decoder, start_token, cell_state) 214 | gen_tokens = gen_tokens.view(batch_size, opts.caption_size) 215 | return gen_tokens 216 | 217 | @staticmethod 218 | def solve_reinforce_loss(word_prob, caption, reward): 219 | mask = caption.le(2.5).float().cumsum(-1).le(1.5).float().data 220 | every_loss = word_prob.clamp(1e-10).log().mul(reward.unsqueeze(-1)).neg() 221 | loss = every_loss.mul(mask).sum(-1).div(mask.sum(-1)).mean(0) 222 | return loss 223 | 224 | @staticmethod 225 | def solve_word_loss(ref_words, word_prob): 226 | word_size = word_prob.size(-1) 227 | ref_words = ref_words.index_select(-1, index=torch.arange(word_size).long()) 228 | word_loss = func.nll_loss(word_prob.log(), ref_words, reduction='none') 229 | mask = (ref_words > 0.5).float() 230 | word_loss = word_loss.mul(mask).sum().div(mask.sum()) 231 | return word_loss 232 | 233 | def solve_reg_loss(self, scope=None): 234 | opts, l1_loss_sum, l2_loss_sum = self._options, 0.0, 0.0 235 | if scope is None: 236 | named_parameters = self.named_parameters() 237 | else: 238 | named_parameters = scope.named_parameters() 239 | for name, param in named_parameters: 240 | if 'word_emb' not in name: 241 | l1_loss_sum += param.abs().sum() 242 | l2_loss_sum += param.pow(2).sum() 243 | reg_loss = opts.l1_factor * l1_loss_sum + opts.l2_factor * l2_loss_sum 244 | return reg_loss 245 | 246 | 247 | class ImgEncoder(torch.nn.Module): 248 | def __init__(self, opts): 249 | super(ImgEncoder, self).__init__() 250 | self._options = opts 251 | self.bn_l0 = torch.nn.BatchNorm1d(num_features=opts.img_size, eps=1e-05, momentum=0.1, 252 | affine=True, track_running_stats=True) 253 | self.fc_l1 = torch.nn.Linear(opts.img_size, opts.img_size, bias=False) 254 | self.bn_l1 = torch.nn.BatchNorm1d(num_features=opts.img_size, eps=1e-05, momentum=0.1, 255 | affine=True, track_running_stats=True) 256 | self.fc_l2 = torch.nn.Linear(opts.img_size, opts.word_emb_size, bias=False) 257 | self.bn_l2 = torch.nn.BatchNorm1d(num_features=opts.word_emb_size, eps=1e-05, momentum=0.1, 258 | affine=True, track_running_stats=True) 259 | self.var_init() 260 | 261 | def forward(self, img): 262 | img_vec = self.bn_l2(func.relu(self.fc_l2( 263 | self.bn_l1(func.relu(self.fc_l1( 264 | self.bn_l0(img) 265 | ))) 266 | ))) 267 | return img_vec 268 | 269 | def var_init(self): 270 | torch.nn.init.normal_(tensor=self.bn_l0.weight) 271 | torch.nn.init.constant_(tensor=self.bn_l0.bias, val=0.0) 272 | torch.nn.init.xavier_uniform_(tensor=self.fc_l1.weight) 273 | torch.nn.init.normal_(tensor=self.bn_l1.weight) 274 | torch.nn.init.constant_(tensor=self.bn_l1.bias, val=0.0) 275 | torch.nn.init.xavier_uniform_(tensor=self.fc_l2.weight) 276 | torch.nn.init.normal_(tensor=self.bn_l2.weight) 277 | torch.nn.init.constant_(tensor=self.bn_l2.bias, val=0.0) 278 | 279 | 280 | class Decoder(torch.nn.Module): 281 | def __init__(self, opts, word2vec=None): 282 | super(Decoder, self).__init__() 283 | self._options = opts 284 | self.word2vec = word2vec 285 | self.word_emb = torch.nn.Embedding(opts.word_num, opts.word_emb_size) 286 | self.cell = SCNCore(opts, opts.word_emb_size, opts.rnn_size, opts.word_num) 287 | self.word_project = torch.nn.Linear(opts.rnn_size, opts.word_num, True) 288 | self.dropout = torch.nn.Dropout(opts.dropout_rate) 289 | self.concept_score = torch.zeros(opts.batch_size, opts.word_num) 290 | self.vocab = None 291 | self.var_init() 292 | 293 | def forward(self, cell_input, cell_state, word_emb=True, word_project=True): 294 | if word_emb: 295 | cell_input = self.word_emb(cell_input) 296 | cell_input = self.dropout(cell_input) 297 | cell_state = self.cell(cell_input, cell_state) 298 | if word_project: 299 | feature = cell_state[0] 300 | word_score = self.word_project(feature) 301 | if self.vocab is not None: 302 | vocab = self.vocab 303 | if feature.size(0) > vocab.size(0): 304 | rep_size = feature.size(0) // vocab.size(0) 305 | vocab_size = vocab.size(-1) 306 | vocab = vocab.unsqueeze(1).repeat(1, rep_size, 1).view(-1, vocab_size) 307 | word_prob = func.softmax(word_score.masked_fill(vocab.le(0.5), -np.inf), -1).clamp(1e-20) 308 | else: 309 | word_prob = func.softmax(word_score, -1) 310 | return word_prob, cell_state 311 | return None, cell_state 312 | 313 | def var_init(self): 314 | if self.word2vec is not None: 315 | self.word_emb.from_pretrained( 316 | embeddings=torch.from_numpy(self.word2vec), freeze=False) 317 | else: 318 | torch.nn.init.uniform_(self.word_emb.weight, -0.08, 0.08) 319 | torch.nn.init.xavier_uniform_(tensor=self.word_project.weight) 320 | 321 | def cell_init(self, batch_size, concept, vocab=None): 322 | opts = self._options 323 | if vocab is not None: 324 | vocab = torch.zeros(batch_size, opts.word_num).scatter(1, vocab, 1).long() 325 | self.vocab = vocab 326 | return self.cell.cell_init(batch_size, concept) 327 | 328 | 329 | class SCNCore(torch.nn.Module): 330 | def __init__(self, opts, input_size, hidden_size, concept_size): 331 | super(SCNCore, self).__init__() 332 | self.hidden_size = hidden_size 333 | self.mix_input_w = torch.nn.Linear(4*hidden_size, 4*hidden_size, True) 334 | self.mix_state_w = torch.nn.Linear(4*hidden_size, 4*hidden_size, True) 335 | self.input_w = torch.nn.Linear(input_size, 4*hidden_size, False) 336 | self.state_w = torch.nn.Linear(hidden_size, 4*hidden_size, False) 337 | self.concept_input_w = torch.nn.Linear(concept_size, 4*hidden_size, False) 338 | self.concept_state_w = torch.nn.Linear(concept_size, 4*hidden_size, False) 339 | self.concept = None 340 | self.dropout = torch.nn.Dropout(opts.dropout_rate) 341 | self.concept_dropout = torch.nn.Dropout(opts.concept_dropout_rate) 342 | self.var_init() 343 | 344 | def var_init(self): 345 | torch.nn.init.xavier_normal_(self.mix_input_w.weight) 346 | torch.nn.init.xavier_normal_(self.mix_state_w.weight) 347 | torch.nn.init.constant_(self.mix_input_w.bias, 0.0) 348 | torch.nn.init.constant_(self.mix_state_w.bias, 0.0) 349 | torch.nn.init.xavier_normal_(self.input_w.weight) 350 | torch.nn.init.xavier_normal_(self.state_w.weight) 351 | torch.nn.init.xavier_normal_(self.concept_input_w.weight) 352 | torch.nn.init.xavier_normal_(self.concept_state_w.weight) 353 | 354 | def forward(self, core_input, state): 355 | assert (isinstance(state, tuple) or isinstance(state, list)) and len(state) == 2 356 | h_state, c_state = state 357 | concept = self.concept 358 | if concept.size(0) < core_input.size(0): 359 | rep_size, concept_size = core_input.size(0) // concept.size(0), concept.size(-1) 360 | concept = concept.unsqueeze(-1).repeat(1, rep_size, 1).view(-1, concept_size) 361 | hidden_vec = self.concept_input_w(concept).mul(self.input_w(core_input)). \ 362 | add(self.concept_state_w(concept).mul(self.state_w(h_state))).sigmoid() 363 | i_t, f_t, o_t, c_t = hidden_vec.split(hidden_vec.size(1) // 4, 1) 364 | c_t = i_t.mul(c_t).add(f_t.mul(c_state)) 365 | h_t = o_t.mul(c_t.tanh()) 366 | state = (h_t, c_t) 367 | return state 368 | 369 | def cell_init(self, batch_size, concept): 370 | self.concept = self.concept_dropout(concept.detach()) 371 | return (torch.zeros(batch_size, self.hidden_size, requires_grad=True).float(), 372 | torch.zeros(batch_size, self.hidden_size, requires_grad=True).float()) 373 | 374 | 375 | class BeamSearch(object): 376 | def __init__(self, opts): 377 | self._options = opts 378 | self.word_length, self.stops, self.prob = None, None, None 379 | self.batch_size = None 380 | self.time = None 381 | self.prev_index_sequence = None 382 | 383 | def init(self, batch_size): 384 | self.batch_size = batch_size 385 | self.word_length = torch.zeros(batch_size).to(torch.int64) 386 | self.stops = torch.zeros(batch_size).to(torch.int64) 387 | self.prob = torch.ones(batch_size) 388 | self.prev_index_sequence = list() 389 | 390 | def forward(self, length, cell, word, state, **kwargs): 391 | self.init(word.size(0)) 392 | word_list = [] 393 | for i in range(length): 394 | self.time = i 395 | word_prob, next_state = cell(word, state) 396 | word, state = self.step(next_state, word_prob) 397 | word_list.append(word) 398 | word = self.get_output_words(word_list) 399 | return word 400 | 401 | def get_output_words(self, word_list): 402 | opts = self._options 403 | word_sequence = [] 404 | index = torch.arange(self.batch_size).mul(opts.beam_size).long() 405 | prev_index_sequence = self.prev_index_sequence 406 | for word, prev_index in zip(word_list[::-1], prev_index_sequence[::-1]): 407 | output_word = word.index_select(0, index) 408 | index = prev_index.index_select(0, index) 409 | word_sequence.append(output_word) 410 | return torch.stack(word_sequence[::-1], 1) 411 | 412 | def step(self, next_state, word_prob): 413 | word_prob = self.solve_prob(word_prob) 414 | word_length = self.solve_length() 415 | next_word, prev_index = self.solve_score(word_prob, word_length) 416 | next_state = self.update(prev_index, next_word, next_state, word_prob) 417 | return next_word, next_state 418 | 419 | def solve_prob(self, word_prob): 420 | opts = self._options 421 | stops = self.stops 422 | stops = stops.unsqueeze(dim=-1) 423 | unstop_word_prob = torch.mul(word_prob, (1 - stops).float()) 424 | batch_size = self.batch_size if self.time == 0 else self.batch_size * opts.beam_size 425 | pad = torch.tensor([[opts.pad_id]]).long() 426 | if torch.cuda.is_available(): 427 | pad = pad.cuda() 428 | stop_prob = torch.zeros(1, opts.word_num).scatter_(1, pad, 1.0).repeat(batch_size, 1) 429 | stop_word_prob = stop_prob.mul(stops.float()) 430 | word_prob = unstop_word_prob.add(stop_word_prob) 431 | prob = self.prob 432 | prob = prob.unsqueeze(-1) 433 | word_prob = prob.mul(word_prob) 434 | return word_prob 435 | 436 | def solve_length(self): 437 | opts, stops, word_length = self._options, self.stops, self.word_length 438 | stops = stops.unsqueeze(dim=-1) 439 | word_length = word_length.unsqueeze(dim=-1) 440 | batch_size = self.batch_size if self.time == 0 else self.batch_size * opts.beam_size 441 | pad = torch.tensor([[opts.eos_id, opts.pad_id]]).long() 442 | if torch.cuda.is_available(): 443 | pad = pad.cuda() 444 | unstop_tokens = torch.ones(1, opts.word_num).scatter_(1, pad, 0.0).\ 445 | repeat(batch_size, 1).long() 446 | add_length = unstop_tokens.mul(1 - stops) 447 | word_length = word_length.add(add_length) 448 | return word_length 449 | 450 | def solve_score(self, word_prob, word_length): 451 | opts = self._options 452 | beam_size = 1 if self.time == 0 else opts.beam_size 453 | length_penalty = ((word_length + 5).float().pow(opts.length_penalty_factor)).\ 454 | div((torch.tensor([6.0])).pow(opts.length_penalty_factor)) 455 | word_score = word_prob.clamp(1e-20, 1.0).log().div(length_penalty) 456 | # mini = word_score.min() 457 | word_score = word_score.view(-1, beam_size * opts.word_num) 458 | beam_score, beam_words = word_score.topk(opts.beam_size) 459 | prev_index = torch.arange(self.batch_size).long().mul(beam_size).view(-1, 1).\ 460 | add(beam_words.div(opts.word_num)).view(-1) 461 | next_words = beam_words.fmod(opts.word_num).view(-1).long() 462 | self.prev_index_sequence.append(prev_index) 463 | return next_words, prev_index 464 | 465 | def update(self, index, word, state, prob): 466 | opts = self._options 467 | next_state = (state[0].index_select(0, index), state[1].index_select(0, index)) 468 | self.stops = word.le(opts.eos_id).long() 469 | self.prob = prob.index_select(0, index).gather(1, word.view(-1, 1)).squeeze(1) 470 | self.word_length = self.word_length.gather(0, index).add(1-self.stops) 471 | return next_state 472 | -------------------------------------------------------------------------------- /Code/OrderClusterStream.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | import random 4 | import numpy as np 5 | import copy 6 | sys.path.append('..') 7 | from Evaluator.BLEU.bleu import BLEUEvaluator 8 | from Evaluator.CIDEr.cider import CIDErEvaluator 9 | from Evaluator.METEOR.meteor import METEOREvaluator 10 | from Evaluator.ROUGE.rouge import RougeEvaluator 11 | 12 | 13 | class DataStream(object): 14 | def __init__(self, opts): 15 | self._options = opts 16 | self.word_freq, self.id2word, self.word2id, self.raw_word2vec = self.get_vocab_info() 17 | (self._PAD, self.pad_id), (self._GO, self.go_id), (self._EOS, self.eos_id) = ("PAD", 0), ('GO', 1), ('EOS', 2) 18 | opts.pad_id, opts.go_id, opts.eos_id = self.pad_id, self.go_id, self.eos_id 19 | self.spec_tokens = [self._PAD, self._GO, self._EOS] 20 | opts.word_num = self.num_units = self.get_num_units() 21 | print('[Stream][Number of words number: {}]'.format(self.num_units)) 22 | self.word2vec = self.raw_word2vec[:self.num_units] 23 | # load pickle file 24 | train_data_set, val_data_set, test_data_set = \ 25 | list(map(self.load_pickle_file, [opts.train_data_path, opts.val_data_path, opts.test_data_path])) 26 | print('[Stream][Data has loaded]') 27 | self.raw_train_data_set = self.erase_pad(train_data_set) 28 | self.raw_val_data_set = self.erase_pad(val_data_set) 29 | self.raw_test_data_set = self.erase_pad(test_data_set) 30 | print('[Stream][ErasePad of caption has finished]') 31 | self.train_data_set = self.reshape(self.raw_train_data_set, plain=True, align=True) 32 | self.val_data_set = self.reshape(self.raw_val_data_set, plain=opts.data_plain, align=True, train=False) 33 | self.test_data_set = self.reshape(self.raw_test_data_set, plain=opts.data_plain, align=True, train=False) 34 | self.full_train_data_set = self.reshape(self.raw_train_data_set, plain=False, align=True, train=True) 35 | self.train_iter = self.val_iter = self.test_iter = self.train_full_iter = -1 36 | self.train_data_size, self.val_data_size, self.test_data_size, self.full_train_data_size = \ 37 | len(self.train_data_set), len(self.val_data_set), len(self.test_data_set), len(self.full_train_data_set) 38 | self.evaluators = [BLEUEvaluator(4), METEOREvaluator(), RougeEvaluator(), CIDErEvaluator()] 39 | 40 | def get_num_units(self): 41 | opts = self._options 42 | if opts.word_num is not None: 43 | return opts.word_num 44 | elif opts.min_freq is not None: 45 | word_num = -1 46 | for i, (word, freq) in enumerate(self.word_freq): 47 | if i > len(self.spec_tokens) and freq < opts.min_freq: 48 | word_num = i 49 | break 50 | return word_num 51 | else: 52 | word_num = len(self.word_freq) 53 | return word_num 54 | 55 | @staticmethod 56 | def load_pickle_file(path): 57 | with open(path, 'rb') as f: 58 | data = pickle.load(f) 59 | f.close() 60 | return data 61 | 62 | def erase_pad(self, data_set): 63 | opts, num_units, eos_id, pad_id = self._options, self.num_units, self.eos_id, self.pad_id 64 | max_length = opts.caption_size 65 | img_idx_seq, captions_seq, feature_seq, num_captions_seq = data_set 66 | 67 | def remove_pad(num_captions): 68 | removed_captions = [] 69 | for num_caption in num_captions: 70 | removed_captions.append(remove_pad_single(num_caption)) 71 | num_captions = np.array(removed_captions) 72 | return num_captions 73 | 74 | def remove_pad_single(num_caption): 75 | removed_caption = [] 76 | for num in num_caption: 77 | if num < num_units: 78 | removed_caption.append(num) 79 | if len(removed_caption) > max_length - 1: 80 | removed_caption = removed_caption[:max_length - 1] + [eos_id] 81 | else: 82 | pad_length = max_length - 1 - len(removed_caption) 83 | removed_caption = removed_caption + [eos_id] + [pad_id] * pad_length 84 | assert len(removed_caption) == max_length 85 | return removed_caption 86 | 87 | captions_sequence = [[' '.join(caption) for caption in captions] for captions in captions_seq] 88 | num_captions_sequence = np.array(list(map(remove_pad, num_captions_seq))) 89 | new_data_set = [img_idx_seq, captions_sequence, feature_seq, num_captions_sequence] 90 | return new_data_set 91 | 92 | @staticmethod 93 | def reshape(data_set, plain=False, align=False, train=True): 94 | records = [] 95 | for img_idx, captions, feature, num_captions in \ 96 | zip(data_set[0], data_set[1], data_set[2], data_set[3]): 97 | if plain: 98 | for caption, num_caption in zip(captions, num_captions): 99 | records.append([img_idx, caption, feature, num_caption]) 100 | else: 101 | if align and train: 102 | records.append([img_idx, captions[:5], feature, num_captions[:5]]) 103 | elif align: 104 | records.append([img_idx, captions[:5], feature, num_captions[:5]]) 105 | else: 106 | records.append([img_idx, captions, feature, num_captions]) 107 | return records 108 | 109 | @staticmethod 110 | def convert(records): 111 | column_num = len(records[0]) 112 | data_set = [[] for _ in range(column_num)] 113 | for record in records: 114 | for i, v in enumerate(record): 115 | data_set[i].append(v) 116 | records = [] 117 | for data in data_set: 118 | np_data = np.array(data) 119 | records.append(np_data) 120 | return records 121 | 122 | def get_next_train_batch(self): 123 | opts = self._options 124 | if self.train_iter == -1: 125 | random.shuffle(self.train_data_set) 126 | self.train_iter += 1 127 | start_idx = self.train_iter * opts.batch_size 128 | end_idx = (self.train_iter + 1) * opts.batch_size 129 | if start_idx >= self.train_data_size: 130 | self.train_iter = -1 131 | return self.get_next_train_batch() 132 | else: 133 | curr_train_batch = self.train_data_set[start_idx:end_idx] 134 | return self.convert(curr_train_batch) 135 | 136 | def get_next_val_batch(self): 137 | opts = self._options 138 | self.val_iter += 1 139 | start_idx = self.val_iter * opts.batch_size 140 | end_idx = (self.val_iter + 1) * opts.batch_size 141 | if start_idx >= self.val_data_size: 142 | self.val_iter = -1 143 | return None 144 | else: 145 | curr_val_batch = self.val_data_set[start_idx:end_idx] 146 | return self.convert(curr_val_batch) 147 | 148 | def get_next_test_batch(self): 149 | opts = self._options 150 | self.test_iter += 1 151 | start_idx = self.test_iter * opts.batch_size 152 | end_idx = (self.test_iter + 1) * opts.batch_size 153 | if start_idx >= self.test_data_size: 154 | self.test_iter = -1 155 | return None 156 | else: 157 | curr_test_batch = self.test_data_set[start_idx:end_idx] 158 | return self.convert(curr_test_batch) 159 | 160 | def get_next_full_train_batch(self): 161 | opts = self._options 162 | if self.train_full_iter == -1: 163 | random.shuffle(self.full_train_data_set) 164 | self.train_full_iter += 1 165 | start_idx = self.train_full_iter * opts.batch_size 166 | end_idx = (self.train_full_iter + 1) * opts.batch_size 167 | if end_idx >= self.full_train_data_size: 168 | self.train_full_iter = -1 169 | return self.get_next_full_train_batch() 170 | else: 171 | curr_train_batch = self.full_train_data_set[start_idx:end_idx] 172 | return self.convert(curr_train_batch) 173 | 174 | @staticmethod 175 | def data_shuffle(data): 176 | raw_idx = data[0] 177 | copy_data = copy.copy(data) 178 | num_seq = np.arange(len(raw_idx)) 179 | while True: 180 | random.shuffle(num_seq) 181 | copy_idx = raw_idx[num_seq] 182 | sign = (raw_idx == copy_idx).astype(np.int32).sum() 183 | if sign == 0: 184 | break 185 | for i, rec in enumerate(copy_data): 186 | copy_data[i] = rec[num_seq] 187 | return copy_data 188 | 189 | def get_vocab_info(self): 190 | opts = self._options 191 | return self.load_pickle_file(opts.vocab_info_path) 192 | 193 | def tokens2sentence(self, tokens): 194 | gen_question = [] 195 | for token in tokens: 196 | if token == self.pad_id or token == self.eos_id: 197 | break 198 | word = self.id2word.get(token) 199 | if word is None: 200 | break 201 | gen_question.append(word) 202 | if len(gen_question) == 0 or gen_question[-1] != '.': 203 | gen_question.append('.') 204 | return ' '.join(gen_question) 205 | 206 | def translate(self, gen_qs_tokens): 207 | sentences = [] 208 | for qs_tokens in gen_qs_tokens: 209 | try: 210 | sentences.append(self.tokens2sentence(qs_tokens)) 211 | except TypeError as e: 212 | print(qs_tokens) 213 | print(self.tokens2sentence(qs_tokens)) 214 | raise e 215 | return sentences 216 | 217 | @staticmethod 218 | def build_hypo_ref(image_ids, hypos, refs): 219 | gen_qs_num = int(len(hypos)/len(refs)) 220 | hypo_ref = [dict(), dict()] 221 | for n, image_id in enumerate(image_ids): 222 | image_name = str(image_id) 223 | hypo_ref[0][image_name] = [] 224 | hypo_ref[1][image_name] = list(refs[n]) 225 | for i in range(gen_qs_num): 226 | index = n * gen_qs_num + i 227 | hypo_ref[0][image_name].append(hypos[index]) 228 | return hypo_ref 229 | 230 | def measure_score(self, hypos, refs, mode='all'): 231 | scores = [] 232 | for evaluator in self.evaluators: 233 | scores.extend(evaluator.compute_score(refs, hypos, mode)) 234 | score_names = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4', 'METEOR', 'ROUGE', 'CIDEr'] 235 | scores_dict = dict(zip(score_names, scores)) 236 | return scores_dict 237 | 238 | def every_measure_score(self, hypos, refs): 239 | cider_eval = self.evaluators[-1] 240 | cider_scores = cider_eval.compute_score(refs, hypos, 'every') 241 | scores = np.squeeze(np.array(cider_scores)) 242 | return scores 243 | 244 | def quick_measure_score(self, hypos, refs): 245 | meteor_score = self.evaluators[1].compute_score(refs, hypos)[0] 246 | cider_score = self.evaluators[3].compute_score(refs, hypos)[0] 247 | return {'METEOR': meteor_score, 'CIDEr': cider_score} 248 | 249 | @staticmethod 250 | def write_record(file_path, line): 251 | with open(file_path, 'a') as f: 252 | f.write(line) 253 | f.close() 254 | -------------------------------------------------------------------------------- /Code/Pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as func 3 | import numpy as np 4 | import sys 5 | import nltk 6 | sys.path.append('../..') 7 | from Retriever import Hierarchical as Retriever 8 | from Generator import Generate as Generator 9 | 10 | 11 | class Pipeline(torch.nn.Module): 12 | def __init__(self, opts, streamer): 13 | super(Pipeline, self).__init__() 14 | self._options = opts 15 | self.streamer = streamer 16 | word2vec = streamer.word2vec 17 | self.high_stat = self.set_high_words() 18 | high2vec = self.high_stat.get('High2Vec') 19 | opts.high_num, opts.high_emb_size = high2vec.shape 20 | opts.word_num, opts.word_emb_size = word2vec.shape 21 | print('[Word2Vec:{}]'.format(self.streamer.word2vec.shape)) 22 | print('[High2Vec:{}]'.format(high2vec.shape)) 23 | self.retriever = Retriever(opts, high2vec, word2vec) 24 | self.generator = Generator(opts, streamer) 25 | 26 | def set_high_words(self): 27 | opts = self._options 28 | from nltk.corpus import stopwords 29 | stopwords = stopwords.words('english') 30 | opts.high_freq_num = 400 31 | streamer = self.streamer 32 | high_freq_words = [] 33 | for i, (word, freq) in enumerate(streamer.word_freq): 34 | tag = nltk.pos_tag([word])[0][-1] 35 | if freq >= opts.high_freq_num and word not in stopwords and tag[0] in ['N', 'V', 'J', 'R']: 36 | high_freq_words.append(word) 37 | high_indices = [self.streamer.word2id.get(word) for word in high_freq_words] 38 | high2idx = {i: idx for i, idx in enumerate(high_indices)} 39 | idx2high = {idx: i for i, idx in high2idx.items()} 40 | high2vec = np.array([self.streamer.word2vec[idx] for idx in high_indices]) 41 | high_stat = {'Idx2High': idx2high, 'High2Idx': high2idx, 'High2Vec': high2vec} 42 | print('[SizeOfHighFreqWords:{}]'.format(len(high_indices))) 43 | return high_stat 44 | 45 | def solve_words(self, batch_captions): 46 | idx2high = self.high_stat.get('Idx2High') 47 | high_words, whole_words = [], [] 48 | high_size, whole_size = 0, 0 49 | for captions in batch_captions: 50 | if isinstance(captions[0], list) or isinstance(captions[0], np.ndarray): 51 | words = set() 52 | for sentence in captions: 53 | words = words | set(sentence) 54 | else: 55 | words = set(captions) 56 | whole_words.append(list(words)) 57 | high_word = [] 58 | for word in list(words): 59 | idx = idx2high.get(word) 60 | if idx is not None: 61 | high_word.append(idx) 62 | high_words.append(high_word) 63 | if len(high_word) > high_size: 64 | high_size = len(high_word) 65 | if len(words) > whole_size: 66 | whole_size = len(words) 67 | for i, (high_word, whole_word) in enumerate(zip(high_words, whole_words)): 68 | high_word = high_word + [high_word[-1]] * (high_size - len(high_word)) 69 | whole_word = whole_word + [0] * (whole_size - len(whole_word)) 70 | high_words[i] = high_word 71 | whole_words[i] = whole_word 72 | return np.array(high_words).astype(np.int64), np.array(whole_words).astype(np.int64) 73 | 74 | def forward(self, img, caption, str_caption=None, mode=None): 75 | opts = self._options 76 | if mode == 'Retrieve': 77 | img = torch.from_numpy(img).float() 78 | if torch.cuda.is_available(): 79 | img = img.cuda() 80 | self_high_vocabs, self_vocabs = self.solve_words(caption) 81 | self_high_vocabs = torch.from_numpy(self_high_vocabs).long() 82 | self_vocabs = torch.from_numpy(self_vocabs).long() 83 | if torch.cuda.is_available(): 84 | self_high_vocabs = self_high_vocabs.cuda() 85 | self_vocabs = self_vocabs.cuda() 86 | self.retriever.train() 87 | self.retriever(img, self_high_vocabs, self_vocabs) 88 | elif mode.startswith('Generate'): 89 | img = torch.from_numpy(img).float() 90 | caption = torch.from_numpy(caption).long() 91 | if torch.cuda.is_available(): 92 | img = img.cuda() 93 | caption = caption.cuda() 94 | self.retriever.eval() 95 | concept_score = self.retriever.full_forward(img) 96 | vocab = concept_score.topk(opts.select_word_size, -1)[1] 97 | self.generator.train() 98 | self.generator(img, concept_score, vocab, caption, str_caption, mode=mode) 99 | else: 100 | raise Exception('Current Mode is not support') 101 | 102 | def eval_retriever(self, img, caption, ret_word=False): 103 | self.retriever.eval() 104 | img = torch.from_numpy(img).float() 105 | if torch.cuda.is_available(): 106 | img = img.cuda() 107 | if not ret_word: 108 | self_high_vocabs, self_vocabs = self.solve_words(caption) 109 | stat = self.retriever.predict(img, self_high_vocabs, self_vocabs, all_ret=False) 110 | return stat 111 | else: 112 | sel_high_words, sel_words = self.retriever.predict(img) 113 | return sel_high_words, sel_words 114 | 115 | def eval_generator(self, img, caption=None): 116 | opts = self._options 117 | self.generator.eval() 118 | self.retriever.eval() 119 | img = torch.from_numpy(img).float() 120 | if torch.cuda.is_available(): 121 | img = img.cuda() 122 | concept_score = self.retriever.full_forward(img) 123 | vocab = concept_score.topk(opts.select_word_size, -1)[1] 124 | gen_captions = self.generator.generate(img, concept_score) 125 | sel_gen_captions = self.generator.generate(img, concept_score, vocab) 126 | return sel_gen_captions, gen_captions 127 | -------------------------------------------------------------------------------- /Code/Retriever.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as func 3 | import numpy as np 4 | 5 | 6 | # 在这个模型当中,我们仍旧使用层次化的抽取模型。 7 | # 第一层抽取模型抽取高频词,同时获得在所有高频词上的得分 8 | # 第二层抽取模型在所有高频词的得分的基础上再做一次抽取 9 | # 第一层加入到第二层里面的信息为一个得分向量 10 | # 整个模型参考 Semantic Compositional Network 11 | class Hierarchical(torch.nn.Module): 12 | def __init__(self, opts, high2vec=None, word2vec=None): 13 | super(Hierarchical, self).__init__() 14 | self._options = opts 15 | self.high2vec = high2vec 16 | self.word2vec = word2vec 17 | opts.high_num, opts.high_emb_size = self.high2vec.shape 18 | opts.word_num, opts.word_emb_size = self.word2vec.shape 19 | self.high_bn_l1 = torch.nn.BatchNorm1d(opts.img_size) 20 | self.high_fc_l1 = torch.nn.Linear(opts.img_size, opts.high_num) 21 | self.high_score_fc_l1 = torch.nn.Linear(opts.high_num, opts.hidden_size) 22 | self.img_high_fc_l1 = torch.nn.Linear(opts.img_size, opts.hidden_size) 23 | self.word_fc_l1 = torch.nn.Linear(opts.hidden_size, opts.hidden_size) 24 | self.word_bn_l2 = torch.nn.BatchNorm1d(opts.hidden_size+opts.img_size) 25 | self.word_fc_l2 = torch.nn.Linear(opts.hidden_size+opts.img_size, opts.word_num) 26 | self.dropout = torch.nn.Dropout(opts.dropout_rate) 27 | self.var_init() 28 | self.optimizer = torch.optim.Adam(params=self.parameters(), lr=opts.learning_rate) 29 | 30 | def var_init(self): 31 | opts = self._options 32 | torch.nn.init.xavier_normal_(self.high_fc_l1.weight) 33 | torch.nn.init.xavier_normal_(self.high_score_fc_l1.weight) 34 | torch.nn.init.xavier_normal_(self.img_high_fc_l1.weight) 35 | torch.nn.init.xavier_normal_(self.word_fc_l1.weight) 36 | torch.nn.init.xavier_normal_(self.word_fc_l2.weight) 37 | torch.nn.init.normal_(self.word_bn_l2.weight) 38 | torch.nn.init.constant_(self.word_bn_l2.bias, 0.0) 39 | 40 | def forward(self, img, high_words, words): 41 | opts = self._options 42 | self.train() 43 | batch_size = img.size(0) 44 | high_prob = self.high_fc_l1(self.high_bn_l1(img)).sigmoid() 45 | joint_vec = self.img_high_fc_l1(self.dropout(img)).mul(self.high_score_fc_l1(high_prob)) 46 | word_prob = self.word_fc_l2(self.dropout( 47 | self.word_bn_l2(torch.cat([func.relu(self.word_fc_l1(self.dropout(joint_vec))), img], -1)) 48 | )).sigmoid() 49 | high_word_labels = torch.zeros(batch_size, opts.high_num).scatter(1, high_words, 1).long() 50 | high_word_loss = self.solve_high_loss(high_prob, high_word_labels) 51 | word_labels = torch.zeros(batch_size, opts.word_num).scatter(1, words, 1).long() 52 | word_loss = self.solve_word_loss(word_prob, word_labels) 53 | reg_loss = self.solve_reg_loss() 54 | loss = opts.high_word_factor * high_word_loss + \ 55 | opts.word_factor * word_loss + reg_loss 56 | self.backward(loss) 57 | if opts.display: 58 | print('[Loss:{:.4f}][HighWordLoss:{:.4f}][WordLoss:{:.4f}][RegLoss:{:.4f}][SelHighSize:{}][SelSize:{}]'. 59 | format(loss, high_word_loss, word_loss, reg_loss, opts.select_high_size, opts.select_word_size)) 60 | sel_high_words = high_prob.topk(opts.select_high_size)[1] 61 | sel_words = word_prob.topk(opts.select_word_size)[1] 62 | high_recall, high_precision = \ 63 | self.check_cover(high_words.data.cpu().numpy(), sel_high_words.data.cpu().numpy(), mode='high') 64 | recall, precision = \ 65 | self.check_cover(words.data.cpu().numpy(), sel_words.data.cpu().numpy(), mode='all') 66 | print('[HighStat Recall:{}]'.format(high_recall)) 67 | print('[HighStat Precision:{}]'.format(high_precision)) 68 | print('[Stat Recall:{}]'.format(recall)) 69 | print('[Stat Precision:{}]'.format(precision)) 70 | print('-'*100) 71 | 72 | def full_forward(self, img): 73 | high_prob = self.high_fc_l1(self.high_bn_l1(img)).sigmoid() 74 | joint_vec = self.img_high_fc_l1(self.dropout(img)).mul(self.high_score_fc_l1(high_prob)) 75 | word_prob = self.word_fc_l2(self.dropout( 76 | self.word_bn_l2(torch.cat([func.relu(self.word_fc_l1(self.dropout(joint_vec))), img], -1)) 77 | )).sigmoid() 78 | return word_prob 79 | 80 | def backward(self, loss): 81 | opts = self._options 82 | optimizer = self.optimizer 83 | optimizer.zero_grad() 84 | loss.backward() 85 | if opts.grad_value_clip: 86 | torch.nn.utils.clip_grad_value_(self.parameters(), opts.grad_clip_value) 87 | elif opts.grad_norm_clip: 88 | torch.nn.utils.clip_grad_norm_(self.parameters(), opts.grad_norm_clip_value) 89 | optimizer.step() 90 | 91 | def predict(self, img, high_words=None, words=None, all_ret=False): 92 | opts = self._options 93 | self.eval() 94 | high_prob = self.high_fc_l1(self.high_bn_l1(img)).sigmoid() 95 | joint_vec = self.img_high_fc_l1(img).mul(self.high_score_fc_l1(high_prob)) 96 | word_prob = self.word_fc_l2( 97 | self.word_bn_l2(torch.cat([func.relu(self.word_fc_l1(joint_vec)), img], -1)) 98 | ).sigmoid() 99 | sel_high_words = high_prob.topk(opts.select_high_size)[1] 100 | sel_words = word_prob.topk(opts.select_word_size)[1] 101 | if high_words is not None and words is not None: 102 | high_recall, high_precision = self.check_cover(high_words, sel_high_words.data.cpu().numpy(), mode='high') 103 | recall, precision = self.check_cover(words, sel_words.data.cpu().numpy(), mode='all') 104 | if all_ret: 105 | return sel_high_words, sel_words, high_recall, high_precision, recall, precision 106 | else: 107 | return high_recall, high_precision, recall, precision 108 | 109 | else: 110 | return sel_high_words, sel_words 111 | 112 | def check_cover(self, words, selects, mode='high'): 113 | opts = self._options 114 | if mode == 'high': 115 | indices = [2, 3, 4, 6, 8, 2*opts.select_high_size, opts.select_high_size] 116 | else: 117 | indices = [16, 32, 64, 128, 256, opts.select_word_size] 118 | recall = [[] for _ in range(len(indices))] 119 | precision = [[] for _ in range(len(indices))] 120 | for word, select in zip(words, selects): 121 | word = set(word) 122 | for i, idx in enumerate(indices): 123 | idx_select = set(select[:idx]) 124 | cover = word.intersection(idx_select) 125 | recall[i].append(len(cover)/float(len(word))) 126 | precision[i].append(len(cover)/float(len(idx_select))) 127 | recall = np.round(np.array(recall).mean(-1), 4) 128 | precision = np.round(np.array(precision).mean(-1), 4) 129 | return recall, precision 130 | 131 | def solve_high_loss(self, word_prob, label): 132 | opts = self._options 133 | pos_mask = label.float().ge(0.5).float() 134 | neg_mask = torch.ones_like(pos_mask).sub(pos_mask) 135 | every_loss = func.binary_cross_entropy(word_prob, label.float(), reduction='none') 136 | loss = every_loss.sum(-1).mean() 137 | if opts.display: 138 | print('[HighWord][PosSampleNum:{}][NegSampleNum:{}]'. 139 | format(pos_mask.sum().data.cpu().numpy(), neg_mask.sum().data.cpu().numpy())) 140 | mean_pos_prob = pos_mask.mul(word_prob).sum().div(pos_mask.sum()) 141 | mean_neg_prob = neg_mask.mul(word_prob).sum().div(neg_mask.sum()) 142 | print('[HighWord][MeanPosProb:{:.4f}][MeanNegProb:{:.4f}]'.format(mean_pos_prob, mean_neg_prob)) 143 | return loss 144 | 145 | def solve_word_loss(self, word_prob, label): 146 | opts = self._options 147 | pos_mask = label.float().ge(0.5).float() 148 | neg_mask = torch.ones_like(pos_mask).sub(pos_mask) 149 | pos_prob = word_prob.mul(pos_mask) 150 | neg_prob = word_prob.mul(neg_mask) 151 | neg_mask = neg_prob.ge(pos_prob.add(neg_mask).min(-1, keepdim=True)[0].sub(0.0)).float().mul(neg_mask) 152 | every_loss = func.binary_cross_entropy(word_prob, label.float(), reduction='none') 153 | loss = every_loss.sum(-1).mean() 154 | if opts.display: 155 | print('[Word][PosSampleNum:{}][NegSampleNum:{}]'. 156 | format(pos_mask.sum().data.cpu().numpy(), neg_mask.sum().data.cpu().numpy())) 157 | mean_pos_prob = pos_mask.mul(word_prob).sum().div(pos_mask.sum()) 158 | mean_neg_prob = neg_mask.mul(word_prob).sum().div(neg_mask.sum()) 159 | print('[Word][MeanPosProb:{:.4f}][MeanNegProb:{:.4f}]'.format(mean_pos_prob, mean_neg_prob)) 160 | return loss 161 | 162 | def solve_reg_loss(self, scope=None): 163 | opts, l1_loss_sum, l2_loss_sum = self._options, 0.0, 0.0 164 | if scope is None: 165 | named_parameters = self.named_parameters() 166 | else: 167 | named_parameters = scope.named_parameters() 168 | for name, param in named_parameters: 169 | if 'word_emb' not in name: 170 | l1_loss_sum += param.abs().sum() 171 | l2_loss_sum += param.pow(2).sum() 172 | reg_loss = opts.l1_factor * l1_loss_sum + opts.l2_factor * l2_loss_sum 173 | return reg_loss 174 | 175 | -------------------------------------------------------------------------------- /Code/Trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import time 5 | import argparse 6 | import torch 7 | import numpy as np 8 | from collections import Counter 9 | import nltk 10 | import pickle 11 | sys.path.append('../..') 12 | from OrderClusterStream import DataStream 13 | from Pipeline import Pipeline as Model 14 | 15 | 16 | def main(): 17 | opts = read_commands() 18 | os.environ["CUDA_VISIBLE_DEVICES"] = opts.device 19 | trainer = Trainer(opts) 20 | if opts.is_training: 21 | trainer.process() 22 | else: 23 | trainer.test(mode='Retrieve') 24 | trainer.test(mode='Generate_MLE') 25 | trainer.test(mode='Generate_Reinforce') 26 | 27 | 28 | def read_commands(): 29 | data_root = os.path.abspath('../Data') 30 | train_root = os.path.join(data_root, 'train') 31 | model_root = os.path.join(data_root, 'model') 32 | parser = argparse.ArgumentParser(usage='MS COCO Data Train Parameters') 33 | parser.add_argument('--is_training', type=int, default=1) 34 | parser.add_argument('--model_name', type=str, default='RLSCNV3') 35 | parser.add_argument('--device', type=str, default='2') 36 | parser.add_argument('--data_name', type=str, default='coco') 37 | parser.add_argument('--data_id', type=int, default=100) 38 | parser.add_argument('--tag_id', type=int, default=100) 39 | parser.add_argument('--data_plain', type=int, default=0) 40 | parser.add_argument('--log_dir', type=str, default=os.path.join(data_root, 'log')) 41 | parser.add_argument('--save_dir', type=str, default=os.path.join(data_root, 'model')) 42 | parser.add_argument('--save_folder', type=str, default=None) 43 | parser.add_argument('--util_dir', type=str, default=os.path.join(data_root, 'util')) 44 | parser.add_argument('--util_folder', type=str, default=None) 45 | parser.add_argument('--pre_reinforce_path', type=str, default=None) 46 | parser.add_argument('--pre_reinforce_constraint_path', type=str, default=None) 47 | parser.add_argument('--pre_mle_path', type=str, default=None) 48 | parser.add_argument('--pre_ret_path', type=str, default=None) 49 | parser.add_argument('--word_num', type=int, default=None) 50 | parser.add_argument('--min_freq', type=int, default=5) 51 | parser.add_argument('--beam_size', type=int, default=3) 52 | parser.add_argument('--learning_rate', type=float, default=0.0005) 53 | parser.add_argument('--grad_value_clip', type=int, default=1) 54 | parser.add_argument('--grad_clip_value', type=float, default=0.2) 55 | parser.add_argument('--grad_norm_clip', type=int, default=0) 56 | parser.add_argument('--grad_norm_clip_value', type=float, default=2.0) 57 | parser.add_argument('--grad_global_norm_clip', type=int, default=0) 58 | parser.add_argument('--grad_global_norm_clip_value', type=float, default=5.0) 59 | parser.add_argument('--l1_reg', type=float, default=1e-7) 60 | parser.add_argument('--l2_reg', type=float, default=1e-7) 61 | parser.add_argument('--epochs', type=int, default=500000) 62 | parser.add_argument('--word_emb_size', type=int, default=300) 63 | parser.add_argument('--batch_size', type=int, default=64) 64 | parser.add_argument('--caption_size', type=int, default=17) 65 | parser.add_argument('--group_size', type=int, default=5) 66 | parser.add_argument('--img_size', type=int, default=2048) 67 | parser.add_argument('--rnn_size', type=int, default=512) 68 | parser.add_argument('--transpose_size', type=int, default=1024) 69 | parser.add_argument('--hidden_size', type=int, default=2048) 70 | parser.add_argument('--reinforce_size', type=int, default=1) 71 | parser.add_argument('--temperature', type=float, default=1.0) 72 | parser.add_argument('--dropout_rate', type=float, default=0.3) 73 | parser.add_argument('--concept_dropout_rate', type=float, default=0.1) 74 | parser.add_argument('--word_dropout_keep', type=float, default=1.0) 75 | parser.add_argument('--length_penalty_factor', type=float, default=0.6) 76 | parser.add_argument('--display_every', type=int, default=10) 77 | parser.add_argument('--save_every', type=int, default=50) 78 | args = parser.parse_args() 79 | args.tag = 'THV{}_V{}'.format(args.data_id, args.tag_id) 80 | args.is_training = True if args.is_training == 1 else False 81 | args.data_plain = True if args.data_plain == 1 else False 82 | args.train_data_path = os.path.join(train_root, '{}_v{}/{}_train_v{}.pkl'.format( 83 | args.data_name, args.data_id, args.data_name, args.data_id)) 84 | args.val_data_path = os.path.join(train_root, '{}_v{}/{}_val_v{}.pkl'.format( 85 | args.data_name, args.data_id, args.data_name, args.data_id)) 86 | args.test_data_path = os.path.join(train_root, '{}_v{}/{}_test_v{}.pkl'.format( 87 | args.data_name, args.data_id, args.data_name, args.data_id)) 88 | args.vocab_info_path = os.path.join(train_root, '{}_v{}/{}_info_v{}.pkl'.format( 89 | args.data_name, args.data_id, args.data_name, args.data_id)) 90 | args.grad_value_clip = True if args.grad_value_clip == 1 else False 91 | args.grad_norm_clip = True if args.grad_norm_clip == 1 else False 92 | args.grad_global_norm_clip = True if args.grad_global_norm_clip == 1 else False 93 | args.data_plain = True if args.data_plain == 1 else False 94 | return args 95 | 96 | 97 | class Trainer(object): 98 | def __init__(self, opts): 99 | self._options = opts 100 | self.model_name = opts.model_name + '_' + opts.tag 101 | self.log_file = os.path.join(opts.log_dir, self.model_name+'_{}.txt'.format( 102 | time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())))) 103 | self.save_folder = os.path.join(opts.save_dir, self.model_name) \ 104 | if opts.save_folder is None else opts.save_folder 105 | self.util_folder = os.path.join(opts.util_dir, self.model_name) \ 106 | if opts.util_folder is None else opts.util_folder 107 | if opts.is_training: 108 | if os.path.exists(self.log_file): 109 | del_cmd = input('[Warning][LogFile {} exists][Delete it?]'.format(self.log_file)) 110 | if del_cmd: 111 | os.remove(self.log_file) 112 | if os.path.exists(self.save_folder): 113 | del_cmd = bool(eval(input('[Warning][SaveFile {} exists][Delete it?]'.format(self.save_folder)))) 114 | if del_cmd: 115 | shutil.rmtree(self.save_folder) 116 | os.mkdir(self.save_folder) 117 | if os.path.exists(self.util_folder): 118 | del_cmd = bool(eval(input('[Warning][UtilFile {} exists][Delete it?]'.format(self.util_folder)))) 119 | if del_cmd: 120 | shutil.rmtree(self.util_folder) 121 | os.mkdir(self.util_folder) 122 | self.streamer = DataStream(opts) 123 | if torch.cuda.is_available(): 124 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 125 | self.model = Model(opts, self.streamer) 126 | self.epoch = 0 127 | self.g_best_score = self.r_best_score = 0.0 128 | 129 | def process(self): 130 | opts = self._options 131 | self.epoch = 0 132 | opts.epochs = 100 133 | self.train(mode='Retrieve') 134 | self.epoch = 0 135 | opts.epochs = 100 136 | self.train(mode='Generate_MLE') 137 | self.epoch = 0 138 | opts.epochs = 100 139 | self.train(mode='Generate_Reinforce') 140 | self.epoch = 0 141 | opts.epochs = 100 142 | self.train(mode='Generate_Reinforce_Constraint') 143 | 144 | def train(self, mode=None): 145 | if mode not in ['Retrieve', 'Generate_MLE', 'Generate_Reinforce', 'Generate_Reinforce_Constraint']: 146 | raise Exception('Current train mode is not supported') 147 | opts, model, streamer = self._options, self.model, self.streamer 148 | self.adjust(mode=mode) 149 | for e in range(opts.epochs): 150 | self.epoch += 1 151 | self.adjust(mode=mode) 152 | image_idx, batch_caption, batch_feature, batch_num_caption = streamer.get_next_full_train_batch() 153 | model(batch_feature, batch_num_caption, batch_caption, mode=mode) 154 | if opts.save: 155 | self.val(mode=mode) 156 | 157 | def adjust(self, mode=None): 158 | if mode not in ['Retrieve', 'Generate_MLE', 'Generate_Reinforce', 'Generate_Reinforce_Constraint']: 159 | print('[Mode:{}]'.format(mode)) 160 | raise Exception('Current train mode is not supported') 161 | epoch, opts = self.epoch, self._options 162 | if mode == 'Retrieve': 163 | opts.learning_rate = 1e-4 164 | opts.high_word_factor = 0.1 165 | opts.word_factor = 1.0 166 | opts.l1_factor = 1e-5 167 | opts.l2_factor = 1e-5 168 | opts.batch_size = 64 169 | opts.select_high_size = 5 170 | opts.select_word_size = 64 171 | param_groups = self.model.retriever.optimizer.param_groups 172 | else: 173 | opts.schedule_sampling = False 174 | if mode == 'Generate_MLE': 175 | opts.l1_factor = 1e-6 176 | opts.l2_factor = 1e-6 177 | opts.learning_rate = 5e-4 * (0.8 ** (epoch // 20000)) 178 | opts.batch_size = 64 179 | param_groups = self.model.generator.optimizer.param_groups 180 | elif mode == 'Generate_Reinforce': 181 | opts.sample_decode_size = 1 182 | opts.l1_factor = 1e-7 183 | opts.l2_factor = 1e-7 184 | opts.temperature = 1.0 185 | opts.mle_factor = 0.0 186 | opts.reinforce_factor = 1.0 187 | opts.learning_rate = 5e-5 * (0.8 ** (epoch // 30000)) 188 | opts.batch_size = 64 189 | opts.select_high_size = 5 190 | opts.select_word_size = opts.word_num 191 | param_groups = self.model.generator.optimizer.param_groups 192 | else: 193 | opts.sample_decode_size = 1 194 | opts.l1_factor = 1e-7 195 | opts.l2_factor = 1e-7 196 | opts.temperature = 1.0 197 | opts.learning_rate = 5e-5 * (0.8 ** (epoch // 30000)) 198 | opts.batch_size = 64 199 | opts.select_high_size = 5 200 | opts.select_word_size = 64 201 | param_groups = self.model.generator.optimizer.param_groups 202 | opts.display = (epoch % opts.display_every) == 0 203 | opts.save = (epoch % opts.save_every) == 0 204 | opts.pe = getattr(opts, 'pe', -1) + 1 205 | for param_group in param_groups: 206 | param_group['lr'] = opts.learning_rate 207 | if opts.display: 208 | print('[Adjust][{}][Epoch:{}][LearningRate:{:.6f}][L1:{}][L2:{}][Dropout:{}][SelectWordsSize:{}]'. 209 | format(mode, epoch, opts.learning_rate, opts.l1_factor, opts.l2_factor, opts.dropout_rate, 210 | opts.select_word_size)) 211 | 212 | def val(self, mode=None): 213 | opts, model, streamer, epoch = self._options, self.model, self.streamer, self.epoch 214 | if mode not in ['Retrieve', 'Generate_MLE', 'Generate_Reinforce', 'Generate_Reinforce_Constraint']: 215 | raise Exception('Current train mode is not supported') 216 | elif mode != 'Retrieve': 217 | sel_hypos, hypos, refs = dict(), dict(), dict() 218 | val_idx = 0 219 | opts.batch_size = 32 220 | while True: 221 | val_idx += 1 222 | data = streamer.get_next_val_batch() 223 | if data is None: 224 | break 225 | image_idx, batch_caption, batch_feature, batch_num_caption = data 226 | sel_gen_tokens, gen_tokens = model.eval_generator(batch_feature) 227 | if not isinstance(gen_tokens, np.ndarray): 228 | gen_tokens = gen_tokens.data.cpu().numpy() 229 | if not isinstance(sel_gen_tokens, np.ndarray): 230 | sel_gen_tokens = sel_gen_tokens.data.cpu().numpy() 231 | for i, (idx, sel_gen_token, gen_token, captions) in \ 232 | enumerate(zip(image_idx, sel_gen_tokens, gen_tokens, list(batch_caption))): 233 | index = str(idx) 234 | refs[index] = list(captions) 235 | gen_caption = streamer.tokens2sentence(gen_token) 236 | hypos[index] = [gen_caption] 237 | sel_gen_caption = streamer.tokens2sentence(sel_gen_token) 238 | sel_hypos[index] = [sel_gen_caption] 239 | print('[Val][Epoch:{}]'.format(epoch)) 240 | base_scores = streamer.quick_measure_score(hypos, refs) 241 | print('[Val][Similarity Metrics]') 242 | for name, score in base_scores.items(): 243 | print('[Base][{}: {}]'.format(name, score)) 244 | sel_scores = streamer.quick_measure_score(sel_hypos, refs) 245 | for name, score in sel_scores.items(): 246 | print('[Sel][{}: {}]'.format(name, score)) 247 | val_score = sel_scores.get('METEOR') + sel_scores.get('CIDEr') 248 | print('[ValScore:{}]'.format(val_score)) 249 | if val_score > self.g_best_score and opts.is_training: 250 | if mode == 'Generate_Reinforce': 251 | self.g_best_score = val_score 252 | path = os.path.join(self.save_folder, self.model_name + '_BestRLGModel.pkl') 253 | print('[Val][Reinforce_Generate][NewBestScore: {}][SavePath: {}]'.format(val_score, path)) 254 | torch.save(self.model.state_dict(), path) 255 | path = os.path.join(self.save_folder, self.model_name + '_Single_BestRLGModel.pkl') 256 | torch.save(self.model.generator.state_dict(), path) 257 | elif mode == 'Generate_Reinforce_Constraint': 258 | self.g_best_score = val_score 259 | path = os.path.join(self.save_folder, self.model_name + '_BestRLGCModel.pkl') 260 | print('[Val][Reinforce_Generate][NewBestScore: {}][SavePath: {}]'.format(val_score, path)) 261 | torch.save(self.model.state_dict(), path) 262 | path = os.path.join(self.save_folder, self.model_name + '_Single_BestRLGCModel.pkl') 263 | torch.save(self.model.generator.state_dict(), path) 264 | else: 265 | self.g_best_score = val_score 266 | path = os.path.join(self.save_folder, self.model_name + '_BestGModel.pkl') 267 | print('[Val][Generate][NewBestScore: {}][SavePath: {}]'.format(val_score, path)) 268 | torch.save(self.model.state_dict(), path) 269 | path = os.path.join(self.save_folder, self.model_name + '_Single_BestGModel.pkl') 270 | torch.save(self.model.generator.state_dict(), path) 271 | else: 272 | val_idx = 0 273 | opts.batch_size = 64 274 | opts.display = False 275 | num_sum, recall_sum, precision_sum, high_recall_sum, high_precision_sum = 0, 0.0, 0.0, 0.0, 0.0 276 | while True: 277 | val_idx += 1 278 | data = streamer.get_next_val_batch() 279 | if data is None: 280 | break 281 | image_idx, batch_caption, batch_feature, batch_num_caption = data 282 | batch_size = len(image_idx) 283 | high_recall, high_precision, recall, precision = model.eval_retriever(batch_feature, batch_num_caption) 284 | recall_sum += recall * batch_size 285 | precision_sum += precision * batch_size 286 | high_recall_sum += high_recall * batch_size 287 | high_precision_sum += high_precision * batch_size 288 | num_sum += batch_size 289 | recall = np.round(recall_sum / float(num_sum), 4) 290 | precision = np.round(precision_sum / float(num_sum), 4) 291 | high_recall = np.round(high_recall_sum / float(num_sum), 4) 292 | high_precision = np.round(high_precision_sum / float(num_sum), 4) 293 | print('[Val][Epoch:{}]'.format(epoch)) 294 | print('[Recall]', recall) 295 | print('[Precision]', precision) 296 | print('[F1]', (2 * np.array(recall) * np.array(precision)) / (np.array(recall) + np.array(precision))) 297 | print('[HighRecall]', high_recall) 298 | print('[HighPrecision]', high_precision) 299 | print('[HighF1]', (2 * np.array(high_recall) * np.array(high_precision)) / 300 | (np.array(high_recall) + np.array(high_precision))) 301 | val_score = recall[-1] 302 | print('[ValScore:{}]'.format(val_score)) 303 | if val_score > self.r_best_score and opts.is_training: 304 | self.r_best_score = val_score 305 | path = os.path.join(self.save_folder, self.model_name+'_BestRModel.pkl') 306 | print('[Val][Retrieve][NewBestScore: {}][SavePath: {}]'.format(val_score, path)) 307 | torch.save(self.model.state_dict(), path) 308 | path = os.path.join(self.save_folder, self.model_name + '_Single_BestRModel.pkl') 309 | torch.save(self.model.retriever.state_dict(), path) 310 | 311 | def test(self, mode=None, from_file=True): 312 | opts, model, streamer = self._options, self.model, self.streamer 313 | if mode not in ['Retrieve', 'Generate_MLE', 'Generate_Reinforce', 'Generate_Reinforce_Constraint']: 314 | raise Exception('Current train mode is not supported') 315 | elif mode != 'Retrieve': 316 | self.adjust(mode=mode) 317 | if from_file: 318 | if mode == 'Generate_MLE': 319 | if opts.pre_mle_path is None: 320 | path = os.path.join(self.save_folder, self.model_name + '_BestGModel.pkl') 321 | else: 322 | path = opts.pre_mle_path 323 | print('[Test][MLE][LoadPath: {}]'.format(path)) 324 | try: 325 | model.generator.load_state_dict(torch.load(path)) 326 | except RuntimeError: 327 | model.load_state_dict(torch.load(path)) 328 | elif mode == 'Generate_Reinforce': 329 | if opts.pre_reinforce_path is None: 330 | path = os.path.join(self.save_folder, self.model_name + '_BestRLGModel.pkl') 331 | else: 332 | path = opts.pre_reinforce_path 333 | print('[Test][Reinforce][LoadPath: {}]'.format(path)) 334 | try: 335 | model.generator.load_state_dict(torch.load(path)) 336 | except RuntimeError: 337 | model.load_state_dict(torch.load(path)) 338 | else: 339 | if opts.pre_reinforce_constraint_path is None: 340 | path = os.path.join(self.save_folder, self.model_name + '_BestRLGCModel.pkl') 341 | else: 342 | path = opts.pre_reinforce_constraint_path 343 | print('[Test][Reinforce][LoadPath: {}]'.format(path)) 344 | try: 345 | model.generator.load_state_dict(torch.load(path)) 346 | except RuntimeError: 347 | model.load_state_dict(torch.load(path)) 348 | 349 | sel_hypos, hypos, refs = dict(), dict(), dict() 350 | val_idx = 0 351 | opts.batch_size = 32 352 | while True: 353 | val_idx += 1 354 | if val_idx > 1000: 355 | break 356 | data = streamer.get_next_test_batch() 357 | if data is None: 358 | break 359 | image_idx, batch_caption, batch_feature, batch_num_caption = data 360 | sel_gen_tokens, gen_tokens = model.eval_generator(batch_feature) 361 | if not isinstance(gen_tokens, np.ndarray): 362 | gen_tokens = gen_tokens.data.cpu().numpy() 363 | if not isinstance(sel_gen_tokens, np.ndarray): 364 | sel_gen_tokens = sel_gen_tokens.data.cpu().numpy() 365 | for i, (idx, sel_gen_token, gen_token, captions) in \ 366 | enumerate(zip(image_idx, sel_gen_tokens, gen_tokens, list(batch_caption))): 367 | index = str(idx) 368 | refs[index] = list(captions) 369 | gen_caption = streamer.tokens2sentence(gen_token) 370 | hypos[index] = [gen_caption] 371 | sel_gen_caption = streamer.tokens2sentence(sel_gen_token) 372 | sel_hypos[index] = [sel_gen_caption] 373 | print('[Test]') 374 | sel_scores = streamer.measure_score(sel_hypos, refs) 375 | for name, score in sel_scores.items(): 376 | print('[Sel][{}: {}]'.format(name, score)) 377 | base_scores = streamer.measure_score(hypos, refs) 378 | for name, score in base_scores.items(): 379 | print('[Base][{}: {}]'.format(name, score)) 380 | else: 381 | self.adjust(mode=mode) 382 | if from_file: 383 | if opts.pre_ret_path is None: 384 | path = os.path.join(self.save_folder, self.model_name + '_BestRModel.pkl') 385 | else: 386 | path = opts.pre_ret_path 387 | print('[Test][LoadPath: {}]'.format(path)) 388 | try: 389 | model.retriever.load_state_dict(torch.load(path)) 390 | except RuntimeError: 391 | model.load_state_dict(torch.load(path)) 392 | val_idx = 0 393 | opts.batch_size = 64 394 | opts.display = False 395 | num_sum, recall_sum, precision_sum, high_recall_sum, high_precision_sum = 0, 0.0, 0.0, 0.0, 0.0 396 | while True: 397 | val_idx += 1 398 | data = streamer.get_next_test_batch() 399 | if data is None: 400 | break 401 | image_idx, batch_caption, batch_feature, batch_num_caption = data 402 | batch_size = len(image_idx) 403 | high_recall, high_precision, recall, precision = model.eval_retriever(batch_feature, batch_num_caption) 404 | recall_sum += recall * batch_size 405 | precision_sum += precision * batch_size 406 | high_recall_sum += high_recall * batch_size 407 | high_precision_sum += high_precision * batch_size 408 | num_sum += batch_size 409 | recall = np.round(recall_sum / float(num_sum), 4) 410 | precision = np.round(precision_sum / float(num_sum), 4) 411 | high_recall = np.round(high_recall_sum / float(num_sum), 4) 412 | high_precision = np.round(high_precision_sum / float(num_sum), 4) 413 | print('[Test]') 414 | print('[Recall]', recall) 415 | print('[Precision]', precision) 416 | print('[F1]', (2 * np.array(recall) * np.array(precision)) / (np.array(recall) + np.array(precision))) 417 | print('[HighRecall]', high_recall) 418 | print('[HighPrecision]', high_precision) 419 | print('[HighF1]', (2 * np.array(high_recall) * np.array(high_precision)) / 420 | (np.array(high_recall) + np.array(high_precision))) 421 | 422 | 423 | if __name__ == '__main__': 424 | main() 425 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Caption 2 | 3 | This is an implementation of *Bridging by Word: Image-Grounded Vocabulary Construction for Visual Captioning* based on pytorch 1.0. 4 | 5 | ## Setup 6 | 0. Install python3.6 and pytorch 1.0. 7 | 1. Download the image data of MS COCO to *Data/raw/img*. 8 | 9 | 2. Data processing. You need to create a folder in *Data/train* with the name of *coco_v?* to store the processed data. 10 | ```python3 11 | cd Code 12 | python BuildImgOrderClusterVocab.py 13 | ``` 14 | 15 | 3. Model training. 16 | ```python3 17 | cd Code 18 | python Trainer.py 19 | ``` 20 | --------------------------------------------------------------------------------