├── .gitignore ├── README.md ├── config ├── check_tiny.yml ├── kaggle_movie_review.yml └── rt-polarity.yml ├── data ├── kaggle_movie_reviews │ ├── test.tsv.zip │ └── train.tsv.zip ├── rt-polaritydata │ ├── rt-polarity.neg │ └── rt-polarity.pos └── tiny_processed_data │ ├── test_X_ids │ ├── test_y │ ├── train_X_ids │ ├── train_y │ └── vocab ├── data_loader.py ├── hook.py ├── images ├── category.png ├── figure-1.png ├── kaggle-loss_and_accuracy.jpg └── rt-polarity_loss_and_accuracy.jpeg ├── main.py ├── model.py ├── predict.py ├── requirements.txt ├── scripts └── prepare_kaggle_movie_reviews.sh ├── text_cnn └── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # text-cnn [![hb-research](https://img.shields.io/badge/hb--research-experiment-green.svg?style=flat&colorA=448C57&colorB=555555)](https://github.com/hb-research) 2 | 3 | This code implements [Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) models. 4 | 5 | - Figure 1: Illustration of a CNN architecture for sentence classification 6 | 7 | ![figure-1](images/figure-1.png) 8 | 9 | 10 | ## Requirements 11 | 12 | - Python 3.6 13 | - TensorFlow 1.4 14 | - [hb-config](https://github.com/hb-research/hb-config) (Singleton Config) 15 | - tqdm 16 | - requests 17 | - [Slack Incoming Webhook URL](https://my.slack.com/services/new/incoming-webhook/) 18 | 19 | ## Project Structure 20 | 21 | init Project by [hb-base](https://github.com/hb-research/hb-base) 22 | 23 | . 24 | ├── config # Config files (.yml, .json) using with hb-config 25 | ├── data # dataset path 26 | ├── notebooks # Prototyping with numpy or tf.interactivesession 27 | ├── scripts # download or prepare dataset using shell scripts 28 | ├── text-cnn # text-cnn architecture graphs (from input to logits) 29 | ├── __init__.py # Graph logic 30 | ├── data_loader.py # raw_date -> precossed_data -> generate_batch (using Dataset) 31 | ├── hook.py # training or test hook feature (eg. print_variables) 32 | ├── main.py # define experiment_fn 33 | ├── model.py # define EstimatorSpec 34 | └── predict.py # test trained model 35 | 36 | Reference : [hb-config](https://github.com/hb-research/hb-config), [Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator), [experiments_fn](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Experiment), [EstimatorSpec](https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec) 37 | 38 | - Dataset : [rt-polarity](https://github.com/yoonkim/CNN_sentence), [Sentiment Analysis on Movie Reviews](https://www.kaggle.com/c/sentiment-analysis-on-movie-reviews/data) 39 | 40 | ## Todo 41 | 42 | - apply embed_type 43 | - CNN-rand 44 | - CNN-static 45 | - CNN-nonstatic 46 | - CNN-multichannel 47 | 48 | ## Config 49 | 50 | example: kaggle\_movie\_review.yml 51 | 52 | ```yml 53 | data: 54 | type: 'kaggle_movie_review' 55 | base_path: 'data/' 56 | raw_data_path: 'kaggle_movie_reviews/' 57 | processed_path: 'kaggle_processed_data' 58 | testset_size: 25000 59 | num_classes: 5 60 | PAD_ID: 0 61 | 62 | model: 63 | batch_size: 64 64 | embed_type: 'rand' #(rand, static, non-static, multichannel) 65 | pretrained_embed: "" 66 | embed_dim: 300 67 | num_filters: 256 68 | filter_sizes: 69 | - 2 70 | - 3 71 | - 4 72 | - 5 73 | dropout: 0.5 74 | 75 | train: 76 | learning_rate: 0.00005 77 | 78 | train_steps: 100000 79 | model_dir: 'logs/kaggle_movie_review' 80 | 81 | save_checkpoints_steps: 1000 82 | loss_hook_n_iter: 1000 83 | check_hook_n_iter: 1000 84 | min_eval_frequency: 1000 85 | 86 | slack: 87 | webhook_url: "" # after training notify you using slack-webhook 88 | ``` 89 | 90 | 91 | ## Usage 92 | 93 | Install requirements. 94 | 95 | ```pip install -r requirements.txt``` 96 | 97 | Then, prepare dataset and train it. 98 | 99 | ``` 100 | sh prepare_kaggle_movie_reviews.sh 101 | python main.py --config kaggle_movie_review --mode train_and_evaluate 102 | ``` 103 | 104 | After training, you can try typing the sentences what you want using `predict.py`. 105 | 106 | ```python python predict.py --config rt-polarity``` 107 | 108 | Predict example 109 | 110 | ``` 111 | python predict.py --config rt-polarity 112 | Setting max_seq_length to Config : 62 113 | load vocab ... 114 | Typing anything :) 115 | 116 | > good 117 | 1 118 | > bad 119 | 0 120 | ``` 121 | 122 | ### Experiments modes 123 | 124 | :white_check_mark: : Working 125 | :white_medium_small_square: : Not tested yet. 126 | 127 | - :white_check_mark: `evaluate` : Evaluate on the evaluation data. 128 | - :white_medium_small_square: `extend_train_hooks` : Extends the hooks for training. 129 | - :white_medium_small_square: `reset_export_strategies` : Resets the export strategies with the new_export_strategies. 130 | - :white_medium_small_square: `run_std_server` : Starts a TensorFlow server and joins the serving thread. 131 | - :white_medium_small_square: `test` : Tests training, evaluating and exporting the estimator for a single step. 132 | - :white_check_mark: `train` : Fit the estimator using the training data. 133 | - :white_check_mark: `train_and_evaluate` : Interleaves training and evaluation. 134 | 135 | 136 | ### Tensorboard 137 | 138 | ```tensorboard --logdir logs``` 139 | 140 | - Category Color 141 | 142 | ![category_image](images/category.png) 143 | 144 | - rt-polarity (binary classification) 145 | 146 | ![images](images/rt-polarity_loss_and_accuracy.jpeg) 147 | 148 | - kaggle_movie_review (multiclass classification) 149 | 150 | ![images](images/kaggle-loss_and_accuracy.jpg) 151 | 152 | 153 | ## Reference 154 | 155 | - [Implementing a CNN for Text Classification in TensorFlow](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/) by Denny Britz 156 | - [Paper - Convolutional Neural Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) (2014) by Y Kim 157 | - [Paper - A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1510.03820.pdf) (2015) Y Zhang 158 | -------------------------------------------------------------------------------- /config/check_tiny.yml: -------------------------------------------------------------------------------- 1 | data: 2 | base_path: 'data/' 3 | processed_path: 'tiny_processed_data' 4 | testset_size: 2 5 | num_classes: 2 6 | PAD_ID: 0 7 | 8 | model: 9 | batch_size: 3 10 | embed_type: 'rand' 11 | pretrained_embed: "" 12 | embed_dim: 32 13 | num_filters: 16 14 | filter_sizes: 15 | - 2 16 | - 3 17 | - 4 18 | dropout: 0.5 19 | 20 | train: 21 | learning_rate: 0.001 22 | 23 | train_steps: 100 24 | model_dir: 'logs/rt-check_tiny' 25 | 26 | save_checkpoints_steps: 100 27 | loss_hook_n_iter: 100 28 | check_hook_n_iter: 10 29 | min_eval_frequency: 10 30 | 31 | print_verbose: True 32 | debug: False 33 | 34 | slack: 35 | webhook_url: "" 36 | -------------------------------------------------------------------------------- /config/kaggle_movie_review.yml: -------------------------------------------------------------------------------- 1 | data: 2 | type: 'kaggle_movie_review' 3 | base_path: 'data/' 4 | raw_data_path: 'kaggle_movie_reviews/' 5 | processed_path: 'kaggle_processed_data' 6 | testset_size: 25000 7 | num_classes: 5 8 | PAD_ID: 0 9 | 10 | model: 11 | batch_size: 64 12 | embed_type: 'rand' #(rand, static, non-static, multichannel) 13 | pretrained_embed: "" 14 | embed_dim: 300 15 | num_filters: 256 16 | filter_sizes: 17 | - 2 18 | - 3 19 | - 4 20 | - 5 21 | dropout: 0.5 22 | 23 | train: 24 | learning_rate: 0.00005 25 | 26 | train_steps: 100000 27 | model_dir: 'logs/kaggle_movie_review' 28 | 29 | save_checkpoints_steps: 1000 30 | loss_hook_n_iter: 1000 31 | check_hook_n_iter: 1000 32 | min_eval_frequency: 1000 33 | 34 | print_verbose: True 35 | debug: False 36 | 37 | slack: 38 | webhook_url: "" 39 | -------------------------------------------------------------------------------- /config/rt-polarity.yml: -------------------------------------------------------------------------------- 1 | data: 2 | type: 'rt-polarity' 3 | base_path: 'data/' 4 | raw_data_path: 'rt-polaritydata/' 5 | processed_path: 'rt-polarity_processed_data' 6 | testset_size: 2000 7 | num_classes: 2 8 | PAD_ID: 0 9 | 10 | model: 11 | batch_size: 64 12 | embed_type: 'rand' #(rand, static, non-static, multichannel) 13 | pretrained_embed: "" 14 | embed_dim: 300 15 | num_filters: 256 16 | filter_sizes: 17 | - 2 18 | - 3 19 | - 4 20 | - 5 21 | dropout: 0.5 22 | 23 | train: 24 | learning_rate: 0.00001 25 | 26 | train_steps: 20000 27 | model_dir: 'logs/rt-polarity' 28 | 29 | save_checkpoints_steps: 100 30 | loss_hook_n_iter: 100 31 | check_hook_n_iter: 100 32 | min_eval_frequency: 100 33 | 34 | print_verbose: True 35 | debug: False 36 | 37 | slack: 38 | webhook_url: "" 39 | -------------------------------------------------------------------------------- /data/kaggle_movie_reviews/test.tsv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/data/kaggle_movie_reviews/test.tsv.zip -------------------------------------------------------------------------------- /data/kaggle_movie_reviews/train.tsv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/data/kaggle_movie_reviews/train.tsv.zip -------------------------------------------------------------------------------- /data/tiny_processed_data/test_X_ids: -------------------------------------------------------------------------------- 1 | 1 2 3 4 5 6 7 8 9 10 2 | 11 12 13 1 4 15 16 3 | 17 18 19 20 4 | 21 22 23 24 25 26 -------------------------------------------------------------------------------- /data/tiny_processed_data/test_y: -------------------------------------------------------------------------------- 1 | 1 2 | 0 3 | 1 4 | 0 -------------------------------------------------------------------------------- /data/tiny_processed_data/train_X_ids: -------------------------------------------------------------------------------- 1 | 1 2 3 4 5 6 7 8 9 10 2 | 11 12 13 1 4 15 16 3 | 17 18 19 20 4 | 21 22 23 24 25 26 -------------------------------------------------------------------------------- /data/tiny_processed_data/train_y: -------------------------------------------------------------------------------- 1 | 1 2 | 0 3 | 1 4 | 0 -------------------------------------------------------------------------------- /data/tiny_processed_data/vocab: -------------------------------------------------------------------------------- 1 | PAD 2 | v1 3 | v2 4 | v3 5 | v4 6 | v5 7 | v6 8 | v7 9 | v8 10 | v9 11 | v10 12 | v11 13 | v12 14 | v13 15 | v14 16 | v15 17 | v16 18 | v17 19 | v18 20 | v19 21 | v20 22 | v21 23 | v22 24 | v23 25 | v24 26 | v25 27 | v26 28 | v27 29 | v28 30 | v29 31 | v30 -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import csv 6 | import os 7 | import random 8 | import re 9 | 10 | import numpy as np 11 | from hbconfig import Config 12 | import tensorflow as tf 13 | from tqdm import tqdm 14 | 15 | 16 | def clean_str(string): 17 | """ 18 | Tokenization/string cleaning for all datasets except for SST. 19 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 20 | """ 21 | string = string.decode('utf-8') 22 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 23 | string = re.sub(r"\'s", " \'s", string) 24 | string = re.sub(r"\'ve", " \'ve", string) 25 | string = re.sub(r"n\'t", " n\'t", string) 26 | string = re.sub(r"\'re", " \'re", string) 27 | string = re.sub(r"\'d", " \'d", string) 28 | string = re.sub(r"\'ll", " \'ll", string) 29 | string = re.sub(r",", " , ", string) 30 | string = re.sub(r"!", " ! ", string) 31 | string = re.sub(r"\(", " \( ", string) 32 | string = re.sub(r"\)", " \) ", string) 33 | string = re.sub(r"\?", " \? ", string) 34 | string = re.sub(r"\s{2,}", " ", string) 35 | return string.strip().lower() 36 | 37 | 38 | def load_data_and_labels(positive_data_file, negative_data_file): 39 | """ 40 | Loads MR polarity data from files, splits the data into words and generates labels. 41 | Returns split sentences and labels. 42 | """ 43 | # Load data from files 44 | positive_examples = list(open(positive_data_file, "rb").readlines()) 45 | positive_examples = [s.strip() for s in positive_examples] 46 | negative_examples = list(open(negative_data_file, "rb").readlines()) 47 | negative_examples = [s.strip() for s in negative_examples] 48 | # Split by words 49 | x_text = positive_examples + negative_examples 50 | x_text = [clean_str(sent) for sent in x_text] 51 | # Generate labels 52 | positive_labels = ['1' for _ in positive_examples] 53 | negative_labels = ['0' for _ in negative_examples] 54 | y = positive_labels + negative_labels 55 | return x_text, y 56 | 57 | 58 | def prepare_raw_data(): 59 | print('Preparing raw data into train set and test set ...') 60 | raw_data_path = os.path.join(Config.data.base_path, Config.data.raw_data_path) 61 | 62 | data_type = Config.data.type 63 | if data_type == "kaggle_movie_review": 64 | train_path = os.path.join(raw_data_path, 'train.tsv') 65 | train_reader = csv.reader(open(train_path), delimiter="\t") 66 | 67 | prepare_dataset(dataset=list(train_reader)) 68 | 69 | elif data_type == "rt-polarity": 70 | pos_path = os.path.join(Config.data.base_path, Config.data.raw_data_path, "rt-polarity.pos") 71 | neg_path = os.path.join(Config.data.base_path, Config.data.raw_data_path, "rt-polarity.neg") 72 | x_text, y = load_data_and_labels(pos_path, neg_path) 73 | 74 | prepare_dataset(x_text=x_text, y=y) 75 | 76 | 77 | def prepare_dataset(dataset=None, x_text=None, y=None): 78 | make_dir(os.path.join(Config.data.base_path, Config.data.processed_path)) 79 | 80 | filenames = ['train_X', 'train_y', 'test_X', 'test_y'] 81 | files = [] 82 | for filename in filenames: 83 | files.append(open(os.path.join(Config.data.base_path, Config.data.processed_path, filename), 'wb')) 84 | 85 | if dataset is not None: 86 | 87 | print("Total data length : ", len(dataset)) 88 | test_ids = random.sample([i for i in range(len(dataset))], Config.data.testset_size) 89 | 90 | for i in tqdm(range(len(dataset))): 91 | if i == 0: 92 | continue 93 | 94 | data = dataset[i] 95 | X, y = data[2], data[3] 96 | 97 | if i in test_ids: 98 | files[2].write((X + "\n").encode('utf-8')) 99 | files[3].write((y + '\n').encode('utf-8')) 100 | else: 101 | files[0].write((X + '\n').encode('utf-8')) 102 | files[1].write((y + '\n').encode('utf-8')) 103 | 104 | else: 105 | 106 | print("Total data length : ", len(y)) 107 | test_ids = random.sample([i for i in range(len(y))], Config.data.testset_size) 108 | 109 | for i in tqdm(range(len(y))): 110 | if i in test_ids: 111 | files[2].write((x_text[i] + "\n").encode('utf-8')) 112 | files[3].write((y[i] + '\n').encode('utf-8')) 113 | else: 114 | files[0].write((x_text[i] + '\n').encode('utf-8')) 115 | files[1].write((y[i] + '\n').encode('utf-8')) 116 | 117 | for file in files: 118 | file.close() 119 | 120 | 121 | def make_dir(path): 122 | """ Create a directory if there isn't one already. """ 123 | try: 124 | os.mkdir(path) 125 | except OSError: 126 | pass 127 | 128 | 129 | def basic_tokenizer(line, normalize_digits=True): 130 | """ A basic tokenizer to tokenize text into tokens. 131 | Feel free to change this to suit your need. """ 132 | line = re.sub('', '', line) 133 | line = re.sub('', '', line) 134 | line = re.sub('\[', '', line) 135 | line = re.sub('\]', '', line) 136 | words = [] 137 | _WORD_SPLIT = re.compile("([.,!?\"'-<>:;)(])") 138 | _DIGIT_RE = re.compile(r"\d") 139 | for fragment in line.strip().lower().split(): 140 | for token in re.split(_WORD_SPLIT, fragment): 141 | if not token: 142 | continue 143 | if normalize_digits: 144 | token = re.sub(_DIGIT_RE, '#', token) 145 | words.append(token) 146 | return words 147 | 148 | 149 | def build_vocab(train_fname, test_fname, normalize_digits=True): 150 | vocab = {} 151 | def count_vocab(fname): 152 | with open(fname, 'rb') as f: 153 | for line in f.readlines(): 154 | line = line.decode('utf-8') 155 | for token in basic_tokenizer(line): 156 | if not token in vocab: 157 | vocab[token] = 0 158 | vocab[token] += 1 159 | 160 | train_path = os.path.join(Config.data.base_path, Config.data.processed_path, train_fname) 161 | test_path = os.path.join(Config.data.base_path, Config.data.processed_path, test_fname) 162 | 163 | count_vocab(train_path) 164 | count_vocab(test_path) 165 | 166 | sorted_vocab = sorted(vocab, key=vocab.get, reverse=True) 167 | 168 | dest_path = os.path.join(Config.data.base_path, Config.data.processed_path, 'vocab') 169 | with open(dest_path, 'wb') as f: 170 | f.write(('' + '\n').encode('utf-8')) 171 | index = 1 172 | for word in sorted_vocab: 173 | f.write((word + '\n').encode('utf-8')) 174 | index += 1 175 | 176 | 177 | def load_vocab(vocab_fname): 178 | print("load vocab ...") 179 | with open(os.path.join(Config.data.base_path, Config.data.processed_path, vocab_fname), 'rb') as f: 180 | words = f.read().decode('utf-8').splitlines() 181 | return {words[i]: i for i in range(len(words))} 182 | 183 | 184 | def sentence2id(vocab, line): 185 | return [vocab.get(token, vocab['']) for token in basic_tokenizer(line)] 186 | 187 | 188 | def token2id(data): 189 | """ Convert all the tokens in the data into their corresponding 190 | index in the vocabulary. """ 191 | vocab_path = 'vocab' 192 | in_path = data 193 | out_path = data + '_ids' 194 | 195 | vocab = load_vocab(vocab_path) 196 | in_file = open(os.path.join(Config.data.base_path, Config.data.processed_path, in_path), 'rb') 197 | out_file = open(os.path.join(Config.data.base_path, Config.data.processed_path, out_path), 'wb') 198 | 199 | lines = in_file.read().decode('utf-8').splitlines() 200 | for line in lines: 201 | ids = [] 202 | sentence_ids = sentence2id(vocab, line) 203 | ids.extend(sentence_ids) 204 | 205 | out_file.write(b' '.join(str(id_).encode('utf-8') for id_ in ids) + b'\n') 206 | 207 | 208 | def process_data(): 209 | print('Preparing data to be model-ready ...') 210 | 211 | build_vocab('train_X', 'test_X') 212 | 213 | token2id('train_X') 214 | token2id('test_X') 215 | 216 | 217 | def make_train_and_test_set(shuffle=True): 218 | print("make Training data and Test data Start....") 219 | 220 | if Config.data.get('max_seq_length', None) is None: 221 | set_max_seq_length(['train_X_ids', 'test_X_ids']) 222 | 223 | train_X, train_y = load_data('train_X_ids', 'train_y') 224 | test_X, test_y = load_data('test_X_ids', 'test_y') 225 | 226 | assert len(train_X) == len(train_y) 227 | assert len(test_X) == len(test_y) 228 | 229 | print(f"train data count : {len(train_y)}") 230 | print(f"test data count : {len(test_y)}") 231 | 232 | if shuffle: 233 | print("shuffle dataset ...") 234 | train_p = np.random.permutation(len(train_y)) 235 | test_p = np.random.permutation(len(test_y)) 236 | 237 | return ((train_X[train_p], train_y[train_p]), 238 | (test_X[test_p], test_y[test_p])) 239 | else: 240 | return ((train_X, train_y), 241 | (test_X, test_y)) 242 | 243 | 244 | def load_data(X_fname, y_fname): 245 | X_input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, X_fname), 'r') 246 | y_input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, y_fname), 'r') 247 | 248 | X_data, y_data = [], [] 249 | for X_line, y_line in zip(X_input_data.readlines(), y_input_data.readlines()): 250 | X_ids = [int(id_) for id_ in X_line.split()] 251 | y_id = int(y_line) 252 | 253 | if len(X_ids) == 0 or y_id >= Config.data.num_classes: 254 | continue 255 | 256 | if len(X_ids) <= Config.data.max_seq_length: 257 | X_data.append(_pad_input(X_ids, Config.data.max_seq_length)) 258 | 259 | y_one_hot = np.zeros(Config.data.num_classes) 260 | y_one_hot[int(y_line)] = 1 261 | y_data.append(y_one_hot) 262 | 263 | print(f"load data from {X_fname}, {y_fname}...") 264 | return np.array(X_data, dtype=np.int32), np.array(y_data, dtype=np.int32) 265 | 266 | 267 | def _pad_input(input_, size): 268 | return input_ + [0] * (size - len(input_)) 269 | 270 | 271 | def set_max_seq_length(dataset_fnames): 272 | 273 | max_seq_length = Config.data.get('max_seq_length', 10) 274 | 275 | for fname in dataset_fnames: 276 | input_data = open(os.path.join(Config.data.base_path, Config.data.processed_path, fname), 'r') 277 | 278 | for line in input_data.readlines(): 279 | ids = [int(id_) for id_ in line.split()] 280 | seq_length = len(ids) 281 | 282 | if seq_length > max_seq_length: 283 | max_seq_length = seq_length 284 | 285 | Config.data.max_seq_length = max_seq_length 286 | print(f"Setting max_seq_length to Config : {max_seq_length}") 287 | 288 | 289 | def make_batch(data, buffer_size=10000, batch_size=64, scope="train"): 290 | 291 | class IteratorInitializerHook(tf.train.SessionRunHook): 292 | """Hook to initialise data iterator after Session is created.""" 293 | 294 | def __init__(self): 295 | super(IteratorInitializerHook, self).__init__() 296 | self.iterator_initializer_func = None 297 | 298 | def after_create_session(self, session, coord): 299 | """Initialise the iterator after the session has been created.""" 300 | self.iterator_initializer_func(session) 301 | 302 | 303 | def get_inputs(): 304 | 305 | iterator_initializer_hook = IteratorInitializerHook() 306 | 307 | def train_inputs(): 308 | with tf.name_scope(scope): 309 | 310 | X, y = data 311 | 312 | # Define placeholders 313 | input_placeholder = tf.placeholder( 314 | tf.int32, [None, Config.data.max_seq_length]) 315 | output_placeholder = tf.placeholder( 316 | tf.int32, [None, Config.data.num_classes]) 317 | 318 | # Build dataset iterator 319 | dataset = tf.data.Dataset.from_tensor_slices( 320 | (input_placeholder, output_placeholder)) 321 | 322 | if scope == "train": 323 | dataset = dataset.repeat(None) # Infinite iterations 324 | else: 325 | dataset = dataset.repeat(1) # 1 Epoch 326 | # dataset = dataset.shuffle(buffer_size=buffer_size) 327 | dataset = dataset.batch(batch_size) 328 | 329 | iterator = dataset.make_initializable_iterator() 330 | next_X, next_y = iterator.get_next() 331 | 332 | tf.identity(next_X[0], 'input_0') 333 | tf.identity(next_y[0], 'target_0') 334 | 335 | # Set runhook to initialize iterator 336 | iterator_initializer_hook.iterator_initializer_func = \ 337 | lambda sess: sess.run( 338 | iterator.initializer, 339 | feed_dict={input_placeholder: X, 340 | output_placeholder: y}) 341 | 342 | # Return batched (features, labels) 343 | return next_X, next_y 344 | 345 | # Return function and hook 346 | return train_inputs, iterator_initializer_hook 347 | 348 | return get_inputs() 349 | 350 | if __name__ == '__main__': 351 | 352 | parser = argparse.ArgumentParser( 353 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 354 | parser.add_argument('--config', type=str, default='config', 355 | help='config file name') 356 | args = parser.parse_args() 357 | 358 | Config(args.config) 359 | 360 | prepare_raw_data() 361 | process_data() 362 | -------------------------------------------------------------------------------- /hook.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | 8 | 9 | def print_variables(variables, rev_vocab=None, every_n_iter=100): 10 | 11 | return tf.train.LoggingTensorHook( 12 | variables, 13 | every_n_iter=every_n_iter, 14 | formatter=format_variable(variables, rev_vocab=rev_vocab)) 15 | 16 | 17 | def format_variable(keys, rev_vocab=None): 18 | 19 | def to_str(sequence): 20 | if type(sequence) == np.ndarray: 21 | tokens = [ 22 | rev_vocab.get(x, '') for x in sequence if x != Config.data.PAD_ID] 23 | return ' '.join(tokens) 24 | else: 25 | x = int(sequence) 26 | return rev_vocab[x] 27 | 28 | def format(values): 29 | result = [] 30 | for key in keys: 31 | if rev_vocab is None: 32 | result.append(f"{key} = {values[key]}") 33 | else: 34 | result.append(f"{key} = {to_str(values[key])}") 35 | 36 | try: 37 | return '\n - '.join(result) 38 | except: 39 | pass 40 | 41 | return format 42 | 43 | 44 | def get_rev_vocab(vocab): 45 | if vocab is None: 46 | return None 47 | return {idx: key for key, idx in vocab.items()} 48 | 49 | 50 | def print_target(variables, every_n_iter=100): 51 | 52 | return tf.train.LoggingTensorHook( 53 | variables, 54 | every_n_iter=every_n_iter, 55 | formatter=print_pos_or_neg(variables)) 56 | 57 | 58 | def print_pos_or_neg(keys): 59 | 60 | def format(values): 61 | result = [] 62 | for key in keys: 63 | if type(values[key]) == np.ndarray: 64 | value = max(values[key]) 65 | else: 66 | value = values[key] 67 | result.append(f"{key} = {value}") 68 | 69 | try: 70 | return ', '.join(result) 71 | except: 72 | pass 73 | 74 | return format 75 | -------------------------------------------------------------------------------- /images/category.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/category.png -------------------------------------------------------------------------------- /images/figure-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/figure-1.png -------------------------------------------------------------------------------- /images/kaggle-loss_and_accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/kaggle-loss_and_accuracy.jpg -------------------------------------------------------------------------------- /images/rt-polarity_loss_and_accuracy.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DongjunLee/text-cnn-tensorflow/1e9ebb02dd806550e717f2ace680117feaa532bf/images/rt-polarity_loss_and_accuracy.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #-- coding: utf-8 -*- 2 | 3 | import argparse 4 | import atexit 5 | import logging 6 | 7 | from hbconfig import Config 8 | import tensorflow as tf 9 | 10 | import data_loader 11 | import hook 12 | from model import Model 13 | import utils 14 | 15 | 16 | def experiment_fn(run_config, params): 17 | 18 | model = Model() 19 | estimator = tf.estimator.Estimator( 20 | model_fn=model.model_fn, 21 | model_dir=Config.train.model_dir, 22 | params=params, 23 | config=run_config) 24 | 25 | vocab = data_loader.load_vocab("vocab") 26 | Config.data.vocab_size = len(vocab) 27 | 28 | train_data, test_data = data_loader.make_train_and_test_set() 29 | train_input_fn, train_input_hook = data_loader.make_batch(train_data, 30 | batch_size=Config.model.batch_size, 31 | scope="train") 32 | test_input_fn, test_input_hook = data_loader.make_batch(test_data, 33 | batch_size=Config.model.batch_size, 34 | scope="test") 35 | 36 | train_hooks = [train_input_hook] 37 | if Config.train.print_verbose: 38 | train_hooks.append(hook.print_variables( 39 | variables=['train/input_0'], 40 | rev_vocab=get_rev_vocab(vocab), 41 | every_n_iter=Config.train.check_hook_n_iter)) 42 | train_hooks.append(hook.print_target( 43 | variables=['train/target_0', 'train/pred_0'], 44 | every_n_iter=Config.train.check_hook_n_iter)) 45 | if Config.train.debug: 46 | train_hooks.append(tf_debug.LocalCLIDebugHook()) 47 | 48 | eval_hooks = [test_input_hook] 49 | if Config.train.debug: 50 | eval_hooks.append(tf_debug.LocalCLIDebugHook()) 51 | 52 | experiment = tf.contrib.learn.Experiment( 53 | estimator=estimator, 54 | train_input_fn=train_input_fn, 55 | eval_input_fn=test_input_fn, 56 | train_steps=Config.train.train_steps, 57 | min_eval_frequency=Config.train.min_eval_frequency, 58 | train_monitors=train_hooks, 59 | eval_hooks=eval_hooks 60 | ) 61 | return experiment 62 | 63 | 64 | def get_rev_vocab(vocab): 65 | if vocab is None: 66 | return None 67 | return {idx: key for key, idx in vocab.items()} 68 | 69 | 70 | def main(mode): 71 | params = tf.contrib.training.HParams(**Config.model.to_dict()) 72 | 73 | run_config = tf.contrib.learn.RunConfig( 74 | model_dir=Config.train.model_dir, 75 | save_checkpoints_steps=Config.train.save_checkpoints_steps) 76 | 77 | tf.contrib.learn.learn_runner.run( 78 | experiment_fn=experiment_fn, 79 | run_config=run_config, 80 | schedule=mode, 81 | hparams=params 82 | ) 83 | 84 | 85 | if __name__ == '__main__': 86 | 87 | parser = argparse.ArgumentParser( 88 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 89 | parser.add_argument('--config', type=str, default='config', 90 | help='config file name') 91 | parser.add_argument('--mode', type=str, default='train', 92 | help='Mode (train/test/train_and_evaluate)') 93 | args = parser.parse_args() 94 | 95 | tf.logging._logger.setLevel(logging.INFO) 96 | 97 | # Print Config setting 98 | Config(args.config) 99 | print("Config: ", Config) 100 | if Config.get("description", None): 101 | print("Config Description") 102 | for key, value in Config.description.items(): 103 | print(f" - {key}: {value}") 104 | 105 | # After terminated Notification to Slack 106 | atexit.register(utils.send_message_to_slack, config_name=args.config) 107 | 108 | main(args.mode) 109 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | 4 | from hbconfig import Config 5 | import tensorflow as tf 6 | from tensorflow.contrib import layers 7 | 8 | import text_cnn 9 | 10 | 11 | 12 | class Model: 13 | 14 | def __init__(self): 15 | pass 16 | 17 | def model_fn(self, mode, features, labels, params): 18 | self.dtype = tf.float32 19 | 20 | self.mode = mode 21 | self.params = params 22 | 23 | self.loss, self.train_op, self.metrics, self.predictions = None, None, None, None 24 | self._init_placeholder(features, labels) 25 | self.build_graph() 26 | 27 | # train mode: required loss and train_op 28 | # eval mode: required loss 29 | # predict mode: required predictions 30 | 31 | return tf.estimator.EstimatorSpec( 32 | mode=mode, 33 | loss=self.loss, 34 | train_op=self.train_op, 35 | eval_metric_ops=self.metrics, 36 | predictions={"prediction": self.predictions}) 37 | 38 | def _init_placeholder(self, features, labels): 39 | self.input_data = features 40 | if type(features) == dict: 41 | self.input_data = features["input_data"] 42 | 43 | self.targets = labels 44 | 45 | def build_graph(self): 46 | graph = text_cnn.Graph(self.mode) 47 | output = graph.build(self.input_data) 48 | 49 | self._build_prediction(output) 50 | if self.mode != tf.estimator.ModeKeys.PREDICT: 51 | self._build_loss(output) 52 | self._build_optimizer() 53 | self._build_metric() 54 | 55 | def _build_loss(self, output): 56 | self.loss = tf.losses.softmax_cross_entropy( 57 | self.targets, 58 | output, 59 | scope="loss") 60 | 61 | def _build_prediction(self, output): 62 | tf.argmax(output[0], name='train/pred_0') # for print_verbose 63 | self.predictions = tf.argmax(output, axis=1) 64 | 65 | def _build_optimizer(self): 66 | self.train_op = layers.optimize_loss( 67 | self.loss, tf.train.get_global_step(), 68 | optimizer='Adam', 69 | learning_rate=Config.train.learning_rate, 70 | summaries=['loss', 'learning_rate'], 71 | name="train_op") 72 | 73 | def _build_metric(self): 74 | self.metrics = { 75 | "accuracy": tf.metrics.accuracy(tf.argmax(self.targets, axis=1), self.predictions) 76 | } 77 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | 2 | #-*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | import sys 7 | 8 | from hbconfig import Config 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | import data_loader 13 | from model import Model 14 | 15 | 16 | 17 | def predict(ids): 18 | 19 | X = np.array(data_loader._pad_input(ids, Config.data.max_seq_length), dtype=np.int32) 20 | X = np.reshape(X, (1, Config.data.max_seq_length)) 21 | 22 | predict_input_fn = tf.estimator.inputs.numpy_input_fn( 23 | x={"input_data": X}, 24 | num_epochs=1, 25 | shuffle=False) 26 | 27 | estimator = _make_estimator() 28 | result = estimator.predict(input_fn=predict_input_fn) 29 | 30 | prediction = next(result)["prediction"] 31 | return prediction 32 | 33 | 34 | def _make_estimator(): 35 | params = tf.contrib.training.HParams(**Config.model.to_dict()) 36 | # Using CPU 37 | run_config = tf.contrib.learn.RunConfig( 38 | model_dir=Config.train.model_dir, 39 | session_config=tf.ConfigProto( 40 | device_count={'GPU': 0} 41 | )) 42 | 43 | model = Model() 44 | return tf.estimator.Estimator( 45 | model_fn=model.model_fn, 46 | model_dir=Config.train.model_dir, 47 | params=params, 48 | config=run_config) 49 | 50 | 51 | def _get_user_input(): 52 | """ Get user's input, which will be transformed into encoder input later """ 53 | print("> ", end="") 54 | sys.stdout.flush() 55 | return sys.stdin.readline() 56 | 57 | 58 | def main(): 59 | data_loader.set_max_seq_length(['train_X_ids', 'test_X_ids']) 60 | vocab = data_loader.load_vocab("vocab") 61 | Config.data.vocab_size = len(vocab) 62 | 63 | print("Typing anything :) \n") 64 | 65 | while True: 66 | sentence = _get_user_input() 67 | ids = data_loader.sentence2id(vocab, sentence) 68 | 69 | if len(ids) > Config.data.max_seq_length: 70 | print(f"Max length I can handle is: {Config.data.max_seq_length}") 71 | continue 72 | 73 | result = predict(ids) 74 | print(result) 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | parser = argparse.ArgumentParser( 80 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 81 | parser.add_argument('--config', type=str, default='config', 82 | help='config file name') 83 | args = parser.parse_args() 84 | 85 | Config(args.config) 86 | Config.model.batch_size = 1 87 | 88 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 89 | tf.logging.set_verbosity(tf.logging.ERROR) 90 | 91 | main() 92 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hb-config 2 | tqdm 3 | requests -------------------------------------------------------------------------------- /scripts/prepare_kaggle_movie_reviews.sh: -------------------------------------------------------------------------------- 1 | cd data/kaggle_movie_reviews 2 | 3 | unzip train.tsv.zip 4 | unzip test.tsv.zip 5 | 6 | cd ../.. 7 | python data_loader.py --config kaggle_movie_review 8 | -------------------------------------------------------------------------------- /text_cnn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from hbconfig import Config 3 | import tensorflow as tf 4 | 5 | 6 | 7 | class Graph: 8 | 9 | def __init__(self, mode, dtype=tf.float32): 10 | self.mode = mode 11 | self.dtype = dtype 12 | 13 | def build(self, input_data): 14 | embedding_input = self.build_embed(input_data) 15 | conv_output = self.build_conv_layers(embedding_input) 16 | return self.build_fully_connected_layers(conv_output) 17 | 18 | def build_embed(self, input_data): 19 | with tf.variable_scope("embeddings", dtype=self.dtype) as scope: 20 | embed_type = Config.model.embed_type 21 | 22 | if embed_type == "rand": 23 | embedding = tf.get_variable( 24 | "embedding-rand", 25 | [Config.data.vocab_size, Config.model.embed_dim], 26 | self.dtype) 27 | elif embed_type == "static": 28 | raise NotImplementedError("CNN-static not implemented yet.") 29 | elif embed_type == "non-static": 30 | raise NotImplementedError("CNN-non-static not implemented yet.") 31 | elif embed_type == "multichannel": 32 | raise NotImplementedError("CNN-multichannel not implemented yet.") 33 | else: 34 | raise ValueError(f"Unknown embed_type {self.embed_type}") 35 | 36 | return tf.expand_dims(tf.nn.embedding_lookup(embedding, input_data), -1) 37 | 38 | def build_conv_layers(self, embedding_input): 39 | with tf.variable_scope("convolutions", dtype=self.dtype) as scope: 40 | pooled_outputs = self._build_conv_maxpool(embedding_input) 41 | 42 | num_total_filters = Config.model.num_filters * len(Config.model.filter_sizes) 43 | concat_pooled = tf.concat(pooled_outputs, 3) 44 | flat_pooled = tf.reshape(concat_pooled, [-1, num_total_filters]) 45 | 46 | if self.mode == tf.estimator.ModeKeys.TRAIN: 47 | h_dropout = tf.layers.dropout(flat_pooled, Config.model.dropout) 48 | else: 49 | h_dropout = tf.layers.dropout(flat_pooled, 0) 50 | return h_dropout 51 | 52 | def _build_conv_maxpool(self, embedding_input): 53 | pooled_outputs = [] 54 | for filter_size in Config.model.filter_sizes: 55 | with tf.variable_scope(f"conv-maxpool-{filter_size}-filter"): 56 | conv = tf.layers.conv2d( 57 | embedding_input, 58 | Config.model.num_filters, 59 | (filter_size, Config.model.embed_dim), 60 | activation=tf.nn.relu) 61 | 62 | pool = tf.layers.max_pooling2d( 63 | conv, 64 | (Config.data.max_seq_length - filter_size + 1, 1), 65 | (1, 1)) 66 | 67 | pooled_outputs.append(pool) 68 | return pooled_outputs 69 | 70 | def build_fully_connected_layers(self, conv_output): 71 | with tf.variable_scope("fully-connected", dtype=self.dtype) as scope: 72 | return tf.layers.dense( 73 | conv_output, 74 | Config.data.num_classes, 75 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 76 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import os.path 4 | 5 | from hbconfig import Config 6 | import requests 7 | 8 | 9 | 10 | def send_message_to_slack(config_name): 11 | project_name = os.path.basename(os.path.abspath(".")) 12 | 13 | data = { 14 | "text": f"The learning is finished with *{project_name}* Project using `{config_name}` config." 15 | } 16 | 17 | webhook_url = Config.slack.webhook_url 18 | if webhook_url == "": 19 | print(data["text"]) 20 | else: 21 | requests.post(Config.slack.webhook_url, data=json.dumps(data)) 22 | --------------------------------------------------------------------------------