├── .idea ├── Adversarial-Learning-for-Neural-Dialogue-Generation.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── config.py ├── dis_pre_train.py ├── discriminator.py ├── gen_data.py ├── gen_pre_train.py ├── generator.py ├── model ├── dis_model.py ├── gen_model.py └── seq2seq.py ├── test.py ├── train.py ├── util.py └── 说明文档.pdf /.idea/Adversarial-Learning-for-Neural-Dialogue-Generation.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 30 | 31 | 32 | 33 | true 34 | DEFINITION_ORDER 35 | 36 | 37 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 68 | 69 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 94 | 95 | 96 | 98 | 99 | 100 | 101 | 1552626005281 102 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 138 | 139 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Learning for Neural Dialogue Generation 2 | 3 | the paper: Adversarial Learning for Neural Dialogue Generation https://arxiv.org/pdf/1701.06547.pdf 4 | 5 | 将数据集中的chitchat.train.query、chitchat.train.answer、chitchat.dev.query和chitchat.dev.answer放入gen_data文件夹下即可训练,dis_data初始为空,经过下面步骤2会自动生成。 6 | 7 | 训练步骤: 8 | 9 | 1.python gen_pre_train.py 预训练生成器。需要有gen_data/chitchat.dev.answer等四个文件才能运行,也可在config.py中修改各种设置。运行后,gen_data/checkpoints中将储存经过预训练的权值文件。 10 | 11 | 2.python gen_data.py 读取生成器预训练后的权值,为判别器预训练过程生成数据。运行后,将会生成dis_data/dev.answer等六个文件。 12 | 13 | 3.python dis_pre_train.py 预训练判别器。需要有上一步生成的六个文件才能运行。运行后,dis_data/checkpoints中将储存经过预训练的权值文件。 14 | 15 | 4.python train.py读取经过预训练的生成器和判别器权值,进行对抗训练,并将权值文件保存。 16 | 17 | 5.测试程序为test.py。执行python test.py后,程序将读取gen_data/checkpoints中训练好的权值文件,进行人机交互——程序将等待用户输入,然后根据用户输入,输出回应,直到用户输入Ctrl+Z退出。 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # configuration options for discriminator network 4 | class disc_config(object): 5 | batch_size = 256 6 | lr = 0.2 7 | lr_decay = 0.9 8 | vocab_size = 35000 9 | embed_dim = 512 10 | steps_per_checkpoint = 100 11 | num_layers = 2 12 | train_dir = './dis_data/' 13 | name_model = "disc_model" 14 | name_loss = "disc_loss" 15 | max_len = 50 16 | piece_size = batch_size * steps_per_checkpoint 17 | piece_dir = "./dis_data/batch_piece/" 18 | valid_num = 100 19 | init_scale = 0.1 20 | num_class = 2 21 | keep_prob = 0.5 22 | max_grad_norm = 5 23 | buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 24 | dis_pre_train_step = 80000 25 | 26 | 27 | # configuration options for generator network 28 | class gen_config(object): 29 | beam_size = 7 30 | learning_rate = 0.5 31 | learning_rate_decay_factor = 0.99 32 | max_gradient_norm = 5.0 33 | batch_size = 128 34 | emb_dim = 512 35 | num_layers = 2 36 | vocab_size = 35000 37 | train_dir = "./gen_data/" 38 | name_model = "st_model" 39 | name_loss = "gen_loss" 40 | teacher_loss = "teacher_loss" 41 | reward_name = "reward" 42 | max_train_data_size = 0 43 | steps_per_checkpoint = 100 44 | buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 45 | buckets_concat = [(5, 10), (10, 15), (20, 25), (40, 50), (100, 50)] 46 | gen_pre_train_step = 400000 -------------------------------------------------------------------------------- /dis_pre_train.py: -------------------------------------------------------------------------------- 1 | import config 2 | from discriminator import * 3 | import tensorflow as tf 4 | import numpy as np 5 | import os 6 | import time 7 | import random 8 | from six.moves import xrange 9 | import sys 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = '4' 12 | 13 | def dis_pre_train(config_disc, config_evl): 14 | config_evl.keep_prob = 1.0 15 | 16 | print("begin training") 17 | 18 | with tf.Session() as session: 19 | 20 | print("prepare_data") 21 | query_set, answer_set, gen_set = prepare_data(config_disc) 22 | 23 | train_bucket_sizes = [len(query_set[b]) for b in xrange(len(config_disc.buckets))] 24 | train_total_size = float(sum(train_bucket_sizes)) 25 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 26 | for i in xrange(len(train_bucket_sizes))] 27 | for set in query_set: 28 | print("set length: ", len(set)) 29 | 30 | model = create_model(session, config_disc, name_scope=config_disc.name_model) 31 | 32 | step_time, loss = 0.0, 0.0 33 | current_step = 0 34 | step_loss_summary = tf.Summary() 35 | 36 | train_step = config_disc.dis_pre_train_step 37 | while train_step>0: 38 | train_step -= 1 39 | random_number_01 = np.random.random_sample() 40 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) 41 | if train_buckets_scale[i] > random_number_01]) 42 | 43 | start_time = time.time() 44 | 45 | b_query, b_answer, b_gen = query_set[bucket_id], answer_set[bucket_id], gen_set[bucket_id] 46 | 47 | train_query, train_answer, train_labels = hier_get_batch(config_disc, len(b_query)-1, b_query, b_answer, b_gen) 48 | 49 | train_query = np.transpose(train_query) 50 | train_answer = np.transpose(train_answer) 51 | 52 | feed_dict = {} 53 | for i in xrange(config_disc.buckets[bucket_id][0]): 54 | feed_dict[model.query[i].name] = train_query[i] 55 | for i in xrange(config_disc.buckets[bucket_id][1]): 56 | feed_dict[model.answer[i].name] = train_answer[i] 57 | feed_dict[model.target.name] = train_labels 58 | 59 | fetches = [model.b_train_op[bucket_id], model.b_logits[bucket_id], model.b_loss[bucket_id], model.target] 60 | train_op, logits, step_loss, target = session.run(fetches, feed_dict) 61 | 62 | step_time += (time.time() - start_time) / config_disc.steps_per_checkpoint 63 | loss += step_loss /config_disc.steps_per_checkpoint 64 | current_step += 1 65 | 66 | if current_step % config_disc.steps_per_checkpoint == 0: 67 | 68 | disc_loss_value = step_loss_summary.value.add() 69 | disc_loss_value.tag = config_disc.name_loss 70 | disc_loss_value.simple_value = float(loss) 71 | 72 | print("logits shape: ", np.shape(logits)) 73 | 74 | # softmax operation 75 | logits = np.transpose(softmax(np.transpose(logits))) 76 | 77 | reward = 0.0 78 | for logit, label in zip(logits, train_labels): 79 | reward += logit[1] # only for true probility 80 | reward = reward / len(train_labels) 81 | print("reward: ", reward) 82 | 83 | 84 | print("current_step: %d, step_loss: %.4f" %(current_step, step_loss)) 85 | 86 | 87 | if current_step % (config_disc.steps_per_checkpoint * 3) == 0: 88 | print("current_step: %d, save_model" % (current_step)) 89 | disc_ckpt_dir = os.path.abspath(os.path.join(config_disc.train_dir, "checkpoints")) 90 | if not os.path.exists(disc_ckpt_dir): 91 | os.makedirs(disc_ckpt_dir) 92 | disc_model_path = os.path.join(disc_ckpt_dir, "disc.model") 93 | model.saver.save(session, disc_model_path, global_step=model.global_step) 94 | 95 | 96 | step_time, loss = 0.0, 0.0 97 | sys.stdout.flush() 98 | 99 | if __name__ == "__main__": 100 | dis_pre_train(config.disc_config, config.disc_config) -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import time 5 | import random 6 | from six.moves import xrange 7 | from model.dis_model import Hier_rnn_model 8 | import util 9 | 10 | from tensorflow.python.platform import gfile 11 | import sys 12 | 13 | def hier_read_data(config, query_path, answer_path, gen_path): 14 | query_set = [[] for _ in config.buckets] 15 | answer_set = [[] for _ in config.buckets] 16 | gen_set = [[] for _ in config.buckets] 17 | with gfile.GFile(query_path, mode="r") as query_file: 18 | with gfile.GFile(answer_path, mode="r") as answer_file: 19 | with gfile.GFile(gen_path, mode="r") as gen_file: 20 | query, answer, gen = query_file.readline(), answer_file.readline(), gen_file.readline() 21 | counter = 0 22 | while query and answer and gen: 23 | counter += 1 24 | if counter % 100000 == 0: 25 | print(" reading disc_data line %d" % counter) 26 | query = [int(id) for id in query.strip().split()] 27 | answer = [int(id) for id in answer.strip().split()] 28 | gen = [int(id) for id in gen.strip().split()] 29 | for i, (query_size, answer_size) in enumerate(config.buckets): 30 | if len(query) <= query_size and len(answer) <= answer_size and len(gen) <= answer_size: 31 | query = query[:query_size] + [util.PAD_ID] * (query_size - len(query) if query_size > len(query) else 0) 32 | query_set[i].append(query) 33 | answer = answer[:answer_size] + [util.PAD_ID] * (answer_size - len(answer) if answer_size > len(answer) else 0) 34 | answer_set[i].append(answer) 35 | gen = gen[:answer_size] + [util.PAD_ID] * (answer_size - len(gen) if answer_size > len(gen) else 0) 36 | gen_set[i].append(gen) 37 | query, answer, gen = query_file.readline(), answer_file.readline(), gen_file.readline() 38 | 39 | return query_set, answer_set, gen_set 40 | 41 | 42 | def hier_get_batch(config, max_set, query_set, answer_set, gen_set): 43 | batch_size = config.batch_size 44 | if batch_size % 2 == 1: 45 | return IOError("Error") 46 | train_query = [] 47 | train_answer = [] 48 | train_labels = [] 49 | half_size = int(batch_size / 2) 50 | for _ in range(half_size): 51 | index = random.randint(0, max_set) 52 | train_query.append(query_set[index]) 53 | train_answer.append(answer_set[index]) 54 | train_labels.append(1) 55 | train_query.append(query_set[index]) 56 | train_answer.append(gen_set[index]) 57 | train_labels.append(0) 58 | return train_query, train_answer, train_labels 59 | 60 | 61 | def create_model(sess, config, name_scope, initializer=None): 62 | with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): 63 | model = Hier_rnn_model(config=config, name_scope=name_scope) 64 | disc_ckpt_dir = os.path.abspath(os.path.join(config.train_dir, "checkpoints")) 65 | ckpt = tf.train.get_checkpoint_state(disc_ckpt_dir) 66 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 67 | print("Reading Hier Disc model parameters from %s" % ckpt.model_checkpoint_path) 68 | model.saver.restore(sess, ckpt.model_checkpoint_path) 69 | else: 70 | print("Created Hier Disc model with fresh parameters.") 71 | disc_global_variables = [gv for gv in tf.global_variables() if name_scope in gv.name] 72 | sess.run(tf.variables_initializer(disc_global_variables)) 73 | return model 74 | 75 | 76 | def prepare_data(config): 77 | train_path = os.path.join(config.train_dir, "train") 78 | voc_file_path = [train_path + ".query", train_path + ".answer", train_path + ".gen"] 79 | vocab_path = os.path.join(config.train_dir, "vocab%d.all" % config.vocab_size) 80 | util.create_vocabulary(vocab_path, voc_file_path, config.vocab_size) 81 | vocab, rev_vocab = util.initialize_vocabulary(vocab_path) 82 | 83 | print("Preparing train disc_data in %s" % config.train_dir) 84 | train_query_path, train_answer_path, train_gen_path, dev_query_path, dev_answer_path, dev_gen_path = \ 85 | util.hier_prepare_disc_data(config.train_dir, vocab, config.vocab_size) 86 | query_set, answer_set, gen_set = hier_read_data(config, train_query_path, train_answer_path, train_gen_path) 87 | return query_set, answer_set, gen_set 88 | 89 | 90 | def softmax(x): 91 | prob = np.exp(x) / np.sum(np.exp(x), axis=0) 92 | return prob 93 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | import config 2 | import tensorflow as tf 3 | from generator import * 4 | import random 5 | import numpy as np 6 | from six.moves import xrange 7 | import util 8 | 9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 11 | 12 | def gen_data(gen_config): 13 | vocab, rev_vocab, dev_set, train_set = prepare_data(gen_config) 14 | 15 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(gen_config.buckets))] 16 | train_total_size = float(sum(train_bucket_sizes)) 17 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 18 | for i in xrange(len(train_bucket_sizes))] 19 | 20 | with tf.Session() as sess: 21 | model = create_model(sess, gen_config, forward_only=True, name_scope=gen_config.name_model) 22 | 23 | disc_train_query = open("dis_data/train.query", "w", encoding='utf-8') 24 | disc_train_answer = open("dis_data/train.answer", "w", encoding='utf-8') 25 | disc_train_gen = open("dis_data/train.gen", "w", encoding='utf-8') 26 | 27 | disc_dev_query = open("dis_data/dev.query", "w", encoding='utf-8') 28 | disc_dev_answer = open("dis_data/dev.answer", "w", encoding='utf-8') 29 | disc_dev_gen = open("dis_data/dev.gen", "w", encoding='utf-8') 30 | 31 | num_step = 0 32 | while num_step < 10000: 33 | print("generating num_step: ", num_step) 34 | random_number_01 = np.random.random_sample() 35 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) 36 | if train_buckets_scale[i] > random_number_01]) 37 | 38 | encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = \ 39 | model.get_batch(train_set, bucket_id, gen_config.batch_size) 40 | 41 | _, _, out_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, 42 | forward_only=True) 43 | 44 | tokens = [] 45 | resps = [] 46 | for seq in out_logits: 47 | token = [] 48 | for t in seq: 49 | token.append(int(np.argmax(t, axis=0))) 50 | tokens.append(token) 51 | tokens_t = [] 52 | for col in range(len(tokens[0])): 53 | tokens_t.append([tokens[row][col] for row in range(len(tokens))]) 54 | 55 | for seq in tokens_t: 56 | if util.EOS_ID in seq: 57 | resps.append(seq[:seq.index(util.EOS_ID)][:gen_config.buckets[bucket_id][1]]) 58 | else: 59 | resps.append(seq[:gen_config.buckets[bucket_id][1]]) 60 | 61 | if num_step % 100 == 0: 62 | for query, answer, resp in zip(batch_source_encoder, batch_source_decoder, resps): 63 | 64 | answer_str = " ".join([str(rev_vocab[an]) for an in answer][:-1]) 65 | disc_dev_answer.write(answer_str) 66 | disc_dev_answer.write("\n") 67 | 68 | query_str = " ".join([str(rev_vocab[qu]) for qu in query]) 69 | disc_dev_query.write(query_str) 70 | disc_dev_query.write("\n") 71 | 72 | resp_str = " ".join([tf.compat.as_str(rev_vocab[output]) for output in resp]) 73 | 74 | disc_dev_gen.write(resp_str) 75 | disc_dev_gen.write("\n") 76 | else: 77 | for query, answer, resp in zip(batch_source_encoder, batch_source_decoder, resps): 78 | 79 | answer_str = " ".join([str(rev_vocab[an]) for an in answer][:-1]) 80 | disc_train_answer.write(answer_str) 81 | disc_train_answer.write("\n") 82 | 83 | query_str = " ".join([str(rev_vocab[qu]) for qu in query]) 84 | disc_train_query.write(query_str) 85 | disc_train_query.write("\n") 86 | 87 | resp_str = " ".join([tf.compat.as_str(rev_vocab[output]) for output in resp]) 88 | 89 | disc_train_gen.write(resp_str) 90 | disc_train_gen.write("\n") 91 | 92 | num_step += 1 93 | 94 | disc_train_gen.close() 95 | disc_train_query.close() 96 | disc_train_answer.close() 97 | disc_dev_gen.close() 98 | disc_dev_query.close() 99 | disc_dev_answer.close() 100 | pass 101 | 102 | if __name__ == "__main__": 103 | gen_data(config.gen_config) -------------------------------------------------------------------------------- /gen_pre_train.py: -------------------------------------------------------------------------------- 1 | import config 2 | import os 3 | import tensorflow as tf 4 | from generator import * 5 | import numpy as np 6 | import time 7 | import sys 8 | import math 9 | import random 10 | from six.moves import xrange 11 | 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 14 | 15 | def gen_pre_train(gen_config): 16 | vocab, rev_vocab, dev_set, train_set = prepare_data(gen_config) 17 | for b_set in train_set: 18 | print("b_set: ", len(b_set)) 19 | 20 | with tf.Session() as sess: 21 | # Create model. 22 | print("Creating %d layers of %d units." % (gen_config.num_layers, gen_config.emb_dim)) 23 | model = create_model(sess, gen_config, forward_only=False, name_scope=gen_config.name_model) 24 | 25 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(gen_config.buckets))] 26 | train_total_size = float(sum(train_bucket_sizes)) 27 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 28 | for i in xrange(len(train_bucket_sizes))] 29 | 30 | # This is the training loop. 31 | step_time, loss = 0.0, 0.0 32 | current_step = 0 33 | 34 | gen_loss_summary = tf.Summary() 35 | 36 | train_step = gen_config.gen_pre_train_step 37 | while train_step>0: 38 | train_step -= 1 39 | # Choose a bucket according to disc_data distribution. We pick a random number 40 | # in [0, 1] and use the corresponding interval in train_buckets_scale. 41 | random_number_01 = np.random.random_sample() 42 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01]) 43 | 44 | # Get a batch and make a step. 45 | start_time = time.time() 46 | encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch( 47 | train_set, bucket_id, gen_config.batch_size) 48 | 49 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=False) 50 | 51 | step_time += (time.time() - start_time) / gen_config.steps_per_checkpoint 52 | loss += step_loss / gen_config.steps_per_checkpoint 53 | current_step += 1 54 | 55 | # Once in a while, we save checkpoint, print statistics, and run evals. 56 | if current_step % gen_config.steps_per_checkpoint == 0: 57 | 58 | bucket_value = gen_loss_summary.value.add() 59 | bucket_value.tag = gen_config.name_loss 60 | bucket_value.simple_value = float(loss) 61 | 62 | # Print statistics for the previous epoch. 63 | perplexity = math.exp(loss) if loss < 300 else float('inf') 64 | print ("global step %d learning rate %.4f step-time %.2f perplexity " 65 | "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), 66 | step_time, perplexity)) 67 | 68 | # Save checkpoint and zero timer and loss. 69 | if current_step % (gen_config.steps_per_checkpoint * 3) == 0: 70 | print("current_step: %d, save model" %(current_step)) 71 | gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) 72 | if not os.path.exists(gen_ckpt_dir): 73 | os.makedirs(gen_ckpt_dir) 74 | checkpoint_path = os.path.join(gen_ckpt_dir, "chitchat.model") 75 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 76 | 77 | 78 | step_time, loss = 0.0, 0.0 79 | sys.stdout.flush() 80 | 81 | if __name__ == "__main__": 82 | gen_pre_train(config.gen_config) -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | import time 5 | import heapq 6 | import tensorflow.python.platform 7 | import numpy as np 8 | from six.moves import xrange 9 | import tensorflow as tf 10 | import util 11 | import model.gen_model as seq2seq_model 12 | from tensorflow.python.platform import gfile 13 | 14 | def read_data(config, source_path, target_path, max_size=None): 15 | data_set = [[] for _ in config.buckets] 16 | with gfile.GFile(source_path, mode="r") as source_file: 17 | with gfile.GFile(target_path, mode="r") as target_file: 18 | source, target = source_file.readline(), target_file.readline() 19 | counter = 0 20 | while source and target and (not max_size or counter < max_size): 21 | counter += 1 22 | if counter % 100000 == 0: 23 | print(" reading disc_data line %d" % counter) 24 | sys.stdout.flush() 25 | source_ids = [int(x) for x in source.split()] 26 | target_ids = [int(x) for x in target.split()] 27 | target_ids.append(util.EOS_ID) 28 | for bucket_id, (source_size, target_size) in enumerate(config.buckets): 29 | if len(source_ids) < source_size and len(target_ids) < target_size: 30 | data_set[bucket_id].append([source_ids, target_ids]) 31 | break 32 | source, target = source_file.readline(), target_file.readline() 33 | return data_set 34 | 35 | 36 | def create_model(session, gen_config, forward_only, name_scope, initializer=None): 37 | """Create translation model and initialize or load parameters in session.""" 38 | with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): 39 | model = seq2seq_model.Seq2SeqModel(gen_config, name_scope=name_scope, forward_only=forward_only) 40 | gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) 41 | ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir) 42 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 43 | print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path) 44 | model.saver.restore(session, ckpt.model_checkpoint_path) 45 | else: 46 | print("Created Gen model with fresh parameters.") 47 | gen_global_variables = [gv for gv in tf.global_variables() if name_scope in gv.name] 48 | session.run(tf.variables_initializer(gen_global_variables)) 49 | return model 50 | 51 | def prepare_data(gen_config): 52 | train_path = os.path.join(gen_config.train_dir, "chitchat.train") 53 | voc_file_path = [train_path+".answer", train_path+".query"] 54 | vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size) 55 | util.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size) 56 | vocab, rev_vocab = util.initialize_vocabulary(vocab_path) 57 | 58 | print("Preparing Chitchat gen_data in %s" % gen_config.train_dir) 59 | train_query, train_answer, dev_query, dev_answer = util.prepare_chitchat_data( 60 | gen_config.train_dir, vocab, gen_config.vocab_size) 61 | 62 | # Read disc_data into buckets and compute their sizes. 63 | print ("Reading development and training gen_data (limit: %d)." 64 | % gen_config.max_train_data_size) 65 | dev_set = read_data(gen_config, dev_query, dev_answer) 66 | train_set = read_data(gen_config, train_query, train_answer, gen_config.max_train_data_size) 67 | 68 | return vocab, rev_vocab, dev_set, train_set 69 | -------------------------------------------------------------------------------- /model/dis_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from six.moves import xrange 4 | 5 | 6 | class Hier_rnn_model(object): 7 | def __init__(self, config, name_scope, dtype=tf.float32): 8 | emb_dim = config.embed_dim 9 | num_layers = config.num_layers 10 | vocab_size = config.vocab_size 11 | num_class = config.num_class 12 | buckets = config.buckets 13 | self.lr = config.lr 14 | self.global_step = tf.Variable(initial_value=0, trainable=False) 15 | 16 | self.query = [] 17 | self.answer = [] 18 | for i in range(buckets[-1][0]): 19 | self.query.append(tf.placeholder(dtype=tf.int32, shape=[None], name="query{0}".format(i))) 20 | for i in xrange(buckets[-1][1]): 21 | self.answer.append(tf.placeholder(dtype=tf.int32, shape=[None], name="answer{0}".format(i))) 22 | 23 | self.target = tf.placeholder(dtype=tf.int64, shape=[None], name="target") 24 | 25 | encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(emb_dim) 26 | encoder_mutil = tf.contrib.rnn.MultiRNNCell([encoder_cell] * num_layers) 27 | encoder_emb = tf.contrib.rnn.EmbeddingWrapper(encoder_mutil, embedding_classes=vocab_size, embedding_size=emb_dim) 28 | 29 | context_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=emb_dim) 30 | context_multi = tf.contrib.rnn.MultiRNNCell([context_cell] * num_layers) 31 | 32 | self.b_query_state = [] 33 | self.b_answer_state = [] 34 | self.b_state = [] 35 | self.b_logits = [] 36 | self.b_loss = [] 37 | self.b_train_op = [] 38 | for i, bucket in enumerate(buckets): 39 | with tf.variable_scope(name_or_scope="Hier_RNN_encoder", reuse=True if i > 0 else None) as var_scope: 40 | query_output, query_state = tf.contrib.rnn.static_rnn(encoder_emb, inputs=self.query[:bucket[0]], dtype=tf.float32) 41 | var_scope.reuse_variables() 42 | answer_output, answer_state = tf.contrib.rnn.static_rnn(encoder_emb, inputs=self.answer[:bucket[1]], dtype=tf.float32) 43 | self.b_query_state.append(query_state) 44 | self.b_answer_state.append(answer_state) 45 | context_input = [query_state[-1][1], answer_state[-1][1]] 46 | 47 | with tf.variable_scope(name_or_scope="Hier_RNN_context", reuse=True if i > 0 else None): 48 | output, state = tf.contrib.rnn.static_rnn(context_multi, context_input, dtype=tf.float32) 49 | self.b_state.append(state) 50 | top_state = state[-1][1] # [batch_size, emb_dim] 51 | 52 | with tf.variable_scope("Softmax_layer_and_output", reuse=True if i > 0 else None): 53 | softmax_w = tf.get_variable("softmax_w", [emb_dim, num_class], dtype=tf.float32) 54 | softmax_b = tf.get_variable("softmax_b", [num_class], dtype=tf.float32) 55 | logits = tf.matmul(top_state, softmax_w) + softmax_b 56 | self.b_logits.append(logits) 57 | 58 | with tf.name_scope("loss"): 59 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=self.target) 60 | mean_loss = tf.reduce_mean(loss) 61 | self.b_loss.append(mean_loss) 62 | 63 | with tf.name_scope("gradient_descent"): 64 | disc_params = [var for var in tf.trainable_variables() if name_scope in var.name] 65 | grads, norm = tf.clip_by_global_norm(tf.gradients(mean_loss, disc_params), config.max_grad_norm) 66 | #optimizer = tf.train.GradientDescentOptimizer(self.lr) 67 | optimizer = tf.train.AdamOptimizer(learning_rate=0.001) 68 | train_op = optimizer.apply_gradients(zip(grads, disc_params), global_step=self.global_step) 69 | self.b_train_op.append(train_op) 70 | 71 | all_variables = [v for v in tf.global_variables() if name_scope in v.name] 72 | self.saver = tf.train.Saver(all_variables) 73 | -------------------------------------------------------------------------------- /model/gen_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | import numpy as np 4 | from six.moves import xrange 5 | import tensorflow as tf 6 | import util 7 | import model.seq2seq as rl_seq2seq 8 | from tensorflow.python.ops import variable_scope 9 | 10 | class Seq2SeqModel(object): 11 | 12 | def __init__(self, config, name_scope, forward_only=False, num_samples=256, dtype=tf.float32): 13 | 14 | source_vocab_size = config.vocab_size 15 | target_vocab_size = config.vocab_size 16 | emb_dim = config.emb_dim 17 | 18 | self.buckets = config.buckets 19 | self.learning_rate = tf.Variable(float(config.learning_rate), trainable=False, dtype=dtype) 20 | self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * config.learning_rate_decay_factor) 21 | self.global_step = tf.Variable(0, trainable=False) 22 | self.batch_size = config.batch_size 23 | self.num_layers = config.num_layers 24 | self.max_gradient_norm = config.max_gradient_norm 25 | self.mc_search = tf.placeholder(tf.bool, name="mc_search") 26 | self.forward_only = tf.placeholder(tf.bool, name="forward_only") 27 | self.up_reward = tf.placeholder(tf.bool, name="up_reward") 28 | self.reward_bias = tf.get_variable("reward_bias", [1], dtype=tf.float32) 29 | # If we use sampled softmax, we need an output projection. 30 | output_projection = None 31 | softmax_loss_function = None 32 | # Sampled softmax only makes sense if we sample less than vocabulary size. 33 | if num_samples > 0 and num_samples < target_vocab_size: 34 | w_t = tf.get_variable("proj_w", [target_vocab_size, emb_dim], dtype=dtype) 35 | w = tf.transpose(w_t) 36 | b = tf.get_variable("proj_b", [target_vocab_size], dtype=dtype) 37 | output_projection = (w, b) 38 | 39 | def sampled_loss(inputs, labels): 40 | labels = tf.reshape(labels, [-1, 1]) 41 | # We need to compute the sampled_softmax_loss using 32bit floats to 42 | # avoid numerical instabilities. 43 | local_w_t = tf.cast(w_t, tf.float32) 44 | local_b = tf.cast(b, tf.float32) 45 | local_inputs = tf.cast(inputs, tf.float32) 46 | return tf.cast( 47 | tf.nn.sampled_softmax_loss(local_w_t, local_b, labels, local_inputs, 48 | num_samples, target_vocab_size), dtype) 49 | 50 | softmax_loss_function = sampled_loss 51 | 52 | # Create the internal multi-layer cell for our RNN. 53 | single_cell = tf.contrib.rnn.GRUCell(emb_dim) 54 | cell = single_cell 55 | if self.num_layers > 1: 56 | cell = tf.contrib.rnn.MultiRNNCell([single_cell] * self.num_layers) 57 | 58 | # The seq2seq function: we use embedding for the input and attention. 59 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 60 | return rl_seq2seq.embedding_attention_seq2seq( 61 | encoder_inputs, 62 | decoder_inputs, 63 | cell, 64 | num_encoder_symbols= source_vocab_size, 65 | num_decoder_symbols= target_vocab_size, 66 | embedding_size= emb_dim, 67 | output_projection=output_projection, 68 | feed_previous=do_decode, 69 | mc_search=self.mc_search, 70 | dtype=dtype) 71 | 72 | # Feeds for inputs. 73 | self.encoder_inputs = [] 74 | self.decoder_inputs = [] 75 | self.target_weights = [] 76 | for i in xrange(self.buckets[-1][0]): # Last bucket is the biggest one. 77 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="encoder{0}".format(i))) 78 | for i in xrange(self.buckets[-1][1] + 1): 79 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="decoder{0}".format(i))) 80 | self.target_weights.append(tf.placeholder(dtype, shape=[None], name="weight{0}".format(i))) 81 | self.reward = [tf.placeholder(tf.float32, name="reward_%i" % i) for i in range(len(self.buckets))] 82 | 83 | # Our targets are decoder inputs shifted by one. 84 | targets = [self.decoder_inputs[i + 1] for i in xrange(len(self.decoder_inputs) - 1)] 85 | 86 | self.outputs, self.losses, self.encoder_state = rl_seq2seq.model_with_buckets( 87 | self.encoder_inputs, self.decoder_inputs, targets, self.target_weights, 88 | self.buckets, source_vocab_size, self.batch_size, 89 | lambda x, y: seq2seq_f(x, y, tf.where(self.forward_only, True, False)), 90 | output_projection=output_projection, softmax_loss_function=softmax_loss_function) 91 | 92 | for b in xrange(len(self.buckets)): 93 | self.outputs[b] = [ 94 | tf.cond( 95 | self.forward_only, 96 | lambda: tf.matmul(output, output_projection[0]) + output_projection[1], 97 | lambda: output 98 | ) 99 | for output in self.outputs[b] 100 | ] 101 | 102 | if not forward_only: 103 | with tf.name_scope("gradient_descent"): 104 | self.gradient_norms = [] 105 | self.updates = [] 106 | self.aj_losses = [] 107 | self.gen_params = [p for p in tf.trainable_variables() if name_scope in p.name] 108 | opt = tf.train.AdamOptimizer() 109 | for b in xrange(len(self.buckets)): 110 | R = tf.subtract(self.reward[b], self.reward_bias) 111 | adjusted_loss = tf.cond(self.up_reward, 112 | lambda:tf.multiply(self.losses[b], self.reward[b]), 113 | lambda: self.losses[b]) 114 | 115 | self.aj_losses.append(adjusted_loss) 116 | gradients = tf.gradients(adjusted_loss, self.gen_params) 117 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, self.max_gradient_norm) 118 | self.gradient_norms.append(norm) 119 | self.updates.append(opt.apply_gradients( 120 | zip(clipped_gradients, self.gen_params), global_step=self.global_step)) 121 | 122 | self.gen_variables = [k for k in tf.global_variables() if name_scope in k.name] 123 | self.saver = tf.train.Saver(self.gen_variables) 124 | 125 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 126 | bucket_id, forward_only=True, reward=1, mc_search=False, up_reward=False, debug=True): 127 | # Check if the sizes match. 128 | encoder_size, decoder_size = self.buckets[bucket_id] 129 | if len(encoder_inputs) != encoder_size: 130 | raise ValueError("Encoder length must be equal to the one in bucket," 131 | " %d != %d." % (len(encoder_inputs), encoder_size)) 132 | if len(decoder_inputs) != decoder_size: 133 | raise ValueError("Decoder length must be equal to the one in bucket," 134 | " %d != %d." % (len(decoder_inputs), decoder_size)) 135 | if len(target_weights) != decoder_size: 136 | raise ValueError("Weights length must be equal to the one in bucket," 137 | " %d != %d." % (len(target_weights), decoder_size)) 138 | 139 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 140 | 141 | input_feed = { 142 | self.forward_only.name: forward_only, 143 | self.up_reward.name: up_reward, 144 | self.mc_search.name: mc_search 145 | } 146 | for l in xrange(len(self.buckets)): 147 | input_feed[self.reward[l].name] = reward 148 | for l in xrange(encoder_size): 149 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 150 | for l in xrange(decoder_size): 151 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 152 | input_feed[self.target_weights[l].name] = target_weights[l] 153 | 154 | # Since our targets are decoder inputs shifted by one, we need one more. 155 | last_target = self.decoder_inputs[decoder_size].name 156 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 157 | 158 | # Output feed: depends on whether we do a backward step or not. 159 | if not forward_only: # normal training 160 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 161 | self.aj_losses[bucket_id], # Gradient norm. 162 | self.losses[bucket_id]] # Loss for this batch. 163 | else: # testing or reinforcement learning 164 | output_feed = [self.encoder_state[bucket_id], self.losses[bucket_id]] # Loss for this batch. 165 | for l in xrange(decoder_size): # Output logits. 166 | output_feed.append(self.outputs[bucket_id][l]) 167 | 168 | outputs = session.run(output_feed, input_feed) 169 | if not forward_only: 170 | return outputs[1], outputs[2], outputs[0] # Gradient norm, loss, no outputs. 171 | else: 172 | return outputs[0], outputs[1], outputs[2:] # encoder_state, loss, outputs. 173 | 174 | def get_batch(self, train_data, bucket_id, batch_size, type=0): 175 | 176 | encoder_size, decoder_size = self.buckets[bucket_id] 177 | encoder_inputs, decoder_inputs = [], [] 178 | 179 | # pad them if needed, reverse encoder inputs and add GO to decoder. 180 | batch_source_encoder, batch_source_decoder = [], [] 181 | if type == 1: 182 | batch_size = 1 183 | for batch_i in xrange(batch_size): 184 | if type == 1: 185 | encoder_input, decoder_input = train_data[bucket_id] 186 | elif type == 2: 187 | encoder_input_a, decoder_input = train_data[bucket_id][0] 188 | encoder_input = encoder_input_a[batch_i] 189 | elif type == 0: 190 | encoder_input, decoder_input = random.choice(train_data[bucket_id]) 191 | 192 | batch_source_encoder.append(encoder_input) 193 | batch_source_decoder.append(decoder_input) 194 | # Encoder inputs are padded and then reversed. 195 | encoder_pad = [util.PAD_ID] * (encoder_size - len(encoder_input)) 196 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) 197 | 198 | # Decoder inputs get an extra "GO" symbol, and are padded then. 199 | decoder_pad_size = decoder_size - len(decoder_input) - 1 200 | decoder_inputs.append([util.GO_ID] + decoder_input + 201 | [util.PAD_ID] * decoder_pad_size) 202 | 203 | # Now we create batch-major vectors from the disc_data selected above. 204 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 205 | 206 | # Batch encoder inputs are just re-indexed encoder_inputs. 207 | for length_idx in xrange(encoder_size): 208 | batch_encoder_inputs.append( 209 | np.array([encoder_inputs[batch_idx][length_idx] 210 | for batch_idx in xrange(batch_size)], dtype=np.int32)) 211 | 212 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 213 | for length_idx in xrange(decoder_size): 214 | batch_decoder_inputs.append( 215 | np.array([decoder_inputs[batch_idx][length_idx] 216 | for batch_idx in xrange(batch_size)], dtype=np.int32)) 217 | 218 | # Create target_weights to be 0 for targets that are padding. 219 | batch_weight = np.ones(batch_size, dtype=np.float32) 220 | for batch_idx in xrange(batch_size): 221 | # We set weight to 0 if the corresponding target is a PAD symbol. 222 | # The corresponding target is decoder_input shifted by 1 forward. 223 | if length_idx < decoder_size - 1: 224 | target = decoder_inputs[batch_idx][length_idx + 1] 225 | if length_idx == decoder_size - 1 or target == util.PAD_ID: 226 | batch_weight[batch_idx] = 0.0 227 | batch_weights.append(batch_weight) 228 | 229 | return (batch_encoder_inputs, batch_decoder_inputs, batch_weights, batch_source_encoder, batch_source_decoder) 230 | -------------------------------------------------------------------------------- /model/seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from six.moves import xrange 3 | from six.moves import zip 4 | import tensorflow as tf 5 | from tensorflow.python import shape 6 | from tensorflow.python.framework import dtypes 7 | from tensorflow.python.framework import ops 8 | from tensorflow.python.ops import array_ops 9 | from tensorflow.python.ops import control_flow_ops 10 | from tensorflow.python.ops import embedding_ops 11 | from tensorflow.python.ops import math_ops 12 | from tensorflow.python.ops import nn_ops 13 | from tensorflow.python.ops import rnn 14 | from tensorflow.python.ops import rnn_cell 15 | from tensorflow.python.ops import variable_scope 16 | from tensorflow.python.util import nest 17 | 18 | try: 19 | from tensorflow.python.ops.rnn_cell_impl import _linear 20 | linear = _linear 21 | except: 22 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell 23 | linear = core_rnn_cell._linear 24 | 25 | 26 | def _argmax_or_mcsearch(embedding, output_projection=None, update_embedding=True, mc_search=False): 27 | def loop_function(prev, _): 28 | if output_projection is not None: 29 | prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1]) 30 | 31 | 32 | if isinstance(mc_search, bool): 33 | prev_symbol = tf.reshape(tf.multinomial(prev, 1), [-1]) if mc_search else math_ops.argmax(prev, 1) 34 | else: 35 | prev_symbol = tf.cond(mc_search, lambda: tf.reshape(tf.multinomial(prev, 1), [-1]), lambda: tf.argmax(prev, 1)) 36 | 37 | 38 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 39 | if not update_embedding: 40 | emb_prev = array_ops.stop_gradient(emb_prev) 41 | return emb_prev 42 | return loop_function 43 | 44 | def attention_decoder(decoder_inputs, 45 | initial_state, 46 | attention_states, 47 | cell, 48 | output_size=None, 49 | num_heads=1, 50 | loop_function=None, 51 | dtype=None, 52 | scope=None, 53 | initial_state_attention=False): 54 | 55 | if not decoder_inputs: 56 | raise ValueError("Must provide at least 1 input to attention decoder.") 57 | if num_heads < 1: 58 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 59 | if attention_states.get_shape()[2].value is None: 60 | raise ValueError("Shape[2] of attention_states must be known: %s" 61 | % attention_states.get_shape()) 62 | if output_size is None: 63 | output_size = cell.output_size 64 | 65 | with variable_scope.variable_scope( 66 | scope or "attention_decoder", dtype=dtype) as scope: 67 | dtype = scope.dtype 68 | 69 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 70 | attn_length = attention_states.get_shape()[1].value 71 | if attn_length is None: 72 | attn_length = shape(attention_states)[1] 73 | attn_size = attention_states.get_shape()[2].value 74 | 75 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 76 | hidden = array_ops.reshape( 77 | attention_states, [-1, attn_length, 1, attn_size]) 78 | hidden_features = [] 79 | v = [] 80 | attention_vec_size = attn_size # Size of query vectors for attention. 81 | for a in xrange(num_heads): 82 | k = variable_scope.get_variable("AttnW_%d" % a, 83 | [1, 1, attn_size, attention_vec_size]) 84 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 85 | v.append( 86 | variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) 87 | 88 | state = initial_state 89 | 90 | def attention(query): 91 | """Put attention masks on hidden using hidden_features and query.""" 92 | ds = [] # Results of attention reads will be stored here. 93 | if nest.is_sequence(query): # If the query is a tuple, flatten it. 94 | query_list = nest.flatten(query) 95 | for q in query_list: # Check that ndims == 2 if specified. 96 | ndims = q.get_shape().ndims 97 | if ndims: 98 | assert ndims == 2 99 | query = array_ops.concat(query_list, 1) 100 | for a in xrange(num_heads): 101 | with variable_scope.variable_scope("Attention_%d" % a): 102 | y = linear(query, attention_vec_size, True) 103 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 104 | # Attention mask is a softmax of v^T * tanh(...). 105 | s = math_ops.reduce_sum( 106 | v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 107 | a = nn_ops.softmax(s) 108 | # Now calculate the attention-weighted vector d. 109 | d = math_ops.reduce_sum( 110 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, 111 | [1, 2]) 112 | ds.append(array_ops.reshape(d, [-1, attn_size])) 113 | return ds 114 | 115 | outputs = [] 116 | prev = None 117 | batch_attn_size = array_ops.stack([batch_size, attn_size]) 118 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) 119 | for _ in xrange(num_heads)] 120 | for a in attns: # Ensure the second shape of attention vectors is set. 121 | a.set_shape([None, attn_size]) 122 | if initial_state_attention: 123 | attns = attention(initial_state) 124 | for i, inp in enumerate(decoder_inputs): 125 | if i > 0: 126 | variable_scope.get_variable_scope().reuse_variables() 127 | # If loop_function is set, we use it instead of decoder_inputs. 128 | if loop_function is not None and prev is not None: 129 | with variable_scope.variable_scope("loop_function", reuse=True): 130 | inp = loop_function(prev, i) 131 | # Merge input and previous attentions into one vector of the right size. 132 | input_size = inp.get_shape().with_rank(2)[1] 133 | if input_size.value is None: 134 | raise ValueError("Could not infer input size from input: %s" % inp.name) 135 | x = linear([inp] + attns, input_size, True) 136 | # Run the RNN. 137 | cell_output, state = cell(x, state) 138 | # Run the attention mechanism. 139 | if i == 0 and initial_state_attention: 140 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 141 | reuse=True): 142 | attns = attention(state) 143 | else: 144 | attns = attention(state) 145 | 146 | with variable_scope.variable_scope("AttnOutputProjection"): 147 | output = linear([cell_output] + attns, output_size, True) 148 | if loop_function is not None: 149 | prev = output 150 | outputs.append(output) 151 | 152 | return outputs, state 153 | 154 | def embedding_attention_decoder(decoder_inputs, 155 | initial_state, 156 | attention_states, 157 | cell, 158 | num_symbols, 159 | embedding_size, 160 | num_heads=1, 161 | output_size=None, 162 | output_projection=None, 163 | feed_previous=False, 164 | update_embedding_for_previous=True, 165 | dtype=None, 166 | scope=None, 167 | initial_state_attention=False, 168 | mc_search = False): 169 | 170 | if output_size is None: 171 | output_size = cell.output_size 172 | if output_projection is not None: 173 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 174 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 175 | 176 | with variable_scope.variable_scope( 177 | scope or "embedding_attention_decoder", dtype=dtype) as scope: 178 | 179 | embedding = variable_scope.get_variable("embedding", 180 | [num_symbols, embedding_size]) 181 | 182 | loop_function = None 183 | if feed_previous == True: 184 | loop_function = _argmax_or_mcsearch(embedding, output_projection, update_embedding_for_previous, mc_search) 185 | 186 | emb_inp = [ 187 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 188 | return attention_decoder( 189 | emb_inp, 190 | initial_state, 191 | attention_states, 192 | cell, 193 | output_size=output_size, 194 | num_heads=num_heads, 195 | loop_function=loop_function, 196 | initial_state_attention=initial_state_attention, 197 | scope=scope) 198 | 199 | def embedding_attention_seq2seq(encoder_inputs, 200 | decoder_inputs, 201 | cell, 202 | num_encoder_symbols, 203 | num_decoder_symbols, 204 | embedding_size, 205 | num_heads=1, 206 | output_projection=None, 207 | feed_previous=False, 208 | dtype=None, 209 | scope=None, 210 | initial_state_attention=False, 211 | mc_search=False): 212 | 213 | with variable_scope.variable_scope( 214 | scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 215 | dtype = scope.dtype 216 | # Encoder. 217 | encoder_cell = tf.contrib.rnn.EmbeddingWrapper( 218 | cell, embedding_classes=num_encoder_symbols, 219 | embedding_size=embedding_size) 220 | encoder_outputs, encoder_state = tf.contrib.rnn.static_rnn( 221 | encoder_cell, encoder_inputs, dtype=dtype) 222 | 223 | # First calculate a concatenation of encoder outputs to put attention on. 224 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) 225 | for e in encoder_outputs] 226 | attention_states = array_ops.concat(top_states, 1) 227 | 228 | # Decoder. 229 | output_size = None 230 | if output_projection is None: 231 | cell = tf.contrib.rnn.OutputProjectionWrapper(cell, num_decoder_symbols) 232 | output_size = num_decoder_symbols 233 | 234 | if isinstance(feed_previous, bool): 235 | outputs, state = embedding_attention_decoder( 236 | decoder_inputs, 237 | encoder_state, 238 | attention_states, 239 | cell, 240 | num_decoder_symbols, 241 | embedding_size, 242 | num_heads=num_heads, 243 | output_size=output_size, 244 | output_projection=output_projection, 245 | feed_previous=feed_previous, 246 | initial_state_attention=initial_state_attention, 247 | mc_search=mc_search, 248 | scope=scope) 249 | return outputs, state, encoder_state 250 | 251 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 252 | def decoder(feed_previous_bool): 253 | reuse = None if feed_previous_bool else True 254 | with variable_scope.variable_scope( 255 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 256 | outputs, state = embedding_attention_decoder( 257 | decoder_inputs, 258 | encoder_state, 259 | attention_states, 260 | cell, 261 | num_decoder_symbols, 262 | embedding_size, 263 | num_heads=num_heads, 264 | output_size=output_size, 265 | output_projection=output_projection, 266 | feed_previous=feed_previous_bool, 267 | update_embedding_for_previous=False, 268 | initial_state_attention=initial_state_attention, 269 | mc_search=mc_search, 270 | scope=scope) 271 | state_list = [state] 272 | if nest.is_sequence(state): 273 | state_list = nest.flatten(state) 274 | return outputs + state_list 275 | 276 | outputs_and_state = control_flow_ops.cond(feed_previous, 277 | lambda: decoder(True), 278 | lambda: decoder(False)) 279 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 280 | state_list = outputs_and_state[outputs_len:] 281 | state = state_list[0] 282 | if nest.is_sequence(encoder_state): 283 | state = nest.pack_sequence_as(structure=encoder_state, 284 | flat_sequence=state_list) 285 | return outputs_and_state[:outputs_len], state, encoder_state 286 | 287 | def sequence_loss_by_example(logits, targets, weights, 288 | average_across_timesteps=True, 289 | softmax_loss_function=None, name=None): 290 | if len(targets) != len(logits) or len(weights) != len(logits): 291 | raise ValueError("Lengths of logits, weights, and targets must be the same " 292 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 293 | with ops.name_scope(name, "sequence_loss_by_example", 294 | logits + targets + weights): 295 | log_perp_list = [] 296 | for logit, target, weight in zip(logits, targets, weights): 297 | if softmax_loss_function is None: 298 | # sequence_loss_by_example is called with scalars sometimes, which 299 | # violates our general scalar strictness policy. 300 | target = array_ops.reshape(target, [-1]) 301 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 302 | logits=logit, labels=target) 303 | else: 304 | crossent = softmax_loss_function(logit, target) 305 | log_perp_list.append(crossent * weight) 306 | log_perps = math_ops.add_n(log_perp_list) 307 | if average_across_timesteps: 308 | total_size = math_ops.add_n(weights) 309 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 310 | log_perps /= total_size 311 | return log_perps 312 | 313 | def sequence_loss(logits, targets, weights, 314 | average_across_timesteps=True, average_across_batch=True, 315 | softmax_loss_function=None, name=None): 316 | 317 | with ops.name_scope(name, "sequence_loss", logits + targets + weights): 318 | cost = math_ops.reduce_sum(sequence_loss_by_example( 319 | logits, targets, weights, 320 | average_across_timesteps=average_across_timesteps, 321 | softmax_loss_function=softmax_loss_function)) 322 | if average_across_batch: 323 | batch_size = array_ops.shape(targets[0])[0] 324 | return cost / math_ops.cast(batch_size, cost.dtype) 325 | else: 326 | return cost 327 | 328 | def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, buckets, vocab_size, batch_size, seq2seq, 329 | output_projection=None, softmax_loss_function=None, per_example_loss=False, name=None): 330 | if len(encoder_inputs) < buckets[-1][0]: 331 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 332 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 333 | if len(targets) < buckets[-1][1]: 334 | raise ValueError("Length of targets (%d) must be at least that of last" 335 | "bucket (%d)." % (len(targets), buckets[-1][1])) 336 | if len(weights) < buckets[-1][1]: 337 | raise ValueError("Length of weights (%d) must be at least that of last" 338 | "bucket (%d)." % (len(weights), buckets[-1][1])) 339 | 340 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 341 | losses = [] 342 | outputs = [] 343 | encoder_states = [] 344 | with ops.name_scope(name, "model_with_buckets", all_inputs): 345 | for j, bucket in enumerate(buckets): 346 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 347 | reuse=True if j > 0 else None): 348 | bucket_outputs, decoder_states, encoder_state = seq2seq(encoder_inputs[:bucket[0]], 349 | decoder_inputs[:bucket[1]]) 350 | outputs.append(bucket_outputs) 351 | encoder_states.append(encoder_state) 352 | if per_example_loss: 353 | losses.append(sequence_loss_by_example( 354 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 355 | softmax_loss_function=softmax_loss_function)) 356 | else: 357 | losses.append(sequence_loss(outputs[-1], targets[:bucket[1]], weights[:bucket[1]], softmax_loss_function=softmax_loss_function)) 358 | 359 | return outputs, losses, encoder_states 360 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import config 2 | import os 3 | import sys 4 | import numpy as np 5 | import tensorflow as tf 6 | import util 7 | from generator import * 8 | 9 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 11 | 12 | 13 | def test(gen_config): 14 | with tf.Session() as sess: 15 | model = create_model(sess, gen_config, forward_only=True, name_scope=gen_config.name_model) 16 | model.batch_size = 1 17 | 18 | train_path = os.path.join(gen_config.train_dir, "chitchat.train") 19 | voc_file_path = [train_path + ".answer", train_path + ".query"] 20 | vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size) 21 | util.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size) 22 | vocab, rev_vocab = util.initialize_vocabulary(vocab_path) 23 | 24 | sys.stdout.write("> ") 25 | sys.stdout.flush() 26 | sentence = sys.stdin.readline() 27 | while sentence: 28 | token_ids = util.sentence_to_token_ids(tf.compat.as_str_any(sentence), vocab) 29 | # print("token_id: ", token_ids) 30 | bucket_id = len(gen_config.buckets) - 1 31 | for i, bucket in enumerate(gen_config.buckets): 32 | if bucket[0] >= len(token_ids): 33 | bucket_id = i 34 | break 35 | # else: 36 | # print("Sentence truncated: %s", sentence) 37 | 38 | encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch({bucket_id: [(token_ids, [1])]}, 39 | bucket_id, model.batch_size, type=0) 40 | 41 | # print("bucket_id: ", bucket_id) 42 | # print("encoder_inputs:", encoder_inputs) 43 | # print("decoder_inputs:", decoder_inputs) 44 | # print("target_weights:", target_weights) 45 | 46 | _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) 47 | 48 | # print("output_logits", np.shape(output_logits)) 49 | 50 | outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] 51 | # print(outputs) 52 | if util.EOS_ID in outputs: 53 | outputs = outputs[:outputs.index(util.EOS_ID)] 54 | 55 | print(" ".join([tf.compat.as_str_any(rev_vocab[output]) for output in outputs])) 56 | print("> ", end="") 57 | sys.stdout.flush() 58 | sentence = sys.stdin.readline() 59 | 60 | if __name__ == '__main__': 61 | test(config.gen_config) 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | import sys 5 | import time 6 | from six.moves import xrange 7 | import generator as gens 8 | import discriminator as h_disc 9 | import random 10 | import config 11 | import util 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 14 | 15 | gen_config = config.gen_config 16 | disc_config = config.disc_config 17 | 18 | # prepare disc_data for discriminator and generator 19 | def disc_train_data(sess, gen_model, vocab, source_inputs, source_outputs, 20 | encoder_inputs, decoder_inputs, target_weights, bucket_id, mc_search=False): 21 | train_query, train_answer = [], [] 22 | query_len = gen_config.buckets[bucket_id][0] 23 | answer_len = gen_config.buckets[bucket_id][1] 24 | 25 | for query, answer in zip(source_inputs, source_outputs): 26 | query = query[:query_len] + [int(util.PAD_ID)] * (query_len - len(query) if query_len > len(query) else 0) 27 | train_query.append(query) 28 | answer = answer[:-1] # del tag EOS 29 | answer = answer[:answer_len] + [int(util.PAD_ID)] * (answer_len - len(answer) if answer_len > len(answer) else 0) 30 | train_answer.append(answer) 31 | train_labels = [1 for _ in source_inputs] 32 | 33 | 34 | def decoder(num_roll): 35 | for _ in xrange(num_roll): 36 | _, _, output_logits = gen_model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, 37 | forward_only=True, mc_search=mc_search) 38 | 39 | seq_tokens = [] 40 | resps = [] 41 | for seq in output_logits: 42 | row_token = [] 43 | for t in seq: 44 | row_token.append(int(np.argmax(t, axis=0))) 45 | seq_tokens.append(row_token) 46 | 47 | seq_tokens_t = [] 48 | for col in range(len(seq_tokens[0])): 49 | seq_tokens_t.append([seq_tokens[row][col] for row in range(len(seq_tokens))]) 50 | 51 | for seq in seq_tokens_t: 52 | if util.EOS_ID in seq: 53 | resps.append(seq[:seq.index(util.EOS_ID)][:gen_config.buckets[bucket_id][1]]) 54 | else: 55 | resps.append(seq[:gen_config.buckets[bucket_id][1]]) 56 | 57 | for i, output in enumerate(resps): 58 | output = output[:answer_len] + [util.PAD_ID] * (answer_len - len(output) if answer_len > len(output) else 0) 59 | train_query.append(train_query[i]) 60 | train_answer.append(output) 61 | train_labels.append(0) 62 | 63 | return train_query, train_answer, train_labels 64 | 65 | if mc_search: 66 | train_query, train_answer, train_labels = decoder(gen_config.beam_size) 67 | else: 68 | train_query, train_answer, train_labels = decoder(1) 69 | 70 | return train_query, train_answer, train_labels 71 | 72 | 73 | def softmax(x): 74 | prob = np.exp(x) / np.sum(np.exp(x), axis=0) 75 | return prob 76 | 77 | 78 | # discriminator api 79 | def disc_step(sess, bucket_id, disc_model, train_query, train_answer, train_labels, forward_only=False): 80 | feed_dict={} 81 | 82 | for i in xrange(len(train_query)): 83 | 84 | feed_dict[disc_model.query[i].name] = train_query[i] 85 | 86 | for i in xrange(len(train_answer)): 87 | feed_dict[disc_model.answer[i].name] = train_answer[i] 88 | 89 | feed_dict[disc_model.target.name]=train_labels 90 | 91 | loss = 0.0 92 | if forward_only: 93 | fetches = [disc_model.b_logits[bucket_id]] 94 | logits = sess.run(fetches, feed_dict) 95 | logits = logits[0] 96 | else: 97 | fetches = [disc_model.b_train_op[bucket_id], disc_model.b_loss[bucket_id], disc_model.b_logits[bucket_id]] 98 | train_op, loss, logits = sess.run(fetches,feed_dict) 99 | 100 | # softmax operation 101 | logits = np.transpose(softmax(np.transpose(logits))) 102 | 103 | reward, gen_num = 0.0, 0 104 | for logit, label in zip(logits, train_labels): 105 | if int(label) == 0: 106 | reward += logit[1] 107 | gen_num += 1 108 | reward = reward / gen_num 109 | 110 | return reward, loss 111 | 112 | 113 | # Adversarial Learning for Neural Dialogue Generation 114 | def al_train(): 115 | with tf.Session() as sess: 116 | 117 | vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config) 118 | for set in train_set: 119 | print("al train len: ", len(set)) 120 | 121 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(gen_config.buckets))] 122 | train_total_size = float(sum(train_bucket_sizes)) 123 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 124 | for i in xrange(len(train_bucket_sizes))] 125 | 126 | disc_model = h_disc.create_model(sess, disc_config, disc_config.name_model) 127 | gen_model = gens.create_model(sess, gen_config, forward_only=False, name_scope=gen_config.name_model) 128 | 129 | current_step = 0 130 | step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0 131 | gen_loss_summary = tf.Summary() 132 | disc_loss_summary = tf.Summary() 133 | 134 | while True: 135 | current_step += 1 136 | start_time = time.time() 137 | random_number_01 = np.random.random_sample() 138 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) 139 | if train_buckets_scale[i] > random_number_01]) 140 | 141 | print("==================Update Discriminator: %d=====================" % current_step) 142 | # 1.Sample (X,Y) from real disc_data 143 | encoder_inputs, decoder_inputs, target_weights, source_inputs, source_outputs = gen_model.get_batch(train_set, bucket_id, gen_config.batch_size) 144 | 145 | # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) 146 | train_query, train_answer, train_labels = disc_train_data(sess, gen_model, vocab, source_inputs, source_outputs, 147 | encoder_inputs, decoder_inputs, target_weights, bucket_id, mc_search=False) 148 | if current_step % 200 == 0: 149 | print("train_query: ", len(train_query)) 150 | print("train_answer: ", len(train_answer)) 151 | print("train_labels: ", len(train_labels)) 152 | for i in xrange(len(train_query)): 153 | print("label: ", train_labels[i]) 154 | print("train_answer_sentence: ", train_answer[i]) 155 | print(" ".join([tf.compat.as_str(rev_vocab[output]) for output in train_answer[i]])) 156 | 157 | train_query = np.transpose(train_query) 158 | train_answer = np.transpose(train_answer) 159 | 160 | # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples 161 | _, disc_step_loss = disc_step(sess, bucket_id, disc_model, train_query, train_answer, train_labels, forward_only=False) 162 | disc_loss += disc_step_loss / disc_config.steps_per_checkpoint 163 | 164 | print("==================Update Generator: %d=========================" % current_step) 165 | # 1.Sample (X,Y) from real disc_data 166 | update_gen_data = gen_model.get_batch(train_set, bucket_id, gen_config.batch_size) 167 | encoder, decoder, weights, source_inputs, source_outputs = update_gen_data 168 | 169 | # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) with Monte Carlo search 170 | train_query, train_answer, train_labels = disc_train_data(sess, gen_model, vocab, source_inputs, source_outputs, 171 | encoder, decoder, weights, bucket_id, mc_search=True) 172 | 173 | if current_step % 200 == 0: 174 | for i in xrange(len(train_query)): 175 | print("label: ", train_labels[i]) 176 | print(" ".join([tf.compat.as_str(rev_vocab[output]) for output in train_answer[i]])) 177 | 178 | train_query = np.transpose(train_query) 179 | train_answer = np.transpose(train_answer) 180 | 181 | # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search 182 | reward, _ = disc_step(sess, bucket_id, disc_model, train_query, train_answer, train_labels, forward_only=True) 183 | batch_reward += reward / gen_config.steps_per_checkpoint 184 | print("step_reward: ", reward) 185 | 186 | # 4.Update G on (X, ^Y ) using reward r 187 | gan_adjusted_loss, gen_step_loss, _ =gen_model.step(sess, encoder, decoder, weights, bucket_id, forward_only=False, 188 | reward=reward, up_reward=True, debug=True) 189 | gen_loss += gen_step_loss / gen_config.steps_per_checkpoint 190 | 191 | print("gen_step_loss: ", gen_step_loss) 192 | print("gen_step_adjusted_loss: ", gan_adjusted_loss) 193 | 194 | # 5.Teacher-Forcing: Update G on (X, Y ) 195 | t_adjusted_loss, t_step_loss, a = gen_model.step(sess, encoder, decoder, weights, bucket_id, forward_only=False) 196 | t_loss += t_step_loss / gen_config.steps_per_checkpoint 197 | 198 | print("t_step_loss: ", t_step_loss) 199 | print("t_adjusted_loss", t_adjusted_loss) 200 | 201 | if current_step % gen_config.steps_per_checkpoint == 0: 202 | 203 | step_time += (time.time() - start_time) / gen_config.steps_per_checkpoint 204 | 205 | print("current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f" 206 | %(current_step, step_time, disc_loss, gen_loss, t_loss, batch_reward)) 207 | 208 | disc_loss_value = disc_loss_summary.value.add() 209 | disc_loss_value.tag = disc_config.name_loss 210 | disc_loss_value.simple_value = float(disc_loss) 211 | 212 | gen_global_steps = sess.run(gen_model.global_step) 213 | gen_loss_value = gen_loss_summary.value.add() 214 | gen_loss_value.tag = gen_config.name_loss 215 | gen_loss_value.simple_value = float(gen_loss) 216 | t_loss_value = gen_loss_summary.value.add() 217 | t_loss_value.tag = gen_config.teacher_loss 218 | t_loss_value.simple_value = float(t_loss) 219 | batch_reward_value = gen_loss_summary.value.add() 220 | batch_reward_value.tag = gen_config.reward_name 221 | batch_reward_value.simple_value = float(batch_reward) 222 | 223 | if current_step % (gen_config.steps_per_checkpoint * 2) == 0: 224 | print("current_steps: %d, save disc model" % current_step) 225 | disc_ckpt_dir = os.path.abspath(os.path.join(disc_config.train_dir, "checkpoints")) 226 | if not os.path.exists(disc_ckpt_dir): 227 | os.makedirs(disc_ckpt_dir) 228 | disc_model_path = os.path.join(disc_ckpt_dir, "disc.model") 229 | disc_model.saver.save(sess, disc_model_path, global_step=disc_model.global_step) 230 | 231 | print("current_steps: %d, save gen model" % current_step) 232 | gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) 233 | if not os.path.exists(gen_ckpt_dir): 234 | os.makedirs(gen_ckpt_dir) 235 | gen_model_path = os.path.join(gen_ckpt_dir, "gen.model") 236 | gen_model.saver.save(sess, gen_model_path, global_step=gen_model.global_step) 237 | 238 | step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0 239 | sys.stdout.flush() 240 | 241 | 242 | def main(_): 243 | al_train() 244 | 245 | if __name__ == "__main__": 246 | tf.app.run() 247 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from six.moves import urllib 4 | from tensorflow.python.platform import gfile 5 | import tensorflow as tf 6 | 7 | # Special vocabulary symbols - we always put them at the start. 8 | _PAD = "_PAD" 9 | _GO = "_GO" 10 | _EOS = "_EOS" 11 | _UNK = "_UNK" 12 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 13 | 14 | PAD_ID = 0 15 | GO_ID = 1 16 | EOS_ID = 2 17 | UNK_ID = 3 18 | 19 | # Regular expressions used to tokenize. 20 | _WORD_SPLIT = re.compile(r"([.,!?\"':;)(])") 21 | _DIGIT_RE = re.compile(r"\d") 22 | 23 | def basic_tokenizer(sentence): 24 | words = [] 25 | for space_separated_fragment in sentence.strip().split(): 26 | if type(space_separated_fragment) == bytes: 27 | space_separated_fragment = space_separated_fragment.decode() 28 | words.extend(_WORD_SPLIT.split(space_separated_fragment)) 29 | return [w for w in words if w] 30 | 31 | def create_vocabulary(vocabulary_path, data_path_list, max_vocabulary_size, 32 | tokenizer=None, normalize_digits=True): 33 | if not gfile.Exists(vocabulary_path): 34 | print("Creating vocabulary %s from disc_data %s" % (vocabulary_path, data_path_list)) 35 | vocab = {} 36 | for data_path in data_path_list: 37 | with gfile.GFile(data_path, mode="r") as f: 38 | counter = 0 39 | for line in f: 40 | counter += 1 41 | if counter % 100000 == 0: 42 | print(" processing line %d" % counter) 43 | line = tf.compat.as_str_any(line) 44 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) 45 | for w in tokens: 46 | word = _DIGIT_RE.sub("0", w) if normalize_digits else w 47 | if word in vocab: 48 | vocab[word] += 1 49 | else: 50 | vocab[word] = 1 51 | 52 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 53 | if len(vocab_list) > max_vocabulary_size: 54 | vocab_list = vocab_list[:max_vocabulary_size] 55 | with gfile.GFile(vocabulary_path, mode="w") as vocab_file: 56 | for w in vocab_list: 57 | vocab_file.write(w + "\n") 58 | 59 | def initialize_vocabulary(vocabulary_path): 60 | if gfile.Exists(vocabulary_path): 61 | rev_vocab = [] 62 | with gfile.GFile(vocabulary_path, mode="r") as f: 63 | rev_vocab.extend(f.readlines()) 64 | rev_vocab = [line.strip() for line in rev_vocab] 65 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 66 | return vocab, rev_vocab 67 | else: 68 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 69 | 70 | def sentence_to_token_ids(sentence, vocabulary, 71 | tokenizer=None, normalize_digits=True): 72 | if tokenizer: 73 | words = tokenizer(sentence) 74 | else: 75 | words = basic_tokenizer(sentence) 76 | if not normalize_digits: 77 | return [vocabulary.get(w, UNK_ID) for w in words] 78 | # Normalize digits by 0 before looking words up in the vocabulary. 79 | return [vocabulary.get(_DIGIT_RE.sub("0", w), UNK_ID) for w in words] 80 | 81 | def data_to_token_ids(data_path, target_path, vocabulary, 82 | tokenizer=None, normalize_digits=True): 83 | if not gfile.Exists(target_path): 84 | print("Tokenizing disc_data in %s" % data_path) 85 | with gfile.GFile(data_path, mode="r") as data_file: 86 | with gfile.GFile(target_path, mode="w") as tokens_file: 87 | counter = 0 88 | for line in data_file: 89 | counter += 1 90 | if counter % 100000 == 0: 91 | print(" tokenizing line %d" % counter) 92 | token_ids = sentence_to_token_ids(line, vocabulary, tokenizer, 93 | normalize_digits) 94 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 95 | 96 | def prepare_chitchat_data(data_dir, vocabulary, vocabulary_size, tokenizer=None): 97 | train_path = os.path.join(data_dir, "chitchat.train") 98 | dev_path = os.path.join(data_dir, "chitchat.dev") 99 | 100 | answer_train_ids_path = train_path + (".ids%d.answer" % vocabulary_size) 101 | query_train_ids_path = train_path + (".ids%d.query" % vocabulary_size) 102 | data_to_token_ids(train_path + ".answer", answer_train_ids_path, vocabulary, tokenizer) 103 | data_to_token_ids(train_path + ".query", query_train_ids_path, vocabulary, tokenizer) 104 | 105 | # Create token ids for the development disc_data. 106 | answer_dev_ids_path = dev_path + (".ids%d.answer" % vocabulary_size) 107 | query_dev_ids_path = dev_path + (".ids%d.query" % vocabulary_size) 108 | data_to_token_ids(dev_path + ".answer", answer_dev_ids_path, vocabulary, tokenizer) 109 | data_to_token_ids(dev_path + ".query", query_dev_ids_path, vocabulary, tokenizer) 110 | 111 | return (query_train_ids_path, answer_train_ids_path, 112 | query_dev_ids_path, answer_dev_ids_path) 113 | 114 | def hier_prepare_disc_data(data_dir, vocabulary, vocabulary_size, tokenizer=None): 115 | train_path = os.path.join(data_dir, "train") 116 | dev_path = os.path.join(data_dir, "dev") 117 | 118 | # Create token ids for the training disc_data. 119 | query_train_ids_path = train_path + (".ids%d.query" % vocabulary_size) 120 | answer_train_ids_path = train_path + (".ids%d.answer" % vocabulary_size) 121 | gen_train_ids_path = train_path + (".ids%d.gen" % vocabulary_size) 122 | 123 | data_to_token_ids(train_path + ".query", query_train_ids_path, vocabulary, tokenizer) 124 | data_to_token_ids(train_path + ".answer", answer_train_ids_path, vocabulary, tokenizer) 125 | data_to_token_ids(train_path + ".gen", gen_train_ids_path, vocabulary, tokenizer) 126 | 127 | # Create token ids for the development disc_data. 128 | query_dev_ids_path = dev_path + (".ids%d.query" % vocabulary_size) 129 | answer_dev_ids_path = dev_path + (".ids%d.answer" % vocabulary_size) 130 | gen_dev_ids_path = dev_path + (".ids%d.gen" % vocabulary_size) 131 | 132 | data_to_token_ids(dev_path + ".query", query_dev_ids_path, vocabulary, tokenizer) 133 | data_to_token_ids(dev_path + ".answer", answer_dev_ids_path, vocabulary, tokenizer) 134 | data_to_token_ids(dev_path + ".gen", gen_dev_ids_path, vocabulary, tokenizer) 135 | 136 | return (query_train_ids_path, answer_train_ids_path, gen_train_ids_path, 137 | query_dev_ids_path, answer_dev_ids_path, gen_dev_ids_path) 138 | -------------------------------------------------------------------------------- /说明文档.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aqzheng/Adversarial-Learning-for-Neural-Dialogue-Generation/57007ed96251030a469367d14ba49340f1b2362e/说明文档.pdf --------------------------------------------------------------------------------