├── .gitignore ├── LICENSE ├── data_helpers.py ├── eval.py ├── README.md ├── rnn.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | w2v_model/ 2 | saved/ 3 | runs/ 4 | **/.idea/ 5 | **/__pycache__/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Joohong Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | 5 | def clean_str(string): 6 | """ 7 | Tokenization/string cleaning for all datasets except for SST. 8 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 9 | """ 10 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 11 | string = re.sub(r"\'s", " \'s", string) 12 | string = re.sub(r"\'ve", " \'ve", string) 13 | string = re.sub(r"n\'t", " n\'t", string) 14 | string = re.sub(r"\'re", " \'re", string) 15 | string = re.sub(r"\'d", " \'d", string) 16 | string = re.sub(r"\'ll", " \'ll", string) 17 | string = re.sub(r",", " , ", string) 18 | string = re.sub(r"!", " ! ", string) 19 | string = re.sub(r"\(", " \( ", string) 20 | string = re.sub(r"\)", " \) ", string) 21 | string = re.sub(r"\?", " \? ", string) 22 | string = re.sub(r"\s{2,}", " ", string) 23 | return string.strip().lower() 24 | 25 | 26 | def load_data_and_labels(positive_data_file, negative_data_file): 27 | """ 28 | Loads MR polarity data from files, splits the data into words and generates labels. 29 | Returns split sentences and labels. 30 | """ 31 | # Load data from files 32 | positive_examples = list(open(positive_data_file, "r", encoding="UTF8").readlines()) 33 | positive_examples = [s.strip() for s in positive_examples] 34 | negative_examples = list(open(negative_data_file, "r", encoding="UTF8").readlines()) 35 | negative_examples = [s.strip() for s in negative_examples] 36 | # Split by words 37 | x_text = positive_examples + negative_examples 38 | x_text = [clean_str(sent) for sent in x_text] 39 | # Generate labels 40 | positive_labels = [[0, 1] for _ in positive_examples] 41 | negative_labels = [[1, 0] for _ in negative_examples] 42 | y = np.concatenate([positive_labels, negative_labels], 0) 43 | return [x_text, y] 44 | 45 | 46 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 47 | """ 48 | Generates a batch iterator for a dataset. 49 | """ 50 | data = np.array(data) 51 | data_size = len(data) 52 | num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1 53 | for epoch in range(num_epochs): 54 | # Shuffle the data at each epoch 55 | if shuffle: 56 | shuffle_indices = np.random.permutation(np.arange(data_size)) 57 | shuffled_data = data[shuffle_indices] 58 | else: 59 | shuffled_data = data 60 | for batch_num in range(num_batches_per_epoch): 61 | start_index = batch_num * batch_size 62 | end_index = min((batch_num + 1) * batch_size, data_size) 63 | yield shuffled_data[start_index:end_index] 64 | 65 | 66 | if __name__ == "__main__": 67 | pos_dir = "data/rt-polaritydata/rt-polarity.pos" 68 | neg_dir = "data/rt-polaritydata/rt-polarity.neg" 69 | 70 | load_data_and_labels(pos_dir, neg_dir) 71 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import data_helpers 5 | 6 | 7 | # Parameters 8 | # ================================================== 9 | 10 | # Data loading params 11 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data") 12 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data") 13 | 14 | # Eval Parameters 15 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)") 16 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 17 | 18 | # Misc Parameters 19 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 20 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 21 | 22 | 23 | FLAGS = tf.flags.FLAGS 24 | FLAGS._parse_flags() 25 | print("\nParameters:") 26 | for attr, value in sorted(FLAGS.__flags.items()): 27 | print("{} = {}".format(attr.upper(), value)) 28 | print("") 29 | 30 | 31 | def eval(): 32 | with tf.device('/cpu:0'): 33 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir) 34 | 35 | # Map data into vocabulary 36 | text_path = os.path.join(FLAGS.checkpoint_dir, "..", "text_vocab") 37 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(text_path) 38 | 39 | x_eval = np.array(list(text_vocab_processor.transform(x_text))) 40 | y_eval = np.argmax(y, axis=1) 41 | 42 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 43 | 44 | graph = tf.Graph() 45 | with graph.as_default(): 46 | session_conf = tf.ConfigProto( 47 | allow_soft_placement=FLAGS.allow_soft_placement, 48 | log_device_placement=FLAGS.log_device_placement) 49 | sess = tf.Session(config=session_conf) 50 | with sess.as_default(): 51 | # Load the saved meta graph and restore variables 52 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 53 | saver.restore(sess, checkpoint_file) 54 | 55 | # Get the placeholders from the graph by name 56 | input_text = graph.get_operation_by_name("input_text").outputs[0] 57 | # input_y = graph.get_operation_by_name("input_y").outputs[0] 58 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 59 | 60 | # Tensors we want to evaluate 61 | predictions = graph.get_operation_by_name("output/predictions").outputs[0] 62 | 63 | # Generate batches for one epoch 64 | batches = data_helpers.batch_iter(list(x_eval), FLAGS.batch_size, 1, shuffle=False) 65 | 66 | # Collect the predictions here 67 | all_predictions = [] 68 | for x_batch in batches: 69 | batch_predictions = sess.run(predictions, {input_text: x_batch, 70 | dropout_keep_prob: 1.0}) 71 | all_predictions = np.concatenate([all_predictions, batch_predictions]) 72 | 73 | correct_predictions = float(sum(all_predictions == y_eval)) 74 | print("Total number of test examples: {}".format(len(y_eval))) 75 | print("Accuracy: {:g}".format(correct_predictions / float(len(y_eval)))) 76 | 77 | 78 | def main(_): 79 | eval() 80 | 81 | 82 | if __name__ == "__main__": 83 | tf.app.run() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Neural Network for Text Calssification 2 | Tensorflow implementation of RNN(Recurrent Neural Network) for sentiment analysis, one of the text classification problems. There are three types of RNN models, 1) Vanilla RNN, 2) Long Short-Term Memory RNN and 3) Gated Recurrent Unit RNN. 3 | 4 | ![rnn](https://user-images.githubusercontent.com/15166794/39031786-370d0cae-44a5-11e8-8440-27102312274c.png) 5 | 6 | 7 | ## Data: Movie Review 8 | * Movie reviews with one sentence per review. Classification involves detecting positive/negative reviews ([Pang and Lee, 2005](#reference)) 9 | * Download "*sentence polarity dataset v1.0*" at the [Official Download Page](http://www.cs.cornell.edu/people/pabo/movie-review-data/) 10 | * Located in *"data/rt-polaritydata/"* in my repository 11 | * *rt-polarity.pos* contains 5331 positive snippets 12 | * *rt-polarity.neg* contains 5331 negative snippets 13 | 14 | 15 | ## Usage 16 | ### Train 17 | * positive data is located in *"data/rt-polaritydata/rt-polarity.pos"* 18 | * negative data is located in *"data/rt-polaritydata/rt-polarity.neg"* 19 | * "[GoogleNews-vectors-negative300](https://code.google.com/archive/p/word2vec/)" is used as pre-trained word2vec model 20 | * Display help message: 21 | 22 | ```bash 23 | $ python train.py --help 24 | ``` 25 | 26 | * **Train Example:** 27 | 28 | #### 1. Vanilla RNN 29 | ![vanilla](https://user-images.githubusercontent.com/15166794/39033685-30859e24-44ae-11e8-9d7d-860c75efe080.png) 30 | 31 | ```bash 32 | $ python train.py --cell_type "vanilla" \ 33 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \ 34 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\ 35 | --word2vec "GoogleNews-vectors-negative300.bin" 36 | ``` 37 | 38 | #### 2. Long Short-Term Memory (LSTM) RNN 39 | ![lstm](https://user-images.githubusercontent.com/15166794/39033684-3053546e-44ae-11e8-893a-7fa685039ce2.png) 40 | 41 | ```bash 42 | $ python train.py --cell_type "lstm" \ 43 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \ 44 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\ 45 | --word2vec "GoogleNews-vectors-negative300.bin" 46 | ``` 47 | 48 | #### 3. Gated Reccurrent Unit (GRU) RNN 49 | ![gru](https://user-images.githubusercontent.com/15166794/39033683-3020ce04-44ae-11e8-821f-1a9652ff5025.png) 50 | 51 | ```bash 52 | $ python train.py --cell_type "gru" \ 53 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \ 54 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\ 55 | --word2vec "GoogleNews-vectors-negative300.bin" 56 | ``` 57 | 58 | 59 | ### Evalutation 60 | * Movie Review dataset has **no test data**. 61 | * If you want to evaluate, you should make test dataset from train data or do cross validation. However, cross validation is not implemented in my project. 62 | * The bellow example just use full rt-polarity dataset same the train dataset 63 | * **Evaluation Example:** 64 | 65 | ```bash 66 | $ python eval.py \ 67 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \ 68 | --neg_dir "data/rt-polaritydata/rt-polarity.neg" \ 69 | --checkpoint_dir "runs/1523902663/checkpoints" 70 | ``` 71 | 72 | 73 | ## Reference 74 | * **Seeing stars: Exploiting class relationships for sentiment categorization with 75 | respect to rating scales** (ACL 2005), B Pong et al. [[paper]](http://www.cs.cornell.edu/home/llee/papers/pang-lee-stars.pdf) 76 | * **Long short-term memory** (Neural Computation 1997), J Schmidhuber et al. [[paper]](https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) 77 | * **Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation** (EMNLP 2014), K Cho et al. [[paper]](https://arxiv.org/abs/1406.1078) 78 | * Understanding LSTM Networks [[blog]](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) 79 | * RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION [[blog]](https://theneuralperspective.com/2016/10/06/recurrent-neural-networks-rnn-part-2-text-classification/) 80 | 81 | 82 | -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class RNN: 5 | def __init__(self, sequence_length, num_classes, vocab_size, embedding_size, 6 | cell_type, hidden_size, l2_reg_lambda=0.0): 7 | 8 | # Placeholders for input, output and dropout 9 | self.input_text = tf.placeholder(tf.int32, shape=[None, sequence_length], name='input_text') 10 | self.input_y = tf.placeholder(tf.float32, shape=[None, num_classes], name='input_y') 11 | self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob') 12 | 13 | l2_loss = tf.constant(0.0) 14 | text_length = self._length(self.input_text) 15 | 16 | # Embedding layer 17 | with tf.device('/cpu:0'), tf.name_scope("text-embedding"): 18 | self.W_text = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), name="W_text") 19 | self.embedded_chars = tf.nn.embedding_lookup(self.W_text, self.input_text) 20 | 21 | # Recurrent Neural Network 22 | with tf.name_scope("rnn"): 23 | cell = self._get_cell(hidden_size, cell_type) 24 | cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.dropout_keep_prob) 25 | all_outputs, _ = tf.nn.dynamic_rnn(cell=cell, 26 | inputs=self.embedded_chars, 27 | sequence_length=text_length, 28 | dtype=tf.float32) 29 | self.h_outputs = self.last_relevant(all_outputs, text_length) 30 | 31 | # Final scores and predictions 32 | with tf.name_scope("output"): 33 | W = tf.get_variable("W", shape=[hidden_size, num_classes], initializer=tf.contrib.layers.xavier_initializer()) 34 | b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b") 35 | l2_loss += tf.nn.l2_loss(W) 36 | l2_loss += tf.nn.l2_loss(b) 37 | self.logits = tf.nn.xw_plus_b(self.h_outputs, W, b, name="logits") 38 | self.predictions = tf.argmax(self.logits, 1, name="predictions") 39 | 40 | # Calculate mean cross-entropy loss 41 | with tf.name_scope("loss"): 42 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 43 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 44 | 45 | # Accuracy 46 | with tf.name_scope("accuracy"): 47 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, axis=1)) 48 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name="accuracy") 49 | 50 | @staticmethod 51 | def _get_cell(hidden_size, cell_type): 52 | if cell_type == "vanilla": 53 | return tf.nn.rnn_cell.BasicRNNCell(hidden_size) 54 | elif cell_type == "lstm": 55 | return tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 56 | elif cell_type == "gru": 57 | return tf.nn.rnn_cell.GRUCell(hidden_size) 58 | else: 59 | print("ERROR: '" + cell_type + "' is a wrong cell type !!!") 60 | return None 61 | 62 | # Length of the sequence data 63 | @staticmethod 64 | def _length(seq): 65 | relevant = tf.sign(tf.abs(seq)) 66 | length = tf.reduce_sum(relevant, reduction_indices=1) 67 | length = tf.cast(length, tf.int32) 68 | return length 69 | 70 | # Extract the output of last cell of each sequence 71 | # Ex) The movie is good -> length = 4 72 | # output = [ [1.314, -3.32, ..., 0.98] 73 | # [0.287, -0.50, ..., 1.55] 74 | # [2.194, -2.12, ..., 0.63] 75 | # [1.938, -1.88, ..., 1.31] 76 | # [ 0.0, 0.0, ..., 0.0] 77 | # ... 78 | # [ 0.0, 0.0, ..., 0.0] ] 79 | # The output we need is 4th output of cell, so extract it. 80 | @staticmethod 81 | def last_relevant(seq, length): 82 | batch_size = tf.shape(seq)[0] 83 | max_length = int(seq.get_shape()[1]) 84 | input_size = int(seq.get_shape()[2]) 85 | index = tf.range(0, batch_size) * max_length + (length - 1) 86 | flat = tf.reshape(seq, [-1, input_size]) 87 | return tf.gather(flat, index) 88 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import datetime 5 | import time 6 | from rnn import RNN 7 | import data_helpers 8 | 9 | # Parameters 10 | # ================================================== 11 | 12 | # Data loading params 13 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data") 14 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data") 15 | tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation") 16 | tf.flags.DEFINE_integer("max_sentence_length", 100, "Max sentence length in train/test data (Default: 100)") 17 | 18 | # Model Hyperparameters 19 | tf.flags.DEFINE_string("cell_type", "vanilla", "Type of rnn cell. Choose 'vanilla' or 'lstm' or 'gru' (Default: vanilla)") 20 | tf.flags.DEFINE_string("word2vec", None, "Word2vec file with pre-trained embeddings") 21 | tf.flags.DEFINE_integer("embedding_dim", 300, "Dimensionality of character embedding (Default: 300)") 22 | tf.flags.DEFINE_integer("hidden_size", 128, "Dimensionality of character embedding (Default: 128)") 23 | tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (Default: 0.5)") 24 | tf.flags.DEFINE_float("l2_reg_lambda", 3.0, "L2 regularization lambda (Default: 3.0)") 25 | 26 | # Training parameters 27 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)") 28 | tf.flags.DEFINE_integer("num_epochs", 100, "Number of training epochs (Default: 100)") 29 | tf.flags.DEFINE_integer("display_every", 10, "Number of iterations to display training info.") 30 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps") 31 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps") 32 | tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store") 33 | tf.flags.DEFINE_float("learning_rate", 1e-3, "Which learning rate to start with. (Default: 1e-3)") 34 | 35 | # Misc Parameters 36 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 37 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 38 | 39 | 40 | FLAGS = tf.flags.FLAGS 41 | FLAGS._parse_flags() 42 | print("\nParameters:") 43 | for attr, value in sorted(FLAGS.__flags.items()): 44 | print("{} = {}".format(attr.upper(), value)) 45 | print("") 46 | 47 | 48 | def train(): 49 | with tf.device('/cpu:0'): 50 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir) 51 | 52 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(FLAGS.max_sentence_length) 53 | x = np.array(list(text_vocab_processor.fit_transform(x_text))) 54 | print("Text Vocabulary Size: {:d}".format(len(text_vocab_processor.vocabulary_))) 55 | 56 | print("x = {0}".format(x.shape)) 57 | print("y = {0}".format(y.shape)) 58 | print("") 59 | 60 | # Randomly shuffle data 61 | np.random.seed(10) 62 | shuffle_indices = np.random.permutation(np.arange(len(y))) 63 | x_shuffled = x[shuffle_indices] 64 | y_shuffled = y[shuffle_indices] 65 | 66 | # Split train/test set 67 | # TODO: This is very crude, should use cross-validation 68 | dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y))) 69 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] 70 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:] 71 | print("Train/Dev split: {:d}/{:d}\n".format(len(y_train), len(y_dev))) 72 | 73 | with tf.Graph().as_default(): 74 | session_conf = tf.ConfigProto( 75 | allow_soft_placement=FLAGS.allow_soft_placement, 76 | log_device_placement=FLAGS.log_device_placement) 77 | sess = tf.Session(config=session_conf) 78 | with sess.as_default(): 79 | rnn = RNN( 80 | sequence_length=x_train.shape[1], 81 | num_classes=y_train.shape[1], 82 | vocab_size=len(text_vocab_processor.vocabulary_), 83 | embedding_size=FLAGS.embedding_dim, 84 | cell_type=FLAGS.cell_type, 85 | hidden_size=FLAGS.hidden_size, 86 | l2_reg_lambda=FLAGS.l2_reg_lambda 87 | ) 88 | 89 | # Define Training procedure 90 | global_step = tf.Variable(0, name="global_step", trainable=False) 91 | train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(rnn.loss, global_step=global_step) 92 | 93 | # Output directory for models and summaries 94 | timestamp = str(int(time.time())) 95 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 96 | print("Writing to {}\n".format(out_dir)) 97 | 98 | # Summaries for loss and accuracy 99 | loss_summary = tf.summary.scalar("loss", rnn.loss) 100 | acc_summary = tf.summary.scalar("accuracy", rnn.accuracy) 101 | 102 | # Train Summaries 103 | train_summary_op = tf.summary.merge([loss_summary, acc_summary]) 104 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 105 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 106 | 107 | # Dev summaries 108 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) 109 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 110 | dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) 111 | 112 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 113 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 114 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 115 | if not os.path.exists(checkpoint_dir): 116 | os.makedirs(checkpoint_dir) 117 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) 118 | 119 | # Write vocabulary 120 | text_vocab_processor.save(os.path.join(out_dir, "text_vocab")) 121 | 122 | # Initialize all variables 123 | sess.run(tf.global_variables_initializer()) 124 | 125 | # Pre-trained word2vec 126 | if FLAGS.word2vec: 127 | # initial matrix with random uniform 128 | initW = np.random.uniform(-0.25, 0.25, (len(text_vocab_processor.vocabulary_), FLAGS.embedding_dim)) 129 | # load any vectors from the word2vec 130 | print("Load word2vec file {0}".format(FLAGS.word2vec)) 131 | with open(FLAGS.word2vec, "rb") as f: 132 | header = f.readline() 133 | vocab_size, layer1_size = map(int, header.split()) 134 | binary_len = np.dtype('float32').itemsize * layer1_size 135 | for line in range(vocab_size): 136 | word = [] 137 | while True: 138 | ch = f.read(1).decode('latin-1') 139 | if ch == ' ': 140 | word = ''.join(word) 141 | break 142 | if ch != '\n': 143 | word.append(ch) 144 | idx = text_vocab_processor.vocabulary_.get(word) 145 | if idx != 0: 146 | initW[idx] = np.fromstring(f.read(binary_len), dtype='float32') 147 | else: 148 | f.read(binary_len) 149 | sess.run(rnn.W_text.assign(initW)) 150 | print("Success to load pre-trained word2vec model!\n") 151 | 152 | # Generate batches 153 | batches = data_helpers.batch_iter( 154 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) 155 | # Training loop. For each batch... 156 | for batch in batches: 157 | x_batch, y_batch = zip(*batch) 158 | # Train 159 | feed_dict = { 160 | rnn.input_text: x_batch, 161 | rnn.input_y: y_batch, 162 | rnn.dropout_keep_prob: FLAGS.dropout_keep_prob 163 | } 164 | _, step, summaries, loss, accuracy = sess.run( 165 | [train_op, global_step, train_summary_op, rnn.loss, rnn.accuracy], feed_dict) 166 | train_summary_writer.add_summary(summaries, step) 167 | 168 | # Training log display 169 | if step % FLAGS.display_every == 0: 170 | time_str = datetime.datetime.now().isoformat() 171 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 172 | 173 | # Evaluation 174 | if step % FLAGS.evaluate_every == 0: 175 | print("\nEvaluation:") 176 | feed_dict_dev = { 177 | rnn.input_text: x_dev, 178 | rnn.input_y: y_dev, 179 | rnn.dropout_keep_prob: 1.0 180 | } 181 | summaries_dev, loss, accuracy = sess.run( 182 | [dev_summary_op, rnn.loss, rnn.accuracy], feed_dict_dev) 183 | dev_summary_writer.add_summary(summaries_dev, step) 184 | 185 | time_str = datetime.datetime.now().isoformat() 186 | print("{}: step {}, loss {:g}, acc {:g}\n".format(time_str, step, loss, accuracy)) 187 | 188 | # Model checkpoint 189 | if step % FLAGS.checkpoint_every == 0: 190 | path = saver.save(sess, checkpoint_prefix, global_step=step) 191 | print("Saved model checkpoint to {}\n".format(path)) 192 | 193 | 194 | def main(_): 195 | train() 196 | 197 | 198 | if __name__ == "__main__": 199 | tf.app.run() 200 | --------------------------------------------------------------------------------