├── .gitignore ├── LICENSE ├── README.md ├── attention.py ├── train.py ├── utils.py ├── visualization.html └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # Tensorboard logs 92 | logdir/ 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Ivanov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Tensorflow implementation of attention mechanism for text classification tasks. 2 | Inspired by "Hierarchical Attention Networks for Document Classification", Zichao Yang et al. (http://www.aclweb.org/anthology/N16-1174). 3 | 4 | ### Requirements 5 | - Python >= 2.6 6 | - Tensorflow >= 1.0 7 | - Keras (IMDB dataset) 8 | - tqdm 9 | 10 | To view visualization example visit http://htmlpreview.github.io/?https://github.com/ilivans/tf-rnn-attention/blob/master/visualization.html 11 | 12 | My bachelor's thesis on sentiment classification of Russian texts using Bi-RNN with attention mechanism: https://github.com/ilivans/attention-sentiment 13 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def attention(inputs, attention_size, time_major=False, return_alphas=False): 5 | """ 6 | Attention mechanism layer which reduces RNN/Bi-RNN outputs with Attention vector. 7 | 8 | The idea was proposed in the article by Z. Yang et al., "Hierarchical Attention Networks 9 | for Document Classification", 2016: http://www.aclweb.org/anthology/N16-1174. 10 | Variables notation is also inherited from the article 11 | 12 | Args: 13 | inputs: The Attention inputs. 14 | Matches outputs of RNN/Bi-RNN layer (not final state): 15 | In case of RNN, this must be RNN outputs `Tensor`: 16 | If time_major == False (default), this must be a tensor of shape: 17 | `[batch_size, max_time, cell.output_size]`. 18 | If time_major == True, this must be a tensor of shape: 19 | `[max_time, batch_size, cell.output_size]`. 20 | In case of Bidirectional RNN, this must be a tuple (outputs_fw, outputs_bw) containing the forward and 21 | the backward RNN outputs `Tensor`. 22 | If time_major == False (default), 23 | outputs_fw is a `Tensor` shaped: 24 | `[batch_size, max_time, cell_fw.output_size]` 25 | and outputs_bw is a `Tensor` shaped: 26 | `[batch_size, max_time, cell_bw.output_size]`. 27 | If time_major == True, 28 | outputs_fw is a `Tensor` shaped: 29 | `[max_time, batch_size, cell_fw.output_size]` 30 | and outputs_bw is a `Tensor` shaped: 31 | `[max_time, batch_size, cell_bw.output_size]`. 32 | attention_size: Linear size of the Attention weights. 33 | time_major: The shape format of the `inputs` Tensors. 34 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 35 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 36 | Using `time_major = True` is a bit more efficient because it avoids 37 | transposes at the beginning and end of the RNN calculation. However, 38 | most TensorFlow data is batch-major, so by default this function 39 | accepts input and emits output in batch-major form. 40 | return_alphas: Whether to return attention coefficients variable along with layer's output. 41 | Used for visualization purpose. 42 | Returns: 43 | The Attention output `Tensor`. 44 | In case of RNN, this will be a `Tensor` shaped: 45 | `[batch_size, cell.output_size]`. 46 | In case of Bidirectional RNN, this will be a `Tensor` shaped: 47 | `[batch_size, cell_fw.output_size + cell_bw.output_size]`. 48 | """ 49 | 50 | if isinstance(inputs, tuple): 51 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs. 52 | inputs = tf.concat(inputs, 2) 53 | 54 | if time_major: 55 | # (T,B,D) => (B,T,D) 56 | inputs = tf.array_ops.transpose(inputs, [1, 0, 2]) 57 | 58 | hidden_size = inputs.shape[2].value # D value - hidden size of the RNN layer 59 | 60 | initializer = tf.random_normal_initializer(stddev=0.1) 61 | 62 | # Trainable parameters 63 | w_omega = tf.get_variable(name="w_omega", shape=[hidden_size, attention_size], initializer=initializer) 64 | b_omega = tf.get_variable(name="b_omega", shape=[attention_size], initializer=initializer) 65 | u_omega = tf.get_variable(name="u_omega", shape=[attention_size], initializer=initializer) 66 | 67 | with tf.name_scope('v'): 68 | # Applying fully connected layer with non-linear activation to each of the B*T timestamps; 69 | # the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size 70 | v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega) 71 | 72 | # For each of the timestamps its vector of size A from `v` is reduced with `u` vector 73 | vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape 74 | alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape 75 | 76 | # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape 77 | output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1) 78 | 79 | if not return_alphas: 80 | return output 81 | else: 82 | return output, alphas 83 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | Toy example of attention layer use 4 | 5 | Train RNN (GRU) on IMDB dataset (binary classification) 6 | Learning and hyper-parameters were not tuned; script serves as an example 7 | """ 8 | from __future__ import print_function, division 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from keras.datasets import imdb 13 | from tensorflow.contrib.rnn import GRUCell 14 | from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn as bi_rnn 15 | from tqdm import tqdm 16 | 17 | from attention import attention 18 | from utils import get_vocabulary_size, fit_in_vocabulary, zero_pad, batch_generator 19 | 20 | NUM_WORDS = 10000 21 | INDEX_FROM = 3 22 | SEQUENCE_LENGTH = 250 23 | EMBEDDING_DIM = 100 24 | HIDDEN_SIZE = 150 25 | ATTENTION_SIZE = 50 26 | KEEP_PROB = 0.8 27 | BATCH_SIZE = 256 28 | NUM_EPOCHS = 3 # Model easily overfits without pre-trained words embeddings, that's why train for a few epochs 29 | DELTA = 0.5 30 | MODEL_PATH = './model' 31 | 32 | # Load the data set 33 | (X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=NUM_WORDS, index_from=INDEX_FROM) 34 | 35 | # Sequences pre-processing 36 | vocabulary_size = get_vocabulary_size(X_train) 37 | X_test = fit_in_vocabulary(X_test, vocabulary_size) 38 | X_train = zero_pad(X_train, SEQUENCE_LENGTH) 39 | X_test = zero_pad(X_test, SEQUENCE_LENGTH) 40 | 41 | # Different placeholders 42 | with tf.name_scope('Inputs'): 43 | batch_ph = tf.placeholder(tf.int32, [None, SEQUENCE_LENGTH], name='batch_ph') 44 | target_ph = tf.placeholder(tf.float32, [None], name='target_ph') 45 | seq_len_ph = tf.placeholder(tf.int32, [None], name='seq_len_ph') 46 | keep_prob_ph = tf.placeholder(tf.float32, name='keep_prob_ph') 47 | 48 | # Embedding layer 49 | with tf.name_scope('Embedding_layer'): 50 | embeddings_var = tf.Variable(tf.random_uniform([vocabulary_size, EMBEDDING_DIM], -1.0, 1.0), trainable=True) 51 | tf.summary.histogram('embeddings_var', embeddings_var) 52 | batch_embedded = tf.nn.embedding_lookup(embeddings_var, batch_ph) 53 | 54 | # (Bi-)RNN layer(-s) 55 | rnn_outputs, _ = bi_rnn(GRUCell(HIDDEN_SIZE), GRUCell(HIDDEN_SIZE), 56 | inputs=batch_embedded, sequence_length=seq_len_ph, dtype=tf.float32) 57 | tf.summary.histogram('RNN_outputs', rnn_outputs) 58 | 59 | # Attention layer 60 | with tf.name_scope('Attention_layer'): 61 | attention_output, alphas = attention(rnn_outputs, ATTENTION_SIZE, return_alphas=True) 62 | tf.summary.histogram('alphas', alphas) 63 | 64 | # Dropout 65 | drop = tf.nn.dropout(attention_output, keep_prob_ph) 66 | 67 | # Fully connected layer 68 | with tf.name_scope('Fully_connected_layer'): 69 | W = tf.Variable(tf.truncated_normal([HIDDEN_SIZE * 2, 1], stddev=0.1)) # Hidden size is multiplied by 2 for Bi-RNN 70 | b = tf.Variable(tf.constant(0., shape=[1])) 71 | y_hat = tf.nn.xw_plus_b(drop, W, b) 72 | y_hat = tf.squeeze(y_hat) 73 | tf.summary.histogram('W', W) 74 | 75 | with tf.name_scope('Metrics'): 76 | # Cross-entropy loss and optimizer initialization 77 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_hat, labels=target_ph)) 78 | tf.summary.scalar('loss', loss) 79 | optimizer = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(loss) 80 | 81 | # Accuracy metric 82 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(tf.sigmoid(y_hat)), target_ph), tf.float32)) 83 | tf.summary.scalar('accuracy', accuracy) 84 | 85 | merged = tf.summary.merge_all() 86 | 87 | # Batch generators 88 | train_batch_generator = batch_generator(X_train, y_train, BATCH_SIZE) 89 | test_batch_generator = batch_generator(X_test, y_test, BATCH_SIZE) 90 | 91 | train_writer = tf.summary.FileWriter('./logdir/train', accuracy.graph) 92 | test_writer = tf.summary.FileWriter('./logdir/test', accuracy.graph) 93 | 94 | session_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 95 | 96 | saver = tf.train.Saver() 97 | 98 | if __name__ == "__main__": 99 | with tf.Session(config=session_conf) as sess: 100 | sess.run(tf.global_variables_initializer()) 101 | print("Start learning...") 102 | for epoch in range(NUM_EPOCHS): 103 | loss_train = 0 104 | loss_test = 0 105 | accuracy_train = 0 106 | accuracy_test = 0 107 | 108 | print("epoch: {}\t".format(epoch), end="") 109 | 110 | # Training 111 | num_batches = X_train.shape[0] // BATCH_SIZE 112 | for b in tqdm(range(num_batches)): 113 | x_batch, y_batch = next(train_batch_generator) 114 | seq_len = np.array([list(x).index(0) + 1 for x in x_batch]) # actual lengths of sequences 115 | loss_tr, acc, _, summary = sess.run([loss, accuracy, optimizer, merged], 116 | feed_dict={batch_ph: x_batch, 117 | target_ph: y_batch, 118 | seq_len_ph: seq_len, 119 | keep_prob_ph: KEEP_PROB}) 120 | accuracy_train += acc 121 | loss_train = loss_tr * DELTA + loss_train * (1 - DELTA) 122 | train_writer.add_summary(summary, b + num_batches * epoch) 123 | accuracy_train /= num_batches 124 | 125 | # Testing 126 | num_batches = X_test.shape[0] // BATCH_SIZE 127 | for b in tqdm(range(num_batches)): 128 | x_batch, y_batch = next(test_batch_generator) 129 | seq_len = np.array([list(x).index(0) + 1 for x in x_batch]) # actual lengths of sequences 130 | loss_test_batch, acc, summary = sess.run([loss, accuracy, merged], 131 | feed_dict={batch_ph: x_batch, 132 | target_ph: y_batch, 133 | seq_len_ph: seq_len, 134 | keep_prob_ph: 1.0}) 135 | accuracy_test += acc 136 | loss_test += loss_test_batch 137 | test_writer.add_summary(summary, b + num_batches * epoch) 138 | accuracy_test /= num_batches 139 | loss_test /= num_batches 140 | 141 | print("loss: {:.3f}, val_loss: {:.3f}, acc: {:.3f}, val_acc: {:.3f}".format( 142 | loss_train, loss_test, accuracy_train, accuracy_test 143 | )) 144 | train_writer.close() 145 | test_writer.close() 146 | saver.save(sess, MODEL_PATH) 147 | print("Run 'tensorboard --logdir=./logdir' to checkout tensorboard logs.") 148 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | 5 | 6 | def zero_pad(X, seq_len): 7 | return np.array([x[:seq_len - 1] + [0] * max(seq_len - len(x), 1) for x in X]) 8 | 9 | 10 | def get_vocabulary_size(X): 11 | return max([max(x) for x in X]) + 1 # plus the 0th word 12 | 13 | 14 | def fit_in_vocabulary(X, voc_size): 15 | return [[w for w in x if w < voc_size] for x in X] 16 | 17 | 18 | def batch_generator(X, y, batch_size): 19 | """Primitive batch generator 20 | """ 21 | size = X.shape[0] 22 | X_copy = X.copy() 23 | y_copy = y.copy() 24 | indices = np.arange(size) 25 | np.random.shuffle(indices) 26 | X_copy = X_copy[indices] 27 | y_copy = y_copy[indices] 28 | i = 0 29 | while True: 30 | if i + batch_size <= size: 31 | yield X_copy[i:i + batch_size], y_copy[i:i + batch_size] 32 | i += batch_size 33 | else: 34 | i = 0 35 | indices = np.arange(size) 36 | np.random.shuffle(indices) 37 | X_copy = X_copy[indices] 38 | y_copy = y_copy[indices] 39 | continue 40 | 41 | 42 | if __name__ == "__main__": 43 | # Test batch generator 44 | gen = batch_generator(np.array(['a', 'b', 'c', 'd']), np.array([1, 2, 3, 4]), 2) 45 | for _ in range(8): 46 | xx, yy = next(gen) 47 | print(xx, yy) 48 | -------------------------------------------------------------------------------- /visualization.html: -------------------------------------------------------------------------------- 1 | how 2 | his 3 | :UNK: 4 | evolved 5 | as 6 | both 7 | man 8 | and 9 | ape 10 | was 11 | outstanding 12 | not 13 | to 14 | mention 15 | the 16 | scenery 17 | of 18 | the 19 | film 20 | christopher 21 | :UNK: 22 | was 23 | astonishing 24 | as 25 | lord 26 | of 27 | :UNK: 28 | christopher 29 | is 30 | the 31 | soul 32 | to 33 | this 34 | masterpiece 35 | i 36 | became 37 | so 38 | with 39 | his 40 | performance 41 | i 42 | could 43 | feel 44 | my 45 | heart 46 | :UNK: 47 | the 48 | of 49 | the 50 | movie 51 | still 52 | moves 53 | me 54 | to 55 | this 56 | day 57 | his 58 | portrayal 59 | of 60 | john 61 | was 62 | oscar 63 | worthy 64 | as 65 | he 66 | should 67 | have 68 | been 69 | nominated 70 | for 71 | it 72 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | Example of attention coefficients visualization 4 | 5 | Uses saved model, so it should be executed after train.py 6 | """ 7 | from train import * 8 | 9 | saver = tf.train.Saver() 10 | 11 | # Calculate alpha coefficients for the first test example 12 | with tf.Session() as sess: 13 | saver.restore(sess, MODEL_PATH) 14 | 15 | x_batch_test, y_batch_test = X_test[:1], y_test[:1] 16 | seq_len_test = np.array([list(x).index(0) + 1 for x in x_batch_test]) 17 | alphas_test = sess.run([alphas], feed_dict={batch_ph: x_batch_test, target_ph: y_batch_test, 18 | seq_len_ph: seq_len_test, keep_prob_ph: 1.0}) 19 | alphas_values = alphas_test[0][0] 20 | 21 | # Build correct mapping from word to index and inverse 22 | word_index = imdb.get_word_index() 23 | word_index = {word: index + INDEX_FROM for word, index in word_index.items()} 24 | word_index[":PAD:"] = 0 25 | word_index[":START:"] = 1 26 | word_index[":UNK:"] = 2 27 | index_word = {value: key for key, value in word_index.items()} 28 | # Represent the sample by words rather than indices 29 | words = list(map(index_word.get, x_batch_test[0])) 30 | 31 | # Save visualization as HTML 32 | with open("visualization.html", "w") as html_file: 33 | for word, alpha in zip(words, alphas_values / alphas_values.max()): 34 | if word == ":START:": 35 | continue 36 | elif word == ":PAD:": 37 | break 38 | html_file.write('%s\n' % (alpha, word)) 39 | 40 | print('\nOpen visualization.html to checkout the attention coefficients visualization.') 41 | --------------------------------------------------------------------------------