├── .idea ├── dictionaries │ └── liuchong.xml ├── vcs.xml ├── modules.xml ├── misc.xml └── seq2seq_chatbot_new.iml ├── README.md ├── data └── dataset-cornell-length10-filter1-vocabSize40000.pkl ├── predict.py ├── train.py ├── data_helpers.py └── model.py /.idea/dictionaries/liuchong.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # seq2seq_chatbot_new 2 | 基于seq2seq模型的简单对话系统的tf实现,具有embedding、attention、beam_search等功能,数据集是Cornell Movie Dialogs 3 | -------------------------------------------------------------------------------- /data/dataset-cornell-length10-filter1-vocabSize40000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lc222/seq2seq_chatbot_new/HEAD/data/dataset-cornell-length10-filter1-vocabSize40000.pkl -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ApexVCS 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/seq2seq_chatbot_new.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from data_helpers import loadDataset, getBatches, sentence2enco 3 | from model import Seq2SeqModel 4 | import sys 5 | import numpy as np 6 | 7 | 8 | tf.app.flags.DEFINE_integer('rnn_size', 1024, 'Number of hidden units in each layer') 9 | tf.app.flags.DEFINE_integer('num_layers', 2, 'Number of layers in each encoder and decoder') 10 | tf.app.flags.DEFINE_integer('embedding_size', 1024, 'Embedding dimensions of encoder and decoder inputs') 11 | 12 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate') 13 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 14 | tf.app.flags.DEFINE_integer('numEpochs', 30, 'Maximum # of training epochs') 15 | tf.app.flags.DEFINE_integer('steps_per_checkpoint', 100, 'Save model checkpoint every this iteration') 16 | tf.app.flags.DEFINE_string('model_dir', 'model/', 'Path to save model checkpoints') 17 | tf.app.flags.DEFINE_string('model_name', 'chatbot.ckpt', 'File name used for model checkpoints') 18 | FLAGS = tf.app.flags.FLAGS 19 | 20 | data_path = 'E:\PycharmProjects\seq2seq_chatbot\seq2seq_chatbot_new\data\dataset-cornell-length10-filter1-vocabSize40000.pkl' 21 | word2id, id2word, trainingSamples = loadDataset(data_path) 22 | 23 | def predict_ids_to_seq(predict_ids, id2word, beam_szie): 24 | ''' 25 | 将beam_search返回的结果转化为字符串 26 | :param predict_ids: 列表,长度为batch_size,每个元素都是decode_len*beam_size的数组 27 | :param id2word: vocab字典 28 | :return: 29 | ''' 30 | for single_predict in predict_ids: 31 | for i in range(beam_szie): 32 | predict_list = np.ndarray.tolist(single_predict[:, :, i]) 33 | predict_seq = [id2word[idx] for idx in predict_list[0]] 34 | print(" ".join(predict_seq)) 35 | 36 | with tf.Session() as sess: 37 | model = Seq2SeqModel(FLAGS.rnn_size, FLAGS.num_layers, FLAGS.embedding_size, FLAGS.learning_rate, word2id, 38 | mode='decode', use_attention=True, beam_search=True, beam_size=5, max_gradient_norm=5.0) 39 | ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir) 40 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 41 | print('Reloading model parameters..') 42 | model.saver.restore(sess, ckpt.model_checkpoint_path) 43 | else: 44 | raise ValueError('No such file:[{}]'.format(FLAGS.model_dir)) 45 | sys.stdout.write("> ") 46 | sys.stdout.flush() 47 | sentence = sys.stdin.readline() 48 | while sentence: 49 | batch = sentence2enco(sentence, word2id) 50 | predicted_ids = model.infer(sess, batch) 51 | # print(predicted_ids) 52 | predict_ids_to_seq(predicted_ids, id2word, 5) 53 | print("> ", "") 54 | sys.stdout.flush() 55 | sentence = sys.stdin.readline() 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from data_helpers import loadDataset, getBatches, sentence2enco 3 | from model import Seq2SeqModel 4 | from tqdm import tqdm 5 | import math 6 | import os 7 | 8 | tf.app.flags.DEFINE_integer('rnn_size', 1024, 'Number of hidden units in each layer') 9 | tf.app.flags.DEFINE_integer('num_layers', 2, 'Number of layers in each encoder and decoder') 10 | tf.app.flags.DEFINE_integer('embedding_size', 1024, 'Embedding dimensions of encoder and decoder inputs') 11 | 12 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate') 13 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 14 | tf.app.flags.DEFINE_integer('numEpochs', 30, 'Maximum # of training epochs') 15 | tf.app.flags.DEFINE_integer('steps_per_checkpoint', 100, 'Save model checkpoint every this iteration') 16 | tf.app.flags.DEFINE_string('model_dir', 'model/', 'Path to save model checkpoints') 17 | tf.app.flags.DEFINE_string('model_name', 'chatbot.ckpt', 'File name used for model checkpoints') 18 | FLAGS = tf.app.flags.FLAGS 19 | 20 | data_path = 'E:\PycharmProjects\seq2seq_chatbot\seq2seq_chatbot_new\data\dataset-cornell-length10-filter1-vocabSize40000.pkl' 21 | word2id, id2word, trainingSamples = loadDataset(data_path) 22 | 23 | with tf.Session() as sess: 24 | model = Seq2SeqModel(FLAGS.rnn_size, FLAGS.num_layers, FLAGS.embedding_size, FLAGS.learning_rate, word2id, 25 | mode='train', use_attention=True, beam_search=False, beam_size=5, max_gradient_norm=5.0) 26 | ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir) 27 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 28 | print('Reloading model parameters..') 29 | model.restore(sess, ckpt.model_checkpoint_path) 30 | else: 31 | print('Created new model parameters..') 32 | sess.run(tf.global_variables_initializer()) 33 | current_step = 0 34 | summary_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph) 35 | for e in range(FLAGS.numEpochs): 36 | print("----- Epoch {}/{} -----".format(e + 1, FLAGS.numEpochs)) 37 | batches = getBatches(trainingSamples, FLAGS.batch_size) 38 | for nextBatch in tqdm(batches, desc="Training"): 39 | loss, summary = model.train(sess, nextBatch) 40 | current_step += 1 41 | if current_step % FLAGS.steps_per_checkpoint == 0: 42 | perplexity = math.exp(float(loss)) if loss < 300 else float('inf') 43 | tqdm.write("----- Step %d -- Loss %.2f -- Perplexity %.2f" % (current_step, loss, perplexity)) 44 | summary_writer.add_summary(summary, current_step) 45 | checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name) 46 | model.saver.save(sess, checkpoint_path, global_step=current_step) -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import nltk 7 | import numpy as np 8 | import pickle 9 | import random 10 | 11 | padToken, goToken, eosToken, unknownToken = 0, 1, 2, 3 12 | 13 | class Batch: 14 | #batch类,里面包含了encoder输入,decoder输入,decoder标签,decoder样本长度mask 15 | def __init__(self): 16 | self.encoder_inputs = [] 17 | self.encoder_inputs_length = [] 18 | self.decoder_targets = [] 19 | self.decoder_targets_length = [] 20 | 21 | def loadDataset(filename): 22 | ''' 23 | 读取样本数据 24 | :param filename: 文件路径,是一个字典,包含word2id、id2word分别是单词与索引对应的字典和反序字典, 25 | trainingSamples样本数据,每一条都是QA对 26 | :return: word2id, id2word, trainingSamples 27 | ''' 28 | dataset_path = os.path.join(filename) 29 | print('Loading dataset from {}'.format(dataset_path)) 30 | with open(dataset_path, 'rb') as handle: 31 | data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset 32 | word2id = data['word2id'] 33 | id2word = data['id2word'] 34 | trainingSamples = data['trainingSamples'] 35 | return word2id, id2word, trainingSamples 36 | 37 | def createBatch(samples): 38 | ''' 39 | 根据给出的samples(就是一个batch的数据),进行padding并构造成placeholder所需要的数据形式 40 | :param samples: 一个batch的样本数据,列表,每个元素都是[question, answer]的形式,id 41 | :return: 处理完之后可以直接传入feed_dict的数据格式 42 | ''' 43 | batch = Batch() 44 | batch.encoder_inputs_length = [len(sample[0]) for sample in samples] 45 | batch.decoder_targets_length = [len(sample[1]) for sample in samples] 46 | 47 | max_source_length = max(batch.encoder_inputs_length) 48 | max_target_length = max(batch.decoder_targets_length) 49 | 50 | for sample in samples: 51 | #将source进行反序并PAD值本batch的最大长度 52 | source = list(reversed(sample[0])) 53 | pad = [padToken] * (max_source_length - len(source)) 54 | batch.encoder_inputs.append(pad + source) 55 | 56 | #将target进行PAD,并添加END符号 57 | target = sample[1] 58 | pad = [padToken] * (max_target_length - len(target)) 59 | batch.decoder_targets.append(target + pad) 60 | #batch.target_inputs.append([goToken] + target + pad[:-1]) 61 | 62 | return batch 63 | 64 | def getBatches(data, batch_size): 65 | ''' 66 | 根据读取出来的所有数据和batch_size将原始数据分成不同的小batch。对每个batch索引的样本调用createBatch函数进行处理 67 | :param data: loadDataset函数读取之后的trainingSamples,就是QA对的列表 68 | :param batch_size: batch大小 69 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度 70 | :return: 列表,每个元素都是一个batch的样本数据,可直接传入feed_dict进行训练 71 | ''' 72 | #每个epoch之前都要进行样本的shuffle 73 | random.shuffle(data) 74 | batches = [] 75 | data_len = len(data) 76 | def genNextSamples(): 77 | for i in range(0, data_len, batch_size): 78 | yield data[i:min(i + batch_size, data_len)] 79 | 80 | for samples in genNextSamples(): 81 | batch = createBatch(samples) 82 | batches.append(batch) 83 | return batches 84 | 85 | def sentence2enco(sentence, word2id): 86 | ''' 87 | 测试的时候将用户输入的句子转化为可以直接feed进模型的数据,现将句子转化成id,然后调用createBatch处理 88 | :param sentence: 用户输入的句子 89 | :param word2id: 单词与id之间的对应关系字典 90 | :param en_de_seq_len: 列表,第一个元素表示source端序列的最大长度,第二个元素表示target端序列的最大长度 91 | :return: 处理之后的数据,可直接feed进模型进行预测 92 | ''' 93 | if sentence == '': 94 | return None 95 | #分词 96 | tokens = nltk.word_tokenize(sentence) 97 | if len(tokens) > 20: 98 | return None 99 | #将每个单词转化为id 100 | wordIds = [] 101 | for token in tokens: 102 | wordIds.append(word2id.get(token, unknownToken)) 103 | #调用createBatch构造batch 104 | batch = createBatch([[wordIds, []]]) 105 | return batch 106 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.util import nest 3 | 4 | 5 | class Seq2SeqModel(): 6 | def __init__(self, rnn_size, num_layers, embedding_size, learning_rate, word_to_idx, mode, use_attention, 7 | beam_search, beam_size, max_gradient_norm=5.0): 8 | self.learing_rate = learning_rate 9 | self.embedding_size = embedding_size 10 | self.rnn_size = rnn_size 11 | self.num_layers = num_layers 12 | self.word_to_idx = word_to_idx 13 | self.vocab_size = len(self.word_to_idx) 14 | self.mode = mode 15 | self.use_attention = use_attention 16 | self.beam_search = beam_search 17 | self.beam_size = beam_size 18 | self.max_gradient_norm = max_gradient_norm 19 | #执行模型构建部分的代码 20 | self.build_model() 21 | 22 | def _create_rnn_cell(self): 23 | def single_rnn_cell(): 24 | # 创建单个cell,这里需要注意的是一定要使用一个single_rnn_cell的函数,不然直接把cell放在MultiRNNCell 25 | # 的列表中最终模型会发生错误 26 | single_cell = tf.contrib.rnn.LSTMCell(self.rnn_size) 27 | #添加dropout 28 | cell = tf.contrib.rnn.DropoutWrapper(single_cell, output_keep_prob=self.keep_prob_placeholder) 29 | return cell 30 | #列表中每个元素都是调用single_rnn_cell函数 31 | cell = tf.contrib.rnn.MultiRNNCell([single_rnn_cell() for _ in range(self.num_layers)]) 32 | return cell 33 | 34 | def build_model(self): 35 | print('building model... ...') 36 | #=================================1, 定义模型的placeholder 37 | self.encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs') 38 | self.encoder_inputs_length = tf.placeholder(tf.int32, [None], name='encoder_inputs_length') 39 | 40 | self.batch_size = tf.placeholder(tf.int32, [], name='batch_size') 41 | self.keep_prob_placeholder = tf.placeholder(tf.float32, name='keep_prob_placeholder') 42 | 43 | self.decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets') 44 | self.decoder_targets_length = tf.placeholder(tf.int32, [None], name='decoder_targets_length') 45 | # 根据目标序列长度,选出其中最大值,然后使用该值构建序列长度的mask标志。用一个sequence_mask的例子来说明起作用 46 | # tf.sequence_mask([1, 3, 2], 5) 47 | # [[True, False, False, False, False], 48 | # [True, True, True, False, False], 49 | # [True, True, False, False, False]] 50 | self.max_target_sequence_length = tf.reduce_max(self.decoder_targets_length, name='max_target_len') 51 | self.mask = tf.sequence_mask(self.decoder_targets_length, self.max_target_sequence_length, dtype=tf.float32, name='masks') 52 | 53 | #=================================2, 定义模型的encoder部分 54 | with tf.variable_scope('encoder'): 55 | #创建LSTMCell,两层+dropout 56 | encoder_cell = self._create_rnn_cell() 57 | #构建embedding矩阵,encoder和decoder公用该词向量矩阵 58 | embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size]) 59 | encoder_inputs_embedded = tf.nn.embedding_lookup(embedding, self.encoder_inputs) 60 | # 使用dynamic_rnn构建LSTM模型,将输入编码成隐层向量。 61 | # encoder_outputs用于attention,batch_size*encoder_inputs_length*rnn_size, 62 | # encoder_state用于decoder的初始化状态,batch_size*rnn_szie 63 | encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded, 64 | sequence_length=self.encoder_inputs_length, 65 | dtype=tf.float32) 66 | 67 | # =================================3, 定义模型的decoder部分 68 | with tf.variable_scope('decoder'): 69 | encoder_inputs_length = self.encoder_inputs_length 70 | if self.beam_search: 71 | # 如果使用beam_search,则需要将encoder的输出进行tile_batch,其实就是复制beam_size份。 72 | print("use beamsearch decoding..") 73 | encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=self.beam_size) 74 | encoder_state = nest.map_structure(lambda s: tf.contrib.seq2seq.tile_batch(s, self.beam_size), encoder_state) 75 | encoder_inputs_length = tf.contrib.seq2seq.tile_batch(self.encoder_inputs_length, multiplier=self.beam_size) 76 | 77 | #定义要使用的attention机制。 78 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.rnn_size, memory=encoder_outputs, 79 | memory_sequence_length=encoder_inputs_length) 80 | #attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.rnn_size, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length) 81 | # 定义decoder阶段要是用的LSTMCell,然后为其封装attention wrapper 82 | decoder_cell = self._create_rnn_cell() 83 | decoder_cell = tf.contrib.seq2seq.AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, 84 | attention_layer_size=self.rnn_size, name='Attention_Wrapper') 85 | #如果使用beam_seach则batch_size = self.batch_size * self.beam_size。因为之前已经复制过一次 86 | batch_size = self.batch_size if not self.beam_search else self.batch_size * self.beam_size 87 | #定义decoder阶段的初始化状态,直接使用encoder阶段的最后一个隐层状态进行赋值 88 | decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32).clone(cell_state=encoder_state) 89 | output_layer = tf.layers.Dense(self.vocab_size, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) 90 | 91 | if self.mode == 'train': 92 | # 定义decoder阶段的输入,其实就是在decoder的target开始处添加一个,并删除结尾处的,并进行embedding。 93 | # decoder_inputs_embedded的shape为[batch_size, decoder_targets_length, embedding_size] 94 | ending = tf.strided_slice(self.decoder_targets, [0, 0], [self.batch_size, -1], [1, 1]) 95 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self.word_to_idx['']), ending], 1) 96 | decoder_inputs_embedded = tf.nn.embedding_lookup(embedding, decoder_input) 97 | #训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能 98 | training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_inputs_embedded, 99 | sequence_length=self.decoder_targets_length, 100 | time_major=False, name='training_helper') 101 | training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper, 102 | initial_state=decoder_initial_state, output_layer=output_layer) 103 | #调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id) 104 | # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss 105 | # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案 106 | decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder, 107 | impute_finished=True, 108 | maximum_iterations=self.max_target_sequence_length) 109 | # 根据输出计算loss和梯度,并定义进行更新的AdamOptimizer和train_op 110 | self.decoder_logits_train = tf.identity(decoder_outputs.rnn_output) 111 | self.decoder_predict_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_pred_train') 112 | # 使用sequence_loss计算loss,这里需要传入之前定义的mask标志 113 | self.loss = tf.contrib.seq2seq.sequence_loss(logits=self.decoder_logits_train, 114 | targets=self.decoder_targets, weights=self.mask) 115 | 116 | # Training summary for the current batch_loss 117 | tf.summary.scalar('loss', self.loss) 118 | self.summary_op = tf.summary.merge_all() 119 | 120 | optimizer = tf.train.AdamOptimizer(self.learing_rate) 121 | trainable_params = tf.trainable_variables() 122 | gradients = tf.gradients(self.loss, trainable_params) 123 | clip_gradients, _ = tf.clip_by_global_norm(gradients, self.max_gradient_norm) 124 | self.train_op = optimizer.apply_gradients(zip(clip_gradients, trainable_params)) 125 | elif self.mode == 'decode': 126 | start_tokens = tf.ones([self.batch_size, ], tf.int32) * self.word_to_idx[''] 127 | end_token = self.word_to_idx[''] 128 | # decoder阶段根据是否使用beam_search决定不同的组合, 129 | # 如果使用则直接调用BeamSearchDecoder(里面已经实现了helper类) 130 | # 如果不使用则调用GreedyEmbeddingHelper+BasicDecoder的组合进行贪婪式解码 131 | if self.beam_search: 132 | inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell, embedding=embedding, 133 | start_tokens=start_tokens, end_token=end_token, 134 | initial_state=decoder_initial_state, 135 | beam_width=self.beam_size, 136 | output_layer=output_layer) 137 | else: 138 | decoding_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embedding, 139 | start_tokens=start_tokens, end_token=end_token) 140 | inference_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=decoding_helper, 141 | initial_state=decoder_initial_state, 142 | output_layer=output_layer) 143 | decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder, 144 | maximum_iterations=10) 145 | # 调用dynamic_decode进行解码,decoder_outputs是一个namedtuple, 146 | # 对于不使用beam_search的时候,它里面包含两项(rnn_outputs, sample_id) 147 | # rnn_output: [batch_size, decoder_targets_length, vocab_size] 148 | # sample_id: [batch_size, decoder_targets_length], tf.int32 149 | 150 | # 对于使用beam_search的时候,它里面包含两项(predicted_ids, beam_search_decoder_output) 151 | # predicted_ids: [batch_size, decoder_targets_length, beam_size],保存输出结果 152 | # beam_search_decoder_output: BeamSearchDecoderOutput instance namedtuple(scores, predicted_ids, parent_ids) 153 | # 所以对应只需要返回predicted_ids或者sample_id即可翻译成最终的结果 154 | if self.beam_search: 155 | self.decoder_predict_decode = decoder_outputs.predicted_ids 156 | else: 157 | self.decoder_predict_decode = tf.expand_dims(decoder_outputs.sample_id, -1) 158 | # =================================4, 保存模型 159 | self.saver = tf.train.Saver(tf.global_variables()) 160 | 161 | def train(self, sess, batch): 162 | #对于训练阶段,需要执行self.train_op, self.loss, self.summary_op三个op,并传入相应的数据 163 | feed_dict = {self.encoder_inputs: batch.encoder_inputs, 164 | self.encoder_inputs_length: batch.encoder_inputs_length, 165 | self.decoder_targets: batch.decoder_targets, 166 | self.decoder_targets_length: batch.decoder_targets_length, 167 | self.keep_prob_placeholder: 0.5, 168 | self.batch_size: len(batch.encoder_inputs)} 169 | _, loss, summary = sess.run([self.train_op, self.loss, self.summary_op], feed_dict=feed_dict) 170 | return loss, summary 171 | 172 | def eval(self, sess, batch): 173 | # 对于eval阶段,不需要反向传播,所以只执行self.loss, self.summary_op两个op,并传入相应的数据 174 | feed_dict = {self.encoder_inputs: batch.encoder_inputs, 175 | self.encoder_inputs_length: batch.encoder_inputs_length, 176 | self.decoder_targets: batch.decoder_targets, 177 | self.decoder_targets_length: batch.decoder_targets_length, 178 | self.keep_prob_placeholder: 1.0, 179 | self.batch_size: len(batch.encoder_inputs)} 180 | loss, summary = sess.run([self.loss, self.summary_op], feed_dict=feed_dict) 181 | return loss, summary 182 | 183 | def infer(self, sess, batch): 184 | #infer阶段只需要运行最后的结果,不需要计算loss,所以feed_dict只需要传入encoder_input相应的数据即可 185 | feed_dict = {self.encoder_inputs: batch.encoder_inputs, 186 | self.encoder_inputs_length: batch.encoder_inputs_length, 187 | self.keep_prob_placeholder: 1.0, 188 | self.batch_size: len(batch.encoder_inputs)} 189 | predict = sess.run([self.decoder_predict_decode], feed_dict=feed_dict) 190 | return predict --------------------------------------------------------------------------------