├── .gitignore
├── LICENSE
├── README.md
├── data_prepare.py
├── data_reader.py
├── img
├── model.png
└── train_log.png
├── layers.py
├── model.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Quoc-Tuan Truong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hierarchical Attention Networks for Document Classification
2 |
3 | This is an implementation of the paper [Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~hovy/papers/16HLT-hierarchical-attention-networks.pdf), NAACL 2016.
4 |
5 | 
6 |
7 | ## Requirements
8 |
9 | - Python 3
10 | - Tensorflow > 1.0
11 | - Pandas
12 | - Nltk
13 | - Tqdm
14 | - [Glove pre-trained word embeddings](http://nlp.stanford.edu/data/glove.6B.zip)
15 |
16 | ## Data
17 |
18 | We use the [data](http://ir.hit.edu.cn/~dytang/paper/emnlp2015/emnlp-2015-data.7z) provided by [Tang et al. 2015](http://ir.hit.edu.cn/~dytang/paper/emnlp2015/emnlp2015.pdf), including 4 datasets:
19 |
20 | - IMDB
21 | - Yelp 2013
22 | - Yelp 2014
23 | - Yelp 2015
24 |
25 | **Note:**
26 | The original data seems to have an [issue](https://github.com/tqtg/hierarchical-attention-networks/issues/1) with unzipping. I re-uploaded the [data](https://drive.google.com/file/d/1OQ_ggjlNUWiTg_zFXc0_OpYXpJRwJP3y) to GG Drive for better downloading speed. Please request for access permission.
27 |
28 | ## Usage
29 |
30 | First, download the [datasets](#data) and unzip into `data` folder.
31 |
32 | Then, run script to prepare the data *(default is using Yelp-2015 dataset)*:
33 |
34 | ```bash
35 | python data_prepare.py
36 | ```
37 |
38 | Train and evaluate the model:
39 |
40 | *(make sure [Glove embeddings](#requirements) are ready before training)*
41 | ```
42 | wget http://nlp.stanford.edu/data/glove.6B.zip
43 | unzip glove.6B.zip
44 | ```
45 | ```bash
46 | python train.py
47 | ```
48 |
49 | Print training arguments:
50 |
51 | ```bash
52 | python train.py --help
53 | ```
54 | ```
55 | optional arguments:
56 | -h, --help show this help message and exit
57 | --cell_dim CELL_DIM
58 | Hidden dimensions of GRU cells (default: 50)
59 | --att_dim ATTENTION_DIM
60 | Dimensionality of attention spaces (default: 100)
61 | --emb_dim EMBEDDING_DIM
62 | Dimensionality of word embedding (default: 200)
63 | --learning_rate LEARNING_RATE
64 | Learning rate (default: 0.0005)
65 | --max_grad_norm MAX_GRAD_NORM
66 | Maximum value of the global norm of the gradients for clipping (default: 5.0)
67 | --dropout_rate DROPOUT_RATE
68 | Probability of dropping neurons (default: 0.5)
69 | --num_classes NUM_CLASSES
70 | Number of classes (default: 5)
71 | --num_checkpoints NUM_CHECKPOINTS
72 | Number of checkpoints to store (default: 1)
73 | --num_epochs NUM_EPOCHS
74 | Number of training epochs (default: 20)
75 | --batch_size BATCH_SIZE
76 | Batch size (default: 64)
77 | --display_step DISPLAY_STEP
78 | Number of steps to display log into TensorBoard (default: 20)
79 | --allow_soft_placement ALLOW_SOFT_PLACEMENT
80 | Allow device soft device placement
81 | ```
82 |
83 | ## Results
84 |
85 | With the *Yelp-2015* dataset, after 5 epochs, we achieved:
86 |
87 | - **69.79%** accuracy on the *dev set*
88 | - **69.62%** accuracy on the *test set*
89 |
90 | No systematic hyper-parameter tunning was performed. The result reported in the paper is **71.0%** for the *Yelp-2015*.
91 |
92 | 
93 |
--------------------------------------------------------------------------------
/data_prepare.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import nltk
3 | import itertools
4 | import pickle
5 |
6 | # Hyper parameters
7 | WORD_CUT_OFF = 5
8 |
9 |
10 | def build_vocab(docs, save_path):
11 | print('Building vocab ...')
12 |
13 | sents = itertools.chain(*[text.split('') for text in docs])
14 | tokenized_sents = [sent.split() for sent in sents]
15 |
16 | # Count the word frequencies
17 | word_freq = nltk.FreqDist(itertools.chain(*tokenized_sents))
18 | print("%d unique words found" % len(word_freq.items()))
19 |
20 | # Cut-off
21 | retained_words = [w for (w, f) in word_freq.items() if f > WORD_CUT_OFF]
22 | print("%d words retained" % len(retained_words))
23 |
24 | # Get the most common words and build index_to_word and word_to_index vectors
25 | # Word index starts from 2, 1 is reserved for UNK, 0 is reserved for padding
26 | word_to_index = {'PAD': 0, 'UNK': 1}
27 | for i, w in enumerate(retained_words):
28 | word_to_index[w] = i + 2
29 | index_to_word = {i: w for (w, i) in word_to_index.items()}
30 |
31 | print("Vocabulary size = %d" % len(word_to_index))
32 |
33 | with open('{}-w2i.pkl'.format(save_path), 'wb') as f:
34 | pickle.dump(word_to_index, f)
35 |
36 | with open('{}-i2w.pkl'.format(save_path), 'wb') as f:
37 | pickle.dump(index_to_word, f)
38 |
39 | return word_to_index
40 |
41 |
42 | def process_and_save(word_to_index, data, out_file):
43 | mapped_data = []
44 | for label, doc in zip(data[4], data[6]):
45 | mapped_doc = [[word_to_index.get(word, 1) for word in sent.split()] for sent in doc.split('')]
46 | mapped_data.append((label, mapped_doc))
47 |
48 | with open(out_file, 'wb') as f:
49 | pickle.dump(mapped_data, f)
50 |
51 |
52 | def read_data(data_file):
53 | data = pd.read_csv(data_file, sep='\t', header=None, usecols=[4, 6])
54 | print('{}, shape={}'.format(data_file, data.shape))
55 | return data
56 |
57 |
58 | if __name__ == '__main__':
59 | train_data = read_data('data/yelp-2015-train.txt.ss')
60 | word_to_index = build_vocab(train_data[6], 'data/yelp-2015')
61 | process_and_save(word_to_index, train_data, 'data/yelp-2015-train.pkl')
62 |
63 | dev_data = read_data('data/yelp-2015-dev.txt.ss')
64 | process_and_save(word_to_index, dev_data, 'data/yelp-2015-dev.pkl')
65 |
66 | test_data = read_data('data/yelp-2015-test.txt.ss')
67 | process_and_save(word_to_index, test_data, 'data/yelp-2015-test.pkl')
68 |
--------------------------------------------------------------------------------
/data_reader.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from tqdm import tqdm
3 | import random
4 | import numpy as np
5 |
6 | class DataReader:
7 | def __init__(self, train_file, dev_file, test_file,
8 | max_word_length=50, max_sent_length=30, num_classes=5):
9 | self.max_word_length = max_word_length
10 | self.max_sent_length = max_sent_length
11 | self.num_classes = num_classes
12 |
13 | self.train_data = self._read_data(train_file)
14 | self.valid_data = self._read_data(dev_file)
15 | self.test_data = self._read_data(test_file)
16 |
17 | def _read_data(self, file_path):
18 | print('Reading data from %s' % file_path)
19 | new_data = []
20 | with open(file_path, 'rb') as f:
21 | data = pickle.load(f)
22 | random.shuffle(data)
23 | for label, doc in data:
24 | doc = doc[:self.max_sent_length]
25 | doc = [sent[:self.max_word_length] for sent in doc]
26 |
27 | label -= 1
28 | assert label >= 0 and label < self.num_classes
29 |
30 | new_data.append((doc, label))
31 |
32 | # sort data by sent lengths to speed up
33 | new_data = sorted(new_data, key=lambda x: len(x[0]))
34 | return new_data
35 |
36 | def _batch_iterator(self, data, batch_size, desc=None):
37 | num_batches = int(np.ceil(len(data) / batch_size))
38 | for b in tqdm(range(num_batches), desc):
39 | begin_offset = batch_size * b
40 | end_offset = batch_size * b + batch_size
41 | if end_offset > len(data):
42 | end_offset = len(data)
43 |
44 | doc_batch = []
45 | label_batch = []
46 | for offset in range(begin_offset, end_offset):
47 | doc_batch.append(data[offset][0])
48 | label_batch.append(data[offset][1])
49 |
50 | yield doc_batch, label_batch
51 |
52 | def read_train_set(self, batch_size, shuffle=False):
53 | if shuffle:
54 | random.shuffle(self.train_data)
55 | return self._batch_iterator(self.train_data, batch_size, desc='Training')
56 |
57 | def read_valid_set(self, batch_size):
58 | return self._batch_iterator(self.valid_data, batch_size, desc='Validating')
59 |
60 | def read_test_set(self, batch_size):
61 | return self._batch_iterator(self.test_data, batch_size, desc='Testing')
62 |
--------------------------------------------------------------------------------
/img/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tqtg/hierarchical-attention-networks/0b36b64115137bebbb66760fe9e543a1298bc158/img/model.png
--------------------------------------------------------------------------------
/img/train_log.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tqtg/hierarchical-attention-networks/0b36b64115137bebbb66760fe9e543a1298bc158/img/train_log.png
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from utils import get_shape
4 |
5 | try:
6 | from tensorflow.contrib.rnn import LSTMStateTuple
7 | except ImportError:
8 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple
9 |
10 |
11 |
12 | def bidirectional_rnn(cell_fw, cell_bw, inputs, input_lengths,
13 | initial_state_fw=None, initial_state_bw=None,
14 | scope=None):
15 | with tf.variable_scope(scope or 'bi_rnn') as scope:
16 | (fw_outputs, bw_outputs), (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn(
17 | cell_fw=cell_fw,
18 | cell_bw=cell_bw,
19 | inputs=inputs,
20 | sequence_length=input_lengths,
21 | initial_state_fw=initial_state_fw,
22 | initial_state_bw=initial_state_bw,
23 | dtype=tf.float32,
24 | scope=scope
25 | )
26 | outputs = tf.concat((fw_outputs, bw_outputs), axis=2)
27 |
28 | def concatenate_state(fw_state, bw_state):
29 | if isinstance(fw_state, LSTMStateTuple):
30 | state_c = tf.concat(
31 | (fw_state.c, bw_state.c), 1, name='bidirectional_concat_c')
32 | state_h = tf.concat(
33 | (fw_state.h, bw_state.h), 1, name='bidirectional_concat_h')
34 | state = LSTMStateTuple(c=state_c, h=state_h)
35 | return state
36 | elif isinstance(fw_state, tf.Tensor):
37 | state = tf.concat((fw_state, bw_state), 1,
38 | name='bidirectional_concat')
39 | return state
40 | elif (isinstance(fw_state, tuple) and
41 | isinstance(bw_state, tuple) and
42 | len(fw_state) == len(bw_state)):
43 | # multilayer
44 | state = tuple(concatenate_state(fw, bw)
45 | for fw, bw in zip(fw_state, bw_state))
46 | return state
47 |
48 | else:
49 | raise ValueError(
50 | 'unknown state type: {}'.format((fw_state, bw_state)))
51 |
52 | state = concatenate_state(fw_state, bw_state)
53 | return outputs, state
54 |
55 |
56 | def masking(scores, sequence_lengths, score_mask_value=tf.constant(-np.inf)):
57 | score_mask = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(scores)[1])
58 | score_mask_values = score_mask_value * tf.ones_like(scores)
59 | return tf.where(score_mask, scores, score_mask_values)
60 |
61 |
62 | def attention(inputs, att_dim, sequence_lengths, scope=None):
63 | assert len(inputs.get_shape()) == 3 and inputs.get_shape()[-1].value is not None
64 |
65 | with tf.variable_scope(scope or 'attention'):
66 | word_att_W = tf.get_variable(name='att_W', shape=[att_dim, 1])
67 |
68 | projection = tf.layers.dense(inputs, att_dim, tf.nn.tanh, name='projection')
69 |
70 | alpha = tf.matmul(tf.reshape(projection, shape=[-1, att_dim]), word_att_W)
71 | alpha = tf.reshape(alpha, shape=[-1, get_shape(inputs)[1]])
72 | alpha = masking(alpha, sequence_lengths, tf.constant(-1e15, dtype=tf.float32))
73 | alpha = tf.nn.softmax(alpha)
74 |
75 | outputs = tf.reduce_sum(inputs * tf.expand_dims(alpha, 2), axis=1)
76 | return outputs, alpha
77 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.rnn as rnn
3 | from layers import bidirectional_rnn, attention
4 | from utils import get_shape, batch_doc_normalize
5 |
6 |
7 | class Model:
8 | def __init__(self, cell_dim, att_dim, vocab_size, emb_size, num_classes, dropout_rate, pretrained_embs):
9 | self.cell_dim = cell_dim
10 | self.att_dim = att_dim
11 | self.emb_size = emb_size
12 | self.vocab_size = vocab_size
13 | self.num_classes = num_classes
14 | self.dropout_rate = dropout_rate
15 | self.pretrained_embs = pretrained_embs
16 |
17 | self.docs = tf.placeholder(shape=(None, None, None), dtype=tf.int32, name='docs')
18 | self.sent_lengths = tf.placeholder(shape=(None,), dtype=tf.int32, name='sent_lengths')
19 | self.word_lengths = tf.placeholder(shape=(None, None), dtype=tf.int32, name='word_lengths')
20 | self.max_word_length = tf.placeholder(dtype=tf.int32, name='max_word_length')
21 | self.max_sent_length = tf.placeholder(dtype=tf.int32, name='max_sent_length')
22 | self.labels = tf.placeholder(shape=(None), dtype=tf.int32, name='labels')
23 | self.is_training = tf.placeholder(dtype=tf.bool, name='is_training')
24 |
25 | self._init_embedding()
26 | self._init_word_encoder()
27 | self._init_sent_encoder()
28 | self._init_classifier()
29 |
30 | def _init_embedding(self):
31 | with tf.variable_scope('embedding'):
32 | self.embedding_matrix = tf.get_variable(name='embedding_matrix',
33 | shape=[self.vocab_size, self.emb_size],
34 | initializer=tf.constant_initializer(self.pretrained_embs),
35 | dtype=tf.float32)
36 | self.embedded_inputs = tf.nn.embedding_lookup(self.embedding_matrix, self.docs)
37 |
38 | def _init_word_encoder(self):
39 | with tf.variable_scope('word-encoder') as scope:
40 | word_inputs = tf.reshape(self.embedded_inputs, [-1, self.max_word_length, self.emb_size])
41 | word_lengths = tf.reshape(self.word_lengths, [-1])
42 |
43 | # word encoder
44 | cell_fw = rnn.GRUCell(self.cell_dim, name='cell_fw')
45 | cell_bw = rnn.GRUCell(self.cell_dim, name='cell_bw')
46 |
47 | init_state_fw = tf.tile(tf.get_variable('init_state_fw',
48 | shape=[1, self.cell_dim],
49 | initializer=tf.constant_initializer(0)),
50 | multiples=[get_shape(word_inputs)[0], 1])
51 | init_state_bw = tf.tile(tf.get_variable('init_state_bw',
52 | shape=[1, self.cell_dim],
53 | initializer=tf.constant_initializer(0)),
54 | multiples=[get_shape(word_inputs)[0], 1])
55 |
56 | rnn_outputs, _ = bidirectional_rnn(cell_fw=cell_fw,
57 | cell_bw=cell_bw,
58 | inputs=word_inputs,
59 | input_lengths=word_lengths,
60 | initial_state_fw=init_state_fw,
61 | initial_state_bw=init_state_bw,
62 | scope=scope)
63 |
64 | word_outputs, word_att_weights = attention(inputs=rnn_outputs,
65 | att_dim=self.att_dim,
66 | sequence_lengths=word_lengths)
67 | self.word_outputs = tf.layers.dropout(word_outputs, self.dropout_rate, training=self.is_training)
68 |
69 | def _init_sent_encoder(self):
70 | with tf.variable_scope('sent-encoder') as scope:
71 | sent_inputs = tf.reshape(self.word_outputs, [-1, self.max_sent_length, 2 * self.cell_dim])
72 |
73 | # sentence encoder
74 | cell_fw = rnn.GRUCell(self.cell_dim, name='cell_fw')
75 | cell_bw = rnn.GRUCell(self.cell_dim, name='cell_bw')
76 |
77 | init_state_fw = tf.tile(tf.get_variable('init_state_fw',
78 | shape=[1, self.cell_dim],
79 | initializer=tf.constant_initializer(0)),
80 | multiples=[get_shape(sent_inputs)[0], 1])
81 | init_state_bw = tf.tile(tf.get_variable('init_state_bw',
82 | shape=[1, self.cell_dim],
83 | initializer=tf.constant_initializer(0)),
84 | multiples=[get_shape(sent_inputs)[0], 1])
85 |
86 | rnn_outputs, _ = bidirectional_rnn(cell_fw=cell_fw,
87 | cell_bw=cell_bw,
88 | inputs=sent_inputs,
89 | input_lengths=self.sent_lengths,
90 | initial_state_fw=init_state_fw,
91 | initial_state_bw=init_state_bw,
92 | scope=scope)
93 |
94 | sent_outputs, sent_att_weights = attention(inputs=rnn_outputs,
95 | att_dim=self.att_dim,
96 | sequence_lengths=self.sent_lengths)
97 | self.sent_outputs = tf.layers.dropout(sent_outputs, self.dropout_rate, training=self.is_training)
98 |
99 | def _init_classifier(self):
100 | with tf.variable_scope('classifier'):
101 | self.logits = tf.layers.dense(inputs=self.sent_outputs, units=self.num_classes, name='logits')
102 |
103 | def get_feed_dict(self, docs, labels, training=False):
104 | padded_docs, sent_lengths, max_sent_length, word_lengths, max_word_length = batch_doc_normalize(docs)
105 | fd = {
106 | self.docs: padded_docs,
107 | self.sent_lengths: sent_lengths,
108 | self.word_lengths: word_lengths,
109 | self.max_sent_length: max_sent_length,
110 | self.max_word_length: max_word_length,
111 | self.labels: labels,
112 | self.is_training: training
113 | }
114 | return fd
115 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from datetime import datetime
3 | from data_reader import DataReader
4 | from model import Model
5 | from utils import read_vocab, count_parameters, load_glove
6 |
7 | # Parameters
8 | # ==================================================
9 | FLAGS = tf.flags.FLAGS
10 |
11 | tf.flags.DEFINE_string("checkpoint_dir", 'checkpoints',
12 | """Path to checkpoint folder""")
13 | tf.flags.DEFINE_string("log_dir", 'logs',
14 | """Path to log folder""")
15 |
16 | tf.flags.DEFINE_integer("cell_dim", 50,
17 | """Hidden dimensions of GRU cells (default: 50)""")
18 | tf.flags.DEFINE_integer("att_dim", 100,
19 | """Dimensionality of attention spaces (default: 100)""")
20 | tf.flags.DEFINE_integer("emb_size", 200,
21 | """Dimensionality of word embedding (default: 200)""")
22 | tf.flags.DEFINE_integer("num_classes", 5,
23 | """Number of classes (default: 5)""")
24 |
25 | tf.flags.DEFINE_integer("num_checkpoints", 1,
26 | """Number of checkpoints to store (default: 1)""")
27 | tf.flags.DEFINE_integer("num_epochs", 20,
28 | """Number of training epochs (default: 20)""")
29 | tf.flags.DEFINE_integer("batch_size", 64,
30 | """Batch size (default: 64)""")
31 | tf.flags.DEFINE_integer("display_step", 20,
32 | """Number of steps to display log into TensorBoard (default: 20)""")
33 |
34 | tf.flags.DEFINE_float("learning_rate", 0.0005,
35 | """Learning rate (default: 0.0005)""")
36 | tf.flags.DEFINE_float("max_grad_norm", 5.0,
37 | """Maximum value of the global norm of the gradients for clipping (default: 5.0)""")
38 | tf.flags.DEFINE_float("dropout_rate", 0.5,
39 | """Probability of dropping neurons (default: 0.5)""")
40 |
41 | tf.flags.DEFINE_boolean("allow_soft_placement", True,
42 | """Allow device soft device placement""")
43 |
44 | if not tf.gfile.Exists(FLAGS.checkpoint_dir):
45 | tf.gfile.MakeDirs(FLAGS.checkpoint_dir)
46 |
47 | if not tf.gfile.Exists(FLAGS.log_dir):
48 | tf.gfile.MakeDirs(FLAGS.log_dir)
49 |
50 | train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train')
51 | valid_writer = tf.summary.FileWriter(FLAGS.log_dir + '/valid')
52 | test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
53 |
54 |
55 | def loss_fn(labels, logits):
56 | onehot_labels = tf.one_hot(labels, depth=FLAGS.num_classes)
57 | cross_entropy_loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels,
58 | logits=logits)
59 | tf.summary.scalar('loss', cross_entropy_loss)
60 | return cross_entropy_loss
61 |
62 |
63 | def train_fn(loss):
64 | trained_vars = tf.trainable_variables()
65 | count_parameters(trained_vars)
66 |
67 | # Gradient clipping
68 | gradients = tf.gradients(loss, trained_vars)
69 |
70 | clipped_grads, global_norm = tf.clip_by_global_norm(gradients, FLAGS.max_grad_norm)
71 | tf.summary.scalar('global_grad_norm', global_norm)
72 |
73 | # Add gradients and vars to summary
74 | # for gradient, var in list(zip(clipped_grads, trained_vars)):
75 | # if 'attention' in var.name:
76 | # tf.summary.histogram(var.name + '/gradient', gradient)
77 | # tf.summary.histogram(var.name, var)
78 |
79 | # Define optimizer
80 | global_step = tf.train.get_or_create_global_step()
81 | optimizer = tf.train.RMSPropOptimizer(FLAGS.learning_rate)
82 | train_op = optimizer.apply_gradients(zip(clipped_grads, trained_vars),
83 | name='train_op',
84 | global_step=global_step)
85 | return train_op, global_step
86 |
87 |
88 | def eval_fn(labels, logits):
89 | predictions = tf.argmax(logits, axis=-1)
90 | correct_preds = tf.equal(predictions, tf.cast(labels, tf.int64))
91 | batch_acc = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
92 | tf.summary.scalar('accuracy', batch_acc)
93 |
94 | total_acc, acc_update = tf.metrics.accuracy(labels, predictions, name='metrics/acc')
95 | metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics")
96 | metrics_init = tf.variables_initializer(var_list=metrics_vars)
97 |
98 | return batch_acc, total_acc, acc_update, metrics_init
99 |
100 |
101 | def main(_):
102 | vocab = read_vocab('data/yelp-2015-w2i.pkl')
103 | glove_embs = load_glove('glove.6B.{}d.txt'.format(FLAGS.emb_size), FLAGS.emb_size, vocab)
104 | data_reader = DataReader(train_file='data/yelp-2015-train.pkl',
105 | dev_file='data/yelp-2015-dev.pkl',
106 | test_file='data/yelp-2015-test.pkl')
107 |
108 | config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
109 | with tf.Session(config=config) as sess:
110 | model = Model(cell_dim=FLAGS.cell_dim,
111 | att_dim=FLAGS.att_dim,
112 | vocab_size=len(vocab),
113 | emb_size=FLAGS.emb_size,
114 | num_classes=FLAGS.num_classes,
115 | dropout_rate=FLAGS.dropout_rate,
116 | pretrained_embs=glove_embs)
117 |
118 | loss = loss_fn(model.labels, model.logits)
119 | train_op, global_step = train_fn(loss)
120 | batch_acc, total_acc, acc_update, metrics_init = eval_fn(model.labels, model.logits)
121 | summary_op = tf.summary.merge_all()
122 | sess.run(tf.global_variables_initializer())
123 |
124 | train_writer.add_graph(sess.graph)
125 | saver = tf.train.Saver(max_to_keep=FLAGS.num_checkpoints)
126 |
127 | print('\n{}> Start training'.format(datetime.now()))
128 |
129 | epoch = 0
130 | valid_step = 0
131 | test_step = 0
132 | train_test_prop = len(data_reader.train_data) / len(data_reader.test_data)
133 | test_batch_size = int(FLAGS.batch_size / train_test_prop)
134 | best_acc = float('-inf')
135 |
136 | while epoch < FLAGS.num_epochs:
137 | epoch += 1
138 | print('\n{}> Epoch: {}'.format(datetime.now(), epoch))
139 |
140 | sess.run(metrics_init)
141 | for batch_docs, batch_labels in data_reader.read_train_set(FLAGS.batch_size, shuffle=True):
142 | _step, _, _loss, _acc, _ = sess.run([global_step, train_op, loss, batch_acc, acc_update],
143 | feed_dict=model.get_feed_dict(batch_docs, batch_labels, training=True))
144 | if _step % FLAGS.display_step == 0:
145 | _summary = sess.run(summary_op, feed_dict=model.get_feed_dict(batch_docs, batch_labels))
146 | train_writer.add_summary(_summary, global_step=_step)
147 | print('Training accuracy = {:.2f}'.format(sess.run(total_acc) * 100))
148 |
149 | sess.run(metrics_init)
150 | for batch_docs, batch_labels in data_reader.read_valid_set(test_batch_size):
151 | _loss, _acc, _ = sess.run([loss, batch_acc, acc_update], feed_dict=model.get_feed_dict(batch_docs, batch_labels))
152 | valid_step += 1
153 | if valid_step % FLAGS.display_step == 0:
154 | _summary = sess.run(summary_op, feed_dict=model.get_feed_dict(batch_docs, batch_labels))
155 | valid_writer.add_summary(_summary, global_step=valid_step)
156 | print('Validation accuracy = {:.2f}'.format(sess.run(total_acc) * 100))
157 |
158 | sess.run(metrics_init)
159 | for batch_docs, batch_labels in data_reader.read_test_set(test_batch_size):
160 | _loss, _acc, _ = sess.run([loss, batch_acc, acc_update], feed_dict=model.get_feed_dict(batch_docs, batch_labels))
161 | test_step += 1
162 | if test_step % FLAGS.display_step == 0:
163 | _summary = sess.run(summary_op, feed_dict=model.get_feed_dict(batch_docs, batch_labels))
164 | test_writer.add_summary(_summary, global_step=test_step)
165 | test_acc = sess.run(total_acc) * 100
166 | print('Testing accuracy = {:.2f}'.format(test_acc))
167 |
168 | if test_acc > best_acc:
169 | best_acc = test_acc
170 | saver.save(sess, FLAGS.checkpoint_dir)
171 | print('Best testing accuracy = {:.2f}'.format(test_acc))
172 |
173 | print("{} Optimization Finished!".format(datetime.now()))
174 | print('Best testing accuracy = {:.2f}'.format(best_acc))
175 |
176 |
177 | if __name__ == '__main__':
178 | tf.app.run()
179 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import pickle
4 |
5 |
6 | def get_shape(tensor):
7 | static_shape = tensor.shape.as_list()
8 | dynamic_shape = tf.unstack(tf.shape(tensor))
9 | dims = [s[1] if s[0] is None else s[0]
10 | for s in zip(static_shape, dynamic_shape)]
11 | return dims
12 |
13 |
14 | def count_parameters(trained_vars):
15 | total_parameters = 0
16 | print('=' * 100)
17 | for variable in trained_vars:
18 | variable_parameters = 1
19 | for dim in variable.get_shape():
20 | variable_parameters *= dim.value
21 | print('{:70} {:20} params'.format(variable.name, variable_parameters))
22 | print('-' * 100)
23 | total_parameters += variable_parameters
24 | print('=' * 100)
25 | print("Total trainable parameters: %d" % total_parameters)
26 | print('=' * 100)
27 |
28 |
29 | def read_vocab(vocab_file):
30 | print('Loading vocabulary ...')
31 | with open(vocab_file, 'rb') as f:
32 | word_to_index = pickle.load(f)
33 | print('Vocabulary size = %d' % len(word_to_index))
34 | return word_to_index
35 |
36 |
37 | def batch_doc_normalize(docs):
38 | sent_lengths = np.array([len(doc) for doc in docs], dtype=np.int32)
39 | max_sent_length = sent_lengths.max()
40 | word_lengths = [[len(sent) for sent in doc] for doc in docs]
41 | max_word_length = max(map(max, word_lengths))
42 |
43 | padded_docs = np.zeros(shape=[len(docs), max_sent_length, max_word_length], dtype=np.int32) # PADDING 0
44 | word_lengths = np.zeros(shape=[len(docs), max_sent_length], dtype=np.int32)
45 | for i, doc in enumerate(docs):
46 | for j, sent in enumerate(doc):
47 | word_lengths[i, j] = len(sent)
48 | for k, word in enumerate(sent):
49 | padded_docs[i, j, k] = word
50 |
51 | return padded_docs, sent_lengths, max_sent_length, word_lengths, max_word_length
52 |
53 |
54 | def load_glove(glove_file, emb_size, vocab):
55 | print('Loading Glove pre-trained word embeddings ...')
56 | embedding_weights = {}
57 | f = open(glove_file, encoding='utf-8')
58 | for line in f:
59 | values = line.split()
60 | word = values[0]
61 | vector = np.asarray(values[1:], dtype='float32')
62 | embedding_weights[word] = vector
63 | f.close()
64 | print('Total {} word vectors in {}'.format(len(embedding_weights), glove_file))
65 |
66 | embedding_matrix = np.random.uniform(-0.5, 0.5, (len(vocab), emb_size)) / emb_size
67 |
68 | oov_count = 0
69 | for word, i in vocab.items():
70 | embedding_vector = embedding_weights.get(word)
71 | if embedding_vector is not None:
72 | embedding_matrix[i] = embedding_vector
73 | else:
74 | oov_count += 1
75 | print('Number of OOV words = %d' % oov_count)
76 |
77 | return embedding_matrix
78 |
--------------------------------------------------------------------------------