├── .idea └── vcs.xml ├── README.md ├── cgan ├── cgan.py └── out │ ├── 366b924f53bba0.svg │ └── GAN output [Epoch 80].jpg ├── ct-gan ├── ct_gan.py ├── models.py └── out │ ├── 367444be15895e.svg │ └── GAN output [Epoch 34].jpg ├── dcgan ├── dcgan.py └── out │ ├── 366cebae21d91c.svg │ ├── GAN output [Epoch 358].jpg │ └── real_samples.png ├── gan ├── gan.py └── out │ ├── 366b962ff767a0.svg │ └── GAN output [Epoch 95].jpg ├── infogan ├── infogan.py └── out │ ├── (Nothing is varied) GAN output [Epoch 14].jpg │ ├── (Only vary c1) GAN output [Epoch 14].jpg │ ├── (Only vary c2) GAN output [Epoch 14].jpg │ └── 367d344b5db086.svg ├── lsgan ├── lsgan.py └── out │ ├── 3677c4a9714c1c.svg │ └── GAN output [Epoch 11].jpg ├── pix2pix ├── datasets.py ├── download_dataset.sh ├── models.py ├── out │ ├── 3685e66f775102.svg │ ├── 3685e7054e8534.svg │ ├── GAN output [Epoch 216].jpg │ └── GAN output [Epoch 240].jpg └── pix2pix.py ├── ralsgan ├── out │ ├── 36844a331a7460.svg │ ├── GAN output [Iteration 31616].png │ └── GAN output [Iteration 31680].png ├── preprocess_cat_dataset.py ├── ralsgan.py └── setting_up_script.sh ├── srgan ├── datasets.py ├── models.py ├── out │ ├── 36896236305c82.svg │ ├── GAN output [Epoch 6] (1).jpg │ └── GAN output [Epoch 6].jpg └── srgan.py ├── wgan-gp ├── models.py ├── out │ ├── 36737d5d4d1ebe.svg │ ├── 3673cb5f59f630.svg │ ├── GAN output [Epoch 105].jpg │ └── GAN output [Epoch 52].jpg └── wgan_gp.py └── wgan ├── dataset_loader.py ├── out ├── bn_366d939cd31210.svg ├── bn_GAN output [Epoch 99].jpg ├── no_bn_366d9efab2beac.svg ├── no_bn_GAN output [Epoch 99] (1).jpg └── real_samples.png └── wgan.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /cgan/cgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Although I tried to implement this exactly as it was described in the paper, 3 | I couldn't get over the many problems such as mode collapse, diverging generator etc with the described network. 4 | I wasn't able to stabilize the training with SGD. I don't understand why a 5 | maxout layer (already nonlinear) has to be followed by a relu layer. I also had to use batchnorm for stability (instead of dropout). 6 | This paper brings the simple (yet effective, especially if you implemented on more 7 | complex networks such as pix2pix or any visual conditional generation system) 8 | idea of concatenating condition (e.g., labels) to the input to get conditional outputs, 9 | yet I think has fundamental flaws in designing the architecture. 10 | ''' 11 | import argparse 12 | 13 | import torch 14 | from torch import nn, optim 15 | from torch.autograd.variable import Variable 16 | 17 | from torchvision import transforms, datasets 18 | from torch.utils.data import DataLoader 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs of training') 22 | parser.add_argument('--batch_size', type=int, default=100, help='size of the batches') 23 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') 24 | parser.add_argument('--b1', type=float, default=0.5, help='adam: beta 1') 25 | parser.add_argument('--b2', type=float, default=0.999, help='adam: beta 2') 26 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 27 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension') 28 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 29 | parser.add_argument('--n_classes', type=int, default=10, help='number of classes (e.g., digits 0 ..9, 10 classes on mnist)') 30 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 31 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 32 | parser.add_argument('--sample_interval', type=int, default=512, help='interval betwen image samples') 33 | opt = parser.parse_args() 34 | 35 | 36 | try: 37 | import visdom 38 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 39 | except ImportError: 40 | vis = None 41 | else: 42 | vis.close(None) # Clear all figures. 43 | 44 | 45 | img_dims = (opt.channels, opt.img_size, opt.img_size) 46 | n_features = opt.channels * opt.img_size * opt.img_size 47 | 48 | def weights_init_xavier(m): 49 | classname = m.__class__.__name__ 50 | if classname.find('Conv') != -1: 51 | nn.init.xavier_normal_(m.weight.data, gain=0.02) 52 | elif classname.find('Linear') != -1: 53 | nn.init.xavier_normal_(m.weight.data, gain=0.02) 54 | elif classname.find('BatchNorm1d') != -1: 55 | nn.init.normal_(m.weight.data, 1.0, 0.02) 56 | nn.init.constant_(m.bias.data, 0.0) 57 | 58 | class Generator(nn.Module): 59 | def __init__(self): 60 | super(Generator, self).__init__() 61 | 62 | # Map z & y (noise and label) into the hidden layer. 63 | # TO DO: How to run this with a function defined here? 64 | self.z_map = nn.Sequential( 65 | nn.Linear(opt.latent_dim, 200), 66 | nn.BatchNorm1d(200), 67 | nn.ReLU(inplace=True), 68 | ) 69 | self.y_map = nn.Sequential( 70 | nn.Linear(opt.n_classes, 1000), 71 | nn.BatchNorm1d(1000), 72 | nn.ReLU(inplace=True), 73 | ) 74 | self.zy_map = nn.Sequential( 75 | nn.Linear(1200, 1200), 76 | nn.BatchNorm1d(1200), 77 | nn.ReLU(inplace=True), 78 | ) 79 | 80 | self.model = nn.Sequential( 81 | nn.Linear(1200, n_features), 82 | nn.Tanh() 83 | ) 84 | # Tanh > Image values are between [-1, 1] 85 | 86 | 87 | def forward(self, z, y): 88 | zh = self.z_map(z) 89 | yh = self.y_map(y) 90 | zy = torch.cat((zh, yh), dim=1) # Combine noise and labels. 91 | zyh = self.zy_map(zy) 92 | x = self.model(zyh) 93 | x = x.view(x.size(0), *img_dims) 94 | return x 95 | 96 | class Discriminator(nn.Module): 97 | def __init__(self): 98 | super(Discriminator, self).__init__() 99 | 100 | self.model = nn.Sequential( 101 | nn.Linear(240, 1), 102 | nn.Sigmoid() 103 | ) 104 | 105 | # Imitating a 3d array by combining second and third dimensions via multiplication for maxout. 106 | self.x_map = nn.Sequential(nn.Linear(n_features, 240 * 5)) 107 | self.y_map = nn.Sequential(nn.Linear(opt.n_classes, 50 * 5)) 108 | self.j_map = nn.Sequential(nn.Linear(240 + 50, 240 * 4)) 109 | 110 | def forward(self, x, y): 111 | # maxout for x 112 | x = x.view(-1, n_features) 113 | x = self.x_map(x) 114 | x, _ = x.view(-1, 240, 5).max(dim=2) # pytorch outputs max values and indices 115 | # .. and y 116 | y = y.view(-1, opt.n_classes) 117 | y = self.y_map(y) 118 | y, _ = y.view(-1, 50, 5).max(dim=2) 119 | # joint maxout layer 120 | jmx = torch.cat((x, y), dim=1) 121 | jmx = self.j_map(jmx) 122 | jmx, _ = jmx.view(-1, 240, 4).max(dim=2) 123 | 124 | prob = self.model(jmx) 125 | return prob 126 | 127 | 128 | def mnist_data(): 129 | compose = transforms.Compose( 130 | [transforms.ToTensor(), 131 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 132 | ]) 133 | output_dir = './data/mnist' 134 | return datasets.MNIST(root=output_dir, train=True, 135 | transform=compose, download=True) 136 | 137 | mnist = mnist_data() 138 | batch_iterator = DataLoader(mnist, shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 139 | 140 | cuda = torch.cuda.is_available() 141 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 142 | LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor 143 | gan_loss = nn.BCELoss() 144 | 145 | generator = Generator() 146 | discriminator = Discriminator() 147 | 148 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 149 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 150 | 151 | 152 | # Loss record. 153 | g_losses = [] 154 | d_losses = [] 155 | epochs = [] 156 | loss_legend = ['Discriminator', 'Generator'] 157 | 158 | if cuda: 159 | generator = generator.cuda() 160 | discriminator = discriminator.cuda() 161 | 162 | # Weight initialization. 163 | generator.apply(weights_init_xavier) 164 | discriminator.apply(weights_init_xavier) 165 | 166 | 167 | for epoch in range(opt.n_epochs): 168 | print('Epoch {}'.format(epoch)) 169 | for i, (batch, labels) in enumerate(batch_iterator): 170 | 171 | real = Variable(Tensor(batch.size(0), 1).fill_(1), requires_grad=False) 172 | fake = Variable(Tensor(batch.size(0), 1).fill_(0), requires_grad=False) 173 | 174 | labels_onehot = Variable(Tensor(batch.size(0), opt.n_classes).zero_()) 175 | labels_ = labels.type(LongTensor) # Note: MNIST dataset is given as a longtensor, and we use .type() to feed it into cuda (if cuda available). 176 | labels_ = labels_.view(batch.size(0), 1) 177 | labels_onehot = labels_onehot.scatter_(1, labels_, 1) # As of 0.4, we don't have one-hot built-in function yet in pytorch. 178 | 179 | imgs_real = Variable(batch.type(Tensor)) 180 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1)) 181 | imgs_fake = generator(noise, labels_onehot) 182 | 183 | # == Discriminator update == # 184 | optimizer_D.zero_grad() 185 | 186 | # A small reminder: given classes c, prob. p, - c*logp - (1-c)*log(1-p) is the BCE/GAN loss. 187 | # Putting D(x) as p, x=real's class as 1, (..and same for fake with c=0) we arrive to BCE loss. 188 | # This is intuitively how well the discriminator can distinguish btw real & fake. 189 | d_loss = gan_loss(discriminator(imgs_real, labels_onehot), real) + \ 190 | gan_loss(discriminator(imgs_fake, labels_onehot), fake) 191 | 192 | d_loss.backward() 193 | optimizer_D.step() 194 | 195 | 196 | # == Generator update == # 197 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1)) 198 | imgs_fake = generator(noise, labels_onehot) 199 | 200 | optimizer_G.zero_grad() 201 | 202 | # Minimizing (1-log(d(g(noise))) is less stable than maximizing log(d(g)) [*]. 203 | # Since BCE loss is defined as a negative sum, maximizing [*] is == minimizing [*]'s negative. 204 | # Intuitively, how well does the G fool the D? 205 | g_loss = gan_loss(discriminator(imgs_fake, labels_onehot), real) 206 | 207 | g_loss.backward() 208 | optimizer_G.step() 209 | 210 | if vis: 211 | batches_done = epoch * len(batch_iterator) + i 212 | if batches_done % opt.sample_interval == 0: 213 | 214 | # Generate digits from 0, ... 9, 5 samples in each column. 215 | noise = Variable(Tensor(5*10, opt.latent_dim).normal_(0, 1)) 216 | 217 | labels_onehot = Variable(Tensor(5*10, opt.n_classes).zero_()) 218 | labels_ = torch.range(0, 9) 219 | labels_ = labels_.view(1, -1).repeat(5, 1).transpose(0, 1).contiguous().view(1, -1) # my version of matlab's repelem 220 | labels_ = labels_.type(LongTensor) 221 | labels_ = labels_.view(-1, 1) 222 | labels_onehot = labels_onehot.scatter_(1, labels_, 1) # As of 0.4, we don't have one-hot built-in function yet in pytorch. 223 | 224 | imgs_fake = Variable(generator(noise, labels_onehot)) 225 | 226 | # Keep a record of losses for plotting. 227 | epochs.append(epoch + i/len(batch_iterator)) 228 | g_losses.append(g_loss.item()) 229 | d_losses.append(d_loss.item()) 230 | 231 | # Display results on visdom page. 232 | vis.line( 233 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 234 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 235 | opts={ 236 | 'title': 'loss over time', 237 | 'legend': loss_legend, 238 | 'xlabel': 'epoch', 239 | 'ylabel': 'loss', 240 | 'width': 512, 241 | 'height': 512 242 | }, 243 | win=1) 244 | vis.images( 245 | imgs_fake.data[:50], 246 | nrow=5, win=2, 247 | opts={ 248 | 'title': 'GAN output [Epoch {}]'.format(epoch), 249 | 'width': 512, 250 | 'height': 512, 251 | } 252 | ) 253 | -------------------------------------------------------------------------------- /cgan/out/366b924f53bba0.svg: -------------------------------------------------------------------------------- 1 | 02040600.511.522.5DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /cgan/out/GAN output [Epoch 80].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/cgan/out/GAN output [Epoch 80].jpg -------------------------------------------------------------------------------- /ct-gan/ct_gan.py: -------------------------------------------------------------------------------- 1 | ''' Improving the Improved Training of Wasserstein GANs: A Consistency Term and Its Dual Effect ''' 2 | import argparse 3 | 4 | import torch 5 | from torch import nn, optim 6 | from torch.autograd.variable import Variable 7 | 8 | from torchvision import transforms, datasets 9 | from torch.utils.data import DataLoader 10 | 11 | from models import Discriminator, Generator 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--n_epochs', type=int, default=1000, help='number of epochs of training') 15 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 16 | parser.add_argument('--gamma', type=float, default=0.0003, help='adam: learning rate') 17 | parser.add_argument('--b1', type=float, default=0.9, help='adam: beta 1') 18 | parser.add_argument('--b2', type=float, default=0.999, help='adam: beta 2') 19 | parser.add_argument('--n_critic', type=int, default=5, help='number of critic iterations per generator iteration') 20 | parser.add_argument('--lambda_1', type=int, default=10, help='gradient penalty coefficient') 21 | parser.add_argument('--lambda_2', type=int, default=2, help='consistency term coefficient') 22 | parser.add_argument('--M', type=float, default=0.2, help='ct-gan hyperparameter') 23 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 24 | parser.add_argument('--latent_dim', type=int, default=128, help='number of input units to generate an image') 25 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 26 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 27 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 28 | parser.add_argument('--sample_interval', type=int, default=256, help='interval betwen image samples') 29 | opt = parser.parse_args() 30 | 31 | try: 32 | import visdom 33 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 34 | except ImportError: 35 | vis = None 36 | else: 37 | vis.close(None) # Clear all figures. 38 | 39 | def load_cifar10(img_size): 40 | compose = transforms.Compose( 41 | [transforms.Resize(img_size), 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 44 | ]) 45 | output_dir = './data/cifar10' 46 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True, 47 | transform=compose) 48 | return cifar 49 | 50 | cifar = load_cifar10(opt.img_size) 51 | batch_iterator = DataLoader(cifar, shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 52 | 53 | cuda = torch.cuda.is_available() 54 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 55 | gan_loss = nn.BCELoss() 56 | 57 | generator = Generator() 58 | discriminator = Discriminator() 59 | 60 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.gamma, betas=(opt.b1, opt.b2)) 61 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.gamma, betas=(opt.b1, opt.b2)) 62 | 63 | # Loss record. 64 | g_losses = [] 65 | d_losses = [] 66 | epochs = [] 67 | loss_legend = ['Discriminator', 'Generator'] 68 | 69 | if cuda: 70 | generator = generator.cuda() 71 | discriminator = discriminator.cuda() 72 | 73 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN. 74 | 75 | for epoch in range(opt.n_epochs): 76 | print('Epoch {}'.format(epoch)) 77 | for i, (batch, _) in enumerate(batch_iterator): 78 | # == Discriminator update == # 79 | for iter in range(opt.n_critic): 80 | # Sample real and fake images, using notation in paper. 81 | x = Variable(batch.type(Tensor)) 82 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1)) 83 | x_tilde = Variable(generator(noise), requires_grad=True) 84 | 85 | epsilon = Variable(Tensor(batch.size(0), 1, 1, 1).uniform_(0, 1)) 86 | 87 | x_hat = epsilon*x + (1 - epsilon)*x_tilde 88 | x_hat = torch.autograd.Variable(x_hat, requires_grad=True) 89 | 90 | # Put the interpolated data through critic. 91 | dw_x = discriminator(x_hat) 92 | grad_x = torch.autograd.grad(outputs=dw_x, inputs=x_hat, 93 | grad_outputs=Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1.0), requires_grad=False), 94 | create_graph=True, retain_graph=True, only_inputs=True) 95 | grad_x = grad_x[0].view(batch.size(0), -1) 96 | grad_x = grad_x.norm(p=2, dim=1) 97 | 98 | # Update critic. 99 | optimizer_D.zero_grad() 100 | 101 | # Standard WGAN loss. 102 | d_wgan_loss = torch.mean(discriminator(x_tilde)) - torch.mean(discriminator(x)) 103 | 104 | # WGAN-GP loss. 105 | d_wgp_loss = torch.mean((grad_x - 1)**2) 106 | 107 | ###### Consistency term. ###### 108 | dw_x1, dw_x1_i = discriminator(x, dropout=0.5, intermediate_output=True) # Perturb the input by applying dropout to hidden layers. 109 | dw_x2, dw_x2_i = discriminator(x, dropout=0.5, intermediate_output=True) 110 | # Using l2 norm as the distance metric d, referring to the official code (paper ambiguous on d). 111 | second_to_last_reg = ((dw_x1_i-dw_x2_i) ** 2).mean(dim=1).unsqueeze_(1).unsqueeze_(2).unsqueeze_(3) 112 | d_wct_loss = (dw_x1-dw_x2) ** 2 \ 113 | + 0.1 * second_to_last_reg \ 114 | - opt.M 115 | d_wct_loss, _ = torch.max(d_wct_loss, 0) # torch.max returns max, and the index of max 116 | 117 | # Combined loss. 118 | d_loss = d_wgan_loss + opt.lambda_1*d_wgp_loss + opt.lambda_2*d_wct_loss 119 | 120 | d_loss.backward() 121 | optimizer_D.step() 122 | 123 | # == Generator update == # 124 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1)) 125 | imgs_fake = generator(noise) 126 | 127 | optimizer_G.zero_grad() 128 | 129 | g_loss = -torch.mean(discriminator(imgs_fake)) 130 | 131 | g_loss.backward() 132 | optimizer_G.step() 133 | 134 | if vis: 135 | batches_done = epoch * len(batch_iterator) + i 136 | if batches_done % opt.sample_interval == 0: 137 | 138 | # Keep a record of losses for plotting. 139 | epochs.append(epoch + i/len(batch_iterator)) 140 | g_losses.append(g_loss.item()) 141 | d_losses.append(d_loss.item()) 142 | 143 | # Generate images for a given set of fixed noise 144 | # so we can track how the GAN learns. 145 | imgs_fake_fixed = generator(noise_fixed).detach().data 146 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) 147 | 148 | # Display results on visdom page. 149 | vis.line( 150 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 151 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 152 | opts={ 153 | 'title': 'loss over time', 154 | 'legend': loss_legend, 155 | 'xlabel': 'epoch', 156 | 'ylabel': 'loss', 157 | 'width': 512, 158 | 'height': 512 159 | }, 160 | win=1) 161 | vis.images( 162 | imgs_fake_fixed[:25], 163 | nrow=5, win=2, 164 | opts={ 165 | 'title': 'GAN output [Epoch {}]'.format(epoch), 166 | 'width': 512, 167 | 'height': 512, 168 | } 169 | ) -------------------------------------------------------------------------------- /ct-gan/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | # from torch.legacy.nn import Identity 4 | 5 | # Residual network. 6 | # CT-GAN basically uses the same architecture WGAN-GP uses, 7 | # but with a different loss. 8 | class MeanPoolConv(nn.Module): 9 | def __init__(self, n_input, n_output, k_size, kaiming_init=True): 10 | super(MeanPoolConv, self).__init__() 11 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 12 | if kaiming_init: 13 | nn.init.kaiming_uniform_(conv1.weight, mode='fan_in', nonlinearity='relu') 14 | self.model = nn.Sequential(conv1) 15 | def forward(self, x): 16 | out = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4.0 17 | out = self.model(out) 18 | return out 19 | 20 | class ConvMeanPool(nn.Module): 21 | def __init__(self, n_input, n_output, k_size, kaiming_init=True): 22 | super(ConvMeanPool, self).__init__() 23 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 24 | 25 | if kaiming_init: 26 | nn.init.kaiming_uniform_(conv1.weight, mode='fan_in', nonlinearity='relu') 27 | 28 | self.model = nn.Sequential(conv1) 29 | def forward(self, x): 30 | out = self.model(x) 31 | out = (out[:,:,::2,::2] + out[:,:,1::2,::2] + out[:,:,::2,1::2] + out[:,:,1::2,1::2]) / 4.0 32 | return out 33 | 34 | class UpsampleConv(nn.Module): 35 | def __init__(self, n_input, n_output, k_size, kaiming_init=True): 36 | super(UpsampleConv, self).__init__() 37 | 38 | conv_layer = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 39 | if kaiming_init: 40 | nn.init.kaiming_uniform_(conv_layer.weight, mode='fan_in', nonlinearity='relu') 41 | 42 | self.model = nn.Sequential( 43 | nn.PixelShuffle(2), 44 | conv_layer, 45 | ) 46 | 47 | def forward(self, x): 48 | x = x.repeat((1, 4, 1, 1)) # Weird concat of WGAN-GPs upsampling process. 49 | out = self.model(x) 50 | return out 51 | 52 | class ResidualBlock(nn.Module): 53 | def __init__(self, n_input, n_output, k_size, resample='up', bn=True, spatial_dim=None): 54 | super(ResidualBlock, self).__init__() 55 | 56 | self.resample = resample 57 | 58 | if resample == 'up': 59 | self.conv1 = UpsampleConv(n_input, n_output, k_size, kaiming_init=True) 60 | self.conv2 = nn.Conv2d(n_output, n_output, k_size, padding=(k_size-1)//2) 61 | self.conv_shortcut = UpsampleConv(n_input, n_output, k_size, kaiming_init=True) 62 | self.out_dim = n_output 63 | 64 | # Weight initialization. 65 | nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu') 66 | elif resample == 'down': 67 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 68 | self.conv2 = ConvMeanPool(n_input, n_output, k_size, kaiming_init=True) 69 | self.conv_shortcut = ConvMeanPool(n_input, n_output, k_size, kaiming_init=True) 70 | self.out_dim = n_output 71 | self.ln_dims = [n_input, spatial_dim, spatial_dim] # Define the dimensions for layer normalization. 72 | 73 | nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu') 74 | else: 75 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 76 | self.conv2 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 77 | self.conv_shortcut = None # Identity 78 | self.out_dim = n_input 79 | self.ln_dims = [n_input, spatial_dim, spatial_dim] 80 | 81 | nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu') 82 | nn.init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu') 83 | 84 | self.model = nn.Sequential( 85 | nn.BatchNorm2d(n_input) if bn else nn.LayerNorm(self.ln_dims), 86 | nn.ReLU(inplace=True), 87 | self.conv1, 88 | nn.BatchNorm2d(self.out_dim) if bn else nn.LayerNorm(self.ln_dims), 89 | nn.ReLU(inplace=True), 90 | self.conv2, 91 | ) 92 | 93 | def forward(self, x): 94 | if self.conv_shortcut is None: 95 | return x + self.model(x) 96 | else: 97 | return self.conv_shortcut(x) + self.model(x) 98 | 99 | class DiscBlock1(nn.Module): 100 | def __init__(self, n_output): 101 | super(DiscBlock1, self).__init__() 102 | 103 | self.conv1 = nn.Conv2d(3, n_output, 3, padding=(3-1)//2) 104 | self.conv2 = ConvMeanPool(n_output, n_output, 1, kaiming_init=True) 105 | self.conv_shortcut = MeanPoolConv(3, n_output, 1, kaiming_init=False) 106 | 107 | # Weight initialization: 108 | nn.init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu') 109 | 110 | self.model = nn.Sequential( 111 | self.conv1, 112 | nn.ReLU(inplace=True), 113 | self.conv2 114 | ) 115 | 116 | def forward(self, x): 117 | return self.conv_shortcut(x) + self.model(x) 118 | 119 | class Generator(nn.Module): 120 | def __init__(self): 121 | super(Generator, self).__init__() 122 | 123 | self.model = nn.Sequential( # 128 x 1 x 1 124 | nn.ConvTranspose2d(128, 128, 4, 1, 0), # 128 x 4 x 4 125 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 8 x 8 126 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 16 x 16 127 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 32 x 32 128 | nn.BatchNorm2d(128), 129 | nn.ReLU(inplace=True), 130 | nn.Conv2d(128, 3, 3, padding=(3-1)//2), # 3 x 32 x 32 (no kaiming init. here) 131 | nn.Tanh() 132 | ) 133 | 134 | def forward(self, z): 135 | img = self.model(z) 136 | return img 137 | 138 | class Discriminator(nn.Module): 139 | def __init__(self): 140 | super(Discriminator, self).__init__() 141 | n_output = 128 142 | ''' 143 | This is a parameter but since we experiment with a single size 144 | of 3 x 32 x 32 images, it is hardcoded here. 145 | ''' 146 | 147 | self.DiscBlock1 = DiscBlock1(n_output) # 128 x 16 x 16 148 | self.block1 = nn.Sequential( 149 | ResidualBlock(n_output, n_output, 3, resample='down', bn=False, spatial_dim=16), # 128 x 8 x 8 150 | ) 151 | self.block2 = nn.Sequential( 152 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8 153 | ) 154 | self.block3 = nn.Sequential( 155 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8 156 | ) 157 | 158 | self.l1 = nn.Sequential(nn.Linear(128, 1)) # 128 x 1 159 | 160 | def forward(self, x, dropout=0.0, intermediate_output=False): 161 | # x = x.view(-1, 3, 32, 32) 162 | y = self.DiscBlock1(x) 163 | y = self.block1(y) 164 | y = F.dropout(y, training=True, p=dropout) 165 | y = self.block2(y) 166 | y = F.dropout(y, training=True, p=dropout) 167 | y = self.block3(y) 168 | y = F.dropout(y, training=True, p=dropout) 169 | y = F.relu(y) 170 | y = y.view(x.size(0), 128, -1) 171 | y = y.mean(dim=2) 172 | critic_value = self.l1(y).unsqueeze_(1).unsqueeze_(2) # or *.view(x.size(0), 128, 1, 1, 1) 173 | 174 | if intermediate_output: 175 | return critic_value, y # y is the D_(.), intermediate layer given in paper. 176 | 177 | return critic_value -------------------------------------------------------------------------------- /ct-gan/out/367444be15895e.svg: -------------------------------------------------------------------------------- 1 | 051015202530−150−100−50050100150200250DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /ct-gan/out/GAN output [Epoch 34].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/ct-gan/out/GAN output [Epoch 34].jpg -------------------------------------------------------------------------------- /dcgan/dcgan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch import nn, optim 6 | from torch.autograd.variable import Variable 7 | 8 | from torchvision import transforms, datasets 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import save_image 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--n_epochs', type=int, default=1000, help='number of epochs of training') 14 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 15 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate') 16 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 17 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 18 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 19 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 20 | parser.add_argument('--channels', type=int, default=3, help='number of image channels') 21 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 22 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 23 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples') 24 | opt = parser.parse_args() 25 | 26 | # Just trying if these different initializations 27 | # matter that much. 28 | def weights_init_xavier(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 31 | m.weight.data.normal_(0.0, 0.02) 32 | elif classname.find('BatchNorm') != -1: 33 | m.weight.data.normal_(1.0, 0.02) 34 | m.bias.data.fill_(0) 35 | 36 | try: 37 | import visdom 38 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 39 | except ImportError: 40 | vis = None 41 | else: 42 | vis.close(None) # Clear all figures. 43 | 44 | 45 | img_dims = (opt.channels, opt.img_size, opt.img_size) 46 | n_features = opt.channels * opt.img_size * opt.img_size 47 | 48 | # Model is taken from Fig. 1 of the DCGAN paper. 49 | # BN before ReLU? https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/ 50 | # I tried BN -> ReLU but got poor results. 51 | # Here I skip one layer of DCGAN because I use CIFAR10 (32x32) 52 | # instead of LSUN (64x64). LSUN is too large to load and train (40 GB for a single bedroom category) 53 | 54 | ''' 55 | UPDATE: Both disc. and generator made 56 | fully convolutional, unlike the original 57 | version of the paper. The code is commented 58 | out for reference. This is the standard practice 59 | in most of the networks today. 60 | ''' 61 | 62 | class Generator(nn.Module): 63 | def __init__(self): 64 | super(Generator, self).__init__() 65 | 66 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0): 67 | block = [ 68 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False), 69 | nn.ReLU(inplace=True), 70 | nn.BatchNorm2d(n_output), 71 | ] 72 | return block 73 | 74 | self.project = nn.Sequential( 75 | nn.Linear(opt.latent_dim, 256 * 4 * 4), 76 | nn.BatchNorm1d(256 * 4 * 4), 77 | nn.ReLU(inplace=True), 78 | ) 79 | 80 | # pytorch doesn't have padding='same' as keras does. 81 | # hence we have to pad manually to get the desired behavior. 82 | # from the manual: 83 | # o = (i−1)∗s−2p+k, i/o = in/out-put, k = kernel size, p = padding, s = stride. 84 | 85 | self.model = nn.Sequential( 86 | *convlayer(opt.latent_dim, 256, 4, 1, 0), # Fully connected layer via convolution. 87 | *convlayer(256, 128, 4, 2, 1), # 4->8, (4-1)*2-2p+4 = 8, p = 1 88 | *convlayer(128, 64, 4, 2, 1), # 8->16, p = 1 89 | nn.ConvTranspose2d(64, opt.channels, 4, 2, 1), 90 | nn.Tanh() 91 | ) 92 | # Tanh > Image values are between [-1, 1] 93 | 94 | def forward(self, z): 95 | # p = self.project(z) 96 | # p = p.view(-1, 256, 4, 4) # Project and reshape (notice that pytorch uses NCHW format) 97 | img = self.model(z) 98 | img = img.view(z.size(0), *img_dims) 99 | return img 100 | 101 | # Discriminator mirrors the generator, like and encoder-decoder in reverse. 102 | class Discriminator(nn.Module): 103 | def __init__(self): 104 | super(Discriminator, self).__init__() 105 | 106 | # Again, we must ensure the same padding. 107 | # From the manual: 108 | # o = (i+2p−d∗(kernel_size−1)−1)/s+1, d = dilation (default = 1). 109 | 110 | self.model = nn.Sequential( 111 | nn.Conv2d(opt.channels, 64, 4, 2, 1, bias=False), # 32-> 16, (32+2p-3-1)/2 + 1 = 16, p = 1 112 | nn.LeakyReLU(0.2, inplace=True), 113 | # nn.BatchNorm2d(64), # IS IT CRUCIAL TO SKIP BN AT FIRST LAYER OF DISCRIMINATOR? 114 | nn.Conv2d(64, 128, 4, 2, 1, bias=False), 115 | nn.LeakyReLU(0.2, inplace=True), 116 | nn.BatchNorm2d(128), 117 | nn.Conv2d(128, 256, 4, 2, 1, bias=False), 118 | nn.LeakyReLU(0.2, inplace=True), 119 | nn.BatchNorm2d(256), 120 | nn.Conv2d(256, 1, 4, 1, 0, bias=False), # FC with Conv. 121 | nn.Sigmoid() 122 | ) 123 | 124 | # we used 3 convolutional blocks, hence 2^3 = 8 times downsampling. 125 | semantic_dim = opt.img_size // (2**3) 126 | 127 | self.l1 = nn.Sequential( 128 | nn.Linear(256 * semantic_dim**2, 1), 129 | nn.Sigmoid() 130 | ) 131 | 132 | def forward(self, img): 133 | prob = self.model(img) 134 | # prob = self.l1(prob.view(prob.size(0), -1)) 135 | return prob 136 | 137 | 138 | # Beware: CIFAR10 download is slow. 139 | def custom_data(CIFAR_CLASS = 5): # Only working with dogs. 140 | compose = transforms.Compose( 141 | [transforms.Resize(opt.img_size), 142 | transforms.ToTensor(), 143 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 144 | ]) 145 | output_dir = './data/cifar10' 146 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True, 147 | transform=compose) 148 | 149 | # Our custom cifar with a single category. 150 | Tensor__ = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor 151 | 152 | features = Tensor__(5000, 3, 32, 32) # 5k samples in each category. 153 | targets = Tensor__(5000, 1) 154 | 155 | # Go through CIFAR to pick only the desired class.. 156 | it = 0 157 | for i in range(cifar.__len__()): 158 | sample = cifar.__getitem__(i) 159 | if sample[1] == CIFAR_CLASS: 160 | features[it, ...] = sample[0] 161 | targets[it] = sample[1] 162 | it += 1 163 | return features, targets 164 | 165 | features, targets = custom_data(CIFAR_CLASS=1) # Cars. 166 | batch_iterator = DataLoader(torch.utils.data.TensorDataset(features, targets), shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 167 | 168 | # Save a batch of real images for reference. 169 | os.makedirs('./out', exist_ok=True) 170 | save_image(next(iter(batch_iterator))[0][:25, ...], './out/real_samples.png', nrow=5, normalize=True) 171 | 172 | cuda = torch.cuda.is_available() 173 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 174 | gan_loss = nn.BCELoss() 175 | 176 | generator = Generator() 177 | discriminator = Discriminator() 178 | 179 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 180 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 181 | 182 | # Loss record. 183 | g_losses = [] 184 | d_losses = [] 185 | epochs = [] 186 | loss_legend = ['Discriminator', 'Generator'] 187 | 188 | if cuda: 189 | generator = generator.cuda() 190 | discriminator = discriminator.cuda() 191 | 192 | generator.apply(weights_init_xavier) 193 | discriminator.apply(weights_init_xavier) 194 | 195 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN. 196 | 197 | for epoch in range(opt.n_epochs): 198 | print('Epoch {}'.format(epoch)) 199 | for i, (batch, _) in enumerate(batch_iterator): 200 | 201 | real = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1), requires_grad=False) 202 | fake = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(0), requires_grad=False) 203 | 204 | imgs_real = Variable(batch.type(Tensor), requires_grad=False) 205 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) 206 | imgs_fake = Variable(generator(noise), requires_grad=False) 207 | 208 | 209 | # == Discriminator update == # 210 | optimizer_D.zero_grad() 211 | 212 | # A small reminder: given classes c, prob. p, - c*logp - (1-c)*log(1-p) is the BCE/GAN loss. 213 | # Putting D(x) as p, x=real's class as 1, (..and same for fake with c=0) we arrive to BCE loss. 214 | # This is intuitively how well the discriminator can distinguish btw real & fake. 215 | d_loss = gan_loss(discriminator(imgs_real), real) + \ 216 | gan_loss(discriminator(imgs_fake), fake) 217 | 218 | d_loss.backward() 219 | optimizer_D.step() 220 | 221 | 222 | # == Generator update == # 223 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) 224 | imgs_fake = generator(noise) 225 | 226 | optimizer_G.zero_grad() 227 | 228 | # Minimizing (1-log(d(g(noise))) is less stable than maximizing log(d(g)) [*]. 229 | # Since BCE loss is defined as a negative sum, maximizing [*] is == minimizing [*]'s negative. 230 | # Intuitively, how well does the G fool the D? 231 | g_loss = gan_loss(discriminator(imgs_fake), real) 232 | 233 | g_loss.backward() 234 | optimizer_G.step() 235 | 236 | if vis: 237 | batches_done = epoch * len(batch_iterator) + i 238 | if batches_done % opt.sample_interval == 0: 239 | 240 | # Keep a record of losses for plotting. 241 | epochs.append(epoch + i/len(batch_iterator)) 242 | g_losses.append(g_loss.item()) 243 | d_losses.append(d_loss.item()) 244 | 245 | # Generate images for a given set of fixed noise 246 | # so we can track how the GAN learns. 247 | imgs_fake_fixed = generator(noise_fixed).detach().data 248 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom. 249 | 250 | # Display results on visdom page. 251 | vis.line( 252 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 253 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 254 | opts={ 255 | 'title': 'loss over time', 256 | 'legend': loss_legend, 257 | 'xlabel': 'epoch', 258 | 'ylabel': 'loss', 259 | 'width': 512, 260 | 'height': 512 261 | }, 262 | win=1) 263 | vis.images( 264 | imgs_fake_fixed, 265 | nrow=5, win=2, 266 | opts={ 267 | 'title': 'GAN output [Epoch {}]'.format(epoch), 268 | 'width': 512, 269 | 'height': 512, 270 | } 271 | ) 272 | -------------------------------------------------------------------------------- /dcgan/out/GAN output [Epoch 358].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/dcgan/out/GAN output [Epoch 358].jpg -------------------------------------------------------------------------------- /dcgan/out/real_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/dcgan/out/real_samples.png -------------------------------------------------------------------------------- /gan/gan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch import nn, optim 5 | from torch.autograd.variable import Variable 6 | 7 | from torchvision import transforms, datasets 8 | from torch.utils.data import DataLoader 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') 12 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 13 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') 14 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 15 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 16 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 17 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 18 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension') 19 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 20 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 21 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 22 | parser.add_argument('--sample_interval', type=int, default=512, help='interval betwen image samples') 23 | opt = parser.parse_args() 24 | 25 | 26 | # Set up the visdom, if exists. 27 | # If not, visualization of the training will not be possible. 28 | # If visdom exists, it should be also on. 29 | # Use 'python -m visdom.server' command to activate the server. 30 | try: 31 | import visdom 32 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 33 | except ImportError: 34 | vis = None 35 | else: 36 | vis.close(None) # Clear all figures. 37 | 38 | img_dims = (opt.channels, opt.img_size, opt.img_size) 39 | n_features = opt.channels * opt.img_size * opt.img_size 40 | 41 | # This does not directly follow the paper. I did not fully understand the 42 | # order of the architecture they describe in there. 43 | class Generator(nn.Module): 44 | def __init__(self): 45 | super(Generator, self).__init__() 46 | 47 | def ganlayer(n_input, n_output, dropout=True): 48 | pipeline = [nn.Linear(n_input, n_output)] 49 | pipeline.append(nn.LeakyReLU(0.2, inplace=True)) 50 | if dropout: 51 | pipeline.append(nn.Dropout(0.25)) 52 | return pipeline 53 | 54 | self.model = nn.Sequential( 55 | *ganlayer(opt.latent_dim, 128, dropout=False), 56 | *ganlayer(128, 256), 57 | *ganlayer(256, 512), 58 | *ganlayer(512, 1024), 59 | nn.Linear(1024, n_features), 60 | nn.Tanh() 61 | ) 62 | # Tanh > Image values are between [-1, 1] 63 | 64 | def forward(self, z): 65 | img = self.model(z) 66 | img = img.view(img.size(0), *img_dims) 67 | return img 68 | 69 | class Discriminator(nn.Module): 70 | def __init__(self): 71 | super(Discriminator, self).__init__() 72 | 73 | self.model = nn.Sequential( 74 | nn.Linear(n_features, 512), 75 | nn.LeakyReLU(0.2, inplace=True), 76 | nn.Linear(512, 256), 77 | nn.LeakyReLU(0.2, inplace=True), 78 | nn.Linear(256, 1), 79 | nn.Sigmoid() 80 | ) 81 | 82 | def forward(self, img): 83 | flatimg = img.view(img.size(0), -1) 84 | prob = self.model( flatimg ) 85 | return prob 86 | 87 | 88 | def mnist_data(): 89 | compose = transforms.Compose( 90 | [transforms.ToTensor(), 91 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 92 | ]) 93 | output_dir = './data/mnist' 94 | return datasets.MNIST(root=output_dir, train=True, 95 | transform=compose, download=True) 96 | 97 | mnist = mnist_data() 98 | batch_iterator = DataLoader(mnist, shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 99 | 100 | cuda = torch.cuda.is_available() 101 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 102 | gan_loss = nn.BCELoss() 103 | 104 | generator = Generator() 105 | discriminator = Discriminator() 106 | 107 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 108 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 109 | 110 | # Loss record. 111 | g_losses = [] 112 | d_losses = [] 113 | epochs = [] 114 | loss_legend = ['Discriminator', 'Generator'] 115 | 116 | if cuda: 117 | generator = generator.cuda() 118 | discriminator = discriminator.cuda() 119 | 120 | for epoch in range(opt.n_epochs): 121 | print('Epoch {}'.format(epoch)) 122 | for i, (batch, _) in enumerate(batch_iterator): 123 | 124 | real = Variable(Tensor(batch.size(0), 1).fill_(1), requires_grad=False) 125 | fake = Variable(Tensor(batch.size(0), 1).fill_(0), requires_grad=False) 126 | 127 | imgs_real = Variable(batch.type(Tensor)) 128 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1)) 129 | imgs_fake = Variable(generator(noise), requires_grad=False) 130 | 131 | 132 | # == Discriminator update == # 133 | optimizer_D.zero_grad() 134 | 135 | # A small reminder: given classes c, prob. p, - c*logp - (1-c)*log(1-p) is the BCE/GAN loss. 136 | # Putting D(x) as p, x=real's class as 1, (..and same for fake with c=0) we arrive to BCE loss. 137 | # This is intuitively how well the discriminator can distinguish btw real & fake. 138 | d_loss = gan_loss(discriminator(imgs_real), real) + \ 139 | gan_loss(discriminator(imgs_fake), fake) 140 | 141 | d_loss.backward() 142 | optimizer_D.step() 143 | 144 | 145 | # == Generator update == # 146 | noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1)) 147 | imgs_fake = generator(noise) 148 | 149 | optimizer_G.zero_grad() 150 | 151 | # Minimizing (1-log(d(g(noise))) is less stable than maximizing log(d(g)) [*]. 152 | # Since BCE loss is defined as a negative sum, maximizing [*] is == minimizing [*]'s negative. 153 | # Intuitively, how well does the G fool the D? 154 | g_loss = gan_loss(discriminator(imgs_fake), real) 155 | 156 | g_loss.backward() 157 | optimizer_G.step() 158 | 159 | if vis: 160 | batches_done = epoch * len(batch_iterator) + i 161 | if batches_done % opt.sample_interval == 0: 162 | 163 | # Keep a record of losses for plotting. 164 | epochs.append(epoch + i/len(batch_iterator)) 165 | g_losses.append(g_loss.item()) 166 | d_losses.append(d_loss.item()) 167 | 168 | # Display results on visdom page. 169 | vis.line( 170 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 171 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 172 | opts={ 173 | 'title': 'loss over time', 174 | 'legend': loss_legend, 175 | 'xlabel': 'epoch', 176 | 'ylabel': 'loss', 177 | 'width': 512, 178 | 'height': 512 179 | }, 180 | win=1) 181 | vis.images( 182 | imgs_fake.data[:25], 183 | nrow=5, win=2, 184 | opts={ 185 | 'title': 'GAN output [Epoch {}]'.format(epoch), 186 | 'width': 512, 187 | 'height': 512, 188 | } 189 | ) 190 | -------------------------------------------------------------------------------- /gan/out/366b962ff767a0.svg: -------------------------------------------------------------------------------- 1 | 02040608012345DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /gan/out/GAN output [Epoch 95].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/gan/out/GAN output [Epoch 95].jpg -------------------------------------------------------------------------------- /infogan/out/(Nothing is varied) GAN output [Epoch 14].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/infogan/out/(Nothing is varied) GAN output [Epoch 14].jpg -------------------------------------------------------------------------------- /infogan/out/(Only vary c1) GAN output [Epoch 14].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/infogan/out/(Only vary c1) GAN output [Epoch 14].jpg -------------------------------------------------------------------------------- /infogan/out/(Only vary c2) GAN output [Epoch 14].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/infogan/out/(Only vary c2) GAN output [Epoch 14].jpg -------------------------------------------------------------------------------- /lsgan/lsgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Using Eqn. 9 of the lsgan paper as the loss function. 3 | Figure 2 of paper's v1 is used as the gen/discr. combo. 4 | Note that there is a v3 (as of 2018, July). 5 | A significant design difference is the use of BN before ReLU 6 | unlike the regular GANs (ReLU > BN), (BN > ReLU). 7 | ''' 8 | 9 | import argparse 10 | 11 | import torch 12 | from torch import nn, optim 13 | from torch.autograd.variable import Variable 14 | 15 | from torchvision import transforms, datasets 16 | from torch.utils.data import DataLoader 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') 20 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 21 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate') 22 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 23 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 24 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 25 | parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension') 26 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 27 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 28 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 29 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples') 30 | opt = parser.parse_args() 31 | 32 | def weights_init_normal(m): 33 | classname = m.__class__.__name__ 34 | if classname.find('Conv') != -1: 35 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 36 | elif classname.find('BatchNorm') != -1: 37 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 38 | torch.nn.init.constant_(m.bias.data, 0.0) 39 | 40 | 41 | try: 42 | import visdom 43 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 44 | except ImportError: 45 | vis = None 46 | else: 47 | vis.close(None) # Clear all figures. 48 | 49 | 50 | class Generator(nn.Module): 51 | def __init__(self): 52 | super(Generator, self).__init__() 53 | 54 | def convlayer(n_input, n_output, k_size=5, stride=2, padding=0, output_padding=0): 55 | block = [ 56 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False, output_padding=output_padding), 57 | nn.BatchNorm2d(n_output), 58 | nn.ReLU(inplace=True), 59 | ] 60 | return block 61 | 62 | # MNIST (not constantinople.. sorry not LSUN) 63 | self.model = nn.Sequential( 64 | *convlayer(opt.latent_dim, 128, 7, 1, 0), # Fully connected, 128 x 7 x 7 65 | *convlayer(128, 64, 5, 2, 2, output_padding=1), # 64 x 14 x 14 66 | *convlayer(64, 32, 5, 2, 2, output_padding=1), # 32 x 28 x 28 67 | nn.ConvTranspose2d(32, opt.channels, 5, 1, 2), # 1 x 28 x 28 68 | nn.Tanh() 69 | ) 70 | 71 | def forward(self, z): 72 | img = self.model(z) 73 | return img 74 | 75 | 76 | class Discriminator(nn.Module): 77 | def __init__(self): 78 | super(Discriminator, self).__init__() 79 | 80 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, normalize=True, dilation=1): 81 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False, dilation=dilation),] 82 | if normalize: 83 | block.append(nn.BatchNorm2d(n_output)) 84 | block.append(nn.LeakyReLU(0.2, inplace=True)) 85 | return block 86 | 87 | self.conv_block = nn.Sequential( # 1 x 28 x 28 88 | *convlayer(opt.channels, 32, 5, 2, 2, normalize=False), # 32 x 14 x 14 89 | *convlayer(32, 64, 5, 2, 2), # 64 x 7 x 7 90 | ) 91 | 92 | self.fc_block = nn.Sequential( 93 | nn.Linear(64 * 7 * 7, 512), 94 | nn.BatchNorm1d(512), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | nn.Linear(512, 1), 97 | ) 98 | 99 | def forward(self, img): 100 | conv_out = self.conv_block(img) 101 | conv_out = conv_out.view(img.size(0), 64 * 7 * 7) 102 | l2_value = self.fc_block(conv_out) 103 | l2_value = l2_value.unsqueeze_(dim=2).unsqueeze_(dim=3) 104 | return l2_value 105 | 106 | 107 | def mnist_data(): 108 | compose = transforms.Compose( 109 | [transforms.Resize(opt.img_size), 110 | transforms.ToTensor(), 111 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 112 | ]) 113 | output_dir = './data/mnist' 114 | return datasets.MNIST(root=output_dir, train=True, 115 | transform=compose, download=True) 116 | 117 | mnist = mnist_data() 118 | batch_iterator = DataLoader(mnist, shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 119 | 120 | 121 | cuda = torch.cuda.is_available() 122 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 123 | gan_loss = nn.MSELoss() # LSGAN's loss, a shortcut to torch.mean( (x-y) ** 2 ) 124 | 125 | generator = Generator() 126 | discriminator = Discriminator() 127 | 128 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 129 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 130 | 131 | # Loss record. 132 | g_losses = [] 133 | d_losses = [] 134 | epochs = [] 135 | loss_legend = ['Discriminator', 'Generator'] 136 | 137 | if cuda: 138 | generator = generator.cuda() 139 | discriminator = discriminator.cuda() 140 | 141 | generator.apply(weights_init_normal) 142 | discriminator.apply(weights_init_normal) 143 | 144 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN. 145 | 146 | for epoch in range(opt.n_epochs): 147 | print('Epoch {}'.format(epoch)) 148 | for i, (batch, _) in enumerate(batch_iterator): 149 | 150 | real = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1.0), requires_grad=False) 151 | fake = Variable(Tensor(batch.size(0), 1, 1, 1).fill_(0.0), requires_grad=False) 152 | 153 | imgs_real = Variable(batch.type(Tensor), requires_grad=False) 154 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) 155 | imgs_fake = generator(noise) # Variable(generator(noise), requires_grad=False) 156 | 157 | 158 | # == Discriminator update == # 159 | optimizer_D.zero_grad() 160 | 161 | d_loss = 0.5*gan_loss(discriminator(imgs_real), real) + \ 162 | 0.5*gan_loss(discriminator(imgs_fake.data), fake) 163 | 164 | d_loss.backward() 165 | optimizer_D.step() 166 | 167 | 168 | # == Generator update == # 169 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) 170 | imgs_fake = generator(noise) 171 | 172 | optimizer_G.zero_grad() 173 | 174 | g_loss = 0.5*gan_loss(discriminator(imgs_fake), real) 175 | 176 | g_loss.backward() 177 | optimizer_G.step() 178 | 179 | if vis: 180 | batches_done = epoch * len(batch_iterator) + i 181 | if batches_done % opt.sample_interval == 0: 182 | 183 | # Keep a record of losses for plotting. 184 | epochs.append(epoch + i/len(batch_iterator)) 185 | g_losses.append(g_loss.item()) 186 | d_losses.append(d_loss.item()) 187 | 188 | # Generate images for a given set of fixed noise 189 | # so we can track how the GAN learns. 190 | imgs_fake_fixed = generator(noise_fixed).detach().data 191 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom. 192 | 193 | # Display results on visdom page. 194 | vis.line( 195 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 196 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 197 | opts={ 198 | 'title': 'loss over time', 199 | 'legend': loss_legend, 200 | 'xlabel': 'epoch', 201 | 'ylabel': 'loss', 202 | 'width': 512, 203 | 'height': 512 204 | }, 205 | win=1) 206 | vis.images( 207 | imgs_fake_fixed, 208 | nrow=5, win=2, 209 | opts={ 210 | 'title': 'GAN output [Epoch {}]'.format(epoch), 211 | 'width': 512, 212 | 'height': 512, 213 | } 214 | ) 215 | -------------------------------------------------------------------------------- /lsgan/out/3677c4a9714c1c.svg: -------------------------------------------------------------------------------- 1 | 024681000.10.20.30.4DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /lsgan/out/GAN output [Epoch 11].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/lsgan/out/GAN output [Epoch 11].jpg -------------------------------------------------------------------------------- /pix2pix/datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils import data 3 | import glob 4 | import torchvision.transforms as transforms 5 | 6 | class Dataset(data.Dataset): 7 | 'Characterizes a dataset for PyTorch' 8 | def __init__(self, data_folder, transform, stage): 9 | 'Initialization' 10 | self.images = glob.glob('{}/{}/*'.format(data_folder, stage)) 11 | self.transform = transforms.Compose(transform) 12 | 'Get image size' 13 | _, self.img_size = Image.open(self.images[0]).size 14 | 15 | def __len__(self): 16 | 'Denotes the total number of samples' 17 | return len(self.images) 18 | 19 | def __getitem__(self, index): 20 | 'Generates one sample of data' 21 | # Select sample 22 | 23 | # Load data and get label 24 | X = Image.open(self.images[index]) 25 | B = X.crop((0, 0, self.img_size, self.img_size)) # (left, upper, right, lower) 26 | A = X.crop((self.img_size, 0, self.img_size+self.img_size, self.img_size)) 27 | 28 | return self.transform(A), self.transform(B) 29 | -------------------------------------------------------------------------------- /pix2pix/download_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz 3 | TAR_FILE=./$FILE.tar.gz 4 | TARGET_DIR=./$FILE/ 5 | wget -N $URL -O $TAR_FILE 6 | mkdir $TARGET_DIR 7 | tar -zxvf $TAR_FILE 8 | rm $TAR_FILE -------------------------------------------------------------------------------- /pix2pix/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | def weights_init_normal(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 9 | elif classname.find('BatchNorm2d') != -1: 10 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | 13 | ''' 14 | Trying to make it compact, I'm afraid 15 | this u-net implementation became very 16 | messy. In future, simplicity before 17 | compactnesss. 18 | ''' 19 | 20 | class Generator(nn.Module): 21 | def __init__(self, in_channels=3, out_channels=3): 22 | super(Generator, self).__init__() 23 | 24 | ##### Encoder ##### 25 | def convblock(n_input, n_output, k_size=4, stride=2, padding=1, normalize=True): 26 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)] 27 | if normalize: 28 | block.append(nn.BatchNorm2d(n_output)) 29 | block.append(nn.LeakyReLU(0.2, inplace=True)) 30 | return block 31 | 32 | g_layers = [64, 128, 256, 512, 512, 512, 512, 512] # Generator feature maps (ref: pix2pix paper) 33 | self.encoder = nn.ModuleList(convblock(in_channels, g_layers[0], normalize=False)) # 1st layer is not normalized. 34 | for iter in range(1, len(g_layers)): 35 | self.encoder += convblock(g_layers[iter-1], g_layers[iter]) 36 | # Ex.: Why use ModuleList instead of python lists? 37 | 38 | ##### Decoder ##### 39 | def convdblock(n_input, n_output, k_size=4, stride=2, padding=1, dropout=0.0): 40 | block = [ 41 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False), 42 | nn.BatchNorm2d(n_output) 43 | ] 44 | if dropout: # Mixing BN and Dropout is not recommended in general. 45 | block.append(nn.Dropout(dropout)) 46 | block.append(nn.ReLU(inplace=True)) 47 | return block 48 | 49 | d_layers = [512, 1024, 1024, 1024, 1024, 512, 256, 128] 50 | self.decoder = nn.ModuleList(convdblock(g_layers[-1], d_layers[0])) 51 | for iter in range(1, len(d_layers)-1): 52 | self.decoder += convdblock(d_layers[iter], d_layers[iter+1]//2) 53 | 54 | self.model = nn.Sequential( 55 | nn.ConvTranspose2d(d_layers[-1], out_channels, 4, 2, 1), 56 | nn.Tanh() 57 | ) 58 | 59 | 60 | def forward(self, z): 61 | # Encode. 62 | e = [self.encoder[1](self.encoder[0](z))] 63 | module_iter = 2 64 | for iter in range(1, 8): # we have 8 downsampling layers 65 | result = self.encoder[module_iter+0](e[iter-1]) 66 | result = self.encoder[module_iter+1](result) 67 | result = self.encoder[module_iter+2](result) 68 | e.append(result) 69 | module_iter += 3 70 | # Decode. 71 | # First d-layer. 72 | d1 = self.decoder[2](self.decoder[1](self.decoder[0](e[-1]))) 73 | d1 = torch.cat((d1, e[-2]), dim=1) 74 | d = [d1] 75 | module_iter = 3 76 | for iter in range(1, 7): # we have 7 upsampling layers 77 | result = self.decoder[module_iter+0](d[iter-1]) 78 | result = self.decoder[module_iter+1](result) 79 | result = self.decoder[module_iter+2](result) 80 | result = torch.cat((result, e[-(iter+2)]), dim=1) # Concating n-i^th layer with i^th. 81 | d.append(result) 82 | module_iter += 3 83 | 84 | # Pass the decoder output through a "flattening" layer to get the image. 85 | img = self.model(d[-1]) 86 | return img 87 | 88 | 89 | class Discriminator(nn.Module): 90 | def __init__(self, a_channels=3, b_channels=3): 91 | super(Discriminator, self).__init__() 92 | 93 | def convblock(n_input, n_output, k_size=4, stride=2, padding=1, normalize=True): 94 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)] 95 | if normalize: 96 | block.append(nn.BatchNorm2d(n_output)) 97 | block.append(nn.LeakyReLU(0.2, inplace=True)) 98 | return block 99 | 100 | self.model = nn.Sequential( 101 | *convblock(a_channels + b_channels, 64, normalize=False), 102 | *convblock(64, 128), 103 | *convblock(128, 256), 104 | *convblock(256, 512), 105 | ) 106 | 107 | self.l1 = nn.Linear(512 * 16 * 16, 1) 108 | 109 | def forward(self, img_A, img_B): 110 | img = torch.cat((img_A, img_B), dim=1) 111 | conv_out = self.model(img) 112 | conv_out = conv_out.view(img_A.size(0), -1) 113 | prob = self.l1(conv_out) 114 | prob = F.sigmoid(prob) 115 | return prob -------------------------------------------------------------------------------- /pix2pix/out/3685e66f775102.svg: -------------------------------------------------------------------------------- 1 | 05010015020001020304050DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /pix2pix/out/3685e7054e8534.svg: -------------------------------------------------------------------------------- 1 | 05010015020001020304050DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /pix2pix/out/GAN output [Epoch 216].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/pix2pix/out/GAN output [Epoch 216].jpg -------------------------------------------------------------------------------- /pix2pix/out/GAN output [Epoch 240].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/pix2pix/out/GAN output [Epoch 240].jpg -------------------------------------------------------------------------------- /pix2pix/pix2pix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | To run, use 3 | bash download_dataset.sh cityscapes 4 | to download the data into the same 5 | directory as the pix2pix.py. 6 | ''' 7 | 8 | ''' 9 | pix2pix paper suggests using a noise 10 | vector along with the real image (to generate 11 | translated images). However, I noticed from 12 | separate experiments that the noise is simply 13 | filtered out, i.e. it does not provide any diversity 14 | for a given A -> B conversion, hence omitted 15 | here. In fact, there is a GAN called BicycleGAN which 16 | addresses the very same issue by employing 17 | autoencoders. 18 | ''' 19 | 20 | import argparse 21 | 22 | import torch 23 | from torch import nn, optim 24 | from torch.autograd.variable import Variable 25 | 26 | from torchvision import transforms, datasets 27 | from torch.utils.data import DataLoader 28 | 29 | from models import Generator, Discriminator, weights_init_normal 30 | from datasets import Dataset 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--n_epochs', type=int, default=1000, help='number of epochs of training') 34 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 35 | parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') 36 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 37 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 38 | parser.add_argument('--lambda_l1', type=int, default=100, help='penalty term for deviation from original images') 39 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 40 | parser.add_argument('--data_folder', type=str, default="cityscapes", help='where the training image pairs are located') 41 | parser.add_argument('--img_size', type=int, default=256, help='size of the input images') 42 | parser.add_argument('--channels', type=int, default=3, help='number of image channels') 43 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 44 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 45 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples') 46 | opt = parser.parse_args() 47 | 48 | 49 | try: 50 | import visdom 51 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 52 | except ImportError: 53 | vis = None 54 | else: 55 | vis.close(None) # Clear all figures. 56 | 57 | # Cityscapes dataset loader. 58 | transform_image = [ 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 61 | ] 62 | data_train = Dataset(opt.data_folder, transform_image, stage='train') 63 | data_val = Dataset(opt.data_folder, transform_image, stage='val') 64 | 65 | params = {'batch_size': opt.batch_size, 'shuffle': True} 66 | dataloader_train = DataLoader(data_val, **params) 67 | 68 | params = {'batch_size': 5, 'shuffle': True} 69 | dataloader_val = DataLoader(data_val, **params) 70 | 71 | cuda = torch.cuda.is_available() 72 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 73 | gan_loss = nn.BCELoss() 74 | l1_loss = nn.L1Loss() 75 | 76 | generator = Generator() 77 | discriminator = Discriminator() 78 | 79 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 80 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 81 | 82 | # Loss record. 83 | g_losses = [] 84 | d_losses = [] 85 | epochs = [] 86 | loss_legend = ['Discriminator', 'Generator'] 87 | 88 | if cuda: 89 | generator = generator.cuda() 90 | discriminator = discriminator.cuda() 91 | gan_loss = gan_loss.cuda() 92 | l1_loss = l1_loss.cuda() 93 | 94 | 95 | generator.apply(weights_init_normal) 96 | discriminator.apply(weights_init_normal) 97 | 98 | for epoch in range(opt.n_epochs): 99 | print('Epoch {}'.format(epoch)) 100 | for i, (batch_A, batch_B) in enumerate(dataloader_train): 101 | 102 | real = Variable(Tensor(batch_A.size(0), 1).fill_(1), requires_grad=False) 103 | fake = Variable(Tensor(batch_A.size(0), 1).fill_(0), requires_grad=False) 104 | 105 | imgs_real_A = Variable(batch_A.type(Tensor)) 106 | imgs_real_B = Variable(batch_B.type(Tensor)) 107 | 108 | # == Discriminator update == # 109 | optimizer_D.zero_grad() 110 | 111 | imgs_fake = Variable(generator(imgs_real_A.detach())) 112 | 113 | d_loss = gan_loss(discriminator(imgs_real_A, imgs_real_B), real) + gan_loss(discriminator(imgs_real_A, imgs_fake), fake) 114 | 115 | d_loss.backward() 116 | optimizer_D.step() 117 | 118 | # == Generator update == # 119 | imgs_fake = generator(imgs_real_A) 120 | 121 | optimizer_G.zero_grad() 122 | 123 | g_loss = gan_loss(discriminator(imgs_real_A, imgs_fake), real) + opt.lambda_l1 * l1_loss(imgs_fake, imgs_real_B) 124 | 125 | g_loss.backward() 126 | optimizer_G.step() 127 | 128 | if vis: 129 | batches_done = epoch * len(dataloader_train) + i 130 | if batches_done % opt.sample_interval == 0: 131 | 132 | # Keep a record of losses for plotting. 133 | epochs.append(epoch + i/len(dataloader_train)) 134 | g_losses.append(g_loss.item()) 135 | d_losses.append(d_loss.item()) 136 | 137 | 138 | 139 | # Generate 5 images from the validation set. 140 | batch_A, batch_B = next(iter(dataloader_val)) 141 | 142 | imgs_val_A = Variable(batch_A.type(Tensor)) 143 | imgs_val_B = Variable(batch_B.type(Tensor)) 144 | imgs_fake_B = generator(imgs_val_A).detach().data 145 | 146 | # For visualization purposes. 147 | imgs_val_A = imgs_val_A.add_(1).div_(2) 148 | imgs_fake_B = imgs_fake_B.add_(1).div_(2) 149 | fake_val = torch.cat((imgs_val_A, imgs_fake_B), dim=2) 150 | 151 | 152 | # Display results on visdom page. 153 | vis.line( 154 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 155 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 156 | opts={ 157 | 'title': 'loss over time', 158 | 'legend': loss_legend, 159 | 'xlabel': 'epoch', 160 | 'ylabel': 'loss', 161 | 'width': 512, 162 | 'height': 512 163 | }, 164 | win=1) 165 | vis.images( 166 | fake_val, 167 | nrow=5, win=2, 168 | opts={ 169 | 'title': 'GAN output [Epoch {}]'.format(epoch), 170 | 'width': 512, 171 | 'height': 512, 172 | } 173 | ) 174 | -------------------------------------------------------------------------------- /ralsgan/out/GAN output [Iteration 31616].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/ralsgan/out/GAN output [Iteration 31616].png -------------------------------------------------------------------------------- /ralsgan/out/GAN output [Iteration 31680].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/ralsgan/out/GAN output [Iteration 31680].png -------------------------------------------------------------------------------- /ralsgan/preprocess_cat_dataset.py: -------------------------------------------------------------------------------- 1 | ## Modified version of https://github.com/microe/angora-blue/blob/master/cascade_training/describe.py by Erik Hovland 2 | ##### I modified it to work with Python 3, changed the paths and made it output a folder for images bigger than 64x64 and a folder for images bigger than 128x128 3 | 4 | import cv2 5 | import glob 6 | import math 7 | import sys 8 | 9 | def rotateCoords(coords, center, angleRadians): 10 | # Positive y is down so reverse the angle, too. 11 | angleRadians = -angleRadians 12 | xs, ys = coords[::2], coords[1::2] 13 | newCoords = [] 14 | n = min(len(xs), len(ys)) 15 | i = 0 16 | centerX = center[0] 17 | centerY = center[1] 18 | cosAngle = math.cos(angleRadians) 19 | sinAngle = math.sin(angleRadians) 20 | while i < n: 21 | xOffset = xs[i] - centerX 22 | yOffset = ys[i] - centerY 23 | newX = xOffset * cosAngle - yOffset * sinAngle + centerX 24 | newY = xOffset * sinAngle + yOffset * cosAngle + centerY 25 | newCoords += [newX, newY] 26 | i += 1 27 | return newCoords 28 | 29 | def preprocessCatFace(coords, image): 30 | 31 | leftEyeX, leftEyeY = coords[0], coords[1] 32 | rightEyeX, rightEyeY = coords[2], coords[3] 33 | mouthX = coords[4] 34 | if leftEyeX > rightEyeX and leftEyeY < rightEyeY and \ 35 | mouthX > rightEyeX: 36 | # The "right eye" is in the second quadrant of the face, 37 | # while the "left eye" is in the fourth quadrant (from the 38 | # viewer's perspective.) Swap the eyes' labels in order to 39 | # simplify the rotation logic. 40 | leftEyeX, rightEyeX = rightEyeX, leftEyeX 41 | leftEyeY, rightEyeY = rightEyeY, leftEyeY 42 | 43 | eyesCenter = (0.5 * (leftEyeX + rightEyeX), 44 | 0.5 * (leftEyeY + rightEyeY)) 45 | 46 | eyesDeltaX = rightEyeX - leftEyeX 47 | eyesDeltaY = rightEyeY - leftEyeY 48 | eyesAngleRadians = math.atan2(eyesDeltaY, eyesDeltaX) 49 | eyesAngleDegrees = eyesAngleRadians * 180.0 / math.pi 50 | 51 | # Straighten the image and fill in gray for blank borders. 52 | rotation = cv2.getRotationMatrix2D( 53 | eyesCenter, eyesAngleDegrees, 1.0) 54 | imageSize = image.shape[1::-1] 55 | straight = cv2.warpAffine(image, rotation, imageSize, 56 | borderValue=(128, 128, 128)) 57 | 58 | # Straighten the coordinates of the features. 59 | newCoords = rotateCoords( 60 | coords, eyesCenter, eyesAngleRadians) 61 | 62 | # Make the face as wide as the space between the ear bases. 63 | w = abs(newCoords[16] - newCoords[6]) 64 | # Make the face square. 65 | h = w 66 | # Put the center point between the eyes at (0.5, 0.4) in 67 | # proportion to the entire face. 68 | minX = eyesCenter[0] - w/2 69 | if minX < 0: 70 | w += minX 71 | minX = 0 72 | minY = eyesCenter[1] - h*2/5 73 | if minY < 0: 74 | h += minY 75 | minY = 0 76 | 77 | # Crop the face. 78 | crop = straight[int(minY):int(minY+h), int(minX):int(minX+w)] 79 | # Return the crop. 80 | return crop 81 | 82 | def describePositive(): 83 | output = open('log.txt', 'w') 84 | for imagePath in glob.glob('cat_dataset/*.jpg'): 85 | # Open the '.cat' annotation file associated with this 86 | # image. 87 | input = open('%s.cat' % imagePath, 'r') 88 | # Read the coordinates of the cat features from the 89 | # file. Discard the first number, which is the number 90 | # of features. 91 | coords = [int(i) for i in input.readline().split()[1:]] 92 | # Read the image. 93 | image = cv2.imread(imagePath) 94 | # Straighten and crop the cat face. 95 | crop = preprocessCatFace(coords, image) 96 | if crop is None: 97 | print >> sys.stderr, \ 98 | 'Failed to preprocess image at %s.' % \ 99 | imagePath 100 | continue 101 | # Save the crop to folders based on size 102 | h, w, colors = crop.shape 103 | if min(h,w) >= 64: 104 | Path1 = imagePath.replace("cat_dataset","cats_bigger_than_64x64") 105 | resized_crop = cv2.resize(crop, (64, 64)) 106 | cv2.imwrite(Path1, resized_crop) 107 | if min(h,w) >= 128: 108 | Path2 = imagePath.replace("cat_dataset","cats_bigger_than_128x128") 109 | resized_crop = cv2.resize(crop, (128, 128)) 110 | cv2.imwrite(Path2, resized_crop) 111 | if min(h,w) >= 256: 112 | Path2 = imagePath.replace("cat_dataset","cats_bigger_than_256x256") 113 | resized_crop = cv2.resize(crop, (256, 256)) 114 | cv2.imwrite(Path2, resized_crop) 115 | # Append the cropped face and its bounds to the 116 | # positive description. 117 | #h, w = crop.shape[:2] 118 | #print (cropPath, 1, 0, 0, w, h, file=output) 119 | 120 | 121 | def main(): 122 | describePositive() 123 | 124 | if __name__ == '__main__': 125 | main() -------------------------------------------------------------------------------- /ralsgan/ralsgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | To run this script, you must have the CATS dataset. 3 | I put the necessary code under this 4 | folder. Steps to follow (ref: @AlexiaJM's github) 5 | 1. Download: Cat Dataset (http://academictorrents.com/details/c501571c29d16d7f41d159d699d0e7fb37092cbd) 6 | (I used a version uploaded by @simoninithomas hence I have 7 | a slightly different version of cats .) 8 | 2. Run setting_up_script.sh in same folder as preprocess_cat_dataset.py 9 | and your CAT dataset (open and run manually) 10 | 3. mv cats_bigger_than_256x256 "cats/0" 11 | (Imagefolder class requires a subfolder under the given 12 | folder (indicating class) 13 | ''' 14 | 15 | 16 | ''' 17 | Relativistic GAN paper (https://arxiv.org/abs/1807.00734) 18 | is very inventive and experimental, hence comes with a 19 | bunch of different versions of the same underlying 20 | idea. I used Ralsgan (relativistic average least squares) 21 | for demonstration as it seems to give the most promising 22 | results for the high-definition outputs with a single shot 23 | network. 24 | A really interesting behavior I observe with this model is 25 | that although around 5k iterations I see some "catlike" 26 | images, the images look like noise before that for thousands 27 | of iterations, which is unlike the other GAN losses, where the 28 | quality improvement is more linear. Also note that the proper 29 | cat generations don't come before 10s of thousands of iterations. 30 | ''' 31 | 32 | import argparse 33 | import numpy 34 | 35 | import torch 36 | from torch import nn, optim 37 | from torch.autograd.variable import Variable 38 | 39 | from torchvision import transforms, datasets 40 | 41 | import os 42 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 43 | 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--n_iters', type=int, default=100e3, help='number of iterations of training') 47 | parser.add_argument('--batch_size', type=int, default=48, help='size of the batches') 48 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate') 49 | parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') 50 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 51 | parser.add_argument('--latent_dim', type=int, default=128, help='dimensionality of the latent space') 52 | parser.add_argument('--input_folder', type=str, default="cats", help='source images folder') 53 | parser.add_argument('--img_size', type=int, default=256, help='size of each image dimension') 54 | parser.add_argument('--channels', type=int, default=3, help='number of image channels') 55 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 56 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 57 | parser.add_argument('--sample_interval', type=int, default=64, help='interval betwen image samples') 58 | opt = parser.parse_args() 59 | 60 | def weights_init_normal(m): 61 | classname = m.__class__.__name__ 62 | if classname.find('Conv') != -1: 63 | m.weight.data.normal_(0.0, 0.02) 64 | elif classname.find('BatchNorm') != -1: 65 | m.weight.data.normal_(1.0, 0.02) 66 | m.bias.data.fill_(0) 67 | 68 | 69 | try: 70 | import visdom 71 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 72 | except ImportError: 73 | vis = None 74 | else: 75 | vis.close(None) # Clear all figures. 76 | 77 | 78 | img_dims = (opt.channels, opt.img_size, opt.img_size) 79 | n_features = opt.channels * opt.img_size * opt.img_size 80 | 81 | 82 | # Appendix D.4., DCGAN for 0. 83 | class Generator(nn.Module): 84 | def __init__(self): 85 | super(Generator, self).__init__() 86 | 87 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0): 88 | block = [ 89 | nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False), 90 | nn.BatchNorm2d(n_output), 91 | nn.ReLU(inplace=True), 92 | ] 93 | return block 94 | 95 | self.model = nn.Sequential( 96 | *convlayer(opt.latent_dim, 1024, 4, 1, 0), # Fully connected layer via convolution. 97 | *convlayer(1024, 512, 4, 2, 1), 98 | *convlayer(512, 256, 4, 2, 1), 99 | *convlayer(256, 128, 4, 2, 1), 100 | *convlayer(128, 64, 4, 2, 1), 101 | *convlayer(64, 32, 4, 2, 1), 102 | nn.ConvTranspose2d(32, opt.channels, 4, 2, 1), 103 | nn.Tanh() 104 | ) 105 | ''' 106 | There is a slight error in v2 of the relativistic gan paper, where 107 | the architecture goes from 128>64>32 but then 64>3. 108 | ''' 109 | 110 | 111 | def forward(self, z): 112 | z = z.view(-1, opt.latent_dim, 1, 1) 113 | img = self.model(z) 114 | return img 115 | 116 | # PACGAN2. 117 | class Discriminator(nn.Module): 118 | def __init__(self): 119 | super(Discriminator, self).__init__() 120 | 121 | def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, bn=False): 122 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)] 123 | if bn: 124 | block.append(nn.BatchNorm2d(n_output)) 125 | block.append(nn.LeakyReLU(0.2, inplace=True)) 126 | return block 127 | 128 | self.model = nn.Sequential( 129 | *convlayer(opt.channels * 2, 32, 4, 2, 1), 130 | *convlayer(32, 64, 4, 2, 1), 131 | *convlayer(64, 128, 4, 2, 1, bn=True), 132 | *convlayer(128, 256, 4, 2, 1, bn=True), 133 | *convlayer(256, 512, 4, 2, 1, bn=True), 134 | *convlayer(512, 1024, 4, 2, 1, bn=True), 135 | nn.Conv2d(1024, 1, 4, 1, 0, bias=False), # FC with Conv. 136 | ) 137 | 138 | def forward(self, imgs): 139 | critic_value = self.model(imgs) 140 | critic_value = critic_value.view(imgs.size(0), -1) 141 | return critic_value 142 | 143 | 144 | ''' 145 | Worst part of the data science: data (prep). 146 | I shamelessly copy these from @AlexiaJM's code. 147 | (https://github.com/AlexiaJM/RelativisticGAN) 148 | ''' 149 | 150 | transform = transforms.Compose([ 151 | transforms.Resize((opt.img_size, opt.img_size)), 152 | transforms.ToTensor(), 153 | transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]) 154 | ]) 155 | data = datasets.ImageFolder(root=os.path.join(os.getcwd(), opt.input_folder), transform=transform) 156 | 157 | 158 | def generate_random_sample(): 159 | while True: 160 | random_indexes = numpy.random.choice(data.__len__(), size=opt.batch_size * 2, replace=False) 161 | batch = [data[i][0] for i in random_indexes] 162 | yield torch.stack(batch, 0) 163 | 164 | 165 | random_sample = generate_random_sample() 166 | 167 | def mse_loss(input, target): 168 | return torch.sum((input - target)**2) / input.data.nelement() 169 | 170 | cuda = torch.cuda.is_available() 171 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 172 | gan_loss = mse_loss 173 | 174 | generator = Generator() 175 | discriminator = Discriminator() 176 | 177 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 178 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 179 | 180 | # Loss record. 181 | g_losses = [] 182 | d_losses = [] 183 | epochs = [] 184 | loss_legend = ['Discriminator', 'Generator'] 185 | 186 | if cuda: 187 | generator = generator.cuda() 188 | discriminator = discriminator.cuda() 189 | 190 | generator.apply(weights_init_normal) 191 | discriminator.apply(weights_init_normal) 192 | 193 | noise_fixed = Variable(Tensor(25, opt.latent_dim).normal_(0, 1), requires_grad=False) 194 | 195 | for it in range(int(opt.n_iters)): 196 | print('Iter. {}'.format(it)) 197 | 198 | batch = random_sample.__next__() 199 | 200 | imgs_real = Variable(batch.type(Tensor)) 201 | imgs_real = torch.cat((imgs_real[0:opt.batch_size, ...], imgs_real[opt.batch_size:opt.batch_size * 2, ...]), 1) 202 | real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False) 203 | 204 | noise = Variable(Tensor(opt.batch_size * 2, opt.latent_dim).normal_(0, 1)) 205 | imgs_fake = generator(noise) 206 | imgs_fake = torch.cat((imgs_fake[0:opt.batch_size, ...], imgs_fake[opt.batch_size:opt.batch_size * 2, ...]), 1) 207 | 208 | # == Discriminator update == # 209 | optimizer_D.zero_grad() 210 | 211 | c_xr = discriminator(imgs_real) 212 | c_xf = discriminator(imgs_fake.detach()) 213 | 214 | d_loss = gan_loss(c_xr, torch.mean(c_xf) + real) + gan_loss(c_xf, torch.mean(c_xr) - real) 215 | 216 | d_loss.backward() 217 | optimizer_D.step() 218 | 219 | # == Generator update == # 220 | batch = random_sample.__next__() 221 | 222 | imgs_real = Variable(batch.type(Tensor)) 223 | imgs_real = torch.cat((imgs_real[0:opt.batch_size, ...], imgs_real[opt.batch_size:opt.batch_size * 2, ...]), 1) 224 | 225 | noise = Variable(Tensor(opt.batch_size * 2, opt.latent_dim).normal_(0, 1)) 226 | imgs_fake = generator(noise) 227 | imgs_fake = torch.cat((imgs_fake[0:opt.batch_size, ...], imgs_fake[opt.batch_size:opt.batch_size * 2, ...]), 1) 228 | 229 | c_xr = discriminator(imgs_real) 230 | c_xf = discriminator(imgs_fake) 231 | real = Variable(Tensor(batch.size(0)//2, 1).fill_(1.0), requires_grad=False) 232 | 233 | optimizer_G.zero_grad() 234 | 235 | g_loss = gan_loss(c_xf, torch.mean(c_xr) + real) + gan_loss(c_xr, torch.mean(c_xf) - real) 236 | 237 | g_loss.backward() 238 | optimizer_G.step() 239 | 240 | if vis: 241 | if it % opt.sample_interval == 0: 242 | 243 | # Keep a record of losses for plotting. 244 | epochs.append(it) 245 | g_losses.append(g_loss.item()) 246 | d_losses.append(d_loss.item()) 247 | 248 | # Generate images for a given set of fixed noise 249 | # so we can track how the GAN learns. 250 | imgs_fake_fixed = generator(noise_fixed).detach().data 251 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom. 252 | 253 | # Display results on visdom page. 254 | vis.line( 255 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 256 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 257 | opts={ 258 | 'title': 'loss over time', 259 | 'legend': loss_legend, 260 | 'xlabel': 'iteration', 261 | 'ylabel': 'loss', 262 | 'width': 512, 263 | 'height': 512 264 | }, 265 | win=1) 266 | vis.images( 267 | imgs_fake_fixed, 268 | nrow=5, win=2, 269 | opts={ 270 | 'title': 'GAN output [Iteration {}]'.format(it), 271 | 'width': 512, 272 | 'height': 512, 273 | } 274 | ) -------------------------------------------------------------------------------- /ralsgan/setting_up_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # wget -nc http://www.simoninithomas.com/data/cats.zip 3 | unzip cats.zip -d cat_dataset 4 | # Move everything in one single directory 5 | mv cat_dataset/CAT_00/* cat_dataset 6 | rmdir cat_dataset/CAT_00 7 | mv cat_dataset/CAT_01/* cat_dataset 8 | rmdir cat_dataset/CAT_01 9 | mv cat_dataset/CAT_02/* cat_dataset 10 | rmdir cat_dataset/CAT_02 11 | mv cat_dataset/CAT_03/* cat_dataset 12 | rmdir cat_dataset/CAT_03 13 | mv cat_dataset/CAT_04/* cat_dataset 14 | rmdir cat_dataset/CAT_04 15 | mv cat_dataset/CAT_05/* cat_dataset 16 | rmdir cat_dataset/CAT_05 17 | mv cat_dataset/CAT_06/* cat_dataset 18 | rmdir cat_dataset/CAT_06 19 | ## Error correction 20 | rm cat_dataset/00000003_019.jpg.cat 21 | mv 00000003_015.jpg.cat cat_dataset/00000003_015.jpg.cat 22 | ## Removing outliers 23 | # Corrupted, drawings, badly cropped, inverted, impossible to tell it's a cat, blocked face 24 | cd cat_dataset 25 | rm 00000004_007.jpg 00000007_002.jpg 00000045_028.jpg 00000050_014.jpg 00000056_013.jpg 00000059_002.jpg 00000108_005.jpg 00000122_023.jpg 00000126_005.jpg 00000132_018.jpg 00000142_024.jpg 00000142_029.jpg 00000143_003.jpg 00000145_021.jpg 00000166_021.jpg 00000169_021.jpg 00000186_002.jpg 00000202_022.jpg 00000208_023.jpg 00000210_003.jpg 00000229_005.jpg 00000236_025.jpg 00000249_016.jpg 00000254_013.jpg 00000260_019.jpg 00000261_029.jpg 00000265_029.jpg 00000271_020.jpg 00000282_026.jpg 00000316_004.jpg 00000352_014.jpg 00000400_026.jpg 00000406_006.jpg 00000431_024.jpg 00000443_027.jpg 00000502_015.jpg 00000504_012.jpg 00000510_019.jpg 00000514_016.jpg 00000514_008.jpg 00000515_021.jpg 00000519_015.jpg 00000522_016.jpg 00000523_021.jpg 00000529_005.jpg 00000556_022.jpg 00000574_011.jpg 00000581_018.jpg 00000582_011.jpg 00000588_016.jpg 00000588_019.jpg 00000590_006.jpg 00000592_018.jpg 00000593_027.jpg 00000617_013.jpg 00000618_016.jpg 00000619_025.jpg 00000622_019.jpg 00000622_021.jpg 00000630_007.jpg 00000645_016.jpg 00000656_017.jpg 00000659_000.jpg 00000660_022.jpg 00000660_029.jpg 00000661_016.jpg 00000663_005.jpg 00000672_027.jpg 00000673_027.jpg 00000675_023.jpg 00000692_006.jpg 00000800_017.jpg 00000805_004.jpg 00000807_020.jpg 00000823_010.jpg 00000824_010.jpg 00000836_008.jpg 00000843_021.jpg 00000850_025.jpg 00000862_017.jpg 00000864_007.jpg 00000865_015.jpg 00000870_007.jpg 00000877_014.jpg 00000882_013.jpg 00000887_028.jpg 00000893_022.jpg 00000907_013.jpg 00000921_029.jpg 00000929_022.jpg 00000934_006.jpg 00000960_021.jpg 00000976_004.jpg 00000987_000.jpg 00000993_009.jpg 00001006_014.jpg 00001008_013.jpg 00001012_019.jpg 00001014_005.jpg 00001020_017.jpg 00001039_008.jpg 00001039_023.jpg 00001048_029.jpg 00001057_003.jpg 00001068_005.jpg 00001113_015.jpg 00001140_007.jpg 00001157_029.jpg 00001158_000.jpg 00001167_007.jpg 00001184_007.jpg 00001188_019.jpg 00001204_027.jpg 00001205_022.jpg 00001219_005.jpg 00001243_010.jpg 00001261_005.jpg 00001270_028.jpg 00001274_006.jpg 00001293_015.jpg 00001312_021.jpg 00001365_026.jpg 00001372_006.jpg 00001379_018.jpg 00001388_024.jpg 00001389_026.jpg 00001418_028.jpg 00001425_012.jpg 00001431_001.jpg 00001456_018.jpg 00001458_003.jpg 00001468_019.jpg 00001475_009.jpg 00001487_020.jpg 26 | rm 00000004_007.jpg.cat 00000007_002.jpg.cat 00000045_028.jpg.cat 00000050_014.jpg.cat 00000056_013.jpg.cat 00000059_002.jpg.cat 00000108_005.jpg.cat 00000122_023.jpg.cat 00000126_005.jpg.cat 00000132_018.jpg.cat 00000142_024.jpg.cat 00000142_029.jpg.cat 00000143_003.jpg.cat 00000145_021.jpg.cat 00000166_021.jpg.cat 00000169_021.jpg.cat 00000186_002.jpg.cat 00000202_022.jpg.cat 00000208_023.jpg.cat 00000210_003.jpg.cat 00000229_005.jpg.cat 00000236_025.jpg.cat 00000249_016.jpg.cat 00000254_013.jpg.cat 00000260_019.jpg.cat 00000261_029.jpg.cat 00000265_029.jpg.cat 00000271_020.jpg.cat 00000282_026.jpg.cat 00000316_004.jpg.cat 00000352_014.jpg.cat 00000400_026.jpg.cat 00000406_006.jpg.cat 00000431_024.jpg.cat 00000443_027.jpg.cat 00000502_015.jpg.cat 00000504_012.jpg.cat 00000510_019.jpg.cat 00000514_016.jpg.cat 00000514_008.jpg.cat 00000515_021.jpg.cat 00000519_015.jpg.cat 00000522_016.jpg.cat 00000523_021.jpg.cat 00000529_005.jpg.cat 00000556_022.jpg.cat 00000574_011.jpg.cat 00000581_018.jpg.cat 00000582_011.jpg.cat 00000588_016.jpg.cat 00000588_019.jpg.cat 00000590_006.jpg.cat 00000592_018.jpg.cat 00000593_027.jpg.cat 00000617_013.jpg.cat 00000618_016.jpg.cat 00000619_025.jpg.cat 00000622_019.jpg.cat 00000622_021.jpg.cat 00000630_007.jpg.cat 00000645_016.jpg.cat 00000656_017.jpg.cat 00000659_000.jpg.cat 00000660_022.jpg.cat 00000660_029.jpg.cat 00000661_016.jpg.cat 00000663_005.jpg.cat 00000672_027.jpg.cat 00000673_027.jpg.cat 00000675_023.jpg.cat 00000692_006.jpg.cat 00000800_017.jpg.cat 00000805_004.jpg.cat 00000807_020.jpg.cat 00000823_010.jpg.cat 00000824_010.jpg.cat 00000836_008.jpg.cat 00000843_021.jpg.cat 00000850_025.jpg.cat 00000862_017.jpg.cat 00000864_007.jpg.cat 00000865_015.jpg.cat 00000870_007.jpg.cat 00000877_014.jpg.cat 00000882_013.jpg.cat 00000887_028.jpg.cat 00000893_022.jpg.cat 00000907_013.jpg.cat 00000921_029.jpg.cat 00000929_022.jpg.cat 00000934_006.jpg.cat 00000960_021.jpg.cat 00000976_004.jpg.cat 00000987_000.jpg.cat 00000993_009.jpg.cat 00001006_014.jpg.cat 00001008_013.jpg.cat 00001012_019.jpg.cat 00001014_005.jpg.cat 00001020_017.jpg.cat 00001039_008.jpg.cat 00001039_023.jpg.cat 00001048_029.jpg.cat 00001057_003.jpg.cat 00001068_005.jpg.cat 00001113_015.jpg.cat 00001140_007.jpg.cat 00001157_029.jpg.cat 00001158_000.jpg.cat 00001167_007.jpg.cat 00001184_007.jpg.cat 00001188_019.jpg.cat 00001204_027.jpg.cat 00001205_022.jpg.cat 00001219_005.jpg.cat 00001243_010.jpg.cat 00001261_005.jpg.cat 00001270_028.jpg.cat 00001274_006.jpg.cat 00001293_015.jpg.cat 00001312_021.jpg.cat 00001365_026.jpg.cat 00001372_006.jpg.cat 00001379_018.jpg.cat 00001388_024.jpg.cat 00001389_026.jpg.cat 00001418_028.jpg.cat 00001425_012.jpg.cat 00001431_001.jpg.cat 00001456_018.jpg.cat 00001458_003.jpg.cat 00001468_019.jpg.cat 00001475_009.jpg.cat 00001487_020.jpg.cat 27 | cd .. 28 | ## Preprocessing and putting in folders for different image sizes 29 | mkdir cats_bigger_than_64x64 30 | mkdir cats_bigger_than_128x128 31 | mkdir cats_bigger_than_256x256 32 | python preprocess_cat_dataset.py 33 | ## Removing cat_dataset 34 | rm -r cat_dataset 35 | ## Move to your favorite place 36 | mv cats_bigger_than_256x256 cats -------------------------------------------------------------------------------- /srgan/datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils import data 3 | import glob 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | 7 | class Dataset(data.Dataset): 8 | 'Characterizes a dataset for PyTorch' 9 | def __init__(self, data_folder, transform_lr, transform_hr, stage): 10 | 'Initialization' 11 | file_list = glob.glob('{}/*'.format(data_folder)) 12 | n = len(file_list) 13 | train_size = np.floor(n * 0.8).astype(np.int) 14 | self.images = file_list[:train_size] if stage is 'train' else file_list[train_size:] 15 | 16 | self.transform_lr = transforms.Compose(transform_lr) 17 | self.transform_hr = transforms.Compose(transform_hr) 18 | 19 | def __len__(self): 20 | 'Denotes the total number of samples' 21 | return len(self.images) 22 | 23 | def __getitem__(self, index): 24 | 'Generates one sample of data' 25 | # Select sample 26 | 27 | # Load data and get label 28 | hr = Image.open(self.images[index]) 29 | 30 | return self.transform_lr(hr), self.transform_hr(hr) -------------------------------------------------------------------------------- /srgan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models import vgg19 4 | from torch.autograd import Variable 5 | 6 | 7 | def weights_init_normal(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm2d') != -1: 12 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | torch.nn.init.constant_(m.bias.data, 0.0) 14 | 15 | class ResidualBlock(nn.Module): 16 | def __init__(self, n_output=64, k_size=3, stride=1, padding=1): 17 | super(ResidualBlock, self).__init__() 18 | self.model = nn.Sequential( 19 | nn.Conv2d(n_output, n_output, k_size, stride, padding), 20 | nn.BatchNorm2d(n_output), 21 | nn.PReLU(), 22 | nn.Conv2d(n_output, n_output, k_size, stride, padding), 23 | nn.BatchNorm2d(n_output), 24 | ) 25 | 26 | def forward(self, x): 27 | return x + self.model(x) 28 | 29 | class ShuffleBlock(nn.Module): 30 | def __init__(self, n_input, n_output, k_size=3, stride=1, padding=1): 31 | super(ShuffleBlock, self).__init__() 32 | self.model = nn.Sequential( 33 | nn.Conv2d(n_input, n_output, k_size, stride, padding), # N, 256, H, W 34 | nn.PixelShuffle(2), # N, 64, 2H, 2W 35 | nn.PReLU(), 36 | ) 37 | ''' 38 | Input: :math:`(N, C * upscale_factor^2, H, W)` 39 | Output: :math:`(N, C, H * upscale_factor, W * upscale_factor)` 40 | ''' 41 | 42 | def forward(self, x): 43 | return self.model(x) 44 | 45 | 46 | # n_fmap = number of feature maps, 47 | # B = number of cascaded residual blocks. 48 | class Generator(nn.Module): 49 | def __init__(self, n_input=3, n_output=3, n_fmap=64, B=16): 50 | super(Generator, self).__init__() 51 | 52 | self.l1 = nn.Sequential( 53 | nn.Conv2d(n_input, n_fmap, 9, 1, 4), 54 | nn.PReLU(), 55 | ) 56 | 57 | # A cascaded of B residual blocks. 58 | self.R = [] 59 | for _ in range(B): 60 | self.R.append(ResidualBlock(n_fmap)) 61 | self.R = nn.Sequential(*self.R) 62 | 63 | self.l2 = nn.Sequential( 64 | nn.Conv2d(n_fmap, n_fmap, 3, 1, 1), 65 | nn.BatchNorm2d(n_fmap), 66 | ) 67 | 68 | self.px = nn.Sequential( 69 | ShuffleBlock(64, 256), 70 | ShuffleBlock(64, 256), 71 | ) 72 | 73 | self.conv_final = nn.Sequential( 74 | nn.Conv2d(64, n_output, 9, 1, 4), 75 | nn.Tanh(), 76 | ) 77 | 78 | 79 | def forward(self, img_in): 80 | out_1 = self.l1(img_in) 81 | out_2 = self.R(out_1) 82 | out_3 = out_1 + self.l2(out_2) 83 | out_4 = self.px(out_3) 84 | return self.conv_final(out_4) 85 | 86 | 87 | class Discriminator(nn.Module): 88 | def __init__(self, lr_channels=3): 89 | super(Discriminator, self).__init__() 90 | 91 | def convblock(n_input, n_output, k_size=3, stride=1, padding=1, bn=True): 92 | block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)] 93 | if bn: 94 | block.append(nn.BatchNorm2d(n_output)) 95 | block.append(nn.LeakyReLU(0.2, inplace=True)) 96 | return block 97 | 98 | self.conv = nn.Sequential( 99 | *convblock(lr_channels, 64, 3, 1, 1, bn=False), 100 | *convblock(64, 64, 3, 2, 1), 101 | *convblock(64, 128, 3, 1, 1), 102 | *convblock(128, 128, 3, 2, 1), 103 | *convblock(128, 256, 3, 1, 1), 104 | *convblock(256, 256, 3, 2, 1), 105 | *convblock(256, 512, 3, 1, 1), 106 | *convblock(512, 512, 3, 2, 1), 107 | ) 108 | 109 | self.fc = nn.Sequential( 110 | nn.Linear(512 * 16 * 16, 1024), 111 | nn.LeakyReLU(0.2, inplace=True), 112 | nn.Linear(1024, 1), 113 | nn.Sigmoid() 114 | ) 115 | 116 | def forward(self, img): 117 | out_1 = self.conv(img) 118 | out_1 = out_1.view(img.size(0), -1) 119 | out_2 = self.fc(out_1) 120 | return out_2 121 | 122 | 123 | class VGGFeatures(nn.Module): 124 | def __init__(self): 125 | super(VGGFeatures, self).__init__() 126 | model = vgg19(pretrained=True) 127 | 128 | children = list(model.features.children()) 129 | max_pool_indices = [index for index, m in enumerate(children) if isinstance(m, nn.MaxPool2d)] 130 | target_features = children[:max_pool_indices[4]] 131 | ''' 132 | We use vgg-5,4 which is the layer output after 5th conv 133 | and right before the 4th max pool. 134 | ''' 135 | self.features = nn.Sequential(*target_features) 136 | for p in self.features.parameters(): 137 | p.requires_grad = False 138 | 139 | ''' 140 | # VGG means and stdevs on pretrained imagenet 141 | mean = -1 + Variable(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 142 | std = 2*Variable(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 143 | 144 | # This is for cuda compatibility. 145 | self.register_buffer('mean', mean) 146 | self.register_buffer('std', std) 147 | ''' 148 | 149 | def forward(self, input): 150 | # input = (input - self.mean) / self.std 151 | output = self.features(input) 152 | return output -------------------------------------------------------------------------------- /srgan/out/GAN output [Epoch 6] (1).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/srgan/out/GAN output [Epoch 6] (1).jpg -------------------------------------------------------------------------------- /srgan/out/GAN output [Epoch 6].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/srgan/out/GAN output [Epoch 6].jpg -------------------------------------------------------------------------------- /srgan/srgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Literally, the 90% of my time was spent 3 | looking for a HD dataset that I can use 4 | for this GAN. It turns out not a lot of 5 | such sets exist. If you happen to have 6 | a nice custom dataset of say size 1024x 7 | 1024, it would fit perfectly for this 8 | code. I settled for "celeba align" 9 | dataset that isn't particularly high 10 | definition, but the memory constraints 11 | forced my hand on this. It is accessible: 12 | https://drive.google.com/open?id=0B7EVK8r0v71pWEZsZE9oNnFzTm8 13 | ''' 14 | 15 | import argparse 16 | 17 | import torch 18 | from torch import nn, optim 19 | from torch.autograd.variable import Variable 20 | 21 | from torchvision import transforms 22 | from torch.utils.data import DataLoader 23 | 24 | from models import Generator, Discriminator, VGGFeatures, weights_init_normal 25 | from datasets import Dataset 26 | 27 | from PIL import Image 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs of training') 31 | parser.add_argument('--batch_size', type=int, default=16, help='size of the batches') 32 | parser.add_argument('--lr', type=float, default=0.0001, help='adam: learning rate') 33 | parser.add_argument('--b1', type=float, default=0.9, help='adam: decay of first order momentum of gradient') 34 | parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') 35 | parser.add_argument('--img_size', type=int, default=256, help='high resolution image output size') 36 | parser.add_argument('--data_folder', type=str, default="img_align_celeba", help='where the training image pairs are located') 37 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 38 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 39 | parser.add_argument('--sample_interval', type=int, default=16, help='interval betwen image samples') 40 | opt = parser.parse_args() 41 | 42 | 43 | 44 | try: 45 | import visdom 46 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 47 | except ImportError: 48 | vis = None 49 | else: 50 | vis.close(None) # Clear all figures. 51 | 52 | 53 | # Celeba dataset loader. 54 | transform_image_hr = [ 55 | transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 58 | ] 59 | transform_image_lr = [ 60 | transforms.Resize((opt.img_size//4, opt.img_size//4), Image.BICUBIC), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 63 | ] 64 | data_train = Dataset(opt.data_folder, transform_image_lr, transform_image_hr, stage='train') 65 | data_val = Dataset(opt.data_folder, transform_image_lr, transform_image_hr, stage='val') 66 | 67 | params = {'batch_size': opt.batch_size, 'shuffle': True} 68 | dataloader_train = DataLoader(data_val, **params) 69 | 70 | params = {'batch_size': 5, 'shuffle': True} 71 | dataloader_val = DataLoader(data_val, **params) 72 | 73 | 74 | cuda = torch.cuda.is_available() 75 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 76 | 77 | gan_loss = nn.BCELoss() 78 | content_loss = nn.MSELoss() 79 | 80 | generator = Generator() 81 | discriminator = Discriminator() 82 | vgg = VGGFeatures() 83 | 84 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 85 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 86 | 87 | # Loss record. 88 | g_losses = [] 89 | d_losses = [] 90 | epochs = [] 91 | loss_legend = ['Discriminator', 'Generator'] 92 | 93 | if cuda: 94 | generator = generator.cuda() 95 | discriminator = discriminator.cuda() 96 | vgg = vgg.cuda() 97 | gan_loss = gan_loss.cuda() 98 | content_loss = content_loss.cuda() 99 | 100 | 101 | generator.apply(weights_init_normal) 102 | discriminator.apply(weights_init_normal) 103 | 104 | for epoch in range(opt.n_epochs): 105 | print('Epoch {}'.format(epoch)) 106 | for i, (batch_lr, batch_hr) in enumerate(dataloader_train): 107 | 108 | real = Variable(Tensor(batch_lr.size(0), 1).fill_(1), requires_grad=False) 109 | fake = Variable(Tensor(batch_lr.size(0), 1).fill_(0), requires_grad=False) 110 | 111 | imgs_real_lr = Variable(batch_lr.type(Tensor)) 112 | imgs_real_hr = Variable(batch_hr.type(Tensor)) 113 | 114 | # == Discriminator update == # 115 | optimizer_D.zero_grad() 116 | 117 | imgs_fake_hr = Variable(generator(imgs_real_lr.detach())) 118 | 119 | d_loss = gan_loss(discriminator(imgs_real_hr), real) + gan_loss(discriminator(imgs_fake_hr), fake) 120 | 121 | d_loss.backward() 122 | optimizer_D.step() 123 | 124 | # == Generator update == # 125 | imgs_fake_hr = generator(imgs_real_lr) 126 | 127 | optimizer_G.zero_grad() 128 | 129 | g_loss = (1/12.75) * content_loss(vgg(imgs_fake_hr), vgg(imgs_real_hr.detach())) + 1e-3 * gan_loss(discriminator(imgs_fake_hr), real) 130 | 131 | g_loss.backward() 132 | optimizer_G.step() 133 | 134 | if vis: 135 | batches_done = epoch * len(dataloader_train) + i 136 | if batches_done % opt.sample_interval == 0: 137 | 138 | # Keep a record of losses for plotting. 139 | epochs.append(epoch + i/len(dataloader_train)) 140 | g_losses.append(g_loss.item()) 141 | d_losses.append(d_loss.item()) 142 | 143 | 144 | 145 | # Generate 5 images from the validation set. 146 | batch_lr, batch_hr = next(iter(dataloader_val)) 147 | 148 | imgs_val_lr = Variable(batch_lr.type(Tensor)) 149 | imgs_val_hr = Variable(batch_hr.type(Tensor)) 150 | imgs_fake_hr = generator(imgs_val_lr).detach().data 151 | 152 | # For visualization purposes. 153 | imgs_val_lr = torch.nn.functional.upsample(imgs_val_lr, size=(imgs_fake_hr.size(2), imgs_fake_hr.size(3)), mode='bilinear') 154 | 155 | imgs_val_lr = imgs_val_lr.mul_(Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).add_(Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 156 | imgs_val_hr = imgs_val_hr.mul_(Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).add_(Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 157 | imgs_fake_hr = imgs_fake_hr.add_(torch.abs(torch.min(imgs_fake_hr))).div_(torch.max(imgs_fake_hr)-torch.min(imgs_fake_hr)) 158 | fake_val = torch.cat((imgs_val_lr, imgs_val_hr, imgs_fake_hr), dim=2) 159 | 160 | # Display results on visdom page. 161 | vis.line( 162 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 163 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 164 | opts={ 165 | 'title': 'loss over time', 166 | 'legend': loss_legend, 167 | 'xlabel': 'epoch', 168 | 'ylabel': 'loss', 169 | 'width': 512, 170 | 'height': 512 171 | }, 172 | win=1) 173 | vis.images( 174 | fake_val, 175 | nrow=5, win=2, 176 | opts={ 177 | 'title': 'GAN output [Epoch {}]'.format(epoch), 178 | 'width': 512, 179 | 'height': 512, 180 | } 181 | ) -------------------------------------------------------------------------------- /wgan-gp/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | # from torch.legacy.nn import Identity 3 | 4 | # Residual network. 5 | # WGAN-GP paper defines a residual block with up & downsampling. 6 | # See the official implementation (given in the paper). 7 | # I use architectures described in the official implementation, 8 | # since I find it hard to deduce the blocks given here from the text alone. 9 | class MeanPoolConv(nn.Module): 10 | def __init__(self, n_input, n_output, k_size): 11 | super(MeanPoolConv, self).__init__() 12 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 13 | self.model = nn.Sequential(conv1) 14 | def forward(self, x): 15 | out = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4.0 16 | out = self.model(out) 17 | return out 18 | 19 | class ConvMeanPool(nn.Module): 20 | def __init__(self, n_input, n_output, k_size): 21 | super(ConvMeanPool, self).__init__() 22 | conv1 = nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 23 | self.model = nn.Sequential(conv1) 24 | def forward(self, x): 25 | out = self.model(x) 26 | out = (out[:,:,::2,::2] + out[:,:,1::2,::2] + out[:,:,::2,1::2] + out[:,:,1::2,1::2]) / 4.0 27 | return out 28 | 29 | class UpsampleConv(nn.Module): 30 | def __init__(self, n_input, n_output, k_size): 31 | super(UpsampleConv, self).__init__() 32 | 33 | self.model = nn.Sequential( 34 | nn.PixelShuffle(2), 35 | nn.Conv2d(n_input, n_output, k_size, stride=1, padding=(k_size-1)//2, bias=True) 36 | ) 37 | def forward(self, x): 38 | x = x.repeat((1, 4, 1, 1)) # Weird concat of WGAN-GPs upsampling process. 39 | out = self.model(x) 40 | return out 41 | 42 | class ResidualBlock(nn.Module): 43 | def __init__(self, n_input, n_output, k_size, resample='up', bn=True, spatial_dim=None): 44 | super(ResidualBlock, self).__init__() 45 | 46 | self.resample = resample 47 | 48 | if resample == 'up': 49 | self.conv1 = UpsampleConv(n_input, n_output, k_size) 50 | self.conv2 = nn.Conv2d(n_output, n_output, k_size, padding=(k_size-1)//2) 51 | self.conv_shortcut = UpsampleConv(n_input, n_output, k_size) 52 | self.out_dim = n_output 53 | elif resample == 'down': 54 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 55 | self.conv2 = ConvMeanPool(n_input, n_output, k_size) 56 | self.conv_shortcut = ConvMeanPool(n_input, n_output, k_size) 57 | self.out_dim = n_output 58 | self.ln_dims = [n_input, spatial_dim, spatial_dim] # Define the dimensions for layer normalization. 59 | else: 60 | self.conv1 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 61 | self.conv2 = nn.Conv2d(n_input, n_input, k_size, padding=(k_size-1)//2) 62 | self.conv_shortcut = None # Identity 63 | self.out_dim = n_input 64 | self.ln_dims = [n_input, spatial_dim, spatial_dim] 65 | 66 | self.model = nn.Sequential( 67 | nn.BatchNorm2d(n_input) if bn else nn.LayerNorm(self.ln_dims), 68 | nn.ReLU(inplace=True), 69 | self.conv1, 70 | nn.BatchNorm2d(self.out_dim) if bn else nn.LayerNorm(self.ln_dims), 71 | nn.ReLU(inplace=True), 72 | self.conv2, 73 | ) 74 | 75 | def forward(self, x): 76 | if self.conv_shortcut is None: 77 | return x + self.model(x) 78 | else: 79 | return self.conv_shortcut(x) + self.model(x) 80 | 81 | class DiscBlock1(nn.Module): 82 | def __init__(self, n_output): 83 | super(DiscBlock1, self).__init__() 84 | 85 | self.conv1 = nn.Conv2d(3, n_output, 3, padding=(3-1)//2) 86 | self.conv2 = ConvMeanPool(n_output, n_output, 1) 87 | self.conv_shortcut = MeanPoolConv(3, n_output, 1) 88 | 89 | self.model = nn.Sequential( 90 | self.conv1, 91 | nn.ReLU(inplace=True), 92 | self.conv2 93 | ) 94 | 95 | def forward(self, x): 96 | return self.conv_shortcut(x) + self.model(x) 97 | 98 | class Generator(nn.Module): 99 | def __init__(self): 100 | super(Generator, self).__init__() 101 | 102 | self.model = nn.Sequential( # 128 x 1 x 1 103 | nn.ConvTranspose2d(128, 128, 4, 1, 0), # 128 x 4 x 4 104 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 8 x 8 105 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 16 x 16 106 | ResidualBlock(128, 128, 3, resample='up'), # 128 x 32 x 32 107 | nn.BatchNorm2d(128), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(128, 3, 3, padding=(3-1)//2), # 3 x 32 x 32 110 | nn.Tanh() 111 | ) 112 | 113 | def forward(self, z): 114 | img = self.model(z) 115 | return img 116 | 117 | class Discriminator(nn.Module): 118 | def __init__(self): 119 | super(Discriminator, self).__init__() 120 | n_output = 128 121 | ''' 122 | This is a parameter but since we experiment with a single size 123 | of 3 x 32 x 32 images, it is hardcoded here. 124 | ''' 125 | 126 | self.DiscBlock1 = DiscBlock1(n_output) # 128 x 16 x 16 127 | 128 | self.model = nn.Sequential( 129 | ResidualBlock(n_output, n_output, 3, resample='down', bn=False, spatial_dim=16), # 128 x 8 x 8 130 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8 131 | ResidualBlock(n_output, n_output, 3, resample=None, bn=False, spatial_dim=8), # 128 x 8 x 8 132 | nn.ReLU(inplace=True), 133 | ) 134 | self.l1 = nn.Sequential(nn.Linear(128, 1)) # 128 x 1 135 | 136 | def forward(self, x): 137 | # x = x.view(-1, 3, 32, 32) 138 | y = self.DiscBlock1(x) 139 | y = self.model(y) 140 | y = y.view(x.size(0), 128, -1) 141 | y = y.mean(dim=2) 142 | out = self.l1(y).unsqueeze_(1).unsqueeze_(2) # or *.view(x.size(0), 128, 1, 1, 1) 143 | return out 144 | -------------------------------------------------------------------------------- /wgan-gp/out/36737d5d4d1ebe.svg: -------------------------------------------------------------------------------- 1 | 01020304050−150−100−500DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /wgan-gp/out/GAN output [Epoch 105].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan-gp/out/GAN output [Epoch 105].jpg -------------------------------------------------------------------------------- /wgan-gp/out/GAN output [Epoch 52].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan-gp/out/GAN output [Epoch 52].jpg -------------------------------------------------------------------------------- /wgan-gp/wgan_gp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch import nn, optim 5 | from torch.autograd.variable import Variable 6 | 7 | from torchvision import transforms, datasets 8 | from torch.utils.data import DataLoader 9 | 10 | from models import Discriminator, Generator 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') 14 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 15 | parser.add_argument('--alpha', type=float, default=0.0001, help='adam: learning rate') 16 | parser.add_argument('--b1', type=float, default=0.5, help='adam: beta 1') 17 | parser.add_argument('--b2', type=float, default=0.9, help='adam: beta 2') 18 | parser.add_argument('--n_critic', type=int, default=5, help='number of critic iterations per generator iteration') 19 | parser.add_argument('--lambda_1', type=int, default=10, help='gradient penalty coefficient') 20 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 21 | parser.add_argument('--channels', type=int, default=3, help='is image rgb (3) or grayscale (1) ?') 22 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 23 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 24 | parser.add_argument('--sample_interval', type=int, default=256, help='interval betwen image samples') 25 | opt = parser.parse_args() 26 | 27 | try: 28 | import visdom 29 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 30 | except ImportError: 31 | vis = None 32 | else: 33 | vis.close(None) # Clear all figures. 34 | 35 | img_dims = (opt.channels, opt.img_size, opt.img_size) 36 | n_features = opt.channels * opt.img_size * opt.img_size 37 | 38 | # TODO: Use some initialization in the future. 39 | def init_weights(m): 40 | if type(m) == nn.ConvTranspose2d: 41 | torch.nn.init.kaiming_normal(m.weight, mode='fan_out', nonlinearity='relu') 42 | elif type(m) == nn.Conv2d: 43 | torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 44 | 45 | def load_cifar10(img_size): 46 | compose = transforms.Compose( 47 | [transforms.Resize(img_size), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 50 | ]) 51 | output_dir = './data/cifar10' 52 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True, 53 | transform=compose) 54 | return cifar 55 | 56 | cifar = load_cifar10(opt.img_size) 57 | batch_iterator = DataLoader(cifar, shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 58 | 59 | cuda = torch.cuda.is_available() 60 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 61 | gan_loss = nn.BCELoss() 62 | 63 | generator = Generator() 64 | discriminator = Discriminator() 65 | 66 | optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.alpha, betas=(opt.b1, opt.b2)) 67 | optimizer_G = optim.Adam(generator.parameters(), lr=opt.alpha, betas=(opt.b1, opt.b2)) 68 | 69 | # Loss record. 70 | g_losses = [] 71 | d_losses = [] 72 | epochs = [] 73 | loss_legend = ['Discriminator', 'Generator'] 74 | 75 | if cuda: 76 | generator = generator.cuda() 77 | discriminator = discriminator.cuda() 78 | 79 | noise_fixed = Variable(Tensor(25, 128, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN. 80 | 81 | for epoch in range(opt.n_epochs): 82 | print('Epoch {}'.format(epoch)) 83 | for i, (batch, _) in enumerate(batch_iterator): 84 | # == Discriminator update == # 85 | for iter in range(opt.n_critic): 86 | # Sample real and fake images, using notation in paper. 87 | x = Variable(batch.type(Tensor)) 88 | noise = Variable(Tensor(batch.size(0), 128, 1, 1).normal_(0, 1)) 89 | x_tilde = Variable(generator(noise), requires_grad=True) 90 | 91 | epsilon = Variable(Tensor(batch.size(0), 1, 1, 1).uniform_(0, 1)) 92 | 93 | x_hat = epsilon*x + (1 - epsilon)*x_tilde 94 | x_hat = torch.autograd.Variable(x_hat, requires_grad=True) 95 | 96 | # Put the interpolated data through critic. 97 | dw_x = discriminator(x_hat) 98 | # A great exercise on learning how the autograd.grad works! 99 | grad_x = torch.autograd.grad(outputs=dw_x, inputs=x_hat, 100 | grad_outputs=Variable(Tensor(batch.size(0), 1, 1, 1).fill_(1.0), requires_grad=False), 101 | create_graph=True, retain_graph=True, only_inputs=True) 102 | grad_x = grad_x[0].view(batch.size(0), -1) 103 | grad_x = grad_x.norm(p=2, dim=1) # My naming is inaccurate, this is the 2-norm of grad(D_w(x_hat)) 104 | 105 | # Update discriminator (or critic, since we don't output probabilities anymore). 106 | optimizer_D.zero_grad() 107 | 108 | # WGAN-GP loss, defined properly as a loss unlike the WGAN paper. 109 | d_loss = torch.mean(discriminator(x_tilde)) - torch.mean(discriminator(x)) + opt.lambda_1*torch.mean((grad_x - 1)**2) 110 | # d_loss = torch.mean(d_loss) # there's a reason for why this shouldn't be done this way :) 111 | 112 | d_loss.backward() 113 | optimizer_D.step() 114 | 115 | # == Generator update == # 116 | noise = Variable(Tensor(batch.size(0), 128, 1, 1).normal_(0, 1)) 117 | imgs_fake = generator(noise) 118 | 119 | optimizer_G.zero_grad() 120 | 121 | g_loss = -torch.mean(discriminator(imgs_fake)) 122 | 123 | g_loss.backward() 124 | optimizer_G.step() 125 | 126 | if vis: 127 | batches_done = epoch * len(batch_iterator) + i 128 | if batches_done % opt.sample_interval == 0: 129 | 130 | # Keep a record of losses for plotting. 131 | epochs.append(epoch + i/len(batch_iterator)) 132 | g_losses.append(g_loss.item()) 133 | d_losses.append(d_loss.item()) 134 | 135 | # Generate images for a given set of fixed noise 136 | # so we can track how the GAN learns. 137 | imgs_fake_fixed = generator(noise_fixed).detach().data 138 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) 139 | 140 | # Display results on visdom page. 141 | vis.line( 142 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 143 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 144 | opts={ 145 | 'title': 'loss over time', 146 | 'legend': loss_legend, 147 | 'xlabel': 'epoch', 148 | 'ylabel': 'loss', 149 | 'width': 512, 150 | 'height': 512 151 | }, 152 | win=1) 153 | vis.images( 154 | imgs_fake_fixed[:25], 155 | nrow=5, win=2, 156 | opts={ 157 | 'title': 'GAN output [Epoch {}]'.format(epoch), 158 | 'width': 512, 159 | 'height': 512, 160 | } 161 | ) -------------------------------------------------------------------------------- /wgan/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | 3 | def load_cifar10(img_size): 4 | compose = transforms.Compose( 5 | [transforms.Resize(img_size), 6 | transforms.ToTensor(), 7 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 8 | ]) 9 | output_dir = './data/cifar10' 10 | cifar = datasets.CIFAR10(root=output_dir, download=True, train=True, 11 | transform=compose) 12 | return cifar 13 | 14 | def load_mnist(img_size): 15 | compose = transforms.Compose( 16 | [transforms.Resize(img_size), 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 19 | ]) 20 | output_dir = './data/mnist' 21 | return datasets.MNIST(root=output_dir, train=True, 22 | transform=compose, download=True) 23 | -------------------------------------------------------------------------------- /wgan/out/bn_GAN output [Epoch 99].jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan/out/bn_GAN output [Epoch 99].jpg -------------------------------------------------------------------------------- /wgan/out/no_bn_366d9efab2beac.svg: -------------------------------------------------------------------------------- 1 | 02040608000.10.20.30.40.50.60.7DiscriminatorGeneratorloss over timeepochloss -------------------------------------------------------------------------------- /wgan/out/no_bn_GAN output [Epoch 99] (1).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan/out/no_bn_GAN output [Epoch 99] (1).jpg -------------------------------------------------------------------------------- /wgan/out/real_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozanciga/gans-with-pytorch/2071efd166935f0b4fb321227e94aa2ad1cfa273/wgan/out/real_samples.png -------------------------------------------------------------------------------- /wgan/wgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Wasserstein DCGAN. 3 | 4 | Q: Does lower loss indicate better images, i.e., 5 | does convergence really mean anything? In regular GANs, it doesn't. 6 | A: Kind of. I would say having a steadily decreasing loss is nice, but it's not always a 1-1 mapping. 7 | (related: Many Paths to Equilibrium: GANs Do Not Need to Decrease a Divergence At Every Step) 8 | 9 | Q: Is it possible to train a stable network without batch normalization, as claimed? 10 | A: Sort of, batch norm still makes things way faster. Also, there is this practice of 11 | not using BN at the first layer of discriminator which I found to not matter. 12 | However, I noticed irregular behavior when BN is not used, such as sudden drops in error. 13 | ''' 14 | 15 | 16 | import argparse 17 | import os 18 | 19 | import torch 20 | from torch import nn, optim 21 | from torch.autograd.variable import Variable 22 | 23 | from torch.utils.data import DataLoader 24 | from torchvision.utils import save_image 25 | 26 | from dataset_loader import load_mnist, load_cifar10 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--dataset', default='cifar10', help='cifar10 | mnist') 30 | parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training') 31 | parser.add_argument('--batch_size', type=int, default=64, help='size of the batches') 32 | parser.add_argument('--lr', type=float, default=0.00005, help='adam: learning rate') 33 | parser.add_argument('--c', type=float, default=0.01, help='clipping value for the gradients') 34 | parser.add_argument('--n_critic', type=int, default=5, help='number of discriminator updates per each generator update') 35 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 36 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 37 | parser.add_argument('--channels', type=int, default=3, help='number of image channels') 38 | parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs') 39 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 40 | parser.add_argument('--sample_interval', type=int, default=256, help='interval betwen image samples') 41 | opt = parser.parse_args() 42 | 43 | 44 | try: 45 | import visdom 46 | vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env. 47 | except ImportError: 48 | vis = None 49 | else: 50 | vis.close(None) # Clear all figures. 51 | 52 | def weights_init(m): 53 | classname = m.__class__.__name__ 54 | if classname.find('Conv') != -1: 55 | m.weight.data.normal_(0.0, 0.02) 56 | elif classname.find('BatchNorm') != -1: 57 | m.weight.data.normal_(1.0, 0.02) 58 | m.bias.data.fill_(0) 59 | 60 | img_dims = (opt.channels, opt.img_size, opt.img_size) 61 | n_features = opt.channels * opt.img_size * opt.img_size 62 | 63 | # DCGAN network. 64 | class Generator(nn.Module): 65 | def __init__(self): 66 | super(Generator, self).__init__() 67 | 68 | def convblock(n_input, n_output, k_size=4, stride=2, padding=0, normalize=True): 69 | block = [nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)] 70 | if normalize: 71 | block.append(nn.BatchNorm2d(n_output)) 72 | block.append(nn.ReLU(inplace=True)) 73 | return block 74 | 75 | self.project = nn.Sequential( 76 | nn.Linear(opt.latent_dim, 256 * 4 * 4), 77 | nn.BatchNorm1d(256 * 4 * 4), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | self.model = nn.Sequential( 82 | *convblock(opt.latent_dim, 256, 4, 1, 0), 83 | *convblock(256, 128, 4, 2, 1), 84 | *convblock(128, 64, 4, 2, 1), 85 | nn.ConvTranspose2d(64, opt.channels, 4, 2, 1), 86 | nn.Tanh() 87 | ) 88 | # Tanh > Image values are between [-1, 1] 89 | 90 | def forward(self, z): 91 | img = self.model(z) 92 | img = img.view(z.size(0), *img_dims) 93 | return img 94 | 95 | 96 | class Discriminator(nn.Module): 97 | def __init__(self): 98 | super(Discriminator, self).__init__() 99 | 100 | def convblock(n_input, n_output, kernel_size=4, stride=2, padding=1, normalize=True): 101 | block = [nn.Conv2d(n_input, n_output, kernel_size, stride, padding, bias=False)] 102 | if normalize: 103 | block.append(nn.BatchNorm2d(n_output)) 104 | block.append(nn.LeakyReLU(0.2, inplace=True)) 105 | return block 106 | 107 | self.model = nn.Sequential( 108 | *convblock(opt.channels, 64, 4, 2, 1, normalize=False), # 32-> 16, (32+2p-3-1)/2 + 1 = 16, p = 1, apparently BN at 1st layer is detrimental for WGANs. 109 | *convblock(64, 128, 4, 2, 1), 110 | *convblock(128, 256, 4, 2, 1), 111 | nn.Conv2d(256, 1, 4, 1, 0, bias=False), # FC with Conv. 112 | nn.Sigmoid() 113 | ) 114 | 115 | def forward(self, img): 116 | prob = self.model(img) 117 | return prob 118 | 119 | 120 | assert (opt.dataset == 'cifar10' or opt.dataset == 'mnist'), 'Unknown dataset! Only cifar10 and mnist are supported.' 121 | 122 | if opt.dataset == 'cifar10': 123 | batch_iterator = DataLoader(load_cifar10(opt.img_size), shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 124 | elif opt.dataset == 'mnist': 125 | batch_iterator = DataLoader(load_mnist(opt.img_size), shuffle=True, batch_size=opt.batch_size) # List, NCHW format. 126 | 127 | # Save a batch of real images for reference. 128 | os.makedirs('./out', exist_ok=True) 129 | save_image(next(iter(batch_iterator))[0][:25, ...], './out/real_samples.png', nrow=5, normalize=True) 130 | 131 | 132 | cuda = torch.cuda.is_available() 133 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 134 | gan_loss = nn.BCELoss() 135 | 136 | generator = Generator() 137 | discriminator = Discriminator() 138 | 139 | optimizer_D = optim.RMSprop(discriminator.parameters(), lr=opt.lr) 140 | optimizer_G = optim.RMSprop(generator.parameters(), lr=opt.lr) 141 | 142 | # Loss record. 143 | g_losses = [] 144 | d_losses = [] 145 | epochs = [] 146 | loss_legend = ['Discriminator', 'Generator'] 147 | 148 | if cuda: 149 | generator = generator.cuda() 150 | discriminator = discriminator.cuda() 151 | 152 | generator.apply(weights_init) 153 | discriminator.apply(weights_init) 154 | 155 | noise_fixed = Variable(Tensor(25, opt.latent_dim, 1, 1).normal_(0, 1), requires_grad=False) # To track the progress of the GAN. 156 | 157 | for epoch in range(opt.n_epochs): 158 | print('Epoch {}'.format(epoch)) 159 | for i, (batch, _) in enumerate(batch_iterator): 160 | 161 | # == Discriminator update == # 162 | for iter in range(opt.n_critic): 163 | # Sample real and fake images. 164 | imgs_real = Variable(batch.type(Tensor)) 165 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1)) 166 | imgs_fake = Variable(generator(noise), requires_grad=False) 167 | 168 | # Update discriminator (or critic, since we don't output probabilities anymore). 169 | optimizer_D.zero_grad() 170 | 171 | # WGAN utility, we ascend on this hence the loss will be the negative. 172 | d_loss = -torch.mean(discriminator(imgs_real) - discriminator(imgs_fake)) 173 | 174 | d_loss.backward() 175 | optimizer_D.step() 176 | 177 | # Detrimental clipping of the generator weights, 178 | # which will be fixed in a few months from this paper. 179 | for params in discriminator.parameters(): 180 | params.data.clamp_(-opt.c, +opt.c) 181 | 182 | # == Generator update == # 183 | noise = Variable(Tensor(batch.size(0), opt.latent_dim, 1, 1).normal_(0, 1)) 184 | imgs_fake = generator(noise) 185 | 186 | optimizer_G.zero_grad() 187 | 188 | g_loss = -torch.mean(discriminator(imgs_fake)) 189 | 190 | g_loss.backward() 191 | optimizer_G.step() 192 | 193 | if vis: 194 | batches_done = epoch * len(batch_iterator) + i 195 | if batches_done % opt.sample_interval == 0: 196 | 197 | # Keep a record of losses for plotting. 198 | epochs.append(epoch + i/len(batch_iterator)) 199 | g_losses.append(-g_loss.item()) # Negative because the loss is actually maximized in WGAN. 200 | d_losses.append(-d_loss.item()) 201 | 202 | # Generate images for a given set of fixed noise 203 | # so we can track how the GAN learns. 204 | imgs_fake_fixed = generator(noise_fixed).detach().data 205 | imgs_fake_fixed = imgs_fake_fixed.add_(1).div_(2) # To normalize and display on visdom. 206 | 207 | # Display results on visdom page. 208 | vis.line( 209 | X=torch.stack([Tensor(epochs)] * len(loss_legend), dim=1), 210 | Y=torch.stack((Tensor(d_losses), Tensor(g_losses)), dim=1), 211 | opts={ 212 | 'title': 'loss over time', 213 | 'legend': loss_legend, 214 | 'xlabel': 'epoch', 215 | 'ylabel': 'loss', 216 | 'width': 512, 217 | 'height': 512 218 | }, 219 | win=1) 220 | vis.images( 221 | imgs_fake_fixed, 222 | nrow=5, win=2, 223 | opts={ 224 | 'title': 'GAN output [Epoch {}]'.format(epoch), 225 | 'width': 512, 226 | 'height': 512, 227 | } 228 | ) --------------------------------------------------------------------------------