├── README.md ├── data └── translation_fra_en.txt ├── lstm_predict.py ├── lstm_train.py └── model ├── seq2seq.h5 ├── source_words └── target_words /README.md: -------------------------------------------------------------------------------- 1 | # Seq2SeqTranslation 2 | Translation model based on sequence to sequence model. 基于seq2seq模型的翻译模型 3 | -------------------------------------------------------------------------------- /lstm_predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: lstm_predict.py 4 | # Author: lhy 5 | # Date: 18-5-22 6 | 7 | from keras.models import Model, load_model 8 | from keras.layers import Input 9 | import numpy as np 10 | 11 | class Translator: 12 | '''初始化''' 13 | def __init__(self): 14 | self.latent_dim = 256 15 | self.source_path = 'model/source_words' 16 | self.target_path = 'model/target_words' 17 | self.modelpath = 'model/seq2seq.h5' 18 | self.max_encoder_seq_length = 16 19 | self.max_decoder_seq_length = 59 20 | self.input_characters = [item for item in open(self.source_path).read().split('*')] 21 | self.target_characters = [item for item in open(self.target_path).read().split('*')] 22 | self.input_token_index = dict([(char, i) for i, char in enumerate(self.input_characters)]) 23 | self.target_token_index = dict([(char, i) for i, char in enumerate(self.target_characters)]) 24 | self.reverse_input_char_index = dict((i, char) for char, i in self.input_token_index.items()) 25 | self.reverse_target_char_index = dict((i, char) for char, i in self.target_token_index.items()) 26 | 27 | '''加载模型, 没有fit数据的过程''' 28 | def load_model(self): 29 | model = load_model(self.modelpath) 30 | encoder_inputs = model.input[0] # input_1 31 | encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1 32 | encoder_states = [state_h_enc, state_c_enc] 33 | encoder_model = Model(encoder_inputs, encoder_states) 34 | decoder_inputs = model.input[1] # input_2 35 | decoder_state_input_h = Input(shape=(self.latent_dim,), name='input_3') 36 | decoder_state_input_c = Input(shape=(self.latent_dim,), name='input_4') 37 | decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] 38 | decoder_lstm = model.layers[3] 39 | decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs) 40 | decoder_states = [state_h_dec, state_c_dec] 41 | decoder_dense = model.layers[4] 42 | decoder_outputs = decoder_dense(decoder_outputs) 43 | decoder_model = Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states) 44 | return encoder_model, decoder_model 45 | 46 | '''解码''' 47 | def decode_sequence(self, input_seq): 48 | encoder_model, decoder_model = self.load_model() 49 | states_value = encoder_model.predict(input_seq) 50 | target_seq = np.zeros((1, 1, len(self.target_characters))) 51 | target_seq[0, 0, self.target_token_index['\t']] = 1. 52 | stop_condition = False 53 | decoded_sentence = '' 54 | while not stop_condition: 55 | output_tokens, h, c = decoder_model.predict([target_seq] + states_value) 56 | sampled_token_index = np.argmax(output_tokens[0, -1, :]) 57 | sampled_char = self.reverse_target_char_index[sampled_token_index] 58 | decoded_sentence += sampled_char 59 | if (sampled_char == '\n' or len(decoded_sentence) > self.max_decoder_seq_length): 60 | stop_condition = True 61 | target_seq = np.zeros((1, 1, len(self.target_characters))) 62 | target_seq[0, 0, sampled_token_index] = 1. 63 | states_value = [h, c] 64 | 65 | return decoded_sentence 66 | 67 | '''新句子向量表示''' 68 | def encode_sentence(self, input_text): 69 | encode_input = np.zeros((1, self.max_encoder_seq_length, len(self.input_characters)), dtype='float32') 70 | for index, char in enumerate(input_text): 71 | print(index, char) 72 | encode_input[0, index, self.input_token_index[char]] = 1. 73 | return encode_input 74 | 75 | # 测试 76 | def test(): 77 | en = 'thank you' 78 | translator = Translator() 79 | input_seq = translator.encode_sentence(en) 80 | decoded_sentence = translator.decode_sequence(input_seq) 81 | print('-') 82 | print('Input sentence:', en) 83 | print('Decoded sentence:', decoded_sentence) 84 | 85 | test() -------------------------------------------------------------------------------- /lstm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # File: lstm_train.py 4 | # Author: lhy 5 | # Date: 18-5-22 6 | 7 | from keras.models import Model 8 | from keras.layers import Input, LSTM, Dense 9 | import numpy as np 10 | 11 | class Translator: 12 | def __init__(self): 13 | self.batch_size = 64 # Batch size for training. 14 | self.epochs = 10 # Number of epochs to train for. 15 | self.latent_dim = 256 # Latent dimensionality of the encoding space. 16 | self.num_samples = 10000 # Number of samples to train on. 17 | self.data_path = 'data/translation_fra_en.txt' 18 | self.inputpath = 'model/source_words' 19 | self.outputpath = 'model/target_words' 20 | self.modelpath = 'model/seq2seq.h5' 21 | self.num_encoder_tokens, self.num_decoder_tokens, self.encoder_input_data, \ 22 | self.decoder_input_data, self.decoder_target_data = self.build_data() 23 | 24 | '''构造训练数据''' 25 | def build_data(self): 26 | input_texts = [] 27 | target_texts = [] 28 | input_characters = set() 29 | target_characters = set() 30 | 31 | with open(self.data_path, 'r', encoding='utf-8') as f: 32 | lines = f.read().split('\n') 33 | #只采用10000条数据进行实验 34 | for line in lines[: min(self.num_samples, len(lines) - 1)]: 35 | input_text, target_text = line.split('\t') 36 | target_text = '\t' + target_text + '\n' 37 | input_texts.append(input_text) 38 | target_texts.append(target_text) 39 | for char in input_text: 40 | if char not in input_characters: 41 | input_characters.add(char) 42 | 43 | for char in target_text: 44 | if char not in target_characters: 45 | target_characters.add(char) 46 | 47 | input_characters = sorted(list(input_characters)) 48 | target_characters = sorted(list(target_characters)) 49 | 50 | with open(self.inputpath, 'w') as f: 51 | f.write("*".join([item for item in input_characters])) 52 | f.close() 53 | 54 | with open(self.outputpath, 'w') as f: 55 | f.write("*".join([item for item in target_characters])) 56 | f.close() 57 | 58 | num_encoder_tokens = len(input_characters) 59 | num_decoder_tokens = len(target_characters) 60 | 61 | max_encoder_seq_length = max([len(txt) for txt in input_texts]) 62 | max_decoder_seq_length = max([len(txt) for txt in target_texts]) 63 | 64 | input_token_index = dict([(char, i) for i, char in enumerate(input_characters)]) 65 | target_token_index = dict([(char, i) for i, char in enumerate(target_characters)]) 66 | 67 | encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype='float32') 68 | decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype='float32') 69 | decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype='float32') 70 | 71 | for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)): 72 | for t, char in enumerate(input_text): 73 | encoder_input_data[i, t, input_token_index[char]] = 1. 74 | 75 | for t, char in enumerate(target_text): 76 | decoder_input_data[i, t, target_token_index[char]] = 1. 77 | if t > 0: 78 | decoder_target_data[i, t - 1, target_token_index[char]] = 1. 79 | 80 | return num_encoder_tokens, num_decoder_tokens, encoder_input_data, decoder_input_data, decoder_target_data 81 | 82 | '''构建模型''' 83 | def build_model(self): 84 | # 输入层 85 | encoder_inputs = Input(shape=(None, self.num_encoder_tokens)) 86 | # LSTM编码层 87 | encoder = LSTM(self.latent_dim, return_state=True) 88 | # 特征抽象层 89 | encoder_outputs, state_h, state_c = encoder(encoder_inputs) 90 | # 保留encoder后得到的状态 91 | encoder_states = [state_h, state_c] 92 | # Set up the decoder, using `encoder_states` as initial state. 93 | decoder_inputs = Input(shape=(None, self.num_decoder_tokens)) 94 | decoder_lstm = LSTM(self.latent_dim, return_sequences=True, return_state=True) 95 | decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states) 96 | decoder_dense = Dense(self.num_decoder_tokens, activation='softmax') 97 | decoder_outputs = decoder_dense(decoder_outputs) 98 | 99 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs) 100 | 101 | model.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics=['accuracy']) 102 | return model 103 | 104 | '''训练模型''' 105 | def train_model(self): 106 | model = self.build_model() 107 | model.fit([self.encoder_input_data, self.decoder_input_data], self.decoder_target_data, 108 | batch_size=self.batch_size, 109 | epochs=self.epochs, 110 | validation_split=0.2, 111 | ) 112 | model.save(self.modelpath) 113 | 114 | tranlator = Translator() 115 | tranlator.train_model() -------------------------------------------------------------------------------- /model/seq2seq.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuhuanyong/Seq2SeqTranslation/9f08aefd03f0cd29db6592ba73e062ea54f6dd0e/model/seq2seq.h5 -------------------------------------------------------------------------------- /model/source_words: -------------------------------------------------------------------------------- 1 | *!*$*%*&*'*,*-*.*0*1*2*3*4*5*6*7*8*9*:*?*A*B*C*D*E*F*G*H*I*J*K*L*M*N*O*P*Q*R*S*T*U*V*W*Y*a*b*c*d*e*f*g*h*i*j*k*l*m*n*o*p*q*r*s*t*u*v*w*x*y*z -------------------------------------------------------------------------------- /model/target_words: -------------------------------------------------------------------------------- 1 | * 2 | * *!*$*%*&*'*(*)*,*-*.*0*1*3*5*6*8*9*:*?*A*B*C*D*E*F*G*H*I*J*K*L*M*N*O*P*Q*R*S*T*U*V*Y*a*b*c*d*e*f*g*h*i*j*k*l*m*n*o*p*q*r*s*t*u*v*w*x*y*z* *«*»*À*Ç*É*Ê*à*â*ç*è*é*ê*ë*î*ï*ô*ù*û*œ* *’*  --------------------------------------------------------------------------------