├── .gitignore ├── _config.yml ├── dis_loss.png ├── gen_loss.png ├── experiment.pkl ├── loss.py ├── discriminator.py ├── README.md ├── rollout.py ├── target_lstm.py ├── data_iter.py ├── generator.py ├── main.py └── experiment_notebook.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /dis_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-czh/SeqGAN-PyTorch/HEAD/dis_loss.png -------------------------------------------------------------------------------- /gen_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-czh/SeqGAN-PyTorch/HEAD/gen_loss.png -------------------------------------------------------------------------------- /experiment.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-czh/SeqGAN-PyTorch/HEAD/experiment.pkl -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PGLoss(nn.Module): 6 | """ 7 | Pseudo-loss that gives corresponding policy gradients (on calling .backward()) 8 | for adversial training of Generator 9 | """ 10 | 11 | def __init__(self): 12 | super(PGLoss, self).__init__() 13 | 14 | def forward(self, pred, target, reward): 15 | """ 16 | Inputs: pred, target, reward 17 | - pred: (batch_size, seq_len), 18 | - target : (batch_size, seq_len), 19 | - reward : (batch_size, ), reward of each whole sentence 20 | """ 21 | one_hot = torch.zeros(pred.size(), dtype=torch.uint8) 22 | if pred.is_cuda: 23 | one_hot = one_hot.cuda() 24 | one_hot.scatter_(1, target.data.view(-1, 1), 1) 25 | loss = torch.masked_select(pred, one_hot) 26 | loss = loss * reward.contiguous().view(-1) 27 | loss = -torch.sum(loss) 28 | return loss 29 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Discriminator(nn.Module): 7 | """ 8 | A CNN for text classification. 9 | Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer. 10 | Highway architecture based on the pooled feature maps is added. Dropout is adopted. 11 | """ 12 | 13 | def __init__(self, num_classes, vocab_size, embedding_dim, filter_sizes, num_filters, dropout_prob): 14 | super(Discriminator, self).__init__() 15 | self.embed = nn.Embedding(vocab_size, embedding_dim) 16 | self.convs = nn.ModuleList([ 17 | nn.Conv2d(1, num_f, (f_size, embedding_dim)) for f_size, num_f in zip(filter_sizes, num_filters) 18 | ]) 19 | self.highway = nn.Linear(sum(num_filters), sum(num_filters)) 20 | self.dropout = nn.Dropout(p = dropout_prob) 21 | self.fc = nn.Linear(sum(num_filters), num_classes) 22 | 23 | def forward(self, x): 24 | """ 25 | Inputs: x 26 | - x: (batch_size, seq_len) 27 | Outputs: out 28 | - out: (batch_size, num_classes) 29 | """ 30 | emb = self.embed(x).unsqueeze(1) # batch_size, 1 * seq_len * emb_dim 31 | convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * seq_len] 32 | pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter] 33 | out = torch.cat(pools, 1) # batch_size * sum(num_filters) 34 | highway = self.highway(out) 35 | transform = F.sigmoid(highway) 36 | out = transform * F.relu(highway) + (1. - transform) * out # sets C = 1 - T 37 | out = F.log_softmax(self.fc(self.dropout(out)), dim=1) # batch * num_classes 38 | return out 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqGAN-PyTorch 2 | An implementation of SeqGAN (Paper: [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/pdf/1609.05473.pdf)) in PyTorch. The code performs the experiment on synthetic data as described in the paper. 3 | 4 | ## Usage 5 | ``` 6 | $ python main.py 7 | ``` 8 | Please refer to ```main.py``` for supported arguments. You can also change model parameters there. 9 | 10 | ## Dependency 11 | * PyTorch 0.4.0+ (1.0 ready) 12 | * Python 3.5+ 13 | * CUDA 8.0+ & cuDNN (For GPU) 14 | * numpy 15 | 16 | ## Hacks and Observations 17 | - Using Adam for Generator and SGD for Discriminator 18 | - Discriminator should neither be trained too powerful (fail to provide useful feedback) nor too ill-performed (randomly guessing, unable to guide generation) 19 | - The GAN phase may not always lead to massive drops in NLL (sometimes very minimal or even increases NLL) 20 | 21 | ## Sample Learning Curve 22 | Learning curve of generator obtained after MLE training for 120 steps (1 epoch per round) followed by adversarial training for 150 rounds (1 epoch per round): 23 | 24 | ![alt tag](https://raw.githubusercontent.com/X-czh/SeqGAN-PyTorch/master/gen_loss.png) 25 | 26 | Learning curve of discriminator obtained after MLE training for 50 steps (3 epochs per step) followed by adversarial training for 150 rounds (9 epoch per round): 27 | 28 | ![alt tag](https://raw.githubusercontent.com/X-czh/SeqGAN-PyTorch/master/dis_loss.png) 29 | 30 | 31 | ## Acknowledgement 32 | This code is based on Zhao Zijian's [SeqGAN-PyTorch](https://github.com/ZiJianZhao/SeqGAN-PyTorch), Surag Nair's [SeqGAN](https://github.com/suragnair/seqGAN) and Lantao Yu's original [implementation](https://github.com/LantaoYu/SeqGAN) in Tensorflow. Many thanks to [Zhao Zijian](https://github.com/ZiJianZhao), [Surag Nair](https://github.com/suragnair) and [Lantao Yu](https://github.com/LantaoYu)! 33 | -------------------------------------------------------------------------------- /rollout.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | 9 | class Rollout(object): 10 | """ Rollout Policy """ 11 | 12 | def __init__(self, model, update_rate): 13 | self.ori_model = model 14 | self.own_model = copy.deepcopy(model) 15 | self.update_rate = update_rate 16 | 17 | def get_reward(self, x, num, discriminator): 18 | """ 19 | Inputs: x, num, discriminator 20 | - x: (batch_size, seq_len) input data 21 | - num: rollout number 22 | - discriminator: discrimanator model 23 | """ 24 | rewards = [] 25 | batch_size = x.size(0) 26 | seq_len = x.size(1) 27 | for i in range(num): 28 | for l in range(1, seq_len): 29 | data = x[:, 0:l] 30 | samples = self.own_model.sample(batch_size, seq_len, data) 31 | pred = discriminator(samples) 32 | pred = pred.cpu().data[:,1].numpy() 33 | if i == 0: 34 | rewards.append(pred) 35 | else: 36 | rewards[l-1] += pred 37 | 38 | # for the last token 39 | pred = discriminator(x) 40 | pred = pred.cpu().data[:, 1].numpy() 41 | if i == 0: 42 | rewards.append(pred) 43 | else: 44 | rewards[seq_len-1] += pred 45 | rewards = np.transpose(np.array(rewards)) / (1.0 * num) # batch_size * seq_len 46 | return rewards 47 | 48 | def update_params(self): 49 | dic = {} 50 | for name, param in self.ori_model.named_parameters(): 51 | dic[name] = param.data 52 | for name, param in self.own_model.named_parameters(): 53 | if name.startswith('emb'): 54 | param.data = dic[name] 55 | else: 56 | param.data = self.update_rate * param.data + (1 - self.update_rate) * dic[name] 57 | -------------------------------------------------------------------------------- /target_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TargetLSTM(nn.Module): 7 | """ Target LSTM """ 8 | 9 | def __init__(self, vocab_size, embedding_dim, hidden_dim, use_cuda): 10 | super(TargetLSTM, self).__init__() 11 | self.hidden_dim = hidden_dim 12 | self.use_cuda = use_cuda 13 | self.embed = nn.Embedding(vocab_size, embedding_dim) 14 | self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) 15 | self.fc = nn.Linear(hidden_dim, vocab_size) 16 | self.log_softmax = nn.LogSoftmax(dim=1) 17 | self.init_params() 18 | 19 | def forward(self, x): 20 | """ 21 | Embeds input and applies LSTM on the input sequence. 22 | 23 | Inputs: x 24 | - x: (batch_size, seq_len), sequence of tokens generated by generator 25 | Outputs: out 26 | - out: (batch_size, vocab_size), lstm output prediction 27 | """ 28 | self.lstm.flatten_parameters() 29 | h0, c0 = self.init_hidden(x.size(0)) 30 | emb = self.embed(x) # batch_size * seq_len * emb_dim 31 | out, _ = self.lstm(emb, (h0, c0)) # out: seq_len * batch_size * hidden_dim 32 | out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # seq_len * batch_size * vocab_size 33 | return out 34 | 35 | def step(self, x, h, c): 36 | """ 37 | Embeds input and applies LSTM one token at a time (seq_len = 1). 38 | 39 | Inputs: x, h, c 40 | - x: (batch_size, 1), sequence of tokens generated by generator 41 | - h: (1, batch_size, hidden_dim), lstm hidden state 42 | - c: (1, batch_size, hidden_dim), lstm cell state 43 | Outputs: out, h, c 44 | - out: (batch_size, 1, vocab_size), lstm output prediction 45 | - h: (1, batch_size, hidden_dim), lstm hidden state 46 | - c: (1, batch_size, hidden_dim), lstm cell state 47 | """ 48 | self.lstm.flatten_parameters() 49 | emb = self.embed(x) # batch_size * 1 * emb_dim 50 | out, (h, c) = self.lstm(emb, (h, c)) # out: batch_size * 1 * hidden_dim 51 | out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # batch_size * vocab_size 52 | return out, h, c 53 | 54 | def init_hidden(self, batch_size): 55 | h = torch.zeros((1, batch_size, self.hidden_dim)) 56 | c = torch.zeros((1, batch_size, self.hidden_dim)) 57 | if self.use_cuda: 58 | h, c = h.cuda(), c.cuda() 59 | return h, c 60 | 61 | def init_params(self): 62 | for param in self.parameters(): 63 | param.data.normal_(0, 1) 64 | 65 | def sample(self, batch_size, seq_len): 66 | """ 67 | Samples the network and returns a batch of samples of length seq_len. 68 | 69 | Outputs: out 70 | - out: (batch_size * seq_len) 71 | """ 72 | samples = [] 73 | h, c = self.init_hidden(batch_size) 74 | x = torch.zeros(batch_size, 1, dtype=torch.int64) 75 | if self.use_cuda: 76 | x = x.cuda() 77 | for _ in range(seq_len): 78 | out, h, c = self.step(x, h, c) 79 | prob = torch.exp(out) 80 | x = torch.multinomial(prob, 1) 81 | samples.append(x) 82 | out = torch.cat(samples, dim=1) # along the batch_size dimension 83 | return out 84 | -------------------------------------------------------------------------------- /data_iter.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | 5 | 6 | class GenDataIter: 7 | """ Toy data iter to load digits """ 8 | 9 | def __init__(self, data_file, batch_size): 10 | super(GenDataIter, self).__init__() 11 | self.batch_size = batch_size 12 | self.data_lis = self.read_file(data_file) 13 | self.data_num = len(self.data_lis) 14 | self.indices = range(self.data_num) 15 | self.num_batches = math.ceil(self.data_num / self.batch_size) 16 | self.idx = 0 17 | self.reset() 18 | 19 | def __len__(self): 20 | return self.num_batches 21 | 22 | def __iter__(self): 23 | return self 24 | 25 | def __next__(self): 26 | return self.next() 27 | 28 | def reset(self): 29 | self.idx = 0 30 | random.shuffle(self.data_lis) 31 | 32 | def next(self): 33 | if self.idx >= self.data_num: 34 | raise StopIteration 35 | index = self.indices[self.idx : self.idx + self.batch_size] 36 | d = [self.data_lis[i] for i in index] 37 | d = torch.tensor(d) 38 | 39 | # 0 is prepended to d as start symbol 40 | data = torch.cat([torch.zeros(len(index), 1, dtype=torch.int64), d], dim=1) 41 | target = torch.cat([d, torch.zeros(len(index), 1, dtype=torch.int64)], dim=1) 42 | 43 | self.idx += self.batch_size 44 | return data, target 45 | 46 | def read_file(self, data_file): 47 | with open(data_file, 'r') as f: 48 | lines = f.readlines() 49 | lis = [] 50 | for line in lines: 51 | l = [int(s) for s in list(line.strip().split())] 52 | lis.append(l) 53 | return lis 54 | 55 | 56 | class DisDataIter: 57 | """ Toy data iter to load digits """ 58 | 59 | def __init__(self, real_data_file, fake_data_file, batch_size): 60 | super(DisDataIter, self).__init__() 61 | self.batch_size = batch_size 62 | real_data_lis = self.read_file(real_data_file) 63 | fake_data_lis = self.read_file(fake_data_file) 64 | self.data = real_data_lis + fake_data_lis 65 | self.labels = [1 for _ in range(len(real_data_lis))] +\ 66 | [0 for _ in range(len(fake_data_lis))] 67 | self.pairs = list(zip(self.data, self.labels)) 68 | self.data_num = len(self.pairs) 69 | self.indices = range(self.data_num) 70 | self.num_batches = math.ceil(self.data_num / self.batch_size) 71 | self.idx = 0 72 | self.reset() 73 | 74 | def __len__(self): 75 | return self.num_batches 76 | 77 | def __iter__(self): 78 | return self 79 | 80 | def __next__(self): 81 | return self.next() 82 | 83 | def reset(self): 84 | self.idx = 0 85 | random.shuffle(self.pairs) 86 | 87 | def next(self): 88 | if self.idx >= self.data_num: 89 | raise StopIteration 90 | index = self.indices[self.idx : self.idx + self.batch_size] 91 | pairs = [self.pairs[i] for i in index] 92 | data = [p[0] for p in pairs] 93 | label = [p[1] for p in pairs] 94 | data = torch.tensor(data) 95 | label = torch.tensor(label) 96 | self.idx += self.batch_size 97 | return data, label 98 | 99 | def read_file(self, data_file): 100 | with open(data_file, 'r') as f: 101 | lines = f.readlines() 102 | lis = [] 103 | for line in lines: 104 | l = [int(s) for s in list(line.strip().split())] 105 | lis.append(l) 106 | return lis 107 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Generator(nn.Module): 7 | """ Generator """ 8 | 9 | def __init__(self, vocab_size, embedding_dim, hidden_dim, use_cuda): 10 | super(Generator, self).__init__() 11 | self.hidden_dim = hidden_dim 12 | self.use_cuda = use_cuda 13 | self.embed = nn.Embedding(vocab_size, embedding_dim) 14 | self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) 15 | self.fc = nn.Linear(hidden_dim, vocab_size) 16 | self.log_softmax = nn.LogSoftmax(dim=1) 17 | self.init_params() 18 | 19 | def forward(self, x): 20 | """ 21 | Embeds input and applies LSTM on the input sequence. 22 | 23 | Inputs: x 24 | - x: (batch_size, seq_len), sequence of tokens generated by generator 25 | Outputs: out 26 | - out: (batch_size * seq_len, vocab_size), lstm output prediction 27 | """ 28 | self.lstm.flatten_parameters() 29 | h0, c0 = self.init_hidden(x.size(0)) 30 | emb = self.embed(x) # batch_size * seq_len * emb_dim 31 | out, _ = self.lstm(emb, (h0, c0)) # out: batch_size * seq_len * hidden_dim 32 | out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # (batch_size*seq_len) * vocab_size 33 | return out 34 | 35 | def step(self, x, h, c): 36 | """ 37 | Embeds input and applies LSTM one token at a time (seq_len = 1). 38 | 39 | Inputs: x, h, c 40 | - x: (batch_size, 1), sequence of tokens generated by generator 41 | - h: (1, batch_size, hidden_dim), lstm hidden state 42 | - c: (1, batch_size, hidden_dim), lstm cell state 43 | Outputs: out, h, c 44 | - out: (batch_size, vocab_size), lstm output prediction 45 | - h: (1, batch_size, hidden_dim), lstm hidden state 46 | - c: (1, batch_size, hidden_dim), lstm cell state 47 | """ 48 | self.lstm.flatten_parameters() 49 | emb = self.embed(x) # batch_size * 1 * emb_dim 50 | out, (h, c) = self.lstm(emb, (h, c)) # out: batch_size * 1 * hidden_dim 51 | out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # batch_size * vocab_size 52 | return out, h, c 53 | 54 | def init_hidden(self, batch_size): 55 | h = torch.zeros(1, batch_size, self.hidden_dim) 56 | c = torch.zeros(1, batch_size, self.hidden_dim) 57 | if self.use_cuda: 58 | h, c = h.cuda(), c.cuda() 59 | return h, c 60 | 61 | def init_params(self): 62 | for param in self.parameters(): 63 | param.data.uniform_(-0.05, 0.05) 64 | 65 | def sample(self, batch_size, seq_len, x=None): 66 | """ 67 | Samples the network and returns a batch of samples of length seq_len. 68 | 69 | Outputs: out 70 | - out: (batch_size * seq_len) 71 | """ 72 | samples = [] 73 | if x is None: 74 | h, c = self.init_hidden(batch_size) 75 | x = torch.zeros(batch_size, 1, dtype=torch.int64) 76 | if self.use_cuda: 77 | x = x.cuda() 78 | for _ in range(seq_len): 79 | out, h, c = self.step(x, h, c) 80 | prob = torch.exp(out) 81 | x = torch.multinomial(prob, 1) 82 | samples.append(x) 83 | else: 84 | h, c = self.init_hidden(x.size(0)) 85 | given_len = x.size(1) 86 | lis = x.chunk(x.size(1), dim=1) 87 | for i in range(given_len): 88 | out, h, c = self.step(lis[i], h, c) 89 | samples.append(lis[i]) 90 | prob = torch.exp(out) 91 | x = torch.multinomial(prob, 1) 92 | for _ in range(given_len, seq_len): 93 | samples.append(x) 94 | out, h, c = self.step(x, h, c) 95 | prob = torch.exp(out) 96 | x = torch.multinomial(prob, 1) 97 | out = torch.cat(samples, dim=1) # along the batch_size dimension 98 | return out 99 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle as pkl 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | 9 | from data_iter import DisDataIter, GenDataIter 10 | from generator import Generator 11 | from discriminator import Discriminator 12 | from target_lstm import TargetLSTM 13 | from rollout import Rollout 14 | from loss import PGLoss 15 | 16 | 17 | # Arguemnts 18 | parser = argparse.ArgumentParser(description='SeqGAN') 19 | parser.add_argument('--hpc', action='store_true', default=False, 20 | help='set to hpc mode') 21 | parser.add_argument('--data_path', type=str, default='/scratch/zc807/seq_gan/', metavar='PATH', 22 | help='data path to save files (default: /scratch/zc807/seq_gan/)') 23 | parser.add_argument('--rounds', type=int, default=150, metavar='N', 24 | help='rounds of adversarial training (default: 150)') 25 | parser.add_argument('--g_pretrain_steps', type=int, default=120, metavar='N', 26 | help='steps of pre-training of generators (default: 120)') 27 | parser.add_argument('--d_pretrain_steps', type=int, default=50, metavar='N', 28 | help='steps of pre-training of discriminators (default: 50)') 29 | parser.add_argument('--g_steps', type=int, default=1, metavar='N', 30 | help='steps of generator updates in one round of adverarial training (default: 1)') 31 | parser.add_argument('--d_steps', type=int, default=3, metavar='N', 32 | help='steps of discriminator updates in one round of adverarial training (default: 3)') 33 | parser.add_argument('--gk_epochs', type=int, default=1, metavar='N', 34 | help='epochs of generator updates in one step of generate update (default: 1)') 35 | parser.add_argument('--dk_epochs', type=int, default=3, metavar='N', 36 | help='epochs of discriminator updates in one step of discriminator update (default: 3)') 37 | parser.add_argument('--update_rate', type=float, default=0.8, metavar='UR', 38 | help='update rate of roll-out model (default: 0.8)') 39 | parser.add_argument('--n_rollout', type=int, default=16, metavar='N', 40 | help='number of roll-out (default: 16)') 41 | parser.add_argument('--vocab_size', type=int, default=10, metavar='N', 42 | help='vocabulary size (default: 10)') 43 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 44 | help='batch size (default: 64)') 45 | parser.add_argument('--n_samples', type=int, default=6400, metavar='N', 46 | help='number of samples gerenated per time (default: 6400)') 47 | parser.add_argument('--gen_lr', type=float, default=1e-3, metavar='LR', 48 | help='learning rate of generator optimizer (default: 1e-3)') 49 | parser.add_argument('--dis_lr', type=float, default=1e-3, metavar='LR', 50 | help='learning rate of discriminator optimizer (default: 1e-3)') 51 | parser.add_argument('--no_cuda', action='store_true', default=False, 52 | help='disables CUDA training') 53 | parser.add_argument('--seed', type=int, default=1, metavar='S', 54 | help='random seed (default: 1)') 55 | 56 | 57 | # Files 58 | POSITIVE_FILE = 'real.data' 59 | NEGATIVE_FILE = 'gene.data' 60 | 61 | 62 | # Genrator Parameters 63 | g_embed_dim = 32 64 | g_hidden_dim = 32 65 | g_seq_len = 20 66 | 67 | 68 | # Discriminator Parameters 69 | d_num_class = 2 70 | d_embed_dim = 64 71 | d_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] 72 | d_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] 73 | d_dropout_prob = 0.2 74 | 75 | 76 | def generate_samples(model, batch_size, generated_num, output_file): 77 | samples = [] 78 | for _ in range(int(generated_num / batch_size)): 79 | sample = model.sample(batch_size, g_seq_len).cpu().data.numpy().tolist() 80 | samples.extend(sample) 81 | with open(output_file, 'w') as fout: 82 | for sample in samples: 83 | string = ' '.join([str(s) for s in sample]) 84 | fout.write('{}\n'.format(string)) 85 | 86 | 87 | def train_generator_MLE(gen, data_iter, criterion, optimizer, epochs, 88 | gen_pretrain_train_loss, args): 89 | """ 90 | Train generator with MLE 91 | """ 92 | for epoch in range(epochs): 93 | total_loss = 0. 94 | for data, target in data_iter: 95 | if args.cuda: 96 | data, target = data.cuda(), target.cuda() 97 | target = target.contiguous().view(-1) 98 | output = gen(data) 99 | loss = criterion(output, target) 100 | total_loss += loss.item() 101 | optimizer.zero_grad() 102 | loss.backward() 103 | optimizer.step() 104 | data_iter.reset() 105 | avg_loss = total_loss / len(data_iter) 106 | print("Epoch {}, train loss: {:.5f}".format(epoch, avg_loss)) 107 | gen_pretrain_train_loss.append(avg_loss) 108 | 109 | 110 | def train_generator_PG(gen, dis, rollout, pg_loss, optimizer, epochs, args): 111 | """ 112 | Train generator with the guidance of policy gradient 113 | """ 114 | for epoch in range(epochs): 115 | # construct the input to the genrator, add zeros before samples and delete the last column 116 | samples = generator.sample(args.batch_size, g_seq_len) 117 | zeros = torch.zeros(args.batch_size, 1, dtype=torch.int64) 118 | if samples.is_cuda: 119 | zeros = zeros.cuda() 120 | inputs = torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous() 121 | targets = samples.data.contiguous().view((-1,)) 122 | 123 | # calculate the reward 124 | rewards = torch.tensor(rollout.get_reward(samples, args.n_rollout, dis)) 125 | if args.cuda: 126 | rewards = rewards.cuda() 127 | 128 | # update generator 129 | output = gen(inputs) 130 | loss = pg_loss(output, targets, rewards) 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | 136 | def eval_generator(model, data_iter, criterion, args): 137 | """ 138 | Evaluate generator with NLL 139 | """ 140 | total_loss = 0. 141 | with torch.no_grad(): 142 | for data, target in data_iter: 143 | if args.cuda: 144 | data, target = data.cuda(), target.cuda() 145 | target = target.contiguous().view(-1) 146 | pred = model(data) 147 | loss = criterion(pred, target) 148 | total_loss += loss.item() 149 | avg_loss = total_loss / len(data_iter) 150 | return avg_loss 151 | 152 | 153 | def train_discriminator(dis, gen, criterion, optimizer, epochs, 154 | dis_adversarial_train_loss, dis_adversarial_train_acc, args): 155 | """ 156 | Train discriminator 157 | """ 158 | generate_samples(gen, args.batch_size, args.n_samples, NEGATIVE_FILE) 159 | data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, args.batch_size) 160 | for epoch in range(epochs): 161 | correct = 0 162 | total_loss = 0. 163 | for data, target in data_iter: 164 | if args.cuda: 165 | data, target = data.cuda(), target.cuda() 166 | target = target.contiguous().view(-1) 167 | output = dis(data) 168 | pred = output.data.max(1)[1] 169 | correct += pred.eq(target.data).cpu().sum() 170 | loss = criterion(output, target) 171 | total_loss += loss.item() 172 | optimizer.zero_grad() 173 | loss.backward() 174 | optimizer.step() 175 | data_iter.reset() 176 | avg_loss = total_loss / len(data_iter) 177 | acc = correct.item() / data_iter.data_num 178 | print("Epoch {}, train loss: {:.5f}, train acc: {:.3f}".format(epoch, avg_loss, acc)) 179 | dis_adversarial_train_loss.append(avg_loss) 180 | dis_adversarial_train_acc.append(acc) 181 | 182 | 183 | def eval_discriminator(model, data_iter, criterion, args): 184 | """ 185 | Evaluate discriminator, dropout is enabled 186 | """ 187 | correct = 0 188 | total_loss = 0. 189 | with torch.no_grad(): 190 | for data, target in data_iter: 191 | if args.cuda: 192 | data, target = data.cuda(), target.cuda() 193 | target = target.contiguous().view(-1) 194 | output = model(data) 195 | pred = output.data.max(1)[1] 196 | correct += pred.eq(target.data).cpu().sum() 197 | loss = criterion(output, target) 198 | total_loss += loss.item() 199 | avg_loss = total_loss / len(data_iter) 200 | acc = correct.item() / data_iter.data_num 201 | return avg_loss, acc 202 | 203 | 204 | def adversarial_train(gen, dis, rollout, pg_loss, nll_loss, gen_optimizer, dis_optimizer, 205 | dis_adversarial_train_loss, dis_adversarial_train_acc, args): 206 | """ 207 | Adversarially train generator and discriminator 208 | """ 209 | # train generator for g_steps 210 | print("#Train generator") 211 | for i in range(args.g_steps): 212 | print("##G-Step {}".format(i)) 213 | train_generator_PG(gen, dis, rollout, pg_loss, gen_optimizer, args.gk_epochs, args) 214 | 215 | # train discriminator for d_steps 216 | print("#Train discriminator") 217 | for i in range(args.d_steps): 218 | print("##D-Step {}".format(i)) 219 | train_discriminator(dis, gen, nll_loss, dis_optimizer, args.dk_epochs, 220 | dis_adversarial_train_loss, dis_adversarial_train_acc, args) 221 | 222 | # update roll-out model 223 | rollout.update_params() 224 | 225 | 226 | if __name__ == '__main__': 227 | # Parse arguments 228 | args = parser.parse_args() 229 | args.cuda = not args.no_cuda and torch.cuda.is_available() 230 | torch.manual_seed(args.seed) 231 | if args.cuda: 232 | torch.cuda.manual_seed(args.seed) 233 | if not args.hpc: 234 | args.data_path = '' 235 | POSITIVE_FILE = args.data_path + POSITIVE_FILE 236 | NEGATIVE_FILE = args.data_path + NEGATIVE_FILE 237 | 238 | # Set models, criteria, optimizers 239 | generator = Generator(args.vocab_size, g_embed_dim, g_hidden_dim, args.cuda) 240 | discriminator = Discriminator(d_num_class, args.vocab_size, d_embed_dim, d_filter_sizes, d_num_filters, d_dropout_prob) 241 | target_lstm = TargetLSTM(args.vocab_size, g_embed_dim, g_hidden_dim, args.cuda) 242 | nll_loss = nn.NLLLoss() 243 | pg_loss = PGLoss() 244 | if args.cuda: 245 | generator = generator.cuda() 246 | discriminator = discriminator.cuda() 247 | target_lstm = target_lstm.cuda() 248 | nll_loss = nll_loss.cuda() 249 | pg_loss = pg_loss.cuda() 250 | cudnn.benchmark = True 251 | gen_optimizer = optim.Adam(params=generator.parameters(), lr=args.gen_lr) 252 | dis_optimizer = optim.SGD(params=discriminator.parameters(), lr=args.dis_lr) 253 | 254 | # Container of experiment data 255 | gen_pretrain_train_loss = [] 256 | gen_pretrain_eval_loss = [] 257 | dis_pretrain_train_loss = [] 258 | dis_pretrain_train_acc = [] 259 | dis_pretrain_eval_loss = [] 260 | dis_pretrain_eval_acc = [] 261 | gen_adversarial_eval_loss = [] 262 | dis_adversarial_train_loss = [] 263 | dis_adversarial_train_acc = [] 264 | dis_adversarial_eval_loss = [] 265 | dis_adversarial_eval_acc = [] 266 | 267 | # Generate toy data using target LSTM 268 | print('#####################################################') 269 | print('Generating data ...') 270 | print('#####################################################\n\n') 271 | generate_samples(target_lstm, args.batch_size, args.n_samples, POSITIVE_FILE) 272 | 273 | # Pre-train generator using MLE 274 | print('#####################################################') 275 | print('Start pre-training generator with MLE...') 276 | print('#####################################################\n') 277 | gen_data_iter = GenDataIter(POSITIVE_FILE, args.batch_size) 278 | for i in range(args.g_pretrain_steps): 279 | print("G-Step {}".format(i)) 280 | train_generator_MLE(generator, gen_data_iter, nll_loss, 281 | gen_optimizer, args.gk_epochs, gen_pretrain_train_loss, args) 282 | generate_samples(generator, args.batch_size, args.n_samples, NEGATIVE_FILE) 283 | eval_iter = GenDataIter(NEGATIVE_FILE, args.batch_size) 284 | gen_loss = eval_generator(target_lstm, eval_iter, nll_loss, args) 285 | gen_pretrain_eval_loss.append(gen_loss) 286 | print("eval loss: {:.5f}\n".format(gen_loss)) 287 | print('#####################################################\n\n') 288 | 289 | # Pre-train discriminator 290 | print('#####################################################') 291 | print('Start pre-training discriminator...') 292 | print('#####################################################\n') 293 | for i in range(args.d_pretrain_steps): 294 | print("D-Step {}".format(i)) 295 | train_discriminator(discriminator, generator, nll_loss, 296 | dis_optimizer, args.dk_epochs, dis_adversarial_train_loss, dis_adversarial_train_acc, args) 297 | generate_samples(generator, args.batch_size, args.n_samples, NEGATIVE_FILE) 298 | eval_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, args.batch_size) 299 | dis_loss, dis_acc = eval_discriminator(discriminator, eval_iter, nll_loss, args) 300 | dis_pretrain_eval_loss.append(dis_loss) 301 | dis_pretrain_eval_acc.append(dis_acc) 302 | print("eval loss: {:.5f}, eval acc: {:.3f}\n".format(dis_loss, dis_acc)) 303 | print('#####################################################\n\n') 304 | 305 | # Adversarial training 306 | print('#####################################################') 307 | print('Start adversarial training...') 308 | print('#####################################################\n') 309 | rollout = Rollout(generator, args.update_rate) 310 | for i in range(args.rounds): 311 | print("Round {}".format(i)) 312 | adversarial_train(generator, discriminator, rollout, 313 | pg_loss, nll_loss, gen_optimizer, dis_optimizer, 314 | dis_adversarial_train_loss, dis_adversarial_train_acc, args) 315 | generate_samples(generator, args.batch_size, args.n_samples, NEGATIVE_FILE) 316 | gen_eval_iter = GenDataIter(NEGATIVE_FILE, args.batch_size) 317 | dis_eval_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, args.batch_size) 318 | gen_loss = eval_generator(target_lstm, gen_eval_iter, nll_loss, args) 319 | gen_adversarial_eval_loss.append(gen_loss) 320 | dis_loss, dis_acc = eval_discriminator(discriminator, dis_eval_iter, nll_loss, args) 321 | dis_adversarial_eval_loss.append(dis_loss) 322 | dis_adversarial_eval_acc.append(dis_acc) 323 | print("gen eval loss: {:.5f}, dis eval loss: {:.5f}, dis eval acc: {:.3f}\n" 324 | .format(gen_loss, dis_loss, dis_acc)) 325 | 326 | # Save experiment data 327 | with open(args.data_path + 'experiment.pkl', 'wb') as f: 328 | pkl.dump( 329 | (gen_pretrain_train_loss, 330 | gen_pretrain_eval_loss, 331 | dis_pretrain_train_loss, 332 | dis_pretrain_train_acc, 333 | dis_pretrain_eval_loss, 334 | dis_pretrain_eval_acc, 335 | gen_adversarial_eval_loss, 336 | dis_adversarial_train_loss, 337 | dis_adversarial_train_acc, 338 | dis_adversarial_eval_loss, 339 | dis_adversarial_eval_acc), 340 | f, 341 | protocol=pkl.HIGHEST_PROTOCOL 342 | ) 343 | -------------------------------------------------------------------------------- /experiment_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle as pkl\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "with open('experiment.pkl', 'rb') as f:\n", 21 | " log = pkl.load(f)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "(gen_pretrain_train_loss,\n", 31 | " gen_pretrain_eval_loss,\n", 32 | " dis_pretrain_train_loss,\n", 33 | " dis_pretrain_train_acc,\n", 34 | " dis_pretrain_eval_loss,\n", 35 | " dis_pretrain_eval_acc,\n", 36 | " gen_adversarial_eval_loss,\n", 37 | " dis_adversarial_train_loss,\n", 38 | " dis_adversarial_train_acc,\n", 39 | " dis_adversarial_eval_loss,\n", 40 | " dis_adversarial_eval_acc) = log" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "# Training Settings" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "same as the default value in code" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "# Generator" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "Text(0,0.5,'Oracle NLL')" 73 | ] 74 | }, 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | }, 79 | { 80 | "data": { 81 | "image/png": "\n", 82 | "text/plain": [ 83 | "" 84 | ] 85 | }, 86 | "metadata": {}, 87 | "output_type": "display_data" 88 | } 89 | ], 90 | "source": [ 91 | "x1 = range(len(gen_pretrain_eval_loss))\n", 92 | "plt.plot(x1, gen_pretrain_eval_loss)\n", 93 | "plt.xlabel('Epochs')\n", 94 | "plt.ylabel('Oracle NLL')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/plain": [ 105 | "Text(0,0.5,'Oracle NLL')" 106 | ] 107 | }, 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | }, 112 | { 113 | "data": { 114 | "image/png": "\n", 115 | "text/plain": [ 116 | "" 117 | ] 118 | }, 119 | "metadata": {}, 120 | "output_type": "display_data" 121 | } 122 | ], 123 | "source": [ 124 | "x2 = range(len(gen_adversarial_eval_loss))\n", 125 | "plt.plot(x2, gen_adversarial_eval_loss)\n", 126 | "plt.xlabel('Epochs')\n", 127 | "plt.ylabel('Oracle NLL')" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "image/png": "\n", 138 | "text/plain": [ 139 | "" 140 | ] 141 | }, 142 | "metadata": {}, 143 | "output_type": "display_data" 144 | } 145 | ], 146 | "source": [ 147 | "x3 = range(len(gen_pretrain_eval_loss) + len(gen_adversarial_eval_loss))\n", 148 | "gen_loss = gen_pretrain_eval_loss + gen_adversarial_eval_loss\n", 149 | "plt.plot(x3, gen_loss)\n", 150 | "plt.xlabel('Epochs')\n", 151 | "plt.ylabel('Oracle NLL')\n", 152 | "plt.savefig('gen_loss.png',dpi=600)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "# Discriminator" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "data": { 169 | "text/plain": [ 170 | "Text(0,0.5,'Oracle NLL')" 171 | ] 172 | }, 173 | "execution_count": 7, 174 | "metadata": {}, 175 | "output_type": "execute_result" 176 | }, 177 | { 178 | "data": { 179 | "image/png": "\n", 180 | "text/plain": [ 181 | "" 182 | ] 183 | }, 184 | "metadata": {}, 185 | "output_type": "display_data" 186 | } 187 | ], 188 | "source": [ 189 | "x4 = range(len(dis_pretrain_eval_loss))\n", 190 | "plt.plot(x4, dis_pretrain_eval_loss)\n", 191 | "plt.xlabel('Steps (3 epoch per step)')\n", 192 | "plt.ylabel('Oracle NLL')" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 8, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "Text(0,0.5,'Oracle NLL')" 204 | ] 205 | }, 206 | "execution_count": 8, 207 | "metadata": {}, 208 | "output_type": "execute_result" 209 | }, 210 | { 211 | "data": { 212 | "image/png": "\n", 213 | "text/plain": [ 214 | "" 215 | ] 216 | }, 217 | "metadata": {}, 218 | "output_type": "display_data" 219 | } 220 | ], 221 | "source": [ 222 | "x5 = range(len(dis_adversarial_eval_loss))\n", 223 | "plt.plot(x5, dis_adversarial_eval_loss)\n", 224 | "plt.xlabel('Rounds (9 epoch per round)')\n", 225 | "plt.ylabel('Oracle NLL')" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 9, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "image/png": "\n", 236 | "text/plain": [ 237 | "" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | } 243 | ], 244 | "source": [ 245 | "x6 = range(len(dis_pretrain_eval_loss) + len(dis_adversarial_eval_loss))\n", 246 | "dis_loss = dis_pretrain_eval_loss + dis_adversarial_eval_loss\n", 247 | "plt.plot(x6, dis_loss)\n", 248 | "plt.xlabel('Steps/Rounds (3 epoch per step (x: 0-49), 9 epochs per round (x: 50-199))')\n", 249 | "plt.ylabel('Oracle NLL')\n", 250 | "plt.savefig('dis_loss.png',dpi=600)" 251 | ] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python 3", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.5.2" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 2 275 | } 276 | --------------------------------------------------------------------------------