├── .idea ├── attention-over-attention-tf-QA.iml ├── misc.xml └── modules.xml ├── README.md ├── counter.pickle ├── model.py ├── reader.py └── util.py /.idea/attention-over-attention-tf-QA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention over Attention 2 | 这是我仿照github上面一个用户的代码,仅对其中少量部分代码进行了修改,是代码支持tf1.0及以上版本
3 | 源代码链接为:https://github.com/OlavHN/attention-over-attention
4 | 代码执行顺序为:
5 | 1,下载数据集
6 | 2,运行reader.py文件,将原数据集保存为.tfrecords文件,方便程序的高效读取
7 | 3,运行model.py文件,训练模型.
8 | 9 | 以下是原链接的readme说明。
10 | Implementation of the paper [Attention-over-Attention Neural Networks for Reading Comprehension](https://arxiv.org/abs/1607.04423) in tensorflow 11 | 12 | Some context on [my blog](http://olavnymoen.com/2016/10/30/attention-over-attention) 13 | 14 | Reading comprehension for cloze style tasks is to remove word from an article summary, then read the article and try to infer the missing word. This example works on the CNN news dataset. 15 | 16 | With the same hyperparameters as reported in the paper, this implementation got an accuracy of 74.3% on both the validation and test set, compared with 73.1% and 74.4% reported by the author. 17 | 18 | To train a new model: `python model.py --training=True --name=my_model` 19 | 20 | To test accuracy: `python model.py --training=False --name=my_model --epochs=1 --dropout_keep_prob=1` 21 | 22 | Note that the tfrecords and model files are stored with [git lfs](https://git-lfs.github.com/) 23 | 24 | Raw data for use with `reader.py` to produce .tfrecords files was downloaded from [http://cs.nyu.edu/~kcho/DMQA/] 25 | 26 | Interesting parts 27 | - Masked softmax implementation 28 | - Example of batched sparse tensors with correct mask handling 29 | - Example of pointer style attention 30 | - Test/validation split part of the tf-graph 31 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.python.ops import sparse_ops 7 | from util import softmax, orthogonal_initializer 8 | 9 | flags = tf.app.flags 10 | FLAGS = flags.FLAGS 11 | flags.DEFINE_integer('vocab_size', 119662, 'Vocabulary size') 12 | flags.DEFINE_integer('embedding_size', 384, 'Embedding dimension') 13 | flags.DEFINE_integer('hidden_size', 256, 'Hidden units') 14 | flags.DEFINE_integer('batch_size', 32, 'Batch size') 15 | flags.DEFINE_integer('epochs', 2, 'Number of epochs to train/test') 16 | flags.DEFINE_boolean('training', True, 'Training or testing a model') 17 | flags.DEFINE_string('name', 'lc_model', 'Model name (used for statistics and model path') 18 | flags.DEFINE_float('dropout_keep_prob', 0.9, 'Keep prob for embedding dropout') 19 | flags.DEFINE_float('l2_reg', 0.0001, 'l2 regularization for embeddings') 20 | 21 | model_path = 'models/' + FLAGS.name 22 | 23 | if not os.path.exists(model_path): 24 | os.makedirs(model_path) 25 | 26 | def read_records(index=0): 27 | train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs) 28 | validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs) 29 | test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs) 30 | 31 | queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue]) 32 | reader = tf.TFRecordReader() 33 | _, serialized_example = reader.read(queue) 34 | features = tf.parse_single_example( 35 | serialized_example, 36 | features={ 37 | 'document': tf.VarLenFeature(tf.int64), 38 | 'query': tf.VarLenFeature(tf.int64), 39 | 'answer': tf.FixedLenFeature([], tf.int64) 40 | }) 41 | 42 | document = sparse_ops.serialize_sparse(features['document']) 43 | query = sparse_ops.serialize_sparse(features['query']) 44 | answer = features['answer'] 45 | 46 | document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch( 47 | [document, query, answer], batch_size=FLAGS.batch_size, 48 | capacity=2000, 49 | min_after_dequeue=1000) 50 | 51 | sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64) 52 | sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64) 53 | 54 | document_batch = tf.sparse_tensor_to_dense(sparse_document_batch) 55 | document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1) 56 | 57 | query_batch = tf.sparse_tensor_to_dense(sparse_query_batch) 58 | query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1) 59 | 60 | return document_batch, document_weights, query_batch, query_weights, answer_batch 61 | 62 | def inference(documents, doc_mask, query, query_mask): 63 | 64 | embedding = tf.get_variable('embedding', 65 | [FLAGS.vocab_size, FLAGS.embedding_size], 66 | initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05)) 67 | 68 | regularizer = tf.nn.l2_loss(embedding) 69 | 70 | doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob) 71 | doc_emb.set_shape([None, None, FLAGS.embedding_size]) 72 | 73 | query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob) 74 | query_emb.set_shape([None, None, FLAGS.embedding_size]) 75 | 76 | with tf.variable_scope('document', initializer=orthogonal_initializer()): 77 | fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) 78 | back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) 79 | 80 | doc_len = tf.reduce_sum(doc_mask, reduction_indices=1) 81 | h, _ = tf.nn.bidirectional_dynamic_rnn( 82 | fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32) 83 | #h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob) 84 | h_doc = tf.concat(h, 2) 85 | 86 | with tf.variable_scope('query', initializer=orthogonal_initializer()): 87 | fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) 88 | back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size) 89 | 90 | query_len = tf.reduce_sum(query_mask, reduction_indices=1) 91 | h, _ = tf.nn.bidirectional_dynamic_rnn( 92 | fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32) 93 | #h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob) 94 | h_query = tf.concat(h, 2) 95 | 96 | M = tf.matmul(h_doc, h_query, adjoint_b=True) 97 | M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1))) 98 | 99 | alpha = softmax(M, 1, M_mask) 100 | beta = softmax(M, 2, M_mask) 101 | 102 | #query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1) 103 | query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1) 104 | 105 | s = tf.squeeze(tf.matmul(alpha, query_importance), [2]) 106 | 107 | unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size)) 108 | y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s]) 109 | 110 | return y_hat, regularizer 111 | 112 | def train(y_hat, regularizer, document, doc_weight, answer): 113 | # Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206 114 | # This unfortunately causes us to expand a sparse tensor into the full vocabulary 115 | index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer) 116 | flat = tf.reshape(y_hat, [-1]) 117 | relevant = tf.gather(flat, index) 118 | 119 | # mean cause reg is independent of batch size 120 | loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer 121 | 122 | global_step = tf.Variable(0, name="global_step", trainable=False) 123 | 124 | accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer))) 125 | 126 | optimizer = tf.train.AdamOptimizer() 127 | grads_and_vars = optimizer.compute_gradients(loss) 128 | capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars] 129 | train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step) 130 | 131 | tf.summary.scalar('loss', loss) 132 | tf.summary.scalar('accuracy', accuracy) 133 | return loss, train_op, global_step, accuracy 134 | 135 | def main(): 136 | dataset = tf.placeholder_with_default(0, []) 137 | document_batch, document_weights, query_batch, query_weights, answer_batch = read_records(dataset) 138 | 139 | y_hat, reg = inference(document_batch, document_weights, query_batch, query_weights) 140 | loss, train_op, global_step, accuracy = train(y_hat, reg, document_batch, document_weights, answer_batch) 141 | summary_op = tf.summary.merge_all() 142 | 143 | with tf.Session() as sess: 144 | summary_writer = tf.summary.FileWriter(model_path, sess.graph) 145 | saver_variables = tf.all_variables() 146 | if not FLAGS.training: 147 | saver_variables = filter(lambda var: var.name != 'input_producer/limit_epochs/epochs:0', saver_variables) 148 | saver_variables = filter(lambda var: var.name != 'smooth_acc:0', saver_variables) 149 | saver_variables = filter(lambda var: var.name != 'avg_acc:0', saver_variables) 150 | saver = tf.train.Saver(saver_variables) 151 | 152 | sess.run([ 153 | tf.initialize_all_variables(), 154 | tf.initialize_local_variables()]) 155 | model = tf.train.latest_checkpoint(model_path) 156 | if model: 157 | print('Restoring ' + model) 158 | saver.restore(sess, model) 159 | 160 | coord = tf.train.Coordinator() 161 | threads = tf.train.start_queue_runners(coord=coord) 162 | 163 | start_time = time.time() 164 | accumulated_accuracy = 0 165 | try: 166 | if FLAGS.training: 167 | while not coord.should_stop(): 168 | loss_t, _, step, acc = sess.run([loss, train_op, global_step, accuracy], feed_dict={dataset: 0}) 169 | elapsed_time, start_time = time.time() - start_time, time.time() 170 | print(step, loss_t, acc, elapsed_time) 171 | if step % 100 == 0: 172 | summary_str = sess.run(summary_op) 173 | summary_writer.add_summary(summary_str, step) 174 | if step % 1000 == 0: 175 | saver.save(sess, model_path + '/aoa', global_step=step) 176 | else: 177 | step = 0 178 | while not coord.should_stop(): 179 | acc = sess.run(accuracy, feed_dict={dataset: 2}) 180 | step += 1 181 | accumulated_accuracy += (acc - accumulated_accuracy) / step 182 | elapsed_time, start_time = time.time() - start_time, time.time() 183 | print(accumulated_accuracy, acc, elapsed_time) 184 | except tf.errors.OutOfRangeError: 185 | print('Done!') 186 | finally: 187 | coord.request_stop() 188 | coord.join(threads) 189 | 190 | ''' 191 | import pickle 192 | with open('counter.pickle', 'r') as f: 193 | counter = pickle.load(f) 194 | word, _ = zip(*counter.most_common()) 195 | ''' 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import Counter 4 | import tensorflow as tf 5 | 6 | def counts(): 7 | cache = 'counter.pickle' 8 | if os.path.exists(cache): 9 | with open(cache, 'r') as f: 10 | return pickle.load(f) 11 | 12 | directories = ['cnn/questions/training/', 'cnn/questions/validation/', 'cnn/questions/test/'] 13 | files = [directory + file_name for directory in directories for file_name in os.listdir(directory)] 14 | counter = Counter() 15 | for file_name in files: 16 | with open(file_name, 'r') as f: 17 | lines = f.readlines() 18 | document = lines[2].split() 19 | query = lines[4].split() 20 | answer = lines[6].split() 21 | for token in document + query + answer: 22 | counter[token] += 1 23 | with open(cache, 'w') as f: 24 | pickle.dump(counter, f) 25 | 26 | return counter 27 | 28 | def tokenize(index, word): 29 | directories = ['cnn/questions/training/', 'cnn/questions/validation/', 'cnn/questions/test/'] 30 | for directory in directories: 31 | out_name = directory.split('/')[-2] + '.tfrecords' 32 | writer = tf.python_io.TFRecordWriter(out_name) 33 | files = map(lambda file_name: directory + file_name, os.listdir(directory)) 34 | for file_name in files: 35 | with open(file_name, 'r') as f: 36 | lines = f.readlines() 37 | document = [index[token] for token in lines[2].split()] 38 | query = [index[token] for token in lines[4].split()] 39 | answer = [index[token] for token in lines[6].split()] 40 | example = tf.train.Example( 41 | features = tf.train.Features( 42 | feature = { 43 | 'document': tf.train.Feature( 44 | int64_list=tf.train.Int64List(value=document)), 45 | 'query': tf.train.Feature( 46 | int64_list=tf.train.Int64List(value=query)), 47 | 'answer': tf.train.Feature( 48 | int64_list=tf.train.Int64List(value=answer)) 49 | })) 50 | 51 | serialized = example.SerializeToString() 52 | writer.write(serialized) 53 | 54 | def main(): 55 | counter = counts() 56 | print('num words',len(counter)) 57 | word, _ = zip(*counter.most_common()) 58 | index = {token: i for i, token in enumerate(word)} 59 | tokenize(index, word) 60 | print('DONE') 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | # Softmax over axis 5 | def softmax(target, axis, mask, epsilon=1e-12, name=None): 6 | with tf.op_scope([target], name, 'softmax'): 7 | max_axis = tf.reduce_max(target, axis, keep_dims=True) 8 | target_exp = tf.exp(target-max_axis) * mask 9 | normalize = tf.reduce_sum(target_exp, axis, keep_dims=True) 10 | softmax = target_exp / (normalize + epsilon) 11 | return softmax 12 | 13 | def orthogonal_initializer(scale = 1.1): 14 | ''' From Lasagne and Keras. Reference: Saxe et al., http://arxiv.org/abs/1312.6120 15 | ''' 16 | print('Warning -- You have opted to use the orthogonal_initializer function') 17 | def _initializer(shape, dtype=tf.float32, partition_info=None): 18 | flat_shape = (shape[0], np.prod(shape[1:])) 19 | a = np.random.normal(0.0, 1.0, flat_shape) 20 | u, _, v = np.linalg.svd(a, full_matrices=False) 21 | # pick the one with the correct shape 22 | q = u if u.shape == flat_shape else v 23 | q = q.reshape(shape) #this needs to be corrected to float32 24 | print('you have initialized one orthogonal matrix.') 25 | return tf.constant(scale * q[:shape[0], :shape[1]], dtype=tf.float32) 26 | return _initializer 27 | --------------------------------------------------------------------------------