├── README.md ├── data_utils.py ├── mixture_utils.py ├── model_utils.py ├── models.py ├── process_imdb.py ├── process_other.py ├── process_yelp.py ├── run_me.py └── run_se_ce_cae.py /README.md: -------------------------------------------------------------------------------- 1 | # Word Cluster Embeddings 2 | Code for ["Smaller Text Classifiers with Discriminative Cluster Embeddings"](http://ttic.uchicago.edu/~mchen/papers/mchen+kgimpel.naacl18.pdf) (NAACL 2018) 3 | 4 | ## Citation 5 | If you use our code, please cite: 6 | 7 | ``` 8 | @inproceedings{chen2018smaller, 9 | author = {Mingda Chen and Kevin Gimpel}, 10 | booktitle = {North American Association for Computational Linguistics (NAACL)}, 11 | title = {Smaller Text Classifiers with Discriminative Cluster Embeddings}, 12 | year = {2018} 13 | } 14 | ``` 15 | 16 | ## Dependencies 17 | 18 | - Python 3.5 19 | - TensorFlow 1.3 20 | - NLTK (for tokenizing IMDB dataset) 21 | 22 | ## Prepare Data 23 | 24 | You can download AG News, DBpedia, Yelp Review Full, Yelp Review Polarity from [here](http://goo.gl/JyCnZq) and IMDB data from [here](http://ai.stanford.edu/~amaas/data/sentiment/). Then run the corresponding data processing code to generate data file. 25 | 26 | Note that in this code Yelp Review Full and Yelp Review Polarity were renamed to yelp-1 and yelp-2 respectively. If you want to use other names, please modify the code accordingly. 27 | 28 | ## Training 29 | 30 | Use `run_se_ce_cae.py` for training standard embeddings, cluster embeddings or cluster adjustment embeddings. Use `run_me.py` for training mixture embeddings. 31 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from collections import Counter 7 | 8 | 9 | def get_dict(data): 10 | word_count = Counter() 11 | for sent in data: 12 | for word in sent: 13 | word_count[word] += 1 14 | return word_count 15 | 16 | 17 | def get_n_class(dataset): 18 | if dataset.lower() == "yelp-1": 19 | return 5 20 | elif dataset.lower() == "yahoo": 21 | return 10 22 | elif dataset.lower() == "agnews": 23 | return 4 24 | elif dataset.lower() == "dbpedia": 25 | return 14 26 | else: 27 | return 2 28 | 29 | 30 | def make_dict( 31 | vocab, vocab_file, max_words=None, save_vocab=False): 32 | if max_words is None: 33 | max_words = len(vocab) 34 | 35 | ls = vocab.most_common(max_words) 36 | 37 | logging.info('#Words: %d -> %d' % (len(vocab), len(ls))) 38 | for key in ls[:5]: 39 | logging.info(key) 40 | logging.info('...') 41 | for key in ls[-5:]: 42 | logging.info(key) 43 | 44 | vocab = {w[0]: index + 1 for (index, w) in enumerate(ls)} 45 | if save_vocab: 46 | logging.info("vocab saving to {}".format(vocab_file)) 47 | with open(vocab_file, "wb+") as vocab_fp: 48 | pickle.dump(vocab, vocab_fp, protocol=-1) 49 | vocab[""] = 0 50 | return vocab 51 | 52 | 53 | def load_dict(vocab_file): 54 | logging.info("loading vocabularies from " + vocab_file + " ...") 55 | with open(vocab_file, "rb") as vocab_fp: 56 | vocab = pickle.load(vocab_fp) 57 | vocab[""] = 0 58 | logging.info("vocab size: {}".format(len(vocab))) 59 | return vocab 60 | 61 | 62 | def data_to_idx(data, vocab): 63 | return [to_idxs(sent, vocab) for sent in data] 64 | 65 | 66 | def prepare_data(revs, vocab): 67 | data = [] 68 | label = [] 69 | for rev in revs: 70 | data.append(to_idxs(rev["text"], vocab)) 71 | label.append(rev["y"]) 72 | return np.asarray(data), np.asarray(label) 73 | 74 | 75 | def make_batch(revs, labels, batch_size, shuffle=True): 76 | n = len(revs) 77 | revs = np.asarray(revs) 78 | labels = np.asarray(labels) 79 | if shuffle: 80 | perm = np.arange(n) 81 | np.random.shuffle(perm) 82 | revs = revs[perm] 83 | labels = labels[perm] 84 | idx_list = np.arange(0, n, batch_size) 85 | batch_data = [] 86 | batch_label = [] 87 | for idx in idx_list: 88 | batch_data.append( 89 | revs[np.arange(idx, min(idx + batch_size, n))]) 90 | batch_label.append( 91 | labels[np.arange(idx, min(idx + batch_size, n))]) 92 | return batch_data, batch_label 93 | 94 | 95 | def pad_seq(data): 96 | data_len = [len(data_) for data_ in data] 97 | max_len = np.max(data_len) 98 | data_holder = np.zeros((len(data), max_len)) 99 | mask = np.zeros_like(data_holder) 100 | 101 | for i, data_ in enumerate(data): 102 | data_holder[i, :len(data_)] = np.asarray(data_) 103 | mask[i, :len(data_)] = 1 104 | return data_holder, mask 105 | 106 | 107 | def show_data(seqs, inv_dict): 108 | def inv_vocab(x): 109 | return inv_dict[x] 110 | tmp = "" 111 | for seq in seqs: 112 | if isinstance(seq, np.int32) or isinstance(seq, int): 113 | tmp += inv_dict[seq] + ' ' 114 | else: 115 | print(' '.join(list(map(inv_vocab, seq)))) 116 | if tmp != "": 117 | print(tmp) 118 | 119 | 120 | # given the corresponding the char and return the index 121 | def to_index(word, vocab): 122 | return vocab.get(word, 0) 123 | 124 | 125 | # given the corresponding the chars and return the indexes. 126 | def to_idxs(words, vocab): 127 | idxs = [to_index(word, vocab) for word in words] 128 | return idxs 129 | 130 | 131 | def cal_unk(sents): 132 | unk_count = 0 133 | total_count = 0 134 | for sent in sents: 135 | for w in sent: 136 | if w == 0: 137 | unk_count += 1 138 | total_count += 1 139 | return unk_count, total_count, unk_count / total_count 140 | -------------------------------------------------------------------------------- /mixture_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | 7 | def make_dict( 8 | vocab, vocab_file, n_word_mask, max_words=None, save_vocab=False): 9 | if max_words is None: 10 | max_words = len(vocab) 11 | ls = vocab.most_common(max_words) 12 | 13 | logging.info('#Words: %d -> %d' % (len(vocab), len(ls))) 14 | for key in ls[:5]: 15 | logging.info(key) 16 | logging.info('...') 17 | for key in ls[-5:]: 18 | logging.info(key) 19 | 20 | word_to_keep = vocab.most_common(n_word_mask) 21 | word_to_keep_vocab = [w[0] for w in word_to_keep] 22 | logging.info("#unique embedding vectors: {}".format(len(word_to_keep))) 23 | 24 | vocab = {w[0]: index + 1 for (index, w) in enumerate(ls)} 25 | vocab[""] = 0 26 | mask = np.zeros(len(vocab)).astype("float32") 27 | for w in word_to_keep_vocab: 28 | mask[vocab[w]] = 1. 29 | if save_vocab: 30 | logging.info("vocab saving to {}".format(vocab_file)) 31 | with open(vocab_file, "wb+") as vocab_fp: 32 | pickle.dump([vocab, mask], vocab_fp, protocol=-1) 33 | return vocab, mask 34 | 35 | 36 | def load_dict(vocab_file): 37 | logging.info("loading vocabularies from " + vocab_file + " ...") 38 | with open(vocab_file, "rb") as vf: 39 | vocab, mask = pickle.load(vf) 40 | logging.info("#unique embedding vectors: {}".format(mask.sum())) 41 | logging.info("vocab size: {}".format(len(vocab))) 42 | return vocab, mask 43 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | """mostly borrow from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb""" 4 | 5 | 6 | def sample_gumbel(shape, eps=1e-20): 7 | """Sample from Gumbel(0, 1)""" 8 | U = tf.random_uniform(shape, minval=0, maxval=1) 9 | return -tf.log(-tf.log(U + eps) + eps) 10 | 11 | 12 | def gumbel_softmax_sample(logits, temperature): 13 | """ Draw a sample from the Gumbel-Softmax distribution""" 14 | y = logits + sample_gumbel(tf.shape(logits)) 15 | return tf.nn.softmax(y / temperature) 16 | 17 | 18 | def gumbel_softmax(logits, temperature, hard=False): 19 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 20 | Args: 21 | logits: [batch_size, n_class] unnormalized log-probs 22 | temperature: non-negative scalar 23 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 24 | Returns: 25 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 26 | If hard=True, then the returned sample will be one-hot, otherwise it will 27 | be a probabilitiy distribution that sums to 1 across classes 28 | """ 29 | y = gumbel_softmax_sample(logits, temperature) 30 | if hard: 31 | y_hard = tf.cast( 32 | tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype) 33 | y = tf.stop_gradient(y_hard - y) + y 34 | return y 35 | 36 | 37 | def softmax_with_temperature(logits, temperature, hard=False): 38 | y = tf.nn.softmax(logits / temperature) 39 | if hard: 40 | y_hard = tf.cast( 41 | tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype) 42 | y = tf.stop_gradient(y_hard - y) + y 43 | return y 44 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import tensorflow as tf 5 | 6 | from model_utils import gumbel_softmax, softmax_with_temperature 7 | 8 | 9 | class base_classifier: 10 | def __init__(self, args): 11 | # model configurations 12 | self.hidden_size = args.hidden_size 13 | self.embed_dim = args.embed_dim 14 | self.n_embed = args.n_embed 15 | self.n_class = args.n_class 16 | self.bidir = args.bidir 17 | self.proj = args.proj 18 | 19 | # training configurations 20 | self.trian_emb = args.train_emb 21 | self.grad_clip = args.grad_clip 22 | self.lr = args.learning_rate 23 | self.opt = args.opt 24 | self.l2 = args.l2 25 | 26 | if args.rnn_type.lower() == 'lstm': 27 | self.cell = tf.contrib.rnn.BasicLSTMCell 28 | elif args.rnn_type.lower() == 'gru': 29 | self.cell = tf.contrib.rnn.GRUCell 30 | elif args.rnn_type.lower() == 'rnn': 31 | self.cell = tf.contrib.rnn.core_rnn_cell.BasicRNNCell 32 | else: 33 | raise NotImplementedError('Invalid rnn type: %s' % args.rnn_type) 34 | 35 | def _build_graph(self): 36 | raise NotImplementedError() 37 | 38 | def _build_optimizer(self, loss): 39 | if self.opt.lower() == 'sgd': 40 | opt = tf.train.GradientDescentOptimizer(self.lr) 41 | elif self.opt.lower() == 'adam': 42 | opt = tf.train.AdamOptimizer(self.lr) 43 | elif self.opt.lower() == 'rmsprop': 44 | opt = tf.train.RMSPropOptimizer(self.lr) 45 | else: 46 | raise NotImplementedError("Invalid type of optimizer: {}" 47 | .format(self.opt)) 48 | 49 | vars_list = tf.trainable_variables() 50 | regularizer = tf.contrib.layers.l2_regularizer(scale=self.l2) 51 | reg_term = tf.contrib.layers.apply_regularization( 52 | regularizer, 53 | [v for v in vars_list if v.name.split(":")[0] != "gs_param"]) 54 | grads, _ = tf.clip_by_global_norm( 55 | tf.gradients(loss + reg_term, vars_list), self.grad_clip) 56 | self.updates = opt.apply_gradients(zip(grads, vars_list)) 57 | 58 | def train(self, sess, inputs, mask, labels, temp): 59 | feed_dict = {self.inputs: inputs, self.mask: mask, 60 | self.tau: temp, self.labels: labels} 61 | _, loss = sess.run( 62 | [self.updates, self.loss], feed_dict) 63 | return loss 64 | 65 | def evaluate(self, sess, inputs, mask, labels, temp): 66 | feed_dict = {self.inputs: inputs, self.mask: mask, 67 | self.tau: temp, self.labels: labels} 68 | acc = sess.run(self.acc, feed_dict) 69 | return acc 70 | 71 | def save(self, sess, saver, save_dir): 72 | save_path = os.path.join(save_dir, 'model.ckpt') 73 | saver.save(sess, save_path) 74 | logging.info("model saved to {}".format(save_path)) 75 | 76 | def restore(self, sess, save_dir): 77 | """ 78 | restore model 79 | """ 80 | save_path = os.path.join(save_dir, 'model.ckpt') 81 | loader = tf.train.import_meta_graph(save_path + '.meta') 82 | loader.restore(sess, save_path) 83 | logging.info("model restored from {}".format(save_path)) 84 | 85 | 86 | class sentiment_classifier(base_classifier): 87 | def __init__(self, vocab_size, args): 88 | super(sentiment_classifier, self).__init__(args) 89 | self.vocab_size = vocab_size 90 | self.aux_dim = args.aux_dim 91 | self._build_graph() 92 | self._build_optimizer(self.loss) 93 | 94 | def _build_graph(self): 95 | self.tau = tf.placeholder(tf.float32, name="temperature") 96 | self.inputs = tf.placeholder(tf.int32, [None, None], name="inputs") 97 | self.labels = tf.placeholder(tf.int32, [None, ], name="labels") 98 | self.mask = tf.placeholder(tf.int32, [None, None], name="mask") 99 | 100 | if self.proj.lower() == "gumbel": 101 | self.embedding = tf.get_variable( 102 | "embedding", [self.n_embed, self.embed_dim], 103 | trainable=self.trian_emb, 104 | initializer=tf.random_uniform_initializer( 105 | minval=-0.1, maxval=0.1)) 106 | self.gs_param = tf.get_variable( 107 | "gs_param", [self.vocab_size, self.n_embed], 108 | initializer=tf.random_uniform_initializer(maxval=1)) 109 | 110 | logits = tf.nn.embedding_lookup(self.gs_param, self.inputs) 111 | batch_size, batch_len = \ 112 | tf.shape(self.inputs)[0], tf.shape(self.inputs)[1] 113 | embed_prob = gumbel_softmax( 114 | tf.reshape(logits, [batch_size * batch_len, -1]), 115 | self.tau, hard=False) 116 | 117 | inputs_embed = tf.matmul(embed_prob, self.embedding) 118 | inputs_embed = tf.reshape( 119 | inputs_embed, [batch_size, batch_len, self.embed_dim]) 120 | 121 | test_embed_prob = softmax_with_temperature( 122 | tf.reshape(logits, [batch_size * batch_len, -1]), 123 | self.tau, hard=True) 124 | test_inputs_embed = tf.matmul( 125 | test_embed_prob, self.embedding) 126 | test_inputs_embed = tf.reshape( 127 | inputs_embed, [batch_size, batch_len, self.embed_dim]) 128 | 129 | elif self.proj.lower() == "standard": 130 | self.embedding = tf.get_variable( 131 | "embedding", [self.vocab_size, self.embed_dim], 132 | trainable=self.trian_emb, 133 | initializer=tf.random_uniform_initializer( 134 | minval=-0.1, maxval=0.1)) 135 | inputs_embed = tf.nn.embedding_lookup(self.embedding, self.inputs) 136 | test_inputs_embed = inputs_embed 137 | else: 138 | raise NotImplementedError( 139 | "invalid projection type: {}".format(self.proj)) 140 | 141 | if self.aux_dim: 142 | self.aux_embedding = tf.get_variable( 143 | "aux_embedding", [self.vocab_size, self.aux_dim], 144 | trainable=self.trian_emb, 145 | initializer=tf.random_uniform_initializer( 146 | minval=-0.1, maxval=0.1)) 147 | aux_inputs_emb = \ 148 | tf.nn.embedding_lookup(self.aux_embedding, self.inputs) 149 | 150 | inputs_embed = tf.concat([inputs_embed, aux_inputs_emb], axis=-1) 151 | test_inputs_embed = tf.concat( 152 | [test_inputs_embed, aux_inputs_emb], axis=-1) 153 | 154 | cell = self.cell(self.hidden_size) 155 | seq_length = tf.reduce_sum(self.mask, axis=1) 156 | 157 | with tf.variable_scope("dynamic_rnn") as scope: 158 | self.states, final_state = \ 159 | tf.nn.dynamic_rnn( 160 | cell=cell, 161 | inputs=inputs_embed, 162 | sequence_length=seq_length, 163 | dtype=tf.float32, 164 | scope="dynamic_rnn") 165 | if type(final_state) is tf.nn.rnn_cell.LSTMStateTuple: 166 | self.final_state = final_state.h 167 | else: 168 | self.final_state = final_state 169 | scope.reuse_variables() 170 | self.test_states, final_state = \ 171 | tf.nn.dynamic_rnn( 172 | cell=cell, 173 | inputs=test_inputs_embed, 174 | sequence_length=seq_length, 175 | dtype=tf.float32, 176 | scope="dynamic_rnn") 177 | if type(final_state) is tf.nn.rnn_cell.LSTMStateTuple: 178 | self.test_final_state = final_state.h 179 | else: 180 | self.test_final_state = final_state 181 | 182 | prob = tf.layers.dense( 183 | self.final_state, self.n_class, 184 | activation=None, name="prob") 185 | log_y_given_h = tf.nn.sparse_softmax_cross_entropy_with_logits( 186 | labels=self.labels, 187 | logits=prob, 188 | name="cross_entropy") 189 | test_prob = tf.layers.dense( 190 | self.test_final_state, self.n_class, 191 | activation=None, name="prob", reuse=True) 192 | self.loss = tf.reduce_mean(log_y_given_h) 193 | 194 | self.pred = tf.cast(tf.argmax(test_prob, axis=1), tf.int32) 195 | self.acc = tf.reduce_sum( 196 | tf.cast(tf.equal(self.labels, self.pred), tf.float32)) 197 | 198 | 199 | class mixture_classifier(base_classifier): 200 | def __init__(self, vocab_size, vocab_mask, args): 201 | super(mixture_classifier, self).__init__(args) 202 | self.vocab_size = vocab_size 203 | self.vocab_mask = vocab_mask 204 | 205 | self._build_graph() 206 | self._build_optimizer(self.loss) 207 | 208 | def _build_graph(self): 209 | self.tau = tf.placeholder(tf.float32, name="temperature") 210 | self.inputs = tf.placeholder(tf.int32, [None, None], name="inputs") 211 | self.labels = tf.placeholder(tf.int32, [None, ], name="labels") 212 | self.mask = tf.placeholder(tf.int32, [None, None], name="mask") 213 | self.vocab_mask = tf.Variable( 214 | self.vocab_mask, trainable=False, name="vocab_mask") 215 | 216 | self.embedding = tf.get_variable( 217 | "embedding", [self.vocab_size, self.embed_dim], 218 | trainable=self.trian_emb, 219 | initializer=tf.random_uniform_initializer( 220 | minval=-0.1, maxval=0.1)) 221 | 222 | if self.proj.lower() == "gumbel": 223 | self.cluster_embedding = tf.get_variable( 224 | "cluster_embedding", [self.n_embed, self.embed_dim], 225 | trainable=self.trian_emb, 226 | initializer=tf.random_uniform_initializer( 227 | minval=-0.1, maxval=0.1)) 228 | self.gs_param = tf.get_variable( 229 | "gs_param", [self.vocab_size, self.n_embed], 230 | initializer=tf.random_uniform_initializer(maxval=1)) 231 | 232 | logits = tf.nn.embedding_lookup(self.gs_param, self.inputs) 233 | batch_size, batch_len = \ 234 | tf.shape(self.inputs)[0], tf.shape(self.inputs)[1] 235 | 236 | embed_prob = gumbel_softmax( 237 | tf.reshape(logits, [batch_size * batch_len, -1]), 238 | self.tau, hard=False) 239 | inputs_embed = tf.matmul(embed_prob, self.cluster_embedding) 240 | inputs_embed = tf.reshape( 241 | inputs_embed, [batch_size, batch_len, self.embed_dim]) 242 | 243 | test_embed_prob = softmax_with_temperature( 244 | tf.reshape(logits, [batch_size * batch_len, -1]), 245 | self.tau, hard=True) 246 | test_inputs_embed = tf.matmul( 247 | test_embed_prob, self.cluster_embedding) 248 | test_inputs_embed = tf.reshape( 249 | inputs_embed, [batch_size, batch_len, self.embed_dim]) 250 | else: 251 | raise NotImplementedError( 252 | "invalid projection type: {}".format(self.proj)) 253 | 254 | v_mask = tf.nn.embedding_lookup(self.vocab_mask, self.inputs) 255 | full_emb = tf.nn.embedding_lookup(self.embedding, self.inputs) 256 | 257 | v_mask = tf.expand_dims(v_mask, axis=-1) 258 | 259 | inputs_embed = full_emb * v_mask + inputs_embed * (1. - v_mask) 260 | test_inputs_embed = \ 261 | full_emb * v_mask + test_inputs_embed * (1. - v_mask) 262 | 263 | cell = self.cell(self.hidden_size) 264 | seq_length = tf.reduce_sum(self.mask, axis=1) 265 | 266 | with tf.variable_scope("dynamic_rnn") as scope: 267 | self.states, final_state = \ 268 | tf.nn.dynamic_rnn( 269 | cell=cell, 270 | inputs=inputs_embed, 271 | sequence_length=seq_length, 272 | dtype=tf.float32, 273 | scope="dynamic_rnn") 274 | if type(final_state) is tf.nn.rnn_cell.LSTMStateTuple: 275 | self.final_state = final_state.h 276 | else: 277 | self.final_state = final_state 278 | 279 | scope.reuse_variables() 280 | self.test_states, final_state = \ 281 | tf.nn.dynamic_rnn( 282 | cell=cell, 283 | inputs=test_inputs_embed, 284 | sequence_length=seq_length, 285 | dtype=tf.float32, 286 | scope="dynamic_rnn") 287 | if type(final_state) is tf.nn.rnn_cell.LSTMStateTuple: 288 | self.test_final_state = final_state.h 289 | else: 290 | self.test_final_state = final_state 291 | 292 | prob = tf.layers.dense( 293 | self.final_state, self.n_class, 294 | activation=None, name="prob") 295 | log_y_given_h = tf.nn.sparse_softmax_cross_entropy_with_logits( 296 | labels=self.labels, 297 | logits=prob, 298 | name="cross_entropy") 299 | 300 | test_prob = tf.layers.dense( 301 | self.test_final_state, self.n_class, 302 | activation=None, name="prob", reuse=True) 303 | self.loss = tf.reduce_mean(log_y_given_h) 304 | 305 | self.pred = tf.cast(tf.argmax(test_prob, axis=1), tf.int32) 306 | self.acc = tf.reduce_sum( 307 | tf.cast(tf.equal(self.labels, self.pred), tf.float32)) 308 | -------------------------------------------------------------------------------- /process_imdb.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import logging 3 | import argparse 4 | import nltk 5 | import os 6 | 7 | from collections import Counter 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | MAX_SENT_LEN = 400 12 | 13 | 14 | def str2bool(v): 15 | return v.lower() in ('yes', 'true', 't', '1', 'y') 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser( 20 | description='Data Processing for \ 21 | \"Smaller Text Classifiers with Discriminative Cluster Embeddings\"') 22 | parser.register('type', 'bool', str2bool) 23 | 24 | parser.add_argument('--type', type=str, default="imdb", 25 | help='data type: imdb') 26 | parser.add_argument('--path_train_pos', type=str, default=None, 27 | help='positive train data path') 28 | parser.add_argument('--path_train_neg', type=str, default=None, 29 | help='negative train data path') 30 | parser.add_argument('--path_test_pos', type=str, default=None, 31 | help='positive test data path') 32 | parser.add_argument('--path_test_neg', type=str, default=None, 33 | help='negative test data path') 34 | parser.add_argument('--dev_size', type=float, default=2000, 35 | help='dev set size') 36 | parser.add_argument('--train_size', type=float, default=1.0, 37 | help='train set ratio') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def clean_review(line): 43 | return nltk.wordpunct_tokenize(line.strip())[: MAX_SENT_LEN] 44 | 45 | 46 | def process_text_files(file_dir, label): 47 | logging.info("loading data from {} ...".format(file_dir)) 48 | vocab = Counter() 49 | revs = [] 50 | for filename in os.listdir(file_dir): 51 | filepath = os.path.join(file_dir, filename) 52 | with open(filepath, 'r') as f: 53 | words = clean_review(f.read().lower()) 54 | for w in words: 55 | vocab[w] += 1 56 | datum = {"y": label, 57 | "text": words, 58 | "num_words": len(words)} 59 | revs.append(datum) 60 | return revs, vocab 61 | 62 | 63 | if __name__ == "__main__": 64 | logging.basicConfig(level=logging.DEBUG, 65 | format='%(asctime)s %(message)s', 66 | datefmt='%m-%d %H:%M') 67 | args = get_args() 68 | 69 | train_pos_revs, train_pos_vocab = \ 70 | process_text_files(args.path_train_pos, 1) 71 | train_neg_revs, train_neg_vocab = \ 72 | process_text_files(args.path_train_neg, 0) 73 | test_pos_revs, test_pos_vocab = \ 74 | process_text_files(args.path_test_pos, 1) 75 | test_neg_revs, test_neg_vocab = \ 76 | process_text_files(args.path_test_neg, 0) 77 | logging.info("data loaded!") 78 | 79 | train_pos_revs, val_pos_revs = \ 80 | train_test_split( 81 | train_pos_revs, test_size=args.dev_size // 2) 82 | train_neg_revs, val_neg_revs = \ 83 | train_test_split( 84 | train_neg_revs, test_size=args.dev_size // 2) 85 | 86 | all_val = val_pos_revs + val_neg_revs 87 | 88 | val_vocab = Counter() 89 | for d in all_val: 90 | for w in d["text"]: 91 | val_vocab[w] += 1 92 | if args.train_size != 1: 93 | train_vocab = Counter() 94 | _, train_pos_revs = \ 95 | train_test_split( 96 | train_pos_revs, 97 | test_size=args.train_size) 98 | for d in train_pos_revs: 99 | for w in d["text"]: 100 | train_vocab[w] += 1 101 | _, train_neg_revs = \ 102 | train_test_split( 103 | train_neg_revs, 104 | test_size=args.train_size) 105 | for d in train_neg_revs: 106 | for w in d["text"]: 107 | train_vocab[w] += 1 108 | vocab = {"train": train_vocab, 109 | "dev": val_vocab, 110 | "test": test_pos_vocab + test_neg_vocab} 111 | else: 112 | vocab = {"train": train_pos_vocab + train_neg_vocab - val_vocab, 113 | "dev": val_vocab, "test": test_pos_vocab + test_neg_vocab} 114 | 115 | revs = {"train": train_pos_revs + train_neg_revs, 116 | "dev": all_val, "test": test_pos_revs + test_neg_revs} 117 | 118 | for split in ["train", "dev", "test"]: 119 | if revs.get(split) is not None: 120 | logging.info(split + " " + "-" * 50) 121 | logging.info("number of sentences: " + str(len(revs[split]))) 122 | logging.info("vocab size: {}".format(len(vocab[split]))) 123 | logging.info("-" * 50) 124 | logging.info("total vocab size: {}".format( 125 | len(sum(vocab.values(), Counter())))) 126 | logging.info("total data size: {}".format(len(sum(revs.values(), [])))) 127 | save_path = args.type.lower() + str(args.train_size) + ".pkl" 128 | pickle.dump([revs, vocab], open(save_path, "wb+"), protocol=-1) 129 | logging.info("dataset saved to {}".format(save_path)) 130 | -------------------------------------------------------------------------------- /process_other.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pickle 4 | import csv 5 | import re 6 | import os 7 | 8 | from sklearn.model_selection import train_test_split 9 | from collections import Counter 10 | 11 | 12 | def str2bool(v): 13 | return v.lower() in ('yes', 'true', 't', '1', 'y') 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser( 18 | description='Data Processing for \ 19 | \"Smaller Text Classifiers with Discriminative Cluster Embeddings\"') 20 | parser.register('type', 'bool', str2bool) 21 | 22 | parser.add_argument('--type', type=str, default="agnews", 23 | help='data type: agnews, dbpedia') 24 | parser.add_argument('--path', type=str, default=None, 25 | help='data path') 26 | parser.add_argument('--clean_str', type="bool", default=True, 27 | help='whether to tokenize data') 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def load_split_data(data_file, dataset, clean_str): 33 | logging.info("loading data from {} ...".format(data_file)) 34 | revs = [] 35 | vocab = Counter() 36 | with open(data_file, 'rt', encoding="utf8") as f: 37 | d = csv.reader(f, delimiter=",") 38 | for line in d: 39 | label_ = line[0] 40 | if len(line) == 2: 41 | data_ = line[1] 42 | elif len(line) > 2: 43 | data_ = "" 44 | for line_ in line[1:]: 45 | data_ += line_.strip(r'\"') 46 | else: 47 | raise ValueError("incorrect data length: {}".format(len(line))) 48 | label_ = int(label_.strip("\"")) - 1 49 | if clean_str: 50 | rev = [] 51 | rev.append(data_.strip(r'\"')) 52 | data_ = clean_string(" ".join(rev)) 53 | else: 54 | data_ = data_.strip(r'\"') 55 | for i, word in enumerate(data_.split(" ")): 56 | vocab[word] += 1 57 | datum = {"y": label_, 58 | "text": data_.split(" "), 59 | "num_words": len(data_.split(" "))} 60 | revs.append(datum) 61 | return revs, vocab 62 | 63 | 64 | def load_data(path, data_folder, dataset, clean_str): 65 | """ 66 | Loads data. 67 | """ 68 | revs = {} 69 | vocabs = {} 70 | train_file = os.path.join(path, data_folder[0]) 71 | if data_folder[1] is not None: 72 | dev_file = os.path.join(path, data_folder[1]) 73 | else: 74 | dev_file = None 75 | test_file = os.path.join(path, data_folder[2]) 76 | 77 | revs_split, vocab_split = \ 78 | load_split_data(train_file, dataset, clean_str) 79 | revs["train"] = revs_split 80 | vocabs["train"] = vocab_split 81 | if dev_file is None: 82 | train_split, test_split = \ 83 | train_test_split(revs_split, test_size=5000) 84 | word_count = Counter() 85 | for data in test_split: 86 | for word in data["text"]: 87 | word_count[word] += 1 88 | vocabs["train"] = vocabs["train"] - word_count 89 | revs["train"] = train_split 90 | 91 | vocabs["dev"] = word_count 92 | revs["dev"] = test_split 93 | else: 94 | revs_split, vocab_split = \ 95 | load_split_data(dev_file, dataset, clean_str) 96 | revs["dev"] = revs_split 97 | vocabs["dev"] = vocab_split 98 | revs_split, vocab_split = \ 99 | load_split_data(test_file, dataset, clean_str) 100 | revs["test"] = revs_split 101 | vocabs["test"] = vocab_split 102 | return revs, vocabs 103 | 104 | 105 | def clean_string(string): 106 | """ 107 | Tokenization/string cleaning for yelp data set 108 | Based on https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 109 | """ 110 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`\"]", " ", string) 111 | string = re.sub(r"\'s", " \'s", string) 112 | string = re.sub(r"\'ve", " \'ve", string) 113 | string = re.sub(r"n\'t", " n\'t", string) 114 | string = re.sub(r"\'re", " \'re", string) 115 | string = re.sub(r"\'d", " \'d", string) 116 | string = re.sub(r"\'ll", " \'ll", string) 117 | string = re.sub(r",", " , ", string) 118 | string = re.sub(r"!", " ! ", string) 119 | string = re.sub(r"\(", " \( ", string) 120 | string = re.sub(r"\"\"", " \" ", string) 121 | string = re.sub(r"\)", " \) ", string) 122 | string = re.sub(r"\?", " \? ", string) 123 | string = re.sub(r"\s{2,}", " ", string) 124 | return string.lower() 125 | 126 | 127 | if __name__ == "__main__": 128 | logging.basicConfig(level=logging.DEBUG, 129 | format='%(asctime)s %(message)s', 130 | datefmt='%m-%d %H:%M') 131 | args = get_args() 132 | # train, dev, test 133 | data_folder = ["train.csv", 134 | None, 135 | "test.csv"] 136 | revs, vocab = \ 137 | load_data(args.path, data_folder, args.type, args.clean_str) 138 | logging.info("data loaded!") 139 | for split in ["train", "dev", "test"]: 140 | if revs.get(split) is not None: 141 | logging.info(split + " " + "-" * 50) 142 | logging.info("number of sentences: " + str(len(revs[split]))) 143 | logging.info("vocab size: {}".format(len(vocab[split]))) 144 | logging.info("-" * 50) 145 | logging.info("total vocab size: {}".format( 146 | len(sum(vocab.values(), Counter())))) 147 | logging.info("total data size: {}".format(len(sum(revs.values(), [])))) 148 | pickle.dump([revs, vocab], 149 | open(args.type.lower() + ".pkl", "wb+"), 150 | protocol=-1) 151 | logging.info("dataset created!") 152 | -------------------------------------------------------------------------------- /process_yelp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pickle 4 | import re 5 | import os 6 | 7 | from sklearn.model_selection import train_test_split 8 | from collections import Counter 9 | 10 | 11 | def str2bool(v): 12 | return v.lower() in ('yes', 'true', 't', '1', 'y') 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser( 17 | description='Data Processing for \ 18 | \"Smaller Text Classifiers with Discriminative Cluster Embeddings\"') 19 | parser.register('type', 'bool', str2bool) 20 | 21 | parser.add_argument('--type', type=str, default="yelp-1", 22 | help='data type: yelp-1, yelp-2') 23 | parser.add_argument('--path', type=str, default=None, 24 | help='data path') 25 | parser.add_argument('--clean_str', type="bool", default=True, 26 | help='whether to tokenize data') 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def load_split_data(data_file, dataset, clean_str): 32 | logging.info("loading data from {} ...".format(data_file)) 33 | revs = [] 34 | vocab = Counter() 35 | with open(data_file, "rb") as f: 36 | for line in f: 37 | label_, data_ = \ 38 | line.decode('unicode_escape').strip("\n").split(",", 1) 39 | label_ = int(label_.strip("\"")) - 1 40 | if clean_str: 41 | rev = [] 42 | rev.append(data_.strip(r'\"')) 43 | data_ = clean_string(" ".join(rev)) 44 | else: 45 | data_ = data_.strip(r'\"') 46 | for i, word in enumerate(data_.split(" ")): 47 | vocab[word] += 1 48 | datum = {"y": label_, 49 | "text": data_.split(" "), 50 | "num_words": len(data_.split(" "))} 51 | revs.append(datum) 52 | return revs, vocab 53 | 54 | 55 | def load_data(path, data_folder, dataset, clean_str): 56 | """ 57 | Loads data. 58 | """ 59 | revs = {} 60 | vocabs = {} 61 | train_file = os.path.join(path, data_folder[0]) 62 | if data_folder[1] is not None: 63 | dev_file = os.path.join(path, data_folder[1]) 64 | else: 65 | dev_file = None 66 | test_file = os.path.join(path, data_folder[2]) 67 | 68 | revs_split, vocab_split = load_split_data(train_file, dataset, clean_str) 69 | revs["train"] = revs_split 70 | vocabs["train"] = vocab_split 71 | if dev_file is None: 72 | train_split, test_split = \ 73 | train_test_split(revs_split, test_size=5000) 74 | word_count = Counter() 75 | for data in test_split: 76 | for word in data["text"]: 77 | word_count[word] += 1 78 | vocabs["train"] = vocabs["train"] - word_count 79 | revs["train"] = train_split 80 | 81 | vocabs["dev"] = word_count 82 | revs["dev"] = test_split 83 | else: 84 | revs_split, vocab_split = \ 85 | load_split_data(dev_file, dataset, clean_str) 86 | revs["dev"] = revs_split 87 | vocabs["dev"] = vocab_split 88 | revs_split, vocab_split = load_split_data(test_file, dataset, clean_str) 89 | revs["test"] = revs_split 90 | vocabs["test"] = vocab_split 91 | return revs, vocabs 92 | 93 | 94 | def clean_string(string): 95 | """ 96 | Tokenization/string cleaning for yelp data set 97 | Based on https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 98 | """ 99 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`\"]", " ", string) 100 | string = re.sub(r"\'s", " \'s", string) 101 | string = re.sub(r"\'ve", " \'ve", string) 102 | string = re.sub(r"n\'t", " n\'t", string) 103 | string = re.sub(r"\'re", " \'re", string) 104 | string = re.sub(r"\'d", " \'d", string) 105 | string = re.sub(r"\'ll", " \'ll", string) 106 | string = re.sub(r",", " , ", string) 107 | string = re.sub(r"!", " ! ", string) 108 | string = re.sub(r"\(", " \( ", string) 109 | string = re.sub(r"\"\"", " \" ", string) 110 | string = re.sub(r"\)", " \) ", string) 111 | string = re.sub(r"\?", " \? ", string) 112 | string = re.sub(r"\s{2,}", " ", string) 113 | return string.lower() 114 | 115 | 116 | if __name__ == "__main__": 117 | logging.basicConfig(level=logging.DEBUG, 118 | format='%(asctime)s %(message)s', 119 | datefmt='%m-%d %H:%M') 120 | args = get_args() 121 | # train, dev, test 122 | if args.type.lower() == "yelp-1": 123 | data_folder = ["train.csv", 124 | None, 125 | "test.csv"] 126 | elif args.type.lower() == "yelp-2": 127 | data_folder = ["train.csv", 128 | None, 129 | "test.csv"] 130 | else: 131 | raise ValueError("invalid dataset type: {}".format(args.type)) 132 | revs, vocab = \ 133 | load_data(args.path, data_folder, args.type, args.clean_str) 134 | logging.info("data loaded!") 135 | for split in ["train", "dev", "test"]: 136 | if revs.get(split) is not None: 137 | logging.info(split + " " + "-" * 50) 138 | logging.info("number of sentences: " + str(len(revs[split]))) 139 | logging.info("vocab size: {}".format(len(vocab[split]))) 140 | logging.info("-" * 50) 141 | logging.info("total vocab size: {}".format( 142 | len(sum(vocab.values(), Counter())))) 143 | logging.info("total data size: {}".format(len(sum(revs.values(), [])))) 144 | pickle.dump([revs, vocab], 145 | open(args.type.lower() + ".pkl", "wb+"), 146 | protocol=-1) 147 | logging.info("dataset created!") 148 | -------------------------------------------------------------------------------- /run_me.py: -------------------------------------------------------------------------------- 1 | import mixture_utils 2 | import data_utils 3 | import argparse 4 | import logging 5 | import pickle 6 | import time 7 | import os 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from models import mixture_classifier 13 | 14 | 15 | def str2bool(v): 16 | return v.lower() in ('yes', 'true', 't', '1', 'y') 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser( 21 | description='TensorFlow implementation of \ 22 | \"Smaller Text Classifiers with Discriminative Cluster Embeddings\"') 23 | parser.register('type', 'bool', str2bool) 24 | # Basics 25 | parser.add_argument('--debug', type="bool", default=False, 26 | help='whether to activate debug mode (default: False)') 27 | parser.add_argument('--random_seed', type=int, default=0, 28 | help='Random seed (default: 0)') 29 | # Data file 30 | parser.add_argument('--dataset', type=str, default=None, 31 | help='Types of dataset: yelp-1, yelp-2, \ 32 | DBPedia, AGNews, IMDB (default: None)') 33 | parser.add_argument('--data_path', type=str, default=None, 34 | help='Data path (default: None)') 35 | parser.add_argument('--vocab_file', type=str, default=None, 36 | help='vocab file path (default: None)') 37 | parser.add_argument('--embed_file', type=str, default=None, 38 | help='embedding file path (default: None)') 39 | # model detail 40 | parser.add_argument('--embed_dim', type=int, default=100, 41 | help='embedding dimension (default: 100)') 42 | parser.add_argument('--n_embed', type=int, default=10, 43 | help='number of embedding vector (default: 10)') 44 | parser.add_argument('--hidden_size', type=int, default=50, 45 | help='hidden dimension of RNN (default: 50)') 46 | parser.add_argument('--learning_rate', type=float, default=1e-3, 47 | help='learning rate (default: 1e-3)') 48 | parser.add_argument('--opt', type=str, default='adam', 49 | help='types of optimizer: adam (default), \ 50 | sgd, rmsprop') 51 | parser.add_argument('--rnn_type', type=str, default='lstm', 52 | help='types of optimizer: lstm (default), \ 53 | gru, rnn') 54 | parser.add_argument('--bidir', type="bool", default=False, 55 | help='whether to use bidirectional \ 56 | (default: False)') 57 | # train detail 58 | parser.add_argument('--save_vocab', type="bool", default=True, 59 | help='whether to save vocabulary \ 60 | (default: True)') 61 | parser.add_argument('--train_emb', type="bool", default=True, 62 | help='whether to train embedding vectors \ 63 | (default: True)') 64 | parser.add_argument('--n_epoch', type=int, default=10, 65 | help='number of epochs (default: 10)') 66 | parser.add_argument('--batch_size', type=int, default=20, 67 | help='batch size (default: 20)') 68 | parser.add_argument('--max_words', type=int, default=50000, 69 | help='maximum number of words in vocabulary \ 70 | (default: 50000)') 71 | parser.add_argument('--keep_word', type=int, default=1000, 72 | help='number of words that use standard embeddings in vocabulary \ 73 | (default: 1000)') 74 | parser.add_argument('--proj', type=str, default="gumbel", 75 | help='types of embedding projection: gumbel (default)') 76 | parser.add_argument('--grad_clip', type=float, default=10., 77 | help='gradient clipping (default: 10)') 78 | parser.add_argument('--init_temp', type=float, default=0.9, 79 | help='initial temperature for gumbel softmax \ 80 | (default: 0.9)') 81 | parser.add_argument('--anneal_rate', type=float, default=0., 82 | help='annealing rate for temperature (default: 0)') 83 | parser.add_argument('--min_temp', type=float, default=0.9, 84 | help='minimum temperature (default: 0.9)') 85 | parser.add_argument('--l2', type=float, default=0., 86 | help='l2 regularizer (default: 0)') 87 | # misc 88 | parser.add_argument('--print_every', type=int, default=500, 89 | help='print training details after \ 90 | this number of iterations (default: 500)') 91 | parser.add_argument('--eval_every', type=int, default=5000, 92 | help='evaluate model after \ 93 | this number of iterations (default: 5000)') 94 | return parser.parse_args() 95 | 96 | 97 | def run(args): 98 | dp = os.path.join(args.data_path, args.dataset.lower() + ".pkl") 99 | logging.info("loading data from {} ...".format(dp)) 100 | with open(dp, "rb+") as infile: 101 | revs, vocabs = pickle.load(infile) 102 | # make vocab 103 | assert args.proj.lower() == "gumbel", "only gumbel is supported!" 104 | if not os.path.isfile(args.vocab_file): 105 | vocab, vocab_mask = mixture_utils.make_dict( 106 | vocabs.get("train"), args.vocab_file, args.keep_word, 107 | args.max_words, args.save_vocab) 108 | else: 109 | vocab, vocab_mask = mixture_utils.load_dict(args.vocab_file) 110 | 111 | train_data, train_label = data_utils.prepare_data(revs["train"], vocab) 112 | test_data, test_label = data_utils.prepare_data(revs["test"], vocab) 113 | dev_data, dev_label = data_utils.prepare_data(revs["dev"], vocab) 114 | 115 | logging.info("#training data: {}".format(len(train_data))) 116 | logging.info("#dev data: {}".format(len(dev_data))) 117 | logging.info("#test data: {}".format(len(test_data))) 118 | 119 | logging.info("#unk words in train data: {}".format( 120 | data_utils.cal_unk(train_data))) 121 | logging.info("#unk words in dev data: {}".format( 122 | data_utils.cal_unk(dev_data))) 123 | logging.info("#unk words in test data: {}".format( 124 | data_utils.cal_unk(test_data))) 125 | 126 | logging.info("initializing model ...") 127 | model = mixture_classifier( 128 | vocab_size=len(vocab), 129 | vocab_mask=vocab_mask, 130 | args=args) 131 | init = tf.global_variables_initializer() 132 | saver = tf.train.Saver(tf.global_variables()) 133 | logging.info("model successfully initialized") 134 | 135 | test_d, test_l = data_utils.make_batch( 136 | test_data, test_label, args.batch_size) 137 | dev_d, dev_l = data_utils.make_batch( 138 | dev_data, dev_label, args.batch_size) 139 | logging.info("-" * 50) 140 | 141 | # training phase 142 | it = best_dev_pred = 0. 143 | logging.info("Training start ...") 144 | with tf.Session() as sess: 145 | sess.run(init) 146 | for epoch in range(args.n_epoch): 147 | train_d, train_l = data_utils.make_batch( 148 | train_data, train_label, args.batch_size) 149 | loss = n_example = 0 150 | start_time = time.time() 151 | for train_doc_, train_label_ in zip(train_d, train_l): 152 | train_doc_, train_mask_ = data_utils.pad_seq(train_doc_) 153 | 154 | if it % 500 == 0: 155 | temp = np.maximum( 156 | args.init_temp * np.exp(-args.anneal_rate * it), 157 | args.min_temp) 158 | 159 | loss_ = model.train( 160 | sess, train_doc_, train_mask_, train_label_, temp) 161 | loss += loss_ * len(train_doc_) 162 | n_example += len(train_doc_) 163 | it += 1 164 | 165 | if it % args.print_every == 0: 166 | logging.info("epoch: {}, it: {} (max: {}), " 167 | "loss: {:.5f}, " 168 | "time: {:.5f}(s), temp: {:.5f}" 169 | .format(epoch, it, len(train_d), 170 | loss / n_example, 171 | time.time() - start_time, 172 | temp)) 173 | loss = n_example = 0 174 | start_time = time.time() 175 | 176 | if it % args.eval_every == 0 or it % len(train_d) == 0: 177 | start_time = time.time() 178 | pred = 0 179 | n_dev = 0 180 | for dev_doc_, dev_label_ in zip(dev_d, dev_l): 181 | dev_doc_, dev_mask_ = data_utils.pad_seq(dev_doc_) 182 | pred_ = model.evaluate( 183 | sess, dev_doc_, dev_mask_, dev_label_, temp) 184 | pred += pred_ 185 | n_dev += len(dev_doc_) 186 | pred /= n_dev 187 | logging.info("Dev acc: {:.5f}, #pred: {}, " 188 | "elapsed time: {:.5f}(s)" 189 | .format(pred, n_dev, 190 | time.time() - start_time)) 191 | start_time = time.time() 192 | n_example = 0 193 | 194 | if best_dev_pred < pred: 195 | best_dev_pred = pred 196 | best_temp = temp 197 | model.save(sess, saver, args.save_dir) 198 | logging.info("Best dev acc: {:.5f}" 199 | .format(best_dev_pred)) 200 | start_time = time.time() 201 | 202 | logging.info("-" * 50) 203 | start_time = time.time() 204 | test_temp = temp 205 | 206 | if epoch == args.n_epoch - 1: 207 | logging.info("final testing ...") 208 | model.restore(sess, args.save_dir) 209 | test_temp = best_temp 210 | pred = n_test = 0 211 | 212 | for test_doc_, test_label_ in zip(test_d, test_l): 213 | test_doc_, test_mask_ = data_utils.pad_seq(test_doc_) 214 | pred_ = model.evaluate( 215 | sess, test_doc_, test_mask_, test_label_, test_temp) 216 | pred += pred_ 217 | n_test += len(test_doc_) 218 | pred /= n_test 219 | 220 | logging.info("-" * 50) 221 | logging.info("test acc: {:.5f}, #pred: {}, elapsed time: {:.5f}(s)" 222 | .format(pred, n_test, time.time() - start_time)) 223 | logging.info("best dev acc: {:.5f}, best temp: {}" 224 | .format(best_dev_pred, best_temp)) 225 | logging.info("vocab size: {}".format(len(vocab))) 226 | 227 | 228 | if __name__ == '__main__': 229 | args = get_args() 230 | np.random.seed(args.random_seed) 231 | tf.set_random_seed(args.random_seed) 232 | args.save_dir = "experiments" + "/" + args.dataset.lower() + "/" \ 233 | + args.opt.lower() + "_mix" + str(args.proj) \ 234 | + "_edim" + str(args.embed_dim) \ 235 | + "_nembed" + str(args.n_embed) + "_temb" + str(args.train_emb) \ 236 | + "_hsize" + str(args.hidden_size) \ 237 | + "_itemp" + str(args.init_temp) + "_min_temp" + str(args.min_temp) \ 238 | + "_anneal" + str(args.anneal_rate) + "_lr" + str(args.learning_rate) \ 239 | + "_epoch" + str(args.n_epoch) + "_l2" + str(args.l2) \ 240 | + "_kw" + str(args.keep_word) + "_words" + str(args.max_words) 241 | args.n_class = data_utils.get_n_class(args.dataset.lower()) 242 | if args.debug: 243 | args.save_dir = "./mix_exp" 244 | if not os.path.exists(args.save_dir): 245 | os.makedirs(args.save_dir) 246 | if args.debug: 247 | logging.basicConfig( 248 | level=logging.DEBUG, 249 | format='%(asctime)s %(message)s', 250 | datefmt='%m-%d %H:%M') 251 | else: 252 | log_file = os.path.join(args.save_dir, 'log') 253 | print("log saving to", log_file) 254 | logging.basicConfig(filename=log_file, 255 | filemode='w+', level=logging.INFO, 256 | format='%(asctime)s %(message)s', 257 | datefmt='%m-%d %H:%M') 258 | if args.dataset is None: 259 | raise ValueError('dataset is not specified.') 260 | if args.vocab_file is None: 261 | raise ValueError('vocab_file is not specified.') 262 | logging.info(args) 263 | run(args) 264 | -------------------------------------------------------------------------------- /run_se_ce_cae.py: -------------------------------------------------------------------------------- 1 | import data_utils 2 | import argparse 3 | import logging 4 | import pickle 5 | import time 6 | import os 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from models import sentiment_classifier 12 | 13 | 14 | def str2bool(v): 15 | return v.lower() in ('yes', 'true', 't', '1', 'y') 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser( 20 | description='TensorFlow implementation of \ 21 | \"Smaller Text Classifiers with Discriminative Cluster Embeddings\"') 22 | parser.register('type', 'bool', str2bool) 23 | # basics 24 | parser.add_argument('--debug', type="bool", default=False, 25 | help='whether to activate debug mode (default: False)') 26 | parser.add_argument('--random_seed', type=int, default=0, 27 | help='Random seed (default: 0)') 28 | # data file 29 | parser.add_argument('--dataset', type=str, default=None, 30 | help='Types of dataset: yelp-1, yelp-2, \ 31 | DBPedia, AGNews, IMDB') 32 | parser.add_argument('--data_path', type=str, default=None, 33 | help='Data path') 34 | parser.add_argument('--vocab_file', type=str, default=None, 35 | help='vocab file path') 36 | parser.add_argument('--embed_file', type=str, default=None, 37 | help='embedding file path') 38 | # model detail 39 | parser.add_argument('--embed_dim', type=int, default=100, 40 | help='embedding dimension (default: 100)') 41 | parser.add_argument('--aux_dim', type=int, default=0, 42 | help='dimension of auxiliary vector for \ 43 | each word (default: 0)') 44 | parser.add_argument('--n_embed', type=int, default=10, 45 | help='number of embedding vector (default: 10)') 46 | parser.add_argument('--hidden_size', type=int, default=50, 47 | help='hidden dimension of RNN (default: 50)') 48 | parser.add_argument('--learning_rate', type=float, default=1e-3, 49 | help='learning rate (default: 1e-3)') 50 | parser.add_argument('--opt', type=str, default='adam', 51 | help='types of optimizer: adam (default), \ 52 | sgd, rmsprop') 53 | parser.add_argument('--rnn_type', type=str, default='lstm', 54 | help='types of optimizer: lstm (default), \ 55 | gru, rnn') 56 | parser.add_argument('--bidir', type="bool", default=False, 57 | help='whether to use bidirectional RNN \ 58 | (default: False)') 59 | # train detail 60 | parser.add_argument('--save_vocab', type="bool", default=True, 61 | help='whether to save vocabulary \ 62 | (default: True)') 63 | parser.add_argument('--train_emb', type="bool", default=True, 64 | help='whether to train embedding vectors \ 65 | (default: True)') 66 | parser.add_argument('--n_epoch', type=int, default=10, 67 | help='number of epochs (default: 10)') 68 | parser.add_argument('--batch_size', type=int, default=20, 69 | help='batch size (default: 20)') 70 | parser.add_argument('--max_words', type=int, default=50000, 71 | help='maximum number of words in vocabulary \ 72 | (default: 50000)') 73 | parser.add_argument('--proj', type=str, default="standard", 74 | help='types of embedding projection: standard (default), \ 75 | gumbel') 76 | parser.add_argument('--grad_clip', type=float, default=10., 77 | help='gradient clipping (default: 10)') 78 | parser.add_argument('--init_temp', type=float, default=0.9, 79 | help='initial temperature for gumbel softmax \ 80 | (default: 0.9)') 81 | parser.add_argument('--anneal_rate', type=float, default=0.0, 82 | help='annealing rate for temperature (default: 0.0)') 83 | parser.add_argument('--min_temp', type=float, default=0.9, 84 | help='minimum temperature (default: 0.9)') 85 | parser.add_argument('--l2', type=float, default=0., 86 | help='l2 regularizer (default: 0)') 87 | # misc 88 | parser.add_argument('--print_every', type=int, default=500, 89 | help='print training details after \ 90 | this number of iterations (default: 500)') 91 | parser.add_argument('--eval_every', type=int, default=5000, 92 | help='evaluate model after \ 93 | this number of iterations (default: 5000)') 94 | return parser.parse_args() 95 | 96 | 97 | def run(args): 98 | dp = os.path.join(args.data_path, args.dataset.lower() + ".pkl") 99 | logging.info("loading data from {} ...".format(dp)) 100 | with open(dp, "rb+") as infile: 101 | revs, vocabs = pickle.load(infile) 102 | # make vocab 103 | if not os.path.isfile(args.vocab_file): 104 | vocab = data_utils.make_dict( 105 | vocabs.get("train"), args.vocab_file, 106 | args.max_words, args.save_vocab) 107 | else: 108 | vocab = data_utils.load_dict(args.vocab_file) 109 | 110 | train_data, train_label = data_utils.prepare_data(revs["train"], vocab) 111 | test_data, test_label = data_utils.prepare_data(revs["test"], vocab) 112 | dev_data, dev_label = data_utils.prepare_data(revs["dev"], vocab) 113 | 114 | logging.info("#training data: {}".format(len(train_data))) 115 | logging.info("#dev data: {}".format(len(dev_data))) 116 | logging.info("#test data: {}".format(len(test_data))) 117 | 118 | logging.info("initializing model ...") 119 | model = sentiment_classifier( 120 | vocab_size=len(vocab), 121 | args=args) 122 | init = tf.global_variables_initializer() 123 | saver = tf.train.Saver(tf.global_variables()) 124 | logging.info("model successfully initialized") 125 | 126 | logging.info("preparing data ...") 127 | logging.info("#unk words in train data: {}".format( 128 | data_utils.cal_unk(train_data))) 129 | logging.info("#unk words in dev data: {}".format( 130 | data_utils.cal_unk(dev_data))) 131 | logging.info("#unk words in test data: {}".format( 132 | data_utils.cal_unk(test_data))) 133 | test_d, test_l = data_utils.make_batch( 134 | test_data, test_label, args.batch_size) 135 | dev_d, dev_l = data_utils.make_batch( 136 | dev_data, dev_label, args.batch_size) 137 | logging.info("-" * 50) 138 | # training phase 139 | 140 | best_dev_pred = it = 0 141 | with tf.Session() as sess: 142 | sess.run(init) 143 | logging.info("Training start ...") 144 | for epoch in range(args.n_epoch): 145 | train_d, train_l = data_utils.make_batch( 146 | train_data, train_label, args.batch_size) 147 | loss = n_example = 0 148 | start_time = time.time() 149 | for train_doc_, train_label_ in zip(train_d, train_l): 150 | train_doc_, train_mask_ = data_utils.pad_seq(train_doc_) 151 | 152 | if it % 500 == 0: 153 | temp = np.maximum( 154 | args.init_temp * np.exp(-args.anneal_rate * it), 155 | args.min_temp) 156 | 157 | loss_ = model.train( 158 | sess, train_doc_, train_mask_, train_label_, temp) 159 | loss += loss_ * len(train_doc_) 160 | n_example += len(train_doc_) 161 | it += 1 162 | 163 | if it % args.print_every == 0: 164 | logging.info("epoch: {}, it: {} (max: {}), " 165 | "loss: {:.5f}, temp: {:.5f}, time: {:.1f}(s)" 166 | .format(epoch, it, len(train_d), 167 | loss / n_example, 168 | temp, time.time() - start_time)) 169 | loss = n_example = 0 170 | start_time = time.time() 171 | 172 | if it % args.eval_every == 0 or it % len(train_d) == 0: 173 | start_time = time.time() 174 | pred = n_dev = 0 175 | for dev_doc_, dev_label_ in zip(dev_d, dev_l): 176 | dev_doc_, dev_mask_ = data_utils.pad_seq(dev_doc_) 177 | pred_ = model.evaluate( 178 | sess, dev_doc_, dev_mask_, dev_label_, temp) 179 | pred += pred_ 180 | n_dev += len(dev_doc_) 181 | pred /= n_dev 182 | logging.info("Dev acc: {:.5f}, #pred: {}, " 183 | "elapsed time: {:.1f}(s)" 184 | .format(pred, n_dev, 185 | time.time() - start_time)) 186 | start_time = time.time() 187 | n_example = 0 188 | 189 | if best_dev_pred < pred: 190 | best_dev_pred = pred 191 | best_temp = temp 192 | model.save(sess, saver, args.save_dir) 193 | logging.info("Best dev acc: {:.5f}" 194 | .format(best_dev_pred)) 195 | start_time = time.time() 196 | 197 | logging.info("-" * 50) 198 | start_time = time.time() 199 | test_temp = temp 200 | 201 | if epoch == args.n_epoch - 1: 202 | logging.info("final testing ...") 203 | model.restore(sess, args.save_dir) 204 | test_temp = best_temp 205 | pred = n_test = 0 206 | 207 | for test_doc_, test_label_ in zip(test_d, test_l): 208 | test_doc_, test_mask_ = data_utils.pad_seq(test_doc_) 209 | pred_ = model.evaluate( 210 | sess, test_doc_, test_mask_, test_label_, test_temp) 211 | pred += pred_ 212 | n_test += len(test_doc_) 213 | pred /= n_test 214 | 215 | logging.info("test acc: {:.5f}, #pred: {}, elapsed time: {:.1f}(s)" 216 | .format(pred, n_test, time.time() - start_time)) 217 | logging.info("best dev acc: {:.5f}, best temp: {}" 218 | .format(best_dev_pred, best_temp)) 219 | logging.info("vocab size: {}".format(len(vocab))) 220 | 221 | 222 | if __name__ == '__main__': 223 | args = get_args() 224 | np.random.seed(args.random_seed) 225 | tf.set_random_seed(args.random_seed) 226 | args.save_dir = "experiments" + "/" + args.dataset.lower() + "/" \ 227 | + args.opt.lower() + "_ce" + str(args.proj) \ 228 | + "_edim" + str(args.embed_dim) \ 229 | + "_nembed" + str(args.n_embed) + "_temb" + str(args.train_emb) \ 230 | + "_hsize" + str(args.hidden_size) \ 231 | + "_itemp" + str(args.init_temp) + "_min_temp" + str(args.min_temp) \ 232 | + "_anneal" + str(args.anneal_rate) + "_lr" + str(args.learning_rate) \ 233 | + "_epoch" + str(args.n_epoch) + "_l2" + str(args.l2) \ 234 | + "_words" + str(args.max_words) 235 | args.n_class = data_utils.get_n_class(args.dataset.lower()) 236 | if args.debug: 237 | args.save_dir = "./" 238 | if not os.path.exists(args.save_dir): 239 | os.makedirs(args.save_dir) 240 | if args.debug: 241 | logging.basicConfig( 242 | level=logging.DEBUG, 243 | format='%(asctime)s %(message)s', 244 | datefmt='%m-%d %H:%M') 245 | else: 246 | log_file = os.path.join(args.save_dir, 'log') 247 | print("log saving to", log_file) 248 | logging.basicConfig(filename=log_file, 249 | filemode='w+', level=logging.INFO, 250 | format='%(asctime)s %(message)s', 251 | datefmt='%m-%d %H:%M') 252 | if args.dataset is None: 253 | raise ValueError('dataset is not specified.') 254 | if args.vocab_file is None: 255 | raise ValueError('vocab_file is not specified.') 256 | logging.info(args) 257 | run(args) 258 | --------------------------------------------------------------------------------