├── .gitignore
├── LICENSE
├── data_helpers.py
├── eval.py
├── README.md
├── rnn.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | w2v_model/
2 | saved/
3 | runs/
4 | **/.idea/
5 | **/__pycache__/
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Joohong Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/data_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 |
4 |
5 | def clean_str(string):
6 | """
7 | Tokenization/string cleaning for all datasets except for SST.
8 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
9 | """
10 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
11 | string = re.sub(r"\'s", " \'s", string)
12 | string = re.sub(r"\'ve", " \'ve", string)
13 | string = re.sub(r"n\'t", " n\'t", string)
14 | string = re.sub(r"\'re", " \'re", string)
15 | string = re.sub(r"\'d", " \'d", string)
16 | string = re.sub(r"\'ll", " \'ll", string)
17 | string = re.sub(r",", " , ", string)
18 | string = re.sub(r"!", " ! ", string)
19 | string = re.sub(r"\(", " \( ", string)
20 | string = re.sub(r"\)", " \) ", string)
21 | string = re.sub(r"\?", " \? ", string)
22 | string = re.sub(r"\s{2,}", " ", string)
23 | return string.strip().lower()
24 |
25 |
26 | def load_data_and_labels(positive_data_file, negative_data_file):
27 | """
28 | Loads MR polarity data from files, splits the data into words and generates labels.
29 | Returns split sentences and labels.
30 | """
31 | # Load data from files
32 | positive_examples = list(open(positive_data_file, "r", encoding="UTF8").readlines())
33 | positive_examples = [s.strip() for s in positive_examples]
34 | negative_examples = list(open(negative_data_file, "r", encoding="UTF8").readlines())
35 | negative_examples = [s.strip() for s in negative_examples]
36 | # Split by words
37 | x_text = positive_examples + negative_examples
38 | x_text = [clean_str(sent) for sent in x_text]
39 | # Generate labels
40 | positive_labels = [[0, 1] for _ in positive_examples]
41 | negative_labels = [[1, 0] for _ in negative_examples]
42 | y = np.concatenate([positive_labels, negative_labels], 0)
43 | return [x_text, y]
44 |
45 |
46 | def batch_iter(data, batch_size, num_epochs, shuffle=True):
47 | """
48 | Generates a batch iterator for a dataset.
49 | """
50 | data = np.array(data)
51 | data_size = len(data)
52 | num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1
53 | for epoch in range(num_epochs):
54 | # Shuffle the data at each epoch
55 | if shuffle:
56 | shuffle_indices = np.random.permutation(np.arange(data_size))
57 | shuffled_data = data[shuffle_indices]
58 | else:
59 | shuffled_data = data
60 | for batch_num in range(num_batches_per_epoch):
61 | start_index = batch_num * batch_size
62 | end_index = min((batch_num + 1) * batch_size, data_size)
63 | yield shuffled_data[start_index:end_index]
64 |
65 |
66 | if __name__ == "__main__":
67 | pos_dir = "data/rt-polaritydata/rt-polarity.pos"
68 | neg_dir = "data/rt-polaritydata/rt-polarity.neg"
69 |
70 | load_data_and_labels(pos_dir, neg_dir)
71 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import data_helpers
5 |
6 |
7 | # Parameters
8 | # ==================================================
9 |
10 | # Data loading params
11 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data")
12 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data")
13 |
14 | # Eval Parameters
15 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)")
16 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
17 |
18 | # Misc Parameters
19 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
20 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
21 |
22 |
23 | FLAGS = tf.flags.FLAGS
24 | FLAGS._parse_flags()
25 | print("\nParameters:")
26 | for attr, value in sorted(FLAGS.__flags.items()):
27 | print("{} = {}".format(attr.upper(), value))
28 | print("")
29 |
30 |
31 | def eval():
32 | with tf.device('/cpu:0'):
33 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir)
34 |
35 | # Map data into vocabulary
36 | text_path = os.path.join(FLAGS.checkpoint_dir, "..", "text_vocab")
37 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(text_path)
38 |
39 | x_eval = np.array(list(text_vocab_processor.transform(x_text)))
40 | y_eval = np.argmax(y, axis=1)
41 |
42 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
43 |
44 | graph = tf.Graph()
45 | with graph.as_default():
46 | session_conf = tf.ConfigProto(
47 | allow_soft_placement=FLAGS.allow_soft_placement,
48 | log_device_placement=FLAGS.log_device_placement)
49 | sess = tf.Session(config=session_conf)
50 | with sess.as_default():
51 | # Load the saved meta graph and restore variables
52 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
53 | saver.restore(sess, checkpoint_file)
54 |
55 | # Get the placeholders from the graph by name
56 | input_text = graph.get_operation_by_name("input_text").outputs[0]
57 | # input_y = graph.get_operation_by_name("input_y").outputs[0]
58 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
59 |
60 | # Tensors we want to evaluate
61 | predictions = graph.get_operation_by_name("output/predictions").outputs[0]
62 |
63 | # Generate batches for one epoch
64 | batches = data_helpers.batch_iter(list(x_eval), FLAGS.batch_size, 1, shuffle=False)
65 |
66 | # Collect the predictions here
67 | all_predictions = []
68 | for x_batch in batches:
69 | batch_predictions = sess.run(predictions, {input_text: x_batch,
70 | dropout_keep_prob: 1.0})
71 | all_predictions = np.concatenate([all_predictions, batch_predictions])
72 |
73 | correct_predictions = float(sum(all_predictions == y_eval))
74 | print("Total number of test examples: {}".format(len(y_eval)))
75 | print("Accuracy: {:g}".format(correct_predictions / float(len(y_eval))))
76 |
77 |
78 | def main(_):
79 | eval()
80 |
81 |
82 | if __name__ == "__main__":
83 | tf.app.run()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recurrent Neural Network for Text Calssification
2 | Tensorflow implementation of RNN(Recurrent Neural Network) for sentiment analysis, one of the text classification problems. There are three types of RNN models, 1) Vanilla RNN, 2) Long Short-Term Memory RNN and 3) Gated Recurrent Unit RNN.
3 |
4 | 
5 |
6 |
7 | ## Data: Movie Review
8 | * Movie reviews with one sentence per review. Classification involves detecting positive/negative reviews ([Pang and Lee, 2005](#reference))
9 | * Download "*sentence polarity dataset v1.0*" at the [Official Download Page](http://www.cs.cornell.edu/people/pabo/movie-review-data/)
10 | * Located in *"data/rt-polaritydata/"* in my repository
11 | * *rt-polarity.pos* contains 5331 positive snippets
12 | * *rt-polarity.neg* contains 5331 negative snippets
13 |
14 |
15 | ## Usage
16 | ### Train
17 | * positive data is located in *"data/rt-polaritydata/rt-polarity.pos"*
18 | * negative data is located in *"data/rt-polaritydata/rt-polarity.neg"*
19 | * "[GoogleNews-vectors-negative300](https://code.google.com/archive/p/word2vec/)" is used as pre-trained word2vec model
20 | * Display help message:
21 |
22 | ```bash
23 | $ python train.py --help
24 | ```
25 |
26 | * **Train Example:**
27 |
28 | #### 1. Vanilla RNN
29 | 
30 |
31 | ```bash
32 | $ python train.py --cell_type "vanilla" \
33 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \
34 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\
35 | --word2vec "GoogleNews-vectors-negative300.bin"
36 | ```
37 |
38 | #### 2. Long Short-Term Memory (LSTM) RNN
39 | 
40 |
41 | ```bash
42 | $ python train.py --cell_type "lstm" \
43 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \
44 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\
45 | --word2vec "GoogleNews-vectors-negative300.bin"
46 | ```
47 |
48 | #### 3. Gated Reccurrent Unit (GRU) RNN
49 | 
50 |
51 | ```bash
52 | $ python train.py --cell_type "gru" \
53 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \
54 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\
55 | --word2vec "GoogleNews-vectors-negative300.bin"
56 | ```
57 |
58 |
59 | ### Evalutation
60 | * Movie Review dataset has **no test data**.
61 | * If you want to evaluate, you should make test dataset from train data or do cross validation. However, cross validation is not implemented in my project.
62 | * The bellow example just use full rt-polarity dataset same the train dataset
63 | * **Evaluation Example:**
64 |
65 | ```bash
66 | $ python eval.py \
67 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \
68 | --neg_dir "data/rt-polaritydata/rt-polarity.neg" \
69 | --checkpoint_dir "runs/1523902663/checkpoints"
70 | ```
71 |
72 |
73 | ## Reference
74 | * **Seeing stars: Exploiting class relationships for sentiment categorization with
75 | respect to rating scales** (ACL 2005), B Pong et al. [[paper]](http://www.cs.cornell.edu/home/llee/papers/pang-lee-stars.pdf)
76 | * **Long short-term memory** (Neural Computation 1997), J Schmidhuber et al. [[paper]](https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735)
77 | * **Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation** (EMNLP 2014), K Cho et al. [[paper]](https://arxiv.org/abs/1406.1078)
78 | * Understanding LSTM Networks [[blog]](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
79 | * RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION [[blog]](https://theneuralperspective.com/2016/10/06/recurrent-neural-networks-rnn-part-2-text-classification/)
80 |
81 |
82 |
--------------------------------------------------------------------------------
/rnn.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class RNN:
5 | def __init__(self, sequence_length, num_classes, vocab_size, embedding_size,
6 | cell_type, hidden_size, l2_reg_lambda=0.0):
7 |
8 | # Placeholders for input, output and dropout
9 | self.input_text = tf.placeholder(tf.int32, shape=[None, sequence_length], name='input_text')
10 | self.input_y = tf.placeholder(tf.float32, shape=[None, num_classes], name='input_y')
11 | self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')
12 |
13 | l2_loss = tf.constant(0.0)
14 | text_length = self._length(self.input_text)
15 |
16 | # Embedding layer
17 | with tf.device('/cpu:0'), tf.name_scope("text-embedding"):
18 | self.W_text = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), name="W_text")
19 | self.embedded_chars = tf.nn.embedding_lookup(self.W_text, self.input_text)
20 |
21 | # Recurrent Neural Network
22 | with tf.name_scope("rnn"):
23 | cell = self._get_cell(hidden_size, cell_type)
24 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.dropout_keep_prob)
25 | all_outputs, _ = tf.nn.dynamic_rnn(cell=cell,
26 | inputs=self.embedded_chars,
27 | sequence_length=text_length,
28 | dtype=tf.float32)
29 | self.h_outputs = self.last_relevant(all_outputs, text_length)
30 |
31 | # Final scores and predictions
32 | with tf.name_scope("output"):
33 | W = tf.get_variable("W", shape=[hidden_size, num_classes], initializer=tf.contrib.layers.xavier_initializer())
34 | b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
35 | l2_loss += tf.nn.l2_loss(W)
36 | l2_loss += tf.nn.l2_loss(b)
37 | self.logits = tf.nn.xw_plus_b(self.h_outputs, W, b, name="logits")
38 | self.predictions = tf.argmax(self.logits, 1, name="predictions")
39 |
40 | # Calculate mean cross-entropy loss
41 | with tf.name_scope("loss"):
42 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
43 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
44 |
45 | # Accuracy
46 | with tf.name_scope("accuracy"):
47 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, axis=1))
48 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name="accuracy")
49 |
50 | @staticmethod
51 | def _get_cell(hidden_size, cell_type):
52 | if cell_type == "vanilla":
53 | return tf.nn.rnn_cell.BasicRNNCell(hidden_size)
54 | elif cell_type == "lstm":
55 | return tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
56 | elif cell_type == "gru":
57 | return tf.nn.rnn_cell.GRUCell(hidden_size)
58 | else:
59 | print("ERROR: '" + cell_type + "' is a wrong cell type !!!")
60 | return None
61 |
62 | # Length of the sequence data
63 | @staticmethod
64 | def _length(seq):
65 | relevant = tf.sign(tf.abs(seq))
66 | length = tf.reduce_sum(relevant, reduction_indices=1)
67 | length = tf.cast(length, tf.int32)
68 | return length
69 |
70 | # Extract the output of last cell of each sequence
71 | # Ex) The movie is good -> length = 4
72 | # output = [ [1.314, -3.32, ..., 0.98]
73 | # [0.287, -0.50, ..., 1.55]
74 | # [2.194, -2.12, ..., 0.63]
75 | # [1.938, -1.88, ..., 1.31]
76 | # [ 0.0, 0.0, ..., 0.0]
77 | # ...
78 | # [ 0.0, 0.0, ..., 0.0] ]
79 | # The output we need is 4th output of cell, so extract it.
80 | @staticmethod
81 | def last_relevant(seq, length):
82 | batch_size = tf.shape(seq)[0]
83 | max_length = int(seq.get_shape()[1])
84 | input_size = int(seq.get_shape()[2])
85 | index = tf.range(0, batch_size) * max_length + (length - 1)
86 | flat = tf.reshape(seq, [-1, input_size])
87 | return tf.gather(flat, index)
88 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import datetime
5 | import time
6 | from rnn import RNN
7 | import data_helpers
8 |
9 | # Parameters
10 | # ==================================================
11 |
12 | # Data loading params
13 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data")
14 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data")
15 | tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
16 | tf.flags.DEFINE_integer("max_sentence_length", 100, "Max sentence length in train/test data (Default: 100)")
17 |
18 | # Model Hyperparameters
19 | tf.flags.DEFINE_string("cell_type", "vanilla", "Type of rnn cell. Choose 'vanilla' or 'lstm' or 'gru' (Default: vanilla)")
20 | tf.flags.DEFINE_string("word2vec", None, "Word2vec file with pre-trained embeddings")
21 | tf.flags.DEFINE_integer("embedding_dim", 300, "Dimensionality of character embedding (Default: 300)")
22 | tf.flags.DEFINE_integer("hidden_size", 128, "Dimensionality of character embedding (Default: 128)")
23 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (Default: 0.5)")
24 | tf.flags.DEFINE_float("l2_reg_lambda", 3.0, "L2 regularization lambda (Default: 3.0)")
25 |
26 | # Training parameters
27 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)")
28 | tf.flags.DEFINE_integer("num_epochs", 100, "Number of training epochs (Default: 100)")
29 | tf.flags.DEFINE_integer("display_every", 10, "Number of iterations to display training info.")
30 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps")
31 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps")
32 | tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store")
33 | tf.flags.DEFINE_float("learning_rate", 1e-3, "Which learning rate to start with. (Default: 1e-3)")
34 |
35 | # Misc Parameters
36 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
37 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
38 |
39 |
40 | FLAGS = tf.flags.FLAGS
41 | FLAGS._parse_flags()
42 | print("\nParameters:")
43 | for attr, value in sorted(FLAGS.__flags.items()):
44 | print("{} = {}".format(attr.upper(), value))
45 | print("")
46 |
47 |
48 | def train():
49 | with tf.device('/cpu:0'):
50 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir)
51 |
52 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(FLAGS.max_sentence_length)
53 | x = np.array(list(text_vocab_processor.fit_transform(x_text)))
54 | print("Text Vocabulary Size: {:d}".format(len(text_vocab_processor.vocabulary_)))
55 |
56 | print("x = {0}".format(x.shape))
57 | print("y = {0}".format(y.shape))
58 | print("")
59 |
60 | # Randomly shuffle data
61 | np.random.seed(10)
62 | shuffle_indices = np.random.permutation(np.arange(len(y)))
63 | x_shuffled = x[shuffle_indices]
64 | y_shuffled = y[shuffle_indices]
65 |
66 | # Split train/test set
67 | # TODO: This is very crude, should use cross-validation
68 | dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
69 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
70 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
71 | print("Train/Dev split: {:d}/{:d}\n".format(len(y_train), len(y_dev)))
72 |
73 | with tf.Graph().as_default():
74 | session_conf = tf.ConfigProto(
75 | allow_soft_placement=FLAGS.allow_soft_placement,
76 | log_device_placement=FLAGS.log_device_placement)
77 | sess = tf.Session(config=session_conf)
78 | with sess.as_default():
79 | rnn = RNN(
80 | sequence_length=x_train.shape[1],
81 | num_classes=y_train.shape[1],
82 | vocab_size=len(text_vocab_processor.vocabulary_),
83 | embedding_size=FLAGS.embedding_dim,
84 | cell_type=FLAGS.cell_type,
85 | hidden_size=FLAGS.hidden_size,
86 | l2_reg_lambda=FLAGS.l2_reg_lambda
87 | )
88 |
89 | # Define Training procedure
90 | global_step = tf.Variable(0, name="global_step", trainable=False)
91 | train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(rnn.loss, global_step=global_step)
92 |
93 | # Output directory for models and summaries
94 | timestamp = str(int(time.time()))
95 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
96 | print("Writing to {}\n".format(out_dir))
97 |
98 | # Summaries for loss and accuracy
99 | loss_summary = tf.summary.scalar("loss", rnn.loss)
100 | acc_summary = tf.summary.scalar("accuracy", rnn.accuracy)
101 |
102 | # Train Summaries
103 | train_summary_op = tf.summary.merge([loss_summary, acc_summary])
104 | train_summary_dir = os.path.join(out_dir, "summaries", "train")
105 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
106 |
107 | # Dev summaries
108 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
109 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
110 | dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
111 |
112 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
113 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
114 | checkpoint_prefix = os.path.join(checkpoint_dir, "model")
115 | if not os.path.exists(checkpoint_dir):
116 | os.makedirs(checkpoint_dir)
117 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
118 |
119 | # Write vocabulary
120 | text_vocab_processor.save(os.path.join(out_dir, "text_vocab"))
121 |
122 | # Initialize all variables
123 | sess.run(tf.global_variables_initializer())
124 |
125 | # Pre-trained word2vec
126 | if FLAGS.word2vec:
127 | # initial matrix with random uniform
128 | initW = np.random.uniform(-0.25, 0.25, (len(text_vocab_processor.vocabulary_), FLAGS.embedding_dim))
129 | # load any vectors from the word2vec
130 | print("Load word2vec file {0}".format(FLAGS.word2vec))
131 | with open(FLAGS.word2vec, "rb") as f:
132 | header = f.readline()
133 | vocab_size, layer1_size = map(int, header.split())
134 | binary_len = np.dtype('float32').itemsize * layer1_size
135 | for line in range(vocab_size):
136 | word = []
137 | while True:
138 | ch = f.read(1).decode('latin-1')
139 | if ch == ' ':
140 | word = ''.join(word)
141 | break
142 | if ch != '\n':
143 | word.append(ch)
144 | idx = text_vocab_processor.vocabulary_.get(word)
145 | if idx != 0:
146 | initW[idx] = np.fromstring(f.read(binary_len), dtype='float32')
147 | else:
148 | f.read(binary_len)
149 | sess.run(rnn.W_text.assign(initW))
150 | print("Success to load pre-trained word2vec model!\n")
151 |
152 | # Generate batches
153 | batches = data_helpers.batch_iter(
154 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
155 | # Training loop. For each batch...
156 | for batch in batches:
157 | x_batch, y_batch = zip(*batch)
158 | # Train
159 | feed_dict = {
160 | rnn.input_text: x_batch,
161 | rnn.input_y: y_batch,
162 | rnn.dropout_keep_prob: FLAGS.dropout_keep_prob
163 | }
164 | _, step, summaries, loss, accuracy = sess.run(
165 | [train_op, global_step, train_summary_op, rnn.loss, rnn.accuracy], feed_dict)
166 | train_summary_writer.add_summary(summaries, step)
167 |
168 | # Training log display
169 | if step % FLAGS.display_every == 0:
170 | time_str = datetime.datetime.now().isoformat()
171 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
172 |
173 | # Evaluation
174 | if step % FLAGS.evaluate_every == 0:
175 | print("\nEvaluation:")
176 | feed_dict_dev = {
177 | rnn.input_text: x_dev,
178 | rnn.input_y: y_dev,
179 | rnn.dropout_keep_prob: 1.0
180 | }
181 | summaries_dev, loss, accuracy = sess.run(
182 | [dev_summary_op, rnn.loss, rnn.accuracy], feed_dict_dev)
183 | dev_summary_writer.add_summary(summaries_dev, step)
184 |
185 | time_str = datetime.datetime.now().isoformat()
186 | print("{}: step {}, loss {:g}, acc {:g}\n".format(time_str, step, loss, accuracy))
187 |
188 | # Model checkpoint
189 | if step % FLAGS.checkpoint_every == 0:
190 | path = saver.save(sess, checkpoint_prefix, global_step=step)
191 | print("Saved model checkpoint to {}\n".format(path))
192 |
193 |
194 | def main(_):
195 | train()
196 |
197 |
198 | if __name__ == "__main__":
199 | tf.app.run()
200 |
--------------------------------------------------------------------------------