├── .gitignore ├── README.md ├── config └── global_config.py ├── datasets ├── data_loader.py └── eng-fra.txt ├── models └── models.py ├── train.py └── utils └── model_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | checkpoints/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Marvelous ChatBot 2 | 3 | [Update] 4 | 5 | it's 2019 now, previously model can not catch up state-of-art now. So we just move towards the future **a transformer based chatbot**, it's much more accurate and flexiable as well as full of imagination interms of being a cahtbot!!! More importantly we are opensourced the whole codes here: http://manaai.cn/aicodes_detail3.html?id=36 6 | Be sure to check it if you interested in chatbot and NLP!! **it's build with tensorflow 2.0 newest api!!** 7 | 8 | 9 | > Aim to build a Marvelous ChatBot 10 | 11 | 12 | # Synopsis 13 | 14 | This is the first and the only opensource of **ChatBot**, I will continues update this repo personally, aim to build a intelligent ChatBot, as the next version of Jarvis. 15 | 16 | This repo will maintain to build a **Marvelous ChatBot** based on PyTorch, 17 | welcome star and submit PR. 18 | 19 | # Already Done 20 | 21 | Currently this repo did those work: 22 | 23 | * based on official tutorial, this repo will move on develop a seq2seq chatbot, QA system; 24 | * re-constructed whole project, separate mess code into `data`, `model`, `train logic`; 25 | * model can be save into local, and reload from previous saved dir, which is lack in official tutorial; 26 | * just replace the dataset you can train your own data! 27 | 28 | Last but not least, this project will maintain or move on other repo in the future but we will 29 | continue implement a practical seq2seq based project to build anything you want: **Translate Machine**, 30 | **ChatBot**, **QA System**... anything you want. 31 | 32 | 33 | # Requirements 34 | 35 | ``` 36 | PyTorch 37 | python3 38 | Ubuntu Any Version 39 | Both CPU and GPU can works 40 | ``` 41 | 42 | # Usage 43 | 44 | Before dive into this repo, you want to glance the whole structure, we have these setups: 45 | 46 | * `config`: contains the config params, which is global in this project, you can change a global param here; 47 | * `datasets`: contains data and data_loader, using your own dataset, you should implement your own data_loader but just a liitle change on this one; 48 | * `models`: contains seq2seq model definition; 49 | * `utils`: this folder is very helpful, it contains some code may help you get out of anoying things, such as save model, or catch KeyboardInterrupt exception or load previous model, all can be done in here. 50 | 51 | to train model is also straightforward, just type: 52 | ``` 53 | python3 train.py 54 | ``` 55 | 56 | # Contribute 57 | 58 | wecome submit PR!!!! Let's build ChatBot together! 59 | 60 | # Contact 61 | 62 | if you have anyquestion, you can find me via wechat `jintianiloveu` 63 | -------------------------------------------------------------------------------- /config/global_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: global_config.py 3 | # author: JinTian 4 | # time: 10/05/2017 4:51 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import torch 20 | import os 21 | 22 | 23 | MODEL_PREFIX = 'seq2seq_translate' 24 | CHECKPOINT_DIR = './checkpoints' 25 | MAX_LENGTH = 10 26 | 27 | 28 | use_cuda = torch.cuda.is_available() 29 | teacher_forcing_ratio = 0.5 30 | -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: data_loader.py 3 | # author: JinTian 4 | # time: 10/05/2017 6:27 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | """ 20 | this file load pair data into seq2seq model 21 | raw file contains: 22 | sequenceA sequenceB 23 | .... 24 | 25 | """ 26 | import torch 27 | from torch.autograd import Variable 28 | import math 29 | import random 30 | import re 31 | import time 32 | import unicodedata 33 | from io import open 34 | from config.global_config import * 35 | 36 | 37 | class PairDataLoader(object): 38 | """ 39 | this class load raw file and generate pair data. 40 | """ 41 | 42 | def __init__(self): 43 | 44 | self.SOS_token = 0 45 | self.EOS_token = 1 46 | self.eng_prefixes = ( 47 | "i am ", "i m ", 48 | "he is", "he s ", 49 | "she is", "she s", 50 | "you are", "you re ", 51 | "we are", "we re ", 52 | "they are", "they re " 53 | ) 54 | 55 | self._prepare_data('eng', 'fra') 56 | 57 | class Lang(object): 58 | 59 | def __init__(self, name): 60 | self.name = name 61 | self.word2index = {} 62 | self.word2count = {} 63 | self.index2word = {0: "SOS", 1: "EOS"} 64 | self.n_words = 2 # Count SOS and EOS 65 | 66 | def add_sentence(self, sentence): 67 | for word in sentence.split(' '): 68 | self.add_word(word) 69 | 70 | def add_word(self, word): 71 | if word not in self.word2index: 72 | self.word2index[word] = self.n_words 73 | self.word2count[word] = 1 74 | self.index2word[self.n_words] = word 75 | self.n_words += 1 76 | else: 77 | self.word2count[word] += 1 78 | 79 | def filter_pair(self, p): 80 | return len(p[0].split(' ')) < MAX_LENGTH and \ 81 | len(p[1].split(' ')) < MAX_LENGTH and \ 82 | p[0].startswith(self.eng_prefixes) 83 | 84 | def filter_pairs(self, pairs): 85 | return [pair for pair in pairs if self.filter_pair(pair)] 86 | 87 | @staticmethod 88 | def unicode_to_ascii(s): 89 | return ''.join( 90 | c for c in unicodedata.normalize('NFD', s) 91 | if unicodedata.category(c) != 'Mn' 92 | ) 93 | 94 | def normalize_string(self, s): 95 | s = self.unicode_to_ascii(s).lower().strip() 96 | s = re.sub(r"([.!?])", r" \1", s) 97 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 98 | return s 99 | 100 | def read_lang(self, lang1, lang2, reverse=False): 101 | print("Reading lines...") 102 | lines = open('./datasets/%s-%s.txt' % (lang1, lang2), encoding='utf-8'). \ 103 | read().strip().split('\n') 104 | pairs = [[self.normalize_string(s) for s in l.split('\t')] for l in lines] 105 | if reverse: 106 | pairs = [list(reversed(p)) for p in pairs] 107 | input_lang = self.Lang(lang2) 108 | output_lang = self.Lang(lang1) 109 | else: 110 | input_lang = self.Lang(lang1) 111 | output_lang = self.Lang(lang2) 112 | 113 | return input_lang, output_lang, pairs 114 | 115 | @staticmethod 116 | def indexes_from_sentence(lang, sentence): 117 | return [lang.word2index[word] for word in sentence.split(' ')] 118 | 119 | def variable_from_sentence(self, lang, sentence): 120 | indexes = self.indexes_from_sentence(lang, sentence) 121 | indexes.append(self.EOS_token) 122 | result = Variable(torch.LongTensor(indexes).view(-1, 1)) 123 | if use_cuda: 124 | return result.cuda() 125 | else: 126 | return result 127 | 128 | def _prepare_data(self, lang1, lang2, reverse=False): 129 | input_lang, output_lang, pairs = self.read_lang(lang1, lang2, reverse) 130 | print("Read %s sentence pairs" % len(pairs)) 131 | self.pairs = self.filter_pairs(pairs) 132 | print("Trimmed to %s sentence pairs" % len(self.pairs)) 133 | print("Counting words...") 134 | for pair in self.pairs: 135 | input_lang.add_sentence(pair[0]) 136 | output_lang.add_sentence(pair[1]) 137 | self.input_lang = input_lang 138 | self.output_lang = output_lang 139 | print("Counted words:") 140 | print(input_lang.name, input_lang.n_words) 141 | print(output_lang.name, output_lang.n_words) 142 | 143 | def get_pair_variable(self): 144 | input_variable = self.variable_from_sentence(self.input_lang, random.choice(self.pairs)[0]) 145 | target_variable = self.variable_from_sentence(self.output_lang, random.choice(self.pairs)[1]) 146 | return input_variable, target_variable 147 | 148 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: models.py 3 | # author: JinTian 4 | # time: 10/05/2017 6:09 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.autograd import Variable 22 | 23 | from config.global_config import * 24 | 25 | 26 | class EncoderRNN(nn.Module): 27 | def __init__(self, input_size, hidden_size, n_layers=1): 28 | super(EncoderRNN, self).__init__() 29 | self.n_layers = n_layers 30 | self.hidden_size = hidden_size 31 | 32 | self.embedding = nn.Embedding(input_size, hidden_size) 33 | self.gru = nn.GRU(hidden_size, hidden_size) 34 | 35 | def forward(self, inputs, hidden): 36 | embedded = self.embedding(inputs).view(1, 1, -1) 37 | output = embedded 38 | for i in range(self.n_layers): 39 | output, hidden = self.gru(output, hidden) 40 | return output, hidden 41 | 42 | def init_hidden(self): 43 | result = Variable(torch.zeros(1, 1, self.hidden_size)) 44 | if use_cuda: 45 | return result.cuda() 46 | else: 47 | return result 48 | 49 | 50 | class DecoderRNN(nn.Module): 51 | def __init__(self, hidden_size, output_size, n_layers=1): 52 | super(DecoderRNN, self).__init__() 53 | self.n_layers = n_layers 54 | self.hidden_size = hidden_size 55 | 56 | self.embedding = nn.Embedding(output_size, hidden_size) 57 | self.gru = nn.GRU(hidden_size, hidden_size) 58 | self.out = nn.Linear(hidden_size, output_size) 59 | self.softmax = nn.LogSoftmax() 60 | 61 | def forward(self, inputs, hidden): 62 | output = self.embedding(inputs).view(1, 1, -1) 63 | for i in range(self.n_layers): 64 | output = F.relu(output) 65 | output, hidden = self.gru(output, hidden) 66 | output = self.softmax(self.out(output[0])) 67 | return output, hidden 68 | 69 | def init_hidden(self): 70 | result = Variable(torch.zeros(1, 1, self.hidden_size)) 71 | if use_cuda: 72 | return result.cuda() 73 | else: 74 | return result 75 | 76 | 77 | class AttnDecoderRNN(nn.Module): 78 | def __init__(self, hidden_size, output_size, n_layers=1, dropout_p=0.1, max_length=MAX_LENGTH): 79 | super(AttnDecoderRNN, self).__init__() 80 | self.hidden_size = hidden_size 81 | self.output_size = output_size 82 | self.n_layers = n_layers 83 | self.dropout_p = dropout_p 84 | self.max_length = max_length 85 | 86 | self.embedding = nn.Embedding(self.output_size, self.hidden_size) 87 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length) 88 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) 89 | self.dropout = nn.Dropout(self.dropout_p) 90 | self.gru = nn.GRU(self.hidden_size, self.hidden_size) 91 | self.out = nn.Linear(self.hidden_size, self.output_size) 92 | 93 | def forward(self, inputs, hidden, encoder_output, encoder_outputs): 94 | embedded = self.embedding(inputs).view(1, 1, -1) 95 | embedded = self.dropout(embedded) 96 | 97 | attn_weights = F.softmax( 98 | self.attn(torch.cat((embedded[0], hidden[0]), 1))) 99 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), 100 | encoder_outputs.unsqueeze(0)) 101 | 102 | output = torch.cat((embedded[0], attn_applied[0]), 1) 103 | output = self.attn_combine(output).unsqueeze(0) 104 | 105 | for i in range(self.n_layers): 106 | output = F.relu(output) 107 | output, hidden = self.gru(output, hidden) 108 | 109 | output = F.log_softmax(self.out(output[0])) 110 | return output, hidden, attn_weights 111 | 112 | def init_hidden(self): 113 | result = Variable(torch.zeros(1, 1, self.hidden_size)) 114 | if use_cuda: 115 | return result.cuda() 116 | else: 117 | return result 118 | 119 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import re 4 | import time 5 | import unicodedata 6 | from io import open 7 | import torch 8 | from torch.autograd import Variable 9 | import torch.optim as optim 10 | from models.models import * 11 | from utils.model_utils import * 12 | from datasets.data_loader import PairDataLoader 13 | 14 | 15 | def train_model(data_loader, input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, 16 | criterion, 17 | max_length=MAX_LENGTH): 18 | encoder_hidden = encoder.init_hidden() 19 | 20 | encoder_optimizer.zero_grad() 21 | decoder_optimizer.zero_grad() 22 | 23 | input_length = input_variable.size()[0] 24 | target_length = target_variable.size()[0] 25 | 26 | encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) 27 | encoder_outputs = encoder_outputs.cuda() if use_cuda else encoder_outputs 28 | 29 | loss = 0 30 | 31 | try: 32 | for ei in range(input_length): 33 | encoder_output, encoder_hidden = encoder( 34 | input_variable[ei], encoder_hidden) 35 | encoder_outputs[ei] = encoder_output[0][0] 36 | except KeyboardInterrupt: 37 | return 38 | 39 | decoder_input = Variable(torch.LongTensor([[data_loader.SOS_token]])) 40 | decoder_input = decoder_input.cuda() if use_cuda else decoder_input 41 | 42 | decoder_hidden = encoder_hidden 43 | 44 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 45 | 46 | if use_teacher_forcing: 47 | # Teacher forcing: Feed the target as the next input 48 | try: 49 | for di in range(target_length): 50 | decoder_output, decoder_hidden, decoder_attention = decoder( 51 | decoder_input, decoder_hidden, encoder_output, encoder_outputs) 52 | loss += criterion(decoder_output[0], target_variable[di]) 53 | decoder_input = target_variable[di] # Teacher forcing 54 | except KeyboardInterrupt: 55 | return 56 | 57 | else: 58 | # Without teacher forcing: use its own predictions as the next input 59 | try: 60 | for di in range(target_length): 61 | decoder_output, decoder_hidden, decoder_attention = decoder( 62 | decoder_input, decoder_hidden, encoder_output, encoder_outputs) 63 | topv, topi = decoder_output.data.topk(1) 64 | ni = topi[0][0] 65 | 66 | decoder_input = Variable(torch.LongTensor([[ni]])) 67 | decoder_input = decoder_input.cuda() if use_cuda else decoder_input 68 | 69 | loss += criterion(decoder_output[0], target_variable[di]) 70 | if ni == data_loader.EOS_token: 71 | break 72 | except KeyboardInterrupt: 73 | return 74 | 75 | loss.backward() 76 | 77 | encoder_optimizer.step() 78 | decoder_optimizer.step() 79 | 80 | return loss.data[0] / target_length 81 | 82 | 83 | def train(data_loader, encoder, decoder, n_epochs, print_every=100, save_every=1000, evaluate_every=100, 84 | learning_rate=0.01): 85 | start = time.time() 86 | print_loss_total = 0 # Reset every print_every 87 | 88 | encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) 89 | decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) 90 | criterion = nn.NLLLoss() 91 | 92 | encoder, decoder, start_epoch = load_previous_model(encoder, decoder, CHECKPOINT_DIR, MODEL_PREFIX) 93 | 94 | for epoch in range(start_epoch, n_epochs + 1): 95 | 96 | input_variable, target_variable = data_loader.get_pair_variable() 97 | 98 | try: 99 | loss = train_model(data_loader, input_variable, target_variable, encoder, 100 | decoder, encoder_optimizer, decoder_optimizer, criterion) 101 | except KeyboardInterrupt: 102 | pass 103 | print_loss_total += loss 104 | 105 | if epoch % print_every == 0: 106 | print_loss_avg = print_loss_total / print_every 107 | print_loss_total = 0 108 | print('%s (%d %d%%) %.4f' % (time_since(start, epoch / n_epochs), 109 | epoch, epoch / n_epochs * 100, print_loss_avg)) 110 | 111 | if epoch % save_every == 0: 112 | save_model(encoder, decoder, CHECKPOINT_DIR, MODEL_PREFIX, epoch) 113 | 114 | if epoch % evaluate_every == 0: 115 | evaluate_randomly(data_loader, encoder, decoder, n=1) 116 | 117 | 118 | def evaluate(data_loader, encoder, decoder, sentence, max_length=MAX_LENGTH): 119 | input_variable = data_loader.variable_from_sentence(data_loader.input_lang, sentence) 120 | input_length = input_variable.size()[0] 121 | encoder_hidden = encoder.init_hidden() 122 | 123 | encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size)) 124 | encoder_outputs = encoder_outputs.cuda() if use_cuda else encoder_outputs 125 | 126 | for ei in range(input_length): 127 | encoder_output, encoder_hidden = encoder(input_variable[ei], 128 | encoder_hidden) 129 | encoder_outputs[ei] = encoder_outputs[ei] + encoder_output[0][0] 130 | 131 | decoder_input = Variable(torch.LongTensor([[data_loader.SOS_token]])) # SOS 132 | decoder_input = decoder_input.cuda() if use_cuda else decoder_input 133 | 134 | decoder_hidden = encoder_hidden 135 | 136 | decoded_words = [] 137 | decoder_attentions = torch.zeros(max_length, max_length) 138 | 139 | for di in range(max_length): 140 | decoder_output, decoder_hidden, decoder_attention = decoder( 141 | decoder_input, decoder_hidden, encoder_output, encoder_outputs) 142 | decoder_attentions[di] = decoder_attention.data 143 | topv, topi = decoder_output.data.topk(1) 144 | ni = topi[0][0] 145 | if ni == data_loader.EOS_token: 146 | decoded_words.append('') 147 | break 148 | else: 149 | decoded_words.append(data_loader.output_lang.index2word[ni]) 150 | 151 | decoder_input = Variable(torch.LongTensor([[ni]])) 152 | decoder_input = decoder_input.cuda() if use_cuda else decoder_input 153 | 154 | return decoded_words, decoder_attentions[:di + 1] 155 | 156 | 157 | def evaluate_randomly(data_loader, encoder, decoder, n=10): 158 | for i in range(n): 159 | pair = random.choice(data_loader.pairs) 160 | print('>', pair[0]) 161 | print('=', pair[1]) 162 | output_words, attentions = evaluate(data_loader, encoder, decoder, pair[0]) 163 | output_sentence = ' '.join(output_words) 164 | print('<', output_sentence) 165 | print('') 166 | 167 | 168 | def main(): 169 | 170 | pair_data_loader = PairDataLoader() 171 | hidden_size = 256 172 | encoder1 = EncoderRNN(pair_data_loader.input_lang.n_words, hidden_size) 173 | attn_decoder1 = AttnDecoderRNN(hidden_size, pair_data_loader.output_lang.n_words, 174 | 1, dropout_p=0.1) 175 | 176 | if use_cuda: 177 | encoder1 = encoder1.cuda() 178 | attn_decoder1 = attn_decoder1.cuda() 179 | print('start training...') 180 | pair_data_loader.get_pair_variable() 181 | train(pair_data_loader, encoder1, attn_decoder1, 75000) 182 | evaluate_randomly(pair_data_loader, encoder1, attn_decoder1) 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: model_utils.py 3 | # author: JinTian 4 | # time: 10/05/2017 6:07 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import torch 20 | import os 21 | import glob 22 | import numpy as np 23 | import time 24 | import math 25 | 26 | 27 | def load_previous_model(encoder, decoder, checkpoint_dir, model_prefix): 28 | """ 29 | this can generally used in PyTorch to load previous model, 30 | this function will find max epoch from checkpoints dir, for other models 31 | just change model load format. 32 | :param encoder: 33 | :param decoder: 34 | :param checkpoint_dir: 35 | :param model_prefix: 36 | :return: 37 | """ 38 | f_list = glob.glob(os.path.join(checkpoint_dir, model_prefix) + '-*.pth') 39 | start_epoch = 1 40 | if len(f_list) >= 1: 41 | epoch_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list] 42 | last_checkpoint = f_list[np.argmax(epoch_list)] 43 | if os.path.exists(last_checkpoint): 44 | print('load from {}'.format(last_checkpoint)) 45 | model_state_dict = torch.load(last_checkpoint, map_location=lambda storage, loc: storage) 46 | encoder.load_state_dict(model_state_dict['encoder']) 47 | decoder.load_state_dict(model_state_dict['decoder']) 48 | start_epoch = np.max(epoch_list) 49 | return encoder, decoder, start_epoch 50 | 51 | 52 | def save_model(encoder, decoder, checkpoint_dir, model_prefix, epoch, max_keep=5): 53 | """ 54 | this method can be used in PyTorch to save model, 55 | this will save model with prefix and epochs. 56 | :param encoder: 57 | :param decoder: 58 | :param checkpoint_dir: 59 | :param model_prefix: 60 | :param epoch: 61 | :param max_keep: 62 | :return: 63 | """ 64 | if not os.path.exists(checkpoint_dir): 65 | os.makedirs(checkpoint_dir) 66 | f_list = glob.glob(os.path.join(checkpoint_dir, model_prefix) + '-*.pth') 67 | if len(f_list) >= max_keep + 2: 68 | # this step using for delete the more than 5 and litter one 69 | epoch_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list] 70 | to_delete = [f_list[i] for i in np.argsort(epoch_list)[-max_keep:]] 71 | for f in to_delete: 72 | os.remove(f) 73 | name = model_prefix + '-{}.pth'.format(epoch) 74 | file_path = os.path.join(checkpoint_dir, name) 75 | model_dict = { 76 | 'encoder': encoder.state_dict(), 77 | 'decoder': decoder.state_dict() 78 | } 79 | torch.save(model_dict, file_path) 80 | 81 | 82 | def as_minutes(s): 83 | m = math.floor(s / 60) 84 | s -= m * 60 85 | return '%dm %ds' % (m, s) 86 | 87 | 88 | def time_since(since, percent): 89 | now = time.time() 90 | s = now - since 91 | es = s / percent 92 | rs = es - s 93 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 94 | --------------------------------------------------------------------------------