├── .gitignore
├── README.md
├── data
└── rt-polaritydata
│ ├── rt-polarity.neg
│ └── rt-polarity.pos
├── data_helpers.py
├── eval.py
├── rcnn.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | w2v_model/
2 | saved/
3 | runs/
4 | **/.idea/
5 | **/__pycache__/
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recurrent Convolutional Neural Network for Text Classification
2 | Tensorflow implementation of "[Recurrent Convolutional Neural Network for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745)".
3 |
4 | 
5 |
6 |
7 | ## Data: Movie Review
8 | * Movie reviews with one sentence per review. Classification involves detecting positive/negative reviews ([Pang and Lee, 2005](#reference)).
9 | * Download "*sentence polarity dataset v1.0*" at the [Official Download Page](http://www.cs.cornell.edu/people/pabo/movie-review-data/).
10 | * Located in *"data/rt-polaritydata/"* in my repository.
11 | * *rt-polarity.pos* contains 5331 positive snippets.
12 | * *rt-polarity.neg* contains 5331 negative snippets.
13 |
14 |
15 | ## Implementation of Recurrent Structure
16 |
17 | 
18 |
19 | * Bidirectional RNN (Bi-RNN) is used to implement the left and right context vectors.
20 | * Each context vector is created by shifting the output of Bi-RNN and concatenating a zero state indicating the start of the context.
21 |
22 |
23 | ## Usage
24 | ### Train
25 | * positive data is located in *"data/rt-polaritydata/rt-polarity.pos"*.
26 | * negative data is located in *"data/rt-polaritydata/rt-polarity.neg"*.
27 | * "[GoogleNews-vectors-negative300](https://code.google.com/archive/p/word2vec/)" is used as pre-trained word2vec model.
28 | * Display help message:
29 |
30 | ```bash
31 | $ python train.py --help
32 | ```
33 |
34 | * **Train Example:**
35 |
36 | ```bash
37 | $ python train.py --cell_type "lstm" \
38 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \
39 | --neg_dir "data/rt-polaritydata/rt-polarity.neg"\
40 | --word2vec "GoogleNews-vectors-negative300.bin"
41 | ```
42 |
43 |
44 | ### Evalutation
45 | * Movie Review dataset has **no test data**.
46 | * If you want to evaluate, you should make test dataset from train data or do cross validation. However, cross validation is not implemented in my project.
47 | * The bellow example just use full rt-polarity dataset same the train dataset.
48 | * **Evaluation Example:**
49 |
50 | ```bash
51 | $ python eval.py \
52 | --pos_dir "data/rt-polaritydata/rt-polarity.pos" \
53 | --neg_dir "data/rt-polaritydata/rt-polarity.neg" \
54 | --checkpoint_dir "runs/1523902663/checkpoints"
55 | ```
56 |
57 |
58 | ## Result
59 | * Comparision between Recurrent Convolutional Neural Network and Convolutional Neural Network.
60 | * dennybritz's [cnn-text-classification-tf](https://github.com/dennybritz/cnn-text-classification-tf) is used for compared CNN model.
61 | * Same pre-trained word2vec used for both models.
62 |
63 | #### Accuracy for validation set
64 | 
65 |
66 | #### Loss for validation set
67 | 
68 |
69 |
70 | ## Reference
71 | * Recurrent Convolutional Neural Network for Text Classification (AAAI 2015), S Lai et al. [[paper]](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/data_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import nltk
4 | import re
5 |
6 |
7 | def clean_str(string):
8 | """
9 | Tokenization/string cleaning for all datasets except for SST.
10 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
11 | """
12 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
13 | string = re.sub(r"\'s", " \'s", string)
14 | string = re.sub(r"\'ve", " \'ve", string)
15 | string = re.sub(r"n\'t", " n\'t", string)
16 | string = re.sub(r"\'re", " \'re", string)
17 | string = re.sub(r"\'d", " \'d", string)
18 | string = re.sub(r"\'ll", " \'ll", string)
19 | string = re.sub(r",", " , ", string)
20 | string = re.sub(r"!", " ! ", string)
21 | string = re.sub(r"\(", " \( ", string)
22 | string = re.sub(r"\)", " \) ", string)
23 | string = re.sub(r"\?", " \? ", string)
24 | string = re.sub(r"\s{2,}", " ", string)
25 | return string.strip().lower()
26 |
27 |
28 | def load_data_and_labels(positive_data_file, negative_data_file):
29 | """
30 | Loads MR polarity data from files, splits the data into words and generates labels.
31 | Returns split sentences and labels.
32 | """
33 | # Load data from files
34 | positive_examples = list(open(positive_data_file, "r", encoding="UTF-8").readlines())
35 | positive_examples = [s.strip() for s in positive_examples]
36 | negative_examples = list(open(negative_data_file, "r", encoding="UTF-8").readlines())
37 | negative_examples = [s.strip() for s in negative_examples]
38 | # Split by words
39 | x_text = positive_examples + negative_examples
40 | x_text = [clean_str(sent) for sent in x_text]
41 | # Generate labels
42 | positive_labels = [[0, 1] for _ in positive_examples]
43 | negative_labels = [[1, 0] for _ in negative_examples]
44 | y = np.concatenate([positive_labels, negative_labels], 0)
45 | return [x_text, y]
46 |
47 |
48 | def batch_iter(data, batch_size, num_epochs, shuffle=True):
49 | """
50 | Generates a batch iterator for a dataset.
51 | """
52 | data = np.array(data)
53 | data_size = len(data)
54 | num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1
55 | for epoch in range(num_epochs):
56 | # Shuffle the data at each epoch
57 | if shuffle:
58 | shuffle_indices = np.random.permutation(np.arange(data_size))
59 | shuffled_data = data[shuffle_indices]
60 | else:
61 | shuffled_data = data
62 | for batch_num in range(num_batches_per_epoch):
63 | start_index = batch_num * batch_size
64 | end_index = min((batch_num + 1) * batch_size, data_size)
65 | yield shuffled_data[start_index:end_index]
66 |
67 |
68 | if __name__ == "__main__":
69 | pos_dir = "data/rt-polaritydata/rt-polarity.pos"
70 | neg_dir = "data/rt-polaritydata/rt-polarity.neg"
71 |
72 | load_data_and_labels(pos_dir, neg_dir)
73 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import data_helpers
5 |
6 |
7 | # Parameters
8 | # ==================================================
9 |
10 | # Data loading params
11 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data")
12 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data")
13 |
14 | # Eval Parameters
15 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)")
16 | tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
17 |
18 | # Misc Parameters
19 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
20 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
21 |
22 |
23 | FLAGS = tf.flags.FLAGS
24 | FLAGS._parse_flags()
25 | print("\nParameters:")
26 | for attr, value in sorted(FLAGS.__flags.items()):
27 | print("{} = {}".format(attr.upper(), value))
28 | print("")
29 |
30 |
31 | def eval():
32 | with tf.device('/cpu:0'):
33 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir)
34 |
35 | # Map data into vocabulary
36 | text_path = os.path.join(FLAGS.checkpoint_dir, "..", "text_vocab")
37 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(text_path)
38 |
39 | x_eval = np.array(list(text_vocab_processor.transform(x_text)))
40 | y_eval = np.argmax(y, axis=1)
41 |
42 | checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
43 |
44 | graph = tf.Graph()
45 | with graph.as_default():
46 | session_conf = tf.ConfigProto(
47 | allow_soft_placement=FLAGS.allow_soft_placement,
48 | log_device_placement=FLAGS.log_device_placement)
49 | sess = tf.Session(config=session_conf)
50 | with sess.as_default():
51 | # Load the saved meta graph and restore variables
52 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
53 | saver.restore(sess, checkpoint_file)
54 |
55 | # Get the placeholders from the graph by name
56 | input_text = graph.get_operation_by_name("input_text").outputs[0]
57 | # input_y = graph.get_operation_by_name("input_y").outputs[0]
58 | dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
59 |
60 | # Tensors we want to evaluate
61 | predictions = graph.get_operation_by_name("output/predictions").outputs[0]
62 |
63 | # Generate batches for one epoch
64 | batches = data_helpers.batch_iter(list(x_eval), FLAGS.batch_size, 1, shuffle=False)
65 |
66 | # Collect the predictions here
67 | all_predictions = []
68 | for x_batch in batches:
69 | batch_predictions = sess.run(predictions, {input_text: x_batch,
70 | dropout_keep_prob: 1.0})
71 | all_predictions = np.concatenate([all_predictions, batch_predictions])
72 |
73 | correct_predictions = float(sum(all_predictions == y_eval))
74 | print("Total number of test examples: {}".format(len(y_eval)))
75 | print("Accuracy: {:g}".format(correct_predictions / float(len(y_eval))))
76 |
77 |
78 | def main(_):
79 | eval()
80 |
81 |
82 | if __name__ == "__main__":
83 | tf.app.run()
--------------------------------------------------------------------------------
/rcnn.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class TextRCNN:
5 | def __init__(self, sequence_length, num_classes, vocab_size, word_embedding_size, context_embedding_size,
6 | cell_type, hidden_size, l2_reg_lambda=0.0):
7 | # Placeholders for input, output and dropout
8 | self.input_text = tf.placeholder(tf.int32, shape=[None, sequence_length], name='input_text')
9 | self.input_y = tf.placeholder(tf.float32, shape=[None, num_classes], name='input_y')
10 | self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')
11 |
12 | l2_loss = tf.constant(0.0)
13 | text_length = self._length(self.input_text)
14 |
15 | # Embeddings
16 | with tf.device('/cpu:0'), tf.name_scope("embedding"):
17 | self.W_text = tf.Variable(tf.random_uniform([vocab_size, word_embedding_size], -1.0, 1.0), name="W_text")
18 | self.embedded_chars = tf.nn.embedding_lookup(self.W_text, self.input_text)
19 |
20 | # Bidirectional(Left&Right) Recurrent Structure
21 | with tf.name_scope("bi-rnn"):
22 | fw_cell = self._get_cell(context_embedding_size, cell_type)
23 | fw_cell = tf.nn.rnn_cell.DropoutWrapper(fw_cell, output_keep_prob=self.dropout_keep_prob)
24 | bw_cell = self._get_cell(context_embedding_size, cell_type)
25 | bw_cell = tf.nn.rnn_cell.DropoutWrapper(bw_cell, output_keep_prob=self.dropout_keep_prob)
26 | (self.output_fw, self.output_bw), states = tf.nn.bidirectional_dynamic_rnn(cell_fw=fw_cell,
27 | cell_bw=bw_cell,
28 | inputs=self.embedded_chars,
29 | sequence_length=text_length,
30 | dtype=tf.float32)
31 |
32 | with tf.name_scope("context"):
33 | shape = [tf.shape(self.output_fw)[0], 1, tf.shape(self.output_fw)[2]]
34 | self.c_left = tf.concat([tf.zeros(shape), self.output_fw[:, :-1]], axis=1, name="context_left")
35 | self.c_right = tf.concat([self.output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right")
36 |
37 | with tf.name_scope("word-representation"):
38 | self.x = tf.concat([self.c_left, self.embedded_chars, self.c_right], axis=2, name="x")
39 | embedding_size = 2*context_embedding_size + word_embedding_size
40 |
41 | with tf.name_scope("text-representation"):
42 | W2 = tf.Variable(tf.random_uniform([embedding_size, hidden_size], -1.0, 1.0), name="W2")
43 | b2 = tf.Variable(tf.constant(0.1, shape=[hidden_size]), name="b2")
44 | self.y2 = tf.tanh(tf.einsum('aij,jk->aik', self.x, W2) + b2)
45 |
46 | with tf.name_scope("max-pooling"):
47 | self.y3 = tf.reduce_max(self.y2, axis=1)
48 |
49 | with tf.name_scope("output"):
50 | W4 = tf.get_variable("W4", shape=[hidden_size, num_classes], initializer=tf.contrib.layers.xavier_initializer())
51 | b4 = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b4")
52 | l2_loss += tf.nn.l2_loss(W4)
53 | l2_loss += tf.nn.l2_loss(b4)
54 | self.logits = tf.nn.xw_plus_b(self.y3, W4, b4, name="logits")
55 | self.predictions = tf.argmax(self.logits, 1, name="predictions")
56 |
57 | # Calculate mean cross-entropy loss
58 | with tf.name_scope("loss"):
59 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
60 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
61 |
62 | # Accuracy
63 | with tf.name_scope("accuracy"):
64 | correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, axis=1))
65 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name="accuracy")
66 |
67 | @staticmethod
68 | def _get_cell(hidden_size, cell_type):
69 | if cell_type == "vanilla":
70 | return tf.nn.rnn_cell.BasicRNNCell(hidden_size)
71 | elif cell_type == "lstm":
72 | return tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
73 | elif cell_type == "gru":
74 | return tf.nn.rnn_cell.GRUCell(hidden_size)
75 | else:
76 | print("ERROR: '" + cell_type + "' is a wrong cell type !!!")
77 | return None
78 |
79 | # Length of the sequence data
80 | @staticmethod
81 | def _length(seq):
82 | relevant = tf.sign(tf.abs(seq))
83 | length = tf.reduce_sum(relevant, reduction_indices=1)
84 | length = tf.cast(length, tf.int32)
85 | return length
86 |
87 | # Extract the output of last cell of each sequence
88 | # Ex) The movie is good -> length = 4
89 | # output = [ [1.314, -3.32, ..., 0.98]
90 | # [0.287, -0.50, ..., 1.55]
91 | # [2.194, -2.12, ..., 0.63]
92 | # [1.938, -1.88, ..., 1.31]
93 | # [ 0.0, 0.0, ..., 0.0]
94 | # ...
95 | # [ 0.0, 0.0, ..., 0.0] ]
96 | # The output we need is 4th output of cell, so extract it.
97 | @staticmethod
98 | def last_relevant(seq, length):
99 | batch_size = tf.shape(seq)[0]
100 | max_length = int(seq.get_shape()[1])
101 | input_size = int(seq.get_shape()[2])
102 | index = tf.range(0, batch_size) * max_length + (length - 1)
103 | flat = tf.reshape(seq, [-1, input_size])
104 | return tf.gather(flat, index)
105 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import os
4 | import datetime
5 | import time
6 | from rcnn import TextRCNN
7 | import data_helpers
8 |
9 | # Parameters
10 | # ==================================================
11 |
12 | # Data loading params
13 | tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data")
14 | tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data")
15 | tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
16 | tf.flags.DEFINE_integer("max_sentence_length", 50, "Max sentence length in train/test data (Default: 50)")
17 |
18 | # Model Hyperparameters
19 | tf.flags.DEFINE_string("cell_type", "vanilla", "Type of RNN cell. Choose 'vanilla' or 'lstm' or 'gru' (Default: vanilla)")
20 | tf.flags.DEFINE_string("word2vec", None, "Word2vec file with pre-trained embeddings")
21 | tf.flags.DEFINE_integer("word_embedding_dim", 300, "Dimensionality of word embedding (Default: 300)")
22 | tf.flags.DEFINE_integer("context_embedding_dim", 512, "Dimensionality of context embedding(= RNN state size) (Default: 512)")
23 | tf.flags.DEFINE_integer("hidden_size", 512, "Size of hidden layer (Default: 512)")
24 | tf.flags.DEFINE_float("dropout_keep_prob", 0.7, "Dropout keep probability (Default: 0.7)")
25 | tf.flags.DEFINE_float("l2_reg_lambda", 0.5, "L2 regularization lambda (Default: 0.5)")
26 |
27 | # Training parameters
28 | tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)")
29 | tf.flags.DEFINE_integer("num_epochs", 10, "Number of training epochs (Default: 10)")
30 | tf.flags.DEFINE_integer("display_every", 10, "Number of iterations to display training info.")
31 | tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps")
32 | tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps")
33 | tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store")
34 | tf.flags.DEFINE_float("learning_rate", 1e-3, "Which learning rate to start with. (Default: 1e-3)")
35 |
36 | # Misc Parameters
37 | tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
38 | tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
39 |
40 |
41 | FLAGS = tf.flags.FLAGS
42 | FLAGS._parse_flags()
43 | print("\nParameters:")
44 | for attr, value in sorted(FLAGS.__flags.items()):
45 | print("{} = {}".format(attr.upper(), value))
46 | print("")
47 |
48 |
49 | def train():
50 | with tf.device('/cpu:0'):
51 | x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir)
52 |
53 | text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(FLAGS.max_sentence_length)
54 | x = np.array(list(text_vocab_processor.fit_transform(x_text)))
55 | print("Text Vocabulary Size: {:d}".format(len(text_vocab_processor.vocabulary_)))
56 |
57 | print("x = {0}".format(x.shape))
58 | print("y = {0}".format(y.shape))
59 | print("")
60 |
61 | # Randomly shuffle data
62 | np.random.seed(10)
63 | shuffle_indices = np.random.permutation(np.arange(len(y)))
64 | x_shuffled = x[shuffle_indices]
65 | y_shuffled = y[shuffle_indices]
66 |
67 | # Split train/test set
68 | # TODO: This is very crude, should use cross-validation
69 | dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
70 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
71 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
72 | print("Train/Dev split: {:d}/{:d}\n".format(len(y_train), len(y_dev)))
73 |
74 | with tf.Graph().as_default():
75 | session_conf = tf.ConfigProto(
76 | allow_soft_placement=FLAGS.allow_soft_placement,
77 | log_device_placement=FLAGS.log_device_placement)
78 | sess = tf.Session(config=session_conf)
79 | with sess.as_default():
80 | rcnn = TextRCNN(
81 | sequence_length=x_train.shape[1],
82 | num_classes=y_train.shape[1],
83 | vocab_size=len(text_vocab_processor.vocabulary_),
84 | word_embedding_size=FLAGS.word_embedding_dim,
85 | context_embedding_size=FLAGS.context_embedding_dim,
86 | cell_type=FLAGS.cell_type,
87 | hidden_size=FLAGS.hidden_size,
88 | l2_reg_lambda=FLAGS.l2_reg_lambda
89 | )
90 |
91 | # Define Training procedure
92 | global_step = tf.Variable(0, name="global_step", trainable=False)
93 | train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(rcnn.loss, global_step=global_step)
94 |
95 | # Output directory for models and summaries
96 | timestamp = str(int(time.time()))
97 | out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
98 | print("Writing to {}\n".format(out_dir))
99 |
100 | # Summaries for loss and accuracy
101 | loss_summary = tf.summary.scalar("loss", rcnn.loss)
102 | acc_summary = tf.summary.scalar("accuracy", rcnn.accuracy)
103 |
104 | # Train Summaries
105 | train_summary_op = tf.summary.merge([loss_summary, acc_summary])
106 | train_summary_dir = os.path.join(out_dir, "summaries", "train")
107 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
108 |
109 | # Dev summaries
110 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
111 | dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
112 | dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
113 |
114 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
115 | checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
116 | checkpoint_prefix = os.path.join(checkpoint_dir, "model")
117 | if not os.path.exists(checkpoint_dir):
118 | os.makedirs(checkpoint_dir)
119 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
120 |
121 | # Write vocabulary
122 | text_vocab_processor.save(os.path.join(out_dir, "text_vocab"))
123 |
124 | # Initialize all variables
125 | sess.run(tf.global_variables_initializer())
126 |
127 | # Pre-trained word2vec
128 | if FLAGS.word2vec:
129 | # initial matrix with random uniform
130 | initW = np.random.uniform(-0.25, 0.25, (len(text_vocab_processor.vocabulary_), FLAGS.word_embedding_dim))
131 | # load any vectors from the word2vec
132 | print("Load word2vec file {0}".format(FLAGS.word2vec))
133 | with open(FLAGS.word2vec, "rb") as f:
134 | header = f.readline()
135 | vocab_size, layer1_size = map(int, header.split())
136 | binary_len = np.dtype('float32').itemsize * layer1_size
137 | for line in range(vocab_size):
138 | word = []
139 | while True:
140 | ch = f.read(1).decode('latin-1')
141 | if ch == ' ':
142 | word = ''.join(word)
143 | break
144 | if ch != '\n':
145 | word.append(ch)
146 | idx = text_vocab_processor.vocabulary_.get(word)
147 | if idx != 0:
148 | initW[idx] = np.fromstring(f.read(binary_len), dtype='float32')
149 | else:
150 | f.read(binary_len)
151 | sess.run(rcnn.W_text.assign(initW))
152 | print("Success to load pre-trained word2vec model!\n")
153 |
154 | # Generate batches
155 | batches = data_helpers.batch_iter(
156 | list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
157 | # Training loop. For each batch...
158 | for batch in batches:
159 | x_batch, y_batch = zip(*batch)
160 | # Train
161 | feed_dict = {
162 | rcnn.input_text: x_batch,
163 | rcnn.input_y: y_batch,
164 | rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob
165 | }
166 | _, step, summaries, loss, accuracy = sess.run(
167 | [train_op, global_step, train_summary_op, rcnn.loss, rcnn.accuracy], feed_dict)
168 | train_summary_writer.add_summary(summaries, step)
169 |
170 | # Training log display
171 | if step % FLAGS.display_every == 0:
172 | time_str = datetime.datetime.now().isoformat()
173 | print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
174 |
175 | # Evaluation
176 | if step % FLAGS.evaluate_every == 0:
177 | print("\nEvaluation:")
178 | feed_dict_dev = {
179 | rcnn.input_text: x_dev,
180 | rcnn.input_y: y_dev,
181 | rcnn.dropout_keep_prob: 1.0
182 | }
183 | summaries_dev, loss, accuracy = sess.run(
184 | [dev_summary_op, rcnn.loss, rcnn.accuracy], feed_dict_dev)
185 | dev_summary_writer.add_summary(summaries_dev, step)
186 |
187 | time_str = datetime.datetime.now().isoformat()
188 | print("{}: step {}, loss {:g}, acc {:g}\n".format(time_str, step, loss, accuracy))
189 |
190 | # Model checkpoint
191 | if step % FLAGS.checkpoint_every == 0:
192 | path = saver.save(sess, checkpoint_prefix, global_step=step)
193 | print("Saved model checkpoint to {}\n".format(path))
194 |
195 |
196 | def main(_):
197 | train()
198 |
199 |
200 | if __name__ == "__main__":
201 | tf.app.run()
202 |
--------------------------------------------------------------------------------