├── .DS_Store ├── .gitattributes ├── .gitignore ├── .idea ├── Topic-Guided Attention-For-Image-Captioning.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── attribute ├── attribute_model.py └── prepro_get_attribute.py ├── beamsearch.py ├── core ├── __init__.py ├── __init__.pyc ├── __pycache__ │ └── __init__.cpython-36.pyc ├── beam_search.py ├── beam_search.pyc ├── bleu.py ├── bleu.pyc ├── model.py ├── model.pyc ├── model_0.py ├── model_2.py ├── solver.py ├── solver.pyc ├── solver_0.py ├── solver_2.py ├── utils.py ├── utils.pyc ├── utils_0.py ├── utils_2.py ├── vggnet.py └── vggnet.pyc ├── h5test.py ├── prepro.py ├── prepro_f8.py ├── resize.py ├── test.py ├── test_model.py ├── topic ├── lda_topic.py ├── split_h5.py └── topic_model.py ├── train.py └── zzzz ├── 0 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── caption.txt └── original.jpg ├── 1 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── caption.txt └── original.jpg ├── 2 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 3 ├── 0.png ├── 1.png ├── 10.png ├── 11.png ├── 12.png ├── 13.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 4 ├── 0.png ├── 1.png ├── 10.png ├── 11.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 5 ├── 0.png ├── 1.png ├── 10.png ├── 11.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 6 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 7 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 9 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── caption.txt └── original.jpg ├── 12 ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg ├── 13 ├── 0.png ├── 1.png ├── 10.png ├── 11.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── caption.txt └── original.jpg └── .DS_Store /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /zzzz/ 2 | zzzz/ 3 | /zzzz 4 | -------------------------------------------------------------------------------- /.idea/Topic-Guided Attention-For-Image-Captioning.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 42 | 43 | 48 | 49 | 50 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 78 | 79 | 82 | 83 | 84 | 85 | 88 | 89 | 92 | 93 | 96 | 97 | 98 | 99 | 102 | 103 | 106 | 107 | 110 | 111 | 112 | 113 | 116 | 117 | 120 | 121 | 124 | 125 | 126 | 127 | 130 | 131 | 134 | 135 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 158 | 159 | 160 | 162 | 163 | 164 | 165 | 1556420839035 166 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Topic-Guided-Attention-For-Image-Captioning 2 | -------------------------------------------------------------------------------- /attribute/attribute_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Feb 3 16:36:21 2018 5 | 6 | @author: xz 7 | """ 8 | #!/usr/bin/env python2 9 | # -*- coding: utf-8 -*- 10 | """ 11 | Created on Thu Jan 11 10:31:21 2018 12 | 13 | @author: xz 14 | """ 15 | 16 | 17 | # Imports 18 | import numpy as np 19 | import tensorflow as tf 20 | import pickle 21 | import os 22 | import hickle 23 | import h5py 24 | 25 | def sample_coco_minibatch(topic_data, feature, batch_size): 26 | data_size = feature.shape[0] 27 | mask = np.random.choice(data_size, batch_size) 28 | features = feature[mask] 29 | file_names = topic_data[mask] 30 | return features, file_names 31 | 32 | 33 | def inference(input): 34 | flat = tf.reshape(input, [-1, 14 * 14 * 512]) 35 | 36 | logits = tf.layers.dense(inputs=flat, units=80) 37 | logits = tf.nn.dropout(logits, 0.5) 38 | #logits = tf.layers.dense(inputs=flat, units=80,activation=tf.nn.sigmoid) 39 | logits=tf.nn.softmax(logits) 40 | return logits 41 | 42 | def train(): 43 | 44 | #TODO 45 | image_topic = [] 46 | topic_path = './val.topics.h5' 47 | with h5py.File(topic_path, 'r') as f: 48 | image_topic = np.asarray(f['topics']) 49 | print ('image_topic ok!') 50 | # TODO 51 | features = [] 52 | feature_path = '../data/coco_data/val/val.h5' 53 | with h5py.File(feature_path, 'r') as f: 54 | features = np.asarray(f['features']) 55 | #features = hickle.load(feature_path) 56 | print ('features ok!') 57 | 58 | # features = np.random.rand(5000, 196, 512) 59 | # image_topic = np.random.rand(5000, 80) 60 | 61 | 62 | log_path = './log/' 63 | model_path = './model/' 64 | 65 | n_examples = len(features) 66 | print(n_examples) 67 | batch_size = 100 68 | n_epoch = 20 69 | save_every = 1 70 | 71 | x = tf.placeholder(tf.float32, [None, 196, 512], name='x-input') 72 | _y = tf.placeholder(tf.float32, [None, 80], name='y-input') 73 | y = inference(x) 74 | 75 | loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=_y, logits=y)) / batch_size 76 | optimizer = tf.train.AdamOptimizer(learning_rate=0.000001) 77 | # grads = tf.gradients(loss, tf.trainable_variables()) 78 | # grads_and_vars = list(zip(grads, tf.trainable_variables())) 79 | train_op = optimizer.minimize(loss=loss) 80 | n_iters_per_epoch = int(np.ceil(float(n_examples) / batch_size)) 81 | 82 | tf.summary.scalar('loss', loss) 83 | for var in tf.trainable_variables(): 84 | tf.summary.histogram(var.op.name, var) 85 | # for grad, var in grads_and_vars: 86 | # tf.summary.histogram(var.op.name+'/gradient', grad) 87 | summary_op = tf.summary.merge_all() 88 | 89 | saver = tf.train.Saver() 90 | 91 | with tf.Session() as sess: 92 | tf.global_variables_initializer().run() 93 | summary_writer = tf.summary.FileWriter(log_path, graph=tf.get_default_graph()) 94 | 95 | for e in range(n_epoch): 96 | rand_idxs = np.random.permutation(n_examples) 97 | 98 | for i in range(n_iters_per_epoch): 99 | xs = features[rand_idxs[i * batch_size:(i + 1) * batch_size]] 100 | ys = image_topic[rand_idxs[i * batch_size:(i + 1) * batch_size]] 101 | feed_dict={x: xs, _y: ys} 102 | _, l = sess.run([train_op, loss], feed_dict) 103 | 104 | if i % 40 == 0: 105 | summary = sess.run(summary_op, feed_dict) 106 | summary_writer.add_summary(summary, e * n_iters_per_epoch + i) 107 | #print ("Processed %d features.." % (e * n_iters_per_epoch + i*batch_size)) 108 | 109 | if (e + 1) % save_every == 0: 110 | saver.save(sess, model_path+'model.ckpt', global_step=e + 1) 111 | print("model-%s saved." % (e + 1)) 112 | 113 | def test(): 114 | x = tf.placeholder(tf.float32, [None, 196,512], name='x-input') 115 | # _y = tf.placeholder(tf.float32, [None, 80], name='y-input') 116 | y = inference(x) 117 | #y = tf.sigmoid(y) 118 | #ys = tf.nn.softmax(y) 119 | 120 | features = [] 121 | feature_path = '../data/coco_data/val/val.h5' 122 | with h5py.File(feature_path, 'r') as f: 123 | features = np.asarray(f['features']) 124 | 125 | image_topic = [] 126 | topic_path = './val.topics.h5' 127 | with h5py.File(topic_path, 'r') as f: 128 | image_topic = np.asarray(f['topics']) 129 | 130 | 131 | logs_train_dir='./model/' 132 | saver = tf.train.Saver() 133 | with tf.Session() as sess: 134 | print("Reading checkpoints...") 135 | ckpt = tf.train.get_checkpoint_state(logs_train_dir) 136 | if ckpt and ckpt.model_checkpoint_path: 137 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 138 | saver.restore(sess, ckpt.model_checkpoint_path) 139 | print('Loading success, global_step is %s' % global_step) 140 | else: 141 | print('No checkpoint file found') 142 | 143 | feed_dict = {x: features} 144 | y = sess.run(y,feed_dict) 145 | #print(sess.run(op_to_restore, feed_dict)) 146 | print(y[10]) 147 | print(image_topic[10]) 148 | 149 | 150 | def main(): 151 | test() 152 | 153 | if __name__ == "__main__": 154 | main() 155 | 156 | -------------------------------------------------------------------------------- /attribute/prepro_get_attribute.py: -------------------------------------------------------------------------------- 1 | from scipy import ndimage 2 | from collections import Counter 3 | from core.vggnet import Vgg19 4 | from core.utils import * 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import pandas as pd 9 | import hickle 10 | import os 11 | import json 12 | import h5py 13 | 14 | 15 | def _build_vocab1000(annotations, word_to_idx, threshold=1): 16 | counter = Counter() 17 | max_len = 0 18 | for i, caption in enumerate(annotations['caption']): 19 | words = caption.split(' ') # caption contrains only lower-case words 20 | for w in words: 21 | counter[w] += 1 22 | 23 | if len(caption.split(" ")) > max_len: 24 | max_len = len(caption.split(" ")) 25 | 26 | vocab = [word for word in counter if counter[word] >= threshold] 27 | print('Filtered %d words to %d words with word count threshold %d.' % (len(counter), len(vocab), threshold)) 28 | 29 | vocab1000 = counter.most_common(150) 30 | vocab1000_idx = np.ndarray(100).astype(np.int32) 31 | 32 | for i in range(100): 33 | vocab1000_idx[i] = word_to_idx[vocab1000[i+50][0]] 34 | return vocab1000_idx 35 | 36 | def _get_attribute(annotations, word_to_idx, split, vocab1000): 37 | image_path = list(annotations['file_name'].unique()) 38 | n_examples = len(image_path) 39 | length = n_examples 40 | image_ids = {} 41 | 42 | 43 | 44 | image_attribute = np.zeros((length, 100)).astype(np.int32) 45 | i = -1 46 | 47 | for caption, image_id in zip(annotations['caption'], annotations['image_id']): 48 | if not image_id in image_ids: 49 | image_ids[image_id] = 0 50 | i += 1 51 | 52 | words = caption.split(" ") 53 | 54 | for word in words: 55 | if word in word_to_idx: 56 | if word_to_idx[word] in vocab1000: 57 | pos = np.argwhere(vocab1000 == word_to_idx[word]) 58 | image_attribute[i][pos] = 1 59 | return image_attribute 60 | 61 | 62 | 63 | 64 | def main(): 65 | max_length = 15 66 | word_count_threshold = 1 67 | 68 | for split in ['train','test','val']: 69 | annotations = load_pickle('../data/f8_data/%s/%s.annotations.pkl' % (split, split)) 70 | 71 | word_to_idx = load_pickle('../data/f8_data/%s/word_to_idx.pkl' % ('train')) 72 | 73 | if split == 'train': 74 | word_to_idx = load_pickle('../data/f8_data/%s/word_to_idx.pkl' % ('train')) 75 | vocab1000 = _build_vocab1000(annotations=annotations, word_to_idx=word_to_idx, threshold=word_count_threshold) 76 | 77 | save_pickle(vocab1000, '../data/f8_data/%s/%s.pkl' % ('train', 'vocab1000')) 78 | 79 | attributes = _get_attribute(annotations, word_to_idx, split, vocab1000) 80 | 81 | save_path = h5py.File('../data/f8_data/%s/%s.attributes.h5' %(split, split),'w') 82 | save_path.create_dataset("attributes", data=attributes) 83 | 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | main() -------------------------------------------------------------------------------- /beamsearch.py: -------------------------------------------------------------------------------- 1 | #import os 2 | #import shutil 3 | # 4 | #oldname="./image/COCO_val2014_000000081264.jpg" 5 | # 6 | #newname= "./COCO_val2014_000000081264.jpg" 7 | # 8 | #shutil.copyfile(oldname,newname) 9 | import shutil 10 | import os 11 | 12 | txtName = "./zzzz/caption.txt" 13 | f=file(txtName, "a+") 14 | f.write('1') 15 | f.close() 16 | 17 | 18 | 19 | #import numpy as np 20 | #import tensorflow as tf 21 | # 22 | #xz1=[[0.9],[0.7]] 23 | # 24 | #xz=[0.9,0.7] 25 | #A = [0.8,0.6,0.3] 26 | #B = [0.9,0.5,0.7] 27 | #C = [A,B] 28 | #C = tf.convert_to_tensor(C) 29 | #k=2 30 | #a = {} 31 | #with tf.Session() as sess: 32 | # 33 | # 34 | # out,xx= tf.nn.top_k(C[0], 1) 35 | # print out 36 | # sess.run(tf.initialize_all_variables()) 37 | # print sess.run(out) 38 | # print sess.run(xx) 39 | #with tf.Session() as sess: 40 | # for i in range(k): 41 | # 42 | # out,xx= tf.nn.top_k(C[i], k) 43 | # print out 44 | # sess.run(tf.initialize_all_variables()) 45 | # print sess.run(out[0]) 46 | # 47 | # for j in range(k): 48 | # if j==0: 49 | # a[xz[i]]=[list([out[j],xx[j]])] 50 | # else: 51 | # a[xz[i]].append(list([out[j],xx[j]])) 52 | # 53 | # aaa = sess.run(a) 54 | ## print (aaa) 55 | ## for key,value in aaa.items(): 56 | ## print ('key is %s,value is %s'%(key,value)) 57 | # 58 | ## print (aaa[xz[0]][0][0]) 59 | ## print (aaa[xz[1]]) 60 | # 61 | # sort = [] 62 | # sor = {} 63 | # for x in range(k): 64 | # for y in range(k): 65 | # sort.append(xz[x]*(aaa[xz[x]][y][0])) 66 | # #sor[aaa[xz[x]][y][1]]=xz[x]*(aaa[xz[x]][y][0]) 67 | # sor[xz[x]*(aaa[xz[x]][y][0])]=[xz[x],aaa[xz[x]][y][1]] 68 | # for key,value in sor.items(): 69 | # print key 70 | # xzz = sorted(sor.items(), key=lambda item:item[0],reverse=True)# sorted return a list[] 71 | # print xzz 72 | # for qq in range(len(xz1)): 73 | # for kk in range(k): 74 | # if xz1[qq][0] == xzz[kk][1][0]: 75 | # xz1[qq].append(xzz[kk][1][1]) 76 | # print xz1 77 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/__init__.py -------------------------------------------------------------------------------- /core/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/__init__.pyc -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/beam_search.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | #demo of beam search for seq2seq model 3 | import numpy as np 4 | import random 5 | vocab = { 6 | 0: 'a', 7 | 1: 'b', 8 | 2: 'c', 9 | 3: 'd', 10 | 4: 'e', 11 | 5: 'BOS', 12 | 6: 'EOS' 13 | } 14 | reverse_vocab = dict([(v,k) for k,v in vocab.items()]) 15 | vocab_size = len(vocab.items()) 16 | def softmax(x): 17 | """Compute softmax values for each sets of scores in x.""" 18 | e_x = np.exp(x - np.max(x)) 19 | return e_x / e_x.sum() 20 | def reduce_mul(l): 21 | out = 1.0 22 | for x in l: 23 | out *= x 24 | return out 25 | def check_all_done(seqs): 26 | for seq in seqs: 27 | if not seq[-1]: 28 | return False 29 | return True 30 | 31 | def decode_step(encoder_context, input_seq): 32 | #encoder_context contains infortaion of encoder 33 | #ouput_step contains the words' probability 34 | #these two varibles should be generated by seq2seq model 35 | words_prob = [random.random() for _ in range(vocab_size)] 36 | #downvote BOS 37 | words_prob[reverse_vocab['BOS']] = 0.0 38 | words_prob = softmax(words_prob) 39 | ouput_step = [(idx,prob) for idx,prob in enumerate(words_prob)] 40 | ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True) 41 | return ouput_step 42 | #seq: [[word,word],[word,word],[word,word]] 43 | #output: [[word,word,word],[word,word,word],[word,word,word]] 44 | def beam_search_step(encoder_context, top_seqs, k): 45 | all_seqs = [] 46 | for seq in top_seqs: 47 | seq_score = reduce_mul([_score for _,_score in seq]) 48 | if seq[-1][0] == reverse_vocab['EOS']: 49 | all_seqs.append((seq, seq_score, True)) 50 | continue 51 | #get current step using encoder_context & seq 52 | current_step = decode_step(encoder_context, seq) 53 | for i,word in enumerate(current_step): 54 | if i >= k: 55 | break 56 | word_index = word[0] 57 | word_score = word[1] 58 | score = seq_score * word_score 59 | rs_seq = seq + [word] 60 | done = (word_index == reverse_vocab['EOS']) 61 | all_seqs.append((rs_seq, score, done)) 62 | all_seqs = sorted(all_seqs, key = lambda seq: seq[1], reverse=True) 63 | topk_seqs = [seq for seq,_,_ in all_seqs[:k]] 64 | all_done = check_all_done(topk_seqs) 65 | return topk_seqs, all_done 66 | def beam_search(encoder_context): 67 | beam_size = 3 68 | max_len = 10 69 | #START 70 | top_seqs = [[(reverse_vocab['BOS'],1.0)]] 71 | #loop 72 | for _ in range(max_len): 73 | top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_size) 74 | if all_done: 75 | break 76 | return top_seqs 77 | if __name__ == '__main__': 78 | #encoder_context is not inportant in this demo 79 | encoder_context = None 80 | top_seqs = beam_search(encoder_context) 81 | for i,seq in enumerate(top_seqs): 82 | print 'Path[%d]: ' % i 83 | for word in seq[1:]: 84 | word_index = word[0] 85 | word_prob = word[1] 86 | print '%s(%.4f)' % (vocab[word_index], word_prob), 87 | if word_index == reverse_vocab['EOS']: 88 | break 89 | print '\n' -------------------------------------------------------------------------------- /core/beam_search.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/beam_search.pyc -------------------------------------------------------------------------------- /core/bleu.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import os 3 | import sys 4 | sys.path.append('../cococaption') 5 | from cococaption.pycocoevalcap.bleu.bleu import Bleu 6 | from cococaption.pycocoevalcap.rouge.rouge import Rouge 7 | from cococaption.pycocoevalcap.cider.cider import Cider 8 | from cococaption.pycocoevalcap.meteor.meteor import Meteor 9 | 10 | def score(ref, hypo): 11 | scorers = [ 12 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]), 13 | (Meteor(),"METEOR"), 14 | (Rouge(),"ROUGE_L"), 15 | (Cider(),"CIDEr") 16 | ] 17 | final_scores = {} 18 | for scorer,method in scorers: 19 | score,scores = scorer.compute_score(ref,hypo) 20 | if type(score)==list: 21 | for m,s in zip(method,score): 22 | final_scores[m] = s 23 | else: 24 | final_scores[method] = score 25 | 26 | return final_scores 27 | 28 | 29 | def evaluate(data_path='./data', split='val', get_scores=False): 30 | reference_path = os.path.join(data_path, "%s/%s.references.pkl" %(split, split)) 31 | candidate_path = os.path.join(data_path, "%s/%s.candidate.captions.pkl" %(split, split)) 32 | 33 | # load caption data 34 | with open(reference_path, 'rb') as f: 35 | ref = pickle.load(f) 36 | with open(candidate_path, 'rb') as f: 37 | cand = pickle.load(f) 38 | 39 | # make dictionary 40 | hypo = {} 41 | for i, caption in enumerate(cand): 42 | hypo[i] = [caption] 43 | 44 | # compute bleu score 45 | final_scores = score(ref, hypo) 46 | 47 | # print out scores 48 | print 'Bleu_1:\t',final_scores['Bleu_1'] 49 | print 'Bleu_2:\t',final_scores['Bleu_2'] 50 | print 'Bleu_3:\t',final_scores['Bleu_3'] 51 | print 'Bleu_4:\t',final_scores['Bleu_4'] 52 | print 'METEOR:\t',final_scores['METEOR'] 53 | print 'ROUGE_L:',final_scores['ROUGE_L'] 54 | print 'CIDEr:\t',final_scores['CIDEr'] 55 | 56 | if get_scores: 57 | return final_scores 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /core/bleu.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/bleu.pyc -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | # ========================================================================================= 2 | # Implementation of "Show, Attend and Tell: Neural Caption Generator With Visual Attention". 3 | # There are some notations. 4 | # N is batch size. 5 | # L is spacial size of feature vector (196). 6 | # D is dimension of image feature vector (512). 7 | # T is the number of time step which is equal to caption's length-1 (16). 8 | # V is vocabulary size (about 10000). 9 | # M is dimension of word vector which is embedding size (default is 512). 10 | # H is dimension of hidden state (default is 1024). 11 | # ========================================================================================= 12 | 13 | from __future__ import division 14 | 15 | import tensorflow as tf 16 | 17 | 18 | class CaptionGenerator(object): 19 | def __init__(self, word_to_idx, dim_feature=[196, 512], att_len=1000, len_T=80, dim_embed=512, dim_hidden=1024, 20 | n_time_step=16, 21 | prev2out=True, ctx2out=True, alpha_c=0.0, selector=True, dropout=True): 22 | """ 23 | Args: 24 | word_to_idx: word-to-index mapping dictionary. 25 | dim_feature: (optional) Dimension of vggnet19 conv5_3 feature vectors. 26 | dim_embed: (optional) Dimension of word embedding. 27 | dim_hidden: (optional) Dimension of all hidden state. 28 | n_time_step: (optional) Time step size of LSTM. 29 | prev2out: (optional) previously generated word to hidden state. (see Eq (7) for explanation) 30 | ctx2out: (optional) context to hidden state (see Eq (7) for explanation) 31 | alpha_c: (optional) Doubly stochastic regularization coefficient. (see Section (4.2.1) for explanation) 32 | selector: (optional) gating scalar for context vector. (see Section (4.2.1) for explanation) 33 | dropout: (optional) If true then dropout layer is added. 34 | """ 35 | 36 | self.word_to_idx = word_to_idx 37 | self.idx_to_word = {i: w for w, i in word_to_idx.iteritems()} 38 | self.prev2out = prev2out 39 | self.ctx2out = ctx2out 40 | self.alpha_c = alpha_c 41 | self.selector = selector 42 | self.dropout = dropout 43 | self.len_T = len_T 44 | self.A = att_len 45 | self.V = len(word_to_idx) 46 | self.L = dim_feature[0] 47 | self.D = dim_feature[1] 48 | self.M = dim_embed 49 | self.H = dim_hidden 50 | self.T = n_time_step 51 | self._start = word_to_idx[''] 52 | self._null = word_to_idx[''] 53 | 54 | self.weight_initializer = tf.contrib.layers.xavier_initializer() 55 | self.const_initializer = tf.constant_initializer(0.0) 56 | self.emb_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0) 57 | 58 | # Place holder for features and captions 59 | self.features = tf.placeholder(tf.float32, [None, self.L, self.D]) 60 | self.topics = tf.placeholder(tf.float32, [None, self.len_T]) 61 | self.attributes = tf.placeholder(tf.float32, [None, self.A]) 62 | self.captions = tf.placeholder(tf.int32, [None, self.T + 1]) 63 | 64 | def _get_initial_lstm(self, features): 65 | with tf.variable_scope('initial_lstm'): 66 | features_mean = tf.reduce_mean(features, 1) 67 | 68 | w_h = tf.get_variable('w_h', [self.D, self.H], initializer=self.weight_initializer) 69 | b_h = tf.get_variable('b_h', [self.H], initializer=self.const_initializer) 70 | h = tf.nn.tanh(tf.matmul(features_mean, w_h) + b_h) 71 | 72 | w_c = tf.get_variable('w_c', [self.D, self.H], initializer=self.weight_initializer) 73 | b_c = tf.get_variable('b_c', [self.H], initializer=self.const_initializer) 74 | c = tf.nn.tanh(tf.matmul(features_mean, w_c) + b_c) 75 | return c, h 76 | 77 | 78 | def _word_embedding(self, inputs, reuse=None): 79 | with tf.variable_scope('word_embedding', reuse=reuse): 80 | w = tf.get_variable('w', [self.V, self.M], initializer=self.emb_initializer) 81 | x = tf.nn.embedding_lookup(w, inputs, name='word_vector') # (N, T, M) or (N, M) 82 | return x 83 | 84 | 85 | def _project_features(self, features): 86 | with tf.variable_scope('project_features'): 87 | w = tf.get_variable('w', [self.D, self.D], initializer=self.weight_initializer) 88 | features_flat = tf.reshape(features, [-1, self.D]) 89 | features_proj = tf.matmul(features_flat, w) 90 | features_proj = tf.reshape(features_proj, [-1, self.L, self.D]) 91 | return features_proj 92 | 93 | 94 | def _project_attributes(self, attributes): 95 | with tf.variable_scope('project_attributes'): 96 | w = tf.get_variable('w', [self.A, self.A], initializer=self.weight_initializer) 97 | attributes_flat = tf.reshape(attributes, [-1, self.A]) 98 | attributes_proj = tf.matmul(attributes_flat, w) 99 | attributes_proj = tf.reshape(attributes_proj, [-1, self.A]) 100 | return attributes_proj 101 | 102 | 103 | # original_model 104 | 105 | def _attention_layer(self, features, features_proj, topic, h, reuse=False): 106 | with tf.variable_scope('attention_layer', reuse=reuse): 107 | w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer) 108 | b = tf.get_variable('b', [self.D], initializer=self.const_initializer) 109 | w_topic = tf.get_variable('w_topic', [self.len_T, self.D], initializer=self.weight_initializer) 110 | 111 | w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer) 112 | 113 | h_att = tf.nn.relu( 114 | features_proj + tf.expand_dims(tf.matmul(h, w), 1) + tf.expand_dims(tf.matmul(topic, w_topic), 115 | 1) + b) 116 | out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L) 117 | alpha = tf.nn.softmax(out_att) 118 | context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') # (N, D) 119 | return context, alpha 120 | 121 | def _attributes_attention_layer(self, attributes, attributes_proj, topic, h, reuse=False): 122 | with tf.variable_scope('attribute_attention_layer', reuse=reuse): 123 | w = tf.get_variable('w', [self.H, self.A], initializer=self.weight_initializer) 124 | b = tf.get_variable('b', [self.A], initializer=self.const_initializer) 125 | w_topic = tf.get_variable('w_topic', [self.len_T, self.A], initializer=self.weight_initializer) 126 | w_att = tf.get_variable('w_att', [self.A, 1], initializer=self.weight_initializer) 127 | h_att = tf.nn.relu( 128 | attributes_proj + tf.matmul(h, w) + tf.matmul(topic, w_topic) + b) 129 | out_att = tf.matmul(h_att, w_att) 130 | alpha = tf.nn.softmax(out_att) 131 | Attributes_context = attributes * alpha 132 | return Attributes_context 133 | 134 | def f_attention_layer(self, f_decoded, h, reuse=False): 135 | with tf.variable_scope('f_attention_layer', reuse=reuse): 136 | w = tf.get_variable('w', [self.H, self.M], initializer=self.weight_initializer) 137 | b = tf.get_variable('b', [self.M], initializer=self.const_initializer) 138 | 139 | w_att = tf.get_variable('w_att', [self.M, 1], initializer=self.weight_initializer) 140 | 141 | h_att = tf.nn.relu( 142 | f_decoded + tf.expand_dims(tf.matmul(h, w), 1) + b) 143 | out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.M]), w_att), [-1, self.T]) # (N, L) 144 | alpha = tf.nn.softmax(out_att) 145 | f_context = tf.reduce_sum(f_decoded * tf.expand_dims(alpha, 2), 1, name='f_context') # (N, D) 146 | return f_context 147 | 148 | 149 | # base_model 150 | # def _attention_layer(self, features, features_proj, topic, h, reuse=False): 151 | # with tf.variable_scope('attention_layer', reuse=reuse): 152 | # w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer) 153 | # b = tf.get_variable('b', [self.D], initializer=self.const_initializer) 154 | # #w_topic = tf.get_variable('w_topic', [self.len_T, self.D], initializer=self.weight_initializer) 155 | # 156 | # w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer) 157 | # 158 | # h_att = tf.nn.relu( 159 | # features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b) 160 | # out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L) 161 | # alpha = tf.nn.softmax(out_att) 162 | # context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') # (N, D) 163 | # return context, alpha 164 | # 165 | # def _attributes_attention_layer(self, attributes, attributes_proj, topic, h, reuse=False): 166 | # with tf.variable_scope('attribute_attention_layer', reuse=reuse): 167 | # w = tf.get_variable('w', [self.H, self.A], initializer=self.weight_initializer) 168 | # b = tf.get_variable('b', [self.A], initializer=self.const_initializer) 169 | # w_topic = tf.get_variable('w_topic', [self.len_T, self.A], initializer=self.weight_initializer) 170 | # 171 | # w_att = tf.get_variable('w_att', [self.A, 1], initializer=self.weight_initializer) 172 | # 173 | # h_att = tf.nn.relu( 174 | # attributes_proj + tf.matmul(h, w) + tf.matmul(topic, w_topic) + b) 175 | # out_att = tf.matmul(h_att, w_att) 176 | # alpha = tf.nn.softmax(out_att) 177 | # Attributes_context = attributes * alpha 178 | # return Attributes_context 179 | 180 | def _selector(self, context, h, reuse=False): 181 | with tf.variable_scope('selector', reuse=reuse): 182 | w = tf.get_variable('w', [self.H, 1], initializer=self.weight_initializer) 183 | b = tf.get_variable('b', [1], initializer=self.const_initializer) 184 | beta = tf.nn.sigmoid(tf.matmul(h, w) + b, 'beta') # (N, 1) 185 | context = tf.multiply(beta, context, name='selected_context') 186 | return context, beta 187 | 188 | 189 | def _decode_lstm(self, x, h, context, dropout=False, reuse=False): 190 | with tf.variable_scope('logits', reuse=reuse): 191 | w_h = tf.get_variable('w_h', [self.H, self.M], initializer=self.weight_initializer) 192 | b_h = tf.get_variable('b_h', [self.M], initializer=self.const_initializer) 193 | w_out = tf.get_variable('w_out', [self.M, self.V], initializer=self.weight_initializer) 194 | b_out = tf.get_variable('b_out', [self.V], initializer=self.const_initializer) 195 | 196 | if dropout: 197 | h = tf.nn.dropout(h, 0.5) 198 | h_logits = tf.matmul(h, w_h) + b_h 199 | 200 | if self.ctx2out: 201 | w_ctx2out = tf.get_variable('w_ctx2out', [self.D, self.M], initializer=self.weight_initializer) 202 | h_logits += tf.matmul(context, w_ctx2out) 203 | 204 | if self.prev2out: 205 | h_logits += x 206 | h_logits = tf.nn.tanh(h_logits) 207 | 208 | if dropout: 209 | h_logits = tf.nn.dropout(h_logits, 0.5) 210 | out_logits = tf.matmul(h_logits, w_out) + b_out 211 | return out_logits 212 | 213 | 214 | def _batch_norm(self, x, mode='train', name=None): 215 | return tf.contrib.layers.batch_norm(inputs=x, 216 | decay=0.95, 217 | center=True, 218 | scale=True, 219 | is_training=(mode == 'train'), 220 | updates_collections=None, 221 | scope=(name + 'batch_norm')) 222 | 223 | 224 | def build_model(self): 225 | features = self.features 226 | attributes = self.attributes 227 | topic = self.topics 228 | captions = self.captions 229 | batch_size = tf.shape(features)[0] 230 | 231 | captions_in = captions[:, :self.T] 232 | captions_out = captions[:, 1:] 233 | mask = tf.to_float(tf.not_equal(captions_out, self._null)) 234 | 235 | # batch normalize feature vectors 236 | features = self._batch_norm(features, mode='train', name='conv_features') 237 | 238 | c, h = self._get_initial_lstm(features=features) 239 | x = self._word_embedding(inputs=captions_in) 240 | features_proj = self._project_features(features=features) 241 | attributes_proj = self._project_attributes(attributes=attributes) 242 | 243 | loss = 0.0 244 | alpha_list = [] 245 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.H) 246 | 247 | for t in range(self.T): 248 | context, alpha = self._attention_layer(features, features_proj, topic, h, reuse=(t != 0)) 249 | attributes_context = self._attributes_attention_layer(attributes, attributes_proj, topic, h, reuse=(t != 0)) 250 | 251 | f_context = self.f_attention_layer(x, h, reuse=(t != 0)) 252 | alpha_list.append(alpha) 253 | 254 | if self.selector: 255 | context, beta = self._selector(context, h, reuse=(t != 0)) 256 | 257 | with tf.variable_scope('lstm', reuse=(t != 0)): 258 | _, (c, h) = lstm_cell(inputs=tf.concat([x[:, t, :], context, attributes_context, f_context], 1), state=[c, h]) 259 | print("yes") 260 | logits = self._decode_lstm(x[:, t, :], h, context, dropout=self.dropout, reuse=(t != 0)) 261 | loss += tf.reduce_sum( 262 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=captions_out[:, t]) * mask[:, t]) 263 | 264 | if self.alpha_c > 0: 265 | alphas = tf.transpose(tf.stack(alpha_list), (1, 0, 2)) # (N, T, L) 266 | alphas_all = tf.reduce_sum(alphas, 1) # (N, L) 267 | alpha_reg = self.alpha_c * tf.reduce_sum((16. / 196 - alphas_all) ** 2) 268 | loss += alpha_reg 269 | 270 | return loss / tf.to_float(batch_size) 271 | 272 | 273 | def build_sampler(self, max_len=20): 274 | features = self.features 275 | attributes = self.attributes 276 | topic = self.topics 277 | captions = self.captions 278 | captions_in = captions[:, :self.T] 279 | 280 | 281 | # batch normalize feature vectors 282 | features = self._batch_norm(features, mode='test', name='conv_features') 283 | 284 | c, h = self._get_initial_lstm(features=features) 285 | features_proj = self._project_features(features=features) 286 | attributes_proj = self._project_attributes(attributes=attributes) 287 | 288 | sampled_word_list = [] 289 | alpha_list = [] 290 | beta_list = [] 291 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.H) 292 | 293 | for t in range(max_len): 294 | if t == 0: 295 | x = self._word_embedding(inputs=tf.fill([tf.shape(features)[0]], self._start)) 296 | else: 297 | x = self._word_embedding(inputs=sampled_word, reuse=True) 298 | 299 | context, alpha = self._attention_layer(features, features_proj, topic, h, reuse=(t != 0)) 300 | attributes_context = self._attributes_attention_layer(attributes, attributes_proj, topic, h, reuse=(t != 0)) 301 | 302 | f_context = self.f_attention_layer(self._word_embedding(inputs=captions_in, reuse = True), h, reuse=(t != 0)) 303 | 304 | alpha_list.append(alpha) 305 | 306 | if self.selector: 307 | context, beta = self._selector(context, h, reuse=(t != 0)) 308 | beta_list.append(beta) 309 | 310 | with tf.variable_scope('lstm', reuse=(t != 0)): 311 | _, (c, h) = lstm_cell(inputs=tf.concat([x, context, attributes_context, f_context], 1), state=[c, h]) 312 | 313 | logits = self._decode_lstm(x, h, context, reuse=(t != 0)) 314 | 315 | sampled_word = tf.argmax(logits, 1) 316 | sampled_word_list.append(sampled_word) 317 | 318 | alphas = tf.transpose(tf.stack(alpha_list), (1, 0, 2)) # (N, T, L) 319 | betas = tf.transpose(tf.squeeze(beta_list), (1, 0)) # (N, T) 320 | sampled_captions = tf.transpose(tf.stack(sampled_word_list), (1, 0)) # (N, max_len) 321 | return alphas, betas, sampled_captions 322 | 323 | -------------------------------------------------------------------------------- /core/model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/model.pyc -------------------------------------------------------------------------------- /core/model_0.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | import skimage.transform 4 | import numpy as np 5 | import time 6 | import os 7 | import cPickle as pickle 8 | from scipy import ndimage 9 | from utils import * 10 | from bleu import evaluate 11 | 12 | 13 | class CaptioningSolver(object): 14 | def __init__(self, model, data, val_data, **kwargs): 15 | """ 16 | Required Arguments: 17 | - model: Show Attend and Tell caption generating model 18 | - data: Training data; dictionary with the following keys: 19 | - features: Feature vectors of shape (82783, 196, 512) 20 | - file_names: Image file names of shape (82783, ) 21 | - captions: Captions of shape (400000, 17) 22 | - image_idxs: Indices for mapping caption to image of shape (400000, ) 23 | - word_to_idx: Mapping dictionary from word to index 24 | - val_data: validation data; for print out BLEU scores for each epoch. 25 | Optional Arguments: 26 | - n_epochs: The number of epochs to run for training. 27 | - batch_size: Mini batch size. 28 | - update_rule: A string giving the name of an update rule 29 | - learning_rate: Learning rate; default value is 0.01. 30 | - print_every: Integer; training losses will be printed every print_every iterations. 31 | - save_every: Integer; model variables will be saved every save_every epoch. 32 | - pretrained_model: String; pretrained model path 33 | - model_path: String; model path for saving 34 | - test_model: String; model path for test 35 | """ 36 | 37 | self.model = model 38 | self.data = data 39 | self.val_data = val_data 40 | self.n_epochs = kwargs.pop('n_epochs', 10) 41 | self.batch_size = kwargs.pop('batch_size', 100) 42 | self.update_rule = kwargs.pop('update_rule', 'adam') 43 | self.learning_rate = kwargs.pop('learning_rate', 0.01) 44 | self.print_bleu = kwargs.pop('print_bleu', False) 45 | self.print_every = kwargs.pop('print_every', 100) 46 | self.save_every = kwargs.pop('save_every', 1) 47 | self.log_path = kwargs.pop('log_path', './log/') 48 | self.model_path = kwargs.pop('model_path', './model/') 49 | self.pretrained_model = kwargs.pop('pretrained_model', None) 50 | self.test_model = kwargs.pop('test_model', './model/lstm/model-1') 51 | 52 | # set an optimizer by update rule 53 | if self.update_rule == 'adam': 54 | self.optimizer = tf.train.AdamOptimizer 55 | elif self.update_rule == 'momentum': 56 | self.optimizer = tf.train.MomentumOptimizer 57 | elif self.update_rule == 'rmsprop': 58 | self.optimizer = tf.train.RMSPropOptimizer 59 | 60 | if not os.path.exists(self.model_path): 61 | os.makedirs(self.model_path) 62 | if not os.path.exists(self.log_path): 63 | os.makedirs(self.log_path) 64 | 65 | 66 | def train(self): 67 | # train/val dataset 68 | n_examples = self.data['captions'].shape[0] 69 | n_iters_per_epoch = int(np.ceil(float(n_examples)/self.batch_size)) 70 | features = self.data['features'] 71 | attributes = self.data['attributes'] 72 | topics = self.data['topics'] 73 | captions = self.data['captions'] 74 | image_idxs = self.data['image_idxs'] 75 | val_features = self.val_data['features'] 76 | val_attributes = self.val_data['attributes'] 77 | val_topics = self.val_data['topics'] 78 | 79 | n_iters_val = int(np.ceil(float(val_features.shape[0])/self.batch_size)) 80 | 81 | # build graphs for training model and sampling captions 82 | loss = self.model.build_model() 83 | # tf.get_variable_scope().reuse_variables() 84 | # _, _, generated_captions = self.model.build_sampler(max_len=20) 85 | # 86 | # # train op 87 | # with tf.name_scope('optimizer'): 88 | # optimizer = self.optimizer(learning_rate=self.learning_rate) 89 | # grads = tf.gradients(loss, tf.trainable_variables()) 90 | # grads_and_vars = list(zip(grads, tf.trainable_variables())) 91 | # train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars) 92 | with tf.variable_scope(tf.get_variable_scope()) as scope: 93 | with tf.name_scope('optimizer'): 94 | tf.get_variable_scope().reuse_variables() 95 | _, _, generated_captions = self.model.build_sampler(max_len=20) 96 | optimizer = self.optimizer(learning_rate=self.learning_rate) 97 | grads = tf.gradients(loss, tf.trainable_variables()) 98 | grads_and_vars = list(zip(grads, tf.trainable_variables())) 99 | train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars) 100 | 101 | # summary op 102 | tf.summary.scalar('batch_loss', loss) 103 | for var in tf.trainable_variables(): 104 | tf.summary.histogram(var.op.name, var) 105 | for grad, var in grads_and_vars: 106 | tf.summary.histogram(var.op.name+'/gradient', grad) 107 | 108 | summary_op = tf.summary.merge_all() 109 | 110 | print "The number of epoch: %d" %self.n_epochs 111 | print "Data size: %d" %n_examples 112 | print "Batch size: %d" %self.batch_size 113 | print "Iterations per epoch: %d" %n_iters_per_epoch 114 | 115 | config = tf.ConfigProto(allow_soft_placement = True) 116 | #config.gpu_options.per_process_gpu_memory_fraction=0.9 117 | config.gpu_options.allow_growth = True 118 | with tf.Session(config=config) as sess: 119 | tf.initialize_all_variables().run() 120 | summary_writer = tf.summary.FileWriter(self.log_path, graph=tf.get_default_graph()) 121 | saver = tf.train.Saver(max_to_keep=40) 122 | 123 | if self.pretrained_model is not None: 124 | print "Start training with pretrained Model.." 125 | saver.restore(sess, self.pretrained_model) 126 | 127 | prev_loss = -1 128 | curr_loss = 0 129 | start_t = time.time() 130 | 131 | for e in range(self.n_epochs): 132 | rand_idxs = np.random.permutation(n_examples) 133 | captions = captions[rand_idxs] 134 | image_idxs = image_idxs[rand_idxs] 135 | 136 | for i in range(n_iters_per_epoch): 137 | captions_batch = captions[i*self.batch_size:(i+1)*self.batch_size] 138 | image_idxs_batch = image_idxs[i*self.batch_size:(i+1)*self.batch_size] 139 | features_batch = features[image_idxs_batch] 140 | attributes_batch = attributes[image_idxs_batch] 141 | topics_batch = topics[image_idxs_batch] 142 | feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch, self.model.captions: captions_batch} 143 | 144 | _, l = sess.run([train_op, loss], feed_dict) 145 | curr_loss += l 146 | 147 | # write summary for tensorboard visualization 148 | if i % 10 == 0: 149 | summary = sess.run(summary_op, feed_dict) 150 | summary_writer.add_summary(summary, e*n_iters_per_epoch + i) 151 | 152 | if (i+1) % self.print_every == 0: 153 | print "\nTrain loss at epoch %d & iteration %d (mini-batch): %.5f" %(e+1, i+1, l) 154 | ground_truths = captions[image_idxs == image_idxs_batch[0]] 155 | decoded = decode_captions(ground_truths, self.model.idx_to_word) 156 | for j, gt in enumerate(decoded): 157 | print "Ground truth %d: %s" %(j+1, gt) 158 | gen_caps = sess.run(generated_captions, feed_dict) 159 | decoded = decode_captions(gen_caps, self.model.idx_to_word) 160 | print "Generated caption: %s\n" %decoded[0] 161 | 162 | print "Previous epoch loss: ", prev_loss 163 | print "Current epoch loss: ", curr_loss 164 | print "Elapsed time: ", time.time() - start_t 165 | prev_loss = curr_loss 166 | curr_loss = 0 167 | 168 | # print out BLEU scores and file write 169 | # if self.print_bleu: 170 | # all_gen_cap = np.ndarray((val_features.shape[0], 20)) 171 | # 172 | # for i in range(n_iters_val): 173 | # features_batch = val_features[i*self.batch_size:(i+1)*self.batch_size] 174 | # attributes_batch = val_attributes[i * self.batch_size:(i + 1) * self.batch_size] 175 | # topics_batch = val_topics[i * self.batch_size:(i + 1) * self.batch_size] 176 | # feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch} 177 | # gen_cap = sess.run(generated_captions, feed_dict=feed_dict) 178 | # all_gen_cap[i*self.batch_size:(i+1)*self.batch_size] = gen_cap 179 | # 180 | # all_decoded = decode_captions(all_gen_cap, self.model.idx_to_word) 181 | # save_pickle(all_decoded, "./data/val/val.candidate.captions.pkl") 182 | # scores = evaluate(data_path='./data', split='val', get_scores=True) 183 | # write_bleu(scores=scores, path=self.model_path, epoch=e) 184 | # 185 | # save model's parameters 186 | if (e+1) % self.save_every == 0: 187 | saver.save(sess, os.path.join(self.model_path, 'model'), global_step=e+1) 188 | print "model-%s saved." %(e+1) 189 | 190 | 191 | def test(self, data, split='train', attention_visualization=True, save_sampled_captions=True): 192 | ''' 193 | Args: 194 | - data: dictionary with the following keys: 195 | - features: Feature vectors of shape (5000, 196, 512) 196 | - file_names: Image file names of shape (5000, ) 197 | - captions: Captions of shape (24210, 17) 198 | - image_idxs: Indices for mapping caption to image of shape (24210, ) 199 | - features_to_captions: Mapping feature to captions (5000, 4~5) 200 | - split: 'train', 'val' or 'test' 201 | - attention_visualization: If True, visualize attention weights with images for each sampled word. (ipthon notebook) 202 | - save_sampled_captions: If True, save sampled captions to pkl file for computing BLEU scores. 203 | ''' 204 | 205 | features = data['features'] 206 | attributes = data['attributes'] 207 | topics = data['topics'] 208 | 209 | # build a graph to sample captions 210 | alphas, betas, sampled_captions = self.model.build_sampler(max_len=20) # (N, max_len, L), (N, max_len) 211 | 212 | config = tf.ConfigProto(allow_soft_placement=True) 213 | config.gpu_options.allow_growth = True 214 | with tf.Session(config=config) as sess: 215 | 216 | saver = tf.train.Saver() 217 | 218 | print("Reading checkpoints...")# 219 | # ckpt = tf.train.get_checkpoint_state(self.model_path)# 220 | # if ckpt and ckpt.model_checkpoint_path:# 221 | # print 11 222 | # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]# 223 | # print global_step 224 | # #saver.restore(sess, ckpt.model_checkpoint_path)# 225 | # saver.restore(sess,self.test_model) 226 | # print('Loading success, global_step is %s' % global_step)# 227 | # else:# 228 | # print('No checkpoint file found')# 229 | 230 | saver.restore(sess,self.test_model) 231 | print "ok" 232 | features_batch, attributes_batch, topics_batch, image_files = sample_coco_minibatch(data, self.batch_size) 233 | feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, 234 | self.model.topics: topics_batch} 235 | alps, bts, sam_cap = sess.run([alphas, betas, sampled_captions], feed_dict) # (N, max_len, L), (N, max_len) 236 | decoded = decode_captions(sam_cap, self.model.idx_to_word) 237 | 238 | if attention_visualization: 239 | for n in range(10): 240 | print "Sampled Caption: %s" %decoded[n] 241 | 242 | # Plot original image 243 | img = ndimage.imread(image_files[n]) 244 | plt.subplot(4, 5, 1) 245 | plt.imshow(img) 246 | plt.axis('off') 247 | 248 | # Plot images with attention weights 249 | words = decoded[n].split(" ") 250 | for t in range(len(words)): 251 | if t > 18: 252 | break 253 | plt.subplot(4, 5, t+2) 254 | plt.text(0, 1, '%s(%.2f)'%(words[t], bts[n,t]) , color='black', backgroundcolor='white', fontsize=8) 255 | plt.imshow(img) 256 | alp_curr = alps[n,t,:].reshape(14,14) 257 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=16, sigma=20) 258 | plt.imshow(alp_img, alpha=0.85) 259 | plt.axis('off') 260 | plt.show() 261 | 262 | if save_sampled_captions: 263 | all_sam_cap = np.ndarray((features.shape[0], 20)) 264 | num_iter = int(np.ceil(float(features.shape[0]) / self.batch_size)) 265 | for i in range(num_iter): 266 | features_batch = features[i*self.batch_size:(i+1)*self.batch_size] 267 | attributes_batch = attributes[i*self.batch_size:(i+1)*self.batch_size] 268 | topics_batch = topics[i*self.batch_size:(i+1)*self.batch_size] 269 | feed_dict = { self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch } 270 | all_sam_cap[i*self.batch_size:(i+1)*self.batch_size] = sess.run(sampled_captions, feed_dict) 271 | all_decoded = decode_captions(all_sam_cap, self.model.idx_to_word) 272 | save_pickle(all_decoded, "./data/coco_data/%s/%s.candidate.captions.pkl" %(split,split)) -------------------------------------------------------------------------------- /core/model_2.py: -------------------------------------------------------------------------------- 1 | # ========================================================================================= 2 | # Implementation of "Show, Attend and Tell: Neural Caption Generator With Visual Attention". 3 | # There are some notations. 4 | # N is batch size. 5 | # L is spacial size of feature vector (196). 6 | # D is dimension of image feature vector (512). 7 | # T is the number of time step which is equal to caption's length-1 (16). 8 | # V is vocabulary size (about 10000). 9 | # M is dimension of word vector which is embedding size (default is 512). 10 | # H is dimension of hidden state (default is 1024). 11 | # ========================================================================================= 12 | 13 | from __future__ import division 14 | 15 | import tensorflow as tf 16 | 17 | 18 | class CaptionGenerator(object): 19 | def __init__(self, word_to_idx, dim_feature=[196, 512], att_len=1000, len_T=80, dim_embed=512, dim_hidden=1024, 20 | n_time_step=16, 21 | prev2out=True, ctx2out=True, alpha_c=0.0, selector=True, dropout=True): 22 | """ 23 | Args: 24 | word_to_idx: word-to-index mapping dictionary. 25 | dim_feature: (optional) Dimension of vggnet19 conv5_3 feature vectors. 26 | dim_embed: (optional) Dimension of word embedding. 27 | dim_hidden: (optional) Dimension of all hidden state. 28 | n_time_step: (optional) Time step size of LSTM. 29 | prev2out: (optional) previously generated word to hidden state. (see Eq (7) for explanation) 30 | ctx2out: (optional) context to hidden state (see Eq (7) for explanation) 31 | alpha_c: (optional) Doubly stochastic regularization coefficient. (see Section (4.2.1) for explanation) 32 | selector: (optional) gating scalar for context vector. (see Section (4.2.1) for explanation) 33 | dropout: (optional) If true then dropout layer is added. 34 | """ 35 | 36 | self.word_to_idx = word_to_idx 37 | self.idx_to_word = {i: w for w, i in word_to_idx.iteritems()} 38 | self.prev2out = prev2out 39 | self.ctx2out = ctx2out 40 | self.alpha_c = alpha_c 41 | self.selector = selector 42 | self.dropout = dropout 43 | self.len_T = len_T 44 | self.A = att_len 45 | self.V = len(word_to_idx) 46 | self.L = dim_feature[0] 47 | self.D = dim_feature[1] 48 | self.M = dim_embed 49 | self.H = dim_hidden 50 | self.T = n_time_step 51 | self._start = word_to_idx[''] 52 | self._null = word_to_idx[''] 53 | 54 | self.weight_initializer = tf.contrib.layers.xavier_initializer() 55 | self.const_initializer = tf.constant_initializer(0.0) 56 | self.emb_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0) 57 | 58 | # Place holder for features and captions 59 | self.features = tf.placeholder(tf.float32, [None, self.L, self.D]) 60 | self.topics = tf.placeholder(tf.float32, [None, self.len_T]) 61 | self.attributes = tf.placeholder(tf.float32, [None, self.A]) 62 | self.captions = tf.placeholder(tf.int32, [None, self.T + 1]) 63 | 64 | def _get_initial_lstm(self, features): 65 | with tf.variable_scope('initial_lstm'): 66 | features_mean = tf.reduce_mean(features, 1) 67 | 68 | w_h = tf.get_variable('w_h', [self.D, self.H], initializer=self.weight_initializer) 69 | b_h = tf.get_variable('b_h', [self.H], initializer=self.const_initializer) 70 | h = tf.nn.tanh(tf.matmul(features_mean, w_h) + b_h) 71 | 72 | w_c = tf.get_variable('w_c', [self.D, self.H], initializer=self.weight_initializer) 73 | b_c = tf.get_variable('b_c', [self.H], initializer=self.const_initializer) 74 | c = tf.nn.tanh(tf.matmul(features_mean, w_c) + b_c) 75 | return c, h 76 | 77 | 78 | def _word_embedding(self, inputs, reuse=None): 79 | with tf.variable_scope('word_embedding', reuse=reuse): 80 | w = tf.get_variable('w', [self.V, self.M], initializer=self.emb_initializer) 81 | x = tf.nn.embedding_lookup(w, inputs, name='word_vector') # (N, T, M) or (N, M) 82 | return x 83 | 84 | 85 | def _project_features(self, features): 86 | with tf.variable_scope('project_features'): 87 | w = tf.get_variable('w', [self.D, self.D], initializer=self.weight_initializer) 88 | features_flat = tf.reshape(features, [-1, self.D]) 89 | features_proj = tf.matmul(features_flat, w) 90 | features_proj = tf.reshape(features_proj, [-1, self.L, self.D]) 91 | return features_proj 92 | 93 | 94 | def _project_attributes(self, attributes): 95 | with tf.variable_scope('project_attributes'): 96 | w = tf.get_variable('w', [self.A, self.A], initializer=self.weight_initializer) 97 | attributes_flat = tf.reshape(attributes, [-1, self.A]) 98 | attributes_proj = tf.matmul(attributes_flat, w) 99 | attributes_proj = tf.reshape(attributes_proj, [-1, self.A]) 100 | return attributes_proj 101 | 102 | 103 | # original_model 104 | 105 | def _attention_layer(self, features, features_proj, topic, h, reuse=False): 106 | with tf.variable_scope('attention_layer', reuse=reuse): 107 | w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer) 108 | b = tf.get_variable('b', [self.D], initializer=self.const_initializer) 109 | w_topic = tf.get_variable('w_topic', [self.len_T, self.D], initializer=self.weight_initializer) 110 | 111 | w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer) 112 | 113 | h_att = tf.nn.relu( 114 | features_proj + tf.expand_dims(tf.matmul(h, w), 1) + tf.expand_dims(tf.matmul(topic, w_topic), 115 | 1) + b) 116 | out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L) 117 | alpha = tf.nn.softmax(out_att) 118 | context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') # (N, D) 119 | return context, alpha 120 | 121 | def _attributes_attention_layer(self, attributes, attributes_proj, topic, h, reuse=False): 122 | with tf.variable_scope('attribute_attention_layer', reuse=reuse): 123 | w = tf.get_variable('w', [self.H, self.A], initializer=self.weight_initializer) 124 | b = tf.get_variable('b', [self.A], initializer=self.const_initializer) 125 | w_topic = tf.get_variable('w_topic', [self.len_T, self.A], initializer=self.weight_initializer) 126 | w_att = tf.get_variable('w_att', [self.A, 1], initializer=self.weight_initializer) 127 | h_att = tf.nn.relu( 128 | attributes_proj + tf.matmul(h, w) + tf.matmul(topic, w_topic) + b) 129 | out_att = tf.matmul(h_att, w_att) 130 | alpha = tf.nn.softmax(out_att) 131 | Attributes_context = attributes * alpha 132 | return Attributes_context 133 | 134 | def f_attention_layer(self, f_decoded, h, reuse=False): 135 | with tf.variable_scope('f_attention_layer', reuse=reuse): 136 | w = tf.get_variable('w', [self.H, self.M], initializer=self.weight_initializer) 137 | b = tf.get_variable('b', [self.M], initializer=self.const_initializer) 138 | 139 | w_att = tf.get_variable('w_att', [self.M, 1], initializer=self.weight_initializer) 140 | 141 | h_att = tf.nn.relu( 142 | f_decoded + tf.expand_dims(tf.matmul(h, w), 1) + b) 143 | out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.M]), w_att), [-1, self.T]) # (N, L) 144 | alpha = tf.nn.softmax(out_att) 145 | f_context = tf.reduce_sum(f_decoded * tf.expand_dims(alpha, 2), 1, name='f_context') # (N, D) 146 | return f_context 147 | 148 | 149 | # base_model 150 | # def _attention_layer(self, features, features_proj, topic, h, reuse=False): 151 | # with tf.variable_scope('attention_layer', reuse=reuse): 152 | # w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer) 153 | # b = tf.get_variable('b', [self.D], initializer=self.const_initializer) 154 | # #w_topic = tf.get_variable('w_topic', [self.len_T, self.D], initializer=self.weight_initializer) 155 | # 156 | # w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer) 157 | # 158 | # h_att = tf.nn.relu( 159 | # features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b) 160 | # out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L) 161 | # alpha = tf.nn.softmax(out_att) 162 | # context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') # (N, D) 163 | # return context, alpha 164 | # 165 | # def _attributes_attention_layer(self, attributes, attributes_proj, topic, h, reuse=False): 166 | # with tf.variable_scope('attribute_attention_layer', reuse=reuse): 167 | # w = tf.get_variable('w', [self.H, self.A], initializer=self.weight_initializer) 168 | # b = tf.get_variable('b', [self.A], initializer=self.const_initializer) 169 | # w_topic = tf.get_variable('w_topic', [self.len_T, self.A], initializer=self.weight_initializer) 170 | # 171 | # w_att = tf.get_variable('w_att', [self.A, 1], initializer=self.weight_initializer) 172 | # 173 | # h_att = tf.nn.relu( 174 | # attributes_proj + tf.matmul(h, w) + tf.matmul(topic, w_topic) + b) 175 | # out_att = tf.matmul(h_att, w_att) 176 | # alpha = tf.nn.softmax(out_att) 177 | # Attributes_context = attributes * alpha 178 | # return Attributes_context 179 | 180 | def _selector(self, context, h, reuse=False): 181 | with tf.variable_scope('selector', reuse=reuse): 182 | w = tf.get_variable('w', [self.H, 1], initializer=self.weight_initializer) 183 | b = tf.get_variable('b', [1], initializer=self.const_initializer) 184 | beta = tf.nn.sigmoid(tf.matmul(h, w) + b, 'beta') # (N, 1) 185 | context = tf.multiply(beta, context, name='selected_context') 186 | return context, beta 187 | 188 | 189 | def _decode_lstm(self, x, h, context, dropout=False, reuse=False): 190 | with tf.variable_scope('logits', reuse=reuse): 191 | w_h = tf.get_variable('w_h', [self.H, self.M], initializer=self.weight_initializer) 192 | b_h = tf.get_variable('b_h', [self.M], initializer=self.const_initializer) 193 | w_out = tf.get_variable('w_out', [self.M, self.V], initializer=self.weight_initializer) 194 | b_out = tf.get_variable('b_out', [self.V], initializer=self.const_initializer) 195 | 196 | if dropout: 197 | h = tf.nn.dropout(h, 0.5) 198 | h_logits = tf.matmul(h, w_h) + b_h 199 | 200 | if self.ctx2out: 201 | w_ctx2out = tf.get_variable('w_ctx2out', [self.D, self.M], initializer=self.weight_initializer) 202 | h_logits += tf.matmul(context, w_ctx2out) 203 | 204 | if self.prev2out: 205 | h_logits += x 206 | h_logits = tf.nn.tanh(h_logits) 207 | 208 | if dropout: 209 | h_logits = tf.nn.dropout(h_logits, 0.5) 210 | out_logits = tf.matmul(h_logits, w_out) + b_out 211 | return out_logits 212 | 213 | 214 | def _batch_norm(self, x, mode='train', name=None): 215 | return tf.contrib.layers.batch_norm(inputs=x, 216 | decay=0.95, 217 | center=True, 218 | scale=True, 219 | is_training=(mode == 'train'), 220 | updates_collections=None, 221 | scope=(name + 'batch_norm')) 222 | 223 | 224 | def build_model(self): 225 | features = self.features 226 | attributes = self.attributes 227 | topic = self.topics 228 | captions = self.captions 229 | batch_size = tf.shape(features)[0] 230 | 231 | captions_in = captions[:, :self.T] 232 | captions_out = captions[:, 1:] 233 | mask = tf.to_float(tf.not_equal(captions_out, self._null)) 234 | 235 | # batch normalize feature vectors 236 | features = self._batch_norm(features, mode='train', name='conv_features') 237 | 238 | c, h = self._get_initial_lstm(features=features) 239 | x = self._word_embedding(inputs=captions_in) 240 | features_proj = self._project_features(features=features) 241 | attributes_proj = self._project_attributes(attributes=attributes) 242 | 243 | loss = 0.0 244 | alpha_list = [] 245 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.H) 246 | 247 | for t in range(self.T): 248 | context, alpha = self._attention_layer(features, features_proj, topic, h, reuse=(t != 0)) 249 | attributes_context = self._attributes_attention_layer(attributes, attributes_proj, topic, h, reuse=(t != 0)) 250 | 251 | f_context = self.f_attention_layer(x, h, reuse=(t != 0)) 252 | alpha_list.append(alpha) 253 | 254 | if self.selector: 255 | context, beta = self._selector(context, h, reuse=(t != 0)) 256 | 257 | with tf.variable_scope('lstm', reuse=(t != 0)): 258 | _, (c, h) = lstm_cell(inputs=tf.concat([x[:, t, :], context, attributes_context, f_context], 1), state=[c, h]) 259 | 260 | logits = self._decode_lstm(x[:, t, :], h, context, dropout=self.dropout, reuse=(t != 0)) 261 | loss += tf.reduce_sum( 262 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=captions_out[:, t]) * mask[:, t]) 263 | 264 | if self.alpha_c > 0: 265 | alphas = tf.transpose(tf.stack(alpha_list), (1, 0, 2)) # (N, T, L) 266 | alphas_all = tf.reduce_sum(alphas, 1) # (N, L) 267 | alpha_reg = self.alpha_c * tf.reduce_sum((16. / 196 - alphas_all) ** 2) 268 | loss += alpha_reg 269 | 270 | return loss / tf.to_float(batch_size) 271 | 272 | 273 | def build_sampler(self, max_len=20): 274 | features = self.features 275 | attributes = self.attributes 276 | topic = self.topics 277 | captions = self.captions 278 | captions_in = captions[:, :self.T] 279 | 280 | 281 | # batch normalize feature vectors 282 | features = self._batch_norm(features, mode='test', name='conv_features') 283 | 284 | c, h = self._get_initial_lstm(features=features) 285 | features_proj = self._project_features(features=features) 286 | attributes_proj = self._project_attributes(attributes=attributes) 287 | 288 | sampled_word_list = [] 289 | alpha_list = [] 290 | beta_list = [] 291 | lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.H) 292 | 293 | for t in range(max_len): 294 | if t == 0: 295 | x = self._word_embedding(inputs=tf.fill([tf.shape(features)[0]], self._start)) 296 | else: 297 | x = self._word_embedding(inputs=sampled_word, reuse=True) 298 | 299 | context, alpha = self._attention_layer(features, features_proj, topic, h, reuse=(t != 0)) 300 | attributes_context = self._attributes_attention_layer(attributes, attributes_proj, topic, h, reuse=(t != 0)) 301 | 302 | f_context = self.f_attention_layer(self._word_embedding(inputs=captions_in, reuse = True), h, reuse=(t != 0)) 303 | 304 | alpha_list.append(alpha) 305 | 306 | if self.selector: 307 | context, beta = self._selector(context, h, reuse=(t != 0)) 308 | beta_list.append(beta) 309 | 310 | with tf.variable_scope('lstm', reuse=(t != 0)): 311 | _, (c, h) = lstm_cell(inputs=tf.concat([x, context, attributes_context, f_context], 1), state=[c, h]) 312 | 313 | logits = self._decode_lstm(x, h, context, reuse=(t != 0)) 314 | 315 | sampled_word = tf.argmax(logits, 1) 316 | sampled_word_list.append(sampled_word) 317 | 318 | alphas = tf.transpose(tf.stack(alpha_list), (1, 0, 2)) # (N, T, L) 319 | betas = tf.transpose(tf.squeeze(beta_list), (1, 0)) # (N, T) 320 | sampled_captions = tf.transpose(tf.stack(sampled_word_list), (1, 0)) # (N, max_len) 321 | return alphas, betas, sampled_captions 322 | 323 | -------------------------------------------------------------------------------- /core/solver.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/solver.pyc -------------------------------------------------------------------------------- /core/solver_0.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | import skimage.transform 4 | import numpy as np 5 | import time 6 | import os 7 | import cPickle as pickle 8 | from scipy import ndimage 9 | from utils import * 10 | from bleu import evaluate 11 | 12 | 13 | class CaptioningSolver(object): 14 | def __init__(self, model, data, val_data, **kwargs): 15 | """ 16 | Required Arguments: 17 | - model: Show Attend and Tell caption generating model 18 | - data: Training data; dictionary with the following keys: 19 | - features: Feature vectors of shape (82783, 196, 512) 20 | - file_names: Image file names of shape (82783, ) 21 | - captions: Captions of shape (400000, 17) 22 | - image_idxs: Indices for mapping caption to image of shape (400000, ) 23 | - word_to_idx: Mapping dictionary from word to index 24 | - val_data: validation data; for print out BLEU scores for each epoch. 25 | Optional Arguments: 26 | - n_epochs: The number of epochs to run for training. 27 | - batch_size: Mini batch size. 28 | - update_rule: A string giving the name of an update rule 29 | - learning_rate: Learning rate; default value is 0.01. 30 | - print_every: Integer; training losses will be printed every print_every iterations. 31 | - save_every: Integer; model variables will be saved every save_every epoch. 32 | - pretrained_model: String; pretrained model path 33 | - model_path: String; model path for saving 34 | - test_model: String; model path for test 35 | """ 36 | 37 | self.model = model 38 | self.data = data 39 | self.val_data = val_data 40 | self.n_epochs = kwargs.pop('n_epochs', 10) 41 | self.batch_size = kwargs.pop('batch_size', 100) 42 | self.update_rule = kwargs.pop('update_rule', 'adam') 43 | self.learning_rate = kwargs.pop('learning_rate', 0.01) 44 | self.print_bleu = kwargs.pop('print_bleu', False) 45 | self.print_every = kwargs.pop('print_every', 100) 46 | self.save_every = kwargs.pop('save_every', 1) 47 | self.log_path = kwargs.pop('log_path', './log/') 48 | self.model_path = kwargs.pop('model_path', './model/') 49 | self.pretrained_model = kwargs.pop('pretrained_model', None) 50 | self.test_model = kwargs.pop('test_model', './model/lstm/model-1') 51 | 52 | # set an optimizer by update rule 53 | if self.update_rule == 'adam': 54 | self.optimizer = tf.train.AdamOptimizer 55 | elif self.update_rule == 'momentum': 56 | self.optimizer = tf.train.MomentumOptimizer 57 | elif self.update_rule == 'rmsprop': 58 | self.optimizer = tf.train.RMSPropOptimizer 59 | 60 | if not os.path.exists(self.model_path): 61 | os.makedirs(self.model_path) 62 | if not os.path.exists(self.log_path): 63 | os.makedirs(self.log_path) 64 | 65 | 66 | def train(self): 67 | # train/val dataset 68 | n_examples = self.data['captions'].shape[0] 69 | n_iters_per_epoch = int(np.ceil(float(n_examples)/self.batch_size)) 70 | features = self.data['features'] 71 | attributes = self.data['attributes'] 72 | topics = self.data['topics'] 73 | captions = self.data['captions'] 74 | image_idxs = self.data['image_idxs'] 75 | val_features = self.val_data['features'] 76 | val_attributes = self.val_data['attributes'] 77 | val_topics = self.val_data['topics'] 78 | 79 | n_iters_val = int(np.ceil(float(val_features.shape[0])/self.batch_size)) 80 | 81 | # build graphs for training model and sampling captions 82 | loss = self.model.build_model() 83 | # tf.get_variable_scope().reuse_variables() 84 | # _, _, generated_captions = self.model.build_sampler(max_len=20) 85 | # 86 | # # train op 87 | # with tf.name_scope('optimizer'): 88 | # optimizer = self.optimizer(learning_rate=self.learning_rate) 89 | # grads = tf.gradients(loss, tf.trainable_variables()) 90 | # grads_and_vars = list(zip(grads, tf.trainable_variables())) 91 | # train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars) 92 | with tf.variable_scope(tf.get_variable_scope()) as scope: 93 | with tf.name_scope('optimizer'): 94 | tf.get_variable_scope().reuse_variables() 95 | _, _, generated_captions = self.model.build_sampler(max_len=20) 96 | optimizer = self.optimizer(learning_rate=self.learning_rate) 97 | grads = tf.gradients(loss, tf.trainable_variables()) 98 | grads_and_vars = list(zip(grads, tf.trainable_variables())) 99 | train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars) 100 | 101 | # summary op 102 | tf.summary.scalar('batch_loss', loss) 103 | for var in tf.trainable_variables(): 104 | tf.summary.histogram(var.op.name, var) 105 | for grad, var in grads_and_vars: 106 | tf.summary.histogram(var.op.name+'/gradient', grad) 107 | 108 | summary_op = tf.summary.merge_all() 109 | 110 | print "The number of epoch: %d" %self.n_epochs 111 | print "Data size: %d" %n_examples 112 | print "Batch size: %d" %self.batch_size 113 | print "Iterations per epoch: %d" %n_iters_per_epoch 114 | 115 | config = tf.ConfigProto(allow_soft_placement = True) 116 | #config.gpu_options.per_process_gpu_memory_fraction=0.9 117 | config.gpu_options.allow_growth = True 118 | with tf.Session(config=config) as sess: 119 | tf.initialize_all_variables().run() 120 | summary_writer = tf.summary.FileWriter(self.log_path, graph=tf.get_default_graph()) 121 | saver = tf.train.Saver(max_to_keep=40) 122 | 123 | if self.pretrained_model is not None: 124 | print "Start training with pretrained Model.." 125 | saver.restore(sess, self.pretrained_model) 126 | 127 | prev_loss = -1 128 | curr_loss = 0 129 | start_t = time.time() 130 | 131 | for e in range(self.n_epochs): 132 | rand_idxs = np.random.permutation(n_examples) 133 | captions = captions[rand_idxs] 134 | image_idxs = image_idxs[rand_idxs] 135 | 136 | for i in range(n_iters_per_epoch): 137 | captions_batch = captions[i*self.batch_size:(i+1)*self.batch_size] 138 | image_idxs_batch = image_idxs[i*self.batch_size:(i+1)*self.batch_size] 139 | features_batch = features[image_idxs_batch] 140 | attributes_batch = attributes[image_idxs_batch] 141 | topics_batch = topics[image_idxs_batch] 142 | feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch, self.model.captions: captions_batch} 143 | 144 | _, l = sess.run([train_op, loss], feed_dict) 145 | curr_loss += l 146 | 147 | # write summary for tensorboard visualization 148 | if i % 10 == 0: 149 | summary = sess.run(summary_op, feed_dict) 150 | summary_writer.add_summary(summary, e*n_iters_per_epoch + i) 151 | 152 | if (i+1) % self.print_every == 0: 153 | print "\nTrain loss at epoch %d & iteration %d (mini-batch): %.5f" %(e+1, i+1, l) 154 | ground_truths = captions[image_idxs == image_idxs_batch[0]] 155 | decoded = decode_captions(ground_truths, self.model.idx_to_word) 156 | for j, gt in enumerate(decoded): 157 | print "Ground truth %d: %s" %(j+1, gt) 158 | gen_caps = sess.run(generated_captions, feed_dict) 159 | decoded = decode_captions(gen_caps, self.model.idx_to_word) 160 | print "Generated caption: %s\n" %decoded[0] 161 | 162 | print "Previous epoch loss: ", prev_loss 163 | print "Current epoch loss: ", curr_loss 164 | print "Elapsed time: ", time.time() - start_t 165 | prev_loss = curr_loss 166 | curr_loss = 0 167 | 168 | # print out BLEU scores and file write 169 | # if self.print_bleu: 170 | # all_gen_cap = np.ndarray((val_features.shape[0], 20)) 171 | # 172 | # for i in range(n_iters_val): 173 | # features_batch = val_features[i*self.batch_size:(i+1)*self.batch_size] 174 | # attributes_batch = val_attributes[i * self.batch_size:(i + 1) * self.batch_size] 175 | # topics_batch = val_topics[i * self.batch_size:(i + 1) * self.batch_size] 176 | # feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch} 177 | # gen_cap = sess.run(generated_captions, feed_dict=feed_dict) 178 | # all_gen_cap[i*self.batch_size:(i+1)*self.batch_size] = gen_cap 179 | # 180 | # all_decoded = decode_captions(all_gen_cap, self.model.idx_to_word) 181 | # save_pickle(all_decoded, "./data/val/val.candidate.captions.pkl") 182 | # scores = evaluate(data_path='./data', split='val', get_scores=True) 183 | # write_bleu(scores=scores, path=self.model_path, epoch=e) 184 | # 185 | # save model's parameters 186 | if (e+1) % self.save_every == 0: 187 | saver.save(sess, os.path.join(self.model_path, 'model'), global_step=e+1) 188 | print "model-%s saved." %(e+1) 189 | 190 | 191 | def test(self, data, split='train', attention_visualization=True, save_sampled_captions=True): 192 | ''' 193 | Args: 194 | - data: dictionary with the following keys: 195 | - features: Feature vectors of shape (5000, 196, 512) 196 | - file_names: Image file names of shape (5000, ) 197 | - captions: Captions of shape (24210, 17) 198 | - image_idxs: Indices for mapping caption to image of shape (24210, ) 199 | - features_to_captions: Mapping feature to captions (5000, 4~5) 200 | - split: 'train', 'val' or 'test' 201 | - attention_visualization: If True, visualize attention weights with images for each sampled word. (ipthon notebook) 202 | - save_sampled_captions: If True, save sampled captions to pkl file for computing BLEU scores. 203 | ''' 204 | 205 | features = data['features'] 206 | attributes = data['attributes'] 207 | topics = data['topics'] 208 | 209 | # build a graph to sample captions 210 | alphas, betas, sampled_captions = self.model.build_sampler(max_len=20) # (N, max_len, L), (N, max_len) 211 | 212 | config = tf.ConfigProto(allow_soft_placement=True) 213 | config.gpu_options.allow_growth = True 214 | with tf.Session(config=config) as sess: 215 | 216 | saver = tf.train.Saver() 217 | 218 | print("Reading checkpoints...")# 219 | # ckpt = tf.train.get_checkpoint_state(self.model_path)# 220 | # if ckpt and ckpt.model_checkpoint_path:# 221 | # print 11 222 | # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]# 223 | # print global_step 224 | # #saver.restore(sess, ckpt.model_checkpoint_path)# 225 | # saver.restore(sess,self.test_model) 226 | # print('Loading success, global_step is %s' % global_step)# 227 | # else:# 228 | # print('No checkpoint file found')# 229 | 230 | saver.restore(sess,self.test_model) 231 | print "ok" 232 | features_batch, attributes_batch, topics_batch, image_files = sample_coco_minibatch(data, self.batch_size) 233 | feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, 234 | self.model.topics: topics_batch} 235 | alps, bts, sam_cap = sess.run([alphas, betas, sampled_captions], feed_dict) # (N, max_len, L), (N, max_len) 236 | decoded = decode_captions(sam_cap, self.model.idx_to_word) 237 | 238 | if attention_visualization: 239 | for n in range(10): 240 | print "Sampled Caption: %s" %decoded[n] 241 | 242 | # Plot original image 243 | img = ndimage.imread(image_files[n]) 244 | plt.subplot(4, 5, 1) 245 | plt.imshow(img) 246 | plt.axis('off') 247 | 248 | # Plot images with attention weights 249 | words = decoded[n].split(" ") 250 | for t in range(len(words)): 251 | if t > 18: 252 | break 253 | plt.subplot(4, 5, t+2) 254 | plt.text(0, 1, '%s(%.2f)'%(words[t], bts[n,t]) , color='black', backgroundcolor='white', fontsize=8) 255 | plt.imshow(img) 256 | alp_curr = alps[n,t,:].reshape(14,14) 257 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=16, sigma=20) 258 | plt.imshow(alp_img, alpha=0.85) 259 | plt.axis('off') 260 | plt.show() 261 | 262 | if save_sampled_captions: 263 | all_sam_cap = np.ndarray((features.shape[0], 20)) 264 | num_iter = int(np.ceil(float(features.shape[0]) / self.batch_size)) 265 | for i in range(num_iter): 266 | features_batch = features[i*self.batch_size:(i+1)*self.batch_size] 267 | attributes_batch = attributes[i*self.batch_size:(i+1)*self.batch_size] 268 | topics_batch = topics[i*self.batch_size:(i+1)*self.batch_size] 269 | feed_dict = { self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch } 270 | all_sam_cap[i*self.batch_size:(i+1)*self.batch_size] = sess.run(sampled_captions, feed_dict) 271 | all_decoded = decode_captions(all_sam_cap, self.model.idx_to_word) 272 | save_pickle(all_decoded, "./data/f8_data/%s/%s.candidate.captions.pkl" %(split,split)) -------------------------------------------------------------------------------- /core/solver_2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | import skimage.transform 4 | import numpy as np 5 | import time 6 | import os 7 | import cPickle as pickle 8 | from scipy import ndimage 9 | from utils import * 10 | from bleu import evaluate 11 | import shutil 12 | 13 | 14 | class CaptioningSolver(object): 15 | def __init__(self, model, data, val_data, **kwargs): 16 | """ 17 | Required Arguments: 18 | - model: Show Attend and Tell caption generating model 19 | - data: Training data; dictionary with the following keys: 20 | - features: Feature vectors of shape (82783, 196, 512) 21 | - file_names: Image file names of shape (82783, ) 22 | - captions: Captions of shape (400000, 17) 23 | - image_idxs: Indices for mapping caption to image of shape (400000, ) 24 | - word_to_idx: Mapping dictionary from word to index 25 | - val_data: validation data; for print out BLEU scores for each epoch. 26 | Optional Arguments: 27 | - n_epochs: The number of epochs to run for training. 28 | - batch_size: Mini batch size. 29 | - update_rule: A string giving the name of an update rule 30 | - learning_rate: Learning rate; default value is 0.01. 31 | - print_every: Integer; training losses will be printed every print_every iterations. 32 | - save_every: Integer; model variables will be saved every save_every epoch. 33 | - pretrained_model: String; pretrained model path 34 | - model_path: String; model path for saving 35 | - test_model: String; model path for test 36 | """ 37 | 38 | self.model = model 39 | self.data = data 40 | self.val_data = val_data 41 | self.n_epochs = kwargs.pop('n_epochs', 10) 42 | self.batch_size = kwargs.pop('batch_size', 100) 43 | self.update_rule = kwargs.pop('update_rule', 'adam') 44 | self.learning_rate = kwargs.pop('learning_rate', 0.01) 45 | self.print_bleu = kwargs.pop('print_bleu', False) 46 | self.print_every = kwargs.pop('print_every', 100) 47 | self.save_every = kwargs.pop('save_every', 1) 48 | self.log_path = kwargs.pop('log_path', './log/') 49 | self.model_path = kwargs.pop('model_path', './model/') 50 | self.pretrained_model = kwargs.pop('pretrained_model', None) 51 | self.test_model = kwargs.pop('test_model', './model/lstm/model-1') 52 | 53 | # set an optimizer by update rule 54 | if self.update_rule == 'adam': 55 | self.optimizer = tf.train.AdamOptimizer 56 | elif self.update_rule == 'momentum': 57 | self.optimizer = tf.train.MomentumOptimizer 58 | elif self.update_rule == 'rmsprop': 59 | self.optimizer = tf.train.RMSPropOptimizer 60 | 61 | if not os.path.exists(self.model_path): 62 | os.makedirs(self.model_path) 63 | if not os.path.exists(self.log_path): 64 | os.makedirs(self.log_path) 65 | 66 | 67 | def train(self): 68 | # train/val dataset 69 | n_examples = self.data['captions'].shape[0] 70 | n_iters_per_epoch = int(np.ceil(float(n_examples)/self.batch_size)) 71 | features = self.data['features'] 72 | attributes = self.data['attributes'] 73 | topics = self.data['topics'] 74 | captions = self.data['captions'] 75 | image_idxs = self.data['image_idxs'] 76 | val_features = self.val_data['features'] 77 | val_attributes = self.val_data['attributes'] 78 | val_topics = self.val_data['topics'] 79 | 80 | n_iters_val = int(np.ceil(float(val_features.shape[0])/self.batch_size)) 81 | 82 | # build graphs for training model and sampling captions 83 | loss = self.model.build_model() 84 | # tf.get_variable_scope().reuse_variables() 85 | # _, _, generated_captions = self.model.build_sampler(max_len=20) 86 | # 87 | # # train op 88 | # with tf.name_scope('optimizer'): 89 | # optimizer = self.optimizer(learning_rate=self.learning_rate) 90 | # grads = tf.gradients(loss, tf.trainable_variables()) 91 | # grads_and_vars = list(zip(grads, tf.trainable_variables())) 92 | # train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars) 93 | with tf.variable_scope(tf.get_variable_scope()) as scope: 94 | with tf.name_scope('optimizer'): 95 | tf.get_variable_scope().reuse_variables() 96 | _, _, generated_captions = self.model.build_sampler(max_len=20) 97 | optimizer = self.optimizer(learning_rate=self.learning_rate) 98 | grads = tf.gradients(loss, tf.trainable_variables()) 99 | grads_and_vars = list(zip(grads, tf.trainable_variables())) 100 | train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars) 101 | 102 | # summary op 103 | tf.summary.scalar('batch_loss', loss) 104 | for var in tf.trainable_variables(): 105 | tf.summary.histogram(var.op.name, var) 106 | for grad, var in grads_and_vars: 107 | tf.summary.histogram(var.op.name+'/gradient', grad) 108 | 109 | summary_op = tf.summary.merge_all() 110 | 111 | print "The number of epoch: %d" %self.n_epochs 112 | print "Data size: %d" %n_examples 113 | print "Batch size: %d" %self.batch_size 114 | print "Iterations per epoch: %d" %n_iters_per_epoch 115 | 116 | config = tf.ConfigProto(allow_soft_placement = True) 117 | #config.gpu_options.per_process_gpu_memory_fraction=0.9 118 | config.gpu_options.allow_growth = True 119 | with tf.Session(config=config) as sess: 120 | tf.initialize_all_variables().run() 121 | summary_writer = tf.summary.FileWriter(self.log_path, graph=tf.get_default_graph()) 122 | saver = tf.train.Saver(max_to_keep=40) 123 | 124 | if self.pretrained_model is not None: 125 | print "Start training with pretrained Model.." 126 | saver.restore(sess, self.pretrained_model) 127 | 128 | prev_loss = -1 129 | curr_loss = 0 130 | start_t = time.time() 131 | 132 | for e in range(self.n_epochs): 133 | rand_idxs = np.random.permutation(n_examples) 134 | captions = captions[rand_idxs] 135 | image_idxs = image_idxs[rand_idxs] 136 | 137 | for i in range(n_iters_per_epoch): 138 | captions_batch = captions[i*self.batch_size:(i+1)*self.batch_size] 139 | image_idxs_batch = image_idxs[i*self.batch_size:(i+1)*self.batch_size] 140 | features_batch = features[image_idxs_batch] 141 | attributes_batch = attributes[image_idxs_batch] 142 | topics_batch = topics[image_idxs_batch] 143 | feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, 144 | self.model.topics: topics_batch, self.model.captions: captions_batch} 145 | 146 | _, l = sess.run([train_op, loss], feed_dict) 147 | curr_loss += l 148 | 149 | # write summary for tensorboard visualization 150 | if i % 10 == 0: 151 | summary = sess.run(summary_op, feed_dict) 152 | summary_writer.add_summary(summary, e*n_iters_per_epoch + i) 153 | 154 | if (i+1) % self.print_every == 0: 155 | print "\nTrain loss at epoch %d & iteration %d (mini-batch): %.5f" %(e+1, i+1, l) 156 | ground_truths = captions[image_idxs == image_idxs_batch[0]] 157 | decoded = decode_captions(ground_truths, self.model.idx_to_word) 158 | for j, gt in enumerate(decoded): 159 | print "Ground truth %d: %s" %(j+1, gt) 160 | gen_caps = sess.run(generated_captions, feed_dict) 161 | decoded = decode_captions(gen_caps, self.model.idx_to_word) 162 | print "Generated caption: %s\n" %decoded[0] 163 | 164 | print "Previous epoch loss: ", prev_loss 165 | print "Current epoch loss: ", curr_loss 166 | print "Elapsed time: ", time.time() - start_t 167 | prev_loss = curr_loss 168 | curr_loss = 0 169 | 170 | # print out BLEU scores and file write 171 | # if self.print_bleu: 172 | # all_gen_cap = np.ndarray((val_features.shape[0], 20)) 173 | # 174 | # for i in range(n_iters_val): 175 | # features_batch = val_features[i*self.batch_size:(i+1)*self.batch_size] 176 | # attributes_batch = val_attributes[i * self.batch_size:(i + 1) * self.batch_size] 177 | # topics_batch = val_topics[i * self.batch_size:(i + 1) * self.batch_size] 178 | # feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, self.model.topics: topics_batch} 179 | # gen_cap = sess.run(generated_captions, feed_dict=feed_dict) 180 | # all_gen_cap[i*self.batch_size:(i+1)*self.batch_size] = gen_cap 181 | # 182 | # all_decoded = decode_captions(all_gen_cap, self.model.idx_to_word) 183 | # save_pickle(all_decoded, "./data/val/val.candidate.captions.pkl") 184 | # scores = evaluate(data_path='./data', split='val', get_scores=True) 185 | # write_bleu(scores=scores, path=self.model_path, epoch=e) 186 | # 187 | # save model's parameters 188 | if (e+1) % self.save_every == 0: 189 | saver.save(sess, os.path.join(self.model_path, 'model'), global_step=e+1) 190 | print "model-%s saved." %(e+1) 191 | 192 | 193 | def test(self, data, split='train', attention_visualization=True, save_sampled_captions=True): 194 | ''' 195 | Args: 196 | - data: dictionary with the following keys: 197 | - features: Feature vectors of shape (5000, 196, 512) 198 | - file_names: Image file names of shape (5000, ) 199 | - captions: Captions of shape (24210, 17) 200 | - image_idxs: Indices for mapping caption to image of shape (24210, ) 201 | - features_to_captions: Mapping feature to captions (5000, 4~5) 202 | - split: 'train', 'val' or 'test' 203 | - attention_visualization: If True, visualize attention weights with images for each sampled word. (ipthon notebook) 204 | - save_sampled_captions: If True, save sampled captions to pkl file for computing BLEU scores. 205 | ''' 206 | 207 | features = data['features'] 208 | attributes = data['attributes'] 209 | topics = data['topics'] 210 | captions = data['captions'] 211 | 212 | 213 | captions_in_use = np.ndarray((features.shape[0], 17)) 214 | image_idxs = self.data['image_idxs'] 215 | start_index = image_idxs[0] 216 | captions_in_use[0] = captions[0] 217 | for i in range(len(image_idxs)): 218 | if start_index != image_idxs[i]: 219 | start_index = image_idxs[i] 220 | captions_in_use[start_index] = captions[i] 221 | 222 | 223 | # build a graph to sample captions 224 | alphas, betas, sampled_captions = self.model.build_sampler(max_len=20) # (N, max_len, L), (N, max_len) 225 | 226 | config = tf.ConfigProto(allow_soft_placement=True) 227 | config.gpu_options.allow_growth = True 228 | with tf.Session(config=config) as sess: 229 | 230 | saver = tf.train.Saver() 231 | 232 | print("Reading checkpoints...")# 233 | # ckpt = tf.train.get_checkpoint_state(self.model_path)# 234 | # if ckpt and ckpt.model_checkpoint_path:# 235 | # print 11 236 | # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]# 237 | # print global_step 238 | # #saver.restore(sess, ckpt.model_checkpoint_path)# 239 | # saver.restore(sess,self.test_model) 240 | # print('Loading success, global_step is %s' % global_step)# 241 | # else:# 242 | # print('No checkpoint file found')# 243 | 244 | saver.restore(sess,self.test_model) 245 | print "ok" 246 | 247 | if save_sampled_captions: 248 | all_sam_cap = np.ndarray((features.shape[0], 20)) 249 | num_iter = int(np.ceil(float(features.shape[0]) / self.batch_size)) 250 | for i in range(num_iter): 251 | features_batch = features[i*self.batch_size:(i+1)*self.batch_size] 252 | attributes_batch = attributes[i*self.batch_size:(i+1)*self.batch_size] 253 | topics_batch = topics[i*self.batch_size:(i+1)*self.batch_size] 254 | captions_batch = captions_in_use[i*self.batch_size:(i+1)*self.batch_size] 255 | feed_dict = { self.model.features: features_batch, self.model.attributes: attributes_batch, 256 | self.model.topics: topics_batch, self.model.captions: captions_batch } 257 | all_sam_cap[i*self.batch_size:(i+1)*self.batch_size] = sess.run(sampled_captions, feed_dict) 258 | all_decoded = decode_captions(all_sam_cap, self.model.idx_to_word) 259 | save_pickle(all_decoded, "./data/coco_data/%s/%s.candidate.captions.pkl" %(split,split)) 260 | 261 | 262 | 263 | 264 | # features_batch, attributes_batch, topics_batch, image_files, captions_batch = sample_coco_minibatch(data, self.batch_size) 265 | # 266 | # feed_dict = {self.model.features: features_batch, self.model.attributes: attributes_batch, 267 | # self.model.topics: topics_batch, self.model.captions: captions_batch} 268 | # alps, bts, sam_cap = sess.run([alphas, betas, sampled_captions], feed_dict) # (N, max_len, L), (N, max_len) 269 | # 270 | # decoded = decode_captions(sam_cap, self.model.idx_to_word) 271 | # 272 | # if attention_visualization: 273 | # for n in range(10): 274 | # print "Sampled Caption: %s" %decoded[n] 275 | # 276 | # # Plot original image 277 | # 278 | # img = ndimage.imread(image_files[n]) 279 | # print image_files[n] 280 | # plt.subplot(4, 5, 1) 281 | # plt.imshow(img) 282 | # plt.axis('off') 283 | # 284 | # # Plot images with attention weights 285 | # 286 | # words = decoded[n].split(" ") 287 | # 288 | # for t in range(len(words)): 289 | # if t > 18: 290 | # break 291 | # plt.subplot(4, 5, t+2) 292 | # plt.text(0, 1, '%s(%.2f)'%(words[t], bts[n,t]) , color='black', backgroundcolor='white', fontsize=8) 293 | # plt.imshow(img) 294 | # alp_curr = alps[n,t,:].reshape(14,14) 295 | # alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=16, sigma=10) 296 | # plt.imshow(alp_img, alpha=0.85) 297 | # plt.axis('off') 298 | # 299 | # plt.show() 300 | 301 | 302 | 303 | 304 | ''' 305 | find pic 306 | ######################### clear files and copy original pic 307 | shutil.rmtree('./zzzz/'+str(n)) 308 | 309 | 310 | os.mkdir('./zzzz/'+str(n)) 311 | 312 | 313 | txtName = './zzzz/'+str(n)+'/caption.txt' 314 | f=file(txtName, "a+") 315 | f.write(decoded[n]) 316 | f.close() 317 | 318 | 319 | oldname="./"+image_files[n] 320 | newname= './zzzz/'+str(n)+'/'+'original'+'.jpg' 321 | shutil.copyfile(oldname,newname) 322 | ########################## 323 | 324 | for t in range(len(words)): 325 | if t > 18: 326 | break 327 | plt.imshow(img) 328 | alp_curr = alps[n,t,:].reshape(14,14) 329 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=16, sigma=20) 330 | plt.imshow(alp_img, alpha=0.85) 331 | plt.axis('off') 332 | plt.savefig('./zzzz/'+str(n)+'/'+str(t)+'.png') 333 | ''' 334 | 335 | 336 | 337 | 338 | 339 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pickle 3 | import hickle 4 | import time 5 | import os 6 | import h5py 7 | 8 | 9 | def load_coco_data(data_path='./data', split='train'): 10 | data_path = os.path.join(data_path, split) 11 | start_t = time.time() 12 | data = {} 13 | 14 | #data['features'] = hickle.load(os.path.join(data_path, '%s.features.hkl' %split)) 15 | 16 | feature_file = os.path.join(data_path, '%s.h5' %split) 17 | with h5py.File(feature_file, 'r') as f: 18 | data['features'] = np.asarray(f['features']) 19 | 20 | 21 | with open(os.path.join(data_path, '%s.file.names.pkl' %split), 'rb') as f: 22 | data['file_names'] = pickle.load(f) 23 | with open(os.path.join(data_path, '%s.captions.pkl' %split), 'rb') as f: 24 | data['captions'] = pickle.load(f) 25 | with open(os.path.join(data_path, '%s.image.idxs.pkl' %split), 'rb') as f: 26 | data['image_idxs'] = pickle.load(f) 27 | 28 | attributes_file = os.path.join(data_path, '%s.attributes.h5' %split) 29 | with h5py.File(attributes_file, 'r') as f: 30 | data['attributes'] = np.asarray(f['attributes']) 31 | 32 | topics_file = os.path.join(data_path, '%s.topics.h5' %split) 33 | with h5py.File(topics_file, 'r') as f: 34 | data['topics'] = np.asarray(f['topics']) 35 | 36 | 37 | 38 | if split == 'train': 39 | with open(os.path.join(data_path, 'word_to_idx.pkl'), 'rb') as f: 40 | data['word_to_idx'] = pickle.load(f) 41 | 42 | for k, v in data.iteritems(): 43 | if type(v) == np.ndarray: 44 | print k, type(v), v.shape, v.dtype 45 | else: 46 | print k, type(v), len(v) 47 | end_t = time.time() 48 | print "Elapse time: %.2f" %(end_t - start_t) 49 | return data 50 | 51 | def decode_captions(captions, idx_to_word): 52 | if captions.ndim == 1: 53 | T = captions.shape[0] 54 | N = 1 55 | else: 56 | N, T = captions.shape 57 | 58 | decoded = [] 59 | for i in range(N): 60 | words = [] 61 | for t in range(T): 62 | if captions.ndim == 1: 63 | word = idx_to_word[captions[t]] 64 | else: 65 | word = idx_to_word[captions[i, t]] 66 | if word == '': 67 | words.append('.') 68 | break 69 | if word != '': 70 | words.append(word) 71 | decoded.append(' '.join(words)) 72 | return decoded 73 | 74 | def sample_coco_minibatch(data, captions_1, batch_size): 75 | data_size = data['features'].shape[0] 76 | mask = np.random.choice(data_size, batch_size) 77 | #mask = np.linspace(1810,1820, 11).astype(np.int32) 78 | features = data['features'][mask] 79 | attributes = data['attributes'][mask] 80 | topics = data['topics'][mask] 81 | file_names = data['file_names'][mask] 82 | captions = captions_1[mask] 83 | 84 | return features, attributes, topics, file_names, captions 85 | 86 | def write_bleu(scores, path, epoch): 87 | if epoch == 0: 88 | file_mode = 'w' 89 | else: 90 | file_mode = 'a' 91 | with open(os.path.join(path, 'val.bleu.scores.txt'), file_mode) as f: 92 | f.write('Epoch %d\n' %(epoch+1)) 93 | f.write('Bleu_1: %f\n' %scores['Bleu_1']) 94 | f.write('Bleu_2: %f\n' %scores['Bleu_2']) 95 | f.write('Bleu_3: %f\n' %scores['Bleu_3']) 96 | f.write('Bleu_4: %f\n' %scores['Bleu_4']) 97 | f.write('METEOR: %f\n' %scores['METEOR']) 98 | f.write('ROUGE_L: %f\n' %scores['ROUGE_L']) 99 | f.write('CIDEr: %f\n\n' %scores['CIDEr']) 100 | 101 | def load_pickle(path): 102 | with open(path, 'rb') as f: 103 | file = pickle.load(f) 104 | print ('Loaded %s..' %path) 105 | return file 106 | 107 | def save_pickle(data, path): 108 | with open(path, 'wb') as f: 109 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 110 | print ('Saved %s..' %path) 111 | 112 | 113 | -------------------------------------------------------------------------------- /core/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/utils.pyc -------------------------------------------------------------------------------- /core/utils_0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pickle 3 | import hickle 4 | import time 5 | import os 6 | import h5py 7 | 8 | 9 | def load_coco_data(data_path='./data', split='train'): 10 | data_path = os.path.join(data_path, split) 11 | start_t = time.time() 12 | data = {} 13 | 14 | #data['features'] = hickle.load(os.path.join(data_path, '%s.features.hkl' %split)) 15 | 16 | feature_file = os.path.join(data_path, '%s.h5' %split) 17 | with h5py.File(feature_file, 'r') as f: 18 | data['features'] = np.asarray(f['features']) 19 | 20 | 21 | with open(os.path.join(data_path, '%s.file.names.pkl' %split), 'rb') as f: 22 | data['file_names'] = pickle.load(f) 23 | with open(os.path.join(data_path, '%s.captions.pkl' %split), 'rb') as f: 24 | data['captions'] = pickle.load(f) 25 | with open(os.path.join(data_path, '%s.image.idxs.pkl' %split), 'rb') as f: 26 | data['image_idxs'] = pickle.load(f) 27 | 28 | attributes_file = os.path.join(data_path, '%s.attributes.h5' %split) 29 | with h5py.File(attributes_file, 'r') as f: 30 | data['attributes'] = np.asarray(f['attributes']) 31 | 32 | topics_file = os.path.join(data_path, '%s.topics.h5' %split) 33 | with h5py.File(topics_file, 'r') as f: 34 | data['topics'] = np.asarray(f['topics']) 35 | 36 | 37 | 38 | if split == 'train': 39 | with open(os.path.join(data_path, 'word_to_idx.pkl'), 'rb') as f: 40 | data['word_to_idx'] = pickle.load(f) 41 | 42 | for k, v in data.iteritems(): 43 | if type(v) == np.ndarray: 44 | print k, type(v), v.shape, v.dtype 45 | else: 46 | print k, type(v), len(v) 47 | end_t = time.time() 48 | print "Elapse time: %.2f" %(end_t - start_t) 49 | return data 50 | 51 | def decode_captions(captions, idx_to_word): 52 | if captions.ndim == 1: 53 | T = captions.shape[0] 54 | N = 1 55 | else: 56 | N, T = captions.shape 57 | 58 | decoded = [] 59 | for i in range(N): 60 | words = [] 61 | for t in range(T): 62 | if captions.ndim == 1: 63 | word = idx_to_word[captions[t]] 64 | else: 65 | word = idx_to_word[captions[i, t]] 66 | if word == '': 67 | words.append('.') 68 | break 69 | if word != '': 70 | words.append(word) 71 | decoded.append(' '.join(words)) 72 | return decoded 73 | 74 | def sample_coco_minibatch(data, batch_size): 75 | data_size = data['features'].shape[0] 76 | mask = np.random.choice(data_size, batch_size) 77 | mask = np.linspace(1810,1820, 11).astype(np.int32) 78 | features = data['features'][mask] 79 | attributes = data['attributes'][mask] 80 | topics = data['topics'][mask] 81 | file_names = data['file_names'][mask] 82 | return features, attributes, topics, file_names 83 | 84 | def write_bleu(scores, path, epoch): 85 | if epoch == 0: 86 | file_mode = 'w' 87 | else: 88 | file_mode = 'a' 89 | with open(os.path.join(path, 'val.bleu.scores.txt'), file_mode) as f: 90 | f.write('Epoch %d\n' %(epoch+1)) 91 | f.write('Bleu_1: %f\n' %scores['Bleu_1']) 92 | f.write('Bleu_2: %f\n' %scores['Bleu_2']) 93 | f.write('Bleu_3: %f\n' %scores['Bleu_3']) 94 | f.write('Bleu_4: %f\n' %scores['Bleu_4']) 95 | f.write('METEOR: %f\n' %scores['METEOR']) 96 | f.write('ROUGE_L: %f\n' %scores['ROUGE_L']) 97 | f.write('CIDEr: %f\n\n' %scores['CIDEr']) 98 | 99 | def load_pickle(path): 100 | with open(path, 'rb') as f: 101 | file = pickle.load(f) 102 | print ('Loaded %s..' %path) 103 | return file 104 | 105 | def save_pickle(data, path): 106 | with open(path, 'wb') as f: 107 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 108 | print ('Saved %s..' %path) -------------------------------------------------------------------------------- /core/utils_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pickle 3 | import hickle 4 | import time 5 | import os 6 | import h5py 7 | 8 | 9 | def load_coco_data(data_path='./data', split='train'): 10 | data_path = os.path.join(data_path, split) 11 | start_t = time.time() 12 | data = {} 13 | 14 | #data['features'] = hickle.load(os.path.join(data_path, '%s.features.hkl' %split)) 15 | 16 | feature_file = os.path.join(data_path, '%s.h5' %split) 17 | with h5py.File(feature_file, 'r') as f: 18 | data['features'] = np.asarray(f['features']) 19 | 20 | 21 | with open(os.path.join(data_path, '%s.file.names.pkl' %split), 'rb') as f: 22 | data['file_names'] = pickle.load(f) 23 | with open(os.path.join(data_path, '%s.captions.pkl' %split), 'rb') as f: 24 | data['captions'] = pickle.load(f) 25 | with open(os.path.join(data_path, '%s.image.idxs.pkl' %split), 'rb') as f: 26 | data['image_idxs'] = pickle.load(f) 27 | 28 | attributes_file = os.path.join(data_path, '%s.attributes.h5' %split) 29 | with h5py.File(attributes_file, 'r') as f: 30 | data['attributes'] = np.asarray(f['attributes']) 31 | 32 | topics_file = os.path.join(data_path, '%s.topics.h5' %split) 33 | with h5py.File(topics_file, 'r') as f: 34 | data['topics'] = np.asarray(f['topics']) 35 | 36 | 37 | 38 | if split == 'train': 39 | with open(os.path.join(data_path, 'word_to_idx.pkl'), 'rb') as f: 40 | data['word_to_idx'] = pickle.load(f) 41 | 42 | for k, v in data.iteritems(): 43 | if type(v) == np.ndarray: 44 | print k, type(v), v.shape, v.dtype 45 | else: 46 | print k, type(v), len(v) 47 | end_t = time.time() 48 | print "Elapse time: %.2f" %(end_t - start_t) 49 | return data 50 | 51 | def decode_captions(captions, idx_to_word): 52 | if captions.ndim == 1: 53 | T = captions.shape[0] 54 | N = 1 55 | else: 56 | N, T = captions.shape 57 | 58 | decoded = [] 59 | for i in range(N): 60 | words = [] 61 | for t in range(T): 62 | if captions.ndim == 1: 63 | word = idx_to_word[captions[t]] 64 | else: 65 | word = idx_to_word[captions[i, t]] 66 | if word == '': 67 | words.append('.') 68 | break 69 | if word != '': 70 | words.append(word) 71 | decoded.append(' '.join(words)) 72 | return decoded 73 | 74 | def sample_coco_minibatch(data, batch_size): 75 | data_size = data['features'].shape[0] 76 | mask = np.random.choice(data_size, batch_size) 77 | #mask = np.linspace(1810,1820, 11).astype(np.int32) 78 | features = data['features'][mask] 79 | attributes = data['attributes'][mask] 80 | topics = data['topics'][mask] 81 | file_names = data['file_names'][mask] 82 | captions = data['captions'][mask] 83 | 84 | return features, attributes, topics, file_names, captions 85 | 86 | def write_bleu(scores, path, epoch): 87 | if epoch == 0: 88 | file_mode = 'w' 89 | else: 90 | file_mode = 'a' 91 | with open(os.path.join(path, 'val.bleu.scores.txt'), file_mode) as f: 92 | f.write('Epoch %d\n' %(epoch+1)) 93 | f.write('Bleu_1: %f\n' %scores['Bleu_1']) 94 | f.write('Bleu_2: %f\n' %scores['Bleu_2']) 95 | f.write('Bleu_3: %f\n' %scores['Bleu_3']) 96 | f.write('Bleu_4: %f\n' %scores['Bleu_4']) 97 | f.write('METEOR: %f\n' %scores['METEOR']) 98 | f.write('ROUGE_L: %f\n' %scores['ROUGE_L']) 99 | f.write('CIDEr: %f\n\n' %scores['CIDEr']) 100 | 101 | def load_pickle(path): 102 | with open(path, 'rb') as f: 103 | file = pickle.load(f) 104 | print ('Loaded %s..' %path) 105 | return file 106 | 107 | def save_pickle(data, path): 108 | with open(path, 'wb') as f: 109 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 110 | print ('Saved %s..' %path) 111 | 112 | 113 | -------------------------------------------------------------------------------- /core/vggnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import scipy.io 3 | 4 | 5 | vgg_layers = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 6 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 7 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 8 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 9 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4'] 10 | 11 | class Vgg19(object): 12 | def __init__(self, vgg_path): 13 | self.vgg_path = vgg_path 14 | 15 | def build_inputs(self): 16 | self.images = tf.placeholder(tf.float32, [None, 224, 224, 3], 'images') 17 | 18 | def build_params(self): 19 | model = scipy.io.loadmat(self.vgg_path) 20 | layers = model['layers'][0] 21 | self.params = {} 22 | with tf.variable_scope('encoder'): 23 | for i, layer in enumerate(layers): 24 | layer_name = layer[0][0][0][0] 25 | layer_type = layer[0][0][1][0] 26 | if layer_type == 'conv': 27 | w = layer[0][0][2][0][0].transpose(1, 0, 2, 3) 28 | b = layer[0][0][2][0][1].reshape(-1) 29 | self.params[layer_name] = {} 30 | self.params[layer_name]['w'] = tf.get_variable(layer_name+'/w', initializer=tf.constant(w)) 31 | self.params[layer_name]['b'] = tf.get_variable(layer_name+'/b',initializer=tf.constant(b)) 32 | 33 | def _conv(self, x, w, b): 34 | return tf.nn.bias_add(tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME'), b) 35 | 36 | def _relu(self, x): 37 | return tf.nn.relu(x) 38 | 39 | def _pool(self, x): 40 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID') 41 | 42 | def build_model(self): 43 | for i, layer in enumerate(vgg_layers): 44 | layer_type = layer[:4] 45 | if layer_type == 'conv': 46 | if layer == 'conv1_1': 47 | h = self.images 48 | h = self._conv(h, self.params[layer]['w'], self.params[layer]['b']) 49 | elif layer_type == 'relu': 50 | h = self._relu(h) 51 | elif layer_type == 'pool': 52 | h = self._pool(h) 53 | if layer == 'conv5_3': 54 | self.features = tf.reshape(h, [-1, 196, 512]) 55 | 56 | 57 | def build(self): 58 | self.build_inputs() 59 | self.build_params() 60 | self.build_model() -------------------------------------------------------------------------------- /core/vggnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/core/vggnet.pyc -------------------------------------------------------------------------------- /h5test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Dec 20 16:55:09 2017 5 | 6 | @author: xz 7 | """ 8 | 9 | #import h5py #导入工具包 10 | #import json 11 | #import numpy as np 12 | #import pickle 13 | ##all_feats = np.ndarray([40000, 196, 512], dtype=np.float32) 14 | ##print all_feats.shape[0] 15 | ##image_topic = [] 16 | #topic_path = './data/train/train.attributes.h5' 17 | #with h5py.File(topic_path, 'r') as f: 18 | # image_topic = np.asarray(f['attributes']) 19 | # 20 | #print image_topic[1] 21 | 22 | 23 | #caption_file = './data/train/data/mscoco/processed_freq5.json' 24 | #caption_file1 = './data/annotations/captions_train2014.json' 25 | #with open(caption_file) as f: 26 | # caption_data = json.load(f) 27 | ## print caption_data['train_image_ids'][:5] 28 | ## print caption_data['train_captions'][:5] 29 | # print caption_data.keys() 30 | # print caption_data['train_captions'][:10] 31 | 32 | #with open('/home/Fdisk/imagecaption/lda/coco_topic.pkl', 'rb') as f: 33 | #file = pickle.load(f) 34 | #print file 35 | #x = load_pickle('/home/Fdisk/imagecaption/lda/coco_topic.pkl') 36 | #print x 37 | 38 | 39 | #coding: utf-8 40 | #demo of beam search for seq2seq model 41 | import numpy as np 42 | import random 43 | vocab = { 44 | 0: 'a', 45 | 1: 'b', 46 | 2: 'c', 47 | 3: 'd', 48 | 4: 'e', 49 | 5: 'BOS', 50 | 6: 'EOS' 51 | } 52 | reverse_vocab = dict([(v,k) for k,v in vocab.items()]) 53 | vocab_size = len(vocab.items()) 54 | def softmax(x): 55 | """Compute softmax values for each sets of scores in x.""" 56 | e_x = np.exp(x - np.max(x)) 57 | return e_x / e_x.sum() 58 | def reduce_mul(l): 59 | out = 1.0 60 | for x in l: 61 | out *= x 62 | return out 63 | def check_all_done(seqs): 64 | for seq in seqs: 65 | if not seq[-1]: 66 | return False 67 | return True 68 | 69 | def decode_step(encoder_context, input_seq): 70 | #encoder_context contains infortaion of encoder 71 | #ouput_step contains the words' probability 72 | #these two varibles should be generated by seq2seq model 73 | words_prob = [random.random() for _ in range(vocab_size)] 74 | #downvote BOS 75 | words_prob[reverse_vocab['BOS']] = 0.0 76 | words_prob = softmax(words_prob) 77 | ouput_step = [(idx,prob) for idx,prob in enumerate(words_prob)] 78 | ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True) 79 | return ouput_step 80 | #seq: [[word,word],[word,word],[word,word]] 81 | #output: [[word,word,word],[word,word,word],[word,word,word]] 82 | def beam_search_step(encoder_context, top_seqs, k): 83 | all_seqs = [] 84 | for seq in top_seqs: 85 | seq_score = reduce_mul([_score for _,_score in seq]) 86 | if seq[-1][0] == reverse_vocab['EOS']: 87 | all_seqs.append((seq, seq_score, True)) 88 | continue 89 | #get current step using encoder_context & seq 90 | current_step = decode_step(encoder_context, seq) 91 | for i,word in enumerate(current_step): 92 | if i >= k: 93 | break 94 | word_index = word[0] 95 | word_score = word[1] 96 | score = seq_score * word_score 97 | rs_seq = seq + [word] 98 | done = (word_index == reverse_vocab['EOS']) 99 | all_seqs.append((rs_seq, score, done)) 100 | all_seqs = sorted(all_seqs, key = lambda seq: seq[1], reverse=True) 101 | topk_seqs = [seq for seq,_,_ in all_seqs[:k]] 102 | all_done = check_all_done(topk_seqs) 103 | return topk_seqs, all_done 104 | def beam_search(encoder_context): 105 | beam_size = 3 106 | max_len = 10 107 | #START 108 | top_seqs = [[(reverse_vocab['BOS'],1.0)]] 109 | #loop 110 | for _ in range(max_len): 111 | top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_size) 112 | if all_done: 113 | break 114 | return top_seqs 115 | if __name__ == '__main__': 116 | #encoder_context is not inportant in this demo 117 | encoder_context = None 118 | top_seqs = beam_search(encoder_context) 119 | for i,seq in enumerate(top_seqs): 120 | print 'Path[%d]: ' % i 121 | for word in seq[1:]: 122 | word_index = word[0] 123 | word_prob = word[1] 124 | print '%s(%.4f)' % (vocab[word_index], word_prob), 125 | if word_index == reverse_vocab['EOS']: 126 | break 127 | print '\n' -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | from scipy import ndimage 2 | from collections import Counter 3 | from core.vggnet import Vgg19 4 | from core.utils import * 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import pandas as pd 9 | import hickle 10 | import os 11 | import json 12 | import h5py 13 | 14 | 15 | def _process_caption_data(caption_file, image_dir, max_length): 16 | with open(caption_file) as f: 17 | caption_data = json.load(f) 18 | 19 | # id_to_filename is a dictionary such as {image_id: filename]} 20 | id_to_filename = {image['id']: image['file_name'] for image in caption_data['images']} 21 | 22 | # data is a list of dictionary which contains 'captions', 'file_name' and 'image_id' as key. 23 | data = [] 24 | for annotation in caption_data['annotations']: 25 | image_id = annotation['image_id'] 26 | annotation['file_name'] = os.path.join(image_dir, id_to_filename[image_id]) 27 | data += [annotation] 28 | 29 | # convert to pandas dataframe (for later visualization or debugging) 30 | caption_data = pd.DataFrame.from_dict(data) 31 | del caption_data['id'] 32 | caption_data.sort_values(by='image_id', inplace=True) 33 | caption_data = caption_data.reset_index(drop=True) 34 | 35 | del_idx = [] 36 | for i, caption in enumerate(caption_data['caption']): 37 | caption = caption.replace('.','').replace(',','').replace("'","").replace('"','') 38 | caption = caption.replace('&','and').replace('(','').replace(")","").replace('-',' ') 39 | caption = " ".join(caption.split()) # replace multiple spaces 40 | 41 | caption_data.set_value(i, 'caption', caption.lower()) 42 | if len(caption.split(" ")) > max_length: 43 | del_idx.append(i) 44 | 45 | # delete captions if size is larger than max_length 46 | print "The number of captions before deletion: %d" %len(caption_data) 47 | caption_data = caption_data.drop(caption_data.index[del_idx]) 48 | caption_data = caption_data.reset_index(drop=True) 49 | print "The number of captions after deletion: %d" %len(caption_data) 50 | return caption_data 51 | 52 | 53 | def _build_vocab(annotations, threshold=1): 54 | counter = Counter() 55 | max_len = 0 56 | for i, caption in enumerate(annotations['caption']): 57 | words = caption.split(' ') # caption contrains only lower-case words 58 | for w in words: 59 | counter[w] +=1 60 | 61 | if len(caption.split(" ")) > max_len: 62 | max_len = len(caption.split(" ")) 63 | 64 | vocab = [word for word in counter if counter[word] >= threshold] 65 | print ('Filtered %d words to %d words with word count threshold %d.' % (len(counter), len(vocab), threshold)) 66 | 67 | word_to_idx = {u'': 0, u'': 1, u'': 2} 68 | idx = 3 69 | for word in vocab: 70 | word_to_idx[word] = idx 71 | idx += 1 72 | print "Max length of caption: ", max_len 73 | return word_to_idx 74 | 75 | 76 | def _build_caption_vector(annotations, word_to_idx, max_length=15): 77 | n_examples = len(annotations) 78 | captions = np.ndarray((n_examples,max_length+2)).astype(np.int32) 79 | 80 | for i, caption in enumerate(annotations['caption']): 81 | words = caption.split(" ") # caption contrains only lower-case words 82 | cap_vec = [] 83 | cap_vec.append(word_to_idx['']) 84 | for word in words: 85 | if word in word_to_idx: 86 | cap_vec.append(word_to_idx[word]) 87 | cap_vec.append(word_to_idx['']) 88 | 89 | # pad short caption with the special null token '' to make it fixed-size vector 90 | if len(cap_vec) < (max_length + 2): 91 | for j in range(max_length + 2 - len(cap_vec)): 92 | cap_vec.append(word_to_idx['']) 93 | 94 | captions[i, :] = np.asarray(cap_vec) 95 | print "Finished building caption vectors" 96 | return captions 97 | 98 | 99 | def _build_file_names(annotations): 100 | image_file_names = [] 101 | id_to_idx = {} 102 | idx = 0 103 | image_ids = annotations['image_id'] 104 | file_names = annotations['file_name'] 105 | for image_id, file_name in zip(image_ids, file_names): 106 | if not image_id in id_to_idx: 107 | id_to_idx[image_id] = idx 108 | image_file_names.append(file_name) 109 | idx += 1 110 | 111 | file_names = np.asarray(image_file_names) 112 | return file_names, id_to_idx 113 | 114 | 115 | def _build_image_idxs(annotations, id_to_idx): 116 | image_idxs = np.ndarray(len(annotations), dtype=np.int32) 117 | image_ids = annotations['image_id'] 118 | for i, image_id in enumerate(image_ids): 119 | image_idxs[i] = id_to_idx[image_id] 120 | return image_idxs 121 | 122 | 123 | def main(): 124 | # batch size for extracting feature vectors from vggnet. 125 | batch_size = 100 126 | # maximum length of caption(number of word). if caption is longer than max_length, deleted. 127 | max_length = 15 128 | # if word occurs less than word_count_threshold in training dataset, the word index is special unknown token. 129 | word_count_threshold = 1 130 | # vgg model path 131 | vgg_model_path = './data/imagenet-vgg-verydeep-19.mat' 132 | 133 | caption_file = 'data/annotations/captions_train2014.json' 134 | image_dir = 'image/%2014_resized/' 135 | 136 | # about 80000 images and 400000 captions for train dataset 137 | train_dataset = _process_caption_data(caption_file='data/annotations/captions_train2014.json', 138 | image_dir='image/train2014_resized/', 139 | max_length=max_length) 140 | 141 | # about 40000 images and 200000 captions 142 | val_dataset = _process_caption_data(caption_file='data/annotations/captions_val2014.json', 143 | image_dir='image/val2014_resized/', 144 | max_length=max_length) 145 | #train_cutoff = int(0.7*len(train_dataset)) 146 | # about 4000 images and 20000 captions for val / test dataset 147 | val_cutoff = int(0.1 * len(val_dataset)) 148 | test_cutoff = int(0.2 * len(val_dataset)) 149 | print 'Finished processing caption data' 150 | 151 | save_pickle(train_dataset, 'data/train/train.annotations.pkl') 152 | #save_pickle(train_dataset[:train_cutoff], 'data/train/train.annotations.pkl') 153 | save_pickle(val_dataset[:val_cutoff], 'data/val/val.annotations.pkl') 154 | save_pickle(val_dataset[val_cutoff:test_cutoff].reset_index(drop=True), 'data/test/test.annotations.pkl') 155 | 156 | for split in ['train', 'val', 'test']: 157 | annotations = load_pickle('./data/%s/%s.annotations.pkl' % (split, split)) 158 | 159 | if split == 'train': 160 | word_to_idx = _build_vocab(annotations=annotations, threshold=word_count_threshold) 161 | save_pickle(word_to_idx, './data/%s/word_to_idx.pkl' % split) 162 | 163 | captions = _build_caption_vector(annotations=annotations, word_to_idx=word_to_idx, max_length=max_length) 164 | save_pickle(captions, './data/%s/%s.captions.pkl' % (split, split)) 165 | 166 | file_names, id_to_idx = _build_file_names(annotations) 167 | save_pickle(file_names, './data/%s/%s.file.names.pkl' % (split, split)) 168 | 169 | image_idxs = _build_image_idxs(annotations, id_to_idx) 170 | save_pickle(image_idxs, './data/%s/%s.image.idxs.pkl' % (split, split)) 171 | 172 | # prepare reference captions to compute bleu scores later 173 | image_ids = {} 174 | feature_to_captions = {} 175 | i = -1 176 | for caption, image_id in zip(annotations['caption'], annotations['image_id']): 177 | if not image_id in image_ids: 178 | image_ids[image_id] = 0 179 | i += 1 180 | feature_to_captions[i] = [] 181 | feature_to_captions[i].append(caption.lower() + ' .') 182 | save_pickle(feature_to_captions, './data/%s/%s.references.pkl' % (split, split)) 183 | print "Finished building %s caption dataset" %split 184 | 185 | # extract conv5_3 feature vectors 186 | vggnet = Vgg19(vgg_model_path) 187 | vggnet.build() 188 | with tf.Session() as sess: 189 | tf.initialize_all_variables().run() 190 | for split in [ 'train','val', 'test']: 191 | anno_path = './data/%s/%s.annotations.pkl' % (split, split) 192 | #save_path = './data/%s/%s.features.hkl' % (split, split) 193 | save_path = h5py.File('./data/%s/%s.h5' %(split, split),'w') 194 | #save_path1 = './data/%s/%s.features1.hkl' % (split, split)# 195 | annotations = load_pickle(anno_path) 196 | image_path = list(annotations['file_name'].unique()) 197 | n_examples = len(image_path) 198 | 199 | all_feats = np.ndarray([n_examples, 196, 512], dtype=np.float32) 200 | #all_feats1 = np.ndarray([n_examples-45000, 196, 512], dtype=np.float32)# 201 | 202 | for start, end in zip(range(0, n_examples, batch_size), 203 | range(batch_size, n_examples + batch_size, batch_size)): 204 | image_batch_file = image_path[start:end] 205 | image_batch = np.array(map(lambda x: ndimage.imread(x, mode='RGB'), image_batch_file)).astype( 206 | np.float32) 207 | feats = sess.run(vggnet.features, feed_dict={vggnet.images: image_batch}) 208 | 209 | all_feats[start:end, :] = feats 210 | 211 | print ("Processed %d %s features.." % (end, split)) 212 | 213 | # use hickle to save huge feature vectors 214 | #hickle.dump(all_feats, save_path) 215 | save_path.create_dataset('features', data=all_feats) 216 | 217 | print ("Saved %s.." % (save_path)) 218 | 219 | 220 | if __name__ == "__main__": 221 | #main() 222 | annotations = load_pickle('./data/coco_data/train/train.annotations.pkl' ) 223 | 224 | #annotations = load_pickle('./data/coco_data/test/test.annotations.pkl' ) 225 | 226 | #word_to_idx = load_pickle('./data/train/word_to_idx.pkl' ) 227 | # word_to_idx = _build_vocab(annotations=annotations, threshold=1) 228 | 229 | 230 | 231 | # file_names, id_to_idx = _build_file_names(annotations) 232 | # print len(file_names) 233 | # 234 | # 235 | # image_idxs = _build_image_idxs(annotations, id_to_idx) 236 | # print image_idxs[:20] 237 | # print len(annotations) 238 | 239 | 240 | 241 | batch_size = 2 242 | vgg_model_path = './data/imagenet-vgg-verydeep-19.mat' 243 | vggnet = Vgg19(vgg_model_path) 244 | vggnet.build() 245 | with tf.Session() as sess: 246 | tf.initialize_all_variables().run() 247 | 248 | 249 | #save_path = './data/%s/%s.features.hkl' % (split, split) 250 | save_path = h5py.File('./data/my_pic.h5','w') 251 | #save_path1 = './data/%s/%s.features1.hkl' % (split, split)# 252 | 253 | image_path = list('image/train2014_resized/COCO_train2014_000000013140.jpg') 254 | image_path = list(annotations['file_name'][:3].unique()) 255 | 256 | 257 | n_examples = 1 258 | 259 | all_feats = np.ndarray([n_examples, 196, 512], dtype=np.float32) 260 | #all_feats1 = np.ndarray([n_examples-45000, 196, 512], dtype=np.float32)# 261 | 262 | for start, end in zip(range(0, n_examples, batch_size), 263 | range(batch_size, n_examples + batch_size, batch_size)): 264 | image_batch_file = image_path[start:end] 265 | image_batch = np.array(map(lambda x: ndimage.imread(x, mode='RGB'), image_batch_file)).astype( 266 | np.float32) 267 | feats = sess.run(vggnet.features, feed_dict={vggnet.images: image_batch}) 268 | 269 | all_feats[start:end, :] = feats 270 | 271 | print ("Processed %d features.." % (end)) 272 | 273 | # use hickle to save huge feature vectors 274 | #hickle.dump(all_feats, save_path) 275 | save_path.create_dataset('features', data=all_feats) 276 | 277 | print ("Saved %s.." % (save_path)) 278 | -------------------------------------------------------------------------------- /prepro_f8.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Jan 17 14:06:42 2018 5 | 6 | @author: xz 7 | """ 8 | from scipy import ndimage 9 | from collections import Counter 10 | from core.vggnet import Vgg19 11 | from core.utils import * 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | import pandas as pd 16 | import hickle 17 | import os 18 | import json 19 | import h5py 20 | import shutil 21 | 22 | def split_dataset(): 23 | 24 | train=[] 25 | val=[] 26 | test=[] 27 | for split in ['train']: 28 | fd = file( '/home/Fdisk/flcikr8/drive-download-20171213T072617Z-001/Flickr8k_text/Flickr_8k.%s.txt'% (split), "r" ) 29 | 30 | for line in fd.readlines(): 31 | train.append(line.strip()) 32 | 33 | for split in ['test']: 34 | fd = file( '/home/Fdisk/flcikr8/drive-download-20171213T072617Z-001/Flickr8k_text/Flickr_8k.%s.txt'% (split), "r" ) 35 | 36 | for line in fd.readlines(): 37 | test.append(line.strip()) 38 | for split in ['val']: 39 | fd = file( '/home/Fdisk/flcikr8/drive-download-20171213T072617Z-001/Flickr8k_text/Flickr_8k.%s.txt'% (split), "r" ) 40 | 41 | for line in fd.readlines(): 42 | val.append(line.strip()) 43 | return train,val,test 44 | # fileDir='/home/Fdisk/flcikr8/drive-download-20171213T072617Z-001/Flickr8k_Dataset/Flicker8k_Dataset' 45 | # for root, dirs, files in os.walk(fileDir): 46 | # for xx in files: 47 | # old_path=os.path.join(root, xx) 48 | # 49 | # if xx in train: 50 | # new_path=os.path.join('/home/Fdisk/flcikr8/f8_train',xx) 51 | # shutil.copyfile(old_path,new_path) 52 | # if xx in val: 53 | # new_path=os.path.join('/home/Fdisk/flcikr8/f8_val',xx) 54 | # shutil.copyfile(old_path,new_path) 55 | # if xx in test: 56 | # new_path=os.path.join('/home/Fdisk/flcikr8/f8_test',xx) 57 | # shutil.copyfile(old_path,new_path) 58 | 59 | 60 | 61 | def load_doc(filename): 62 | file = open(filename, 'r') 63 | text = file.read() 64 | file.close() 65 | return text 66 | 67 | def load_data(doc): 68 | id_=0 69 | data_train = [] 70 | data_val = [] 71 | data_test = [] 72 | mapping = dict() 73 | for line in doc.split('\n'): 74 | tokens = line.split() 75 | if len(line) < 2: 76 | continue 77 | image_id, image_desc = tokens[0], tokens[1:] 78 | image_id = image_id.split('.')[0] 79 | file_name = image_id+'.jpg' 80 | image_desc = ' '.join(image_desc) 81 | train = [] 82 | val = [] 83 | test = [] 84 | train,val,test = split_dataset() 85 | if file_name in train: 86 | file_name = os.path.join('image/trainf8_resized/',file_name) 87 | mapping['image_id']=image_id 88 | mapping['caption']=image_desc 89 | mapping['file_name']=file_name 90 | mapping['id']=id_ 91 | id_ +=1 92 | data_train += [mapping] 93 | mapping={} 94 | if file_name in val: 95 | file_name = os.path.join('image/valf8_resized/',file_name) 96 | mapping['image_id']=image_id 97 | mapping['caption']=image_desc 98 | mapping['file_name']=file_name 99 | mapping['id']=id_ 100 | id_ +=1 101 | data_val += [mapping] 102 | mapping={} 103 | if file_name in test: 104 | file_name = os.path.join('image/testf8_resized/',file_name) 105 | mapping['image_id']=image_id 106 | mapping['caption']=image_desc 107 | mapping['file_name']=file_name 108 | mapping['id']=id_ 109 | id_ +=1 110 | data_test += [mapping] 111 | mapping={} 112 | return data_train,data_val,data_test 113 | 114 | def _process_caption_data(data,max_length): 115 | 116 | caption_data = pd.DataFrame.from_dict(data) 117 | del caption_data['id'] 118 | caption_data.sort_values(by='image_id', inplace=True) 119 | caption_data = caption_data.reset_index(drop=True) 120 | 121 | del_idx = [] 122 | for i, caption in enumerate(caption_data['caption']): 123 | caption = caption.replace('.','').replace(',','').replace("'","").replace('"','') 124 | caption = caption.replace('&','and').replace('(','').replace(")","").replace('-',' ') 125 | caption = " ".join(caption.split()) # replace multiple spaces 126 | 127 | caption_data.set_value(i, 'caption', caption.lower()) 128 | if len(caption.split(" ")) > max_length: 129 | del_idx.append(i) 130 | 131 | 132 | # delete captions if size is larger than max_length 133 | print "The number of captions before deletion: %d" %len(caption_data) 134 | caption_data = caption_data.drop(caption_data.index[del_idx]) 135 | caption_data = caption_data.reset_index(drop=True) 136 | print "The number of captions after deletion: %d" %len(caption_data) 137 | return caption_data 138 | def _build_vocab(annotations, threshold=1): 139 | counter = Counter() 140 | max_len = 0 141 | for i, caption in enumerate(annotations['caption']): 142 | words = caption.split(' ') # caption contrains only lower-case words 143 | for w in words: 144 | counter[w] +=1 145 | 146 | if len(caption.split(" ")) > max_len: 147 | max_len = len(caption.split(" ")) 148 | 149 | vocab = [word for word in counter if counter[word] >= threshold] 150 | print ('Filtered %d words to %d words with word count threshold %d.' % (len(counter), len(vocab), threshold)) 151 | 152 | word_to_idx = {u'': 0, u'': 1, u'': 2} 153 | idx = 3 154 | for word in vocab: 155 | word_to_idx[word] = idx 156 | idx += 1 157 | print "Max length of caption: ", max_len 158 | return word_to_idx 159 | 160 | 161 | def _build_caption_vector(annotations, word_to_idx, max_length=15): 162 | n_examples = len(annotations) 163 | captions = np.ndarray((n_examples,max_length+2)).astype(np.int32) 164 | 165 | for i, caption in enumerate(annotations['caption']): 166 | words = caption.split(" ") # caption contrains only lower-case words 167 | cap_vec = [] 168 | cap_vec.append(word_to_idx['']) 169 | for word in words: 170 | if word in word_to_idx: 171 | cap_vec.append(word_to_idx[word]) 172 | cap_vec.append(word_to_idx['']) 173 | 174 | # pad short caption with the special null token '' to make it fixed-size vector 175 | if len(cap_vec) < (max_length + 2): 176 | for j in range(max_length + 2 - len(cap_vec)): 177 | cap_vec.append(word_to_idx['']) 178 | 179 | captions[i, :] = np.asarray(cap_vec) 180 | print "Finished building caption vectors" 181 | return captions 182 | 183 | 184 | def _build_file_names(annotations): 185 | image_file_names = [] 186 | id_to_idx = {} 187 | idx = 0 188 | image_ids = annotations['image_id'] 189 | file_names = annotations['file_name'] 190 | for image_id, file_name in zip(image_ids, file_names): 191 | if not image_id in id_to_idx: 192 | id_to_idx[image_id] = idx 193 | image_file_names.append(file_name) 194 | idx += 1 195 | 196 | file_names = np.asarray(image_file_names) 197 | return file_names, id_to_idx 198 | 199 | 200 | def _build_image_idxs(annotations, id_to_idx): 201 | image_idxs = np.ndarray(len(annotations), dtype=np.int32) 202 | image_ids = annotations['image_id'] 203 | for i, image_id in enumerate(image_ids): 204 | image_idxs[i] = id_to_idx[image_id] 205 | return image_idxs 206 | 207 | def load_pickle(path): 208 | with open(path, 'rb') as f: 209 | file = pickle.load(f) 210 | print('Loaded %s..' % path) 211 | return file 212 | def main(): 213 | 214 | 215 | 216 | 217 | xx = load_pickle('/home/Fdisk/imagecaption/data/f8_data/train/train.file.names.pkl') 218 | train = [] 219 | val = [] 220 | test = [] 221 | train,val,test = split_dataset() 222 | 223 | train_1= [] 224 | for aa in train: 225 | train_1.append(os.path.join('image/trainf8_resized/',aa)) 226 | 227 | for i in train_1: 228 | if i not in xx: 229 | print i 230 | 231 | 232 | 233 | 234 | # batch_size = 100 235 | # max_length = 15 236 | # word_count_threshold = 1 237 | # vgg_model_path = './data/imagenet-vgg-verydeep-19.mat' 238 | # filename = '/home/Fdisk/flcikr8/drive-download-20171213T072617Z-001/Flickr8k_text/Flickr8k.token.txt' 239 | # doc = load_doc(filename) 240 | # train_data=[] 241 | # val_data= [] 242 | # test_data=[] 243 | # train_data,val_data,test_data= load_data(doc) 244 | # print len(train_data) 245 | # train_dataset =_process_caption_data(train_data,max_length=max_length) 246 | # val_dataset =_process_caption_data(val_data,max_length=max_length) 247 | # test_dataset =_process_caption_data(test_data,max_length=max_length) 248 | # print 'Finished processing caption data' 249 | # save_pickle(train_dataset, 'data/f8_data/train/train.annotations.pkl') 250 | # save_pickle(val_dataset, 'data/f8_data/val/val.annotations.pkl') 251 | # save_pickle(test_dataset, 'data/f8_data/test/test.annotations.pkl') 252 | # 253 | # for split in ['train', 'val', 'test']: 254 | # annotations = load_pickle('./data/f8_data/%s/%s.annotations.pkl' % (split, split)) 255 | # 256 | # if split == 'train': 257 | # word_to_idx = _build_vocab(annotations=annotations, threshold=word_count_threshold) 258 | # save_pickle(word_to_idx, './data/f8_data/%s/word_to_idx.pkl' % split) 259 | # 260 | # captions = _build_caption_vector(annotations=annotations, word_to_idx=word_to_idx, max_length=max_length) 261 | # save_pickle(captions, './data/f8_data/%s/%s.captions.pkl' % (split, split)) 262 | # 263 | # file_names, id_to_idx = _build_file_names(annotations) 264 | # save_pickle(file_names, './data/f8_data/%s/%s.file.names.pkl' % (split, split)) 265 | # 266 | # image_idxs = _build_image_idxs(annotations, id_to_idx) 267 | # save_pickle(image_idxs, './data/f8_data/%s/%s.image.idxs.pkl' % (split, split)) 268 | # 269 | # # prepare reference captions to compute bleu scores later 270 | # image_ids = {} 271 | # feature_to_captions = {} 272 | # i = -1 273 | # for caption, image_id in zip(annotations['caption'], annotations['image_id']): 274 | # if not image_id in image_ids: 275 | # image_ids[image_id] = 0 276 | # i += 1 277 | # feature_to_captions[i] = [] 278 | # feature_to_captions[i].append(caption.lower() + ' .') 279 | # save_pickle(feature_to_captions, './data/f8_data/%s/%s.references.pkl' % (split, split)) 280 | # print "Finished building %s caption dataset" %split 281 | # 282 | # # extract conv5_3 feature vectors 283 | # vggnet = Vgg19(vgg_model_path) 284 | # vggnet.build() 285 | # with tf.Session() as sess: 286 | # tf.initialize_all_variables().run() 287 | # for split in [ 'train','val', 'test']: 288 | # anno_path = './data/f8_data/%s/%s.annotations.pkl' % (split, split) 289 | # #save_path = './data/%s/%s.features.hkl' % (split, split) 290 | # save_path = h5py.File('./data/f8_data/%s/%s.h5' %(split, split),'w') 291 | # #save_path1 = './data/%s/%s.features1.hkl' % (split, split)# 292 | # annotations = load_pickle(anno_path) 293 | # image_path = list(annotations['file_name'].unique()) 294 | # n_examples = len(image_path) 295 | # 296 | # all_feats = np.ndarray([n_examples, 196, 512], dtype=np.float32) 297 | # #all_feats1 = np.ndarray([n_examples-45000, 196, 512], dtype=np.float32)# 298 | # 299 | # for start, end in zip(range(0, n_examples, batch_size), 300 | # range(batch_size, n_examples + batch_size, batch_size)): 301 | # image_batch_file = image_path[start:end] 302 | # image_batch = np.array(map(lambda x: ndimage.imread(x, mode='RGB'), image_batch_file)).astype( 303 | # np.float32) 304 | # feats = sess.run(vggnet.features, feed_dict={vggnet.images: image_batch}) 305 | # 306 | # all_feats[start:end, :] = feats 307 | # 308 | # print ("Processed %d %s features.." % (end, split)) 309 | # 310 | # # use hickle to save huge feature vectors 311 | # #hickle.dump(all_feats, save_path) 312 | # save_path.create_dataset('features', data=all_feats) 313 | # 314 | # print ("Saved %s.." % (save_path)) 315 | 316 | 317 | 318 | 319 | if __name__ == "__main__": 320 | main() -------------------------------------------------------------------------------- /resize.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | 5 | def resize_image(image): 6 | width, height = image.size 7 | if width > height: 8 | left = (width - height) / 2 9 | right = width - left 10 | top = 0 11 | bottom = height 12 | else: 13 | top = (height - width) / 2 14 | bottom = height - top 15 | left = 0 16 | right = width 17 | image = image.crop((left, top, right, bottom)) 18 | image = image.resize([224, 224], Image.ANTIALIAS) 19 | return image 20 | 21 | def main(): 22 | splits = ['train', 'val','test'] 23 | #splits = ['val'] 24 | for split in splits: 25 | folder = '/home/Fdisk/flcikr8/f8_%s/' %split 26 | resized_folder = './image/%sf8_resized/' %split 27 | if not os.path.exists(resized_folder): 28 | os.makedirs(resized_folder) 29 | print 'Start resizing %s images.' %split 30 | image_files = os.listdir(folder) 31 | num_images = len(image_files) 32 | for i, image_file in enumerate(image_files): 33 | with open(os.path.join(folder, image_file), 'r+b') as f: 34 | with Image.open(f) as image: 35 | image = resize_image(image) 36 | image.save(os.path.join(resized_folder, image_file), image.format) 37 | if i % 100 == 0: 38 | print 'Resized images: %d/%d' %(i, num_images) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pickle 3 | import hickle 4 | import time 5 | import os 6 | import h5py 7 | 8 | def load_pickle(path): 9 | with open(path, 'rb') as f: 10 | file = pickle.load(f) 11 | print ('Loaded %s..' %path) 12 | return file 13 | 14 | 15 | data_path='./data/coco_data' 16 | split='test' 17 | data_path = os.path.join(data_path, split) 18 | 19 | data = {} 20 | 21 | 22 | with open(os.path.join(data_path, '%s.file.names.pkl' %split), 'rb') as f: 23 | data['file_names'] = pickle.load(f) 24 | with open(os.path.join(data_path, '%s.captions.pkl' %split), 'rb') as f: 25 | data['captions'] = pickle.load(f) 26 | 27 | with open(os.path.join(data_path, '%s.image.idxs.pkl' %split), 'rb') as f: 28 | data['image_idxs'] = pickle.load(f) 29 | 30 | with open(os.path.join(data_path, '%s.candidate.captions_seq.pkl' %split), 'rb') as f: 31 | data['captions_seq'] = pickle.load(f) 32 | 33 | attributes_file = os.path.join(data_path, '%s.attributes.h5' %split) 34 | with h5py.File(attributes_file, 'r') as f: 35 | data['attributes'] = np.asarray(f['attributes']) 36 | 37 | 38 | 39 | print len(data['captions_seq']) 40 | for i in range(10): 41 | for j in range(20): 42 | if data['captions_seq'][i][j]//1 != 0: 43 | print 'false' 44 | print data['captions_seq'][i][j] 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Dec 12 15:23:36 2017 5 | 6 | @author: xz 7 | """ 8 | 9 | import matplotlib.pyplot as plt 10 | import cPickle as pickle 11 | import tensorflow as tf 12 | from core.solver import CaptioningSolver 13 | from core.model import CaptionGenerator 14 | from core.utils import load_coco_data 15 | from core.bleu import evaluate 16 | 17 | 18 | plt.rcParams['figure.figsize'] = (8.0, 6.0) # set default size of plots 19 | plt.rcParams['image.interpolation'] = 'nearest' 20 | plt.rcParams['image.cmap'] = 'gray' 21 | 22 | 23 | data = load_coco_data(data_path='./data/coco_data/', split='test') 24 | with open('./data/coco_data/train/word_to_idx.pkl', 'rb') as f: 25 | word_to_idx = pickle.load(f) 26 | 27 | #print '~~~~~~~~~~~~~~~~~~~~~~~' 28 | # 29 | #for i in range(data['features'].shape[0]): 30 | # 31 | # if data['file_names'][i] =='image/train2014_resized/COCO_train2014_000000013140.jpg': 32 | # print i 33 | # print data['file_names'][i] 34 | #print data['file_names'][1813] 35 | 36 | 37 | model = CaptionGenerator(word_to_idx, dim_feature=[196, 512], dim_embed=512, 38 | dim_hidden=1024, n_time_step=16, prev2out=True, 39 | ctx2out=True, alpha_c=1.0, selector=True, dropout=True) 40 | 41 | solver = CaptioningSolver(model, data, data, n_epochs=20, batch_size=128, update_rule='adam', 42 | learning_rate=0.0025, print_every=2000, save_every=1, image_path='./image/val2014_resized', 43 | pretrained_model=None, model_path='./model/preview_model/', test_model='./model/preview_model/model-20', 44 | print_bleu=False, log_path='./log/') 45 | 46 | 47 | #solver.test(data, split='val') 48 | #test = load_coco_data(data_path='./data/coco_data', split='test') 49 | #tf.get_variable_scope().reuse_variables() 50 | solver.test(data, split='test') 51 | #evaluate(data_path='./data/coco_data', split='val') 52 | evaluate(data_path='./data/coco_data', split='test') 53 | 54 | #solver.test(data, split='test') 55 | # 56 | #evaluate(data_path='./data', split='test') 57 | -------------------------------------------------------------------------------- /topic/lda_topic.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import lda 4 | import pickle 5 | from collections import defaultdict 6 | import h5py 7 | 8 | def load_pickle(path): 9 | with open(path, 'rb') as f: 10 | file = pickle.load(f) 11 | print('Loaded %s..' % path) 12 | return file 13 | 14 | 15 | # word_to_idx: Mapping dictionary from word to index 16 | def _build_caption_vector(annotations, word_to_idx, max_length=15): 17 | 18 | captions = {} 19 | for i, caption in enumerate(annotations['caption']): 20 | words = caption.split(" ") 21 | cap_vec = [] 22 | for word in words: 23 | if word in word_to_idx: 24 | cap_vec.append(word_to_idx[word]) 25 | captions[i] = np.asarray(cap_vec) 26 | print("Finished building caption vectors") 27 | return captions 28 | 29 | 30 | # file_names: Image file names of shape (82783, ) 31 | def _build_file_names(annotations): 32 | image_file_names = [] 33 | id_to_idx = {} 34 | idx = 0 35 | image_ids = annotations['image_id'] 36 | file_names = annotations['file_name'] 37 | print len(annotations['file_name']) 38 | for image_id, file_name in zip(image_ids, file_names): 39 | if not image_id in id_to_idx: 40 | id_to_idx[image_id] = idx 41 | image_file_names.append(file_name) 42 | idx += 1 43 | 44 | print idx 45 | 46 | file_names = np.asarray(image_file_names) 47 | return file_names, id_to_idx 48 | 49 | 50 | def _build_image_idxs(annotations, id_to_idx): 51 | image_idxs = np.ndarray(len(annotations['image_id']), dtype=np.int32) 52 | 53 | image_ids = annotations['image_id'] 54 | for i, image_id in enumerate(image_ids): 55 | image_idxs[i] = id_to_idx[image_id] 56 | return image_idxs 57 | 58 | 59 | def save_pickle(data, path): 60 | with open(path, 'wb') as f: 61 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 62 | print('Saved %s..' % path) 63 | 64 | 65 | def main(): 66 | annotations_train = load_pickle('./f8_topic/%s.annotations.pkl' % ('train')) 67 | annotations_val = load_pickle('./f8_topic/%s.annotations.pkl' % ('val')) 68 | annotations_test = load_pickle('./f8_topic/%s.annotations.pkl' % ('test')) 69 | word_to_idx = load_pickle('./f8_topic/word_to_idx.pkl') 70 | 71 | # xx = load_pickle('/home/Fdisk/imagecaption/data/f8_data/test/test.file.names.pkl') 72 | # xx1 = load_pickle('/home/Fdisk/imagecaption/data/f8_data/val/val.file.names.pkl') 73 | 74 | 75 | 76 | 77 | 78 | annotations = {} 79 | for split in ['image_id', 'file_name', 'caption']: 80 | x = annotations_train[split] 81 | x = x.tolist() 82 | y = annotations_val[split] 83 | y = y.tolist() 84 | z = annotations_test[split] 85 | z = z.tolist() 86 | for i in range(len(y)): 87 | x.append(y[i]) 88 | for j in range(len(z)): 89 | x.append(z[j]) 90 | annotations[split] = x 91 | 92 | 93 | 94 | len_vocab = len(word_to_idx) 95 | print len_vocab 96 | captions = _build_caption_vector(annotations, word_to_idx, 15) 97 | 98 | file_names, id_to_idx = _build_file_names(annotations) 99 | image_idxs = _build_image_idxs(annotations, id_to_idx) 100 | 101 | 102 | #ldac = np.zeros((82783 + 4052 + 4047, len_vocab)).astype(np.int32) 103 | ldac = np.zeros((5999 + 1000 + 1000, len_vocab)).astype(np.int32) 104 | for i in range(len(captions)): 105 | for j in range(len(captions[i])): 106 | if captions[i][j] < len_vocab: 107 | ldac[image_idxs[i]][captions[i][j]] += 1 108 | 109 | X = ldac 110 | model = lda.LDA(n_topics=80, n_iter=2000, random_state=1) 111 | model.fit(X) 112 | plt.plot(model.loglikelihoods_[5:]) 113 | doc_topic = model.doc_topic_ 114 | 115 | save_path = h5py.File('./f8_topic/f8_topic.h5','w') 116 | save_path.create_dataset('all_topic', data=doc_topic) 117 | print ("Saved %s.." % (save_path)) 118 | 119 | 120 | if __name__ == "__main__": 121 | main() -------------------------------------------------------------------------------- /topic/split_h5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import h5py 4 | import numpy as np 5 | 6 | 7 | def main(): 8 | with h5py.File('./f8_topic/f8_topic.h5', 'r') as f: 9 | topic_train = f['all_topic'][0:5999] 10 | topic_val = f['all_topic'][5999 : 5999+1000] 11 | topic_test = f['all_topic'][5999+1000 : 5999+2000] 12 | 13 | train = h5py.File("../data/f8_data/train/train.topics.h5", "w") 14 | train.create_dataset("topics", data=topic_train) 15 | 16 | val = h5py.File("../data/f8_data/val/val.topics.h5", "w") 17 | val.create_dataset("topics", data=topic_val) 18 | 19 | test = h5py.File("../data/f8_data/test/test.topics.h5", "w") 20 | test.create_dataset("topics", data=topic_test) 21 | 22 | if __name__ == '__main__': 23 | main() -------------------------------------------------------------------------------- /topic/topic_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Jan 11 10:31:21 2018 5 | 6 | @author: xz 7 | """ 8 | 9 | 10 | # Imports 11 | import numpy as np 12 | import tensorflow as tf 13 | import pickle 14 | import os 15 | import hickle 16 | import h5py 17 | 18 | def sample_coco_minibatch(topic_data, feature, batch_size): 19 | data_size = feature.shape[0] 20 | mask = np.random.choice(data_size, batch_size) 21 | features = feature[mask] 22 | file_names = topic_data[mask] 23 | return features, file_names 24 | 25 | 26 | #def convLayer(x, kHeight, kWidth, strideX, strideY, 27 | # featureNum, name, padding = "SAME"): 28 | # """convlutional""" 29 | # channel = int(x.get_shape()[-1]) #获取channel数 30 | # with tf.variable_scope(name) as scope: 31 | # w = tf.get_variable("w", shape = [kHeight, kWidth, channel, featureNum]) 32 | # b = tf.get_variable("b", shape = [featureNum]) 33 | # 34 | ## x = tf.nn.conv2d(x, w, stride, padding='SAME', name='conv') 35 | ## x = tf.nn.bias_add(x, b, name='bias_add') 36 | ## x = tf.nn.relu(x, name='relu') 37 | # 38 | # 39 | # featureMap = tf.nn.conv2d(x, w, strides = [1, strideY, strideX, 1], padding = padding) 40 | # out = tf.nn.bias_add(featureMap, b) 41 | # return tf.nn.relu(out, name = scope.name) 42 | # #return tf.nn.relu(tf.reshape(out, featureMap.get_shape().as_list()), name = scope.name) 43 | # 44 | # 45 | #def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"): 46 | # """max-pooling""" 47 | # return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1], 48 | # strides = [1, strideX, strideY, 1], padding = padding, name = name) 49 | # 50 | def dropout(x, keepPro, name = None): 51 | """dropout""" 52 | return tf.nn.dropout(x, keepPro, name) 53 | # 54 | #def fcLayer(x, inputD, outputD, reluFlag, name): 55 | # """fully-connect""" 56 | # with tf.variable_scope(name) as scope: 57 | # w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float") 58 | # b = tf.get_variable("b", [outputD], dtype = "float") 59 | # out = tf.nn.xw_plus_b(x, w, b, name = scope.name) 60 | # if reluFlag: 61 | # return tf.nn.relu(out) 62 | # else: 63 | # return out 64 | 65 | 66 | def conv(layer_name, x, out_channels, kernel_size=[3,3], stride=[1,1,1,1], is_pretrain=True): 67 | '''Convolution op wrapper, use RELU activation after convolution 68 | Args: 69 | layer_name: e.g. conv1, pool1... 70 | x: input tensor, [batch_size, height, width, channels] 71 | out_channels: number of output channels (or comvolutional kernels) 72 | kernel_size: the size of convolutional kernel, VGG paper used: [3,3] 73 | stride: A list of ints. 1-D of length 4. VGG paper used: [1, 1, 1, 1] 74 | is_pretrain: if load pretrained parameters, freeze all conv layers. 75 | Depending on different situations, you can just set part of conv layers to be freezed. 76 | the parameters of freezed layers will not change when training. 77 | Returns: 78 | 4D tensor 79 | ''' 80 | 81 | in_channels = x.get_shape()[-1] 82 | with tf.variable_scope(layer_name): 83 | w = tf.get_variable(name='weights', 84 | trainable=is_pretrain, 85 | shape=[kernel_size[0], kernel_size[1], in_channels, out_channels], 86 | initializer=tf.contrib.layers.xavier_initializer()) # default is uniform distribution initialization 87 | b = tf.get_variable(name='biases', 88 | trainable=is_pretrain, 89 | shape=[out_channels], 90 | initializer=tf.constant_initializer(0.0)) 91 | x = tf.nn.conv2d(x, w, stride, padding='SAME', name='conv') 92 | x = tf.nn.bias_add(x, b, name='bias_add') 93 | x = tf.nn.relu(x, name='relu') 94 | return x 95 | 96 | 97 | def pool(layer_name, x, kernel=[1,2,2,1], stride=[1,2,2,1], is_max_pool=True): 98 | '''Pooling op 99 | Args: 100 | x: input tensor 101 | kernel: pooling kernel, VGG paper used [1,2,2,1], the size of kernel is 2X2 102 | stride: stride size, VGG paper used [1,2,2,1] 103 | padding: 104 | is_max_pool: boolen 105 | if True: use max pooling 106 | else: use avg pooling 107 | ''' 108 | if is_max_pool: 109 | x = tf.nn.max_pool(x, kernel, strides=stride, padding='SAME', name=layer_name) 110 | else: 111 | x = tf.nn.avg_pool(x, kernel, strides=stride, padding='SAME', name=layer_name) 112 | return x 113 | 114 | 115 | def batch_norm(x): 116 | '''Batch normlization(I didn't include the offset and scale) 117 | ''' 118 | epsilon = 1e-3 119 | batch_mean, batch_var = tf.nn.moments(x, [0]) 120 | x = tf.nn.batch_normalization(x, 121 | mean=batch_mean, 122 | variance=batch_var, 123 | offset=None, 124 | scale=None, 125 | variance_epsilon=epsilon) 126 | return x 127 | 128 | 129 | def FC_layer(layer_name, x, out_nodes): 130 | '''Wrapper for fully connected layers with RELU activation as default 131 | Args: 132 | layer_name: e.g. 'FC1', 'FC2' 133 | x: input feature map 134 | out_nodes: number of neurons for current FC layer 135 | ''' 136 | shape = x.get_shape() 137 | if len(shape) == 4: 138 | size = shape[1].value * shape[2].value * shape[3].value 139 | else: 140 | size = shape[-1].value 141 | 142 | with tf.variable_scope(layer_name): 143 | w = tf.get_variable('weights', 144 | shape=[size, out_nodes], 145 | initializer=tf.contrib.layers.xavier_initializer()) 146 | b = tf.get_variable('biases', 147 | shape=[out_nodes], 148 | initializer=tf.constant_initializer(0.0)) 149 | flat_x = tf.reshape(x, [-1, size]) # flatten into 1D 150 | 151 | x = tf.nn.bias_add(tf.matmul(flat_x, w), b) 152 | x = tf.nn.relu(x) 153 | return x 154 | 155 | KEEPPRO = 0.5 156 | is_pretrain = True 157 | 158 | def inference(input): 159 | #input = tf.convert_to_tensor(input) 160 | flat = tf.reshape(input, [-1,196,512,3]) 161 | x = conv('conv5_4', flat, 512, kernel_size=[3,3], stride=[1,1,1,1], is_pretrain=is_pretrain) 162 | x = pool('pool3', x, kernel=[1,2,2,1], stride=[1,2,2,1], is_max_pool=True) 163 | 164 | x = FC_layer('fc6', x, out_nodes=4096) 165 | x = dropout(x,KEEPPRO) 166 | x = batch_norm(x) 167 | x = FC_layer('fc7', x, out_nodes=1000) 168 | x = dropout(x,KEEPPRO) 169 | x = batch_norm(x) 170 | x = FC_layer('fc8', x, out_nodes=80) 171 | 172 | 173 | # fcIn = tf.reshape(pool5, [-1, 7*7*512]) 174 | # fc6 = fcLayer(fcIn, 7*7*512, 4096, True, "fc6") 175 | # dropout1 = dropout(fc6, KEEPPRO) 176 | # 177 | # fc7 = fcLayer(dropout1, 4096, 4096, True, "fc7") 178 | # dropout2 = dropout(fc7, KEEPPRO) 179 | # 180 | # 181 | # fc8 = fcLayer(dropout2, 4096, 1000, True, "fc8") 182 | # dropout3 = dropout(fc8, KEEPPRO) 183 | # 184 | # fc9 = fcLayer(dropout3,1000,80,True,"fcout") 185 | 186 | 187 | #logits = tf.layers.dense(inputs=pool5, units=80) 188 | #logits = tf.nn.dropout(logits, 0.5) 189 | #logits = tf.layers.dense(inputs=flat, units=80,activation=tf.nn.sigmoid) 190 | #logits=tf.nn.softmax(logits) 191 | return x 192 | 193 | def train(): 194 | 195 | #TODO 196 | # image_topic = [] 197 | # topic_path = './val.topics.h5' 198 | # with h5py.File(topic_path, 'r') as f: 199 | # image_topic = np.asarray(f['topics']) 200 | # print ('image_topic ok!') 201 | # # TODO 202 | # features = [] 203 | # feature_path = '../data/coco_data/val/val.h5' 204 | # with h5py.File(feature_path, 'r') as f: 205 | # features = np.asarray(f['features']) 206 | # #features = hickle.load(feature_path) 207 | # print ('features ok!') 208 | 209 | features = np.random.rand(5000, 196, 512) 210 | image_topic = np.random.rand(5000, 80) 211 | 212 | 213 | log_path = './log/' 214 | model_path = './model/' 215 | 216 | n_examples = len(features) 217 | print(n_examples) 218 | batch_size = 64 219 | n_epoch = 10 220 | save_every = 1 221 | 222 | x = tf.placeholder(tf.float32, [batch_size, 196, 512,3], name='x-input') 223 | _y = tf.placeholder(tf.float32, [batch_size, 80], name='y-input') 224 | y = inference(x) 225 | 226 | loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=_y, logits=y)) / batch_size 227 | optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 228 | # grads = tf.gradients(loss, tf.trainable_variables()) 229 | # grads_and_vars = list(zip(grads, tf.trainable_variables())) 230 | train_op = optimizer.minimize(loss=loss) 231 | n_iters_per_epoch = int(np.ceil(float(n_examples) / batch_size)) 232 | print (n_iters_per_epoch) 233 | 234 | tf.summary.scalar('loss', loss) 235 | for var in tf.trainable_variables(): 236 | tf.summary.histogram(var.op.name, var) 237 | # for grad, var in grads_and_vars: 238 | # tf.summary.histogram(var.op.name+'/gradient', grad) 239 | summary_op = tf.summary.merge_all() 240 | 241 | saver = tf.train.Saver() 242 | 243 | 244 | # config = tf.ConfigProto(allow_soft_placement = True) 245 | # #config.gpu_options.per_process_gpu_memory_fraction=0.9 246 | # config.gpu_options.allow_growth = True 247 | with tf.Session() as sess: 248 | tf.initialize_all_variables().run() 249 | 250 | summary_writer = tf.summary.FileWriter(log_path, graph=tf.get_default_graph()) 251 | print '-.-' 252 | for e in range(n_epoch): 253 | rand_idxs = np.random.permutation(n_examples) 254 | 255 | for i in range(n_iters_per_epoch): 256 | xs = features[rand_idxs[i * batch_size:(i + 1) * batch_size]] 257 | ys = image_topic[rand_idxs[i * batch_size:(i + 1) * batch_size]] 258 | feed_dict={x: xs, _y: ys} 259 | _, l = sess.run([train_op, loss], feed_dict) 260 | 261 | if i % 40 == 0: 262 | summary = sess.run(summary_op, feed_dict) 263 | summary_writer.add_summary(summary, e * n_iters_per_epoch + i) 264 | #print ("Processed %d features.." % (e * n_iters_per_epoch + i*batch_size)) 265 | 266 | if (e + 1) % save_every == 0: 267 | saver.save(sess, model_path+'model.ckpt', global_step=e + 1) 268 | print("model-%s saved." % (e + 1)) 269 | 270 | def test(): 271 | x = tf.placeholder(tf.float32, [None, 196,512], name='x-input') 272 | # _y = tf.placeholder(tf.float32, [None, 80], name='y-input') 273 | y = inference(x) 274 | #y = tf.sigmoid(y) 275 | #ys = tf.nn.softmax(y) 276 | 277 | features = [] 278 | feature_path = '../data/coco_data/val/val.h5' 279 | with h5py.File(feature_path, 'r') as f: 280 | features = np.asarray(f['features']) 281 | 282 | image_topic = [] 283 | topic_path = './val.topics.h5' 284 | with h5py.File(topic_path, 'r') as f: 285 | image_topic = np.asarray(f['topics']) 286 | 287 | 288 | logs_train_dir='./model/' 289 | saver = tf.train.Saver() 290 | with tf.Session() as sess: 291 | print("Reading checkpoints...") 292 | ckpt = tf.train.get_checkpoint_state(logs_train_dir) 293 | if ckpt and ckpt.model_checkpoint_path: 294 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 295 | saver.restore(sess, ckpt.model_checkpoint_path) 296 | print('Loading success, global_step is %s' % global_step) 297 | else: 298 | print('No checkpoint file found') 299 | 300 | feed_dict = {x: features} 301 | y = sess.run(y,feed_dict) 302 | #print(sess.run(op_to_restore, feed_dict)) 303 | print(y[10]) 304 | print(image_topic[10]) 305 | 306 | 307 | def main(): 308 | train() 309 | 310 | if __name__ == "__main__": 311 | main() 312 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from core.solver import CaptioningSolver 2 | from core.model import CaptionGenerator 3 | from core.utils import load_coco_data 4 | 5 | 6 | def main(): 7 | # load train dataset 8 | data = load_coco_data(data_path='./data/coco_data', split='train') 9 | word_to_idx = data['word_to_idx'] 10 | # load val dataset to print out bleu scores every epoch 11 | val_data = load_coco_data(data_path='./data/coco_data', split='val') 12 | 13 | model = CaptionGenerator(word_to_idx, dim_feature=[196, 512], dim_embed=512, 14 | dim_hidden=1024, n_time_step=16, prev2out=True, 15 | ctx2out=True, alpha_c=1.0, selector=True, dropout=True) 16 | 17 | solver = CaptioningSolver(model, data, val_data, n_epochs=20, batch_size=128, update_rule='adam', 18 | learning_rate=0.001, print_every=1000, save_every=10, image_path='./image/', 19 | pretrained_model=None, model_path='model/preview_model', test_model='model/lstm/model-10', 20 | print_bleu=True, log_path='log/preview_model_log/') 21 | 22 | solver.train() 23 | 24 | if __name__ == "__main__": 25 | main() -------------------------------------------------------------------------------- /zzzz/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/.DS_Store -------------------------------------------------------------------------------- /zzzz/0/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/0.png -------------------------------------------------------------------------------- /zzzz/0/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/1.png -------------------------------------------------------------------------------- /zzzz/0/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/2.png -------------------------------------------------------------------------------- /zzzz/0/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/3.png -------------------------------------------------------------------------------- /zzzz/0/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/4.png -------------------------------------------------------------------------------- /zzzz/0/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/5.png -------------------------------------------------------------------------------- /zzzz/0/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/6.png -------------------------------------------------------------------------------- /zzzz/0/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/7.png -------------------------------------------------------------------------------- /zzzz/0/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/8.png -------------------------------------------------------------------------------- /zzzz/0/caption.txt: -------------------------------------------------------------------------------- 1 | a cat laying on top of a keyboard . -------------------------------------------------------------------------------- /zzzz/0/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/0/original.jpg -------------------------------------------------------------------------------- /zzzz/1/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/0.png -------------------------------------------------------------------------------- /zzzz/1/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/1.png -------------------------------------------------------------------------------- /zzzz/1/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/2.png -------------------------------------------------------------------------------- /zzzz/1/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/3.png -------------------------------------------------------------------------------- /zzzz/1/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/4.png -------------------------------------------------------------------------------- /zzzz/1/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/5.png -------------------------------------------------------------------------------- /zzzz/1/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/6.png -------------------------------------------------------------------------------- /zzzz/1/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/7.png -------------------------------------------------------------------------------- /zzzz/1/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/8.png -------------------------------------------------------------------------------- /zzzz/1/caption.txt: -------------------------------------------------------------------------------- 1 | a bathroom with a sink mirror and shower . -------------------------------------------------------------------------------- /zzzz/1/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/1/original.jpg -------------------------------------------------------------------------------- /zzzz/12/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/0.png -------------------------------------------------------------------------------- /zzzz/12/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/1.png -------------------------------------------------------------------------------- /zzzz/12/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/2.png -------------------------------------------------------------------------------- /zzzz/12/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/3.png -------------------------------------------------------------------------------- /zzzz/12/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/4.png -------------------------------------------------------------------------------- /zzzz/12/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/5.png -------------------------------------------------------------------------------- /zzzz/12/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/6.png -------------------------------------------------------------------------------- /zzzz/12/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/7.png -------------------------------------------------------------------------------- /zzzz/12/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/8.png -------------------------------------------------------------------------------- /zzzz/12/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/9.png -------------------------------------------------------------------------------- /zzzz/12/caption.txt: -------------------------------------------------------------------------------- 1 | a man hitting a tennis ball with a racket . -------------------------------------------------------------------------------- /zzzz/12/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/12/original.jpg -------------------------------------------------------------------------------- /zzzz/13/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/0.png -------------------------------------------------------------------------------- /zzzz/13/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/1.png -------------------------------------------------------------------------------- /zzzz/13/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/10.png -------------------------------------------------------------------------------- /zzzz/13/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/11.png -------------------------------------------------------------------------------- /zzzz/13/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/2.png -------------------------------------------------------------------------------- /zzzz/13/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/3.png -------------------------------------------------------------------------------- /zzzz/13/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/4.png -------------------------------------------------------------------------------- /zzzz/13/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/5.png -------------------------------------------------------------------------------- /zzzz/13/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/6.png -------------------------------------------------------------------------------- /zzzz/13/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/7.png -------------------------------------------------------------------------------- /zzzz/13/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/8.png -------------------------------------------------------------------------------- /zzzz/13/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/9.png -------------------------------------------------------------------------------- /zzzz/13/caption.txt: -------------------------------------------------------------------------------- 1 | a vase filled with flowers sitting on top of a table . -------------------------------------------------------------------------------- /zzzz/13/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/13/original.jpg -------------------------------------------------------------------------------- /zzzz/2/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/0.png -------------------------------------------------------------------------------- /zzzz/2/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/1.png -------------------------------------------------------------------------------- /zzzz/2/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/2.png -------------------------------------------------------------------------------- /zzzz/2/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/3.png -------------------------------------------------------------------------------- /zzzz/2/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/4.png -------------------------------------------------------------------------------- /zzzz/2/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/5.png -------------------------------------------------------------------------------- /zzzz/2/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/6.png -------------------------------------------------------------------------------- /zzzz/2/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/7.png -------------------------------------------------------------------------------- /zzzz/2/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/8.png -------------------------------------------------------------------------------- /zzzz/2/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/9.png -------------------------------------------------------------------------------- /zzzz/2/caption.txt: -------------------------------------------------------------------------------- 1 | a large boat on the water near a lighthouse . -------------------------------------------------------------------------------- /zzzz/2/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/2/original.jpg -------------------------------------------------------------------------------- /zzzz/3/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/0.png -------------------------------------------------------------------------------- /zzzz/3/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/1.png -------------------------------------------------------------------------------- /zzzz/3/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/10.png -------------------------------------------------------------------------------- /zzzz/3/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/11.png -------------------------------------------------------------------------------- /zzzz/3/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/12.png -------------------------------------------------------------------------------- /zzzz/3/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/13.png -------------------------------------------------------------------------------- /zzzz/3/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/2.png -------------------------------------------------------------------------------- /zzzz/3/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/3.png -------------------------------------------------------------------------------- /zzzz/3/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/4.png -------------------------------------------------------------------------------- /zzzz/3/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/5.png -------------------------------------------------------------------------------- /zzzz/3/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/6.png -------------------------------------------------------------------------------- /zzzz/3/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/7.png -------------------------------------------------------------------------------- /zzzz/3/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/8.png -------------------------------------------------------------------------------- /zzzz/3/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/9.png -------------------------------------------------------------------------------- /zzzz/3/caption.txt: -------------------------------------------------------------------------------- 1 | a man is skiing in the water with a man in the background . -------------------------------------------------------------------------------- /zzzz/3/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/3/original.jpg -------------------------------------------------------------------------------- /zzzz/4/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/0.png -------------------------------------------------------------------------------- /zzzz/4/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/1.png -------------------------------------------------------------------------------- /zzzz/4/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/10.png -------------------------------------------------------------------------------- /zzzz/4/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/11.png -------------------------------------------------------------------------------- /zzzz/4/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/2.png -------------------------------------------------------------------------------- /zzzz/4/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/3.png -------------------------------------------------------------------------------- /zzzz/4/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/4.png -------------------------------------------------------------------------------- /zzzz/4/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/5.png -------------------------------------------------------------------------------- /zzzz/4/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/6.png -------------------------------------------------------------------------------- /zzzz/4/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/7.png -------------------------------------------------------------------------------- /zzzz/4/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/8.png -------------------------------------------------------------------------------- /zzzz/4/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/9.png -------------------------------------------------------------------------------- /zzzz/4/caption.txt: -------------------------------------------------------------------------------- 1 | a cat is sticking its head out of a red door . -------------------------------------------------------------------------------- /zzzz/4/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/4/original.jpg -------------------------------------------------------------------------------- /zzzz/5/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/0.png -------------------------------------------------------------------------------- /zzzz/5/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/1.png -------------------------------------------------------------------------------- /zzzz/5/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/10.png -------------------------------------------------------------------------------- /zzzz/5/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/11.png -------------------------------------------------------------------------------- /zzzz/5/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/2.png -------------------------------------------------------------------------------- /zzzz/5/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/3.png -------------------------------------------------------------------------------- /zzzz/5/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/4.png -------------------------------------------------------------------------------- /zzzz/5/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/5.png -------------------------------------------------------------------------------- /zzzz/5/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/6.png -------------------------------------------------------------------------------- /zzzz/5/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/7.png -------------------------------------------------------------------------------- /zzzz/5/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/8.png -------------------------------------------------------------------------------- /zzzz/5/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/9.png -------------------------------------------------------------------------------- /zzzz/5/caption.txt: -------------------------------------------------------------------------------- 1 | a woman sitting at a table with a large pizza box . -------------------------------------------------------------------------------- /zzzz/5/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/5/original.jpg -------------------------------------------------------------------------------- /zzzz/6/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/0.png -------------------------------------------------------------------------------- /zzzz/6/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/1.png -------------------------------------------------------------------------------- /zzzz/6/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/2.png -------------------------------------------------------------------------------- /zzzz/6/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/3.png -------------------------------------------------------------------------------- /zzzz/6/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/4.png -------------------------------------------------------------------------------- /zzzz/6/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/5.png -------------------------------------------------------------------------------- /zzzz/6/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/6.png -------------------------------------------------------------------------------- /zzzz/6/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/7.png -------------------------------------------------------------------------------- /zzzz/6/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/8.png -------------------------------------------------------------------------------- /zzzz/6/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/9.png -------------------------------------------------------------------------------- /zzzz/6/caption.txt: -------------------------------------------------------------------------------- 1 | a large white refrigerator freezer sitting in a kitchen . -------------------------------------------------------------------------------- /zzzz/6/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/6/original.jpg -------------------------------------------------------------------------------- /zzzz/7/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/0.png -------------------------------------------------------------------------------- /zzzz/7/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/1.png -------------------------------------------------------------------------------- /zzzz/7/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/2.png -------------------------------------------------------------------------------- /zzzz/7/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/3.png -------------------------------------------------------------------------------- /zzzz/7/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/4.png -------------------------------------------------------------------------------- /zzzz/7/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/5.png -------------------------------------------------------------------------------- /zzzz/7/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/6.png -------------------------------------------------------------------------------- /zzzz/7/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/7.png -------------------------------------------------------------------------------- /zzzz/7/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/8.png -------------------------------------------------------------------------------- /zzzz/7/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/9.png -------------------------------------------------------------------------------- /zzzz/7/caption.txt: -------------------------------------------------------------------------------- 1 | a dark room with a large bed in it . -------------------------------------------------------------------------------- /zzzz/7/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/7/original.jpg -------------------------------------------------------------------------------- /zzzz/9/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/0.png -------------------------------------------------------------------------------- /zzzz/9/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/1.png -------------------------------------------------------------------------------- /zzzz/9/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/2.png -------------------------------------------------------------------------------- /zzzz/9/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/3.png -------------------------------------------------------------------------------- /zzzz/9/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/4.png -------------------------------------------------------------------------------- /zzzz/9/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/5.png -------------------------------------------------------------------------------- /zzzz/9/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/6.png -------------------------------------------------------------------------------- /zzzz/9/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/7.png -------------------------------------------------------------------------------- /zzzz/9/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/8.png -------------------------------------------------------------------------------- /zzzz/9/caption.txt: -------------------------------------------------------------------------------- 1 | a giraffe standing next to a tree branch . -------------------------------------------------------------------------------- /zzzz/9/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhihaoZhu/Topic-Guided--Attention-For-Image-Captioning/c56cedc8cca73fd0598c60de73a344bed32d6c78/zzzz/9/original.jpg --------------------------------------------------------------------------------