├── README.md ├── dataset └── readme.txt ├── get_label.py ├── model-img.png ├── model.py ├── output └── readme.txt ├── sample-train.sh └── sample.config /README.md: -------------------------------------------------------------------------------- 1 | # DAZER 2 | The Tensorflow implementation of our ACL 2018 paper: 3 | ***A Deep Relevance Model for Zero-Shot Document Filtering, Chenliang Li, Wei Zhou, Feng Ji, Yu Duan, Haiqing Chen*** 4 | Paper url: http://aclweb.org/anthology/P18-1214 5 | 6 |

7 | 8 |

9 | 10 | 11 | ### Requirements 12 | - Python 3.5 13 | - Tensorflow 1.2 14 | - Numpy 15 | - Traitlets 16 | 17 | ### Guide To Use 18 | 19 | **Prepare your dataset**: first, prepare your own data. 20 | See [Data Preparation](#data-preparation) 21 | 22 | 23 | **Configure**: then, configure the model through the config file. Configurable parameters are listed [here](#configurations) 24 | 25 | See the example: [sample.config](https://github.com/WHUIR/DAZER/blob/master/sample.config) 26 | 27 | In additional, you need to change the zero-shot label settings in [get_label.py](https://github.com/WHUIR/DAZER/blob/master/get_label.py) 28 | 29 | (You need make sure both get_label.py and model.py are put in same directory) 30 | 31 | 32 | **Training** : pass the config file, training data and validation data as 33 | ```ruby 34 | python model.py config-file\ 35 | --train \ 36 | --train_file: path to training data\ 37 | --validation_file: path to validation data\ 38 | --checkpoint_dir: directory to store/load model checkpoints\ 39 | --load_model: True or False(depends on existing or not). Start with a new model or continue training 40 | ``` 41 | 42 | See example: [sample-train.sh](https://github.com/WHUIR/DAZER/blob/master/sample-train.sh) 43 | 44 | **Testing**: pass the config file and testing data as 45 | ```ruby 46 | python model.py config-file\ 47 | --test \ 48 | --test_file: path to testing data\ 49 | --test_size: size of testing data (number of testing samples)\ 50 | --checkpoint_dir: directory to load trained model\ 51 | --output_score_file: file to output documents score\ 52 | 53 | ``` 54 | Relevance scores will be output to output_score_file, one score per line, in the same order as test_file. 55 | 56 | 57 | ### Data Preparation 58 | 59 | 60 | All seed words and documents must be mapped into sequences of integer term ids. Term id starts with 1. 61 | 62 | **Training Data Format** 63 | 64 | Each training sample is a tuple of (seed words, postive document, negative document) 65 | 66 | `seed_words \t postive_document \t negative_document ` 67 | 68 | Example: `334,453,768 \t 123,435,657,878,6,556 \t 443,554,534,3,67,8,12,2,7,9 ` 69 | 70 | 71 | **Testing Data Format** 72 | 73 | Each testing sample is a tuple of (seed words, document) 74 | 75 | `seed_words \t document` 76 | 77 | Example: `334,453,768 \t 123,435,657,878,6,556` 78 | 79 | 80 | **Validation Data Format** 81 | 82 | The format is same as training data format 83 | 84 | 85 | **Label Dict File Format** 86 | 87 | Each line is a tuple of (label_name, seed_words) 88 | 89 | `label_name/seed_words` 90 | 91 | Example: `alt.atheism/atheist christian atheism god islamic` 92 | 93 | 94 | **Word2id File Format** 95 | 96 | Each line is a tuple of (word, id) 97 | 98 | `word id` 99 | 100 | Example: `world 123` 101 | 102 | 103 | **Embedding File Format** 104 | 105 | Each line is a tuple of (id, embedding) 106 | 107 | `id embedding` 108 | 109 | Example: `1 0.3 0.4 0.5 0.6 -0.4 -0.2` 110 | 111 | 112 | ### Configurations 113 | 114 | 115 | **Model Configurations** 116 | - BaseNN.embedding_size: embedding dimension of word 117 | - BaseNN.max_q_len: max query length 118 | - BaseNN.max_d_len: max document length 119 | - DataGenerator.max_q_len: max query length. Should be the same as BaseNN.max_q_len 120 | - DataGenerator.max_d_len: max query length. Should be the same as BaseNN.max_d_len 121 | - BaseNN.vocabulary_size: vocabulary size 122 | - DataGenerator.vocabulary_size: vocabulary size 123 | - BaseNN.batch_size: batch size 124 | - BaseNN.max_epochs: max number of epochs to train 125 | - BaseNN.eval_frequency: evaluate model on validation set very this epochs 126 | - BaseNN.checkpoint_steps: save model very this epochs 127 | 128 | 129 | **Data** 130 | - DAZER.emb_in: path of initial embeddings file 131 | - DAZER.label_dict_path: path of label dict file 132 | - DAZER.word2id_path: path of word2id file 133 | 134 | 135 | **Training Parameters** 136 | - DAZER.epsilon: epsilon for Adam Optimizer 137 | - DAZER.embedding_size: embedding dimension of word 138 | - DAZER.vocabulary_size: vocabulary size of the dataset 139 | - DAZER.kernal_width: width of the kernel 140 | - DAZER.kernal_num: num of kernel 141 | - DAZER.regular_term: weight of L2 loss 142 | - DAZER.maxpooling_num: num of K-max pooling 143 | - DAZER.decoder_mlp1_num: num of hidden units of first mlp in relevance aggregation part 144 | - DAZER.decoder_mlp2_num: num of hidden units of second mlp in relevance aggregation part 145 | - DAZER.model_learning_rate: learning rate for model instead of adversarial calssifier 146 | - DAZER.adv_learning_rate: learning rate for adversarial classfier 147 | - DAZER.train_class_num: num of class in training time 148 | - DAZER.adv_term: weight of adversarial loss when updating model's parameters 149 | - DAZER.zsl_num: num of zero-shot labels 150 | - DAZER.zsl_type: type of zero-shot label setting ( you may have multiply zero-shot settings in same number of zero-shot label, this indicates which type of zero-shot label setting you pick for experiemnt, see [get_label.py](https://github.com/WHUIR/DAZER/blob/master/get_label.py) for more details ) 151 | -------------------------------------------------------------------------------- /dataset/readme.txt: -------------------------------------------------------------------------------- 1 | Dataset is saved in this directory. 2 | -------------------------------------------------------------------------------- /get_label.py: -------------------------------------------------------------------------------- 1 | 2 | def get_word2id(word2id_path): 3 | word2id = {} 4 | with open(word2id_path,'r',encoding='gbk') as f: 5 | for line in f: 6 | w,id = line.strip().split(' ') 7 | word2id[w] = int(id) 8 | return word2id 9 | 10 | def get_labels(label_dict_path,word2id_path): 11 | #use the label-dict file and word2id file to get label_dict, reverse_label_dict and label_list 12 | #which is useful in our DAZER model 13 | label_dict = {} 14 | reverse_label_dict = {} 15 | label_list = [] 16 | word2id = get_word2id(word2id_path) 17 | with open(label_dict_path,'r') as f: 18 | for line in f: 19 | c_name,words = line.strip().split('/') 20 | ids = [word2id[w] for w in words.split(' ')] 21 | label_dict[c_name] = ids 22 | label_list.append(c_name) 23 | ids_str = ','.join([str(x) for x in ids]) 24 | reverse_label_dict[ids_str] = c_name 25 | return label_dict, reverse_label_dict, label_list 26 | 27 | def get_label_index(label_list, zsl_num,zsl_type): 28 | #get the index of zeroshot label 29 | #below is the experiments setting of 20NG in our ACL paper, you should change them in your own dataset 30 | 31 | #e.g., zeroshot_labels_1[0] = [['sci.space'],['comp.graphics']] 32 | #it means we use label "sci.space" for zeroshot experiments 33 | #and randomly pick label 'comp.graphics' to prevent overfitting 34 | #please refer to the "Evaluation protocol" part of our paper 35 | 36 | zeroshot_labels_1 = [ 37 | [['sci.space'],['comp.graphics']], 38 | [['rec.sport.baseball'],['talk.politics.misc']], 39 | [['sci.med'],['rec.autos']], 40 | [['comp.sys.ibm.pc.hardware'],['rec.sport.hockey']], 41 | ] 42 | 43 | zeroshot_labels_2= [ 44 | [['sci.med','sci.space'],['talk.politics.guns']], 45 | [['alt.atheism','sci.electronics'],['comp.sys.ibm.pc.hardware']], 46 | [['soc.religion.christian','talk.politics.mideast'],['rec.sport.baseball']], 47 | [['rec.sport.baseball','rec.sport.hockey'],['comp.sys.mac.hardware']] 48 | ] 49 | 50 | zeroshot_labels_3 = [ 51 | [['comp.sys.ibm.pc.hardware','comp.windows.x','sci.electronics'],['talk.politics.mideast']], 52 | ] 53 | 54 | zeroshot_labels = [zeroshot_labels_1,zeroshot_labels_2,zeroshot_labels_3] 55 | 56 | z_labels = zeroshot_labels[zsl_num-1][zsl_type-1][0] + zeroshot_labels[zsl_num-1][zsl_type-1][1] 57 | label_test = [] 58 | for _l in label_list: 59 | if _l not in z_labels: 60 | label_test.append(_l) 61 | indexs = list(range(len(label_test))) 62 | zip_label_index = zip(label_test, indexs) 63 | return dict(list(zip_label_index)) 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /model-img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHUIR/DAZER/cc028184b120148eb45bba875b7f3f4c7f0e5294/model-img.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import time 4 | import get_label 5 | 6 | import sys 7 | import argparse 8 | from traitlets.config.loader import PyFileConfigLoader 9 | from traitlets.config import Configurable 10 | from traitlets import ( 11 | Int, 12 | Float, 13 | Bool, 14 | Unicode, 15 | ) 16 | 17 | class DataGenerator(Configurable): 18 | #params for data generator 19 | max_q_len = Int(10, help='max q len').tag(config=True) 20 | max_d_len = Int(500, help='max document len').tag(config=True) 21 | q_name = Unicode('q') 22 | d_name = Unicode('d') 23 | q_str_name = Unicode('q_str') 24 | q_lens_name = Unicode('q_lens') 25 | aux_d_name = Unicode('d_aux') 26 | vocabulary_size = Int(2000000).tag(config=True) 27 | 28 | def __init__(self, **kwargs): 29 | #init the data generator 30 | super(DataGenerator, self).__init__(**kwargs) 31 | print ("generator's vocabulary size: ", self.vocabulary_size) 32 | 33 | def pairwise_reader(self, pair_stream, batch_size, with_idf=False): 34 | #generate the batch of x,y in training time 35 | l_q = [] 36 | l_q_str = [] 37 | l_d = [] 38 | l_d_aux = [] 39 | l_y = [] 40 | l_q_lens = [] 41 | for line in pair_stream: 42 | cols = line.strip().split('\t') 43 | y = float(1.0) 44 | l_q_str.append(cols[0]) 45 | q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size]) 46 | t1 = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size]) 47 | t2 = np.array([int(t) for t in cols[2].split(',') if int(t) < self.vocabulary_size]) 48 | 49 | #padding 50 | v_q = np.zeros(self.max_q_len) 51 | v_d = np.zeros(self.max_d_len) 52 | v_d_aux = np.zeros(self.max_d_len) 53 | 54 | v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)] 55 | v_d[:min(t1.shape[0], self.max_d_len)] = t1[:min(t1.shape[0], self.max_d_len)] 56 | v_d_aux[:min(t2.shape[0], self.max_d_len)] = t2[:min(t2.shape[0], self.max_d_len)] 57 | 58 | l_q.append(v_q) 59 | l_d.append(v_d) 60 | l_d_aux.append(v_d_aux) 61 | l_y.append(y) 62 | l_q_lens.append(len(q)) 63 | 64 | if len(l_q) >= batch_size: 65 | Q = np.array(l_q, dtype=int,) 66 | D = np.array(l_d, dtype=int,) 67 | D_aux = np.array(l_d_aux, dtype=int,) 68 | Q_lens = np.array(l_q_lens, dtype=int,) 69 | Y = np.array(l_y, dtype=int,) 70 | X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str} 71 | yield X, Y 72 | l_q, l_d, l_d_aux, l_y, l_q_lens, l_ids, l_q_str = [], [], [], [], [], [], [] 73 | if l_q: 74 | Q = np.array(l_q, dtype=int,) 75 | D = np.array(l_d, dtype=int,) 76 | D_aux = np.array(l_d_aux, dtype=int,) 77 | Q_lens = np.array(l_q_lens, dtype=int,) 78 | Y = np.array(l_y, dtype=int,) 79 | X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str} 80 | yield X, Y 81 | 82 | def test_pairwise_reader(self, pair_stream, batch_size): 83 | #generate the batch of x,y in test time 84 | l_q = [] 85 | l_q_lens = [] 86 | l_d = [] 87 | 88 | for line in pair_stream: 89 | cols = line.strip().split('\t') 90 | q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size]) 91 | t = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size]) 92 | 93 | v_q = np.zeros(self.max_q_len) 94 | v_d = np.zeros(self.max_d_len) 95 | 96 | v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)] 97 | v_d[:min(t.shape[0], self.max_d_len)] = t[:min(t.shape[0], self.max_d_len)] 98 | 99 | l_q.append(v_q) 100 | l_d.append(v_d) 101 | l_q_lens.append(len(q)) 102 | 103 | if len(l_q) >= batch_size: 104 | Q = np.array(l_q, dtype=int,) 105 | D = np.array(l_d, dtype=int,) 106 | Q_lens = np.array(l_q_lens, dtype=int,) 107 | X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens} 108 | yield X 109 | l_q, l_d, l_q_lens = [], [], [] 110 | if l_q: 111 | Q = np.array(l_q, dtype=int,) 112 | D = np.array(l_d, dtype=int,) 113 | Q_lens = np.array(l_q_lens, dtype=int,) 114 | X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens} 115 | yield X 116 | 117 | class BaseNN(Configurable): 118 | #params of base deeprank model 119 | max_q_len = Int(10, help='max q len').tag(config=True) 120 | max_d_len = Int(50, help='max document len').tag(config=True) 121 | batch_size = Int(16, help="minibatch size").tag(config=True) 122 | max_epochs = Float(10, help="maximum number of epochs").tag(config=True) 123 | eval_frequency = Int(10000, help="print out minibatch every * epoches").tag(config=True) 124 | checkpoint_steps = Int(10000, help="store trained model every * epoches").tag(config=True) 125 | 126 | def __init__(self, **kwargs): 127 | super(BaseNN, self).__init__(**kwargs) 128 | # generator 129 | self.data_generator = DataGenerator(config=self.config) 130 | self.val_data_generator = DataGenerator(config=self.config) #validation in training stage is full test data in 20ng 131 | self.test_data_generator = DataGenerator(config=self.config) #test is zeros shot test data in 20ng (delete docs of zero shot label) 132 | 133 | @staticmethod 134 | def weight_variable(shape,name): 135 | tmp = np.sqrt(3.0) / np.sqrt(shape[0] + shape[1]) 136 | initial = tf.random_uniform(shape, minval=-tmp, maxval=tmp) 137 | return tf.Variable(initial_value=initial,name=name) 138 | 139 | def gen_query_mask(self, Q): 140 | mask = np.zeros((self.batch_size, self.max_q_len)) 141 | for b in range(len(Q)): 142 | for q in range(len(Q[b])): 143 | if Q[b][q] > 0: 144 | mask[b][q] = 1 145 | 146 | return mask 147 | 148 | def gen_doc_mask(self, D): 149 | mask = np.zeros((self.batch_size, self.max_d_len)) 150 | for b in range(len(D)): 151 | for q in range(len(D[b])): 152 | if D[b][q] > 0: 153 | mask[b][q] = 1 154 | 155 | return mask 156 | 157 | class DAZER(BaseNN): 158 | #params of zeroshot document filtering model 159 | embedding_size = Int(300, help="embedding dimension").tag(config=True) 160 | vocabulary_size = Int(2000000, help="vocabulary size").tag(config=True) 161 | kernal_width = Int(5, help='kernal width').tag(config=True) 162 | kernal_num = Int(50, help='number of kernal').tag(config=True) 163 | regular_term = Float(0.01, help='param for controlling wight of L2 loss').tag(config=True) 164 | maxpooling_num = Int(3, help='number of k-maxpooling').tag(config=True) 165 | decoder_mlp1_num = Int(75, help='number of hidden units of first mlp in relevance aggregation part').tag(config=True) 166 | decoder_mlp2_num = Int(1, help='number of hidden units of second mlp in relevance aggregation part').tag(config=True) 167 | emb_in = Unicode('None', help="initial embedding. Terms should be hashed to ids.").tag(config=True) 168 | model_learning_rate = Float(0.001, help="learning rate of model").tag(config=True) 169 | adv_learning_rate = Float(0.001, help='learning rate of adv classifier').tag(config=True) 170 | epsilon = Float(0.00001, help="Epsilon for Adam").tag(config=True) 171 | label_dict_path = Unicode('None', help='label dict path').tag(config=True) 172 | word2id_path = Unicode('None', help='word2id path').tag(config=True) 173 | train_class_num = Int(16, help='num of class in training data').tag(config=True) 174 | adv_term = Float(0.2, help='regular term of adversrial loss').tag(config=True) 175 | zsl_num = Int(1, help='num of zeroshot label').tag(config=True) 176 | zsl_type = Int(1, help='type of zeroshot label setting').tag(config=True) 177 | 178 | def __init__(self, **kwargs): 179 | #init the DAZER model 180 | super(DAZER, self).__init__(**kwargs) 181 | print ("trying to load initial embeddings from: ", self.emb_in) 182 | if self.emb_in != 'None': 183 | self.emb = self.load_word2vec(self.emb_in) 184 | self.embeddings = tf.Variable(tf.constant(self.emb, dtype='float32', shape=[self.vocabulary_size + 1, self.embedding_size]),trainable=False) 185 | print ("Initialized embeddings with {0}".format(self.emb_in)) 186 | else: 187 | self.embeddings = tf.Variable(tf.random_uniform([self.vocabulary_size + 1, self.embedding_size], -1.0, 1.0)) 188 | 189 | #variables of the DAZER model 190 | self.query_gate_weight = BaseNN.weight_variable((self.embedding_size, self.kernal_num),'gate_weight') 191 | self.query_gate_bias = tf.Variable(initial_value=tf.zeros((self.kernal_num)),name='gate_bias') 192 | self.adv_weight = BaseNN.weight_variable((self.decoder_mlp1_num,self.train_class_num),name='adv_weight') 193 | self.adv_bias = tf.Variable(initial_value=tf.zeros((1,self.train_class_num)),name='adv_bias') 194 | #get the label information to help adversarial learning 195 | self.label_dict, self.reverse_label_dict, self.label_list = get_label.get_labels(self.label_dict_path, self.word2id_path) 196 | self.label_index_dict = get_label.get_label_index(self.label_list, self.zsl_num, self.zsl_type) 197 | 198 | def load_word2vec(self, emb_file_path): 199 | emb = np.zeros((self.vocabulary_size + 1, self.embedding_size)) 200 | nlines = 0 201 | with open(emb_file_path) as f: 202 | for line in f: 203 | nlines += 1 204 | if nlines == 1: 205 | continue 206 | items = line.split() 207 | tid = int(items[0]) 208 | if tid > self.vocabulary_size: 209 | print (tid) 210 | continue 211 | vec = np.array([float(t) for t in items[1:]]) 212 | emb[tid, :] = vec 213 | if nlines % 20000 == 0: 214 | print ("load {0} vectors...".format(nlines)) 215 | return emb 216 | 217 | def gen_adv_query_mask(self, q_ids): 218 | q_mask = np.zeros((self.batch_size, self.train_class_num)) 219 | for batch_num, b_q_id in enumerate(q_ids): 220 | c_name = self.reverse_label_dict[b_q_id] 221 | c_index = self.label_index_dict[c_name] 222 | q_mask[batch_num][c_index] = 1 223 | return q_mask 224 | 225 | def get_class_gate(self,class_vec, emb_d): 226 | ''' 227 | compute the gate in kernal space 228 | :param class_vec: avg emb of seed words 229 | :param emb_d: emb of doc 230 | :return:the class gate [batchsize,d_len,kernal_num] 231 | ''' 232 | gate1 = tf.expand_dims(tf.matmul(class_vec, self.query_gate_weight), axis=1) 233 | bias = tf.expand_dims(self.query_gate_bias,axis=0) 234 | gate = tf.add(gate1, bias) 235 | return tf.sigmoid(gate) 236 | 237 | def L2_model_loss(self): 238 | all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' not in v.name] 239 | loss = 0. 240 | for each in all_para: 241 | loss += tf.nn.l2_loss(each) 242 | return loss 243 | 244 | def L2_adv_loss(self): 245 | all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' in v.name] 246 | loss = 0. 247 | for each in all_para: 248 | loss += tf.nn.l2_loss(each) 249 | return loss 250 | 251 | def train(self, train_pair_file_path, val_pair_file_path, checkpoint_dir, load_model=False): 252 | 253 | input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len]) 254 | input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len]) 255 | input_neg_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len]) 256 | q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,]) 257 | q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len]) 258 | pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len]) 259 | neg_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len]) 260 | input_q_index = tf.placeholder(tf.int32, shape=[self.batch_size,self.train_class_num]) 261 | 262 | emb_q = tf.nn.embedding_lookup(self.embeddings,input_q) 263 | class_vec_sum = tf.reduce_sum( 264 | tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)), 265 | axis=1 266 | ) 267 | 268 | #get class vec 269 | class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,-1)) 270 | emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d) 271 | emb_neg_d = tf.nn.embedding_lookup(self.embeddings,input_neg_d) 272 | 273 | #get query gate 274 | pos_query_gate = self.get_class_gate(class_vec, emb_pos_d) 275 | neg_query_gate = self.get_class_gate(class_vec, emb_neg_d) 276 | 277 | # CNN for document 278 | pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d) 279 | pos_sub_info = tf.expand_dims(class_vec,axis=1) - emb_pos_d 280 | pos_conv_input = tf.concat([emb_pos_d,pos_mult_info,pos_sub_info], axis=-1) 281 | 282 | neg_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_neg_d) 283 | neg_sub_info = tf.expand_dims(class_vec,axis=1) - emb_neg_d 284 | neg_conv_input = tf.concat([emb_neg_d,neg_mult_info,neg_sub_info], axis=-1) 285 | 286 | 287 | #in fact that's 1D conv, but we implement it by conv2d 288 | pos_conv = tf.layers.conv2d( 289 | inputs = tf.expand_dims(pos_conv_input,axis=-1), 290 | filters = self.kernal_num, 291 | kernel_size=[self.kernal_width,self.embedding_size*3], 292 | strides = [1,self.embedding_size*3], 293 | padding = 'SAME', 294 | trainable = True, 295 | name='doc_conv' 296 | ) 297 | 298 | neg_conv = tf.layers.conv2d( 299 | inputs = tf.expand_dims(neg_conv_input,axis=-1), 300 | filters = self.kernal_num, 301 | kernel_size=[self.kernal_width,self.embedding_size*3], 302 | strides = [1,self.embedding_size*3], 303 | padding = 'SAME', 304 | trainable = True, 305 | name='doc_conv', 306 | reuse=True 307 | ) 308 | #shape=[batch,max_dlen,1,kernal_num] 309 | #reshape to [batch,max_dlen,kernal_num] 310 | rs_pos_conv = tf.squeeze(pos_conv) 311 | rs_neg_conv = tf.squeeze(neg_conv) 312 | 313 | #query_gate elment-wise multiply rs_pos_conv 314 | pos_gate_conv = tf.multiply(pos_query_gate, rs_pos_conv) 315 | neg_gate_conv = tf.multiply(neg_query_gate, rs_neg_conv) 316 | 317 | #K-max_pooling 318 | #transpose to [batch,knum,dlen],then get max k in each kernal filter 319 | transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1]) 320 | transpose_neg_gate_conv = tf.transpose(neg_gate_conv, perm=[0,2,1]) 321 | 322 | #shape = [batch,k_num,maxpolling_num] 323 | #the k-max pooling here is implemented by function top_k, so the relative position information is ignored 324 | pos_kmaxpooling,_ = tf.nn.top_k( 325 | input=transpose_pos_gate_conv, 326 | k=self.maxpooling_num, 327 | ) 328 | neg_kmaxpooling,_ = tf.nn.top_k( 329 | input=transpose_neg_gate_conv, 330 | k=self.maxpooling_num, 331 | ) 332 | 333 | pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1)) 334 | neg_encoder = tf.reshape(neg_kmaxpooling, shape=(self.batch_size,-1)) 335 | 336 | pos_decoder_mlp1 = tf.layers.dense( 337 | inputs=pos_encoder, 338 | units=self.decoder_mlp1_num, 339 | activation=tf.nn.tanh, 340 | trainable=True, 341 | name='decoder_mlp1' 342 | ) 343 | 344 | neg_decoder_mlp1 = tf.layers.dense( 345 | inputs=neg_encoder, 346 | units=self.decoder_mlp1_num, 347 | activation=tf.nn.tanh, 348 | trainable=True, 349 | name='decoder_mlp1', 350 | reuse=True 351 | ) 352 | 353 | pos_decoder_mlp2 = tf.layers.dense( 354 | inputs=pos_decoder_mlp1, 355 | units=self.decoder_mlp2_num, 356 | activation=tf.nn.tanh, 357 | trainable=True, 358 | name='decoder_mlp2' 359 | ) 360 | 361 | neg_decoder_mlp2 = tf.layers.dense( 362 | inputs=neg_decoder_mlp1, 363 | units=self.decoder_mlp2_num, 364 | activation=tf.nn.tanh, 365 | trainable=True, 366 | name='decoder_mlp2', 367 | reuse=True 368 | ) 369 | 370 | score_pos = pos_decoder_mlp2 371 | score_neg = neg_decoder_mlp2 372 | 373 | hinge_loss = tf.reduce_mean(tf.maximum(0.0, 1 - score_pos + score_neg)) 374 | adv_prob = tf.nn.softmax(tf.add(tf.matmul(pos_decoder_mlp1, self.adv_weight), self.adv_bias)) 375 | log_adv_prob = tf.log(adv_prob) 376 | adv_loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(log_adv_prob, tf.cast(input_q_index,tf.float32)), axis=1, keep_dims=True)) 377 | L2_adv_loss = self.regular_term*self.L2_adv_loss() 378 | 379 | #to apply GRL, we use two seperate optimizers for adversarial classifier and the rest part of DAZER 380 | #optimizer for adversarial classifier 381 | adv_var_list = [v for v in tf.trainable_variables() if 'adv' in v.name] 382 | adv_opt = tf.train.AdamOptimizer(learning_rate=self.adv_learning_rate, epsilon=self.epsilon).minimize(loss=(-1 * adv_loss + L2_adv_loss), var_list=adv_var_list) 383 | 384 | #optimizer for rest part of DAZER model 385 | L2_model_loss = self.regular_term*self.L2_model_loss() 386 | model_var_list = [v for v in tf.trainable_variables() if 'adv' not in v.name] 387 | loss = hinge_loss + L2_model_loss + (adv_loss * self.adv_term) 388 | model_opt = tf.train.AdamOptimizer(learning_rate=self.model_learning_rate, epsilon=self.epsilon).minimize(loss = loss, var_list = model_var_list) 389 | 390 | config = tf.ConfigProto() 391 | config.gpu_options.allow_growth = True 392 | val_results = [] 393 | save_num = 0 394 | save_var = [v for v in tf.trainable_variables()] 395 | 396 | # Create a local session to run the training. 397 | with tf.Session(config=config) as sess: 398 | saver = tf.train.Saver(max_to_keep=50,var_list=save_var) 399 | start_time = time.time() 400 | if not load_model: 401 | print ("Initializing a new model...") 402 | init = tf.global_variables_initializer() 403 | sess.run(init) 404 | print('New model initialized!') 405 | else: 406 | #to load trained model, and keep training 407 | #remember to change the name of ckpt file 408 | init = tf.global_variables_initializer() 409 | sess.run(init) 410 | saver.restore(sess, checkpoint_dir+'/zsl25.ckpt') 411 | print ("model loaded!") 412 | 413 | # Loop through training steps. 414 | step = 0 415 | loss_list = [] 416 | for epoch in range(int(self.max_epochs)): 417 | epoch_val_loss = 0 418 | epoch_loss = 0 419 | epoch_hinge_loss = 0. 420 | epoch_adv_loss = 0 421 | epoch_s = time.time() 422 | pair_stream = open(train_pair_file_path) 423 | 424 | for BATCH in self.data_generator.pairwise_reader(pair_stream, self.batch_size): 425 | step += 1 426 | X, Y = BATCH 427 | query = X[u'q'] 428 | str_query = X[u'q_str'] 429 | q_index = self.gen_adv_query_mask(str_query) 430 | pos_doc = X[u'd'] 431 | neg_doc = X[u'd_aux'] 432 | train_q_lens = X[u'q_lens'] 433 | M_query = self.gen_query_mask(query) 434 | M_pos = self.gen_doc_mask(pos_doc) 435 | M_neg = self.gen_doc_mask(neg_doc) 436 | 437 | if X[u'q_lens'].shape[0] != self.batch_size: 438 | continue 439 | train_feed_dict = {input_q:query, 440 | input_pos_d:pos_doc, 441 | q_lens:train_q_lens, 442 | input_neg_d:neg_doc, 443 | q_mask:M_query, 444 | pos_d_mask:M_pos, 445 | neg_d_mask:M_neg, 446 | input_q_index: q_index} 447 | 448 | _1,l,hinge_l,_2,adv_l = sess.run([model_opt,loss,hinge_loss,adv_opt,adv_loss], feed_dict=train_feed_dict) 449 | epoch_loss += l 450 | epoch_hinge_loss += hinge_l 451 | epoch_adv_loss += adv_l 452 | 453 | if (epoch + 1) % self.eval_frequency == 0: 454 | #after eval_frequency epochs we run model on val dataset 455 | val_start = time.time() 456 | val_pair_stream = open(val_pair_file_path) 457 | for BATCH in self.val_data_generator.pairwise_reader(val_pair_stream, self.batch_size): 458 | X_val,Y_val = BATCH 459 | query = X_val[u'q'] 460 | pos_doc = X_val[u'd'] 461 | neg_doc = X_val[u'd_aux'] 462 | val_q_lens = X_val[u'q_lens'] 463 | M_query = self.gen_query_mask(query) 464 | M_pos = self.gen_doc_mask(pos_doc) 465 | M_neg = self.gen_doc_mask(neg_doc) 466 | if X_val[u'q'].shape[0] != self.batch_size: 467 | continue 468 | train_feed_dict = {input_q:query, 469 | input_pos_d:pos_doc, 470 | input_neg_d:neg_doc, 471 | q_lens:val_q_lens, 472 | q_mask:M_query, 473 | pos_d_mask:M_pos, 474 | neg_d_mask:M_neg} 475 | 476 | # Run the graph and fetch some of the nodes. 477 | v_loss = sess.run(hinge_loss, feed_dict=train_feed_dict) 478 | epoch_val_loss += v_loss 479 | val_results.append(epoch_val_loss) 480 | 481 | val_end = time.time() 482 | print('---Validation:epoch %d, %.1f ms , val_loss are %f' % (epoch+1,val_end-val_start,epoch_val_loss)) 483 | sys.stdout.flush() 484 | loss_list.append(epoch_loss) 485 | epoch_e = time.time() 486 | print('---Train:%d epoches cost %f seconds, hinge cost = %f model cost = %f, adv cost = %f...'%(epoch+1,epoch_e-epoch_s,epoch_hinge_loss, epoch_loss,epoch_adv_loss)) 487 | # save model after checkpoint_steps epochs 488 | if (epoch+1)%self.checkpoint_steps == 0: 489 | save_num += 1 490 | saver.save(sess, checkpoint_dir + 'zsl'+str(epoch+1)+'.ckpt') 491 | pair_stream.close() 492 | 493 | with open('save_training_loss.txt','w') as f: 494 | for index,_loss in enumerate(loss_list): 495 | f.write('epoch'+str(index+1)+', loss:'+str(_loss)+'\n') 496 | 497 | with open('save_val_cost.txt','w') as f: 498 | for index, v_l in enumerate(val_results): 499 | f.write('epoch'+str((index+1)*self.eval_frequency)+' val loss:'+str(v_l)+'\n') 500 | 501 | # end training 502 | end_time = time.time() 503 | print('All costs %f seconds...'%(end_time-start_time)) 504 | 505 | def test(self, test_point_file_path, test_size, output_file_path, checkpoint_dir=None, load_model=False): 506 | 507 | input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len]) 508 | input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len]) 509 | q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,]) 510 | q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len]) 511 | pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len]) 512 | 513 | emb_q = tf.nn.embedding_lookup(self.embeddings,input_q) 514 | class_vec_sum = tf.reduce_sum( 515 | tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)), 516 | axis=1 517 | ) 518 | 519 | class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,axis=-1)) 520 | emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d) 521 | 522 | #get query gate 523 | query_gate = self.get_class_gate(class_vec, emb_pos_d) 524 | pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d) 525 | pos_sub_info = tf.expand_dims(class_vec, axis=1) - emb_pos_d 526 | pos_conv_input = tf.concat([emb_pos_d,pos_mult_info, pos_sub_info], axis=-1) 527 | 528 | # CNN for document 529 | pos_conv = tf.layers.conv2d( 530 | inputs = tf.expand_dims(pos_conv_input,axis=-1), 531 | filters = self.kernal_num, 532 | kernel_size=[self.kernal_width,self.embedding_size*3], 533 | strides = [1,self.embedding_size*3], 534 | padding = 'SAME', 535 | trainable = True, 536 | name='doc_conv' 537 | ) 538 | 539 | #shape=[batch,max_dlen,1,kernal_num] 540 | #reshape to [batch,max_dlen,kernal_num] 541 | rs_pos_conv = tf.squeeze(pos_conv) 542 | 543 | #query_gate elment-wise multiply rs_pos_conv 544 | #[batch,kernal_num] , [batch,max_dlen,kernal_num] 545 | pos_gate_conv = tf.multiply(query_gate, rs_pos_conv) 546 | 547 | #K-max_pooling 548 | #transpose to [batch,knum,dlen],then get max k in each kernal filter 549 | transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1]) 550 | 551 | #[batch,k_num,maxpolling_num] 552 | pos_kmaxpooling,_ = tf.nn.top_k( 553 | input=transpose_pos_gate_conv, 554 | k=self.maxpooling_num, 555 | ) 556 | pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1)) 557 | 558 | pos_decoder_mlp1 = tf.layers.dense( 559 | inputs=pos_encoder, 560 | units=self.decoder_mlp1_num, 561 | activation=tf.nn.tanh, 562 | trainable=True, 563 | name='decoder_mlp1' 564 | ) 565 | 566 | pos_decoder_mlp2 = tf.layers.dense( 567 | inputs=pos_decoder_mlp1, 568 | units=self.decoder_mlp2_num, 569 | activation=tf.nn.tanh, 570 | trainable=True, 571 | name='decoder_mlp2' 572 | ) 573 | 574 | score_pos = pos_decoder_mlp2 575 | config = tf.ConfigProto() 576 | config.gpu_options.allow_growth = True 577 | save_var = [v for v in tf.trainable_variables()] 578 | # Create a local session to run the testing. 579 | for i in range(int(self.max_epochs/self.checkpoint_steps)): 580 | with tf.Session(config=config) as sess: 581 | test_point_stream = open(test_point_file_path) 582 | outfile = open(output_file_path+'-epoch'+str(self.checkpoint_steps*(i+1))+'.txt', 'w') 583 | saver = tf.train.Saver(var_list=save_var) 584 | 585 | if load_model: 586 | p = checkpoint_dir + 'zsl'+str(self.checkpoint_steps*(i+1))+'.ckpt' 587 | init = tf.global_variables_initializer() 588 | sess.run(init) 589 | saver.restore(sess, p) 590 | print ("data loaded!") 591 | else: 592 | init = tf.global_variables_initializer() 593 | sess.run(init) 594 | 595 | # Loop through training steps. 596 | for b in range(int(np.ceil(float(test_size)/self.batch_size))): 597 | X = next(self.test_data_generator.test_pairwise_reader(test_point_stream, self.batch_size)) 598 | if(X[u'q'].shape[0] != self.batch_size): 599 | continue 600 | query = X[u'q'] 601 | pos_doc = X[u'd'] 602 | test_q_lens = X[u'q_lens'] 603 | M_query = self.gen_query_mask(query) 604 | M_pos = self.gen_doc_mask(pos_doc) 605 | test_feed_dict = {input_q: query, 606 | input_pos_d: pos_doc, 607 | q_lens: test_q_lens, 608 | q_mask: M_query, 609 | pos_d_mask: M_pos} 610 | 611 | # Run the graph and fetch some of the nodes. 612 | scores = sess.run(score_pos, feed_dict=test_feed_dict) 613 | 614 | for score in scores: 615 | outfile.write('{0}\n'.format(score[0])) 616 | 617 | outfile.close() 618 | test_point_stream.close() 619 | 620 | if __name__ == '__main__': 621 | parser = argparse.ArgumentParser() 622 | parser.add_argument("config_file_path") 623 | 624 | parser.add_argument("--train", action='store_true') 625 | parser.add_argument("--train_file", '-f', help="train_pair_file_path") 626 | parser.add_argument("--validation_file", '-v', help="val_pair_file_path") 627 | parser.add_argument("--train_size", '-z', type=int, help="number of train samples") 628 | parser.add_argument("--load_model", '-l', action='store_true') 629 | 630 | parser.add_argument("--test", action="store_true") 631 | parser.add_argument("--test_file") 632 | parser.add_argument("--test_size", type=int, default=0) 633 | parser.add_argument("--output_score_file", '-o') 634 | parser.add_argument("--emb_file_path", '-e') 635 | parser.add_argument("--checkpoint_dir", '-s', help="store data to here") 636 | 637 | args = parser.parse_args() 638 | 639 | conf = PyFileConfigLoader(args.config_file_path).load_config() 640 | 641 | if args.train: 642 | nn = DAZER(config=conf) 643 | nn.train(train_pair_file_path=args.train_file, 644 | val_pair_file_path=args.validation_file, 645 | checkpoint_dir=args.checkpoint_dir, 646 | load_model=args.load_model) 647 | else: 648 | nn = DAZER(config=conf) 649 | nn.test(test_point_file_path=args.test_file, 650 | test_size=args.test_size, 651 | output_file_path=args.output_score_file, 652 | load_model=True, 653 | checkpoint_dir=args.checkpoint_dir) 654 | 655 | -------------------------------------------------------------------------------- /output/readme.txt: -------------------------------------------------------------------------------- 1 | Output file is saved in this directory 2 | -------------------------------------------------------------------------------- /sample-train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='0' python model.py sample.config --train --train_file dataset/20ng_train.txt --validation_file dataset/20ng_val.txt --checkpoint_dir output/ 2 | -------------------------------------------------------------------------------- /sample.config: -------------------------------------------------------------------------------- 1 | c = get_config() 2 | 3 | c.DataGenerator.max_q_len=8 4 | c.DataGenerator.max_d_len=500 5 | c.DataGenerator.vocabulary_size=253988 6 | 7 | c.BaseNN.vocabulary_size=253988 8 | c.BaseNN.embedding_size=300 9 | c.BaseNN.max_q_len=8 10 | c.BaseNN.max_d_len=500 11 | c.BaseNN.max_epochs=50 12 | c.BaseNN.eval_frequency=5 13 | c.BaseNN.checkpoint_steps=5 14 | c.BaseNN.batch_size=16 15 | 16 | c.DAZER.emb_in = 'glove_20ng_knrm.txt' 17 | c.DAZER.kernal_width=5 18 | c.DAZER.kernal_num=50 19 | c.DAZER.regular_term=0.0001 20 | c.DAZER.adv_term = 0.1 21 | c.DAZER.train_class_num = 17 22 | c.DAZER.model_learning_rate=0.00001 23 | c.DAZER.adv_learning_rate=0.00001 24 | c.DAZER.maxpooling_num=3 25 | c.DAZER.decoder_mlp1_num=75 26 | c.DAZER.decoder_mlp2_num=1 27 | c.DAZER.word2id_path = 'knrm_word2id.txt' 28 | c.DAZER.zsl_num=1 29 | c.DAZER.zsl_type=1 30 | c.DAZER.label_dict_path = '20ng seedwords.txt' 31 | --------------------------------------------------------------------------------