├── data ├── missing_glove_ques.pkl └── Readme.md ├── README.md ├── LICENSE ├── word2glove.py ├── util.py ├── main.py ├── vgg_model.py ├── abcnn_model.py ├── data_loader.py └── blind_model.ipynb /data/missing_glove_ques.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmelvix/visual-question-answering-tensorflow/HEAD/data/missing_glove_ques.pkl -------------------------------------------------------------------------------- /data/Readme.md: -------------------------------------------------------------------------------- 1 | # To load data files for VQA: 2 | 3 | ``` 4 | sh data_download.sh script 5 | ``` 6 | 7 | Download pre-trained pickle files from [here](https://drive.google.com/drive/folders/0BzPcu5uhlGR3Z2tQUmhEQ1ZUSG8?usp=sharing) 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stacked Attention Network for VQA 2 | 3 | ## Dependencies 4 | 1. Tensorflow 5 | 2. Skimage 6 | 3. Numpy 7 | 8 | ## Dataset download instructions 9 | 1. cd data/ 10 | 1. Download raw data using data_download.sh script 11 | 2. Download pickle files of pre-trained model from [here](https://drive.google.com/drive/folders/0BzPcu5uhlGR3Z2tQUmhEQ1ZUSG8?usp=sharing) 12 | 13 | ## Excecution instructions 14 | ``` 15 | python main.py 16 | ``` 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lenord Melvix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /word2glove.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import data_loader as dl 5 | import pickle 6 | import skimage 7 | import skimage.io 8 | import skimage.transform 9 | from itertools import cycle 10 | from tensorflow.contrib import rnn 11 | import matplotlib.pyplot as plt 12 | 13 | def build_glove_dict(datapath, dest_file): 14 | 15 | # Check if Glove model exists 16 | if os.path.isfile(datapath): 17 | print "Glove file exists. Preparing dictionary" 18 | else: 19 | print "ERROR: Glove dict not found" 20 | return -1 21 | 22 | # Read glove file as dictionary 23 | glove_dict = {} 24 | progress = 0 25 | with open(datapath) as f: 26 | for line in f.readlines(): 27 | word = line.split() 28 | glove_dict[word[0]] = [float(vec) for vec in word[1:]] 29 | progress += 1 30 | 31 | if progress%1000==0: 32 | print "Processed %d vectors" % (progress) 33 | 34 | # Saving glove dictionary 35 | with open(dest_file, 'wb') as handle: 36 | pickle.dump(glove_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 37 | 38 | print "Saved glove dictionary in %s" % (dest_file) 39 | return 0 40 | 41 | def get_glove_dict(src_file): 42 | 43 | # Check if Glove dict exists 44 | if os.path.isfile(src_file): 45 | print "Glove dictionary exists. Retrieving data" 46 | else: 47 | print "ERROR: Glove dict not found" 48 | return -1 49 | 50 | with open(src_file, 'rb') as handle: 51 | glove_dict = pickle.load(handle) 52 | 53 | print "Completed Glove dict retrieval" 54 | return glove_dict 55 | 56 | def build_missing_w2g(word_dict, glove_dict, dest_file): 57 | 58 | missing_words = [] 59 | total_words = len(word_dict.keys()) 60 | progress = 0 61 | 62 | # Append words not found in Glove dict to list 63 | for word in word_dict.keys(): 64 | progress += 1 65 | if word.lower() not in glove_dict: 66 | missing_words.append(word_dict[word]) 67 | if progress%1000 == 0: 68 | print "Processed %d out of %d words" %(progress, total_words) 69 | 70 | with open(dest_file, 'wb') as handle: 71 | pickle.dump(missing_words, handle, protocol=pickle.HIGHEST_PROTOCOL) 72 | 73 | print "Saved missing words in %s" %(dest_file) 74 | return 0 75 | 76 | def is_missing_encoding(words, missing_pkl): 77 | 78 | # Load list of missing words 79 | with open(missing_pkl, 'rb') as handle: 80 | missing_list = pickle.load(handle) 81 | 82 | # Return True if any word in question does not have Glove vector 83 | if not set(words).isdisjoint(missing_list): 84 | return True 85 | else: 86 | return False 87 | 88 | def encode(word_idx, vocab, glove_dict): 89 | 90 | # Handle padding separately 91 | if word_idx not in vocab.values(): 92 | return np.zeros((len(glove_dict['the']))) 93 | else: 94 | word = vocab.keys()[vocab.values().index(word_idx)] 95 | return np.array(glove_dict[word.lower()]) 96 | 97 | def encode_onehot(word_idx, vocab, dummy): 98 | return np.eye(len(vocab.keys()))[int(word_idx)] 99 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | import numpy as np 4 | import tensorflow as tf 5 | import skimage 6 | import skimage.io 7 | import skimage.transform 8 | from itertools import cycle 9 | import word2glove as w2g 10 | import data_loader as dl 11 | import abcnn_model as abc 12 | 13 | def one_hot_ans(vqa, ans): 14 | return np.eye(len(vqa['answer_vocab'].keys()))[ans] 15 | 16 | def get_batch(sess, vqa, batch_size, mode='training'): 17 | 18 | missing_pkl = 'data/missing_glove_ques.pkl' 19 | glove_pkl = 'data/glove_6B_50.pkl' 20 | image_batch = np.array([]) 21 | image_vgg_batch = np.array([]) 22 | answer_batch = np.array([]) 23 | question_batch = np.array([]) 24 | vgg, images = dl.getVGGhandle() 25 | batch_count = 0 26 | glove_dict = w2g.get_glove_dict(glove_pkl) 27 | 28 | if mode == 'training': 29 | purpose = 'train' 30 | img_datapath = 'data/train2014' 31 | else: 32 | purpose = 'val' 33 | mode = 'validation' 34 | img_datapath = 'data/val2014' 35 | 36 | for data in cycle(vqa[mode]): # Changed to smaller subset 37 | vqa_id = data['image_id'] 38 | vqa_ans = data['answer'] 39 | vqa_ques = data['question'] 40 | 41 | # Skip question if it does not have Glove vector encoding 42 | if w2g.is_missing_encoding(vqa_ques, missing_pkl) == True: 43 | continue 44 | 45 | # Filter non-YES/NO/2 answers to avoid skew 46 | if vqa_ans < 3: 47 | continue 48 | 49 | # Get image and build batch 50 | img = dl.getImage(img_datapath, vqa_id, purpose) 51 | if( len(img.shape) < 3 or img.shape[2] < 3 ): 52 | continue 53 | img = skimage.transform.resize(img, (224, 224)) 54 | img = img.reshape((1, 224, 224, 3)) 55 | if image_batch.size==0: 56 | image_batch = abc.rgb_histogram(img) 57 | else: 58 | image_batch = np.concatenate((image_batch, abc.rgb_histogram(img)), axis=0) 59 | 60 | # Get VGG image features and build batch 61 | img_vgg = dl.getImageFeatures(sess, vgg, images, img.reshape((224,224,3))) 62 | if image_vgg_batch.size==0: 63 | image_vgg_batch = img_vgg 64 | else: 65 | image_vgg_batch = np.concatenate((image_vgg_batch,img_vgg),0) 66 | 67 | # Get answer and build batch 68 | if answer_batch.size==0: 69 | answer_batch = one_hot_ans(vqa, vqa_ans) 70 | else: 71 | answer_batch = np.vstack((answer_batch, one_hot_ans(vqa, vqa_ans))) 72 | 73 | # Get Glove encoded question 74 | question_word = np.array([]) 75 | for word in vqa_ques: 76 | encoded_word = w2g.encode(word, vqa['question_vocab'], glove_dict) 77 | if question_word.size==0: 78 | question_word = encoded_word 79 | else: 80 | question_word = np.dstack((question_word, encoded_word)) 81 | 82 | if question_batch.size==0: 83 | question_batch = question_word 84 | else: 85 | question_batch = np.concatenate((question_batch, question_word), 0) 86 | 87 | batch_count += 1 88 | if batch_count==batch_size: 89 | yield image_batch, image_vgg_batch, np.transpose(question_batch,(0,2,1)), answer_batch 90 | batch_count = 0 91 | image_batch = np.zeros((0, 224, 224, 3)) 92 | question_batch = np.array([]) 93 | answer_batch = np.array([]) 94 | image_vgg_batch = np.array([]) 95 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import pickle 4 | import warnings 5 | 6 | import tensorflow as tf 7 | from tensorflow.contrib import rnn 8 | 9 | import data_loader as dl 10 | import word2glove as w2g 11 | import abcnn_model as abc 12 | import util 13 | 14 | def preprocess_question(): 15 | glove_source = 'data/glove.6B.50d.txt' 16 | glove_pkl = 'data/glove_6B_50.pkl' 17 | missing_pkl = 'data/missing_glove_ques.pkl' 18 | 19 | # Prepare Glove Dictionary 20 | if os.path.isfile(glove_pkl): 21 | print "Glove dictionary already exists!" 22 | else: 23 | if w2g.build_glove_dict(glove_source, glove_pkl) == 0: 24 | print "COMPLETED: Glove dictionary parsing" 25 | 26 | # Identify missing words in Glove dictionary 27 | if os.path.isfile(missing_pkl): 28 | print "Missing question vectors already processed!" 29 | else: 30 | # Load Glove Dictionary 31 | glove_dict = w2g.get_glove_dict(glove_pkl) 32 | 33 | # Load VQA training data 34 | vqa_data = dl.load_questions_answers('data') 35 | print "COMPLETED: VQA data retrieval" 36 | ques_vocab = vqa_data['question_vocab'] 37 | if w2g.build_missing_w2g(ques_vocab, glove_dict, missing_pkl) == 0: 38 | print "COMPLETED Missing question words identification" 39 | 40 | def train(): 41 | batch_size = 10 42 | print "Starting ABC-CNN training" 43 | vqa = dl.load_questions_answers('data') 44 | 45 | # Create subset of data for over-fitting 46 | sub_vqa = {} 47 | sub_vqa['training'] = vqa['training'][:10] 48 | sub_vqa['validation'] = vqa['validation'][:10] 49 | sub_vqa['answer_vocab'] = vqa['answer_vocab'] 50 | sub_vqa['question_vocab'] = vqa['question_vocab'] 51 | sub_vqa['max_question_length'] = vqa['max_question_length'] 52 | 53 | train_size = len(vqa['training']) 54 | max_itr = (train_size // batch_size) * 10 55 | 56 | with tf.Session() as sess: 57 | image, ques, ans, optimizer, loss, accuracy = abc.model(sess, batch_size) 58 | print "Defined ABC model" 59 | 60 | train_loader = util.get_batch(sess, vqa, batch_size, 'training') 61 | print "Created train dataset generator" 62 | 63 | valid_loader = util.get_batch(sess, vqa, batch_size, 'validation') 64 | print "Created validation dataset generator" 65 | 66 | writer = abc.write_tensorboard(sess) 67 | init = tf.global_variables_initializer() 68 | merged = tf.summary.merge_all() 69 | sess.run(init) 70 | print "Initialized Tensor variables" 71 | 72 | itr = 1 73 | 74 | while itr < max_itr: 75 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 76 | run_metadata = tf.RunMetadata() 77 | 78 | _, vgg_batch, ques_batch, answer_batch = train_loader.next() 79 | _, valid_vgg_batch, valid_ques_batch, valid_answer_batch = valid_loader.next() 80 | sess.run(optimizer, feed_dict={image: vgg_batch, ques: ques_batch, ans: answer_batch}) 81 | [train_summary, train_loss, train_accuracy] = sess.run([merged, loss, accuracy], 82 | feed_dict={image: vgg_batch, ques: ques_batch, ans: answer_batch}, 83 | options=run_options, 84 | run_metadata=run_metadata) 85 | [valid_loss, valid_accuracy] = sess.run([loss, accuracy], 86 | feed_dict={image: valid_vgg_batch, 87 | ques: valid_ques_batch, 88 | ans: valid_answer_batch}) 89 | 90 | writer.add_run_metadata(run_metadata, 'step%03d' % itr) 91 | writer.add_summary(train_summary, itr) 92 | writer.flush() 93 | print "Iteration:%d\tTraining Loss:%f\tTraining Accuracy:%f\tValidation Loss:%f\tValidation Accuracy:%f"%( 94 | itr, train_loss, 100.*train_accuracy, valid_loss, 100.*valid_accuracy) 95 | itr += 1 96 | 97 | if __name__ == '__main__': 98 | warnings.filterwarnings("ignore") 99 | preprocess_question() 100 | train() 101 | 102 | -------------------------------------------------------------------------------- /vgg_model.py: -------------------------------------------------------------------------------- 1 | # source of code 2 | # https://github.com/machrisaa/tensorflow-vgg/blob/master/vgg16.py 3 | 4 | import inspect 5 | import os 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import time 10 | 11 | VGG_MEAN = [103.939, 116.779, 123.68] 12 | 13 | 14 | class Vgg16: 15 | def __init__(self, vgg16_npy_path='data/vgg16.npy'): 16 | if vgg16_npy_path is None: 17 | path = inspect.getfile(Vgg16) 18 | path = os.path.abspath(os.path.join(path, os.pardir)) 19 | path = os.path.join(path, "vgg16.npy") 20 | vgg16_npy_path = path 21 | print(path) 22 | 23 | self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item() 24 | print("npy file loaded") 25 | 26 | def build(self, rgb): 27 | """ 28 | load variable from npy to build the VGG 29 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 30 | """ 31 | 32 | start_time = time.time() 33 | print("build model started") 34 | rgb_scaled = rgb * 255.0 35 | 36 | # Convert RGB to BGR 37 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 38 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 39 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 40 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 41 | bgr = tf.concat(axis=3, values=[ 42 | blue - VGG_MEAN[0], 43 | green - VGG_MEAN[1], 44 | red - VGG_MEAN[2], 45 | ]) 46 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 47 | 48 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 49 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 50 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 51 | 52 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 53 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 54 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 55 | 56 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 57 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 58 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 59 | self.pool3 = self.max_pool(self.conv3_3, 'pool3') 60 | 61 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 62 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 63 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 64 | self.pool4 = self.max_pool(self.conv4_3, 'pool4') 65 | 66 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 67 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 68 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 69 | self.pool5 = self.max_pool(self.conv5_3, 'pool5') 70 | 71 | self.fc6 = self.fc_layer(self.pool5, "fc6") 72 | assert self.fc6.get_shape().as_list()[1:] == [4096] 73 | self.relu6 = tf.nn.relu(self.fc6) 74 | 75 | self.fc7 = self.fc_layer(self.relu6, "fc7") 76 | self.relu7 = tf.nn.relu(self.fc7) 77 | 78 | self.fc8 = self.fc_layer(self.relu7, "fc8") 79 | 80 | self.prob = tf.nn.softmax(self.fc8, name="prob") 81 | 82 | self.data_dict = None 83 | print(("build model finished: %ds" % (time.time() - start_time))) 84 | 85 | def avg_pool(self, bottom, name): 86 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 87 | 88 | def max_pool(self, bottom, name): 89 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 90 | 91 | def conv_layer(self, bottom, name): 92 | with tf.variable_scope(name): 93 | filt = self.get_conv_filter(name) 94 | 95 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 96 | 97 | conv_biases = self.get_bias(name) 98 | bias = tf.nn.bias_add(conv, conv_biases) 99 | 100 | relu = tf.nn.relu(bias) 101 | return relu 102 | 103 | def fc_layer(self, bottom, name): 104 | with tf.variable_scope(name): 105 | shape = bottom.get_shape().as_list() 106 | dim = 1 107 | for d in shape[1:]: 108 | dim *= d 109 | x = tf.reshape(bottom, [-1, dim]) 110 | 111 | weights = self.get_fc_weight(name) 112 | biases = self.get_bias(name) 113 | 114 | # Fully connected layer. Note that the '+' operation automatically 115 | # broadcasts the biases. 116 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 117 | 118 | return fc 119 | 120 | def get_conv_filter(self, name): 121 | return tf.constant(self.data_dict[name][0], name="filter") 122 | 123 | def get_bias(self, name): 124 | return tf.constant(self.data_dict[name][1], name="biases") 125 | 126 | def get_fc_weight(self, name): 127 | return tf.constant(self.data_dict[name][0], name="weights") -------------------------------------------------------------------------------- /abcnn_model.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import data_loader as dl 5 | import pickle 6 | import skimage 7 | import skimage.io 8 | import skimage.transform 9 | from itertools import cycle 10 | import word2glove as w2g 11 | from tensorflow.contrib import rnn 12 | 13 | def model(sess, batch_size): 14 | 15 | # Placeholders for input image, question and output answer 16 | image = tf.placeholder("float", [None, 14, 14, 512], name="VGG") 17 | ques = tf.placeholder("float", [None, 22, 50], name="ques") 18 | ans = tf.placeholder("float", [None, 1000], name="ans") 19 | 20 | with tf.variable_scope('Image') as scope: 21 | img_w = tf.Variable(tf.random_normal([1,1,512,128]), name="weight") 22 | img_b = tf.Variable(tf.random_normal([128]), name="bias") 23 | img_conv = tf.nn.conv2d(image, img_w, strides=[1,1,1,1], padding='SAME') 24 | img_feature = tf.nn.relu(img_conv + img_b, name="activation") 25 | variable_summaries(img_w) 26 | 27 | # Learn semantics of question using LSTM 28 | with tf.variable_scope('Question') as scope: 29 | ques_w = tf.Variable(tf.random_normal([256, 128]), name="W_qa") 30 | ques_b = tf.Variable(tf.random_normal([128]), name="b_A") 31 | ques_sem = ques_semantics(ques, ques_w, ques_b) 32 | ques_vec = tf.reshape(ques_sem, shape=[-1, 1, 1, 128]) 33 | variable_summaries(ques_w) 34 | 35 | # Get attention probabilities 36 | with tf.variable_scope('Attention') as scope: 37 | with tf.variable_scope('embedding') as scope: 38 | visual_embed = tf.multiply(img_feature, ques_vec) 39 | variable_summaries(visual_embed) 40 | 41 | with tf.variable_scope('probability') as scope: 42 | embed_squeeze = tf.reduce_sum(tf.reduce_sum(visual_embed, axis=2), axis=1) 43 | attention_prob = tf.nn.softmax(embed_squeeze, name="p_I") 44 | 45 | # Build reduced Image feature using attention map 46 | with tf.variable_scope('ReducedImage') as scope: 47 | attention_prob = tf.reshape(attention_prob, shape=[-1,1,1,128]) 48 | rimage_feature = tf.multiply(img_feature, attention_prob) 49 | variable_summaries(rimage_feature) 50 | reduced_image_visualize = tf.reshape(tf.reduce_sum(rimage_feature, axis=3), shape=[-1, 14, 14, 1]) 51 | attention_summary = tf.summary.image('weighted-features', reduced_image_visualize) 52 | 53 | # Adding a convolution layer to reduce dimensions 54 | with tf.variable_scope('RImage') as scope: 55 | red_img_w = tf.Variable(tf.random_normal([1,1,128,8]), name="weight") 56 | red_img_b = tf.Variable(tf.random_normal([8]), name="bias") 57 | red_img_conv = tf.nn.conv2d(rimage_feature, red_img_w, strides=[1,1,1,1], padding='SAME') 58 | red_img_activ = tf.nn.relu(red_img_conv + red_img_b, name="activation") 59 | reduced_image_feature = tf.reshape(red_img_activ, shape=[-1, 1568]) 60 | variable_summaries(red_img_w) 61 | 62 | with tf.variable_scope('Image') as scope: 63 | img_w = tf.Variable(tf.random_normal([1,1,128,8]), name="weight") 64 | img_b = tf.Variable(tf.random_normal([8]), name="bias") 65 | img_conv = tf.nn.conv2d(img_feature, img_w, strides=[1,1,1,1], padding='SAME') 66 | img_activ = tf.nn.relu(img_conv + img_b, name="activation") 67 | image_feature = tf.reshape(img_activ, shape=[-1, 1568]) 68 | variable_summaries(img_w) 69 | 70 | 71 | # Combine all three features to build dense layer 72 | with tf.variable_scope('Dense') as scope: 73 | with tf.variable_scope('question') as scope: 74 | sem_dense_w = tf.Variable(tf.random_normal([128, 1000]), name="q_weight") 75 | variable_summaries(sem_dense_w) 76 | 77 | with tf.variable_scope('image') as scope: 78 | img_dense_w = tf.Variable(tf.random_normal([1568, 1000]), name="i_weight") 79 | variable_summaries(img_dense_w) 80 | 81 | with tf.variable_scope('attention') as scope: 82 | reduced_img_dense_w = tf.Variable(tf.random_normal([1568, 1000]), name="ri_weight") 83 | variable_summaries(reduced_img_dense_w) 84 | 85 | with tf.variable_scope('bias') as scope: 86 | dense_b = tf.Variable(tf.random_normal([1000]), name="ans_bias") 87 | variable_summaries(dense_b) 88 | 89 | dense = tf.matmul(ques_sem, sem_dense_w) + \ 90 | tf.matmul(reduced_image_feature, reduced_img_dense_w) + \ 91 | tf.matmul(image_feature, img_dense_w) + \ 92 | dense_b 93 | 94 | # Apply softmax on dense layer to get answers 95 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=dense, labels=ans) 96 | mean_cross_entropy = tf.reduce_mean(cross_entropy) 97 | tf.summary.scalar('mean_cross_entropy', mean_cross_entropy) 98 | 99 | # Create Optimizer for reducing loss 100 | optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(mean_cross_entropy) 101 | 102 | # Evaluate model 103 | correct_prediction = tf.equal(tf.argmax(ans, 1), tf.argmax(dense, 1)) 104 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 105 | tf.summary.scalar('accuracy', accuracy) 106 | 107 | # Return Model parameters 108 | return image, ques, ans, optimizer, mean_cross_entropy, accuracy 109 | 110 | def rgb_histogram(im): 111 | rhist,_ = np.histogram(np.reshape(im[:,:,0], (1,-1)), bins=np.arange(257)) 112 | ghist,_ = np.histogram(np.reshape(im[:,:,1], (1,-1)), bins=np.arange(257)) 113 | bhist,_ = np.histogram(np.reshape(im[:,:,2], (1, -1)), bins=np.arange(257)) 114 | hist = np.append((rhist, ghist, bhist), 0) 115 | return hist 116 | 117 | def ques_semantics(word, weight, bias): 118 | with tf.variable_scope('LSTM') as scope: 119 | word = tf.unstack(word, 22, 1) 120 | lstm_cell = rnn.BasicLSTMCell(256, forget_bias=1.0) 121 | output, states = rnn.static_rnn(lstm_cell, word, dtype=tf.float32) 122 | ques_sem = tf.matmul(states[-1], weight) + bias 123 | return tf.nn.relu(ques_sem, "ques-semantics-acitvation") 124 | 125 | def write_tensorboard(sess): 126 | writer = tf.summary.FileWriter('graph/log/', sess.graph) 127 | return writer 128 | 129 | def variable_summaries(var): 130 | with tf.name_scope('summaries'): 131 | mean = tf.reduce_mean(var) 132 | tf.summary.scalar('mean', mean) 133 | with tf.name_scope('stddev'): 134 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 135 | tf.summary.scalar('stddev', stddev) 136 | tf.summary.scalar('max', tf.reduce_max(var)) 137 | tf.summary.scalar('min', tf.reduce_min(var)) 138 | tf.summary.histogram('histogram', var) 139 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from os.path import isfile, join 4 | import re 5 | import numpy as np 6 | import pprint 7 | import pickle 8 | from itertools import cycle 9 | import vgg_model as vgg_model 10 | import tensorflow as tf 11 | import skimage 12 | import skimage.io 13 | import skimage.transform 14 | import os 15 | import random 16 | import re 17 | import time 18 | 19 | def load_questions_answers(data_dir): 20 | 21 | questions = None 22 | answers = None 23 | 24 | t_q_json_file = join(data_dir, 'MultipleChoice_mscoco_train2014_questions.json') 25 | t_a_json_file = join(data_dir, 'mscoco_train2014_annotations.json') 26 | 27 | v_q_json_file = join(data_dir, 'MultipleChoice_mscoco_val2014_questions.json') 28 | v_a_json_file = join(data_dir, 'mscoco_val2014_annotations.json') 29 | qa_data_file = join(data_dir, 'qa_data_file.pkl') 30 | vocab_file = join(data_dir, 'vocab_file.pkl') 31 | 32 | # IF ALREADY EXTRACTED 33 | if isfile(qa_data_file): 34 | with open(qa_data_file) as f: 35 | data = pickle.load(f) 36 | return data 37 | 38 | print "Loading Training questions" 39 | with open(t_q_json_file) as f: 40 | t_questions = json.loads(f.read()) 41 | 42 | print "Loading Training anwers" 43 | with open(t_a_json_file) as f: 44 | t_answers = json.loads(f.read()) 45 | 46 | print "Loading Val questions" 47 | with open(v_q_json_file) as f: 48 | v_questions = json.loads(f.read()) 49 | 50 | print "Loading Val answers" 51 | with open(v_a_json_file) as f: 52 | v_answers = json.loads(f.read()) 53 | 54 | 55 | print "Ans", len(t_answers['annotations']), len(v_answers['annotations']) 56 | print "Qu", len(t_questions['questions']), len(v_questions['questions']) 57 | 58 | answers = t_answers['annotations'] + v_answers['annotations'] 59 | questions = t_questions['questions'] + v_questions['questions'] 60 | 61 | answer_vocab = make_answer_vocab(answers) 62 | question_vocab, max_question_length = make_questions_vocab(questions, answers, answer_vocab) 63 | print "Max Question Length", max_question_length 64 | word_regex = re.compile(r'\w+') 65 | training_data = [] 66 | for i,question in enumerate( t_questions['questions']): 67 | ans = t_answers['annotations'][i]['multiple_choice_answer'] 68 | if ans in answer_vocab: 69 | training_data.append({ 70 | 'image_id' : t_answers['annotations'][i]['image_id'], 71 | 'question' : np.zeros(max_question_length), 72 | 'answer' : answer_vocab[ans] 73 | }) 74 | question_words = re.findall(word_regex, question['question']) 75 | 76 | base = max_question_length - len(question_words) 77 | for i in range(0, len(question_words)): 78 | training_data[-1]['question'][base + i] = question_vocab[ question_words[i] ] 79 | 80 | print "Training Data", len(training_data) 81 | val_data = [] 82 | for i,question in enumerate( v_questions['questions']): 83 | ans = v_answers['annotations'][i]['multiple_choice_answer'] 84 | if ans in answer_vocab: 85 | val_data.append({ 86 | 'image_id' : v_answers['annotations'][i]['image_id'], 87 | 'question' : np.zeros(max_question_length), 88 | 'answer' : answer_vocab[ans] 89 | }) 90 | question_words = re.findall(word_regex, question['question']) 91 | 92 | base = max_question_length - len(question_words) 93 | for i in range(0, len(question_words)): 94 | val_data[-1]['question'][base + i] = question_vocab[ question_words[i] ] 95 | 96 | print "Validation Data", len(val_data) 97 | 98 | data = { 99 | 'training' : training_data, 100 | 'validation' : val_data, 101 | 'answer_vocab' : answer_vocab, 102 | 'question_vocab' : question_vocab, 103 | 'max_question_length' : max_question_length 104 | } 105 | 106 | print "Saving qa_data" 107 | with open(qa_data_file, 'wb') as f: 108 | pickle.dump(data, f) 109 | 110 | with open(vocab_file, 'wb') as f: 111 | vocab_data = { 112 | 'answer_vocab' : data['answer_vocab'], 113 | 'question_vocab' : data['question_vocab'], 114 | 'max_question_length' : data['max_question_length'] 115 | } 116 | pickle.dump(vocab_data, f) 117 | 118 | return data 119 | 120 | def get_question_answer_vocab(data_dir): 121 | vocab_file = join(data_dir, 'vocab_file.pkl') 122 | vocab_data = pickle.load(open(vocab_file)) 123 | return vocab_data 124 | 125 | def make_answer_vocab(answers): 126 | top_n = 1000 127 | answer_frequency = {} 128 | for annotation in answers: 129 | answer = annotation['multiple_choice_answer'] 130 | if answer in answer_frequency: 131 | answer_frequency[answer] += 1 132 | else: 133 | answer_frequency[answer] = 1 134 | 135 | answer_frequency_tuples = [ (-frequency, answer) for answer, frequency in answer_frequency.iteritems()] 136 | answer_frequency_tuples.sort() 137 | answer_frequency_tuples = answer_frequency_tuples[0:top_n-1] 138 | 139 | answer_vocab = {} 140 | for i, ans_freq in enumerate(answer_frequency_tuples): 141 | # print i, ans_freq 142 | ans = ans_freq[1] 143 | answer_vocab[ans] = i 144 | 145 | answer_vocab['UNK'] = top_n - 1 146 | return answer_vocab 147 | 148 | def make_questions_vocab(questions, answers, answer_vocab): 149 | word_regex = re.compile(r'\w+') 150 | question_frequency = {} 151 | 152 | max_question_length = 0 153 | for i,question in enumerate(questions): 154 | ans = answers[i]['multiple_choice_answer'] 155 | count = 0 156 | if ans in answer_vocab: 157 | question_words = re.findall(word_regex, question['question']) 158 | for qw in question_words: 159 | if qw in question_frequency: 160 | question_frequency[qw] += 1 161 | else: 162 | question_frequency[qw] = 1 163 | count += 1 164 | if count > max_question_length: 165 | max_question_length = count 166 | 167 | 168 | qw_freq_threhold = 0 169 | qw_tuples = [ (-frequency, qw) for qw, frequency in question_frequency.iteritems()] 170 | # qw_tuples.sort() 171 | 172 | qw_vocab = {} 173 | for i, qw_freq in enumerate(qw_tuples): 174 | frequency = -qw_freq[0] 175 | qw = qw_freq[1] 176 | # print frequency, qw 177 | if frequency > qw_freq_threhold: 178 | # +1 for accounting the zero padding for batc training 179 | qw_vocab[qw] = i + 1 180 | else: 181 | break 182 | 183 | qw_vocab['UNK'] = len(qw_vocab) + 1 184 | 185 | return qw_vocab, max_question_length 186 | 187 | def getImage(datapath, imageID, purpose='train'): 188 | name_3 = str(imageID) 189 | name_2 = '0' * (12-len(name_3)) 190 | name_1 = 'COCO_' + purpose + '2014_' 191 | fileName = name_1 + name_2 + name_3 + '.jpg' 192 | filepath = join(datapath,fileName) 193 | img = skimage.io.imread(filepath) 194 | return(img) 195 | 196 | def getImageFeatures(sess, vgg, images ,img ): 197 | # load image 198 | img = img / 255.0 199 | assert (0 <= img).all() and (img <= 1.0).all() 200 | # print "Original Image Shape: ", img.shape 201 | # we crop image from center 202 | short_edge = min(img.shape[:2]) 203 | yy = int((img.shape[0] - short_edge) / 2) 204 | xx = int((img.shape[1] - short_edge) / 2) 205 | crop_img = img[yy: yy + short_edge, xx: xx + short_edge] 206 | # resize to 224, 224 207 | resized_img = skimage.transform.resize(crop_img, (224, 224)) 208 | img_reshape = resized_img.reshape((1, 224, 224, 3)) 209 | img_feature = sess.run(vgg.pool4, feed_dict={images:img_reshape}) 210 | return(img_feature) 211 | 212 | def getVGGhandle(): 213 | images = tf.placeholder("float", [None, 224, 224, 3]) 214 | vgg = vgg_model.Vgg16() 215 | with tf.name_scope("content_vgg"): 216 | vgg.build(images) 217 | return(vgg,images) -------------------------------------------------------------------------------- /blind_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true, 8 | "deletable": true, 9 | "editable": true 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import os, sys\n", 14 | "import numpy as np\n", 15 | "from itertools import cycle\n", 16 | "import data_loader\n", 17 | "import tensorflow as tf\n", 18 | "from tensorflow.contrib import rnn\n", 19 | "import word2glove as w2g" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": { 26 | "collapsed": false, 27 | "deletable": true, 28 | "editable": true 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "qa_data = data_loader.load_questions_answers('data')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": { 39 | "collapsed": false 40 | }, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "['training',\n", 46 | " 'validation',\n", 47 | " 'answer_vocab',\n", 48 | " 'question_vocab',\n", 49 | " 'max_question_length']" 50 | ] 51 | }, 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "qa_data.keys()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": { 65 | "collapsed": true, 66 | "deletable": true, 67 | "editable": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "sub_qa_data = {}\n", 72 | "sub_qa_data['training'] = qa_data['training'][:10]\n", 73 | "sub_qa_data['validation'] = qa_data['validation'][:10]\n", 74 | "s" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": { 81 | "collapsed": false 82 | }, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "10" 88 | ] 89 | }, 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "len(sub_qa_data['training'])" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 12, 102 | "metadata": { 103 | "collapsed": false, 104 | "deletable": true, 105 | "editable": true 106 | }, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "15182" 112 | ] 113 | }, 114 | "execution_count": 12, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "len(w2g.encode_onehot(0., qa_data['question_vocab'],0))" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 3, 126 | "metadata": { 127 | "collapsed": false, 128 | "deletable": true, 129 | "editable": true 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "#Get questions in batches\n", 134 | "def one_hot_qa(ques, ans):\n", 135 | " ques_holder = np.eye(len(qa_data['question_vocab'].keys()))\n", 136 | " ans_holder = np.eye(len(qa_data['answer_vocab'].keys()))\n", 137 | " enc_ques = np.array([])\n", 138 | " enc_ans = np.array([])\n", 139 | " \n", 140 | " for word in ques:\n", 141 | " if enc_ques.size == 0:\n", 142 | " enc_ques = ques_holder[int(word)]\n", 143 | " else:\n", 144 | " enc_ques = np.vstack((enc_ques, ques_holder[int(word)]))\n", 145 | "\n", 146 | " enc_ans = ans_holder[ans] \n", 147 | " return enc_ques, enc_ans\n", 148 | "\n", 149 | "def get_qa_batches(batch_size):\n", 150 | " ques_arr = np.array([])\n", 151 | " ans_arr = np.array([])\n", 152 | " counter = 0\n", 153 | " \n", 154 | " for entry in qa_data['training']:\n", 155 | " if ques_arr.size == 0:\n", 156 | " ques_arr, ans_arr = one_hot_qa(entry['question'], entry['answer'])\n", 157 | " else:\n", 158 | " next_ques, next_ans = one_hot_qa(entry['question'], entry['answer'])\n", 159 | " ques_arr = np.dstack((ques_arr, next_ques))\n", 160 | " ans_arr = np.vstack((ans_arr, next_ans))\n", 161 | " \n", 162 | " counter += 1\n", 163 | " if counter == batch_size:\n", 164 | " yield np.transpose(ques_arr, (2, 0, 1)), np.transpose(ans_arr, (0, 1))\n", 165 | " counter = 0\n", 166 | " ques_arr = np.array([])\n", 167 | " ans_arr = np.array([])" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 4, 173 | "metadata": { 174 | "collapsed": false, 175 | "deletable": true, 176 | "editable": true 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "# Parameters\n", 181 | "learning_rate = 0.001\n", 182 | "training_iters = 100000\n", 183 | "batch_size = 128\n", 184 | "display_step = 10\n", 185 | "\n", 186 | "# Network Parameters\n", 187 | "n_input = 15182 # MNIST data input (img shape: 28*28)\n", 188 | "n_steps = 22 # timesteps\n", 189 | "n_hidden = 128 # hidden layer num of features\n", 190 | "n_classes = 1000 # MNIST total classes (0-9 digits)\n", 191 | "\n", 192 | "# tf Graph input\n", 193 | "x = tf.placeholder(\"float\", [None, n_steps, n_input])\n", 194 | "y = tf.placeholder(\"float\", [None, n_classes])\n", 195 | "\n", 196 | "# Define weights\n", 197 | "weights = {\n", 198 | " 'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))\n", 199 | "}\n", 200 | "biases = {\n", 201 | " 'out': tf.Variable(tf.random_normal([n_classes]))\n", 202 | "}" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 5, 208 | "metadata": { 209 | "collapsed": false, 210 | "deletable": true, 211 | "editable": true 212 | }, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "(?, 22, 15182)\n", 219 | "Iter 2, Minibatch Loss= 6.661595, Training Accuracy= 0.00000\n", 220 | "Iter 3, Minibatch Loss= 9.913521, Training Accuracy= 0.34375\n", 221 | "Iter 4, Minibatch Loss= 6.378217, Training Accuracy= 0.00000\n", 222 | "Iter 5, Minibatch Loss= 6.724869, Training Accuracy= 0.00000\n", 223 | "Iter 6, Minibatch Loss= 6.659103, Training Accuracy= 0.00000\n", 224 | "Iter 7, Minibatch Loss= 6.684987, Training Accuracy= 0.03906\n", 225 | "Iter 8, Minibatch Loss= 6.183796, Training Accuracy= 0.15625\n", 226 | "Iter 9, Minibatch Loss= 7.943097, Training Accuracy= 0.22656\n", 227 | "Iter 10, Minibatch Loss= 5.878778, Training Accuracy= 0.21094\n", 228 | "Iter 11, Minibatch Loss= 6.150809, Training Accuracy= 0.01562\n", 229 | "Iter 12, Minibatch Loss= 6.257290, Training Accuracy= 0.00781\n", 230 | "Iter 13, Minibatch Loss= 6.238634, Training Accuracy= 0.00000\n", 231 | "Iter 14, Minibatch Loss= 6.305418, Training Accuracy= 0.00000\n", 232 | "Iter 15, Minibatch Loss= 6.109542, Training Accuracy= 0.02344\n", 233 | "Iter 16, Minibatch Loss= 5.931391, Training Accuracy= 0.04688\n", 234 | "Iter 17, Minibatch Loss= 5.627723, Training Accuracy= 0.05469\n", 235 | "Iter 18, Minibatch Loss= 5.325613, Training Accuracy= 0.25781\n", 236 | "Iter 19, Minibatch Loss= 5.620074, Training Accuracy= 0.16406\n", 237 | "Iter 20, Minibatch Loss= 4.658114, Training Accuracy= 0.25781\n", 238 | "Iter 21, Minibatch Loss= 4.547112, Training Accuracy= 0.32812\n", 239 | "Iter 22, Minibatch Loss= 3.832364, Training Accuracy= 0.36719\n", 240 | "Iter 23, Minibatch Loss= 4.039901, Training Accuracy= 0.28125\n", 241 | "Iter 24, Minibatch Loss= 4.230416, Training Accuracy= 0.22656\n", 242 | "Iter 25, Minibatch Loss= 3.941765, Training Accuracy= 0.25000\n", 243 | "Iter 26, Minibatch Loss= 3.797354, Training Accuracy= 0.28906\n", 244 | "Iter 27, Minibatch Loss= 3.363014, Training Accuracy= 0.39062\n", 245 | "Iter 28, Minibatch Loss= 3.233534, Training Accuracy= 0.39844\n", 246 | "Iter 29, Minibatch Loss= 3.338089, Training Accuracy= 0.35938\n", 247 | "Iter 30, Minibatch Loss= 3.595165, Training Accuracy= 0.30469\n", 248 | "Iter 31, Minibatch Loss= 3.467433, Training Accuracy= 0.28125\n", 249 | "Iter 32, Minibatch Loss= 3.148545, Training Accuracy= 0.32812\n", 250 | "Iter 33, Minibatch Loss= 2.911647, Training Accuracy= 0.44531\n", 251 | "Iter 34, Minibatch Loss= 3.302559, Training Accuracy= 0.28125\n", 252 | "Iter 35, Minibatch Loss= 3.269233, Training Accuracy= 0.32812\n", 253 | "Iter 36, Minibatch Loss= 3.282627, Training Accuracy= 0.39844\n", 254 | "Iter 37, Minibatch Loss= 3.386882, Training Accuracy= 0.40625\n", 255 | "Iter 38, Minibatch Loss= 3.511612, Training Accuracy= 0.32812\n", 256 | "Iter 39, Minibatch Loss= 3.732874, Training Accuracy= 0.28125\n" 257 | ] 258 | }, 259 | { 260 | "ename": "KeyboardInterrupt", 261 | "evalue": "", 262 | "output_type": "error", 263 | "traceback": [ 264 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 265 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 266 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mitr\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mbatch_question\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_answer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_data_generator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_question\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_answer\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_question\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_answer\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcost\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_question\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_answer\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 267 | "\u001b[0;32m/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 776\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 777\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 778\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 779\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 780\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 268 | "\u001b[0;32m/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 980\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 981\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 982\u001b[0;31m feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m 983\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 984\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 269 | "\u001b[0;32m/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1030\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1031\u001b[0m return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1032\u001b[0;31m target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m 1033\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1034\u001b[0m return self._do_call(_prun_fn, self._session, handle, feed_dict,\n", 270 | "\u001b[0;32m/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1040\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 271 | "\u001b[0;32m/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1019\u001b[0m return tf_session.TF_Run(session, options,\n\u001b[1;32m 1020\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1021\u001b[0;31m status, run_metadata)\n\u001b[0m\u001b[1;32m 1022\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 272 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "def RNN(x, weights, biases):\n", 278 | " print x.get_shape()\n", 279 | " x = tf.unstack(x, n_steps, 1)\n", 280 | " lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n", 281 | " outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)\n", 282 | " return tf.matmul(outputs[-1], weights['out']) + biases['out']\n", 283 | "\n", 284 | "pred = RNN(x, weights, biases)\n", 285 | "\n", 286 | "# Define loss and optimizer\n", 287 | "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n", 288 | "optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)\n", 289 | "\n", 290 | "# # Evaluate model\n", 291 | "correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n", 292 | "accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n", 293 | "\n", 294 | "# # Initializing the variables\n", 295 | "init = tf.global_variables_initializer()\n", 296 | "\n", 297 | "with tf.Session() as sess:\n", 298 | " sess.run(init)\n", 299 | " train_data_generator = get_qa_batches(128)\n", 300 | " itr = 1\n", 301 | " while itr < 10000:\n", 302 | " itr += 1\n", 303 | " batch_question,batch_answer = train_data_generator.next()\n", 304 | " sess.run(optimizer, feed_dict={x: batch_question, y: batch_answer})\n", 305 | " acc = sess.run(accuracy, feed_dict={x: batch_question, y: batch_answer})\n", 306 | " loss = sess.run(cost, feed_dict={x: batch_question, y: batch_answer})\n", 307 | " print \"Iter \" + str(itr) + \", Minibatch Loss= \" + \"{:.6f}\".format(loss) + \", Training Accuracy= \" + \\\n", 308 | " \"{:.5f}\".format(acc)\n" 309 | ] 310 | } 311 | ], 312 | "metadata": { 313 | "kernelspec": { 314 | "display_name": "Python 2", 315 | "language": "python", 316 | "name": "python2" 317 | }, 318 | "language_info": { 319 | "codemirror_mode": { 320 | "name": "ipython", 321 | "version": 2 322 | }, 323 | "file_extension": ".py", 324 | "mimetype": "text/x-python", 325 | "name": "python", 326 | "nbconvert_exporter": "python", 327 | "pygments_lexer": "ipython2", 328 | "version": "2.7.12" 329 | } 330 | }, 331 | "nbformat": 4, 332 | "nbformat_minor": 2 333 | } 334 | --------------------------------------------------------------------------------