├── assets ├── d_loss.png ├── g_loss.png └── w_distance.png ├── README.md ├── models ├── __init__.py ├── discriminator.py └── generator.py ├── LICENSE ├── utils.py ├── .gitignore ├── logger.py └── train.py /assets/d_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keon/seq2seq-wgan/HEAD/assets/d_loss.png -------------------------------------------------------------------------------- /assets/g_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keon/seq2seq-wgan/HEAD/assets/g_loss.png -------------------------------------------------------------------------------- /assets/w_distance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keon/seq2seq-wgan/HEAD/assets/w_distance.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WIP] seq2seq-wgan 2 | Improved Training of Wasserstein GANs for Neural Machine Translation 3 | 4 | 5 | Based on the paper [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028). 6 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator import Discriminator 2 | from .generator import Encoder, Decoder, Seq2Seq 3 | 4 | __all__ = [ 5 | Encoder, 6 | Decoder, 7 | Seq2Seq, 8 | Discriminator, 9 | ] 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Keon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | 9 | class Discriminator(nn.Module): 10 | def __init__(self, vocab_size, embed_size, hidden_size, 11 | n_layers=1, dropout=0.2): 12 | super(Discriminator, self).__init__() 13 | self.embed_size = embed_size 14 | self.hidden_size = hidden_size 15 | self.vocab_size = vocab_size 16 | self.n_layers = n_layers 17 | 18 | self.embed = nn.Linear(vocab_size, embed_size, bias=False) 19 | self.dropout = nn.Dropout(dropout) 20 | self.gru = nn.GRU(hidden_size + embed_size, hidden_size, 21 | n_layers, dropout=dropout) 22 | self.out = nn.Linear(hidden_size, 1) 23 | 24 | def forward(self, input, context): 25 | """ 26 | input: I x B x Vocab 27 | hidden: I x B x H 28 | context: I x B x E 29 | """ 30 | # Get the embedding of the current input word (last output word) 31 | embedded = self.embed(input) # (I,B,E) 32 | embedded = self.dropout(embedded) 33 | # Combine embedded input word and attended context, run through RNN 34 | rnn_input = torch.cat([embedded, context], 2) 35 | output, hidden = self.gru(rnn_input, None) 36 | out = self.out(output[-1]) # [b, h] -> [b, 1] 37 | return out 38 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import spacy 3 | import torch 4 | from torchtext.data import Field, BucketIterator 5 | from torchtext.datasets import Multi30k 6 | 7 | 8 | def enable_gradients(model): 9 | for p in model.parameters(): 10 | p.requires_grad = True 11 | 12 | 13 | def disable_gradients(model): 14 | for p in model.parameters(): 15 | p.requires_grad = False 16 | 17 | 18 | def to_onehot(index, vocab_size): 19 | batch_size, seq_len = index.size(0), index.size(1) 20 | onehot = torch.FloatTensor(batch_size, seq_len, vocab_size).zero_() 21 | onehot.scatter_(2, index.data.cpu().unsqueeze(2), 1) 22 | return onehot 23 | 24 | 25 | def load_dataset(batch_size): 26 | spacy_de = spacy.load('de') 27 | spacy_en = spacy.load('en') 28 | url = re.compile('(.*)') 29 | 30 | def tokenize_de(text): 31 | return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))] 32 | 33 | def tokenize_en(text): 34 | return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))] 35 | 36 | DE = Field(tokenize=list, include_lengths=True, 37 | init_token='', eos_token='') 38 | EN = Field(tokenize=list, include_lengths=True, 39 | init_token='', eos_token='') 40 | train, val, test = Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN)) 41 | DE.build_vocab(train.src) 42 | EN.build_vocab(train.trg) 43 | train_iter, val_iter, test_iter = BucketIterator.splits( 44 | (train, val, test), batch_size=batch_size, repeat=False) 45 | return train_iter, val_iter, test_iter, DE, EN 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | .data 104 | .save 105 | .tmp 106 | .samples 107 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from visdom import Visdom 4 | 5 | 6 | class VisdomWriter(object): 7 | def __init__(self, title, xlabel='Epoch', ylabel='Loss'): 8 | """Extended Visdom Writer""" 9 | self.vis = Visdom() 10 | assert self.vis.check_connection() 11 | self.title = title 12 | self.xlabel = xlabel 13 | self.ylabel = ylabel 14 | self.x = 0 15 | self.win = None 16 | 17 | def update_text(self, text): 18 | """Text Memo (usually used to note hyperparameter-configurations)""" 19 | self.vis.text(text) 20 | 21 | def update(self, y): 22 | """Update loss (X: Step (Epoch) / Y: loss)""" 23 | self.x += 1 24 | if self.win is None: 25 | self.win = self.vis.line( 26 | X=np.array([self.x]), 27 | Y=np.array([y]), 28 | opts=dict( 29 | title=self.title, 30 | xlabel=self.xlabel, 31 | ylabel=self.ylabel, 32 | )) 33 | else: 34 | self.vis.updateTrace( 35 | X=np.array([self.x]), 36 | Y=np.array([y]), 37 | win=self.win) 38 | 39 | 40 | def log_samples(file_path, samples, EN, is_output=True): 41 | eos = EN.vocab.stoi[''] 42 | if is_output: 43 | _, argmax = torch.max(samples, 2) 44 | samples = argmax.cpu().data 45 | samples = samples.t() 46 | decoded_samples = [] 47 | for i in range(len(samples)): 48 | decoded = ''.join([EN.vocab.itos[s] for s in samples[i]]) 49 | decoded_samples.append(decoded) 50 | with open(file_path, 'a+') as f: 51 | for sample in decoded_samples: 52 | f.write(sample + '\n') 53 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, input_size, embed_size, hidden_size, 11 | n_layers=1, dropout=0.2): 12 | super(Encoder, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.embed_size = embed_size 16 | self.embed = nn.Embedding(input_size, embed_size) 17 | self.gru = nn.GRU(embed_size, hidden_size, n_layers, 18 | dropout=dropout, bidirectional=True) 19 | 20 | def forward(self, src, hidden=None): 21 | if self.embed_size is not None: 22 | embedded = self.embed(src) 23 | outputs, hidden = self.gru(embedded, hidden) 24 | # sum bidirectional outputs 25 | outputs = (outputs[:, :, :self.hidden_size] + 26 | outputs[:, :, self.hidden_size:]) 27 | return outputs, hidden 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, hidden_size): 32 | super(Attention, self).__init__() 33 | self.hidden_size = hidden_size 34 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 35 | self.v = nn.Parameter(torch.rand(hidden_size)) 36 | stdv = 1. / math.sqrt(self.v.size(0)) 37 | self.v.data.uniform_(-stdv, stdv) 38 | 39 | def forward(self, hidden, encoder_outputs): 40 | timestep = encoder_outputs.size(0) 41 | h = hidden.repeat(timestep, 1, 1).transpose(0, 1) 42 | encoder_outputs = encoder_outputs.transpose(0, 1) # [B*T*H] 43 | attn_energies = self.score(h, encoder_outputs) 44 | return F.softmax(attn_energies, dim=1).unsqueeze(1) 45 | 46 | def score(self, hidden, encoder_outputs): 47 | # [B*T*2H]->[B*T*H] 48 | energy = self.attn(torch.cat([hidden, encoder_outputs], 2)) 49 | energy = energy.transpose(1, 2) # [B*H*T] 50 | v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B*1*H] 51 | energy = torch.bmm(v, energy) # [B*1*T] 52 | return energy.squeeze(1) # [B*T] 53 | 54 | 55 | class Decoder(nn.Module): 56 | def __init__(self, embed_size, hidden_size, output_size, 57 | n_layers=1, dropout=0.2): 58 | super(Decoder, self).__init__() 59 | self.embed_size = embed_size 60 | self.hidden_size = hidden_size 61 | self.output_size = output_size 62 | self.n_layers = n_layers 63 | 64 | self.embed = nn.Embedding(output_size, embed_size) 65 | self.dropout = nn.Dropout(dropout) 66 | self.attention = Attention(hidden_size) 67 | self.gru = nn.GRU(hidden_size + embed_size, hidden_size, 68 | n_layers, dropout=dropout) 69 | self.out = nn.Linear(hidden_size * 2, output_size) 70 | 71 | def forward(self, input, last_hidden, encoder_outputs): 72 | # Get the embedding of the current input word (last output word) 73 | embedded = self.embed(input).view(1, input.data.size(0), -1) # (1,B,N) 74 | embedded = self.dropout(embedded) 75 | # Calculate attention weights and apply to encoder outputs 76 | attn_weights = self.attention(last_hidden[-1], encoder_outputs) 77 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # (B,1,N) 78 | context = context.transpose(0, 1) # (1,B,N) 79 | # Combine embedded input word and attended context, run through RNN 80 | rnn_input = torch.cat([embedded, context], 2) 81 | output, hidden = self.gru(rnn_input, last_hidden) 82 | output = output.squeeze(0) # (1,B,N) -> (B,N) 83 | context = context.squeeze(0) 84 | output = self.out(torch.cat([output, context], 1)) 85 | output = F.log_softmax(output, dim=1) 86 | return output, hidden, context 87 | 88 | 89 | class Seq2Seq(nn.Module): 90 | def __init__(self, encoder, decoder): 91 | super(Seq2Seq, self).__init__() 92 | self.encoder = encoder 93 | self.decoder = decoder 94 | 95 | def forward(self, src, trg=None, teacher_forcing_ratio=0.4): 96 | batch_size = src.size(1) 97 | max_len = trg.size(0) 98 | vocab_size = self.decoder.output_size 99 | outputs = Variable(torch.zeros(max_len, batch_size, vocab_size)).cuda() 100 | contexts = Variable(torch.zeros(max_len, batch_size, 101 | self.decoder.hidden_size)).cuda() 102 | 103 | encoder_output, encoder_hidden = self.encoder(src) 104 | hidden = encoder_hidden[:self.decoder.n_layers] 105 | output = Variable(trg.data[0, :]).cuda() 106 | outputs[0] = to_onehot(output, vocab_size) 107 | for t in range(1, len(trg)): 108 | output, hidden, context = self.decoder( 109 | output, hidden, encoder_output) 110 | outputs[t] = output 111 | contexts[t] = context 112 | is_teacher = random.random() < teacher_forcing_ratio 113 | top1 = output.data.topk(1)[1].squeeze() 114 | output = Variable(trg.data[t] if is_teacher else top1).cuda() 115 | return outputs[1:], contexts[1:] 116 | 117 | 118 | def to_onehot(orig, vocab_size): 119 | batch_size = orig.size(0) 120 | onehot = torch.FloatTensor(batch_size, vocab_size).zero_() 121 | onehot.scatter_(1, orig.data.cpu().unsqueeze(1), 1) 122 | return onehot 123 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import optim 4 | from torch.autograd import Variable, grad 5 | import torch.nn.functional as F 6 | from models import Encoder, Decoder, Seq2Seq, Discriminator 7 | from utils import load_dataset, to_onehot, enable_gradients, disable_gradients 8 | from logger import VisdomWriter, log_samples 9 | 10 | 11 | def parse_arguments(): 12 | p = argparse.ArgumentParser(description='Hyperparams') 13 | p.add_argument('-epochs', type=int, default=100000, 14 | help='number of epochs for train') 15 | p.add_argument('-batch_size', type=int, default=32, 16 | help='number of epochs for train') 17 | p.add_argument('-lamb', type=float, default=10, 18 | help='lambda') 19 | return p.parse_args() 20 | 21 | 22 | def grad_penalty(D, real, gen, context, lamb): 23 | alpha = torch.rand(real.size()).cuda() 24 | x_hat = alpha * real + ((1 - alpha) * gen).cuda() 25 | x_hat = Variable(x_hat, requires_grad=True) 26 | context = Variable(context) 27 | d_hat = D(x_hat, context) 28 | ones = torch.ones(d_hat.size()).cuda() 29 | gradients = grad(outputs=d_hat, inputs=x_hat, 30 | grad_outputs=ones, create_graph=True, 31 | retain_graph=True, only_inputs=True)[0] 32 | penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb 33 | return penalty 34 | 35 | 36 | def D_loss(D, G, src, trg, lamb, curriculum): 37 | src_len = min(curriculum, len(src)-1) + 1 38 | trg_len = min(curriculum, len(src)-1) + 1 39 | # with gen 40 | gen_trg, context = G(src[:src_len], trg[:trg_len]) 41 | d_gen = D(gen_trg, context) 42 | # with real 43 | trg = to_onehot(trg, D.vocab_size).type(torch.FloatTensor)[1:trg_len] 44 | trg = Variable(trg.cuda()) 45 | d_real = D(trg, context) 46 | # calculate gradient panalty 47 | penalty = grad_penalty(D, trg.data, gen_trg.data, context.data, lamb) 48 | loss = d_gen.mean() - d_real.mean() + penalty 49 | return loss 50 | 51 | 52 | def G_loss(D, G, src, trg, curriculum): 53 | src_len = min(curriculum, len(src)-1) + 1 54 | trg_len = min(curriculum, len(src)-1) + 1 55 | gen_trg, context = G(src[:src_len], trg[:trg_len]) 56 | loss_g = D(gen_trg, context) 57 | return -loss_g.mean() 58 | 59 | 60 | def evaluate(e, model, val_iter, vocab_size, DE, EN, curriculum): 61 | model.eval() 62 | pad = EN.vocab.stoi[''] 63 | total_loss = 0 64 | for b, batch in enumerate(val_iter): 65 | src, len_src = batch.src 66 | trg, len_trg = batch.trg 67 | src = Variable(src.data.cuda(), volatile=True) 68 | trg = Variable(trg.data.cuda(), volatile=True) 69 | src_len = min(curriculum, len(src)-1) + 1 70 | trg_len = min(curriculum, len(src)-1) + 1 71 | output = model(src[:src_len], trg[:trg_len])[0] 72 | loss = F.cross_entropy(output.view(-1, vocab_size), 73 | trg[1:trg_len].contiguous().view(-1), 74 | ignore_index=pad) 75 | total_loss += loss.data[0] 76 | log_samples('./.samples/%d-translation.txt' % e, output, EN) 77 | return total_loss / len(val_iter) 78 | 79 | 80 | def main(): 81 | args = parse_arguments() 82 | hidden_size = 512 83 | embed_size = 256 84 | assert torch.cuda.is_available() 85 | 86 | # visdom for plotting 87 | vis_g = VisdomWriter("Generator Loss", 88 | xlabel='Iteration', ylabel='Loss') 89 | vis_d = VisdomWriter("Negative Discriminator Loss", 90 | xlabel='Iteration', ylabel='Loss') 91 | 92 | print("[!] preparing dataset...") 93 | train_iter, val_iter, test_iter, DE, EN = load_dataset(args.batch_size) 94 | de_size, en_size = len(DE.vocab), len(EN.vocab) 95 | print("de_vocab_size: %d en_vocab_size: %d" % (de_size, en_size)) 96 | 97 | print("[!] Instantiating models...") 98 | encoder = Encoder(de_size, embed_size, hidden_size, 99 | n_layers=2, dropout=0.5) 100 | decoder = Decoder(embed_size, hidden_size, en_size, 101 | n_layers=1, dropout=0.5) 102 | G = Seq2Seq(encoder, decoder).cuda() 103 | D = Discriminator(en_size, embed_size, hidden_size).cuda() 104 | optimizer_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.9)) 105 | optimizer_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9)) 106 | # TTUR paper https://arxiv.org/abs/1706.08500 107 | 108 | # pretrained 109 | # G.load_state_dict(torch.load("./.tmp/21.pt")) 110 | 111 | curriculum = 1 112 | dis_loss = [] 113 | gen_loss = [] 114 | for e in range(1, args.epochs+1): 115 | # Training 116 | for b, batch in enumerate(train_iter): 117 | src, len_src = batch.src 118 | trg, len_trg = batch.trg 119 | src, trg = src.cuda(), trg.cuda() 120 | # (1) Update D network 121 | enable_gradients(D) 122 | disable_gradients(G) 123 | G.eval() 124 | D.train() 125 | # clamp parameters to a cube 126 | for p in D.parameters(): 127 | p.data.clamp_(-0.01, 0.01) 128 | D.zero_grad() 129 | loss_d = D_loss(D, G, src, trg, args.lamb, curriculum) 130 | loss_d.backward() 131 | optimizer_D.step() 132 | dis_loss.append(loss_d.data[0]) 133 | # (2) Update G network 134 | if b % 10 == 0: 135 | enable_gradients(G) 136 | disable_gradients(D) 137 | D.eval() 138 | G.train() 139 | G.zero_grad() 140 | loss_g = G_loss(D, G, src, trg, curriculum) 141 | loss_g.backward() 142 | optimizer_G.step() 143 | gen_loss.append(loss_g.data[0]) 144 | # plot losses 145 | if b % 10 == 0 and b > 1: 146 | vis_d.update(-loss_d.data[0]) 147 | vis_g.update(loss_g.data[0]) 148 | if e % 10 == 0 and e > 1: 149 | ce_loss = evaluate(e, G, val_iter, en_size, DE, EN, curriculum) 150 | print(ce_loss) 151 | if e % 100 == 0 and e > 1: 152 | curriculum += 1 153 | 154 | # Validation 155 | # disable_gradients(G) 156 | # disable_gradients(D) 157 | # loss_d, loss_g = 0, 0 158 | # for b, batch in enumerate(val_iter): 159 | # src, len_src = batch.src 160 | # trg, len_trg = batch.trg 161 | # src, trg = src.cuda(), trg.cuda() 162 | # # (1) Validate D 163 | # loss_d += D_loss(D, G, src, trg, args.lamb, curriculum) 164 | # # (2) Validate G 165 | # loss_g += G_loss(D, G, src, trg, curriculum) 166 | # print("loss_d:", loss_d / len(val_iter), 167 | # "loss_g", loss_g / len(val_iter)) 168 | 169 | # Save the model if the validation loss is the best we've seen so far. 170 | # if not best_val_loss or val_loss < best_val_loss: 171 | # print("[!] saving model...") 172 | # if not os.path.isdir(".save"): 173 | # os.makedirs(".save") 174 | # torch.save(G.state_dict(), './.save/wseq2seq_g_%d.pt' % (i)) 175 | # torch.save(D.state_dict(), './.save/wseq2seq_d_%d.pt' % (i)) 176 | # best_val_loss = val_loss 177 | 178 | 179 | if __name__ == "__main__": 180 | try: 181 | main() 182 | except KeyboardInterrupt as e: 183 | print("[STOP]", e) 184 | --------------------------------------------------------------------------------