├── README.md ├── data_util.py ├── LICENSE ├── .gitignore ├── yelp.py ├── yelp_prepare.py ├── bn_lstm_test.py ├── model_components.py ├── bn_lstm.py ├── worker.py └── HAN_model.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep Text Classifier 2 | 3 | Implementation of document classification model described in [Hierarchical Attention Networks for Document Classification (Yang et al., 2016)](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf). 4 | 5 | ## How to run 6 | Get Yelp review dataset here: https://www.yelp.com/dataset_challenge 7 | ``` 8 | python3 yelp_prepare.py yelp_academic_dataset_review.json 9 | python3 worker.py --mode=train --device=/gpu:0 --batch-size=30 10 | ``` 11 | 12 | ## Results 13 | I am getting 65% accuracy on a dev set (16% of data) after 3 epochs. Results reported in the paper are 71% on Yelp'15. 14 | No systemic hyperparameter optimization was performed. -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def batch(inputs): 5 | batch_size = len(inputs) 6 | 7 | document_sizes = np.array([len(doc) for doc in inputs], dtype=np.int32) 8 | document_size = document_sizes.max() 9 | 10 | sentence_sizes_ = [[len(sent) for sent in doc] for doc in inputs] 11 | sentence_size = max(map(max, sentence_sizes_)) 12 | 13 | b = np.zeros(shape=[batch_size, document_size, sentence_size], dtype=np.int32) # == PAD 14 | 15 | sentence_sizes = np.zeros(shape=[batch_size, document_size], dtype=np.int32) 16 | for i, document in enumerate(inputs): 17 | for j, sentence in enumerate(document): 18 | sentence_sizes[i, j] = sentence_sizes_[i][j] 19 | for k, word in enumerate(sentence): 20 | b[i, j, k] = word 21 | 22 | return b, document_sizes, sentence_sizes -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Matvey Ezhov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /yelp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | train_dir = os.path.join(os.path.curdir, 'yelp') 5 | data_dir = os.path.join(train_dir, 'data') 6 | 7 | for dir in [train_dir, data_dir]: 8 | if not os.path.exists(dir): 9 | os.makedirs(dir) 10 | 11 | trainset_fn = os.path.join(data_dir, 'train.dataset') 12 | devset_fn = os.path.join(data_dir, 'dev.dataset') 13 | testset_fn = os.path.join(data_dir, 'test.dataset') 14 | vocab_fn = os.path.join(data_dir, 'vocab.pickle') 15 | 16 | reserved_tokens = 5 17 | unknown_id = 2 18 | 19 | vocab_size = 50001 20 | 21 | def _read_dataset(fn, review_max_sentences=30, sentence_max_length=30, epochs=1): 22 | c = 0 23 | while 1: 24 | c += 1 25 | if epochs > 0 and c > epochs: 26 | return 27 | print('epoch %s' % c) 28 | with open(fn, 'rb') as f: 29 | try: 30 | while 1: 31 | x, y = pickle.load(f) 32 | 33 | # clip review to specified max lengths 34 | x = x[:review_max_sentences] 35 | x = [sent[:sentence_max_length] for sent in x] 36 | 37 | y -= 1 38 | assert y >= 0 and y <= 4 39 | yield x, y 40 | except EOFError: 41 | continue 42 | 43 | def read_trainset(epochs=1): 44 | return _read_dataset(trainset_fn, epochs=epochs) 45 | 46 | def read_devset(epochs=1): 47 | return _read_dataset(devset_fn, epochs=epochs) 48 | 49 | def read_vocab(): 50 | with open(vocab_fn, 'rb') as f: 51 | return pickle.load(f) 52 | 53 | def read_labels(): 54 | return {i: i for i in range(5)} 55 | -------------------------------------------------------------------------------- /yelp_prepare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument("review_path") 4 | args = parser.parse_args() 5 | 6 | import os 7 | import ujson as json 8 | import spacy 9 | import pickle 10 | import random 11 | from tqdm import tqdm 12 | from collections import defaultdict 13 | import numpy as np 14 | from yelp import * 15 | 16 | en = spacy.load('en') 17 | en.pipeline = [en.tagger, en.parser] 18 | 19 | def read_reviews(): 20 | with open(args.review_path, 'rb') as f: 21 | for line in f: 22 | yield json.loads(line) 23 | 24 | def build_word_frequency_distribution(): 25 | path = os.path.join(data_dir, 'word_freq.pickle') 26 | 27 | try: 28 | with open(path, 'rb') as freq_dist_f: 29 | freq_dist_f = pickle.load(freq_dist_f) 30 | print('frequency distribution loaded') 31 | return freq_dist_f 32 | except IOError: 33 | pass 34 | 35 | print('building frequency distribution') 36 | freq = defaultdict(int) 37 | for i, review in enumerate(read_reviews()): 38 | doc = en.tokenizer(review['text']) 39 | for token in doc: 40 | freq[token.orth_] += 1 41 | if i % 10000 == 0: 42 | with open(path, 'wb') as freq_dist_f: 43 | pickle.dump(freq, freq_dist_f) 44 | print('dump at {}'.format(i)) 45 | return freq 46 | 47 | def build_vocabulary(lower=3, n=50000): 48 | try: 49 | with open(vocab_fn, 'rb') as vocab_file: 50 | vocab = pickle.load(vocab_file) 51 | print('vocabulary loaded') 52 | return vocab 53 | except IOError: 54 | print('building vocabulary') 55 | freq = build_word_frequency_distribution() 56 | top_words = list(sorted(freq.items(), key=lambda x: -x[1]))[:n-lower+1] 57 | vocab = {} 58 | i = lower 59 | for w, freq in top_words: 60 | vocab[w] = i 61 | i += 1 62 | with open(vocab_fn, 'wb') as vocab_file: 63 | pickle.dump(vocab, vocab_file) 64 | return vocab 65 | 66 | UNKNOWN = 2 67 | 68 | def make_data(split_points=(0.8, 0.94)): 69 | train_ratio, dev_ratio = split_points 70 | vocab = build_vocabulary() 71 | train_f = open(trainset_fn, 'wb') 72 | dev_f = open(devset_fn, 'wb') 73 | test_f = open(testset_fn, 'wb') 74 | 75 | try: 76 | for review in tqdm(read_reviews()): 77 | x = [] 78 | for sent in en(review['text']).sents: 79 | x.append([vocab.get(tok.orth_, UNKNOWN) for tok in sent]) 80 | y = review['stars'] 81 | 82 | r = random.random() 83 | if r < train_ratio: 84 | f = train_f 85 | elif r < dev_ratio: 86 | f = dev_f 87 | else: 88 | f = test_f 89 | pickle.dump((x, y), f) 90 | except KeyboardInterrupt: 91 | pass 92 | 93 | train_f.close() 94 | dev_f.close() 95 | test_f.close() 96 | 97 | if __name__ == '__main__': 98 | make_data() -------------------------------------------------------------------------------- /bn_lstm_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | import os 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.python.ops.rnn import dynamic_rnn 7 | from bn_lstm import LSTMCell, BNLSTMCell, orthogonal_initializer 8 | from tensorflow.examples.tutorials.mnist import input_data 9 | 10 | batch_size = 100 11 | hidden_size = 100 12 | 13 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 14 | 15 | x = tf.placeholder(tf.float32, [None, 784]) 16 | training = tf.placeholder(tf.bool) 17 | 18 | x_inp = tf.expand_dims(x, -1) 19 | lstm = BNLSTMCell(hidden_size, training) #LSTMCell(hidden_size) 20 | 21 | #c, h 22 | initialState = ( 23 | tf.random_normal([batch_size, hidden_size], stddev=0.1), 24 | tf.random_normal([batch_size, hidden_size], stddev=0.1)) 25 | 26 | outputs, state = dynamic_rnn(lstm, x_inp, initial_state=initialState, dtype=tf.float32) 27 | 28 | _, final_hidden = state 29 | 30 | W = tf.get_variable('W', [hidden_size, 10], initializer=orthogonal_initializer()) 31 | b = tf.get_variable('b', [10]) 32 | 33 | y = tf.nn.softmax(tf.matmul(final_hidden, W) + b) 34 | 35 | y_ = tf.placeholder(tf.float32, [None, 10]) 36 | 37 | cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 38 | 39 | optimizer = tf.train.AdamOptimizer() 40 | gvs = optimizer.compute_gradients(cross_entropy) 41 | capped_gvs = [(None if grad is None else tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs] 42 | train_step = optimizer.apply_gradients(capped_gvs) 43 | 44 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 45 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 46 | 47 | # Summaries 48 | tf.summary.scalar("accuracy", accuracy) 49 | tf.summary.scalar("xe_loss", cross_entropy) 50 | for (grad, var), (capped_grad, _) in zip(gvs, capped_gvs): 51 | if grad is not None: 52 | tf.summary.histogram('grad/{}'.format(var.name), capped_grad) 53 | tf.summary.histogram('capped_fraction/{}'.format(var.name), 54 | tf.nn.zero_fraction(grad - capped_grad)) 55 | tf.summary.histogram('weight/{}'.format(var.name), var) 56 | 57 | merged = tf.merge_all_summaries() 58 | 59 | init = tf.initialize_all_variables() 60 | 61 | sess = tf.Session() 62 | sess.run(init) 63 | 64 | logdir = 'logs/' + str(uuid.uuid4()) 65 | os.makedirs(logdir) 66 | print('logging to ' + logdir) 67 | writer = tf.summary.trainWriter(logdir, sess.graph) 68 | 69 | current_time = time.time() 70 | print("Using population statistics (training: False) at test time gives worse results than batch statistics") 71 | 72 | for i in range(100000): 73 | batch_xs, batch_ys = mnist.train.next_batch(batch_size) 74 | loss, _ = sess.run([cross_entropy, train_step], feed_dict={x: batch_xs, y_: batch_ys, training: True}) 75 | step_time = time.time() - current_time 76 | current_time = time.time() 77 | if i % 100 == 0: 78 | batch_xs, batch_ys = mnist.validation.next_batch(batch_size) 79 | summary. _str = sess.run(merged, feed_dict={x: batch_xs, y_: batch_ys, training: False}) 80 | writer.summary.add_str, i) 81 | print(loss, step_time) 82 | -------------------------------------------------------------------------------- /model_components.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | 4 | try: 5 | from tensorflow.contrib.rnn import LSTMStateTuple 6 | except ImportError: 7 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple 8 | 9 | 10 | def bidirectional_rnn(cell_fw, cell_bw, inputs_embedded, input_lengths, 11 | scope=None): 12 | """Bidirecional RNN with concatenated outputs and states""" 13 | with tf.variable_scope(scope or "birnn") as scope: 14 | ((fw_outputs, 15 | bw_outputs), 16 | (fw_state, 17 | bw_state)) = ( 18 | tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw, 19 | cell_bw=cell_bw, 20 | inputs=inputs_embedded, 21 | sequence_length=input_lengths, 22 | dtype=tf.float32, 23 | swap_memory=True, 24 | scope=scope)) 25 | outputs = tf.concat((fw_outputs, bw_outputs), 2) 26 | 27 | def concatenate_state(fw_state, bw_state): 28 | if isinstance(fw_state, LSTMStateTuple): 29 | state_c = tf.concat( 30 | (fw_state.c, bw_state.c), 1, name='bidirectional_concat_c') 31 | state_h = tf.concat( 32 | (fw_state.h, bw_state.h), 1, name='bidirectional_concat_h') 33 | state = LSTMStateTuple(c=state_c, h=state_h) 34 | return state 35 | elif isinstance(fw_state, tf.Tensor): 36 | state = tf.concat((fw_state, bw_state), 1, 37 | name='bidirectional_concat') 38 | return state 39 | elif (isinstance(fw_state, tuple) and 40 | isinstance(bw_state, tuple) and 41 | len(fw_state) == len(bw_state)): 42 | # multilayer 43 | state = tuple(concatenate_state(fw, bw) 44 | for fw, bw in zip(fw_state, bw_state)) 45 | return state 46 | 47 | else: 48 | raise ValueError( 49 | 'unknown state type: {}'.format((fw_state, bw_state))) 50 | 51 | 52 | state = concatenate_state(fw_state, bw_state) 53 | return outputs, state 54 | 55 | 56 | def task_specific_attention(inputs, output_size, 57 | initializer=layers.xavier_initializer(), 58 | activation_fn=tf.tanh, scope=None): 59 | """ 60 | Performs task-specific attention reduction, using learned 61 | attention context vector (constant within task of interest). 62 | 63 | Args: 64 | inputs: Tensor of shape [batch_size, units, input_size] 65 | `input_size` must be static (known) 66 | `units` axis will be attended over (reduced from output) 67 | `batch_size` will be preserved 68 | output_size: Size of output's inner (feature) dimension 69 | 70 | Returns: 71 | outputs: Tensor of shape [batch_size, output_dim]. 72 | """ 73 | assert len(inputs.get_shape()) == 3 and inputs.get_shape()[-1].value is not None 74 | 75 | with tf.variable_scope(scope or 'attention') as scope: 76 | attention_context_vector = tf.get_variable(name='attention_context_vector', 77 | shape=[output_size], 78 | initializer=initializer, 79 | dtype=tf.float32) 80 | input_projection = layers.fully_connected(inputs, output_size, 81 | activation_fn=activation_fn, 82 | scope=scope) 83 | 84 | vector_attn = tf.reduce_sum(tf.multiply(input_projection, attention_context_vector), axis=2, keep_dims=True) 85 | attention_weights = tf.nn.softmax(vector_attn, dim=1) 86 | weighted_projection = tf.multiply(input_projection, attention_weights) 87 | 88 | outputs = tf.reduce_sum(weighted_projection, axis=1) 89 | 90 | return outputs 91 | -------------------------------------------------------------------------------- /bn_lstm.py: -------------------------------------------------------------------------------- 1 | # borrowed from https://github.com/OlavHN/bnlstm, updated for r1.0 2 | 3 | import math 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | try: 8 | from tensorflow.contrib.rnn import RNNCell 9 | except ImportError: 10 | RNNCell = tf.nn.rnn_cell.RNNCel 11 | 12 | 13 | class LSTMCell(RNNCell): 14 | """Vanilla LSTM implemented with same initializations as BN-LSTM""" 15 | def __init__(self, num_units): 16 | self.num_units = num_units 17 | 18 | @property 19 | def state_size(self): 20 | return (self.num_units, self.num_units) 21 | 22 | @property 23 | def output_size(self): 24 | return self.num_units 25 | 26 | def __call__(self, x, state, scope=None): 27 | with tf.variable_scope(scope or type(self).__name__): 28 | c, h = state 29 | 30 | # Keep W_xh and W_hh separate here as well to reuse initialization methods 31 | x_size = x.get_shape().as_list()[1] 32 | W_xh = tf.get_variable('W_xh', 33 | [x_size, 4 * self.num_units], 34 | initializer=orthogonal_initializer()) 35 | W_hh = tf.get_variable('W_hh', 36 | [self.num_units, 4 * self.num_units], 37 | initializer=bn_lstm_identity_initializer(0.95)) 38 | bias = tf.get_variable('bias', [4 * self.num_units]) 39 | 40 | # hidden = tf.matmul(x, W_xh) + tf.matmul(h, W_hh) + bias 41 | # improve speed by concat. 42 | concat = tf.concat([x, h], 1) 43 | W_both = tf.concat([W_xh, W_hh], 0) 44 | hidden = tf.matmul(concat, W_both) + bias 45 | 46 | i, j, f, o = tf.split(hidden, 4, axis=1) 47 | 48 | new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j) 49 | new_h = tf.tanh(new_c) * tf.sigmoid(o) 50 | 51 | return new_h, (new_c, new_h) 52 | 53 | class BNLSTMCell(RNNCell): 54 | """Batch normalized LSTM as described in http://arxiv.org/abs/1603.09025""" 55 | def __init__(self, num_units, training): 56 | self.num_units = num_units 57 | self.training = training 58 | 59 | @property 60 | def state_size(self): 61 | return (self.num_units, self.num_units) 62 | 63 | @property 64 | def output_size(self): 65 | return self.num_units 66 | 67 | def __call__(self, x, state, scope=None): 68 | with tf.variable_scope(scope or 'bn_lstm'): 69 | c, h = state 70 | 71 | x_size = x.get_shape().as_list()[1] 72 | W_xh = tf.get_variable('W_xh', 73 | [x_size, 4 * self.num_units], 74 | initializer=orthogonal_initializer()) 75 | W_hh = tf.get_variable('W_hh', 76 | [self.num_units, 4 * self.num_units], 77 | initializer=bn_lstm_identity_initializer(0.95)) 78 | bias = tf.get_variable('bias', [4 * self.num_units]) 79 | 80 | xh = tf.matmul(x, W_xh) 81 | hh = tf.matmul(h, W_hh) 82 | 83 | bn_xh = batch_norm(xh, 'xh', self.training) 84 | bn_hh = batch_norm(hh, 'hh', self.training) 85 | 86 | hidden = bn_xh + bn_hh + bias 87 | 88 | i, j, f, o = tf.split(hidden, 4, axis=1) 89 | 90 | new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j) 91 | bn_new_c = batch_norm(new_c, 'c', self.training) 92 | 93 | new_h = tf.tanh(bn_new_c) * tf.sigmoid(o) 94 | 95 | return new_h, (new_c, new_h) 96 | 97 | def orthogonal(shape): 98 | flat_shape = (shape[0], np.prod(shape[1:])) 99 | a = np.random.normal(0.0, 1.0, flat_shape) 100 | u, _, v = np.linalg.svd(a, full_matrices=False) 101 | q = u if u.shape == flat_shape else v 102 | return q.reshape(shape) 103 | 104 | def bn_lstm_identity_initializer(scale): 105 | def _initializer(shape, dtype=tf.float32, partition_info=None): 106 | """Ugly cause LSTM params calculated in one matrix multiply""" 107 | size = shape[0] 108 | # gate (j) is identity 109 | t = np.zeros(shape) 110 | t[:, size:size * 2] = np.identity(size) * scale 111 | t[:, :size] = orthogonal([size, size]) 112 | t[:, size * 2:size * 3] = orthogonal([size, size]) 113 | t[:, size * 3:] = orthogonal([size, size]) 114 | return tf.constant(t, dtype=dtype) 115 | 116 | return _initializer 117 | 118 | def orthogonal_initializer(): 119 | def _initializer(shape, dtype=tf.float32, partition_info=None): 120 | return tf.constant(orthogonal(shape), dtype) 121 | return _initializer 122 | 123 | def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.999): 124 | """Assume 2d [batch, values] tensor""" 125 | 126 | with tf.variable_scope(name_scope): 127 | size = x.get_shape().as_list()[1] 128 | 129 | scale = tf.get_variable('scale', [size], 130 | initializer=tf.constant_initializer(0.1)) 131 | offset = tf.get_variable('offset', [size]) 132 | 133 | pop_mean = tf.get_variable('pop_mean', [size], 134 | initializer=tf.zeros_initializer(), 135 | trainable=False) 136 | pop_var = tf.get_variable('pop_var', [size], 137 | initializer=tf.ones_initializer(), 138 | trainable=False) 139 | batch_mean, batch_var = tf.nn.moments(x, [0]) 140 | 141 | train_mean_op = tf.assign( 142 | pop_mean, 143 | pop_mean * decay + batch_mean * (1 - decay)) 144 | train_var_op = tf.assign( 145 | pop_var, 146 | pop_var * decay + batch_var * (1 - decay)) 147 | 148 | def batch_statistics(): 149 | with tf.control_dependencies([train_mean_op, train_var_op]): 150 | return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon) 151 | 152 | def population_statistics(): 153 | return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon) 154 | 155 | return tf.cond(training, batch_statistics, population_statistics) 156 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--task', default='yelp', choices=['yelp']) 5 | parser.add_argument('--mode', default='train', choices=['train', 'eval']) 6 | parser.add_argument('--checkpoint-frequency', type=int, default=100) 7 | parser.add_argument('--eval-frequency', type=int, default=10000) 8 | parser.add_argument('--batch-size', type=int, default=30) 9 | parser.add_argument("--device", default="/cpu:0") 10 | parser.add_argument("--max-grad-norm", type=float, default=5.0) 11 | parser.add_argument("--lr", type=float, default=0.001) 12 | args = parser.parse_args() 13 | 14 | import importlib 15 | import os 16 | import pickle 17 | import random 18 | import time 19 | from collections import Counter, defaultdict 20 | 21 | import numpy as np 22 | import pandas as pd 23 | import spacy 24 | import tensorflow as tf 25 | from tensorflow.contrib.tensorboard.plugins import projector 26 | from tqdm import tqdm 27 | 28 | import ujson 29 | from data_util import batch 30 | 31 | task_name = args.task 32 | 33 | task = importlib.import_module(task_name) 34 | 35 | checkpoint_dir = os.path.join(task.train_dir, 'checkpoint') 36 | tflog_dir = os.path.join(task.train_dir, 'tflog') 37 | checkpoint_name = task_name + '-model' 38 | checkpoint_dir = os.path.join(task.train_dir, 'checkpoints') 39 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) 40 | 41 | # @TODO: move calculation into `task file` 42 | trainset = task.read_trainset(epochs=1) 43 | class_weights = pd.Series(Counter([l for _, l in trainset])) 44 | class_weights = 1/(class_weights/class_weights.mean()) 45 | class_weights = class_weights.to_dict() 46 | 47 | vocab = task.read_vocab() 48 | labels = task.read_labels() 49 | 50 | classes = max(labels.values())+1 51 | vocab_size = task.vocab_size 52 | 53 | labels_rev = {int(v): k for k, v in labels.items()} 54 | vocab_rev = {int(v): k for k, v in vocab.items()} 55 | 56 | 57 | def HAN_model_1(session, restore_only=False): 58 | """Hierarhical Attention Network""" 59 | import tensorflow as tf 60 | try: 61 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, DropoutWrapper 62 | except ImportError: 63 | MultiRNNCell = tf.nn.rnn_cell.MultiRNNCell 64 | GRUCell = tf.nn.rnn_cell.GRUCell 65 | from bn_lstm import BNLSTMCell 66 | from HAN_model import HANClassifierModel 67 | 68 | is_training = tf.placeholder(dtype=tf.bool, name='is_training') 69 | 70 | cell = BNLSTMCell(80, is_training) # h-h batchnorm LSTMCell 71 | # cell = GRUCell(30) 72 | cell = MultiRNNCell([cell]*5) 73 | 74 | model = HANClassifierModel( 75 | vocab_size=vocab_size, 76 | embedding_size=200, 77 | classes=classes, 78 | word_cell=cell, 79 | sentence_cell=cell, 80 | word_output_size=100, 81 | sentence_output_size=100, 82 | device=args.device, 83 | learning_rate=args.lr, 84 | max_grad_norm=args.max_grad_norm, 85 | dropout_keep_proba=0.5, 86 | is_training=is_training, 87 | ) 88 | 89 | saver = tf.train.Saver(tf.global_variables()) 90 | checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) 91 | if checkpoint: 92 | print("Reading model parameters from %s" % checkpoint.model_checkpoint_path) 93 | saver.restore(session, checkpoint.model_checkpoint_path) 94 | elif restore_only: 95 | raise FileNotFoundError("Cannot restore model") 96 | else: 97 | print("Created model with fresh parameters") 98 | session.run(tf.global_variables_initializer()) 99 | # tf.get_default_graph().finalize() 100 | return model, saver 101 | 102 | model_fn = HAN_model_1 103 | 104 | def decode(ex): 105 | print('text: ' + '\n'.join([' '.join([vocab_rev.get(wid, '') for wid in sent]) for sent in ex[0]])) 106 | print('label: ', labels_rev[ex[1]]) 107 | 108 | print('data loaded') 109 | 110 | def batch_iterator(dataset, batch_size, max_epochs): 111 | for i in range(max_epochs): 112 | xb = [] 113 | yb = [] 114 | for ex in dataset: 115 | x, y = ex 116 | xb.append(x) 117 | yb.append(y) 118 | if len(xb) == batch_size: 119 | yield xb, yb 120 | xb, yb = [], [] 121 | 122 | 123 | def ev(session, model, dataset): 124 | predictions = [] 125 | labels = [] 126 | examples = [] 127 | for x, y in tqdm(batch_iterator(dataset, args.batch_size, 1)): 128 | examples.extend(x) 129 | labels.extend(y) 130 | predictions.extend(session.run(model.prediction, model.get_feed_data(x, is_training=False))) 131 | 132 | df = pd.DataFrame({'predictions': predictions, 'labels': labels, 'examples': examples}) 133 | return df 134 | 135 | 136 | def evaluate(dataset): 137 | tf.reset_default_graph() 138 | config = tf.ConfigProto(allow_soft_placement=True) 139 | with tf.Session(config=config) as s: 140 | model, _ = model_fn(s, restore_only=True) 141 | df = ev(s, model, dataset) 142 | print((df['predictions'] == df['labels']).mean()) 143 | import IPython 144 | IPython.embed() 145 | 146 | 147 | def train(): 148 | tf.reset_default_graph() 149 | 150 | config = tf.ConfigProto(allow_soft_placement=True) 151 | 152 | with tf.Session(config=config) as s: 153 | model, saver = model_fn(s) 154 | summary_writer = tf.summary.FileWriter(tflog_dir, graph=tf.get_default_graph()) 155 | 156 | # Format: tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto 157 | # pconf = projector.ProjectorConfig() 158 | 159 | # # You can add multiple embeddings. Here we add only one. 160 | # embedding = pconf.embeddings.add() 161 | # embedding.tensor_name = m.embedding_matrix.name 162 | 163 | # # Link this tensor to its metadata file (e.g. labels). 164 | # embedding.metadata_path = vocab_tsv 165 | 166 | # print(embedding.tensor_name) 167 | 168 | # Saves a configuration file that TensorBoard will read during startup. 169 | 170 | for i, (x, y) in enumerate(batch_iterator(task.read_trainset(epochs=3), args.batch_size, 300)): 171 | fd = model.get_feed_data(x, y, class_weights=class_weights) 172 | 173 | # import IPython 174 | # IPython.embed() 175 | 176 | t0 = time.clock() 177 | step, summaries, loss, accuracy, _ = s.run([ 178 | model.global_step, 179 | model.summary_op, 180 | model.loss, 181 | model.accuracy, 182 | model.train_op, 183 | ], fd) 184 | td = time.clock() - t0 185 | 186 | summary_writer.add_summary(summaries, global_step=step) 187 | # projector.visualize_embeddings(summary_writer, pconf) 188 | 189 | if step % 1 == 0: 190 | print('step %s, loss=%s, accuracy=%s, t=%s, inputs=%s' % (step, loss, accuracy, round(td, 2), fd[model.inputs].shape)) 191 | if step != 0 and step % args.checkpoint_frequency == 0: 192 | print('checkpoint & graph meta') 193 | saver.save(s, checkpoint_path, global_step=step) 194 | print('checkpoint done') 195 | if step != 0 and step % args.eval_frequency == 0: 196 | print('evaluation at step %s' % i) 197 | dev_df = ev(s, model, task.read_devset(epochs=1)) 198 | print('dev accuracy: %.2f' % (dev_df['predictions'] == dev_df['labels']).mean()) 199 | 200 | def main(): 201 | if args.mode == 'train': 202 | train() 203 | elif args.mode == 'eval': 204 | evaluate(task.read_devset(epochs=1)) 205 | 206 | if __name__ == '__main__': 207 | main() 208 | -------------------------------------------------------------------------------- /HAN_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as layers 3 | import numpy as np 4 | import data_util 5 | from model_components import task_specific_attention, bidirectional_rnn 6 | 7 | 8 | class HANClassifierModel(): 9 | """ Implementation of document classification model described in 10 | `Hierarchical Attention Networks for Document Classification (Yang et al., 2016)` 11 | (https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf)""" 12 | 13 | def __init__(self, 14 | vocab_size, 15 | embedding_size, 16 | classes, 17 | word_cell, 18 | sentence_cell, 19 | word_output_size, 20 | sentence_output_size, 21 | max_grad_norm, 22 | dropout_keep_proba, 23 | is_training=None, 24 | learning_rate=1e-4, 25 | device='/cpu:0', 26 | scope=None): 27 | self.vocab_size = vocab_size 28 | self.embedding_size = embedding_size 29 | self.classes = classes 30 | self.word_cell = word_cell 31 | self.word_output_size = word_output_size 32 | self.sentence_cell = sentence_cell 33 | self.sentence_output_size = sentence_output_size 34 | self.max_grad_norm = max_grad_norm 35 | self.dropout_keep_proba = dropout_keep_proba 36 | 37 | with tf.variable_scope(scope or 'tcm') as scope: 38 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 39 | 40 | if is_training is not None: 41 | self.is_training = is_training 42 | else: 43 | self.is_training = tf.placeholder(dtype=tf.bool, name='is_training') 44 | 45 | self.sample_weights = tf.placeholder(shape=(None,), dtype=tf.float32, name='sample_weights') 46 | 47 | # [document x sentence x word] 48 | self.inputs = tf.placeholder(shape=(None, None, None), dtype=tf.int32, name='inputs') 49 | 50 | # [document x sentence] 51 | self.word_lengths = tf.placeholder(shape=(None, None), dtype=tf.int32, name='word_lengths') 52 | 53 | # [document] 54 | self.sentence_lengths = tf.placeholder(shape=(None,), dtype=tf.int32, name='sentence_lengths') 55 | 56 | # [document] 57 | self.labels = tf.placeholder(shape=(None,), dtype=tf.int32, name='labels') 58 | 59 | (self.document_size, 60 | self.sentence_size, 61 | self.word_size) = tf.unstack(tf.shape(self.inputs)) 62 | 63 | self._init_embedding(scope) 64 | 65 | # embeddings cannot be placed on GPU 66 | with tf.device(device): 67 | self._init_body(scope) 68 | 69 | with tf.variable_scope('train'): 70 | self.cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels, logits=self.logits) 71 | 72 | self.loss = tf.reduce_mean(tf.multiply(self.cross_entropy, self.sample_weights)) 73 | tf.summary.scalar('loss', self.loss) 74 | 75 | self.accuracy = tf.reduce_mean(tf.cast(tf.nn.in_top_k(self.logits, self.labels, 1), tf.float32)) 76 | tf.summary.scalar('accuracy', self.accuracy) 77 | 78 | tvars = tf.trainable_variables() 79 | 80 | grads, global_norm = tf.clip_by_global_norm( 81 | tf.gradients(self.loss, tvars), 82 | self.max_grad_norm) 83 | tf.summary.scalar('global_grad_norm', global_norm) 84 | 85 | opt = tf.train.AdamOptimizer(learning_rate) 86 | 87 | self.train_op = opt.apply_gradients( 88 | zip(grads, tvars), name='train_op', 89 | global_step=self.global_step) 90 | 91 | self.summary_op = tf.summary.merge_all() 92 | 93 | def _init_embedding(self, scope): 94 | with tf.variable_scope(scope): 95 | with tf.variable_scope("embedding") as scope: 96 | self.embedding_matrix = tf.get_variable( 97 | name="embedding_matrix", 98 | shape=[self.vocab_size, self.embedding_size], 99 | initializer=layers.xavier_initializer(), 100 | dtype=tf.float32) 101 | self.inputs_embedded = tf.nn.embedding_lookup( 102 | self.embedding_matrix, self.inputs) 103 | 104 | def _init_body(self, scope): 105 | with tf.variable_scope(scope): 106 | 107 | word_level_inputs = tf.reshape(self.inputs_embedded, [ 108 | self.document_size * self.sentence_size, 109 | self.word_size, 110 | self.embedding_size 111 | ]) 112 | word_level_lengths = tf.reshape( 113 | self.word_lengths, [self.document_size * self.sentence_size]) 114 | 115 | with tf.variable_scope('word') as scope: 116 | word_encoder_output, _ = bidirectional_rnn( 117 | self.word_cell, self.word_cell, 118 | word_level_inputs, word_level_lengths, 119 | scope=scope) 120 | 121 | with tf.variable_scope('attention') as scope: 122 | word_level_output = task_specific_attention( 123 | word_encoder_output, 124 | self.word_output_size, 125 | scope=scope) 126 | 127 | with tf.variable_scope('dropout'): 128 | word_level_output = layers.dropout( 129 | word_level_output, keep_prob=self.dropout_keep_proba, 130 | is_training=self.is_training, 131 | ) 132 | 133 | # sentence_level 134 | 135 | sentence_inputs = tf.reshape( 136 | word_level_output, [self.document_size, self.sentence_size, self.word_output_size]) 137 | 138 | with tf.variable_scope('sentence') as scope: 139 | sentence_encoder_output, _ = bidirectional_rnn( 140 | self.sentence_cell, self.sentence_cell, sentence_inputs, self.sentence_lengths, scope=scope) 141 | 142 | with tf.variable_scope('attention') as scope: 143 | sentence_level_output = task_specific_attention( 144 | sentence_encoder_output, self.sentence_output_size, scope=scope) 145 | 146 | with tf.variable_scope('dropout'): 147 | sentence_level_output = layers.dropout( 148 | sentence_level_output, keep_prob=self.dropout_keep_proba, 149 | is_training=self.is_training, 150 | ) 151 | 152 | with tf.variable_scope('classifier'): 153 | self.logits = layers.fully_connected( 154 | sentence_level_output, self.classes, activation_fn=None) 155 | 156 | self.prediction = tf.argmax(self.logits, axis=-1) 157 | 158 | def get_feed_data(self, x, y=None, class_weights=None, is_training=True): 159 | x_m, doc_sizes, sent_sizes = data_util.batch(x) 160 | fd = { 161 | self.inputs: x_m, 162 | self.sentence_lengths: doc_sizes, 163 | self.word_lengths: sent_sizes, 164 | } 165 | if y is not None: 166 | fd[self.labels] = y 167 | if class_weights is not None: 168 | fd[self.sample_weights] = [class_weights[yy] for yy in y] 169 | else: 170 | fd[self.sample_weights] = np.ones(shape=[len(x_m)], dtype=np.float32) 171 | fd[self.is_training] = is_training 172 | return fd 173 | 174 | 175 | if __name__ == '__main__': 176 | try: 177 | from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple, GRUCell 178 | except ImportError: 179 | LSTMCell = tf.nn.rnn_cell.LSTMCell 180 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple 181 | GRUCell = tf.nn.rnn_cell.GRUCell 182 | 183 | tf.reset_default_graph() 184 | with tf.Session() as session: 185 | model = HANClassifierModel( 186 | vocab_size=10, 187 | embedding_size=5, 188 | classes=2, 189 | word_cell=GRUCell(10), 190 | sentence_cell=GRUCell(10), 191 | word_output_size=10, 192 | sentence_output_size=10, 193 | max_grad_norm=5.0, 194 | dropout_keep_proba=0.5, 195 | ) 196 | session.run(tf.global_variables_initializer()) 197 | 198 | fd = { 199 | model.is_training: False, 200 | model.inputs: [[ 201 | [5, 4, 1, 0], 202 | [3, 3, 6, 7], 203 | [6, 7, 0, 0] 204 | ], 205 | [ 206 | [2, 2, 1, 0], 207 | [3, 3, 6, 7], 208 | [0, 0, 0, 0] 209 | ]], 210 | model.word_lengths: [ 211 | [3, 4, 2], 212 | [3, 4, 0], 213 | ], 214 | model.sentence_lengths: [3, 2], 215 | model.labels: [0, 1], 216 | } 217 | 218 | print(session.run(model.logits, fd)) 219 | session.run(model.train_op, fd) 220 | --------------------------------------------------------------------------------