├── README.md └── qa system ├── LSTMInferrence_qa.py ├── __pycache__ ├── LSTMInferrence_qa.cpython-34.pyc ├── LSTMInferrence_test.cpython-34.pyc ├── bucket_io_qa.cpython-34.pyc ├── dataSupport_qa.cpython-34.pyc ├── lstm_qa.cpython-34.pyc ├── lstm_v2.cpython-34.pyc └── preprocessing.cpython-34.pyc ├── concatQA.py ├── dataSupport_qa.py ├── dumpText.py ├── evaluateQA.py ├── interactiveQA.py ├── lstm_qa.py ├── param └── params ├── preprocessing.py └── train_lstm_qa.py /README.md: -------------------------------------------------------------------------------- 1 | # Visual-Question-Answering 2 | This is the demo for VQA. 3 | 4 | The data this program need can download from http://visualqa.org/download.html 5 | 6 | This program need the mxnet dependency. That you need install it first. 7 | The websit is :http://mxnet.readthedocs.io/en/latest/how_to/build.html 8 | -------------------------------------------------------------------------------- /qa system/LSTMInferrence_qa.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | 4 | from lstm_qa import LSTMParam2, LSTMState, lstm_inference_symbol 5 | 6 | class LSTMInferrenceModel(object): 7 | 8 | def __init__(self, num_lstm_layer, input_size, num_hidden, 9 | image_feats, num_embed, label_dim, arg_params, images=None, ctx=mx.cpu(), dropout=0.): 10 | 11 | self.sym = lstm_inference_symbol( 12 | num_lstm_layer, input_size, num_hidden, num_embed, label_dim, dropout 13 | ) 14 | batch_size = 1 15 | self.image_feats = image_feats 16 | self.batch_size = batch_size 17 | # all image names 18 | if images is not None: 19 | self.images = images 20 | 21 | init_c = [('l{0}_init_c'.format(i), (batch_size, num_hidden)) for i in range(num_lstm_layer)] 22 | init_h = [('l{0}_init_h'.format(i), (batch_size, num_hidden)) for i in range(num_lstm_layer)] 23 | img_batch_feats = ('img_batch_feats', (batch_size, image_feats.shape[0])) 24 | 25 | data_shape = [('data', (batch_size,))] 26 | input_shape = dict(init_c + init_h + data_shape + img_batch_feats) 27 | 28 | self.executor = self.sym.simple_bind(ctx = mx.cpu(), **input_shape) 29 | 30 | for key in self.executor.arg_dict.keys(): 31 | if key in arg_params: 32 | print('key: {0}\nvalue:{1}, executor\'s shape: {2}'.format(key, arg_params[key].shape, self.executor.arg_dict[key].shape)) 33 | arg_params[key].copyto(self.executor.arg_dict[key]) 34 | 35 | state_name = [] 36 | for i in range(num_lstm_layer): 37 | state_name.append('l{0}_init_c'.format(i)) 38 | state_name.append('l{0}_init_h'.format(i)) 39 | self.state_dict = dict(zip(state_name, self.executor.outputs[1:])) 40 | 41 | self.input_arr = mx.nd.zeros(data_shape[0][1]) 42 | 43 | # each forward will accept a word'index and return a vector of probability of next word 44 | def forward(self, input_data, img_id=None, new_seq=False): 45 | if new_seq: 46 | for key in self.state_dict.keys(): 47 | self.executor.arg_dict[key][:] = 0. 48 | 49 | if img_id is not None: 50 | batch_img_feat = mx.nd.array(self.get_image_feats(img_id)) 51 | else: 52 | batch_img_feat = ma.nd.array(self.image_feats) 53 | batch_img_feat.copyto(self.executor.arg_dict['img_batch_feats']) # fetch the image feature corresponding the question 54 | input_data.copyto(self.executor.arg_dict['data']) 55 | 56 | self.executor.forward() 57 | # for k,v in self.state_dict.items(): 58 | # print(k, ': ', v.asnumpy()) 59 | for key in self.state_dict.keys(): 60 | self.state_dict[key].copyto(self.executor.arg_dict[key]) 61 | prob = self.executor.outputs[0].asnumpy() 62 | return prob 63 | 64 | def get_image_feats(self, idx): 65 | img_feats = np.zeros((self.batch_size, self.image_feats.shape[0])) 66 | if not isinstance(idx, list):# if idx is a number, then put it in a list 67 | idx = [idx] 68 | for i in range(len(idx)): 69 | # id = idx[i] 70 | # img_id = self.images[id] 71 | # map_id = self.feat_id_map[img_id] 72 | # print('img id: ', img_id, 'map_id: ', map_id) 73 | img_feats[i, :] = self.image_feats[:, self.feat_id_map[self.images[i]]] 74 | return img_feats 75 | -------------------------------------------------------------------------------- /qa system/__pycache__/LSTMInferrence_qa.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/LSTMInferrence_qa.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/__pycache__/LSTMInferrence_test.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/LSTMInferrence_test.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/__pycache__/bucket_io_qa.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/bucket_io_qa.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/__pycache__/dataSupport_qa.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/dataSupport_qa.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/__pycache__/lstm_qa.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/lstm_qa.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/__pycache__/lstm_v2.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/lstm_v2.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/__pycache__/preprocessing.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhi136/Visual-Question-Answering/4d2d3f6c73d6889c2402fc938e590ec5dd21bef0/qa system/__pycache__/preprocessing.cpython-34.pyc -------------------------------------------------------------------------------- /qa system/concatQA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def concate(questionPath, answerPath, concationPath): 4 | questions = [q.strip() for q in open(questionPath).readlines()] 5 | answers = [a.strip() for a in open(answerPath).readlines()] 6 | 7 | i = 0 8 | with open(concationPath, 'wt') as writehandle: 9 | for q, a in zip(questions, answers): 10 | writehandle.write(q + ' ' + a + '\n') 11 | i += 1 12 | print('Concate finished! Total number: {0}.'.format(i)) 13 | 14 | if __name__ == '__main__': 15 | 16 | questionPath = '../data/coco_qa/questions/val/questions_val2014.txt' 17 | answerPath = '../data/coco_qa/answers/val/answers_val2014_modal.txt' 18 | concationPath = '../data/coco_qa/concateQA/concateqa_val.txt' 19 | 20 | concate(questionPath, answerPath, concationPath) -------------------------------------------------------------------------------- /qa system/dataSupport_qa.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme 2 | # pylint: disable=superfluous-parens, no-member, invalid-name 3 | import sys 4 | sys.path.insert(0, "../../python") 5 | import numpy as np 6 | import mxnet as mx 7 | 8 | # The interface of a data iter that works for bucketing 9 | # 10 | # DataIter 11 | # - default_bucket_key: the bucket key for the default symbol. 12 | # 13 | # DataBatch 14 | # - provide_data: same as DataIter, but specific to this batch 15 | # - provide_label: same as DataIter, but specific to this batch 16 | # - bucket_key: the key for the bucket that should be used for this batch 17 | 18 | # words is the unique words list. this method use to generate two map: word to num and num to word. 19 | def revocab(words): 20 | revocab = {v:k for k, v in words.items()} 21 | return revocab 22 | 23 | def default_text2id(sentence, the_vocab): 24 | print(sentence) 25 | words = sentence.split(' ') 26 | words = [the_vocab[w] for w in words if len(w) > 0] 27 | return words 28 | 29 | def default_gen_buckets(sentences, batch_size, the_vocab): 30 | len_dict = {} 31 | max_len = -1 32 | # count the number of sentence for each unique length 33 | for sentence in sentences: 34 | words = default_text2id(sentence, the_vocab) 35 | if len(words) == 0: 36 | continue 37 | if len(words) > max_len: 38 | max_len = len(words) 39 | if len(words) in len_dict: 40 | len_dict[len(words)] += 1 41 | else: 42 | len_dict[len(words)] = 1 43 | print('the default generated buckets:\n', len_dict) 44 | 45 | tl = 0 46 | buckets = [] 47 | # create each bucket by batch_size, this operation will be merge different length of sentence into one bucket 48 | # when the number of len less than batch_size 49 | for l, n in len_dict.items(): # TODO: There are better heuristic ways to do this 50 | if n + tl >= batch_size: 51 | buckets.append(l) 52 | tl = 0 53 | else: 54 | tl += n 55 | if tl > 0: 56 | buckets.append(max_len) 57 | return buckets 58 | 59 | 60 | 61 | class SimpleBatch(object): 62 | def __init__(self, data_names, data, label_names, label, bucket_key): 63 | self.data = data 64 | self.label = label 65 | self.data_names = data_names 66 | self.label_names = label_names 67 | self.bucket_key = bucket_key 68 | 69 | self.pad = 0 70 | self.index = None # TODO: what is index? 71 | 72 | @property 73 | def provide_data(self): 74 | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] 75 | 76 | @property 77 | def provide_label(self): 78 | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] 79 | 80 | class SequencesIterQA(mx.io.DataIter): 81 | def __init__(self, QApath, imagePath, vocab, re_vocab, image_feat, 82 | feat_id_map, buckets, batch_size, init_states, batch_img_feats, 83 | data_name='data', label_name='label', 84 | text2id=None, read_content=None, id2text=None): 85 | super(SequencesIterQA, self).__init__() 86 | 87 | if text2id == None: 88 | self.text2id = default_text2id 89 | else: 90 | self.text2id = text2id 91 | if id2text != None: self.id2text = id2text 92 | if read_content == None: 93 | self.read_content = default_read_content 94 | else: 95 | self.read_content = read_content 96 | 97 | # read training set data. 98 | sentences = self.read_content(QApath) 99 | images = self.read_content(imagePath) 100 | 101 | print('Build buckets!') 102 | if len(buckets) == 0: 103 | # these buckets contains all each unique length in the training data for each bucket. 104 | buckets = default_gen_buckets(sentences, batch_size, vocab) 105 | # the length of all unique word 106 | self.vocab_size = len(vocab) 107 | self.data_name = data_name 108 | self.label_name = label_name 109 | self.image_feats = image_feat 110 | self.feat_id_map = feat_id_map 111 | self.images = images 112 | self.batch_img_feats = batch_img_feats 113 | 114 | # Beacuse of each bucket correspond to a different length of sentence, sort buckets by its length. 115 | buckets.sort() 116 | self.buckets = buckets 117 | # create a list of data corresponding to the buckets'length 118 | self.data = [[] for _ in buckets] 119 | 120 | # pre-allocate with the largest bucket for better memory sharing 121 | self.default_bucket_key = max(buckets) 122 | 123 | for sentence in sentences: 124 | sentence = self.text2id(sentence, vocab) 125 | if len(sentence) == 0: 126 | continue 127 | # if current sentence's length less than or equal to the ith bucket, then put the sentence into the ith data list 128 | # after this for loop done, self.data contains len(buckets) sentence list. 129 | for i, bkt in enumerate(buckets): 130 | if bkt >= len(sentence): 131 | self.data[i].append(sentence) 132 | break 133 | # we just ignore the sentence it is longer than the maximum 134 | # bucket size here 135 | 136 | # convert data into ndarrays for better speed during training 137 | data = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data)] 138 | for i_bucket in range(len(self.buckets)): 139 | for j in range(len(self.data[i_bucket])): 140 | sentence = self.data[i_bucket][j] 141 | data[i_bucket][j, :len(sentence)] = sentence 142 | self.data = data 143 | 144 | # Get the size of each bucket, so that we could sample uniformly from the bucket. 145 | # In other words, get the number of sentence for each data bucket in self.data 146 | bucket_sizes = [len(x) for x in self.data] 147 | 148 | print("Summary of dataset ==================") 149 | for bkt, size in zip(buckets, bucket_sizes): 150 | # print the number of sentence for each bucket 151 | print("bucket of len %3d : %d samples" % (bkt, size)) 152 | 153 | self.batch_size = batch_size 154 | self.make_data_iter_plan() 155 | 156 | self.init_states = init_states 157 | print('init_states: {0}'.format(init_states)) 158 | self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] 159 | 160 | self.provide_data = [('data', (batch_size, self.default_bucket_key))] + init_states + [batch_img_feats] 161 | self.provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key))] 162 | 163 | def make_data_iter_plan(self): 164 | "make a random data iteration plan" 165 | # truncate each bucket into multiple of batch-size 166 | bucket_n_batches = [] 167 | for i in range(len(self.data)): 168 | # calculate the number to truncate by batch_size for each bucket 169 | bucket_n_batches.append(len(self.data[i]) / self.batch_size) 170 | # recount the sentence in ith bucket 171 | self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)] 172 | 173 | # make a bucket plan and its element is the form [0...0, 1...1, 2...2...] 0,1,2 indicates the ith bucket and n is the number for truncation 174 | bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)]) 175 | print('-------bucket_plan---------: \n{0}'.format(bucket_plan)) 176 | np.random.shuffle(bucket_plan) 177 | # for each sentence set x in self.data, we generate len(x) number in the range from 0 to len(x)-1 in random order. 178 | # each number correspond to a sentence's index. 179 | bucket_idx_all = [np.random.permutation(len(x)) for x in self.data] 180 | 181 | self.bucket_plan = bucket_plan 182 | self.bucket_idx_all = bucket_idx_all 183 | # initialize the start id, the id start from 0 for each sentence set x in self.data 184 | self.bucket_curr_idx = [0 for x in self.data] 185 | 186 | self.data_buffer = [] 187 | self.label_buffer = [] 188 | # for each bucket, initilize the data and label 189 | for i_bucket in range(len(self.data)): 190 | data = np.zeros((self.batch_size, self.buckets[i_bucket])) 191 | label = np.zeros((self.batch_size, self.buckets[i_bucket])) 192 | self.data_buffer.append(data) 193 | self.label_buffer.append(label) 194 | 195 | def __iter__(self): 196 | 197 | # for each bucket, get one block sentences 198 | for i_bucket in self.bucket_plan: 199 | data = self.data_buffer[i_bucket] 200 | i_idx = self.bucket_curr_idx[i_bucket] 201 | # get batch_size sentence index from the all_idx 202 | idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size] 203 | self.bucket_curr_idx[i_bucket] += self.batch_size 204 | 205 | # print('*********this perform data parallelism************') 206 | init_state_names = [x[0] for x in self.init_states] 207 | data[:] = self.data[i_bucket][idx] 208 | self.batch_img_feats_array = self.get_image_feats(idx) 209 | label = self.label_buffer[i_bucket] 210 | label[:, :-1] = data[:, 1:] 211 | label[:, -1] = 0 212 | 213 | for sentence in data: 214 | assert len(sentence) == self.buckets[i_bucket] 215 | 216 | data_all = [mx.nd.array(data)] + self.init_state_arrays + [mx.nd.array(self.batch_img_feats_array)] 217 | label_all = [mx.nd.array(label)] 218 | data_names = ['data'] + init_state_names + [self.batch_img_feats[0]] 219 | label_names = ['softmax_label'] 220 | 221 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, self.buckets[i_bucket]) 222 | yield data_batch 223 | 224 | 225 | def reset(self): 226 | self.bucket_curr_idx = [0 for x in self.data] 227 | 228 | def get_image_feats(self, idx): 229 | img_feats = np.zeros((self.batch_size, self.image_feats.shape[0])) 230 | for i in range(len(idx)): 231 | # id = idx[i] 232 | # img_id = self.images[id] 233 | # map_id = self.feat_id_map[img_id] 234 | # print('img id: ', img_id, 'map_id: ', map_id) 235 | img_feats[i, :] = self.image_feats[:, self.feat_id_map[self.images[idx[i]]]] 236 | return img_feats -------------------------------------------------------------------------------- /qa system/dumpText.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import argparse 3 | # import progressbar 4 | import re 5 | import json 6 | 7 | def getModalAnswer(answers): 8 | candidates = {} 9 | for i in range(10): 10 | candidates[answers[i]['answer']] = 1 11 | 12 | for i in range(10): 13 | candidates[answers[i]['answer']] += 1 14 | 15 | return max(candidates.items(), key=operator.itemgetter(1))[0] 16 | 17 | def getAllAnswer(answers): 18 | answer_list = [] 19 | for i in range(10): 20 | answer_list.append(answers[i]['answer']) 21 | 22 | return ';'.join(answer_list) 23 | 24 | def counTokens(line): 25 | tokens = [word for word in re.split('\s+|[,.!?;"()]', line) if word.strip()] 26 | return tokens 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('-split', type=str, default='train', 32 | help='Specify which part of the dataset you want to dump to text. Your options are: train, val, test, test-dev') 33 | parser.add_argument('-answers', type=str, default='modal', 34 | help='Specify if you want to dump just the most frequent answer for each questions (modal), or all the answers (all)') 35 | args = parser.parse_args() 36 | 37 | if args.split == 'train': 38 | annFile = '../data/coco_qa/answers/train/mscoco_train2014_annotations.json' 39 | quesFile = '../data/coco_qa/questions/train/OpenEnded_mscoco_train2014_questions.json' 40 | questions_file = open('../data/coco_qa/questions/train/questions_train2014.txt', 'w') 41 | questions_id_file = open('../data/coco_qa/questions/train/questions_id_train2014.txt', 'w') 42 | questions_lengths_file = open('../data/coco_qa/questions/train/questions_lengths_train2014.txt', 'w') 43 | if args.answers == 'modal': 44 | answers_file = open('../data/coco_qa/answers/train/answers_train2014_modal.txt', 'w') 45 | elif args.answers == 'all': 46 | answers_file = open('../data/coco_qa/answers/train/answers_train2014_all.txt', 'w') 47 | coco_image_id = open('./data/coco_qa/images/train/images_train2014.txt', 'w') 48 | data_split = 'training data' 49 | elif args.split == 'val': 50 | annFile = '../data/coco_qa/answers/val/mscoco_val2014_annotations.json' 51 | quesFile = '../data/coco_qa/questions/val/OpenEnded_mscoco_val2014_questions.json' 52 | questions_file = open('../data/coco_qa/questions/val/questions_val2014.txt', 'w') 53 | questions_id_file = open('../data/coco_qa/questions/val/questions_id_val2014.txt', 'w') 54 | questions_lengths_file = open('../data/coco_qa/questions/val/questions_lengths_val2014.txt', 'w') 55 | if args.answers == 'modal': 56 | answers_file = open('../data/coco_qa/answers/val/answers_val2014_modal.txt', 'w') 57 | elif args.answers == 'all': 58 | answers_file = open('../data/coco_qa/answers/val/answers_val2014_all.txt', 'w') 59 | coco_image_id = open('../data/coco_qa/images/val/images_val2014_all.txt', 'w') 60 | data_split = 'validation data' 61 | elif args.split == 'test-dev': 62 | quesFile = '../data/coco_qa/questions/test_dev/OpenEnded_mscoco_test-dev2015_questions.json' 63 | questions_file = open('../data/coco_qa/questions/test_dev/questions_test-dev2015.txt', 'w') 64 | questions_id_file = open('../data/coco_qa/questions/test_dev/questions_id_test-dev2015.txt', 'w') 65 | questions_lengths_file = open('../data/coco_qa/questions/test_dev/questions_lengths_test-dev2015.txt', 'w') 66 | coco_image_id = open('../data/coco_qa/images/test_dev/images_test-dev2015.txt', 'w') 67 | data_split = 'test-dev data' 68 | elif args.split == 'test': 69 | quesFile = '../data/coco_qa/questions/test/OpenEnded_mscoco_test2015_questions.json' 70 | questions_file = open('../data/coco_qa/questions/test/questions_test2015.txt', 'w') 71 | questions_id_file = open('../data/coco_qa/questions/test/questions_id_test2015.txt', 'w') 72 | questions_lengths_file = open('../data/coco_qa/questions/test/questions_lengths_test2015.txt', 'w') 73 | coco_image_id = open('../data/coco_qa/images/test/images_test2015.txt', 'w') 74 | data_split = 'test data' 75 | else: 76 | raise RuntimeError('Incorrect split. Your choices are:\ntrain\nval\ntest-dev\ntest') 77 | 78 | # initialize questions and answers 79 | questions = json.load(open(quesFile, 'r')) 80 | ques = questions['questions'] 81 | print('number of questions: {0}'.format(len(ques))) 82 | if args.split == 'train' or args.split == 'val': 83 | qa = json.load(open(annFile, 'r')) 84 | ans = qa['annotations'] 85 | # print('number of answers: {0}'.format(len(ans))) 86 | 87 | 88 | iterator = zip(range(len(ques)), ques) 89 | # pbar = progressbar.ProgressBar() 90 | print('Dumping questions, answers, questionIDs, imageIDs, and questions lengths to text files...') 91 | for i, q in iter(iterator): 92 | questions_file.write((q['question'] + '\n')) 93 | questions_id_file.write((str(q['question_id']) + '\n')) 94 | questions_lengths_file.write((str(len(counTokens(q['question']))) + '\n')) 95 | coco_image_id.write((str(q['image_id']) + '\n')) 96 | print('id: {0}, ques_id: {1}, question: {2}, ques_lengths: {3}, image_id: {4}'.format( 97 | i, q['question_id'], q['question'], len(counTokens(q['question'])), q['image_id'])) 98 | if args.split == 'train' or args.split == 'val': 99 | if args.answers == 'modal': 100 | answers_file.write(getModalAnswer(ans[i]['answers']) + '\n') 101 | elif args.answers == 'all': 102 | answers_file.write(getAllAnswer(ans[i]['answers']) + '\n') 103 | # print('answers: {0}'.format(ans[i]['answers'])) 104 | 105 | print('completed dumping ', data_split) 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /qa system/evaluateQA.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import scipy.io 5 | import os 6 | import random 7 | import bisect 8 | 9 | from PIL import Image 10 | from LSTMInferrence_qa import LSTMInferrenceModel 11 | from dataSupport_qa import revocab 12 | 13 | # read each line from the target file and extract unique word 14 | def gen_vocab(readfile, vocab=None): 15 | if vocab is None: 16 | vocab = {} 17 | vocab['#'] = 0 18 | idx = len(vocab) 19 | content = read_content(readfile) 20 | words = [] 21 | for line in content: 22 | if len(re.findall('[0-9]', line)) > 0 or len(re.findall('[()\']', line)) > 0: 23 | tmp = [word for word in re.split('\s+|[,!?;]', line) if word.strip()] 24 | else: 25 | tmp = [word for word in re.split('\s+|[,!?;"()]', line) if word.strip()] 26 | words.extend(tmp) 27 | 28 | for character in words: 29 | if len(character) == 0: continue 30 | if character not in vocab: 31 | vocab[character] = idx 32 | idx += 1 33 | return vocab 34 | 35 | # Read from doc, each line is a element in content 36 | def read_content(path): 37 | lines = [line.strip() for line in open(path).readlines() if line.strip()] 38 | return lines 39 | 40 | def text2id(sentence, the_vocab): 41 | 42 | if len(sentence) == 0: return [] 43 | if len(re.findall('[0-9]', sentence)) > 0 or len(re.findall('[()\']', sentence)) > 0: 44 | words = [word for word in re.split('\s+|[,!?;]', sentence) if word.strip()] 45 | else: 46 | words = [word for word in re.split('\s+|[,!?;"()]', sentence) if word.strip()] 47 | words = [the_vocab[w] for w in words if len(w) > 0] 48 | 49 | return words 50 | 51 | def load_dict_data(loadfile, separator=':'): 52 | params = {} 53 | for line in open(loadfile, 'r').readlines(): 54 | k_v = line.strip().split(separator) 55 | params[k_v[0]] = int(k_v[1]) 56 | return params 57 | 58 | def makeInput(word, word2num, arr): 59 | if type(word) == int: 60 | ind = word 61 | else: 62 | if len(word) == 0: 63 | return -1 64 | ind = word2num[word] 65 | 66 | tmp = np.zeros((1,)) 67 | tmp[0] = ind 68 | arr[:] = tmp 69 | 70 | def makeOutput(prob, num2word): 71 | ind = np.maxarg(prob)[0] 72 | 73 | try: 74 | char = num2word[ind] 75 | except: 76 | char = '' 77 | return char 78 | 79 | def isEqual(correct, predict): 80 | if len(correct) == 0 or len(predict) == 0: 81 | return False 82 | if not isinstance(correct, list): 83 | if len(re.findall('[0-9]', correct)) > 0 or len(re.findall('[()\']', correct)) > 0: 84 | correct = [word for word in re.split('\s+|[,!?;]', correct) if word.strip()] 85 | else: 86 | correct = [word for word in re.split('\s+|[,!?;"()]', correct) if word.strip()] 87 | for c, p in zip(correct, predict): 88 | if c != p: 89 | return False 90 | return True 91 | 92 | if __name__ == '__main__': 93 | 94 | # these two qa files use to build dictionary. 95 | qa_train = '../data/coco_qa/concateQA/concateqa_train.txt' 96 | qa_val = '../data/coco_qa/concateQA/concateqa_val.txt' 97 | questions_val = '../data/coco_qa/questions/val/questions_val2014.txt' 98 | answers_val = '../data/coco_qa/answers/val/answers_val2014_modal.txt' 99 | image_val = '../data/coco_qa/images/val/images_val2014_all.txt' 100 | real_images = '/media/leo/qa\ images/val2014/' 101 | vgg_feats_path = '../data/coco_qa/image_fatures/vgg_feats.mat' 102 | image_map_ids = '../data/coco_qa/image_fatures/coco_vgg_IDMap.txt' 103 | params_file = './param/params' 104 | 105 | ques_val = read_content(questions_val) 106 | ans_val = read_content(answers_val) 107 | img_val = read_content(imagePath) 108 | image_names = os.listdir(real_images) 109 | 110 | # build vocabulary according to the questions and answers in training set and validation set 111 | vocabulary = gen_vocab(qa_train) 112 | print('only train: {0}'.format(len(vocabulary))) 113 | vocabulary = gen_vocab(qa_val, vocabulary) 114 | print('train and val: {0}'.format(len(vocabulary))) 115 | re_vocab = revocab(vocabulary) 116 | params = load_dict_data(params_file) 117 | 118 | # load image features 119 | img_feats = scipy.io.loadmat(vgg_feats_path)['feats'] 120 | images_ids = read_content(image_map_ids) 121 | image_id = {}# key is str, value is int 122 | for line in images_ids: 123 | tmp = line.split(' ') 124 | image_id[tmp[0]] = int(tmp[1]) 125 | 126 | print('vocabulary:\n{0}'.format(len(vocabulary))) 127 | 128 | # load model form check_point 129 | load_epoch = None 130 | __, arg_params, __ = mx.model.load_checkpoint('./model/QAmodel', load_epoch) 131 | 132 | # build an inferential model 133 | model = LSTMInferrenceModel( 134 | num_lstm_layer = params['num_lstm_layer'], input_size = len(vocabulary), 135 | num_hidden = params['num_hidden'], image_feats = img_feats, num_embed = params['num_embed'], 136 | images = images_val, label_dim = len(vocabulary), arg_params = arg_params, dropout=0.5 137 | ) 138 | 139 | endchar = '#' 140 | accuracy = 0 141 | for i in range(len(ques_val)): 142 | q = ques_val[i] 143 | a = ans_val[i] 144 | img = img_val[i] 145 | 146 | indata = mx.nd.zeros((1,)) 147 | next_char = '' 148 | i = 0 149 | newSentence = True 150 | ques = text2id(q, vocabulary) 151 | outputs = [] 152 | ignore_length = len(ques) 153 | # produce predicted answer 154 | while next_char != endchar: 155 | 156 | if i <= ignore_length - 1: 157 | next_char = ques[i] 158 | else: 159 | next_char = outputs[-1] 160 | makeInput(next_char, vocabulary, indata) 161 | 162 | prob = model.forward(indata, img_id, newSentence) 163 | newSentence = False 164 | 165 | if i >= ignore_length - 1: 166 | next_char = makeOutput(prob, revocab, vocabulary) 167 | if next_char == '#': break 168 | outputs.append(next_char) 169 | i += 1 170 | # count the correct prediction 171 | if (isEqual(a, outputs)): accuracy += 1 172 | 173 | # show current validation image 174 | target_real_image = real_images 175 | for img_name in image_names: 176 | if img in img_name: 177 | target_real_image += img_name 178 | break 179 | img = Image.open(target_real_image) 180 | plt.figure(target_real_image.split('.')[0]) 181 | plt.title('answers: {0}\npredicted answers: {1}'.format(a, ' '.outputs)) 182 | plt.imshow(img) 183 | plt.show() 184 | 185 | # print the accuracy on validation set 186 | print('The accuracy on validation: {0}%'.format(accuracy)) 187 | 188 | -------------------------------------------------------------------------------- /qa system/interactiveQA.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from preprocessing import extract_feature 5 | import os 6 | 7 | from PIL import Image 8 | from LSTMInferrence_qa import LSTMInferrenceModel 9 | 10 | # read each line from the target file and extract unique word 11 | def gen_vocab(readfile, vocab=None): 12 | if vocab is None: 13 | vocab = {} 14 | vocab['#'] = 0 15 | idx = len(vocab) 16 | content = read_content(readfile) 17 | words = [] 18 | for line in content: 19 | if len(re.findall('[0-9]', line)) > 0 or len(re.findall('[()\']', line)) > 0: 20 | tmp = [word for word in re.split('\s+|[,!?;]', line) if word.strip()] 21 | else: 22 | tmp = [word for word in re.split('\s+|[,!?;"()]', line) if word.strip()] 23 | words.extend(tmp) 24 | 25 | for character in words: 26 | if len(character) == 0: continue 27 | if character not in vocab: 28 | vocab[character] = idx 29 | idx += 1 30 | return vocab 31 | 32 | # Read from doc, each line is a element in content 33 | def read_content(path): 34 | lines = [line.strip() for line in open(path).readlines() if line.strip()] 35 | return lines 36 | 37 | def text2id(sentence, the_vocab): 38 | 39 | if len(sentence) == 0: return [] 40 | if len(re.findall('[0-9]', sentence)) > 0 or len(re.findall('[()\']', sentence)) > 0: 41 | words = [word for word in re.split('\s+|[,!?;]', sentence) if word.strip()] 42 | else: 43 | words = [word for word in re.split('\s+|[,!?;"()]', sentence) if word.strip()] 44 | words = [the_vocab[w] for w in words if len(w) > 0] 45 | 46 | return words 47 | 48 | def load_dict_data(loadfile, separator=':'): 49 | params = {} 50 | for line in open(loadfile, 'r').readlines(): 51 | k_v = line.strip().split(separator) 52 | params[k_v[0]] = int(k_v[1]) 53 | return params 54 | 55 | def makeInput(word, word2num, arr): 56 | if type(word) == int: 57 | ind = word 58 | else: 59 | if len(word) == 0: 60 | return -1 61 | ind = word2num[word] 62 | 63 | tmp = np.zeros((1,)) 64 | tmp[0] = ind 65 | arr[:] = tmp 66 | 67 | def makeOutput(prob, num2word): 68 | ind = np.maxarg(prob)[0] 69 | 70 | try: 71 | char = num2word[ind] 72 | except: 73 | char = '' 74 | return char 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | # these two qa files use to build dictionary. 80 | qa_train = '../data/coco_qa/concateQA/concateqa_train.txt' 81 | qa_val = '../data/coco_qa/concateQA/concateqa_val.txt' 82 | params_file = './param/params' 83 | 84 | # build vocabulary according to the questions and answers in training set and validation set 85 | vocabulary = gen_vocab(qa_train) 86 | print('only train: {0}'.format(len(vocabulary))) 87 | vocabulary = gen_vocab(qa_val, vocabulary) 88 | print('train and val: {0}'.format(len(vocabulary))) 89 | re_vocab = revocab(vocabulary) 90 | params = load_dict_data(params_file) 91 | 92 | # print('vocabulary:\n{0}'.format(len(vocabulary))) 93 | 94 | # load cnn model 95 | # extract image feature 96 | load_cnn_epoch = 1 97 | model_prefix = './model/cnn_model/Inception-7' 98 | cnn_model = mx.model.FeedForward.load(model_prefix, load_cnn_epoch, ctx=mx.gpu()) 99 | internals = cnn_model.symbol.get_internals() 100 | fea_symbol = internals['flatten_output'] 101 | feature_extractor = mx.model.FeedForward(ctx = mx.gpu(), symbol = fea_symbol, numpy_batch_size=1, 102 | arg_params = cnn_model.arg_params, aux_params = cnn_model.aux_params, 103 | allow_extra_params = True) 104 | 105 | image_path = input('please input the image path').strip() 106 | path = '/media/leo/新加卷/qa images/train2014/COCO_train2014_000000000009.jpg' # used to test 107 | img_feat = extract_feature(image_path, feature_extractor) 108 | # print(img_feat.shape) 109 | 110 | # load LSTM model form check_point 111 | load_lstm_epoch = 11 112 | __, lstm_arg_params, __ = mx.model.load_checkpoint('./model/QAmodel', load_epoch) 113 | 114 | # build an inferential model 115 | model = LSTMInferrenceModel( 116 | num_lstm_layer = params['num_lstm_layer'], input_size = len(vocabulary), 117 | num_hidden = params['num_hidden'], image_feats = img_feat, num_embed = params['num_embed'], 118 | label_dim = len(vocabulary), arg_params = lstm_arg_params, dropout=0.5 119 | ) 120 | # it may raise an error 121 | # because the trained lstm network used image features prodeced by vggnet 122 | # whose feature dimension is 4096 123 | # and this program use google net to extract image feature, 124 | # which will prodece 2048 dimensions feature vector 125 | 126 | # I suggest that if you want to perform this program interactively, 127 | # you should use this google net to extract all training and testing image features firstly. 128 | # then, you should use these features to train a new lstm network. 129 | # after that, you can correctly perform this interactive program. 130 | 131 | while True: 132 | question = input('please input question').strip() 133 | endchar = '#' 134 | 135 | indata = mx.nd.zeros((1,)) 136 | next_char = '' 137 | i = 0 138 | newSentence = True 139 | ques = text2id(question, vocabulary) 140 | outputs = [] 141 | ignore_length = len(ques) 142 | # produce predicted answer 143 | while next_char != endchar: 144 | 145 | if i <= ignore_length - 1: 146 | next_char = ques[i] 147 | else: 148 | next_char = outputs[-1] 149 | makeInput(next_char, vocabulary, indata) 150 | 151 | prob = model.forward(indata, newSentence) 152 | newSentence = False 153 | 154 | if i >= ignore_length - 1: 155 | next_char = makeOutput(prob, revocab, vocabulary) 156 | if next_char == '#': break 157 | outputs.append(next_char) 158 | i += 1 159 | 160 | # show current image and predicted answer 161 | img = Image.open(image_path) 162 | plt.figure(target_real_image.split('.')[0]) 163 | plt.title('answers: {0}\npredicted answers: {1}'.format(a, ' '.outputs)) 164 | plt.imshow(img) 165 | plt.show() -------------------------------------------------------------------------------- /qa system/lstm_qa.py: -------------------------------------------------------------------------------- 1 | # pylint:skip-file 2 | import sys 3 | sys.path.insert(0, "../../python") 4 | import mxnet as mx 5 | import numpy as np 6 | from collections import namedtuple 7 | import time 8 | import math 9 | LSTMState = namedtuple("LSTMState", ["c", "h"]) 10 | LSTMParam2 = namedtuple("LSTMParam2", ["i2h_weight_f", "i2h_bias_f", "h2h_weight_f", "h2h_bias_f", 11 | "i2h_weight_i", "i2h_bias_i", "h2h_weight_i", "h2h_bias_i", 12 | "i2h_weight_t", "i2h_bias_t", "h2h_weight_t", "h2h_bias_t", 13 | "i2h_weight_o", "i2h_bias_o", "h2h_weight_o", "h2h_bias_o", 14 | ]) 15 | 16 | def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 17 | """ LSTM Cell symbol """ 18 | if dropout > 0: 19 | indata = mx.sym.Dropout(data=indata, p=dropout) 20 | # forget gate 21 | i2h_f = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight_f, bias=param.i2h_bias_f, 22 | num_hidden=num_hidden, name="t{0}_l{1}_i2h_f".format(seqidx, layeridx)) 23 | h2h_f = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight_f, bias=param.h2h_bias_f, 24 | num_hidden=num_hidden, name="t{0}_l{1}_h2h_f".format(seqidx, layeridx)) 25 | forget_gate = mx.sym.Activation(i2h_f+h2h_f, act_type="sigmoid") 26 | # input gate 27 | i2h_i = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight_i, bias=param.i2h_bias_i, 28 | num_hidden=num_hidden, name="t{0}_l{1}_i2h_i".format(seqidx, layeridx)) 29 | h2h_i = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight_i, bias=param.h2h_bias_i, 30 | num_hidden=num_hidden, name="t{0}_l{1}_h2h_i".format(seqidx, layeridx)) 31 | in_gate = mx.sym.Activation(i2h_i+h2h_i, act_type="sigmoid") 32 | # transform gate 33 | i2h_t = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight_t, bias=param.i2h_bias_t, 34 | num_hidden=num_hidden, name="t{0}_l{1}_i2h_t".format(seqidx, layeridx)) 35 | h2h_t = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight_t, bias=param.h2h_bias_t, 36 | num_hidden=num_hidden, name="t{0}_l{1}_h2h_t".format(seqidx, layeridx)) 37 | transform_gate = mx.sym.Activation(i2h_t+h2h_t, act_type="sigmoid") 38 | # output gate 39 | i2h_o = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight_o, bias=param.i2h_bias_o, 40 | num_hidden=num_hidden, name="t{0}_l{1}_i2h_o".format(seqidx, layeridx)) 41 | h2h_o = mx.sym.FullyConnected(data=prev_state.h, weight=param.h2h_weight_o, bias=param.h2h_bias_o, 42 | num_hidden=num_hidden, name="t{0}_l{1}_h2h_o".format(seqidx, layeridx)) 43 | out_gate = mx.sym.Activation(i2h_o+h2h_o, act_type="sigmoid") 44 | next_c = (forget_gate * prev_state.c) + (in_gate * transform_gate) 45 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 46 | return LSTMState(c=next_c, h=next_h) 47 | 48 | 49 | # we define a new unrolling function here because the original 50 | # one in lstm.py concats all the labels at the last layer together, 51 | # making the mini-batch size of the label different from the data. 52 | # I think the existing data-parallelization code need some modification 53 | # to allow this situation to work properly 54 | def unroll_lstm(num_lstm_layer, seq_len, input_size, batch_size, 55 | num_hidden, num_embed, num_label, dropout=0.): 56 | 57 | embed_weight = mx.sym.Variable("embed_weight") 58 | pred_weight = mx.sym.Variable("pred_weight") 59 | pred_bias = mx.sym.Variable("pred_bias") 60 | param_cells = [] 61 | last_states = [] 62 | for i in range(num_lstm_layer): 63 | param_cells.append(LSTMParam2(i2h_weight_f=mx.sym.Variable("l{0}_i2h_f_weight".format(i)), 64 | i2h_bias_f=mx.sym.Variable("l{0}_i2h_f_bias".format(i)), 65 | h2h_weight_f=mx.sym.Variable("l{0}_h2h_f_weight".format(i)), 66 | h2h_bias_f=mx.sym.Variable("l{0}_h2h_f_bias".format(i)), 67 | i2h_weight_i=mx.sym.Variable("l{0}_i2h_i_weight".format(i)), 68 | i2h_bias_i=mx.sym.Variable("l{0}_i2h_i_bias".format(i)), 69 | h2h_weight_i=mx.sym.Variable("l{0}_h2h_i_weight".format(i)), 70 | h2h_bias_i=mx.sym.Variable("l{0}_h2h_i_bias".format(i)), 71 | i2h_weight_t=mx.sym.Variable("l{0}_i2h_t_weight".format(i)), 72 | i2h_bias_t=mx.sym.Variable("l{0}_i2h_t_bias".format(i)), 73 | h2h_weight_t=mx.sym.Variable("l{0}_h2h_t_weight".format(i)), 74 | h2h_bias_t=mx.sym.Variable("l{0}_h2h_t_bias".format(i)), 75 | i2h_weight_o=mx.sym.Variable("l{0}_i2h_o_weight".format(i)), 76 | i2h_bias_o=mx.sym.Variable("l{0}_i2h_o_bias".format(i)), 77 | h2h_weight_o=mx.sym.Variable("l{0}_h2h_o_weight".format(i)), 78 | h2h_bias_o=mx.sym.Variable("l{0}_h2h_o_bias".format(i)), 79 | )) 80 | 81 | state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 82 | h=mx.sym.Variable("l%d_init_h" % i)) 83 | last_states.append(state) 84 | assert(len(last_states) == num_lstm_layer) 85 | 86 | # embeding layer 87 | data = mx.sym.Variable('data') 88 | label = mx.sym.Variable('softmax_label') 89 | img_feat = mx.sym.Variable('img_batch_feats') 90 | embed = mx.sym.Embedding(data=data, input_dim=input_size, 91 | weight=embed_weight, output_dim=num_embed, name='embed') 92 | # print('seq_len: {0}, embed_output_dim: {1}'.format(seq_len, num_embed)) 93 | wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) 94 | 95 | hidden_all = [] 96 | # this hidden_all append hidden according the input data's order 97 | for seqidx in range(seq_len): 98 | hidden = mx.sym.Concat(*[img_feat,wordvec[seqidx]], dim=1) 99 | 100 | # stack LSTM 101 | for i in range(num_lstm_layer): 102 | if i == 0: 103 | dp_ratio = 0. 104 | else: 105 | dp_ratio = dropout 106 | next_state = lstm(num_hidden, indata=hidden, 107 | prev_state=last_states[i], 108 | param=param_cells[i], 109 | seqidx=seqidx, layeridx=i, dropout=dp_ratio) 110 | hidden = next_state.h 111 | last_states[i] = next_state 112 | # decoder 113 | if dropout > 0.: 114 | x = mx.sym.Dropout(data=hidden, p=dropout) 115 | hidden_all.append(hidden) 116 | 117 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 118 | pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, 119 | weight=pred_weight, bias=pred_bias, name='pred') 120 | 121 | ################################################################################ 122 | # Make label the same shape as our produced data path 123 | # I did not observe big speed difference between the following two ways 124 | 125 | # label = mx.sym.transpose(data=label) 126 | # label = mx.sym.Reshape(data=label, shape=(-1,)) 127 | 128 | # in order to keep consistent with input data'hidden_all order, the label also concate label's data in the same order 129 | label_slice = mx.sym.SliceChannel(data=label, num_outputs=seq_len) 130 | label = [label_slice[t] for t in range(seq_len)] 131 | label = mx.sym.Concat(*label, dim=0) 132 | label = mx.sym.Reshape(data=label, shape=(-1,)) 133 | ################################################################################ 134 | 135 | sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') 136 | 137 | return sm 138 | 139 | def lstm_inference_symbol(num_lstm_layer, input_size, 140 | num_hidden, num_embed, num_label, dropout=0.): 141 | seqidx = 0 142 | embed_weight=mx.sym.Variable("embed_weight") 143 | pred_weight = mx.sym.Variable("pred_weight") 144 | pred_bias = mx.sym.Variable("pred_bias") 145 | param_cells = [] 146 | last_states = [] 147 | for i in range(num_lstm_layer): 148 | param_cells.append(LSTMParam2(i2h_weight_f=mx.sym.Variable("l{0}_i2h_f_weight".format(i)), 149 | i2h_bias_f=mx.sym.Variable("l{0}_i2h_f_bias".format(i)), 150 | h2h_weight_f=mx.sym.Variable("l{0}_h2h_f_weight".format(i)), 151 | h2h_bias_f=mx.sym.Variable("l{0}_h2h_f_bias".format(i)), 152 | i2h_weight_i=mx.sym.Variable("l{0}_i2h_i_weight".format(i)), 153 | i2h_bias_i=mx.sym.Variable("l{0}_i2h_i_bias".format(i)), 154 | h2h_weight_i=mx.sym.Variable("l{0}_h2h_i_weight".format(i)), 155 | h2h_bias_i=mx.sym.Variable("l{0}_h2h_i_bias".format(i)), 156 | i2h_weight_t=mx.sym.Variable("l{0}_i2h_t_weight".format(i)), 157 | i2h_bias_t=mx.sym.Variable("l{0}_i2h_t_bias".format(i)), 158 | h2h_weight_t=mx.sym.Variable("l{0}_h2h_t_weight".format(i)), 159 | h2h_bias_t=mx.sym.Variable("l{0}_h2h_t_bias".format(i)), 160 | i2h_weight_o=mx.sym.Variable("l{0}_i2h_o_weight".format(i)), 161 | i2h_bias_o=mx.sym.Variable("l{0}_i2h_o_bias".format(i)), 162 | h2h_weight_o=mx.sym.Variable("l{0}_h2h_o_weight".format(i)), 163 | h2h_bias_o=mx.sym.Variable("l{0}_h2h_o_bias".format(i)), 164 | )) 165 | 166 | state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 167 | h=mx.sym.Variable("l%d_init_h" % i)) 168 | last_states.append(state) 169 | assert(len(last_states) == num_lstm_layer) 170 | data = mx.sym.Variable("data") 171 | img_feat = mx.sym.Variable('img_batch_feats') 172 | 173 | wordvec = mx.sym.Embedding(data=data, 174 | input_dim=input_size, 175 | output_dim=num_embed, 176 | weight=embed_weight, 177 | name="embed") 178 | hidden = [img_feat, wordvec] 179 | # stack LSTM 180 | for i in range(num_lstm_layer): 181 | if i==0: 182 | dp=0. 183 | else: 184 | dp = dropout 185 | next_state = lstm(num_hidden, indata=hidden, 186 | prev_state=last_states[i], 187 | param=param_cells[i], 188 | seqidx=seqidx, layeridx=i, dropout=dp) 189 | hidden = next_state.h 190 | last_states[i] = next_state 191 | # decoder 192 | if dropout > 0.: 193 | hidden = mx.sym.Dropout(data=hidden, p=dropout) 194 | fc = mx.sym.FullyConnected(data=hidden, num_hidden=num_label, 195 | weight=pred_weight, bias=pred_bias, name='pred') 196 | sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') 197 | output = [sm] 198 | for state in last_states: 199 | output.append(state.c) 200 | output.append(state.h) 201 | return mx.sym.Group(output) 202 | -------------------------------------------------------------------------------- /qa system/param/params: -------------------------------------------------------------------------------- 1 | num_embed:256 2 | num_hidden:512 3 | num_lstm_layer:2 4 | -------------------------------------------------------------------------------- /qa system/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mxnet as mx 3 | import os 4 | from skimage import io, transform 5 | 6 | def PreprocessImage(path): 7 | # load image 8 | img = io.imread(path) 9 | print("Original Image Shape: ", img.shape) 10 | # we crop image from center 11 | short_egde = min(img.shape[:2]) 12 | yy = int((img.shape[0] - short_egde) / 2) 13 | xx = int((img.shape[1] - short_egde) / 2) 14 | crop_img = img[yy : yy + short_egde, xx : xx + short_egde] 15 | # resize to 299, 299 16 | resized_img = transform.resize(crop_img, (299, 299)) 17 | # convert to numpy.ndarray 18 | sample = np.asarray(resized_img) * 256 19 | # swap axes to make image from (299, 299, 3) to (3, 299, 299) 20 | sample = np.swapaxes(sample, 0, 2) 21 | sample = np.swapaxes(sample, 1, 2) 22 | # sub mean 23 | normed_img = sample - 128. 24 | normed_img /= 128. 25 | 26 | return np.reshape(normed_img, (1, 3, 299, 299)) 27 | 28 | # extract all image features under the given path 29 | def extract_all_features(path, cnn_model): 30 | image_names = os.listdir(path) 31 | image_feats = [] 32 | # these image names order shuold correspond to the question 33 | for image_name in image_names: 34 | feature = extract_feature(image_name, cnn_model) 35 | if len(feature) == 0: 36 | image_feats.append(None) 37 | else: 38 | image_feats.append(feature) 39 | 40 | return image_feats 41 | 42 | # path is the image's path 43 | def extract_feature(path, cnn_model): 44 | 45 | image = PreprocessImage(path) 46 | flatten_output = cnn_model.predict(image) 47 | feat = flatten_output[0] 48 | return np.reshape(feat, (len(feat), 1)) 49 | -------------------------------------------------------------------------------- /qa system/train_lstm_qa.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import scipy.io 4 | import re 5 | import logging 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | import lstm_qa as lstm 10 | from dataSupport_qa import SequencesIterQA, revocab 11 | 12 | logger = logging.getLogger() 13 | logger.setLevel(logging.DEBUG) 14 | 15 | 16 | # read each line from the target file and extract unique word 17 | def gen_vocab(readfile, vocab=None): 18 | if vocab is None: 19 | vocab = {} 20 | vocab['#'] = 0 21 | print('vocab len: {0}'.format(len(vocab))) 22 | idx = len(vocab) 23 | content = read_content(readfile) 24 | words = [] 25 | for line in content: 26 | if len(re.findall('[0-9]', line)) > 0 or len(re.findall('[()\']', line)) > 0: 27 | tmp = [word for word in re.split('\s+|[,!?;]', line) if word.strip()] 28 | else: 29 | tmp = [word for word in re.split('\s+|[,!?;"()]', line) if word.strip()] 30 | words.extend(tmp) 31 | 32 | print('text: {0}'.format(len(words))) 33 | 34 | for character in words: 35 | if len(character) == 0: continue 36 | if character not in vocab: 37 | vocab[character] = idx 38 | idx += 1 39 | return vocab 40 | 41 | # Read from doc, each line is a element in content 42 | def read_content(path): 43 | lines = [line.strip() for line in open(path).readlines() if line.strip()] 44 | return lines 45 | 46 | 47 | def text2id(sentence, the_vocab): 48 | # print('before: ', sentence) 49 | if len(sentence) == 0: return [] 50 | if len(re.findall('[0-9]', sentence)) > 0 or len(re.findall('[()\']', sentence)) > 0: 51 | words = [word for word in re.split('\s+|[,!?;]', sentence) if word.strip()] 52 | else: 53 | words = [word for word in re.split('\s+|[,!?;"()]', sentence) if word.strip()] 54 | words = [the_vocab[w] for w in words if len(w) > 0] 55 | # print('after: ', words) 56 | return words 57 | 58 | def splitwords(sentencelist): 59 | if len(sentencelist) == 0: return [] 60 | wordslist = [] 61 | for sentence in sentencelist: 62 | if len(sentence) == 0: continue 63 | if len(re.findall('[0-9]', sentence)) > 0 or len(re.findall('[()\']', sentence)) > 0: 64 | words = [word for word in re.split('\s+|[,!?;]', sentence) if word.strip()] 65 | else: 66 | words = [word for word in re.split('\s+|[,!?;"()]', sentence) if word.strip()] 67 | wordslist.append(words) 68 | return wordslist 69 | 70 | # Evaluation metric 1 71 | def Perplexity(label, pred): 72 | # print('pred\'s shape: {0}'.format(pred.shape)) 73 | # nums = [] 74 | # print('label: {0}'.format(idx2text(num2word, label[:,0]))) 75 | 76 | # label = np.array([label[i] for i in range(label.shape[1])]) 77 | # label = label.reshape((-1,)) 78 | label = label.T.reshape((-1,)) 79 | # print('label i shape: {0}'.format(label.shape)) 80 | loss = 0. 81 | for i in range(pred.shape[0]): 82 | loss += -np.log(max(1e-10, pred[i][int(label[i])])) 83 | return np.exp(loss / label.size) 84 | 85 | # Evaluation metric 2 86 | def CrossEntropySoftmax(label, pred): 87 | # nums = [] 88 | label = label.T.reshape((-1),) 89 | # print('label\'shape: {0}'.format(label.shape)) 90 | loss = 0. 91 | for i in range(pred.shape[0]): 92 | loss += -np.log(pred[i][int(label[i])] + 1e-8) 93 | return loss / pred.shape[0] 94 | 95 | def idx2text(re_vocab, nums): 96 | text = [re_vocab[num] for num in nums] 97 | return text 98 | 99 | # save vocabulary 100 | def save_vocab(savefile, vocabulary): 101 | with open(savefile, 'wt') as writehandle: 102 | for word in vocabulary: 103 | writehandle.write(word+' ') 104 | 105 | # save dict data 106 | def save_dict(savefile, data): 107 | if not type(data) == dict: 108 | raise TypeError('The method need a dict type data.') 109 | with open(savefile, 'wt') as writehandle: 110 | for k,v in data.items(): 111 | record = k+':'+str(v)+'\n' 112 | writehandle.write(record) 113 | 114 | if __name__ == '__main__': 115 | print('************Strat Training*************') 116 | 117 | qa_train = '../data/coco_qa/concateQA/concateqa_train.txt' 118 | qa_val = '../data/coco_qa/concateQA/concateqa_val.txt' 119 | image_train = '../data/coco_qa/images/train/images_train2014.txt' 120 | vgg_feats_path = '../data/coco_qa/image_fatures/vgg_feats.mat' 121 | image_map_ids = '../data/coco_qa/image_fatures/coco_vgg_IDMap.txt' 122 | 123 | # build vocabulary according to the questions and answers in training set and validation set 124 | vocabulary = gen_vocab(qa_train) 125 | print('only train: {0}'.format(len(vocabulary))) 126 | vocabulary = gen_vocab(qa_val, vocabulary) 127 | print('train and val: {0}'.format(len(vocabulary))) 128 | # print(sorted(vocabulary.items(), key=lambda vocab:vocab[1], reverse=False)) 129 | re_vocab = revocab(vocabulary) 130 | 131 | # load image features 132 | image_feats = scipy.io.loadmat(vgg_feats_path)['feats'] 133 | images_ids = read_content(image_map_ids) 134 | image_id = {}# key is str, value is int 135 | for line in images_ids: 136 | tmp = line.split(' ') 137 | image_id[tmp[0]] = int(tmp[1]) 138 | 139 | batch_size = 32 140 | num_hidden = 512 141 | num_embed = 256 142 | num_lstm_layer = 2 143 | content = read_content(qa_train) 144 | content = [len(words) for words in splitwords(content)] 145 | buckets = [max(content) + 1] 146 | # buckets = [] 147 | 148 | print('maximum of qa: {0}'.format(buckets[0])) 149 | 150 | num_epoch = 75 151 | lr = 0.00005 152 | momentum = 0.9 153 | 154 | params = dict() 155 | params['num_hidden'] = num_hidden 156 | params['num_lstm_layer'] = num_lstm_layer 157 | params['num_embed'] = num_embed 158 | save_dict('../param/params', params) # save network's parameters 159 | 160 | # ctx = mx.gpu() 161 | devs = [mx.context.gpu(i) for i in range(1)] 162 | save_vocab('../param/vocabulary', vocabulary) # save those unique words 163 | print('len of vocabulary: ', len(vocabulary)) 164 | 165 | init_c = [('l{0}_init_c'.format(l), (batch_size, num_hidden)) for l in range(num_lstm_layer)] 166 | init_h = [('l{0}_init_h'.format(l), (batch_size, num_hidden)) for l in range(num_lstm_layer)] 167 | batch_img_feats = ('img_batch_feats', (batch_size, image_feats.shape[0])) 168 | init_states = init_c + init_h 169 | # print('at the start file: {0}'.format(init_states)) 170 | 171 | trainIter = SequencesIterQA(qa_train, image_train, vocabulary, re_vocab, 172 | image_feats, image_id, buckets, batch_size, init_states, batch_img_feats, 173 | text2id=text2id, read_content=read_content, id2text=idx2text) 174 | 175 | def gen_sym(seq_len): 176 | # # seq_len because input word seq_len dosen't need the last word and label seq_len dosen't need the fisrt word 177 | # # used for buickets is not [] 178 | # seq_len because input word seq_len dosen't need the last word and label seq_len dosen't need the fisrt word 179 | return lstm.unroll_lstm(num_lstm_layer, seq_len, len(vocabulary), batch_size, 180 | num_hidden, num_embed, len(vocabulary), dropout=0.5) 181 | 182 | if len(buckets) == 1: 183 | symbol = gen_sym(buckets[0]) 184 | else: 185 | symbol = gen_sym 186 | 187 | # load model if the pre-trained model is existed 188 | model_prefix = './model/QAmodel' 189 | load_epoch = 11 190 | model_args = {} 191 | if model_prefix is not None and load_epoch is not None: 192 | print('load previous model.') 193 | tmp = mx.model.FeedForward.load(model_prefix, load_epoch) 194 | model_args = {'arg_params' : tmp.arg_params, 195 | 'aux_params' : tmp.aux_params} 196 | rescale_grad = 1. / batch_size 197 | 198 | # model_args['learning_rate'] = lr 199 | # model_args['wd'] = 0.0002 200 | # model_args['momentum'] = momentum 201 | 202 | optimizer = mx.optimizer.Adam(learning_rate = lr, wd = 0.0002, rescale_grad = rescale_grad) 203 | model = mx.model.FeedForward( 204 | ctx=mx.gpu(), 205 | symbol=symbol, 206 | optimizer=optimizer, 207 | num_epoch=num_epoch, 208 | initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), 209 | **model_args 210 | ) 211 | 212 | head = '%(asctime)-15s %(message)s' 213 | logging.basicConfig(level=logging.DEBUG, format=head) 214 | 215 | save_model_prefix = './model/QAmodel' 216 | checkpoint = mx.callback.do_checkpoint(save_model_prefix) 217 | 218 | model.fit( 219 | X = trainIter, 220 | eval_metric = mx.metric.np(CrossEntropySoftmax), 221 | epoch_end_callback = checkpoint, 222 | batch_end_callback = mx.callback.Speedometer(batch_size, 50) 223 | ) 224 | --------------------------------------------------------------------------------