├── README.md ├── dataset └── readme.txt ├── get_label.py ├── model-img.png ├── model.py ├── output └── readme.txt ├── sample-train.sh └── sample.config /README.md: -------------------------------------------------------------------------------- 1 | # DAZER 2 | The Tensorflow implementation of our ACL 2018 paper: 3 | ***A Deep Relevance Model for Zero-Shot Document Filtering, Chenliang Li, Wei Zhou, Feng Ji, Yu Duan, Haiqing Chen*** 4 | Paper url: http://aclweb.org/anthology/P18-1214 5 | 6 |
7 |
8 |
BaseNN.embedding_size
: embedding dimension of word
117 | - BaseNN.max_q_len
: max query length
118 | - BaseNN.max_d_len
: max document length
119 | - DataGenerator.max_q_len
: max query length. Should be the same as BaseNN.max_q_len
120 | - DataGenerator.max_d_len
: max query length. Should be the same as BaseNN.max_d_len
121 | - BaseNN.vocabulary_size
: vocabulary size
122 | - DataGenerator.vocabulary_size
: vocabulary size
123 | - BaseNN.batch_size
: batch size
124 | - BaseNN.max_epochs
: max number of epochs to train
125 | - BaseNN.eval_frequency
: evaluate model on validation set very this epochs
126 | - BaseNN.checkpoint_steps
: save model very this epochs
127 |
128 |
129 | **Data**
130 | - DAZER.emb_in
: path of initial embeddings file
131 | - DAZER.label_dict_path
: path of label dict file
132 | - DAZER.word2id_path
: path of word2id file
133 |
134 |
135 | **Training Parameters**
136 | - DAZER.epsilon
: epsilon for Adam Optimizer
137 | - DAZER.embedding_size
: embedding dimension of word
138 | - DAZER.vocabulary_size
: vocabulary size of the dataset
139 | - DAZER.kernal_width
: width of the kernel
140 | - DAZER.kernal_num
: num of kernel
141 | - DAZER.regular_term
: weight of L2 loss
142 | - DAZER.maxpooling_num
: num of K-max pooling
143 | - DAZER.decoder_mlp1_num
: num of hidden units of first mlp in relevance aggregation part
144 | - DAZER.decoder_mlp2_num
: num of hidden units of second mlp in relevance aggregation part
145 | - DAZER.model_learning_rate
: learning rate for model instead of adversarial calssifier
146 | - DAZER.adv_learning_rate
: learning rate for adversarial classfier
147 | - DAZER.train_class_num
: num of class in training time
148 | - DAZER.adv_term
: weight of adversarial loss when updating model's parameters
149 | - DAZER.zsl_num
: num of zero-shot labels
150 | - DAZER.zsl_type
: type of zero-shot label setting ( you may have multiply zero-shot settings in same number of zero-shot label, this indicates which type of zero-shot label setting you pick for experiemnt, see [get_label.py](https://github.com/WHUIR/DAZER/blob/master/get_label.py) for more details )
151 |
--------------------------------------------------------------------------------
/dataset/readme.txt:
--------------------------------------------------------------------------------
1 | Dataset is saved in this directory.
2 |
--------------------------------------------------------------------------------
/get_label.py:
--------------------------------------------------------------------------------
1 |
2 | def get_word2id(word2id_path):
3 | word2id = {}
4 | with open(word2id_path,'r',encoding='gbk') as f:
5 | for line in f:
6 | w,id = line.strip().split(' ')
7 | word2id[w] = int(id)
8 | return word2id
9 |
10 | def get_labels(label_dict_path,word2id_path):
11 | #use the label-dict file and word2id file to get label_dict, reverse_label_dict and label_list
12 | #which is useful in our DAZER model
13 | label_dict = {}
14 | reverse_label_dict = {}
15 | label_list = []
16 | word2id = get_word2id(word2id_path)
17 | with open(label_dict_path,'r') as f:
18 | for line in f:
19 | c_name,words = line.strip().split('/')
20 | ids = [word2id[w] for w in words.split(' ')]
21 | label_dict[c_name] = ids
22 | label_list.append(c_name)
23 | ids_str = ','.join([str(x) for x in ids])
24 | reverse_label_dict[ids_str] = c_name
25 | return label_dict, reverse_label_dict, label_list
26 |
27 | def get_label_index(label_list, zsl_num,zsl_type):
28 | #get the index of zeroshot label
29 | #below is the experiments setting of 20NG in our ACL paper, you should change them in your own dataset
30 |
31 | #e.g., zeroshot_labels_1[0] = [['sci.space'],['comp.graphics']]
32 | #it means we use label "sci.space" for zeroshot experiments
33 | #and randomly pick label 'comp.graphics' to prevent overfitting
34 | #please refer to the "Evaluation protocol" part of our paper
35 |
36 | zeroshot_labels_1 = [
37 | [['sci.space'],['comp.graphics']],
38 | [['rec.sport.baseball'],['talk.politics.misc']],
39 | [['sci.med'],['rec.autos']],
40 | [['comp.sys.ibm.pc.hardware'],['rec.sport.hockey']],
41 | ]
42 |
43 | zeroshot_labels_2= [
44 | [['sci.med','sci.space'],['talk.politics.guns']],
45 | [['alt.atheism','sci.electronics'],['comp.sys.ibm.pc.hardware']],
46 | [['soc.religion.christian','talk.politics.mideast'],['rec.sport.baseball']],
47 | [['rec.sport.baseball','rec.sport.hockey'],['comp.sys.mac.hardware']]
48 | ]
49 |
50 | zeroshot_labels_3 = [
51 | [['comp.sys.ibm.pc.hardware','comp.windows.x','sci.electronics'],['talk.politics.mideast']],
52 | ]
53 |
54 | zeroshot_labels = [zeroshot_labels_1,zeroshot_labels_2,zeroshot_labels_3]
55 |
56 | z_labels = zeroshot_labels[zsl_num-1][zsl_type-1][0] + zeroshot_labels[zsl_num-1][zsl_type-1][1]
57 | label_test = []
58 | for _l in label_list:
59 | if _l not in z_labels:
60 | label_test.append(_l)
61 | indexs = list(range(len(label_test)))
62 | zip_label_index = zip(label_test, indexs)
63 | return dict(list(zip_label_index))
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/model-img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WHUIR/DAZER/cc028184b120148eb45bba875b7f3f4c7f0e5294/model-img.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import time
4 | import get_label
5 |
6 | import sys
7 | import argparse
8 | from traitlets.config.loader import PyFileConfigLoader
9 | from traitlets.config import Configurable
10 | from traitlets import (
11 | Int,
12 | Float,
13 | Bool,
14 | Unicode,
15 | )
16 |
17 | class DataGenerator(Configurable):
18 | #params for data generator
19 | max_q_len = Int(10, help='max q len').tag(config=True)
20 | max_d_len = Int(500, help='max document len').tag(config=True)
21 | q_name = Unicode('q')
22 | d_name = Unicode('d')
23 | q_str_name = Unicode('q_str')
24 | q_lens_name = Unicode('q_lens')
25 | aux_d_name = Unicode('d_aux')
26 | vocabulary_size = Int(2000000).tag(config=True)
27 |
28 | def __init__(self, **kwargs):
29 | #init the data generator
30 | super(DataGenerator, self).__init__(**kwargs)
31 | print ("generator's vocabulary size: ", self.vocabulary_size)
32 |
33 | def pairwise_reader(self, pair_stream, batch_size, with_idf=False):
34 | #generate the batch of x,y in training time
35 | l_q = []
36 | l_q_str = []
37 | l_d = []
38 | l_d_aux = []
39 | l_y = []
40 | l_q_lens = []
41 | for line in pair_stream:
42 | cols = line.strip().split('\t')
43 | y = float(1.0)
44 | l_q_str.append(cols[0])
45 | q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size])
46 | t1 = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size])
47 | t2 = np.array([int(t) for t in cols[2].split(',') if int(t) < self.vocabulary_size])
48 |
49 | #padding
50 | v_q = np.zeros(self.max_q_len)
51 | v_d = np.zeros(self.max_d_len)
52 | v_d_aux = np.zeros(self.max_d_len)
53 |
54 | v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)]
55 | v_d[:min(t1.shape[0], self.max_d_len)] = t1[:min(t1.shape[0], self.max_d_len)]
56 | v_d_aux[:min(t2.shape[0], self.max_d_len)] = t2[:min(t2.shape[0], self.max_d_len)]
57 |
58 | l_q.append(v_q)
59 | l_d.append(v_d)
60 | l_d_aux.append(v_d_aux)
61 | l_y.append(y)
62 | l_q_lens.append(len(q))
63 |
64 | if len(l_q) >= batch_size:
65 | Q = np.array(l_q, dtype=int,)
66 | D = np.array(l_d, dtype=int,)
67 | D_aux = np.array(l_d_aux, dtype=int,)
68 | Q_lens = np.array(l_q_lens, dtype=int,)
69 | Y = np.array(l_y, dtype=int,)
70 | X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str}
71 | yield X, Y
72 | l_q, l_d, l_d_aux, l_y, l_q_lens, l_ids, l_q_str = [], [], [], [], [], [], []
73 | if l_q:
74 | Q = np.array(l_q, dtype=int,)
75 | D = np.array(l_d, dtype=int,)
76 | D_aux = np.array(l_d_aux, dtype=int,)
77 | Q_lens = np.array(l_q_lens, dtype=int,)
78 | Y = np.array(l_y, dtype=int,)
79 | X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str}
80 | yield X, Y
81 |
82 | def test_pairwise_reader(self, pair_stream, batch_size):
83 | #generate the batch of x,y in test time
84 | l_q = []
85 | l_q_lens = []
86 | l_d = []
87 |
88 | for line in pair_stream:
89 | cols = line.strip().split('\t')
90 | q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size])
91 | t = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size])
92 |
93 | v_q = np.zeros(self.max_q_len)
94 | v_d = np.zeros(self.max_d_len)
95 |
96 | v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)]
97 | v_d[:min(t.shape[0], self.max_d_len)] = t[:min(t.shape[0], self.max_d_len)]
98 |
99 | l_q.append(v_q)
100 | l_d.append(v_d)
101 | l_q_lens.append(len(q))
102 |
103 | if len(l_q) >= batch_size:
104 | Q = np.array(l_q, dtype=int,)
105 | D = np.array(l_d, dtype=int,)
106 | Q_lens = np.array(l_q_lens, dtype=int,)
107 | X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens}
108 | yield X
109 | l_q, l_d, l_q_lens = [], [], []
110 | if l_q:
111 | Q = np.array(l_q, dtype=int,)
112 | D = np.array(l_d, dtype=int,)
113 | Q_lens = np.array(l_q_lens, dtype=int,)
114 | X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens}
115 | yield X
116 |
117 | class BaseNN(Configurable):
118 | #params of base deeprank model
119 | max_q_len = Int(10, help='max q len').tag(config=True)
120 | max_d_len = Int(50, help='max document len').tag(config=True)
121 | batch_size = Int(16, help="minibatch size").tag(config=True)
122 | max_epochs = Float(10, help="maximum number of epochs").tag(config=True)
123 | eval_frequency = Int(10000, help="print out minibatch every * epoches").tag(config=True)
124 | checkpoint_steps = Int(10000, help="store trained model every * epoches").tag(config=True)
125 |
126 | def __init__(self, **kwargs):
127 | super(BaseNN, self).__init__(**kwargs)
128 | # generator
129 | self.data_generator = DataGenerator(config=self.config)
130 | self.val_data_generator = DataGenerator(config=self.config) #validation in training stage is full test data in 20ng
131 | self.test_data_generator = DataGenerator(config=self.config) #test is zeros shot test data in 20ng (delete docs of zero shot label)
132 |
133 | @staticmethod
134 | def weight_variable(shape,name):
135 | tmp = np.sqrt(3.0) / np.sqrt(shape[0] + shape[1])
136 | initial = tf.random_uniform(shape, minval=-tmp, maxval=tmp)
137 | return tf.Variable(initial_value=initial,name=name)
138 |
139 | def gen_query_mask(self, Q):
140 | mask = np.zeros((self.batch_size, self.max_q_len))
141 | for b in range(len(Q)):
142 | for q in range(len(Q[b])):
143 | if Q[b][q] > 0:
144 | mask[b][q] = 1
145 |
146 | return mask
147 |
148 | def gen_doc_mask(self, D):
149 | mask = np.zeros((self.batch_size, self.max_d_len))
150 | for b in range(len(D)):
151 | for q in range(len(D[b])):
152 | if D[b][q] > 0:
153 | mask[b][q] = 1
154 |
155 | return mask
156 |
157 | class DAZER(BaseNN):
158 | #params of zeroshot document filtering model
159 | embedding_size = Int(300, help="embedding dimension").tag(config=True)
160 | vocabulary_size = Int(2000000, help="vocabulary size").tag(config=True)
161 | kernal_width = Int(5, help='kernal width').tag(config=True)
162 | kernal_num = Int(50, help='number of kernal').tag(config=True)
163 | regular_term = Float(0.01, help='param for controlling wight of L2 loss').tag(config=True)
164 | maxpooling_num = Int(3, help='number of k-maxpooling').tag(config=True)
165 | decoder_mlp1_num = Int(75, help='number of hidden units of first mlp in relevance aggregation part').tag(config=True)
166 | decoder_mlp2_num = Int(1, help='number of hidden units of second mlp in relevance aggregation part').tag(config=True)
167 | emb_in = Unicode('None', help="initial embedding. Terms should be hashed to ids.").tag(config=True)
168 | model_learning_rate = Float(0.001, help="learning rate of model").tag(config=True)
169 | adv_learning_rate = Float(0.001, help='learning rate of adv classifier').tag(config=True)
170 | epsilon = Float(0.00001, help="Epsilon for Adam").tag(config=True)
171 | label_dict_path = Unicode('None', help='label dict path').tag(config=True)
172 | word2id_path = Unicode('None', help='word2id path').tag(config=True)
173 | train_class_num = Int(16, help='num of class in training data').tag(config=True)
174 | adv_term = Float(0.2, help='regular term of adversrial loss').tag(config=True)
175 | zsl_num = Int(1, help='num of zeroshot label').tag(config=True)
176 | zsl_type = Int(1, help='type of zeroshot label setting').tag(config=True)
177 |
178 | def __init__(self, **kwargs):
179 | #init the DAZER model
180 | super(DAZER, self).__init__(**kwargs)
181 | print ("trying to load initial embeddings from: ", self.emb_in)
182 | if self.emb_in != 'None':
183 | self.emb = self.load_word2vec(self.emb_in)
184 | self.embeddings = tf.Variable(tf.constant(self.emb, dtype='float32', shape=[self.vocabulary_size + 1, self.embedding_size]),trainable=False)
185 | print ("Initialized embeddings with {0}".format(self.emb_in))
186 | else:
187 | self.embeddings = tf.Variable(tf.random_uniform([self.vocabulary_size + 1, self.embedding_size], -1.0, 1.0))
188 |
189 | #variables of the DAZER model
190 | self.query_gate_weight = BaseNN.weight_variable((self.embedding_size, self.kernal_num),'gate_weight')
191 | self.query_gate_bias = tf.Variable(initial_value=tf.zeros((self.kernal_num)),name='gate_bias')
192 | self.adv_weight = BaseNN.weight_variable((self.decoder_mlp1_num,self.train_class_num),name='adv_weight')
193 | self.adv_bias = tf.Variable(initial_value=tf.zeros((1,self.train_class_num)),name='adv_bias')
194 | #get the label information to help adversarial learning
195 | self.label_dict, self.reverse_label_dict, self.label_list = get_label.get_labels(self.label_dict_path, self.word2id_path)
196 | self.label_index_dict = get_label.get_label_index(self.label_list, self.zsl_num, self.zsl_type)
197 |
198 | def load_word2vec(self, emb_file_path):
199 | emb = np.zeros((self.vocabulary_size + 1, self.embedding_size))
200 | nlines = 0
201 | with open(emb_file_path) as f:
202 | for line in f:
203 | nlines += 1
204 | if nlines == 1:
205 | continue
206 | items = line.split()
207 | tid = int(items[0])
208 | if tid > self.vocabulary_size:
209 | print (tid)
210 | continue
211 | vec = np.array([float(t) for t in items[1:]])
212 | emb[tid, :] = vec
213 | if nlines % 20000 == 0:
214 | print ("load {0} vectors...".format(nlines))
215 | return emb
216 |
217 | def gen_adv_query_mask(self, q_ids):
218 | q_mask = np.zeros((self.batch_size, self.train_class_num))
219 | for batch_num, b_q_id in enumerate(q_ids):
220 | c_name = self.reverse_label_dict[b_q_id]
221 | c_index = self.label_index_dict[c_name]
222 | q_mask[batch_num][c_index] = 1
223 | return q_mask
224 |
225 | def get_class_gate(self,class_vec, emb_d):
226 | '''
227 | compute the gate in kernal space
228 | :param class_vec: avg emb of seed words
229 | :param emb_d: emb of doc
230 | :return:the class gate [batchsize,d_len,kernal_num]
231 | '''
232 | gate1 = tf.expand_dims(tf.matmul(class_vec, self.query_gate_weight), axis=1)
233 | bias = tf.expand_dims(self.query_gate_bias,axis=0)
234 | gate = tf.add(gate1, bias)
235 | return tf.sigmoid(gate)
236 |
237 | def L2_model_loss(self):
238 | all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' not in v.name]
239 | loss = 0.
240 | for each in all_para:
241 | loss += tf.nn.l2_loss(each)
242 | return loss
243 |
244 | def L2_adv_loss(self):
245 | all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' in v.name]
246 | loss = 0.
247 | for each in all_para:
248 | loss += tf.nn.l2_loss(each)
249 | return loss
250 |
251 | def train(self, train_pair_file_path, val_pair_file_path, checkpoint_dir, load_model=False):
252 |
253 | input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len])
254 | input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len])
255 | input_neg_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len])
256 | q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,])
257 | q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len])
258 | pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len])
259 | neg_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len])
260 | input_q_index = tf.placeholder(tf.int32, shape=[self.batch_size,self.train_class_num])
261 |
262 | emb_q = tf.nn.embedding_lookup(self.embeddings,input_q)
263 | class_vec_sum = tf.reduce_sum(
264 | tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)),
265 | axis=1
266 | )
267 |
268 | #get class vec
269 | class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,-1))
270 | emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d)
271 | emb_neg_d = tf.nn.embedding_lookup(self.embeddings,input_neg_d)
272 |
273 | #get query gate
274 | pos_query_gate = self.get_class_gate(class_vec, emb_pos_d)
275 | neg_query_gate = self.get_class_gate(class_vec, emb_neg_d)
276 |
277 | # CNN for document
278 | pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d)
279 | pos_sub_info = tf.expand_dims(class_vec,axis=1) - emb_pos_d
280 | pos_conv_input = tf.concat([emb_pos_d,pos_mult_info,pos_sub_info], axis=-1)
281 |
282 | neg_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_neg_d)
283 | neg_sub_info = tf.expand_dims(class_vec,axis=1) - emb_neg_d
284 | neg_conv_input = tf.concat([emb_neg_d,neg_mult_info,neg_sub_info], axis=-1)
285 |
286 |
287 | #in fact that's 1D conv, but we implement it by conv2d
288 | pos_conv = tf.layers.conv2d(
289 | inputs = tf.expand_dims(pos_conv_input,axis=-1),
290 | filters = self.kernal_num,
291 | kernel_size=[self.kernal_width,self.embedding_size*3],
292 | strides = [1,self.embedding_size*3],
293 | padding = 'SAME',
294 | trainable = True,
295 | name='doc_conv'
296 | )
297 |
298 | neg_conv = tf.layers.conv2d(
299 | inputs = tf.expand_dims(neg_conv_input,axis=-1),
300 | filters = self.kernal_num,
301 | kernel_size=[self.kernal_width,self.embedding_size*3],
302 | strides = [1,self.embedding_size*3],
303 | padding = 'SAME',
304 | trainable = True,
305 | name='doc_conv',
306 | reuse=True
307 | )
308 | #shape=[batch,max_dlen,1,kernal_num]
309 | #reshape to [batch,max_dlen,kernal_num]
310 | rs_pos_conv = tf.squeeze(pos_conv)
311 | rs_neg_conv = tf.squeeze(neg_conv)
312 |
313 | #query_gate elment-wise multiply rs_pos_conv
314 | pos_gate_conv = tf.multiply(pos_query_gate, rs_pos_conv)
315 | neg_gate_conv = tf.multiply(neg_query_gate, rs_neg_conv)
316 |
317 | #K-max_pooling
318 | #transpose to [batch,knum,dlen],then get max k in each kernal filter
319 | transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1])
320 | transpose_neg_gate_conv = tf.transpose(neg_gate_conv, perm=[0,2,1])
321 |
322 | #shape = [batch,k_num,maxpolling_num]
323 | #the k-max pooling here is implemented by function top_k, so the relative position information is ignored
324 | pos_kmaxpooling,_ = tf.nn.top_k(
325 | input=transpose_pos_gate_conv,
326 | k=self.maxpooling_num,
327 | )
328 | neg_kmaxpooling,_ = tf.nn.top_k(
329 | input=transpose_neg_gate_conv,
330 | k=self.maxpooling_num,
331 | )
332 |
333 | pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1))
334 | neg_encoder = tf.reshape(neg_kmaxpooling, shape=(self.batch_size,-1))
335 |
336 | pos_decoder_mlp1 = tf.layers.dense(
337 | inputs=pos_encoder,
338 | units=self.decoder_mlp1_num,
339 | activation=tf.nn.tanh,
340 | trainable=True,
341 | name='decoder_mlp1'
342 | )
343 |
344 | neg_decoder_mlp1 = tf.layers.dense(
345 | inputs=neg_encoder,
346 | units=self.decoder_mlp1_num,
347 | activation=tf.nn.tanh,
348 | trainable=True,
349 | name='decoder_mlp1',
350 | reuse=True
351 | )
352 |
353 | pos_decoder_mlp2 = tf.layers.dense(
354 | inputs=pos_decoder_mlp1,
355 | units=self.decoder_mlp2_num,
356 | activation=tf.nn.tanh,
357 | trainable=True,
358 | name='decoder_mlp2'
359 | )
360 |
361 | neg_decoder_mlp2 = tf.layers.dense(
362 | inputs=neg_decoder_mlp1,
363 | units=self.decoder_mlp2_num,
364 | activation=tf.nn.tanh,
365 | trainable=True,
366 | name='decoder_mlp2',
367 | reuse=True
368 | )
369 |
370 | score_pos = pos_decoder_mlp2
371 | score_neg = neg_decoder_mlp2
372 |
373 | hinge_loss = tf.reduce_mean(tf.maximum(0.0, 1 - score_pos + score_neg))
374 | adv_prob = tf.nn.softmax(tf.add(tf.matmul(pos_decoder_mlp1, self.adv_weight), self.adv_bias))
375 | log_adv_prob = tf.log(adv_prob)
376 | adv_loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(log_adv_prob, tf.cast(input_q_index,tf.float32)), axis=1, keep_dims=True))
377 | L2_adv_loss = self.regular_term*self.L2_adv_loss()
378 |
379 | #to apply GRL, we use two seperate optimizers for adversarial classifier and the rest part of DAZER
380 | #optimizer for adversarial classifier
381 | adv_var_list = [v for v in tf.trainable_variables() if 'adv' in v.name]
382 | adv_opt = tf.train.AdamOptimizer(learning_rate=self.adv_learning_rate, epsilon=self.epsilon).minimize(loss=(-1 * adv_loss + L2_adv_loss), var_list=adv_var_list)
383 |
384 | #optimizer for rest part of DAZER model
385 | L2_model_loss = self.regular_term*self.L2_model_loss()
386 | model_var_list = [v for v in tf.trainable_variables() if 'adv' not in v.name]
387 | loss = hinge_loss + L2_model_loss + (adv_loss * self.adv_term)
388 | model_opt = tf.train.AdamOptimizer(learning_rate=self.model_learning_rate, epsilon=self.epsilon).minimize(loss = loss, var_list = model_var_list)
389 |
390 | config = tf.ConfigProto()
391 | config.gpu_options.allow_growth = True
392 | val_results = []
393 | save_num = 0
394 | save_var = [v for v in tf.trainable_variables()]
395 |
396 | # Create a local session to run the training.
397 | with tf.Session(config=config) as sess:
398 | saver = tf.train.Saver(max_to_keep=50,var_list=save_var)
399 | start_time = time.time()
400 | if not load_model:
401 | print ("Initializing a new model...")
402 | init = tf.global_variables_initializer()
403 | sess.run(init)
404 | print('New model initialized!')
405 | else:
406 | #to load trained model, and keep training
407 | #remember to change the name of ckpt file
408 | init = tf.global_variables_initializer()
409 | sess.run(init)
410 | saver.restore(sess, checkpoint_dir+'/zsl25.ckpt')
411 | print ("model loaded!")
412 |
413 | # Loop through training steps.
414 | step = 0
415 | loss_list = []
416 | for epoch in range(int(self.max_epochs)):
417 | epoch_val_loss = 0
418 | epoch_loss = 0
419 | epoch_hinge_loss = 0.
420 | epoch_adv_loss = 0
421 | epoch_s = time.time()
422 | pair_stream = open(train_pair_file_path)
423 |
424 | for BATCH in self.data_generator.pairwise_reader(pair_stream, self.batch_size):
425 | step += 1
426 | X, Y = BATCH
427 | query = X[u'q']
428 | str_query = X[u'q_str']
429 | q_index = self.gen_adv_query_mask(str_query)
430 | pos_doc = X[u'd']
431 | neg_doc = X[u'd_aux']
432 | train_q_lens = X[u'q_lens']
433 | M_query = self.gen_query_mask(query)
434 | M_pos = self.gen_doc_mask(pos_doc)
435 | M_neg = self.gen_doc_mask(neg_doc)
436 |
437 | if X[u'q_lens'].shape[0] != self.batch_size:
438 | continue
439 | train_feed_dict = {input_q:query,
440 | input_pos_d:pos_doc,
441 | q_lens:train_q_lens,
442 | input_neg_d:neg_doc,
443 | q_mask:M_query,
444 | pos_d_mask:M_pos,
445 | neg_d_mask:M_neg,
446 | input_q_index: q_index}
447 |
448 | _1,l,hinge_l,_2,adv_l = sess.run([model_opt,loss,hinge_loss,adv_opt,adv_loss], feed_dict=train_feed_dict)
449 | epoch_loss += l
450 | epoch_hinge_loss += hinge_l
451 | epoch_adv_loss += adv_l
452 |
453 | if (epoch + 1) % self.eval_frequency == 0:
454 | #after eval_frequency epochs we run model on val dataset
455 | val_start = time.time()
456 | val_pair_stream = open(val_pair_file_path)
457 | for BATCH in self.val_data_generator.pairwise_reader(val_pair_stream, self.batch_size):
458 | X_val,Y_val = BATCH
459 | query = X_val[u'q']
460 | pos_doc = X_val[u'd']
461 | neg_doc = X_val[u'd_aux']
462 | val_q_lens = X_val[u'q_lens']
463 | M_query = self.gen_query_mask(query)
464 | M_pos = self.gen_doc_mask(pos_doc)
465 | M_neg = self.gen_doc_mask(neg_doc)
466 | if X_val[u'q'].shape[0] != self.batch_size:
467 | continue
468 | train_feed_dict = {input_q:query,
469 | input_pos_d:pos_doc,
470 | input_neg_d:neg_doc,
471 | q_lens:val_q_lens,
472 | q_mask:M_query,
473 | pos_d_mask:M_pos,
474 | neg_d_mask:M_neg}
475 |
476 | # Run the graph and fetch some of the nodes.
477 | v_loss = sess.run(hinge_loss, feed_dict=train_feed_dict)
478 | epoch_val_loss += v_loss
479 | val_results.append(epoch_val_loss)
480 |
481 | val_end = time.time()
482 | print('---Validation:epoch %d, %.1f ms , val_loss are %f' % (epoch+1,val_end-val_start,epoch_val_loss))
483 | sys.stdout.flush()
484 | loss_list.append(epoch_loss)
485 | epoch_e = time.time()
486 | print('---Train:%d epoches cost %f seconds, hinge cost = %f model cost = %f, adv cost = %f...'%(epoch+1,epoch_e-epoch_s,epoch_hinge_loss, epoch_loss,epoch_adv_loss))
487 | # save model after checkpoint_steps epochs
488 | if (epoch+1)%self.checkpoint_steps == 0:
489 | save_num += 1
490 | saver.save(sess, checkpoint_dir + 'zsl'+str(epoch+1)+'.ckpt')
491 | pair_stream.close()
492 |
493 | with open('save_training_loss.txt','w') as f:
494 | for index,_loss in enumerate(loss_list):
495 | f.write('epoch'+str(index+1)+', loss:'+str(_loss)+'\n')
496 |
497 | with open('save_val_cost.txt','w') as f:
498 | for index, v_l in enumerate(val_results):
499 | f.write('epoch'+str((index+1)*self.eval_frequency)+' val loss:'+str(v_l)+'\n')
500 |
501 | # end training
502 | end_time = time.time()
503 | print('All costs %f seconds...'%(end_time-start_time))
504 |
505 | def test(self, test_point_file_path, test_size, output_file_path, checkpoint_dir=None, load_model=False):
506 |
507 | input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len])
508 | input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len])
509 | q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,])
510 | q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len])
511 | pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len])
512 |
513 | emb_q = tf.nn.embedding_lookup(self.embeddings,input_q)
514 | class_vec_sum = tf.reduce_sum(
515 | tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)),
516 | axis=1
517 | )
518 |
519 | class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,axis=-1))
520 | emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d)
521 |
522 | #get query gate
523 | query_gate = self.get_class_gate(class_vec, emb_pos_d)
524 | pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d)
525 | pos_sub_info = tf.expand_dims(class_vec, axis=1) - emb_pos_d
526 | pos_conv_input = tf.concat([emb_pos_d,pos_mult_info, pos_sub_info], axis=-1)
527 |
528 | # CNN for document
529 | pos_conv = tf.layers.conv2d(
530 | inputs = tf.expand_dims(pos_conv_input,axis=-1),
531 | filters = self.kernal_num,
532 | kernel_size=[self.kernal_width,self.embedding_size*3],
533 | strides = [1,self.embedding_size*3],
534 | padding = 'SAME',
535 | trainable = True,
536 | name='doc_conv'
537 | )
538 |
539 | #shape=[batch,max_dlen,1,kernal_num]
540 | #reshape to [batch,max_dlen,kernal_num]
541 | rs_pos_conv = tf.squeeze(pos_conv)
542 |
543 | #query_gate elment-wise multiply rs_pos_conv
544 | #[batch,kernal_num] , [batch,max_dlen,kernal_num]
545 | pos_gate_conv = tf.multiply(query_gate, rs_pos_conv)
546 |
547 | #K-max_pooling
548 | #transpose to [batch,knum,dlen],then get max k in each kernal filter
549 | transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1])
550 |
551 | #[batch,k_num,maxpolling_num]
552 | pos_kmaxpooling,_ = tf.nn.top_k(
553 | input=transpose_pos_gate_conv,
554 | k=self.maxpooling_num,
555 | )
556 | pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1))
557 |
558 | pos_decoder_mlp1 = tf.layers.dense(
559 | inputs=pos_encoder,
560 | units=self.decoder_mlp1_num,
561 | activation=tf.nn.tanh,
562 | trainable=True,
563 | name='decoder_mlp1'
564 | )
565 |
566 | pos_decoder_mlp2 = tf.layers.dense(
567 | inputs=pos_decoder_mlp1,
568 | units=self.decoder_mlp2_num,
569 | activation=tf.nn.tanh,
570 | trainable=True,
571 | name='decoder_mlp2'
572 | )
573 |
574 | score_pos = pos_decoder_mlp2
575 | config = tf.ConfigProto()
576 | config.gpu_options.allow_growth = True
577 | save_var = [v for v in tf.trainable_variables()]
578 | # Create a local session to run the testing.
579 | for i in range(int(self.max_epochs/self.checkpoint_steps)):
580 | with tf.Session(config=config) as sess:
581 | test_point_stream = open(test_point_file_path)
582 | outfile = open(output_file_path+'-epoch'+str(self.checkpoint_steps*(i+1))+'.txt', 'w')
583 | saver = tf.train.Saver(var_list=save_var)
584 |
585 | if load_model:
586 | p = checkpoint_dir + 'zsl'+str(self.checkpoint_steps*(i+1))+'.ckpt'
587 | init = tf.global_variables_initializer()
588 | sess.run(init)
589 | saver.restore(sess, p)
590 | print ("data loaded!")
591 | else:
592 | init = tf.global_variables_initializer()
593 | sess.run(init)
594 |
595 | # Loop through training steps.
596 | for b in range(int(np.ceil(float(test_size)/self.batch_size))):
597 | X = next(self.test_data_generator.test_pairwise_reader(test_point_stream, self.batch_size))
598 | if(X[u'q'].shape[0] != self.batch_size):
599 | continue
600 | query = X[u'q']
601 | pos_doc = X[u'd']
602 | test_q_lens = X[u'q_lens']
603 | M_query = self.gen_query_mask(query)
604 | M_pos = self.gen_doc_mask(pos_doc)
605 | test_feed_dict = {input_q: query,
606 | input_pos_d: pos_doc,
607 | q_lens: test_q_lens,
608 | q_mask: M_query,
609 | pos_d_mask: M_pos}
610 |
611 | # Run the graph and fetch some of the nodes.
612 | scores = sess.run(score_pos, feed_dict=test_feed_dict)
613 |
614 | for score in scores:
615 | outfile.write('{0}\n'.format(score[0]))
616 |
617 | outfile.close()
618 | test_point_stream.close()
619 |
620 | if __name__ == '__main__':
621 | parser = argparse.ArgumentParser()
622 | parser.add_argument("config_file_path")
623 |
624 | parser.add_argument("--train", action='store_true')
625 | parser.add_argument("--train_file", '-f', help="train_pair_file_path")
626 | parser.add_argument("--validation_file", '-v', help="val_pair_file_path")
627 | parser.add_argument("--train_size", '-z', type=int, help="number of train samples")
628 | parser.add_argument("--load_model", '-l', action='store_true')
629 |
630 | parser.add_argument("--test", action="store_true")
631 | parser.add_argument("--test_file")
632 | parser.add_argument("--test_size", type=int, default=0)
633 | parser.add_argument("--output_score_file", '-o')
634 | parser.add_argument("--emb_file_path", '-e')
635 | parser.add_argument("--checkpoint_dir", '-s', help="store data to here")
636 |
637 | args = parser.parse_args()
638 |
639 | conf = PyFileConfigLoader(args.config_file_path).load_config()
640 |
641 | if args.train:
642 | nn = DAZER(config=conf)
643 | nn.train(train_pair_file_path=args.train_file,
644 | val_pair_file_path=args.validation_file,
645 | checkpoint_dir=args.checkpoint_dir,
646 | load_model=args.load_model)
647 | else:
648 | nn = DAZER(config=conf)
649 | nn.test(test_point_file_path=args.test_file,
650 | test_size=args.test_size,
651 | output_file_path=args.output_score_file,
652 | load_model=True,
653 | checkpoint_dir=args.checkpoint_dir)
654 |
655 |
--------------------------------------------------------------------------------
/output/readme.txt:
--------------------------------------------------------------------------------
1 | Output file is saved in this directory
2 |
--------------------------------------------------------------------------------
/sample-train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES='0' python model.py sample.config --train --train_file dataset/20ng_train.txt --validation_file dataset/20ng_val.txt --checkpoint_dir output/
2 |
--------------------------------------------------------------------------------
/sample.config:
--------------------------------------------------------------------------------
1 | c = get_config()
2 |
3 | c.DataGenerator.max_q_len=8
4 | c.DataGenerator.max_d_len=500
5 | c.DataGenerator.vocabulary_size=253988
6 |
7 | c.BaseNN.vocabulary_size=253988
8 | c.BaseNN.embedding_size=300
9 | c.BaseNN.max_q_len=8
10 | c.BaseNN.max_d_len=500
11 | c.BaseNN.max_epochs=50
12 | c.BaseNN.eval_frequency=5
13 | c.BaseNN.checkpoint_steps=5
14 | c.BaseNN.batch_size=16
15 |
16 | c.DAZER.emb_in = 'glove_20ng_knrm.txt'
17 | c.DAZER.kernal_width=5
18 | c.DAZER.kernal_num=50
19 | c.DAZER.regular_term=0.0001
20 | c.DAZER.adv_term = 0.1
21 | c.DAZER.train_class_num = 17
22 | c.DAZER.model_learning_rate=0.00001
23 | c.DAZER.adv_learning_rate=0.00001
24 | c.DAZER.maxpooling_num=3
25 | c.DAZER.decoder_mlp1_num=75
26 | c.DAZER.decoder_mlp2_num=1
27 | c.DAZER.word2id_path = 'knrm_word2id.txt'
28 | c.DAZER.zsl_num=1
29 | c.DAZER.zsl_type=1
30 | c.DAZER.label_dict_path = '20ng seedwords.txt'
31 |
--------------------------------------------------------------------------------