├── README.md ├── test.py ├── attention_sum_reader.py └── data_utils.py /README.md: -------------------------------------------------------------------------------- 1 | http://arxiv.org/abs/1603.01547 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import random 9 | import logging 10 | import tensorflow as tf 11 | 12 | from attention_sum_reader import Attention_sum_reader 13 | from data_utils import load_vocab, gen_embeddings, read_cbt_data, data_provider 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"]="2" 16 | logging.basicConfig(level=logging.INFO) 17 | random.seed(1) 18 | 19 | d_len = 1000 20 | q_len = 150 21 | A_len = 10 22 | lr_init = 0.005 23 | lr_decay = 2000 24 | hidden_size = 128 25 | batch_size = 128 26 | step_num = 20000 27 | embed_dim = 100 28 | embed_file = '/data1/flashlin/data/glove/glove.6B.100d.txt' 29 | 30 | word_dict = load_vocab('out/vocab') 31 | embedding_matrix = gen_embeddings(word_dict, embed_dim, embed_file) 32 | embedding_matrix = embedding_matrix.astype('float32') 33 | asr = Attention_sum_reader('pig', d_len, q_len, A_len, lr_init, lr_decay, 34 | embedding_matrix, hidden_size) 35 | 36 | #src_data = read_cbt_data('out/cbtest_NE_train.txt.idx', [100, d_len], [10, q_len]) 37 | #provider = data_provider(src_data, batch_size, d_len, q_len, step_num=step_num) 38 | 39 | src_data = read_cbt_data('out/cbtest_NE_valid_2000ex.txt.idx', [100, d_len], [10, q_len]) 40 | provider = data_provider(src_data, batch_size, d_len, q_len, None, 1) 41 | 42 | with tf.Session() as sess: 43 | #asr.train(sess, provider, 'pig', 500) 44 | asr.test(sess, provider, 'gru_model/gru-19800') 45 | -------------------------------------------------------------------------------- /attention_sum_reader.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import logging 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from tensorflow.contrib.rnn import LSTMCell, GRUCell, MultiRNNCell, DropoutWrapper 13 | 14 | class Attention_sum_reader(object): 15 | def __init__(self, name, d_len, q_len, A_len, lr_init, lr_decay, embedding_matrix, hidden_size, num_layers): 16 | self._name = name 17 | self._d_len = d_len 18 | self._q_len = q_len 19 | self._A_len = A_len 20 | self._lr_init = lr_init 21 | self._lr_decay = lr_decay 22 | #self._embedding_matrix = tf.Variable(embedding_matrix, dtype=tf.float32) 23 | self._embedding_matrix = embedding_matrix 24 | self._hidden_size = hidden_size 25 | self._num_layers = num_layers 26 | 27 | self._d_input = tf.placeholder(dtype=tf.int32, shape=(None, d_len), name='d_input') 28 | self._q_input = tf.placeholder(dtype=tf.int32, shape=(None, q_len), name='q_input') 29 | self._context_mask = tf.placeholder(dtype=tf.int8, shape=(None, d_len), name='context_mask') 30 | self._ca = tf.placeholder(dtype=tf.int32, shape=(None, A_len), name='ca') 31 | self._y = tf.placeholder(dtype=tf.int32, shape=(None), name='y') 32 | 33 | self._build_network() 34 | 35 | self._saver = tf.train.Saver() 36 | 37 | def train(self, sess, provider, save_dir, save_period, model_path=None): 38 | sess.run(tf.global_variables_initializer()) 39 | 40 | if model_path: 41 | logging.info('[restore] {}'.format(model_path)) 42 | self._saver.restore(sess, model_path) 43 | 44 | losses = [] 45 | predictions = [] 46 | for data in provider: 47 | d_input, q_input, context_mask, ca, y = data 48 | _, loss, prediction = sess.run( 49 | [self._train_op, self._loss, self._prediction], 50 | feed_dict={self._d_input: d_input, self._q_input: q_input, self._context_mask: context_mask, 51 | self._ca: ca, self._y: y}) 52 | losses.append(loss) 53 | predictions.append(prediction / len(d_input)) 54 | 55 | step = sess.run(self._global_step) 56 | 57 | if step % 100 == 0: 58 | logging.info('[Train] step: {}, loss: {}, prediction: {}, lr: {}'.format( 59 | step, 60 | np.sum(losses) / len(losses), 61 | np.sum(predictions) / len(predictions), 62 | sess.run(self._lr))) 63 | losses = [] 64 | predictions = [] 65 | 66 | if step % save_period == 0 and step > 0: 67 | save_path = os.path.join(save_dir, self._name) 68 | logging.info('[Save] {} {}'.format(save_path, step)) 69 | self._saver.save(sess, save_path, global_step=self._global_step) 70 | 71 | 72 | def test(self, sess, provider, model_path): 73 | logging.info('[restore] {}'.format(model_path)) 74 | self._saver.restore(sess, model_path) 75 | 76 | q_num = 0.0 77 | p_num = 0.0 78 | for (i, data) in enumerate(provider): 79 | d_input, q_input, context_mask, ca, y = data 80 | prediction = sess.run( 81 | self._prediction, 82 | feed_dict={self._d_input: d_input, self._q_input: q_input, self._context_mask: context_mask, 83 | self._ca: ca, self._y: y}) 84 | 85 | q_num += len(d_input) 86 | p_num += prediction 87 | 88 | if i % 50 == 0: 89 | logging.info('[test] q_num: {}, p_num: {}, {}'.format(q_num, p_num, float(p_num)/q_num)) 90 | 91 | logging.info('[test] q_num: {}, p_num: {}, {}'.format(q_num, p_num, float(p_num)/q_num)) 92 | 93 | def _RNNCell(self): 94 | cell = GRUCell(self._hidden_size) 95 | #cell = LSTMCell(self._hidden_size) 96 | return DropoutWrapper(cell, input_keep_prob=0.8, output_keep_prob=0.8) 97 | #return cell 98 | 99 | def _Optimizer(self, global_step): 100 | self._lr = tf.train.exponential_decay(self._lr_init, global_step, self._lr_decay, 0.5, staircase=True) 101 | 102 | #return tf.train.GradientDescentOptimizer(self._lr) 103 | return tf.train.AdamOptimizer(self._lr) 104 | 105 | def _build_network(self): 106 | with tf.variable_scope('q_encoder'): 107 | q_embed = tf.nn.embedding_lookup(self._embedding_matrix, self._q_input) 108 | q_lens = tf.reduce_sum(tf.sign(tf.abs(self._q_input)), 1) 109 | outputs, final_states = tf.nn.bidirectional_dynamic_rnn( 110 | cell_bw=self._RNNCell(), cell_fw=self._RNNCell(), 111 | inputs=q_embed, dtype=tf.float32, sequence_length=q_lens) 112 | q_encode = tf.concat([final_states[0], final_states[1]], axis=-1) 113 | #q_encode = tf.concat([final_states[0][-1][1], final_states[1][-1][1]], axis=-1) 114 | 115 | # [batch_size, hidden_size * 2] 116 | logging.info('q_encode shape {}'.format(q_encode.get_shape())) 117 | logging.info('q_encode shape {}'.format(final_states[0][-1][0].get_shape())) 118 | 119 | with tf.variable_scope('d_encoder'): 120 | d_embed = tf.nn.embedding_lookup(self._embedding_matrix, self._d_input) 121 | d_lens = tf.reduce_sum(tf.sign(tf.abs(self._d_input)), 1) 122 | outputs, final_states = tf.nn.bidirectional_dynamic_rnn( 123 | cell_bw=self._RNNCell(), cell_fw=self._RNNCell(), 124 | inputs=d_embed, dtype=tf.float32, sequence_length=d_lens) 125 | d_encode = tf.concat(outputs, axis=-1) 126 | 127 | # [batch_size, d_len, hidden_size * 2] 128 | logging.info('d_encode shape {}'.format(d_encode.get_shape())) 129 | 130 | with tf.variable_scope('dot_sum'): 131 | def reduce_attention_sum(data): 132 | at, d, ca = data 133 | def reduce_attention_sum_by_ans(aid): 134 | return tf.reduce_sum(tf.multiply(at, tf.cast(tf.equal(d, aid), tf.float32))) 135 | return tf.map_fn(reduce_attention_sum_by_ans, ca, dtype=tf.float32) 136 | 137 | attention_value = tf.map_fn( 138 | lambda v: tf.reduce_sum(tf.multiply(v[0], v[1]), -1), 139 | (q_encode, d_encode), 140 | dtype=tf.float32) 141 | attention_value_masked = tf.multiply(attention_value, tf.cast(self._context_mask, tf.float32)) 142 | attention_value_softmax = tf.nn.softmax(attention_value_masked) 143 | self._attention_sum = tf.map_fn(reduce_attention_sum, 144 | (attention_value_softmax, self._d_input, self._ca), dtype=tf.float32) 145 | 146 | # [batch_size, A_len] 147 | logging.info('attention_sum shape {}'.format(self._attention_sum.get_shape())) 148 | 149 | with tf.variable_scope('prediction'): 150 | self._prediction = tf.reduce_sum(tf.cast( 151 | tf.equal(tf.cast(self._y, dtype=tf.int64), tf.argmax(self._attention_sum, 1)), tf.float32)) 152 | 153 | with tf.variable_scope('loss'): 154 | label = tf.Variable([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=tf.float32, trainable=False) 155 | #self._output = self._attention_sum / tf.reduce_sum(self._attention_sum, -1, keep_dims=True) 156 | self._loss = tf.reduce_mean(-tf.log(tf.reduce_sum(self._attention_sum * label, -1))) 157 | 158 | with tf.variable_scope('train'): 159 | self._global_step = tf.contrib.framework.get_or_create_global_step() 160 | optimizer = self._Optimizer(self._global_step) 161 | self._train_op = optimizer.minimize(self._loss, global_step=self._global_step) 162 | 163 | if __name__ == '__main__': 164 | logging.basicConfig(level=logging.INFO) 165 | 166 | embedded = tf.zeros((1000, 100), dtype=tf.float32) 167 | Attention_sum_reader(name='miao', d_len=600, q_len=60, A_len=10, lr=0.1, embedding_matrix=embedded, hidden_size=128, num_layers=2) 168 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import io 9 | import random 10 | import re 11 | import logging 12 | import nltk 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | from collections import Counter 17 | 18 | _PAD = u'_PAD' 19 | _UNK = u'_UNK' 20 | _EOS = u'_EOS' 21 | _START_VOCAB = [_PAD, _UNK, _EOS] 22 | 23 | PAD_ID = 0 24 | UNK_ID = 1 25 | EOS_ID = 2 26 | 27 | def writeWrapper(fh, context): 28 | if not type(context) is unicode: 29 | assert(type(context) is str) 30 | context = unicode(context, 'utf8') 31 | fh.write(context) 32 | 33 | def tokenizer(sentence): 34 | sentence = ' '.join(sentence.split("|")) 35 | tokens = nltk.word_tokenize(sentence.lower()) 36 | return tokens 37 | 38 | def gen_vocab(data_file, word_dict=None): 39 | word_dict = word_dict if word_dict else Counter() 40 | 41 | for i, line in enumerate(io.open(data_file, 'r', encoding='utf8')): 42 | if len(line.strip()) == 0: 43 | continue 44 | 45 | tokens = tokenizer(line[line.index(' ') + 1:]) 46 | word_dict.update(tokens) 47 | 48 | if i % 100000 == 0: 49 | logging.info('[gen_vocab] data_file: %s, i: %d' % (data_file, i)) 50 | 51 | return word_dict 52 | 53 | def save_vocab(word_dict, vocab_file): 54 | with io.open(vocab_file, 'w', encoding='utf8') as f: 55 | for word in _START_VOCAB: 56 | writeWrapper(f, word + u'\n') 57 | for word in word_dict: 58 | writeWrapper(f, word + u'\n') 59 | 60 | def load_vocab(vocab_file): 61 | assert(vocab_file != None) 62 | word_dict = {} 63 | with io.open(vocab_file, 'r', encoding='utf8') as f: 64 | for wid, word in enumerate(f): 65 | word = word.strip() 66 | word_dict[word] = wid 67 | return word_dict 68 | 69 | def sentence_to_token_ids(sentence, word_dict): 70 | return [word_dict.get(token, UNK_ID) for token in tokenizer(sentence)] 71 | 72 | def cbt_data_to_token_ids(data_file, target_file, vocab_file): 73 | word_dict = load_vocab(vocab_file) 74 | 75 | with io.open(data_file, 'r', encoding='utf8') as data_file, io.open(target_file, 'w', encoding='utf8') as tokens_file: 76 | for line in data_file: 77 | if len(line.strip()) == 0: 78 | writeWrapper(tokens_file, u'\n') 79 | continue 80 | 81 | num = int(line[:line.index(' ')]) 82 | line = line[line.index(' ') + 1:] 83 | 84 | if num == 21: 85 | q, a, _, A = line.split('\t') 86 | tokens_ids_q = sentence_to_token_ids(q, word_dict) 87 | tokens_ids_A = [word_dict.get(Ai.lower(), UNK_ID) for Ai in A.rstrip('\n').split('|')] 88 | context = (' '.join([str(token) for token in tokens_ids_q]) + '\t' + 89 | str(word_dict.get(a.lower(), UNK_ID)) + '\t' + 90 | '|'.join([str(token) for token in tokens_ids_A]) + '\n') 91 | context = unicode(context, 'utf8') 92 | writeWrapper(tokens_file, context) 93 | else: 94 | tokens_ids = sentence_to_token_ids(line, word_dict) 95 | context = ' '.join([str(token) for token in tokens_ids]) + '\n' 96 | writeWrapper(tokens_file, context) 97 | 98 | def prepare_cbt_data(data_dir, out_dir, train_file, valid_file, test_file): 99 | if not tf.gfile.Exists(out_dir): 100 | os.mkdir(out_dir) 101 | 102 | src_train_file = os.path.join(data_dir, train_file) 103 | src_valid_file = os.path.join(data_dir, valid_file) 104 | src_test_file = os.path.join(data_dir, test_file) 105 | idx_train_file = os.path.join(out_dir, train_file + ".idx") 106 | idx_valid_file = os.path.join(out_dir, valid_file + ".idx") 107 | idx_test_file = os.path.join(out_dir, test_file + ".idx") 108 | vocab_file = os.path.join(out_dir, "vocab") 109 | 110 | wd = gen_vocab(src_train_file) 111 | wd = gen_vocab(src_valid_file, wd) 112 | wd = gen_vocab(src_test_file, wd) 113 | save_vocab(wd, vocab_file) 114 | logging.info('Total words: %d' % len(wd)) 115 | logging.info('Total distinct words: %d' % sum(wd.values())) 116 | 117 | cbt_data_to_token_ids(src_train_file, idx_train_file, vocab_file) 118 | cbt_data_to_token_ids(src_valid_file, idx_valid_file, vocab_file) 119 | cbt_data_to_token_ids(src_test_file, idx_test_file, vocab_file) 120 | 121 | return idx_train_file, idx_valid_file, idx_train_file, vocab_file 122 | 123 | def read_cbt_data(idx_file, d_len_range = None, q_len_range = None, max_count = None): 124 | def ok(d_len, q_len, A_len): 125 | d_con = (not d_len_range) or (d_len_range[0] < d_len < d_len_range[1]) 126 | q_con = (not q_len_range) or (q_len_range[0] < q_len < q_len_range[1]) 127 | A_con = (A_len == 10) 128 | return d_con and q_con and A_con 129 | 130 | skip = 0 131 | documents, questions, answers, candidates = [], [], [], [] 132 | with io.open(idx_file, 'r', encoding='utf8') as f: 133 | cnt = 0 134 | d, q, a, A = [], [], [], [] 135 | for line in f: 136 | cnt += 1 137 | if cnt <= 20: 138 | d.extend(line.strip().split(' ') + [EOS_ID]) 139 | elif cnt == 21: 140 | tmp = line.strip().split('\t') 141 | q = tmp[0].split(' ') + [EOS_ID] 142 | a = [1 if tmp[1] == wid else 0 for wid in d] 143 | A = [Ai for Ai in tmp[2].split('|')] 144 | A.remove(tmp[1]) 145 | A.insert(0, tmp[1]) 146 | 147 | if ok(len(d), len(q), len(A)): 148 | documents.append(d) 149 | questions.append(q) 150 | answers.append(a) 151 | candidates.append(A) 152 | else: 153 | skip += 1 154 | elif cnt == 22: 155 | d, q, a, A = [], [], [], [] 156 | cnt = 0 157 | 158 | if max_count and len(questions) >= max_count: 159 | break; 160 | 161 | logging.info('[read_cbt_data] skip: {}, read: {}'.format(skip, len(questions))) 162 | 163 | return documents, questions, answers, candidates 164 | 165 | def get_embed_dim(embed_file): 166 | line = io.open(embed_file, 'r', encoding='utf8').readline() 167 | return len(line.split()) - 1 168 | 169 | def gen_embeddings(word_dict, embed_dim, embed_file=None): 170 | num_words = len(word_dict) 171 | #return tf.random_uniform([num_words, embed_dim], -0.1, 0.1) 172 | 173 | embedding_matrix = np.random.uniform(-0.1, 0.1, [num_words, embed_dim]) 174 | if embed_file: 175 | pre_trained = 0 176 | for line in io.open(embed_file, 'r', encoding='utf8'): 177 | items = line.split() 178 | word = items[0] 179 | assert(embed_dim + 1 == len(items)) 180 | if word in word_dict: 181 | pre_trained += 1 182 | embedding_matrix[word_dict[word]] = [float(x) for x in items[1:]] 183 | 184 | logging.info('Embedding file: %s, pre_trained_rate: %.2f' % (embed_file, 100.0 * pre_trained / num_words)) 185 | 186 | return embedding_matrix 187 | 188 | def data_provider(src_data, batch_size, d_len, q_len, step_num=None, epoch_num=None): 189 | documents, questions, answers, candidates = src_data 190 | N = len(documents) 191 | 192 | logging.info('[data_provider] N: {}, batch_size: {}, step_num: {}, d_len: {}, q_len: {}'.format( 193 | N, batch_size, step_num, d_len, q_len)) 194 | assert(len(questions) == N and len(answers) == N and len(candidates) == N) 195 | assert(N > batch_size * 10) 196 | 197 | context_masks = [] 198 | ys = [] 199 | for i in range(N): 200 | context_mask = [1] * len(documents[i]) + [0] * (d_len - len(documents[i])) 201 | context_masks.append(context_mask) 202 | 203 | assert(len(documents[i]) <= d_len) 204 | documents[i] += [PAD_ID] * (d_len - len(documents[i])) 205 | 206 | assert(len(questions[i]) <= q_len) 207 | questions[i] += [PAD_ID] * (q_len - len(questions[i])) 208 | 209 | ys.append(0) 210 | 211 | if not step_num: 212 | assert(epoch_num) 213 | step_num = N // batch_size 214 | if N % batch_size != 0: 215 | step_num += 1 216 | step_num *= epoch_num 217 | 218 | h = N 219 | idx = [i for i in range(N)] 220 | for _ in range(step_num): 221 | if h == N: 222 | random.shuffle(idx) 223 | h = 0 224 | logging.info('[data_provider] new epoch') 225 | 226 | d_input = [] 227 | q_input = [] 228 | context_mask = [] 229 | ca = [] 230 | y = [] 231 | 232 | data_len = batch_size if h + batch_size <= N else N - h 233 | for i in range(data_len): 234 | d_input.append(documents[idx[h + i]]) 235 | q_input.append(questions[idx[h + i]]) 236 | context_mask.append(context_masks[idx[h + i]]) 237 | ca.append(candidates[idx[h + i]]) 238 | y.append(ys[idx[h + i]]) 239 | h += data_len 240 | 241 | yield d_input, q_input, context_mask, ca, y 242 | 243 | if __name__ == '__main__': 244 | logging.basicConfig(level=logging.INFO) 245 | 246 | idx_train, idx_valid, idx_test, vocab = prepare_cbt_data( 247 | '/data1/flashlin/data/CBTest/data/', 'out', 'cbtest_NE_train.txt', 'cbtest_NE_valid_2000ex.txt', 'cbtest_NE_test_2500ex.txt') 248 | 249 | --------------------------------------------------------------------------------