├── .gitignore
├── README.md
├── config
├── check_tiny.yml
├── kaggle_movie_review.yml
└── rt-polarity.yml
├── data
├── kaggle_movie_reviews
│ ├── test.tsv.zip
│ └── train.tsv.zip
├── rt-polaritydata
│ ├── rt-polarity.neg
│ └── rt-polarity.pos
└── tiny_processed_data
│ ├── test_X_ids
│ ├── test_y
│ ├── train_X_ids
│ ├── train_y
│ └── vocab
├── data_loader.py
├── hook.py
├── images
├── category.png
├── figure-1.png
├── kaggle-loss_and_accuracy.jpg
└── rt-polarity_loss_and_accuracy.jpeg
├── main.py
├── model.py
├── predict.py
├── requirements.txt
├── scripts
└── prepare_kaggle_movie_reviews.sh
├── text_cnn
└── __init__.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | logs/
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # text-cnn [](https://github.com/hb-research)
2 |
3 | This code implements [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) models.
4 |
5 | - Figure 1: Illustration of a CNN architecture for sentence classification
6 |
7 | 
8 |
9 |
10 | ## Requirements
11 |
12 | - Python 3.6
13 | - TensorFlow 1.4
14 | - [hb-config](https://github.com/hb-research/hb-config) (Singleton Config)
15 | - tqdm
16 | - requests
17 | - [Slack Incoming Webhook URL](https://my.slack.com/services/new/incoming-webhook/)
18 |
19 | ## Project Structure
20 |
21 | init Project by [hb-base](https://github.com/hb-research/hb-base)
22 |
23 | .
24 | ├── config # Config files (.yml, .json) using with hb-config
25 | ├── data # dataset path
26 | ├── notebooks # Prototyping with numpy or tf.interactivesession
27 | ├── scripts # download or prepare dataset using shell scripts
28 | ├── text-cnn # text-cnn architecture graphs (from input to logits)
29 | ├── __init__.py # Graph logic
30 | ├── data_loader.py # raw_date -> precossed_data -> generate_batch (using Dataset)
31 | ├── hook.py # training or test hook feature (eg. print_variables)
32 | ├── main.py # define experiment_fn
33 | ├── model.py # define EstimatorSpec
34 | └── predict.py # test trained model
35 |
36 | Reference : [hb-config](https://github.com/hb-research/hb-config), [Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator), [experiments_fn](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Experiment), [EstimatorSpec](https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec)
37 |
38 | - Dataset : [rt-polarity](https://github.com/yoonkim/CNN_sentence), [Sentiment Analysis on Movie Reviews](https://www.kaggle.com/c/sentiment-analysis-on-movie-reviews/data)
39 |
40 | ## Todo
41 |
42 | - apply embed_type
43 | - CNN-rand
44 | - CNN-static
45 | - CNN-nonstatic
46 | - CNN-multichannel
47 |
48 | ## Config
49 |
50 | example: kaggle\_movie\_review.yml
51 |
52 | ```yml
53 | data:
54 | type: 'kaggle_movie_review'
55 | base_path: 'data/'
56 | raw_data_path: 'kaggle_movie_reviews/'
57 | processed_path: 'kaggle_processed_data'
58 | testset_size: 25000
59 | num_classes: 5
60 | PAD_ID: 0
61 |
62 | model:
63 | batch_size: 64
64 | embed_type: 'rand' #(rand, static, non-static, multichannel)
65 | pretrained_embed: ""
66 | embed_dim: 300
67 | num_filters: 256
68 | filter_sizes:
69 | - 2
70 | - 3
71 | - 4
72 | - 5
73 | dropout: 0.5
74 |
75 | train:
76 | learning_rate: 0.00005
77 |
78 | train_steps: 100000
79 | model_dir: 'logs/kaggle_movie_review'
80 |
81 | save_checkpoints_steps: 1000
82 | loss_hook_n_iter: 1000
83 | check_hook_n_iter: 1000
84 | min_eval_frequency: 1000
85 |
86 | slack:
87 | webhook_url: "" # after training notify you using slack-webhook
88 | ```
89 |
90 |
91 | ## Usage
92 |
93 | Install requirements.
94 |
95 | ```pip install -r requirements.txt```
96 |
97 | Then, prepare dataset and train it.
98 |
99 | ```
100 | sh prepare_kaggle_movie_reviews.sh
101 | python main.py --config kaggle_movie_review --mode train_and_evaluate
102 | ```
103 |
104 | After training, you can try typing the sentences what you want using `predict.py`.
105 |
106 | ```python python predict.py --config rt-polarity```
107 |
108 | Predict example
109 |
110 | ```
111 | python predict.py --config rt-polarity
112 | Setting max_seq_length to Config : 62
113 | load vocab ...
114 | Typing anything :)
115 |
116 | > good
117 | 1
118 | > bad
119 | 0
120 | ```
121 |
122 | ### Experiments modes
123 |
124 | :white_check_mark: : Working
125 | :white_medium_small_square: : Not tested yet.
126 |
127 | - :white_check_mark: `evaluate` : Evaluate on the evaluation data.
128 | - :white_medium_small_square: `extend_train_hooks` : Extends the hooks for training.
129 | - :white_medium_small_square: `reset_export_strategies` : Resets the export strategies with the new_export_strategies.
130 | - :white_medium_small_square: `run_std_server` : Starts a TensorFlow server and joins the serving thread.
131 | - :white_medium_small_square: `test` : Tests training, evaluating and exporting the estimator for a single step.
132 | - :white_check_mark: `train` : Fit the estimator using the training data.
133 | - :white_check_mark: `train_and_evaluate` : Interleaves training and evaluation.
134 |
135 |
136 | ### Tensorboard
137 |
138 | ```tensorboard --logdir logs```
139 |
140 | - Category Color
141 |
142 | 
143 |
144 | - rt-polarity (binary classification)
145 |
146 | 
147 |
148 | - kaggle_movie_review (multiclass classification)
149 |
150 | 
151 |
152 |
153 | ## Reference
154 |
155 | - [Implementing a CNN for Text Classification in TensorFlow](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/) by Denny Britz
156 | - [Paper - Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) (2014) by Y Kim
157 | - [Paper - A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1510.03820.pdf) (2015) Y Zhang
158 |
--------------------------------------------------------------------------------
/config/check_tiny.yml:
--------------------------------------------------------------------------------
1 | data:
2 | base_path: 'data/'
3 | processed_path: 'tiny_processed_data'
4 | testset_size: 2
5 | num_classes: 2
6 | PAD_ID: 0
7 |
8 | model:
9 | batch_size: 3
10 | embed_type: 'rand'
11 | pretrained_embed: ""
12 | embed_dim: 32
13 | num_filters: 16
14 | filter_sizes:
15 | - 2
16 | - 3
17 | - 4
18 | dropout: 0.5
19 |
20 | train:
21 | learning_rate: 0.001
22 |
23 | train_steps: 100
24 | model_dir: 'logs/rt-check_tiny'
25 |
26 | save_checkpoints_steps: 100
27 | loss_hook_n_iter: 100
28 | check_hook_n_iter: 10
29 | min_eval_frequency: 10
30 |
31 | print_verbose: True
32 | debug: False
33 |
34 | slack:
35 | webhook_url: ""
36 |
--------------------------------------------------------------------------------
/config/kaggle_movie_review.yml:
--------------------------------------------------------------------------------
1 | data:
2 | type: 'kaggle_movie_review'
3 | base_path: 'data/'
4 | raw_data_path: 'kaggle_movie_reviews/'
5 | processed_path: 'kaggle_processed_data'
6 | testset_size: 25000
7 | num_classes: 5
8 | PAD_ID: 0
9 |
10 | model:
11 | batch_size: 64
12 | embed_type: 'rand' #(rand, static, non-static, multichannel)
13 | pretrained_embed: ""
14 | embed_dim: 300
15 | num_filters: 256
16 | filter_sizes:
17 | - 2
18 | - 3
19 | - 4
20 | - 5
21 | dropout: 0.5
22 |
23 | train:
24 | learning_rate: 0.00005
25 |
26 | train_steps: 100000
27 | model_dir: 'logs/kaggle_movie_review'
28 |
29 | save_checkpoints_steps: 1000
30 | loss_hook_n_iter: 1000
31 | check_hook_n_iter: 1000
32 | min_eval_frequency: 1000
33 |
34 | print_verbose: True
35 | debug: False
36 |
37 | slack:
38 | webhook_url: ""
39 |
--------------------------------------------------------------------------------
/config/rt-polarity.yml:
--------------------------------------------------------------------------------
1 | data:
2 | type: 'rt-polarity'
3 | base_path: 'data/'
4 | raw_data_path: 'rt-polaritydata/'
5 | processed_path: 'rt-polarity_processed_data'
6 | testset_size: 2000
7 | num_classes: 2
8 | PAD_ID: 0
9 |
10 | model:
11 | batch_size: 64
12 | embed_type: 'rand' #(rand, static, non-static, multichannel)
13 | pretrained_embed: ""
14 | embed_dim: 300
15 | num_filters: 256
16 | filter_sizes:
17 | - 2
18 | - 3
19 | - 4
20 | - 5
21 | dropout: 0.5
22 |
23 | train:
24 | learning_rate: 0.00001
25 |
26 | train_steps: 20000
27 | model_dir: 'logs/rt-polarity'
28 |
29 | save_checkpoints_steps: 100
30 | loss_hook_n_iter: 100
31 | check_hook_n_iter: 100
32 | min_eval_frequency: 100
33 |
34 | print_verbose: True
35 | debug: False
36 |
37 | slack:
38 | webhook_url: ""
39 |
--------------------------------------------------------------------------------
/data/kaggle_movie_reviews/test.tsv.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/data/kaggle_movie_reviews/test.tsv.zip
--------------------------------------------------------------------------------
/data/kaggle_movie_reviews/train.tsv.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/data/kaggle_movie_reviews/train.tsv.zip
--------------------------------------------------------------------------------
/data/tiny_processed_data/test_X_ids:
--------------------------------------------------------------------------------
1 | 1 2 3 4 5 6 7 8 9 10
2 | 11 12 13 1 4 15 16
3 | 17 18 19 20
4 | 21 22 23 24 25 26
--------------------------------------------------------------------------------
/data/tiny_processed_data/test_y:
--------------------------------------------------------------------------------
1 | 1
2 | 0
3 | 1
4 | 0
--------------------------------------------------------------------------------
/data/tiny_processed_data/train_X_ids:
--------------------------------------------------------------------------------
1 | 1 2 3 4 5 6 7 8 9 10
2 | 11 12 13 1 4 15 16
3 | 17 18 19 20
4 | 21 22 23 24 25 26
--------------------------------------------------------------------------------
/data/tiny_processed_data/train_y:
--------------------------------------------------------------------------------
1 | 1
2 | 0
3 | 1
4 | 0
--------------------------------------------------------------------------------
/data/tiny_processed_data/vocab:
--------------------------------------------------------------------------------
1 | PAD
2 | v1
3 | v2
4 | v3
5 | v4
6 | v5
7 | v6
8 | v7
9 | v8
10 | v9
11 | v10
12 | v11
13 | v12
14 | v13
15 | v14
16 | v15
17 | v16
18 | v17
19 | v18
20 | v19
21 | v20
22 | v21
23 | v22
24 | v23
25 | v24
26 | v25
27 | v26
28 | v27
29 | v28
30 | v29
31 | v30
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 |
4 | import argparse
5 | import csv
6 | import os
7 | import random
8 | import re
9 |
10 | import numpy as np
11 | from hbconfig import Config
12 | import tensorflow as tf
13 | from tqdm import tqdm
14 |
15 |
16 | def clean_str(string):
17 | """
18 | Tokenization/string cleaning for all datasets except for SST.
19 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
20 | """
21 | string = string.decode('utf-8')
22 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
23 | string = re.sub(r"\'s", " \'s", string)
24 | string = re.sub(r"\'ve", " \'ve", string)
25 | string = re.sub(r"n\'t", " n\'t", string)
26 | string = re.sub(r"\'re", " \'re", string)
27 | string = re.sub(r"\'d", " \'d", string)
28 | string = re.sub(r"\'ll", " \'ll", string)
29 | string = re.sub(r",", " , ", string)
30 | string = re.sub(r"!", " ! ", string)
31 | string = re.sub(r"\(", " \( ", string)
32 | string = re.sub(r"\)", " \) ", string)
33 | string = re.sub(r"\?", " \? ", string)
34 | string = re.sub(r"\s{2,}", " ", string)
35 | return string.strip().lower()
36 |
37 |
38 | def load_data_and_labels(positive_data_file, negative_data_file):
39 | """
40 | Loads MR polarity data from files, splits the data into words and generates labels.
41 | Returns split sentences and labels.
42 | """
43 | # Load data from files
44 | positive_examples = list(open(positive_data_file, "rb").readlines())
45 | positive_examples = [s.strip() for s in positive_examples]
46 | negative_examples = list(open(negative_data_file, "rb").readlines())
47 | negative_examples = [s.strip() for s in negative_examples]
48 | # Split by words
49 | x_text = positive_examples + negative_examples
50 | x_text = [clean_str(sent) for sent in x_text]
51 | # Generate labels
52 | positive_labels = ['1' for _ in positive_examples]
53 | negative_labels = ['0' for _ in negative_examples]
54 | y = positive_labels + negative_labels
55 | return x_text, y
56 |
57 |
58 | def prepare_raw_data():
59 | print('Preparing raw data into train set and test set ...')
60 | raw_data_path = os.path.join(Config.data.base_path, Config.data.raw_data_path)
61 |
62 | data_type = Config.data.type
63 | if data_type == "kaggle_movie_review":
64 | train_path = os.path.join(raw_data_path, 'train.tsv')
65 | train_reader = csv.reader(open(train_path), delimiter="\t")
66 |
67 | prepare_dataset(dataset=list(train_reader))
68 |
69 | elif data_type == "rt-polarity":
70 | pos_path = os.path.join(Config.data.base_path, Config.data.raw_data_path, "rt-polarity.pos")
71 | neg_path = os.path.join(Config.data.base_path, Config.data.raw_data_path, "rt-polarity.neg")
72 | x_text, y = load_data_and_labels(pos_path, neg_path)
73 |
74 | prepare_dataset(x_text=x_text, y=y)
75 |
76 |
77 | def prepare_dataset(dataset=None, x_text=None, y=None):
78 | make_dir(os.path.join(Config.data.base_path, Config.data.processed_path))
79 |
80 | filenames = ['train_X', 'train_y', 'test_X', 'test_y']
81 | files = []
82 | for filename in filenames:
83 | files.append(open(os.path.join(Config.data.base_path, Config.data.processed_path, filename), 'wb'))
84 |
85 | if dataset is not None:
86 |
87 | print("Total data length : ", len(dataset))
88 | test_ids = random.sample([i for i in range(len(dataset))], Config.data.testset_size)
89 |
90 | for i in tqdm(range(len(dataset))):
91 | if i == 0:
92 | continue
93 |
94 | data = dataset[i]
95 | X, y = data[2], data[3]
96 |
97 | if i in test_ids:
98 | files[2].write((X + "\n").encode('utf-8'))
99 | files[3].write((y + '\n').encode('utf-8'))
100 | else:
101 | files[0].write((X + '\n').encode('utf-8'))
102 | files[1].write((y + '\n').encode('utf-8'))
103 |
104 | else:
105 |
106 | print("Total data length : ", len(y))
107 | test_ids = random.sample([i for i in range(len(y))], Config.data.testset_size)
108 |
109 | for i in tqdm(range(len(y))):
110 | if i in test_ids:
111 | files[2].write((x_text[i] + "\n").encode('utf-8'))
112 | files[3].write((y[i] + '\n').encode('utf-8'))
113 | else:
114 | files[0].write((x_text[i] + '\n').encode('utf-8'))
115 | files[1].write((y[i] + '\n').encode('utf-8'))
116 |
117 | for file in files:
118 | file.close()
119 |
120 |
121 | def make_dir(path):
122 | """ Create a directory if there isn't one already. """
123 | try:
124 | os.mkdir(path)
125 | except OSError:
126 | pass
127 |
128 |
129 | def basic_tokenizer(line, normalize_digits=True):
130 | """ A basic tokenizer to tokenize text into tokens.
131 | Feel free to change this to suit your need. """
132 | line = re.sub('', '', line)
133 | line = re.sub('', '', line)
134 | line = re.sub('\[', '', line)
135 | line = re.sub('\]', '', line)
136 | words = []
137 | _WORD_SPLIT = re.compile("([.,!?\"'-<>:;)(])")
138 | _DIGIT_RE = re.compile(r"\d")
139 | for fragment in line.strip().lower().split():
140 | for token in re.split(_WORD_SPLIT, fragment):
141 | if not token:
142 | continue
143 | if normalize_digits:
144 | token = re.sub(_DIGIT_RE, '#', token)
145 | words.append(token)
146 | return words
147 |
148 |
149 | def build_vocab(train_fname, test_fname, normalize_digits=True):
150 | vocab = {}
151 | def count_vocab(fname):
152 | with open(fname, 'rb') as f:
153 | for line in f.readlines():
154 | line = line.decode('utf-8')
155 | for token in basic_tokenizer(line):
156 | if not token in vocab:
157 | vocab[token] = 0
158 | vocab[token] += 1
159 |
160 | train_path = os.path.join(Config.data.base_path, Config.data.processed_path, train_fname)
161 | test_path = os.path.join(Config.data.base_path, Config.data.processed_path, test_fname)
162 |
163 | count_vocab(train_path)
164 | count_vocab(test_path)
165 |
166 | sorted_vocab = sorted(vocab, key=vocab.get, reverse=True)
167 |
168 | dest_path = os.path.join(Config.data.base_path, Config.data.processed_path, 'vocab')
169 | with open(dest_path, 'wb') as f:
170 | f.write(('' + '\n').encode('utf-8'))
171 | index = 1
172 | for word in sorted_vocab:
173 | f.write((word + '\n').encode('utf-8'))
174 | index += 1
175 |
176 |
177 | def load_vocab(vocab_fname):
178 | print("load vocab ...")
179 | with open(os.path.join(Config.data.base_path, Config.data.processed_path, vocab_fname), 'rb') as f:
180 | words = f.read().decode('utf-8').splitlines()
181 | return {words[i]: i for i in range(len(words))}
182 |
183 |
184 | def sentence2id(vocab, line):
185 | return [vocab.get(token, vocab['']) for token in basic_tokenizer(line)]
186 |
187 |
188 | def token2id(data):
189 | """ Convert all the tokens in the data into their corresponding
190 | index in the vocabulary. """
191 | vocab_path = 'vocab'
192 | in_path = data
193 | out_path = data + '_ids'
194 |
195 | vocab = load_vocab(vocab_path)
196 | in_file = open(os.path.join(Config.data.base_path, Config.data.processed_path, in_path), 'rb')
197 | out_file = open(os.path.join(Config.data.base_path, Config.data.processed_path, out_path), 'wb')
198 |
199 | lines = in_file.read().decode('utf-8').splitlines()
200 | for line in lines:
201 | ids = []
202 | sentence_ids = sentence2id(vocab, line)
203 | ids.extend(sentence_ids)
204 |
205 | out_file.write(b' '.join(str(id_).encode('utf-8') for id_ in ids) + b'\n')
206 |
207 |
208 | def process_data():
209 | print('Preparing data to be model-ready ...')
210 |
211 | build_vocab('train_X', 'test_X')
212 |
213 | token2id('train_X')
214 | token2id('test_X')
215 |
216 |
217 | def make_train_and_test_set(shuffle=True):
218 | print("make Training data and Test data Start....")
219 |
220 | if Config.data.get('max_seq_length', None) is None:
221 | set_max_seq_length(['train_X_ids', 'test_X_ids'])
222 |
223 | train_X, train_y = load_data('train_X_ids', 'train_y')
224 | test_X, test_y = load_data('test_X_ids', 'test_y')
225 |
226 | assert len(train_X) == len(train_y)
227 | assert len(test_X) == len(test_y)
228 |
229 | print(f"train data count : {len(train_y)}")
230 | print(f"test data count : {len(test_y)}")
231 |
232 | if shuffle:
233 | print("shuffle dataset ...")
234 | train_p = np.random.permutation(len(train_y))
235 | test_p = np.random.permutation(len(test_y))
236 |
237 | return ((train_X[train_p], train_y[train_p]),
238 | (test_X[test_p], test_y[test_p]))
239 | else:
240 | return ((train_X, train_y),
241 | (test_X, test_y))
242 |
243 |
244 | def load_data(X_fname, y_fname):
245 | X_input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, X_fname), 'r')
246 | y_input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, y_fname), 'r')
247 |
248 | X_data, y_data = [], []
249 | for X_line, y_line in zip(X_input_data.readlines(), y_input_data.readlines()):
250 | X_ids = [int(id_) for id_ in X_line.split()]
251 | y_id = int(y_line)
252 |
253 | if len(X_ids) == 0 or y_id >= Config.data.num_classes:
254 | continue
255 |
256 | if len(X_ids) <= Config.data.max_seq_length:
257 | X_data.append(_pad_input(X_ids, Config.data.max_seq_length))
258 |
259 | y_one_hot = np.zeros(Config.data.num_classes)
260 | y_one_hot[int(y_line)] = 1
261 | y_data.append(y_one_hot)
262 |
263 | print(f"load data from {X_fname}, {y_fname}...")
264 | return np.array(X_data, dtype=np.int32), np.array(y_data, dtype=np.int32)
265 |
266 |
267 | def _pad_input(input_, size):
268 | return input_ + [0] * (size - len(input_))
269 |
270 |
271 | def set_max_seq_length(dataset_fnames):
272 |
273 | max_seq_length = Config.data.get('max_seq_length', 10)
274 |
275 | for fname in dataset_fnames:
276 | input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, fname), 'r')
277 |
278 | for line in input_data.readlines():
279 | ids = [int(id_) for id_ in line.split()]
280 | seq_length = len(ids)
281 |
282 | if seq_length > max_seq_length:
283 | max_seq_length = seq_length
284 |
285 | Config.data.max_seq_length = max_seq_length
286 | print(f"Setting max_seq_length to Config : {max_seq_length}")
287 |
288 |
289 | def make_batch(data, buffer_size=10000, batch_size=64, scope="train"):
290 |
291 | class IteratorInitializerHook(tf.train.SessionRunHook):
292 | """Hook to initialise data iterator after Session is created."""
293 |
294 | def __init__(self):
295 | super(IteratorInitializerHook, self).__init__()
296 | self.iterator_initializer_func = None
297 |
298 | def after_create_session(self, session, coord):
299 | """Initialise the iterator after the session has been created."""
300 | self.iterator_initializer_func(session)
301 |
302 |
303 | def get_inputs():
304 |
305 | iterator_initializer_hook = IteratorInitializerHook()
306 |
307 | def train_inputs():
308 | with tf.name_scope(scope):
309 |
310 | X, y = data
311 |
312 | # Define placeholders
313 | input_placeholder = tf.placeholder(
314 | tf.int32, [None, Config.data.max_seq_length])
315 | output_placeholder = tf.placeholder(
316 | tf.int32, [None, Config.data.num_classes])
317 |
318 | # Build dataset iterator
319 | dataset = tf.data.Dataset.from_tensor_slices(
320 | (input_placeholder, output_placeholder))
321 |
322 | if scope == "train":
323 | dataset = dataset.repeat(None) # Infinite iterations
324 | else:
325 | dataset = dataset.repeat(1) # 1 Epoch
326 | # dataset = dataset.shuffle(buffer_size=buffer_size)
327 | dataset = dataset.batch(batch_size)
328 |
329 | iterator = dataset.make_initializable_iterator()
330 | next_X, next_y = iterator.get_next()
331 |
332 | tf.identity(next_X[0], 'input_0')
333 | tf.identity(next_y[0], 'target_0')
334 |
335 | # Set runhook to initialize iterator
336 | iterator_initializer_hook.iterator_initializer_func = \
337 | lambda sess: sess.run(
338 | iterator.initializer,
339 | feed_dict={input_placeholder: X,
340 | output_placeholder: y})
341 |
342 | # Return batched (features, labels)
343 | return next_X, next_y
344 |
345 | # Return function and hook
346 | return train_inputs, iterator_initializer_hook
347 |
348 | return get_inputs()
349 |
350 | if __name__ == '__main__':
351 |
352 | parser = argparse.ArgumentParser(
353 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
354 | parser.add_argument('--config', type=str, default='config',
355 | help='config file name')
356 | args = parser.parse_args()
357 |
358 | Config(args.config)
359 |
360 | prepare_raw_data()
361 | process_data()
362 |
--------------------------------------------------------------------------------
/hook.py:
--------------------------------------------------------------------------------
1 |
2 | from hbconfig import Config
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 |
7 |
8 |
9 | def print_variables(variables, rev_vocab=None, every_n_iter=100):
10 |
11 | return tf.train.LoggingTensorHook(
12 | variables,
13 | every_n_iter=every_n_iter,
14 | formatter=format_variable(variables, rev_vocab=rev_vocab))
15 |
16 |
17 | def format_variable(keys, rev_vocab=None):
18 |
19 | def to_str(sequence):
20 | if type(sequence) == np.ndarray:
21 | tokens = [
22 | rev_vocab.get(x, '') for x in sequence if x != Config.data.PAD_ID]
23 | return ' '.join(tokens)
24 | else:
25 | x = int(sequence)
26 | return rev_vocab[x]
27 |
28 | def format(values):
29 | result = []
30 | for key in keys:
31 | if rev_vocab is None:
32 | result.append(f"{key} = {values[key]}")
33 | else:
34 | result.append(f"{key} = {to_str(values[key])}")
35 |
36 | try:
37 | return '\n - '.join(result)
38 | except:
39 | pass
40 |
41 | return format
42 |
43 |
44 | def get_rev_vocab(vocab):
45 | if vocab is None:
46 | return None
47 | return {idx: key for key, idx in vocab.items()}
48 |
49 |
50 | def print_target(variables, every_n_iter=100):
51 |
52 | return tf.train.LoggingTensorHook(
53 | variables,
54 | every_n_iter=every_n_iter,
55 | formatter=print_pos_or_neg(variables))
56 |
57 |
58 | def print_pos_or_neg(keys):
59 |
60 | def format(values):
61 | result = []
62 | for key in keys:
63 | if type(values[key]) == np.ndarray:
64 | value = max(values[key])
65 | else:
66 | value = values[key]
67 | result.append(f"{key} = {value}")
68 |
69 | try:
70 | return ', '.join(result)
71 | except:
72 | pass
73 |
74 | return format
75 |
--------------------------------------------------------------------------------
/images/category.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/category.png
--------------------------------------------------------------------------------
/images/figure-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/figure-1.png
--------------------------------------------------------------------------------
/images/kaggle-loss_and_accuracy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/kaggle-loss_and_accuracy.jpg
--------------------------------------------------------------------------------
/images/rt-polarity_loss_and_accuracy.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/rt-polarity_loss_and_accuracy.jpeg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #-- coding: utf-8 -*-
2 |
3 | import argparse
4 | import atexit
5 | import logging
6 |
7 | from hbconfig import Config
8 | import tensorflow as tf
9 |
10 | import data_loader
11 | import hook
12 | from model import Model
13 | import utils
14 |
15 |
16 | def experiment_fn(run_config, params):
17 |
18 | model = Model()
19 | estimator = tf.estimator.Estimator(
20 | model_fn=model.model_fn,
21 | model_dir=Config.train.model_dir,
22 | params=params,
23 | config=run_config)
24 |
25 | vocab = data_loader.load_vocab("vocab")
26 | Config.data.vocab_size = len(vocab)
27 |
28 | train_data, test_data = data_loader.make_train_and_test_set()
29 | train_input_fn, train_input_hook = data_loader.make_batch(train_data,
30 | batch_size=Config.model.batch_size,
31 | scope="train")
32 | test_input_fn, test_input_hook = data_loader.make_batch(test_data,
33 | batch_size=Config.model.batch_size,
34 | scope="test")
35 |
36 | train_hooks = [train_input_hook]
37 | if Config.train.print_verbose:
38 | train_hooks.append(hook.print_variables(
39 | variables=['train/input_0'],
40 | rev_vocab=get_rev_vocab(vocab),
41 | every_n_iter=Config.train.check_hook_n_iter))
42 | train_hooks.append(hook.print_target(
43 | variables=['train/target_0', 'train/pred_0'],
44 | every_n_iter=Config.train.check_hook_n_iter))
45 | if Config.train.debug:
46 | train_hooks.append(tf_debug.LocalCLIDebugHook())
47 |
48 | eval_hooks = [test_input_hook]
49 | if Config.train.debug:
50 | eval_hooks.append(tf_debug.LocalCLIDebugHook())
51 |
52 | experiment = tf.contrib.learn.Experiment(
53 | estimator=estimator,
54 | train_input_fn=train_input_fn,
55 | eval_input_fn=test_input_fn,
56 | train_steps=Config.train.train_steps,
57 | min_eval_frequency=Config.train.min_eval_frequency,
58 | train_monitors=train_hooks,
59 | eval_hooks=eval_hooks
60 | )
61 | return experiment
62 |
63 |
64 | def get_rev_vocab(vocab):
65 | if vocab is None:
66 | return None
67 | return {idx: key for key, idx in vocab.items()}
68 |
69 |
70 | def main(mode):
71 | params = tf.contrib.training.HParams(**Config.model.to_dict())
72 |
73 | run_config = tf.contrib.learn.RunConfig(
74 | model_dir=Config.train.model_dir,
75 | save_checkpoints_steps=Config.train.save_checkpoints_steps)
76 |
77 | tf.contrib.learn.learn_runner.run(
78 | experiment_fn=experiment_fn,
79 | run_config=run_config,
80 | schedule=mode,
81 | hparams=params
82 | )
83 |
84 |
85 | if __name__ == '__main__':
86 |
87 | parser = argparse.ArgumentParser(
88 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
89 | parser.add_argument('--config', type=str, default='config',
90 | help='config file name')
91 | parser.add_argument('--mode', type=str, default='train',
92 | help='Mode (train/test/train_and_evaluate)')
93 | args = parser.parse_args()
94 |
95 | tf.logging._logger.setLevel(logging.INFO)
96 |
97 | # Print Config setting
98 | Config(args.config)
99 | print("Config: ", Config)
100 | if Config.get("description", None):
101 | print("Config Description")
102 | for key, value in Config.description.items():
103 | print(f" - {key}: {value}")
104 |
105 | # After terminated Notification to Slack
106 | atexit.register(utils.send_message_to_slack, config_name=args.config)
107 |
108 | main(args.mode)
109 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 |
4 | from hbconfig import Config
5 | import tensorflow as tf
6 | from tensorflow.contrib import layers
7 |
8 | import text_cnn
9 |
10 |
11 |
12 | class Model:
13 |
14 | def __init__(self):
15 | pass
16 |
17 | def model_fn(self, mode, features, labels, params):
18 | self.dtype = tf.float32
19 |
20 | self.mode = mode
21 | self.params = params
22 |
23 | self.loss, self.train_op, self.metrics, self.predictions = None, None, None, None
24 | self._init_placeholder(features, labels)
25 | self.build_graph()
26 |
27 | # train mode: required loss and train_op
28 | # eval mode: required loss
29 | # predict mode: required predictions
30 |
31 | return tf.estimator.EstimatorSpec(
32 | mode=mode,
33 | loss=self.loss,
34 | train_op=self.train_op,
35 | eval_metric_ops=self.metrics,
36 | predictions={"prediction": self.predictions})
37 |
38 | def _init_placeholder(self, features, labels):
39 | self.input_data = features
40 | if type(features) == dict:
41 | self.input_data = features["input_data"]
42 |
43 | self.targets = labels
44 |
45 | def build_graph(self):
46 | graph = text_cnn.Graph(self.mode)
47 | output = graph.build(self.input_data)
48 |
49 | self._build_prediction(output)
50 | if self.mode != tf.estimator.ModeKeys.PREDICT:
51 | self._build_loss(output)
52 | self._build_optimizer()
53 | self._build_metric()
54 |
55 | def _build_loss(self, output):
56 | self.loss = tf.losses.softmax_cross_entropy(
57 | self.targets,
58 | output,
59 | scope="loss")
60 |
61 | def _build_prediction(self, output):
62 | tf.argmax(output[0], name='train/pred_0') # for print_verbose
63 | self.predictions = tf.argmax(output, axis=1)
64 |
65 | def _build_optimizer(self):
66 | self.train_op = layers.optimize_loss(
67 | self.loss, tf.train.get_global_step(),
68 | optimizer='Adam',
69 | learning_rate=Config.train.learning_rate,
70 | summaries=['loss', 'learning_rate'],
71 | name="train_op")
72 |
73 | def _build_metric(self):
74 | self.metrics = {
75 | "accuracy": tf.metrics.accuracy(tf.argmax(self.targets, axis=1), self.predictions)
76 | }
77 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 |
2 | #-*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import os
6 | import sys
7 |
8 | from hbconfig import Config
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | import data_loader
13 | from model import Model
14 |
15 |
16 |
17 | def predict(ids):
18 |
19 | X = np.array(data_loader._pad_input(ids, Config.data.max_seq_length), dtype=np.int32)
20 | X = np.reshape(X, (1, Config.data.max_seq_length))
21 |
22 | predict_input_fn = tf.estimator.inputs.numpy_input_fn(
23 | x={"input_data": X},
24 | num_epochs=1,
25 | shuffle=False)
26 |
27 | estimator = _make_estimator()
28 | result = estimator.predict(input_fn=predict_input_fn)
29 |
30 | prediction = next(result)["prediction"]
31 | return prediction
32 |
33 |
34 | def _make_estimator():
35 | params = tf.contrib.training.HParams(**Config.model.to_dict())
36 | # Using CPU
37 | run_config = tf.contrib.learn.RunConfig(
38 | model_dir=Config.train.model_dir,
39 | session_config=tf.ConfigProto(
40 | device_count={'GPU': 0}
41 | ))
42 |
43 | model = Model()
44 | return tf.estimator.Estimator(
45 | model_fn=model.model_fn,
46 | model_dir=Config.train.model_dir,
47 | params=params,
48 | config=run_config)
49 |
50 |
51 | def _get_user_input():
52 | """ Get user's input, which will be transformed into encoder input later """
53 | print("> ", end="")
54 | sys.stdout.flush()
55 | return sys.stdin.readline()
56 |
57 |
58 | def main():
59 | data_loader.set_max_seq_length(['train_X_ids', 'test_X_ids'])
60 | vocab = data_loader.load_vocab("vocab")
61 | Config.data.vocab_size = len(vocab)
62 |
63 | print("Typing anything :) \n")
64 |
65 | while True:
66 | sentence = _get_user_input()
67 | ids = data_loader.sentence2id(vocab, sentence)
68 |
69 | if len(ids) > Config.data.max_seq_length:
70 | print(f"Max length I can handle is: {Config.data.max_seq_length}")
71 | continue
72 |
73 | result = predict(ids)
74 | print(result)
75 |
76 |
77 | if __name__ == '__main__':
78 |
79 | parser = argparse.ArgumentParser(
80 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
81 | parser.add_argument('--config', type=str, default='config',
82 | help='config file name')
83 | args = parser.parse_args()
84 |
85 | Config(args.config)
86 | Config.model.batch_size = 1
87 |
88 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
89 | tf.logging.set_verbosity(tf.logging.ERROR)
90 |
91 | main()
92 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hb-config
2 | tqdm
3 | requests
--------------------------------------------------------------------------------
/scripts/prepare_kaggle_movie_reviews.sh:
--------------------------------------------------------------------------------
1 | cd data/kaggle_movie_reviews
2 |
3 | unzip train.tsv.zip
4 | unzip test.tsv.zip
5 |
6 | cd ../..
7 | python data_loader.py --config kaggle_movie_review
8 |
--------------------------------------------------------------------------------
/text_cnn/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from hbconfig import Config
3 | import tensorflow as tf
4 |
5 |
6 |
7 | class Graph:
8 |
9 | def __init__(self, mode, dtype=tf.float32):
10 | self.mode = mode
11 | self.dtype = dtype
12 |
13 | def build(self, input_data):
14 | embedding_input = self.build_embed(input_data)
15 | conv_output = self.build_conv_layers(embedding_input)
16 | return self.build_fully_connected_layers(conv_output)
17 |
18 | def build_embed(self, input_data):
19 | with tf.variable_scope("embeddings", dtype=self.dtype) as scope:
20 | embed_type = Config.model.embed_type
21 |
22 | if embed_type == "rand":
23 | embedding = tf.get_variable(
24 | "embedding-rand",
25 | [Config.data.vocab_size, Config.model.embed_dim],
26 | self.dtype)
27 | elif embed_type == "static":
28 | raise NotImplementedError("CNN-static not implemented yet.")
29 | elif embed_type == "non-static":
30 | raise NotImplementedError("CNN-non-static not implemented yet.")
31 | elif embed_type == "multichannel":
32 | raise NotImplementedError("CNN-multichannel not implemented yet.")
33 | else:
34 | raise ValueError(f"Unknown embed_type {self.embed_type}")
35 |
36 | return tf.expand_dims(tf.nn.embedding_lookup(embedding, input_data), -1)
37 |
38 | def build_conv_layers(self, embedding_input):
39 | with tf.variable_scope("convolutions", dtype=self.dtype) as scope:
40 | pooled_outputs = self._build_conv_maxpool(embedding_input)
41 |
42 | num_total_filters = Config.model.num_filters * len(Config.model.filter_sizes)
43 | concat_pooled = tf.concat(pooled_outputs, 3)
44 | flat_pooled = tf.reshape(concat_pooled, [-1, num_total_filters])
45 |
46 | if self.mode == tf.estimator.ModeKeys.TRAIN:
47 | h_dropout = tf.layers.dropout(flat_pooled, Config.model.dropout)
48 | else:
49 | h_dropout = tf.layers.dropout(flat_pooled, 0)
50 | return h_dropout
51 |
52 | def _build_conv_maxpool(self, embedding_input):
53 | pooled_outputs = []
54 | for filter_size in Config.model.filter_sizes:
55 | with tf.variable_scope(f"conv-maxpool-{filter_size}-filter"):
56 | conv = tf.layers.conv2d(
57 | embedding_input,
58 | Config.model.num_filters,
59 | (filter_size, Config.model.embed_dim),
60 | activation=tf.nn.relu)
61 |
62 | pool = tf.layers.max_pooling2d(
63 | conv,
64 | (Config.data.max_seq_length - filter_size + 1, 1),
65 | (1, 1))
66 |
67 | pooled_outputs.append(pool)
68 | return pooled_outputs
69 |
70 | def build_fully_connected_layers(self, conv_output):
71 | with tf.variable_scope("fully-connected", dtype=self.dtype) as scope:
72 | return tf.layers.dense(
73 | conv_output,
74 | Config.data.num_classes,
75 | kernel_initializer=tf.contrib.layers.xavier_initializer())
76 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import json
3 | import os.path
4 |
5 | from hbconfig import Config
6 | import requests
7 |
8 |
9 |
10 | def send_message_to_slack(config_name):
11 | project_name = os.path.basename(os.path.abspath("."))
12 |
13 | data = {
14 | "text": f"The learning is finished with *{project_name}* Project using `{config_name}` config."
15 | }
16 |
17 | webhook_url = Config.slack.webhook_url
18 | if webhook_url == "":
19 | print(data["text"])
20 | else:
21 | requests.post(Config.slack.webhook_url, data=json.dumps(data))
22 |
--------------------------------------------------------------------------------