├── .DS_Store ├── README.md ├── average_chk.sh ├── cnn_text_discriminator.py ├── configs ├── config_discriminator_pretrain.yaml ├── config_gan_train.yaml ├── config_generate_sample.yaml ├── config_generator_pretrain_BPE.yaml ├── config_generator_train.yaml ├── config_generator_train_BPE.yaml └── config_text_discriminator_pretrain.yaml ├── data ├── .DS_Store ├── answer_test.txt ├── answer_train.txt ├── answer_val.txt ├── question_test.txt ├── question_train.txt └── question_val.txt ├── data_iterator.py ├── evaluate.py ├── evaluate.sh ├── gan_train.py ├── gan_train.sh ├── generate_samples.py ├── model.py ├── multi-bleu.perl ├── runDev.py ├── run_gan_dev.sh ├── share_function.py ├── shuffle.py ├── tensor2tensor ├── __init__.py ├── __init__.pyc ├── avg_checkpoints.py ├── common_attention.py ├── common_attention.pyc ├── common_layers.py ├── common_layers.pyc ├── expert_utils.py └── expert_utils.pyc ├── text_disc_pretrain.sh ├── text_discriminator_pretrain.py ├── train.py ├── train_qgen.py ├── train_unSuper_en2de.sh ├── train_unsupervised.py ├── utils.py └── vocab.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Configuration Files: 3 | 1. config_generator_train_BPE.yaml - Contains the configuration to train the supervised Question generation task. 4 | 2. config_generator_pretrain_BPE.yaml - Contains the configuration to train the unsupervised task (with back translation) 5 | 6 | 7 | 8 | To run unsupervised pretraining: 9 | ``` 10 | python train_unsupervised.py -c configs/config_generator_pretrain_BPE.yaml 11 | ``` 12 | 13 | To train the supervised question generation task: 14 | ``` 15 | python train_qgen.py -c configs/config_generator_train_BPE.yaml 16 | ``` 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /average_chk.sh: -------------------------------------------------------------------------------- 1 | TRAIN_DIR=model3 2 | 3 | python avg_checkpoints.py --checkpoints="model.ckpt-109305, model.ckpt-109692, model.ckpt-110050, model.ckpt-110437" \ 4 | --prefix=$TRAIN_DIR \ 5 | --output_path="$TRAIN_DIR/averaged.ckpt" 6 | -------------------------------------------------------------------------------- /cnn_text_discriminator.py: -------------------------------------------------------------------------------- 1 | # this code is implemented as a discriminator to classify the sentence 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.ops import math_ops 5 | from tensorflow.python.ops import variable_scope as vs 6 | 7 | #from data_iterator import domainTextIterator 8 | from data_iterator import unSuperGanTextIterator 9 | from data_iterator import disThreeTextIterator 10 | from data_iterator import disTextIterator 11 | from share_function import dis_length_prepare 12 | from share_function import average_clip_gradient 13 | from share_function import average_clip_gradient_by_value 14 | from share_function import dis_three_length_prepare 15 | from model import split_tensor 16 | 17 | 18 | import time 19 | import numpy 20 | import os 21 | 22 | from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm 23 | 24 | def conv_batch_norm(x, is_train, scope='bn', decay=0.9, reuse_var = False): 25 | 26 | out = batch_norm(x, 27 | decay=decay, 28 | center=True, 29 | scale=True, 30 | updates_collections=None, 31 | is_training=is_train, 32 | reuse=reuse_var, 33 | trainable=True, 34 | scope=scope) 35 | return out 36 | 37 | def linear(inputs, output_size, use_bias, scope='linear'): 38 | if not scope: 39 | scope=tf.get_variable_scope() 40 | 41 | input_size = inputs.get_shape()[1].value 42 | dtype=inputs.dtype 43 | 44 | with tf.variable_scope(scope): 45 | weights=tf.get_variable('weights', [input_size, output_size], dtype=dtype) 46 | res = tf.matmul(inputs, weights) 47 | if not use_bias: 48 | return res 49 | biases=tf.get_variable('biases', [output_size], dtype=dtype) 50 | return tf.add(res, biases) 51 | 52 | def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu, reuse_var=False): 53 | output = input_ 54 | if reuse_var == True: 55 | tf.get_variable_scope().reuse_variables() 56 | for idx in xrange(layer_size): 57 | output = f(linear(output, size, 0, scope='output_lin_%d' %idx)) 58 | transform_gate = tf.sigmoid(linear(input_, size, 0, scope='transform_lin_%d'%idx) +bias) 59 | carry_gate = 1. - transform_gate 60 | output = transform_gate * output + carry_gate * input_ 61 | return output 62 | 63 | def highway_s(input_, size, layer_size=1, bias=-2, f=tf.nn.relu, reuse_var=False): 64 | output = input_ 65 | if reuse_var == True: 66 | tf.get_variable_scope().reuse_variables() 67 | for idx in xrange(layer_size): 68 | output = f(linear(output, size, 0, scope='output_s_lin_%d' %idx)) 69 | transform_gate = tf.sigmoid(linear(input_, size, 0, scope='transform_s_lin_%d'%idx) +bias) 70 | carry_gate = 1. - transform_gate 71 | output = transform_gate * output + carry_gate * input_ 72 | return output 73 | 74 | class cnn_layer(object): 75 | def __init__(self, filter_size, dim_word, num_filter, scope='cnn_layer', init_device='/cpu:0', reuse_var=False): 76 | self.filter_size = filter_size 77 | self.dim_word = dim_word 78 | self.num_filter = num_filter 79 | self.scope = scope 80 | self.reuse_var = reuse_var 81 | if reuse_var == False: 82 | with tf.variable_scope(self.scope or 'cnn_layer'): 83 | with tf.variable_scope('self_model'): 84 | with tf.device(init_device): 85 | filter_shape = [filter_size, dim_word, 1, num_filter] 86 | b = tf.get_variable('b', initializer = tf.constant(0.1, shape=[num_filter])) 87 | W = tf.get_variable('W', initializer = tf.truncated_normal(filter_shape, stddev=0.1)) 88 | 89 | ## convolutuon with batch normalization 90 | def conv_op(self, input_sen, stride, is_train, padding='VALID', is_batch_norm = True, f_activation=tf.nn.relu): 91 | with tf.variable_scope(self.scope): 92 | with tf.variable_scope('self_model'): 93 | tf.get_variable_scope().reuse_variables() 94 | b = tf.get_variable('b') 95 | W = tf.get_variable('W') 96 | conv = tf.nn.conv2d( 97 | input_sen, 98 | W, 99 | stride, 100 | padding, 101 | name='conv') 102 | bias_add = tf.nn.bias_add(conv, b) 103 | 104 | if is_batch_norm : 105 | with tf.variable_scope('conv_batch_norm'): 106 | conv_bn = conv_batch_norm(bias_add, is_train = is_train, scope='bn', reuse_var = self.reuse_var) 107 | h = f_activation(conv_bn, name='relu') 108 | else: 109 | h = f_activation(bias_add, name='relu') 110 | 111 | return h 112 | 113 | class text_DisCNN(object): 114 | """ 115 | A CNN for sentence classification 116 | Uses an embedding layer, followed by a convolutional layer, max_pooling and softmax layer. 117 | vocab_size_s: the size of Chinese vocab 118 | vocab_size_t: the size of English vocab 119 | source_dict: Chinese dict 120 | target_dict: English dict 121 | s_domain_data: Chinese data 122 | t_domain_data: English data 123 | g_domain_data: generated data 124 | """ 125 | 126 | def __init__(self, sess, max_len, num_classes, vocab_size_s, batch_size, dim_word, filter_sizes, num_filters, source_dict, gpu_device, s_domain_data, s_domain_generated_data, dev_s_domain_data, dev_s_domain_generated_data=None, max_epoches=10, dispFreq = 1, saveFreq = 10, devFreq=1000, clip_c = 1.0, optimizer='adadelta', saveto='text_discriminator', reload=False, reshuffle = False, l2_reg_lambda=0.0, scope='text_discnn', init_device="/cpu:0", reuse_var=False): 127 | 128 | self.sess = sess 129 | self.max_len = max_len 130 | self.num_classes = num_classes 131 | self.vocab_size_s = vocab_size_s 132 | self.dim_word = dim_word 133 | self.filter_sizes = filter_sizes 134 | self.num_filters = num_filters 135 | self.l2_reg_lambda = l2_reg_lambda 136 | self.num_filters_total = sum(self.num_filters) 137 | self.scope = scope 138 | self.s_domain_data = s_domain_data 139 | self.s_domain_generated_data = s_domain_generated_data 140 | self.dev_s_domain_data = dev_s_domain_data 141 | self.dev_s_domain_generated_data = dev_s_domain_generated_data 142 | self.reshuffle = reshuffle 143 | self.batch_size = batch_size 144 | self.max_epoches = max_epoches 145 | self.dispFreq = dispFreq 146 | self.saveFreq = saveFreq 147 | self.devFreq = devFreq 148 | self.clip_c = clip_c 149 | self.saveto = saveto 150 | self.reload = reload 151 | 152 | print('num_filters_total is ', self.num_filters_total) 153 | 154 | if optimizer == 'adam': 155 | self.ptimizer = tf.train.AdamOptimizer() 156 | print("using adam as the optimizer for the discriminator") 157 | elif optimizer == 'adadelta': 158 | self.optimizer = tf.train.AdadeltaOptimizer(learning_rate=1.,rho=0.95,epsilon=1e-6) 159 | print("using adadelta as the optimizer for the discriminator") 160 | elif optimizer == 'sgd': 161 | self.optimizer = tf.train.GradientDescentOptimizer(0.0001) 162 | print("using sgd as the optimizer for the discriminator") 163 | elif optimizer == 'rmsprop': 164 | self.optimizer = tf.train.RMSPropOptimizer(0.0001) 165 | print("using rmsprop as the optimizer for the discriminator") 166 | else : 167 | raise ValueError("optimizer must be adam, adadelta or sgd.") 168 | 169 | dictionaries=[] 170 | dictionaries.append(source_dict) 171 | self.dictionaries = dictionaries 172 | 173 | gpu_string = gpu_device 174 | gpu_devices = [] 175 | gpu_devices = gpu_string.split('-') 176 | self.gpu_devices = gpu_devices[1:] 177 | self.gpu_num = len(self.gpu_devices) 178 | #print('the gpu_num is ', self.gpu_num) 179 | 180 | self.build_placeholder() 181 | 182 | if reuse_var == False: 183 | with tf.variable_scope(self.scope or 'disCNN'): 184 | with tf.variable_scope('model_self'): 185 | with tf.device(init_device): 186 | embeddingtable = tf.get_variable('embeddingtable', initializer = tf.random_uniform([self.vocab_size_s, self.dim_word], -1.0, 1.0)) 187 | W = tf.get_variable('W', initializer = tf.truncated_normal([self.num_filters_total, self.num_classes], stddev=0.1)) 188 | b = tf.get_variable('b', initializer = tf.constant(0.1, shape=[self.num_classes])) 189 | 190 | ## build_model ########## 191 | print('building train model') 192 | self.build_train_model() 193 | print('done') 194 | print('build_discriminate ') 195 | #self.build_discriminate(gpu_device=self.gpu_devices[-1]) 196 | self.build_discriminator_model(dis_devices=self.gpu_devices) 197 | print('done') 198 | 199 | params = [param for param in tf.global_variables() if self.scope in param.name] 200 | if not self.sess.run(tf.is_variable_initialized(params[0])): 201 | init_op = tf.variables_initializer(params) 202 | self.sess.run(init_op) 203 | 204 | saver = tf.train.Saver(params) 205 | self.saver = saver 206 | 207 | if self.reload: 208 | print('reloading file from %s' % self.saveto) 209 | self.saver.restore(self.sess, self.saveto) 210 | print('reloading file done') 211 | 212 | 213 | def build_placeholder(self, gpu_num = None): 214 | self.x_list = [] 215 | self.y_list = [] 216 | self.drop_list = [] 217 | if gpu_num is None: 218 | gpu_num = self.gpu_num 219 | 220 | for i in range(gpu_num): 221 | input_x = tf.placeholder(tf.int32, [self.max_len, None], name='input_x') 222 | input_y = tf.placeholder(tf.float32, [self.num_classes, None], name='input_y') 223 | drop_prob = tf.placeholder(tf.float32, name='dropout_prob') 224 | 225 | self.x_list.append(input_x) 226 | self.y_list.append(input_y) 227 | self.drop_list.append(drop_prob) 228 | 229 | def get_inputs(self, gpu_device): 230 | try: 231 | gpu_id = self.gpu_devices.index(gpu_device) 232 | except: 233 | raise ValueError('get inputs error!') 234 | return self.x_list[gpu_id], self.y_list[gpu_id], self.drop_list[gpu_id] 235 | 236 | 237 | def build_model(self, reuse_var=False, gpu_device='0'): 238 | with tf.variable_scope(self.scope): 239 | with tf.device('/gpu:%d' % int(gpu_device)): 240 | input_x, input_y, drop_keep_prob = self.get_inputs(gpu_device) 241 | 242 | input_x_trans = tf.transpose(input_x, [1,0]) 243 | input_y_trans = tf.transpose(input_y, [1,0]) 244 | 245 | with tf.variable_scope('model_self'): 246 | tf.get_variable_scope().reuse_variables() 247 | W = tf.get_variable('W') 248 | b = tf.get_variable('b') 249 | embeddingtable = tf.get_variable('embeddingtable') 250 | 251 | sentence_embed = tf.nn.embedding_lookup(embeddingtable, input_x_trans) 252 | sentence_embed_expanded = tf.expand_dims(sentence_embed, -1) 253 | pooled_outputs = [] 254 | for filter_size, num_filter in zip(self.filter_sizes, self.num_filters): 255 | scope = "conv_maxpool-%s" % filter_size 256 | filter_shape = [filter_size, self.dim_word, 1, num_filter] 257 | strides=[1,1,1,1] 258 | conv = cnn_layer(filter_size, self.dim_word, num_filter, scope=scope, reuse_var = reuse_var) 259 | is_train = True 260 | conv_out = conv.conv_op(sentence_embed_expanded, strides, is_train=is_train) 261 | pooled = tf.nn.max_pool(conv_out, ksize=[1, (self.max_len - filter_size +1), 1, 1], strides=strides, padding='VALID', name='pool') 262 | pooled_outputs.append(pooled) 263 | 264 | h_pool = tf.concat(axis=3, values=pooled_outputs) 265 | h_pool_flat = tf.reshape(h_pool, [-1, self.num_filters_total]) 266 | 267 | h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0, reuse_var=reuse_var) 268 | h_drop = tf.nn.dropout(h_highway, drop_keep_prob) 269 | 270 | 271 | scores = tf.nn.xw_plus_b(h_drop, W, b, name='scores') 272 | ypred_for_auc = tf.nn.softmax(scores) 273 | predictions = tf.argmax(scores, 1, name='prediction') 274 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=input_y_trans) 275 | 276 | correct_predictions = tf.equal(predictions, tf.argmax(input_y_trans, 1)) 277 | accuracy = tf.reduce_mean(tf.cast(correct_predictions, 'float'), name='accuracy') 278 | 279 | params = [param for param in tf.trainable_variables() if self.scope in param.name] 280 | 281 | #for param in params: 282 | # print param.name 283 | 284 | #self.params = params 285 | 286 | grads_and_vars = self.optimizer.compute_gradients(losses, params) 287 | 288 | #for grad, var in grads_and_vars: 289 | # print (var.name, grad) 290 | 291 | l2_loss = tf.constant(0.0) 292 | l2_loss += tf.nn.l2_loss(W) 293 | l2_loss += tf.nn.l2_loss(b) 294 | loss = tf.reduce_mean(losses) + self.l2_reg_lambda * l2_loss 295 | 296 | return input_x, input_y, drop_keep_prob, ypred_for_auc, predictions, loss, correct_predictions, accuracy, grads_and_vars 297 | 298 | def build_discriminator_body(self, input_x, input_y, dropout_keep_prob, reuse_var=True): 299 | 300 | input_x_trans = input_x 301 | input_y_trans = input_y 302 | dis_dropout_keep_prob = dropout_keep_prob 303 | 304 | with tf.variable_scope('model_self'): 305 | tf.get_variable_scope().reuse_variables() 306 | W = tf.get_variable('W') 307 | b = tf.get_variable('b') 308 | embeddingtable = tf.get_variable('embeddingtable') 309 | 310 | sentence_embed = tf.nn.embedding_lookup(embeddingtable, input_x_trans) 311 | 312 | sentence_embed_expanded = tf.expand_dims(sentence_embed, -1) 313 | 314 | pooled_outputs = [] 315 | 316 | for filter_size, num_filter in zip(self.filter_sizes, self.num_filters): 317 | #print('the filter size is ', filter_size) 318 | scope = "conv_maxpool-%s" % filter_size 319 | filter_shape = [filter_size, self.dim_word, 1, num_filter] 320 | strides=[1,1,1,1] 321 | conv = cnn_layer(filter_size, self.dim_word, num_filter, scope=scope, reuse_var = reuse_var) 322 | is_train = False 323 | conv_out = conv.conv_op(sentence_embed_expanded, strides, is_train=is_train) 324 | pooled = tf.nn.max_pool(conv_out, ksize=[1, (self.max_len - filter_size +1), 1, 1], strides=strides, padding='VALID', name='pool') 325 | #print('the shape of the pooled is ', pooled.get_shape()) 326 | pooled_outputs.append(pooled) 327 | 328 | h_pool = tf.concat(axis=3, values=pooled_outputs) 329 | #print('the shape of h_pool is ', h_pool.get_shape()) 330 | #print('the num_filters_total is ', self.num_filters_total) 331 | 332 | h_pool_flat = tf.reshape(h_pool, [-1, self.num_filters_total]) 333 | 334 | #print('the shape of h_pool_flat is ', h_pool_flat.get_shape()) 335 | 336 | h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0, reuse_var=reuse_var) 337 | h_drop = tf.nn.dropout(h_highway, dis_dropout_keep_prob) 338 | 339 | scores = tf.nn.xw_plus_b(h_drop, W, b, name='scores') 340 | 341 | ypred_for_auc = tf.nn.softmax(scores) 342 | predictions = tf.argmax(scores, 1, name='prediction') 343 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=input_y_trans) 344 | 345 | correct_predictions = tf.equal(predictions, tf.argmax(input_y_trans, 1)) 346 | accuracy = tf.reduce_mean(tf.cast(correct_predictions, 'float'), name='accuracy') 347 | 348 | grads_and_vars = self.optimizer.compute_gradients(losses) 349 | 350 | l2_loss = tf.constant(0.0) 351 | l2_loss += tf.nn.l2_loss(W) 352 | l2_loss += tf.nn.l2_loss(b) 353 | loss = tf.reduce_mean(losses) + self.l2_reg_lambda * l2_loss 354 | 355 | return ypred_for_auc 356 | 357 | def build_discriminator_model(self, dis_devices): 358 | with tf.variable_scope(self.scope): 359 | with tf.device('/cpu:0'): 360 | self.dis_input_x = tf.placeholder(tf.int32, [self.max_len, None], name='input_x') 361 | self.dis_input_y = tf.placeholder(tf.float32, [self.num_classes, None], name='input_y') 362 | self.dis_dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob') 363 | 364 | dis_input_x = tf.transpose(self.dis_input_x, [1, 0]) 365 | dis_input_y = tf.transpose(self.dis_input_y, [1, 0]) 366 | 367 | devices = ['/gpu:' + i for i in dis_devices] 368 | 369 | input_x_list = split_tensor(dis_input_x, len(devices)) 370 | input_y_list = split_tensor(dis_input_y, len(devices)) 371 | 372 | dis_dropout_keep_prob = [self.dis_dropout_keep_prob] * len(devices) 373 | 374 | batch_size_list = [tf.shape(x)[0] for x in input_x_list] 375 | 376 | pred_list = [None] * len(devices) 377 | for i, (input_x, input_y, drop, device) in enumerate(zip(input_x_list, input_y_list, dis_dropout_keep_prob, devices)): 378 | with tf.device(device): 379 | print("building discriminator model on device %s" % device) 380 | ypred_for_auc = self.build_discriminator_body(input_x, input_y, drop, reuse_var=True) 381 | pred_list[i] = ypred_for_auc 382 | 383 | self.dis_ypred_for_auc = tf.concat(pred_list, axis=0) 384 | 385 | 386 | def build_train_model(self): 387 | loss = tf.convert_to_tensor(0.) 388 | grads = [] 389 | accu = tf.convert_to_tensor(0.) 390 | 391 | reuse_var = False 392 | for i, gpu_device in enumerate(self.gpu_devices): 393 | #print('i is %d, gpu is %s' %(i, gpu_device)) 394 | if i > 0: 395 | reuse_var = True 396 | #print('reuse_var is ', reuse_var) 397 | _, _, _, ypred_for_auc, predictions, losses, correct_predictions, accuracy, grads_and_vars = self.build_model(reuse_var=reuse_var, gpu_device=gpu_device) 398 | loss += losses 399 | accu += accuracy 400 | grads.append(grads_and_vars) 401 | 402 | loss = loss / self.gpu_num 403 | accuracy = accu / self.gpu_num 404 | #grads_and_vars = average_clip_gradient(grads, self.clip_c) 405 | grads_and_vars = average_clip_gradient_by_value(grads, -1.0, 1.0) 406 | optm = self.optimizer.apply_gradients(grads_and_vars) 407 | 408 | clip_ops = [] 409 | 410 | var_s = [var for var in tf.trainable_variables() if self.scope in var.name] 411 | for var in var_s: 412 | clip_ops.append(tf.assign(var, tf.clip_by_value(var, -1., 1.))) 413 | 414 | clip_ops = tf.group(*clip_ops) 415 | 416 | self.clip_ops = clip_ops 417 | 418 | self.train_loss = loss 419 | self.train_accuracy = accuracy 420 | self.train_grads_and_vars = grads_and_vars 421 | self.train_optm = optm 422 | self.train_ypred = ypred_for_auc 423 | 424 | 425 | def train(self, max_epoch = None, s_domain_data=None, s_domain_generated_data=None): 426 | 427 | if s_domain_data is None or s_domain_generated_data is None: 428 | s_domain_data = self.s_domain_data 429 | s_domain_generated_data = self.s_domain_generated_data 430 | 431 | print('the s_domain, s_domain_generated_data is %s, %s' %(s_domain_data, s_domain_generated_data)) 432 | 433 | if max_epoch is None: 434 | max_epoch = self.max_epoches 435 | 436 | def train_iter(): 437 | Epoch = 0 438 | while True: 439 | if self.reshuffle: 440 | os.popen('python shuffle.py ' + s_domain_data +' ' + s_domain_generated_data) 441 | os.popen('mv ' + s_domain_data + '.shuf '+ s_domain_data) 442 | os.popen('mv ' + s_domain_generated_data + '.shuf ' + s_domain_generated_data) 443 | 444 | 445 | disTrain = disTextIterator(s_domain_data, s_domain_generated_data, 446 | self.dictionaries[0], 447 | batch = self.batch_size * self.gpu_num, 448 | maxlen = self.max_len, 449 | n_words_target = self.vocab_size_s) 450 | 451 | 452 | ExampleNum = 0 453 | print( 'Epoch :', Epoch) 454 | 455 | EpochStart = time.time() 456 | for x, y in disTrain: 457 | if len(x) < self.gpu_num: 458 | continue 459 | ExampleNum+=len(x) 460 | yield x, y, Epoch 461 | TimeCost = time.time() - EpochStart 462 | 463 | Epoch +=1 464 | print('Seen ', ExampleNum, ' examples for text_discriminator. Time Cost : ', TimeCost) 465 | 466 | train_it = train_iter() 467 | drop_prob = 1.0 468 | 469 | TrainStart = time.time() 470 | epoch = 0 471 | uidx = 0 472 | HourIdx = 0 473 | print('train begin') 474 | while epoch < max_epoch: 475 | if time.time() - TrainStart >= 3600 * HourIdx: 476 | print('------------------------------------------Hour %d --------------------' % HourIdx) 477 | HourIdx +=1 478 | 479 | BatchStart = time.time() 480 | x, y, epoch = next(train_it) 481 | uidx +=1 482 | #print('uidx is ', uidx) 483 | #print(len(x)) 484 | if not len(x) % self.gpu_num == 0 or x is None: 485 | print('the positive data is bad') 486 | continue 487 | x_data_list = numpy.split(numpy.array(x), self.gpu_num) 488 | y_data_list = numpy.split(numpy.array(y), self.gpu_num) 489 | 490 | myFeed_dict={} 491 | for i, x, y in zip(range(self.gpu_num), x_data_list, y_data_list): 492 | x = x.tolist() 493 | x, y = dis_length_prepare(x, y, self.num_classes, self.max_len) 494 | myFeed_dict[self.x_list[i]]=x 495 | myFeed_dict[self.y_list[i]]=y 496 | myFeed_dict[self.drop_list[i]]=drop_prob 497 | 498 | _, loss_out, accuracy_out, grads_out = self.sess.run([self.train_optm, self.train_loss, self.train_accuracy, self.train_grads_and_vars], feed_dict=myFeed_dict) 499 | 500 | if uidx == 1: 501 | _ = self.sess.run(self.clip_ops) 502 | #x_variable = [self.sess.run(tf.assign(x, tf.clip_by_value(x, -1.0, 1.0))) for x in tf.trainable_variables() if self.scope in x.name] # clip the value into -0.01 to 0.01 503 | 504 | #print('ypred_for_auc is ', ypred_out) 505 | BatchTime = time.time()-BatchStart 506 | 507 | if numpy.mod(uidx, self.dispFreq) == 0: 508 | print("epoch %d, samples %d, loss %f, accuracy %f BatchTime %f, for discriminator pretraining " % (epoch, uidx * self.gpu_num * self.batch_size, loss_out, accuracy_out, BatchTime)) 509 | 510 | if numpy.mod(uidx, self.saveFreq) == 0: 511 | print('save params when epoch %d, samples %d' %(epoch, uidx * self.gpu_num * self.batch_size)) 512 | self.saver.save(self.sess, self.saveto) 513 | 514 | if numpy.mod(uidx, self.devFreq) == 0: 515 | print('doing nothing on the evaluation sets') 516 | 517 | -------------------------------------------------------------------------------- /configs/config_discriminator_pretrain.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: 3 | dst_vocab: 4 | src_vocab_size_a: 32000 5 | dst_vocab_size_a: 32000 6 | src_vocab_size_b: 32000 7 | dst_vocab_size_b: 32000 8 | hidden_units: 512 9 | scale_embedding: True 10 | attention_dropout_rate: 0.0 11 | residual_dropout_rate: 0.1 12 | num_blocks: 6 13 | num_heads: 8 14 | binding_embedding: False 15 | kl_weight: 0.0 16 | enc_layer_indep: 4 17 | enc_layer_share: 1 18 | dec_layer_indep: 4 19 | dec_layer_share: 1 20 | generate_maxlen: 50 21 | lock_enc_embed: True 22 | multi_channel_encoder: True 23 | 24 | train: 25 | logdir: './dis_log' 26 | dis_src_vocab: 27 | dis_dst_vocab: 28 | dis_max_epoches: 2 29 | dis_dispFreq: 1 30 | dis_saveFreq: 100 31 | dis_devFreq: 100 32 | dis_batch_size: 100 33 | dis_saveto: '../experience/unSupervisedNMT/ai_challenger/dis_pretrain/dis_mono' 34 | dis_reshuffle: True 35 | dis_gpu_device: 'gpu-0-1-2-3-4-5-6-7' 36 | dis_max_len: 50 37 | s_domain_data: 38 | t_domain_data: 39 | s_domain_generated_data: 40 | t_domain_generated_data: 41 | dev_s_domain_data: 42 | dev_t_domain_data: 43 | dev_s_domain_generated_data: 44 | dev_t_domain_generated_data: 45 | dis_reload: True 46 | dis_clip_c: 1.0 47 | dis_dim_word: 512 48 | dis_optimizer: 'rmsprop' 49 | dis_scope: 'discnn' 50 | -------------------------------------------------------------------------------- /configs/config_gan_train.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: '/home/user/zy/dl4mt/corpus/lf_50_for_unsupervised_NMT/data_for_gan_training/vocab.bpe.32000' 3 | dst_vocab: '/home/user/zy/dl4mt/corpus/lf_50_for_unsupervised_NMT/data_for_gan_training/vocab.bpe.32000' 4 | src_vocab_size_a: 32000 5 | dst_vocab_size_a: 32000 6 | src_vocab_size_b: 32000 7 | dst_vocab_size_b: 32000 8 | hidden_units: 512 9 | scale_embedding: True 10 | attention_dropout_rate: 0.0 11 | residual_dropout_rate: 0.1 12 | num_blocks: 6 13 | num_heads: 8 14 | binding_embedding: False 15 | kl_weight: 0.0 16 | enc_layer_indep: 4 17 | enc_layer_share: 1 18 | dec_layer_indep: 4 19 | dec_layer_share: 1 20 | generate_maxlen: 50 21 | lock_enc_embed: True 22 | multi_channel_encoder: True 23 | gan_iter_num: 8888 24 | gan_gen_iter_num: 1 25 | gan_dis_iter_num: 1 26 | generate_num: 5000 27 | rollnum: 20 28 | bias_num: 0.6 29 | logdir: './gan.log' 30 | 31 | train: 32 | devices: '0' 33 | src_path: 34 | dst_path: 35 | tokens_per_batch: 25000 36 | max_length: 50 37 | num_epochs: 500 38 | logdir: 'gan.log' 39 | save_freq: 1000 40 | summary_freq: 100 41 | grads_clip: 5 42 | optimizer: 'adam_decay' 43 | learning_rate: 0.00005 44 | gan_learning_rate: 0.00005 45 | learning_rate_warmup_steps: 4000 46 | label_smoothing: 0.1 47 | batch_size: 250 48 | shared_embedding: False 49 | 50 | generator: 51 | src_vocab: 52 | dst_vocab: 53 | devices: '0,1' 54 | src_path: 55 | dst_path: 56 | tokens_per_batch: 25000 57 | max_length: 50 58 | num_epochs: 500 59 | logdir: 'gan.log' 60 | save_freq: 1000 61 | summary_freq: 100 62 | grads_clip: 5 63 | optimizer: 'rmsprop' 64 | modelFile: 65 | learning_rate: 0.00001 66 | gan_learning_rate: 0.00005 67 | learning_rate_warmup_steps: 4000 68 | label_smoothing: 0.1 69 | batch_size: 100 70 | 71 | discriminator: 72 | dis_src_vocab: 73 | dis_dst_vocab: 74 | dis_max_epoches: 888 75 | dis_dispFreq: 1 76 | dis_saveFreq: 100 77 | dis_devFreq: 100 78 | dis_batch_size: 40 79 | dis_saveto: '../experience/unSupervisedNMT/ai_challenger/gan_training/dis_mono_first' 80 | dis_saveto_trg: '../experience/unSupervisedNMT/ai_challenger/gan_training/dis_mono_trg' 81 | dis_reshuffle: True 82 | dis_gpu_devices: 'gpu-2-3-4-5' 83 | dis_max_len: 50 84 | s_domain_data: 85 | t_domain_data: 86 | s_domain_generated_data: 87 | t_domain_generated_data: 88 | dev_s_domain_data: 89 | dev_t_domain_data: 90 | dev_s_domain_generated_data: 91 | dev_t_domain_generated_data: 92 | dis_reload: True 93 | dis_clip_c: 1.0 94 | dis_dim_word: 512 95 | dis_optimizer: 'rmsprop' 96 | dis_scope: 'discnn_src' 97 | dis_scope_trg: 'discnn_trg' 98 | 99 | test: 100 | mode: 'ab' 101 | src_path: './bleuTest/newstest2016/newstest2016.tok.bpe.32000.en' 102 | dst_path: './bleuTest/newstest2016/newstest2016.tok.bpe.32000.de' 103 | ori_dst_path: './bleuTest/newstest2016/newstest2016.tok.de' 104 | output_path: './bleuTest/newstest2016/newstest2016.ab.output' 105 | batch_size: 256 106 | max_target_length: 200 107 | beam_size: 4 108 | lp_alpha: 0.6 109 | devices: '0,1' 110 | modelFile: 111 | -------------------------------------------------------------------------------- /configs/config_generate_sample.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: 3 | dst_vocab: 4 | src_vocab_size_a: 32000 5 | dst_vocab_size_a: 32000 6 | src_vocab_size_b: 32000 7 | dst_vocab_size_b: 32000 8 | hidden_units: 512 9 | scale_embedding: True 10 | attention_dropout_rate: 0.0 11 | residual_dropout_rate: 0.1 12 | num_blocks: 6 13 | num_heads: 8 14 | binding_embedding: False 15 | kl_weight: 0.0 16 | enc_layer_indep: 4 17 | enc_layer_share: 1 18 | dec_layer_indep: 4 19 | dec_layer_share: 1 20 | generate_maxlen: 50 21 | lock_enc_embed: True 22 | multi_channel_encoder: True 23 | 24 | train: 25 | devices: '0,1,2,3,4,5,6,7' 26 | src_path: 27 | dst_path: 28 | s_domain_generated_data: 29 | t_domain_generated_data: 30 | tokens_per_batch: 5000 31 | max_length: 50 32 | num_epochs: 1 33 | logdir: './generate_samples.log' 34 | save_freq: 1000 35 | summary_freq: 100 36 | grads_clip: 5.0 37 | optimizer: 'adam_decay' 38 | learning_rate: 0.00001 39 | gan_learning_rate: 0.00001 40 | learning_rate_warmup_steps: 10000 41 | warm_steps: 1000 42 | label_smoothing: 0.1 43 | batch_size: 256 44 | shared_embedding: False 45 | shuffle_k: 5 46 | modelFile: '../experience/unSupervisedNMT/ai_challenger/ende-s4-i0_multiChannel_divide/model_epoch_1_step_307000' 47 | -------------------------------------------------------------------------------- /configs/config_generator_pretrain_BPE.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: 3 | dst_vocab: 4 | src_gen: 5 | dst_gen: 6 | src_vocab_size_a: 39799 7 | dst_vocab_size_a: 39799 8 | src_vocab_size_b: 17278 9 | dst_vocab_size_b: 17278 10 | src_vocab_size: 0 11 | dst_vocab_size: 0 12 | hidden_units: 300 13 | scale_embedding: True 14 | attention_dropout_rate: 0.0 15 | residual_dropout_rate: 0.2 16 | num_blocks: 6 17 | num_heads: 6 18 | binding_embedding: False 19 | kl_weight: 0.00002 20 | enc_layer_indep: 4 21 | enc_layer_share: 1 22 | dec_layer_indep: 4 23 | dec_layer_share: 1 24 | generate_maxlen: 50 25 | lock_enc_embed: False 26 | vari_emb_scale: 0.0 27 | multi_channel_encoder: True 28 | 29 | train: 30 | devices: '2' 31 | src_path: 32 | dst_path: 33 | src_path_1: 34 | dst_path_1: 35 | tokens_per_batch: 5000 36 | max_length: 80 37 | num_epochs: 15 38 | logdir: 'out/48' 39 | save_freq: 500 40 | summary_freq: 100 41 | disp_freq: 100 42 | grads_clip: 5 43 | optimizer: 'adam_decay' 44 | learning_rate: 0.00001 45 | gan_learning_rate: 0.00001 46 | learning_rate_warmup_steps: 10000 47 | warm_steps: 1000 48 | label_smoothing: 0.1 49 | batch_size: 64 50 | shared_embedding: False 51 | shuffle_k: 5 52 | src_pretrain_wordemb_path: 53 | dst_pretrain_wordemb_path: 54 | restore_partial: False 55 | src_more: True 56 | restore_embed: True 57 | restore_decoder: False 58 | restore_decode_embed: True 59 | decoder_logdir: 'out/22' 60 | 61 | test: 62 | mode: 'ab' 63 | src_path: 64 | dst_path: 65 | ori_dst_path: 66 | output_path: 67 | batch_size: 32 68 | max_target_length: 200 69 | beam_size: 4 70 | lp_alpha: 0.6 71 | devices: '2' 72 | -------------------------------------------------------------------------------- /configs/config_generator_train.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: 3 | dst_vocab: 4 | src_gen: 5 | dst_gen: 6 | src_vocab_size_a: 75567 7 | dst_vocab_size_a: 75567 8 | src_vocab_size_b: 11074 9 | dst_vocab_size_b: 11074 10 | src_vocab_size: 75567 11 | dst_vocab_size: 11074 12 | hidden_units: 512 13 | scale_embedding: True 14 | attention_dropout_rate: 0.0 15 | residual_dropout_rate: 0.2 16 | num_blocks: 6 17 | num_heads: 8 18 | binding_embedding: False 19 | kl_weight: 0.00002 20 | enc_layer_indep: 4 21 | enc_layer_share: 1 22 | dec_layer_indep: 4 23 | dec_layer_share: 1 24 | generate_maxlen: 50 25 | lock_enc_embed: False 26 | vari_emb_scale: 0.01 27 | multi_channel_encoder: True 28 | 29 | train: 30 | devices: '2' 31 | src_path: 32 | dst_path: 33 | src_path_1: 34 | dst_path_1: 35 | tokens_per_batch: 5000 36 | max_length: 80 37 | num_epochs: 10 38 | logdir: 'out/05' 39 | save_freq: 50 40 | summary_freq: 10 41 | disp_freq: 10 42 | grads_clip: 5 43 | optimizer: 'adam_decay' 44 | learning_rate: 0.00001 45 | gan_learning_rate: 0.00001 46 | learning_rate_warmup_steps: 10000 47 | warm_steps: 1000 48 | label_smoothing: 0.1 49 | batch_size: 64 50 | shared_embedding: False 51 | shuffle_k: 5 52 | src_pretrain_wordemb_path: './data/embeddings/vectors-en.txt' 53 | dst_pretrain_wordemb_path: './data/embeddings/vectors-hi.txt' 54 | test: 55 | mode: 'bb' 56 | src_path: 57 | dst_path: 58 | ori_dst_path: 59 | output_path: 60 | batch_size: 32 61 | max_target_length: 200 62 | beam_size: 4 63 | lp_alpha: 0.6 64 | devices: '2' 65 | -------------------------------------------------------------------------------- /configs/config_generator_train_BPE.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: 3 | dst_vocab: 4 | src_gen: 5 | dst_gen: 6 | src_vocab_size_a: 39799 7 | dst_vocab_size_a: 39799 8 | src_vocab_size_b: 17278 9 | dst_vocab_size_b: 17278 10 | src_vocab_size: 0 11 | dst_vocab_size: 0 12 | hidden_units: 500 13 | scale_embedding: True 14 | attention_dropout_rate: 0.0 15 | residual_dropout_rate: 0.2 16 | num_blocks: 6 17 | num_heads: 10 18 | binding_embedding: False 19 | kl_weight: 0.00002 20 | enc_layer_indep: 4 21 | enc_layer_share: 1 22 | dec_layer_indep: 4 23 | dec_layer_share: 1 24 | generate_maxlen: 50 25 | lock_enc_embed: False 26 | vari_emb_scale: 0.00 27 | multi_channel_encoder: True 28 | train_ratio: 0 29 | 30 | train: 31 | devices: '2' 32 | src_path: 33 | dst_path: 34 | src_path_1: 35 | dst_path_1: 36 | tokens_per_batch: 5000 37 | max_length: 80 38 | num_epochs: 15 39 | logdir: 'out/48' 40 | save_freq: 100 41 | summary_freq: 10 42 | disp_freq: 10 43 | grads_clip: 5 44 | optimizer: 'adam_decay' 45 | learning_rate: 0.00001 46 | gan_learning_rate: 0.00001 47 | learning_rate_warmup_steps: 10000 48 | warm_steps: 1000 49 | label_smoothing: 0.1 50 | batch_size: 64 51 | shared_embedding: False 52 | shuffle_k: 5 53 | src_pretrain_wordemb_path: 54 | dst_pretrain_wordemb_path: 55 | src_more: True 56 | restore_partial: False 57 | restore_embed: True 58 | restore_decoder: False 59 | restore_decoder_embed: False 60 | decoder_logdir: 'out/32' 61 | 62 | test: 63 | mode: 'bb' 64 | src_path: 65 | dst_path: 66 | ori_dst_path: 67 | output_path: 68 | batch_size: 32 69 | max_target_length: 200 70 | beam_size: 4 71 | lp_alpha: 0.6 72 | devices: '2' 73 | -------------------------------------------------------------------------------- /configs/config_text_discriminator_pretrain.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | src_vocab: 3 | dst_vocab: 4 | src_vocab_size_a: 32000 5 | dst_vocab_size_a: 32000 6 | src_vocab_size_b: 32000 7 | dst_vocab_size_b: 32000 8 | hidden_units: 512 9 | scale_embedding: True 10 | attention_dropout_rate: 0.0 11 | residual_dropout_rate: 0.1 12 | num_blocks: 6 13 | num_heads: 8 14 | binding_embedding: False 15 | kl_weight: 0.0 16 | enc_layer_indep: 4 17 | enc_layer_share: 1 18 | dec_layer_indep: 4 19 | dec_layer_share: 1 20 | generate_maxlen: 50 21 | lock_enc_embed: True 22 | multi_channel_encoder: True 23 | 24 | train: 25 | logdir: './dis_log' 26 | dis_src_vocab: 27 | dis_max_epoches: 2 28 | dis_dispFreq: 1 29 | dis_saveFreq: 100 30 | dis_devFreq: 100 31 | dis_batch_size: 100 32 | dis_saveto: '../experience/unSupervisedNMT/ai_challenger/dis_pretrain/dis_mono_trg' 33 | dis_reshuffle: True 34 | devices: 'gpu-0-1-2-3-4-5-6-7' 35 | dis_max_len: 50 36 | s_domain_data: 37 | s_domain_generated_data: 38 | dev_s_domain_data: 39 | dev_s_domain_generated_data: 40 | dis_reload: True 41 | dis_clip_c: 1.0 42 | dis_dim_word: 512 43 | dis_optimizer: 'rmsprop' 44 | text_scope: 'discnn_trg' 45 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/data/.DS_Store -------------------------------------------------------------------------------- /data_iterator.py: -------------------------------------------------------------------------------- 1 | import cPickle as pkl 2 | import gzip 3 | import numpy 4 | 5 | 6 | def fopen(filename, mode='r'): 7 | if filename.endswith('.gz'): 8 | return gzip.open(filename, mode) 9 | return open(filename, mode) 10 | 11 | class unSuperGanTextIterator: 12 | def __init__(self, s_domain_data, t_domain_data, s_domain_generated_data, t_domain_generated_data, dic_s, dic_t, batch=1, maxlen=50, n_words_source=-1, n_words_target=-1): 13 | self.s_domain_data = fopen(s_domain_data, 'r') 14 | self.s_domain_generated_data = fopen(s_domain_generated_data, 'r') 15 | self.t_domain_data = fopen(t_domain_data, 'r') 16 | self.t_domain_generated_data = fopen(t_domain_generated_data, 'r') 17 | 18 | with open(dic_t) as f_trg, open(dic_s) as s_trg: 19 | self.dic_target = pkl.load(f_trg) 20 | self.dic_source = pkl.load(s_trg) 21 | 22 | self.batch_size = batch 23 | assert self.batch_size % 2 == 0 24 | self.maxlen = maxlen 25 | self.n_words_trg = n_words_target 26 | self.n_words_src = n_words_source 27 | self.end_of_data = False 28 | 29 | def __iter__(self): 30 | return self 31 | 32 | def reset(self): 33 | self.s_domain_data.seek(0) 34 | self.s_domain_generated_data.seek(0) 35 | self.t_domain_data.seek(0) 36 | self.t_domain_generated_data.seek(0) 37 | 38 | def next(self): 39 | if self.end_of_data: 40 | self.end_of_data = False 41 | self.reset() 42 | raise StopIteration 43 | 44 | x = [] 45 | y = [] 46 | 47 | try: 48 | while True: 49 | ss = self.s_domain_data.readline() 50 | if ss == "": 51 | raise IOError 52 | ss = ss.strip().split() 53 | ss = [self.dic_source[w] if w in self.dic_source else 1 for w in ss] 54 | if self.n_words_src > 0: 55 | ss = [w if w < self.n_words_src else 1 for w in ss] 56 | 57 | tt = self.t_domain_data.readline() 58 | if tt == "": 59 | raise IOError 60 | tt = tt.strip().split() 61 | tt = [self.dic_target[w] if w in self.dic_target else 1 for w in tt] 62 | if self.n_words_trg > 0: 63 | tt = [w if w < self.n_words_trg else 1 for w in tt] 64 | 65 | 66 | sg = self.s_domain_generated_data.readline() 67 | if sg == "": 68 | raise IOError 69 | sg = sg.strip().split() 70 | sg = [self.dic_source[w] if w in self.dic_source else 1 for w in sg] 71 | if self.n_words_src > 0: 72 | sg = [w if w < self.n_words_src else 1 for w in sg] 73 | 74 | tg = self.t_domain_generated_data.readline() 75 | if tg == "": 76 | raise IOError 77 | tg = tg.strip().split() 78 | tg = [self.dic_target[w] if w in self.dic_target else 1 for w in tg] 79 | if self.n_words_trg > 0: 80 | tg = [w if w < self.n_words_trg else 1 for w in tg] 81 | 82 | if len(ss) > self.maxlen or len(tt) >self.maxlen or len(sg) > self.maxlen or len(tg) > self.maxlen: 83 | continue 84 | 85 | x.append(ss) 86 | y.append([1,0,0]) 87 | 88 | x.append(tt) 89 | y.append([0,1,0]) 90 | 91 | x.append(sg) 92 | y.append([0,0,1]) 93 | x.append(tg) 94 | y.append([0,0,1]) 95 | 96 | if len(x) >= self.batch_size and len(y) >= self.batch_size: 97 | shuffle_indices = numpy.random.permutation(numpy.arange(len(x))) 98 | x_np = numpy.array(x) 99 | y_np = numpy.array(y) 100 | 101 | x_np_shuffled = x_np[shuffle_indices] 102 | y_np_shuffled = y_np[shuffle_indices] 103 | 104 | x_shuffled = x_np_shuffled.tolist() 105 | y_shuffled = y_np_shuffled.tolist() 106 | 107 | break 108 | except IOError: 109 | self.end_of_data = True 110 | 111 | if len(x) <=0 or len(y) <=0: 112 | self.end_of_data = False 113 | self.reset() 114 | raise StopIteration 115 | 116 | if len(x) >=self.batch_size: 117 | return x_shuffled[:self.batch_size], y_shuffled[:self.batch_size] 118 | else: 119 | return x, y 120 | 121 | class disThreeTextIterator: 122 | def __init__(self, positive_data, negative_data, source_data, dic_target, dic_source, batch=1, maxlen=50, n_words_target=-1, n_words_source=-1): 123 | self.positive = fopen(positive_data, 'r') 124 | self.negative = fopen(negative_data, 'r') 125 | self.source = fopen(source_data, 'r') 126 | 127 | with open(dic_target) as f_trg: 128 | self.dic_target = pkl.load(f_trg) 129 | with open(dic_source) as s_trg: 130 | self.dic_source = pkl.load(s_trg) 131 | 132 | self.batch_size = batch 133 | assert self.batch_size % 2 == 0 134 | self.maxlen = maxlen 135 | self.n_words_trg = n_words_target 136 | self.n_words_src = n_words_source 137 | self.end_of_data = False 138 | 139 | def __iter__(self): 140 | return self 141 | 142 | def reset(self): 143 | self.positive.seek(0) 144 | self.negative.seek(0) 145 | self.source.seek(0) 146 | def next(self): 147 | if self.end_of_data: 148 | self.end_of_data = False 149 | self.reset() 150 | raise StopIteration 151 | 152 | positive = [] 153 | negative = [] 154 | source = [] 155 | x = [] 156 | xs = [] 157 | y = [] 158 | 159 | try: 160 | while True: 161 | ss = self.positive.readline() 162 | if ss == "": 163 | raise IOError 164 | ss = ss.strip().split() 165 | ss = [self.dic_target[w] if w in self.dic_target else 1 for w in ss] 166 | if self.n_words_trg > 0: 167 | ss = [w if w < self.n_words_trg else 1 for w in ss] 168 | 169 | tt = self.negative.readline() 170 | if tt == "": 171 | raise IOError 172 | tt = tt.strip().split() 173 | tt = [self.dic_target[w] if w in self.dic_target else 1 for w in tt] 174 | if self.n_words_trg > 0: 175 | tt = [w if w < self.n_words_trg else 1 for w in tt] 176 | 177 | ll = self.source.readline() 178 | if ll == "": 179 | raise IOError 180 | ll = ll.strip().split() 181 | ll = [self.dic_source[w] if w in self.dic_source else 1 for w in ll] 182 | if self.n_words_src > 0: 183 | ll = [w if w < self.n_words_src else 1 for w in ll] 184 | 185 | if len(ss) > self.maxlen or len(tt) >self.maxlen or len(ll) > self.maxlen: 186 | continue 187 | 188 | positive.append(ss) 189 | negative.append(tt) 190 | source.append(ll) 191 | 192 | x = positive + negative 193 | 194 | positive_labels = [[0, 1] for _ in positive] 195 | negative_labels = [[1, 0] for _ in negative] 196 | y = positive_labels + negative_labels 197 | 198 | xs = source + source 199 | 200 | shuffle_indices = numpy.random.permutation(numpy.arange(len(x))) 201 | x_np = numpy.array(x) 202 | y_np = numpy.array(y) 203 | xs_np =numpy.array(xs) 204 | 205 | x_np_shuffled = x_np[shuffle_indices] 206 | y_np_shuffled = y_np[shuffle_indices] 207 | xs_np_shuffled = xs_np[shuffle_indices] 208 | 209 | x_shuffled = x_np_shuffled.tolist() 210 | y_shuffled = y_np_shuffled.tolist() 211 | xs_shuffled =xs_np_shuffled.tolist() 212 | 213 | if len(x_shuffled) >= self.batch_size and len(y_shuffled) >= self.batch_size and len(xs_shuffled) >=self.batch_size: 214 | break 215 | except IOError: 216 | self.end_of_data = True 217 | 218 | if len(positive) <=0 or len(negative) <=0: 219 | self.end_of_data = False 220 | self.reset() 221 | raise StopIteration 222 | 223 | return x_shuffled, y_shuffled, xs_shuffled 224 | 225 | class disTextIterator: 226 | def __init__(self, positive_data, negative_data, dis_dict, batch=1, maxlen=30, n_words_target=-1): 227 | self.positive = fopen(positive_data, 'r') 228 | self.negative = fopen(negative_data, 'r') 229 | with open(dis_dict) as f: 230 | self.dis_dict = pkl.load(f) 231 | 232 | self.batch_size = batch 233 | assert self.batch_size % 2 == 0, 'the batch size of disTextIterator is not an even number' 234 | 235 | self.maxlen = maxlen 236 | self.n_words_target = n_words_target 237 | self.end_of_data = False 238 | 239 | def __iter__(self): 240 | return self 241 | 242 | def reset(self): 243 | self.positive.seek(0) 244 | self.negative.seek(0) 245 | 246 | def next(self): 247 | if self.end_of_data: 248 | self.end_of_data = False 249 | self.reset() 250 | raise StopIteration 251 | 252 | positive = [] 253 | negative = [] 254 | x = [] 255 | y = [] 256 | try: 257 | while True: 258 | ss = self.positive.readline() 259 | if ss == "": 260 | raise IOError 261 | ss = ss.strip().split() 262 | ss = [self.dis_dict[w] if w in self.dis_dict else 1 for w in ss] 263 | if self.n_words_target > 0: 264 | ss = [w if w < self.n_words_target else 1 for w in ss] 265 | 266 | 267 | tt = self.negative.readline() 268 | if tt == "": 269 | raise IOError 270 | tt = tt.strip().split() 271 | tt = [self.dis_dict[w] if w in self.dis_dict else 1 for w in tt] 272 | if self.n_words_target > 0: 273 | tt = [w if w < self.n_words_target else 1 for w in tt] 274 | 275 | if len(ss) > self.maxlen or len(tt) > self.maxlen: 276 | continue 277 | 278 | positive.append(ss) 279 | negative.append(tt) 280 | x = positive + negative 281 | positive_labels = [[0, 1] for _ in positive] 282 | negative_labels = [[1, 0] for _ in negative] 283 | y = positive_labels + negative_labels 284 | shuffle_indices = numpy.random.permutation(numpy.arange(len(x))) 285 | x_np = numpy.array(x) 286 | y_np = numpy.array(y) 287 | x_np_shuffled = x_np[shuffle_indices] 288 | y_np_shuffled = y_np[shuffle_indices] 289 | 290 | x_shuffled = x_np_shuffled.tolist() 291 | y_shuffled = y_np_shuffled.tolist() 292 | 293 | if len(x_shuffled) >= self.batch_size and len(y_shuffled) >= self.batch_size: 294 | break 295 | 296 | except IOError: 297 | self.end_of_data = True 298 | 299 | if len(positive) <= 0 or len(negative) <= 0: 300 | self.end_of_data = False 301 | self.reset() 302 | raise StopIteration 303 | 304 | return x_shuffled, y_shuffled 305 | 306 | 307 | class genTextIterator: 308 | def __init__(self, train_data, source_dict, batch_size=1, maxlen=30, n_words_source=-1): 309 | self.source = fopen(train_data, 'r') 310 | 311 | with open(source_dict, 'rb') as f: 312 | self.source_dict = pkl.load(f) 313 | 314 | self.batch_size = batch_size 315 | self.maxlen = maxlen 316 | 317 | self.n_words_source = n_words_source 318 | self.end_of_data = False 319 | 320 | def __iter__(self): 321 | return self 322 | 323 | def reset(self): 324 | self.source.seek(0) 325 | 326 | def next(self): 327 | if self.end_of_data: 328 | self.end_of_data= False 329 | self.reset() 330 | raise StopIteration 331 | 332 | source = [] 333 | try: 334 | while True: 335 | ss = self.source.readline() 336 | if ss == "": 337 | raise IOError 338 | ss = ss.strip().split() 339 | ss = [self.source_dict[w] if w in self.source_dict else 1 for w in ss] 340 | if self.n_words_source > 0: 341 | ss = [w if w self.maxlen: 344 | continue 345 | 346 | source.append(ss) 347 | 348 | if len(source) >= self.batch_size: 349 | break 350 | except: 351 | self.end_of_data=True 352 | 353 | if len(source)<=0: 354 | self.end_of_data = False 355 | self.reset() 356 | raise StopIteration 357 | 358 | return source 359 | 360 | class TextIterator: 361 | """Simple Bitext iterator.""" 362 | def __init__(self, source, target, 363 | source_dict, target_dict, 364 | batch_size=128, 365 | maxlen=100, 366 | n_words_source=-1, 367 | n_words_target=-1): 368 | self.source = fopen(source, 'r') 369 | self.target = fopen(target, 'r') 370 | with open(source_dict, 'rb') as f: 371 | self.source_dict = pkl.load(f) 372 | with open(target_dict, 'rb') as f: 373 | self.target_dict = pkl.load(f) 374 | 375 | self.batch_size = batch_size 376 | self.maxlen = maxlen 377 | 378 | self.n_words_source = n_words_source 379 | self.n_words_target = n_words_target 380 | 381 | self.end_of_data = False 382 | 383 | def __iter__(self): 384 | return self 385 | 386 | def reset(self): 387 | self.source.seek(0) 388 | self.target.seek(0) 389 | 390 | def next(self): 391 | if self.end_of_data: 392 | self.end_of_data = False 393 | self.reset() 394 | raise StopIteration 395 | 396 | source = [] 397 | target = [] 398 | 399 | try: 400 | 401 | # actual work here 402 | while True: 403 | 404 | # read from source file and map to word index 405 | ss = self.source.readline() 406 | if ss == "": 407 | raise IOError 408 | ss = ss.strip().split() 409 | ss = [self.source_dict[w] if w in self.source_dict else 1 410 | for w in ss] 411 | if self.n_words_source > 0: 412 | ss = [w if w < self.n_words_source else 1 for w in ss] 413 | 414 | # read from source file and map to word index 415 | tt = self.target.readline() 416 | if tt == "": 417 | raise IOError 418 | tt = tt.strip().split() 419 | tt = [self.target_dict[w] if w in self.target_dict else 1 420 | for w in tt] 421 | if self.n_words_target > 0: 422 | tt = [w if w < self.n_words_target else 1 for w in tt] 423 | 424 | if len(ss) > self.maxlen and len(tt) > self.maxlen: 425 | continue 426 | 427 | source.append(ss) 428 | target.append(tt) 429 | 430 | if len(source) >= self.batch_size or \ 431 | len(target) >= self.batch_size: 432 | break 433 | except IOError: 434 | self.end_of_data = True 435 | 436 | if len(source) <= 0 or len(target) <= 0: 437 | self.end_of_data = False 438 | self.reset() 439 | raise StopIteration 440 | 441 | return source, target 442 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import codecs 3 | import os 4 | #import path 5 | import tensorflow as tf 6 | import numpy as np 7 | import yaml 8 | import time 9 | import logging 10 | from tempfile import mkstemp 11 | from argparse import ArgumentParser 12 | 13 | from model import Model, INT_TYPE 14 | from utils import DataUtil, AttrDict 15 | 16 | 17 | class Evaluator(object): 18 | """ 19 | Evaluate the model. 20 | """ 21 | def __init__(self, config): 22 | self.config = config 23 | 24 | # Load model 25 | self.model = Model(config) 26 | # self.model.build_test_model() 27 | self.model.build_variational_test_model(mode=config.test.mode) 28 | logging.info('build_test_variational_model done!') 29 | self.du = DataUtil(config) 30 | self.du.load_vocab(src_vocab=config.src_vocab, 31 | dst_vocab=config.dst_vocab, 32 | src_vocab_size=config.src_vocab_size_a, 33 | dst_vocab_size=config.src_vocab_size_b) 34 | 35 | # Create session 36 | sess_config = tf.ConfigProto() 37 | sess_config.gpu_options.allow_growth = True 38 | sess_config.allow_soft_placement = True 39 | self.sess = tf.Session(config=sess_config, graph=self.model.graph) 40 | # Restore model. 41 | with self.model.graph.as_default(): 42 | saver = tf.train.Saver(tf.global_variables()) 43 | self.saver=saver 44 | saver.restore(self.sess, tf.train.latest_checkpoint(config.train.logdir)) 45 | 46 | def __del__(self): 47 | self.sess.close() 48 | 49 | def greedy_search(self, X): 50 | """ 51 | Greedy search. 52 | Args: 53 | X: A 2-d array with size [n, src_length], source sentence indices. 54 | 55 | Returns: 56 | A 2-d array with size [n, dst_length], destination sentence indices. 57 | """ 58 | encoder_output = self.sess.run(self.model.encoder_output, feed_dict={self.model.src_pl: X}) 59 | preds = np.ones([X.shape[0], 1], dtype=INT_TYPE) * 2 # 60 | finish = np.zeros(X.shape[0:1], dtype=np.bool) 61 | for i in xrange(config.test.max_target_length): 62 | last_preds = self.sess.run(self.model.preds, feed_dict={self.model.encoder_output: encoder_output, 63 | self.model.decoder_input: preds}) 64 | finish += last_preds == 3 # 65 | if finish.all(): 66 | break 67 | preds = np.concatenate((preds, last_preds[:, None]), axis=1) 68 | 69 | return preds[:, 1:] 70 | 71 | def beam_search(self, X): 72 | """ 73 | Beam search with batch inputs. 74 | Args: 75 | X: A 2-d array with size [n, src_length], source sentence indices. 76 | 77 | Returns: 78 | A 2-d array with size [n, dst_length], target sentence indices. 79 | """ 80 | 81 | beam_size, batch_size = self.config.test.beam_size, X.shape[0] 82 | inf = 1e10 83 | 84 | def get_bias_scores(scores, bias): 85 | """ 86 | If a sequence is finished, we only allow one alive branch. This function aims to give one branch a zero score 87 | and the rest -inf score. 88 | Args: 89 | scores: A real value array with shape [batch_size * beam_size, beam_size]. 90 | bias: A bool array with shape [batch_size * beam_size]. 91 | 92 | Returns: 93 | A real value array with shape [batch_size * beam_size, beam_size]. 94 | """ 95 | b = np.array([0.0] + [-inf] * (beam_size - 1)) 96 | b = np.repeat(b[None,:], batch_size * beam_size, axis=0) # [batch * beam_size, beam_size] 97 | return scores * (1 - bias[:, None]) + b * bias[:, None] 98 | 99 | def get_bias_preds(preds, bias): 100 | """ 101 | If a sequence is finished, all of its branch should be (3). 102 | Args: 103 | preds: A int array with shape [batch_size * beam_size, beam_size]. 104 | bias: A bool array with shape [batch_size * beam_size]. 105 | 106 | Returns: 107 | A int array with shape [batch_size * beam_size]. 108 | """ 109 | return preds * (1 - bias[:, None]) + bias[:, None] * 3 110 | 111 | # Get encoder outputs. 112 | encoder_output = self.sess.run(self.model.encoder_output, feed_dict={self.model.src_pl: X}) 113 | # Prepare beam search inputs. 114 | encoder_output = np.repeat(encoder_output, beam_size, axis=0) # shape: [batch_size * beam_size, hidden_units] 115 | preds = np.ones([batch_size * beam_size, 1], dtype=INT_TYPE) * 2 # [[, , ..., ]], shape: [batch_size * beam_size, 1] 116 | scores = np.array(([0.0] + [-inf] * (beam_size - 1)) * batch_size) # [0, -inf, -inf ,..., 0, -inf, -inf, ...], shape: [batch_size * beam_size] 117 | for i in xrange(self.config.test.max_target_length): 118 | # Whether sequences finished. 119 | bias = np.equal(preds[:, -1], 3) # ? 120 | # If all sequences finished, break the loop. 121 | if bias.all(): 122 | break 123 | 124 | # Expand the nodes. 125 | last_k_preds, last_k_scores = \ 126 | self.sess.run([self.model.k_preds, self.model.k_scores], 127 | feed_dict={self.model.encoder_output: encoder_output, 128 | self.model.decoder_input: preds}) # [batch_size * beam_size, beam_size] 129 | 130 | last_k_preds = get_bias_preds(last_k_preds, bias) 131 | last_k_scores = get_bias_scores(last_k_scores, bias) 132 | # Shrink the search range. 133 | scores = scores[:, None] + last_k_scores # [batch_size * beam_size, beam_size] 134 | scores = scores.reshape([batch_size, beam_size**2]) # [batch_size, beam_size * beam_size] 135 | 136 | # Reserve beam_size nodes. 137 | k_indices = np.argsort(scores)[:, -beam_size:] # [batch_size, beam_size] 138 | k_indices = np.repeat(np.array(range(0, batch_size)), beam_size) * beam_size**2 + k_indices.flatten() # [batch_size * beam_size] 139 | scores = scores.flatten()[k_indices] # [batch_size * beam_size] 140 | last_k_preds = last_k_preds.flatten()[k_indices] 141 | preds = preds[k_indices // beam_size] 142 | preds = np.concatenate((preds, last_k_preds[:, None]), axis=1) # [batch_size * beam_size, i] 143 | 144 | scores = scores.reshape([batch_size, beam_size]) 145 | preds = preds.reshape([batch_size, beam_size, -1]) # [batch_size, beam_size, max_length] 146 | lengths = np.sum(np.not_equal(preds, 3), axis=-1) # [batch_size, beam_size] 147 | lp = ((5 + lengths) / (5 + 1)) ** self.config.test.lp_alpha # Length penalty 148 | scores /= lp # following GNMT. 149 | max_indices = np.argmax(scores, axis=-1) # [batch_size] 150 | max_indices += np.array(range(batch_size)) * beam_size 151 | preds = preds.reshape([batch_size * beam_size, -1]) 152 | logging.debug(scores.flatten()[max_indices]) 153 | return preds[max_indices][:, 1:] 154 | 155 | def loss(self, X, Y): 156 | return self.sess.run(self.model.loss_sum, feed_dict={self.model.src_pl: X, self.model.dst_pl: Y}) 157 | 158 | def translate(self, src_o='src', trg_o='dst'): 159 | logging.info('Translate %s.' % self.config.test.src_path) 160 | _, tmp = mkstemp() 161 | fd = codecs.open(tmp, 'w', 'utf8') 162 | count = 0 163 | start = time.time() 164 | 165 | for X in self.du.get_test_batches(o=src_o): 166 | Y = self.beam_search(X) 167 | sents = self.du.indices_to_words(Y, o=trg_o) 168 | for sent in sents: 169 | print(sent, file=fd) 170 | #print(sent) 171 | count += len(X) 172 | logging.info('%d sentences processed in %.2f minutes.' % (count, (time.time()-start) / 60)) 173 | fd.close() 174 | # Remove BPE flag, if have. 175 | os.system("sed -r 's/(@@ )|(@@ ?$)//g' %s > %s" % (tmp, self.config.test.output_path)) 176 | logging.info('The result file was saved in %s.' % self.config.test.output_path) 177 | 178 | def ppl(self, src_o='src', trg_o='trg'): 179 | if 'dst_path' not in self.config.test: 180 | logging.warning("Skip PPL calculation due to missing of parameter 'dst_path' in config file.") 181 | return 182 | logging.info('Calculate PPL for %s and %s.' % (self.config.test.src_path, self.config.test.dst_path)) 183 | token_count = 0 184 | loss_sum = 0 185 | for batch in self.du.get_test_batches_with_target(src_o=src_o,trg_o=trg_o): 186 | X, Y = batch 187 | loss_sum += self.loss(X, Y) 188 | token_count += np.sum(np.greater(Y, 0)) 189 | # Compute PPL 190 | logging.info('PPL: %.4f' % np.exp(loss_sum / token_count)) 191 | 192 | def evaluate(self): 193 | if self.model.mode == 'aa': 194 | src_o = trg_o = 'src' 195 | elif self.model.mode == 'ab': 196 | src_o = 'src' 197 | trg_o = 'dst' 198 | elif self.model.mode == 'bb': 199 | src_o = trg_o = 'dst' 200 | elif self.model.mode == 'ba': 201 | src_o = 'dst' 202 | trg_o = 'src' 203 | else: 204 | raise Exception('mode Error!') 205 | 206 | self.translate(src_o=src_o, trg_o=trg_o) 207 | if 'eval_script' in self.config.test: 208 | script_path = self.config.test.eval_script 209 | else: 210 | script_path = 'multi-bleu.perl' 211 | script_interpreter = script_path.rsplit('.', 1)[1] 212 | script_dir = os.path.dirname(script_path) or '.' 213 | os.chdir(script_dir) 214 | # Call a script to evaluate. 215 | os.system("%s %s %s < %s" % (script_interpreter, script_path, self.config.test.ori_dst_path, 216 | self.config.test.output_path)) 217 | self.ppl(src_o=src_o,trg_o=trg_o) 218 | 219 | def evaluate_translate_only(self): 220 | if self.model.mode == 'aa': 221 | src_o = trg_o = 'src' 222 | elif self.model.mode == 'ab': 223 | src_o = 'src' 224 | trg_o = 'dst' 225 | elif self.model.mode == 'bb': 226 | src_o = trg_o = 'dst' 227 | elif self.model.mode == 'ba': 228 | src_o = 'dst' 229 | trg_o = 'src' 230 | else: 231 | raise Exception('mode Error!') 232 | 233 | self.translate(src_o=src_o, trg_o=trg_o) 234 | 235 | if __name__ == '__main__': 236 | parser = ArgumentParser() 237 | parser.add_argument('-c', '--config', dest='config') 238 | args = parser.parse_args() 239 | # Read config 240 | config = AttrDict(yaml.load(open(args.config))) 241 | # Logger 242 | logging.basicConfig(level=logging.INFO) 243 | evaluator = Evaluator(config) 244 | evaluator.evaluate() 245 | logging.info("Done") 246 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES='2' 3 | 4 | python evaluate.py -c ./configs/config_generator_train.yaml 5 | -------------------------------------------------------------------------------- /gan_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import yaml 3 | import time 4 | import os 5 | import sys 6 | import numpy as np 7 | import logging 8 | from argparse import ArgumentParser 9 | import tensorflow as tf 10 | 11 | from utils import DataUtil, AttrDict 12 | from model import Model 13 | from cnn_text_discriminator import text_DisCNN 14 | from share_function import deal_generated_samples 15 | from share_function import deal_generated_samples_to_maxlen 16 | from share_function import extend_sentence_to_maxlen 17 | from share_function import prepare_gan_dis_data 18 | from share_function import FlushFile 19 | 20 | def gan_train(config): 21 | sess_config = tf.ConfigProto() 22 | sess_config.gpu_options.allow_growth = True 23 | sess_config.allow_soft_placement = True 24 | 25 | default_graph=tf.Graph() 26 | with default_graph.as_default(): 27 | sess = tf.Session(config=sess_config, graph=default_graph) 28 | 29 | logger = logging.getLogger('') 30 | du = DataUtil(config=config) 31 | du.load_vocab(src_vocab=config.generator.src_vocab, 32 | dst_vocab=config.generator.dst_vocab, 33 | src_vocab_size=config.src_vocab_size_a, 34 | dst_vocab_size=config.src_vocab_size_b) 35 | 36 | generator = Model(config=config, graph=default_graph, sess=sess) 37 | generator.build_variational_train_model() 38 | 39 | generator.init_and_restore(modelFile=config.generator.modelFile) 40 | 41 | dis_filter_sizes = [i for i in range(1, config.discriminator.dis_max_len, 4)] 42 | dis_num_filters = [(100 + i * 10) for i in range(1, config.discriminator.dis_max_len, 4)] 43 | 44 | discriminator_src = text_DisCNN( 45 | sess=sess, 46 | max_len=config.discriminator.dis_max_len, 47 | num_classes=2, 48 | vocab_size_s=config.dst_vocab_size_a, 49 | batch_size=config.discriminator.dis_batch_size, 50 | dim_word=config.discriminator.dis_dim_word, 51 | filter_sizes=dis_filter_sizes, 52 | num_filters=dis_num_filters, 53 | source_dict=config.discriminator.dis_src_vocab, 54 | gpu_device=config.discriminator.dis_gpu_devices, 55 | s_domain_data=config.discriminator.s_domain_data, 56 | s_domain_generated_data=config.discriminator.s_domain_generated_data, 57 | dev_s_domain_data=config.discriminator.dev_s_domain_data, 58 | dev_s_domain_generated_data=config.discriminator.dev_s_domain_generated_data, 59 | max_epoches=config.discriminator.dis_max_epoches, 60 | dispFreq=config.discriminator.dis_dispFreq, 61 | saveFreq=config.discriminator.dis_saveFreq, 62 | saveto=config.discriminator.dis_saveto, 63 | reload=config.discriminator.dis_reload, 64 | clip_c=config.discriminator.dis_clip_c, 65 | optimizer=config.discriminator.dis_optimizer, 66 | reshuffle=config.discriminator.dis_reshuffle, 67 | scope=config.discriminator.dis_scope 68 | ) 69 | 70 | discriminator_trg = text_DisCNN( 71 | sess=sess, 72 | max_len=config.discriminator.dis_max_len, 73 | num_classes=2, 74 | vocab_size_s=config.dst_vocab_size_b, 75 | batch_size=config.discriminator.dis_batch_size, 76 | dim_word=config.discriminator.dis_dim_word, 77 | filter_sizes=dis_filter_sizes, 78 | num_filters=dis_num_filters, 79 | source_dict=config.discriminator.dis_dst_vocab, 80 | gpu_device=config.discriminator.dis_gpu_devices, 81 | s_domain_data=config.discriminator.t_domain_data, 82 | s_domain_generated_data=config.discriminator.t_domain_generated_data, 83 | dev_s_domain_data=config.discriminator.dev_t_domain_data, 84 | dev_s_domain_generated_data=config.discriminator.dev_t_domain_generated_data, 85 | max_epoches=config.discriminator.dis_max_epoches, 86 | dispFreq=config.discriminator.dis_dispFreq, 87 | saveFreq=config.discriminator.dis_saveFreq, 88 | saveto=config.discriminator.dis_saveto_trg, 89 | reload=config.discriminator.dis_reload, 90 | clip_c=config.discriminator.dis_clip_c, 91 | optimizer=config.discriminator.dis_optimizer, 92 | reshuffle=config.discriminator.dis_reshuffle, 93 | scope=config.discriminator.dis_scope_trg 94 | ) 95 | 96 | batch_iter = du.get_training_batches( 97 | set_train_src_path=config.generator.src_path, 98 | set_train_dst_path=config.generator.dst_path, 99 | set_batch_size=config.generator.batch_size, 100 | set_max_length=config.generator.max_length 101 | ) 102 | 103 | for epoch in range(1, config.gan_iter_num + 1): 104 | for gen_iter in range(config.gan_gen_iter_num): 105 | batch = next(batch_iter) 106 | x, y = batch[0], batch[1] 107 | generate_ab, generate_ba = generator.generate_step(x, y) 108 | 109 | logging.info("generate the samples") 110 | generate_ab_dealed, generate_ab_mask = deal_generated_samples(generate_ab, du.dst2idx) 111 | generate_ba_dealed, generate_ba_mask = deal_generated_samples(generate_ba, du.src2idx) 112 | # 113 | #### for debug 114 | ##print('the sample is ') 115 | ##sample_str=du.indices_to_words(y_sample_dealed, 'dst') 116 | ##print(sample_str) 117 | # 118 | 119 | x_to_maxlen = extend_sentence_to_maxlen(x) 120 | y_to_maxlen = extend_sentence_to_maxlen(y) 121 | 122 | logging.info("calculate the reward") 123 | rewards_ab = generator.get_reward(x=x, 124 | x_to_maxlen=x_to_maxlen, 125 | y_sample=generate_ab_dealed, 126 | y_sample_mask=generate_ab_mask, 127 | rollnum=config.rollnum, 128 | disc=discriminator_trg, 129 | max_len=config.discriminator.dis_max_len, 130 | bias_num=config.bias_num, 131 | data_util=du, 132 | direction='ab') 133 | 134 | rewards_ba = generator.get_reward(x=y, 135 | x_to_maxlen=y_to_maxlen, 136 | y_sample=generate_ba_dealed, 137 | y_sample_mask=generate_ba_mask, 138 | rollnum=config.rollnum, 139 | disc=discriminator_src, 140 | max_len=config.discriminator.dis_max_len, 141 | bias_num=config.bias_num, 142 | data_util=du, 143 | direction='ba') 144 | 145 | 146 | loss_ab = generator.generate_step_and_update(x, generate_ab_dealed, rewards_ab) 147 | 148 | loss_ba = generator.generate_step_and_update(y, generate_ba_dealed, rewards_ba) 149 | 150 | print("the reward for ab and ba is ", rewards_ab, rewards_ba) 151 | print("the loss is for ab and ba is", loss_ab, loss_ba) 152 | 153 | logging.info("save the model into %s" % config.generator.modelFile) 154 | generator.saver.save(generator.sess, config.generator.modelFile) 155 | 156 | 157 | #### modified to here, next starts from here 158 | 159 | logging.info("prepare the gan_dis_data begin") 160 | data_num = prepare_gan_dis_data( 161 | train_data_source=config.generator.src_path, 162 | train_data_target=config.generator.dst_path, 163 | gan_dis_source_data=config.discriminator.s_domain_data, 164 | gan_dis_positive_data=config.discriminator.t_domain_data, 165 | num=config.generate_num, 166 | reshuf=True 167 | ) 168 | 169 | logging.info("generate and the save t_domain_generated_data in to %s." %config.discriminator.t_domain_generated_data) 170 | 171 | generator.generate_and_save(data_util=du, 172 | infile=config.discriminator.s_domain_data, 173 | generate_batch=config.discriminator.dis_batch_size, 174 | outfile=config.discriminator.t_domain_generated_data, 175 | direction='ab' 176 | ) 177 | 178 | logging.info("generate and the save s_domain_generated_data in to %s." %config.discriminator.s_domain_generated_data) 179 | 180 | generator.generate_and_save(data_util=du, 181 | infile=config.discriminator.t_domain_data, 182 | generate_batch=config.discriminator.dis_batch_size, 183 | outfile=config.discriminator.s_domain_generated_data, 184 | direction='ba' 185 | ) 186 | 187 | logging.info("prepare %d gan_dis_data done!" %data_num) 188 | logging.info("finetuen the discriminator begin") 189 | 190 | discriminator_src.train(max_epoch=config.gan_dis_iter_num, 191 | s_domain_data=config.discriminator.s_domain_data, 192 | s_domain_generated_data=config.discriminator.s_domain_generated_data, 193 | ) 194 | discriminator_src.saver.save(discriminator_src.sess, discriminator_src.saveto) 195 | 196 | discriminator_trg.train(max_epoch=config.gan_dis_iter_num, 197 | s_domain_data=config.discriminator.t_domain_data, 198 | s_domain_generated_data=config.discriminator.t_domain_generated_data, 199 | ) 200 | discriminator_trg.saver.save(discriminator_trg.sess, discriminator_trg.saveto) 201 | logging.info("finetune the discrimiantor done!") 202 | 203 | logging.info('reinforcement training done!') 204 | 205 | if __name__ == '__main__': 206 | sys.stdout = FlushFile(sys.stdout) 207 | parser = ArgumentParser() 208 | parser.add_argument('-c', '--config', dest='config') 209 | args = parser.parse_args() 210 | # Read config 211 | config = AttrDict(yaml.load(open(args.config))) 212 | # Logger 213 | if not os.path.exists(config.logdir): 214 | os.makedirs(config.logdir) 215 | logging.basicConfig(filename=config.logdir+'/train.log', level=logging.INFO) 216 | console = logging.StreamHandler() 217 | console.setLevel(logging.DEBUG) 218 | logging.getLogger('').addHandler(console) 219 | # Train 220 | gan_train(config) 221 | 222 | -------------------------------------------------------------------------------- /gan_train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5' 2 | 3 | python gan_train.py -c ./configs/config_gan_train.yaml 4 | -------------------------------------------------------------------------------- /generate_samples.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import yaml 3 | import time 4 | import os 5 | import logging 6 | from argparse import ArgumentParser 7 | import tensorflow as tf 8 | 9 | from utils import DataUtil, AttrDict 10 | from model import Model 11 | from cnn_discriminator import DisCNN 12 | from share_function import deal_generated_samples 13 | from share_function import extend_sentence_to_maxlen 14 | 15 | def generate_samples(config): 16 | sess_config = tf.ConfigProto() 17 | sess_config.gpu_options.allow_growth = True 18 | sess_config.allow_soft_placement = True 19 | 20 | default_graph = tf.Graph() 21 | with default_graph.as_default(): 22 | sess = tf.Session(config=sess_config, graph=default_graph) 23 | 24 | logger = logging.getLogger('') 25 | du = DataUtil(config=config) 26 | du.load_vocab(src_vocab=config.src_vocab, 27 | dst_vocab=config.dst_vocab, 28 | src_vocab_size=config.src_vocab_size_a, 29 | dst_vocab_size=config.dst_vocab_size_b) 30 | 31 | generator = Model(config=config, graph=default_graph, sess=sess) 32 | generator.build_variational_train_model() 33 | 34 | generator.init_and_restore(config.train.modelFile) 35 | 36 | print("begin generate the data and save the negative") 37 | generator.generate_and_save(du, config.train.src_path, config.train.batch_size, config.train.t_domain_generated_data, direction='ab') 38 | generator.generate_and_save(du, config.train.dst_path, config.train.batch_size, config.train.s_domain_generated_data, direction='ba') 39 | print("generate the data done!") 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = ArgumentParser() 44 | parser.add_argument('-c', '--config', dest='config') 45 | args = parser.parse_args() 46 | # Read config 47 | config = AttrDict(yaml.load(open(args.config))) 48 | # Logger 49 | if not os.path.exists(config.train.logdir): 50 | os.makedirs(config.train.logdir) 51 | logging.basicConfig(filename=config.train.logdir+'/train.log', level=logging.DEBUG) 52 | console = logging.StreamHandler() 53 | console.setLevel(logging.INFO) 54 | logging.getLogger('').addHandler(console) 55 | 56 | # Train 57 | generate_samples(config) 58 | -------------------------------------------------------------------------------- /multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | sub add_to_ref { 35 | my ($file,$REF) = @_; 36 | my $s=0; 37 | open(REF,$file) or die "Can't read $file"; 38 | while() { 39 | chop; 40 | push @{$$REF[$s++]}, $_; 41 | } 42 | close(REF); 43 | } 44 | 45 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 46 | my $s=0; 47 | while() { 48 | chop; 49 | $_ = lc if $lowercase; 50 | my @WORD = split; 51 | my %REF_NGRAM = (); 52 | my $length_translation_this_sentence = scalar(@WORD); 53 | my ($closest_diff,$closest_length) = (9999,9999); 54 | foreach my $reference (@{$REF[$s]}) { 55 | # print "$s $_ <=> $reference\n"; 56 | $reference = lc($reference) if $lowercase; 57 | my @WORD = split(' ',$reference); 58 | my $length = scalar(@WORD); 59 | my $diff = abs($length_translation_this_sentence-$length); 60 | if ($diff < $closest_diff) { 61 | $closest_diff = $diff; 62 | $closest_length = $length; 63 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 64 | } elsif ($diff == $closest_diff) { 65 | $closest_length = $length if $length < $closest_length; 66 | # from two references with the same closeness to me 67 | # take the *shorter* into account, not the "first" one. 68 | } 69 | for(my $n=1;$n<=4;$n++) { 70 | my %REF_NGRAM_N = (); 71 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 72 | my $ngram = "$n"; 73 | for(my $w=0;$w<$n;$w++) { 74 | $ngram .= " ".$WORD[$start+$w]; 75 | } 76 | $REF_NGRAM_N{$ngram}++; 77 | } 78 | foreach my $ngram (keys %REF_NGRAM_N) { 79 | if (!defined($REF_NGRAM{$ngram}) || 80 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 81 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 82 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 83 | } 84 | } 85 | } 86 | } 87 | $length_translation += $length_translation_this_sentence; 88 | $length_reference += $closest_length; 89 | for(my $n=1;$n<=4;$n++) { 90 | my %T_NGRAM = (); 91 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 92 | my $ngram = "$n"; 93 | for(my $w=0;$w<$n;$w++) { 94 | $ngram .= " ".$WORD[$start+$w]; 95 | } 96 | $T_NGRAM{$ngram}++; 97 | } 98 | foreach my $ngram (keys %T_NGRAM) { 99 | $ngram =~ /^(\d+) /; 100 | my $n = $1; 101 | # my $corr = 0; 102 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 103 | $TOTAL[$n] += $T_NGRAM{$ngram}; 104 | if (defined($REF_NGRAM{$ngram})) { 105 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 106 | $CORRECT[$n] += $T_NGRAM{$ngram}; 107 | # $corr = $T_NGRAM{$ngram}; 108 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 109 | } 110 | else { 111 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 112 | # $corr = $REF_NGRAM{$ngram}; 113 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 114 | } 115 | } 116 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 117 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 118 | } 119 | } 120 | $s++; 121 | } 122 | my $brevity_penalty = 1; 123 | my $bleu = 0; 124 | 125 | my @bleu=(); 126 | 127 | for(my $n=1;$n<=4;$n++) { 128 | if (defined ($TOTAL[$n])){ 129 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 130 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 131 | }else{ 132 | $bleu[$n]=0; 133 | } 134 | } 135 | 136 | if ($length_reference==0){ 137 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 138 | exit(1); 139 | } 140 | 141 | if ($length_translation<$length_reference) { 142 | $brevity_penalty = exp(1-$length_reference/$length_translation); 143 | } 144 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 145 | my_log( $bleu[2] ) + 146 | my_log( $bleu[3] ) + 147 | my_log( $bleu[4] ) ) / 4) ; 148 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 149 | 100*$bleu, 150 | 100*$bleu[1], 151 | 100*$bleu[2], 152 | 100*$bleu[3], 153 | 100*$bleu[4], 154 | $brevity_penalty, 155 | $length_translation / $length_reference, 156 | $length_translation, 157 | $length_reference; 158 | 159 | sub my_log { 160 | return -9999999999 unless $_[0]; 161 | return log($_[0]); 162 | } 163 | -------------------------------------------------------------------------------- /runDev.py: -------------------------------------------------------------------------------- 1 | from evaluate import Evaluator 2 | import yaml 3 | import time 4 | import numpy as np 5 | import os 6 | import sys 7 | import logging 8 | from share_function import FlushFile 9 | from utils import DataUtil, AttrDict 10 | 11 | evaluate_config=sys.argv[1] 12 | config=AttrDict(yaml.load(open(evaluate_config))) 13 | evaluator=Evaluator(config) 14 | 15 | 16 | idx=0 17 | bleu=0.10 18 | logFile='DevResult' 19 | 20 | sys.stdout=FlushFile(sys.stdout) 21 | logging.basicConfig(level=logging.INFO, 22 | filename=logFile, 23 | filemode='a') 24 | 25 | if not os.path.isdir('./model_best'): 26 | os.system('mkdir model_best') 27 | else: 28 | logging.info("model_best dir exists") 29 | 30 | while True: 31 | idx +=1 32 | logging.info('idx: '+str(idx)) 33 | evaluator.saver.restore(evaluator.sess, config.test.modelFile) 34 | evaluator.evaluate_translate_only() 35 | command_line ="./multi-bleu.perl "+config.test.ori_dst_path+" < "+config.test.output_path 36 | fd = os.popen(command_line) 37 | output=fd.read() 38 | print(output) 39 | try: 40 | BLEUStrIdx = output.index('BLEU = ') 41 | bleu_new = float(output[BLEUStrIdx+7:BLEUStrIdx+12]) 42 | if bleu_new > bleu: 43 | bleu=bleu_new 44 | command_line="cp "+config.test.modelFile+".*"+ " ./model_best/" 45 | os.popen(command_line) 46 | logging.info('best update in idex '+str(idx)) 47 | 48 | logging.info('bleu: '+str(bleu_new)+'\n') 49 | except: 50 | logging.info('BLEU score not found!') 51 | 52 | fd.close() 53 | time.sleep(5000) 54 | 55 | 56 | -------------------------------------------------------------------------------- /run_gan_dev.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES='6,7' 2 | python runDev.py ./configs/config_gan_train.yaml 3 | -------------------------------------------------------------------------------- /share_function.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import tensorflow as tf 4 | import numpy 5 | import time 6 | import os 7 | from operator import mul 8 | 9 | from data_iterator import disTextIterator 10 | from data_iterator import genTextIterator 11 | from data_iterator import TextIterator 12 | 13 | from collections import defaultdict 14 | from math import exp 15 | 16 | 17 | def prepare_gan_dis_data(train_data_source, train_data_target, gan_dis_source_data, gan_dis_positive_data, 18 | num=None, reshuf=True): 19 | 20 | source = open(train_data_source, 'r') 21 | sourceLists = source.readlines() 22 | 23 | if num is None or num > len(sourceLists): 24 | num = len(sourceLists) 25 | 26 | if reshuf: 27 | os.popen('python shuffle.py ' +train_data_source+' '+train_data_target) 28 | os.popen('head -n ' + str(num) +' '+ train_data_source+'.shuf'+' >'+gan_dis_source_data) 29 | os.popen('head -n ' + str(num) +' '+ train_data_target+'.shuf'+' >'+gan_dis_positive_data) 30 | else: 31 | os.popen('head -n ' + str(num) +' '+ train_data_source + '.shuf' + ' >'+gan_dis_source_data) 32 | os.popen('head -n ' + str(num) +' '+ train_data_target + '.shuf' + ' >'+gan_dis_positive_data) 33 | 34 | os.popen('rm '+train_data_source+'.shuf') 35 | os.popen('rm '+train_data_target+'.shuf') 36 | return num 37 | 38 | def prepare_three_gan_dis_dev_data(gan_dis_positive_data, gan_dis_negative_data, gan_dis_source_data, dev_dis_positive_data, dev_dis_negative_data, dev_dis_source_data, num): 39 | gan_dis = open(gan_dis_positive_data, 'r') 40 | disLists = gan_dis.readlines() 41 | 42 | if num is None or num > len(disLists): 43 | num = len(disLists) 44 | 45 | os.popen('head -n '+ str(num) +' '+gan_dis_positive_data+' >'+dev_dis_positive_data) 46 | os.popen('head -n '+ str(num) +' '+gan_dis_negative_data+' >'+dev_dis_negative_data) 47 | os.popen('head -n '+ str(num) +' '+gan_dis_source_data+' >'+dev_dis_source_data) 48 | 49 | return num 50 | 51 | def prepare_gan_dis_dev_data(gan_dis_positive_data, gan_dis_negative_data, dev_dis_positive_data, dev_dis_negative_data, num): 52 | 53 | gan_dis = open(gan_dis_positive_data, 'r') 54 | disLists = gan_dis.readlines() 55 | 56 | if num is None or num > len(disLists): 57 | num = len(disLists) 58 | 59 | os.popen('head -n '+ str(num) +' '+gan_dis_positive_data+' >'+dev_dis_positive_data) 60 | os.popen('head -n '+ str(num) +' '+gan_dis_negative_data+' >'+dev_dis_negative_data) 61 | 62 | return num 63 | 64 | def print_string(src_or_trg, indexs, worddicts_r): 65 | sample_str = '' 66 | for index in indexs: 67 | if index > 0: 68 | if src_or_trg == 'y': 69 | word_str = worddicts_r[1][index] 70 | else: 71 | word_str = worddicts_r[0][index] 72 | sample_str = sample_str + word_str + ' ' 73 | return sample_str 74 | 75 | class FlushFile: 76 | """ 77 | A wrapper for File, allowing users see result immediately. 78 | """ 79 | def __init__(self, f): 80 | self.f = f 81 | 82 | def write(self, x): 83 | self.f.write(x) 84 | self.f.flush() 85 | 86 | def _p(pp, name): 87 | return '%s_%s' % (pp, name) 88 | 89 | def dis_train_iter(dis_positive_data, dis_negative_data, reshuffle, dictionary, n_words_trg, batch_size, maxlen): 90 | iter = 0 91 | while True: 92 | if reshuffle: 93 | os.popen('python shuffle.py '+dis_positive_data+' '+dis_positive_data) 94 | os.popen('mv ' + dis_negative_data + '.shuf ' + dis_negtive_data) 95 | os.popen('mv ' + dis_negative_data + '.shuf ' + dis_negative_data) 96 | disTrain = disTextIterator(dis_positive_data, dis_negative_data, dictionary, batch_size, maxlen, n_words_trg) 97 | iter +=1 98 | ExampleNum = 0 99 | iterStart = time.time() 100 | for x, y in disTrain: 101 | ExampleNum += len(x) 102 | yield x, y, iter 103 | TimeCost = time.time() - EpochStart 104 | print('Seen ', ExampleNum, ' examples for discriminator. Time cost : ', TimeCost) 105 | 106 | 107 | def gen_train_iter(gen_file, reshuffle, dictionary, n_words, batch_size, maxlen): 108 | iter = 0 109 | while True: 110 | if reshuffle: 111 | os.popen('python shuffle.py '+ gen_file) 112 | os.popen('mv '+ gen_file +'.shuf ' + gen_file) 113 | gen_train = genTextIterator(gen_file, dictionary, n_words_source = n_words, batch_size = batch_size, maxlen=maxlen) 114 | ExampleNum = 0 115 | EpochStart = time.time() 116 | for x in gen_train: 117 | if len(x) < batch_size: 118 | continue 119 | ExampleNum +=len(x) 120 | yield x, iter 121 | TimeCost = time.time() - EpochStart 122 | iter +=1 123 | print('Seen ', ExampleNum, 'generator samples. Time cost is ', TimeCost) 124 | 125 | def gen_force_train_iter(source_data, target_data, reshuffle, source_dict, target_dict, batch_size, maxlen, n_words_src, n_words_trg): 126 | iter = 0 127 | while True: 128 | if reshuffle: 129 | os.popen('python shuffle.py '+ source_data + ' ' + target_data) 130 | os.popen('mv '+ source_data + '.shuf ' + source_data) 131 | os.popen('mv '+ target_data + '.shuf ' + target_data) 132 | gen_force_train = TextIterator(source_data, target_data, source_dict, target_dict, batch_size, maxlen, n_words_src, n_words_trg) 133 | ExampleNum = 0 134 | EpochStart = time.time() 135 | for x, y in gen_force_train: 136 | if len(x) < batch_size and len(y) < batch_size: 137 | continue 138 | ExampleNum += len(x) 139 | yield x, y, iter 140 | TimeCost = time.time() - EpochStart 141 | iter +=1 142 | print('Seen', ExampleNum, 'generator samples. Time cost is ', TimeCost) 143 | 144 | def prepare_data(seqs_x, seqs_y, maxlen=None, n_words_src=30000, 145 | n_words=30000, precision='float32'): 146 | # x: a list of sentences 147 | lengths_x = [len(s) for s in seqs_x] 148 | lengths_y = [len(s) for s in seqs_y] 149 | 150 | if maxlen is not None: 151 | new_seqs_x = [] 152 | new_seqs_y = [] 153 | new_lengths_x = [] 154 | new_lengths_y = [] 155 | for l_x, s_x, l_y, s_y in zip(lengths_x, seqs_x, lengths_y, seqs_y): 156 | if l_x < maxlen and l_y < maxlen: 157 | new_seqs_x.append(s_x) 158 | new_lengths_x.append(l_x) 159 | new_seqs_y.append(s_y) 160 | new_lengths_y.append(l_y) 161 | lengths_x = new_lengths_x 162 | seqs_x = new_seqs_x 163 | lengths_y = new_lengths_y 164 | seqs_y = new_seqs_y 165 | 166 | if len(lengths_x) < 1 or len(lengths_y) < 1: 167 | return None, None, None, None 168 | 169 | n_samples = len(seqs_x) 170 | maxlen_x = numpy.max(lengths_x) + 1 171 | maxlen_y = numpy.max(lengths_y) + 1 172 | 173 | x = numpy.zeros((maxlen_x, n_samples)).astype('int32') 174 | y = numpy.zeros((maxlen_y, n_samples)).astype('int32') 175 | x_mask = numpy.zeros((maxlen_x, n_samples)).astype(precision) 176 | y_mask = numpy.zeros((maxlen_y, n_samples)).astype(precision) 177 | for idx, [s_x, s_y] in enumerate(zip(seqs_x, seqs_y)): 178 | x[:lengths_x[idx], idx] = s_x 179 | x_mask[:lengths_x[idx]+1, idx] = 1. 180 | y[:lengths_y[idx], idx] = s_y 181 | y_mask[:lengths_y[idx]+1, idx] = 1. 182 | 183 | return x, x_mask, y, y_mask 184 | 185 | 186 | def dis_three_length_prepare(seqs_x, seqs_y, seqs_xs, maxlen=50): 187 | n_samples = len(seqs_x) 188 | x = numpy.zeros((maxlen, n_samples)).astype('int32') 189 | y = numpy.zeros((2, n_samples)).astype('int32') 190 | xs = numpy.zeros((maxlen, n_samples)).astype('int32') 191 | 192 | for idx, [s_x, s_y, s_xs] in enumerate(zip(seqs_x, seqs_y, seqs_xs)): 193 | x[:len(s_x), idx] = s_x 194 | y[:len(s_y), idx] = s_y 195 | xs[:len(s_xs), idx] = s_xs 196 | return x, y, xs 197 | 198 | def dis_length_prepare(seqs_x, seqs_y, num_classes=2, maxlen=50): 199 | n_samples = len(seqs_x) 200 | x = numpy.zeros((maxlen, n_samples)).astype('int32') 201 | y = numpy.zeros((num_classes, n_samples)).astype('int32') 202 | 203 | for idx, [s_x, s_y] in enumerate(zip(seqs_x, seqs_y)): 204 | x[:len(s_x), idx] = s_x 205 | y[:len(s_y), idx] = s_y 206 | return x, y 207 | 208 | def prepare_single_sentence(seqs_x, maxlen=50): 209 | n_samples = len(seqs_x) 210 | lens_x = [len(seq) for seq in seqs_x] 211 | maxlen_x = numpy.max(lens_x) + 1 212 | 213 | x = numpy.zeros((maxlen_x, n_samples)).astype('int32') 214 | for idx, s_x in enumerate(seqs_x): 215 | x[:len(s_x), idx] = s_x 216 | return x 217 | 218 | def prepare_multiple_sentence(seqs_x, maxlen=50, precision='float32'): 219 | n_samples = len(seqs_x) 220 | lens_x = [len(seq) for seq in seqs_x] 221 | maxlen_x = numpy.max(lens_x) + 1 222 | 223 | x = numpy.zeros((maxlen_x, n_samples)).astype('int32') 224 | x_mask = numpy.zeros((maxlen_x, n_samples)).astype(precision) 225 | 226 | for idx, s_x in enumerate(seqs_x): 227 | x[:len(s_x), idx] = s_x 228 | x_mask[:len(s_x), idx] = 1. 229 | 230 | return x, x_mask 231 | 232 | def prepare_sentence_to_maxlen(seqs_x, maxlen=50, precision='float32'): 233 | n_samples = len(seqs_x) 234 | x = numpy.zeros((maxlen, n_samples)).astype('int32') 235 | 236 | for idx, s_x in enumerate(seqs_x): 237 | x[:len(s_x), idx]=s_x 238 | return x 239 | 240 | def extend_sentence_to_maxlen(seqs, maxlen = 50): 241 | n_samples = len(seqs) 242 | x=numpy.zeros((n_samples, maxlen)).astype('int32') 243 | for idx, seq in enumerate(seqs): 244 | x[idx, :len(seq)]=seq 245 | return x 246 | 247 | 248 | def deal_generated_y_sentence(seqs_y, worddicts, precision='float32'): 249 | n_samples = len(seqs_y) 250 | lens_y = [len(seq) for seq in seqs_y] 251 | maxlen_y = numpy.max(lens_y) 252 | eosTag = '' 253 | eosIndex = worddicts[1][eosTag] 254 | 255 | y = numpy.zeros((maxlen_y, n_samples)).astype('int32') 256 | y_mask = numpy.zeros((maxlen_y, n_samples)).astype(precision) 257 | 258 | for idy, s_y in enumerate(seqs_y): 259 | try: 260 | firstIndex = s_y.tolist().index(eosIndex)+1 261 | except ValueError: 262 | firstIndex = maxlen_y - 1 263 | 264 | y[:firstIndex, idy]=s_y[:firstIndex] 265 | y_mask[:firstIndex, idy]=1. 266 | 267 | return y, y_mask 268 | 269 | def deal_generated_samples(y_sample, dicts): 270 | 271 | eosTag='
' 272 | eosIndex = dicts.get(eosTag) 273 | #print("eosIndex is", eosIndex) 274 | n_samples = len(y_sample) 275 | lens_y = [len(y) for y in y_sample] 276 | maxlen_y = numpy.max(lens_y) 277 | 278 | y = numpy.zeros((n_samples, maxlen_y)).astype('int32') 279 | y_mask = numpy.zeros((n_samples, maxlen_y)).astype('float32') 280 | 281 | for idy, s_y in enumerate(y_sample): 282 | try: 283 | firstIndex = s_y.tolist().index(eosIndex) ###
not included 284 | except: 285 | firstIndex = len(s_y) 286 | y[idy, :firstIndex]=s_y[:firstIndex] 287 | y_mask[idy, :firstIndex]=1. 288 | 289 | return y, y_mask 290 | 291 | def deal_generated_samples_to_maxlen(y_sample, dicts, maxlen): 292 | 293 | eosTag='' 294 | eosIndex = dicts.get(eosTag) 295 | #print("eosIndex is", eosIndex) 296 | n_samples = len(y_sample) 297 | 298 | y = numpy.zeros((n_samples, maxlen)).astype('int32') 299 | y_mask = numpy.zeros((n_samples, maxlen)).astype('float32') 300 | 301 | for idy, s_y in enumerate(y_sample): 302 | try: 303 | firstIndex = s_y.tolist().index(eosIndex) ### not included 304 | except: 305 | firstIndex = len(s_y) 306 | y[idy, :firstIndex]=s_y[:firstIndex] 307 | y_mask[idy, :firstIndex]=1. 308 | 309 | return y, y_mask 310 | 311 | 312 | def remove_pad_tolist(seqs): 313 | seqs_removed_list=[] 314 | 315 | for ids, s_y in enumerate(seqs): 316 | try: 317 | firstIndex = s_y.tolist().index(0) 318 | except ValueError: 319 | firstIndex = len(s_y) - 1 320 | seqs_removed_list.append(s_y[:firstIndex]) 321 | return seqs_removed_list 322 | 323 | def ortho_weight(ndim, precision='float32'): 324 | W=numpy.random.randn(ndim, ndim) 325 | u,s,v=numpy.linalg.svd(W) 326 | return u.astype(precision) 327 | 328 | def norm_weight(nin, nout=None, scale=0.01, ortho=True, precision='float32'): 329 | if nout is None: 330 | nout=nin 331 | if nout == nin and ortho: 332 | W=ortho_weight(nin) 333 | else: 334 | W=scale * numpy.random.randn(nin,nout) 335 | return W.astype(precision) 336 | 337 | def tableLookup(vocab_size, embedding_size, scope="tableLookup", init_device='/cpu:0', reuse_var=False, prefix='tablelookup'): 338 | 339 | if not scope: 340 | scope=tf.get_variable_scope() 341 | 342 | with tf.variable_scope(scope) as vs: 343 | if not reuse_var: 344 | with tf.device(init_device): 345 | embeddings_init=norm_weight(vocab_size, embedding_size) 346 | embeddings=tf.get_variable('embeddings',shape=[vocab_size, embedding_size], initializer=tf.constant_initializer(embeddings_init)) 347 | else: 348 | tf.get_variable_scope().reuse_variables() 349 | embeddings=tf.get_variable('embeddings') 350 | return embeddings 351 | 352 | def FCLayer(state_below, input_size, output_size, is_3d = True, reuse_var = False, use_bias=True, activation=None, scope='ff', init_device='/cpu:0', prefix='ff', precision='float32'): 353 | 354 | if not scope: 355 | scope=tf.get_variable_scope() 356 | 357 | with tf.variable_scope(scope): 358 | if not reuse_var: 359 | with tf.device(init_device): 360 | W_init = norm_weight(input_size, output_size) 361 | matrix=tf.get_variable('W', [input_size, output_size], initializer=tf.constant_initializer(W_init), trainable=True) 362 | if use_bias: 363 | bias_init = numpy.zeros((output_size,)).astype(precision) 364 | bias = tf.get_variable('b', output_size, initializer=tf.constant_initializer(bias_init), trainable=True) 365 | else: 366 | tf.get_variable_scope().reuse_variables() 367 | matrix=tf.get_variable('W') 368 | if use_bias: 369 | bias=tf.get_variable('b') 370 | 371 | inputShape = tf.shape(state_below) 372 | if is_3d : 373 | state_below=tf.reshape(state_below, [-1, inputShape[2]]) 374 | output=tf.matmul(state_below, matrix) 375 | output=tf.reshape(output, [-1, inputShape[1] , output_size]) 376 | else : 377 | output=tf.matmul(state_below, matrix) 378 | if use_bias: 379 | output=tf.add(output, bias) 380 | if activation is not None: 381 | output = activation(output) 382 | return output 383 | 384 | 385 | def average_clip_gradient(tower_grads, clip_c): 386 | average_grads = [] 387 | for grad_and_vars in zip(*tower_grads): 388 | # Note that each grad_and_vars looks like the following: 389 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 390 | grads = [] 391 | for g, _ in grad_and_vars: 392 | # Add 0 dimension to the gradients to represent the tower. 393 | expanded_g = tf.expand_dims(g, 0) 394 | #Append on a 'tower' dimension which we will average over below. 395 | grads.append(expanded_g) 396 | # Average over the 'tower' dimension. 397 | grad = tf.concat(axis=0, values=grads) 398 | grad = tf.reduce_mean(grad, 0) 399 | # Keep in mind that the Variables are redundant because they are shared 400 | # across towers. So .. we will just return the first tower's pointer to 401 | # the Variable. 402 | v = grad_and_vars[0][1] 403 | grad_and_var = (grad, v) 404 | average_grads.append(grad_and_var) 405 | if clip_c > 0: 406 | grad, value = zip(*average_grads) 407 | grad, global_norm = tf.clip_by_global_norm(grad, clip_c) 408 | average_grads = zip(grad,value) 409 | 410 | #self.average_grads = average_grads 411 | 412 | return average_grads 413 | 414 | def average_clip_gradient_by_value(tower_grads, clip_min, clip_max): 415 | average_grads = [] 416 | for grad_and_vars in zip(*tower_grads): 417 | # Note that each grad_and_vars looks like the following: 418 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 419 | grads = [] 420 | for g, _ in grad_and_vars: 421 | # Add 0 dimension to the gradients to represent the tower. 422 | expanded_g = tf.expand_dims(g, 0) 423 | #Append on a 'tower' dimension which we will average over below. 424 | grads.append(expanded_g) 425 | # Average over the 'tower' dimension. 426 | grad = tf.concat(axis=0, values=grads) 427 | grad = tf.reduce_mean(grad, 0) 428 | # Keep in mind that the Variables are redundant because they are shared 429 | # across towers. So .. we will just return the first tower's pointer to 430 | # the Variable. 431 | v = grad_and_vars[0][1] 432 | grad_and_var = (grad, v) 433 | average_grads.append(grad_and_var) 434 | if clip_max > 0: 435 | grad, value = zip(*average_grads) 436 | grad = [tf.clip_by_value(x, clip_min, clip_max) for x in grad] 437 | average_grads = zip(grad,value) 438 | 439 | #self.average_grads = average_grads 440 | 441 | return average_grads 442 | 443 | 444 | def get_ngrams(input_tokens, max_n=None): 445 | if max_n is None: 446 | max_n = 4 447 | 448 | n_grams=[] 449 | for n in range(1, max_n+1): 450 | n_grams.append(defaultdict(int)) 451 | for n_gram in zip(*[input_tokens[i:] for i in range(n)]): 452 | n_grams[n-1][n_gram] +=1 453 | return n_grams 454 | 455 | def score(ref_tokens, hypothesis_tokens, max_n=None): 456 | if max_n is None: 457 | max_n =4 458 | 459 | def product(iterable): 460 | return reduce(mul, iterable, 1) 461 | 462 | def n_gram_precision(ref_ngrams, hyp_ngrams): 463 | precision=[] 464 | for n in range(1, max_n + 1): 465 | overlap = 0 466 | for ref_ngram, ref_ngram_count in ref_ngrams[n-1].iteritems(): 467 | if ref_ngram in hyp_ngrams[n-1]: 468 | overlap += min(ref_ngram_count, hyp_ngrams[n-1][ref_ngram]) 469 | hyp_length = max(0, len(hypothesis_tokens)-n+1) 470 | if n >=2: 471 | overlap += 1 472 | hyp_length += 1 473 | precision.append(overlap/hyp_length if hyp_length > 0 else 0.0) 474 | return precision 475 | 476 | def brevity_penalty(ref_length, hyp_length): 477 | return min(1.0, exp(1-(ref_length/hyp_length if hyp_length > 0 else 0.0))) 478 | 479 | hypothesis_length = len(hypothesis_tokens) 480 | ref_length = len(ref_tokens) 481 | hypothesis_ngrams = get_ngrams(hypothesis_tokens) 482 | ref_ngrams = get_ngrams(ref_tokens) 483 | 484 | np = n_gram_precision(ref_ngrams, hypothesis_ngrams) 485 | bp = brevity_penalty(ref_length, hypothesis_length) 486 | 487 | return product(np)**(1 / max_n) * bp 488 | 489 | 490 | -------------------------------------------------------------------------------- /shuffle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | from tempfile import mkstemp 6 | from subprocess import call 7 | 8 | 9 | 10 | def main(files): 11 | 12 | tf_os, tpath = mkstemp() 13 | tf = open(tpath, 'w') 14 | 15 | fds = [open(ff) for ff in files] 16 | 17 | for l in fds[0]: 18 | lines = [l.strip()] + [ff.readline().strip() for ff in fds[1:]] 19 | print >>tf, "|||".join(lines) 20 | 21 | [ff.close() for ff in fds] 22 | tf.close() 23 | 24 | tf = open(tpath, 'r') 25 | lines = tf.readlines() 26 | random.shuffle(lines) 27 | 28 | fds = [open(ff+'.shuf','w') for ff in files] 29 | 30 | for l in lines: 31 | s = l.strip().split('|||') 32 | for ii, fd in enumerate(fds): 33 | print >>fd, s[ii] 34 | 35 | [ff.close() for ff in fds] 36 | 37 | os.remove(tpath) 38 | 39 | if __name__ == '__main__': 40 | main(sys.argv[1:]) 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /tensor2tensor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/tensor2tensor/__init__.py -------------------------------------------------------------------------------- /tensor2tensor/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/tensor2tensor/__init__.pyc -------------------------------------------------------------------------------- /tensor2tensor/avg_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Script to average values of variables in a list of checkpoint files.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | # Dependency imports 21 | 22 | import numpy as np 23 | import six 24 | from six.moves import zip # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("checkpoints", "", 31 | "Comma-separated list of checkpoints to average.") 32 | flags.DEFINE_string("prefix", "", 33 | "Prefix (e.g., directory) to append to each checkpoint.") 34 | flags.DEFINE_string("output_path", "/tmp/averaged.ckpt", 35 | "Path to output the averaged checkpoint to.") 36 | 37 | 38 | def checkpoint_exists(path): 39 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 40 | tf.gfile.Exists(path + ".index")) 41 | 42 | 43 | def main(_): 44 | # Get the checkpoints list from flags and run some basic checks. 45 | checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] 46 | checkpoints = [c for c in checkpoints if c] 47 | if not checkpoints: 48 | raise ValueError("No checkpoints provided for averaging.") 49 | if FLAGS.prefix: 50 | checkpoints = [FLAGS.prefix + c for c in checkpoints] 51 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 52 | if not checkpoints: 53 | raise ValueError( 54 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) 55 | 56 | # Read variables from all checkpoints and average them. 57 | tf.logging.info("Reading variables and averaging checkpoints:") 58 | for c in checkpoints: 59 | tf.logging.info("%s ", c) 60 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 61 | var_values, var_dtypes = {}, {} 62 | for (name, shape) in var_list: 63 | if not name.startswith("global_step"): 64 | var_values[name] = np.zeros(shape) 65 | for checkpoint in checkpoints: 66 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 67 | for name in var_values: 68 | tensor = reader.get_tensor(name) 69 | var_dtypes[name] = tensor.dtype 70 | var_values[name] += tensor 71 | tf.logging.info("Read from checkpoint %s", checkpoint) 72 | for name in var_values: # Average. 73 | var_values[name] /= len(checkpoints) 74 | 75 | tf_vars = [ 76 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) 77 | for v in var_values 78 | ] 79 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 80 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 81 | global_step = tf.Variable( 82 | 0, name="global_step", trainable=False, dtype=tf.int64) 83 | saver = tf.train.Saver(tf.all_variables()) 84 | 85 | # Build a model consisting only of variables, set them to the average values. 86 | with tf.Session() as sess: 87 | sess.run(tf.initialize_all_variables()) 88 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 89 | six.iteritems(var_values)): 90 | sess.run(assign_op, {p: value}) 91 | # Use the built saver to save the averaged checkpoint. 92 | saver.save(sess, FLAGS.output_path, global_step=global_step) 93 | 94 | tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path) 95 | 96 | 97 | if __name__ == "__main__": 98 | tf.app.run() 99 | -------------------------------------------------------------------------------- /tensor2tensor/common_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for attention.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import math 21 | 22 | import tensorflow as tf 23 | 24 | from tensor2tensor import common_layers 25 | 26 | 27 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 28 | """Adds a bunch of sinusoids of different frequencies to a Tensor. 29 | 30 | Each channel of the input Tensor is incremented by a sinusoid of a different 31 | frequency and phase. 32 | 33 | This allows attention to learn to use absolute and relative positions. 34 | Timing signals should be added to some precursors of both the query and the 35 | memory inputs to attention. 36 | 37 | The use of relative position is possible because sin(x+y) and cos(x+y) can be 38 | experessed in terms of y, sin(x) and cos(x). 39 | 40 | In particular, we use a geometric sequence of timescales starting with 41 | min_timescale and ending with max_timescale. The number of different 42 | timescales is equal to channels / 2. For each timescale, we 43 | generate the two sinusoidal signals sin(timestep/timescale) and 44 | cos(timestep/timescale). All of these sinusoids are concatenated in 45 | the channels dimension. 46 | 47 | Args: 48 | x: a Tensor with shape [batch, length, channels] 49 | min_timescale: a float 50 | max_timescale: a float 51 | 52 | Returns: 53 | a Tensor the same shape as x. 54 | """ 55 | length = tf.shape(x)[1] 56 | channels = tf.shape(x)[2] 57 | position = tf.to_float(tf.range(length)) 58 | num_timescales = channels // 2 59 | log_timescale_increment = ( 60 | math.log(float(max_timescale) / float(min_timescale)) / 61 | (tf.to_float(num_timescales) - 1)) 62 | inv_timescales = min_timescale * tf.exp( 63 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 64 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 65 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 66 | signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) 67 | signal = tf.reshape(signal, [1, length, channels]) 68 | return x + signal 69 | 70 | 71 | def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4): 72 | """Adds a bunch of sinusoids of different frequencies to a Tensor. 73 | 74 | Each channel of the input Tensor is incremented by a sinusoid of a different 75 | frequency and phase in one of the positional dimensions. 76 | 77 | This allows attention to learn to use absolute and relative positions. 78 | Timing signals should be added to some precursors of both the query and the 79 | memory inputs to attention. 80 | 81 | The use of relative position is possible because sin(a+b) and cos(a+b) can be 82 | experessed in terms of b, sin(a) and cos(a). 83 | 84 | x is a Tensor with n "positional" dimensions, e.g. one dimension for a 85 | sequence or two dimensions for an image 86 | 87 | We use a geometric sequence of timescales starting with 88 | min_timescale and ending with max_timescale. The number of different 89 | timescales is equal to channels // (n * 2). For each timescale, we 90 | generate the two sinusoidal signals sin(timestep/timescale) and 91 | cos(timestep/timescale). All of these sinusoids are concatenated in 92 | the channels dimension. 93 | 94 | Args: 95 | x: a Tensor with shape [batch, d1 ... dn, channels] 96 | min_timescale: a float 97 | max_timescale: a float 98 | 99 | Returns: 100 | a Tensor the same shape as x. 101 | """ 102 | static_shape = x.get_shape().as_list() 103 | num_dims = len(static_shape) - 2 104 | channels = tf.shape(x)[-1] 105 | num_timescales = channels // (num_dims * 2) 106 | log_timescale_increment = ( 107 | math.log(float(max_timescale) / float(min_timescale)) / 108 | (tf.to_float(num_timescales) - 1)) 109 | inv_timescales = min_timescale * tf.exp( 110 | tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) 111 | for dim in xrange(num_dims): 112 | length = tf.shape(x)[dim + 1] 113 | position = tf.to_float(tf.range(length)) 114 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims( 115 | inv_timescales, 0) 116 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 117 | prepad = dim * 2 * num_timescales 118 | postpad = channels - (dim + 1) * 2 * num_timescales 119 | signal = tf.pad(signal, [[0, 0], [prepad, postpad]]) 120 | for _ in xrange(1 + dim): 121 | signal = tf.expand_dims(signal, 0) 122 | for _ in xrange(num_dims - 1 - dim): 123 | signal = tf.expand_dims(signal, -2) 124 | x += signal 125 | return x 126 | 127 | 128 | def add_positional_embedding_nd(x, max_length, name): 129 | """Add n-dimensional positional embedding. 130 | 131 | Adds embeddings to represent the positional dimensions of the tensor. 132 | The input tensor has n positional dimensions - i.e. 1 for text, 2 for images, 133 | 3 for video, etc. 134 | 135 | Args: 136 | x: a Tensor with shape [batch, p1 ... pn, depth] 137 | max_length: an integer. static maximum size of any dimension. 138 | name: a name for this layer. 139 | 140 | Returns: 141 | a Tensor the same shape as x. 142 | """ 143 | static_shape = x.get_shape().as_list() 144 | dynamic_shape = tf.shape(x) 145 | num_dims = len(static_shape) - 2 146 | depth = static_shape[-1] 147 | base_shape = [1] * (num_dims + 1) + [depth] 148 | base_start = [0] * (num_dims + 2) 149 | base_size = [-1] + [1] * num_dims + [depth] 150 | for i in xrange(num_dims): 151 | shape = base_shape[:] 152 | start = base_start[:] 153 | size = base_size[:] 154 | shape[i + 1] = max_length 155 | size[i + 1] = dynamic_shape[i + 1] 156 | var = (tf.get_variable( 157 | name + "_%d" % i, shape, 158 | initializer=tf.random_normal_initializer(0, depth ** -0.5)) 159 | * (depth ** 0.5)) 160 | x += tf.slice(var, start, size) 161 | return x 162 | 163 | 164 | def embedding_to_padding(emb): 165 | """Input embeddings -> is_padding. 166 | 167 | We have hacked symbol_modality to return all-zero embeddings for padding. 168 | 169 | Args: 170 | emb: a Tensor with shape [..., depth]. 171 | Returns: 172 | a boolean Tensor with shape [...]. 173 | """ 174 | emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1) 175 | return tf.equal(emb_sum, 0.0) 176 | 177 | 178 | def attention_bias_lower_triangle(length): 179 | """Create an bias tensor to be added to attention logits. 180 | 181 | Args: 182 | length: a Scalar. 183 | 184 | Returns: 185 | a `Tensor` with shape [1, 1, length, length]. 186 | """ 187 | lower_triangle = tf.matrix_band_part(tf.ones([length, length]), -1, 0) 188 | ret = -1e9 * (1.0 - lower_triangle) 189 | return tf.reshape(ret, [1, 1, length, length]) 190 | 191 | 192 | # added by zhyang 193 | # 20180111 194 | 195 | def attention_bias_upper_triangle(length): 196 | upper_triangle = tf.matrix_band_part(tf.ones([length, length]), 0, -1) 197 | ret = -1e9 * (1.0 - upper_triangle) 198 | return tf.reshape(ret, [1, 1, length, length]) 199 | 200 | 201 | def attention_bias_ignore_padding(memory_padding): 202 | """Create an bias tensor to be added to attention logits. 203 | 204 | Args: 205 | memory_padding: a boolean `Tensor` with shape [batch, memory_length]. 206 | 207 | Returns: 208 | a `Tensor` with shape [batch, 1, 1, memory_length]. 209 | """ 210 | ret = tf.to_float(memory_padding) * -1e9 211 | return tf.expand_dims(tf.expand_dims(ret, 1), 1) 212 | 213 | def split_last_dimension(x, n): 214 | """Reshape x so that the last dimension becomes two dimensions. 215 | 216 | The first of these two dimensions is n. 217 | 218 | Args: 219 | x: a Tensor with shape [..., m] 220 | n: an integer. 221 | 222 | Returns: 223 | a Tensor with shape [..., n, m/n] 224 | """ 225 | old_shape = x.get_shape().dims 226 | last = old_shape[-1] 227 | new_shape = old_shape[:-1] + [n] + [last // n if last else None] 228 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0)) 229 | ret.set_shape(new_shape) 230 | return ret 231 | 232 | 233 | def combine_last_two_dimensions(x): 234 | """Reshape x so that the last two dimension become one. 235 | 236 | Args: 237 | x: a Tensor with shape [..., a, b] 238 | 239 | Returns: 240 | a Tensor with shape [..., ab] 241 | """ 242 | old_shape = x.get_shape().dims 243 | a, b = old_shape[-2:] 244 | new_shape = old_shape[:-2] + [a * b if a and b else None] 245 | ret = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0)) 246 | ret.set_shape(new_shape) 247 | return ret 248 | 249 | 250 | def split_heads(x, num_heads): 251 | """Split channels (dimension 3) into multiple heads (becomes dimension 1). 252 | 253 | Args: 254 | x: a Tensor with shape [batch, length, channels] 255 | num_heads: an integer 256 | 257 | Returns: 258 | a Tensor with shape [batch, num_heads, length, channels / num_heads] 259 | """ 260 | return tf.transpose(split_last_dimension(x, num_heads), [0, 2, 1, 3]) 261 | 262 | 263 | def combine_heads(x): 264 | """Inverse of split_heads. 265 | 266 | Args: 267 | x: a Tensor with shape [batch, num_heads, length, channels / num_heads] 268 | 269 | Returns: 270 | a Tensor with shape [batch, length, channels] 271 | """ 272 | return combine_last_two_dimensions(tf.transpose(x, [0, 2, 1, 3])) 273 | 274 | 275 | def attention_image_summary(attn, image_shapes=None): 276 | """Compute color image summary. 277 | 278 | Args: 279 | attn: a Tensor with shape [batch, num_heads, query_length, memory_length] 280 | image_shapes: optional quadruple of integer scalars. 281 | If the query positions and memory positions represent the 282 | pixels of a flattened image, then pass in their dimensions: 283 | (query_rows, query_cols, memory_rows, memory_cols). 284 | """ 285 | num_heads = attn.get_shape().as_list()[1] 286 | # [batch, query_length, memory_length, num_heads] 287 | image = tf.transpose(attn, [0, 2, 3, 1]) 288 | image = tf.pow(image, 0.2) # for high-dynamic-range 289 | # Each head will correspond to one of RGB. 290 | # pad the heads to be a multiple of 3 291 | image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, -num_heads % 3]]) 292 | image = split_last_dimension(image, 3) 293 | image = tf.reduce_max(image, 4) 294 | if image_shapes is not None: 295 | q_rows, q_cols, m_rows, m_cols = list(image_shapes) 296 | image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3]) 297 | image = tf.transpose(image, [0, 1, 3, 2, 4, 5]) 298 | image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3]) 299 | tf.summary.image("attention", image, max_outputs=1) 300 | 301 | 302 | def dot_product_attention(q, 303 | k, 304 | v, 305 | bias, 306 | dropout_rate=0.0, 307 | summaries=False, 308 | image_shapes=None, 309 | name=None): 310 | """dot-product attention. 311 | 312 | Args: 313 | q: a Tensor with shape [batch, heads, length_q, depth_k] 314 | k: a Tensor with shape [batch, heads, length_kv, depth_k] 315 | v: a Tensor with shape [batch, heads, length_kv, depth_v] 316 | bias: bias Tensor (see attention_bias()) 317 | dropout_rate: a floating point number 318 | summaries: a boolean 319 | image_shapes: optional quadruple of integer scalars for image summary. 320 | If the query positions and memory positions represent the 321 | pixels of a flattened image, then pass in their dimensions: 322 | (query_rows, query_cols, memory_rows, memory_cols). 323 | name: an optional string 324 | 325 | Returns: 326 | A Tensor. 327 | """ 328 | with tf.variable_scope( 329 | name, default_name="dot_product_attention", values=[q, k, v]): 330 | # [batch, num_heads, query_length, memory_length] 331 | logits = tf.matmul(q, k, transpose_b=True) 332 | if bias is not None: 333 | logits += bias 334 | weights = tf.nn.softmax(logits, name="attention_weights") 335 | # dropping out the attention links for each of the heads 336 | weights = tf.nn.dropout(weights, 1.0 - dropout_rate) 337 | if summaries and not tf.get_variable_scope().reuse: 338 | attention_image_summary(weights, image_shapes) 339 | return tf.matmul(weights, v) 340 | 341 | 342 | def multihead_attention(query_antecedent, 343 | memory_antecedent, 344 | bias, 345 | total_key_depth, 346 | total_value_depth, 347 | output_depth, 348 | num_heads, 349 | dropout_rate, 350 | summaries=False, 351 | image_shapes=None, 352 | name=None): 353 | """Multihead scaled-dot-product attention with input/output transformations. 354 | 355 | Args: 356 | query_antecedent: a Tensor with shape [batch, length_q, channels] 357 | memory_antecedent: a Tensor with shape [batch, length_m, channels] 358 | bias: bias Tensor (see attention_bias()) 359 | total_key_depth: an integer 360 | total_value_depth: an integer 361 | output_depth: an integer 362 | num_heads: an integer dividing total_key_depth and total_value_depth 363 | dropout_rate: a floating point number 364 | summaries: a boolean 365 | image_shapes: optional quadruple of integer scalars for image summary. 366 | If the query positions and memory positions represent the 367 | pixels of a flattened image, then pass in their dimensions: 368 | (query_rows, query_cols, memory_rows, memory_cols). 369 | name: an optional string 370 | 371 | Returns: 372 | A Tensor. 373 | """ 374 | with tf.variable_scope( 375 | name, 376 | default_name="multihead_attention", 377 | values=[query_antecedent, memory_antecedent]): 378 | if memory_antecedent is None: 379 | # self attention 380 | combined = common_layers.conv1d( 381 | query_antecedent, 382 | total_key_depth * 2 + total_value_depth, 383 | 1, 384 | name="qkv_transform") 385 | q, k, v = tf.split( 386 | combined, [total_key_depth, total_key_depth, total_value_depth], 387 | axis=2) 388 | else: 389 | q = common_layers.conv1d( 390 | query_antecedent, total_key_depth, 1, name="q_transform") 391 | combined = common_layers.conv1d( 392 | memory_antecedent, 393 | total_key_depth + total_value_depth, 394 | 1, 395 | name="kv_transform") 396 | k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2) 397 | q = split_heads(q, num_heads) 398 | k = split_heads(k, num_heads) 399 | v = split_heads(v, num_heads) 400 | key_depth_per_head = total_key_depth // num_heads 401 | q *= key_depth_per_head**-0.5 402 | x = dot_product_attention( 403 | q, k, v, bias, dropout_rate, summaries, image_shapes) 404 | x = combine_heads(x) 405 | x = common_layers.conv1d(x, output_depth, 1, name="output_transform") 406 | return x 407 | -------------------------------------------------------------------------------- /tensor2tensor/common_attention.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/tensor2tensor/common_attention.pyc -------------------------------------------------------------------------------- /tensor2tensor/common_layers.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/tensor2tensor/common_layers.pyc -------------------------------------------------------------------------------- /tensor2tensor/expert_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwajeet93/clqg/75f7311983f48f164fd371641c9abaabbba9ef3d/tensor2tensor/expert_utils.pyc -------------------------------------------------------------------------------- /text_disc_pretrain.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 2 | 3 | python text_discriminator_pretrain.py -c ./configs/config_text_discriminator_pretrain.yaml 4 | -------------------------------------------------------------------------------- /text_discriminator_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import yaml 3 | import time 4 | import os 5 | import sys 6 | import logging 7 | from argparse import ArgumentParser 8 | import tensorflow as tf 9 | 10 | from utils import DataUtil, AttrDict 11 | from model import Model 12 | from cnn_text_discriminator import text_DisCNN 13 | from share_function import deal_generated_samples 14 | from share_function import extend_sentence_to_maxlen 15 | from share_function import FlushFile 16 | 17 | def gan_train(config): 18 | sess_config = tf.ConfigProto() 19 | sess_config.gpu_options.allow_growth = True 20 | sess_config.allow_soft_placement = True 21 | 22 | default_graph=tf.Graph() 23 | with default_graph.as_default(): 24 | 25 | sess = tf.Session(config=sess_config, graph=default_graph) 26 | logger = logging.getLogger('') 27 | 28 | dis_filter_sizes = [i for i in range(1, config.train.dis_max_len, 4)] 29 | dis_num_filters = [(100+i * 10) for i in range(1, config.train.dis_max_len, 4)] 30 | 31 | #print("the scope is ", config.train.dis_scope) 32 | 33 | discriminator = text_DisCNN( 34 | sess=sess, 35 | max_len=config.train.dis_max_len, 36 | num_classes=2, 37 | vocab_size_s=config.src_vocab_size_a, 38 | batch_size=config.train.dis_batch_size, 39 | dim_word=config.train.dis_dim_word, 40 | filter_sizes=dis_filter_sizes, 41 | num_filters=dis_num_filters, 42 | source_dict=config.train.dis_src_vocab, 43 | gpu_device=config.train.devices, 44 | s_domain_data=config.train.s_domain_data, 45 | s_domain_generated_data=config.train.s_domain_generated_data, 46 | dev_s_domain_data=config.train.dev_s_domain_data, 47 | dev_s_domain_generated_data=config.train.dev_s_domain_generated_data, 48 | max_epoches=config.train.dis_max_epoches, 49 | dispFreq=config.train.dis_dispFreq, 50 | saveFreq=config.train.dis_saveFreq, 51 | saveto=config.train.dis_saveto, 52 | reload=config.train.dis_reload, 53 | clip_c=config.train.dis_clip_c, 54 | optimizer=config.train.dis_optimizer, 55 | reshuffle=config.train.dis_reshuffle, 56 | scope=config.train.text_scope 57 | ) 58 | 59 | logging.info("text_discriminator pretrain begins!") 60 | discriminator.train() 61 | logging.info("text_discriminator pretrain done") 62 | 63 | 64 | if __name__ == '__main__': 65 | sys.stdout = FlushFile(sys.stdout) 66 | parser = ArgumentParser() 67 | parser.add_argument('-c', '--config', dest='config') 68 | args = parser.parse_args() 69 | # Read config 70 | config = AttrDict(yaml.load(open(args.config))) 71 | # Logger 72 | if not os.path.exists(config.train.logdir): 73 | os.makedirs(config.train.logdir) 74 | logging.basicConfig(filename=config.train.logdir+'/train.log', level=logging.DEBUG) 75 | console = logging.StreamHandler() 76 | console.setLevel(logging.INFO) 77 | logging.getLogger('').addHandler(console) 78 | # Train 79 | gan_train(config) 80 | 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import yaml 3 | import time 4 | import os 5 | import logging 6 | from argparse import ArgumentParser 7 | import tensorflow as tf 8 | 9 | from utils import DataUtil, AttrDict 10 | from model import Model 11 | 12 | 13 | def train(config): 14 | logger = logging.getLogger('') 15 | 16 | """Train a model with a config file.""" 17 | du = DataUtil(config=config) 18 | du.load_vocab() 19 | 20 | model = Model(config=config) 21 | model.build_variational_train_model() 22 | 23 | sess_config = tf.ConfigProto() 24 | sess_config.gpu_options.allow_growth = True 25 | sess_config.allow_soft_placement = True 26 | 27 | with model.graph.as_default(): 28 | saver = tf.train.Saver(var_list=tf.global_variables()) 29 | summary_writer = tf.summary.FileWriter(config.train.logdir, graph=model.graph) 30 | # saver_partial = tf.train.Saver(var_list=[v for v in tf.trainable_variables() if 'Adam' not in v.name]) 31 | 32 | with tf.Session(config=sess_config) as sess: 33 | # Initialize all variables. 34 | sess.run(tf.global_variables_initializer()) 35 | try: 36 | # saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 37 | # print('Restore partial model from %s.' % config.train.logdir) 38 | saver.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 39 | except: 40 | logger.info('Failed to reload model.') 41 | for epoch in range(1, config.train.num_epochs+1): 42 | for batch in du.get_training_batches_with_buckets(): 43 | start_time = time.time() 44 | step = sess.run(model.global_step) 45 | # Summary 46 | if step % config.train.summary_freq == 0: 47 | step, lr, gnorm, loss, acc, summary, _ = sess.run( 48 | [model.global_step, model.learning_rate, model.grads_norm, 49 | model.loss, model.acc, model.summary_op, model.train_op], 50 | feed_dict={model.src_a_pl: batch[0], model.dst_a_pl: batch[0], 51 | model.src_b_pl: batch[1], model.dst_b_pl: batch[1]}) 52 | summary_writer.add_summary(summary, global_step=step) 53 | else: 54 | step, lr, gnorm, loss, acc, _ = sess.run( 55 | [model.global_step, model.learning_rate, model.grads_norm, 56 | model.loss, model.acc, model.train_op], 57 | feed_dict={model.src_a_pl: batch[0], model.dst_a_pl: batch[0], 58 | model.src_b_pl: batch[1], model.dst_b_pl: batch[1]}) 59 | 60 | if step % config.train.disp_freq == 0: 61 | logger.info( 62 | 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}\tacc: {5:.4f}\ttime: {6:.4f}'. 63 | format(epoch, step, lr, gnorm, loss, acc, time.time() - start_time)) 64 | 65 | # Save model 66 | if step % config.train.save_freq == 0: 67 | mp = config.train.logdir + '/model_epoch_%d_step_%d' % (epoch, step) 68 | saver.save(sess, mp) 69 | logger.info('Save model in %s.' % mp) 70 | logger.info("Finish training.") 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = ArgumentParser() 75 | parser.add_argument('-c', '--config', dest='config') 76 | args = parser.parse_args() 77 | # Read config 78 | config = AttrDict(yaml.load(open(args.config))) 79 | # Logger 80 | if not os.path.exists(config.train.logdir): 81 | os.makedirs(config.train.logdir) 82 | logging.basicConfig(filename=config.train.logdir+'/train.log', level=logging.DEBUG) 83 | console = logging.StreamHandler() 84 | console.setLevel(logging.INFO) 85 | logging.getLogger('').addHandler(console) 86 | # Train 87 | train(config) 88 | -------------------------------------------------------------------------------- /train_qgen.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import yaml 3 | import time 4 | import os 5 | import logging 6 | import numpy as np 7 | from argparse import ArgumentParser 8 | import tensorflow as tf 9 | 10 | from utils import DataUtil, AttrDict 11 | from model import Model 12 | from share_function import deal_generated_samples 13 | import codecs 14 | 15 | 16 | def train(config): 17 | logger = logging.getLogger('') 18 | 19 | """Train a model with a config file.""" 20 | du = DataUtil(config=config) 21 | du.load_vocab(src_vocab=config.src_vocab, 22 | dst_vocab=config.dst_vocab, 23 | src_vocab_size=config.src_vocab_size_a, 24 | dst_vocab_size=config.src_vocab_size_b) 25 | 26 | model = Model(config=config) 27 | model.build_variational_train_model() 28 | 29 | sess_config = tf.ConfigProto() 30 | sess_config.gpu_options.allow_growth = True 31 | sess_config.allow_soft_placement = True 32 | 33 | with model.graph.as_default(): 34 | if config.train.restore_partial: 35 | saver_partial = tf.train.Saver(var_list=model.restore_param) 36 | if config.train.restore_decoder: 37 | saver_partial_decoder = tf.train.Saver(var_list=model.restore_decoder_param) 38 | saver = tf.train.Saver(var_list=tf.global_variables()) 39 | summary_writer = tf.summary.FileWriter(config.train.logdir, graph=model.graph) 40 | # saver_partial = tf.train.Saver(var_list=[v for v in tf.trainable_variables() if 'Adam' not in v.name]) 41 | 42 | with tf.Session(config=sess_config) as sess: 43 | # Initialize all variables. 44 | sess.run(tf.global_variables_initializer()) 45 | reload_pretrain_embedding=False 46 | try: 47 | #saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 48 | #print('Restore partial model from %s.' % config.train.logdir) 49 | if config.train.restore_partial: 50 | saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 51 | logger.info('Partial restore complete') 52 | if config.train.restore_decoder: 53 | saver_partial_decoder.restore(sess, tf.train.latest_checkpoint(config.train.decoder_logdir)) 54 | logger.info('Decoder restore complete') 55 | if not config.train.restore_embed and not config.train.restore_decoder_embed: 56 | reload_pretrain_embedding = True 57 | else: 58 | saver.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 59 | except: 60 | logger.info('Failed to reload model.') 61 | reload_pretrain_embedding=True 62 | 63 | if reload_pretrain_embedding: 64 | logger.info('reload the pretrained embeddings for the encoders') 65 | src_pretrained_embedding={} 66 | dst_pretrained_embedding={} 67 | try: 68 | 69 | for l in codecs.open(config.train.src_pretrain_wordemb_path, 'r', 'utf-8'): 70 | word_emb=l.strip().split() 71 | # print(word_emb) 72 | if len(word_emb)== config.hidden_units + 1: 73 | word, emb = word_emb[0], np.array(map(float, word_emb[1:])) 74 | src_pretrained_embedding[word]=emb 75 | 76 | for l in codecs.open(config.train.dst_pretrain_wordemb_path, 'r', 'utf-8'): 77 | word_emb=l.strip().split() 78 | if len(word_emb)==config.hidden_units + 1: 79 | word, emb = word_emb[0], np.array(map(float, word_emb[1:])) 80 | dst_pretrained_embedding[word]=emb 81 | 82 | logger.info('reload the word embedding done') 83 | 84 | tf.get_variable_scope().reuse_variables() 85 | src_embed_a=tf.get_variable('enc_aembedding/src_embedding/kernel') 86 | src_embed_b=tf.get_variable('enc_bembedding/src_embedding/kernel') 87 | 88 | dst_embed_a=tf.get_variable('dec_aembedding/dst_embedding/kernel') 89 | dst_embed_b=tf.get_variable('dec_bembedding/dst_embedding/kernel') 90 | 91 | count_a=0 92 | src_value_a=sess.run(src_embed_a) 93 | dst_value_a=sess.run(dst_embed_a) 94 | # print(src_value_a) 95 | for word in src_pretrained_embedding: 96 | if word in du.src2idx: 97 | id = du.src2idx[word] 98 | # print(id) 99 | src_value_a[id] = src_pretrained_embedding[word] 100 | dst_value_a[id] = src_pretrained_embedding[word] 101 | count_a += 1 102 | sess.run(src_embed_a.assign(src_value_a)) 103 | sess.run(dst_embed_a.assign(dst_value_a)) 104 | # print(sess.run(src_embed_a)) 105 | 106 | 107 | count_b=0 108 | src_value_b = sess.run(src_embed_b) 109 | dst_value_b = sess.run(dst_embed_b) 110 | for word in dst_pretrained_embedding: 111 | if word in du.dst2idx: 112 | id = du.dst2idx[word] 113 | # print(id) 114 | src_value_b[id] = dst_pretrained_embedding[word] 115 | dst_value_b[id] = dst_pretrained_embedding[word] 116 | count_b += 1 117 | sess.run(src_embed_b.assign(src_value_b)) 118 | sess.run(dst_embed_b.assign(dst_value_b)) 119 | 120 | logger.info('restore %d src_embedding and %d dst_embedding done' %(count_a, count_b)) 121 | 122 | except: 123 | logger.info('Failed to load the pretriaed embeddings') 124 | 125 | # tmp_writer = codecs.open('tmp_test', 'w', 'utf-8') 126 | if config.train_ratio > 0: 127 | curr_english = 0 128 | 129 | for epoch in range(1, config.train.num_epochs+1): 130 | #for batch in du.get_training_batches_with_buckets(): 131 | for batch in du.get_training_batches(shuffle=True): 132 | # swap the batch[0] and batch[1] accroding to whether the length of the sequence is odd or even 133 | # batch_swap=[] 134 | # swap_0 = np.arange(batch[0].shape[1]) 135 | # swap_1 = np.arange(batch[1].shape[1]) 136 | # 137 | # if len(swap_0) % 2 == 0: 138 | # swap_0[0::2]+=1 139 | # swap_0[1::2]-=1 140 | # else: 141 | # swap_0[0:-1:2]+=1 142 | # swap_0[1::2]-=1 143 | # 144 | # if len(swap_1) % 2 == 0: 145 | # swap_1[0::2]+=1 146 | # swap_1[1::2]-=1 147 | # else: 148 | # swap_1[0:-1:2] += 1 149 | # swap_1[1::2] -= 1 150 | # 151 | # batch_swap.append(batch[0].transpose()[swap_0].transpose()) 152 | # batch_swap.append(batch[1].transpose()[swap_1].transpose()) 153 | 154 | #print("Len 1", len(batch[2])) 155 | #print("Len 2", len(batch[3])) 156 | #print(batch_swap[0]) 157 | 158 | # randomly shuffle the batch[0] and batch[1] 159 | #batch_shuffle=[] 160 | #shuffle_0_indices = np.random.permutation(np.arange(batch[0].shape[1])) 161 | #shuffle_1_indices = np.random.permutation(np.arange(batch[1].shape[1])) 162 | #batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose()) 163 | #batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose()) 164 | 165 | 166 | def get_shuffle_k_indices(length, shuffle_k): 167 | shuffle_k_indices = [] 168 | rand_start = np.random.randint(shuffle_k) 169 | 170 | indices_list_start = list(np.random.permutation(np.arange(0, rand_start))) 171 | shuffle_k_indices.extend(indices_list_start) 172 | 173 | for i in range(rand_start, length, shuffle_k): 174 | if i + shuffle_k > length: 175 | indices_list_i = list(np.random.permutation(np.arange(i, length))) 176 | else: 177 | indices_list_i = list(np.random.permutation(np.arange(i, i + shuffle_k))) 178 | 179 | shuffle_k_indices.extend(indices_list_i) 180 | 181 | return np.array(shuffle_k_indices) 182 | 183 | batch_shuffle=[] 184 | shuffle_0_indices = get_shuffle_k_indices(batch[0].shape[1], config.train.shuffle_k) 185 | shuffle_1_indices = get_shuffle_k_indices(batch[1].shape[1], config.train.shuffle_k) 186 | #print(shuffle_0_indices) 187 | batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose()) 188 | batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose()) 189 | 190 | start_time = time.time() 191 | step = sess.run(model.global_step) 192 | 193 | 194 | 195 | if config.train_ratio == 0: 196 | 197 | step, lr, gnorm_aa, loss_aa, acc_aa, _ = sess.run( 198 | [model.global_step, model.learning_rate, model.grads_norm_aa, 199 | model.loss_aa, model.acc_aa, model.train_op_aa], 200 | feed_dict={model.src_a_pl: batch[0], model.dst_a_pl: batch[2]}) 201 | 202 | 203 | step, lr, gnorm_bb, loss_bb, acc_bb, _ = sess.run( 204 | [model.global_step, model.learning_rate, model.grads_norm_bb, 205 | model.loss_bb, model.acc_bb, model.train_op_bb], 206 | feed_dict={model.src_b_pl: batch[1], model.dst_b_pl: batch[3]}) 207 | 208 | else: 209 | 210 | step, lr, gnorm_aa, loss_aa, acc_aa, _ = sess.run( 211 | [model.global_step, model.learning_rate, model.grads_norm_aa, 212 | model.loss_aa, model.acc_aa, model.train_op_aa], 213 | feed_dict={model.src_a_pl: batch[0], model.dst_a_pl: batch[2]}) 214 | 215 | curr_english += 1 216 | if curr_english == config.train_ratio: 217 | curr_english = 0 218 | step, lr, gnorm_bb, loss_bb, acc_bb, _ = sess.run( 219 | [model.global_step, model.learning_rate, model.grads_norm_bb, 220 | model.loss_bb, model.acc_bb, model.train_op_bb], 221 | feed_dict={model.src_b_pl: batch[1], model.dst_b_pl: batch[3]}) 222 | 223 | # this step takes too much time 224 | #generate_ab, generate_ba = sess.run( 225 | # [model.generate_ab, model.generate_ba], 226 | # feed_dict={model.src_a_pl: batch[0], model.src_b_pl: batch[1]}) 227 | 228 | #generate_ab_dealed, _ = deal_generated_samples(generate_ab, du.dst2idx) 229 | #generate_ba_dealed, _ = deal_generated_samples(generate_ba, du.src2idx) 230 | 231 | #for sent in du.indices_to_words(batch[0], o='src'): 232 | # print(sent, file=tmp_writer) 233 | #for sent in du.indices_to_words(generate_ab_dealed, o='dst'): 234 | # print(sent, file=tmp_writer) 235 | 236 | #step, acc_ab, loss_ab, _ = sess.run( 237 | # [model.global_step, model.acc_ab, model.loss_ab, model.train_op_ab], 238 | # feed_dict={model.src_a_pl:generate_ba_dealed, model.dst_b_pl: batch[3]}) 239 | 240 | #step, acc_ba, loss_ba, _ = sess.run( 241 | # [model.global_step, model.acc_ba, model.loss_ba, model.train_op_ba], 242 | # feed_dict={model.src_b_pl:generate_ab_dealed, model.dst_a_pl: batch[2]}) 243 | 244 | #if step % config.train.disp_freq == 0: 245 | # logger.info('epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}' 246 | # '\tacc: {5:.4f}\tcross_loss: {6:.4f}\tcross_acc: {7:.4f}\ttime: {8:.4f}' 247 | # .format(epoch, step, lr, gnorm_aa, loss_aa, acc_aa, loss_ab, acc_ab, 248 | # time.time() - start_time)) 249 | 250 | if step % config.train.disp_freq == 0: 251 | logger.info('epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}' 252 | '\tacc: {5:.4f}\ttime: {6:.4f}' 253 | .format(epoch, step, lr, gnorm_bb, loss_bb, acc_bb, 254 | time.time() - start_time)) 255 | 256 | 257 | # Save model 258 | if step % config.train.save_freq == 0: 259 | mp = config.train.logdir + '/model_epoch_%d_step_%d' % (epoch, step) 260 | saver.save(sess, mp) 261 | logger.info('Save model in %s.' % mp) 262 | 263 | logger.info("Finish training.") 264 | 265 | 266 | if __name__ == '__main__': 267 | parser = ArgumentParser() 268 | parser.add_argument('-c', '--config', dest='config') 269 | args = parser.parse_args() 270 | # Read config 271 | config = AttrDict(yaml.load(open(args.config))) 272 | # Logger 273 | if not os.path.exists(config.train.logdir): 274 | os.makedirs(config.train.logdir) 275 | logging.basicConfig(filename=config.train.logdir+'/train.log', level=logging.DEBUG) 276 | console = logging.StreamHandler() 277 | console.setLevel(logging.INFO) 278 | logging.getLogger('').addHandler(console) 279 | # Train 280 | train(config) 281 | -------------------------------------------------------------------------------- /train_unSuper_en2de.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 3 | 4 | python train_unsupervised.py -c ./configs/config_generator_train.yaml -------------------------------------------------------------------------------- /train_unsupervised.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import yaml 3 | import time 4 | import os 5 | import logging 6 | import numpy as np 7 | from argparse import ArgumentParser 8 | import tensorflow as tf 9 | 10 | from utils import DataUtil, AttrDict 11 | from model import Model 12 | from share_function import deal_generated_samples 13 | import codecs 14 | 15 | 16 | def train(config): 17 | logger = logging.getLogger('') 18 | 19 | """Train a model with a config file.""" 20 | du = DataUtil(config=config) 21 | du.load_vocab(src_vocab=config.src_vocab, 22 | dst_vocab=config.dst_vocab, 23 | src_vocab_size=config.src_vocab_size_a, 24 | dst_vocab_size=config.src_vocab_size_b) 25 | 26 | model = Model(config=config) 27 | model.build_variational_train_model() 28 | 29 | sess_config = tf.ConfigProto() 30 | sess_config.gpu_options.allow_growth = True 31 | sess_config.allow_soft_placement = True 32 | 33 | with model.graph.as_default(): 34 | if config.train.restore_partial: 35 | saver_partial = tf.train.Saver(var_list=model.restore_param) 36 | if config.train.restore_decoder: 37 | saver_partial_decoder = tf.train.Saver(var_list=model.restore_decoder_param) 38 | saver = tf.train.Saver(var_list=tf.global_variables()) 39 | summary_writer = tf.summary.FileWriter(config.train.logdir, graph=model.graph) 40 | # saver_partial = tf.train.Saver(var_list=[v for v in tf.trainable_variables() if 'Adam' not in v.name]) 41 | 42 | with tf.Session(config=sess_config) as sess: 43 | # Initialize all variables. 44 | sess.run(tf.global_variables_initializer()) 45 | reload_pretrain_embedding=False 46 | try: 47 | #saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 48 | #print('Restore partial model from %s.' % config.train.logdir) 49 | if config.train.restore_partial: 50 | saver_partial.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 51 | logger.info('Partial restore complete') 52 | if config.train.restore_decoder: 53 | saver_partial_decoder.restore(sess, tf.train.latest_checkpoint(config.train.decoder_logdir)) 54 | logger.info('Decoder restore complete') 55 | if not config.train.restore_embed and not config.train.restore_decoder_embed: 56 | reload_pretrain_embedding = True 57 | 58 | else: 59 | saver.restore(sess, tf.train.latest_checkpoint(config.train.logdir)) 60 | except: 61 | logger.info('Failed to reload model.') 62 | reload_pretrain_embedding=True 63 | 64 | if reload_pretrain_embedding: 65 | logger.info('reload the pretrained embeddings for the encoders') 66 | src_pretrained_embedding={} 67 | dst_pretrained_embedding={} 68 | try: 69 | 70 | for l in codecs.open(config.train.src_pretrain_wordemb_path, 'r', 'utf-8'): 71 | word_emb=l.strip().split() 72 | # print(word_emb) 73 | if len(word_emb)== config.hidden_units + 1: 74 | word, emb = word_emb[0], np.array(map(float, word_emb[1:])) 75 | src_pretrained_embedding[word]=emb 76 | 77 | for l in codecs.open(config.train.dst_pretrain_wordemb_path, 'r', 'utf-8'): 78 | word_emb=l.strip().split() 79 | if len(word_emb)==config.hidden_units + 1: 80 | word, emb = word_emb[0], np.array(map(float, word_emb[1:])) 81 | dst_pretrained_embedding[word]=emb 82 | 83 | logger.info('reload the word embedding done') 84 | 85 | tf.get_variable_scope().reuse_variables() 86 | src_embed_a=tf.get_variable('enc_aembedding/src_embedding/kernel') 87 | src_embed_b=tf.get_variable('enc_bembedding/src_embedding/kernel') 88 | 89 | dst_embed_a=tf.get_variable('dec_aembedding/dst_embedding/kernel') 90 | dst_embed_b=tf.get_variable('dec_bembedding/dst_embedding/kernel') 91 | 92 | count_a=0 93 | src_value_a=sess.run(src_embed_a) 94 | dst_value_a=sess.run(dst_embed_a) 95 | # print(src_value_a) 96 | for word in src_pretrained_embedding: 97 | if word in du.src2idx: 98 | id = du.src2idx[word] 99 | # print(id) 100 | src_value_a[id] = src_pretrained_embedding[word] 101 | dst_value_a[id] = src_pretrained_embedding[word] 102 | count_a += 1 103 | sess.run(src_embed_a.assign(src_value_a)) 104 | sess.run(dst_embed_a.assign(dst_value_a)) 105 | # print(sess.run(src_embed_a)) 106 | 107 | 108 | count_b=0 109 | src_value_b = sess.run(src_embed_b) 110 | dst_value_b = sess.run(dst_embed_b) 111 | for word in dst_pretrained_embedding: 112 | if word in du.dst2idx: 113 | id = du.dst2idx[word] 114 | # print(id) 115 | src_value_b[id] = dst_pretrained_embedding[word] 116 | dst_value_b[id] = dst_pretrained_embedding[word] 117 | count_b += 1 118 | sess.run(src_embed_b.assign(src_value_b)) 119 | sess.run(dst_embed_b.assign(dst_value_b)) 120 | 121 | logger.info('restore %d src_embedding and %d dst_embedding done' %(count_a, count_b)) 122 | 123 | except: 124 | logger.info('Failed to load the pretriaed embeddings') 125 | 126 | # tmp_writer = codecs.open('tmp_test', 'w', 'utf-8') 127 | 128 | for epoch in range(1, config.train.num_epochs+1): 129 | #for batch in du.get_training_batches_with_buckets(): 130 | for batch in du.get_training_batches(shuffle=True): 131 | # swap the batch[0] and batch[1] accroding to whether the length of the sequence is odd or even 132 | # batch_swap=[] 133 | # swap_0 = np.arange(batch[0].shape[1]) 134 | # swap_1 = np.arange(batch[1].shape[1]) 135 | # 136 | # if len(swap_0) % 2 == 0: 137 | # swap_0[0::2]+=1 138 | # swap_0[1::2]-=1 139 | # else: 140 | # swap_0[0:-1:2]+=1 141 | # swap_0[1::2]-=1 142 | # 143 | # if len(swap_1) % 2 == 0: 144 | # swap_1[0::2]+=1 145 | # swap_1[1::2]-=1 146 | # else: 147 | # swap_1[0:-1:2] += 1 148 | # swap_1[1::2] -= 1 149 | # 150 | # batch_swap.append(batch[0].transpose()[swap_0].transpose()) 151 | # batch_swap.append(batch[1].transpose()[swap_1].transpose()) 152 | 153 | #print("Len 1", len(batch[2])) 154 | #print("Len 2", len(batch[3])) 155 | #print(batch_swap[0]) 156 | 157 | # randomly shuffle the batch[0] and batch[1] 158 | #batch_shuffle=[] 159 | #shuffle_0_indices = np.random.permutation(np.arange(batch[0].shape[1])) 160 | #shuffle_1_indices = np.random.permutation(np.arange(batch[1].shape[1])) 161 | #batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose()) 162 | #batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose()) 163 | 164 | 165 | def get_shuffle_k_indices(length, shuffle_k): 166 | shuffle_k_indices = [] 167 | rand_start = np.random.randint(shuffle_k) 168 | 169 | indices_list_start = list(np.random.permutation(np.arange(0, rand_start))) 170 | shuffle_k_indices.extend(indices_list_start) 171 | 172 | for i in range(rand_start, length, shuffle_k): 173 | if i + shuffle_k > length: 174 | indices_list_i = list(np.random.permutation(np.arange(i, length))) 175 | else: 176 | indices_list_i = list(np.random.permutation(np.arange(i, i + shuffle_k))) 177 | 178 | shuffle_k_indices.extend(indices_list_i) 179 | 180 | return np.array(shuffle_k_indices) 181 | 182 | batch_shuffle=[] 183 | shuffle_0_indices = get_shuffle_k_indices(batch[0].shape[1], config.train.shuffle_k) 184 | shuffle_1_indices = get_shuffle_k_indices(batch[1].shape[1], config.train.shuffle_k) 185 | #print(shuffle_0_indices) 186 | batch_shuffle.append(batch[0].transpose()[shuffle_0_indices].transpose()) 187 | batch_shuffle.append(batch[1].transpose()[shuffle_1_indices].transpose()) 188 | 189 | start_time = time.time() 190 | step = sess.run(model.global_step) 191 | 192 | step, lr, gnorm_aa, loss_aa, acc_aa, _ = sess.run( 193 | [model.global_step, model.learning_rate, model.grads_norm_aa, 194 | model.loss_aa, model.acc_aa, model.train_op_aa], 195 | feed_dict={model.src_a_pl: batch_shuffle[0], model.dst_a_pl: batch[0]}) 196 | 197 | step, lr, gnorm_bb, loss_bb, acc_bb, _ = sess.run( 198 | [model.global_step, model.learning_rate, model.grads_norm_bb, 199 | model.loss_bb, model.acc_bb, model.train_op_bb], 200 | feed_dict={model.src_b_pl: batch_shuffle[1], model.dst_b_pl: batch[1]}) 201 | 202 | 203 | # this step takes too much time 204 | generate_ab, generate_ba = sess.run( 205 | [model.generate_ab, model.generate_ba], 206 | feed_dict={model.src_a_pl: batch[0], model.src_b_pl: batch[1]}) 207 | 208 | generate_ab_dealed, _ = deal_generated_samples(generate_ab, du.dst2idx) 209 | generate_ba_dealed, _ = deal_generated_samples(generate_ba, du.src2idx) 210 | 211 | #for sent in du.indices_to_words(batch[0], o='src'): 212 | # print(sent, file=tmp_writer) 213 | #for sent in du.indices_to_words(generate_ab_dealed, o='dst'): 214 | # print(sent, file=tmp_writer) 215 | 216 | step, acc_ab, loss_ab, _ = sess.run( 217 | [model.global_step, model.acc_ab, model.loss_ab, model.train_op_ab], 218 | feed_dict={model.src_a_pl:generate_ba_dealed, model.dst_b_pl: batch[3]}) 219 | 220 | step, acc_ba, loss_ba, _ = sess.run( 221 | [model.global_step, model.acc_ba, model.loss_ba, model.train_op_ba], 222 | feed_dict={model.src_b_pl:generate_ab_dealed, model.dst_a_pl: batch[2]}) 223 | 224 | if step % config.train.disp_freq == 0: 225 | logger.info('epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}' 226 | '\tacc: {5:.4f}\tcross_loss: {6:.4f}\tcross_acc: {7:.4f}\ttime: {8:.4f}' 227 | .format(epoch, step, lr, gnorm_bb, loss_bb, acc_bb, loss_bb, acc_ab, 228 | time.time() - start_time)) 229 | 230 | #if step % config.train.disp_freq == 0: 231 | # logger.info('epoch: {0}\tstep: {1}\tlr: {2:.6f}\tgnorm: {3:.4f}\tloss: {4:.4f}' 232 | # '\tacc: {5:.4f}\ttime: {6:.4f}' 233 | # .format(epoch, step, lr, gnorm_bb, loss_bb, acc_bb, 234 | # time.time() - start_time)) 235 | 236 | 237 | # Save model 238 | if step % config.train.save_freq == 0: 239 | mp = config.train.logdir + '/model_epoch_%d_step_%d' % (epoch, step) 240 | saver.save(sess, mp) 241 | logger.info('Save model in %s.' % mp) 242 | 243 | logger.info("Finish training.") 244 | 245 | 246 | if __name__ == '__main__': 247 | parser = ArgumentParser() 248 | parser.add_argument('-c', '--config', dest='config') 249 | args = parser.parse_args() 250 | # Read config 251 | config = AttrDict(yaml.load(open(args.config))) 252 | # Logger 253 | if not os.path.exists(config.train.logdir): 254 | os.makedirs(config.train.logdir) 255 | logging.basicConfig(filename=config.train.logdir+'/train.log', level=logging.DEBUG) 256 | console = logging.StreamHandler() 257 | console.setLevel(logging.INFO) 258 | logging.getLogger('').addHandler(console) 259 | # Train 260 | train(config) 261 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | import codecs 5 | import logging 6 | from tempfile import mkstemp 7 | from itertools import izip 8 | 9 | 10 | class AttrDict(dict): 11 | """ 12 | Dictionary whose keys can be accessed as attributes. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | super(AttrDict, self).__init__(*args, **kwargs) 17 | 18 | def __getattr__(self, item): 19 | if type(self[item]) is dict: 20 | self[item] = AttrDict(self[item]) 21 | return self[item] 22 | 23 | 24 | class DataUtil(object): 25 | """ 26 | Util class for creating batches for training and testing. 27 | """ 28 | def __init__(self, config): 29 | self.config = config 30 | self._logger = logging.getLogger('util') 31 | #self.load_vocab() 32 | 33 | def load_vocab(self, 34 | src_vocab=None, 35 | dst_vocab=None, 36 | src_vocab_size=None, 37 | dst_vocab_size=None): 38 | """ 39 | Load vocab from disk. The fisrt four items in the vocab should be , , , 40 | """ 41 | 42 | def load_vocab_(path, vocab_size): 43 | vocab = [line.split()[0] for line in codecs.open(path, 'r', 'utf-8')] 44 | vocab = vocab[:vocab_size] 45 | assert len(vocab) == vocab_size 46 | word2idx = {word: idx for idx, word in enumerate(vocab)} 47 | idx2word = {idx: word for idx, word in enumerate(vocab)} 48 | return word2idx, idx2word 49 | 50 | if src_vocab and dst_vocab and src_vocab_size and dst_vocab_size: 51 | self._logger.debug('Load set vocabularies as %s and %s.' % (src_vocab, dst_vocab)) 52 | self.src2idx, self.idx2src = load_vocab_(src_vocab, src_vocab_size) 53 | self.dst2idx, self.idx2dst = load_vocab_(dst_vocab, dst_vocab_size) 54 | else: 55 | self._logger.debug('Load vocabularies %s and %s.' % (self.config.src_vocab, self.config.dst_vocab)) 56 | self.src2idx, self.idx2src = load_vocab_(self.config.src_vocab, self.config.src_vocab_size) 57 | self.dst2idx, self.idx2dst = load_vocab_(self.config.dst_vocab, self.config.dst_vocab_size) 58 | 59 | def get_training_batches(self, 60 | shuffle=True, 61 | set_train_src_path=None, 62 | set_train_dst_path=None, 63 | set_batch_size=None, 64 | set_max_length=None): 65 | """ 66 | Generate batches with fixed batch size. 67 | """ 68 | if set_train_src_path and set_train_dst_path: 69 | src_path=set_train_src_path 70 | dst_path=set_train_dst_path 71 | else: 72 | src_path = self.config.train.src_path 73 | dst_path = self.config.train.dst_path 74 | 75 | if set_batch_size: 76 | batch_size=set_batch_size 77 | else: 78 | batch_size = self.config.train.batch_size 79 | 80 | if set_max_length: 81 | max_length=set_max_length 82 | else: 83 | max_length = self.config.train.max_length 84 | 85 | # Shuffle the training files. 86 | #Need to edit the shuffle for unison shuffling 87 | if shuffle: 88 | #src_shuf_path, dst_shuf_path = self.shuffle([src_path, dst_path]) 89 | #src_shuf_path, dst_shuf_path, src_1_shuf_path, dst_1_shuf_path = self.shuffle([src_path, dst_path, self.config.train.src_path_1, self.config.train.dst_path_1]) 90 | if self.config.train.src_more: 91 | dst_shuf_path, dst_1_shuf_path, src_shuf_path, src_1_shuf_path = self.shuffle([dst_path, self.config.train.dst_path_1, src_path, self.config.train.src_path_1]) 92 | else: 93 | src_shuf_path, dst_shuf_path, src_1_shuf_path, dst_1_shuf_path = self.shuffle([src_path, dst_path, self.config.train.src_path_1, self.config.train.dst_path_1]) 94 | 95 | else: 96 | src_shuf_path = src_path 97 | dst_shuf_path = dst_path 98 | src_1_shuf_path = self.config.train.src_path_1 99 | dst_1_shuf_path = self.config.train.dst_path_1 100 | 101 | src_sents, dst_sents = [], [] 102 | src_1_sents, dst_1_sents = [], [] 103 | for src_sent, src_1_sent, dst_sent, dst_1_sent in izip(codecs.open(src_shuf_path, 'r', 'utf8',errors='ignore'), codecs.open(src_1_shuf_path, 'r', 'utf-8', errors='ignore'), 104 | codecs.open(dst_shuf_path, 'r', 'utf8', errors='ignore'), codecs.open(dst_1_shuf_path, 'r', 'utf-8', errors='ignore')): 105 | # If exceed the max length, abandon this sentence pair. 106 | src_sent = src_sent.split() 107 | dst_sent = dst_sent.split() 108 | src_1_sent = src_1_sent.split() 109 | dst_1_sent = dst_1_sent.split() 110 | if len(src_sent) > max_length or len(dst_sent) > max_length or len(src_1_sent) > max_length or len(dst_1_sent) > max_length: 111 | continue 112 | src_sents.append(src_sent) 113 | dst_sents.append(dst_sent) 114 | src_1_sents.append(src_1_sent) 115 | dst_1_sents.append(dst_1_sent) 116 | # Create a padded batch. 117 | if len(src_sents) >= batch_size: 118 | yield self.create_batch(src_sents, o='src'), self.create_batch(dst_sents, o='dst'), self.create_batch(src_1_sents, o='src'), self.create_batch(dst_1_sents, o='dst') 119 | src_sents, dst_sents = [], [] 120 | src_1_sents, dst_1_sents = [], [] 121 | 122 | if src_sents and dst_sents: 123 | yield self.create_batch(src_sents, o='src'), self.create_batch(dst_sents, o='dst'), self.create_batch(src_1_sents, o='src'), self.create_batch(dst_1_sents, o='dst') 124 | 125 | # Remove shuffled files when epoch finished. 126 | if shuffle: 127 | os.remove(src_shuf_path) 128 | os.remove(dst_shuf_path) 129 | 130 | def get_training_batches_with_buckets(self, shuffle=True): 131 | """ 132 | Generate batches according to bucket setting. 133 | """ 134 | 135 | buckets = [(i, i) for i in range(10, 100, 5)] + [(self.config.train.max_length, self.config.train.max_length)] 136 | 137 | def select_bucket(sl, dl): 138 | for l1, l2 in buckets: 139 | if sl < l1 and dl < l2: 140 | return (l1, l2) 141 | return None 142 | 143 | # Shuffle the training files. 144 | src_path = self.config.train.src_path 145 | dst_path = self.config.train.dst_path 146 | if shuffle: 147 | self._logger.debug('Shuffle files %s and %s.' % (src_path, dst_path)) 148 | src_shuf_path, dst_shuf_path = self.shuffle([src_path, dst_path]) 149 | else: 150 | src_shuf_path = src_path 151 | dst_shuf_path = dst_path 152 | 153 | caches = {} 154 | for bucket in buckets: 155 | caches[bucket] = [[], [], 0, 0] # src sentences, dst sentences, src tokens, dst tokens 156 | 157 | for src_sent, dst_sent in izip(codecs.open(src_shuf_path, 'r', 'utf8'), 158 | codecs.open(dst_shuf_path, 'r', 'utf8')): 159 | src_sent = src_sent.split() 160 | dst_sent = dst_sent.split() 161 | #print("src sent", src_sent) 162 | #print("dst_sent", dst_sent) 163 | 164 | bucket = select_bucket(len(src_sent), len(dst_sent)) 165 | if bucket is None: # No bucket is selected when the sentence length exceed the max length. 166 | continue 167 | 168 | caches[bucket][0].append(src_sent) 169 | caches[bucket][1].append(dst_sent) 170 | caches[bucket][2] += len(src_sent) 171 | caches[bucket][3] += len(dst_sent) 172 | 173 | if max(caches[bucket][2], caches[bucket][3]) >= self.config.train.tokens_per_batch: 174 | batch = (self.create_batch(caches[bucket][0], o='src'), self.create_batch(caches[bucket][1], o='dst')) 175 | self._logger.debug( 176 | 'Yield batch with source shape %s and target shape %s.' % (batch[0].shape, batch[1].shape)) 177 | yield batch 178 | caches[bucket] = [[], [], 0, 0] 179 | 180 | # Clean remain sentences. 181 | for bucket in buckets: 182 | # Ensure each device at least get one sample. 183 | #print("len", self.config.train.devices.split(',')) 184 | #print("bucket", len(caches[bucket][0])) 185 | if len(caches[bucket][0]) > len(self.config.train.devices.split(',')): 186 | batch = (self.create_batch(caches[bucket][0], o='src'), self.create_batch(caches[bucket][1], o='dst')) 187 | self._logger.debug( 188 | 'Yield batch with source shape %s and target shape %s.' % (batch[0].shape, batch[1].shape)) 189 | yield batch 190 | 191 | # Remove shuffled files when epoch finished. 192 | if shuffle: 193 | os.remove(src_shuf_path) 194 | os.remove(dst_shuf_path) 195 | 196 | @staticmethod 197 | def shuffle(list_of_files): 198 | #print(list_of_files) 199 | tf_os, tpath = mkstemp() 200 | tf = open(tpath, 'w') 201 | 202 | fds = [open(ff) for ff in list_of_files] 203 | 204 | for l in fds[0]: 205 | lines = [l.strip()] + [ff.readline().strip() for ff in fds[1:]] 206 | print("|||||".join(lines), file=tf) 207 | 208 | [ff.close() for ff in fds] 209 | tf.close() 210 | 211 | os.system('shuf %s > %s' % (tpath, tpath + '.shuf')) 212 | 213 | fds = [open(ff + '.{}.shuf'.format(os.getpid()), 'w') for ff in list_of_files] 214 | 215 | for l in open(tpath + '.shuf'): 216 | s = l.strip().split('|||||') 217 | for i, fd in enumerate(fds): 218 | print(s[i], file=fd) 219 | 220 | [ff.close() for ff in fds] 221 | 222 | os.remove(tpath) 223 | os.remove(tpath + '.shuf') 224 | 225 | return [ff + '.{}.shuf'.format(os.getpid()) for ff in list_of_files] 226 | 227 | def get_test_batches(self, 228 | set_src_path=None, 229 | set_batch=None, 230 | o='src'): 231 | if set_src_path and set_batch: 232 | src_path=set_src_path 233 | batch_size=set_batch 234 | else: 235 | src_path = self.config.test.src_path 236 | #src_path = self.config.test.dst_path 237 | batch_size = self.config.test.batch_size 238 | 239 | # Read batches from test files. 240 | src_sents = [] 241 | for src_sent in codecs.open(src_path, 'r', 'utf8'): 242 | src_sent = src_sent.split() 243 | src_sents.append(src_sent) 244 | # Create a padded batch. 245 | if len(src_sents) >= batch_size: 246 | yield self.create_batch(src_sents, o=o) 247 | src_sents = [] 248 | if src_sents: 249 | yield self.create_batch(src_sents, o=o) 250 | 251 | def get_test_batches_with_target(self, 252 | set_test_src_path=None, 253 | set_test_dst_path=None, 254 | set_batch_size=None, 255 | src_o='src', 256 | trg_o='dst'): 257 | """ 258 | Usually we don't need target sentences for test unless we want to compute PPl. 259 | Returns: 260 | Paired source and target batches. 261 | """ 262 | if set_test_src_path and set_test_dst_path and set_batch_size: 263 | src_path=set_test_src_path 264 | dst_path=set_test_dst_path 265 | batch_size=set_batch_size 266 | 267 | else: 268 | src_path = self.config.test.src_path 269 | dst_path = self.config.test.dst_path 270 | batch_size = self.config.test.batch_size 271 | 272 | # Read batches from test files. 273 | src_sents, dst_sents = [], [] 274 | for src_sent, dst_sent in izip(codecs.open(src_path, 'r', 'utf8'), 275 | codecs.open(dst_path, 'r', 'utf8')): 276 | src_sent = src_sent.split() 277 | dst_sent = dst_sent.split() 278 | src_sents.append(src_sent) 279 | dst_sents.append(dst_sent) 280 | # Create a padded batch. 281 | if len(src_sents) >= batch_size: 282 | yield self.create_batch(src_sents, o=src_o), self.create_batch(dst_sents, o=trg_o) 283 | src_sents, dst_sents = [], [] 284 | if src_sents: 285 | yield self.create_batch(src_sents, o=src_o), self.create_batch(dst_sents, o=trg_o) 286 | 287 | def create_batch(self, sents, o): 288 | # Convert words to indices. 289 | assert o in ('src', 'dst') 290 | word2idx = self.src2idx if o == 'src' else self.dst2idx 291 | indices = [] 292 | for sent in sents: 293 | x = [word2idx.get(word, 1) for word in (sent + [u""])] # 1: OOV, : End of Text 294 | indices.append(x) 295 | 296 | # Pad to the same length. 297 | maxlen = max([len(s) for s in indices]) 298 | X = np.zeros([len(indices), maxlen], np.int32) 299 | for i, x in enumerate(indices): 300 | X[i, :len(x)] = x 301 | 302 | return X 303 | 304 | def indices_to_words(self, Y, o='dst'): 305 | assert o in ('src', 'dst') 306 | idx2word = self.idx2src if o == 'src' else self.idx2dst 307 | sents = [] 308 | for y in Y: # for each sentence 309 | sent = [] 310 | for i in y: # For each word 311 | if i == 3: # 312 | break 313 | w = idx2word[i] 314 | sent.append(w) 315 | sents.append(' '.join(sent)) 316 | return sents 317 | 318 | def indices_to_words_del_pad(self, Y, o='dst'): 319 | assert o in ('src', 'dst') 320 | idx2word = self.idx2src if o == 'src' else self.idx2dst 321 | pad_index = idx2word 322 | sents=[] 323 | for y in Y: 324 | sent= [] 325 | for i in y: 326 | if i > 0: 327 | w = idx2word[i] 328 | sent.append(w) 329 | sents.append(' '.join(sent)) 330 | return sents 331 | 332 | 333 | 334 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | from __future__ import print_function 4 | import codecs 5 | import re as regex 6 | import yaml 7 | from argparse import ArgumentParser 8 | from collections import Counter 9 | 10 | from utils import AttrDict 11 | 12 | 13 | def make_vocab(fpath, fname): 14 | """Constructs vocabulary. 15 | 16 | Args: 17 | fpath: A string. Input file path. 18 | fname: A string. Output file name. 19 | 20 | Writes vocabulary line by line to `fname`. 21 | """ 22 | word2cnt = Counter() 23 | for l in codecs.open(fpath, 'r', 'utf-8'): 24 | words = l.split() 25 | word2cnt.update(Counter(words)) 26 | with codecs.open(fname, 'w', 'utf-8') as fout: 27 | fout.write("{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n{}\t1000000000\n".format("", "", "", "")) 28 | for word, cnt in word2cnt.most_common(len(word2cnt)): 29 | fout.write(u"{}\t{}\n".format(word, cnt)) 30 | 31 | if __name__ == '__main__': 32 | parser = ArgumentParser() 33 | parser.add_argument('-c', '--config', dest='config') 34 | args = parser.parse_args() 35 | # Read config 36 | config = AttrDict(yaml.load(open(args.config))) 37 | make_vocab(config.src_gen, config.src_vocab) 38 | make_vocab(config.dst_gen, config.dst_vocab) 39 | print("Done") 40 | --------------------------------------------------------------------------------