├── README.md ├── loss.py ├── .gitignore ├── discriminator.py ├── rollout.py ├── target_lstm.py ├── generator.py ├── data_iter.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # SeqGAN-PyTorch 2 | A implementation of SeqGAN in PyTorch, following the implementation in tensorflow. 3 | 4 | 5 | ## Tested with: 6 | * **PyTorch v1 Stable** 7 | * Python 3.6 8 | * CUDA at least 8.0 (For GPU) 9 | 10 | ## Origin 11 | The idea is from paper [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/pdf/1609.05473.pdf) 12 | 13 | The code is rewrited in PyTorch with the structure largely from [Tensorflow Implementation](https://github.com/LantaoYu/SeqGAN) 14 | 15 | ## Runing 16 | ``` 17 | $ python main.py 18 | ``` 19 | After runing this file, the results will be printed on terminal. You can change the parameters in the ```main.py```. 20 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | class NLLLoss(nn.Module): 7 | """Self-Defined NLLLoss Function 8 | 9 | Args: 10 | weight: Tensor (num_class, ) 11 | """ 12 | def __init__(self, weight): 13 | super(NLLLoss, self).__init__() 14 | self.weight = weight 15 | 16 | def forward(self, prob, target): 17 | """ 18 | Args: 19 | prob: (N, C) 20 | target : (N, ) 21 | """ 22 | N = target.size(0) 23 | C = prob.size(1) 24 | weight = Variable(self.weight).view((1, -1)) 25 | weight = weight.expand(N, C) # (N, C) 26 | if prob.is_cuda: 27 | weight = weight.cuda() 28 | prob = weight * prob 29 | 30 | one_hot = torch.zeros((N, C)) 31 | if prob.is_cuda: 32 | one_hot = one_hot.cuda() 33 | one_hot.scatter_(1, target.data.view((-1,1)), 1) 34 | one_hot = one_hot.type(torch.ByteTensor) 35 | one_hot = Variable(one_hot) 36 | if prob.is_cuda: 37 | one_hot = one_hot.cuda() 38 | loss = torch.masked_select(prob, one_hot) 39 | return -torch.sum(loss) 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.data 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class Discriminator(nn.Module): 13 | """A CNN for text classification 14 | 15 | architecture: Embedding >> Convolution >> Max-pooling >> Softmax 16 | """ 17 | 18 | def __init__(self, num_classes, vocab_size, emb_dim, filter_sizes, num_filters, dropout): 19 | super(Discriminator, self).__init__() 20 | self.emb = nn.Embedding(vocab_size, emb_dim) 21 | self.convs = nn.ModuleList([ 22 | nn.Conv2d(1, n, (f, emb_dim)) for (n, f) in zip(num_filters, filter_sizes) 23 | ]) 24 | self.highway = nn.Linear(sum(num_filters), sum(num_filters)) 25 | self.dropout = nn.Dropout(p=dropout) 26 | self.lin = nn.Linear(sum(num_filters), num_classes) 27 | self.softmax = nn.LogSoftmax() 28 | self.init_parameters() 29 | 30 | def forward(self, x): 31 | """ 32 | Args: 33 | x: (batch_size * seq_len) 34 | """ 35 | emb = self.emb(x).unsqueeze(1) # batch_size * 1 * seq_len * emb_dim 36 | convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * length] 37 | pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter] 38 | pred = torch.cat(pools, 1) # batch_size * num_filters_sum 39 | highway = self.highway(pred) 40 | pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred 41 | pred = self.softmax(self.lin(self.dropout(pred))) 42 | return pred 43 | 44 | def init_parameters(self): 45 | for param in self.parameters(): 46 | param.data.uniform_(-0.05, 0.05) 47 | -------------------------------------------------------------------------------- /rollout.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import random 5 | import math 6 | import copy 7 | 8 | import tqdm 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | 17 | class Rollout(object): 18 | """Roll-out policy""" 19 | def __init__(self, model, update_rate): 20 | self.ori_model = model 21 | self.own_model = copy.deepcopy(model) 22 | self.update_rate = update_rate 23 | 24 | def get_reward(self, x, num, discriminator): 25 | """ 26 | Args: 27 | x : (batch_size, seq_len) input data 28 | num : roll-out number 29 | discriminator : discrimanator model 30 | """ 31 | rewards = [] 32 | batch_size = x.size(0) 33 | seq_len = x.size(1) 34 | for i in range(num): 35 | for l in range(1, seq_len): 36 | data = x[:, 0:l] 37 | samples = self.own_model.sample(batch_size, seq_len, data) 38 | pred = discriminator(samples) 39 | pred = pred.cpu().data[:,1].numpy() 40 | if i == 0: 41 | rewards.append(pred) 42 | else: 43 | rewards[l-1] += pred 44 | 45 | # for the last token 46 | pred = discriminator(x) 47 | pred = pred.cpu().data[:, 1].numpy() 48 | if i == 0: 49 | rewards.append(pred) 50 | else: 51 | rewards[seq_len-1] += pred 52 | rewards = np.transpose(np.array(rewards)) / (1.0 * num) # batch_size * seq_len 53 | return rewards 54 | 55 | def update_params(self): 56 | dic = {} 57 | for name, param in self.ori_model.named_parameters(): 58 | dic[name] = param.data 59 | for name, param in self.own_model.named_parameters(): 60 | if name.startswith('emb'): 61 | param.data = dic[name] 62 | else: 63 | param.data = self.update_rate * param.data + (1 - self.update_rate) * dic[name] 64 | -------------------------------------------------------------------------------- /target_lstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | class TargetLSTM(nn.Module): 14 | """Target Lstm """ 15 | def __init__(self, num_emb, emb_dim, hidden_dim, use_cuda): 16 | super(TargetLSTM, self).__init__() 17 | self.num_emb = num_emb 18 | self.emb_dim = emb_dim 19 | self.hidden_dim = hidden_dim 20 | self.use_cuda = use_cuda 21 | self.emb = nn.Embedding(num_emb, emb_dim) 22 | self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True) 23 | self.lin = nn.Linear(hidden_dim, num_emb) 24 | self.softmax = nn.LogSoftmax() 25 | self.init_params() 26 | 27 | def forward(self, x): 28 | """ 29 | Args: 30 | x: (batch_size, seq_len), sequence of tokens generated by generator 31 | """ 32 | emb = self.emb(x) 33 | h0, c0 = self.init_hidden(x.size(0)) 34 | output, (h, c) = self.lstm(emb, (h0, c0)) 35 | pred = self.softmax(self.lin(output.contiguous().view(-1, self.hidden_dim))) 36 | return pred 37 | 38 | def step(self, x, h, c): 39 | """ 40 | Args: 41 | x: (batch_size, 1), sequence of tokens generated by generator 42 | h: (1, batch_size, hidden_dim), lstm hidden state 43 | c: (1, batch_size, hidden_dim), lstm cell state 44 | """ 45 | emb = self.emb(x) 46 | output, (h, c) = self.lstm(emb, (h, c)) 47 | pred = F.softmax(self.lin(output.view(-1, self.hidden_dim)), dim=1) 48 | return pred, h, c 49 | 50 | 51 | def init_hidden(self, batch_size): 52 | h = Variable(torch.zeros((1, batch_size, self.hidden_dim))) 53 | c = Variable(torch.zeros((1, batch_size, self.hidden_dim))) 54 | if self.use_cuda: 55 | h, c = h.cuda(), c.cuda() 56 | return h, c 57 | 58 | def init_params(self): 59 | for param in self.parameters(): 60 | param.data.normal_(0, 1) 61 | 62 | def sample(self, batch_size, seq_len): 63 | res = [] 64 | with torch.no_grad(): 65 | x = Variable(torch.zeros((batch_size, 1)).long()) 66 | if self.use_cuda: 67 | x = x.cuda() 68 | h, c = self.init_hidden(batch_size) 69 | samples = [] 70 | for i in range(seq_len): 71 | output, h, c = self.step(x, h, c) 72 | x = output.multinomial(1) 73 | samples.append(x) 74 | output = torch.cat(samples, dim=1) 75 | return output 76 | return None 77 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | class Generator(nn.Module): 14 | """Generator """ 15 | def __init__(self, num_emb, emb_dim, hidden_dim, use_cuda): 16 | super(Generator, self).__init__() 17 | self.num_emb = num_emb 18 | self.emb_dim = emb_dim 19 | self.hidden_dim = hidden_dim 20 | self.use_cuda = use_cuda 21 | self.emb = nn.Embedding(num_emb, emb_dim) 22 | self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True) 23 | self.lin = nn.Linear(hidden_dim, num_emb) 24 | self.softmax = nn.LogSoftmax() 25 | self.init_params() 26 | 27 | def forward(self, x): 28 | """ 29 | Args: 30 | x: (batch_size, seq_len), sequence of tokens generated by generator 31 | """ 32 | emb = self.emb(x) 33 | h0, c0 = self.init_hidden(x.size(0)) 34 | output, (h, c) = self.lstm(emb, (h0, c0)) 35 | pred = self.softmax(self.lin(output.contiguous().view(-1, self.hidden_dim))) 36 | return pred 37 | 38 | def step(self, x, h, c): 39 | """ 40 | Args: 41 | x: (batch_size, 1), sequence of tokens generated by generator 42 | h: (1, batch_size, hidden_dim), lstm hidden state 43 | c: (1, batch_size, hidden_dim), lstm cell state 44 | """ 45 | emb = self.emb(x) 46 | output, (h, c) = self.lstm(emb, (h, c)) 47 | pred = F.softmax(self.lin(output.view(-1, self.hidden_dim)), dim=1) 48 | return pred, h, c 49 | 50 | 51 | def init_hidden(self, batch_size): 52 | h = Variable(torch.zeros((1, batch_size, self.hidden_dim))) 53 | c = Variable(torch.zeros((1, batch_size, self.hidden_dim))) 54 | if self.use_cuda: 55 | h, c = h.cuda(), c.cuda() 56 | return h, c 57 | 58 | def init_params(self): 59 | for param in self.parameters(): 60 | param.data.uniform_(-0.05, 0.05) 61 | 62 | def sample(self, batch_size, seq_len, x=None): 63 | res = [] 64 | flag = False # whether sample from zero 65 | if x is None: 66 | flag = True 67 | if flag: 68 | x = Variable(torch.zeros((batch_size, 1)).long()) 69 | if self.use_cuda: 70 | x = x.cuda() 71 | h, c = self.init_hidden(batch_size) 72 | samples = [] 73 | if flag: 74 | for i in range(seq_len): 75 | output, h, c = self.step(x, h, c) 76 | x = output.multinomial(1) 77 | samples.append(x) 78 | else: 79 | given_len = x.size(1) 80 | lis = x.chunk(x.size(1), dim=1) 81 | for i in range(given_len): 82 | output, h, c = self.step(lis[i], h, c) 83 | samples.append(lis[i]) 84 | x = output.multinomial(1) 85 | for i in range(given_len, seq_len): 86 | samples.append(x) 87 | output, h, c = self.step(x, h, c) 88 | x = output.multinomial(1) 89 | output = torch.cat(samples, dim=1) 90 | return output 91 | -------------------------------------------------------------------------------- /data_iter.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import random 5 | import math 6 | 7 | import tqdm 8 | 9 | import numpy as np 10 | import torch 11 | class GenDataIter(object): 12 | """ Toy data iter to load digits""" 13 | def __init__(self, data_file, batch_size): 14 | super(GenDataIter, self).__init__() 15 | self.batch_size = batch_size 16 | self.data_lis = self.read_file(data_file) 17 | self.data_num = len(self.data_lis) 18 | self.indices = range(self.data_num) 19 | self.num_batches = int(math.ceil(float(self.data_num)/self.batch_size)) 20 | self.idx = 0 21 | 22 | def __len__(self): 23 | return self.num_batches 24 | 25 | def __iter__(self): 26 | return self 27 | 28 | def __next__(self): 29 | return self.next() 30 | 31 | def reset(self): 32 | self.idx = 0 33 | random.shuffle(self.data_lis) 34 | 35 | def next(self): 36 | if self.idx >= self.data_num: 37 | raise StopIteration 38 | index = self.indices[self.idx:self.idx+self.batch_size] 39 | d = [self.data_lis[i] for i in index] 40 | d = torch.LongTensor(np.asarray(d, dtype='int64')) 41 | data = torch.cat([torch.zeros(self.batch_size, 1).long(), d], dim=1) 42 | target = torch.cat([d, torch.zeros(self.batch_size, 1).long()], dim=1) 43 | self.idx += self.batch_size 44 | return data, target 45 | 46 | def read_file(self, data_file): 47 | with open(data_file, 'r') as f: 48 | lines = f.readlines() 49 | lis = [] 50 | for line in lines: 51 | l = line.strip().split(' ') 52 | l = [int(s) for s in l] 53 | lis.append(l) 54 | return lis 55 | 56 | class DisDataIter(object): 57 | """ Toy data iter to load digits""" 58 | def __init__(self, real_data_file, fake_data_file, batch_size): 59 | super(DisDataIter, self).__init__() 60 | self.batch_size = batch_size 61 | real_data_lis = self.read_file(real_data_file) 62 | fake_data_lis = self.read_file(fake_data_file) 63 | self.data = real_data_lis + fake_data_lis 64 | self.labels = [1 for _ in range(len(real_data_lis))] +\ 65 | [0 for _ in range(len(fake_data_lis))] 66 | self.pairs = list(zip(self.data, self.labels)) 67 | self.data_num = len(self.pairs) 68 | self.indices = range(self.data_num) 69 | self.num_batches = int(math.ceil(float(self.data_num)/self.batch_size)) 70 | self.idx = 0 71 | 72 | def __len__(self): 73 | return self.num_batches 74 | 75 | def __iter__(self): 76 | return self 77 | 78 | def __next__(self): 79 | return self.next() 80 | 81 | def reset(self): 82 | self.idx = 0 83 | random.shuffle(self.pairs) 84 | 85 | def next(self): 86 | if self.idx >= self.data_num: 87 | raise StopIteration 88 | index = self.indices[self.idx:self.idx+self.batch_size] 89 | pairs = [self.pairs[i] for i in index] 90 | data = [p[0] for p in pairs] 91 | label = [p[1] for p in pairs] 92 | data = torch.LongTensor(np.asarray(data, dtype='int64')) 93 | label = torch.LongTensor(np.asarray(label, dtype='int64')) 94 | self.idx += self.batch_size 95 | return data, label 96 | 97 | def read_file(self, data_file): 98 | with open(data_file, 'r') as f: 99 | lines = f.readlines() 100 | lis = [] 101 | for line in lines: 102 | l = line.strip().split(' ') 103 | l = [int(s) for s in l] 104 | lis.append(l) 105 | return lis 106 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import random 5 | import math 6 | 7 | import argparse 8 | import tqdm 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | 17 | from generator import Generator 18 | from discriminator import Discriminator 19 | from target_lstm import TargetLSTM 20 | from rollout import Rollout 21 | from data_iter import GenDataIter, DisDataIter 22 | # ================== Parameter Definition ================= 23 | 24 | parser = argparse.ArgumentParser(description='Training Parameter') 25 | parser.add_argument('--cuda', action='store', default=None, type=int) 26 | opt = parser.parse_args() 27 | print(opt) 28 | 29 | # Basic Training Paramters 30 | SEED = 88 31 | BATCH_SIZE = 64 32 | TOTAL_BATCH = 200 33 | GENERATED_NUM = 10000 34 | POSITIVE_FILE = 'real.data' 35 | NEGATIVE_FILE = 'gene.data' 36 | EVAL_FILE = 'eval.data' 37 | VOCAB_SIZE = 5000 38 | PRE_EPOCH_NUM = 120 39 | 40 | if opt.cuda is not None and opt.cuda >= 0: 41 | torch.cuda.set_device(opt.cuda) 42 | opt.cuda = True 43 | 44 | # Genrator Parameters 45 | g_emb_dim = 32 46 | g_hidden_dim = 32 47 | g_sequence_len = 20 48 | 49 | # Discriminator Parameters 50 | d_emb_dim = 64 51 | d_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] 52 | d_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] 53 | 54 | d_dropout = 0.75 55 | d_num_class = 2 56 | 57 | 58 | 59 | def generate_samples(model, batch_size, generated_num, output_file): 60 | samples = [] 61 | for _ in range(int(generated_num / batch_size)): 62 | sample = model.sample(batch_size, g_sequence_len).cpu().data.numpy().tolist() 63 | samples.extend(sample) 64 | with open(output_file, 'w') as fout: 65 | for sample in samples: 66 | string = ' '.join([str(s) for s in sample]) 67 | fout.write('%s\n' % string) 68 | 69 | def train_epoch(model, data_iter, criterion, optimizer): 70 | total_loss = 0. 71 | total_words = 0. 72 | for (data, target) in data_iter:#tqdm( 73 | #data_iter, mininterval=2, desc=' - Training', leave=False): 74 | data = Variable(data) 75 | target = Variable(target) 76 | if opt.cuda: 77 | data, target = data.cuda(), target.cuda() 78 | target = target.contiguous().view(-1) 79 | pred = model.forward(data) 80 | loss = criterion(pred, target) 81 | total_loss += loss.item() 82 | total_words += data.size(0) * data.size(1) 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | data_iter.reset() 87 | return math.exp(total_loss / total_words) 88 | 89 | def eval_epoch(model, data_iter, criterion): 90 | total_loss = 0. 91 | total_words = 0. 92 | with torch.no_grad(): 93 | for (data, target) in data_iter:#tqdm( 94 | #data_iter, mininterval=2, desc=' - Training', leave=False): 95 | data = Variable(data) 96 | target = Variable(target) 97 | if opt.cuda: 98 | data, target = data.cuda(), target.cuda() 99 | target = target.contiguous().view(-1) 100 | pred = model.forward(data) 101 | loss = criterion(pred, target) 102 | total_loss += loss.item() 103 | total_words += data.size(0) * data.size(1) 104 | data_iter.reset() 105 | 106 | assert total_words > 0 # Otherwise NullpointerException 107 | return math.exp(total_loss / total_words) 108 | 109 | class GANLoss(nn.Module): 110 | """Reward-Refined NLLLoss Function for adversial training of Gnerator""" 111 | def __init__(self): 112 | super(GANLoss, self).__init__() 113 | 114 | def forward(self, prob, target, reward): 115 | """ 116 | Args: 117 | prob: (N, C), torch Variable 118 | target : (N, ), torch Variable 119 | reward : (N, ), torch Variable 120 | """ 121 | N = target.size(0) 122 | C = prob.size(1) 123 | one_hot = torch.zeros((N, C)) 124 | if prob.is_cuda: 125 | one_hot = one_hot.cuda() 126 | one_hot.scatter_(1, target.data.view((-1,1)), 1) 127 | one_hot = one_hot.type(torch.ByteTensor) 128 | one_hot = Variable(one_hot) 129 | if prob.is_cuda: 130 | one_hot = one_hot.cuda() 131 | loss = torch.masked_select(prob, one_hot) 132 | loss = loss * reward 133 | loss = -torch.sum(loss) 134 | return loss 135 | 136 | 137 | def main(): 138 | random.seed(SEED) 139 | np.random.seed(SEED) 140 | 141 | # Define Networks 142 | generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda) 143 | discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim, d_filter_sizes, d_num_filters, d_dropout) 144 | target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda) 145 | if opt.cuda: 146 | generator = generator.cuda() 147 | discriminator = discriminator.cuda() 148 | target_lstm = target_lstm.cuda() 149 | # Generate toy data using target lstm 150 | print('Generating data ...') 151 | generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE) 152 | 153 | # Load data from file 154 | gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE) 155 | 156 | # Pretrain Generator using MLE 157 | gen_criterion = nn.NLLLoss(reduction='sum') 158 | gen_optimizer = optim.Adam(generator.parameters()) 159 | if opt.cuda: 160 | gen_criterion = gen_criterion.cuda() 161 | print('Pretrain with MLE ...') 162 | for epoch in range(PRE_EPOCH_NUM): 163 | loss = train_epoch(generator, gen_data_iter, gen_criterion, gen_optimizer) 164 | print('Epoch [%d] Model Loss: %f'% (epoch, loss)) 165 | generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) 166 | eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE) 167 | loss = eval_epoch(target_lstm, eval_iter, gen_criterion) 168 | print('Epoch [%d] True Loss: %f' % (epoch, loss)) 169 | 170 | # Pretrain Discriminator 171 | dis_criterion = nn.NLLLoss(reduction='sum') 172 | dis_optimizer = optim.Adam(discriminator.parameters()) 173 | if opt.cuda: 174 | dis_criterion = dis_criterion.cuda() 175 | print('Pretrain Discriminator ...') 176 | for epoch in range(5): 177 | generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE) 178 | dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE) 179 | for _ in range(3): 180 | loss = train_epoch(discriminator, dis_data_iter, dis_criterion, dis_optimizer) 181 | print('Epoch [%d], loss: %f' % (epoch, loss)) 182 | # Adversarial Training 183 | rollout = Rollout(generator, 0.8) 184 | print('#####################################################') 185 | print('Start Adeversatial Training...\n') 186 | gen_gan_loss = GANLoss() 187 | gen_gan_optm = optim.Adam(generator.parameters()) 188 | if opt.cuda: 189 | gen_gan_loss = gen_gan_loss.cuda() 190 | gen_criterion = nn.NLLLoss(reduction='sum') 191 | if opt.cuda: 192 | gen_criterion = gen_criterion.cuda() 193 | dis_criterion = nn.NLLLoss(reduction='sum') 194 | dis_optimizer = optim.Adam(discriminator.parameters()) 195 | if opt.cuda: 196 | dis_criterion = dis_criterion.cuda() 197 | for total_batch in range(TOTAL_BATCH): 198 | ## Train the generator for one step 199 | for it in range(1): 200 | samples = generator.sample(BATCH_SIZE, g_sequence_len) 201 | # construct the input to the genrator, add zeros before samples and delete the last column 202 | zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor) 203 | if samples.is_cuda: 204 | zeros = zeros.cuda() 205 | inputs = Variable(torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous()) 206 | targets = Variable(samples.data).contiguous().view((-1,)) 207 | # calculate the reward 208 | rewards = rollout.get_reward(samples, 16, discriminator) 209 | rewards = Variable(torch.Tensor(rewards)) 210 | rewards = torch.exp(rewards).contiguous().view((-1,)) 211 | if opt.cuda: 212 | rewards = rewards.cuda() 213 | prob = generator.forward(inputs) 214 | loss = gen_gan_loss(prob, targets, rewards) 215 | gen_gan_optm.zero_grad() 216 | loss.backward() 217 | gen_gan_optm.step() 218 | 219 | if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1: 220 | generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) 221 | eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE) 222 | loss = eval_epoch(target_lstm, eval_iter, gen_criterion) 223 | print('Batch [%d] True Loss: %f' % (total_batch, loss)) 224 | rollout.update_params() 225 | 226 | for _ in range(4): 227 | generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE) 228 | dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE) 229 | for _ in range(2): 230 | loss = train_epoch(discriminator, dis_data_iter, dis_criterion, dis_optimizer) 231 | if __name__ == '__main__': 232 | main() 233 | --------------------------------------------------------------------------------