├── README.md ├── cnn_models ├── char_cnn.py ├── vd_cnn.py └── word_cnn.py ├── data_utils.py ├── requirements.txt ├── rnn_models ├── attention_rnn.py ├── rcnn.py └── word_rnn.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Text Classification Models with Tensorflow 2 | Tensorflow implementation of Text Classification Models. 3 | 4 | Implemented Models: 5 | 6 | 1. Word-level CNN [[paper](https://arxiv.org/abs/1408.5882)] 7 | 2. Character-level CNN [[paper](https://arxiv.org/abs/1509.01626)] 8 | 3. Very Deep CNN [[paper](https://arxiv.org/abs/1606.01781)] 9 | 4. Word-level Bidirectional RNN 10 | 5. Attention-Based Bidirectional RNN [[paper](http://www.aclweb.org/anthology/P16-2034)] 11 | 6. RCNN [[paper](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552)] 12 | 13 | Semi-supervised text classification(Transfer learning) models are implemented at [[dongjun-Lee/transfer-learning-text-tf]](https://github.com/dongjun-Lee/transfer-learning-text-tf). 14 | 15 | 16 | ## Requirements 17 | - Python3 18 | - Tensorflow 19 | - pip install -r requirements.txt 20 | 21 | 22 | ## Usage 23 | 24 | ### Train 25 | To train classification models for dbpedia dataset, 26 | ``` 27 | $ python train.py --model="" 28 | ``` 29 | (\: word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn) 30 | 31 | ### Test 32 | To test classification accuracy for test data after training, 33 | ``` 34 | $ python test.py --model="" 35 | ``` 36 | 37 | 38 | ### Sample Test Results 39 | Trained and tested with dbpedia dataset. (```dbpedia_csv/train.csv```, ```dbpedia_csv/test.csv```) 40 | 41 | Model | WordCNN | CharCNN | VDCNN | WordRNN | AttentionRNN | RCNN | *SA-LSTM | *LM-LSTM | 42 | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 43 | Accuracy | 98.42% | 98.05% | 97.60% | 98.57% | 98.61% | 98.68% | 98.88% | 98.86% | 44 | 45 | (SA-LSTM and LM-LSTM are implemented at [[dongjun-Lee/transfer-learning-text-tf]](https://github.com/dongjun-Lee/transfer-learning-text-tf).) 46 | 47 | 48 | ## Models 49 | 50 | ### 1. Word-level CNN 51 | Implementation of [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882). 52 | 53 | 54 | 55 | 56 | ### 2. Char-level CNN 57 | Implementation of [Character-level Convolutional Networks for Text Classification](https://arxiv.org/abs/1509.01626). 58 | 59 | 60 | 61 | 62 | ### 3. Very Deep CNN (VDCNN) 63 | Implementation of [Very Deep Convolutional Networks for Text Classification](https://arxiv.org/abs/1606.01781). 64 | 65 | 66 | 67 | ### 4. Word-level Bi-RNN 68 | Bi-directional RNN for Text Classification. 69 | 70 | 1. Embedding layer 71 | 2. Bidirectional RNN layer 72 | 3. Concat all the outputs from RNN layer 73 | 4. Fully-connected layer 74 | 75 | 76 | ### 5. Attention-Based Bi-RNN 77 | Implementation of [Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification](http://www.aclweb.org/anthology/P16-2034). 78 | 79 | 80 | 81 | ### 6. RCNN 82 | Implementation of [Recurrent Convolutional Neural Networks for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552). 83 | 84 | 85 | 86 | ## References 87 | - [dennybritz/cnn-text-classification-tf](https://github.com/dennybritz/cnn-text-classification-tf) 88 | - [zonetrooper32/VDCNN](https://github.com/zonetrooper32/VDCNN) 89 | -------------------------------------------------------------------------------- /cnn_models/char_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class CharCNN(object): 5 | def __init__(self, alphabet_size, document_max_len, num_class): 6 | self.learning_rate = 1e-3 7 | self.filter_sizes = [7, 7, 3, 3, 3, 3] 8 | self.num_filters = 256 9 | self.kernel_initializer = tf.truncated_normal_initializer(stddev=0.05) 10 | 11 | self.x = tf.placeholder(tf.int32, [None, document_max_len], name="x") 12 | self.y = tf.placeholder(tf.int32, [None], name="y") 13 | self.is_training = tf.placeholder(tf.bool, [], name="is_training") 14 | self.global_step = tf.Variable(0, trainable=False) 15 | self.keep_prob = tf.where(self.is_training, 0.5, 1.0) 16 | 17 | self.x_one_hot = tf.one_hot(self.x, alphabet_size) 18 | self.x_expanded = tf.expand_dims(self.x_one_hot, -1) 19 | 20 | # ============= Convolutional Layers ============= 21 | with tf.name_scope("conv-maxpool-1"): 22 | conv1 = tf.layers.conv2d( 23 | self.x_expanded, 24 | filters=self.num_filters, 25 | kernel_size=[self.filter_sizes[0], alphabet_size], 26 | kernel_initializer=self.kernel_initializer, 27 | activation=tf.nn.relu) 28 | pool1 = tf.layers.max_pooling2d( 29 | conv1, 30 | pool_size=(3, 1), 31 | strides=(3, 1)) 32 | pool1 = tf.transpose(pool1, [0, 1, 3, 2]) 33 | 34 | with tf.name_scope("conv-maxpool-2"): 35 | conv2 = tf.layers.conv2d( 36 | pool1, 37 | filters=self.num_filters, 38 | kernel_size=[self.filter_sizes[1], self.num_filters], 39 | kernel_initializer=self.kernel_initializer, 40 | activation=tf.nn.relu) 41 | pool2 = tf.layers.max_pooling2d( 42 | conv2, 43 | pool_size=(3, 1), 44 | strides=(3, 1)) 45 | pool2 = tf.transpose(pool2, [0, 1, 3, 2]) 46 | 47 | with tf.name_scope("conv-3"): 48 | conv3 = tf.layers.conv2d( 49 | pool2, 50 | filters=self.num_filters, 51 | kernel_size=[self.filter_sizes[2], self.num_filters], 52 | kernel_initializer=self.kernel_initializer, 53 | activation=tf.nn.relu) 54 | conv3 = tf.transpose(conv3, [0, 1, 3, 2]) 55 | 56 | with tf.name_scope("conv-4"): 57 | conv4 = tf.layers.conv2d( 58 | conv3, 59 | filters=self.num_filters, 60 | kernel_size=[self.filter_sizes[3], self.num_filters], 61 | kernel_initializer=self.kernel_initializer, 62 | activation=tf.nn.relu) 63 | conv4 = tf.transpose(conv4, [0, 1, 3, 2]) 64 | 65 | with tf.name_scope("conv-5"): 66 | conv5 = tf.layers.conv2d( 67 | conv4, 68 | filters=self.num_filters, 69 | kernel_size=[self.filter_sizes[4], self.num_filters], 70 | kernel_initializer=self.kernel_initializer, 71 | activation=tf.nn.relu) 72 | conv5 = tf.transpose(conv5, [0, 1, 3, 2]) 73 | 74 | with tf.name_scope("conv-maxpool-6"): 75 | conv6 = tf.layers.conv2d( 76 | conv5, 77 | filters=self.num_filters, 78 | kernel_size=[self.filter_sizes[5], self.num_filters], 79 | kernel_initializer=self.kernel_initializer, 80 | activation=tf.nn.relu) 81 | pool6 = tf.layers.max_pooling2d( 82 | conv6, 83 | pool_size=(3, 1), 84 | strides=(3, 1)) 85 | pool6 = tf.transpose(pool6, [0, 2, 1, 3]) 86 | h_pool = tf.reshape(pool6, [-1, 34 * self.num_filters]) 87 | 88 | # ============= Fully Connected Layers ============= 89 | with tf.name_scope("fc-1"): 90 | fc1_out = tf.layers.dense(h_pool, 1024, activation=tf.nn.relu, kernel_initializer=self.kernel_initializer) 91 | 92 | with tf.name_scope("fc-2"): 93 | fc2_out = tf.layers.dense(fc1_out, 1024, activation=tf.nn.relu, kernel_initializer=self.kernel_initializer) 94 | 95 | with tf.name_scope("fc-3"): 96 | self.logits = tf.layers.dense(fc2_out, num_class, activation=None, kernel_initializer=self.kernel_initializer) 97 | self.predictions = tf.argmax(self.logits, -1, output_type=tf.int32) 98 | 99 | # ============= Loss and Accuracy ============= 100 | with tf.name_scope("loss"): 101 | self.y_one_hot = tf.one_hot(self.y, num_class) 102 | self.loss = tf.reduce_mean( 103 | tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits, labels=self.y_one_hot)) 104 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 105 | 106 | with tf.name_scope("accuracy"): 107 | correct_predictions = tf.equal(self.predictions, self.y) 108 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 109 | -------------------------------------------------------------------------------- /cnn_models/vd_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class VDCNN(object): 5 | def __init__(self, alphabet_size, document_max_len, num_class): 6 | self.embedding_size = 16 7 | self.filter_sizes = [3, 3, 3, 3, 3] 8 | self.num_filters = [64, 64, 128, 256, 512] 9 | self.num_blocks = [2, 2, 2, 2] 10 | self.learning_rate = 1e-3 11 | self.cnn_initializer = tf.keras.initializers.he_normal() 12 | self.fc_initializer = tf.truncated_normal_initializer(stddev=0.05) 13 | 14 | self.x = tf.placeholder(tf.int32, [None, document_max_len], name="x") 15 | self.y = tf.placeholder(tf.int32, [None], name="y") 16 | self.is_training = tf.placeholder(tf.bool, [], name="is_training") 17 | self.global_step = tf.Variable(0, trainable=False) 18 | 19 | # ============= Embedding Layer ============= 20 | with tf.name_scope("embedding"): 21 | init_embeddings = tf.random_uniform([alphabet_size, self.embedding_size], -1.0, 1.0) 22 | self.embeddings = tf.get_variable("embeddings", initializer=init_embeddings) 23 | x_emb = tf.nn.embedding_lookup(self.embeddings, self.x) 24 | self.x_expanded = tf.expand_dims(x_emb, -1) 25 | 26 | # ============= First Convolution Layer ============= 27 | with tf.variable_scope("conv-0"): 28 | conv0 = tf.layers.conv2d( 29 | self.x_expanded, 30 | filters=self.num_filters[0], 31 | kernel_size=[self.filter_sizes[0], self.embedding_size], 32 | kernel_initializer=self.cnn_initializer, 33 | activation=tf.nn.relu) 34 | conv0 = tf.transpose(conv0, [0, 1, 3, 2]) 35 | 36 | # ============= Convolution Blocks ============= 37 | with tf.name_scope("conv-block-1"): 38 | conv1 = self.conv_block(conv0, 1) 39 | 40 | with tf.name_scope("conv-block-2"): 41 | conv2 = self.conv_block(conv1, 2) 42 | 43 | with tf.name_scope("conv-block-3"): 44 | conv3 = self.conv_block(conv2, 3) 45 | 46 | with tf.name_scope("conv-block-4"): 47 | conv4 = self.conv_block(conv3, 4, max_pool=False) 48 | 49 | # ============= k-max Pooling ============= 50 | with tf.name_scope("k-max-pooling"): 51 | h = tf.transpose(tf.squeeze(conv4, -1), [0, 2, 1]) 52 | top_k = tf.nn.top_k(h, k=8, sorted=False).values 53 | h_flat = tf.reshape(top_k, [-1, 512 * 8]) 54 | 55 | # ============= Fully Connected Layers ============= 56 | with tf.name_scope("fc-1"): 57 | fc1_out = tf.layers.dense(h_flat, 2048, activation=tf.nn.relu, kernel_initializer=self.fc_initializer) 58 | 59 | with tf.name_scope("fc-2"): 60 | fc2_out = tf.layers.dense(fc1_out, 2048, activation=tf.nn.relu, kernel_initializer=self.fc_initializer) 61 | 62 | with tf.name_scope("fc-3"): 63 | self.logits = tf.layers.dense(fc2_out, num_class, activation=None, kernel_initializer=self.fc_initializer) 64 | self.predictions = tf.argmax(self.logits, -1, output_type=tf.int32) 65 | 66 | # ============= Loss and Accuracy ============= 67 | with tf.name_scope("loss"): 68 | y_one_hot = tf.one_hot(self.y, num_class) 69 | self.loss = tf.reduce_mean( 70 | tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits, labels=y_one_hot)) 71 | 72 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 73 | with tf.control_dependencies(update_ops): 74 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 75 | 76 | with tf.name_scope("accuracy"): 77 | correct_predictions = tf.equal(self.predictions, self.y) 78 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 79 | 80 | def conv_block(self, input, i, max_pool=True): 81 | with tf.variable_scope("conv-block-%s" % i): 82 | # Two "conv-batch_norm-relu" layers. 83 | for j in range(2): 84 | with tf.variable_scope("conv-%s" % j): 85 | # convolution 86 | conv = tf.layers.conv2d( 87 | input, 88 | filters=self.num_filters[i], 89 | kernel_size=[self.filter_sizes[i], self.num_filters[i-1]], 90 | kernel_initializer=self.cnn_initializer, 91 | activation=None) 92 | # batch normalization 93 | conv = tf.layers.batch_normalization(conv, training=self.is_training) 94 | # relu 95 | conv = tf.nn.relu(conv) 96 | conv = tf.transpose(conv, [0, 1, 3, 2]) 97 | 98 | if max_pool: 99 | # Max pooling 100 | pool = tf.layers.max_pooling2d( 101 | conv, 102 | pool_size=(3, 1), 103 | strides=(2, 1), 104 | padding="SAME") 105 | return pool 106 | else: 107 | return conv 108 | -------------------------------------------------------------------------------- /cnn_models/word_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class WordCNN(object): 5 | def __init__(self, vocabulary_size, document_max_len, num_class): 6 | self.embedding_size = 128 7 | self.learning_rate = 1e-3 8 | self.filter_sizes = [3, 4, 5] 9 | self.num_filters = 100 10 | 11 | self.x = tf.placeholder(tf.int32, [None, document_max_len], name="x") 12 | self.y = tf.placeholder(tf.int32, [None], name="y") 13 | self.is_training = tf.placeholder(tf.bool, [], name="is_training") 14 | self.global_step = tf.Variable(0, trainable=False) 15 | self.keep_prob = tf.where(self.is_training, 0.5, 1.0) 16 | 17 | with tf.name_scope("embedding"): 18 | init_embeddings = tf.random_uniform([vocabulary_size, self.embedding_size]) 19 | self.embeddings = tf.get_variable("embeddings", initializer=init_embeddings) 20 | self.x_emb = tf.nn.embedding_lookup(self.embeddings, self.x) 21 | self.x_emb = tf.expand_dims(self.x_emb, -1) 22 | 23 | pooled_outputs = [] 24 | for filter_size in self.filter_sizes: 25 | conv = tf.layers.conv2d( 26 | self.x_emb, 27 | filters=self.num_filters, 28 | kernel_size=[filter_size, self.embedding_size], 29 | strides=(1, 1), 30 | padding="VALID", 31 | activation=tf.nn.relu) 32 | pool = tf.layers.max_pooling2d( 33 | conv, 34 | pool_size=[document_max_len - filter_size + 1, 1], 35 | strides=(1, 1), 36 | padding="VALID") 37 | pooled_outputs.append(pool) 38 | 39 | h_pool = tf.concat(pooled_outputs, 3) 40 | h_pool_flat = tf.reshape(h_pool, [-1, self.num_filters * len(self.filter_sizes)]) 41 | 42 | with tf.name_scope("dropout"): 43 | h_drop = tf.nn.dropout(h_pool_flat, self.keep_prob) 44 | 45 | with tf.name_scope("output"): 46 | self.logits = tf.layers.dense(h_drop, num_class, activation=None) 47 | self.predictions = tf.argmax(self.logits, -1, output_type=tf.int32) 48 | 49 | with tf.name_scope("loss"): 50 | self.loss = tf.reduce_mean( 51 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)) 52 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 53 | 54 | with tf.name_scope("accuracy"): 55 | correct_predictions = tf.equal(self.predictions, self.y) 56 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 57 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import re 5 | from nltk.tokenize import word_tokenize 6 | import collections 7 | import pandas as pd 8 | import pickle 9 | import numpy as np 10 | 11 | TRAIN_PATH = "dbpedia_csv/train.csv" 12 | TEST_PATH = "dbpedia_csv/test.csv" 13 | 14 | 15 | def download_dbpedia(): 16 | dbpedia_url = 'https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz' 17 | 18 | wget.download(dbpedia_url) 19 | with tarfile.open("dbpedia_csv.tar.gz", "r:gz") as tar: 20 | tar.extractall() 21 | 22 | 23 | def clean_str(text): 24 | text = re.sub(r"[^A-Za-z0-9(),!?\'\`\"]", " ", text) 25 | text = re.sub(r"\s{2,}", " ", text) 26 | text = text.strip().lower() 27 | 28 | return text 29 | 30 | 31 | def build_word_dict(): 32 | if not os.path.exists("word_dict.pickle"): 33 | train_df = pd.read_csv(TRAIN_PATH, names=["class", "title", "content"]) 34 | contents = train_df["content"] 35 | 36 | words = list() 37 | for content in contents: 38 | for word in word_tokenize(clean_str(content)): 39 | words.append(word) 40 | 41 | word_counter = collections.Counter(words).most_common() 42 | word_dict = dict() 43 | word_dict[""] = 0 44 | word_dict[""] = 1 45 | word_dict[""] = 2 46 | for word, _ in word_counter: 47 | word_dict[word] = len(word_dict) 48 | 49 | with open("word_dict.pickle", "wb") as f: 50 | pickle.dump(word_dict, f) 51 | 52 | else: 53 | with open("word_dict.pickle", "rb") as f: 54 | word_dict = pickle.load(f) 55 | 56 | return word_dict 57 | 58 | 59 | def build_word_dataset(step, word_dict, document_max_len): 60 | if step == "train": 61 | df = pd.read_csv(TRAIN_PATH, names=["class", "title", "content"]) 62 | else: 63 | df = pd.read_csv(TEST_PATH, names=["class", "title", "content"]) 64 | 65 | # Shuffle dataframe 66 | df = df.sample(frac=1) 67 | x = list(map(lambda d: word_tokenize(clean_str(d)), df["content"])) 68 | x = list(map(lambda d: list(map(lambda w: word_dict.get(w, word_dict[""]), d)), x)) 69 | x = list(map(lambda d: d + [word_dict[""]], x)) 70 | x = list(map(lambda d: d[:document_max_len], x)) 71 | x = list(map(lambda d: d + (document_max_len - len(d)) * [word_dict[""]], x)) 72 | 73 | y = list(map(lambda d: d - 1, list(df["class"]))) 74 | 75 | return x, y 76 | 77 | 78 | def build_char_dataset(step, model, document_max_len): 79 | alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} " 80 | if step == "train": 81 | df = pd.read_csv(TRAIN_PATH, names=["class", "title", "content"]) 82 | else: 83 | df = pd.read_csv(TEST_PATH, names=["class", "title", "content"]) 84 | 85 | # Shuffle dataframe 86 | df = df.sample(frac=1) 87 | 88 | char_dict = dict() 89 | char_dict[""] = 0 90 | char_dict[""] = 1 91 | for c in alphabet: 92 | char_dict[c] = len(char_dict) 93 | 94 | alphabet_size = len(alphabet) + 2 95 | 96 | x = list(map(lambda content: list(map(lambda d: char_dict.get(d, char_dict[""]), content.lower())), df["content"])) 97 | x = list(map(lambda d: d[:document_max_len], x)) 98 | x = list(map(lambda d: d + (document_max_len - len(d)) * [char_dict[""]], x)) 99 | 100 | y = list(map(lambda d: d - 1, list(df["class"]))) 101 | 102 | return x, y, alphabet_size 103 | 104 | 105 | def batch_iter(inputs, outputs, batch_size, num_epochs): 106 | inputs = np.array(inputs) 107 | outputs = np.array(outputs) 108 | 109 | num_batches_per_epoch = (len(inputs) - 1) // batch_size + 1 110 | for epoch in range(num_epochs): 111 | for batch_num in range(num_batches_per_epoch): 112 | start_index = batch_num * batch_size 113 | end_index = min((batch_num + 1) * batch_size, len(inputs)) 114 | yield inputs[start_index:end_index], outputs[start_index:end_index] 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim==3.3.0 2 | wget==3.2 3 | nltk==3.2.5 4 | scikit-learn==0.19.1 5 | -------------------------------------------------------------------------------- /rnn_models/attention_rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | 4 | 5 | class AttentionRNN(object): 6 | def __init__(self, vocabulary_size, document_max_len, num_class): 7 | self.embedding_size = 256 8 | self.num_hidden = 256 9 | self.num_layers = 2 10 | self.learning_rate = 1e-3 11 | 12 | self.x = tf.placeholder(tf.int32, [None, document_max_len], name="x") 13 | self.x_len = tf.reduce_sum(tf.sign(self.x), 1) 14 | self.y = tf.placeholder(tf.int32, [None], name="y") 15 | self.is_training = tf.placeholder(tf.bool, [], name="is_training") 16 | self.global_step = tf.Variable(0, trainable=False) 17 | self.keep_prob = tf.where(self.is_training, 0.5, 1.0) 18 | 19 | with tf.name_scope("embedding"): 20 | init_embeddings = tf.random_uniform([vocabulary_size, self.embedding_size]) 21 | self.embeddings = tf.get_variable("embeddings", initializer=init_embeddings) 22 | self.x_emb = tf.nn.embedding_lookup(self.embeddings, self.x) 23 | 24 | with tf.name_scope("birnn"): 25 | fw_cells = [rnn.BasicLSTMCell(self.num_hidden) for _ in range(self.num_layers)] 26 | bw_cells = [rnn.BasicLSTMCell(self.num_hidden) for _ in range(self.num_layers)] 27 | fw_cells = [rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob) for cell in fw_cells] 28 | bw_cells = [rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob) for cell in bw_cells] 29 | 30 | self.rnn_outputs, _, _ = rnn.stack_bidirectional_dynamic_rnn( 31 | fw_cells, bw_cells, self.x_emb, sequence_length=self.x_len, dtype=tf.float32) 32 | 33 | with tf.name_scope("attention"): 34 | self.attention_score = tf.nn.softmax(tf.layers.dense(self.rnn_outputs, 1, activation=tf.nn.tanh), axis=1) 35 | self.attention_out = tf.squeeze( 36 | tf.matmul(tf.transpose(self.rnn_outputs, perm=[0, 2, 1]), self.attention_score), 37 | axis=-1) 38 | 39 | with tf.name_scope("output"): 40 | self.logits = tf.layers.dense(self.attention_out, num_class, activation=None) 41 | self.predictions = tf.argmax(self.logits, -1, output_type=tf.int32) 42 | 43 | with tf.name_scope("loss"): 44 | self.loss = tf.reduce_mean( 45 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)) 46 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 47 | 48 | with tf.name_scope("accuracy"): 49 | correct_predictions = tf.equal(self.predictions, self.y) 50 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 51 | -------------------------------------------------------------------------------- /rnn_models/rcnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | 4 | 5 | class RCNN(object): 6 | def __init__(self, vocabulary_size, document_max_len, num_class): 7 | self.embedding_size = 256 8 | self.rnn_num_hidden = 256 9 | self.fc_num_hidden = 256 10 | self.learning_rate = 1e-3 11 | 12 | self.x = tf.placeholder(tf.int32, [None, document_max_len], name="x") 13 | self.x_len = tf.reduce_sum(tf.sign(self.x), 1) 14 | self.y = tf.placeholder(tf.int32, [None], name="y") 15 | self.is_training = tf.placeholder(tf.bool, [], name="is_training") 16 | self.global_step = tf.Variable(0, trainable=False) 17 | self.keep_prob = tf.where(self.is_training, 0.5, 1.0) 18 | 19 | with tf.name_scope("embedding"): 20 | init_embeddings = tf.random_uniform([vocabulary_size, self.embedding_size]) 21 | self.embeddings = tf.get_variable("embeddings", initializer=init_embeddings) 22 | self.x_emb = tf.nn.embedding_lookup(self.embeddings, self.x) 23 | 24 | with tf.name_scope("birnn"): 25 | fw_cell = rnn.BasicLSTMCell(self.rnn_num_hidden) 26 | bw_cell = rnn.BasicLSTMCell(self.rnn_num_hidden) 27 | fw_cell = rnn.DropoutWrapper(fw_cell, output_keep_prob=self.keep_prob) 28 | bw_cell = rnn.DropoutWrapper(bw_cell, output_keep_prob=self.keep_prob) 29 | 30 | rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, self.x_emb, 31 | sequence_length=self.x_len, dtype=tf.float32) 32 | self.fw_output, self.bw_output = rnn_outputs 33 | 34 | with tf.name_scope("word-representation"): 35 | x = tf.concat([self.fw_output, self.x_emb, self.bw_output], axis=2) 36 | self.y2 = tf.layers.dense(x, self.fc_num_hidden, activation=tf.nn.tanh) 37 | 38 | with tf.name_scope("text-representation"): 39 | self.y3 = tf.reduce_max(self.y2, axis=1) 40 | 41 | with tf.name_scope("output"): 42 | self.logits = tf.layers.dense(self.y3, num_class) 43 | self.predictions = tf.argmax(self.logits, -1, output_type=tf.int32) 44 | 45 | with tf.name_scope("loss"): 46 | self.loss = tf.reduce_mean( 47 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)) 48 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 49 | 50 | with tf.name_scope("accuracy"): 51 | correct_predictions = tf.equal(self.predictions, self.y) 52 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 53 | -------------------------------------------------------------------------------- /rnn_models/word_rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | 4 | 5 | class WordRNN(object): 6 | def __init__(self, vocabulary_size, document_max_len, num_class): 7 | self.embedding_size = 256 8 | self.num_hidden = 256 9 | self.num_layers = 2 10 | self.learning_rate = 1e-3 11 | 12 | self.x = tf.placeholder(tf.int32, [None, document_max_len], name="x") 13 | self.x_len = tf.reduce_sum(tf.sign(self.x), 1) 14 | self.y = tf.placeholder(tf.int32, [None], name="y") 15 | self.is_training = tf.placeholder(tf.bool, [], name="is_training") 16 | self.global_step = tf.Variable(0, trainable=False) 17 | self.keep_prob = tf.where(self.is_training, 0.5, 1.0) 18 | 19 | with tf.name_scope("embedding"): 20 | init_embeddings = tf.random_uniform([vocabulary_size, self.embedding_size]) 21 | self.embeddings = tf.get_variable("embeddings", initializer=init_embeddings) 22 | self.x_emb = tf.nn.embedding_lookup(self.embeddings, self.x) 23 | 24 | with tf.name_scope("birnn"): 25 | fw_cells = [rnn.BasicLSTMCell(self.num_hidden) for _ in range(self.num_layers)] 26 | bw_cells = [rnn.BasicLSTMCell(self.num_hidden) for _ in range(self.num_layers)] 27 | fw_cells = [rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob) for cell in fw_cells] 28 | bw_cells = [rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob) for cell in bw_cells] 29 | 30 | rnn_outputs, _, _ = rnn.stack_bidirectional_dynamic_rnn( 31 | fw_cells, bw_cells, self.x_emb, sequence_length=self.x_len, dtype=tf.float32) 32 | rnn_outputs_flat = tf.reshape(rnn_outputs, [-1, document_max_len * self.num_hidden * 2]) 33 | 34 | with tf.name_scope("output"): 35 | self.logits = tf.layers.dense(rnn_outputs_flat, num_class, activation=None) 36 | self.predictions = tf.argmax(self.logits, -1, output_type=tf.int32) 37 | 38 | with tf.name_scope("loss"): 39 | self.loss = tf.reduce_mean( 40 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)) 41 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step) 42 | 43 | with tf.name_scope("accuracy"): 44 | correct_predictions = tf.equal(self.predictions, self.y) 45 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 46 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | from data_utils import * 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--model", type=str, default="word_cnn", 8 | help="word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn") 9 | args = parser.parse_args() 10 | 11 | BATCH_SIZE = 128 12 | WORD_MAX_LEN = 100 13 | CHAR_MAX_LEN = 1014 14 | 15 | if args.model == "char_cnn": 16 | test_x, test_y, alphabet_size = build_char_dataset("test", "char_cnn", CHAR_MAX_LEN) 17 | elif args.model == "vd_cnn": 18 | test_x, test_y, alphabet_size = build_char_dataset("test", "vdcnn", CHAR_MAX_LEN) 19 | else: 20 | word_dict = build_word_dict() 21 | test_x, test_y = build_word_dataset("test", word_dict, WORD_MAX_LEN) 22 | 23 | checkpoint_file = tf.train.latest_checkpoint(args.model) 24 | graph = tf.Graph() 25 | with graph.as_default(): 26 | with tf.Session() as sess: 27 | saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) 28 | saver.restore(sess, checkpoint_file) 29 | 30 | x = graph.get_operation_by_name("x").outputs[0] 31 | y = graph.get_operation_by_name("y").outputs[0] 32 | is_training = graph.get_operation_by_name("is_training").outputs[0] 33 | accuracy = graph.get_operation_by_name("accuracy/accuracy").outputs[0] 34 | 35 | batches = batch_iter(test_x, test_y, BATCH_SIZE, 1) 36 | sum_accuracy, cnt = 0, 0 37 | for batch_x, batch_y in batches: 38 | feed_dict = { 39 | x: batch_x, 40 | y: batch_y, 41 | is_training: False 42 | } 43 | 44 | accuracy_out = sess.run(accuracy, feed_dict=feed_dict) 45 | sum_accuracy += accuracy_out 46 | cnt += 1 47 | 48 | print("Test Accuracy : {0}".format(sum_accuracy / cnt)) 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | from data_utils import * 4 | from sklearn.model_selection import train_test_split 5 | from cnn_models.word_cnn import WordCNN 6 | from cnn_models.char_cnn import CharCNN 7 | from cnn_models.vd_cnn import VDCNN 8 | from rnn_models.word_rnn import WordRNN 9 | from rnn_models.attention_rnn import AttentionRNN 10 | from rnn_models.rcnn import RCNN 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--model", type=str, default="word_cnn", 15 | help="word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn") 16 | args = parser.parse_args() 17 | 18 | if not os.path.exists("dbpedia_csv"): 19 | print("Downloading dbpedia dataset...") 20 | download_dbpedia() 21 | 22 | NUM_CLASS = 14 23 | BATCH_SIZE = 64 24 | NUM_EPOCHS = 10 25 | WORD_MAX_LEN = 100 26 | CHAR_MAX_LEN = 1014 27 | 28 | print("Building dataset...") 29 | if args.model == "char_cnn": 30 | x, y, alphabet_size = build_char_dataset("train", "char_cnn", CHAR_MAX_LEN) 31 | elif args.model == "vd_cnn": 32 | x, y, alphabet_size = build_char_dataset("train", "vdcnn", CHAR_MAX_LEN) 33 | else: 34 | word_dict = build_word_dict() 35 | vocabulary_size = len(word_dict) 36 | x, y = build_word_dataset("train", word_dict, WORD_MAX_LEN) 37 | 38 | train_x, valid_x, train_y, valid_y = train_test_split(x, y, test_size=0.15) 39 | 40 | 41 | with tf.Session() as sess: 42 | if args.model == "word_cnn": 43 | model = WordCNN(vocabulary_size, WORD_MAX_LEN, NUM_CLASS) 44 | elif args.model == "char_cnn": 45 | model = CharCNN(alphabet_size, CHAR_MAX_LEN, NUM_CLASS) 46 | elif args.model == "vd_cnn": 47 | model = VDCNN(alphabet_size, CHAR_MAX_LEN, NUM_CLASS) 48 | elif args.model == "word_rnn": 49 | model = WordRNN(vocabulary_size, WORD_MAX_LEN, NUM_CLASS) 50 | elif args.model == "att_rnn": 51 | model = AttentionRNN(vocabulary_size, WORD_MAX_LEN, NUM_CLASS) 52 | elif args.model == "rcnn": 53 | model = RCNN(vocabulary_size, WORD_MAX_LEN, NUM_CLASS) 54 | else: 55 | raise NotImplementedError() 56 | 57 | sess.run(tf.global_variables_initializer()) 58 | saver = tf.train.Saver(tf.global_variables()) 59 | 60 | train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS) 61 | num_batches_per_epoch = (len(train_x) - 1) // BATCH_SIZE + 1 62 | max_accuracy = 0 63 | 64 | for x_batch, y_batch in train_batches: 65 | train_feed_dict = { 66 | model.x: x_batch, 67 | model.y: y_batch, 68 | model.is_training: True 69 | } 70 | 71 | _, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict=train_feed_dict) 72 | 73 | if step % 100 == 0: 74 | print("step {0}: loss = {1}".format(step, loss)) 75 | 76 | if step % 2000 == 0: 77 | # Test accuracy with validation data for each epoch. 78 | valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1) 79 | sum_accuracy, cnt = 0, 0 80 | 81 | for valid_x_batch, valid_y_batch in valid_batches: 82 | valid_feed_dict = { 83 | model.x: valid_x_batch, 84 | model.y: valid_y_batch, 85 | model.is_training: False 86 | } 87 | 88 | accuracy = sess.run(model.accuracy, feed_dict=valid_feed_dict) 89 | sum_accuracy += accuracy 90 | cnt += 1 91 | valid_accuracy = sum_accuracy / cnt 92 | 93 | print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) 94 | 95 | # Save model 96 | if valid_accuracy > max_accuracy: 97 | max_accuracy = valid_accuracy 98 | saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step=step) 99 | print("Model is saved.\n") 100 | --------------------------------------------------------------------------------