├── assets
├── model.png
├── test.png
├── decoder.png
├── encoder.png
└── seq2seq.png
├── README.md
├── train.py
├── predict.py
├── model.py
└── utils.py
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/model.png
--------------------------------------------------------------------------------
/assets/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/test.png
--------------------------------------------------------------------------------
/assets/decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/decoder.png
--------------------------------------------------------------------------------
/assets/encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/encoder.png
--------------------------------------------------------------------------------
/assets/seq2seq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanqiangmiffy/seq2seq-nmt/HEAD/assets/seq2seq.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # seq2seq-nmt
2 |
3 |

4 |
5 | 基于Keras实现seq2seq,进行英文到中文的翻译
6 |
7 | 
8 |
9 |
10 | ## 模型结构
11 | 
12 |
13 | ## 推理模型结构
14 | - encoder
15 |
16 | 
17 |
18 | - decoder
19 |
20 | 
21 |
22 | ## 测试结果
23 | ```text
24 | Is it all there?
25 | 全都在那裡嗎?
26 |
27 | Is it too salty?
28 | 还有多余的盐吗?
29 |
30 | Is she Japanese?
31 | 她是日本人嗎?
32 |
33 | Is this a river?
34 | 這是一條河嗎?
35 |
36 | Isn't that mine?
37 | 那是我的吗?
38 |
39 | It is up to you.
40 | 由你來決定。
41 | ```
42 | ```text
43 | python predict.py --eng_sent "It's a nice day."
44 |
45 | 今天天氣很好。
46 | ```
47 |
48 | ## 参考:
49 |
50 | https://github.com/pjgao/seq2seq_keras/blob/master/seq2seq_keras.ipynb
51 |
52 | https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py
53 |
54 | https://blog.csdn.net/PIPIXIU/article/details/81016974
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from model import Seq2Seq
2 | from utils import load_data
3 | # from keras.utils import plot_model
4 | import os
5 | import numpy as np
6 |
7 | # os.environ["PATH"] += os.pathsep + 'E:/Program Files (x86)/Graphviz2.38/bin'
8 |
9 | # 参数设置
10 | file_path='data/cmn.txt'
11 | n_units = 256
12 | batch_size = 64
13 | epoch = 200
14 | num_samples = 10000
15 |
16 | # 加载数据
17 | input_texts,target_texts,input_dict,target_dict,target_dict_reverse,\
18 | output_length,input_feature_length,output_feature_length,\
19 | encoder_input,decoder_input,decoder_output=load_data(file_path,num_samples)
20 |
21 | seq2seq=Seq2Seq(input_feature_length,output_feature_length,n_units)
22 | model_train,encoder_infer,decoder_infer=seq2seq.create_model()
23 |
24 |
25 | # 查看模型结构
26 | # plot_model(to_file='assets/model.png',model=model_train,show_shapes=True)
27 | # plot_model(to_file='assets/encoder.png',model=encoder_infer,show_shapes=True)
28 | # plot_model(to_file='assets/decoder.png',model=decoder_infer,show_shapes=True)
29 |
30 | model_train.compile(optimizer='rmsprop', loss='categorical_crossentropy')
31 |
32 |
33 | print(model_train.summary())
34 | print(encoder_infer.summary())
35 | print(decoder_infer.summary())
36 |
37 | # 模型训练
38 | model_train.fit([encoder_input,decoder_input],decoder_output,batch_size=batch_size,epochs=epoch,validation_split=0.2)
39 |
40 | model_train.save("result/model_train.h5")
41 | encoder_infer.save("result/encoder_infer.h5")
42 | decoder_infer.save("result/decoder_infer.h5")
43 |
44 |
45 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from keras.models import Model,load_model
2 | import numpy as np
3 | from utils import load_data
4 | import argparse
5 |
6 | # 参数设置
7 | file_path='data/cmn.txt'
8 | n_units = 256
9 | batch_size = 64
10 | epoch = 20
11 | num_samples = 10000
12 |
13 | # 加载数据
14 | input_texts,target_texts,input_dict,target_dict,target_dict_reverse,\
15 | output_length,input_feature_length,output_feature_length,\
16 | encoder_input,decoder_input,decoder_output=load_data(file_path,num_samples)
17 |
18 |
19 | def predict_chinese(source,encoder_inference, decoder_inference, n_steps, features):
20 | #先通过推理encoder获得预测输入序列的隐状态
21 | state = encoder_inference.predict(source)
22 | #第一个字符'\t',为起始标志
23 | predict_seq = np.zeros((1,1,features))
24 | predict_seq[0,0,target_dict['\t']] = 1
25 |
26 | output = ''
27 | #开始对encoder获得的隐状态进行推理
28 | #每次循环用上次预测的字符作为输入来预测下一次的字符,直到预测出了终止符
29 | for i in range(n_steps):#n_steps为句子最大长度
30 | #给decoder输入上一个时刻的h,c隐状态,以及上一次的预测字符predict_seq
31 | yhat,h,c = decoder_inference.predict([predict_seq]+state)
32 | #注意,这里的yhat为Dense之后输出的结果,因此与h不同
33 | char_index = np.argmax(yhat[0,-1,:])
34 | char = target_dict_reverse[char_index]
35 | output += char
36 | state = [h,c]#本次状态做为下一次的初始状态继续传递
37 | predict_seq = np.zeros((1,1,features))
38 | predict_seq[0,0,char_index] = 1
39 | if char == '\n':#预测到了终止符则停下来
40 | break
41 | return output
42 |
43 |
44 | encoder_infer=load_model("result/encoder_infer.h5")
45 | decoder_infer=load_model("result/decoder_infer.h5")
46 |
47 | # for i in range(1000,1100):
48 | # test = encoder_input[i:i+1,:,:]#i:i+1保持数组是三维
49 | # out = predict_chinese(test,encoder_infer,decoder_infer,output_length,output_feature_length)
50 | # print(input_texts[i])
51 | # print(out)
52 |
53 |
54 | if __name__ == '__main__':
55 | parser=argparse.ArgumentParser(description="Please input an english sentence!")
56 | parser.add_argument('--eng_sent','-e',required=True,help="an english sentence!")
57 | args=parser.parse_args()
58 |
59 | seq=args.eng_sent
60 | input_length = max([len(i) for i in input_texts])
61 | encoder_input = np.zeros((1, input_length, input_feature_length))
62 | for char_index, char in enumerate(seq):
63 | encoder_input[0, char_index, input_dict[char]] = 1
64 | out = predict_chinese(encoder_input, encoder_infer, decoder_infer, output_length, output_feature_length)
65 | print(out)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Input,LSTM,Dense
2 | from keras.models import Model
3 |
4 |
5 | class Seq2Seq(object):
6 |
7 | def __init__(self,n_input,n_output,n_units):
8 | self.n_input=n_input
9 | self.n_output=n_output
10 | self.n_units=n_units
11 |
12 | def create_model(self):
13 | # 训练阶段
14 | # encoder
15 | encoder_input = Input(shape=(None, self.n_input))
16 | # encoder输入维度n_input为每个时间步的输入xt的维度,这里是用来one-hot的英文字符数
17 | encoder = LSTM(self.n_units, return_state=True)
18 | # n_units为LSTM单元中每个门的神经元的个数,return_state设为True时才会返回最后时刻的状态h,c
19 | _, encoder_h, encoder_c = encoder(encoder_input)
20 | encoder_state = [encoder_h, encoder_c]
21 | # 保留下来encoder的末状态作为decoder的初始状态
22 |
23 | # decoder
24 | decoder_input = Input(shape=(None, self.n_output))
25 | # decoder的输入维度为中文字符数
26 | decoder = LSTM(self.n_units, return_sequences=True, return_state=True)
27 | # 训练模型时需要decoder的输出序列来与结果对比优化,故return_sequences也要设为True
28 | decoder_output, _, _ = decoder(decoder_input, initial_state=encoder_state)
29 | # 在训练阶段只需要用到decoder的输出序列,不需要用最终状态h.c
30 | decoder_dense = Dense(self.n_output, activation='softmax')
31 | decoder_output = decoder_dense(decoder_output)
32 | # 输出序列经过全连接层得到结果
33 |
34 | # 生成的训练模型
35 | model = Model([encoder_input, decoder_input], decoder_output)
36 | # 第一个参数为训练模型的输入,包含了encoder和decoder的输入,第二个参数为模型的输出,包含了decoder的输出
37 |
38 | # 推理阶段,用于预测过程
39 | # 推断模型—encoder
40 | encoder_infer = Model(encoder_input, encoder_state)
41 |
42 | # 推断模型-decoder
43 | decoder_state_input_h = Input(shape=(self.n_units,))
44 | decoder_state_input_c = Input(shape=(self.n_units,))
45 | decoder_state_input = [decoder_state_input_h, decoder_state_input_c] # 上个时刻的状态h,c
46 |
47 | decoder_infer_output, decoder_infer_state_h, decoder_infer_state_c = decoder(decoder_input,
48 | initial_state=decoder_state_input)
49 | decoder_infer_state = [decoder_infer_state_h, decoder_infer_state_c] # 当前时刻得到的状态
50 | decoder_infer_output = decoder_dense(decoder_infer_output) # 当前时刻的输出
51 | decoder_infer = Model([decoder_input] + decoder_state_input, [decoder_infer_output] + decoder_infer_state)
52 |
53 | return model, encoder_infer, decoder_infer
54 |
55 |
56 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 |
5 | def load_data(filepath,num_samples=10000):
6 | data=pd.read_table(filepath,header=None).iloc[:num_samples,:,]
7 | data.columns=['inputs','targets']
8 | data['targets']=data['targets'].apply(lambda x:'\t'+x+'\n')
9 |
10 | input_texts=data['inputs'].values.tolist()
11 | target_texts=data['targets'].values.tolist()
12 |
13 | input_characters=sorted(list(set(data.inputs.unique().sum())))
14 | targets_characters=sorted(list(set(data.targets.unique().sum())))
15 | # print(targets_characters)
16 |
17 | input_length=max([len(i) for i in input_texts])
18 | output_length=max([len(i) for i in target_texts])
19 | input_feature_length=len(input_characters)
20 | output_feature_length=len(targets_characters)
21 |
22 | encoder_input=np.zeros((num_samples,input_length,input_feature_length))
23 | decoder_input=np.zeros((num_samples,output_length,output_feature_length))
24 | decoder_output=np.zeros((num_samples,output_length,output_feature_length))
25 |
26 | input_dict={char:index for index,char in enumerate(input_characters)}
27 | input_dict_reverse={index:char for index,char in enumerate(input_characters)}
28 |
29 | target_dict={char:index for index,char in enumerate(targets_characters)}
30 | target_dict_reverse={index:char for index,char in enumerate(targets_characters)}
31 |
32 | for seq_index,seq in enumerate(input_texts):
33 | for char_index,char in enumerate(seq):
34 | encoder_input[seq_index,char_index,input_dict[char]]=1
35 |
36 | for seq_index,seq in enumerate(target_texts):
37 | for char_index,char in enumerate(seq):
38 | decoder_input[seq_index,char_index,target_dict[char]]=1.0
39 | if char_index>0:
40 | decoder_output[seq_index,char_index-1,target_dict[char]]=1.0
41 |
42 | # print(' '.join([input_dict_reverse[np.argmax(i)] for i in encoder_input[0] if max(i)!=0]))
43 | # print(' '.join([target_dict_reverse[np.argmax(i)] for i in decoder_output[0] if max(i)!=0]))
44 | # print(' '.join([target_dict[np.argmax(i)] for i in decoder_input[0] if max(i)!=0]))
45 | # data_dict=dict()
46 | # data_dict['input_texts']=input_texts
47 | # data_dict['target_texts']=target_texts
48 | # data_dict['input_dict']=input_dict
49 | # data_dict['input_dict_reverse']=input_dict_reverse
50 | # data_dict['target_dict']=target_dict
51 | # data_dict['target_dict_reverse']=target_dict_reverse
52 | # data_dict['input_length']=input_length
53 | # data_dict['output_length']=output_length
54 | # data_dict['input_feature_length']=input_feature_length
55 | # data_dict['output_feature_length']=output_feature_length
56 | # data_dict['encoder_input']=encoder_input
57 | # data_dict['decoder_input']=decoder_input
58 | # data_dict['decoder_output']=decoder_output
59 |
60 | return input_texts,target_texts,input_dict,target_dict,target_dict_reverse,output_length,\
61 | input_feature_length,output_feature_length,encoder_input,decoder_input,decoder_output
62 |
63 | if __name__ == '__main__':
64 | data_path = 'data/cmn.txt'
65 | load_data(data_path)
66 |
--------------------------------------------------------------------------------