├── .idea
├── dictionaries
│ └── liuchong.xml
├── vcs.xml
├── modules.xml
├── misc.xml
└── seq2seq_chatbot_new.iml
├── README.md
├── data
└── dataset-cornell-length10-filter1-vocabSize40000.pkl
├── predict.py
├── train.py
├── data_helpers.py
└── model.py
/.idea/dictionaries/liuchong.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # seq2seq_chatbot_new
2 | 基于seq2seq模型的简单对话系统的tf实现,具有embedding、attention、beam_search等功能,数据集是Cornell Movie Dialogs
3 |
--------------------------------------------------------------------------------
/data/dataset-cornell-length10-filter1-vocabSize40000.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lc222/seq2seq_chatbot_new/HEAD/data/dataset-cornell-length10-filter1-vocabSize40000.pkl
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ApexVCS
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/seq2seq_chatbot_new.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from data_helpers import loadDataset, getBatches, sentence2enco
3 | from model import Seq2SeqModel
4 | import sys
5 | import numpy as np
6 |
7 |
8 | tf.app.flags.DEFINE_integer('rnn_size', 1024, 'Number of hidden units in each layer')
9 | tf.app.flags.DEFINE_integer('num_layers', 2, 'Number of layers in each encoder and decoder')
10 | tf.app.flags.DEFINE_integer('embedding_size', 1024, 'Embedding dimensions of encoder and decoder inputs')
11 |
12 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate')
13 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
14 | tf.app.flags.DEFINE_integer('numEpochs', 30, 'Maximum # of training epochs')
15 | tf.app.flags.DEFINE_integer('steps_per_checkpoint', 100, 'Save model checkpoint every this iteration')
16 | tf.app.flags.DEFINE_string('model_dir', 'model/', 'Path to save model checkpoints')
17 | tf.app.flags.DEFINE_string('model_name', 'chatbot.ckpt', 'File name used for model checkpoints')
18 | FLAGS = tf.app.flags.FLAGS
19 |
20 | data_path = 'E:\PycharmProjects\seq2seq_chatbot\seq2seq_chatbot_new\data\dataset-cornell-length10-filter1-vocabSize40000.pkl'
21 | word2id, id2word, trainingSamples = loadDataset(data_path)
22 |
23 | def predict_ids_to_seq(predict_ids, id2word, beam_szie):
24 | '''
25 | 将beam_search返回的结果转化为字符串
26 | :param predict_ids: 列表,长度为batch_size,每个元素都是decode_len*beam_size的数组
27 | :param id2word: vocab字典
28 | :return:
29 | '''
30 | for single_predict in predict_ids:
31 | for i in range(beam_szie):
32 | predict_list = np.ndarray.tolist(single_predict[:, :, i])
33 | predict_seq = [id2word[idx] for idx in predict_list[0]]
34 | print(" ".join(predict_seq))
35 |
36 | with tf.Session() as sess:
37 | model = Seq2SeqModel(FLAGS.rnn_size, FLAGS.num_layers, FLAGS.embedding_size, FLAGS.learning_rate, word2id,
38 | mode='decode', use_attention=True, beam_search=True, beam_size=5, max_gradient_norm=5.0)
39 | ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
40 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
41 | print('Reloading model parameters..')
42 | model.saver.restore(sess, ckpt.model_checkpoint_path)
43 | else:
44 | raise ValueError('No such file:[{}]'.format(FLAGS.model_dir))
45 | sys.stdout.write("> ")
46 | sys.stdout.flush()
47 | sentence = sys.stdin.readline()
48 | while sentence:
49 | batch = sentence2enco(sentence, word2id)
50 | predicted_ids = model.infer(sess, batch)
51 | # print(predicted_ids)
52 | predict_ids_to_seq(predicted_ids, id2word, 5)
53 | print("> ", "")
54 | sys.stdout.flush()
55 | sentence = sys.stdin.readline()
56 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from data_helpers import loadDataset, getBatches, sentence2enco
3 | from model import Seq2SeqModel
4 | from tqdm import tqdm
5 | import math
6 | import os
7 |
8 | tf.app.flags.DEFINE_integer('rnn_size', 1024, 'Number of hidden units in each layer')
9 | tf.app.flags.DEFINE_integer('num_layers', 2, 'Number of layers in each encoder and decoder')
10 | tf.app.flags.DEFINE_integer('embedding_size', 1024, 'Embedding dimensions of encoder and decoder inputs')
11 |
12 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate')
13 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
14 | tf.app.flags.DEFINE_integer('numEpochs', 30, 'Maximum # of training epochs')
15 | tf.app.flags.DEFINE_integer('steps_per_checkpoint', 100, 'Save model checkpoint every this iteration')
16 | tf.app.flags.DEFINE_string('model_dir', 'model/', 'Path to save model checkpoints')
17 | tf.app.flags.DEFINE_string('model_name', 'chatbot.ckpt', 'File name used for model checkpoints')
18 | FLAGS = tf.app.flags.FLAGS
19 |
20 | data_path = 'E:\PycharmProjects\seq2seq_chatbot\seq2seq_chatbot_new\data\dataset-cornell-length10-filter1-vocabSize40000.pkl'
21 | word2id, id2word, trainingSamples = loadDataset(data_path)
22 |
23 | with tf.Session() as sess:
24 | model = Seq2SeqModel(FLAGS.rnn_size, FLAGS.num_layers, FLAGS.embedding_size, FLAGS.learning_rate, word2id,
25 | mode='train', use_attention=True, beam_search=False, beam_size=5, max_gradient_norm=5.0)
26 | ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
27 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
28 | print('Reloading model parameters..')
29 | model.restore(sess, ckpt.model_checkpoint_path)
30 | else:
31 | print('Created new model parameters..')
32 | sess.run(tf.global_variables_initializer())
33 | current_step = 0
34 | summary_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)
35 | for e in range(FLAGS.numEpochs):
36 | print("----- Epoch {}/{} -----".format(e + 1, FLAGS.numEpochs))
37 | batches = getBatches(trainingSamples, FLAGS.batch_size)
38 | for nextBatch in tqdm(batches, desc="Training"):
39 | loss, summary = model.train(sess, nextBatch)
40 | current_step += 1
41 | if current_step % FLAGS.steps_per_checkpoint == 0:
42 | perplexity = math.exp(float(loss)) if loss < 300 else float('inf')
43 | tqdm.write("----- Step %d -- Loss %.2f -- Perplexity %.2f" % (current_step, loss, perplexity))
44 | summary_writer.add_summary(summary, current_step)
45 | checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
46 | model.saver.save(sess, checkpoint_path, global_step=current_step)
--------------------------------------------------------------------------------
/data_helpers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import nltk
7 | import numpy as np
8 | import pickle
9 | import random
10 |
11 | padToken, goToken, eosToken, unknownToken = 0, 1, 2, 3
12 |
13 | class Batch:
14 | #batch类,里面包含了encoder输入,decoder输入,decoder标签,decoder样本长度mask
15 | def __init__(self):
16 | self.encoder_inputs = []
17 | self.encoder_inputs_length = []
18 | self.decoder_targets = []
19 | self.decoder_targets_length = []
20 |
21 | def loadDataset(filename):
22 | '''
23 | 读取样本数据
24 | :param filename: 文件路径,是一个字典,包含word2id、id2word分别是单词与索引对应的字典和反序字典,
25 | trainingSamples样本数据,每一条都是QA对
26 | :return: word2id, id2word, trainingSamples
27 | '''
28 | dataset_path = os.path.join(filename)
29 | print('Loading dataset from {}'.format(dataset_path))
30 | with open(dataset_path, 'rb') as handle:
31 | data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset
32 | word2id = data['word2id']
33 | id2word = data['id2word']
34 | trainingSamples = data['trainingSamples']
35 | return word2id, id2word, trainingSamples
36 |
37 | def createBatch(samples):
38 | '''
39 | 根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式
40 | :param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id
41 | :return: 处理完之后可以直接传入feed_dict的数据格式
42 | '''
43 | batch = Batch()
44 | batch.encoder_inputs_length = [len(sample[0]) for sample in samples]
45 | batch.decoder_targets_length = [len(sample[1]) for sample in samples]
46 |
47 | max_source_length = max(batch.encoder_inputs_length)
48 | max_target_length = max(batch.decoder_targets_length)
49 |
50 | for sample in samples:
51 | #将source进行反序并PAD值本batch的最大长度
52 | source = list(reversed(sample[0]))
53 | pad = [padToken] * (max_source_length - len(source))
54 | batch.encoder_inputs.append(pad + source)
55 |
56 | #将target进行PAD,并添加END符号
57 | target = sample[1]
58 | pad = [padToken] * (max_target_length - len(target))
59 | batch.decoder_targets.append(target + pad)
60 | #batch.target_inputs.append([goToken] + target + pad[:-1])
61 |
62 | return batch
63 |
64 | def getBatches(data, batch_size):
65 | '''
66 | 根据读取出来的所有数据和batch_size将原始数据分成不同的小batch。对每个batch索引的样本调用createBatch函数进行处理
67 | :param data: loadDataset函数读取之后的trainingSamples,就是QA对的列表
68 | :param batch_size: batch大小
69 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
70 | :return: 列表,每个元素都是一个batch的样本数据,可直接传入feed_dict进行训练
71 | '''
72 | #每个epoch之前都要进行样本的shuffle
73 | random.shuffle(data)
74 | batches = []
75 | data_len = len(data)
76 | def genNextSamples():
77 | for i in range(0, data_len, batch_size):
78 | yield data[i:min(i + batch_size, data_len)]
79 |
80 | for samples in genNextSamples():
81 | batch = createBatch(samples)
82 | batches.append(batch)
83 | return batches
84 |
85 | def sentence2enco(sentence, word2id):
86 | '''
87 | 测试的时候将用户输入的句子转化为可以直接feed进模型的数据,现将句子转化成id,然后调用createBatch处理
88 | :param sentence: 用户输入的句子
89 | :param word2id: 单词与id之间的对应关系字典
90 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度
91 | :return: 处理之后的数据,可直接feed进模型进行预测
92 | '''
93 | if sentence == '':
94 | return None
95 | #分词
96 | tokens = nltk.word_tokenize(sentence)
97 | if len(tokens) > 20:
98 | return None
99 | #将每个单词转化为id
100 | wordIds = []
101 | for token in tokens:
102 | wordIds.append(word2id.get(token, unknownToken))
103 | #调用createBatch构造batch
104 | batch = createBatch([[wordIds, []]])
105 | return batch
106 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.util import nest
3 |
4 |
5 | class Seq2SeqModel():
6 | def __init__(self, rnn_size, num_layers, embedding_size, learning_rate, word_to_idx, mode, use_attention,
7 | beam_search, beam_size, max_gradient_norm=5.0):
8 | self.learing_rate = learning_rate
9 | self.embedding_size = embedding_size
10 | self.rnn_size = rnn_size
11 | self.num_layers = num_layers
12 | self.word_to_idx = word_to_idx
13 | self.vocab_size = len(self.word_to_idx)
14 | self.mode = mode
15 | self.use_attention = use_attention
16 | self.beam_search = beam_search
17 | self.beam_size = beam_size
18 | self.max_gradient_norm = max_gradient_norm
19 | #执行模型构建部分的代码
20 | self.build_model()
21 |
22 | def _create_rnn_cell(self):
23 | def single_rnn_cell():
24 | # 创建单个cell,这里需要注意的是一定要使用一个single_rnn_cell的函数,不然直接把cell放在MultiRNNCell
25 | # 的列表中最终模型会发生错误
26 | single_cell = tf.contrib.rnn.LSTMCell(self.rnn_size)
27 | #添加dropout
28 | cell = tf.contrib.rnn.DropoutWrapper(single_cell, output_keep_prob=self.keep_prob_placeholder)
29 | return cell
30 | #列表中每个元素都是调用single_rnn_cell函数
31 | cell = tf.contrib.rnn.MultiRNNCell([single_rnn_cell() for _ in range(self.num_layers)])
32 | return cell
33 |
34 | def build_model(self):
35 | print('building model... ...')
36 | #=================================1, 定义模型的placeholder
37 | self.encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
38 | self.encoder_inputs_length = tf.placeholder(tf.int32, [None], name='encoder_inputs_length')
39 |
40 | self.batch_size = tf.placeholder(tf.int32, [], name='batch_size')
41 | self.keep_prob_placeholder = tf.placeholder(tf.float32, name='keep_prob_placeholder')
42 |
43 | self.decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
44 | self.decoder_targets_length = tf.placeholder(tf.int32, [None], name='decoder_targets_length')
45 | # 根据目标序列长度,选出其中最大值,然后使用该值构建序列长度的mask标志。用一个sequence_mask的例子来说明起作用
46 | # tf.sequence_mask([1, 3, 2], 5)
47 | # [[True, False, False, False, False],
48 | # [True, True, True, False, False],
49 | # [True, True, False, False, False]]
50 | self.max_target_sequence_length = tf.reduce_max(self.decoder_targets_length, name='max_target_len')
51 | self.mask = tf.sequence_mask(self.decoder_targets_length, self.max_target_sequence_length, dtype=tf.float32, name='masks')
52 |
53 | #=================================2, 定义模型的encoder部分
54 | with tf.variable_scope('encoder'):
55 | #创建LSTMCell,两层+dropout
56 | encoder_cell = self._create_rnn_cell()
57 | #构建embedding矩阵,encoder和decoder公用该词向量矩阵
58 | embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size])
59 | encoder_inputs_embedded = tf.nn.embedding_lookup(embedding, self.encoder_inputs)
60 | # 使用dynamic_rnn构建LSTM模型,将输入编码成隐层向量。
61 | # encoder_outputs用于attention,batch_size*encoder_inputs_length*rnn_size,
62 | # encoder_state用于decoder的初始化状态,batch_size*rnn_szie
63 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded,
64 | sequence_length=self.encoder_inputs_length,
65 | dtype=tf.float32)
66 |
67 | # =================================3, 定义模型的decoder部分
68 | with tf.variable_scope('decoder'):
69 | encoder_inputs_length = self.encoder_inputs_length
70 | if self.beam_search:
71 | # 如果使用beam_search,则需要将encoder的输出进行tile_batch,其实就是复制beam_size份。
72 | print("use beamsearch decoding..")
73 | encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=self.beam_size)
74 | encoder_state = nest.map_structure(lambda s: tf.contrib.seq2seq.tile_batch(s, self.beam_size), encoder_state)
75 | encoder_inputs_length = tf.contrib.seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_size)
76 |
77 | #定义要使用的attention机制。
78 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.rnn_size, memory=encoder_outputs,
79 | memory_sequence_length=encoder_inputs_length)
80 | #attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length)
81 | # 定义decoder阶段要是用的LSTMCell,然后为其封装attention wrapper
82 | decoder_cell = self._create_rnn_cell()
83 | decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism,
84 | attention_layer_size=self.rnn_size, name='Attention_Wrapper')
85 | #如果使用beam_seach则batch_size = self.batch_size * self.beam_size。因为之前已经复制过一次
86 | batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size
87 | #定义decoder阶段的初始化状态,直接使用encoder阶段的最后一个隐层状态进行赋值
88 | decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32).clone(cell_state=encoder_state)
89 | output_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
90 |
91 | if self.mode == 'train':
92 | # 定义decoder阶段的输入,其实就是在decoder的target开始处添加一个,并删除结尾处的,并进行embedding。
93 | # decoder_inputs_embedded的shape为[batch_size, decoder_targets_length, embedding_size]
94 | ending = tf.strided_slice(self.decoder_targets, [0, 0], [self.batch_size, -1], [1, 1])
95 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self.word_to_idx['']), ending], 1)
96 | decoder_inputs_embedded = tf.nn.embedding_lookup(embedding, decoder_input)
97 | #训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能
98 | training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_inputs_embedded,
99 | sequence_length=self.decoder_targets_length,
100 | time_major=False, name='training_helper')
101 | training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper,
102 | initial_state=decoder_initial_state, output_layer=output_layer)
103 | #调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id)
104 | # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss
105 | # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案
106 | decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder,
107 | impute_finished=True,
108 | maximum_iterations=self.max_target_sequence_length)
109 | # 根据输出计算loss和梯度,并定义进行更新的AdamOptimizer和train_op
110 | self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
111 | self.decoder_predict_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_pred_train')
112 | # 使用sequence_loss计算loss,这里需要传入之前定义的mask标志
113 | self.loss = tf.contrib.seq2seq.sequence_loss(logits=self.decoder_logits_train,
114 | targets=self.decoder_targets, weights=self.mask)
115 |
116 | # Training summary for the current batch_loss
117 | tf.summary.scalar('loss', self.loss)
118 | self.summary_op = tf.summary.merge_all()
119 |
120 | optimizer = tf.train.AdamOptimizer(self.learing_rate)
121 | trainable_params = tf.trainable_variables()
122 | gradients = tf.gradients(self.loss, trainable_params)
123 | clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm)
124 | self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params))
125 | elif self.mode == 'decode':
126 | start_tokens = tf.ones([self.batch_size, ], tf.int32) * self.word_to_idx['']
127 | end_token = self.word_to_idx['']
128 | # decoder阶段根据是否使用beam_search决定不同的组合,
129 | # 如果使用则直接调用BeamSearchDecoder(里面已经实现了helper类)
130 | # 如果不使用则调用GreedyEmbeddingHelper+BasicDecoder的组合进行贪婪式解码
131 | if self.beam_search:
132 | inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=embedding,
133 | start_tokens=start_tokens, end_token=end_token,
134 | initial_state=decoder_initial_state,
135 | beam_width=self.beam_size,
136 | output_layer=output_layer)
137 | else:
138 | decoding_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embedding,
139 | start_tokens=start_tokens, end_token=end_token)
140 | inference_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=decoding_helper,
141 | initial_state=decoder_initial_state,
142 | output_layer=output_layer)
143 | decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder,
144 | maximum_iterations=10)
145 | # 调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,
146 | # 对于不使用beam_search的时候,它里面包含两项(rnn_outputs, sample_id)
147 | # rnn_output: [batch_size, decoder_targets_length, vocab_size]
148 | # sample_id: [batch_size, decoder_targets_length], tf.int32
149 |
150 | # 对于使用beam_search的时候,它里面包含两项(predicted_ids, beam_search_decoder_output)
151 | # predicted_ids: [batch_size, decoder_targets_length, beam_size],保存输出结果
152 | # beam_search_decoder_output: BeamSearchDecoderOutput instance namedtuple(scores, predicted_ids, parent_ids)
153 | # 所以对应只需要返回predicted_ids或者sample_id即可翻译成最终的结果
154 | if self.beam_search:
155 | self.decoder_predict_decode = decoder_outputs.predicted_ids
156 | else:
157 | self.decoder_predict_decode = tf.expand_dims(decoder_outputs.sample_id, -1)
158 | # =================================4, 保存模型
159 | self.saver = tf.train.Saver(tf.global_variables())
160 |
161 | def train(self, sess, batch):
162 | #对于训练阶段,需要执行self.train_op, self.loss, self.summary_op三个op,并传入相应的数据
163 | feed_dict = {self.encoder_inputs: batch.encoder_inputs,
164 | self.encoder_inputs_length: batch.encoder_inputs_length,
165 | self.decoder_targets: batch.decoder_targets,
166 | self.decoder_targets_length: batch.decoder_targets_length,
167 | self.keep_prob_placeholder: 0.5,
168 | self.batch_size: len(batch.encoder_inputs)}
169 | _, loss, summary = sess.run([self.train_op, self.loss, self.summary_op], feed_dict=feed_dict)
170 | return loss, summary
171 |
172 | def eval(self, sess, batch):
173 | # 对于eval阶段,不需要反向传播,所以只执行self.loss, self.summary_op两个op,并传入相应的数据
174 | feed_dict = {self.encoder_inputs: batch.encoder_inputs,
175 | self.encoder_inputs_length: batch.encoder_inputs_length,
176 | self.decoder_targets: batch.decoder_targets,
177 | self.decoder_targets_length: batch.decoder_targets_length,
178 | self.keep_prob_placeholder: 1.0,
179 | self.batch_size: len(batch.encoder_inputs)}
180 | loss, summary = sess.run([self.loss, self.summary_op], feed_dict=feed_dict)
181 | return loss, summary
182 |
183 | def infer(self, sess, batch):
184 | #infer阶段只需要运行最后的结果,不需要计算loss,所以feed_dict只需要传入encoder_input相应的数据即可
185 | feed_dict = {self.encoder_inputs: batch.encoder_inputs,
186 | self.encoder_inputs_length: batch.encoder_inputs_length,
187 | self.keep_prob_placeholder: 1.0,
188 | self.batch_size: len(batch.encoder_inputs)}
189 | predict = sess.run([self.decoder_predict_decode], feed_dict=feed_dict)
190 | return predict
--------------------------------------------------------------------------------