├── .gitignore ├── README.md ├── data └── rt-polaritydata │ ├── rt-polarity.neg │ └── rt-polarity.pos ├── data_helpers.py ├── eval.py ├── rcnn.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | w2v_model/ 2 | saved/ 3 | runs/ 4 | **/.idea/ 5 | **/__pycache__/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Convolutional Neural Network for Text Classification 2 | Tensorflow implementation of "[Recurrent Convolutional Neural Network for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745)". 3 | 4 | ![rcnn](https://user-images.githubusercontent.com/15166794/39769535-703d02c8-5327-11e8-99d8-44a060e63e48.PNG) 5 | 6 | 7 | ## Data: Movie Review 8 | * Movie reviews with one sentence per review. Classification involves detecting positive/negative reviews ([Pang and Lee, 2005](#reference)). 9 | * Download "*sentence polarity dataset v1.0*" at the [Official Download Page](http://www.cs.cornell.edu/people/pabo/movie-review-data/). 10 | * Located in *"data/rt-polaritydata/"* in my repository. 11 | * *rt-polarity.pos* contains 5331 positive snippets. 12 | * *rt-polarity.neg* contains 5331 negative snippets. 13 | 14 | 15 | ## Implementation of Recurrent Structure 16 | 17 | ![recurrent_structure](https://user-images.githubusercontent.com/15166794/39777565-db89ca68-533e-11e8-8a87-785f98b3cfef.PNG) 18 | 19 | * Bidirectional RNN (Bi-RNN) is used to implement the left and right context vectors. 20 | * Each context vector is created by shifting the output of Bi-RNN and concatenating a zero state indicating the start of the context. 21 | 22 | 23 | ## Usage 24 | ### Train 25 | * positive data is located in *"data/rt-polaritydata/rt-polarity.pos"*. 26 | * negative data is located in *"data/rt-polaritydata/rt-polarity.neg"*. 27 | * "[GoogleNews-vectors-negative300](https://code.google.com/archive/p/word2vec/)" is used as pre-trained word2vec model. 28 | * Display help message: 29 | 30 | ```bash 31 | $ python train.py --help 32 | ``` 33 | 34 | * **Train Example:** 35 | 36 | ```bash 37 | $ python train.py --cell_type "lstm" \ 38 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \ 39 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\ 40 | --word2vec "GoogleNews-vectors-negative300.bin" 41 | ``` 42 | 43 | 44 | ### Evalutation 45 | * Movie Review dataset has **no test data**. 46 | * If you want to evaluate, you should make test dataset from train data or do cross validation. However, cross validation is not implemented in my project. 47 | * The bellow example just use full rt-polarity dataset same the train dataset. 48 | * **Evaluation Example:** 49 | 50 | ```bash 51 | $ python eval.py \ 52 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \ 53 | --neg_dir "data/rt-polaritydata/rt-polarity.neg" \ 54 | --checkpoint_dir "runs/1523902663/checkpoints" 55 | ``` 56 | 57 | 58 | ## Result 59 | * Comparision between Recurrent Convolutional Neural Network and Convolutional Neural Network. 60 | * dennybritz's [cnn-text-classification-tf](https://github.com/dennybritz/cnn-text-classification-tf) is used for compared CNN model. 61 | * Same pre-trained word2vec used for both models. 62 | 63 | #### Accuracy for validation set 64 | ![accuracy](https://user-images.githubusercontent.com/15166794/39774365-9b8aa27e-5335-11e8-9710-515bc03dccb6.PNG) 65 | 66 | #### Loss for validation set 67 | ![accuracy](https://user-images.githubusercontent.com/15166794/39774367-9bb2166a-5335-11e8-8d71-f06a61eee88a.PNG) 68 | 69 | 70 | ## Reference 71 | * Recurrent Convolutional Neural Network for Text Classification (AAAI 2015), S Lai et al. [[paper]](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745) 72 | 73 | 74 | -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import nltk 4 | import re 5 | 6 | 7 | def clean_str(string): 8 | """ 9 | Tokenization/string cleaning for all datasets except for SST. 10 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 11 | """ 12 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 13 | string = re.sub(r"\'s", " \'s", string) 14 | string = re.sub(r"\'ve", " \'ve", string) 15 | string = re.sub(r"n\'t", " n\'t", string) 16 | string = re.sub(r"\'re", " \'re", string) 17 | string = re.sub(r"\'d", " \'d", string) 18 | string = re.sub(r"\'ll", " \'ll", string) 19 | string = re.sub(r",", " , ", string) 20 | string = re.sub(r"!", " ! ", string) 21 | string = re.sub(r"\(", " \( ", string) 22 | string = re.sub(r"\)", " \) ", string) 23 | string = re.sub(r"\?", " \? ", string) 24 | string = re.sub(r"\s{2,}", " ", string) 25 | return string.strip().lower() 26 | 27 | 28 | def load_data_and_labels(positive_data_file, negative_data_file): 29 | """ 30 | Loads MR polarity data from files, splits the data into words and generates labels. 31 | Returns split sentences and labels. 32 | """ 33 | # Load data from files 34 | positive_examples = list(open(positive_data_file, "r", encoding="UTF-8").readlines()) 35 | positive_examples = [s.strip() for s in positive_examples] 36 | negative_examples = list(open(negative_data_file, "r", encoding="UTF-8").readlines()) 37 | negative_examples = [s.strip() for s in negative_examples] 38 | # Split by words 39 | x_text = positive_examples + negative_examples 40 | x_text = [clean_str(sent) for sent in x_text] 41 | # Generate labels 42 | positive_labels = [[0, 1] for _ in positive_examples] 43 | negative_labels = [[1, 0] for _ in negative_examples] 44 | y = np.concatenate([positive_labels, negative_labels], 0) 45 | return [x_text, y] 46 | 47 | 48 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 49 | """ 50 | Generates a batch iterator for a dataset. 51 | """ 52 | data = np.array(data) 53 | data_size = len(data) 54 | num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1 55 | for epoch in range(num_epochs): 56 | # Shuffle the data at each epoch 57 | if shuffle: 58 | shuffle_indices = np.random.permutation(np.arange(data_size)) 59 | shuffled_data = data[shuffle_indices] 60 | else: 61 | shuffled_data = data 62 | for batch_num in range(num_batches_per_epoch): 63 | start_index = batch_num * batch_size 64 | end_index = min((batch_num + 1) * batch_size, data_size) 65 | yield shuffled_data[start_index:end_index] 66 | 67 | 68 | if __name__ == "__main__": 69 | pos_dir = "data/rt-polaritydata/rt-polarity.pos" 70 | neg_dir = "data/rt-polaritydata/rt-polarity.neg" 71 | 72 | load_data_and_labels(pos_dir, neg_dir) 73 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import data_helpers 5 | 6 | 7 | # Parameters 8 | # ================================================== 9 | 10 | # Data loading params 11 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data") 12 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data") 13 | 14 | # Eval Parameters 15 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)") 16 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") 17 | 18 | # Misc Parameters 19 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 20 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 21 | 22 | 23 | FLAGS = tf.flags.FLAGS 24 | FLAGS._parse_flags() 25 | print("\nParameters:") 26 | for attr, value in sorted(FLAGS.__flags.items()): 27 | print("{} = {}".format(attr.upper(), value)) 28 | print("") 29 | 30 | 31 | def eval(): 32 | with tf.device('/cpu:0'): 33 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir) 34 | 35 | # Map data into vocabulary 36 | text_path = os.path.join(FLAGS.checkpoint_dir, "..", "text_vocab") 37 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(text_path) 38 | 39 | x_eval = np.array(list(text_vocab_processor.transform(x_text))) 40 | y_eval = np.argmax(y, axis=1) 41 | 42 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 43 | 44 | graph = tf.Graph() 45 | with graph.as_default(): 46 | session_conf = tf.ConfigProto( 47 | allow_soft_placement=FLAGS.allow_soft_placement, 48 | log_device_placement=FLAGS.log_device_placement) 49 | sess = tf.Session(config=session_conf) 50 | with sess.as_default(): 51 | # Load the saved meta graph and restore variables 52 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 53 | saver.restore(sess, checkpoint_file) 54 | 55 | # Get the placeholders from the graph by name 56 | input_text = graph.get_operation_by_name("input_text").outputs[0] 57 | # input_y = graph.get_operation_by_name("input_y").outputs[0] 58 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0] 59 | 60 | # Tensors we want to evaluate 61 | predictions = graph.get_operation_by_name("output/predictions").outputs[0] 62 | 63 | # Generate batches for one epoch 64 | batches = data_helpers.batch_iter(list(x_eval), FLAGS.batch_size, 1, shuffle=False) 65 | 66 | # Collect the predictions here 67 | all_predictions = [] 68 | for x_batch in batches: 69 | batch_predictions = sess.run(predictions, {input_text: x_batch, 70 | dropout_keep_prob: 1.0}) 71 | all_predictions = np.concatenate([all_predictions, batch_predictions]) 72 | 73 | correct_predictions = float(sum(all_predictions == y_eval)) 74 | print("Total number of test examples: {}".format(len(y_eval))) 75 | print("Accuracy: {:g}".format(correct_predictions / float(len(y_eval)))) 76 | 77 | 78 | def main(_): 79 | eval() 80 | 81 | 82 | if __name__ == "__main__": 83 | tf.app.run() -------------------------------------------------------------------------------- /rcnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class TextRCNN: 5 | def __init__(self, sequence_length, num_classes, vocab_size, word_embedding_size, context_embedding_size, 6 | cell_type, hidden_size, l2_reg_lambda=0.0): 7 | # Placeholders for input, output and dropout 8 | self.input_text = tf.placeholder(tf.int32, shape=[None, sequence_length], name='input_text') 9 | self.input_y = tf.placeholder(tf.float32, shape=[None, num_classes], name='input_y') 10 | self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob') 11 | 12 | l2_loss = tf.constant(0.0) 13 | text_length = self._length(self.input_text) 14 | 15 | # Embeddings 16 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 17 | self.W_text = tf.Variable(tf.random_uniform([vocab_size, word_embedding_size], -1.0, 1.0), name="W_text") 18 | self.embedded_chars = tf.nn.embedding_lookup(self.W_text, self.input_text) 19 | 20 | # Bidirectional(Left&Right) Recurrent Structure 21 | with tf.name_scope("bi-rnn"): 22 | fw_cell = self._get_cell(context_embedding_size, cell_type) 23 | fw_cell = tf.nn.rnn_cell.DropoutWrapper(fw_cell, output_keep_prob=self.dropout_keep_prob) 24 | bw_cell = self._get_cell(context_embedding_size, cell_type) 25 | bw_cell = tf.nn.rnn_cell.DropoutWrapper(bw_cell, output_keep_prob=self.dropout_keep_prob) 26 | (self.output_fw, self.output_bw), states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell, 27 | cell_bw=bw_cell, 28 | inputs=self.embedded_chars, 29 | sequence_length=text_length, 30 | dtype=tf.float32) 31 | 32 | with tf.name_scope("context"): 33 | shape = [tf.shape(self.output_fw)[0], 1, tf.shape(self.output_fw)[2]] 34 | self.c_left = tf.concat([tf.zeros(shape), self.output_fw[:, :-1]], axis=1, name="context_left") 35 | self.c_right = tf.concat([self.output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right") 36 | 37 | with tf.name_scope("word-representation"): 38 | self.x = tf.concat([self.c_left, self.embedded_chars, self.c_right], axis=2, name="x") 39 | embedding_size = 2*context_embedding_size + word_embedding_size 40 | 41 | with tf.name_scope("text-representation"): 42 | W2 = tf.Variable(tf.random_uniform([embedding_size, hidden_size], -1.0, 1.0), name="W2") 43 | b2 = tf.Variable(tf.constant(0.1, shape=[hidden_size]), name="b2") 44 | self.y2 = tf.tanh(tf.einsum('aij,jk->aik', self.x, W2) + b2) 45 | 46 | with tf.name_scope("max-pooling"): 47 | self.y3 = tf.reduce_max(self.y2, axis=1) 48 | 49 | with tf.name_scope("output"): 50 | W4 = tf.get_variable("W4", shape=[hidden_size, num_classes], initializer=tf.contrib.layers.xavier_initializer()) 51 | b4 = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b4") 52 | l2_loss += tf.nn.l2_loss(W4) 53 | l2_loss += tf.nn.l2_loss(b4) 54 | self.logits = tf.nn.xw_plus_b(self.y3, W4, b4, name="logits") 55 | self.predictions = tf.argmax(self.logits, 1, name="predictions") 56 | 57 | # Calculate mean cross-entropy loss 58 | with tf.name_scope("loss"): 59 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 60 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 61 | 62 | # Accuracy 63 | with tf.name_scope("accuracy"): 64 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, axis=1)) 65 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name="accuracy") 66 | 67 | @staticmethod 68 | def _get_cell(hidden_size, cell_type): 69 | if cell_type == "vanilla": 70 | return tf.nn.rnn_cell.BasicRNNCell(hidden_size) 71 | elif cell_type == "lstm": 72 | return tf.nn.rnn_cell.BasicLSTMCell(hidden_size) 73 | elif cell_type == "gru": 74 | return tf.nn.rnn_cell.GRUCell(hidden_size) 75 | else: 76 | print("ERROR: '" + cell_type + "' is a wrong cell type !!!") 77 | return None 78 | 79 | # Length of the sequence data 80 | @staticmethod 81 | def _length(seq): 82 | relevant = tf.sign(tf.abs(seq)) 83 | length = tf.reduce_sum(relevant, reduction_indices=1) 84 | length = tf.cast(length, tf.int32) 85 | return length 86 | 87 | # Extract the output of last cell of each sequence 88 | # Ex) The movie is good -> length = 4 89 | # output = [ [1.314, -3.32, ..., 0.98] 90 | # [0.287, -0.50, ..., 1.55] 91 | # [2.194, -2.12, ..., 0.63] 92 | # [1.938, -1.88, ..., 1.31] 93 | # [ 0.0, 0.0, ..., 0.0] 94 | # ... 95 | # [ 0.0, 0.0, ..., 0.0] ] 96 | # The output we need is 4th output of cell, so extract it. 97 | @staticmethod 98 | def last_relevant(seq, length): 99 | batch_size = tf.shape(seq)[0] 100 | max_length = int(seq.get_shape()[1]) 101 | input_size = int(seq.get_shape()[2]) 102 | index = tf.range(0, batch_size) * max_length + (length - 1) 103 | flat = tf.reshape(seq, [-1, input_size]) 104 | return tf.gather(flat, index) 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import datetime 5 | import time 6 | from rcnn import TextRCNN 7 | import data_helpers 8 | 9 | # Parameters 10 | # ================================================== 11 | 12 | # Data loading params 13 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data") 14 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data") 15 | tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation") 16 | tf.flags.DEFINE_integer("max_sentence_length", 50, "Max sentence length in train/test data (Default: 50)") 17 | 18 | # Model Hyperparameters 19 | tf.flags.DEFINE_string("cell_type", "vanilla", "Type of RNN cell. Choose 'vanilla' or 'lstm' or 'gru' (Default: vanilla)") 20 | tf.flags.DEFINE_string("word2vec", None, "Word2vec file with pre-trained embeddings") 21 | tf.flags.DEFINE_integer("word_embedding_dim", 300, "Dimensionality of word embedding (Default: 300)") 22 | tf.flags.DEFINE_integer("context_embedding_dim", 512, "Dimensionality of context embedding(= RNN state size) (Default: 512)") 23 | tf.flags.DEFINE_integer("hidden_size", 512, "Size of hidden layer (Default: 512)") 24 | tf.flags.DEFINE_float("dropout_keep_prob", 0.7, "Dropout keep probability (Default: 0.7)") 25 | tf.flags.DEFINE_float("l2_reg_lambda", 0.5, "L2 regularization lambda (Default: 0.5)") 26 | 27 | # Training parameters 28 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)") 29 | tf.flags.DEFINE_integer("num_epochs", 10, "Number of training epochs (Default: 10)") 30 | tf.flags.DEFINE_integer("display_every", 10, "Number of iterations to display training info.") 31 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps") 32 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps") 33 | tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store") 34 | tf.flags.DEFINE_float("learning_rate", 1e-3, "Which learning rate to start with. (Default: 1e-3)") 35 | 36 | # Misc Parameters 37 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 38 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") 39 | 40 | 41 | FLAGS = tf.flags.FLAGS 42 | FLAGS._parse_flags() 43 | print("\nParameters:") 44 | for attr, value in sorted(FLAGS.__flags.items()): 45 | print("{} = {}".format(attr.upper(), value)) 46 | print("") 47 | 48 | 49 | def train(): 50 | with tf.device('/cpu:0'): 51 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir) 52 | 53 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(FLAGS.max_sentence_length) 54 | x = np.array(list(text_vocab_processor.fit_transform(x_text))) 55 | print("Text Vocabulary Size: {:d}".format(len(text_vocab_processor.vocabulary_))) 56 | 57 | print("x = {0}".format(x.shape)) 58 | print("y = {0}".format(y.shape)) 59 | print("") 60 | 61 | # Randomly shuffle data 62 | np.random.seed(10) 63 | shuffle_indices = np.random.permutation(np.arange(len(y))) 64 | x_shuffled = x[shuffle_indices] 65 | y_shuffled = y[shuffle_indices] 66 | 67 | # Split train/test set 68 | # TODO: This is very crude, should use cross-validation 69 | dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y))) 70 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] 71 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:] 72 | print("Train/Dev split: {:d}/{:d}\n".format(len(y_train), len(y_dev))) 73 | 74 | with tf.Graph().as_default(): 75 | session_conf = tf.ConfigProto( 76 | allow_soft_placement=FLAGS.allow_soft_placement, 77 | log_device_placement=FLAGS.log_device_placement) 78 | sess = tf.Session(config=session_conf) 79 | with sess.as_default(): 80 | rcnn = TextRCNN( 81 | sequence_length=x_train.shape[1], 82 | num_classes=y_train.shape[1], 83 | vocab_size=len(text_vocab_processor.vocabulary_), 84 | word_embedding_size=FLAGS.word_embedding_dim, 85 | context_embedding_size=FLAGS.context_embedding_dim, 86 | cell_type=FLAGS.cell_type, 87 | hidden_size=FLAGS.hidden_size, 88 | l2_reg_lambda=FLAGS.l2_reg_lambda 89 | ) 90 | 91 | # Define Training procedure 92 | global_step = tf.Variable(0, name="global_step", trainable=False) 93 | train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(rcnn.loss, global_step=global_step) 94 | 95 | # Output directory for models and summaries 96 | timestamp = str(int(time.time())) 97 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 98 | print("Writing to {}\n".format(out_dir)) 99 | 100 | # Summaries for loss and accuracy 101 | loss_summary = tf.summary.scalar("loss", rcnn.loss) 102 | acc_summary = tf.summary.scalar("accuracy", rcnn.accuracy) 103 | 104 | # Train Summaries 105 | train_summary_op = tf.summary.merge([loss_summary, acc_summary]) 106 | train_summary_dir = os.path.join(out_dir, "summaries", "train") 107 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 108 | 109 | # Dev summaries 110 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) 111 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 112 | dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) 113 | 114 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 115 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) 116 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 117 | if not os.path.exists(checkpoint_dir): 118 | os.makedirs(checkpoint_dir) 119 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) 120 | 121 | # Write vocabulary 122 | text_vocab_processor.save(os.path.join(out_dir, "text_vocab")) 123 | 124 | # Initialize all variables 125 | sess.run(tf.global_variables_initializer()) 126 | 127 | # Pre-trained word2vec 128 | if FLAGS.word2vec: 129 | # initial matrix with random uniform 130 | initW = np.random.uniform(-0.25, 0.25, (len(text_vocab_processor.vocabulary_), FLAGS.word_embedding_dim)) 131 | # load any vectors from the word2vec 132 | print("Load word2vec file {0}".format(FLAGS.word2vec)) 133 | with open(FLAGS.word2vec, "rb") as f: 134 | header = f.readline() 135 | vocab_size, layer1_size = map(int, header.split()) 136 | binary_len = np.dtype('float32').itemsize * layer1_size 137 | for line in range(vocab_size): 138 | word = [] 139 | while True: 140 | ch = f.read(1).decode('latin-1') 141 | if ch == ' ': 142 | word = ''.join(word) 143 | break 144 | if ch != '\n': 145 | word.append(ch) 146 | idx = text_vocab_processor.vocabulary_.get(word) 147 | if idx != 0: 148 | initW[idx] = np.fromstring(f.read(binary_len), dtype='float32') 149 | else: 150 | f.read(binary_len) 151 | sess.run(rcnn.W_text.assign(initW)) 152 | print("Success to load pre-trained word2vec model!\n") 153 | 154 | # Generate batches 155 | batches = data_helpers.batch_iter( 156 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) 157 | # Training loop. For each batch... 158 | for batch in batches: 159 | x_batch, y_batch = zip(*batch) 160 | # Train 161 | feed_dict = { 162 | rcnn.input_text: x_batch, 163 | rcnn.input_y: y_batch, 164 | rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob 165 | } 166 | _, step, summaries, loss, accuracy = sess.run( 167 | [train_op, global_step, train_summary_op, rcnn.loss, rcnn.accuracy], feed_dict) 168 | train_summary_writer.add_summary(summaries, step) 169 | 170 | # Training log display 171 | if step % FLAGS.display_every == 0: 172 | time_str = datetime.datetime.now().isoformat() 173 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) 174 | 175 | # Evaluation 176 | if step % FLAGS.evaluate_every == 0: 177 | print("\nEvaluation:") 178 | feed_dict_dev = { 179 | rcnn.input_text: x_dev, 180 | rcnn.input_y: y_dev, 181 | rcnn.dropout_keep_prob: 1.0 182 | } 183 | summaries_dev, loss, accuracy = sess.run( 184 | [dev_summary_op, rcnn.loss, rcnn.accuracy], feed_dict_dev) 185 | dev_summary_writer.add_summary(summaries_dev, step) 186 | 187 | time_str = datetime.datetime.now().isoformat() 188 | print("{}: step {}, loss {:g}, acc {:g}\n".format(time_str, step, loss, accuracy)) 189 | 190 | # Model checkpoint 191 | if step % FLAGS.checkpoint_every == 0: 192 | path = saver.save(sess, checkpoint_prefix, global_step=step) 193 | print("Saved model checkpoint to {}\n".format(path)) 194 | 195 | 196 | def main(_): 197 | train() 198 | 199 | 200 | if __name__ == "__main__": 201 | tf.app.run() 202 | --------------------------------------------------------------------------------