├── .idea
├── dictionaries
│ └── liuchong.xml
├── misc.xml
├── modules.xml
└── seq2seq_chatbot.iml
├── README.md
├── __init__.py
├── chatbot.py
├── data
└── dataset-cornell-length10-filter1-vocabSize40000.pkl
├── data_utils.py
├── seq2seq.py
└── seq2seq_model.py
/.idea/dictionaries/liuchong.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ApexVCS
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/seq2seq_chatbot.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | =================================================更新===========================================================
2 | 训练好的模型已经上传到百度云网盘,如果大家有需要可以前去下载。模型训练速度的话,CPU,16G内存,一天即刻训练完成~~~
3 |
4 | 链接:https://pan.baidu.com/s/1hrNxaSk 密码:d2sn
5 |
6 | =================================================分割线,下面是正文===============================================
7 |
8 | 本文是一个简单的基于seq2seq模型的chatbot对话系统的tensorflow实现。
9 |
10 | 代码的讲解可以参考我的知乎专栏文章:
11 |
12 | [从头实现深度学习的对话系统--简单chatbot代码实现](https://zhuanlan.zhihu.com/p/32455898)
13 |
14 | 代码参考了DeepQA,在其基础上添加了beam search的功能和attention的机制,
15 |
16 | 最终的效果如下图所示:
17 |
18 | 
19 |
20 | 
21 |
22 | 测试效果,根据用户输入回复概率最大的前beam_size个句子:
23 |
24 | 
25 |
26 | #使用方法
27 |
28 | 1,下载代码到本地(data文件夹下已经包含了处理好的数据集,所以无需额外下载数据集)
29 |
30 | 2,训练模型,将chatbot.py文件第34行的decode参数修改为False,进行训练模型
31 |
32 | (之后我会把我这里训练好的模型上传到网上方便大家使用)
33 |
34 | 3,训练完之后(大概要一天左右的时间,30个epoches),再将decode参数修改为True
35 |
36 | 就可以进行测试了。输入你想问的话看他回复什么吧==
37 |
38 | 这里还需要注意的就是要记得修改数据集和最后模型文件的绝对路径,不然可能会报错。
39 |
40 | 分别在44行,57行,82行三处。好了,接下来就可以愉快的玩耍了~~
41 |
42 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lc222/seq2seq_chatbot/7a419e4e9587d9a87b4acb0e141789b329146fe1/__init__.py
--------------------------------------------------------------------------------
/chatbot.py:
--------------------------------------------------------------------------------
1 | """Most of the code comes from seq2seq tutorial. Binary for training conversation models and decoding from them.
2 |
3 | Running this program without --decode will tokenize it in a very basic way,
4 | and then start training a model saving checkpoints to --train_dir.
5 |
6 | Running with --decode starts an interactive loop so you can see how
7 | the current checkpoint performs
8 |
9 | See the following papers for more information on neural translation models.
10 | * http://arxiv.org/abs/1409.3215
11 | * http://arxiv.org/abs/1409.0473
12 | * http://arxiv.org/abs/1412.2007
13 | """
14 |
15 | import math
16 | import sys
17 | import time
18 | from data_utils import *
19 | from seq2seq_model import *
20 | from tqdm import tqdm
21 |
22 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.")
23 | tf.app.flags.DEFINE_integer("batch_size", 256, "Batch size to use during training.")
24 | tf.app.flags.DEFINE_integer("numEpochs", 30, "Batch size to use during training.")
25 | tf.app.flags.DEFINE_integer("size", 512, "Size of each model layer.")
26 | tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
27 | tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
28 | tf.app.flags.DEFINE_integer("en_de_seq_len", 20, "English vocabulary size.")
29 | tf.app.flags.DEFINE_integer("max_train_data_size", 0, "Limit on the size of training data (0: no limit).")
30 | tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "How many training steps to do per checkpoint.")
31 | tf.app.flags.DEFINE_string("train_dir", './tmp', "How many training steps to do per checkpoint.")
32 | tf.app.flags.DEFINE_integer("beam_size", 5, "How many training steps to do per checkpoint.")
33 | tf.app.flags.DEFINE_boolean("beam_search", True, "Set to True for beam_search.")
34 | tf.app.flags.DEFINE_boolean("decode", True, "Set to True for interactive decoding.")
35 | FLAGS = tf.app.flags.FLAGS
36 |
37 | def create_model(session, forward_only, beam_search, beam_size = 5):
38 | """Create translation model and initialize or load parameters in session."""
39 | model = Seq2SeqModel(
40 | FLAGS.en_vocab_size, FLAGS.en_vocab_size, [10, 10],
41 | FLAGS.size, FLAGS.num_layers, FLAGS.batch_size,
42 | FLAGS.learning_rate, forward_only=forward_only, beam_search=beam_search, beam_size=beam_size)
43 | ckpt = tf.train.latest_checkpoint(FLAGS.train_dir)
44 | model_path = 'E:\PycharmProjects\Seq-to-Seq\seq2seq_chatbot\\tmp\chat_bot.ckpt-0'
45 | if forward_only:
46 | model.saver.restore(session, model_path)
47 | elif ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
48 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
49 | model.saver.restore(session, ckpt.model_checkpoint_path)
50 | else:
51 | print("Created model with fresh parameters.")
52 | session.run(tf.initialize_all_variables())
53 | return model
54 |
55 | def train():
56 | # prepare dataset
57 | data_path = 'E:\PycharmProjects\Seq-to-Seq\seq2seq_chatbot\data\dataset-cornell-length10-filter1-vocabSize40000.pkl'
58 | word2id, id2word, trainingSamples = loadDataset(data_path)
59 | with tf.Session() as sess:
60 | print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
61 | model = create_model(sess, False, beam_search=False, beam_size=5)
62 | current_step = 0
63 | for e in range(FLAGS.numEpochs):
64 | print("----- Epoch {}/{} -----".format(e + 1, FLAGS.numEpochs))
65 | batches = getBatches(trainingSamples, FLAGS.batch_size, model.en_de_seq_len)
66 | for nextBatch in tqdm(batches, desc="Training"):
67 | _, step_loss = model.step(sess, nextBatch.encoderSeqs, nextBatch.decoderSeqs, nextBatch.targetSeqs,
68 | nextBatch.weights, goToken)
69 | current_step += 1
70 | if current_step % FLAGS.steps_per_checkpoint == 0:
71 | perplexity = math.exp(float(step_loss)) if step_loss < 300 else float('inf')
72 | tqdm.write("----- Step %d -- Loss %.2f -- Perplexity %.2f" % (current_step, step_loss, perplexity))
73 | checkpoint_path = os.path.join(FLAGS.train_dir, "chat_bot.ckpt")
74 | model.saver.save(sess, checkpoint_path, global_step=model.global_step)
75 |
76 | def decode():
77 | with tf.Session() as sess:
78 | beam_size = FLAGS.beam_size
79 | beam_search = FLAGS.beam_search
80 | model = create_model(sess, True, beam_search=beam_search, beam_size=beam_size)
81 | model.batch_size = 1
82 | data_path = 'E:\PycharmProjects\Seq-to-Seq\seq2seq_chatbot\data\dataset-cornell-length10-filter1-vocabSize40000.pkl'
83 | word2id, id2word, trainingSamples = loadDataset(data_path)
84 |
85 | if beam_search:
86 | sys.stdout.write("> ")
87 | sys.stdout.flush()
88 | sentence = sys.stdin.readline()
89 | while sentence:
90 | batch = sentence2enco(sentence, word2id, model.en_de_seq_len)
91 | beam_path, beam_symbol = model.step(sess, batch.encoderSeqs, batch.decoderSeqs, batch.targetSeqs,
92 | batch.weights, goToken)
93 | paths = [[] for _ in range(beam_size)]
94 | curr = [i for i in range(beam_size)]
95 | num_steps = len(beam_path)
96 | for i in range(num_steps-1, -1, -1):
97 | for kk in range(beam_size):
98 | paths[kk].append(beam_symbol[i][curr[kk]])
99 | curr[kk] = beam_path[i][curr[kk]]
100 | recos = set()
101 | print("Replies --------------------------------------->")
102 | for kk in range(beam_size):
103 | foutputs = [int(logit) for logit in paths[kk][::-1]]
104 | if eosToken in foutputs:
105 | foutputs = foutputs[:foutputs.index(eosToken)]
106 | rec = " ".join([tf.compat.as_str(id2word[output]) for output in foutputs if output in id2word])
107 | if rec not in recos:
108 | recos.add(rec)
109 | print(rec)
110 | print("> ", "")
111 | sys.stdout.flush()
112 | sentence = sys.stdin.readline()
113 | # else:
114 | # sys.stdout.write("> ")
115 | # sys.stdout.flush()
116 | # sentence = sys.stdin.readline()
117 | #
118 | # while sentence:
119 | # # Get token-ids for the input sentence.
120 | # token_ids = sentence_to_token_ids(tf.compat.as_bytes(sentence), vocab)
121 | # # Which bucket does it belong to?
122 | # bucket_id = min([b for b in xrange(len(_buckets))
123 | # if _buckets[b][0] > len(token_ids)])
124 | # # for loc in locs:
125 | # # Get a 1-element batch to feed the sentence to the model.
126 | # encoder_inputs, decoder_inputs, target_weights = model.get_batch(
127 | # {bucket_id: [(token_ids, [],)]}, bucket_id)
128 | #
129 | # _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
130 | # target_weights, bucket_id, True,beam_search)
131 | # # This is a greedy decoder - outputs are just argmaxes of output_logits.
132 | #
133 | # outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
134 | # # If there is an EOS symbol in outputs, cut them at that point.
135 | # if EOS_ID in outputs:
136 | # # print outputs
137 | # outputs = outputs[:outputs.index(EOS_ID)]
138 | #
139 | # print(" ".join([tf.compat.as_str(rev_vocab[output]) for output in outputs]))
140 | # print("> ", "")
141 | # sys.stdout.flush()
142 | # sentence = sys.stdin.readline()
143 |
144 | def main(_):
145 | if FLAGS.decode:
146 | decode()
147 | else:
148 | train()
149 |
150 | if __name__ == "__main__":
151 | tf.app.run()
152 |
--------------------------------------------------------------------------------
/data/dataset-cornell-length10-filter1-vocabSize40000.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lc222/seq2seq_chatbot/7a419e4e9587d9a87b4acb0e141789b329146fe1/data/dataset-cornell-length10-filter1-vocabSize40000.pkl
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import nltk
7 |
8 | import pickle
9 | import random
10 |
11 | padToken, goToken, eosToken, unknownToken = 0, 1, 2, 3
12 |
13 | class Batch:
14 | #batch类,里面包含了encoder输入,decoder输入,decoder标签,decoder样本长度mask
15 | def __init__(self):
16 | self.encoderSeqs = []
17 | self.decoderSeqs = []
18 | self.targetSeqs = []
19 | self.weights = []
20 |
21 | def loadDataset(filename):
22 | '''
23 | 读取样本数据
24 | :param filename: 文件路径,是一个字典,包含word2id、id2word分别是单词与索引对应的字典和反序字典,
25 | trainingSamples样本数据,每一条都是QA对
26 | :return: word2id, id2word, trainingSamples
27 | '''
28 | dataset_path = os.path.join(filename)
29 | print('Loading dataset from {}'.format(dataset_path))
30 | with open(dataset_path, 'rb') as handle:
31 | data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset
32 | word2id = data['word2id']
33 | id2word = data['id2word']
34 | trainingSamples = data['trainingSamples']
35 | return word2id, id2word, trainingSamples
36 |
37 | def createBatch(samples, en_de_seq_len):
38 | '''
39 | 根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式
40 | :param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id
41 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
42 | :return: 处理完之后可以直接传入feed_dict的数据格式
43 | '''
44 | batch = Batch()
45 | #根据样本长度获得batch size大小
46 | batchSize = len(samples)
47 | #将每条数据的问题和答案分开传入到相应的变量中
48 | for i in range(batchSize):
49 | sample = samples[i]
50 | batch.encoderSeqs.append(list(reversed(sample[0]))) # 将输入反序,可提高模型效果
51 | batch.decoderSeqs.append([goToken] + sample[1] + [eosToken]) # Add the and tokens
52 | batch.targetSeqs.append(batch.decoderSeqs[-1][1:]) # Same as decoder, but shifted to the left (ignore the )
53 | # 将每个元素PAD到指定长度,并构造weights序列长度mask标志
54 | batch.encoderSeqs[i] = [padToken] * (en_de_seq_len[0] - len(batch.encoderSeqs[i])) + batch.encoderSeqs[i]
55 | batch.weights.append([1.0] * len(batch.targetSeqs[i]) + [0.0] * (en_de_seq_len[1] - len(batch.targetSeqs[i])))
56 | batch.decoderSeqs[i] = batch.decoderSeqs[i] + [padToken] * (en_de_seq_len[1] - len(batch.decoderSeqs[i]))
57 | batch.targetSeqs[i] = batch.targetSeqs[i] + [padToken] * (en_de_seq_len[1] - len(batch.targetSeqs[i]))
58 |
59 | #--------------------接下来就是将数据进行reshape操作,变成序列长度*batch_size格式的数据------------------------
60 | encoderSeqsT = [] # Corrected orientation
61 | for i in range(en_de_seq_len[0]):
62 | encoderSeqT = []
63 | for j in range(batchSize):
64 | encoderSeqT.append(batch.encoderSeqs[j][i])
65 | encoderSeqsT.append(encoderSeqT)
66 | batch.encoderSeqs = encoderSeqsT
67 |
68 | decoderSeqsT = []
69 | targetSeqsT = []
70 | weightsT = []
71 | for i in range(en_de_seq_len[1]):
72 | decoderSeqT = []
73 | targetSeqT = []
74 | weightT = []
75 | for j in range(batchSize):
76 | decoderSeqT.append(batch.decoderSeqs[j][i])
77 | targetSeqT.append(batch.targetSeqs[j][i])
78 | weightT.append(batch.weights[j][i])
79 | decoderSeqsT.append(decoderSeqT)
80 | targetSeqsT.append(targetSeqT)
81 | weightsT.append(weightT)
82 | batch.decoderSeqs = decoderSeqsT
83 | batch.targetSeqs = targetSeqsT
84 | batch.weights = weightsT
85 |
86 | return batch
87 |
88 | def getBatches(data, batch_size, en_de_seq_len):
89 | '''
90 | 根据读取出来的所有数据和batch_size将原始数据分成不同的小batch。对每个batch索引的样本调用createBatch函数进行处理
91 | :param data: loadDataset函数读取之后的trainingSamples,就是QA对的列表
92 | :param batch_size: batch大小
93 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
94 | :return: 列表,每个元素都是一个batch的样本数据,可直接传入feed_dict进行训练
95 | '''
96 | #每个epoch之前都要进行样本的shuffle
97 | random.shuffle(data)
98 | batches = []
99 | data_len = len(data)
100 | def genNextSamples():
101 | for i in range(0, data_len, batch_size):
102 | yield data[i:min(i + batch_size, data_len)]
103 |
104 | for samples in genNextSamples():
105 | batch = createBatch(samples, en_de_seq_len)
106 | batches.append(batch)
107 | return batches
108 |
109 | def sentence2enco(sentence, word2id, en_de_seq_len):
110 | '''
111 | 测试的时候将用户输入的句子转化为可以直接feed进模型的数据,现将句子转化成id,然后调用createBatch处理
112 | :param sentence: 用户输入的句子
113 | :param word2id: 单词与id之间的对应关系字典
114 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
115 | :return: 处理之后的数据,可直接feed进模型进行预测
116 | '''
117 | if sentence == '':
118 | return None
119 | #分词
120 | tokens = nltk.word_tokenize(sentence)
121 | if len(tokens) > en_de_seq_len[0]:
122 | return None
123 | #将每个单词转化为id
124 | wordIds = []
125 | for token in tokens:
126 | wordIds.append(word2id.get(token, unknownToken))
127 | #调用createBatch构造batch
128 | batch = createBatch([[wordIds, []]], en_de_seq_len)
129 | return batch
130 |
--------------------------------------------------------------------------------
/seq2seq.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import copy
7 |
8 | # We disable pylint because we need python3 compatibility.
9 | from six.moves import xrange # pylint: disable=redefined-builtin
10 | from six.moves import zip # pylint: disable=redefined-builtin
11 |
12 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell
13 | from tensorflow.python.framework import dtypes
14 | from tensorflow.python.framework import ops
15 | from tensorflow.python.ops import array_ops
16 | from tensorflow.python.ops import control_flow_ops
17 | from tensorflow.python.ops import embedding_ops
18 | from tensorflow.python.ops import math_ops
19 | from tensorflow.python.ops import nn_ops
20 | from tensorflow.python.ops import rnn
21 | from tensorflow.python.ops import rnn_cell_impl
22 | from tensorflow.python.ops import variable_scope
23 | from tensorflow.python.util import nest
24 |
25 | Linear = rnn_cell_impl._Linear # pylint: disable=protected-access,invalid-name
26 |
27 | def _extract_beam_search(embedding, beam_size, num_symbols, embedding_size, output_projection=None):
28 |
29 | def loop_function(prev, i, log_beam_probs, beam_path, beam_symbols):
30 | if output_projection is not None:
31 | prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
32 | # 对输出概率进行归一化和取log,这样序列概率相乘就可以变成概率相加
33 | probs = tf.log(tf.nn.softmax(prev))
34 | if i == 1:
35 | probs = tf.reshape(probs[0, :], [-1, num_symbols])
36 | if i > 1:
37 | # 将当前序列的概率与之前序列概率相加得到结果之前有beam_szie个序列,本次产生num_symbols个结果,
38 | # 所以reshape成这样的tensor
39 | probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols])
40 | # 选出概率最大的前beam_size个序列,从beam_size * num_symbols个元素中选出beam_size个
41 | best_probs, indices = tf.nn.top_k(probs, beam_size)
42 | indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1])))
43 | best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1]))
44 |
45 | # beam_size * num_symbols,看对应的是哪个序列和单词
46 | symbols = indices % num_symbols # Which word in vocabulary.
47 | beam_parent = indices // num_symbols # Which hypothesis it came from.
48 | beam_symbols.append(symbols)
49 | beam_path.append(beam_parent)
50 | log_beam_probs.append(best_probs)
51 |
52 | # 对beam-search选出的beam size个单词进行embedding,得到相应的词向量
53 | emb_prev = embedding_ops.embedding_lookup(embedding, symbols)
54 | emb_prev = tf.reshape(emb_prev, [-1, embedding_size])
55 | return emb_prev
56 |
57 | return loop_function
58 |
59 | def beam_attention_decoder(decoder_inputs,
60 | initial_state,
61 | attention_states,
62 | cell,
63 | embedding,
64 | output_size=None,
65 | num_heads=1,
66 | loop_function=None,
67 | dtype=None,
68 | scope=None,
69 | initial_state_attention=False, output_projection=None, beam_size=10):
70 | if not decoder_inputs:
71 | raise ValueError("Must provide at least 1 input to attention decoder.")
72 | if num_heads < 1:
73 | raise ValueError("With less than 1 heads, use a non-attention decoder.")
74 | if not attention_states.get_shape()[1:2].is_fully_defined():
75 | raise ValueError("Shape[1] and [2] of attention_states must be known: %s"
76 | % attention_states.get_shape())
77 | if output_size is None:
78 | output_size = cell.output_size
79 |
80 | with variable_scope.variable_scope(scope or "attention_decoder", dtype=dtype) as scope:
81 | dtype = scope.dtype
82 | # batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping.
83 | attn_length = attention_states.get_shape()[1].value
84 | if attn_length is None:
85 | attn_length = array_ops.shape(attention_states)[1]
86 | attn_size = attention_states.get_shape()[2].value
87 |
88 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
89 | hidden = array_ops.reshape(attention_states, [-1, attn_length, 1, attn_size])
90 | hidden_features = []
91 | v = []
92 | attention_vec_size = attn_size # Size of query vectors for attention.
93 | for a in xrange(num_heads):
94 | k = variable_scope.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size])
95 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
96 | v.append(variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
97 |
98 | state = []
99 | # 将encoder的最后一个隐层状态扩展成beam_size维,因为decoder阶段的batch_size是beam_size。
100 | # initial_state是一个列表,RNN有多少层就有多少个元素,每个元素都是一个LSTMStateTuple,包含h,c两个隐层状态
101 | # 所以要将其扩展成beam_size维,其实是把c和h进行扩展,最后再合成LSTMStateTuple就可以了
102 | for layers in initial_state:
103 | c = [layers.c] * beam_size
104 | h = [layers.h] * beam_size
105 | c = tf.concat(c, 0)
106 | h = tf.concat(h, 0)
107 | state.append(rnn_cell_impl.LSTMStateTuple(c, h))
108 | state = tuple(state)
109 | # state_size = int(initial_state.get_shape().with_rank(2)[1])
110 | # states = []
111 | # for kk in range(beam_size):
112 | # states.append(initial_state)
113 | # state = tf.concat(states, 0)
114 | # state = initial_state
115 |
116 | def attention(query):
117 | ds = [] # Results of attention reads will be stored here.
118 | if nest.is_sequence(query): # If the query is a tuple, flatten it.
119 | query_list = nest.flatten(query)
120 | for q in query_list: # Check that ndims == 2 if specified.
121 | ndims = q.get_shape().ndims
122 | if ndims:
123 | assert ndims == 2
124 | query = array_ops.concat(query_list, 1)
125 | for a in xrange(num_heads):
126 | with variable_scope.variable_scope("Attention_%d" % a):
127 | y = Linear(query, attention_vec_size, True)(query)
128 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
129 | # Attention mask is a softmax of v^T * tanh(...).
130 | s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3])
131 | a = nn_ops.softmax(s)
132 | # Now calculate the attention-weighted vector d.
133 | d = math_ops.reduce_sum(array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
134 | ds.append(array_ops.reshape(d, [-1, attn_size]))
135 | return ds
136 |
137 | outputs = []
138 | prev = None
139 | # attention也要定义成beam_size为的tensor
140 | batch_attn_size = array_ops.stack([beam_size, attn_size])
141 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) for _ in xrange(num_heads)]
142 | for a in attns: # Ensure the second shape of attention vectors is set.
143 | a.set_shape([None, attn_size])
144 | if initial_state_attention:
145 | attns = attention(initial_state)
146 |
147 | log_beam_probs, beam_path, beam_symbols = [], [], []
148 | for i, inp in enumerate(decoder_inputs):
149 | if i > 0:
150 | variable_scope.get_variable_scope().reuse_variables()
151 | # If loop_function is set, we use it instead of decoder_inputs.
152 | if i == 0:
153 | #i=0时,输入时一个batch_szie=beam_size的tensor,且里面每个元素的值都是相同的,都是标志
154 | inp = tf.nn.embedding_lookup(embedding, tf.constant(1, dtype=tf.int32, shape=[beam_size]))
155 |
156 | if loop_function is not None and prev is not None:
157 | with variable_scope.variable_scope("loop_function", reuse=True):
158 | inp = loop_function(prev, i, log_beam_probs, beam_path, beam_symbols)
159 | # Merge input and previous attentions into one vector of the right size.
160 | input_size = inp.get_shape().with_rank(2)[1]
161 | if input_size.value is None:
162 | raise ValueError("Could not infer input size from input: %s" % inp.name)
163 | inputs = [inp] + attns
164 | x = Linear(inputs, input_size, True)(inputs)
165 |
166 | # Run the RNN.
167 | cell_output, state = cell(x, state)
168 | # Run the attention mechanism.
169 | if i == 0 and initial_state_attention:
170 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True):
171 | attns = attention(state)
172 | else:
173 | attns = attention(state)
174 |
175 | with variable_scope.variable_scope("AttnOutputProjection"):
176 | inputs = [cell_output] + attns
177 | output = Linear(inputs, output_size, True)(inputs)
178 | if loop_function is not None:
179 | prev = output
180 | outputs.append(tf.argmax(nn_ops.xw_plus_b(output, output_projection[0], output_projection[1]), axis=1))
181 |
182 | return outputs, state, tf.reshape(tf.concat(beam_path, 0), [-1, beam_size]), tf.reshape(tf.concat(beam_symbols, 0),
183 | [-1, beam_size])
184 |
185 | def embedding_attention_decoder(decoder_inputs,
186 | initial_state,
187 | attention_states,
188 | cell,
189 | num_symbols,
190 | embedding_size,
191 | num_heads=1,
192 | output_size=None,
193 | output_projection=None,
194 | feed_previous=False,
195 | update_embedding_for_previous=True,
196 | dtype=None,
197 | scope=None,
198 | initial_state_attention=False, beam_search=True, beam_size=10):
199 | if output_size is None:
200 | output_size = cell.output_size
201 | if output_projection is not None:
202 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
203 | proj_biases.get_shape().assert_is_compatible_with([num_symbols])
204 |
205 | with variable_scope.variable_scope(scope or "embedding_attention_decoder", dtype=dtype) as scope:
206 | embedding = variable_scope.get_variable("embedding", [num_symbols, embedding_size])
207 | emb_inp = [embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs]
208 | loop_function = _extract_beam_search(embedding, beam_size, num_symbols, embedding_size, output_projection)
209 | return beam_attention_decoder(
210 | emb_inp, initial_state, attention_states, cell, embedding, output_size=output_size,
211 | num_heads=num_heads, loop_function=loop_function,
212 | initial_state_attention=initial_state_attention, output_projection=output_projection,
213 | beam_size=beam_size)
214 |
215 |
216 | def embedding_attention_seq2seq(encoder_inputs,
217 | decoder_inputs,
218 | cell,
219 | num_encoder_symbols,
220 | num_decoder_symbols,
221 | embedding_size,
222 | num_heads=1,
223 | output_projection=None,
224 | feed_previous=False,
225 | dtype=None,
226 | scope=None,
227 | initial_state_attention=False, beam_search=True, beam_size=10):
228 | with variable_scope.variable_scope(scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
229 | dtype = scope.dtype
230 | # Encoder.
231 | encoder_cell = copy.deepcopy(cell)
232 | encoder_cell = core_rnn_cell.EmbeddingWrapper(encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size)
233 | encoder_outputs, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
234 |
235 | # First calculate a concatenation of encoder outputs to put attention on.
236 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs]
237 | attention_states = array_ops.concat(top_states, 1)
238 |
239 | # Decoder.
240 | output_size = None
241 | if output_projection is None:
242 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
243 | output_size = num_decoder_symbols
244 |
245 | return embedding_attention_decoder(
246 | decoder_inputs,
247 | encoder_state,
248 | attention_states,
249 | cell,
250 | num_decoder_symbols,
251 | embedding_size,
252 | num_heads=num_heads,
253 | output_size=output_size,
254 | output_projection=output_projection,
255 | feed_previous=feed_previous,
256 | initial_state_attention=initial_state_attention, beam_search=beam_search, beam_size=beam_size)
257 |
258 |
--------------------------------------------------------------------------------
/seq2seq_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from seq2seq import embedding_attention_seq2seq
3 | class Seq2SeqModel():
4 |
5 | def __init__(self, source_vocab_size, target_vocab_size, en_de_seq_len, hidden_size, num_layers,
6 | batch_size, learning_rate, num_samples=1024,
7 | forward_only=False, beam_search=True, beam_size=10):
8 | '''
9 | 初始化并创建模型
10 | :param source_vocab_size:encoder输入的vocab size
11 | :param target_vocab_size: decoder输入的vocab size,这里跟上面一样
12 | :param en_de_seq_len: 源和目的序列最大长度
13 | :param hidden_size: RNN模型的隐藏层单元个数
14 | :param num_layers: RNN堆叠的层数
15 | :param batch_size: batch大小
16 | :param learning_rate: 学习率
17 | :param num_samples: 计算loss时做sampled softmax时的采样数
18 | :param forward_only: 预测时指定为真
19 | :param beam_search: 预测时是采用greedy search还是beam search
20 | :param beam_size: beam search的大小
21 | '''
22 | self.source_vocab_size = source_vocab_size
23 | self.target_vocab_size = target_vocab_size
24 | self.en_de_seq_len = en_de_seq_len
25 | self.hidden_size = hidden_size
26 | self.num_layers = num_layers
27 | self.batch_size = batch_size
28 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
29 | self.num_samples = num_samples
30 | self.forward_only = forward_only
31 | self.beam_search = beam_search
32 | self.beam_size = beam_size
33 | self.global_step = tf.Variable(0, trainable=False)
34 |
35 | output_projection = None
36 | softmax_loss_function = None
37 | # 定义采样loss函数,传入后面的sequence_loss_by_example函数
38 | if num_samples > 0 and num_samples < self.target_vocab_size:
39 | w = tf.get_variable('proj_w', [hidden_size, self.target_vocab_size])
40 | w_t = tf.transpose(w)
41 | b = tf.get_variable('proj_b', [self.target_vocab_size])
42 | output_projection = (w, b)
43 | #调用sampled_softmax_loss函数计算sample loss,这样可以节省计算时间
44 | def sample_loss(logits, labels):
45 | labels = tf.reshape(labels, [-1, 1])
46 | return tf.nn.sampled_softmax_loss(w_t, b, labels=labels, inputs=logits, num_sampled=num_samples, num_classes=self.target_vocab_size)
47 | softmax_loss_function = sample_loss
48 |
49 | self.keep_drop = tf.placeholder(tf.float32)
50 | # 定义encoder和decoder阶段的多层dropout RNNCell
51 | def create_rnn_cell():
52 | encoDecoCell = tf.contrib.rnn.BasicLSTMCell(hidden_size)
53 | encoDecoCell = tf.contrib.rnn.DropoutWrapper(encoDecoCell, input_keep_prob=1.0, output_keep_prob=self.keep_drop)
54 | return encoDecoCell
55 | encoCell = tf.contrib.rnn.MultiRNNCell([create_rnn_cell() for _ in range(num_layers)])
56 |
57 | # 定义输入的placeholder,采用了列表的形式
58 | self.encoder_inputs = []
59 | self.decoder_inputs = []
60 | self.decoder_targets = []
61 | self.target_weights = []
62 | for i in range(en_de_seq_len[0]):
63 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None, ], name="encoder{0}".format(i)))
64 | for i in range(en_de_seq_len[1]):
65 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None, ], name="decoder{0}".format(i)))
66 | self.decoder_targets.append(tf.placeholder(tf.int32, shape=[None, ], name="target{0}".format(i)))
67 | self.target_weights.append(tf.placeholder(tf.float32, shape=[None, ], name="weight{0}".format(i)))
68 |
69 | # test模式,将上一时刻输出当做下一时刻输入传入
70 | if forward_only:
71 | if beam_search:#如果是beam_search的话,则调用自己写的embedding_attention_seq2seq函数,而不是legacy_seq2seq下面的
72 | self.beam_outputs, _, self.beam_path, self.beam_symbol = embedding_attention_seq2seq(
73 | self.encoder_inputs, self.decoder_inputs, encoCell, num_encoder_symbols=source_vocab_size,
74 | num_decoder_symbols=target_vocab_size, embedding_size=hidden_size,
75 | output_projection=output_projection, feed_previous=True)
76 | else:
77 | decoder_outputs, _ = tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
78 | self.encoder_inputs, self.decoder_inputs, encoCell, num_encoder_symbols=source_vocab_size,
79 | num_decoder_symbols=target_vocab_size, embedding_size=hidden_size,
80 | output_projection=output_projection, feed_previous=True)
81 | # 因为seq2seq模型中未指定output_projection,所以需要在输出之后自己进行output_projection
82 | if output_projection is not None:
83 | self.outputs = tf.matmul(decoder_outputs, output_projection[0]) + output_projection[1]
84 | else:
85 | # 因为不需要将output作为下一时刻的输入,所以不用output_projection
86 | decoder_outputs, _ = tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
87 | self.encoder_inputs, self.decoder_inputs, encoCell, num_encoder_symbols=source_vocab_size,
88 | num_decoder_symbols=target_vocab_size, embedding_size=hidden_size, output_projection=output_projection,
89 | feed_previous=False)
90 | self.loss = tf.contrib.legacy_seq2seq.sequence_loss(
91 | decoder_outputs, self.decoder_targets, self.target_weights, softmax_loss_function=softmax_loss_function)
92 |
93 | # Initialize the optimizer
94 | opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08)
95 | self.optOp = opt.minimize(self.loss)
96 |
97 | self.saver = tf.train.Saver(tf.all_variables())
98 |
99 | def step(self, session, encoder_inputs, decoder_inputs, decoder_targets, target_weights, go_token_id):
100 | #传入一个batch的数据,并训练性对应的模型
101 | # 构建sess.run时的feed_inpits
102 | feed_dict = {}
103 | if not self.forward_only:
104 | feed_dict[self.keep_drop] = 0.5
105 | for i in range(self.en_de_seq_len[0]):
106 | feed_dict[self.encoder_inputs[i].name] = encoder_inputs[i]
107 | for i in range(self.en_de_seq_len[1]):
108 | feed_dict[self.decoder_inputs[i].name] = decoder_inputs[i]
109 | feed_dict[self.decoder_targets[i].name] = decoder_targets[i]
110 | feed_dict[self.target_weights[i].name] = target_weights[i]
111 | run_ops = [self.optOp, self.loss]
112 | else:
113 | feed_dict[self.keep_drop] = 1.0
114 | for i in range(self.en_de_seq_len[0]):
115 | feed_dict[self.encoder_inputs[i].name] = encoder_inputs[i]
116 | feed_dict[self.decoder_inputs[0].name] = [go_token_id]
117 | if self.beam_search:
118 | run_ops = [self.beam_path, self.beam_symbol]
119 | else:
120 | run_ops = [self.outputs]
121 |
122 | outputs = session.run(run_ops, feed_dict)
123 | if not self.forward_only:
124 | return None, outputs[1]
125 | else:
126 | if self.beam_search:
127 | return outputs[0], outputs[1]
--------------------------------------------------------------------------------