├── .gitignore ├── LICENSE ├── README.md ├── data_prepare.py ├── data_reader.py ├── img ├── model.png └── train_log.png ├── layers.py ├── model.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Quoc-Tuan Truong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Attention Networks for Document Classification 2 | 3 | This is an implementation of the paper [Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~hovy/papers/16HLT-hierarchical-attention-networks.pdf), NAACL 2016. 4 | 5 | ![alt tag](img/model.png) 6 | 7 | ## Requirements 8 | 9 | - Python 3 10 | - Tensorflow > 1.0 11 | - Pandas 12 | - Nltk 13 | - Tqdm 14 | - [Glove pre-trained word embeddings](http://nlp.stanford.edu/data/glove.6B.zip) 15 | 16 | ## Data 17 | 18 | We use the [data](http://ir.hit.edu.cn/~dytang/paper/emnlp2015/emnlp-2015-data.7z) provided by [Tang et al. 2015](http://ir.hit.edu.cn/~dytang/paper/emnlp2015/emnlp2015.pdf), including 4 datasets: 19 | 20 | - IMDB 21 | - Yelp 2013 22 | - Yelp 2014 23 | - Yelp 2015 24 | 25 | **Note:** 26 | The original data seems to have an [issue](https://github.com/tqtg/hierarchical-attention-networks/issues/1) with unzipping. I re-uploaded the [data](https://drive.google.com/file/d/1OQ_ggjlNUWiTg_zFXc0_OpYXpJRwJP3y) to GG Drive for better downloading speed. Please request for access permission. 27 | 28 | ## Usage 29 | 30 | First, download the [datasets](#data) and unzip into `data` folder. 31 |
32 | Then, run script to prepare the data *(default is using Yelp-2015 dataset)*: 33 | 34 | ```bash 35 | python data_prepare.py 36 | ``` 37 | 38 | Train and evaluate the model: 39 |
40 | *(make sure [Glove embeddings](#requirements) are ready before training)* 41 | ``` 42 | wget http://nlp.stanford.edu/data/glove.6B.zip 43 | unzip glove.6B.zip 44 | ``` 45 | ```bash 46 | python train.py 47 | ``` 48 | 49 | Print training arguments: 50 | 51 | ```bash 52 | python train.py --help 53 | ``` 54 | ``` 55 | optional arguments: 56 | -h, --help show this help message and exit 57 | --cell_dim CELL_DIM 58 | Hidden dimensions of GRU cells (default: 50) 59 | --att_dim ATTENTION_DIM 60 | Dimensionality of attention spaces (default: 100) 61 | --emb_dim EMBEDDING_DIM 62 | Dimensionality of word embedding (default: 200) 63 | --learning_rate LEARNING_RATE 64 | Learning rate (default: 0.0005) 65 | --max_grad_norm MAX_GRAD_NORM 66 | Maximum value of the global norm of the gradients for clipping (default: 5.0) 67 | --dropout_rate DROPOUT_RATE 68 | Probability of dropping neurons (default: 0.5) 69 | --num_classes NUM_CLASSES 70 | Number of classes (default: 5) 71 | --num_checkpoints NUM_CHECKPOINTS 72 | Number of checkpoints to store (default: 1) 73 | --num_epochs NUM_EPOCHS 74 | Number of training epochs (default: 20) 75 | --batch_size BATCH_SIZE 76 | Batch size (default: 64) 77 | --display_step DISPLAY_STEP 78 | Number of steps to display log into TensorBoard (default: 20) 79 | --allow_soft_placement ALLOW_SOFT_PLACEMENT 80 | Allow device soft device placement 81 | ``` 82 | 83 | ## Results 84 | 85 | With the *Yelp-2015* dataset, after 5 epochs, we achieved: 86 | 87 | - **69.79%** accuracy on the *dev set* 88 | - **69.62%** accuracy on the *test set* 89 | 90 | No systematic hyper-parameter tunning was performed. The result reported in the paper is **71.0%** for the *Yelp-2015*. 91 | 92 | ![alt tag](img/train_log.png) 93 | -------------------------------------------------------------------------------- /data_prepare.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import nltk 3 | import itertools 4 | import pickle 5 | 6 | # Hyper parameters 7 | WORD_CUT_OFF = 5 8 | 9 | 10 | def build_vocab(docs, save_path): 11 | print('Building vocab ...') 12 | 13 | sents = itertools.chain(*[text.split('') for text in docs]) 14 | tokenized_sents = [sent.split() for sent in sents] 15 | 16 | # Count the word frequencies 17 | word_freq = nltk.FreqDist(itertools.chain(*tokenized_sents)) 18 | print("%d unique words found" % len(word_freq.items())) 19 | 20 | # Cut-off 21 | retained_words = [w for (w, f) in word_freq.items() if f > WORD_CUT_OFF] 22 | print("%d words retained" % len(retained_words)) 23 | 24 | # Get the most common words and build index_to_word and word_to_index vectors 25 | # Word index starts from 2, 1 is reserved for UNK, 0 is reserved for padding 26 | word_to_index = {'PAD': 0, 'UNK': 1} 27 | for i, w in enumerate(retained_words): 28 | word_to_index[w] = i + 2 29 | index_to_word = {i: w for (w, i) in word_to_index.items()} 30 | 31 | print("Vocabulary size = %d" % len(word_to_index)) 32 | 33 | with open('{}-w2i.pkl'.format(save_path), 'wb') as f: 34 | pickle.dump(word_to_index, f) 35 | 36 | with open('{}-i2w.pkl'.format(save_path), 'wb') as f: 37 | pickle.dump(index_to_word, f) 38 | 39 | return word_to_index 40 | 41 | 42 | def process_and_save(word_to_index, data, out_file): 43 | mapped_data = [] 44 | for label, doc in zip(data[4], data[6]): 45 | mapped_doc = [[word_to_index.get(word, 1) for word in sent.split()] for sent in doc.split('')] 46 | mapped_data.append((label, mapped_doc)) 47 | 48 | with open(out_file, 'wb') as f: 49 | pickle.dump(mapped_data, f) 50 | 51 | 52 | def read_data(data_file): 53 | data = pd.read_csv(data_file, sep='\t', header=None, usecols=[4, 6]) 54 | print('{}, shape={}'.format(data_file, data.shape)) 55 | return data 56 | 57 | 58 | if __name__ == '__main__': 59 | train_data = read_data('data/yelp-2015-train.txt.ss') 60 | word_to_index = build_vocab(train_data[6], 'data/yelp-2015') 61 | process_and_save(word_to_index, train_data, 'data/yelp-2015-train.pkl') 62 | 63 | dev_data = read_data('data/yelp-2015-dev.txt.ss') 64 | process_and_save(word_to_index, dev_data, 'data/yelp-2015-dev.pkl') 65 | 66 | test_data = read_data('data/yelp-2015-test.txt.ss') 67 | process_and_save(word_to_index, test_data, 'data/yelp-2015-test.pkl') 68 | -------------------------------------------------------------------------------- /data_reader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from tqdm import tqdm 3 | import random 4 | import numpy as np 5 | 6 | class DataReader: 7 | def __init__(self, train_file, dev_file, test_file, 8 | max_word_length=50, max_sent_length=30, num_classes=5): 9 | self.max_word_length = max_word_length 10 | self.max_sent_length = max_sent_length 11 | self.num_classes = num_classes 12 | 13 | self.train_data = self._read_data(train_file) 14 | self.valid_data = self._read_data(dev_file) 15 | self.test_data = self._read_data(test_file) 16 | 17 | def _read_data(self, file_path): 18 | print('Reading data from %s' % file_path) 19 | new_data = [] 20 | with open(file_path, 'rb') as f: 21 | data = pickle.load(f) 22 | random.shuffle(data) 23 | for label, doc in data: 24 | doc = doc[:self.max_sent_length] 25 | doc = [sent[:self.max_word_length] for sent in doc] 26 | 27 | label -= 1 28 | assert label >= 0 and label < self.num_classes 29 | 30 | new_data.append((doc, label)) 31 | 32 | # sort data by sent lengths to speed up 33 | new_data = sorted(new_data, key=lambda x: len(x[0])) 34 | return new_data 35 | 36 | def _batch_iterator(self, data, batch_size, desc=None): 37 | num_batches = int(np.ceil(len(data) / batch_size)) 38 | for b in tqdm(range(num_batches), desc): 39 | begin_offset = batch_size * b 40 | end_offset = batch_size * b + batch_size 41 | if end_offset > len(data): 42 | end_offset = len(data) 43 | 44 | doc_batch = [] 45 | label_batch = [] 46 | for offset in range(begin_offset, end_offset): 47 | doc_batch.append(data[offset][0]) 48 | label_batch.append(data[offset][1]) 49 | 50 | yield doc_batch, label_batch 51 | 52 | def read_train_set(self, batch_size, shuffle=False): 53 | if shuffle: 54 | random.shuffle(self.train_data) 55 | return self._batch_iterator(self.train_data, batch_size, desc='Training') 56 | 57 | def read_valid_set(self, batch_size): 58 | return self._batch_iterator(self.valid_data, batch_size, desc='Validating') 59 | 60 | def read_test_set(self, batch_size): 61 | return self._batch_iterator(self.test_data, batch_size, desc='Testing') 62 | -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tqtg/hierarchical-attention-networks/0b36b64115137bebbb66760fe9e543a1298bc158/img/model.png -------------------------------------------------------------------------------- /img/train_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tqtg/hierarchical-attention-networks/0b36b64115137bebbb66760fe9e543a1298bc158/img/train_log.png -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import get_shape 4 | 5 | try: 6 | from tensorflow.contrib.rnn import LSTMStateTuple 7 | except ImportError: 8 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple 9 | 10 | 11 | 12 | def bidirectional_rnn(cell_fw, cell_bw, inputs, input_lengths, 13 | initial_state_fw=None, initial_state_bw=None, 14 | scope=None): 15 | with tf.variable_scope(scope or 'bi_rnn') as scope: 16 | (fw_outputs, bw_outputs), (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn( 17 | cell_fw=cell_fw, 18 | cell_bw=cell_bw, 19 | inputs=inputs, 20 | sequence_length=input_lengths, 21 | initial_state_fw=initial_state_fw, 22 | initial_state_bw=initial_state_bw, 23 | dtype=tf.float32, 24 | scope=scope 25 | ) 26 | outputs = tf.concat((fw_outputs, bw_outputs), axis=2) 27 | 28 | def concatenate_state(fw_state, bw_state): 29 | if isinstance(fw_state, LSTMStateTuple): 30 | state_c = tf.concat( 31 | (fw_state.c, bw_state.c), 1, name='bidirectional_concat_c') 32 | state_h = tf.concat( 33 | (fw_state.h, bw_state.h), 1, name='bidirectional_concat_h') 34 | state = LSTMStateTuple(c=state_c, h=state_h) 35 | return state 36 | elif isinstance(fw_state, tf.Tensor): 37 | state = tf.concat((fw_state, bw_state), 1, 38 | name='bidirectional_concat') 39 | return state 40 | elif (isinstance(fw_state, tuple) and 41 | isinstance(bw_state, tuple) and 42 | len(fw_state) == len(bw_state)): 43 | # multilayer 44 | state = tuple(concatenate_state(fw, bw) 45 | for fw, bw in zip(fw_state, bw_state)) 46 | return state 47 | 48 | else: 49 | raise ValueError( 50 | 'unknown state type: {}'.format((fw_state, bw_state))) 51 | 52 | state = concatenate_state(fw_state, bw_state) 53 | return outputs, state 54 | 55 | 56 | def masking(scores, sequence_lengths, score_mask_value=tf.constant(-np.inf)): 57 | score_mask = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(scores)[1]) 58 | score_mask_values = score_mask_value * tf.ones_like(scores) 59 | return tf.where(score_mask, scores, score_mask_values) 60 | 61 | 62 | def attention(inputs, att_dim, sequence_lengths, scope=None): 63 | assert len(inputs.get_shape()) == 3 and inputs.get_shape()[-1].value is not None 64 | 65 | with tf.variable_scope(scope or 'attention'): 66 | word_att_W = tf.get_variable(name='att_W', shape=[att_dim, 1]) 67 | 68 | projection = tf.layers.dense(inputs, att_dim, tf.nn.tanh, name='projection') 69 | 70 | alpha = tf.matmul(tf.reshape(projection, shape=[-1, att_dim]), word_att_W) 71 | alpha = tf.reshape(alpha, shape=[-1, get_shape(inputs)[1]]) 72 | alpha = masking(alpha, sequence_lengths, tf.constant(-1e15, dtype=tf.float32)) 73 | alpha = tf.nn.softmax(alpha) 74 | 75 | outputs = tf.reduce_sum(inputs * tf.expand_dims(alpha, 2), axis=1) 76 | return outputs, alpha 77 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.rnn as rnn 3 | from layers import bidirectional_rnn, attention 4 | from utils import get_shape, batch_doc_normalize 5 | 6 | 7 | class Model: 8 | def __init__(self, cell_dim, att_dim, vocab_size, emb_size, num_classes, dropout_rate, pretrained_embs): 9 | self.cell_dim = cell_dim 10 | self.att_dim = att_dim 11 | self.emb_size = emb_size 12 | self.vocab_size = vocab_size 13 | self.num_classes = num_classes 14 | self.dropout_rate = dropout_rate 15 | self.pretrained_embs = pretrained_embs 16 | 17 | self.docs = tf.placeholder(shape=(None, None, None), dtype=tf.int32, name='docs') 18 | self.sent_lengths = tf.placeholder(shape=(None,), dtype=tf.int32, name='sent_lengths') 19 | self.word_lengths = tf.placeholder(shape=(None, None), dtype=tf.int32, name='word_lengths') 20 | self.max_word_length = tf.placeholder(dtype=tf.int32, name='max_word_length') 21 | self.max_sent_length = tf.placeholder(dtype=tf.int32, name='max_sent_length') 22 | self.labels = tf.placeholder(shape=(None), dtype=tf.int32, name='labels') 23 | self.is_training = tf.placeholder(dtype=tf.bool, name='is_training') 24 | 25 | self._init_embedding() 26 | self._init_word_encoder() 27 | self._init_sent_encoder() 28 | self._init_classifier() 29 | 30 | def _init_embedding(self): 31 | with tf.variable_scope('embedding'): 32 | self.embedding_matrix = tf.get_variable(name='embedding_matrix', 33 | shape=[self.vocab_size, self.emb_size], 34 | initializer=tf.constant_initializer(self.pretrained_embs), 35 | dtype=tf.float32) 36 | self.embedded_inputs = tf.nn.embedding_lookup(self.embedding_matrix, self.docs) 37 | 38 | def _init_word_encoder(self): 39 | with tf.variable_scope('word-encoder') as scope: 40 | word_inputs = tf.reshape(self.embedded_inputs, [-1, self.max_word_length, self.emb_size]) 41 | word_lengths = tf.reshape(self.word_lengths, [-1]) 42 | 43 | # word encoder 44 | cell_fw = rnn.GRUCell(self.cell_dim, name='cell_fw') 45 | cell_bw = rnn.GRUCell(self.cell_dim, name='cell_bw') 46 | 47 | init_state_fw = tf.tile(tf.get_variable('init_state_fw', 48 | shape=[1, self.cell_dim], 49 | initializer=tf.constant_initializer(0)), 50 | multiples=[get_shape(word_inputs)[0], 1]) 51 | init_state_bw = tf.tile(tf.get_variable('init_state_bw', 52 | shape=[1, self.cell_dim], 53 | initializer=tf.constant_initializer(0)), 54 | multiples=[get_shape(word_inputs)[0], 1]) 55 | 56 | rnn_outputs, _ = bidirectional_rnn(cell_fw=cell_fw, 57 | cell_bw=cell_bw, 58 | inputs=word_inputs, 59 | input_lengths=word_lengths, 60 | initial_state_fw=init_state_fw, 61 | initial_state_bw=init_state_bw, 62 | scope=scope) 63 | 64 | word_outputs, word_att_weights = attention(inputs=rnn_outputs, 65 | att_dim=self.att_dim, 66 | sequence_lengths=word_lengths) 67 | self.word_outputs = tf.layers.dropout(word_outputs, self.dropout_rate, training=self.is_training) 68 | 69 | def _init_sent_encoder(self): 70 | with tf.variable_scope('sent-encoder') as scope: 71 | sent_inputs = tf.reshape(self.word_outputs, [-1, self.max_sent_length, 2 * self.cell_dim]) 72 | 73 | # sentence encoder 74 | cell_fw = rnn.GRUCell(self.cell_dim, name='cell_fw') 75 | cell_bw = rnn.GRUCell(self.cell_dim, name='cell_bw') 76 | 77 | init_state_fw = tf.tile(tf.get_variable('init_state_fw', 78 | shape=[1, self.cell_dim], 79 | initializer=tf.constant_initializer(0)), 80 | multiples=[get_shape(sent_inputs)[0], 1]) 81 | init_state_bw = tf.tile(tf.get_variable('init_state_bw', 82 | shape=[1, self.cell_dim], 83 | initializer=tf.constant_initializer(0)), 84 | multiples=[get_shape(sent_inputs)[0], 1]) 85 | 86 | rnn_outputs, _ = bidirectional_rnn(cell_fw=cell_fw, 87 | cell_bw=cell_bw, 88 | inputs=sent_inputs, 89 | input_lengths=self.sent_lengths, 90 | initial_state_fw=init_state_fw, 91 | initial_state_bw=init_state_bw, 92 | scope=scope) 93 | 94 | sent_outputs, sent_att_weights = attention(inputs=rnn_outputs, 95 | att_dim=self.att_dim, 96 | sequence_lengths=self.sent_lengths) 97 | self.sent_outputs = tf.layers.dropout(sent_outputs, self.dropout_rate, training=self.is_training) 98 | 99 | def _init_classifier(self): 100 | with tf.variable_scope('classifier'): 101 | self.logits = tf.layers.dense(inputs=self.sent_outputs, units=self.num_classes, name='logits') 102 | 103 | def get_feed_dict(self, docs, labels, training=False): 104 | padded_docs, sent_lengths, max_sent_length, word_lengths, max_word_length = batch_doc_normalize(docs) 105 | fd = { 106 | self.docs: padded_docs, 107 | self.sent_lengths: sent_lengths, 108 | self.word_lengths: word_lengths, 109 | self.max_sent_length: max_sent_length, 110 | self.max_word_length: max_word_length, 111 | self.labels: labels, 112 | self.is_training: training 113 | } 114 | return fd 115 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from datetime import datetime 3 | from data_reader import DataReader 4 | from model import Model 5 | from utils import read_vocab, count_parameters, load_glove 6 | 7 | # Parameters 8 | # ================================================== 9 | FLAGS = tf.flags.FLAGS 10 | 11 | tf.flags.DEFINE_string("checkpoint_dir", 'checkpoints', 12 | """Path to checkpoint folder""") 13 | tf.flags.DEFINE_string("log_dir", 'logs', 14 | """Path to log folder""") 15 | 16 | tf.flags.DEFINE_integer("cell_dim", 50, 17 | """Hidden dimensions of GRU cells (default: 50)""") 18 | tf.flags.DEFINE_integer("att_dim", 100, 19 | """Dimensionality of attention spaces (default: 100)""") 20 | tf.flags.DEFINE_integer("emb_size", 200, 21 | """Dimensionality of word embedding (default: 200)""") 22 | tf.flags.DEFINE_integer("num_classes", 5, 23 | """Number of classes (default: 5)""") 24 | 25 | tf.flags.DEFINE_integer("num_checkpoints", 1, 26 | """Number of checkpoints to store (default: 1)""") 27 | tf.flags.DEFINE_integer("num_epochs", 20, 28 | """Number of training epochs (default: 20)""") 29 | tf.flags.DEFINE_integer("batch_size", 64, 30 | """Batch size (default: 64)""") 31 | tf.flags.DEFINE_integer("display_step", 20, 32 | """Number of steps to display log into TensorBoard (default: 20)""") 33 | 34 | tf.flags.DEFINE_float("learning_rate", 0.0005, 35 | """Learning rate (default: 0.0005)""") 36 | tf.flags.DEFINE_float("max_grad_norm", 5.0, 37 | """Maximum value of the global norm of the gradients for clipping (default: 5.0)""") 38 | tf.flags.DEFINE_float("dropout_rate", 0.5, 39 | """Probability of dropping neurons (default: 0.5)""") 40 | 41 | tf.flags.DEFINE_boolean("allow_soft_placement", True, 42 | """Allow device soft device placement""") 43 | 44 | if not tf.gfile.Exists(FLAGS.checkpoint_dir): 45 | tf.gfile.MakeDirs(FLAGS.checkpoint_dir) 46 | 47 | if not tf.gfile.Exists(FLAGS.log_dir): 48 | tf.gfile.MakeDirs(FLAGS.log_dir) 49 | 50 | train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train') 51 | valid_writer = tf.summary.FileWriter(FLAGS.log_dir + '/valid') 52 | test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') 53 | 54 | 55 | def loss_fn(labels, logits): 56 | onehot_labels = tf.one_hot(labels, depth=FLAGS.num_classes) 57 | cross_entropy_loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, 58 | logits=logits) 59 | tf.summary.scalar('loss', cross_entropy_loss) 60 | return cross_entropy_loss 61 | 62 | 63 | def train_fn(loss): 64 | trained_vars = tf.trainable_variables() 65 | count_parameters(trained_vars) 66 | 67 | # Gradient clipping 68 | gradients = tf.gradients(loss, trained_vars) 69 | 70 | clipped_grads, global_norm = tf.clip_by_global_norm(gradients, FLAGS.max_grad_norm) 71 | tf.summary.scalar('global_grad_norm', global_norm) 72 | 73 | # Add gradients and vars to summary 74 | # for gradient, var in list(zip(clipped_grads, trained_vars)): 75 | # if 'attention' in var.name: 76 | # tf.summary.histogram(var.name + '/gradient', gradient) 77 | # tf.summary.histogram(var.name, var) 78 | 79 | # Define optimizer 80 | global_step = tf.train.get_or_create_global_step() 81 | optimizer = tf.train.RMSPropOptimizer(FLAGS.learning_rate) 82 | train_op = optimizer.apply_gradients(zip(clipped_grads, trained_vars), 83 | name='train_op', 84 | global_step=global_step) 85 | return train_op, global_step 86 | 87 | 88 | def eval_fn(labels, logits): 89 | predictions = tf.argmax(logits, axis=-1) 90 | correct_preds = tf.equal(predictions, tf.cast(labels, tf.int64)) 91 | batch_acc = tf.reduce_mean(tf.cast(correct_preds, tf.float32)) 92 | tf.summary.scalar('accuracy', batch_acc) 93 | 94 | total_acc, acc_update = tf.metrics.accuracy(labels, predictions, name='metrics/acc') 95 | metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics") 96 | metrics_init = tf.variables_initializer(var_list=metrics_vars) 97 | 98 | return batch_acc, total_acc, acc_update, metrics_init 99 | 100 | 101 | def main(_): 102 | vocab = read_vocab('data/yelp-2015-w2i.pkl') 103 | glove_embs = load_glove('glove.6B.{}d.txt'.format(FLAGS.emb_size), FLAGS.emb_size, vocab) 104 | data_reader = DataReader(train_file='data/yelp-2015-train.pkl', 105 | dev_file='data/yelp-2015-dev.pkl', 106 | test_file='data/yelp-2015-test.pkl') 107 | 108 | config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement) 109 | with tf.Session(config=config) as sess: 110 | model = Model(cell_dim=FLAGS.cell_dim, 111 | att_dim=FLAGS.att_dim, 112 | vocab_size=len(vocab), 113 | emb_size=FLAGS.emb_size, 114 | num_classes=FLAGS.num_classes, 115 | dropout_rate=FLAGS.dropout_rate, 116 | pretrained_embs=glove_embs) 117 | 118 | loss = loss_fn(model.labels, model.logits) 119 | train_op, global_step = train_fn(loss) 120 | batch_acc, total_acc, acc_update, metrics_init = eval_fn(model.labels, model.logits) 121 | summary_op = tf.summary.merge_all() 122 | sess.run(tf.global_variables_initializer()) 123 | 124 | train_writer.add_graph(sess.graph) 125 | saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoints) 126 | 127 | print('\n{}> Start training'.format(datetime.now())) 128 | 129 | epoch = 0 130 | valid_step = 0 131 | test_step = 0 132 | train_test_prop = len(data_reader.train_data) / len(data_reader.test_data) 133 | test_batch_size = int(FLAGS.batch_size / train_test_prop) 134 | best_acc = float('-inf') 135 | 136 | while epoch < FLAGS.num_epochs: 137 | epoch += 1 138 | print('\n{}> Epoch: {}'.format(datetime.now(), epoch)) 139 | 140 | sess.run(metrics_init) 141 | for batch_docs, batch_labels in data_reader.read_train_set(FLAGS.batch_size, shuffle=True): 142 | _step, _, _loss, _acc, _ = sess.run([global_step, train_op, loss, batch_acc, acc_update], 143 | feed_dict=model.get_feed_dict(batch_docs, batch_labels, training=True)) 144 | if _step % FLAGS.display_step == 0: 145 | _summary = sess.run(summary_op, feed_dict=model.get_feed_dict(batch_docs, batch_labels)) 146 | train_writer.add_summary(_summary, global_step=_step) 147 | print('Training accuracy = {:.2f}'.format(sess.run(total_acc) * 100)) 148 | 149 | sess.run(metrics_init) 150 | for batch_docs, batch_labels in data_reader.read_valid_set(test_batch_size): 151 | _loss, _acc, _ = sess.run([loss, batch_acc, acc_update], feed_dict=model.get_feed_dict(batch_docs, batch_labels)) 152 | valid_step += 1 153 | if valid_step % FLAGS.display_step == 0: 154 | _summary = sess.run(summary_op, feed_dict=model.get_feed_dict(batch_docs, batch_labels)) 155 | valid_writer.add_summary(_summary, global_step=valid_step) 156 | print('Validation accuracy = {:.2f}'.format(sess.run(total_acc) * 100)) 157 | 158 | sess.run(metrics_init) 159 | for batch_docs, batch_labels in data_reader.read_test_set(test_batch_size): 160 | _loss, _acc, _ = sess.run([loss, batch_acc, acc_update], feed_dict=model.get_feed_dict(batch_docs, batch_labels)) 161 | test_step += 1 162 | if test_step % FLAGS.display_step == 0: 163 | _summary = sess.run(summary_op, feed_dict=model.get_feed_dict(batch_docs, batch_labels)) 164 | test_writer.add_summary(_summary, global_step=test_step) 165 | test_acc = sess.run(total_acc) * 100 166 | print('Testing accuracy = {:.2f}'.format(test_acc)) 167 | 168 | if test_acc > best_acc: 169 | best_acc = test_acc 170 | saver.save(sess, FLAGS.checkpoint_dir) 171 | print('Best testing accuracy = {:.2f}'.format(test_acc)) 172 | 173 | print("{} Optimization Finished!".format(datetime.now())) 174 | print('Best testing accuracy = {:.2f}'.format(best_acc)) 175 | 176 | 177 | if __name__ == '__main__': 178 | tf.app.run() 179 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pickle 4 | 5 | 6 | def get_shape(tensor): 7 | static_shape = tensor.shape.as_list() 8 | dynamic_shape = tf.unstack(tf.shape(tensor)) 9 | dims = [s[1] if s[0] is None else s[0] 10 | for s in zip(static_shape, dynamic_shape)] 11 | return dims 12 | 13 | 14 | def count_parameters(trained_vars): 15 | total_parameters = 0 16 | print('=' * 100) 17 | for variable in trained_vars: 18 | variable_parameters = 1 19 | for dim in variable.get_shape(): 20 | variable_parameters *= dim.value 21 | print('{:70} {:20} params'.format(variable.name, variable_parameters)) 22 | print('-' * 100) 23 | total_parameters += variable_parameters 24 | print('=' * 100) 25 | print("Total trainable parameters: %d" % total_parameters) 26 | print('=' * 100) 27 | 28 | 29 | def read_vocab(vocab_file): 30 | print('Loading vocabulary ...') 31 | with open(vocab_file, 'rb') as f: 32 | word_to_index = pickle.load(f) 33 | print('Vocabulary size = %d' % len(word_to_index)) 34 | return word_to_index 35 | 36 | 37 | def batch_doc_normalize(docs): 38 | sent_lengths = np.array([len(doc) for doc in docs], dtype=np.int32) 39 | max_sent_length = sent_lengths.max() 40 | word_lengths = [[len(sent) for sent in doc] for doc in docs] 41 | max_word_length = max(map(max, word_lengths)) 42 | 43 | padded_docs = np.zeros(shape=[len(docs), max_sent_length, max_word_length], dtype=np.int32) # PADDING 0 44 | word_lengths = np.zeros(shape=[len(docs), max_sent_length], dtype=np.int32) 45 | for i, doc in enumerate(docs): 46 | for j, sent in enumerate(doc): 47 | word_lengths[i, j] = len(sent) 48 | for k, word in enumerate(sent): 49 | padded_docs[i, j, k] = word 50 | 51 | return padded_docs, sent_lengths, max_sent_length, word_lengths, max_word_length 52 | 53 | 54 | def load_glove(glove_file, emb_size, vocab): 55 | print('Loading Glove pre-trained word embeddings ...') 56 | embedding_weights = {} 57 | f = open(glove_file, encoding='utf-8') 58 | for line in f: 59 | values = line.split() 60 | word = values[0] 61 | vector = np.asarray(values[1:], dtype='float32') 62 | embedding_weights[word] = vector 63 | f.close() 64 | print('Total {} word vectors in {}'.format(len(embedding_weights), glove_file)) 65 | 66 | embedding_matrix = np.random.uniform(-0.5, 0.5, (len(vocab), emb_size)) / emb_size 67 | 68 | oov_count = 0 69 | for word, i in vocab.items(): 70 | embedding_vector = embedding_weights.get(word) 71 | if embedding_vector is not None: 72 | embedding_matrix[i] = embedding_vector 73 | else: 74 | oov_count += 1 75 | print('Number of OOV words = %d' % oov_count) 76 | 77 | return embedding_matrix 78 | --------------------------------------------------------------------------------