├── config.py ├── .gitignore ├── README.md ├── main.py ├── load.py ├── model.py ├── evaluate.py └── train.py /config.py: -------------------------------------------------------------------------------- 1 | MAX_LENGTH = 15 2 | teacher_forcing_ratio = 1.0 3 | save_dir = './save' 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | save 104 | data 105 | *.sh 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-chatbot 2 | This is a pytorch seq2seq tutorial for [Formosa Speech Grand Challenge](https://fgc.stpi.narl.org.tw/activity/techai), which is modified from [pratical-pytorch seq2seq-translation-batched](https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb). 3 | [Tutorial](https://pytorch.org/tutorials/beginner/chatbot_tutorial.html) introducing this repo from pytorch official website, [Tutorial](https://fgc.stpi.narl.org.tw/activity/videoDetail/4b1141305df38a7c015e194f22f8015b) in Chinese. 4 | 5 | ## Update 6 | A new version is already implemented in branch "dev". 7 | 8 | ## Requirement 9 | * python 3.5+ 10 | * pytorch 0.4.0 11 | * tqdm 12 | 13 | ## Get started 14 | #### Clone the repository 15 | ``` 16 | git clone https://github.com/ywk991112/pytorch-chatbot 17 | ``` 18 | #### Corpus 19 | In the corpus file, the input-output sequence pairs should be in the adjacent lines. For example, 20 | ``` 21 | I'll see you next time. 22 | Sure. Bye. 23 | How are you? 24 | Better than ever. 25 | ``` 26 | The corpus files should be placed under a path like, 27 | ``` 28 | pytorch-chatbot/data/ 29 | ``` 30 | Otherwise, the corpus file will be tracked by git. 31 | #### Pretrained Model 32 | The pretrained model on [movie\_subtitles corpus](https://www.space.ntu.edu.tw/navigate/s/229EDD285D994B82B72CEDE5B5CA0CE0QQY) with an bidirectional rnn layer and 33 | hidden size 512 can be downloaded in [this link](https://www.space.ntu.edu.tw/navigate/s/D287C8C95A0B4877B8666A45D5D318C0QQY). 34 | The pretrained model file should be placed in directory as followed. 35 | ``` 36 | mkdir -p save/model/movie_subtitles/1-1_512 37 | mv 50000_backup_bidir_model.tar save/model/movie_subtitles/1-1_512 38 | ``` 39 | #### Training 40 | Run this command to start training, change the argument values in your own need. 41 | ``` 42 | python main.py -tr -la 1 -hi 512 -lr 0.0001 -it 50000 -b 64 -p 500 -s 1000 43 | ``` 44 | Continue training with saved model. 45 | ``` 46 | python main.py -tr -l -lr 0.0001 -it 50000 -b 64 -p 500 -s 1000 47 | ``` 48 | For more options, 49 | ``` 50 | python main.py -h 51 | ``` 52 | #### Testing 53 | Models will be saved in `pytorch-chatbot/save/model` while training, and this can be changed in `config.py`. 54 | Evaluate the saved model with input sequences in the corpus. 55 | ``` 56 | python main.py -te -c 57 | ``` 58 | Test the model with input sequence manually. 59 | ``` 60 | python main.py -te -c -i 61 | ``` 62 | Beam search with size k. 63 | ``` 64 | python main.py -te -c -be k [-i] 65 | ``` 66 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from train import trainIters 3 | from evaluate import runTest 4 | 5 | def parse(): 6 | parser = argparse.ArgumentParser(description='Attention Seq2Seq Chatbot') 7 | parser.add_argument('-tr', '--train', help='Train the model with corpus') 8 | parser.add_argument('-te', '--test', help='Test the saved model') 9 | parser.add_argument('-l', '--load', help='Load the model and train') 10 | parser.add_argument('-c', '--corpus', help='Test the saved model with vocabulary of the corpus') 11 | parser.add_argument('-r', '--reverse', action='store_true', help='Reverse the input sequence') 12 | parser.add_argument('-f', '--filter', action='store_true', help='Filter to small training data set') 13 | parser.add_argument('-i', '--input', action='store_true', help='Test the model by input the sentence') 14 | parser.add_argument('-it', '--iteration', type=int, default=10000, help='Train the model with it iterations') 15 | parser.add_argument('-p', '--print', type=int, default=100, help='Print every p iterations') 16 | parser.add_argument('-b', '--batch_size', type=int, default=64, help='Batch size') 17 | parser.add_argument('-la', '--layer', type=int, default=1, help='Number of layers in encoder and decoder') 18 | parser.add_argument('-hi', '--hidden', type=int, default=256, help='Hidden size in encoder and decoder') 19 | parser.add_argument('-be', '--beam', type=int, default=1, help='Hidden size in encoder and decoder') 20 | parser.add_argument('-s', '--save', type=int, default=500, help='Save every s iterations') 21 | parser.add_argument('-lr', '--learning_rate', type=float, default=0.01, help='Learning rate') 22 | parser.add_argument('-d', '--dropout', type=float, default=0.1, help='Dropout probability for rnn and dropout layers') 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | def parseFilename(filename, test=False): 28 | filename = filename.split('/') 29 | dataType = filename[-1][:-4] # remove '.tar' 30 | parse = dataType.split('_') 31 | reverse = 'reverse' in parse 32 | layers, hidden = filename[-2].split('_') 33 | n_layers = int(layers.split('-')[0]) 34 | hidden_size = int(hidden) 35 | return n_layers, hidden_size, reverse 36 | 37 | def run(args): 38 | reverse, fil, n_iteration, print_every, save_every, learning_rate, \ 39 | n_layers, hidden_size, batch_size, beam_size, inp, dropout = \ 40 | args.reverse, args.filter, args.iteration, args.print, args.save, args.learning_rate, \ 41 | args.layer, args.hidden, args.batch_size, args.beam, args.input, args.dropout 42 | if args.train and not args.load: 43 | trainIters(args.train, reverse, n_iteration, learning_rate, batch_size, 44 | n_layers, hidden_size, print_every, save_every, dropout) 45 | elif args.load: 46 | n_layers, hidden_size, reverse = parseFilename(args.load) 47 | trainIters(args.train, reverse, n_iteration, learning_rate, batch_size, 48 | n_layers, hidden_size, print_every, save_every, dropout, loadFilename=args.load) 49 | elif args.test: 50 | n_layers, hidden_size, reverse = parseFilename(args.test, True) 51 | runTest(n_layers, hidden_size, reverse, args.test, beam_size, inp, args.corpus) 52 | 53 | 54 | if __name__ == '__main__': 55 | args = parse() 56 | run(args) 57 | -------------------------------------------------------------------------------- /load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | import os 4 | import unicodedata 5 | 6 | from config import MAX_LENGTH, save_dir 7 | 8 | SOS_token = 0 9 | EOS_token = 1 10 | PAD_token = 2 11 | 12 | class Voc: 13 | def __init__(self, name): 14 | self.name = name 15 | self.word2index = {} 16 | self.word2count = {} 17 | self.index2word = {0: "SOS", 1: "EOS", 2:"PAD"} 18 | self.n_words = 3 # Count SOS and EOS 19 | 20 | def addSentence(self, sentence): 21 | for word in sentence.split(' '): 22 | self.addWord(word) 23 | 24 | def addWord(self, word): 25 | if word not in self.word2index: 26 | self.word2index[word] = self.n_words 27 | self.word2count[word] = 1 28 | self.index2word[self.n_words] = word 29 | self.n_words += 1 30 | else: 31 | self.word2count[word] += 1 32 | 33 | # Turn a Unicode string to plain ASCII, thanks to 34 | # http://stackoverflow.com/a/518232/2809427 35 | def unicodeToAscii(s): 36 | return ''.join( 37 | c for c in unicodedata.normalize('NFD', s) 38 | if unicodedata.category(c) != 'Mn' 39 | ) 40 | 41 | # Lowercase, trim, and remove non-letter characters 42 | def normalizeString(s): 43 | s = unicodeToAscii(s.lower().strip()) 44 | s = re.sub(r"([.!?])", r" \1", s) 45 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 46 | s = re.sub(r"\s+", r" ", s).strip() 47 | return s 48 | 49 | def readVocs(corpus, corpus_name): 50 | print("Reading lines...") 51 | 52 | # combine every two lines into pairs and normalize 53 | with open(corpus) as f: 54 | content = f.readlines() 55 | # import gzip 56 | # content = gzip.open(corpus, 'rt') 57 | lines = [x.strip() for x in content] 58 | it = iter(lines) 59 | # pairs = [[normalizeString(x), normalizeString(next(it))] for x in it] 60 | pairs = [[x, next(it)] for x in it] 61 | 62 | voc = Voc(corpus_name) 63 | return voc, pairs 64 | 65 | def filterPair(p): 66 | # input sequences need to preserve the last word for EOS_token 67 | return len(p[0].split(' ')) < MAX_LENGTH and \ 68 | len(p[1].split(' ')) < MAX_LENGTH 69 | 70 | def filterPairs(pairs): 71 | return [pair for pair in pairs if filterPair(pair)] 72 | 73 | def prepareData(corpus, corpus_name): 74 | voc, pairs = readVocs(corpus, corpus_name) 75 | print("Read {!s} sentence pairs".format(len(pairs))) 76 | pairs = filterPairs(pairs) 77 | print("Trimmed to {!s} sentence pairs".format(len(pairs))) 78 | print("Counting words...") 79 | for pair in pairs: 80 | voc.addSentence(pair[0]) 81 | voc.addSentence(pair[1]) 82 | print("Counted words:", voc.n_words) 83 | directory = os.path.join(save_dir, 'training_data', corpus_name) 84 | if not os.path.exists(directory): 85 | os.makedirs(directory) 86 | torch.save(voc, os.path.join(directory, '{!s}.tar'.format('voc'))) 87 | torch.save(pairs, os.path.join(directory, '{!s}.tar'.format('pairs'))) 88 | return voc, pairs 89 | 90 | def loadPrepareData(corpus): 91 | corpus_name = corpus.split('/')[-1].split('.')[0] 92 | try: 93 | print("Start loading training data ...") 94 | voc = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 'voc.tar')) 95 | pairs = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 'pairs.tar')) 96 | except FileNotFoundError: 97 | print("Saved data not found, start preparing trianing data ...") 98 | voc, pairs = prepareData(corpus, corpus_name) 99 | return voc, pairs 100 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | USE_CUDA = torch.cuda.is_available() 6 | device = torch.device("cuda" if USE_CUDA else "cpu") 7 | 8 | class EncoderRNN(nn.Module): 9 | def __init__(self, input_size, hidden_size, embedding, n_layers=1, dropout=0): 10 | super(EncoderRNN, self).__init__() 11 | self.n_layers = n_layers 12 | self.hidden_size = hidden_size 13 | self.embedding = embedding 14 | 15 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers, 16 | dropout=(0 if n_layers == 1 else dropout), bidirectional=True) 17 | 18 | def forward(self, input_seq, input_lengths, hidden=None): 19 | embedded = self.embedding(input_seq) 20 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) 21 | outputs, hidden = self.gru(packed, hidden) # output: (seq_len, batch, hidden*n_dir) 22 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) 23 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs (1, batch, hidden) 24 | return outputs, hidden 25 | 26 | class Attn(nn.Module): 27 | def __init__(self, method, hidden_size): 28 | super(Attn, self).__init__() 29 | 30 | self.method = method 31 | self.hidden_size = hidden_size 32 | 33 | if self.method == 'general': 34 | self.attn = nn.Linear(self.hidden_size, hidden_size) 35 | 36 | elif self.method == 'concat': 37 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 38 | self.v = nn.Parameter(torch.FloatTensor(1, hidden_size)) 39 | 40 | def forward(self, hidden, encoder_outputs): 41 | # hidden [1, 64, 512], encoder_outputs [14, 64, 512] 42 | max_len = encoder_outputs.size(0) 43 | batch_size = encoder_outputs.size(1) 44 | 45 | # Create variable to store attention energies 46 | attn_energies = torch.zeros(batch_size, max_len) # B x S 47 | attn_energies = attn_energies.to(device) 48 | 49 | # For each batch of encoder outputs 50 | for b in range(batch_size): 51 | # Calculate energy for each encoder output 52 | for i in range(max_len): 53 | attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0)) 54 | 55 | # Normalize energies to weights in range 0 to 1, resize to 1 x B x S 56 | return F.softmax(attn_energies, dim=1).unsqueeze(1) 57 | 58 | def score(self, hidden, encoder_output): 59 | # hidden [1, 512], encoder_output [1, 512] 60 | if self.method == 'dot': 61 | energy = hidden.squeeze(0).dot(encoder_output.squeeze(0)) 62 | return energy 63 | 64 | elif self.method == 'general': 65 | energy = self.attn(encoder_output) 66 | energy = hidden.squeeze(0).dot(energy.squeeze(0)) 67 | return energy 68 | 69 | elif self.method == 'concat': 70 | energy = self.attn(torch.cat((hidden, encoder_output), 1)) 71 | energy = self.v.squeeze(0).dot(energy.squeeze(0)) 72 | return energy 73 | 74 | class LuongAttnDecoderRNN(nn.Module): 75 | def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): 76 | super(LuongAttnDecoderRNN, self).__init__() 77 | 78 | # Keep for reference 79 | self.attn_model = attn_model 80 | self.hidden_size = hidden_size 81 | self.output_size = output_size 82 | self.n_layers = n_layers 83 | self.dropout = dropout 84 | 85 | # Define layers 86 | self.embedding = embedding 87 | self.embedding_dropout = nn.Dropout(dropout) 88 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout)) 89 | self.concat = nn.Linear(hidden_size * 2, hidden_size) 90 | self.out = nn.Linear(hidden_size, output_size) 91 | 92 | # Choose attention model 93 | if attn_model != 'none': 94 | self.attn = Attn(attn_model, hidden_size) 95 | 96 | def forward(self, input_seq, last_hidden, encoder_outputs): 97 | # Note: we run this one step at a time 98 | 99 | # Get the embedding of the current input word (last output word) 100 | embedded = self.embedding(input_seq) 101 | embedded = self.embedding_dropout(embedded) #[1, 64, 512] 102 | if(embedded.size(0) != 1): 103 | raise ValueError('Decoder input sequence length should be 1') 104 | 105 | # Get current hidden state from input word and last hidden state 106 | rnn_output, hidden = self.gru(embedded, last_hidden) 107 | 108 | # Calculate attention from current RNN state and all encoder outputs; 109 | # apply to encoder outputs to get weighted average 110 | attn_weights = self.attn(rnn_output, encoder_outputs) #[64, 1, 14] 111 | # encoder_outputs [14, 64, 512] 112 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) #[64, 1, 512] 113 | 114 | # Attentional vector using the RNN hidden state and context vector 115 | # concatenated together (Luong eq. 5) 116 | rnn_output = rnn_output.squeeze(0) #[64, 512] 117 | context = context.squeeze(1) #[64, 512] 118 | concat_input = torch.cat((rnn_output, context), 1) #[64, 1024] 119 | concat_output = torch.tanh(self.concat(concat_input)) #[64, 512] 120 | 121 | # Finally predict next token (Luong eq. 6, without softmax) 122 | output = self.out(concat_output) #[64, output_size] 123 | 124 | # Return final output, hidden state, and attention weights (for visualization) 125 | return output, hidden, attn_weights 126 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from train import indexesFromSentence 4 | from load import SOS_token, EOS_token 5 | from load import MAX_LENGTH, loadPrepareData, Voc 6 | from model import * 7 | 8 | USE_CUDA = torch.cuda.is_available() 9 | device = torch.device("cuda" if USE_CUDA else "cpu") 10 | 11 | class Sentence: 12 | def __init__(self, decoder_hidden, last_idx=SOS_token, sentence_idxes=[], sentence_scores=[]): 13 | if(len(sentence_idxes) != len(sentence_scores)): 14 | raise ValueError("length of indexes and scores should be the same") 15 | self.decoder_hidden = decoder_hidden 16 | self.last_idx = last_idx 17 | self.sentence_idxes = sentence_idxes 18 | self.sentence_scores = sentence_scores 19 | 20 | def avgScore(self): 21 | if len(self.sentence_scores) == 0: 22 | raise ValueError("Calculate average score of sentence, but got no word") 23 | # return mean of sentence_score 24 | return sum(self.sentence_scores) / len(self.sentence_scores) 25 | 26 | def addTopk(self, topi, topv, decoder_hidden, beam_size, voc): 27 | topv = torch.log(topv) 28 | terminates, sentences = [], [] 29 | for i in range(beam_size): 30 | if topi[0][i] == EOS_token: 31 | terminates.append(([voc.index2word[idx.item()] for idx in self.sentence_idxes] + [''], 32 | self.avgScore())) # tuple(word_list, score_float 33 | continue 34 | idxes = self.sentence_idxes[:] # pass by value 35 | scores = self.sentence_scores[:] # pass by value 36 | idxes.append(topi[0][i]) 37 | scores.append(topv[0][i]) 38 | sentences.append(Sentence(decoder_hidden, topi[0][i], idxes, scores)) 39 | return terminates, sentences 40 | 41 | def toWordScore(self, voc): 42 | words = [] 43 | for i in range(len(self.sentence_idxes)): 44 | if self.sentence_idxes[i] == EOS_token: 45 | words.append('') 46 | else: 47 | words.append(voc.index2word[self.sentence_idxes[i].item()]) 48 | if self.sentence_idxes[-1] != EOS_token: 49 | words.append('') 50 | return (words, self.avgScore()) 51 | 52 | def beam_decode(decoder, decoder_hidden, encoder_outputs, voc, beam_size, max_length=MAX_LENGTH): 53 | terminal_sentences, prev_top_sentences, next_top_sentences = [], [], [] 54 | prev_top_sentences.append(Sentence(decoder_hidden)) 55 | for i in range(max_length): 56 | for sentence in prev_top_sentences: 57 | decoder_input = torch.LongTensor([[sentence.last_idx]]) 58 | decoder_input = decoder_input.to(device) 59 | 60 | decoder_hidden = sentence.decoder_hidden 61 | decoder_output, decoder_hidden, _ = decoder( 62 | decoder_input, decoder_hidden, encoder_outputs 63 | ) 64 | topv, topi = decoder_output.topk(beam_size) 65 | term, top = sentence.addTopk(topi, topv, decoder_hidden, beam_size, voc) 66 | terminal_sentences.extend(term) 67 | next_top_sentences.extend(top) 68 | 69 | next_top_sentences.sort(key=lambda s: s.avgScore(), reverse=True) 70 | prev_top_sentences = next_top_sentences[:beam_size] 71 | next_top_sentences = [] 72 | 73 | terminal_sentences += [sentence.toWordScore(voc) for sentence in prev_top_sentences] 74 | terminal_sentences.sort(key=lambda x: x[1], reverse=True) 75 | 76 | n = min(len(terminal_sentences), 15) 77 | return terminal_sentences[:n] 78 | 79 | def decode(decoder, decoder_hidden, encoder_outputs, voc, max_length=MAX_LENGTH): 80 | 81 | decoder_input = torch.LongTensor([[SOS_token]]) 82 | decoder_input = decoder_input.to(device) 83 | 84 | decoded_words = [] 85 | decoder_attentions = torch.zeros(max_length, max_length) #TODO: or (MAX_LEN+1, MAX_LEN+1) 86 | 87 | for di in range(max_length): 88 | decoder_output, decoder_hidden, decoder_attn = decoder( 89 | decoder_input, decoder_hidden, encoder_outputs 90 | ) 91 | _, topi = decoder_output.topk(3) 92 | ni = topi[0][0] 93 | if ni == EOS_token: 94 | decoded_words.append('') 95 | break 96 | else: 97 | decoded_words.append(voc.index2word[ni.item()]) 98 | 99 | decoder_input = torch.LongTensor([[ni]]) 100 | decoder_input = decoder_input.to(device) 101 | 102 | return decoded_words, decoder_attentions[:di + 1] 103 | 104 | 105 | def evaluate(encoder, decoder, voc, sentence, beam_size, max_length=MAX_LENGTH): 106 | indexes_batch = [indexesFromSentence(voc, sentence)] #[1, seq_len] 107 | lengths = [len(indexes) for indexes in indexes_batch] 108 | input_batch = torch.LongTensor(indexes_batch).transpose(0, 1) 109 | input_batch = input_batch.to(device) 110 | 111 | encoder_outputs, encoder_hidden = encoder(input_batch, lengths, None) 112 | 113 | decoder_hidden = encoder_hidden[:decoder.n_layers] 114 | 115 | if beam_size == 1: 116 | return decode(decoder, decoder_hidden, encoder_outputs, voc) 117 | else: 118 | return beam_decode(decoder, decoder_hidden, encoder_outputs, voc, beam_size) 119 | 120 | 121 | def evaluateRandomly(encoder, decoder, voc, pairs, reverse, beam_size, n=10): 122 | for _ in range(n): 123 | pair = random.choice(pairs) 124 | print("=============================================================") 125 | if reverse: 126 | print('>', " ".join(reversed(pair[0].split()))) 127 | else: 128 | print('>', pair[0]) 129 | if beam_size == 1: 130 | output_words, _ = evaluate(encoder, decoder, voc, pair[0], beam_size) 131 | output_sentence = ' '.join(output_words) 132 | print('<', output_sentence) 133 | else: 134 | output_words_list = evaluate(encoder, decoder, voc, pair[0], beam_size) 135 | for output_words, score in output_words_list: 136 | output_sentence = ' '.join(output_words) 137 | print("{:.3f} < {}".format(score, output_sentence)) 138 | 139 | def evaluateInput(encoder, decoder, voc, beam_size): 140 | pair = '' 141 | while(1): 142 | try: 143 | pair = input('> ') 144 | if pair == 'q': break 145 | if beam_size == 1: 146 | output_words, _ = evaluate(encoder, decoder, voc, pair, beam_size) 147 | output_sentence = ' '.join(output_words) 148 | print('<', output_sentence) 149 | else: 150 | output_words_list = evaluate(encoder, decoder, voc, pair, beam_size) 151 | for output_words, score in output_words_list: 152 | output_sentence = ' '.join(output_words) 153 | print("{:.3f} < {}".format(score, output_sentence)) 154 | except KeyError: 155 | print("Incorrect spelling.") 156 | 157 | 158 | def runTest(n_layers, hidden_size, reverse, modelFile, beam_size, inp, corpus): 159 | torch.set_grad_enabled(False) 160 | 161 | voc, pairs = loadPrepareData(corpus) 162 | embedding = nn.Embedding(voc.n_words, hidden_size) 163 | encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers) 164 | attn_model = 'dot' 165 | decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers) 166 | 167 | checkpoint = torch.load(modelFile) 168 | encoder.load_state_dict(checkpoint['en']) 169 | decoder.load_state_dict(checkpoint['de']) 170 | 171 | # train mode set to false, effect only on dropout, batchNorm 172 | encoder.train(False); 173 | decoder.train(False); 174 | 175 | encoder = encoder.to(device) 176 | decoder = decoder.to(device) 177 | 178 | if inp: 179 | evaluateInput(encoder, decoder, voc, beam_size) 180 | else: 181 | evaluateRandomly(encoder, decoder, voc, pairs, reverse, beam_size, 20) 182 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import optim 5 | import torch.backends.cudnn as cudnn 6 | 7 | import itertools 8 | import random 9 | import math 10 | import os 11 | from tqdm import tqdm 12 | from load import loadPrepareData 13 | from load import SOS_token, EOS_token, PAD_token 14 | from model import EncoderRNN, LuongAttnDecoderRNN 15 | from config import MAX_LENGTH, teacher_forcing_ratio, save_dir 16 | 17 | USE_CUDA = torch.cuda.is_available() 18 | device = torch.device("cuda" if USE_CUDA else "cpu") 19 | 20 | cudnn.benchmark = True 21 | ############################################# 22 | # generate file name for saving parameters 23 | ############################################# 24 | def filename(reverse, obj): 25 | filename = '' 26 | if reverse: 27 | filename += 'reverse_' 28 | filename += obj 29 | return filename 30 | 31 | 32 | ############################################# 33 | # Prepare Training Data 34 | ############################################# 35 | def indexesFromSentence(voc, sentence): 36 | return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] 37 | 38 | # batch_first: true -> false, i.e. shape: seq_len * batch 39 | def zeroPadding(l, fillvalue=PAD_token): 40 | return list(itertools.zip_longest(*l, fillvalue=fillvalue)) 41 | 42 | def binaryMatrix(l, value=PAD_token): 43 | m = [] 44 | for i, seq in enumerate(l): 45 | m.append([]) 46 | for token in seq: 47 | if token == PAD_token: 48 | m[i].append(0) 49 | else: 50 | m[i].append(1) 51 | return m 52 | 53 | # convert to index, add EOS 54 | # return input pack_padded_sequence 55 | def inputVar(l, voc): 56 | indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l] 57 | lengths = [len(indexes) for indexes in indexes_batch] 58 | padList = zeroPadding(indexes_batch) 59 | padVar = torch.LongTensor(padList) 60 | return padVar, lengths 61 | 62 | # convert to index, add EOS, zero padding 63 | # return output variable, mask, max length of the sentences in batch 64 | def outputVar(l, voc): 65 | indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l] 66 | max_target_len = max([len(indexes) for indexes in indexes_batch]) 67 | padList = zeroPadding(indexes_batch) 68 | mask = binaryMatrix(padList) 69 | mask = torch.ByteTensor(mask) 70 | padVar = torch.LongTensor(padList) 71 | return padVar, mask, max_target_len 72 | 73 | # pair_batch is a list of (input, output) with length batch_size 74 | # sort list of (input, output) pairs by input length, reverse input 75 | # return input, lengths for pack_padded_sequence, output_variable, mask 76 | def batch2TrainData(voc, pair_batch, reverse): 77 | if reverse: 78 | pair_batch = [pair[::-1] for pair in pair_batch] 79 | pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True) 80 | input_batch, output_batch = [], [] 81 | for pair in pair_batch: 82 | input_batch.append(pair[0]) 83 | output_batch.append(pair[1]) 84 | inp, lengths = inputVar(input_batch, voc) 85 | output, mask, max_target_len = outputVar(output_batch, voc) 86 | return inp, lengths, output, mask, max_target_len 87 | 88 | ############################################# 89 | # Training 90 | ############################################# 91 | 92 | def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, 93 | encoder_optimizer, decoder_optimizer, batch_size, max_length=MAX_LENGTH): 94 | 95 | encoder_optimizer.zero_grad() 96 | decoder_optimizer.zero_grad() 97 | 98 | input_variable = input_variable.to(device) 99 | target_variable = target_variable.to(device) 100 | mask = mask.to(device) 101 | 102 | loss = 0 103 | print_losses = [] 104 | n_totals = 0 105 | 106 | encoder_outputs, encoder_hidden = encoder(input_variable, lengths, None) 107 | 108 | decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]]) 109 | decoder_input = decoder_input.to(device) 110 | 111 | decoder_hidden = encoder_hidden[:decoder.n_layers] 112 | 113 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 114 | 115 | # Run through decoder one time step at a time 116 | if use_teacher_forcing: 117 | for t in range(max_target_len): 118 | decoder_output, decoder_hidden, _ = decoder( 119 | decoder_input, decoder_hidden, encoder_outputs 120 | ) 121 | decoder_input = target_variable[t].view(1, -1) # Next input is current target 122 | loss += F.cross_entropy(decoder_output, target_variable[t], ignore_index=EOS_token) 123 | else: 124 | for t in range(max_target_len): 125 | decoder_output, decoder_hidden, decoder_attn = decoder( 126 | decoder_input, decoder_hidden, encoder_outputs 127 | ) 128 | _, topi = decoder_output.topk(1) # [64, 1] 129 | 130 | decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]]) 131 | decoder_input = decoder_input.to(device) 132 | loss += F.cross_entropy(decoder_output, target_variable[t], ignore_index=EOS_token) 133 | 134 | loss.backward() 135 | 136 | clip = 50.0 137 | _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip) 138 | _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip) 139 | 140 | encoder_optimizer.step() 141 | decoder_optimizer.step() 142 | 143 | return loss.item() / max_target_len 144 | 145 | 146 | def trainIters(corpus, reverse, n_iteration, learning_rate, batch_size, n_layers, hidden_size, 147 | print_every, save_every, dropout, loadFilename=None, attn_model='dot', decoder_learning_ratio=5.0): 148 | 149 | voc, pairs = loadPrepareData(corpus) 150 | 151 | # training data 152 | corpus_name = os.path.split(corpus)[-1].split('.')[0] 153 | training_batches = None 154 | try: 155 | training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 156 | '{}_{}_{}.tar'.format(n_iteration, \ 157 | filename(reverse, 'training_batches'), \ 158 | batch_size))) 159 | except FileNotFoundError: 160 | print('Training pairs not found, generating ...') 161 | training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)], reverse) 162 | for _ in range(n_iteration)] 163 | torch.save(training_batches, os.path.join(save_dir, 'training_data', corpus_name, 164 | '{}_{}_{}.tar'.format(n_iteration, \ 165 | filename(reverse, 'training_batches'), \ 166 | batch_size))) 167 | # model 168 | checkpoint = None 169 | print('Building encoder and decoder ...') 170 | embedding = nn.Embedding(voc.n_words, hidden_size) 171 | encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers, dropout) 172 | attn_model = 'dot' 173 | decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers, dropout) 174 | if loadFilename: 175 | checkpoint = torch.load(loadFilename) 176 | encoder.load_state_dict(checkpoint['en']) 177 | decoder.load_state_dict(checkpoint['de']) 178 | # use cuda 179 | encoder = encoder.to(device) 180 | decoder = decoder.to(device) 181 | 182 | # optimizer 183 | print('Building optimizers ...') 184 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) 185 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio) 186 | if loadFilename: 187 | encoder_optimizer.load_state_dict(checkpoint['en_opt']) 188 | decoder_optimizer.load_state_dict(checkpoint['de_opt']) 189 | 190 | # initialize 191 | print('Initializing ...') 192 | start_iteration = 1 193 | perplexity = [] 194 | print_loss = 0 195 | if loadFilename: 196 | start_iteration = checkpoint['iteration'] + 1 197 | perplexity = checkpoint['plt'] 198 | 199 | for iteration in tqdm(range(start_iteration, n_iteration + 1)): 200 | training_batch = training_batches[iteration - 1] 201 | input_variable, lengths, target_variable, mask, max_target_len = training_batch 202 | 203 | loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder, 204 | decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size) 205 | print_loss += loss 206 | perplexity.append(loss) 207 | 208 | if iteration % print_every == 0: 209 | print_loss_avg = math.exp(print_loss / print_every) 210 | print('%d %d%% %.4f' % (iteration, iteration / n_iteration * 100, print_loss_avg)) 211 | print_loss = 0 212 | 213 | if (iteration % save_every == 0): 214 | directory = os.path.join(save_dir, 'model', corpus_name, '{}-{}_{}'.format(n_layers, n_layers, hidden_size)) 215 | if not os.path.exists(directory): 216 | os.makedirs(directory) 217 | torch.save({ 218 | 'iteration': iteration, 219 | 'en': encoder.state_dict(), 220 | 'de': decoder.state_dict(), 221 | 'en_opt': encoder_optimizer.state_dict(), 222 | 'de_opt': decoder_optimizer.state_dict(), 223 | 'loss': loss, 224 | 'plt': perplexity 225 | }, os.path.join(directory, '{}_{}.tar'.format(iteration, filename(reverse, 'backup_bidir_model')))) 226 | --------------------------------------------------------------------------------