├── README.md ├── bleu.py ├── download.sh ├── logger.py ├── main.py ├── model ├── Attention.py ├── Decoder.py ├── Encoder.py ├── Seq2Seq.py └── __init__.py ├── prepro.py ├── trainer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Torchtext-Seq2Seq 2 | [Pytorch](https://github.com/pytorch/pytorch) 3 | implementation of [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473). 4 | 5 | 6 | ### Prerequisites 7 | * [Python 3.5+](https://www.continuum.io/downloads) 8 | * [PyTorch 0.2.0](http://pytorch.org/) 9 | * [Torchtext 0.2.1](https://github.com/pytorch/text) 10 | * [spaCy 2.0.5](https://spacy.io/) 11 | * [TensorFlow 1.3+](https://www.tensorflow.org/) (optional for tensorboard) 12 | 13 | 14 | ## Getting Started 15 | #### 1. Clone the repository 16 | ```bash 17 | $ git clone https://github.com/Mjkim88/Pytorch-Torchtext-Seq2Seq.git 18 | $ cd Pytorch-Torchtext-Seq2Seq 19 | ``` 20 | 21 | #### 2. Download the dataset 22 | ```bash 23 | $ bash download.sh 24 | ``` 25 | This commands will download Europarl v7 and dev datasets to `data/` folder. 26 | If you want to use other datasets, you don't need to run this command. 27 | 28 | #### 3. Train the model 29 | ```bash 30 | $ python main.py --dataset 'europarl' --src_lang 'fr' --trg_lang 'en' --data_path './data' \ 31 | --train_path './data/training/europarl-v7.fr-en' --val_path './data/dev/newstest2013' \ 32 | --log log --sample sample 33 | ``` 34 | If you initially run the above command, the model starts from preprocessing data using Torchtext and automatically saves the preprocessed JSON file to `/data`, so that it avoids preprocessing the same datasets again. 35 | 36 | #### (Optional) Tensorboard visualization 37 | ```bash 38 | $ tensorboard --logdir='./logs' --port=8888 39 | ``` 40 | For the tensorboard visualization, open the new terminal and run the command below and open `http://localhost:8888` on your web browser. 41 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code borrowed from 3 | https://github.com/MaximumEntropy/Seq2Seq-PyTorch 4 | """ 5 | from collections import Counter 6 | import math 7 | import numpy as np 8 | import subprocess 9 | 10 | 11 | def bleu_stats(hypothesis, reference): 12 | """Compute statistics for BLEU.""" 13 | stats = [] 14 | stats.append(len(hypothesis)) 15 | stats.append(len(reference)) 16 | for n in range(1, 5): 17 | s_ngrams = Counter( 18 | [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)] 19 | ) 20 | r_ngrams = Counter( 21 | [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)] 22 | ) 23 | stats.append(max([sum((s_ngrams & r_ngrams).values()), 0])) 24 | stats.append(max([len(hypothesis) + 1 - n, 0])) 25 | return stats 26 | 27 | 28 | def bleu(stats): 29 | """Compute BLEU given n-gram statistics.""" 30 | if len(list(filter(lambda x: x == 0, stats))) > 0: 31 | return 0 32 | (c, r) = stats[:2] 33 | log_bleu_prec = sum( 34 | [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])] 35 | ) / 4. 36 | return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec) 37 | 38 | 39 | def get_bleu(hypotheses, reference): 40 | """Get validation BLEU score for dev set.""" 41 | stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) 42 | for hyp, ref in zip(hypotheses, reference): 43 | stats += np.array(bleu_stats(hyp, ref)) 44 | return 100 * bleu(stats) -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | cd data 4 | 5 | wget http://statmt.org/wmt13/training-parallel-europarl-v7.tgz 6 | tar -xzvf training-parallel-europarl-v7.tgz 7 | rm training-parallel-europarl-v7.tgz 8 | 9 | wget http://statmt.org/wmt14/dev.tgz 10 | tar -xzvf dev.tgz 11 | rm dev.tgz 12 | 13 | cd .. 14 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 4 | import tensorflow as tf 5 | import numpy as np 6 | import scipy.misc 7 | try: 8 | from StringIO import StringIO # Python 2.7 9 | except ImportError: 10 | from io import BytesIO # Python 3.x 11 | 12 | 13 | class Logger(object): 14 | 15 | def __init__(self, log_dir): 16 | """Create a summary writer logging to log_dir.""" 17 | self.writer = tf.summary.FileWriter(log_dir) 18 | 19 | def scalar_summary(self, tag, value, step): 20 | """Log a scalar variable.""" 21 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 22 | self.writer.add_summary(summary, step) 23 | 24 | def image_summary(self, tag, images, step): 25 | """Log a list of images.""" 26 | 27 | img_summaries = [] 28 | for i, img in enumerate(images): 29 | # Write the image to a string 30 | try: 31 | s = StringIO() 32 | except: 33 | s = BytesIO() 34 | scipy.misc.toimage(img).save(s, format="png") 35 | 36 | # Create an Image object 37 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 38 | height=img.shape[0], 39 | width=img.shape[1]) 40 | # Create a Summary value 41 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 42 | 43 | # Create and write Summary 44 | summary = tf.Summary(value=img_summaries) 45 | self.writer.add_summary(summary, step) 46 | 47 | def histo_summary(self, tag, values, step, bins=1000): 48 | """Log a histogram of the tensor of values.""" 49 | 50 | # Create a histogram using numpy 51 | # print (values) 52 | 53 | counts, bin_edges = np.histogram(values, bins=bins) 54 | 55 | 56 | # Fill the fields of the histogram proto 57 | hist = tf.HistogramProto() 58 | hist.min = float(np.min(values)) 59 | hist.max = float(np.max(values)) 60 | hist.num = int(np.prod(values.shape)) 61 | hist.sum = float(np.sum(values)) 62 | hist.sum_squares = float(np.sum(values**2)) 63 | 64 | # Drop the start of the first bin 65 | bin_edges = bin_edges[1:] 66 | 67 | # Add bin edges and counts 68 | for edge in bin_edges: 69 | hist.bucket_limit.append(edge) 70 | for c in counts: 71 | hist.bucket.append(c) 72 | 73 | # Create and write Summary 74 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 75 | self.writer.add_summary(summary, step) 76 | self.writer.flush() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from torch.backends import cudnn 5 | 6 | from prepro import * 7 | from trainer import * 8 | 9 | 10 | def main(args): 11 | cuda.set_device(int(args.gpu_num)) 12 | cudnn.benchmark = True 13 | 14 | # Load dataset 15 | train_file = os.path.join(args.data_path, "data_{}_{}_{}_{}.json".format(args.dataset, args.src_lang, 16 | args.trg_lang, args.max_len)) 17 | val_file = os.path.join(args.data_path, "data_dev_{}_{}_{}.json".format(args.src_lang, args.trg_lang, args.max_len)) 18 | 19 | start_time = time.time() 20 | if os.path.isfile(train_file) and os.path.isfile(val_file): 21 | print ("Loading data..") 22 | dp = DataPreprocessor() 23 | train_dataset, val_dataset, vocabs = dp.load_data(train_file, val_file) 24 | else: 25 | print ("Preprocessing data..") 26 | dp = DataPreprocessor() 27 | train_dataset, val_dataset, vocabs = dp.preprocess(args.train_path, args.val_path, train_file, val_file, 28 | args.src_lang, args.trg_lang, args.max_len) 29 | 30 | 31 | print ("Elapsed Time: %1.3f \n" %(time.time() - start_time)) 32 | 33 | print ("=========== Data Stat ===========") 34 | print ("Train: ", len(train_dataset)) 35 | print ("val: ", len(val_dataset)) 36 | print ("=================================") 37 | 38 | train_loader = data.BucketIterator(dataset=train_dataset, batch_size=args.batch_size, 39 | repeat=False, shuffle=True, sort_within_batch=True, 40 | sort_key=lambda x: len(x.src)) 41 | val_loader = data.BucketIterator(dataset=val_dataset, batch_size=args.batch_size, 42 | repeat=False, shuffle=True, sort_within_batch=True, 43 | sort_key=lambda x: len(x.src)) 44 | 45 | trainer = Trainer(train_loader, val_loader, vocabs, args) 46 | trainer.train() 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | parser = argparse.ArgumentParser() 52 | 53 | # Language setting 54 | parser.add_argument('--dataset', type=str, default='europarl') 55 | parser.add_argument('--src_lang', type=str, default='fr') 56 | parser.add_argument('--trg_lang', type=str, default='en') 57 | parser.add_argument('--max_len', type=int, default=50) 58 | 59 | # Model hyper-parameters 60 | parser.add_argument('--lr', type=float, default=0.0001) 61 | parser.add_argument('--grad_clip', type=float, default=2) 62 | parser.add_argument('--num_layer', type=int, default=2) 63 | parser.add_argument('--embed_dim', type=int, default=512) 64 | parser.add_argument('--hidden_dim', type=int, default=1024) 65 | 66 | # Training setting 67 | parser.add_argument('--batch_size', type=int, default=40) 68 | parser.add_argument('--num_epoch', type=int, default=100) 69 | 70 | # Path 71 | parser.add_argument('--data_path', type=str, default='./data/') 72 | parser.add_argument('--train_path', type=str, default='./data/training/europarl-v7.fr-en') 73 | parser.add_argument('--val_path', type=str, default='./data/dev/newstest2013') 74 | 75 | # Dir. 76 | parser.add_argument('--log', type=str, default='log') 77 | parser.add_argument('--sample', type=str, default='sample') 78 | 79 | # Misc. 80 | parser.add_argument('--gpu_num', type=int, default=0) 81 | 82 | args = parser.parse_args() 83 | print (args) 84 | main(args) 85 | -------------------------------------------------------------------------------- /model/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from utils import * 8 | 9 | 10 | class Attention(nn.Module): 11 | def __init__(self, hidden_dim): 12 | super(Attention, self).__init__() 13 | 14 | self.enc_h_in = nn.Linear(hidden_dim*2, hidden_dim) 15 | self.prev_s_in = nn.Linear(hidden_dim, hidden_dim) 16 | self.linear = nn.Linear(hidden_dim, 1) 17 | 18 | def forward(self, enc_h, prev_s): 19 | ''' 20 | enc_h : B x S x 2*H 21 | prev_s : B x 1 x H 22 | ''' 23 | seq_len = enc_h.size(1) 24 | 25 | enc_h_in = self.enc_h_in(enc_h) # B x S x H 26 | prev_s = self.prev_s_in(prev_s).unsqueeze(1) # B x 1 x H 27 | 28 | h = F.tanh(enc_h_in + prev_s.expand_as(enc_h_in)) # B x S x H 29 | h = self.linear(h) # B x S x 1 30 | 31 | alpha = F.softmax(h) 32 | ctx = torch.bmm(alpha.transpose(2,1), enc_h).squeeze(1) # B x 1 x 2*H 33 | 34 | return ctx 35 | -------------------------------------------------------------------------------- /model/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | from utils import * 7 | from model import * 8 | 9 | class Decoder(nn.Module): 10 | def __init__(self, vocab_size, embed_dim, hidden_dim, max_len, trg_soi): 11 | super(Decoder, self).__init__() 12 | self.hidden_dim = hidden_dim 13 | self.max_len = max_len 14 | self.vocab_size = vocab_size 15 | self.trg_soi = trg_soi 16 | 17 | self.embed = nn.Embedding(vocab_size, embed_dim) 18 | self.attention = Attention(hidden_dim) 19 | self.decodercell = DecoderCell(embed_dim, hidden_dim) 20 | self.dec2word = nn.Linear(hidden_dim, vocab_size) 21 | 22 | 23 | def forward(self, enc_h, prev_s, target=None): 24 | ''' 25 | enc_h : B x S x 2*H 26 | prev_s : B x H 27 | ''' 28 | 29 | if target is not None: 30 | batch_size, target_len = target.size(0), target.size(1) 31 | 32 | dec_h = Variable(torch.zeros(batch_size, target_len, self.hidden_dim)) 33 | 34 | if torch.cuda.is_available(): 35 | dec_h = dec_h.cuda() 36 | 37 | target = self.embed(target) 38 | for i in range(target_len): 39 | ctx = self.attention(enc_h, prev_s) 40 | prev_s = self.decodercell(target[:, i], prev_s, ctx) 41 | dec_h[:,i,:] = prev_s.unsqueeze(1) 42 | 43 | outputs = self.dec2word(dec_h) 44 | 45 | else: 46 | batch_size = enc_h.size(0) 47 | target = Variable(torch.LongTensor([self.trg_soi] * batch_size), volatile=True).view(batch_size, 1) 48 | outputs = Variable(torch.zeros(batch_size, self.max_len, self.vocab_size)) 49 | 50 | if torch.cuda.is_available(): 51 | target = target.cuda() 52 | outputs = outputs.cuda() 53 | 54 | for i in range(self.max_len): 55 | target = self.embed(target).squeeze(1) 56 | ctx = self.attention(enc_h, prev_s) 57 | prev_s = self.decodercell(target, prev_s, ctx) 58 | output = self.dec2word(prev_s) 59 | outputs[:,i,:] = output 60 | target = output.topk(1)[1] 61 | 62 | return outputs 63 | 64 | 65 | class DecoderCell(nn.Module): 66 | def __init__(self, embed_dim, hidden_dim): 67 | super(DecoderCell, self).__init__() 68 | 69 | self.input_weights = nn.Linear(embed_dim, hidden_dim*2) 70 | self.hidden_weights = nn.Linear(hidden_dim, hidden_dim*2) 71 | self.ctx_weights = nn.Linear(hidden_dim*2, hidden_dim*2) 72 | 73 | self.input_in = nn.Linear(embed_dim, hidden_dim) 74 | self.hidden_in = nn.Linear(hidden_dim, hidden_dim) 75 | self.ctx_in = nn.Linear(hidden_dim*2, hidden_dim) 76 | 77 | 78 | def forward(self, trg_word, prev_s, ctx): 79 | ''' 80 | trg_word : B x E 81 | prev_s : B x H 82 | ctx : B x 2*H 83 | ''' 84 | gates = self.input_weights(trg_word) + self.hidden_weights(prev_s) + self.ctx_weights(ctx) 85 | reset_gate, update_gate = gates.chunk(2,1) 86 | 87 | reset_gate = F.sigmoid(reset_gate) 88 | update_gate = F.sigmoid(update_gate) 89 | 90 | prev_s_tilde = self.input_in(trg_word) + self.hidden_in(prev_s) + self.ctx_in(ctx) 91 | prev_s_tilde = F.tanh(prev_s_tilde) 92 | 93 | prev_s = torch.mul((1-reset_gate), prev_s) + torch.mul(reset_gate, prev_s_tilde) 94 | return prev_s 95 | 96 | -------------------------------------------------------------------------------- /model/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from utils import * 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2): 10 | super(Encoder, self).__init__() 11 | self.num_layers = num_layers 12 | self.hidden_dim = hidden_dim 13 | 14 | self.embedding = nn.Embedding(vocab_size, embed_dim) 15 | self.gru = nn.GRU(embed_dim, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True, ) 16 | 17 | def forward(self, source, src_length=None, hidden=None): 18 | ''' 19 | source: B x T 20 | ''' 21 | batch_size = source.size(0) 22 | src_embed = self.embedding(source) 23 | 24 | if hidden is None: 25 | h_size = (self.num_layers *2, batch_size, self.hidden_dim) 26 | enc_h_0 = Variable(src_embed.data.new(*h_size).zero_(), requires_grad=False) 27 | 28 | if src_length is not None: 29 | src_embed = nn.utils.rnn.pack_padded_sequence(src_embed, src_length, batch_first=True) 30 | 31 | enc_h, enc_h_t = self.gru(src_embed, enc_h_0) 32 | 33 | if src_length is not None: 34 | enc_h, _ = nn.utils.rnn.pad_packed_sequence(enc_h, batch_first=True) 35 | 36 | return enc_h, enc_h_t 37 | -------------------------------------------------------------------------------- /model/Seq2Seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from utils import * 7 | from model import * 8 | 9 | class Seq2Seq(nn.Module): 10 | def __init__(self, src_nword, trg_nword, num_layer, embed_dim, hidden_dim, max_len, trg_soi): 11 | super(Seq2Seq, self).__init__() 12 | 13 | self.hidden_dim = hidden_dim 14 | self.trg_nword = trg_nword 15 | 16 | self.encoder = Encoder(src_nword, embed_dim, hidden_dim) 17 | self.linear = nn.Linear(hidden_dim, hidden_dim) 18 | self.decoder = Decoder(trg_nword, embed_dim, hidden_dim, max_len, trg_soi) 19 | 20 | 21 | def forward(self, source, src_length=None, target=None): 22 | batch_size = source.size(0) 23 | 24 | enc_h, enc_h_t = self.encoder(source, src_length) # B x S x 2*H / 2 x B x H 25 | 26 | dec_h0 = enc_h_t[-1] # B x H 27 | dec_h0 = F.tanh(self.linear(dec_h0)) # B x 1 x 2*H 28 | 29 | out = self.decoder(enc_h, dec_h0, target) # B x S x H 30 | out = F.log_softmax(out.contiguous().view(-1, self.trg_nword)) 31 | 32 | return out 33 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.Encoder import * 2 | from model.Attention import * 3 | from model.Decoder import * 4 | from model.Seq2Seq import * -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchtext import data 3 | from torchtext import datasets 4 | import time 5 | import re 6 | import spacy 7 | import os 8 | from tqdm import tqdm 9 | 10 | SOS_WORD = '' 11 | EOS_WORD = '' 12 | PAD_WORD = '' 13 | 14 | 15 | class MaxlenTranslationDataset(data.Dataset): 16 | # Code modified from 17 | # https://github.com/pytorch/text/blob/master/torchtext/datasets/translation.py 18 | # to be able to control the max length of the source and target sentences 19 | 20 | def __init__(self, path, exts, fields, max_len=None, **kwargs): 21 | 22 | if not isinstance(fields[0], (tuple, list)): 23 | fields = [('src', fields[0]), ('trg', fields[1])] 24 | 25 | src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) 26 | 27 | examples = [] 28 | with open(src_path) as src_file, open(trg_path) as trg_file: 29 | for src_line, trg_line in tqdm(zip(src_file, trg_file)): 30 | src_line, trg_line = src_line.split(' '), trg_line.split(' ') 31 | if max_len is not None: 32 | src_line = src_line[:max_len] 33 | src_line = str(' '.join(src_line)) 34 | trg_line = trg_line[:max_len] 35 | trg_line = str(' '.join(trg_line)) 36 | 37 | if src_line != '' and trg_line != '': 38 | examples.append(data.Example.fromlist( 39 | [src_line, trg_line], fields)) 40 | 41 | super(MaxlenTranslationDataset, self).__init__(examples, fields, **kwargs) 42 | 43 | 44 | class DataPreprocessor(object): 45 | def __init__(self): 46 | self.src_field, self.trg_field = self.generate_fields() 47 | 48 | def preprocess(self, train_path, val_path, train_file, val_file, src_lang, trg_lang, max_len=None): 49 | # Generating torchtext dataset class 50 | print ("Preprocessing train dataset...") 51 | train_dataset = self.generate_data(train_path, src_lang, trg_lang, max_len) 52 | 53 | print ("Saving train dataset...") 54 | self.save_data(train_file, train_dataset) 55 | 56 | print ("Preprocessing validation dataset...") 57 | val_dataset = self.generate_data(val_path, src_lang, trg_lang, max_len) 58 | 59 | print ("Saving validation dataset...") 60 | self.save_data(val_file, val_dataset) 61 | 62 | # Building field vocabulary 63 | self.src_field.build_vocab(train_dataset, max_size=30000) 64 | self.trg_field.build_vocab(train_dataset, max_size=30000) 65 | 66 | src_vocab, trg_vocab, src_inv_vocab, trg_inv_vocab = self.generate_vocabs() 67 | 68 | vocabs = {'src_vocab': src_vocab, 'trg_vocab':trg_vocab, 69 | 'src_inv_vocab':src_inv_vocab, 'trg_inv_vocab':trg_inv_vocab} 70 | 71 | return train_dataset, val_dataset, vocabs 72 | 73 | def load_data(self, train_file, val_file): 74 | 75 | # Loading saved data 76 | train_dataset = torch.load(train_file) 77 | train_examples = train_dataset['examples'] 78 | 79 | val_dataset = torch.load(val_file) 80 | val_examples = val_dataset['examples'] 81 | 82 | # Generating torchtext dataset class 83 | fields = [('src', self.src_field), ('trg', self.trg_field)] 84 | train_dataset = data.Dataset(fields=fields, examples=train_examples) 85 | val_dataset = data.Dataset(fields=fields, examples=val_examples) 86 | 87 | # Building field vocabulary 88 | self.src_field.build_vocab(train_dataset, max_size=30000) 89 | self.trg_field.build_vocab(train_dataset, max_size=30000) 90 | 91 | src_vocab, trg_vocab, src_inv_vocab, trg_inv_vocab = self.generate_vocabs() 92 | vocabs = {'src_vocab': src_vocab, 'trg_vocab':trg_vocab, 93 | 'src_inv_vocab':src_inv_vocab, 'trg_inv_vocab':trg_inv_vocab} 94 | 95 | return train_dataset, val_dataset, vocabs 96 | 97 | 98 | def save_data(self, data_file, dataset): 99 | 100 | examples = vars(dataset)['examples'] 101 | dataset = {'examples': examples} 102 | 103 | torch.save(dataset, data_file) 104 | 105 | def generate_fields(self): 106 | src_field = data.Field(tokenize=data.get_tokenizer('spacy'), 107 | init_token=SOS_WORD, 108 | eos_token=EOS_WORD, 109 | pad_token=PAD_WORD, 110 | include_lengths=True, 111 | batch_first=True) 112 | 113 | trg_field = data.Field(tokenize=data.get_tokenizer('spacy'), 114 | init_token=SOS_WORD, 115 | eos_token=EOS_WORD, 116 | pad_token=PAD_WORD, 117 | include_lengths=True, 118 | batch_first=True) 119 | 120 | return src_field, trg_field 121 | 122 | def generate_data(self, data_path, src_lang, trg_lang, max_len=None): 123 | exts = ('.'+src_lang, '.'+trg_lang) 124 | 125 | dataset = MaxlenTranslationDataset( 126 | path=data_path, 127 | exts=(exts), 128 | fields=(self.src_field, self.trg_field), 129 | max_len=max_len) 130 | 131 | return dataset 132 | 133 | def generate_vocabs(self): 134 | # Define string to index vocabs 135 | src_vocab = self.src_field.vocab.stoi 136 | trg_vocab = self.trg_field.vocab.stoi 137 | 138 | # Define index to string vocabs 139 | src_inv_vocab = self.src_field.vocab.itos 140 | trg_inv_vocab = self.trg_field.vocab.itos 141 | 142 | return src_vocab, trg_vocab, src_inv_vocab, trg_inv_vocab 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import cuda 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | from torch.optim.lr_scheduler import * 7 | 8 | import numpy as np 9 | import math 10 | import time 11 | import os 12 | 13 | from logger import Logger 14 | from tqdm import tqdm 15 | 16 | from prepro import * 17 | from utils import * 18 | from model.Seq2Seq import Seq2Seq 19 | from bleu import * 20 | 21 | 22 | class Trainer(object): 23 | def __init__(self, train_loader, val_loader, vocabs, args): 24 | 25 | # Language setting 26 | self.max_len = args.max_len 27 | 28 | # Data Loader 29 | self.train_loader = train_loader 30 | self.val_loader = val_loader 31 | 32 | # Path 33 | self.data_path = args.data_path 34 | self.sample_path = os.path.join('./samples/' + args.sample) 35 | self.log_path = os.path.join('./logs/' + args.log) 36 | 37 | if not os.path.exists(self.sample_path): os.makedirs(self.sample_path) 38 | if not os.path.exists(self.log_path): os.makedirs(self.log_path) 39 | 40 | # Hyper-parameters 41 | self.lr = args.lr 42 | self.grad_clip = args.grad_clip 43 | self.embed_dim = args.embed_dim 44 | self.hidden_dim = args.hidden_dim 45 | self.num_layer = args.num_layer 46 | 47 | # Training setting 48 | self.batch_size = args.batch_size 49 | self.num_epoch = args.num_epoch 50 | self.iter_per_epoch = len(train_loader) 51 | 52 | # Log 53 | self.logger = open(self.log_path+'/log.txt','w') 54 | self.sample = open(self.sample_path+'/sample.txt','w') 55 | self.tf_log = Logger(self.log_path) 56 | 57 | self.build_model(vocabs) 58 | 59 | 60 | def build_model(self, vocabs): 61 | # build dictionaries 62 | self.src_vocab = vocabs['src_vocab'] 63 | self.trg_vocab = vocabs['trg_vocab'] 64 | self.src_inv_vocab = vocabs['src_inv_vocab'] 65 | self.trg_inv_vocab = vocabs['trg_inv_vocab'] 66 | self.trg_soi = self.trg_vocab[SOS_WORD] 67 | 68 | self.src_nword = len(self.src_vocab) 69 | self.trg_nword = len(self.trg_vocab) 70 | 71 | # build the model 72 | self.model = Seq2Seq(self.src_nword, self.trg_nword, self.num_layer, self.embed_dim, self.hidden_dim, 73 | self.max_len, self.trg_soi) 74 | 75 | # set the criterion and optimizer 76 | self.criterion = nn.NLLLoss() 77 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 78 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 1, gamma=0.8) 79 | 80 | if torch.cuda.is_available(): 81 | self.model.cuda() 82 | 83 | print (self.model) 84 | print (self.criterion) 85 | print (self.optimizer) 86 | 87 | 88 | 89 | def train(self): 90 | self.best_bleu = .0 91 | 92 | for epoch in range(self.num_epoch): 93 | #self.scheduler.step() 94 | self.train_loss = AverageMeter() 95 | self.train_bleu = AverageMeter() 96 | start_time = time.time() 97 | 98 | for i, batch in enumerate(tqdm(self.train_loader)): 99 | self.model.train() 100 | 101 | src_input = batch.src[0]; src_length = batch.src[1] 102 | trg_input = batch.trg[0][:,:-1]; trg_output=batch.trg[0][:,1:]; trg_length = batch.trg[1] 103 | batch_size, trg_len = trg_input.size(0), trg_input.size(1) 104 | 105 | decoder_logit = self.model(src_input, src_length.tolist(), trg_input) 106 | pred = decoder_logit.view(batch_size, trg_len, -1) 107 | 108 | self.optimizer.zero_grad() 109 | loss = self.criterion(decoder_logit, trg_output.contiguous().view(-1)) 110 | loss.backward() 111 | torch.nn.utils.clip_grad_norm(self.model.parameters(), self.grad_clip) 112 | self.optimizer.step() 113 | 114 | # Compute BLEU score and Loss 115 | pred_sents = [] 116 | trg_sents = [] 117 | for j in range(batch_size): 118 | pred_sent = self.get_sentence(tensor2np(pred[j]).argmax(axis=-1), 'trg') 119 | trg_sent = self.get_sentence(tensor2np(trg_output[j]), 'trg') 120 | pred_sents.append(pred_sent) 121 | trg_sents.append(trg_sent) 122 | bleu_value = get_bleu(pred_sents, trg_sents) 123 | self.train_bleu.update(bleu_value, 1) 124 | self.train_loss.update(loss.data[0], batch_size) 125 | 126 | if i % 5000 == 0 and i != 0: 127 | self.print_train_result(epoch, i, start_time) 128 | self.print_sample(batch_size, epoch, i, src_input, trg_output, pred) 129 | self.eval(epoch, i) 130 | self.train_loss = AverageMeter() 131 | self.train_bleu = AverageMeter() 132 | start_time = time.time() 133 | 134 | # Logging tensorboard 135 | info = { 136 | 'epoch': epoch, 137 | 'train_iter': i, 138 | 'train_loss': self.train_loss.avg, 139 | 'train_bleu': self.train_bleu.avg 140 | } 141 | for tag, value in info.items(): 142 | self.tf_log.scalar_summary(tag, value, (epoch * self.iter_per_epoch)+i+1) 143 | 144 | self.print_train_result(epoch, i, start_time) 145 | self.print_sample(batch_size, epoch, i, src_input, trg_output, pred) 146 | self.eval(epoch, i) 147 | 148 | 149 | def eval(self, epoch, train_iter): 150 | self.model.eval() 151 | val_bleu = AverageMeter() 152 | start_time = time.time() 153 | 154 | for i, batch in enumerate(tqdm(self.val_loader)): 155 | src_input = batch.src[0]; src_length = batch.src[1] 156 | trg_input = batch.trg[0][:,:-1]; trg_output=batch.trg[0][:,1:]; trg_length = batch.trg[1] 157 | batch_size, trg_len = trg_input.size(0), trg_input.size(1) 158 | 159 | decoder_logit = self.model(src_input, src_length.tolist()) 160 | pred = decoder_logit.view(batch_size, self.max_len, -1) 161 | 162 | # Compute BLEU score 163 | pred_sents = [] 164 | trg_sents = [] 165 | for j in range(batch_size): 166 | pred_sent = self.get_sentence(tensor2np(pred[j]).argmax(axis=-1), 'trg') 167 | trg_sent = self.get_sentence(tensor2np(trg_output[j]), 'trg') 168 | pred_sents.append(pred_sent) 169 | trg_sents.append(trg_sent) 170 | bleu_value = get_bleu(pred_sents, trg_sents) 171 | val_bleu.update(bleu_value, 1) 172 | 173 | self.print_valid_result(epoch, train_iter, val_bleu.avg, start_time) 174 | self.print_sample(batch_size, epoch, train_iter, src_input, trg_output, pred) 175 | 176 | # Save model if bleu score is higher than the best 177 | if self.best_bleu < val_bleu.avg: 178 | self.best_bleu = val_bleu.avg 179 | checkpoint = { 180 | 'model': self.model, 181 | 'epoch': epoch 182 | } 183 | torch.save(checkpoint, self.log_path+'/Model_e%d_i%d_%.3f.pt' % (epoch, train_iter, val_bleu.avg)) 184 | 185 | # Logging tensorboard 186 | info = { 187 | 'epoch': epoch, 188 | 'train_iter': train_iter, 189 | 'train_loss': self.train_loss.avg, 190 | 'train_bleu': self.train_bleu.avg, 191 | 'bleu': val_bleu.avg 192 | } 193 | 194 | for tag, value in info.items(): 195 | self.tf_log.scalar_summary(tag, value, (epoch * self.iter_per_epoch)+train_iter+1) 196 | 197 | 198 | def get_sentence(self, sentence, side): 199 | def _eos_parsing(sentence): 200 | if EOS_WORD in sentence: 201 | return sentence[:sentence.index(EOS_WORD)+1] 202 | else: 203 | return sentence 204 | 205 | # index sentence to word sentence 206 | if side == 'trg': 207 | sentence = [self.trg_inv_vocab[x] for x in sentence] 208 | else: 209 | sentence = [self.src_inv_vocab[x] for x in sentence] 210 | 211 | return _eos_parsing(sentence) 212 | 213 | 214 | def print_train_result(self, epoch, train_iter, start_time): 215 | mode = ("================================= Train ====================================") 216 | print (mode, '\n') 217 | self.logger.write(mode+'\n') 218 | self.sample.write(mode+'\n') 219 | 220 | message = "Train epoch: %d iter: %d train loss: %1.3f train bleu: %1.3f elapsed: %1.3f " % ( 221 | epoch, train_iter, self.train_loss.avg, self.train_bleu.avg, time.time() - start_time) 222 | print (message, '\n\n') 223 | self.logger.write(message+'\n\n') 224 | 225 | 226 | def print_valid_result(self, epoch, train_iter, val_bleu, start_time): 227 | mode = ("================================= Validation ====================================") 228 | print (mode, '\n') 229 | self.logger.write(mode+'\n') 230 | self.sample.write(mode+'\n') 231 | 232 | message = "Train epoch: %d iter: %d train loss: %1.3f train_bleu: %1.3f val bleu score: %1.3f elapsed: %1.3f " % ( 233 | epoch, train_iter, self.train_loss.avg, self.train_bleu.avg, val_bleu, time.time() - start_time) 234 | print (message, '\n\n' ) 235 | self.logger.write(message+'\n\n') 236 | 237 | 238 | def print_sample(self, batch_size, epoch, train_iter, source, target, pred): 239 | 240 | def _write_and_print(message): 241 | for x in message: 242 | self.sample.write(x+'\n') 243 | print ((" ").join(message)) 244 | 245 | random_idx = randomChoice(batch_size) 246 | src_sample = self.get_sentence(tensor2np(source)[random_idx], 'src') 247 | trg_sample = self.get_sentence(tensor2np(target)[random_idx], 'trg') 248 | pred_sample = self.get_sentence(tensor2np(pred[random_idx]).argmax(axis=-1), 'trg') 249 | 250 | src_message = ["Source Sentence: ", (" ").join(src_sample), '\n'] 251 | trg_message = ["Target Sentence: ", (" ").join(trg_sample), '\n'] 252 | pred_message = ["Generated Sentence: ", (" ").join(pred_sample), '\n'] 253 | 254 | message = "Train epoch: %d iter: %d " % (epoch, train_iter) 255 | self.sample.write(message+'\n') 256 | _write_and_print(src_message) 257 | _write_and_print(trg_message) 258 | _write_and_print(pred_message) 259 | self.sample.write('\n\n\n') 260 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | def tensor2np(tensor): 4 | return tensor.data.cpu().numpy() 5 | 6 | def randomChoice(batch_size): 7 | return random.randint(0, batch_size - 1) 8 | 9 | class AverageMeter(object): 10 | """ 11 | Computes and stores the average and current value 12 | Borrowed from ImageNet training in PyTorch project 13 | https://github.com/pytorch/examples/tree/master/imagenet 14 | """ 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | --------------------------------------------------------------------------------