├── .idea ├── dictionaries │ └── liuchong.xml ├── misc.xml ├── modules.xml └── seq2seq_chatbot.iml ├── README.md ├── __init__.py ├── chatbot.py ├── data └── dataset-cornell-length10-filter1-vocabSize40000.pkl ├── data_utils.py ├── seq2seq.py └── seq2seq_model.py /.idea/dictionaries/liuchong.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ApexVCS 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/seq2seq_chatbot.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | =================================================更新=========================================================== 2 | 训练好的模型已经上传到百度云网盘,如果大家有需要可以前去下载。模型训练速度的话,CPU,16G内存,一天即刻训练完成~~~ 3 | 4 | 链接:https://pan.baidu.com/s/1hrNxaSk 密码:d2sn 5 | 6 | =================================================分割线,下面是正文=============================================== 7 | 8 | 本文是一个简单的基于seq2seq模型的chatbot对话系统的tensorflow实现。 9 | 10 | 代码的讲解可以参考我的知乎专栏文章: 11 | 12 | [从头实现深度学习的对话系统--简单chatbot代码实现](https://zhuanlan.zhihu.com/p/32455898) 13 | 14 | 代码参考了DeepQA,在其基础上添加了beam search的功能和attention的机制, 15 | 16 | 最终的效果如下图所示: 17 | 18 | ![](https://i.imgur.com/pN7AfAB.png) 19 | 20 | ![](https://i.imgur.com/RnvBDwO.png) 21 | 22 | 测试效果,根据用户输入回复概率最大的前beam_size个句子: 23 | 24 | ![](https://i.imgur.com/EdsQ5FE.png) 25 | 26 | #使用方法 27 | 28 | 1,下载代码到本地(data文件夹下已经包含了处理好的数据集,所以无需额外下载数据集) 29 | 30 | 2,训练模型,将chatbot.py文件第34行的decode参数修改为False,进行训练模型 31 | 32 | (之后我会把我这里训练好的模型上传到网上方便大家使用) 33 | 34 | 3,训练完之后(大概要一天左右的时间,30个epoches),再将decode参数修改为True 35 | 36 | 就可以进行测试了。输入你想问的话看他回复什么吧== 37 | 38 | 这里还需要注意的就是要记得修改数据集和最后模型文件的绝对路径,不然可能会报错。 39 | 40 | 分别在44行,57行,82行三处。好了,接下来就可以愉快的玩耍了~~ 41 | 42 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lc222/seq2seq_chatbot/7a419e4e9587d9a87b4acb0e141789b329146fe1/__init__.py -------------------------------------------------------------------------------- /chatbot.py: -------------------------------------------------------------------------------- 1 | """Most of the code comes from seq2seq tutorial. Binary for training conversation models and decoding from them. 2 | 3 | Running this program without --decode will tokenize it in a very basic way, 4 | and then start training a model saving checkpoints to --train_dir. 5 | 6 | Running with --decode starts an interactive loop so you can see how 7 | the current checkpoint performs 8 | 9 | See the following papers for more information on neural translation models. 10 | * http://arxiv.org/abs/1409.3215 11 | * http://arxiv.org/abs/1409.0473 12 | * http://arxiv.org/abs/1412.2007 13 | """ 14 | 15 | import math 16 | import sys 17 | import time 18 | from data_utils import * 19 | from seq2seq_model import * 20 | from tqdm import tqdm 21 | 22 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") 23 | tf.app.flags.DEFINE_integer("batch_size", 256, "Batch size to use during training.") 24 | tf.app.flags.DEFINE_integer("numEpochs", 30, "Batch size to use during training.") 25 | tf.app.flags.DEFINE_integer("size", 512, "Size of each model layer.") 26 | tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.") 27 | tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.") 28 | tf.app.flags.DEFINE_integer("en_de_seq_len", 20, "English vocabulary size.") 29 | tf.app.flags.DEFINE_integer("max_train_data_size", 0, "Limit on the size of training data (0: no limit).") 30 | tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "How many training steps to do per checkpoint.") 31 | tf.app.flags.DEFINE_string("train_dir", './tmp', "How many training steps to do per checkpoint.") 32 | tf.app.flags.DEFINE_integer("beam_size", 5, "How many training steps to do per checkpoint.") 33 | tf.app.flags.DEFINE_boolean("beam_search", True, "Set to True for beam_search.") 34 | tf.app.flags.DEFINE_boolean("decode", True, "Set to True for interactive decoding.") 35 | FLAGS = tf.app.flags.FLAGS 36 | 37 | def create_model(session, forward_only, beam_search, beam_size = 5): 38 | """Create translation model and initialize or load parameters in session.""" 39 | model = Seq2SeqModel( 40 | FLAGS.en_vocab_size, FLAGS.en_vocab_size, [10, 10], 41 | FLAGS.size, FLAGS.num_layers, FLAGS.batch_size, 42 | FLAGS.learning_rate, forward_only=forward_only, beam_search=beam_search, beam_size=beam_size) 43 | ckpt = tf.train.latest_checkpoint(FLAGS.train_dir) 44 | model_path = 'E:\PycharmProjects\Seq-to-Seq\seq2seq_chatbot\\tmp\chat_bot.ckpt-0' 45 | if forward_only: 46 | model.saver.restore(session, model_path) 47 | elif ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): 48 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 49 | model.saver.restore(session, ckpt.model_checkpoint_path) 50 | else: 51 | print("Created model with fresh parameters.") 52 | session.run(tf.initialize_all_variables()) 53 | return model 54 | 55 | def train(): 56 | # prepare dataset 57 | data_path = 'E:\PycharmProjects\Seq-to-Seq\seq2seq_chatbot\data\dataset-cornell-length10-filter1-vocabSize40000.pkl' 58 | word2id, id2word, trainingSamples = loadDataset(data_path) 59 | with tf.Session() as sess: 60 | print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) 61 | model = create_model(sess, False, beam_search=False, beam_size=5) 62 | current_step = 0 63 | for e in range(FLAGS.numEpochs): 64 | print("----- Epoch {}/{} -----".format(e + 1, FLAGS.numEpochs)) 65 | batches = getBatches(trainingSamples, FLAGS.batch_size, model.en_de_seq_len) 66 | for nextBatch in tqdm(batches, desc="Training"): 67 | _, step_loss = model.step(sess, nextBatch.encoderSeqs, nextBatch.decoderSeqs, nextBatch.targetSeqs, 68 | nextBatch.weights, goToken) 69 | current_step += 1 70 | if current_step % FLAGS.steps_per_checkpoint == 0: 71 | perplexity = math.exp(float(step_loss)) if step_loss < 300 else float('inf') 72 | tqdm.write("----- Step %d -- Loss %.2f -- Perplexity %.2f" % (current_step, step_loss, perplexity)) 73 | checkpoint_path = os.path.join(FLAGS.train_dir, "chat_bot.ckpt") 74 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 75 | 76 | def decode(): 77 | with tf.Session() as sess: 78 | beam_size = FLAGS.beam_size 79 | beam_search = FLAGS.beam_search 80 | model = create_model(sess, True, beam_search=beam_search, beam_size=beam_size) 81 | model.batch_size = 1 82 | data_path = 'E:\PycharmProjects\Seq-to-Seq\seq2seq_chatbot\data\dataset-cornell-length10-filter1-vocabSize40000.pkl' 83 | word2id, id2word, trainingSamples = loadDataset(data_path) 84 | 85 | if beam_search: 86 | sys.stdout.write("> ") 87 | sys.stdout.flush() 88 | sentence = sys.stdin.readline() 89 | while sentence: 90 | batch = sentence2enco(sentence, word2id, model.en_de_seq_len) 91 | beam_path, beam_symbol = model.step(sess, batch.encoderSeqs, batch.decoderSeqs, batch.targetSeqs, 92 | batch.weights, goToken) 93 | paths = [[] for _ in range(beam_size)] 94 | curr = [i for i in range(beam_size)] 95 | num_steps = len(beam_path) 96 | for i in range(num_steps-1, -1, -1): 97 | for kk in range(beam_size): 98 | paths[kk].append(beam_symbol[i][curr[kk]]) 99 | curr[kk] = beam_path[i][curr[kk]] 100 | recos = set() 101 | print("Replies --------------------------------------->") 102 | for kk in range(beam_size): 103 | foutputs = [int(logit) for logit in paths[kk][::-1]] 104 | if eosToken in foutputs: 105 | foutputs = foutputs[:foutputs.index(eosToken)] 106 | rec = " ".join([tf.compat.as_str(id2word[output]) for output in foutputs if output in id2word]) 107 | if rec not in recos: 108 | recos.add(rec) 109 | print(rec) 110 | print("> ", "") 111 | sys.stdout.flush() 112 | sentence = sys.stdin.readline() 113 | # else: 114 | # sys.stdout.write("> ") 115 | # sys.stdout.flush() 116 | # sentence = sys.stdin.readline() 117 | # 118 | # while sentence: 119 | # # Get token-ids for the input sentence. 120 | # token_ids = sentence_to_token_ids(tf.compat.as_bytes(sentence), vocab) 121 | # # Which bucket does it belong to? 122 | # bucket_id = min([b for b in xrange(len(_buckets)) 123 | # if _buckets[b][0] > len(token_ids)]) 124 | # # for loc in locs: 125 | # # Get a 1-element batch to feed the sentence to the model. 126 | # encoder_inputs, decoder_inputs, target_weights = model.get_batch( 127 | # {bucket_id: [(token_ids, [],)]}, bucket_id) 128 | # 129 | # _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, 130 | # target_weights, bucket_id, True,beam_search) 131 | # # This is a greedy decoder - outputs are just argmaxes of output_logits. 132 | # 133 | # outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] 134 | # # If there is an EOS symbol in outputs, cut them at that point. 135 | # if EOS_ID in outputs: 136 | # # print outputs 137 | # outputs = outputs[:outputs.index(EOS_ID)] 138 | # 139 | # print(" ".join([tf.compat.as_str(rev_vocab[output]) for output in outputs])) 140 | # print("> ", "") 141 | # sys.stdout.flush() 142 | # sentence = sys.stdin.readline() 143 | 144 | def main(_): 145 | if FLAGS.decode: 146 | decode() 147 | else: 148 | train() 149 | 150 | if __name__ == "__main__": 151 | tf.app.run() 152 | -------------------------------------------------------------------------------- /data/dataset-cornell-length10-filter1-vocabSize40000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lc222/seq2seq_chatbot/7a419e4e9587d9a87b4acb0e141789b329146fe1/data/dataset-cornell-length10-filter1-vocabSize40000.pkl -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import nltk 7 | 8 | import pickle 9 | import random 10 | 11 | padToken, goToken, eosToken, unknownToken = 0, 1, 2, 3 12 | 13 | class Batch: 14 | #batch类,里面包含了encoder输入,decoder输入,decoder标签,decoder样本长度mask 15 | def __init__(self): 16 | self.encoderSeqs = [] 17 | self.decoderSeqs = [] 18 | self.targetSeqs = [] 19 | self.weights = [] 20 | 21 | def loadDataset(filename): 22 | ''' 23 | 读取样本数据 24 | :param filename: 文件路径,是一个字典,包含word2id、id2word分别是单词与索引对应的字典和反序字典, 25 | trainingSamples样本数据,每一条都是QA对 26 | :return: word2id, id2word, trainingSamples 27 | ''' 28 | dataset_path = os.path.join(filename) 29 | print('Loading dataset from {}'.format(dataset_path)) 30 | with open(dataset_path, 'rb') as handle: 31 | data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset 32 | word2id = data['word2id'] 33 | id2word = data['id2word'] 34 | trainingSamples = data['trainingSamples'] 35 | return word2id, id2word, trainingSamples 36 | 37 | def createBatch(samples, en_de_seq_len): 38 | ''' 39 | 根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式 40 | :param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id 41 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度 42 | :return: 处理完之后可以直接传入feed_dict的数据格式 43 | ''' 44 | batch = Batch() 45 | #根据样本长度获得batch size大小 46 | batchSize = len(samples) 47 | #将每条数据的问题和答案分开传入到相应的变量中 48 | for i in range(batchSize): 49 | sample = samples[i] 50 | batch.encoderSeqs.append(list(reversed(sample[0]))) # 将输入反序,可提高模型效果 51 | batch.decoderSeqs.append([goToken] + sample[1] + [eosToken]) # Add the and tokens 52 | batch.targetSeqs.append(batch.decoderSeqs[-1][1:]) # Same as decoder, but shifted to the left (ignore the ) 53 | # 将每个元素PAD到指定长度,并构造weights序列长度mask标志 54 | batch.encoderSeqs[i] = [padToken] * (en_de_seq_len[0] - len(batch.encoderSeqs[i])) + batch.encoderSeqs[i] 55 | batch.weights.append([1.0] * len(batch.targetSeqs[i]) + [0.0] * (en_de_seq_len[1] - len(batch.targetSeqs[i]))) 56 | batch.decoderSeqs[i] = batch.decoderSeqs[i] + [padToken] * (en_de_seq_len[1] - len(batch.decoderSeqs[i])) 57 | batch.targetSeqs[i] = batch.targetSeqs[i] + [padToken] * (en_de_seq_len[1] - len(batch.targetSeqs[i])) 58 | 59 | #--------------------接下来就是将数据进行reshape操作,变成序列长度*batch_size格式的数据------------------------ 60 | encoderSeqsT = [] # Corrected orientation 61 | for i in range(en_de_seq_len[0]): 62 | encoderSeqT = [] 63 | for j in range(batchSize): 64 | encoderSeqT.append(batch.encoderSeqs[j][i]) 65 | encoderSeqsT.append(encoderSeqT) 66 | batch.encoderSeqs = encoderSeqsT 67 | 68 | decoderSeqsT = [] 69 | targetSeqsT = [] 70 | weightsT = [] 71 | for i in range(en_de_seq_len[1]): 72 | decoderSeqT = [] 73 | targetSeqT = [] 74 | weightT = [] 75 | for j in range(batchSize): 76 | decoderSeqT.append(batch.decoderSeqs[j][i]) 77 | targetSeqT.append(batch.targetSeqs[j][i]) 78 | weightT.append(batch.weights[j][i]) 79 | decoderSeqsT.append(decoderSeqT) 80 | targetSeqsT.append(targetSeqT) 81 | weightsT.append(weightT) 82 | batch.decoderSeqs = decoderSeqsT 83 | batch.targetSeqs = targetSeqsT 84 | batch.weights = weightsT 85 | 86 | return batch 87 | 88 | def getBatches(data, batch_size, en_de_seq_len): 89 | ''' 90 | 根据读取出来的所有数据和batch_size将原始数据分成不同的小batch。对每个batch索引的样本调用createBatch函数进行处理 91 | :param data: loadDataset函数读取之后的trainingSamples,就是QA对的列表 92 | :param batch_size: batch大小 93 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度 94 | :return: 列表,每个元素都是一个batch的样本数据,可直接传入feed_dict进行训练 95 | ''' 96 | #每个epoch之前都要进行样本的shuffle 97 | random.shuffle(data) 98 | batches = [] 99 | data_len = len(data) 100 | def genNextSamples(): 101 | for i in range(0, data_len, batch_size): 102 | yield data[i:min(i + batch_size, data_len)] 103 | 104 | for samples in genNextSamples(): 105 | batch = createBatch(samples, en_de_seq_len) 106 | batches.append(batch) 107 | return batches 108 | 109 | def sentence2enco(sentence, word2id, en_de_seq_len): 110 | ''' 111 | 测试的时候将用户输入的句子转化为可以直接feed进模型的数据,现将句子转化成id,然后调用createBatch处理 112 | :param sentence: 用户输入的句子 113 | :param word2id: 单词与id之间的对应关系字典 114 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度 115 | :return: 处理之后的数据,可直接feed进模型进行预测 116 | ''' 117 | if sentence == '': 118 | return None 119 | #分词 120 | tokens = nltk.word_tokenize(sentence) 121 | if len(tokens) > en_de_seq_len[0]: 122 | return None 123 | #将每个单词转化为id 124 | wordIds = [] 125 | for token in tokens: 126 | wordIds.append(word2id.get(token, unknownToken)) 127 | #调用createBatch构造batch 128 | batch = createBatch([[wordIds, []]], en_de_seq_len) 129 | return batch 130 | -------------------------------------------------------------------------------- /seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import copy 7 | 8 | # We disable pylint because we need python3 compatibility. 9 | from six.moves import xrange # pylint: disable=redefined-builtin 10 | from six.moves import zip # pylint: disable=redefined-builtin 11 | 12 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell 13 | from tensorflow.python.framework import dtypes 14 | from tensorflow.python.framework import ops 15 | from tensorflow.python.ops import array_ops 16 | from tensorflow.python.ops import control_flow_ops 17 | from tensorflow.python.ops import embedding_ops 18 | from tensorflow.python.ops import math_ops 19 | from tensorflow.python.ops import nn_ops 20 | from tensorflow.python.ops import rnn 21 | from tensorflow.python.ops import rnn_cell_impl 22 | from tensorflow.python.ops import variable_scope 23 | from tensorflow.python.util import nest 24 | 25 | Linear = rnn_cell_impl._Linear # pylint: disable=protected-access,invalid-name 26 | 27 | def _extract_beam_search(embedding, beam_size, num_symbols, embedding_size, output_projection=None): 28 | 29 | def loop_function(prev, i, log_beam_probs, beam_path, beam_symbols): 30 | if output_projection is not None: 31 | prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1]) 32 | # 对输出概率进行归一化和取log,这样序列概率相乘就可以变成概率相加 33 | probs = tf.log(tf.nn.softmax(prev)) 34 | if i == 1: 35 | probs = tf.reshape(probs[0, :], [-1, num_symbols]) 36 | if i > 1: 37 | # 将当前序列的概率与之前序列概率相加得到结果之前有beam_szie个序列,本次产生num_symbols个结果, 38 | # 所以reshape成这样的tensor 39 | probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols]) 40 | # 选出概率最大的前beam_size个序列,从beam_size * num_symbols个元素中选出beam_size个 41 | best_probs, indices = tf.nn.top_k(probs, beam_size) 42 | indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1]))) 43 | best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1])) 44 | 45 | # beam_size * num_symbols,看对应的是哪个序列和单词 46 | symbols = indices % num_symbols # Which word in vocabulary. 47 | beam_parent = indices // num_symbols # Which hypothesis it came from. 48 | beam_symbols.append(symbols) 49 | beam_path.append(beam_parent) 50 | log_beam_probs.append(best_probs) 51 | 52 | # 对beam-search选出的beam size个单词进行embedding,得到相应的词向量 53 | emb_prev = embedding_ops.embedding_lookup(embedding, symbols) 54 | emb_prev = tf.reshape(emb_prev, [-1, embedding_size]) 55 | return emb_prev 56 | 57 | return loop_function 58 | 59 | def beam_attention_decoder(decoder_inputs, 60 | initial_state, 61 | attention_states, 62 | cell, 63 | embedding, 64 | output_size=None, 65 | num_heads=1, 66 | loop_function=None, 67 | dtype=None, 68 | scope=None, 69 | initial_state_attention=False, output_projection=None, beam_size=10): 70 | if not decoder_inputs: 71 | raise ValueError("Must provide at least 1 input to attention decoder.") 72 | if num_heads < 1: 73 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 74 | if not attention_states.get_shape()[1:2].is_fully_defined(): 75 | raise ValueError("Shape[1] and [2] of attention_states must be known: %s" 76 | % attention_states.get_shape()) 77 | if output_size is None: 78 | output_size = cell.output_size 79 | 80 | with variable_scope.variable_scope(scope or "attention_decoder", dtype=dtype) as scope: 81 | dtype = scope.dtype 82 | # batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 83 | attn_length = attention_states.get_shape()[1].value 84 | if attn_length is None: 85 | attn_length = array_ops.shape(attention_states)[1] 86 | attn_size = attention_states.get_shape()[2].value 87 | 88 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 89 | hidden = array_ops.reshape(attention_states, [-1, attn_length, 1, attn_size]) 90 | hidden_features = [] 91 | v = [] 92 | attention_vec_size = attn_size # Size of query vectors for attention. 93 | for a in xrange(num_heads): 94 | k = variable_scope.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size]) 95 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 96 | v.append(variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) 97 | 98 | state = [] 99 | # 将encoder的最后一个隐层状态扩展成beam_size维,因为decoder阶段的batch_size是beam_size。 100 | # initial_state是一个列表,RNN有多少层就有多少个元素,每个元素都是一个LSTMStateTuple,包含h,c两个隐层状态 101 | # 所以要将其扩展成beam_size维,其实是把c和h进行扩展,最后再合成LSTMStateTuple就可以了 102 | for layers in initial_state: 103 | c = [layers.c] * beam_size 104 | h = [layers.h] * beam_size 105 | c = tf.concat(c, 0) 106 | h = tf.concat(h, 0) 107 | state.append(rnn_cell_impl.LSTMStateTuple(c, h)) 108 | state = tuple(state) 109 | # state_size = int(initial_state.get_shape().with_rank(2)[1]) 110 | # states = [] 111 | # for kk in range(beam_size): 112 | # states.append(initial_state) 113 | # state = tf.concat(states, 0) 114 | # state = initial_state 115 | 116 | def attention(query): 117 | ds = [] # Results of attention reads will be stored here. 118 | if nest.is_sequence(query): # If the query is a tuple, flatten it. 119 | query_list = nest.flatten(query) 120 | for q in query_list: # Check that ndims == 2 if specified. 121 | ndims = q.get_shape().ndims 122 | if ndims: 123 | assert ndims == 2 124 | query = array_ops.concat(query_list, 1) 125 | for a in xrange(num_heads): 126 | with variable_scope.variable_scope("Attention_%d" % a): 127 | y = Linear(query, attention_vec_size, True)(query) 128 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 129 | # Attention mask is a softmax of v^T * tanh(...). 130 | s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 131 | a = nn_ops.softmax(s) 132 | # Now calculate the attention-weighted vector d. 133 | d = math_ops.reduce_sum(array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) 134 | ds.append(array_ops.reshape(d, [-1, attn_size])) 135 | return ds 136 | 137 | outputs = [] 138 | prev = None 139 | # attention也要定义成beam_size为的tensor 140 | batch_attn_size = array_ops.stack([beam_size, attn_size]) 141 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) for _ in xrange(num_heads)] 142 | for a in attns: # Ensure the second shape of attention vectors is set. 143 | a.set_shape([None, attn_size]) 144 | if initial_state_attention: 145 | attns = attention(initial_state) 146 | 147 | log_beam_probs, beam_path, beam_symbols = [], [], [] 148 | for i, inp in enumerate(decoder_inputs): 149 | if i > 0: 150 | variable_scope.get_variable_scope().reuse_variables() 151 | # If loop_function is set, we use it instead of decoder_inputs. 152 | if i == 0: 153 | #i=0时,输入时一个batch_szie=beam_size的tensor,且里面每个元素的值都是相同的,都是标志 154 | inp = tf.nn.embedding_lookup(embedding, tf.constant(1, dtype=tf.int32, shape=[beam_size])) 155 | 156 | if loop_function is not None and prev is not None: 157 | with variable_scope.variable_scope("loop_function", reuse=True): 158 | inp = loop_function(prev, i, log_beam_probs, beam_path, beam_symbols) 159 | # Merge input and previous attentions into one vector of the right size. 160 | input_size = inp.get_shape().with_rank(2)[1] 161 | if input_size.value is None: 162 | raise ValueError("Could not infer input size from input: %s" % inp.name) 163 | inputs = [inp] + attns 164 | x = Linear(inputs, input_size, True)(inputs) 165 | 166 | # Run the RNN. 167 | cell_output, state = cell(x, state) 168 | # Run the attention mechanism. 169 | if i == 0 and initial_state_attention: 170 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True): 171 | attns = attention(state) 172 | else: 173 | attns = attention(state) 174 | 175 | with variable_scope.variable_scope("AttnOutputProjection"): 176 | inputs = [cell_output] + attns 177 | output = Linear(inputs, output_size, True)(inputs) 178 | if loop_function is not None: 179 | prev = output 180 | outputs.append(tf.argmax(nn_ops.xw_plus_b(output, output_projection[0], output_projection[1]), axis=1)) 181 | 182 | return outputs, state, tf.reshape(tf.concat(beam_path, 0), [-1, beam_size]), tf.reshape(tf.concat(beam_symbols, 0), 183 | [-1, beam_size]) 184 | 185 | def embedding_attention_decoder(decoder_inputs, 186 | initial_state, 187 | attention_states, 188 | cell, 189 | num_symbols, 190 | embedding_size, 191 | num_heads=1, 192 | output_size=None, 193 | output_projection=None, 194 | feed_previous=False, 195 | update_embedding_for_previous=True, 196 | dtype=None, 197 | scope=None, 198 | initial_state_attention=False, beam_search=True, beam_size=10): 199 | if output_size is None: 200 | output_size = cell.output_size 201 | if output_projection is not None: 202 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 203 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 204 | 205 | with variable_scope.variable_scope(scope or "embedding_attention_decoder", dtype=dtype) as scope: 206 | embedding = variable_scope.get_variable("embedding", [num_symbols, embedding_size]) 207 | emb_inp = [embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 208 | loop_function = _extract_beam_search(embedding, beam_size, num_symbols, embedding_size, output_projection) 209 | return beam_attention_decoder( 210 | emb_inp, initial_state, attention_states, cell, embedding, output_size=output_size, 211 | num_heads=num_heads, loop_function=loop_function, 212 | initial_state_attention=initial_state_attention, output_projection=output_projection, 213 | beam_size=beam_size) 214 | 215 | 216 | def embedding_attention_seq2seq(encoder_inputs, 217 | decoder_inputs, 218 | cell, 219 | num_encoder_symbols, 220 | num_decoder_symbols, 221 | embedding_size, 222 | num_heads=1, 223 | output_projection=None, 224 | feed_previous=False, 225 | dtype=None, 226 | scope=None, 227 | initial_state_attention=False, beam_search=True, beam_size=10): 228 | with variable_scope.variable_scope(scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 229 | dtype = scope.dtype 230 | # Encoder. 231 | encoder_cell = copy.deepcopy(cell) 232 | encoder_cell = core_rnn_cell.EmbeddingWrapper(encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) 233 | encoder_outputs, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) 234 | 235 | # First calculate a concatenation of encoder outputs to put attention on. 236 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs] 237 | attention_states = array_ops.concat(top_states, 1) 238 | 239 | # Decoder. 240 | output_size = None 241 | if output_projection is None: 242 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 243 | output_size = num_decoder_symbols 244 | 245 | return embedding_attention_decoder( 246 | decoder_inputs, 247 | encoder_state, 248 | attention_states, 249 | cell, 250 | num_decoder_symbols, 251 | embedding_size, 252 | num_heads=num_heads, 253 | output_size=output_size, 254 | output_projection=output_projection, 255 | feed_previous=feed_previous, 256 | initial_state_attention=initial_state_attention, beam_search=beam_search, beam_size=beam_size) 257 | 258 | -------------------------------------------------------------------------------- /seq2seq_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from seq2seq import embedding_attention_seq2seq 3 | class Seq2SeqModel(): 4 | 5 | def __init__(self, source_vocab_size, target_vocab_size, en_de_seq_len, hidden_size, num_layers, 6 | batch_size, learning_rate, num_samples=1024, 7 | forward_only=False, beam_search=True, beam_size=10): 8 | ''' 9 | 初始化并创建模型 10 | :param source_vocab_size:encoder输入的vocab size 11 | :param target_vocab_size: decoder输入的vocab size,这里跟上面一样 12 | :param en_de_seq_len: 源和目的序列最大长度 13 | :param hidden_size: RNN模型的隐藏层单元个数 14 | :param num_layers: RNN堆叠的层数 15 | :param batch_size: batch大小 16 | :param learning_rate: 学习率 17 | :param num_samples: 计算loss时做sampled softmax时的采样数 18 | :param forward_only: 预测时指定为真 19 | :param beam_search: 预测时是采用greedy search还是beam search 20 | :param beam_size: beam search的大小 21 | ''' 22 | self.source_vocab_size = source_vocab_size 23 | self.target_vocab_size = target_vocab_size 24 | self.en_de_seq_len = en_de_seq_len 25 | self.hidden_size = hidden_size 26 | self.num_layers = num_layers 27 | self.batch_size = batch_size 28 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False) 29 | self.num_samples = num_samples 30 | self.forward_only = forward_only 31 | self.beam_search = beam_search 32 | self.beam_size = beam_size 33 | self.global_step = tf.Variable(0, trainable=False) 34 | 35 | output_projection = None 36 | softmax_loss_function = None 37 | # 定义采样loss函数,传入后面的sequence_loss_by_example函数 38 | if num_samples > 0 and num_samples < self.target_vocab_size: 39 | w = tf.get_variable('proj_w', [hidden_size, self.target_vocab_size]) 40 | w_t = tf.transpose(w) 41 | b = tf.get_variable('proj_b', [self.target_vocab_size]) 42 | output_projection = (w, b) 43 | #调用sampled_softmax_loss函数计算sample loss,这样可以节省计算时间 44 | def sample_loss(logits, labels): 45 | labels = tf.reshape(labels, [-1, 1]) 46 | return tf.nn.sampled_softmax_loss(w_t, b, labels=labels, inputs=logits, num_sampled=num_samples, num_classes=self.target_vocab_size) 47 | softmax_loss_function = sample_loss 48 | 49 | self.keep_drop = tf.placeholder(tf.float32) 50 | # 定义encoder和decoder阶段的多层dropout RNNCell 51 | def create_rnn_cell(): 52 | encoDecoCell = tf.contrib.rnn.BasicLSTMCell(hidden_size) 53 | encoDecoCell = tf.contrib.rnn.DropoutWrapper(encoDecoCell, input_keep_prob=1.0, output_keep_prob=self.keep_drop) 54 | return encoDecoCell 55 | encoCell = tf.contrib.rnn.MultiRNNCell([create_rnn_cell() for _ in range(num_layers)]) 56 | 57 | # 定义输入的placeholder,采用了列表的形式 58 | self.encoder_inputs = [] 59 | self.decoder_inputs = [] 60 | self.decoder_targets = [] 61 | self.target_weights = [] 62 | for i in range(en_de_seq_len[0]): 63 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None, ], name="encoder{0}".format(i))) 64 | for i in range(en_de_seq_len[1]): 65 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None, ], name="decoder{0}".format(i))) 66 | self.decoder_targets.append(tf.placeholder(tf.int32, shape=[None, ], name="target{0}".format(i))) 67 | self.target_weights.append(tf.placeholder(tf.float32, shape=[None, ], name="weight{0}".format(i))) 68 | 69 | # test模式,将上一时刻输出当做下一时刻输入传入 70 | if forward_only: 71 | if beam_search:#如果是beam_search的话,则调用自己写的embedding_attention_seq2seq函数,而不是legacy_seq2seq下面的 72 | self.beam_outputs, _, self.beam_path, self.beam_symbol = embedding_attention_seq2seq( 73 | self.encoder_inputs, self.decoder_inputs, encoCell, num_encoder_symbols=source_vocab_size, 74 | num_decoder_symbols=target_vocab_size, embedding_size=hidden_size, 75 | output_projection=output_projection, feed_previous=True) 76 | else: 77 | decoder_outputs, _ = tf.contrib.legacy_seq2seq.embedding_attention_seq2seq( 78 | self.encoder_inputs, self.decoder_inputs, encoCell, num_encoder_symbols=source_vocab_size, 79 | num_decoder_symbols=target_vocab_size, embedding_size=hidden_size, 80 | output_projection=output_projection, feed_previous=True) 81 | # 因为seq2seq模型中未指定output_projection,所以需要在输出之后自己进行output_projection 82 | if output_projection is not None: 83 | self.outputs = tf.matmul(decoder_outputs, output_projection[0]) + output_projection[1] 84 | else: 85 | # 因为不需要将output作为下一时刻的输入,所以不用output_projection 86 | decoder_outputs, _ = tf.contrib.legacy_seq2seq.embedding_attention_seq2seq( 87 | self.encoder_inputs, self.decoder_inputs, encoCell, num_encoder_symbols=source_vocab_size, 88 | num_decoder_symbols=target_vocab_size, embedding_size=hidden_size, output_projection=output_projection, 89 | feed_previous=False) 90 | self.loss = tf.contrib.legacy_seq2seq.sequence_loss( 91 | decoder_outputs, self.decoder_targets, self.target_weights, softmax_loss_function=softmax_loss_function) 92 | 93 | # Initialize the optimizer 94 | opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08) 95 | self.optOp = opt.minimize(self.loss) 96 | 97 | self.saver = tf.train.Saver(tf.all_variables()) 98 | 99 | def step(self, session, encoder_inputs, decoder_inputs, decoder_targets, target_weights, go_token_id): 100 | #传入一个batch的数据,并训练性对应的模型 101 | # 构建sess.run时的feed_inpits 102 | feed_dict = {} 103 | if not self.forward_only: 104 | feed_dict[self.keep_drop] = 0.5 105 | for i in range(self.en_de_seq_len[0]): 106 | feed_dict[self.encoder_inputs[i].name] = encoder_inputs[i] 107 | for i in range(self.en_de_seq_len[1]): 108 | feed_dict[self.decoder_inputs[i].name] = decoder_inputs[i] 109 | feed_dict[self.decoder_targets[i].name] = decoder_targets[i] 110 | feed_dict[self.target_weights[i].name] = target_weights[i] 111 | run_ops = [self.optOp, self.loss] 112 | else: 113 | feed_dict[self.keep_drop] = 1.0 114 | for i in range(self.en_de_seq_len[0]): 115 | feed_dict[self.encoder_inputs[i].name] = encoder_inputs[i] 116 | feed_dict[self.decoder_inputs[0].name] = [go_token_id] 117 | if self.beam_search: 118 | run_ops = [self.beam_path, self.beam_symbol] 119 | else: 120 | run_ops = [self.outputs] 121 | 122 | outputs = session.run(run_ops, feed_dict) 123 | if not self.forward_only: 124 | return None, outputs[1] 125 | else: 126 | if self.beam_search: 127 | return outputs[0], outputs[1] --------------------------------------------------------------------------------