├── assets
├── d_loss.png
├── g_loss.png
└── w_distance.png
├── README.md
├── models
├── __init__.py
├── discriminator.py
└── generator.py
├── LICENSE
├── utils.py
├── .gitignore
├── logger.py
└── train.py
/assets/d_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keon/seq2seq-wgan/HEAD/assets/d_loss.png
--------------------------------------------------------------------------------
/assets/g_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keon/seq2seq-wgan/HEAD/assets/g_loss.png
--------------------------------------------------------------------------------
/assets/w_distance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keon/seq2seq-wgan/HEAD/assets/w_distance.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [WIP] seq2seq-wgan
2 | Improved Training of Wasserstein GANs for Neural Machine Translation
3 |
4 |
5 | Based on the paper [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028).
6 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .discriminator import Discriminator
2 | from .generator import Encoder, Decoder, Seq2Seq
3 |
4 | __all__ = [
5 | Encoder,
6 | Decoder,
7 | Seq2Seq,
8 | Discriminator,
9 | ]
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Keon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import random
4 | from torch import nn
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 |
8 |
9 | class Discriminator(nn.Module):
10 | def __init__(self, vocab_size, embed_size, hidden_size,
11 | n_layers=1, dropout=0.2):
12 | super(Discriminator, self).__init__()
13 | self.embed_size = embed_size
14 | self.hidden_size = hidden_size
15 | self.vocab_size = vocab_size
16 | self.n_layers = n_layers
17 |
18 | self.embed = nn.Linear(vocab_size, embed_size, bias=False)
19 | self.dropout = nn.Dropout(dropout)
20 | self.gru = nn.GRU(hidden_size + embed_size, hidden_size,
21 | n_layers, dropout=dropout)
22 | self.out = nn.Linear(hidden_size, 1)
23 |
24 | def forward(self, input, context):
25 | """
26 | input: I x B x Vocab
27 | hidden: I x B x H
28 | context: I x B x E
29 | """
30 | # Get the embedding of the current input word (last output word)
31 | embedded = self.embed(input) # (I,B,E)
32 | embedded = self.dropout(embedded)
33 | # Combine embedded input word and attended context, run through RNN
34 | rnn_input = torch.cat([embedded, context], 2)
35 | output, hidden = self.gru(rnn_input, None)
36 | out = self.out(output[-1]) # [b, h] -> [b, 1]
37 | return out
38 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import spacy
3 | import torch
4 | from torchtext.data import Field, BucketIterator
5 | from torchtext.datasets import Multi30k
6 |
7 |
8 | def enable_gradients(model):
9 | for p in model.parameters():
10 | p.requires_grad = True
11 |
12 |
13 | def disable_gradients(model):
14 | for p in model.parameters():
15 | p.requires_grad = False
16 |
17 |
18 | def to_onehot(index, vocab_size):
19 | batch_size, seq_len = index.size(0), index.size(1)
20 | onehot = torch.FloatTensor(batch_size, seq_len, vocab_size).zero_()
21 | onehot.scatter_(2, index.data.cpu().unsqueeze(2), 1)
22 | return onehot
23 |
24 |
25 | def load_dataset(batch_size):
26 | spacy_de = spacy.load('de')
27 | spacy_en = spacy.load('en')
28 | url = re.compile('(.*)')
29 |
30 | def tokenize_de(text):
31 | return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))]
32 |
33 | def tokenize_en(text):
34 | return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))]
35 |
36 | DE = Field(tokenize=list, include_lengths=True,
37 | init_token='', eos_token='')
38 | EN = Field(tokenize=list, include_lengths=True,
39 | init_token='', eos_token='')
40 | train, val, test = Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN))
41 | DE.build_vocab(train.src)
42 | EN.build_vocab(train.trg)
43 | train_iter, val_iter, test_iter = BucketIterator.splits(
44 | (train, val, test), batch_size=batch_size, repeat=False)
45 | return train_iter, val_iter, test_iter, DE, EN
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | .data
104 | .save
105 | .tmp
106 | .samples
107 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from visdom import Visdom
4 |
5 |
6 | class VisdomWriter(object):
7 | def __init__(self, title, xlabel='Epoch', ylabel='Loss'):
8 | """Extended Visdom Writer"""
9 | self.vis = Visdom()
10 | assert self.vis.check_connection()
11 | self.title = title
12 | self.xlabel = xlabel
13 | self.ylabel = ylabel
14 | self.x = 0
15 | self.win = None
16 |
17 | def update_text(self, text):
18 | """Text Memo (usually used to note hyperparameter-configurations)"""
19 | self.vis.text(text)
20 |
21 | def update(self, y):
22 | """Update loss (X: Step (Epoch) / Y: loss)"""
23 | self.x += 1
24 | if self.win is None:
25 | self.win = self.vis.line(
26 | X=np.array([self.x]),
27 | Y=np.array([y]),
28 | opts=dict(
29 | title=self.title,
30 | xlabel=self.xlabel,
31 | ylabel=self.ylabel,
32 | ))
33 | else:
34 | self.vis.updateTrace(
35 | X=np.array([self.x]),
36 | Y=np.array([y]),
37 | win=self.win)
38 |
39 |
40 | def log_samples(file_path, samples, EN, is_output=True):
41 | eos = EN.vocab.stoi['']
42 | if is_output:
43 | _, argmax = torch.max(samples, 2)
44 | samples = argmax.cpu().data
45 | samples = samples.t()
46 | decoded_samples = []
47 | for i in range(len(samples)):
48 | decoded = ''.join([EN.vocab.itos[s] for s in samples[i]])
49 | decoded_samples.append(decoded)
50 | with open(file_path, 'a+') as f:
51 | for sample in decoded_samples:
52 | f.write(sample + '\n')
53 |
--------------------------------------------------------------------------------
/models/generator.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import random
4 | from torch import nn
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 |
8 |
9 | class Encoder(nn.Module):
10 | def __init__(self, input_size, embed_size, hidden_size,
11 | n_layers=1, dropout=0.2):
12 | super(Encoder, self).__init__()
13 | self.input_size = input_size
14 | self.hidden_size = hidden_size
15 | self.embed_size = embed_size
16 | self.embed = nn.Embedding(input_size, embed_size)
17 | self.gru = nn.GRU(embed_size, hidden_size, n_layers,
18 | dropout=dropout, bidirectional=True)
19 |
20 | def forward(self, src, hidden=None):
21 | if self.embed_size is not None:
22 | embedded = self.embed(src)
23 | outputs, hidden = self.gru(embedded, hidden)
24 | # sum bidirectional outputs
25 | outputs = (outputs[:, :, :self.hidden_size] +
26 | outputs[:, :, self.hidden_size:])
27 | return outputs, hidden
28 |
29 |
30 | class Attention(nn.Module):
31 | def __init__(self, hidden_size):
32 | super(Attention, self).__init__()
33 | self.hidden_size = hidden_size
34 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
35 | self.v = nn.Parameter(torch.rand(hidden_size))
36 | stdv = 1. / math.sqrt(self.v.size(0))
37 | self.v.data.uniform_(-stdv, stdv)
38 |
39 | def forward(self, hidden, encoder_outputs):
40 | timestep = encoder_outputs.size(0)
41 | h = hidden.repeat(timestep, 1, 1).transpose(0, 1)
42 | encoder_outputs = encoder_outputs.transpose(0, 1) # [B*T*H]
43 | attn_energies = self.score(h, encoder_outputs)
44 | return F.softmax(attn_energies, dim=1).unsqueeze(1)
45 |
46 | def score(self, hidden, encoder_outputs):
47 | # [B*T*2H]->[B*T*H]
48 | energy = self.attn(torch.cat([hidden, encoder_outputs], 2))
49 | energy = energy.transpose(1, 2) # [B*H*T]
50 | v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B*1*H]
51 | energy = torch.bmm(v, energy) # [B*1*T]
52 | return energy.squeeze(1) # [B*T]
53 |
54 |
55 | class Decoder(nn.Module):
56 | def __init__(self, embed_size, hidden_size, output_size,
57 | n_layers=1, dropout=0.2):
58 | super(Decoder, self).__init__()
59 | self.embed_size = embed_size
60 | self.hidden_size = hidden_size
61 | self.output_size = output_size
62 | self.n_layers = n_layers
63 |
64 | self.embed = nn.Embedding(output_size, embed_size)
65 | self.dropout = nn.Dropout(dropout)
66 | self.attention = Attention(hidden_size)
67 | self.gru = nn.GRU(hidden_size + embed_size, hidden_size,
68 | n_layers, dropout=dropout)
69 | self.out = nn.Linear(hidden_size * 2, output_size)
70 |
71 | def forward(self, input, last_hidden, encoder_outputs):
72 | # Get the embedding of the current input word (last output word)
73 | embedded = self.embed(input).view(1, input.data.size(0), -1) # (1,B,N)
74 | embedded = self.dropout(embedded)
75 | # Calculate attention weights and apply to encoder outputs
76 | attn_weights = self.attention(last_hidden[-1], encoder_outputs)
77 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # (B,1,N)
78 | context = context.transpose(0, 1) # (1,B,N)
79 | # Combine embedded input word and attended context, run through RNN
80 | rnn_input = torch.cat([embedded, context], 2)
81 | output, hidden = self.gru(rnn_input, last_hidden)
82 | output = output.squeeze(0) # (1,B,N) -> (B,N)
83 | context = context.squeeze(0)
84 | output = self.out(torch.cat([output, context], 1))
85 | output = F.log_softmax(output, dim=1)
86 | return output, hidden, context
87 |
88 |
89 | class Seq2Seq(nn.Module):
90 | def __init__(self, encoder, decoder):
91 | super(Seq2Seq, self).__init__()
92 | self.encoder = encoder
93 | self.decoder = decoder
94 |
95 | def forward(self, src, trg=None, teacher_forcing_ratio=0.4):
96 | batch_size = src.size(1)
97 | max_len = trg.size(0)
98 | vocab_size = self.decoder.output_size
99 | outputs = Variable(torch.zeros(max_len, batch_size, vocab_size)).cuda()
100 | contexts = Variable(torch.zeros(max_len, batch_size,
101 | self.decoder.hidden_size)).cuda()
102 |
103 | encoder_output, encoder_hidden = self.encoder(src)
104 | hidden = encoder_hidden[:self.decoder.n_layers]
105 | output = Variable(trg.data[0, :]).cuda()
106 | outputs[0] = to_onehot(output, vocab_size)
107 | for t in range(1, len(trg)):
108 | output, hidden, context = self.decoder(
109 | output, hidden, encoder_output)
110 | outputs[t] = output
111 | contexts[t] = context
112 | is_teacher = random.random() < teacher_forcing_ratio
113 | top1 = output.data.topk(1)[1].squeeze()
114 | output = Variable(trg.data[t] if is_teacher else top1).cuda()
115 | return outputs[1:], contexts[1:]
116 |
117 |
118 | def to_onehot(orig, vocab_size):
119 | batch_size = orig.size(0)
120 | onehot = torch.FloatTensor(batch_size, vocab_size).zero_()
121 | onehot.scatter_(1, orig.data.cpu().unsqueeze(1), 1)
122 | return onehot
123 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch import optim
4 | from torch.autograd import Variable, grad
5 | import torch.nn.functional as F
6 | from models import Encoder, Decoder, Seq2Seq, Discriminator
7 | from utils import load_dataset, to_onehot, enable_gradients, disable_gradients
8 | from logger import VisdomWriter, log_samples
9 |
10 |
11 | def parse_arguments():
12 | p = argparse.ArgumentParser(description='Hyperparams')
13 | p.add_argument('-epochs', type=int, default=100000,
14 | help='number of epochs for train')
15 | p.add_argument('-batch_size', type=int, default=32,
16 | help='number of epochs for train')
17 | p.add_argument('-lamb', type=float, default=10,
18 | help='lambda')
19 | return p.parse_args()
20 |
21 |
22 | def grad_penalty(D, real, gen, context, lamb):
23 | alpha = torch.rand(real.size()).cuda()
24 | x_hat = alpha * real + ((1 - alpha) * gen).cuda()
25 | x_hat = Variable(x_hat, requires_grad=True)
26 | context = Variable(context)
27 | d_hat = D(x_hat, context)
28 | ones = torch.ones(d_hat.size()).cuda()
29 | gradients = grad(outputs=d_hat, inputs=x_hat,
30 | grad_outputs=ones, create_graph=True,
31 | retain_graph=True, only_inputs=True)[0]
32 | penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb
33 | return penalty
34 |
35 |
36 | def D_loss(D, G, src, trg, lamb, curriculum):
37 | src_len = min(curriculum, len(src)-1) + 1
38 | trg_len = min(curriculum, len(src)-1) + 1
39 | # with gen
40 | gen_trg, context = G(src[:src_len], trg[:trg_len])
41 | d_gen = D(gen_trg, context)
42 | # with real
43 | trg = to_onehot(trg, D.vocab_size).type(torch.FloatTensor)[1:trg_len]
44 | trg = Variable(trg.cuda())
45 | d_real = D(trg, context)
46 | # calculate gradient panalty
47 | penalty = grad_penalty(D, trg.data, gen_trg.data, context.data, lamb)
48 | loss = d_gen.mean() - d_real.mean() + penalty
49 | return loss
50 |
51 |
52 | def G_loss(D, G, src, trg, curriculum):
53 | src_len = min(curriculum, len(src)-1) + 1
54 | trg_len = min(curriculum, len(src)-1) + 1
55 | gen_trg, context = G(src[:src_len], trg[:trg_len])
56 | loss_g = D(gen_trg, context)
57 | return -loss_g.mean()
58 |
59 |
60 | def evaluate(e, model, val_iter, vocab_size, DE, EN, curriculum):
61 | model.eval()
62 | pad = EN.vocab.stoi['']
63 | total_loss = 0
64 | for b, batch in enumerate(val_iter):
65 | src, len_src = batch.src
66 | trg, len_trg = batch.trg
67 | src = Variable(src.data.cuda(), volatile=True)
68 | trg = Variable(trg.data.cuda(), volatile=True)
69 | src_len = min(curriculum, len(src)-1) + 1
70 | trg_len = min(curriculum, len(src)-1) + 1
71 | output = model(src[:src_len], trg[:trg_len])[0]
72 | loss = F.cross_entropy(output.view(-1, vocab_size),
73 | trg[1:trg_len].contiguous().view(-1),
74 | ignore_index=pad)
75 | total_loss += loss.data[0]
76 | log_samples('./.samples/%d-translation.txt' % e, output, EN)
77 | return total_loss / len(val_iter)
78 |
79 |
80 | def main():
81 | args = parse_arguments()
82 | hidden_size = 512
83 | embed_size = 256
84 | assert torch.cuda.is_available()
85 |
86 | # visdom for plotting
87 | vis_g = VisdomWriter("Generator Loss",
88 | xlabel='Iteration', ylabel='Loss')
89 | vis_d = VisdomWriter("Negative Discriminator Loss",
90 | xlabel='Iteration', ylabel='Loss')
91 |
92 | print("[!] preparing dataset...")
93 | train_iter, val_iter, test_iter, DE, EN = load_dataset(args.batch_size)
94 | de_size, en_size = len(DE.vocab), len(EN.vocab)
95 | print("de_vocab_size: %d en_vocab_size: %d" % (de_size, en_size))
96 |
97 | print("[!] Instantiating models...")
98 | encoder = Encoder(de_size, embed_size, hidden_size,
99 | n_layers=2, dropout=0.5)
100 | decoder = Decoder(embed_size, hidden_size, en_size,
101 | n_layers=1, dropout=0.5)
102 | G = Seq2Seq(encoder, decoder).cuda()
103 | D = Discriminator(en_size, embed_size, hidden_size).cuda()
104 | optimizer_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.9))
105 | optimizer_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))
106 | # TTUR paper https://arxiv.org/abs/1706.08500
107 |
108 | # pretrained
109 | # G.load_state_dict(torch.load("./.tmp/21.pt"))
110 |
111 | curriculum = 1
112 | dis_loss = []
113 | gen_loss = []
114 | for e in range(1, args.epochs+1):
115 | # Training
116 | for b, batch in enumerate(train_iter):
117 | src, len_src = batch.src
118 | trg, len_trg = batch.trg
119 | src, trg = src.cuda(), trg.cuda()
120 | # (1) Update D network
121 | enable_gradients(D)
122 | disable_gradients(G)
123 | G.eval()
124 | D.train()
125 | # clamp parameters to a cube
126 | for p in D.parameters():
127 | p.data.clamp_(-0.01, 0.01)
128 | D.zero_grad()
129 | loss_d = D_loss(D, G, src, trg, args.lamb, curriculum)
130 | loss_d.backward()
131 | optimizer_D.step()
132 | dis_loss.append(loss_d.data[0])
133 | # (2) Update G network
134 | if b % 10 == 0:
135 | enable_gradients(G)
136 | disable_gradients(D)
137 | D.eval()
138 | G.train()
139 | G.zero_grad()
140 | loss_g = G_loss(D, G, src, trg, curriculum)
141 | loss_g.backward()
142 | optimizer_G.step()
143 | gen_loss.append(loss_g.data[0])
144 | # plot losses
145 | if b % 10 == 0 and b > 1:
146 | vis_d.update(-loss_d.data[0])
147 | vis_g.update(loss_g.data[0])
148 | if e % 10 == 0 and e > 1:
149 | ce_loss = evaluate(e, G, val_iter, en_size, DE, EN, curriculum)
150 | print(ce_loss)
151 | if e % 100 == 0 and e > 1:
152 | curriculum += 1
153 |
154 | # Validation
155 | # disable_gradients(G)
156 | # disable_gradients(D)
157 | # loss_d, loss_g = 0, 0
158 | # for b, batch in enumerate(val_iter):
159 | # src, len_src = batch.src
160 | # trg, len_trg = batch.trg
161 | # src, trg = src.cuda(), trg.cuda()
162 | # # (1) Validate D
163 | # loss_d += D_loss(D, G, src, trg, args.lamb, curriculum)
164 | # # (2) Validate G
165 | # loss_g += G_loss(D, G, src, trg, curriculum)
166 | # print("loss_d:", loss_d / len(val_iter),
167 | # "loss_g", loss_g / len(val_iter))
168 |
169 | # Save the model if the validation loss is the best we've seen so far.
170 | # if not best_val_loss or val_loss < best_val_loss:
171 | # print("[!] saving model...")
172 | # if not os.path.isdir(".save"):
173 | # os.makedirs(".save")
174 | # torch.save(G.state_dict(), './.save/wseq2seq_g_%d.pt' % (i))
175 | # torch.save(D.state_dict(), './.save/wseq2seq_d_%d.pt' % (i))
176 | # best_val_loss = val_loss
177 |
178 |
179 | if __name__ == "__main__":
180 | try:
181 | main()
182 | except KeyboardInterrupt as e:
183 | print("[STOP]", e)
184 |
--------------------------------------------------------------------------------