├── .gitignore
├── LICENSE
├── README.md
├── attention.py
├── train.py
├── utils.py
├── visualization.html
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | # Tensorboard logs
92 | logdir/
93 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ilya Ivanov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Tensorflow implementation of attention mechanism for text classification tasks.
2 | Inspired by "Hierarchical Attention Networks for Document Classification", Zichao Yang et al. (http://www.aclweb.org/anthology/N16-1174).
3 |
4 | ### Requirements
5 | - Python >= 2.6
6 | - Tensorflow >= 1.0
7 | - Keras (IMDB dataset)
8 | - tqdm
9 |
10 | To view visualization example visit http://htmlpreview.github.io/?https://github.com/ilivans/tf-rnn-attention/blob/master/visualization.html
11 |
12 | My bachelor's thesis on sentiment classification of Russian texts using Bi-RNN with attention mechanism: https://github.com/ilivans/attention-sentiment
13 |
--------------------------------------------------------------------------------
/attention.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def attention(inputs, attention_size, time_major=False, return_alphas=False):
5 | """
6 | Attention mechanism layer which reduces RNN/Bi-RNN outputs with Attention vector.
7 |
8 | The idea was proposed in the article by Z. Yang et al., "Hierarchical Attention Networks
9 | for Document Classification", 2016: http://www.aclweb.org/anthology/N16-1174.
10 | Variables notation is also inherited from the article
11 |
12 | Args:
13 | inputs: The Attention inputs.
14 | Matches outputs of RNN/Bi-RNN layer (not final state):
15 | In case of RNN, this must be RNN outputs `Tensor`:
16 | If time_major == False (default), this must be a tensor of shape:
17 | `[batch_size, max_time, cell.output_size]`.
18 | If time_major == True, this must be a tensor of shape:
19 | `[max_time, batch_size, cell.output_size]`.
20 | In case of Bidirectional RNN, this must be a tuple (outputs_fw, outputs_bw) containing the forward and
21 | the backward RNN outputs `Tensor`.
22 | If time_major == False (default),
23 | outputs_fw is a `Tensor` shaped:
24 | `[batch_size, max_time, cell_fw.output_size]`
25 | and outputs_bw is a `Tensor` shaped:
26 | `[batch_size, max_time, cell_bw.output_size]`.
27 | If time_major == True,
28 | outputs_fw is a `Tensor` shaped:
29 | `[max_time, batch_size, cell_fw.output_size]`
30 | and outputs_bw is a `Tensor` shaped:
31 | `[max_time, batch_size, cell_bw.output_size]`.
32 | attention_size: Linear size of the Attention weights.
33 | time_major: The shape format of the `inputs` Tensors.
34 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
35 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
36 | Using `time_major = True` is a bit more efficient because it avoids
37 | transposes at the beginning and end of the RNN calculation. However,
38 | most TensorFlow data is batch-major, so by default this function
39 | accepts input and emits output in batch-major form.
40 | return_alphas: Whether to return attention coefficients variable along with layer's output.
41 | Used for visualization purpose.
42 | Returns:
43 | The Attention output `Tensor`.
44 | In case of RNN, this will be a `Tensor` shaped:
45 | `[batch_size, cell.output_size]`.
46 | In case of Bidirectional RNN, this will be a `Tensor` shaped:
47 | `[batch_size, cell_fw.output_size + cell_bw.output_size]`.
48 | """
49 |
50 | if isinstance(inputs, tuple):
51 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs.
52 | inputs = tf.concat(inputs, 2)
53 |
54 | if time_major:
55 | # (T,B,D) => (B,T,D)
56 | inputs = tf.array_ops.transpose(inputs, [1, 0, 2])
57 |
58 | hidden_size = inputs.shape[2].value # D value - hidden size of the RNN layer
59 |
60 | initializer = tf.random_normal_initializer(stddev=0.1)
61 |
62 | # Trainable parameters
63 | w_omega = tf.get_variable(name="w_omega", shape=[hidden_size, attention_size], initializer=initializer)
64 | b_omega = tf.get_variable(name="b_omega", shape=[attention_size], initializer=initializer)
65 | u_omega = tf.get_variable(name="u_omega", shape=[attention_size], initializer=initializer)
66 |
67 | with tf.name_scope('v'):
68 | # Applying fully connected layer with non-linear activation to each of the B*T timestamps;
69 | # the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size
70 | v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega)
71 |
72 | # For each of the timestamps its vector of size A from `v` is reduced with `u` vector
73 | vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape
74 | alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape
75 |
76 | # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape
77 | output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1)
78 |
79 | if not return_alphas:
80 | return output
81 | else:
82 | return output, alphas
83 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | """
3 | Toy example of attention layer use
4 |
5 | Train RNN (GRU) on IMDB dataset (binary classification)
6 | Learning and hyper-parameters were not tuned; script serves as an example
7 | """
8 | from __future__ import print_function, division
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 | from keras.datasets import imdb
13 | from tensorflow.contrib.rnn import GRUCell
14 | from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn as bi_rnn
15 | from tqdm import tqdm
16 |
17 | from attention import attention
18 | from utils import get_vocabulary_size, fit_in_vocabulary, zero_pad, batch_generator
19 |
20 | NUM_WORDS = 10000
21 | INDEX_FROM = 3
22 | SEQUENCE_LENGTH = 250
23 | EMBEDDING_DIM = 100
24 | HIDDEN_SIZE = 150
25 | ATTENTION_SIZE = 50
26 | KEEP_PROB = 0.8
27 | BATCH_SIZE = 256
28 | NUM_EPOCHS = 3 # Model easily overfits without pre-trained words embeddings, that's why train for a few epochs
29 | DELTA = 0.5
30 | MODEL_PATH = './model'
31 |
32 | # Load the data set
33 | (X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=NUM_WORDS, index_from=INDEX_FROM)
34 |
35 | # Sequences pre-processing
36 | vocabulary_size = get_vocabulary_size(X_train)
37 | X_test = fit_in_vocabulary(X_test, vocabulary_size)
38 | X_train = zero_pad(X_train, SEQUENCE_LENGTH)
39 | X_test = zero_pad(X_test, SEQUENCE_LENGTH)
40 |
41 | # Different placeholders
42 | with tf.name_scope('Inputs'):
43 | batch_ph = tf.placeholder(tf.int32, [None, SEQUENCE_LENGTH], name='batch_ph')
44 | target_ph = tf.placeholder(tf.float32, [None], name='target_ph')
45 | seq_len_ph = tf.placeholder(tf.int32, [None], name='seq_len_ph')
46 | keep_prob_ph = tf.placeholder(tf.float32, name='keep_prob_ph')
47 |
48 | # Embedding layer
49 | with tf.name_scope('Embedding_layer'):
50 | embeddings_var = tf.Variable(tf.random_uniform([vocabulary_size, EMBEDDING_DIM], -1.0, 1.0), trainable=True)
51 | tf.summary.histogram('embeddings_var', embeddings_var)
52 | batch_embedded = tf.nn.embedding_lookup(embeddings_var, batch_ph)
53 |
54 | # (Bi-)RNN layer(-s)
55 | rnn_outputs, _ = bi_rnn(GRUCell(HIDDEN_SIZE), GRUCell(HIDDEN_SIZE),
56 | inputs=batch_embedded, sequence_length=seq_len_ph, dtype=tf.float32)
57 | tf.summary.histogram('RNN_outputs', rnn_outputs)
58 |
59 | # Attention layer
60 | with tf.name_scope('Attention_layer'):
61 | attention_output, alphas = attention(rnn_outputs, ATTENTION_SIZE, return_alphas=True)
62 | tf.summary.histogram('alphas', alphas)
63 |
64 | # Dropout
65 | drop = tf.nn.dropout(attention_output, keep_prob_ph)
66 |
67 | # Fully connected layer
68 | with tf.name_scope('Fully_connected_layer'):
69 | W = tf.Variable(tf.truncated_normal([HIDDEN_SIZE * 2, 1], stddev=0.1)) # Hidden size is multiplied by 2 for Bi-RNN
70 | b = tf.Variable(tf.constant(0., shape=[1]))
71 | y_hat = tf.nn.xw_plus_b(drop, W, b)
72 | y_hat = tf.squeeze(y_hat)
73 | tf.summary.histogram('W', W)
74 |
75 | with tf.name_scope('Metrics'):
76 | # Cross-entropy loss and optimizer initialization
77 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_hat, labels=target_ph))
78 | tf.summary.scalar('loss', loss)
79 | optimizer = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(loss)
80 |
81 | # Accuracy metric
82 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(tf.sigmoid(y_hat)), target_ph), tf.float32))
83 | tf.summary.scalar('accuracy', accuracy)
84 |
85 | merged = tf.summary.merge_all()
86 |
87 | # Batch generators
88 | train_batch_generator = batch_generator(X_train, y_train, BATCH_SIZE)
89 | test_batch_generator = batch_generator(X_test, y_test, BATCH_SIZE)
90 |
91 | train_writer = tf.summary.FileWriter('./logdir/train', accuracy.graph)
92 | test_writer = tf.summary.FileWriter('./logdir/test', accuracy.graph)
93 |
94 | session_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
95 |
96 | saver = tf.train.Saver()
97 |
98 | if __name__ == "__main__":
99 | with tf.Session(config=session_conf) as sess:
100 | sess.run(tf.global_variables_initializer())
101 | print("Start learning...")
102 | for epoch in range(NUM_EPOCHS):
103 | loss_train = 0
104 | loss_test = 0
105 | accuracy_train = 0
106 | accuracy_test = 0
107 |
108 | print("epoch: {}\t".format(epoch), end="")
109 |
110 | # Training
111 | num_batches = X_train.shape[0] // BATCH_SIZE
112 | for b in tqdm(range(num_batches)):
113 | x_batch, y_batch = next(train_batch_generator)
114 | seq_len = np.array([list(x).index(0) + 1 for x in x_batch]) # actual lengths of sequences
115 | loss_tr, acc, _, summary = sess.run([loss, accuracy, optimizer, merged],
116 | feed_dict={batch_ph: x_batch,
117 | target_ph: y_batch,
118 | seq_len_ph: seq_len,
119 | keep_prob_ph: KEEP_PROB})
120 | accuracy_train += acc
121 | loss_train = loss_tr * DELTA + loss_train * (1 - DELTA)
122 | train_writer.add_summary(summary, b + num_batches * epoch)
123 | accuracy_train /= num_batches
124 |
125 | # Testing
126 | num_batches = X_test.shape[0] // BATCH_SIZE
127 | for b in tqdm(range(num_batches)):
128 | x_batch, y_batch = next(test_batch_generator)
129 | seq_len = np.array([list(x).index(0) + 1 for x in x_batch]) # actual lengths of sequences
130 | loss_test_batch, acc, summary = sess.run([loss, accuracy, merged],
131 | feed_dict={batch_ph: x_batch,
132 | target_ph: y_batch,
133 | seq_len_ph: seq_len,
134 | keep_prob_ph: 1.0})
135 | accuracy_test += acc
136 | loss_test += loss_test_batch
137 | test_writer.add_summary(summary, b + num_batches * epoch)
138 | accuracy_test /= num_batches
139 | loss_test /= num_batches
140 |
141 | print("loss: {:.3f}, val_loss: {:.3f}, acc: {:.3f}, val_acc: {:.3f}".format(
142 | loss_train, loss_test, accuracy_train, accuracy_test
143 | ))
144 | train_writer.close()
145 | test_writer.close()
146 | saver.save(sess, MODEL_PATH)
147 | print("Run 'tensorboard --logdir=./logdir' to checkout tensorboard logs.")
148 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 |
5 |
6 | def zero_pad(X, seq_len):
7 | return np.array([x[:seq_len - 1] + [0] * max(seq_len - len(x), 1) for x in X])
8 |
9 |
10 | def get_vocabulary_size(X):
11 | return max([max(x) for x in X]) + 1 # plus the 0th word
12 |
13 |
14 | def fit_in_vocabulary(X, voc_size):
15 | return [[w for w in x if w < voc_size] for x in X]
16 |
17 |
18 | def batch_generator(X, y, batch_size):
19 | """Primitive batch generator
20 | """
21 | size = X.shape[0]
22 | X_copy = X.copy()
23 | y_copy = y.copy()
24 | indices = np.arange(size)
25 | np.random.shuffle(indices)
26 | X_copy = X_copy[indices]
27 | y_copy = y_copy[indices]
28 | i = 0
29 | while True:
30 | if i + batch_size <= size:
31 | yield X_copy[i:i + batch_size], y_copy[i:i + batch_size]
32 | i += batch_size
33 | else:
34 | i = 0
35 | indices = np.arange(size)
36 | np.random.shuffle(indices)
37 | X_copy = X_copy[indices]
38 | y_copy = y_copy[indices]
39 | continue
40 |
41 |
42 | if __name__ == "__main__":
43 | # Test batch generator
44 | gen = batch_generator(np.array(['a', 'b', 'c', 'd']), np.array([1, 2, 3, 4]), 2)
45 | for _ in range(8):
46 | xx, yy = next(gen)
47 | print(xx, yy)
48 |
--------------------------------------------------------------------------------
/visualization.html:
--------------------------------------------------------------------------------
1 | how
2 | his
3 | :UNK:
4 | evolved
5 | as
6 | both
7 | man
8 | and
9 | ape
10 | was
11 | outstanding
12 | not
13 | to
14 | mention
15 | the
16 | scenery
17 | of
18 | the
19 | film
20 | christopher
21 | :UNK:
22 | was
23 | astonishing
24 | as
25 | lord
26 | of
27 | :UNK:
28 | christopher
29 | is
30 | the
31 | soul
32 | to
33 | this
34 | masterpiece
35 | i
36 | became
37 | so
38 | with
39 | his
40 | performance
41 | i
42 | could
43 | feel
44 | my
45 | heart
46 | :UNK:
47 | the
48 | of
49 | the
50 | movie
51 | still
52 | moves
53 | me
54 | to
55 | this
56 | day
57 | his
58 | portrayal
59 | of
60 | john
61 | was
62 | oscar
63 | worthy
64 | as
65 | he
66 | should
67 | have
68 | been
69 | nominated
70 | for
71 | it
72 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | """
3 | Example of attention coefficients visualization
4 |
5 | Uses saved model, so it should be executed after train.py
6 | """
7 | from train import *
8 |
9 | saver = tf.train.Saver()
10 |
11 | # Calculate alpha coefficients for the first test example
12 | with tf.Session() as sess:
13 | saver.restore(sess, MODEL_PATH)
14 |
15 | x_batch_test, y_batch_test = X_test[:1], y_test[:1]
16 | seq_len_test = np.array([list(x).index(0) + 1 for x in x_batch_test])
17 | alphas_test = sess.run([alphas], feed_dict={batch_ph: x_batch_test, target_ph: y_batch_test,
18 | seq_len_ph: seq_len_test, keep_prob_ph: 1.0})
19 | alphas_values = alphas_test[0][0]
20 |
21 | # Build correct mapping from word to index and inverse
22 | word_index = imdb.get_word_index()
23 | word_index = {word: index + INDEX_FROM for word, index in word_index.items()}
24 | word_index[":PAD:"] = 0
25 | word_index[":START:"] = 1
26 | word_index[":UNK:"] = 2
27 | index_word = {value: key for key, value in word_index.items()}
28 | # Represent the sample by words rather than indices
29 | words = list(map(index_word.get, x_batch_test[0]))
30 |
31 | # Save visualization as HTML
32 | with open("visualization.html", "w") as html_file:
33 | for word, alpha in zip(words, alphas_values / alphas_values.max()):
34 | if word == ":START:":
35 | continue
36 | elif word == ":PAD:":
37 | break
38 | html_file.write('%s\n' % (alpha, word))
39 |
40 | print('\nOpen visualization.html to checkout the attention coefficients visualization.')
41 |
--------------------------------------------------------------------------------