├── assets ├── model.png ├── test.png ├── decoder.png ├── encoder.png └── seq2seq.png ├── README.md ├── train.py ├── predict.py ├── model.py └── utils.py /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/model.png -------------------------------------------------------------------------------- /assets/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/test.png -------------------------------------------------------------------------------- /assets/decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/decoder.png -------------------------------------------------------------------------------- /assets/encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/encoder.png -------------------------------------------------------------------------------- /assets/seq2seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/seq2seq.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # seq2seq-nmt 2 | 3 |
4 | 5 | 基于Keras实现seq2seq,进行英文到中文的翻译 6 | 7 | ![](https://github.com/yanqiangmiffy/seq2seq-nmt/blob/master/assets/seq2seq.png) 8 | 9 | 10 | ## 模型结构 11 | ![](https://github.com/yanqiangmiffy/seq2seq-nmt/blob/master/assets/model.png) 12 | 13 | ## 推理模型结构 14 | - encoder 15 | 16 | ![](https://github.com/yanqiangmiffy/seq2seq-nmt/blob/master/assets/encoder.png) 17 | 18 | - decoder 19 | 20 | ![](https://github.com/yanqiangmiffy/seq2seq-nmt/blob/master/assets/decoder.png) 21 | 22 | ## 测试结果 23 | ```text 24 | Is it all there? 25 | 全都在那裡嗎? 26 | 27 | Is it too salty? 28 | 还有多余的盐吗? 29 | 30 | Is she Japanese? 31 | 她是日本人嗎? 32 | 33 | Is this a river? 34 | 這是一條河嗎? 35 | 36 | Isn't that mine? 37 | 那是我的吗? 38 | 39 | It is up to you. 40 | 由你來決定。 41 | ``` 42 | ```text 43 | python predict.py --eng_sent "It's a nice day." 44 | 45 | 今天天氣很好。 46 | ``` 47 | 48 | ## 参考: 49 | 50 | https://github.com/pjgao/seq2seq_keras/blob/master/seq2seq_keras.ipynb 51 | 52 | https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py 53 | 54 | https://blog.csdn.net/PIPIXIU/article/details/81016974 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import Seq2Seq 2 | from utils import load_data 3 | # from keras.utils import plot_model 4 | import os 5 | import numpy as np 6 | 7 | # os.environ["PATH"] += os.pathsep + 'E:/Program Files (x86)/Graphviz2.38/bin' 8 | 9 | # 参数设置 10 | file_path='data/cmn.txt' 11 | n_units = 256 12 | batch_size = 64 13 | epoch = 200 14 | num_samples = 10000 15 | 16 | # 加载数据 17 | input_texts,target_texts,input_dict,target_dict,target_dict_reverse,\ 18 | output_length,input_feature_length,output_feature_length,\ 19 | encoder_input,decoder_input,decoder_output=load_data(file_path,num_samples) 20 | 21 | seq2seq=Seq2Seq(input_feature_length,output_feature_length,n_units) 22 | model_train,encoder_infer,decoder_infer=seq2seq.create_model() 23 | 24 | 25 | # 查看模型结构 26 | # plot_model(to_file='assets/model.png',model=model_train,show_shapes=True) 27 | # plot_model(to_file='assets/encoder.png',model=encoder_infer,show_shapes=True) 28 | # plot_model(to_file='assets/decoder.png',model=decoder_infer,show_shapes=True) 29 | 30 | model_train.compile(optimizer='rmsprop', loss='categorical_crossentropy') 31 | 32 | 33 | print(model_train.summary()) 34 | print(encoder_infer.summary()) 35 | print(decoder_infer.summary()) 36 | 37 | # 模型训练 38 | model_train.fit([encoder_input,decoder_input],decoder_output,batch_size=batch_size,epochs=epoch,validation_split=0.2) 39 | 40 | model_train.save("result/model_train.h5") 41 | encoder_infer.save("result/encoder_infer.h5") 42 | decoder_infer.save("result/decoder_infer.h5") 43 | 44 | 45 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model,load_model 2 | import numpy as np 3 | from utils import load_data 4 | import argparse 5 | 6 | # 参数设置 7 | file_path='data/cmn.txt' 8 | n_units = 256 9 | batch_size = 64 10 | epoch = 20 11 | num_samples = 10000 12 | 13 | # 加载数据 14 | input_texts,target_texts,input_dict,target_dict,target_dict_reverse,\ 15 | output_length,input_feature_length,output_feature_length,\ 16 | encoder_input,decoder_input,decoder_output=load_data(file_path,num_samples) 17 | 18 | 19 | def predict_chinese(source,encoder_inference, decoder_inference, n_steps, features): 20 | #先通过推理encoder获得预测输入序列的隐状态 21 | state = encoder_inference.predict(source) 22 | #第一个字符'\t',为起始标志 23 | predict_seq = np.zeros((1,1,features)) 24 | predict_seq[0,0,target_dict['\t']] = 1 25 | 26 | output = '' 27 | #开始对encoder获得的隐状态进行推理 28 | #每次循环用上次预测的字符作为输入来预测下一次的字符,直到预测出了终止符 29 | for i in range(n_steps):#n_steps为句子最大长度 30 | #给decoder输入上一个时刻的h,c隐状态,以及上一次的预测字符predict_seq 31 | yhat,h,c = decoder_inference.predict([predict_seq]+state) 32 | #注意,这里的yhat为Dense之后输出的结果,因此与h不同 33 | char_index = np.argmax(yhat[0,-1,:]) 34 | char = target_dict_reverse[char_index] 35 | output += char 36 | state = [h,c]#本次状态做为下一次的初始状态继续传递 37 | predict_seq = np.zeros((1,1,features)) 38 | predict_seq[0,0,char_index] = 1 39 | if char == '\n':#预测到了终止符则停下来 40 | break 41 | return output 42 | 43 | 44 | encoder_infer=load_model("result/encoder_infer.h5") 45 | decoder_infer=load_model("result/decoder_infer.h5") 46 | 47 | # for i in range(1000,1100): 48 | # test = encoder_input[i:i+1,:,:]#i:i+1保持数组是三维 49 | # out = predict_chinese(test,encoder_infer,decoder_infer,output_length,output_feature_length) 50 | # print(input_texts[i]) 51 | # print(out) 52 | 53 | 54 | if __name__ == '__main__': 55 | parser=argparse.ArgumentParser(description="Please input an english sentence!") 56 | parser.add_argument('--eng_sent','-e',required=True,help="an english sentence!") 57 | args=parser.parse_args() 58 | 59 | seq=args.eng_sent 60 | input_length = max([len(i) for i in input_texts]) 61 | encoder_input = np.zeros((1, input_length, input_feature_length)) 62 | for char_index, char in enumerate(seq): 63 | encoder_input[0, char_index, input_dict[char]] = 1 64 | out = predict_chinese(encoder_input, encoder_infer, decoder_infer, output_length, output_feature_length) 65 | print(out) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input,LSTM,Dense 2 | from keras.models import Model 3 | 4 | 5 | class Seq2Seq(object): 6 | 7 | def __init__(self,n_input,n_output,n_units): 8 | self.n_input=n_input 9 | self.n_output=n_output 10 | self.n_units=n_units 11 | 12 | def create_model(self): 13 | # 训练阶段 14 | # encoder 15 | encoder_input = Input(shape=(None, self.n_input)) 16 | # encoder输入维度n_input为每个时间步的输入xt的维度,这里是用来one-hot的英文字符数 17 | encoder = LSTM(self.n_units, return_state=True) 18 | # n_units为LSTM单元中每个门的神经元的个数,return_state设为True时才会返回最后时刻的状态h,c 19 | _, encoder_h, encoder_c = encoder(encoder_input) 20 | encoder_state = [encoder_h, encoder_c] 21 | # 保留下来encoder的末状态作为decoder的初始状态 22 | 23 | # decoder 24 | decoder_input = Input(shape=(None, self.n_output)) 25 | # decoder的输入维度为中文字符数 26 | decoder = LSTM(self.n_units, return_sequences=True, return_state=True) 27 | # 训练模型时需要decoder的输出序列来与结果对比优化,故return_sequences也要设为True 28 | decoder_output, _, _ = decoder(decoder_input, initial_state=encoder_state) 29 | # 在训练阶段只需要用到decoder的输出序列,不需要用最终状态h.c 30 | decoder_dense = Dense(self.n_output, activation='softmax') 31 | decoder_output = decoder_dense(decoder_output) 32 | # 输出序列经过全连接层得到结果 33 | 34 | # 生成的训练模型 35 | model = Model([encoder_input, decoder_input], decoder_output) 36 | # 第一个参数为训练模型的输入,包含了encoder和decoder的输入,第二个参数为模型的输出,包含了decoder的输出 37 | 38 | # 推理阶段,用于预测过程 39 | # 推断模型—encoder 40 | encoder_infer = Model(encoder_input, encoder_state) 41 | 42 | # 推断模型-decoder 43 | decoder_state_input_h = Input(shape=(self.n_units,)) 44 | decoder_state_input_c = Input(shape=(self.n_units,)) 45 | decoder_state_input = [decoder_state_input_h, decoder_state_input_c] # 上个时刻的状态h,c 46 | 47 | decoder_infer_output, decoder_infer_state_h, decoder_infer_state_c = decoder(decoder_input, 48 | initial_state=decoder_state_input) 49 | decoder_infer_state = [decoder_infer_state_h, decoder_infer_state_c] # 当前时刻得到的状态 50 | decoder_infer_output = decoder_dense(decoder_infer_output) # 当前时刻的输出 51 | decoder_infer = Model([decoder_input] + decoder_state_input, [decoder_infer_output] + decoder_infer_state) 52 | 53 | return model, encoder_infer, decoder_infer 54 | 55 | 56 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def load_data(filepath,num_samples=10000): 6 | data=pd.read_table(filepath,header=None).iloc[:num_samples,:,] 7 | data.columns=['inputs','targets'] 8 | data['targets']=data['targets'].apply(lambda x:'\t'+x+'\n') 9 | 10 | input_texts=data['inputs'].values.tolist() 11 | target_texts=data['targets'].values.tolist() 12 | 13 | input_characters=sorted(list(set(data.inputs.unique().sum()))) 14 | targets_characters=sorted(list(set(data.targets.unique().sum()))) 15 | # print(targets_characters) 16 | 17 | input_length=max([len(i) for i in input_texts]) 18 | output_length=max([len(i) for i in target_texts]) 19 | input_feature_length=len(input_characters) 20 | output_feature_length=len(targets_characters) 21 | 22 | encoder_input=np.zeros((num_samples,input_length,input_feature_length)) 23 | decoder_input=np.zeros((num_samples,output_length,output_feature_length)) 24 | decoder_output=np.zeros((num_samples,output_length,output_feature_length)) 25 | 26 | input_dict={char:index for index,char in enumerate(input_characters)} 27 | input_dict_reverse={index:char for index,char in enumerate(input_characters)} 28 | 29 | target_dict={char:index for index,char in enumerate(targets_characters)} 30 | target_dict_reverse={index:char for index,char in enumerate(targets_characters)} 31 | 32 | for seq_index,seq in enumerate(input_texts): 33 | for char_index,char in enumerate(seq): 34 | encoder_input[seq_index,char_index,input_dict[char]]=1 35 | 36 | for seq_index,seq in enumerate(target_texts): 37 | for char_index,char in enumerate(seq): 38 | decoder_input[seq_index,char_index,target_dict[char]]=1.0 39 | if char_index>0: 40 | decoder_output[seq_index,char_index-1,target_dict[char]]=1.0 41 | 42 | # print(' '.join([input_dict_reverse[np.argmax(i)] for i in encoder_input[0] if max(i)!=0])) 43 | # print(' '.join([target_dict_reverse[np.argmax(i)] for i in decoder_output[0] if max(i)!=0])) 44 | # print(' '.join([target_dict[np.argmax(i)] for i in decoder_input[0] if max(i)!=0])) 45 | # data_dict=dict() 46 | # data_dict['input_texts']=input_texts 47 | # data_dict['target_texts']=target_texts 48 | # data_dict['input_dict']=input_dict 49 | # data_dict['input_dict_reverse']=input_dict_reverse 50 | # data_dict['target_dict']=target_dict 51 | # data_dict['target_dict_reverse']=target_dict_reverse 52 | # data_dict['input_length']=input_length 53 | # data_dict['output_length']=output_length 54 | # data_dict['input_feature_length']=input_feature_length 55 | # data_dict['output_feature_length']=output_feature_length 56 | # data_dict['encoder_input']=encoder_input 57 | # data_dict['decoder_input']=decoder_input 58 | # data_dict['decoder_output']=decoder_output 59 | 60 | return input_texts,target_texts,input_dict,target_dict,target_dict_reverse,output_length,\ 61 | input_feature_length,output_feature_length,encoder_input,decoder_input,decoder_output 62 | 63 | if __name__ == '__main__': 64 | data_path = 'data/cmn.txt' 65 | load_data(data_path) 66 | --------------------------------------------------------------------------------